[to #42322933] text-error-correction support batch inference

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11398895
This commit is contained in:
klayzhang.zb
2023-01-12 09:33:56 +08:00
committed by wenmeng.zwm
parent 78f812dbb6
commit f9a01acecf
4 changed files with 42 additions and 19 deletions

View File

@@ -78,18 +78,16 @@ class BartForTextErrorCorrection(TorchModel):
"""
import fairseq.utils
if len(input['net_input']['src_tokens'].size()) == 1:
input['net_input']['src_tokens'] = input['net_input'][
'src_tokens'].view(1, -1)
batch_size = input['src_tokens'].size(0)
input = {'net_input': input}
if torch.cuda.is_available():
input = fairseq.utils.move_to_cuda(input, device=self._device)
sample = input
translations = self.task.inference_step(self.generator, self.models,
sample)
# get 1-best List[Tensor]
preds = translations[0][0]['tokens']
return TextErrorCorrectionOutput(predictions=preds)
input)
batch_preds = []
for i in range(batch_size):
# get 1-best List[Tensor]
batch_preds.append(translations[i][0]['tokens'])
return TextErrorCorrectionOutput(predictions=batch_preds)

View File

@@ -78,9 +78,14 @@ class TextErrorCorrectionPipeline(Pipeline):
"""
pred_str = self.vocab.string(
inputs['predictions'],
'@@',
extra_symbols_to_ignore={self.vocab.pad()})
sc_sent = []
for sent in inputs['predictions']:
pred_str = self.vocab.string(
sent, '@@', extra_symbols_to_ignore={self.vocab.pad()})
sc_sent.append(''.join(pred_str.split()))
return {OutputKeys.OUTPUT: ''.join(pred_str.split())}
# for consistent with old version
if len(sc_sent) == 1:
sc_sent = sc_sent[0]
return {OutputKeys.OUTPUT: sc_sent}

View File

@@ -3,6 +3,8 @@
import os.path as osp
from typing import Any, Dict
import torch
from modelscope.metainfo import Preprocessors
from modelscope.preprocessors.base import Preprocessor
from modelscope.preprocessors.builder import PREPROCESSORS
@@ -24,6 +26,8 @@ class TextErrorCorrectionPreprocessor(Preprocessor):
"""
super().__init__(*args, **kwargs)
self.vocab = Dictionary.load(osp.join(model_dir, 'dict.src.txt'))
self.max_length = 100 + 1 # 1 is eos token
self.padding_value = 2
def __call__(self, data: str) -> Dict[str, Any]:
"""process the raw input data
@@ -44,7 +48,12 @@ class TextErrorCorrectionPreprocessor(Preprocessor):
text = ' '.join([x for x in data])
inputs = self.vocab.encode_line(
text, append_eos=True, add_if_not_exist=False)
lengths = inputs.size()
sample = dict()
sample['net_input'] = {'src_tokens': inputs, 'src_lengths': lengths}
return sample
lengths = inputs.size()[0]
padding = torch.tensor([self.padding_value] * # noqa: W504
(self.max_length - lengths))
inputs = torch.unsqueeze(torch.cat([padding, inputs]), dim=0)
lengths = torch.tensor([lengths])
out = {'src_tokens': inputs, 'src_lengths': lengths}
return out

View File

@@ -19,6 +19,8 @@ class TextErrorCorrectionTest(unittest.TestCase, DemoCompatibilityCheck):
self.model_id = 'damo/nlp_bart_text-error-correction_chinese'
input = '随着中国经济突飞猛近,建造工业与日俱增'
input_2 = '这洋的话,下一年的福气来到自己身上。'
input_3 = '在拥挤时间,为了让人们尊守交通规律,派至少两个警察或者交通管理者。'
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_direct_download(self):
@@ -34,6 +36,15 @@ class TextErrorCorrectionTest(unittest.TestCase, DemoCompatibilityCheck):
f'pipeline1: {pipeline1(self.input)}\npipeline2: {pipeline2(self.input)}'
)
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name_batch(self):
run_kwargs = {'batch_size': 2}
pipeline_ins = pipeline(
task=Tasks.text_error_correction, model=self.model_id)
print(
'batch: ',
pipeline_ins([self.input, self.input_2, self.input_3], run_kwargs))
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)