使用SAM2获取视频物体连续帧蒙版

最新更新于: 2026年1月8日晚上8点35分

源代码直接clone我fork的SAM2仓库里面就有全部代码,修改视频和图像路径即可使用:wty-yy-mini/SAM2

由于视频中的物体需要进行连续帧打码, 又不能打码过多导致视频展示性降低, 因此尝试通过SAM2框选prompt跟踪掩码的功能来对视频中的物体进行打码, 我的电脑配置如下

  1. CPU: R7-5700X
  2. GPU: RTX 4080
  3. 内存: 50GB

只对原视频中三帧框选需要蒙版的对象即可, 整个视频分割效果如下

原视频 下载链接
SAM2分割打码后效果 下载链接

最后利用SAM2分割得到蒙版制作频闪摄影效果,用普通的固定位相机录像就可以得到,无需三脚架、快门线、黑色背景、闪光灯等设备:
利用SAM2达到频闪摄影效果

安装SAM2

推荐安装Conda Miniforge环境: miniforge, 再根据SAM2 GitHub 官方安装流程安装好Pytorch, SAM2, 下载好模型文件, 放到sam2/checkpoints文件夹下:

  1. sam2.1_hiera_tiny.pt
  2. sam2.1_hiera_small.pt
  3. sam2.1_hiera_base_plus.pt

如果后续运行显存不够, 可以尝试更小的模型文件, 我这里使用的是sam2.1_hiera_base_plus

分割视频

由于SAM2每次需要将整个视频划分为图片, 再一次性全部读入内存中, 并读取全部的prompt, 如果处理的视频时间过长则会显存溢出, 我的显存只有16GB每次只能处理20秒的数据, 下面我们先对视频数据进行处理:

裁剪命令如下, 可以将mp4文件按照20s中所有帧提取到同级目录下的文件夹中, 用于后续处理

python extract_frames_from_video.py --video-file g1_dance_demo.mp4 --folder-duration 20

extract_frames_from_video.py源代码如下

# -*- coding: utf-8 -*-
'''
@File    : extract_frames_from_video.py
@Time    : 2026/01/08 19:49:41
@Author  : wty-yy
@Version : 1.0
@Blog    : https://wty-yy.github.io/
@Desc    : 从视频中按指定时间间隔提取帧并保存为图片存储到对应文件夹中。
python extract_frames_from_video.py \
    --video-file g1_dance_demo.mp4 \
    --output-folder g1_dance_demo_frames \
    --start 2 --end 4 --folder-duration 0.2
'''
import os
import cv2
import time
from pathlib import Path
from typing import Optional

