언어 모델로 생성을 해 보면, 말이 너무 많다고 느껴진다. 문장이 끝나지 않고 계속 이어지는 경우가 많다. 게다가 이어지는 문장은 내가 원하는 문장이 아니다. 그럴 때 사용하는 것이 Stop Sequence 이다.
Stop Sequence란, 모델이 더 이상 토큰을 생성하지 않고 멈추는 조건이 되는 문자열이다.
예를 들어, 아래의 예시에서 ### 답변:에 해당하는 줄까지만 생성하기를 원한다고 해보자.
### 지시문: 내일 먹을 저녁 메뉴 추천 해줘.
### 답변: 내일 먹을 저녁 메뉴로는 간장 게장을 추천드립니다.
### 지시문: 내일 먹을 저녁은 뭐가 ...
그런데 모델은 그 뒤로도 ### 지시문:과 이어지는 문장들을 만들어버렸다. 이 경우에 우리는 ### 지시문을 stop sequence로 추가해서, 모델이 필요 이상으로 생성하는 것을 막을 수 있다.
ChatGPT API를 살펴보면, stop이라는 매개변수에 stop seqeunce를 전달해서 모델이 생성하는 것을 멈출 수 있다.
stop (string, array, null | Optional | Defaults to null)
-- Up to 4 sequences where the API will stop generating further tokens.
그렇지만, 나는 허깅페이스의 transformers에서도 stop sequence를 사용하고 싶었다.
Transformers Docs를 보면, generate 메소드 사용 시에 stopping_criteria 라는 매개변수를 줄 수 있다. (예를 들면 아래 코드처럼)
outputs = model.generate(
**inputs,
max_new_tokens=5,
num_beams=4,
num_return_sequences=4,
return_dict_in_generate=True,
output_scores=True,
stopping_criteria = stopping_criteria # Like this!
)
문제는.. 저 자리에 들어가는 stopping_criteria를 어떻게 넣어줘야 하는지에 대한 설명이 많지 않다는 것이다.
stopping_criteria (StoppingCriteriaList, optional)
— Custom stopping criteria that complement the default stopping criteria built from arguments and a generation config. If a stopping criteria is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users.
사실상 StoppingCriteriaList가 들어가야 한다는 것만 알 수 있었다.
처음에는 토크나이징해서 input_ids를 때려 넣으려고(?) 했는데, (당연히) 되지 않았다.
여기저기 서치하다가, Class를 상속받아 사용하는 해결 방법을 보게 되었고 약간의 힌트를 얻긴 했지만 명확하게 이해가 되지는 않았다.
그래서 StoppingCriteriaLIst가 정의된 소스코드를 살펴보기로 했다.
class StoppingCriteriaList(list):
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return any(criteria(input_ids, scores) for criteria in self)
@property
def max_length(self) -> Optional[int]:
for stopping_criterium in self:
if isinstance(stopping_criterium, MaxLengthCriteria): # Here!
return stopping_criterium.max_length
elif isinstance(stopping_criterium, MaxNewTokensCriteria):
return stopping_criterium.max_length
return None
max_length 함수 안을 살펴보면, StoppingCriteriaList의 요소가 MaxLengthCriteria의 인스턴스인지 확인하는 부분이 있다.
generate()시에 모델이 생성할 최대 토큰 길이도 지정할 수 있는데, 이 기능이 일종의 StoppingCriteria라는 걸 알게 되었다.
같은 파일 위쪽 라인에서 MaxLengthCriteria를 찾아왔다.
StoppingCriteria를 상속받아서, __call__()에서 (호출이 되면) 초기에 입력한 프롬프트와 지금까지 생성한 토큰들의 input_ids를 바탕으로 생성을 멈추어야하는 조건을 만족했는지에 대해 True/False로 결과값을 return한다.
또, StoppingCriteria는 'subclassed'되어야 한다는 에러 문구도 보인다.
class StoppingCriteria(ABC):
"""Abstract base class for all stopping criteria that can be applied during generation.
If your stopping criteria depends on the `scores` input, make sure you pass `return_dict_in_generate=True,
output_scores=True` to `generate`.
"""
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
# Look at the Error Message
raise NotImplementedError("StoppingCriteria needs to be subclassed")
class MaxLengthCriteria(StoppingCriteria):
"""
This class can be used to stop generation whenever the full generated number of tokens exceeds `max_length`. Keep
in mind for decoder-only type of transformers, this will include the initial prompted tokens.
Args:
max_length (`int`):
The maximum length that the output sequence can have in number of tokens.
max_position_embeddings (`int`, *optional*):
The maximum model length, as defined by the model's `config.max_position_embeddings` attribute.
"""
def __init__(self, max_length: int, max_position_embeddings: Optional[int] = None):
self.max_length = max_length
self.max_position_embeddings = max_position_embeddings
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
cur_len = input_ids.shape[-1]
# 현재 길이와 max_length 비교
is_done = cur_len >= self.max_length
if self.max_position_embeddings is not None and not is_done and cur_len >= self.max_position_embeddings:
logger.warning_once(
"This is a friendly reminder - the current text generation call will exceed the model's predefined "
f"maximum length ({self.max_position_embeddings}). Depending on the model, you may observe "
"exceptions, performance degradation, or nothing at all."
)
return is_done
이를 통해, StoppingCriteria는
StoppingCriteria의 자식 Class를 생성한 뒤__call__함수를 정의해서 사용한다.__call__함수는 토큰을 하나 생성할 때마다 어떠한 조건을 만족했는지에 대한 여부를 boolean으로 반환한다.라는 사실을 알 수 있었다.
또한, transformers/src/transformers/generation/utils.py 파일을 보면 get_stopping_criteria라는 함수가 정의되어있는데,
def _get_stopping_criteria(
self, generation_config: GenerationConfig, stopping_criteria: Optional[StoppingCriteriaList]
) -> StoppingCriteriaList:
criteria = StoppingCriteriaList() # instance 생성
if generation_config.max_length is not None:
max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
# MaxLengthCriteria 추가
criteria.append(
MaxLengthCriteria(
max_length=generation_config.max_length,
max_position_embeddings=max_position_embeddings,
)
)
if generation_config.max_time is not None:
criteria.append(MaxTimeCriteria(max_time=generation_config.max_time))
criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
return criteria
StoppingCriteriaList의 인스턴스인 criteria 변수에MaxLengthCriteria가 append 메소드로 추가되고 있다.(criteria.append(MaxLengthCriteria..))즉, StoppingCriteriaList는 변수명 그대로 StoppingCriteria들을 담는 리스트이다.
알게 된 내용들을 바탕으로 자식클래스인 CustomStoppingCriteria 를 구현했다.
class CustomStoppingCriteria(StoppingCriteria):
def __init__(self, stop_ids:torch.Tensor):
self.stop_ids = stop_ids[0] # 토큰 id 리스트
self.n = stop_ids.shape[1] # stop sequence의 길이
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> bool:
'''
길이가 n인 stop seqeunce와, 마지막으로 생성한 n개의 토큰을 비교하여
모두 같으면 True 반환
'''
should_stop =False
if input_ids.shape[1] > self.n-1:
last_n_ids = input_ids[0][-self.n:] # 마지막으로 생성한 n개의 토큰
for i in range(self.n):
if self.stop_ids[i] != last_n_ids[i]: # stop sequence와 비교
should_stop = False
break
else :
should_stop = True
return should_stop
그리고 아래와 같이 사용했더니 잘 적용이 된다!
# Tokenize My Stop Sequence
stop_sequence = self.tokenizer(" ### 지시문", return_tensors = 'pt').input_ids.to(self.device)
# Make Subclass Instance
stopping_criteria = CustomStoppingCriteria(stop_sequence)
stopping_criteria = StoppingCriteriaList([stopping_criteria]))
# Generate
output_ids = self.model.generate(inputs=input.input_ids,
max_length=150,
stopping_criteria = stopping_criteria
** 만약, stop_sequence로 여러 개를 사용하고 싶다면 코드에 수정이 필요하다.
누군가에게 도움이 되길 'ㅅ'