From f9a01acecff6145d178e9250d2da2d51f65926dd Mon Sep 17 00:00:00 2001 From: "klayzhang.zb" Date: Thu, 12 Jan 2023 09:33:56 +0800 Subject: [PATCH] [to #42322933] text-error-correction support batch inference Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11398895 --- .../models/nlp/bart/text_error_correction.py | 18 ++++++++---------- .../nlp/text_error_correction_pipeline.py | 15 ++++++++++----- .../preprocessors/nlp/text_error_correction.py | 17 +++++++++++++---- tests/pipelines/test_text_error_correction.py | 11 +++++++++++ 4 files changed, 42 insertions(+), 19 deletions(-) diff --git a/modelscope/models/nlp/bart/text_error_correction.py b/modelscope/models/nlp/bart/text_error_correction.py index ab765190..9ff619f1 100644 --- a/modelscope/models/nlp/bart/text_error_correction.py +++ b/modelscope/models/nlp/bart/text_error_correction.py @@ -78,18 +78,16 @@ class BartForTextErrorCorrection(TorchModel): """ import fairseq.utils - if len(input['net_input']['src_tokens'].size()) == 1: - input['net_input']['src_tokens'] = input['net_input'][ - 'src_tokens'].view(1, -1) + batch_size = input['src_tokens'].size(0) + input = {'net_input': input} if torch.cuda.is_available(): input = fairseq.utils.move_to_cuda(input, device=self._device) - sample = input - translations = self.task.inference_step(self.generator, self.models, - sample) - - # get 1-best List[Tensor] - preds = translations[0][0]['tokens'] - return TextErrorCorrectionOutput(predictions=preds) + input) + batch_preds = [] + for i in range(batch_size): + # get 1-best List[Tensor] + batch_preds.append(translations[i][0]['tokens']) + return TextErrorCorrectionOutput(predictions=batch_preds) diff --git a/modelscope/pipelines/nlp/text_error_correction_pipeline.py b/modelscope/pipelines/nlp/text_error_correction_pipeline.py index 1e6d525a..4abaaca1 100644 --- a/modelscope/pipelines/nlp/text_error_correction_pipeline.py +++ b/modelscope/pipelines/nlp/text_error_correction_pipeline.py @@ -78,9 +78,14 @@ class TextErrorCorrectionPipeline(Pipeline): """ - pred_str = self.vocab.string( - inputs['predictions'], - '@@', - extra_symbols_to_ignore={self.vocab.pad()}) + sc_sent = [] + for sent in inputs['predictions']: + pred_str = self.vocab.string( + sent, '@@', extra_symbols_to_ignore={self.vocab.pad()}) + sc_sent.append(''.join(pred_str.split())) - return {OutputKeys.OUTPUT: ''.join(pred_str.split())} + # for consistent with old version + if len(sc_sent) == 1: + sc_sent = sc_sent[0] + + return {OutputKeys.OUTPUT: sc_sent} diff --git a/modelscope/preprocessors/nlp/text_error_correction.py b/modelscope/preprocessors/nlp/text_error_correction.py index 357a946f..e3a1433d 100644 --- a/modelscope/preprocessors/nlp/text_error_correction.py +++ b/modelscope/preprocessors/nlp/text_error_correction.py @@ -3,6 +3,8 @@ import os.path as osp from typing import Any, Dict +import torch + from modelscope.metainfo import Preprocessors from modelscope.preprocessors.base import Preprocessor from modelscope.preprocessors.builder import PREPROCESSORS @@ -24,6 +26,8 @@ class TextErrorCorrectionPreprocessor(Preprocessor): """ super().__init__(*args, **kwargs) self.vocab = Dictionary.load(osp.join(model_dir, 'dict.src.txt')) + self.max_length = 100 + 1 # 1 is eos token + self.padding_value = 2 def __call__(self, data: str) -> Dict[str, Any]: """process the raw input data @@ -44,7 +48,12 @@ class TextErrorCorrectionPreprocessor(Preprocessor): text = ' '.join([x for x in data]) inputs = self.vocab.encode_line( text, append_eos=True, add_if_not_exist=False) - lengths = inputs.size() - sample = dict() - sample['net_input'] = {'src_tokens': inputs, 'src_lengths': lengths} - return sample + lengths = inputs.size()[0] + + padding = torch.tensor([self.padding_value] * # noqa: W504 + (self.max_length - lengths)) + inputs = torch.unsqueeze(torch.cat([padding, inputs]), dim=0) + lengths = torch.tensor([lengths]) + out = {'src_tokens': inputs, 'src_lengths': lengths} + + return out diff --git a/tests/pipelines/test_text_error_correction.py b/tests/pipelines/test_text_error_correction.py index a714d3d0..81d74c8a 100644 --- a/tests/pipelines/test_text_error_correction.py +++ b/tests/pipelines/test_text_error_correction.py @@ -19,6 +19,8 @@ class TextErrorCorrectionTest(unittest.TestCase, DemoCompatibilityCheck): self.model_id = 'damo/nlp_bart_text-error-correction_chinese' input = '随着中国经济突飞猛近,建造工业与日俱增' + input_2 = '这洋的话,下一年的福气来到自己身上。' + input_3 = '在拥挤时间,为了让人们尊守交通规律,派至少两个警察或者交通管理者。' @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_direct_download(self): @@ -34,6 +36,15 @@ class TextErrorCorrectionTest(unittest.TestCase, DemoCompatibilityCheck): f'pipeline1: {pipeline1(self.input)}\npipeline2: {pipeline2(self.input)}' ) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name_batch(self): + run_kwargs = {'batch_size': 2} + pipeline_ins = pipeline( + task=Tasks.text_error_correction, model=self.model_id) + print( + 'batch: ', + pipeline_ins([self.input, self.input_2, self.input_3], run_kwargs)) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_model_from_modelhub(self): model = Model.from_pretrained(self.model_id)