[llama3/llama/generation.py][class Llama] def build def __init__

ma-kjh·2024년 8월 27일
0

LLM

목록 보기
7/15

class llama에 대해 알아보자.

class Llama:
    @staticmethod
    def build(
        ckpt_dir: str,
        tokenizer_path: str,
        max_seq_len: int,
        max_batch_size: int,
        model_parallel_size: Optional[int] = None, # ?
        seed: int = 1,
    ) -> "Llama":

해당 build 함수는 모델 체크포인트를 로딩하고 initializing해서 Llama instance를 빌드하는 과정.

Args:
가장 먼저 build가 정의되어 있다.

  • ckpt_dir (str) : checkpoint file이 들어있는 directory의 path
  • tokenizer_path (str) : tokenizer file의 path
  • max_seq_len (int) : 입력 텍스트의 Maximum sequence length를 정의한다.
  • max_batch_size (int) : inference에서의 Maximum batch size를 의미한다.
  • model_parallel_size (Optional[int], optional): model parallel process의 개수들을 의미한다. 입력해주지 않으면 None으로 알아서 결정됨.

Returns:

  • Llama: Llama class의 인스턴스를 내놓으며, 모델과 토크나이저가 loaded되어있음.

Raises:

  • AssertionError: checkpoint file이 없거나, parallel size가 맞지 않는경우

Note:

  • 이 방법은 distributed process group을 initialize하고 device를 CUDA로 설정한다.
		assert 1 <= max_seq_len <= 8192, f"max_seq_len must be between 1 and 8192, got {max_seq_len}." # Llama 3기준인데, 3.1은 어케되어있나..
        assert os.path.isdir(ckpt_dir), f"Checkpoint directory '{ckpt_dir}' does not exits."
        assert os.path.isfile(tokenizer_path), f"Tokenizer file '{tokenizer_path}' does not exist."
		if not torch.distributed.is_initialized():
        	torch.distributed.init_process_group("nccl")
        if not model_parallel_is_initialized():
        	if model_parallel_size is None:
            	model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
            initialize_model_parallel(model_parallel_size) # 나중에 살펴보기
            local_rank = int(os.environ.get("LOCAL_RANK", 0))
            torch.cuda.set_device(local_rank)
            
            # seed must be the same in all processes
            torch.manual_seed(seed)
            
            if local_rank > 0:
            	sys.stdout = open(os.devnull, "w")
            
            start.time = time.time()
            checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
            assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
            assert model_parallel_size == len(
            	checkpoints
        	), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
        	ckpt_path = checkpoints[get_model_parallel_rank()]
        	checkpoint = torch.load(ckpt_path, map_location="cpu")
        	with open(Path(ckpt_dir) / "params.json", "r") as f:
            	params = json.loads(f.read())
            
  • 여기까지는 모델을 parallel하게 불러오는 과정을 설명하는 것으로 보임.
		model_args: ModelArgs = ModelArgs(
        	max_seq_len=max_seq_len,
            max_batch_size=max_batch_size,
            **params,
        )
  • 모델 args를 llama/model.py에 있는 class ModelArgs에서 불러온다.
@dataclass
class ModelArgs:
    dim: int = 4096
    n_layers: int = 32
    n_heads: int = 32
    n_kv_heads: Optional[int] = None
    vocab_size: int = -1
    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
    ffn_dim_multiplier: Optional[float] = None
    norm_eps: float = 1e-5
    rope_theta: float = 500000

    max_batch_size: int = 32
    max_seq_len: int = 2048
  • max_batch_size -> inference시에 사용할 수 있는 최대 배치수를 32,
  • max_seq_len -> 입력할 수 있는 최대 seq length를 2048로 default로 정의되어있음.
		tokenizer = Tokenizer(model_path=tokenizer_path)
        assert model_args.vocab_size == tokenizer.n_words
        if torch.cuda.is_bf16_supported():
            torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
        else:
            torch.set_default_tensor_type(torch.cuda.HalfTensor)
  • tokenizer도 가져온다. 이 때, tokenizer가 가지고 있는 word 수 (아마도 vocab size ? )와 모델이 설정한 vocab size가 같아야 한다. (당연한 부분인데, LLM 사전학습에서 사용했던 tokenizer를 통해 vocab를 예측하는데, vocab개수나 그 의미가 다르면 말이 안됨.)

  • torch.cuda.is_bf16_supported() 지원하면 BFloat16Tensor로 바꾸는데 일반적으로는 학습에서 BFloat16을 사용하는 것이 바람직 한 것으로 보인다. (larger exponent range 8 bit -126 ~ 127). inference 시에는 FP16(Half-Precision Float, Exponent bit 5, -14 to +15)

		model = Transformer(model_args)
        model.load_state_dict(checkpoint, strict=False)
        print(f"Loaded in {time.time() - start_time:.2f} seconds")
        
        return Llama(model, tokenizer)
	def __init__(self, model: Transformer, tokenizer: Tokenizer):
    	self.model = model
        self.tokenizer = tokenizer
        self.formatter = ChatFormat(tokenizer)

