mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 03:59:23 +01:00
update sequence classfication outpus
This commit is contained in:
@@ -32,24 +32,25 @@ class NLIPipeline(Pipeline):
|
||||
"""
|
||||
assert isinstance(model, str) or isinstance(model, SbertForNLI), \
|
||||
'model must be a single str or SbertForNLI'
|
||||
sc_model = model if isinstance(
|
||||
model = model if isinstance(
|
||||
model, SbertForNLI) else Model.from_pretrained(model)
|
||||
if preprocessor is None:
|
||||
preprocessor = NLIPreprocessor(
|
||||
sc_model.model_dir,
|
||||
model.model_dir,
|
||||
first_sequence=first_sequence,
|
||||
second_sequence=second_sequence)
|
||||
sc_model.eval()
|
||||
super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs)
|
||||
assert len(sc_model.id2label) > 0
|
||||
model.eval()
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
assert len(model.id2label) > 0
|
||||
|
||||
def forward(self, inputs: Dict[str, Any],
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
with torch.no_grad():
|
||||
return super().forward(inputs, **forward_params)
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any],
|
||||
**postprocess_params) -> Dict[str, str]:
|
||||
def postprocess(self,
|
||||
inputs: Dict[str, Any],
|
||||
topk: int = 5) -> Dict[str, str]:
|
||||
"""process the prediction results
|
||||
|
||||
Args:
|
||||
@@ -59,30 +60,13 @@ class NLIPipeline(Pipeline):
|
||||
Dict[str, str]: the prediction results
|
||||
"""
|
||||
|
||||
probs = inputs['probabilities']
|
||||
logits = inputs['logits']
|
||||
predictions = np.argsort(-probs, axis=-1)
|
||||
preds = predictions[0]
|
||||
b = 0
|
||||
new_result = list()
|
||||
for pred in preds:
|
||||
new_result.append({
|
||||
'pred': self.model.id2label[pred],
|
||||
'prob': float(probs[b][pred]),
|
||||
'logit': float(logits[b][pred])
|
||||
})
|
||||
new_results = list()
|
||||
new_results.append({
|
||||
'id':
|
||||
inputs['id'][b] if 'id' in inputs else str(uuid.uuid4()),
|
||||
'output':
|
||||
new_result,
|
||||
'predictions':
|
||||
new_result[0]['pred'],
|
||||
'probabilities':
|
||||
','.join([str(t) for t in inputs['probabilities'][b]]),
|
||||
'logits':
|
||||
','.join([str(t) for t in inputs['logits'][b]])
|
||||
})
|
||||
probs = inputs['probabilities'][0]
|
||||
num_classes = probs.shape[0]
|
||||
topk = min(topk, num_classes)
|
||||
top_indices = np.argpartition(probs, -topk)[-topk:]
|
||||
cls_ids = top_indices[np.argsort(probs[top_indices])]
|
||||
probs = probs[cls_ids].tolist()
|
||||
|
||||
return new_results[0]
|
||||
cls_names = [self.model.id2label[cid] for cid in cls_ids]
|
||||
|
||||
return {'scores': probs, 'labels': cls_names}
|
||||
|
||||
@@ -36,25 +36,26 @@ class SentimentClassificationPipeline(Pipeline):
|
||||
"""
|
||||
assert isinstance(model, str) or isinstance(model, SbertForSentimentClassification), \
|
||||
'model must be a single str or SbertForSentimentClassification'
|
||||
sc_model = model if isinstance(
|
||||
model = model if isinstance(
|
||||
model,
|
||||
SbertForSentimentClassification) else Model.from_pretrained(model)
|
||||
if preprocessor is None:
|
||||
preprocessor = SentimentClassificationPreprocessor(
|
||||
sc_model.model_dir,
|
||||
model.model_dir,
|
||||
first_sequence=first_sequence,
|
||||
second_sequence=second_sequence)
|
||||
sc_model.eval()
|
||||
super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs)
|
||||
assert len(sc_model.id2label) > 0
|
||||
model.eval()
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
assert len(model.id2label) > 0
|
||||
|
||||
def forward(self, inputs: Dict[str, Any],
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
with torch.no_grad():
|
||||
return super().forward(inputs, **forward_params)
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any],
|
||||
**postprocess_params) -> Dict[str, str]:
|
||||
def postprocess(self,
|
||||
inputs: Dict[str, Any],
|
||||
topk: int = 5) -> Dict[str, str]:
|
||||
"""process the prediction results
|
||||
|
||||
Args:
|
||||
@@ -64,30 +65,13 @@ class SentimentClassificationPipeline(Pipeline):
|
||||
Dict[str, str]: the prediction results
|
||||
"""
|
||||
|
||||
probs = inputs['probabilities']
|
||||
logits = inputs['logits']
|
||||
predictions = np.argsort(-probs, axis=-1)
|
||||
preds = predictions[0]
|
||||
b = 0
|
||||
new_result = list()
|
||||
for pred in preds:
|
||||
new_result.append({
|
||||
'pred': self.model.id2label[pred],
|
||||
'prob': float(probs[b][pred]),
|
||||
'logit': float(logits[b][pred])
|
||||
})
|
||||
new_results = list()
|
||||
new_results.append({
|
||||
'id':
|
||||
inputs['id'][b] if 'id' in inputs else str(uuid.uuid4()),
|
||||
'output':
|
||||
new_result,
|
||||
'predictions':
|
||||
new_result[0]['pred'],
|
||||
'probabilities':
|
||||
','.join([str(t) for t in inputs['probabilities'][b]]),
|
||||
'logits':
|
||||
','.join([str(t) for t in inputs['logits'][b]])
|
||||
})
|
||||
probs = inputs['probabilities'][0]
|
||||
num_classes = probs.shape[0]
|
||||
topk = min(topk, num_classes)
|
||||
top_indices = np.argpartition(probs, -topk)[-topk:]
|
||||
cls_ids = top_indices[np.argsort(probs[top_indices])]
|
||||
probs = probs[cls_ids].tolist()
|
||||
|
||||
return new_results[0]
|
||||
cls_names = [self.model.id2label[cid] for cid in cls_ids]
|
||||
|
||||
return {'scores': probs, 'labels': cls_names}
|
||||
|
||||
@@ -101,6 +101,20 @@ TASK_OUTPUTS = {
|
||||
# }
|
||||
Tasks.sentence_similarity: ['scores', 'labels'],
|
||||
|
||||
# sentiment classification result for single sample
|
||||
# {
|
||||
# "labels": ["happy", "sad", "calm", "angry"],
|
||||
# "scores": [0.9, 0.1, 0.05, 0.05]
|
||||
# }
|
||||
Tasks.sentiment_classification: ['scores', 'labels'],
|
||||
|
||||
# nli result for single sample
|
||||
# {
|
||||
# "labels": ["happy", "sad", "calm", "angry"],
|
||||
# "scores": [0.9, 0.1, 0.05, 0.05]
|
||||
# }
|
||||
Tasks.nli: ['scores', 'labels'],
|
||||
|
||||
# ============ audio tasks ===================
|
||||
|
||||
# audio processed for single file in PCM format
|
||||
|
||||
Reference in New Issue
Block a user