diff --git a/modelscope/pipelines/base.py b/modelscope/pipelines/base.py index 1da65213..5fd1aa21 100644 --- a/modelscope/pipelines/base.py +++ b/modelscope/pipelines/base.py @@ -80,7 +80,7 @@ class Pipeline(ABC): self.preprocessor = preprocessor def __call__(self, input: Union[Input, List[Input]], *args, - **post_kwargs) -> Union[Dict[str, Any], Generator]: + **kwargs) -> Union[Dict[str, Any], Generator]: # model provider should leave it as it is # modelscope library developer will handle this function @@ -89,24 +89,41 @@ class Pipeline(ABC): if isinstance(input, list): output = [] for ele in input: - output.append(self._process_single(ele, *args, **post_kwargs)) + output.append(self._process_single(ele, *args, **kwargs)) elif isinstance(input, PyDataset): - return self._process_iterator(input, *args, **post_kwargs) + return self._process_iterator(input, *args, **kwargs) else: - output = self._process_single(input, *args, **post_kwargs) + output = self._process_single(input, *args, **kwargs) return output - def _process_iterator(self, input: Input, *args, **post_kwargs): + def _process_iterator(self, input: Input, *args, **kwargs): for ele in input: - yield self._process_single(ele, *args, **post_kwargs) + yield self._process_single(ele, *args, **kwargs) - def _process_single(self, input: Input, *args, - **post_kwargs) -> Dict[str, Any]: - out = self.preprocess(input) - out = self.forward(out) - out = self.postprocess(out, **post_kwargs) + def _sanitize_parameters(self, **pipeline_parameters): + """ + this method should sanitize the keyword args to preprocessor params, + forward params and postprocess params on '__call__' or '_process_single' method + considering to be a normal classmethod with default implementation / output + + Returns: + Dict[str, str]: preprocess_params = {} + Dict[str, str]: forward_params = {} + Dict[str, str]: postprocess_params = pipeline_parameters + """ + # raise NotImplementedError("_sanitize_parameters not implemented") + return {}, {}, pipeline_parameters + + def _process_single(self, input: Input, *args, **kwargs) -> Dict[str, Any]: + + # sanitize the parameters + preprocess_params, forward_params, postprocess_params = self._sanitize_parameters( + **kwargs) + out = self.preprocess(input, **preprocess_params) + out = self.forward(out, **forward_params) + out = self.postprocess(out, **postprocess_params) self._check_output(out) return out @@ -126,23 +143,25 @@ class Pipeline(ABC): raise ValueError(f'expected output keys are {output_keys}, ' f'those {missing_keys} are missing') - def preprocess(self, inputs: Input) -> Dict[str, Any]: + def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: """ Provide default implementation based on preprocess_cfg and user can reimplement it """ assert self.preprocessor is not None, 'preprocess method should be implemented' assert not isinstance(self.preprocessor, List),\ 'default implementation does not support using multiple preprocessors.' - return self.preprocessor(inputs) + return self.preprocessor(inputs, **preprocess_params) - def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: """ Provide default implementation using self.model and user can reimplement it """ assert self.model is not None, 'forward method should be implemented' assert not self.has_multiple_models, 'default implementation does not support multiple models in a pipeline.' - return self.model(inputs) + return self.model(inputs, **forward_params) @abstractmethod - def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + def postprocess(self, inputs: Dict[str, Any], + **postprocess_params) -> Dict[str, Any]: """ If current pipeline support model reuse, common postprocess code should be write here. diff --git a/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py b/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py index 1ea500e2..ed0a67a2 100644 --- a/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py +++ b/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py @@ -39,18 +39,32 @@ class ZeroShotClassificationPipeline(Pipeline): self.entailment_id = 0 self.contradiction_id = 2 - self.candidate_labels = kwargs.pop('candidate_labels') - self.hypothesis_template = kwargs.pop('hypothesis_template', '{}') - self.multi_label = kwargs.pop('multi_label', False) if preprocessor is None: preprocessor = ZeroShotClassificationPreprocessor( - sc_model.model_dir, - candidate_labels=self.candidate_labels, - hypothesis_template=self.hypothesis_template) + sc_model.model_dir) super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) - def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + def _sanitize_parameters(self, **kwargs): + preprocess_params = {} + postprocess_params = {} + + if 'candidate_labels' in kwargs: + candidate_labels = kwargs.pop('candidate_labels') + preprocess_params['candidate_labels'] = candidate_labels + postprocess_params['candidate_labels'] = candidate_labels + else: + raise ValueError('You must include at least one label.') + preprocess_params['hypothesis_template'] = kwargs.pop( + 'hypothesis_template', '{}') + + postprocess_params['multi_label'] = kwargs.pop('multi_label', False) + return preprocess_params, {}, postprocess_params + + def postprocess(self, + inputs: Dict[str, Any], + candidate_labels, + multi_label=False) -> Dict[str, Any]: """process the prediction results Args: @@ -61,8 +75,7 @@ class ZeroShotClassificationPipeline(Pipeline): """ logits = inputs['logits'] - - if self.multi_label or len(self.candidate_labels) == 1: + if multi_label or len(candidate_labels) == 1: logits = logits[..., [self.contradiction_id, self.entailment_id]] scores = softmax(logits, axis=-1)[..., 1] else: @@ -71,7 +84,7 @@ class ZeroShotClassificationPipeline(Pipeline): reversed_index = list(reversed(scores.argsort())) result = { - 'labels': [self.candidate_labels[i] for i in reversed_index], + 'labels': [candidate_labels[i] for i in reversed_index], 'scores': [scores[i].item() for i in reversed_index], } return result diff --git a/modelscope/preprocessors/nlp.py b/modelscope/preprocessors/nlp.py index 9c7e2ff4..96381660 100644 --- a/modelscope/preprocessors/nlp.py +++ b/modelscope/preprocessors/nlp.py @@ -196,12 +196,11 @@ class ZeroShotClassificationPreprocessor(Preprocessor): from sofa import SbertTokenizer self.model_dir: str = model_dir self.sequence_length = kwargs.pop('sequence_length', 512) - self.candidate_labels = kwargs.pop('candidate_labels') - self.hypothesis_template = kwargs.pop('hypothesis_template', '{}') self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir) @type_assert(object, str) - def __call__(self, data: str) -> Dict[str, Any]: + def __call__(self, data: str, hypothesis_template: str, + candidate_labels: list) -> Dict[str, Any]: """process the raw input data Args: @@ -212,8 +211,8 @@ class ZeroShotClassificationPreprocessor(Preprocessor): Returns: Dict[str, Any]: the preprocessed data """ - pairs = [[data, self.hypothesis_template.format(label)] - for label in self.candidate_labels] + pairs = [[data, hypothesis_template.format(label)] + for label in candidate_labels] features = self.tokenizer( pairs, diff --git a/tests/pipelines/test_zero_shot_classification.py b/tests/pipelines/test_zero_shot_classification.py index 1fe69e5b..2c32142a 100644 --- a/tests/pipelines/test_zero_shot_classification.py +++ b/tests/pipelines/test_zero_shot_classification.py @@ -13,53 +13,47 @@ from modelscope.utils.constant import Tasks class ZeroShotClassificationTest(unittest.TestCase): model_id = 'damo/nlp_structbert_zero-shot-classification_chinese-base' sentence = '全新突破 解放军运20版空中加油机曝光' - candidate_labels = ['文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事'] + labels = ['文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事'] + template = '这篇文章的标题是{}' def test_run_from_local(self): cache_path = snapshot_download(self.model_id) - tokenizer = ZeroShotClassificationPreprocessor( - cache_path, candidate_labels=self.candidate_labels) + tokenizer = ZeroShotClassificationPreprocessor(cache_path) model = BertForZeroShotClassification(cache_path, tokenizer=tokenizer) pipeline1 = ZeroShotClassificationPipeline( - model, - preprocessor=tokenizer, - candidate_labels=self.candidate_labels, - ) + model, preprocessor=tokenizer) pipeline2 = pipeline( Tasks.zero_shot_classification, model=model, - preprocessor=tokenizer, - candidate_labels=self.candidate_labels) + preprocessor=tokenizer) - print(f'sentence: {self.sentence}\n' - f'pipeline1:{pipeline1(input=self.sentence)}') + print( + f'sentence: {self.sentence}\n' + f'pipeline1:{pipeline1(input=self.sentence,candidate_labels=self.labels)}' + ) print() - print(f'sentence: {self.sentence}\n' - f'pipeline2: {pipeline2(input=self.sentence)}') + print( + f'sentence: {self.sentence}\n' + f'pipeline2: {pipeline2(self.sentence,candidate_labels=self.labels,hypothesis_template=self.template)}' + ) def test_run_with_model_from_modelhub(self): model = Model.from_pretrained(self.model_id) - tokenizer = ZeroShotClassificationPreprocessor( - model.model_dir, candidate_labels=self.candidate_labels) + tokenizer = ZeroShotClassificationPreprocessor(model.model_dir) pipeline_ins = pipeline( task=Tasks.zero_shot_classification, model=model, - preprocessor=tokenizer, - candidate_labels=self.candidate_labels) - print(pipeline_ins(input=self.sentence)) + preprocessor=tokenizer) + print(pipeline_ins(input=self.sentence, candidate_labels=self.labels)) def test_run_with_model_name(self): pipeline_ins = pipeline( - task=Tasks.zero_shot_classification, - model=self.model_id, - candidate_labels=self.candidate_labels) - print(pipeline_ins(input=self.sentence)) + task=Tasks.zero_shot_classification, model=self.model_id) + print(pipeline_ins(input=self.sentence, candidate_labels=self.labels)) def test_run_with_default_model(self): - pipeline_ins = pipeline( - task=Tasks.zero_shot_classification, - candidate_labels=self.candidate_labels) - print(pipeline_ins(input=self.sentence)) + pipeline_ins = pipeline(task=Tasks.zero_shot_classification) + print(pipeline_ins(input=self.sentence, candidate_labels=self.labels)) if __name__ == '__main__':