
self-attention은 input 표현을 향상시키기 위해 시퀀스 내의 각 position이 다른 모든 position과 상호작용하여 관련성을 결정할 수 있도록 하는 기술이다.
[질문 토큰들] + [SEP] + [답변 토큰들]훈련되지 않는 가중치를 이용하여 간단한 self-attention mechanism을 구현해보자.
query = inputs[1] # 2nd input token is the query
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
# dot product (transpose not necessary here since they are 1-dim vectors)
attn_scores_2[i] = torch.dot(x_i, query)
print(attn_scores_2)
torch.softmax()를 이용한다.attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
query = inputs[1] # 2nd input token is the query
context_vec_2 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
context_vec_2 += attn_weights_2[i]*x_i
모든 input 토큰에 대한 를 구해보자.
attn_scores = inputs @ inputs.T
attn_weights = torch.softmax(attn_scores, dim=-1)
all_context_vecs = attn_weights @ inputs
기존 트랜스포머 아키텍처에서 사용되는 self-attention mechanism인 scaled dot-product attention을 구현해보자.
requires_grad를 False로 두고, 훈련 동안은 True로 둔다.W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
keys = inputs @ W_key # 3D input 토큰을 2D 임베딩 공간에 투영
values = inputs @ W_value
attn_scores_2 = query_2 @ keys.T # All attention scores for given query
d_k = keys.shape[1] # 일단 x(2)에 대해서만 계산
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
context_vec_2 = attn_weights_2 @ values
self-attention mechanism을 class로 구현하자.
nn.Linear에 bias를 사용하지 않으면 행렬 곱셈과 동일하며, weight 초기화 스키마로 인해 nn.Parameter보다 안정적인 모델 훈련이 가능하다.class SelfAttention_v2(nn.Module):
def __init__(self, d_in, d_out, qkv_bias=False):
super().__init__()
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
def forward(self, x):
keys = self.W_key(x)
queries = self.W_query(x)
values = self.W_value(x)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
context_vec = attn_weights @ values
return context_vec
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))
causal attention은 대각선 위의 attention weights가 마스킹되어, 주어진 입력에 대해 LLM이 context vector를 계산하는 동안 future tokens을 활용할 수 없도록 보장한다.
torch.tril 함수를 통해 위 masked matrix를 만들 수 있다. 하지만 softmax 이후 mask가 적용되면 확률분포에 문제가 생길 수 있으므로, 행의 합이 1이 되도록 정규화를 해야한다. mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=-1)
오버피팅을 줄이기 위해 훈련 동안 드롭아웃을 추가적으로 적용해보자.
dropout = torch.nn.Dropout(0.5) # dropout rate of 50%
print(dropout(attn_weights))
여러 개의 input과 causal, dropout masks가 적용된 self-attention을 구현해보자.
class CausalAttention(nn.Module):
def __init__(self, d_in, d_out, context_length,
dropout, qkv_bias=False):
super().__init__()
self.d_out = d_out
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.dropout = nn.Dropout(dropout) # New
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # New
def forward(self, x):
b, num_tokens, d_in = x.shape # New batch dimension b
keys = self.W_key(x)
queries = self.W_query(x)
values = self.W_value(x)
attn_scores = queries @ keys.transpose(1, 2) # Changed transpose
attn_scores.masked_fill_( # New, _ ops are in-place
self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) # `num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_size
attn_weights = torch.softmax(
attn_scores / keys.shape[-1]**0.5, dim=-1
)
attn_weights = self.dropout(attn_weights) # New 드롭아웃
context_vec = attn_weights @ values
return context_vec
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
torch.triu()를 사용하여 미래 시점의 정보가 사용되지 않도록 하고, register_buffer()를 사용하여 학습되지 않는 고정된 텐서로 등록한다.masked_fill_()을 사용하여 미래 시점의 attn_scores을 -∞로 설정한다.
위 그림처럼 지금까지 구현한 single-head attention을 쌓아서 아래의 multi-head attention을 얻어보자.
class MultiHeadAttentionWrapper(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
super().__init__()
self.heads = nn.ModuleList(
[CausalAttention(d_in, d_out, context_length, dropout, qkv_bias)
for _ in range(num_heads)]
)
def forward(self, x):
return torch.cat([head(x) for head in self.heads], dim=-1)
torch.manual_seed(123)
context_length = batch.shape[1] # This is the number of tokens
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(
d_in, d_out, context_length, 0.0, num_heads=2
)
context_vecs = mha(batch)
CausalAttention이 2개이므로 output 임베딩 차원은 4이다.
CausalAttention을 여러 개 감싸지 않고MultiHeadAttention이라는 독립형 클래스를 작성할 수 있고, 단일 attention heads를 합치지 않고 단일 , , 가중치 행렬을 생성하여 이를 각 attention head에 대한 개별 행렬로 나눌 수 있다.
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
super().__init__()
assert (d_out % num_heads == 0), \
"d_out must be divisible by num_heads"
self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout)
self.register_buffer(
"mask",
torch.triu(torch.ones(context_length, context_length),
diagonal=1)
)
def forward(self, x):
b, num_tokens, d_in = x.shape
keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
queries = self.W_query(x)
values = self.W_value(x)
# We implicitly split the matrix by adding a `num_heads` dimension
# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
keys = keys.transpose(1, 2)
queries = queries.transpose(1, 2)
values = values.transpose(1, 2)
# Compute scaled dot-product attention (aka self-attention) with a causal mask
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
# Original mask truncated to the number of tokens and converted to boolean
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
# Use the mask to fill attention scores
attn_scores.masked_fill_(mask_bool, -torch.inf)
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
attn_weights = self.dropout(attn_weights)
# Shape: (b, num_tokens, num_heads, head_dim)
context_vec = (attn_weights @ values).transpose(1, 2)
# Combine heads, where self.d_out = self.num_heads * self.head_dim
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
context_vec = self.out_proj(context_vec) # optional projection
return context_vec
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
.view로 각 matrix를 num_heads만큼 쪼개고, context vector를 구할 때 다시 .contiguous()를 사용해 메모리 상의 연속성을 보장한 후 결합한다.self.out_proj)를 추가했지만, 모델링 성능에 영향을 미치지 않으며 없앨 수 있다는 연구가 등장했다.
MultiHeadAttentionWrapper과 MultiHeadAttention의 차이이다.MultiHeadAttentionCombinedQKV 클래스는 단일 가중치 행렬(self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias))을 사용하며, 이를 통해 query, value, key를 계산한다. q, k, v = qkv.unbind(0)를 통해 각 값을 얻는다.torch.einsum을 사용하여 Einstein summation을 구현한다.scaled_dot_product_attention와 self-attention의 메모리 최적화된 버전인 FlashAttention을 사용한다.context_vec = nn.functional.scaled_dot_product_attention( queries, keys, values, attn_mask=None, dropout_p=use_dropout, is_causal=True)를 통해 context vector를 얻는다. if self.context_length >= num_tokens:
attn_mask = self.mask[:num_tokens, :num_tokens]
else:
attn_mask = self.mask[:self.context_length, :self.context_length]torch.nn.MultiheadAttention을 사용한다.self.multihead_attn = nn.MultiheadAttention(embed_dim=d_out, num_heads=num_heads, dropout=dropout, bias=qkv_bias, add_bias_kv=qkv_bias, batch_first=True,)MHAPyTorchClass 인스턴스를 생성할 때 매개변수 need_weights를 False로 두면, scaled_dot_product_attention를 사용한다.score_mod을 통한 attention score 수정이 가능하다. context_vec = torch.nn.flex_attention(queries, keys, values, block_mask=attn_mask)self.register_buffer를 통해 사용한다.mask tensor가 GPU를 사용하더라도 PyTorch parameter가 아니므로 CPU에서 존재하기 때문에 오류가 발생한다.self.mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))