
최근 제가 리뷰했던 VideoMind (ICLR 2025) 프로젝트는 긴 비디오 시퀀스를 추론하기 위해 에이전트 시스템(Planner, Grounder)과 Chain-of-LoRA를 도입하여 성공적인 성과를 거두었습니다.
하지만 냉정하게 시스템 관점에서 돌아보면, 이는 Transformer가 가지는 KV Cache 폭발 문제를 회피하기 위한 '소프트웨어적 우회' 전략이었습니다. 한 번에 긴 영상을 보면 메모리가 터지니, 영상을 잘게 쪼개서 에이전트가 순차적으로 보도록 만든 것이죠. 이 방식은 메모리는 아낄 수 있지만, 필연적으로 파이썬 루프와 I/O 병목으로 인한 심각한 추론 지연을 야기합니다.
세미나에서 리뷰한 SAMBA (ICLR 2025)는 이 문제를 아키텍처 레벨에서 근본적으로 해결합니다. 선형 순환 모델인 Mamba와 슬라이딩 윈도우 어텐션(SWA)을 결합하여 백본 자체가 복잡도를 가지게 만듭니다.
저는 동일한 Long Video Feature Extraction Task를 가정하고, 기존 VideoMind의 에이전트 우회 방식과 SAMBA의 End-to-End 방식을 직접 PyTorch 코드로 구현하여 성능을 3-Way로 벤치마킹했습니다.
긴 비디오 시퀀스(Token Length 최대 24,000)를 입력으로 주었을 때, 세 가지 접근법의 메모리(VRAM)와 속도(Latency)를 프로파일링했습니다.
아래는 제가 직접 설계하여 Colab 환경에서 검증한 벤치마킹 스크립트입니다. Transformer의 병목과 SAMBA의 병렬 처리 구조를 모사했습니다.
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import gc
import time
# 하드웨어 세팅
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# =====================================================================
# 1. Standard Transformer (VideoMind의 근본적 한계)
# =====================================================================
class FullAttentionLayer(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.qkv = nn.Linear(d_model, d_model * 3)
self.out = nn.Linear(d_model, d_model)
self.n_heads = n_heads
def forward(self, x):
B, N, D = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, D // self.n_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
# PyTorch 최적화를 우회하여 실제 O(N^2) 메모리 병목을 유도
scores = torch.matmul(q, k.transpose(-2, -1)) / (q.size(-1) ** 0.5)
attn = torch.nn.functional.softmax(scores, dim=-1)
attn_out = torch.matmul(attn, v)
return self.out(attn_out.permute(0, 2, 1, 3).reshape(B, N, D))
# =====================================================================
# 2. VideoMind Agent Bypass (기존 연구의 소프트웨어적 우회 전략 모사)
# - 긴 영상을 한 번에 못 보니, Chunk(예: 1000 프레임) 단위로 잘라서
# 순차적으로 Attention을 돌리고 결과를 취합하는 방식 (Chain-of-LoRA 스타일)
# =====================================================================
class VideoMindAgentBypass(nn.Module):
def __init__(self, d_model, n_heads, chunk_size=1000):
super().__init__()
self.transformer_block = FullAttentionLayer(d_model, n_heads)
self.chunk_size = chunk_size
def forward(self, x):
B, N, D = x.shape
out_chunks = []
# 에이전트가 영상을 쪼개서 순차적으로 추론하는 과정 모사 (O(N)번의 순차 루프)
for i in range(0, N, self.chunk_size):
chunk = x[:, i:i+self.chunk_size, :]
# 각 청크별로 Transformer 연산
processed_chunk = self.transformer_block(chunk)
out_chunks.append(processed_chunk)
return torch.cat(out_chunks, dim=1)
# =====================================================================
# 3. Samba Layer Proxy (본 논문의 해결책: End-to-End O(N) 하이브리드)
# =====================================================================
class SambaHybridLayer(nn.Module):
def __init__(self, d_model, n_heads, window_size=1000):
super().__init__()
# 1. Mamba Proxy: O(N) 순환 상태로 전체 문맥을 압축
self.mamba_proxy = nn.Sequential(
nn.Conv1d(d_model, d_model, kernel_size=4, padding=3),
nn.SiLU(),
nn.Linear(d_model, d_model)
)
# 2. SWA Proxy: 병렬 처리 가능한 국소 어텐션
self.window_size = window_size
self.qkv_swa = nn.Linear(d_model, d_model * 3)
self.out_swa = nn.Linear(d_model, d_model)
self.n_heads = n_heads
def forward(self, x):
B, N, D = x.shape
# Temporal Dynamics (Mamba - End to End)
x_mamba = self.mamba_proxy[0](x.transpose(1, 2))[..., :N].transpose(1, 2)
x_mamba = self.mamba_proxy[2](self.mamba_proxy[1](x_mamba))
qkv = self.qkv_swa(x_mamba).reshape(B, N, 3, self.n_heads, D // self.n_heads)
# 병렬 Block Local Attention (우회 없이 한 번에 연산)
pad_len = (self.window_size - N % self.window_size) % self.window_size
if pad_len > 0:
qkv = F.pad(qkv, (0, 0, 0, 0, 0, 0, 0, pad_len))
N_padded = N + pad_len
num_blocks = N_padded // self.window_size
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q_block = q.view(B, self.n_heads, num_blocks, self.window_size, -1).transpose(1, 2).reshape(B * num_blocks, self.n_heads, self.window_size, -1)
k_block = k.view(B, self.n_heads, num_blocks, self.window_size, -1).transpose(1, 2).reshape(B * num_blocks, self.n_heads, self.window_size, -1)
v_block = v.view(B, self.n_heads, num_blocks, self.window_size, -1).transpose(1, 2).reshape(B * num_blocks, self.n_heads, self.window_size, -1)
attn_chunk = F.scaled_dot_product_attention(q_block, k_block, v_block)
out = attn_chunk.view(B, num_blocks, self.n_heads, self.window_size, -1).transpose(1, 2).reshape(B, N_padded, self.n_heads, -1)
out = out[:, :N, :, :]
return self.out_swa(out.reshape(B, N, D))
# =====================================================================
# 4. Profiling System
# =====================================================================
def profile_task(model_class, seq_lengths, model_name, d_model=256, n_heads=8):
peak_mems = []
lats = []
for seq_len in seq_lengths:
try:
model = model_class(d_model, n_heads).to(device)
# Long Video Feature Extraction Task 모사 (1배치, 긴 프레임)
dummy_video_features = torch.randn(1, seq_len, d_model, device=device)
# Warmup
with torch.no_grad():
for _ in range(2):
_ = model(dummy_video_features)
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# Run Measurement
with torch.no_grad():
start_event.record()
_ = model(dummy_video_features)
end_event.record()
torch.cuda.synchronize()
peak_mem = torch.cuda.max_memory_allocated() / (1024 ** 2)
latency_ms = start_event.elapsed_time(end_event)
peak_mems.append(peak_mem)
lats.append(latency_ms)
print(f"[{model_name}] Seq {seq_len:<6}: {peak_mem:.1f} MB / {latency_ms:.1f} ms")
except RuntimeError as e:
if "out of memory" in str(e).lower():
print(f"[{model_name}] Seq {seq_len:<6}: OOM 발생 (메모리 폭발)")
peak_mems.append(None)
lats.append(None)
else:
raise e
finally:
if 'model' in locals(): del model
if 'dummy_video_features' in locals(): del dummy_video_features
torch.cuda.empty_cache()
gc.collect()
return peak_mems, lats
# =====================================================================
# 5. Run & Visualization
# =====================================================================
seq_lengths = [1000, 2000, 4000, 8000, 16000, 24000]
print("1. Profiling Native Transformer (OOM 발생 예정)...")
tf_mem, tf_lat = profile_task(FullAttentionLayer, seq_lengths, "Transformer")
print("\n2. Profiling VideoMind Bypass (메모리는 아끼지만 속도 지연 발생)...")
vm_mem, vm_lat = profile_task(VideoMindAgentBypass, seq_lengths, "VideoMind Bypass")
print("\n3. Profiling SAMBA End-to-End (메모리 & 속도 모두 최적화)...")
samba_mem, samba_lat = profile_task(SambaHybridLayer, seq_lengths, "SAMBA")

직접 PyTorch로 구현한 프로파일링 결과는 제 가설을 증명했습니다. 두 지표의 결과를 비교해 보겠습니다.
Metric 1: Peak VRAM
여기까지만 보면 "그냥 VideoMind처럼 영상을 잘라서 주면 되는 거 아냐?"라고 할 수 있습니다. 하지만 진정한 차이는 다음 지표에서 나타납니다.
Metric 2: Inference Latency (소프트웨어 우회 vs 하드웨어 네이티브)
💡 최종 어필 포인트
이 토이 프로젝트를 통해, 소프트웨어(에이전트)로 우회하던 VideoMind의 속도 지연 한계를 명확히 인지하고, SAMBA와 같은 하이브리드 아키텍처를 Video 모델의 네이티브 백본으로 Task Transfer하는 것이 가장 근본적인 해결책임을 데이터로 입증했습니다.
단순히 최신 논문의 지식을 습득하는 것에 그치지 않고, 이를 저의 기존 연구가 가졌던 시스템적 한계를 돌파하는 실제적인 아키텍처 솔루션으로 즉각 치환하여 코드로 증명해 보았다는 점이 이번 세미나의 가장 큰 수확입니다.