fix eval RecursionError

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/13099203
This commit is contained in:
hemu.zp
2023-06-29 10:47:54 +08:00
parent 88de9f78aa
commit 612f0ebbc4

View File

@@ -1,6 +1,8 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Dict, Iterable, List
import sys
from contextlib import contextmanager
from typing import Dict, Iterable, List, Tuple
from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu
from rouge import Rouge
@@ -55,7 +57,8 @@ class TextGenerationMetric(Metric):
def mean(iter: Iterable) -> float:
return sum(iter) / len(self.preds)
rouge_scores = self.rouge.get_scores(hyps=preds, refs=tgts)
with extend_recursion_limit(preds, tgts):
rouge_scores = self.rouge.get_scores(hyps=preds, refs=tgts)
rouge_1 = mean(map(lambda score: score['rouge-1']['f'], rouge_scores))
rouge_l = mean(map(lambda score: score['rouge-l']['f'], rouge_scores))
@@ -87,3 +90,14 @@ class TextGenerationMetric(Metric):
def __setstate__(self, state):
self.__init__()
self.preds, self.tgts = state
@contextmanager
def extend_recursion_limit(preds: Tuple[str], tgts: Tuple[str]):
origin_limit = sys.getrecursionlimit()
new_limit = max(len(pred)
for pred in preds) * max(len(tgt) for tgt in tgts)
if new_limit > origin_limit:
sys.setrecursionlimit(new_limit)
yield
sys.setrecursionlimit(origin_limit)