ToBo Code Guide

우병주·2026년 1월 1일

Overview

  • NIPS'25 ToBo의 코드를 쉽게 이해하기 위한 가이드라인
  • ToBo의 코드는 RSP를 기반으로 하여 RSP와 비교해서 봐도 좋음
  • docs/Train.md: 학습 코드 실행 방법 + 하이퍼 파라미터 있음.
  • main_pretrain_tobo.py: 메인 실행 코드. Train.md에서 호출방법 있음
  • engine_tobo.py: 1 epoch 학습 코드
  • models_tobo.py: 젤 중요한 코드. tobo 아키텍쳐 정의, loss 정의가 있음
  • timm/: tobo 구현에 사용되는 베이스모듈 제공 (PatchEmbed 등)
  • util/: 데이터로더, optimizer 등 제공
  • 나머지 부분은 그냥 구성 요소고, ToBo가 어떻게 동작하고 학습되는지 보기 위해서는 models_tobo.py만 봐도 충분함

docs/Train.md

accum_iter=8
aux_path=sm
batch_size=24
epochs=400
local_data_path=/mnt/tmp/kinetics400/videos
mask_ratio=0.9
max_distance=96
max_frames=1
min_frames=1
model_name=vit_small_patch16
num_frames=2
repeated_sampling=2
save_path=/mnt/tmp/checkpoints
tgt_path=csm
w_tobo=1.0
w_mim=1.0
warmup_epochs=40
num_gpus_per_node=$(nvidia-smi -L | wc -l)

python -m torch.distributed.launch --nproc_per_node=${num_gpus_per_node} main_pretrain_semae.py \
     --batch_size ${batch_size} \
     ... 생략 ...
     --norm_pix_loss \

우선 제공된 학습 파라미터부터 파악하자

  • accum_iter 란?
    • GPU 리소스(vram)이 부족해서 batch를 원하는 만큼 키울 수 없을때, gradient값을 n step 누적한 후에 한 번에 parameter를 업데이트 하는 기법 (effective batch)
    • contrastive learning 처럼 batch dependent한 연산이 있을시에는 손해 볼 수 있음. 이 연구는 해당 안됨
  • total batch: RSP를 따라서, total batch가 1536이 되도록 해야한다.
    • total batch = batch_per_gpu * num_gpu * accum_iter
    • batch_size=24로 되어있는 것은 per gpu 값이며, 총 gpu 8개 쓴다
    • accum_iter=8로 설정해서 1536 = 24 x 8 x 8을 맞추었다
  • repeated_sampling: "Augment Your Batch" 논문의 기법. 배치 내에서 동일한 이미지 n번씩 복사해 서로 다른 augmentation 적용 함.
    • 한 이미지에 대해 다양한 시각 변형을 동시에 배워서 정규화 효과를 얻음
    • 즉 1 epoch이 원래는 데이터셋을 한 번씩 보는건데, 여기선 n=2번씩 보는거
    • 따라서 논문엔 400 epoch을 쓰는 대신에 200 epoch & repeated_sampling=2 를 썻다고 나옴
  • w_tobo=1.0,w_mim=1.0: 두 loss의 weight 값. 논문에 안적힌 mim loss가 존재
  • tgt_path=csm, aux_path=sm: decoder block 호출 방법
    • decoder block은 cross-attn + self-attn + FFN의 3단계 구성
    • csm path는 cross -> self -> FFN : 전부 쓰는 경로
    • sm path는 self -> FFN : cross 없이 쓰는 경로
  • max_frames, min_frames, num_frames: 안 쓰임
    • max, min은 진짜로 안 쓰여서 의도 파악 불가
    • num_frames: src 프레임 수. 하지만 여기서 얼마로 설정되든 코드 내부에서 1로 정해놔서 세팅 의미 없음 (ToBo 연구 과정에서 과거 n frames로 부터 미래를 예측하는 코드 설계했음을 암시)
  • norm_pix_loss: pixel prediction target을 raw pixel이 아닌 normalized 된 픽셀로 할 것인지에 대한 세팅. True로 세팅됨 (후술할 예정)

