mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 03:59:23 +01:00
allow params pass to pipeline's __call__ method
This commit is contained in:
@@ -80,7 +80,7 @@ class Pipeline(ABC):
|
||||
self.preprocessor = preprocessor
|
||||
|
||||
def __call__(self, input: Union[Input, List[Input]], *args,
|
||||
**post_kwargs) -> Union[Dict[str, Any], Generator]:
|
||||
**kwargs) -> Union[Dict[str, Any], Generator]:
|
||||
# model provider should leave it as it is
|
||||
# modelscope library developer will handle this function
|
||||
|
||||
@@ -89,24 +89,41 @@ class Pipeline(ABC):
|
||||
if isinstance(input, list):
|
||||
output = []
|
||||
for ele in input:
|
||||
output.append(self._process_single(ele, *args, **post_kwargs))
|
||||
output.append(self._process_single(ele, *args, **kwargs))
|
||||
|
||||
elif isinstance(input, PyDataset):
|
||||
return self._process_iterator(input, *args, **post_kwargs)
|
||||
return self._process_iterator(input, *args, **kwargs)
|
||||
|
||||
else:
|
||||
output = self._process_single(input, *args, **post_kwargs)
|
||||
output = self._process_single(input, *args, **kwargs)
|
||||
return output
|
||||
|
||||
def _process_iterator(self, input: Input, *args, **post_kwargs):
|
||||
def _process_iterator(self, input: Input, *args, **kwargs):
|
||||
for ele in input:
|
||||
yield self._process_single(ele, *args, **post_kwargs)
|
||||
yield self._process_single(ele, *args, **kwargs)
|
||||
|
||||
def _process_single(self, input: Input, *args,
|
||||
**post_kwargs) -> Dict[str, Any]:
|
||||
out = self.preprocess(input)
|
||||
out = self.forward(out)
|
||||
out = self.postprocess(out, **post_kwargs)
|
||||
def _sanitize_parameters(self, **pipeline_parameters):
|
||||
"""
|
||||
this method should sanitize the keyword args to preprocessor params,
|
||||
forward params and postprocess params on '__call__' or '_process_single' method
|
||||
considering to be a normal classmethod with default implementation / output
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: preprocess_params = {}
|
||||
Dict[str, str]: forward_params = {}
|
||||
Dict[str, str]: postprocess_params = pipeline_parameters
|
||||
"""
|
||||
# raise NotImplementedError("_sanitize_parameters not implemented")
|
||||
return {}, {}, pipeline_parameters
|
||||
|
||||
def _process_single(self, input: Input, *args, **kwargs) -> Dict[str, Any]:
|
||||
|
||||
# sanitize the parameters
|
||||
preprocess_params, forward_params, postprocess_params = self._sanitize_parameters(
|
||||
**kwargs)
|
||||
out = self.preprocess(input, **preprocess_params)
|
||||
out = self.forward(out, **forward_params)
|
||||
out = self.postprocess(out, **postprocess_params)
|
||||
self._check_output(out)
|
||||
return out
|
||||
|
||||
@@ -126,23 +143,25 @@ class Pipeline(ABC):
|
||||
raise ValueError(f'expected output keys are {output_keys}, '
|
||||
f'those {missing_keys} are missing')
|
||||
|
||||
def preprocess(self, inputs: Input) -> Dict[str, Any]:
|
||||
def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
|
||||
""" Provide default implementation based on preprocess_cfg and user can reimplement it
|
||||
"""
|
||||
assert self.preprocessor is not None, 'preprocess method should be implemented'
|
||||
assert not isinstance(self.preprocessor, List),\
|
||||
'default implementation does not support using multiple preprocessors.'
|
||||
return self.preprocessor(inputs)
|
||||
return self.preprocessor(inputs, **preprocess_params)
|
||||
|
||||
def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def forward(self, inputs: Dict[str, Any],
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
""" Provide default implementation using self.model and user can reimplement it
|
||||
"""
|
||||
assert self.model is not None, 'forward method should be implemented'
|
||||
assert not self.has_multiple_models, 'default implementation does not support multiple models in a pipeline.'
|
||||
return self.model(inputs)
|
||||
return self.model(inputs, **forward_params)
|
||||
|
||||
@abstractmethod
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def postprocess(self, inputs: Dict[str, Any],
|
||||
**postprocess_params) -> Dict[str, Any]:
|
||||
""" If current pipeline support model reuse, common postprocess
|
||||
code should be write here.
|
||||
|
||||
|
||||
@@ -39,18 +39,32 @@ class ZeroShotClassificationPipeline(Pipeline):
|
||||
|
||||
self.entailment_id = 0
|
||||
self.contradiction_id = 2
|
||||
self.candidate_labels = kwargs.pop('candidate_labels')
|
||||
self.hypothesis_template = kwargs.pop('hypothesis_template', '{}')
|
||||
self.multi_label = kwargs.pop('multi_label', False)
|
||||
|
||||
if preprocessor is None:
|
||||
preprocessor = ZeroShotClassificationPreprocessor(
|
||||
sc_model.model_dir,
|
||||
candidate_labels=self.candidate_labels,
|
||||
hypothesis_template=self.hypothesis_template)
|
||||
sc_model.model_dir)
|
||||
super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs)
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
preprocess_params = {}
|
||||
postprocess_params = {}
|
||||
|
||||
if 'candidate_labels' in kwargs:
|
||||
candidate_labels = kwargs.pop('candidate_labels')
|
||||
preprocess_params['candidate_labels'] = candidate_labels
|
||||
postprocess_params['candidate_labels'] = candidate_labels
|
||||
else:
|
||||
raise ValueError('You must include at least one label.')
|
||||
preprocess_params['hypothesis_template'] = kwargs.pop(
|
||||
'hypothesis_template', '{}')
|
||||
|
||||
postprocess_params['multi_label'] = kwargs.pop('multi_label', False)
|
||||
return preprocess_params, {}, postprocess_params
|
||||
|
||||
def postprocess(self,
|
||||
inputs: Dict[str, Any],
|
||||
candidate_labels,
|
||||
multi_label=False) -> Dict[str, Any]:
|
||||
"""process the prediction results
|
||||
|
||||
Args:
|
||||
@@ -61,8 +75,7 @@ class ZeroShotClassificationPipeline(Pipeline):
|
||||
"""
|
||||
|
||||
logits = inputs['logits']
|
||||
|
||||
if self.multi_label or len(self.candidate_labels) == 1:
|
||||
if multi_label or len(candidate_labels) == 1:
|
||||
logits = logits[..., [self.contradiction_id, self.entailment_id]]
|
||||
scores = softmax(logits, axis=-1)[..., 1]
|
||||
else:
|
||||
@@ -71,7 +84,7 @@ class ZeroShotClassificationPipeline(Pipeline):
|
||||
|
||||
reversed_index = list(reversed(scores.argsort()))
|
||||
result = {
|
||||
'labels': [self.candidate_labels[i] for i in reversed_index],
|
||||
'labels': [candidate_labels[i] for i in reversed_index],
|
||||
'scores': [scores[i].item() for i in reversed_index],
|
||||
}
|
||||
return result
|
||||
|
||||
@@ -196,12 +196,11 @@ class ZeroShotClassificationPreprocessor(Preprocessor):
|
||||
from sofa import SbertTokenizer
|
||||
self.model_dir: str = model_dir
|
||||
self.sequence_length = kwargs.pop('sequence_length', 512)
|
||||
self.candidate_labels = kwargs.pop('candidate_labels')
|
||||
self.hypothesis_template = kwargs.pop('hypothesis_template', '{}')
|
||||
self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir)
|
||||
|
||||
@type_assert(object, str)
|
||||
def __call__(self, data: str) -> Dict[str, Any]:
|
||||
def __call__(self, data: str, hypothesis_template: str,
|
||||
candidate_labels: list) -> Dict[str, Any]:
|
||||
"""process the raw input data
|
||||
|
||||
Args:
|
||||
@@ -212,8 +211,8 @@ class ZeroShotClassificationPreprocessor(Preprocessor):
|
||||
Returns:
|
||||
Dict[str, Any]: the preprocessed data
|
||||
"""
|
||||
pairs = [[data, self.hypothesis_template.format(label)]
|
||||
for label in self.candidate_labels]
|
||||
pairs = [[data, hypothesis_template.format(label)]
|
||||
for label in candidate_labels]
|
||||
|
||||
features = self.tokenizer(
|
||||
pairs,
|
||||
|
||||
@@ -13,53 +13,47 @@ from modelscope.utils.constant import Tasks
|
||||
class ZeroShotClassificationTest(unittest.TestCase):
|
||||
model_id = 'damo/nlp_structbert_zero-shot-classification_chinese-base'
|
||||
sentence = '全新突破 解放军运20版空中加油机曝光'
|
||||
candidate_labels = ['文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事']
|
||||
labels = ['文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事']
|
||||
template = '这篇文章的标题是{}'
|
||||
|
||||
def test_run_from_local(self):
|
||||
cache_path = snapshot_download(self.model_id)
|
||||
tokenizer = ZeroShotClassificationPreprocessor(
|
||||
cache_path, candidate_labels=self.candidate_labels)
|
||||
tokenizer = ZeroShotClassificationPreprocessor(cache_path)
|
||||
model = BertForZeroShotClassification(cache_path, tokenizer=tokenizer)
|
||||
pipeline1 = ZeroShotClassificationPipeline(
|
||||
model,
|
||||
preprocessor=tokenizer,
|
||||
candidate_labels=self.candidate_labels,
|
||||
)
|
||||
model, preprocessor=tokenizer)
|
||||
pipeline2 = pipeline(
|
||||
Tasks.zero_shot_classification,
|
||||
model=model,
|
||||
preprocessor=tokenizer,
|
||||
candidate_labels=self.candidate_labels)
|
||||
preprocessor=tokenizer)
|
||||
|
||||
print(f'sentence: {self.sentence}\n'
|
||||
f'pipeline1:{pipeline1(input=self.sentence)}')
|
||||
print(
|
||||
f'sentence: {self.sentence}\n'
|
||||
f'pipeline1:{pipeline1(input=self.sentence,candidate_labels=self.labels)}'
|
||||
)
|
||||
print()
|
||||
print(f'sentence: {self.sentence}\n'
|
||||
f'pipeline2: {pipeline2(input=self.sentence)}')
|
||||
print(
|
||||
f'sentence: {self.sentence}\n'
|
||||
f'pipeline2: {pipeline2(self.sentence,candidate_labels=self.labels,hypothesis_template=self.template)}'
|
||||
)
|
||||
|
||||
def test_run_with_model_from_modelhub(self):
|
||||
model = Model.from_pretrained(self.model_id)
|
||||
tokenizer = ZeroShotClassificationPreprocessor(
|
||||
model.model_dir, candidate_labels=self.candidate_labels)
|
||||
tokenizer = ZeroShotClassificationPreprocessor(model.model_dir)
|
||||
pipeline_ins = pipeline(
|
||||
task=Tasks.zero_shot_classification,
|
||||
model=model,
|
||||
preprocessor=tokenizer,
|
||||
candidate_labels=self.candidate_labels)
|
||||
print(pipeline_ins(input=self.sentence))
|
||||
preprocessor=tokenizer)
|
||||
print(pipeline_ins(input=self.sentence, candidate_labels=self.labels))
|
||||
|
||||
def test_run_with_model_name(self):
|
||||
pipeline_ins = pipeline(
|
||||
task=Tasks.zero_shot_classification,
|
||||
model=self.model_id,
|
||||
candidate_labels=self.candidate_labels)
|
||||
print(pipeline_ins(input=self.sentence))
|
||||
task=Tasks.zero_shot_classification, model=self.model_id)
|
||||
print(pipeline_ins(input=self.sentence, candidate_labels=self.labels))
|
||||
|
||||
def test_run_with_default_model(self):
|
||||
pipeline_ins = pipeline(
|
||||
task=Tasks.zero_shot_classification,
|
||||
candidate_labels=self.candidate_labels)
|
||||
print(pipeline_ins(input=self.sentence))
|
||||
pipeline_ins = pipeline(task=Tasks.zero_shot_classification)
|
||||
print(pipeline_ins(input=self.sentence, candidate_labels=self.labels))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user