근래에는 자연어 처리, 컴퓨터 비전을 가리지 않고 attention 혹은 transformer가 잘 활용되고 있다. 이로 인해 Flash Attention, xformers와 같은 efficient attention operation에 관련된 연구들이 활발히 진행되었으며 오늘 이 글에서는 xformers를 활용하는 방법에 대해서 소개하고자 한다.
Flash Attention은 후속 연구도 있고 Pytorch2.0에 통합도 되었지만 제한적인 GPU에서만 현재 지원되고 있다. 반면 xformers는 pytorch 위에서 구현된 library로 풀고자하는 문제나 모델의 디자인에 따라 성능이 떨어질 위험이 있으나 pytorch가 구동되는 곳이라면 어디서든 사용가능하다는 장점이 있다.
"""
Forked from
https://github.com/facebookresearch/xformers/blob/main/HOWTO.md#blocksparseattention
"""
import torch
from xformers.components import MultiHeadDispatch
from xformers.components.attention import BlockSparseAttention
BATCH = 4
HEADS = 4
SEQ = 4096
EMB = 512
BLOCK_SIZE = 32
DROPOUT = 0.0
dtype = torch.float16
# Let's try out a causal mask, but really it could be anything "block sparse enough"
causal_mask = torch.tril(torch.ones((SEQ, SEQ), device=torch.device("cuda"), dtype=dtype))
blocks = SEQ // BLOCK_SIZE
causal_layout = torch.tril(torch.ones([HEADS, blocks, blocks], dtype=torch.bool))
# Let's build our blocksparse attention. Please note that the layout can be
# [SEQ//BLOCK_SIZE, SEQ//BLOCK_SIZE] or [HEADS, SEQ//BLOCK_SIZE, SEQ//BLOCK_SIZE]
# so that _you can pass a different layout per head_
attention = BlockSparseAttention(layout=causal_layout, block_size=BLOCK_SIZE, dropout=DROPOUT, num_heads=HEADS)
# Out of commodity, let's build our multihead attention now
# "multi_head" will be responsible for the forward
multi_head = (
MultiHeadDispatch(
dim_model=EMB,
residual_dropout=DROPOUT,
num_heads=HEADS,
attention=attention,
)
.cuda()
.half()
)
# Now FW some random data
# Note that passing a per-coefficient mask makes it possible to remove extra coefficients,
# which where required by the blockification
query = torch.randn((BATCH, SEQ, EMB), requires_grad=True, device=torch.device("cuda"), dtype=dtype)
# Self attention in this particular example, no limitations really
att_val = multi_head(query=query, key=query, value=query)#, att_mask=causal_mask)
#########################################
# Bonus: compare the memory use vs dense:
def mem_use(fn, kwargs, title):
# bookeeping
import time
start = time.time()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# actually run the function
fn(**kwargs)
torch.cuda.synchronize()
stop = time.time()
# now report
max_memory = torch.cuda.max_memory_allocated() // 2 ** 20
print(f"{title} - Peak memory use: {max_memory}MB - {round((stop-start)*1e6)/1e3}ms")
pytorch_multihead = torch.nn.MultiheadAttention(
EMB, HEADS, batch_first=True, device=torch.device("cuda"), dtype=torch.float16
)
mem_use(multi_head, {"query": query, "key": query, "value": query}, "Blocksparse")
mem_use(pytorch_multihead, {"query": query, "key": query, "value": query}, "PyTorch")
결과는 아래와 같다
Blocksparse - Peak memory use: 980MB - 9.827ms
PyTorch - Peak memory use: 1504MB - 11.18ms
from torch import nn
from xformers.components import MultiHeadDispatch
from xformers.components.attention import BlockSparseAttention
class EfficientMHA(nn.Module):
def __init__(self, embed_dim, num_heads, seq_len, block_size):
super().__init__()
self.attention_layer = MultiHeadDispatch(
dim_model=embed_dim,
num_heads=num_heads,
attention=BlockSparseAttention(
layout=torch.tril(torch.ones([num_heads, seq_len // block_size, seq_len // block_size], dtype=torch.bool)),
block_size=block_size,
num_heads=num_heads,
)
)
self.norm = nn.LayerNorm(embed_dim)
def forward(self, key, query, value):
output = self.attention_layer(
key=key,
query=query,
value=value,
)
output = output + query
output = self.norm(output)
return output