def extract_frames(
        video_path: str,
        output_dir: Optional[str] = None,
        start: int = 0,
        end: Optional[int] = None,
        interval_sec: Optional[float] = None,
        folder_duration: Optional[float] = None,
        folder_prefix: str = "frame",
        folder_start_idx: int = 1,
    ):
    """
    从视频中按指定时间间隔提取帧并保存为图片存储到对应文件夹中, 例如:
        folder_duration=20, start=0, end=60, 则会创建3个文件夹:

        output_dir/
            1_frame0-19/
            2_frame20-39/
            3_frame40-59/

        每个文件夹内保存对应时间段的视频帧图片。

    Args:
        video_path (str): 输入视频文件的路径。
        output_dir (str): 保存图片的目标目录, 默认在video_path同级路径下创建同名文件夹。
        start (int): 开始时间点 (秒), 默认从0秒开始。
        end (Optional[int]): 结束时间点 (秒), 默认到视频结尾。
        interval_sec (Optional[float]): 提取帧的时间间隔 (秒), 默认提取所有帧。
        folder_duration (Optional[float]): 每个子文件夹包含的视频时长 (秒), 默认不分文件夹。
        folder_prefix (str): 子文件夹前缀名称, 默认 "frame"。
        folder_start_idx (int): 子文件夹起始索引, 默认从1开始。
    """
    if not os.path.exists(video_path):
        print(f"[ERROR] Video file {video_path} not found.")
        return

    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"[ERROR] Can't open video {video_path}.")
        return

    fps = cap.get(cv2.CAP_PROP_FPS)
    if fps == 0:
        print("[ERROR] Can't get FPS.")
        cap.release()
        return
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    print(f"Process video: {video_path}")
    print(f"  - FPS: {fps:.2f}")

    if end is None or end > total_frames / fps:
        end = total_frames / fps
    if output_dir is None:
        output_dir = str(Path(video_path).parent / f"{Path(video_path).stem}_frames")
    output_dir = str(output_dir) + f"_{int(1/interval_sec) if interval_sec is not None else int(fps)}fps"
    output_dir = Path(output_dir)
    output_dirs = []
    output_dir.mkdir(parents=True, exist_ok=True)
    now_save_dir = output_dir / f"{folder_start_idx}_{folder_prefix}{int(start*fps)}-{int(end*fps-1)}"
    if folder_duration is None:
        now_save_dir.mkdir(parents=True, exist_ok=True)
        output_dirs.append(now_save_dir.name)
    now_end = None  # 当前子文件夹结束时间点 (秒)


    # 循环读取和保存帧
    frame_count = 0  # 记录总共读取了多少帧
    saved_count = 0  # 记录总共保存了多少帧
    next_save_time_sec = 0.0  # 下一个保存图片的目标时间点(秒)

    use_tqdm = True
    try:
        from tqdm import tqdm
        bar = tqdm(range(0, int(end * fps)), desc="Extracting frames", unit=" frames")
    except ImportError:
        bar = range(0, int(end * fps))
        use_tqdm = False
    
    start_time = time.time()
    for _ in bar:
        if cap.isOpened():
            flag, frame = cap.read()
        if not flag:
            break

        current_time_sec = frame_count / fps
        frame_count += 1
        if end is not None and current_time_sec >= end:
            break
        if current_time_sec < start:
            continue

        # 检查是否达到了保存时间点
        if current_time_sec > next_save_time_sec or interval_sec is None:
            if folder_duration is not None and (now_end is None or current_time_sec >= now_end):
                # 创建新的子文件夹
                now_save_dir = output_dir / f"{folder_start_idx}_{folder_prefix}{int(current_time_sec*fps)}-{int(min(current_time_sec+folder_duration, end)*fps-1)}"
                output_dirs.append(now_save_dir.name)
                now_save_dir.mkdir(parents=True, exist_ok=True)
                folder_start_idx += 1
                saved_count = 0
                now_end = current_time_sec + folder_duration

            output_filename = now_save_dir / f"{saved_count:05d}.jpg"
            cv2.imwrite(output_filename, frame)
            saved_count += 1
            if interval_sec is not None:
                next_save_time_sec += interval_sec
        if not use_tqdm:
            if frame_count % int(fps) == 0:
                time_used = time.time() - start_time
                print(f"  - Processed {frame_count}/{total_frames} frames, remaining {int(time_used)}<{int((total_frames - frame_count) / frame_count * time_used)} sec")

    cap.release()
    print(f"\n处理完成。")
    print(f"  - 总共读取帧数: {frame_count}")
    print(f"  - 图片保存至: '{output_dir}'")
    print(f"  - 包含子文件夹: {output_dirs}")

from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--video-file", type=str, required=True, help="Path to the video file.")
parser.add_argument("--output-folder", type=str, help="Directory to save extracted frames.")
parser.add_argument("--start", type=int, default=0, help="Start time in seconds.")
parser.add_argument("--end", type=int, help="End time in seconds.")
parser.add_argument("--folder-duration", type=float, help="Duration of each folder in seconds.")
parser.add_argument("--interval-sec", type=float, help="Interval between extracted frames in seconds.")
parser.add_argument("--folder-prefix", type=str, default="frame", help="Prefix for folder names.")
parser.add_argument("--start-idx", type=int, default=1, help="Starting index for folder naming.")

args = parser.parse_args()

extract_frames(
    video_path=args.video_file,
    output_dir=args.output_folder,
    start=args.start,
    end=args.end,
    folder_duration=args.folder_duration,
    interval_sec=args.interval_sec,
    folder_prefix=args.folder_prefix,
    folder_start_idx=args.start_idx,
)

LabelMe制作prompt

