Trainer(...) 에 preprocess_logits_for_metrics=preprocess_logits_for_metrics
를 넘겨주었는지 확인하자.
def preprocess_logits_for_metrics(logits, labels):
if isinstance(logits, tuple):
# Depending on the model and config, logits may contain extra tensors,
# like past_key_values, but logits always come first
logits = logits[0]
return logits.argmax(dim=-1)
nested_concat 가 GPU memory errors 를 일으켰다. np.argmax(logits, axis=-1)
를 통해 output logit vector 차원을 줄여서 문제 해결
https://github.com/huggingface/transformers/issues/8476#issuecomment-738773626