evaluation 에서 OOM 에러

AFL·2023년 12월 7일
0

Trainer(...) 에 preprocess_logits_for_metrics=preprocess_logits_for_metrics 를 넘겨주었는지 확인하자.

def preprocess_logits_for_metrics(logits, labels):
        if isinstance(logits, tuple):
            # Depending on the model and config, logits may contain extra tensors,
            # like past_key_values, but logits always come first
            logits = logits[0]
        return logits.argmax(dim=-1)

nested_concat 가 GPU memory errors 를 일으켰다. np.argmax(logits, axis=-1) 를 통해 output logit vector 차원을 줄여서 문제 해결

https://github.com/huggingface/transformers/issues/8476#issuecomment-738773626

profile
공부해서 남주자

0개의 댓글

관련 채용 정보