[to #42322933][FIX] fix batch postprocessor bug

1. 修复batch处理带来的postprocessor的错误
2. 修复SDK中preprocessor base.py中cfg的bug。
3. 新增max_length

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11435946
This commit is contained in:
klayzhang.zb
2023-01-14 14:30:18 +00:00
committed by wenmeng.zwm
parent 01e41f15d2
commit d36a04d106
4 changed files with 18 additions and 16 deletions

View File

@@ -75,17 +75,14 @@ class TextErrorCorrectionPipeline(Pipeline):
'output': '随着中国经济突飞猛进,建造工业与日俱增'
}
"""
sc_sent = []
for sent in inputs['predictions']:
pred_str = self.vocab.string(
sent, '@@', extra_symbols_to_ignore={self.vocab.pad()})
sc_sent.append(''.join(pred_str.split()))
sc_tensor = inputs['predictions']
if isinstance(sc_tensor, list):
sc_tensor = sc_tensor[0]
sc_sent = self.vocab.string(
sc_tensor, '@@', extra_symbols_to_ignore={self.vocab.pad()})
# for consistent with old version
if len(sc_sent) == 1:
sc_sent = sc_sent[0]
sc_sent = ''.join(sc_sent.split())
return {OutputKeys.OUTPUT: sc_sent}

View File

@@ -306,7 +306,7 @@ class Preprocessor(ABC):
preprocessor.mode = preprocessor_mode
sub_cfg.pop('model_dir', None)
if not hasattr(preprocessor, 'cfg'):
preprocessor.cfg = cfg
preprocessor.cfg = sub_cfg
return preprocessor
def save_pretrained(self,

View File

@@ -17,7 +17,11 @@ class TextErrorCorrectionPreprocessor(Preprocessor):
"""The preprocessor used in text correction task.
"""
def __init__(self, model_dir: str, *args, **kwargs):
def __init__(self,
model_dir: str,
max_length: int = None,
*args,
**kwargs):
from fairseq.data import Dictionary
"""preprocess the data via the vocab file from the `model_dir` path
@@ -26,8 +30,8 @@ class TextErrorCorrectionPreprocessor(Preprocessor):
"""
super().__init__(*args, **kwargs)
self.vocab = Dictionary.load(osp.join(model_dir, 'dict.src.txt'))
self.max_length = 100 + 1 # 1 is eos token
self.padding_value = 2
self.max_length = max_length + 1 if max_length is not None else 129 # 1 is eos token
self.padding_value = self.vocab.pad()
def __call__(self, data: str) -> Dict[str, Any]:
"""process the raw input data

View File

@@ -6,7 +6,8 @@ from modelscope.models import Model
from modelscope.models.nlp import BartForTextErrorCorrection
from modelscope.pipelines import pipeline
from modelscope.pipelines.nlp import TextErrorCorrectionPipeline
from modelscope.preprocessors import TextErrorCorrectionPreprocessor
from modelscope.preprocessors import (Preprocessor,
TextErrorCorrectionPreprocessor)
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.test_utils import test_level
@@ -26,7 +27,7 @@ class TextErrorCorrectionTest(unittest.TestCase, DemoCompatibilityCheck):
def test_run_with_direct_download(self):
cache_path = snapshot_download(self.model_id)
model = BartForTextErrorCorrection(cache_path)
preprocessor = TextErrorCorrectionPreprocessor(cache_path)
preprocessor = Preprocessor.from_pretrained(cache_path)
pipeline1 = TextErrorCorrectionPipeline(model, preprocessor)
pipeline2 = pipeline(
Tasks.text_error_correction,
@@ -48,7 +49,7 @@ class TextErrorCorrectionTest(unittest.TestCase, DemoCompatibilityCheck):
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
preprocessor = TextErrorCorrectionPreprocessor(model.model_dir)
preprocessor = Preprocessor.from_pretrained(model.model_dir)
pipeline_ins = pipeline(
task=Tasks.text_error_correction,
model=model,