1. 함수의 역할과 의의
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
...
return embedding
-
입력
timesteps: 주로 시간 축을 나타내는 정수(또는 실수) 시퀀스. 예) ([t_1, t_2, ..., t_N]).
dim: 출력 임베딩 벡터의 차원.
max_period: 최저주파수를 결정하기 위한 하이퍼파라미터(주기). 보통 10,000 같은 값을 사용합니다.
repeat_only: True로 설정 시, 위치 임베딩 없이 단순하게 timestep 값을 특정 차원으로 반복한 결과를 반환합니다.
-
출력
- (\text{embedding}): (\text{shape} = [N, \text{dim}])인 텐서로서, 각
timesteps 원소에 대해 사인/코사인 함수를 적용한 결과를 합쳐서 만든 임베딩 벡터가 반환됩니다.
2. 원리와 이유
- 딥러닝 모델(특히 트랜스포머 계열)에서 시점(
timesteps)을 직접 숫자 그대로 모델에 입력하면, 모델이 “시점이 커질수록 무언가가 달라져야 한다”라는 규칙성을 쉽게 학습하기 어렵습니다.
- 따라서 특정한 함수를 이용해 시점 정보를 벡터 형태(주로 사인/코사인 파)를 이용한 주기적(Position/Time) 임베딩으로 변환해 모델에 입력
- 주기성(Sinusoidal Embedding) 도입
- (\sin), (\cos) 같은 주기 함수를 사용하여, 입력되는 시점 (t)이 커질 때에도 주파수(주기)가 다른 여러 스케일로 변환된 값들을 모델이 이용할 수 있게 해줍니다.
- 즉, 시점이 선형적으로 증가해도 임베딩 공간에서는 다양한 주파수로 시점 차이가 반영되므로, 멀리 떨어진 시점 정보까지 구분하거나 패턴을 쉽게 학습할 수 있습니다.
- 다양한 스케일의 주파수
max_period=10000 등을 사용하면 (sin(t/10k)), (cos(t/10k)) 식으로, 서로 다른 주기로 시점 정보를 반영합니다.
- “가장 큰 주기(=가장 낮은 주파수) ~ 가장 작은 주기(=가장 높은 주파수)” 범위를 커버하는 것이 핵심입니다.
3. 코드 동작 상세
코드에서 핵심적으로 보는 부분은 다음과 같습니다:
if not repeat_only:
half = dim // 2
freqs = torch.exp(-math.log(max_period) *
torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
...
-
half = dim // 2
- 임베딩 차원
dim의 절반을 구합니다. 보통 (\cos) 부분과 (\sin) 부분을 각각 dim/2 차원으로 구성하려고 이렇게 나눕니다.
-
freqs 계산
freqs = torch.exp(-math.log(max_period) *
torch.arange(start=0, end=half, dtype=torch.float32) / half)
- (torch.arange(0,half))는 ([0,1,2,...(half−1)])로 구성된 텐서를 만듭니다.
-math.log(max_period) / half 라는 스칼라에 해당 인덱스들을 곱하고, 그 결과를 (exp)로 감쌉니다.
- 이를 수식으로 쓰면,
[
freqs[k]=exp(−halfln(max_period)⋅k),k=0,1,...,(half−1).
]
- 이 값들은 대략 (1)에서 (max_period1) 사이를 지수적으로 커버하게 됩니다.
- (k=0)일 때 (exp(0)=1),
- (k=half−1)일 때 (≈max_period1).
-
args = timesteps[:, None].float() * freqs[None]
- (timesteps)를 2D 형태로 확장해서 ([N, 1]) 형태로 만들고,
freqs를 ([1,half])로 확장해 곱합니다.
- 수식으로는
[
argsi,k=ti×freqs[k],
]
(i=1,…,N,k=1,…,half).
- 이는 "시점
t × 주파수 f_k"를 구해, 각 시점마다 서로 다른 주파수 대역이 적용되도록 만듭니다.
-
(\cos), (\sin) 계산 및 연결
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
- (cos(args))와 (sin(args))를 (dim/2) 차원씩 두 개 붙여서 최종 (dim) 차원 벡터로 만듭니다.
- 최종 임베딩은
[
embeddingi=[cos(ti⋅f0),…,cos(ti⋅fh−1),sin(ti⋅f0),…,sin(ti⋅fh−1)],
]
(h=half.)
-
if dim % 2: 블록
dim이 홀수이면, 사인/코사인 쌍이 아닌 1차원이 남습니다. 이때 torch.zeros_like(embedding[:, :1])로 0벡터를 붙여 차원을 맞춰줍니다.
-
repeat_only=True인 경우
else:
embedding = repeat(timesteps, 'b -> b d', d=dim)
- 여기서는 (\sin, \cos)를 사용하지 않고, 단순히
timesteps 값을 원하는 차원만큼 반복해 ([N, \text{dim}]) 형식으로 만드는 역할만 합니다.
4. 수학적 공식 요약
4.1 지수적으로 변하는 주파수 벡터 (freqs)
[
freqs[k]=exp(−halfln(max_period)×k),
k=0,1,…,half−1.
]
- (k=0)에서는 (freqs[0]=1)
- (k=half−1)에서는 (freqs[half−1]≈max_period1)
4.2 최종 임베딩 (\text{embedding}(t))
입력 시점 (t_i)에 대한 임베딩 벡터는
[
embedding(ti)=[cos(ti×freqs[0]),…,cos(ti×freqs[h−1]),sin(ti×freqs[0]),…,sin(ti×freqs[h−1])],
]
여기서 (h=2dim) (또는 dim//2)입니다.
- 이렇게 각 (ti) 마다 여러 주파수 대역의 (\sin)과 (\cos)을 계산해 1차원 벡터로 만듦으로써, 모델이 시간 정보(또는 위치 정보)에 대한 특징을 여러 스케일로 학습하도록 돕습니다.
5. 왜 사인/코사인 함수인가?
- 연속적 시점 매핑
- 사인/코사인은 주기 함수이므로, (시점 차이가 커도) 여러 주파수를 통해 큰 스케일의 변화와 작은 스케일의 변화를 모두 표현할 수 있습니다.
- 위치나 시간 순서 정보 유지
- 트랜스포머 등의 구조에서 순서를 명시적으로 표현하기 위해 고안된 아이디어입니다.
- 임의의 (\sin), (\cos)는 쉽게 미분 가능하고, 특히 (pos(t+Δt)−pos(t))에 대한 위상(phase) 차이가 비선형적이면서도 일정한 주기를 가지고 변하므로, 모델이 (t)와 (Δt)의 관계를 유연하게 포착할 수 있습니다.
- 학습이 필요 없는(Non-learned) 임베딩의 장점
- 파라미터 없이도 다양한 주파수 대역 정보를 이미 코드 차원에서 제공할 수 있으므로, 훈련 과정에서 “시간 순서 이해”에 대한 부담이 줄어듭니다.
- 물론 사인/코사인 대신, 학습 가능한 파라미터를 두어 임베딩을 학습형으로 만들기도 합니다
- (예: RoPE(Rotary Position Embedding), Learned Positional Embeddings 등).