[GitHub 구현] SAM 2: Segment Anything in Images and Videos

방선생·2025년 9월 16일
0

PROJECT

목록 보기
5/5

SAM 2: Segment Anything in Images and Videos - github



SAM2 in Images (Meta GitHub)


  • 콘다 환경 생성

conda create -n sam2 python=3.11
conda activate sam2


  • SAM2 git clone 및 필요한 라이브러리 다운로드

    • The code requires python>=3.10, as well as torch>=2.5.1 and torchvision>=0.20.1
git clone https://github.com/facebookresearch/sam2.git && cd sam2

pip install -e .


  • Download Checkpoints

cd checkpoints && \
./download_ckpts.sh && \
cd ..


  • segment_auto.py 생성 (~/sam2)

    • 그냥 GitHub 예제 실행시키면 GPU runtime error 떠서 메모리문제 해결을 위해 직접 코드 작성
# file: segment_auto.py
import argparse
import os
import cv2
import numpy as np
import torch

from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

def load_image_rgb(path: str):
    img_bgr = cv2.imread(path, cv2.IMREAD_COLOR)
    if img_bgr is None:
        raise FileNotFoundError(f"Image not found: {path}")
    return cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)

def save_overlay(rgb, masks, out_path, alpha=0.5):
    overlay = rgb.copy().astype(np.float32)
    colors = np.random.randint(0, 255, size=(len(masks), 3), dtype=np.uint8)
    for i, mk in enumerate(masks):
        m = mk["segmentation"].astype(bool)
        overlay[m] = overlay[m] * (1 - alpha) + colors[i] * alpha
    overlay = np.clip(overlay, 0, 255).astype(np.uint8)
    cv2.imwrite(out_path, cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))

def save_binary_masks(masks, out_dir, base):
    os.makedirs(out_dir, exist_ok=True)
    for i, mk in enumerate(masks):
        m = (mk["segmentation"].astype(np.uint8) * 255)
        cv2.imwrite(os.path.join(out_dir, f"{base}_mask_{i:03d}.png"), m)

# === [최소 추가 1] 저메모리용 리사이즈 유틸 ===
def resize_long_side_keep_aspect(rgb, max_side: int):
    if max_side is None or max_side <= 0:
        return rgb
    h, w = rgb.shape[:2]
    long_side = max(h, w)
    if long_side <= max_side:
        return rgb
    scale = max_side / float(long_side)
    nh, nw = int(round(h * scale)), int(round(w * scale))
    return cv2.resize(rgb, (nw, nh), interpolation=cv2.INTER_AREA)

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--image", required=True, help="입력 이미지 경로")
    ap.add_argument("--checkpoint", default="./checkpoints/sam2.1_hiera_small.pt",
                    help="SAM2 체크포인트(.pt) 경로")
    ap.add_argument("--config", default="configs/sam2.1/sam2.1_hiera_s.yaml",
                    help="SAM2 설정 YAML 경로 (패키지 내부 경로 기준: configs/...)")
    ap.add_argument("--out", default="out", help="출력 폴더")
    ap.add_argument("--overlay", action="store_true", help="반투명 오버레이 PNG도 저장")

    ap.add_argument("--max-side", type=int, default=0,
                    help="긴 변을 이 크기로 리사이즈(0이면 원본 유지). 예: 1024")
    ap.add_argument("--pps", type=int, default=8,
                    help="points_per_side (샘플링 밀도, 낮출수록 메모리↓; 예: 4)")
    ap.add_argument("--ppb", type=int, default=32,
                    help="points_per_batch (배치 크기, 낮출수록 피크 메모리↓; 예: 16)")
    ap.add_argument("--crop-layers", type=int, default=0,
                    help="crop_n_layers (멀티크롭 단계 수, 0이면 OFF)")

    args = ap.parse_args()

    os.makedirs(args.out, exist_ok=True)
    rgb = load_image_rgb(args.image)

    rgb = resize_long_side_keep_aspect(rgb, args.max_side)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    sam2_model = build_sam2(
        args.config,
        args.checkpoint,
        device=device,
        apply_postprocessing=False
    )

    mask_generator = SAM2AutomaticMaskGenerator(
        sam2_model,
        points_per_side=args.pps,
        points_per_batch=args.ppb,
        crop_n_layers=args.crop_layers,
        # 필요시 더 줄이고 싶을 때 추가 옵션:
        # crop_overlap_ratio=0.25,
        # pred_iou_thresh=0.90,
        # stability_score_thresh=0.95,
    )

    with torch.inference_mode():
        ctx = (
            torch.autocast(device_type="cuda", dtype=torch.bfloat16)
            if device == "cuda" else torch.cpu.amp.autocast(enabled=False)
        )
        with ctx:
            masks = mask_generator.generate(rgb)

    print(f"[INFO] generated {len(masks)} masks")
    stem = os.path.splitext(os.path.basename(args.image))[0]
    save_binary_masks(masks, args.out, stem)

    if args.overlay:
        ov_path = os.path.join(args.out, f"{stem}_overlay.png")
        save_overlay(rgb, masks, ov_path)
        print(f"[INFO] overlay saved at {ov_path}")