SAM2支持两种prompt: points和boxes, 通过prompt可以给出我们想要分割的目标对象, 在第ii帧我们给出了分割对象, 后面模型就会自动追踪将其分割出来, 这里我们使用LabelMe来绘制boxes prompt, 这里推荐下载GitHub - LabelMe - v5.5.0可执行文件, 直接运行

LabelMe的配置文件在~/.labelmerc (Linux), 用户目录/.labelmerc (Windows)下, 如果每次启动的配置不符合要求, 可以修改配置文件, 便于下次启动

打开LabelMe, 首先进行配置:

  1. File -> Save Automatically 打开
  2. File -> Save With Image Data 关闭
  3. Edit -> Keep Previous Annotation 关闭

然后点击Open Dir, 打开刚才我们执行完extract_frames_from_video.py得到的数据文件夹g1_dance_demo_frames_30fps/1_frame0-149, 选择矩形边界框如下图所示, 框选出边界即可, 相同目标的label给成相同的, 上下帧会自动对应上(例如我这里就有四个类别0,1,2,3), Ctrl+S保存标记 (由于我们选择了Save Automatically, 切换到其他的图片也能自动保存标记)

第0帧 第14帧 第29帧
sam2_mask_prompts_1 sam2_mask_prompts_2 sam2_mask_prompts_3

SAM2蒙版生成

完成上述多帧的prompt标记后, 完成后我们的--video-parent-dir下格式应该为 (这里举个例子)

VID_20251210_094125_frames_30fps  # 执行extract_frames_from_video.py后包含每段的分割文件
├── 1_frame450-600              # 分段1
│    ├── 00000.jpg              # 分割图片
│    ├── 00000.json             # 第一帧必须要有目标框
│    ├── *.jpg
│    ├── 00016.jpg
│    ├── 00016.json             # 后续帧蒙版遗漏了可以手动补上框
│    └── ...
├── 2_frame600-750              # 分段2
└── 3_frame750-900              # 分段3

运行下面分割代码python sam2_segment_video.py ..., 最后就会在同文件夹下生成每个片段的蒙版视频, 如最上面的视频效果,具体使用说明可以看代码开头的注释部分, 同时保存视频和masks可以得到如下的文件结构

VID_20251210_094125_frames_30fps
├── 1_frame450-600_masked.mp4
├── 1_frame450-600_masks
├── 1_frame600-750_masked.mp4
├── 1_frame600-750_masks
├── 1_frame750-900_masked.mp4
└── 1_frame750-900_masks
# -*- coding: utf-8 -*-
'''
@File    : sam2_segment_video.py
@Time    : 2026/01/08 19:49:22
@Author  : wty-yy
@Version : 1.0
@Blog    : https://wty-yy.github.io/posts/4856/
@Desc    : 用于对视频帧进行SAM2分割并保存分割结果。

clone SAM2仓库并pip install -e .安装后,还需安装 pip install imageio[ffmpeg]

首先使用extract_frames_from_video.py (参考blog) 从视频中提取帧,然后将该脚本放在SAM2仓库中的tools/目录下,下载SAM2模型
sam2.1_hiera_base_plus.pt 或 sam2.1_hiera_tiny.pt 到checkpoints文件夹中,最后运行本脚本进行分割。

python tools/sam2_segment_video_extract.py \
    --video-parent-dir /home/yy/Downloads/VID_20251210_094125_frames_30fps \
    --show-prompts \
    --save-mask-video \
    --save-mask-frames \
    --mask-color avg

--video-parent-dir: 包含视频帧文件夹的父目录, 例如 /path/to/video_frames_parent_dir 下有多个子文件夹, 每个子文件夹内包含对应视频的帧图片。
--show-prompts: 是否弹出窗口以显示提示点和边界框的可视化。
--save-mask-video: 是否保存带有分割掩码的视频,保存在`*_masks.mp4`文件夹中。
--save-mask-frames: 是否保存分割掩码的帧图片,保存在`*_masks`文件夹中,蒙版将用(255,255,255)表示,其他均为(0,0,0)。
--mask-color: 用于掩码区域的颜色 ('white'表示白色,'avg'表示使用区域的平均颜色)。
'''
from pathlib import Path

