num_heads 개수만큼 Q,K,V를 나누는 함수이다. 이 과정은 두 가지 작업으로 분류된다.
def split_heads(self, inputs, batch_size):
inputs = tf.reshape(
inputs, shape=(batch_size, -1, self.num_heads, self.depth))
return tf.transpose(inputs, perm=[0, 2, 1, 3])
1. Reshape 작업
- inputs 의 형상이 (batch_size, seq_len, hid_dim) -> (batch_size, seq_len, n_heads, depth) 바뀜.
- 여기서 inputs 텐서는 선형 레이어를 통과한 후의 텐서이다.
- self.depth 는 헤드의 깊이. 즉, 만들어지는 Q,K,V 벡터의 차원으로 hid_dim / h_heads 로 계산된다.
2. Transpose 작업
- 이 함수는 텐서의 차원을 재배열 한다.
- perms = [0, 2, 1, 3] 을 통해 (batch_size, n_heads, seq_len, depth)로 바뀜.
예시
- batch_size = 32
- seq_len = 10
- hid_dim = 512
- n_heads = 8
- depth = hid_dim / n_heads = 64 라고 가정해보자.
- inputs 형상은 (32,10,512), reshape 한 형상은 (32,10,8,64), 전치하고 난 뒤 형상은 (32,8,10,64)
reshape 함수에서 -1 을 사용하는 이유는 무엇일까????
- reshape함수에서 -1을 사용하는 이유는 TensorFlow가 해당 차원의 크기를 자동으로 계산하도록 하기 위함이다. 예시로 자세히 알아보자.
import tensorflow as tf
batch_size = 32
seq_len = 10
hid_dim = 512
inputs = tf.random.uniform((batch_size, seq_len, hid_dim))
n_heads = 8
depth = hid_dim // n_heads
reshaped = tf.reshape(inputs, (batch_size, seq_len, n_heads, depth))
- (32,10,512) 형상을 갖는 inputs. 이 형상을 (batch_size, seq_len, n_heads, depth) 형상으로 변경하고 싶다고 가정해보자. '-1'을 사용하지 않았지만, 특정 차원의 크기를 지정하지 않으면 직접 계산해야한다. 위에선 'depth'를 직접 계산하여 넣었지만 '-1'을 사용하여 자동으로 계산할 수 있다. '-1'은 자동으로 해당 차원의 크기를 계산한다.
reshaped = tf.reshape(inputs, (batch_size, -1, n_heads, depth))
-실제 계산과정은 다음과 같다.
# 원래 텐서의 총 요소 수 = batch_size * seq_len * hid_dim
total_elements = batch_size * seq_len * hid_dim
#(32,10,512)
# reshape 후의 각 차원의 크기
new_shape = (batch_size, seq_len, n_heads, depth) # (32, 10, 8, 64)
# 총 요소 수는 동일해야 함
assert total_elements == batch_size * seq_len * n_heads * depth # 163840 = 32 * 10 * 8 * 64
- 결론: '-1' 은 전체 요소수를 유지하면서 특정 차원의 크기를 자동으로 계산하도록 한다. 코드의 가독성을 높이기 위함