ChatFormat은 여기 들어가면 나옴.


Explanation of @staticmethod

  1. @staticmethod Decorator: The @staticmethod decorator is used to define a method in a class that does not operate on an instance of the class (i.e., it doesn't require access to self or cls). This means the method can be called on the class itself rather than on an instance of the class.
  • No Access to Instance or Class: A static method doesn't have access to the instance (self) or the class (cls). It behaves just like a regular function, but it belongs to the class's namespace and can be called using the class name.
  • Usage: Static methods are used when you need a function that logically belongs to a class but doesn't need to interact with the class or its instances.
class MathOperations:
    @staticmethod
    def add(a, b):
        return a + b

result = MathOperations.add(5, 3)  # Called directly on the class
  • Initialization (__init__): Handles basic setup, like storing parameters that are always required.
  • Building (build): Handles complex setup, like loading a pre-trained model from a checkpoint, configuring the tokenizer, and setting up any optional configurations.

In the Context of BFloat16 and FP16

  • BFloat16 (Brain Float 16-bit):
    - Exponent bits : 8 bits

    • Exponent range : -126 to +127

    • Effective rage : Approximately 103810^{-38} to 3.4×10383.4\times 10^{38}

      Bfloat16 has the same exponent range as FP32 (standard 32-bit floating-point), allowing it to represent a wide range of values, but with less precision in the significand (mantissa)

  • FP16 (Half-Precision Float):
    - Exponent bits : 5 bits

    • Exponent range : -14 to +15

    • Effective range : Approximately 6.1×1056.1\times 10^{-5} to 6550465504

      FP16 has a narrower exponent range, meaning it can't represent as large or as small values as BFloat16 or FP32. However, it has more precision than BFloat16 within this smaller range.

  • In FP16: With 10 bits in the significand, FP16 can more closely approximate the value of 1.1 (it might represent it as 1.099609375).

  • In BFloat16: With only 7 bits in the significand, BFloat16 has less precision and might represent 1.1 as 1.125 or 1.0625, which is less accurate.

Trade-off:

  • BFloat16 is preferred for tasks where the range of values is more critical (e.g., deep learning training with large dynamic ranges).
  • FP16 is better when precision within a smaller range is more important (e.g., certain types of inference where values don't need to span as wide a range).

The precision in the significand (mantissa) of a floating-point number is related to how many bits are used to represent the fractional part of the number. The exact decimal values that can be represented are determined by the binary fractions like 121,122,123\frac{1}{2^1}, \frac{1}{2^2}, \frac{1}{2^3} and so on.

Breakdown of the 10-bit Significand

  • The 10-bit significand in FP16 can represent values as binary fractions.

  • Each bit in the significand represents a power of 12\frac{1}{2}:

    • The first bit (after the implicit leading 1) represents 121=0.5\frac{1}{2^1} = 0.5.
    • The second bit represents 122=0.25\frac{1}{2^2} = 0.25.
    • The third bit represents 123=0.125\frac{1}{2^3} = 0.125.
    • And so on, up to the tenth bit, which represents 1210\frac{1}{2^{10}}.

Example: Using the 10-bit Significand

Let’s say you have a 10-bit binary significand like 1.0001100110:

  • The 1. is implicit and represents the leading 1.
  • The 0001100110 part is the 10-bit significand.

In decimal, this binary number represents:

1+021+022+023+124+125+026+027+128+129+02101 + \frac{0}{2^1} + \frac{0}{2^2} + \frac{0}{2^3} + \frac{1}{2^4} + \frac{1}{2^5} + \frac{0}{2^6} + \frac{0}{2^7} + \frac{1}{2^8} + \frac{1}{2^9} + \frac{0}{2^{10}}

Which simplifies to:

1+0+0+0+0.0625+0.03125+0+0+0.00390625+0.001953125+01.0996093751 + 0 + 0 + 0 + 0.0625 + 0.03125 + 0 + 0 + 0.00390625 + 0.001953125 + 0 \approx 1.099609375

Bit-Level Precision

  • 1 bit: Represents a range of 121=0.5\frac{1}{2^1} = 0.5
  • 2 bits: Can represent up to 122=0.25\frac{1}{2^2} = 0.25
  • ...
  • 10 bits: Can represent up to 12100.0009765625\frac{1}{2^{10}} \approx 0.0009765625

Summary

  • 10 bits in the significand allow for a precision of up to 1210\frac{1}{2^{10}}, or approximately 0.0009765625.
  • The number of bits in the significand determines how finely you can represent fractional values between whole numbers.
profile
거인의 어깨에 올라서서 더 넓은 세상을 바라보라 - 아이작 뉴턴

0개의 댓글