mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 20:19:22 +01:00
fix eval RecursionError
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/13099203
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from typing import Dict, Iterable, List
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, Iterable, List, Tuple
|
||||
|
||||
from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu
|
||||
from rouge import Rouge
|
||||
@@ -55,7 +57,8 @@ class TextGenerationMetric(Metric):
|
||||
def mean(iter: Iterable) -> float:
|
||||
return sum(iter) / len(self.preds)
|
||||
|
||||
rouge_scores = self.rouge.get_scores(hyps=preds, refs=tgts)
|
||||
with extend_recursion_limit(preds, tgts):
|
||||
rouge_scores = self.rouge.get_scores(hyps=preds, refs=tgts)
|
||||
rouge_1 = mean(map(lambda score: score['rouge-1']['f'], rouge_scores))
|
||||
rouge_l = mean(map(lambda score: score['rouge-l']['f'], rouge_scores))
|
||||
|
||||
@@ -87,3 +90,14 @@ class TextGenerationMetric(Metric):
|
||||
def __setstate__(self, state):
|
||||
self.__init__()
|
||||
self.preds, self.tgts = state
|
||||
|
||||
|
||||
@contextmanager
|
||||
def extend_recursion_limit(preds: Tuple[str], tgts: Tuple[str]):
|
||||
origin_limit = sys.getrecursionlimit()
|
||||
new_limit = max(len(pred)
|
||||
for pred in preds) * max(len(tgt) for tgt in tgts)
|
||||
if new_limit > origin_limit:
|
||||
sys.setrecursionlimit(new_limit)
|
||||
yield
|
||||
sys.setrecursionlimit(origin_limit)
|
||||
|
||||
Reference in New Issue
Block a user