diff --git a/modelscope/pipelines/nlp/text_error_correction_pipeline.py b/modelscope/pipelines/nlp/text_error_correction_pipeline.py index 4abaaca1..39fcdcc1 100644 --- a/modelscope/pipelines/nlp/text_error_correction_pipeline.py +++ b/modelscope/pipelines/nlp/text_error_correction_pipeline.py @@ -75,17 +75,14 @@ class TextErrorCorrectionPipeline(Pipeline): 'output': '随着中国经济突飞猛进,建造工业与日俱增' } - """ - sc_sent = [] - for sent in inputs['predictions']: - pred_str = self.vocab.string( - sent, '@@', extra_symbols_to_ignore={self.vocab.pad()}) - sc_sent.append(''.join(pred_str.split())) + sc_tensor = inputs['predictions'] + if isinstance(sc_tensor, list): + sc_tensor = sc_tensor[0] + sc_sent = self.vocab.string( + sc_tensor, '@@', extra_symbols_to_ignore={self.vocab.pad()}) - # for consistent with old version - if len(sc_sent) == 1: - sc_sent = sc_sent[0] + sc_sent = ''.join(sc_sent.split()) return {OutputKeys.OUTPUT: sc_sent} diff --git a/modelscope/preprocessors/base.py b/modelscope/preprocessors/base.py index d9b2836f..4161c4b1 100644 --- a/modelscope/preprocessors/base.py +++ b/modelscope/preprocessors/base.py @@ -306,7 +306,7 @@ class Preprocessor(ABC): preprocessor.mode = preprocessor_mode sub_cfg.pop('model_dir', None) if not hasattr(preprocessor, 'cfg'): - preprocessor.cfg = cfg + preprocessor.cfg = sub_cfg return preprocessor def save_pretrained(self, diff --git a/modelscope/preprocessors/nlp/text_error_correction.py b/modelscope/preprocessors/nlp/text_error_correction.py index e3a1433d..e13953e3 100644 --- a/modelscope/preprocessors/nlp/text_error_correction.py +++ b/modelscope/preprocessors/nlp/text_error_correction.py @@ -17,7 +17,11 @@ class TextErrorCorrectionPreprocessor(Preprocessor): """The preprocessor used in text correction task. """ - def __init__(self, model_dir: str, *args, **kwargs): + def __init__(self, + model_dir: str, + max_length: int = None, + *args, + **kwargs): from fairseq.data import Dictionary """preprocess the data via the vocab file from the `model_dir` path @@ -26,8 +30,8 @@ class TextErrorCorrectionPreprocessor(Preprocessor): """ super().__init__(*args, **kwargs) self.vocab = Dictionary.load(osp.join(model_dir, 'dict.src.txt')) - self.max_length = 100 + 1 # 1 is eos token - self.padding_value = 2 + self.max_length = max_length + 1 if max_length is not None else 129 # 1 is eos token + self.padding_value = self.vocab.pad() def __call__(self, data: str) -> Dict[str, Any]: """process the raw input data diff --git a/tests/pipelines/test_text_error_correction.py b/tests/pipelines/test_text_error_correction.py index 81d74c8a..332ea2a7 100644 --- a/tests/pipelines/test_text_error_correction.py +++ b/tests/pipelines/test_text_error_correction.py @@ -6,7 +6,8 @@ from modelscope.models import Model from modelscope.models.nlp import BartForTextErrorCorrection from modelscope.pipelines import pipeline from modelscope.pipelines.nlp import TextErrorCorrectionPipeline -from modelscope.preprocessors import TextErrorCorrectionPreprocessor +from modelscope.preprocessors import (Preprocessor, + TextErrorCorrectionPreprocessor) from modelscope.utils.constant import Tasks from modelscope.utils.demo_utils import DemoCompatibilityCheck from modelscope.utils.test_utils import test_level @@ -26,7 +27,7 @@ class TextErrorCorrectionTest(unittest.TestCase, DemoCompatibilityCheck): def test_run_with_direct_download(self): cache_path = snapshot_download(self.model_id) model = BartForTextErrorCorrection(cache_path) - preprocessor = TextErrorCorrectionPreprocessor(cache_path) + preprocessor = Preprocessor.from_pretrained(cache_path) pipeline1 = TextErrorCorrectionPipeline(model, preprocessor) pipeline2 = pipeline( Tasks.text_error_correction, @@ -48,7 +49,7 @@ class TextErrorCorrectionTest(unittest.TestCase, DemoCompatibilityCheck): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_model_from_modelhub(self): model = Model.from_pretrained(self.model_id) - preprocessor = TextErrorCorrectionPreprocessor(model.model_dir) + preprocessor = Preprocessor.from_pretrained(model.model_dir) pipeline_ins = pipeline( task=Tasks.text_error_correction, model=model,