[llama3/llama/model.py][class Transformer] def __init__ and def forward

ma-kjh·2024년 8월 27일
0

LLM

목록 보기
8/15

[llama3/llama/generation.py][class Llama] def build 에서 마지막에 model을 불러올 때 Transformer class를 사용해서 불러온다.

class Transformer(nn.Module):
	def __init__(self, params: ModelArgs):
    	super().__init__()
        self.params = params
        self.vocab_size = prams.vocab_size # default가 -1인데..
        self.n_layers = params.n_layers # 32 decoder layer를 쌓는다.
        
        self.tok_embeddings = VocabParallelEmbedding(
        	params.vocab_size, params.dim, init_method=lambda x: x
        ) # 뭐지이건
        
        self.layers = torch.nn.ModuleList()
        for layer_id in range(params.n_layers):
        	self.layers.append(TransformerBlock(layer_id, params)) # 미리 정의한 TransformerBlock을 32개 쌓아준다.
        self.norm = RMSNorm(params.dim, eps=params.norm_eps) # RMSNorm이라는 걸 사용하는데, 나중에 살펴보기.
        self.output = ColumnParallelLinear(
        	params.dim, params.vocab_size, bias=False, init_method=lambda x: x # ColumnParallelLinear (살펴보기)
        ) 
        
        self.freqs_cis = precompute_freqs_cis(
        	params.dim // params.n_heads,
            params.max_seq_len * 2,
            params.rope_theta,
        ) # 아마도 positinal embedding 하는 과정 같음.        
  • 여기까지가 __init__ 하는 부분인데, 사실 어려운 부분은 없고 처음보는 것들이 몇몇 있다.

  • VocabParallelEmbedding : from fairscale.nn.model_parallel.layers import VocabParallelEmbedding 으로 되어있는데, 나중에 보자.

  • self.norm = RMSNorm(*) 이 부분에서 처음보는 Norm이 추가되는데 Llama에서 사용되는 거니까 나중에 알아보기.

  • ColumnParallelLinear : Vocab* 랑 마찬가지.

  • freqs_cis : 이것도.. 뭐지

그럼 forward보면서 좀 살펴보면

	@torch.inference_mode() # 이렇게 하면 훈련은 어떻게해 .. ? 바꿀 수 있는건가. 아님 필요 없나 forward에서는..
    def forward(self, tokens: torch.Tensor, start_pos: int): # start_pos ?
    	_bsz, seqlen = token.shape # 들어오는 입력 token [batch_size, sequence_length] 로 들어올테니, shape를 받는 것으로 보임.
        h = self.tok_embeddings(tokens) # 아무래도 그냥 token to embedding으로 보임. [batch_size, sequence_length, emb_dim] 이겠지 ?
        self.freqs_cis = self.freqs_cis.to(h.device)
        freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] # 시작 포지션부터, seq len 끝낼때까지만 가져와서 사용하는 것으로 보임.
        mask = None
        if seqlen > 1:
        	mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device) # 일단 (seqlen, seqlen) matrix에서 float `-inf`로 가득채움.
            
            mask = torch.triu(mask, diagonal=1) # diagonal=1 이면 diagonal까지 0으로 만들어버림. 아래부분 다 0
            
            # When performing key-value caching, we compute the attention scores
            # only for the new sequence. Thus, the matrix of scores is of size
            # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
            # j > cache_len + i, since row i corresponds to token cache_len + i.
            mask = torch.hstack(
                [torch.zeros((seqlen, start_pos), device=tokens.device), mask]
            ).type_as(h)

        for layer in self.layers:
            h = layer(h, start_pos, freqs_cis, mask)
        h = self.norm(h)
        output = self.output(h).float()
        return output
            

>>> a = torch.randn(3, 3)
>>> a
tensor([[ 0.2309,  0.5207,  2.0049],
        [ 0.2072, -1.0680,  0.6602],
        [ 0.3480, -0.5211, -0.4573]])
>>> torch.triu(a, diagonal=1)
tensor([[ 0.0000,  0.5207,  2.0049],
        [ 0.0000,  0.0000,  0.6602],
        [ 0.0000,  0.0000,  0.0000]])

The @torch.inference_mode() decorator in PyTorch is used to improve the efficiency of inference (i.e., the process of making predictions with a trained model) by disabling certain features that are only necessary during training. Here's what it does:

Functions of @torch.inference_mode()

  1. Disables Gradient Calculation:

    • Similar to torch.no_grad(), torch.inference_mode() turns off gradient calculation, which is unnecessary during inference and can save memory and computation. This is useful because gradients are only needed when updating model parameters during training, not during inference.
  2. Optimizes Memory Usage:

    • torch.inference_mode() is more aggressive than torch.no_grad() in terms of optimizations. It can lead to further memory savings and performance improvements by skipping certain operations that are irrelevant during inference, such as version counter updates on tensors.
  3. Read-Only Operations:

    • The context ensures that all operations performed within it are considered read-only, which helps the PyTorch engine optimize even more during inference.

When to Use @torch.inference_mode()

  • Inference Time: When you're using a model purely for inference and want to maximize performance and minimize memory usage.
  • Deployment: In production environments where models are deployed for making predictions, torch.inference_mode() helps ensure that the system runs as efficiently as possible.

Example Usage

@torch.inference_mode()
def predict(model, inputs):
    return model(inputs)

# This will run the model in inference mode, with optimizations applied.
outputs = predict(my_model, my_inputs)

In this example, the predict function will execute without tracking gradients, with all the additional optimizations provided by torch.inference_mode(). This is particularly beneficial when deploying models in environments where performance is critical.

Summary

@torch.inference_mode() is a decorator that optimizes the inference process by disabling gradient calculations and applying additional optimizations, making it more efficient than torch.no_grad() for read-only operations during inference.

profile
거인의 어깨에 올라서서 더 넓은 세상을 바라보라 - 아이작 뉴턴

0개의 댓글