if __name__ == "__main__":
    main()


  • 실행

python segment_auto.py \
--image /home/bang/Downloads/IMG_1180.jpeg \
--checkpoint ./checkpoints/sam2.1_hiera_base_plus.pt \
--config configs/sam2.1/sam2.1_hiera_b+.yaml \
--out runs/sam2_auto_bplus \
--overlay \
--max-side 2048 --pps 32 --ppb 64 --crop-layers 0


+

SAM2 in Videos (ultralytics SAM2)

ultralytics - sam-2

  • 콘다 환경 생성 및 ultralytics 다운

conda create -n ultralytics python=3.11
conda activate ultralytics
pip install ultralytics


  • python sam2_video.py 생성

    • ultralytics에서 SAM2_<버전>.pt 다운
import argparse
from pathlib import Path
import time
import sys
import cv2
import numpy as np
from ultralytics import SAM

try:
    from tqdm.auto import tqdm
    HAVE_TQDM = True
except Exception:
    HAVE_TQDM = False


def human_time(seconds: float) -> str:
    if seconds is None or seconds != seconds or seconds == float("inf"):
        return "N/A"
    m, s = divmod(int(seconds), 60)
    h, m = divmod(m, 60)
    if h:
        return f"{h:d}:{m:02d}:{s:02d}"
    return f"{m:d}:{s:02d}"


def iou_matrix(masks_a, masks_b):
    if len(masks_a) == 0 or len(masks_b) == 0:
        return np.zeros((len(masks_a), len(masks_b)), dtype=np.float32)
    A = np.stack(masks_a, axis=0).astype(np.uint8)
    B = np.stack(masks_b, axis=0).astype(np.uint8)
    inter = (A[:, None] & B[None]).sum(axis=(2, 3)).astype(np.float32)
    union = (A[:, None] | B[None]).sum(axis=(2, 3)).astype(np.float32)
    with np.errstate(divide="ignore", invalid="ignore"):
        iou = np.where(union > 0, inter / union, 0.0)
    return iou


def color_from_id(i):
    np.random.seed(i * 9973 + 12345)
    c = np.random.randint(50, 205, size=3, dtype=np.uint8)
    return (int(c[0]), int(c[1]), int(c[2]))