models_tobo.py: methods overview

다음 순서로 큰 그림을 이해하고 하나씩 보세요.

  • forward: 이미지 주면 loss 구하는 가장 상위의 최종 함수
  • forward_loss: 이미지(gt)와 pred 주면 L2 loss 구해주는 함수
  • forward_encoder_asym: encoder forwarding. 이미지와 mask_ratio를 주면 masking해서 forward하고, 임베딩과 mask 위치를 return
  • forward_decoder_tobo: decoder forwarding. src와 tgt의 임베딩과 mask위치를 주면 두 가지 모드로 각각 forwarding해서 mim loss를 위한 pixel pred와 tobo loss를 위한 pixel pred 두 개를 return 함.

models_tobo.py: forward

(1) src 이미지에 대해 forward_encoder_asym 호출
(2) tgt 이미지에 대해 forward_encoder_asym 호출 (with masking)
(3) src와 tgt 임베딩들을 forward_decoder_tobo에 넣어서 pixel pred 값들을 얻기
(4) 두 pixel pred 값에 대해서 forward_loss 각각 호출해서 mim_loss와 tobo_loss 각각 구해서 return

models_tobo.py: forward_encoder_asym

오리지널 코드에 주석을 달음 + minimal 리팩토링. 오리지널 대신 이걸로 이해해도 동일하게 동작함.

def forward_encoder_asym(self, imgs, mask_ratio=0.0, src=False):
    x = self.patch_embed(imgs) #(B, H, W, 3)->(B, L, C)
    x = x + self.pos_embed[:, 1:, :]   #(B, L, C)

    if mask_ratio != 0.0: 
        x, mask, ids_restore, ids_keep = self.random_masking(x, mask_ratio)
        # x: (B, L, C) -> (B, l, C), where l = (1-r)*L
    else:
        mask, ids_restore, ids_keep = None, None, None
    
    cls = self.cls_token + self.pos_embed[:, :1, :] #(1, 1, C)
    cls = cls.expand(x.shape[0], -1, -1) #(B, 1, C)
    x = torch.cat((cls, x), dim=1) #(B, L'+1, C)

    for blk in self.blocks:
        x = blk(x)
    x = self.norm(x)
    return x, mask, ids_restore, ids_keep
  • 그냥 이미지를 ViT encoder에 넣는 평범한 과정 with masking
  • self.patch_embed는 16x16 conv임
  • self.pos_embed는 positional embedding으로 (1, L+1, C)로 고정됨 (0번째 자리가 cls의 pos_embed)
  • self.norm은 ViT 블럭 끝나고 나면 하는 layernorm임.
  • random_masking을 하고 나면
    • x는 L -> l개로 selection 됨: (B, l, C), l = (1-r)*L
    • mask는 masking 한 영역만 1로 표시된 binary mask: (B, L)의 shape임
    • ids_restore는 random shuffle의 역 매핑임 (복원을 위한): (B, L)
    • ids_keep은 살아있는 놈들의 index만 모은 놈임: (B, l)

models_tobo.py: forward_decoder_tobo

코드의 쉬운 이해를 위해서 주석 및 리팩토링을 함. 원래 코드랑 동일하게 동작 (호출 부분 뒤에 2개만 빼면)

