dataset.load_metric('em'), squad, f1, em

Hyun·2022년 8월 1일
0

기타등등

목록 보기
11/11

datasets 패키지에 metrics도 구현되어있다는 것을 알았다.

!pip install datasets

일단 설치를 한 뒤, 구현된 metric들의 목록을 보았다.

from datasets import list_metrics
list_metrics()

['accuracy', 'bertscore', 'bleu', 'bleurt', 'cer', 'chrf', 'code_eval', 'comet', 'competition_math', 'coval', 'cuad', 'exact_match', 'f1', 'frugalscore', 'glue', 'google_bleu', 'indic_glue', 'mae', 'mahalanobis', 'matthews_correlation', 'mauve', 'mean_iou', 'meteor', 'mse', 'pearsonr', 'perplexity', 'poseval', 'precision', 'recall', 'rl_reliability', 'roc_auc', 'rouge', 'sacrebleu', 'sari', 'seqeval', 'spearmanr', 'squad', 'squad_v2', 'super_glue', 'ter', 'trec_eval', 'wer', 'wiki_split', 'xnli', 'xtreme_s', 'angelina-wang/directional_bias_amplification', 'codeparrot/apps_metric', 'cpllab/syntaxgym', 'daiyizheng/valid', 'erntkn/dice_coefficient', 'gorkaartola/my_metric', 'hack/test_metric', 'jordyvl/ece', 'kaggle/ai4code', 'kaggle/amex', 'loubnabnl/apps_metric2', 'lvwerra/bary_score', 'lvwerra/test', 'mathemakitten/harness_sentiment', 'mathemakitten/sentiment', 'mfumanelli/geometric_mean', 'mgfrantz/roc_auc_macro', 'yzha/ctc_eval']

squad

load_metric로 위의 metric중 하나인 squad 불러보자

from datasets import load_metric
metric_squad = load_metric('squad')

위의 설명처럼 f1(em보다 느슨한 평가지표)과 exact_match(em, 완벽히 알맞는가?)을 구할 수 있다.

predictions = [{'prediction_text': '1976', 'id': '56e10a3be3433e1400422b22'}]
references = [{'answers': {'answer_start': [97], 'text': ['1976']}, 'id': '56e10a3be3433e1400422b22'}]
squad_metric = datasets.load_metric("squad")
results = squad_metric.compute(predictions=predictions, references=references)
print(results)	# {'exact_match': 100.0, 'f1': 100.0}
f1, em = results['f1'], results['exact_match']

그런데, predictions과 references를 위의 형태(한 데이터당 dictionary형태를 갖춰야한다)에 맞춰야한다는 단점이 있다.
predictions : id, prediction_text의 key를 갖는 dictionary의 list
references : id, answers(answer_start, text를 갖는 dic)의 key를 갖는 dictionary의 list

가볍게 비교할 데이터만 추출했기 때문에 f1과 exact_match를 직접 불러서 사용하였다.

f1

load_metric('f1')
형태 : Metric(name: "f1", features: {'predictions': Value(dtype='int32', id=None), 'references': Value(dtype='int32', id=None)}

f1_metric = datasets.load_metric("f1")
results = f1_metric.compute(references=[0, 1, 0, 1, 0], predictions=[0, 0, 1, 1, 0])
print(results)
f1_metric = datasets.load_metric("f1")
results = f1_metric.compute(references=[0, 1, 0, 1, 0], predictions=[0, 0, 1, 1, 0])
print(results)	#  {'f1': 0.5}

exact_match

load_metric('exact_match')
형태 : Metric(name: "exact_match", features: {'predictions': Value(dtype='string', id='sequence'), 'references': Value(dtype='string', id='sequence')}

exact_match = datasets.load_metric("exact_match")
refs = ["the cat", "theater", "YELLING", "agent007"]
preds = ["cat?", "theater", "yelling", "agent"]
results = exact_match.compute(references=refs, predictions=preds, regexes_to_ignore=["the ", "yell", "YELL"], ignore_case=True, ignore_punctuation=True)
print(round(results["exact_match"], 1))	# 75.0
    

0개의 댓글