class MaskTracker:
    def __init__(self, iou_thr=0.5, ema_alpha=0.5, min_area=400, smooth_kernel=3, memory=3):
        self.iou_thr = iou_thr
        self.ema_alpha = ema_alpha
        self.min_area = min_area
        self.smooth_kernel = smooth_kernel
        self.memory = memory
        self.tracks = {}
        self.next_id = 1

    def _postprocess(self, m):
        if self.smooth_kernel > 0:
            k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (self.smooth_kernel, self.smooth_kernel))
            m = cv2.morphologyEx(m, cv2.MORPH_CLOSE, k)
        num, labels, stats, _ = cv2.connectedComponentsWithStats(m.astype(np.uint8), 4)
        out = np.zeros_like(m, dtype=np.uint8)
        for i in range(1, num):
            if stats[i, cv2.CC_STAT_AREA] >= self.min_area:
                out[labels == i] = 1
        return out

    def update(self, curr_masks_bin):
        curr = [self._postprocess(m) for m in curr_masks_bin]
        prev_ids = list(self.tracks.keys())
        prev_bin = []
        for tid in prev_ids:
            prev_bin.append((self.tracks[tid]["mask"] >= 0.5).astype(np.uint8))

        M = iou_matrix(prev_bin, curr)
        used_curr = set()
        matched = []

        if M.size > 0:
            pairs = []
            for i in range(M.shape[0]):
                for j in range(M.shape[1]):
                    if M[i, j] >= self.iou_thr:
                        pairs.append((M[i, j], i, j))
            pairs.sort(reverse=True)
            used_prev = set()
            for iou, i, j in pairs:
                if i in used_prev or j in used_curr:
                    continue
                tid = prev_ids[i]
                used_prev.add(i); used_curr.add(j)
                matched.append((tid, j))

        new_ids = []
        for j, m in enumerate(curr):
            if j in used_curr:
                continue
            tid = self.next_id
            self.next_id += 1
            self.tracks[tid] = {"mask": m.astype(np.float32), "age": 0}
            new_ids.append(tid)

        for tid, j in matched:
            prev_ema = self.tracks[tid]["mask"]
            curr_m = curr[j].astype(np.float32)
            ema = self.ema_alpha * curr_m + (1.0 - self.ema_alpha) * prev_ema
            self.tracks[tid]["mask"] = ema
            self.tracks[tid]["age"] = 0

        to_del = []
        for tid in self.tracks:
            if any(tid == t for t, _ in matched):
                continue
            self.tracks[tid]["age"] += 1
            if self.tracks[tid]["age"] > self.memory:
                to_del.append(tid)
        for tid in to_del:
            del self.tracks[tid]

        out = []
        for tid in sorted(self.tracks.keys()):
            m = (self.tracks[tid]["mask"] >= 0.5).astype(np.uint8)
            if m.sum() >= self.min_area:
                out.append((tid, m))
        return out


def parse_args():
    p = argparse.ArgumentParser(description="SAM2 video (stable render: ID/color persist + EMA smoothing)")
    p.add_argument("--source", type=str, default="/home/bang/sam/fish.mp4", help="입력 비디오 경로")
    p.add_argument("--out", type=str, default=None, help="출력 MP4 경로(기본: 입력파일명_sam2.mp4)")
    p.add_argument("--model", type=str, default="/home/bang/sam/sam2_s.pt", help="SAM2 가중치 경로/이름")
    p.add_argument("--imgsz", type=int, default=1024, help="SAM 입력 정사각 크기")
    p.add_argument("--device", type=str, default="0", help="CUDA 장치 인덱스 또는 'cpu'")
    p.add_argument("--verbose", action="store_true", help="자세한 로그")
    p.add_argument("--log_every", type=int, default=30, help="tqdm 미사용 시 로그 간격")
    p.add_argument("--iou_thr", type=float, default=0.5, help="트랙 매칭 IoU 임계값")
    p.add_argument("--ema_alpha", type=float, default=0.5, help="EMA 가중치(현재 프레임 비중)")
    p.add_argument("--min_area", type=int, default=400, help="최소 마스크 면적(픽셀)")
    p.add_argument("--smooth_kernel", type=int, default=3, help="모폴로지 닫기 커널(0이면 비활성)")
    p.add_argument("--memory", type=int, default=3, help="매칭 실패 허용 프레임(트랙 유지 길이)")
    p.add_argument("--alpha_overlay", type=float, default=0.45, help="마스크 오버레이 투명도(0~1)")
    return p.parse_args()


