![]() | ![]() |
|---|
docs/Train.md: 학습 코드 실행 방법 + 하이퍼 파라미터 있음. main_pretrain_tobo.py: 메인 실행 코드. Train.md에서 호출방법 있음engine_tobo.py: 1 epoch 학습 코드models_tobo.py: 젤 중요한 코드. tobo 아키텍쳐 정의, loss 정의가 있음timm/: tobo 구현에 사용되는 베이스모듈 제공 (PatchEmbed 등)util/: 데이터로더, optimizer 등 제공models_tobo.py만 봐도 충분함accum_iter=8
aux_path=sm
batch_size=24
epochs=400
local_data_path=/mnt/tmp/kinetics400/videos
mask_ratio=0.9
max_distance=96
max_frames=1
min_frames=1
model_name=vit_small_patch16
num_frames=2
repeated_sampling=2
save_path=/mnt/tmp/checkpoints
tgt_path=csm
w_tobo=1.0
w_mim=1.0
warmup_epochs=40
num_gpus_per_node=$(nvidia-smi -L | wc -l)
python -m torch.distributed.launch --nproc_per_node=${num_gpus_per_node} main_pretrain_semae.py \
--batch_size ${batch_size} \
... 생략 ...
--norm_pix_loss \
우선 제공된 학습 파라미터부터 파악하자
accum_iter 란?total batch: RSP를 따라서, total batch가 1536이 되도록 해야한다. total batch = batch_per_gpu * num_gpu * accum_iterbatch_size=24로 되어있는 것은 per gpu 값이며, 총 gpu 8개 쓴다accum_iter=8로 설정해서 1536 = 24 x 8 x 8을 맞추었다repeated_sampling: "Augment Your Batch" 논문의 기법. 배치 내에서 동일한 이미지 n번씩 복사해 서로 다른 augmentation 적용 함. w_tobo=1.0,w_mim=1.0: 두 loss의 weight 값. 논문에 안적힌 mim loss가 존재tgt_path=csm, aux_path=sm: decoder block 호출 방법cross-attn + self-attn + FFN의 3단계 구성csm path는 cross -> self -> FFN : 전부 쓰는 경로sm path는 self -> FFN : cross 없이 쓰는 경로max_frames, min_frames, num_frames: 안 쓰임norm_pix_loss: pixel prediction target을 raw pixel이 아닌 normalized 된 픽셀로 할 것인지에 대한 세팅. True로 세팅됨 (후술할 예정)다음 순서로 큰 그림을 이해하고 하나씩 보세요.
forward: 이미지 주면 loss 구하는 가장 상위의 최종 함수forward_loss: 이미지(gt)와 pred 주면 L2 loss 구해주는 함수forward_encoder_asym: encoder forwarding. 이미지와 mask_ratio를 주면 masking해서 forward하고, 임베딩과 mask 위치를 returnforward_decoder_tobo: decoder forwarding. src와 tgt의 임베딩과 mask위치를 주면 두 가지 모드로 각각 forwarding해서 mim loss를 위한 pixel pred와 tobo loss를 위한 pixel pred 두 개를 return 함. (1) src 이미지에 대해 forward_encoder_asym 호출
(2) tgt 이미지에 대해 forward_encoder_asym 호출 (with masking)
(3) src와 tgt 임베딩들을 forward_decoder_tobo에 넣어서 pixel pred 값들을 얻기
(4) 두 pixel pred 값에 대해서 forward_loss 각각 호출해서 mim_loss와 tobo_loss 각각 구해서 return
오리지널 코드에 주석을 달음 + minimal 리팩토링. 오리지널 대신 이걸로 이해해도 동일하게 동작함.
def forward_encoder_asym(self, imgs, mask_ratio=0.0, src=False):
x = self.patch_embed(imgs) #(B, H, W, 3)->(B, L, C)
x = x + self.pos_embed[:, 1:, :] #(B, L, C)
if mask_ratio != 0.0:
x, mask, ids_restore, ids_keep = self.random_masking(x, mask_ratio)
# x: (B, L, C) -> (B, l, C), where l = (1-r)*L
else:
mask, ids_restore, ids_keep = None, None, None
cls = self.cls_token + self.pos_embed[:, :1, :] #(1, 1, C)
cls = cls.expand(x.shape[0], -1, -1) #(B, 1, C)
x = torch.cat((cls, x), dim=1) #(B, L'+1, C)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x, mask, ids_restore, ids_keep
self.patch_embed는 16x16 conv임self.pos_embed는 positional embedding으로 (1, L+1, C)로 고정됨 (0번째 자리가 cls의 pos_embed)self.norm은 ViT 블럭 끝나고 나면 하는 layernorm임.random_masking을 하고 나면 x는 L -> l개로 selection 됨: (B, l, C), l = (1-r)*Lmask는 masking 한 영역만 1로 표시된 binary mask: (B, L)의 shape임ids_restore는 random shuffle의 역 매핑임 (복원을 위한): (B, L)ids_keep은 살아있는 놈들의 index만 모은 놈임: (B, l)코드의 쉬운 이해를 위해서 주석 및 리팩토링을 함. 원래 코드랑 동일하게 동작 (호출 부분 뒤에 2개만 빼면)
def forward_decoder_tobo(self, src_cls_x, tgt_cls_x, tgt_ids_restore):
B, L, _ = src_x.shape
B, l, _ = tgt_x.shape
src_cls_x = self.decoder_embed_mae(src_cls_x) #(B, L+1, C) -> (B, L+1, D)
tgt_cls_x = self.decoder_embed_mae(tgt_cls_x) #(B, l+1, C) -> (B, l+1, D)
src_cls, src_x = src_cls_x[:, :1], src_cls_x[:, 1:] #(B, 1, D), (B, L, D)
tgt_cls, tgt_x = tgt_cls_x[:, :1], tgt_cls_x[:, 1:] #(B, 1, D), (B, l, D)
_, _, D = tgt_x.shape
# Restore mask embedding for tgt_x
mask_tokens = self.mask_token.repeat(B, L-l, 1) #(1, 1, D) -> (B, L-l, D)
tgt_x = torch.cat([tgt_x, mask_tokens], dim=1) #(B, L, D)
tgt_x = torch.gather(tgt_x, dim=1,
index=tgt_ids_restore.unsqueeze(-1).repeat(1, 1, D)) # unshuffle
# Add positional embedding (broadcasting)
src_cls = src_cls + self.decoder_pos_embed[:, :1] #(B, 1, D)
tgt_cls = tgt_cls + self.decoder_pos_embed[:, :1] #(B, 1, D)
src_x = src_x + self.decoder_pos_embed[:, 1:] #(B, L, D)
tgt_x = tgt_x + self.decoder_pos_embed[:, 1:] #(B, L, D)
# 1. SiamMAE prediction
x1 = torch.cat([tgt_cls, tgt_x], dim=1) #(B, L+1, D)
kvx = torch.cat([src_cls, src_x], dim=1) #(B, L+1, D)
for blk in self.decoder_blocks:
x1 = blk(x1, kvx=kvx, num_frames=1, path="csm") #(B, L+1, D)
# 2. ToBo prediction
x2 = torch.cat([src_cls, tgt_x], dim=1) #(B, L+1, D)
for blk in self.decoder_blocks:
x2 = blk(x2, kvx=None, num_frames=1, path="sm") #(B, L+1, D)
x = torch.cat([x1, x2], dim=0) # (2B, L+1, D)
x = self.decoder_norm(x) # (2B, L+1, D)
x = self.decoder_pred(x) # (2B, L+1, 16*16*3)
x = x[:, 1:, :] # (2B, L, 16*16*3)
return x
patchify, unpatchify: 그냥 이미지의 pixel 값을 flatten해서 벡터화 하거나 그 역연산. (16x16x3 -> 768). 그대로 접고 펴기만 하는 연산. patch_embed랑 기능이 다름. 그냥 recon loss 주거나 normalize를 편하게 하기 위해서 (B, L, C) 형태로 만들어주기 위함forward_loss: 단순히 (B, L, C)로 patchify한 픽셀 값과 decoder가 pred한 값 주면 L2 loss 리턴하는 함수인데, 중요한 두 가지 옵션이 있음https://arxiv.org/pdf/2311.00961