[to #42322933]fix token classification bugs

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10550136
This commit is contained in:
zhangzhicheng.zzc
2022-10-27 22:53:16 +08:00
committed by yingda.chen
parent 9df3f5c41f
commit b713e3de1c
3 changed files with 4 additions and 2 deletions

View File

@@ -254,6 +254,7 @@ class Pipelines(object):
translation_en_to_de = 'translation_en_to_de' # keep it underscore
translation_en_to_ro = 'translation_en_to_ro' # keep it underscore
translation_en_to_fr = 'translation_en_to_fr' # keep it underscore
token_classification = 'token-classification'
# audio tasks
sambert_hifigan_tts = 'sambert-hifigan-tts'

View File

@@ -66,7 +66,6 @@ class TokenClassificationModel(SingleBackboneTaskModelBase):
attentions=outputs.attentions,
offset_mapping=input['offset_mapping'],
)
return outputs
def extract_logits(self, outputs):
return outputs[OutputKeys.LOGITS].cpu().detach()

View File

@@ -17,6 +17,8 @@ from modelscope.utils.tensor_utils import (torch_nested_detach,
__all__ = ['TokenClassificationPipeline']
@PIPELINES.register_module(
Tasks.token_classification, module_name=Pipelines.token_classification)
@PIPELINES.register_module(
Tasks.token_classification, module_name=Pipelines.part_of_speech)
@PIPELINES.register_module(
@@ -41,7 +43,7 @@ class TokenClassificationPipeline(Pipeline):
str) else model
if preprocessor is None:
preprocessor = Model.from_pretrained(
preprocessor = Preprocessor.from_pretrained(
model.model_dir,
sequence_length=kwargs.pop('sequence_length', 128))
model.eval()