def main():
    args = parse_args()

    print("[1/7] 입력 비디오 열기:", args.source)
    cap = cv2.VideoCapture(args.source)
    if not cap.isOpened():
        raise FileNotFoundError(f"비디오를 열 수 없습니다: {args.source}")

    fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
    w   = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    h   = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    total_frames_prop = cap.get(cv2.CAP_PROP_FRAME_COUNT)
    total_frames = int(total_frames_prop) if total_frames_prop and total_frames_prop > 0 else 0

    if args.verbose:
        print(f"    - 해상도: {w}x{h}, FPS: {fps:.2f}, 총 프레임: {total_frames or '미상'}")

    print("[2/7] 출력 경로 설정")
    in_stem = Path(args.source).stem
    out_path = Path(args.out) if args.out else Path(f"{in_stem}_sam2_stable.mp4")

    print("[3/7] SAM2 로드:", args.model)
    model = SAM(args.model)

    print(f"[4/7] 추론 시작 (imgsz={args.imgsz}, device={args.device})")
    start_time = time.time()
    results_gen = model.predict(
        source=args.source,
        imgsz=args.imgsz,
        device=args.device,
        stream=True,
        verbose=args.verbose
    )

    print("[5/7] VideoWriter 초기화:", out_path)
    writer = cv2.VideoWriter(
        str(out_path),
        cv2.VideoWriter_fourcc(*"mp4v"),
        fps,
        (w, h)
    )
    if not writer.isOpened():
        raise RuntimeError("VideoWriter 초기화 실패: 코덱 확인 요망")

    print("[6/7] 안정화 렌더링 시작… (ID/색상 고정 + EMA)")
    tracker = MaskTracker(
        iou_thr=args.iou_thr,
        ema_alpha=args.ema_alpha,
        min_area=args.min_area,
        smooth_kernel=args.smooth_kernel,
        memory=args.memory
    )

    frame_count = 0
    pbar = tqdm(total=total_frames if (HAVE_TQDM and total_frames > 0) else None,
                unit="f", dynamic_ncols=True) if HAVE_TQDM else None
    last_log = 0

    for res in results_gen:
        frame = res.orig_img.copy()

        masks_bin = []
        if res.masks is not None and getattr(res.masks, "data", None) is not None:
            m = res.masks.data
            if hasattr(m, "cpu"):
                m = m.cpu().numpy()
            for i in range(m.shape[0]):
                mb = (m[i] > 0.5).astype(np.uint8)
                masks_bin.append(mb)

        tracked = tracker.update(masks_bin)

        overlay = frame.copy()
        for tid, mb in tracked:
            color = color_from_id(tid)
            overlay[mb == 1] = (
                int(args.alpha_overlay * color[0] + (1 - args.alpha_overlay) * overlay[mb == 1][:, 0].mean()),
                int(args.alpha_overlay * color[1] + (1 - args.alpha_overlay) * overlay[mb == 1][:, 1].mean()),
                int(args.alpha_overlay * color[2] + (1 - args.alpha_overlay) * overlay[mb == 1][:, 2].mean()),
            )
            contours, _ = cv2.findContours(mb, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            cv2.drawContours(overlay, contours, -1, color, thickness=2)

        frame = overlay
        writer.write(frame)
        frame_count += 1

        now = time.time()
        elapsed = now - start_time
        fps_inst = frame_count / elapsed if elapsed > 0 else 0.0
        if pbar is not None:
            pbar.update(1)
            pbar.set_postfix({
                "fps": f"{fps_inst:.1f}",
                "elapsed": human_time(elapsed),
                "eta": human_time((total_frames - frame_count) / fps_inst) if (total_frames > 0 and fps_inst > 0) else "N/A"
            })
        else:
            if (frame_count - last_log) >= max(1, args.log_every):
                last_log = frame_count
                if total_frames > 0 and fps_inst > 0:
                    remain = (total_frames - frame_count) / fps_inst
                    pct = 100.0 * frame_count / total_frames
                    sys.stdout.write(
                        f"\r프레임 {frame_count}/{total_frames} ({pct:5.1f}%) | {fps_inst:5.1f} fps | 경과 {human_time(elapsed)} | 남은예상 {human_time(remain)}   "
                    )
                else:
                    sys.stdout.write(
                        f"\r프레임 {frame_count} | {fps_inst:5.1f} fps | 경과 {human_time(elapsed)}   "
                    )
                sys.stdout.flush()

    if pbar is not None:
        pbar.close()
    else:
        print()

    writer.release()
    cap.release()

    total_time = time.time() - start_time
    print(f"\n[완료] {frame_count}프레임 처리 → 저장: {out_path.resolve()}")
    print(f"총 소요시간: {human_time(total_time)} | 평균 처리속도: {frame_count / total_time:.2f} fps")


if __name__ == "__main__":
    main()
  • parse_args에서 경로설정 변경


  • 실행

python segment_auto.py \
--image /home/bang/Downloads/IMG_1180.jpeg \
--checkpoint ./checkpoints/sam2.1_hiera_base_plus.pt \
--config configs/sam2.1/sam2.1_hiera_b+.yaml \
--out runs/sam2_auto_bplus \
--overlay \
--max-side 2048 --pps 32 --ppb 64 --crop-layers 0








참고자료

SAM2 - Meta

SAM2 - Paper

pixabay - 무료 이미지 및 동영상

profile
AI & Robotics

0개의 댓글