def forward_decoder_tobo(self, src_cls_x, tgt_cls_x, tgt_ids_restore):
	B, L, _ = src_x.shape
    B, l, _ = tgt_x.shape
    src_cls_x = self.decoder_embed_mae(src_cls_x) #(B, L+1, C) -> (B, L+1, D)
    tgt_cls_x = self.decoder_embed_mae(tgt_cls_x) #(B, l+1, C) -> (B, l+1, D)
    src_cls, src_x = src_cls_x[:, :1], src_cls_x[:, 1:] #(B, 1, D), (B, L, D)
    tgt_cls, tgt_x = tgt_cls_x[:, :1], tgt_cls_x[:, 1:] #(B, 1, D), (B, l, D)
    _, _, D = tgt_x.shape
    
    # Restore mask embedding for tgt_x
    mask_tokens = self.mask_token.repeat(B, L-l, 1) #(1, 1, D) -> (B, L-l, D)
    tgt_x = torch.cat([tgt_x, mask_tokens], dim=1) #(B, L, D)
    tgt_x = torch.gather(tgt_x, dim=1, 
        index=tgt_ids_restore.unsqueeze(-1).repeat(1, 1, D)) # unshuffle
    
    # Add positional embedding (broadcasting)
    src_cls = src_cls + self.decoder_pos_embed[:, :1] #(B, 1, D)
    tgt_cls = tgt_cls + self.decoder_pos_embed[:, :1] #(B, 1, D)
    src_x = src_x + self.decoder_pos_embed[:, 1:] #(B, L, D)
    tgt_x = tgt_x + self.decoder_pos_embed[:, 1:] #(B, L, D)
    
    # 1. SiamMAE prediction
    x1 = torch.cat([tgt_cls, tgt_x], dim=1) #(B, L+1, D)
    kvx = torch.cat([src_cls, src_x], dim=1) #(B, L+1, D)
    for blk in self.decoder_blocks:
        x1 = blk(x1, kvx=kvx, num_frames=1, path="csm") #(B, L+1, D)
        
    # 2. ToBo prediction
    x2 = torch.cat([src_cls, tgt_x], dim=1) #(B, L+1, D)
    for blk in self.decoder_blocks:
        x2 = blk(x2, kvx=None, num_frames=1, path="sm") #(B, L+1, D)
    
	x = torch.cat([x1, x2], dim=0) # (2B, L+1, D) 
    x = self.decoder_norm(x) # (2B, L+1, D) 
    x = self.decoder_pred(x) # (2B, L+1, 16*16*3) 
	x = x[:, 1:, :] # (2B, L, 16*16*3)
    return x
  • 원래 코드의 h, x라는 표기가 헷갈려서 정확히 src_cls_x와 tgt_cls_x로 표기함.
  • 코드의 순서를 조절함 가독성을 위해서
  • broadcasting을 고려해서 불필요한 repeat, expand를 제거하고 num_frame 등 사용되지 않는 옵션 제거함.

models_tobo.py: util methods

  • patchify, unpatchify: 그냥 이미지의 pixel 값을 flatten해서 벡터화 하거나 그 역연산. (16x16x3 -> 768). 그대로 접고 펴기만 하는 연산. patch_embed랑 기능이 다름. 그냥 recon loss 주거나 normalize를 편하게 하기 위해서 (B, L, C) 형태로 만들어주기 위함
  • forward_loss: 단순히 (B, L, C)로 patchify한 픽셀 값과 decoder가 pred한 값 주면 L2 loss 리턴하는 함수인데, 중요한 두 가지 옵션이 있음
    (1) norm_pix_loss (True): pixel target을 패치별로 normalize함. MAE에서도 decoder이 타겟이 raw pixel보다 normalized raw pixel일 시 성능이 더 올라감. gpt에 의하면 밝은 픽셀일 수록 loss의 크기가 커지고 어두운 픽셀일 수록 loss가 작아지기에 밝은 픽셀 복원에 집중하는걸 방지하는 효과가 있고 학습 안정화에 도움 된다고 함.
    (2) mask: mask token에만 loss를 걸지, visible token에 대해서는 loss를 걸지 않음. 이것이 효과적임 또한 MAE 논문에 나와있음

https://openaccess.thecvf.com/content/CVPR2025/papers/Baade_Self-Supervised_Cross-View_Correspondence_with_Predictive_Cycle_Consistency_CVPR_2025_paper.pdf

https://arxiv.org/pdf/2311.00961

https://arxiv.org/pdf/2512.13684

https://openaccess.thecvf.com/content/CVPR2025/papers/Liu_When_the_Future_Becomes_the_Past_Taming_Temporal_Correspondence_for_CVPR_2025_paper.pdf

0개의 댓글