[llama3/llama/generation.py][class Llama] def sample_top_p

ma-kjh·2024년 8월 30일
0

LLM

목록 보기
13/13
def sample_top_p(probs, p):
	"""
    Perform top-p (nucleus) sampling on a probability distribution.
    
    Args:
    	probs (torch.Tensor): Probability distribution tensor.
        p (float): Probability threshold for top-p sampling.
        
    Returns:
    	torch.Tensor: Sampled token indices.
    
    Note:
    	Top-p sampling selects the smallest set of tokens whose cumulative probability mass exceeds the threshold p. The distribution is renormalized based on the selected tokens.
    """
    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
    probs_sum = torch.cumsum(probs_sort, dim=-1)
    mask = probs_sum - probs_sort > p # probs sum에서 해당 prob값을 뺏을 때, 나머지. 한마디로 높은애들 중에서 자기자신 뺀 건데, 그 값이 0.95(threshold)보다 크면,(자기자신이 차지하는 비율이 작다는 얘기인듯.) masking한다는거.(True)
    probs_sort[mask] = 0.0
    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
    next_token = torch.multinomial(probs_sort, num_samples=1)
    next_token = torch.gather(probs_idx, -1, next_token)
 	return next_token
    

LLM의 최종 output (next token)은 결국 sampling을 통해서 결정된다.
이 때 LLaMA3.1에서 사용되는 method는 sample_top_p이다.

해당 함수는 probsp를 입력으로 받는데 각각은 다음을 의미한다.

probs (torch.Tensor) : Linear Layer를 통과한 logits값들 주에서 마지막 위치(바로 다음을 예측)에 해당되는 token의 softmax probability. 이 때 temperature가 반영되어 계산된다(llama generate.py).

probs = torch.softmax(logits[:, -1]/ temperature, dim=-1)

그리고 p 같은 경우는 threshold를 의미하는데 상위 logit에 해당되는 token들 중에서 하나를 선택하겠다는 의미로 받아들일 수 있다.

-> 여기서 만약에 temperature가 높은(1.0에 가까운) 경우, logit값들은 12만개의 골고루 뿌려진 형태의 distribution을 형성할 것이다. 반면에 temperature가 낮은(0.0에 가까운 경우) 하나의 logit값이 매우 크게 반영이 되겠고, 샘플링을 진행할 때 가장 높은 토큰만을 선택하게 될 것임.

In the context of the sample_top_p function you've provided, "cumulative probability mass" refers to the sum of probabilities for a sequence of tokens when sorted in descending order of their individual probabilities.

Explanation of Cumulative Probability Mass:

  1. Probability Distribution (probs):

    • This is a tensor representing the probabilities of different tokens in a vocabulary. Each entry in the tensor corresponds to the likelihood that the token appears as the next token in the generated sequence.
  2. Sorting the Probabilities:

    • probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
    • The probabilities are sorted in descending order (probs_sort), and probs_idx holds the corresponding indices of the tokens after sorting.
  3. Calculating Cumulative Probability Mass:

    • probs_sum = torch.cumsum(probs_sort, dim=-1)
    • Here, torch.cumsum computes the cumulative sum of the sorted probabilities. This means each entry in probs_sum contains the sum of all previous probabilities up to that point in the sorted list.

    For example, if probs_sort = [0.4, 0.3, 0.2, 0.1], then probs_sum would be [0.4, 0.7, 0.9, 1.0].

  4. Top-p Sampling:

    • The function aims to select the smallest set of tokens whose cumulative probability mass (sum) exceeds a threshold p.
    • The mask mask = probs_sum - probs_sort > p identifies tokens where the cumulative probability exceeds p after subtracting the current token's probability. These tokens will be excluded from sampling.
  5. Renormalization and Sampling:

    • The probabilities of tokens not excluded by the mask are renormalized to sum to 1.
    • next_token = torch.multinomial(probs_sort, num_samples=1) samples from the renormalized distribution.

Cumulative Probability Mass in This Context:

  • The cumulative probability mass is essentially the running total of the sorted probabilities as you iterate over the list.
  • For a given threshold p, top-p sampling involves including tokens until this cumulative mass exceeds p. The intuition is that you only sample from the most probable tokens that together cover at least p percent of the total probability distribution, making the sampling process focus on the most likely tokens while still allowing for some diversity.

In summary, cumulative probability mass in this situation refers to the sum of probabilities up to a certain point in the sorted list, and it is used to decide which tokens are eligible for sampling in the top-p (nucleus) sampling method.

probs_sort = [0.4/0.9, 0.3/0.9, 0.2/0.9, 0.0][0.444, 0.333, 0.222, 0.0]
         ```
profile
거인의 어깨에 올라서서 더 넓은 세상을 바라보라 - 아이작 뉴턴

0개의 댓글