mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 20:19:51 +01:00
[to #42322933][FIX] fix batch postprocessor bug
1. 修复batch处理带来的postprocessor的错误 2. 修复SDK中preprocessor base.py中cfg的bug。 3. 新增max_length Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11435946
This commit is contained in:
committed by
wenmeng.zwm
parent
01e41f15d2
commit
d36a04d106
@@ -75,17 +75,14 @@ class TextErrorCorrectionPipeline(Pipeline):
|
||||
'output': '随着中国经济突飞猛进,建造工业与日俱增'
|
||||
}
|
||||
|
||||
|
||||
"""
|
||||
|
||||
sc_sent = []
|
||||
for sent in inputs['predictions']:
|
||||
pred_str = self.vocab.string(
|
||||
sent, '@@', extra_symbols_to_ignore={self.vocab.pad()})
|
||||
sc_sent.append(''.join(pred_str.split()))
|
||||
sc_tensor = inputs['predictions']
|
||||
if isinstance(sc_tensor, list):
|
||||
sc_tensor = sc_tensor[0]
|
||||
sc_sent = self.vocab.string(
|
||||
sc_tensor, '@@', extra_symbols_to_ignore={self.vocab.pad()})
|
||||
|
||||
# for consistent with old version
|
||||
if len(sc_sent) == 1:
|
||||
sc_sent = sc_sent[0]
|
||||
sc_sent = ''.join(sc_sent.split())
|
||||
|
||||
return {OutputKeys.OUTPUT: sc_sent}
|
||||
|
||||
@@ -306,7 +306,7 @@ class Preprocessor(ABC):
|
||||
preprocessor.mode = preprocessor_mode
|
||||
sub_cfg.pop('model_dir', None)
|
||||
if not hasattr(preprocessor, 'cfg'):
|
||||
preprocessor.cfg = cfg
|
||||
preprocessor.cfg = sub_cfg
|
||||
return preprocessor
|
||||
|
||||
def save_pretrained(self,
|
||||
|
||||
@@ -17,7 +17,11 @@ class TextErrorCorrectionPreprocessor(Preprocessor):
|
||||
"""The preprocessor used in text correction task.
|
||||
"""
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
def __init__(self,
|
||||
model_dir: str,
|
||||
max_length: int = None,
|
||||
*args,
|
||||
**kwargs):
|
||||
from fairseq.data import Dictionary
|
||||
"""preprocess the data via the vocab file from the `model_dir` path
|
||||
|
||||
@@ -26,8 +30,8 @@ class TextErrorCorrectionPreprocessor(Preprocessor):
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.vocab = Dictionary.load(osp.join(model_dir, 'dict.src.txt'))
|
||||
self.max_length = 100 + 1 # 1 is eos token
|
||||
self.padding_value = 2
|
||||
self.max_length = max_length + 1 if max_length is not None else 129 # 1 is eos token
|
||||
self.padding_value = self.vocab.pad()
|
||||
|
||||
def __call__(self, data: str) -> Dict[str, Any]:
|
||||
"""process the raw input data
|
||||
|
||||
@@ -6,7 +6,8 @@ from modelscope.models import Model
|
||||
from modelscope.models.nlp import BartForTextErrorCorrection
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.nlp import TextErrorCorrectionPipeline
|
||||
from modelscope.preprocessors import TextErrorCorrectionPreprocessor
|
||||
from modelscope.preprocessors import (Preprocessor,
|
||||
TextErrorCorrectionPreprocessor)
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.demo_utils import DemoCompatibilityCheck
|
||||
from modelscope.utils.test_utils import test_level
|
||||
@@ -26,7 +27,7 @@ class TextErrorCorrectionTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
def test_run_with_direct_download(self):
|
||||
cache_path = snapshot_download(self.model_id)
|
||||
model = BartForTextErrorCorrection(cache_path)
|
||||
preprocessor = TextErrorCorrectionPreprocessor(cache_path)
|
||||
preprocessor = Preprocessor.from_pretrained(cache_path)
|
||||
pipeline1 = TextErrorCorrectionPipeline(model, preprocessor)
|
||||
pipeline2 = pipeline(
|
||||
Tasks.text_error_correction,
|
||||
@@ -48,7 +49,7 @@ class TextErrorCorrectionTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_with_model_from_modelhub(self):
|
||||
model = Model.from_pretrained(self.model_id)
|
||||
preprocessor = TextErrorCorrectionPreprocessor(model.model_dir)
|
||||
preprocessor = Preprocessor.from_pretrained(model.model_dir)
|
||||
pipeline_ins = pipeline(
|
||||
task=Tasks.text_error_correction,
|
||||
model=model,
|
||||
|
||||
Reference in New Issue
Block a user