mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
[to #42322933] text-error-correction support batch inference
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11398895
This commit is contained in:
committed by
wenmeng.zwm
parent
78f812dbb6
commit
f9a01acecf
@@ -78,18 +78,16 @@ class BartForTextErrorCorrection(TorchModel):
|
||||
"""
|
||||
import fairseq.utils
|
||||
|
||||
if len(input['net_input']['src_tokens'].size()) == 1:
|
||||
input['net_input']['src_tokens'] = input['net_input'][
|
||||
'src_tokens'].view(1, -1)
|
||||
batch_size = input['src_tokens'].size(0)
|
||||
|
||||
input = {'net_input': input}
|
||||
if torch.cuda.is_available():
|
||||
input = fairseq.utils.move_to_cuda(input, device=self._device)
|
||||
|
||||
sample = input
|
||||
|
||||
translations = self.task.inference_step(self.generator, self.models,
|
||||
sample)
|
||||
|
||||
# get 1-best List[Tensor]
|
||||
preds = translations[0][0]['tokens']
|
||||
return TextErrorCorrectionOutput(predictions=preds)
|
||||
input)
|
||||
batch_preds = []
|
||||
for i in range(batch_size):
|
||||
# get 1-best List[Tensor]
|
||||
batch_preds.append(translations[i][0]['tokens'])
|
||||
return TextErrorCorrectionOutput(predictions=batch_preds)
|
||||
|
||||
@@ -78,9 +78,14 @@ class TextErrorCorrectionPipeline(Pipeline):
|
||||
|
||||
"""
|
||||
|
||||
pred_str = self.vocab.string(
|
||||
inputs['predictions'],
|
||||
'@@',
|
||||
extra_symbols_to_ignore={self.vocab.pad()})
|
||||
sc_sent = []
|
||||
for sent in inputs['predictions']:
|
||||
pred_str = self.vocab.string(
|
||||
sent, '@@', extra_symbols_to_ignore={self.vocab.pad()})
|
||||
sc_sent.append(''.join(pred_str.split()))
|
||||
|
||||
return {OutputKeys.OUTPUT: ''.join(pred_str.split())}
|
||||
# for consistent with old version
|
||||
if len(sc_sent) == 1:
|
||||
sc_sent = sc_sent[0]
|
||||
|
||||
return {OutputKeys.OUTPUT: sc_sent}
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
import os.path as osp
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Preprocessors
|
||||
from modelscope.preprocessors.base import Preprocessor
|
||||
from modelscope.preprocessors.builder import PREPROCESSORS
|
||||
@@ -24,6 +26,8 @@ class TextErrorCorrectionPreprocessor(Preprocessor):
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.vocab = Dictionary.load(osp.join(model_dir, 'dict.src.txt'))
|
||||
self.max_length = 100 + 1 # 1 is eos token
|
||||
self.padding_value = 2
|
||||
|
||||
def __call__(self, data: str) -> Dict[str, Any]:
|
||||
"""process the raw input data
|
||||
@@ -44,7 +48,12 @@ class TextErrorCorrectionPreprocessor(Preprocessor):
|
||||
text = ' '.join([x for x in data])
|
||||
inputs = self.vocab.encode_line(
|
||||
text, append_eos=True, add_if_not_exist=False)
|
||||
lengths = inputs.size()
|
||||
sample = dict()
|
||||
sample['net_input'] = {'src_tokens': inputs, 'src_lengths': lengths}
|
||||
return sample
|
||||
lengths = inputs.size()[0]
|
||||
|
||||
padding = torch.tensor([self.padding_value] * # noqa: W504
|
||||
(self.max_length - lengths))
|
||||
inputs = torch.unsqueeze(torch.cat([padding, inputs]), dim=0)
|
||||
lengths = torch.tensor([lengths])
|
||||
out = {'src_tokens': inputs, 'src_lengths': lengths}
|
||||
|
||||
return out
|
||||
|
||||
@@ -19,6 +19,8 @@ class TextErrorCorrectionTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
self.model_id = 'damo/nlp_bart_text-error-correction_chinese'
|
||||
|
||||
input = '随着中国经济突飞猛近,建造工业与日俱增'
|
||||
input_2 = '这洋的话,下一年的福气来到自己身上。'
|
||||
input_3 = '在拥挤时间,为了让人们尊守交通规律,派至少两个警察或者交通管理者。'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_with_direct_download(self):
|
||||
@@ -34,6 +36,15 @@ class TextErrorCorrectionTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
f'pipeline1: {pipeline1(self.input)}\npipeline2: {pipeline2(self.input)}'
|
||||
)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_model_name_batch(self):
|
||||
run_kwargs = {'batch_size': 2}
|
||||
pipeline_ins = pipeline(
|
||||
task=Tasks.text_error_correction, model=self.model_id)
|
||||
print(
|
||||
'batch: ',
|
||||
pipeline_ins([self.input, self.input_2, self.input_3], run_kwargs))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_with_model_from_modelhub(self):
|
||||
model = Model.from_pretrained(self.model_id)
|
||||
|
||||
Reference in New Issue
Block a user