import os
# if using Apple MPS, fall back to CPU for unsupported ops
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import torch

import json
import imageio
import argparse
import numpy as np
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt
from typing import Literal

def show_mask(mask, ax, obj_id=None, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        cmap = plt.get_cmap("tab10")
        cmap_idx = 0 if obj_id is None else obj_id
        color = np.array([*cmap(cmap_idx)[:3], 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=200):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))

class SAM2SegmentVideoProcessor:
    def __init__(self,
            show_prompts=False,
            save_mask_video=False,
            save_mask_frames=False,
            mask_color: Literal['white', 'avg']="white",
        ):
        self.show_prompts = show_prompts
        self.save_mask_video = save_mask_video
        self.save_mask_frames = save_mask_frames
        self.mask_color = mask_color
        # select the device for computation
        if torch.cuda.is_available():
            device = torch.device("cuda")
        else:
            device = torch.device("cpu")
        print(f"using device: {device}")

        if device.type == "cuda":
            # use bfloat16 for the entire notebook
            torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
            # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
            if torch.cuda.get_device_properties(0).major >= 8:
                torch.backends.cuda.matmul.allow_tf32 = True
                torch.backends.cudnn.allow_tf32 = True

        from sam2.build_sam import build_sam2_video_predictor

        sam2_checkpoint = "./checkpoints/sam2.1_hiera_base_plus.pt"
        model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"
        # sam2_checkpoint = "./checkpoints/sam2.1_hiera_tiny.pt"
        # model_cfg = "./configs/sam2.1/sam2.1_hiera_t.yaml"

        self.predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device)
        self.label2obj_id = {}
        self.num_prompts = 0

    def init_state(self, video_dir: str):
        self.video_dir = Path(video_dir)
        self.frames = sorted(Path(video_dir).rglob("*.jpg"))
        self.w, self.h = Image.open(self.frames[0]).size
        print(f"found {len(self.frames)} frames")

        # Init model
        self.inference_state = self.predictor.init_state(video_path=video_dir)
        self.predictor.reset_state(self.inference_state)

    def load_frame_prompt(self):
        """ Find all json files for frame prompts (from LabelMe) """
        for frame_json in sorted(self.video_dir.rglob("*.json")):
            frame_idx = int(frame_json.stem)
            with open(frame_json, "r") as f:
                labelme_data = json.load(f)
            shapes = labelme_data["shapes"]
            if len(shapes) == 0: continue
            plt.figure(figsize=(self.w / 100, self.h / 100), dpi=100)
            plt.title(f"Frame {frame_idx} with Box Prompts")
            plt.imshow(Image.open(self.frames[frame_idx]))
            add_infos = {}
            for shape in shapes:
                if shape['shape_type'] not in ['rectangle', 'point']:
                    continue
                if shape['shape_type'] == 'rectangle':
                    label = shape['label']
                elif shape['shape_type'] == 'point':
                    label_with_flag = shape['label']
                    label = label_with_flag.split('_')[0]
                    point_label = 0 if 'neg' in label_with_flag else 1
                if label not in self.label2obj_id:
                    self.label2obj_id[label] = len(self.label2obj_id)
                if self.label2obj_id[label] not in add_infos:
                    add_infos[self.label2obj_id[label]] = {
                        'inference_state': self.inference_state,
                        'frame_idx': frame_idx,
                        'obj_id': self.label2obj_id[label],
                        'box': [],
                        'points': [],
                        'labels': []
                    }
                infos = add_infos[self.label2obj_id[label]]

                if shape['shape_type'] == 'rectangle':
                    box = shape['points']  # [[x0, y0], [x1, y1]]
                    x0, y0 = box[0]
                    x1, y1 = box[1]
                    box = [min(x0, x1), min(y0, y1), max(x0, x1), max(y0, y1)]
                    infos['box'].append(box)
                    show_box(box, plt.gca())
                if shape['shape_type'] == 'point':
                    point = shape['points']  # [[x, y]]
                    infos['points'].append(point[0])
                    infos['labels'].append(point_label)
                    show_points(np.array(point), np.array([point_label]), plt.gca())

            if len(add_infos) != 0:
                for infos in add_infos.values():
                    _, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box(**infos)

            for i, out_obj_id in enumerate(out_obj_ids):
                show_mask((out_mask_logits[i] > 0).cpu().numpy(), plt.gca(), obj_id=out_obj_id)
            if self.show_prompts:
                plt.tight_layout()
                plt.axis('off')
                plt.show()
            else:
                plt.close()
            self.num_prompts += 1
    
    def segment_frames(self):
        if self.save_mask_frames:
            output_dir  = self.video_dir.parent / f"{self.video_dir.name}_masks"
            output_dir.mkdir(exist_ok=True, parents=True)
        if self.save_mask_video:
            output_video = self.video_dir.parent / f"{self.video_dir.name}_masked.mp4"
            writer = imageio.get_writer(output_video, fps=30)
        video_segments = {}
        if self.num_prompts > 0:
            for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(self.inference_state):
                video_segments[out_frame_idx] = {
                    out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
                    for i, out_obj_id in enumerate(out_obj_ids)
                }
        for frame_idx in tqdm(range(len(self.frames))):
            # Write video
            img = Image.open(self.frames[frame_idx])
            img = np.array(img)
            if self.save_mask_video and frame_idx not in video_segments:
                writer.append_data(img)
                continue
            for out_obj_id, out_mask in video_segments[frame_idx].items():
                mask = out_mask.reshape(self.h, self.w)
                pixels = img[mask].astype(np.float32)
                if len(pixels) > 0:
                    if self.mask_color == 'avg':
                        img[mask] = np.mean(pixels, axis=0).astype(np.uint8)  # color the masked area with mean color
                    elif self.mask_color == 'white':
                        img[mask] = 255  # white out the masked area

            if self.save_mask_video:
                writer.append_data(img)

            # Save segmented frames
            if self.save_mask_frames:
                mask = np.zeros_like(img)
                for out_obj_id, out_mask in video_segments[frame_idx].items():
                    mask[out_mask.reshape(self.h, self.w)] = 255
                output_path = output_dir / f"{frame_idx:05d}.jpg"
                Image.fromarray(mask).save(output_path)

            # Matplotlib visualization (optional)
            # plt.figure(figsize=(self.w / 100, self.h / 100), dpi=100)
            # plt.title(f"Frame {frame_idx}")
            # plt.imshow(Image.open(self.frames[frame_idx]))
            # for out_obj_id, out_mask in video_segments[frame_idx].items():
            #     show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
            # plt.axis('off')
            # output_path = output_dir / f"{frame_idx:05d}.jpg"
            # plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
            # plt.close()
        writer.close()
        print(f"Segmented video saved to: {output_video}")

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--video-parent-dir", type=str, default="", help="Path to the parent of video frames directory.")
    parser.add_argument("--show-prompts", action="store_true", help="Whether to show prompts visualization.")
    parser.add_argument("--save-mask-video", action="store_true", help="Whether to save masked video.")
    parser.add_argument("--save-mask-frames", action="store_true", help="Whether to save masked frames.")
    parser.add_argument("--mask-color", type=str, default="white", help="Color to use for masking (e.g., 'white', 'avg').")
    args = parser.parse_args()

    show_prompts = args.show_prompts
    video_parent_dir = args.video_parent_dir
    save_mask_video = args.save_mask_video
    save_mask_frames = args.save_mask_frames
    mask_color = args.mask_color
    kwargs = {
        "show_prompts": show_prompts,
        "save_mask_video": save_mask_video,
        "save_mask_frames": save_mask_frames,
        "mask_color": mask_color,
    }

    video_dirs = [x for x in sorted(Path(video_parent_dir).glob("*")) if x.is_dir()]
    for video_dir in video_dirs:
        idx = int(video_dir.name.split("_")[0])
        if 'masks' == video_dir.name.split('_')[-1]:
            continue
        if idx >= 1:
        # if 2 <= idx <= 9 and idx not in []:
            sam2_segment_video_processor = SAM2SegmentVideoProcessor(**kwargs)
            print(f"Processing video directory: {video_dir}")
            sam2_segment_video_processor.init_state(str(video_dir))
            sam2_segment_video_processor.load_frame_prompt()
            sam2_segment_video_processor.segment_frames()

