
# 예시 코드로 살펴보는 그룹 쿼리 어텐션
def group_query_attention(queries, keys, values):
# queries: [batch_size, num_queries, d_model] # 예: (32, 4, 256)
# keys: [batch_size, seq_length, d_model] # 예: (32, 8, 256)
# values: [batch_size, seq_length, d_model] # 예: (32, 8, 256)
# 1. 어텐션 스코어 계산
attention_scores = torch.matmul(queries, keys.transpose(-2, -1))
attention_scores = attention_scores / math.sqrt(d_model)
# 2. 소프트맥스 적용
attention_weights = F.softmax(attention_scores, dim=-1)
# 3. 가중치 적용 및 출력 계산
output = torch.matmul(attention_weights, values)
return output
그룹 쿼리 어텐션은 특히 긴 시퀀스를 처리할 때 효율적이며, transformer 모델의 성능을 유지하면서도 계산 비용을 크게 줄일 수 있는 장점이 있습니다. multi-head attention의 변형으로서, 실제 적용 시에는 모델의 크기와 태스크의 특성에 따라 적절한 쿼리 수를 선택하는 것이 중요합니다.
예문 "나는 오늘 공원에서 산책을 했다"
입력 토큰들 (8개):
["나는", "오늘", "공원", "에서", "산책", "을", "했다", "<pad>"]
쿼리 토큰들 (4개):
Q1: 주체/시간 관련 쿼리 # "나는", "오늘" 정보 수집
Q2: 장소 관련 쿼리 # "공원", "에서" 정보 수집
Q3: 행동 관련 쿼리 # "산책", "을" 정보 수집
Q4: 문장 완성 쿼리 # "했다" 정보 수집
이렇게 처리된 정보는 다음 층으로 전달되어 추가 처리가 이루어집니다. 실제로는 더 복잡한 계산이 이루어지지만, 이런 방식으로 토큰들의 정보가 쿼리를 통해 효율적으로 처리됩니다.