mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
[to #42322933]fix token classification bugs
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10550136
This commit is contained in:
committed by
yingda.chen
parent
9df3f5c41f
commit
b713e3de1c
@@ -254,6 +254,7 @@ class Pipelines(object):
|
||||
translation_en_to_de = 'translation_en_to_de' # keep it underscore
|
||||
translation_en_to_ro = 'translation_en_to_ro' # keep it underscore
|
||||
translation_en_to_fr = 'translation_en_to_fr' # keep it underscore
|
||||
token_classification = 'token-classification'
|
||||
|
||||
# audio tasks
|
||||
sambert_hifigan_tts = 'sambert-hifigan-tts'
|
||||
|
||||
@@ -66,7 +66,6 @@ class TokenClassificationModel(SingleBackboneTaskModelBase):
|
||||
attentions=outputs.attentions,
|
||||
offset_mapping=input['offset_mapping'],
|
||||
)
|
||||
return outputs
|
||||
|
||||
def extract_logits(self, outputs):
|
||||
return outputs[OutputKeys.LOGITS].cpu().detach()
|
||||
|
||||
@@ -17,6 +17,8 @@ from modelscope.utils.tensor_utils import (torch_nested_detach,
|
||||
__all__ = ['TokenClassificationPipeline']
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.token_classification, module_name=Pipelines.token_classification)
|
||||
@PIPELINES.register_module(
|
||||
Tasks.token_classification, module_name=Pipelines.part_of_speech)
|
||||
@PIPELINES.register_module(
|
||||
@@ -41,7 +43,7 @@ class TokenClassificationPipeline(Pipeline):
|
||||
str) else model
|
||||
|
||||
if preprocessor is None:
|
||||
preprocessor = Model.from_pretrained(
|
||||
preprocessor = Preprocessor.from_pretrained(
|
||||
model.model_dir,
|
||||
sequence_length=kwargs.pop('sequence_length', 128))
|
||||
model.eval()
|
||||
|
||||
Reference in New Issue
Block a user