Files
modelscope/examples/pytorch/chatglm6b/text_generation_metric.py
tastelikefeet 45cf0035f4 fix chatglm2 evaluation error: hypothesis emtpy (#348)
* fix evaluation error: hypothesis emtpy

* fix pipeline

* fix bug
2023-07-03 23:16:38 +08:00

86 lines
2.8 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Dict, Iterable, List
import jieba
import numpy as np
from nltk.translate.bleu_score import (SmoothingFunction, corpus_bleu,
sentence_bleu)
from rouge import Rouge
from modelscope.metainfo import Metrics
from modelscope.metrics.base import Metric
from modelscope.metrics.builder import METRICS, MetricKeys
from modelscope.utils.chinese_utils import rebuild_chinese_str
from modelscope.utils.registry import default_group
@METRICS.register_module(group_key=default_group, module_name='chatglm')
class TextGenerationMetric(Metric):
def __init__(self, target_text='tgts', pred_text='preds'):
self.preds: List[str] = []
self.tgts: List[str] = []
self.rouge = Rouge()
self.target_text = target_text
self.pred_text = pred_text
def add(self, outputs: Dict[str, List[str]], inputs: Dict[str, List[str]]):
ground_truths = inputs[self.target_text]
eval_results = outputs[self.pred_text]
for truth in ground_truths:
self.tgts.append(truth)
for result in eval_results:
self.preds.append(result)
def _check(self, pred: str, tgt: str) -> bool:
def remove_useless(string: str) -> str:
return string.replace(' ', '').replace('.', '')
return len(remove_useless(pred)) != 0 and len(remove_useless(tgt)) != 0
def evaluate(self):
preds, labels = self.preds, self.tgts
if isinstance(preds, tuple):
preds = preds[0]
score_dict = {
'rouge-1': [],
'rouge-2': [],
'rouge-l': [],
'bleu-4': []
}
for pred, label in zip(preds, labels):
hypothesis = list(jieba.cut(pred))
if len(hypothesis) == 0 or ''.join(hypothesis) == '.':
hypothesis = ['</s>']
reference = list(jieba.cut(label))
rouge = Rouge()
scores = rouge.get_scores(' '.join(hypothesis),
' '.join(reference))
result = scores[0]
for k, v in result.items():
score_dict[k].append(round(v['f'] * 100, 4))
bleu_score = sentence_bleu(
[list(label)],
list(pred),
smoothing_function=SmoothingFunction().method3)
score_dict['bleu-4'].append(round(bleu_score * 100, 4))
for k, v in score_dict.items():
score_dict[k] = float(np.mean(v))
return score_dict
def merge(self, other: 'TextGenerationMetric'):
self.preds.extend(other.preds)
self.tgts.extend(other.tgts)
def __getstate__(self):
return self.preds, self.tgts
def __setstate__(self, state):
self.__init__()
self.preds, self.tgts = state