频闪摄影效果图片制作

有了上述的蒙版, 假如要做出同一张图下的连续残影移动(频闪摄影), 效果图如下

利用SAM2达到频闪摄影效果

原理非常简单, 只需要找到想要的几帧作为虚影, 再将其中的关键帧高亮出来即可, 下面代码前25行逐行进行对应配置即可

# -*- coding: utf-8 -*-
'''
@File    : image_trajectory.py
@Time    : 2026/01/08 20:32:07
@Author  : wty-yy
@Version : 1.0
@Blog    : https://wty-yy.github.io/posts/4856/
@Desc    : 在视频帧序列中叠加轨迹蒙版以生成频闪摄影效果图像。
'''
from PIL import Image
import numpy as np

base_dir = "./tools/VID_20251210_094125_frames_30fps"   # 视频裁剪主文件夹

origin_image_dir = f"{base_dir}/1_frame450-659"         # 原始图像文件夹
mask_dir = f"{base_dir}/1_frame450-659_masks"           # 分割掩码文件夹
base_image = f"{base_dir}/base_image.png"               # 基础背景图像, 后续在此基础上叠加蒙版
trajectory_idxs = [54, 72, 84, 94, 96, 109, 125, 133, 140, 152, 178, 192]  # 轨迹帧索引 (包含虚影和关键帧)
key_idxs = [54, 96, 109, 133, 192]  # 关键帧索引 (对应下面的颜色映射)

