allow params pass to pipeline's __call__ method

This commit is contained in:
智丞
2022-06-21 20:36:17 +08:00
parent fcc5740238
commit 8ae2e46ad3
4 changed files with 82 additions and 57 deletions

View File

@@ -80,7 +80,7 @@ class Pipeline(ABC):
self.preprocessor = preprocessor
def __call__(self, input: Union[Input, List[Input]], *args,
**post_kwargs) -> Union[Dict[str, Any], Generator]:
**kwargs) -> Union[Dict[str, Any], Generator]:
# model provider should leave it as it is
# modelscope library developer will handle this function
@@ -89,24 +89,41 @@ class Pipeline(ABC):
if isinstance(input, list):
output = []
for ele in input:
output.append(self._process_single(ele, *args, **post_kwargs))
output.append(self._process_single(ele, *args, **kwargs))
elif isinstance(input, PyDataset):
return self._process_iterator(input, *args, **post_kwargs)
return self._process_iterator(input, *args, **kwargs)
else:
output = self._process_single(input, *args, **post_kwargs)
output = self._process_single(input, *args, **kwargs)
return output
def _process_iterator(self, input: Input, *args, **post_kwargs):
def _process_iterator(self, input: Input, *args, **kwargs):
for ele in input:
yield self._process_single(ele, *args, **post_kwargs)
yield self._process_single(ele, *args, **kwargs)
def _process_single(self, input: Input, *args,
**post_kwargs) -> Dict[str, Any]:
out = self.preprocess(input)
out = self.forward(out)
out = self.postprocess(out, **post_kwargs)
def _sanitize_parameters(self, **pipeline_parameters):
"""
this method should sanitize the keyword args to preprocessor params,
forward params and postprocess params on '__call__' or '_process_single' method
considering to be a normal classmethod with default implementation / output
Returns:
Dict[str, str]: preprocess_params = {}
Dict[str, str]: forward_params = {}
Dict[str, str]: postprocess_params = pipeline_parameters
"""
# raise NotImplementedError("_sanitize_parameters not implemented")
return {}, {}, pipeline_parameters
def _process_single(self, input: Input, *args, **kwargs) -> Dict[str, Any]:
# sanitize the parameters
preprocess_params, forward_params, postprocess_params = self._sanitize_parameters(
**kwargs)
out = self.preprocess(input, **preprocess_params)
out = self.forward(out, **forward_params)
out = self.postprocess(out, **postprocess_params)
self._check_output(out)
return out
@@ -126,23 +143,25 @@ class Pipeline(ABC):
raise ValueError(f'expected output keys are {output_keys}, '
f'those {missing_keys} are missing')
def preprocess(self, inputs: Input) -> Dict[str, Any]:
def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
""" Provide default implementation based on preprocess_cfg and user can reimplement it
"""
assert self.preprocessor is not None, 'preprocess method should be implemented'
assert not isinstance(self.preprocessor, List),\
'default implementation does not support using multiple preprocessors.'
return self.preprocessor(inputs)
return self.preprocessor(inputs, **preprocess_params)
def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
""" Provide default implementation using self.model and user can reimplement it
"""
assert self.model is not None, 'forward method should be implemented'
assert not self.has_multiple_models, 'default implementation does not support multiple models in a pipeline.'
return self.model(inputs)
return self.model(inputs, **forward_params)
@abstractmethod
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def postprocess(self, inputs: Dict[str, Any],
**postprocess_params) -> Dict[str, Any]:
""" If current pipeline support model reuse, common postprocess
code should be write here.

View File

@@ -39,18 +39,32 @@ class ZeroShotClassificationPipeline(Pipeline):
self.entailment_id = 0
self.contradiction_id = 2
self.candidate_labels = kwargs.pop('candidate_labels')
self.hypothesis_template = kwargs.pop('hypothesis_template', '{}')
self.multi_label = kwargs.pop('multi_label', False)
if preprocessor is None:
preprocessor = ZeroShotClassificationPreprocessor(
sc_model.model_dir,
candidate_labels=self.candidate_labels,
hypothesis_template=self.hypothesis_template)
sc_model.model_dir)
super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs)
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def _sanitize_parameters(self, **kwargs):
preprocess_params = {}
postprocess_params = {}
if 'candidate_labels' in kwargs:
candidate_labels = kwargs.pop('candidate_labels')
preprocess_params['candidate_labels'] = candidate_labels
postprocess_params['candidate_labels'] = candidate_labels
else:
raise ValueError('You must include at least one label.')
preprocess_params['hypothesis_template'] = kwargs.pop(
'hypothesis_template', '{}')
postprocess_params['multi_label'] = kwargs.pop('multi_label', False)
return preprocess_params, {}, postprocess_params
def postprocess(self,
inputs: Dict[str, Any],
candidate_labels,
multi_label=False) -> Dict[str, Any]:
"""process the prediction results
Args:
@@ -61,8 +75,7 @@ class ZeroShotClassificationPipeline(Pipeline):
"""
logits = inputs['logits']
if self.multi_label or len(self.candidate_labels) == 1:
if multi_label or len(candidate_labels) == 1:
logits = logits[..., [self.contradiction_id, self.entailment_id]]
scores = softmax(logits, axis=-1)[..., 1]
else:
@@ -71,7 +84,7 @@ class ZeroShotClassificationPipeline(Pipeline):
reversed_index = list(reversed(scores.argsort()))
result = {
'labels': [self.candidate_labels[i] for i in reversed_index],
'labels': [candidate_labels[i] for i in reversed_index],
'scores': [scores[i].item() for i in reversed_index],
}
return result

View File

@@ -196,12 +196,11 @@ class ZeroShotClassificationPreprocessor(Preprocessor):
from sofa import SbertTokenizer
self.model_dir: str = model_dir
self.sequence_length = kwargs.pop('sequence_length', 512)
self.candidate_labels = kwargs.pop('candidate_labels')
self.hypothesis_template = kwargs.pop('hypothesis_template', '{}')
self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir)
@type_assert(object, str)
def __call__(self, data: str) -> Dict[str, Any]:
def __call__(self, data: str, hypothesis_template: str,
candidate_labels: list) -> Dict[str, Any]:
"""process the raw input data
Args:
@@ -212,8 +211,8 @@ class ZeroShotClassificationPreprocessor(Preprocessor):
Returns:
Dict[str, Any]: the preprocessed data
"""
pairs = [[data, self.hypothesis_template.format(label)]
for label in self.candidate_labels]
pairs = [[data, hypothesis_template.format(label)]
for label in candidate_labels]
features = self.tokenizer(
pairs,

View File

@@ -13,53 +13,47 @@ from modelscope.utils.constant import Tasks
class ZeroShotClassificationTest(unittest.TestCase):
model_id = 'damo/nlp_structbert_zero-shot-classification_chinese-base'
sentence = '全新突破 解放军运20版空中加油机曝光'
candidate_labels = ['文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事']
labels = ['文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事']
template = '这篇文章的标题是{}'
def test_run_from_local(self):
cache_path = snapshot_download(self.model_id)
tokenizer = ZeroShotClassificationPreprocessor(
cache_path, candidate_labels=self.candidate_labels)
tokenizer = ZeroShotClassificationPreprocessor(cache_path)
model = BertForZeroShotClassification(cache_path, tokenizer=tokenizer)
pipeline1 = ZeroShotClassificationPipeline(
model,
preprocessor=tokenizer,
candidate_labels=self.candidate_labels,
)
model, preprocessor=tokenizer)
pipeline2 = pipeline(
Tasks.zero_shot_classification,
model=model,
preprocessor=tokenizer,
candidate_labels=self.candidate_labels)
preprocessor=tokenizer)
print(f'sentence: {self.sentence}\n'
f'pipeline1:{pipeline1(input=self.sentence)}')
print(
f'sentence: {self.sentence}\n'
f'pipeline1:{pipeline1(input=self.sentence,candidate_labels=self.labels)}'
)
print()
print(f'sentence: {self.sentence}\n'
f'pipeline2: {pipeline2(input=self.sentence)}')
print(
f'sentence: {self.sentence}\n'
f'pipeline2: {pipeline2(self.sentence,candidate_labels=self.labels,hypothesis_template=self.template)}'
)
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
tokenizer = ZeroShotClassificationPreprocessor(
model.model_dir, candidate_labels=self.candidate_labels)
tokenizer = ZeroShotClassificationPreprocessor(model.model_dir)
pipeline_ins = pipeline(
task=Tasks.zero_shot_classification,
model=model,
preprocessor=tokenizer,
candidate_labels=self.candidate_labels)
print(pipeline_ins(input=self.sentence))
preprocessor=tokenizer)
print(pipeline_ins(input=self.sentence, candidate_labels=self.labels))
def test_run_with_model_name(self):
pipeline_ins = pipeline(
task=Tasks.zero_shot_classification,
model=self.model_id,
candidate_labels=self.candidate_labels)
print(pipeline_ins(input=self.sentence))
task=Tasks.zero_shot_classification, model=self.model_id)
print(pipeline_ins(input=self.sentence, candidate_labels=self.labels))
def test_run_with_default_model(self):
pipeline_ins = pipeline(
task=Tasks.zero_shot_classification,
candidate_labels=self.candidate_labels)
print(pipeline_ins(input=self.sentence))
pipeline_ins = pipeline(task=Tasks.zero_shot_classification)
print(pipeline_ins(input=self.sentence, candidate_labels=self.labels))
if __name__ == '__main__':