From 612f0ebbc459e3219ae67fcd6c7a6bc13a220bcf Mon Sep 17 00:00:00 2001 From: "hemu.zp" Date: Thu, 29 Jun 2023 10:47:54 +0800 Subject: [PATCH] fix eval RecursionError Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/13099203 --- modelscope/metrics/text_generation_metric.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/modelscope/metrics/text_generation_metric.py b/modelscope/metrics/text_generation_metric.py index 95947d3e..8ad65eaf 100644 --- a/modelscope/metrics/text_generation_metric.py +++ b/modelscope/metrics/text_generation_metric.py @@ -1,6 +1,8 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from typing import Dict, Iterable, List +import sys +from contextlib import contextmanager +from typing import Dict, Iterable, List, Tuple from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu from rouge import Rouge @@ -55,7 +57,8 @@ class TextGenerationMetric(Metric): def mean(iter: Iterable) -> float: return sum(iter) / len(self.preds) - rouge_scores = self.rouge.get_scores(hyps=preds, refs=tgts) + with extend_recursion_limit(preds, tgts): + rouge_scores = self.rouge.get_scores(hyps=preds, refs=tgts) rouge_1 = mean(map(lambda score: score['rouge-1']['f'], rouge_scores)) rouge_l = mean(map(lambda score: score['rouge-l']['f'], rouge_scores)) @@ -87,3 +90,14 @@ class TextGenerationMetric(Metric): def __setstate__(self, state): self.__init__() self.preds, self.tgts = state + + +@contextmanager +def extend_recursion_limit(preds: Tuple[str], tgts: Tuple[str]): + origin_limit = sys.getrecursionlimit() + new_limit = max(len(pred) + for pred in preds) * max(len(tgt) for tgt in tgts) + if new_limit > origin_limit: + sys.setrecursionlimit(new_limit) + yield + sys.setrecursionlimit(origin_limit)