# 颜色定义
TAB_BLUE = np.array([31, 119, 180], dtype=np.uint8)
TAB_ORANGE = np.array([255, 127, 14], dtype=np.uint8)
TAB_GREEN = np.array([44, 160, 44], dtype=np.uint8)
TAB_RED = np.array([214, 39, 40], dtype=np.uint8)
TINT_STRENGTH = 0.6  # 颜色叠加强度

key_idx_color_map = {  # 关键帧索引到颜色的映射
    192: TAB_GREEN,  # Resumed Gait (恢复行走)
    133: TAB_RED,    # Impact Absorption (冲击吸收)
    109: TAB_BLUE,   # Rapid Adaptation (快速适应)
    96:  TAB_ORANGE, # Support Loss (支撑丢失)
    54:  TAB_GREEN,  # Steady-state Gait (稳态行走)
}

def get_img_and_key(idx):
    mask_path = f"{mask_dir}/{idx:05d}.jpg"
    origin_image_path = f"{origin_image_dir}/{idx:05d}.jpg"
    mask = np.array(Image.open(mask_path).convert("L"))
    mask = mask > 200  # 边缘有些噪声, 加大mask阈值可以消去这些
    origin_image = np.array(Image.open(origin_image_path).convert("RGBA"))
    return origin_image, mask

result_image = np.array(Image.open(base_image).convert("RGBA"))

print("Processing trajectory frames (ghosts)...")
for i, traj_idx in enumerate(reversed(trajectory_idxs)):
    origin_image, mask = get_img_and_key(traj_idx)
    if traj_idx not in key_idxs:
        result_image[mask] = origin_image[mask] * 0.4 + result_image[mask] * 0.6

print("Processing key frames with coloring...")
for key_idx in reversed(key_idxs):
    origin_image, mask = get_img_and_key(key_idx)
    target_color = key_idx_color_map.get(key_idx)

    robot_pixels_rgba = origin_image[mask]
    robot_rgb = robot_pixels_rgba[:, :3]
    robot_alpha = robot_pixels_rgba[:, 3:4]
    color_overlay = np.full_like(robot_rgb, target_color)
    tinted_rgb = (robot_rgb * (1 - TINT_STRENGTH) + color_overlay * TINT_STRENGTH).astype(np.uint8)
    tinted_rgba = np.concatenate([tinted_rgb, robot_alpha], axis=1)
    result_image[mask] = tinted_rgba

result_image = Image.fromarray(result_image)
result_image.save(f"{base_dir}/result_image.png")
print(f"Result image saved to {base_dir}/result_image.png")

使用SAM2获取视频物体连续帧蒙版
https://wty-yy.xyz/posts/4856/
作者
wty
发布于
2025年11月8日
许可协议