From b713e3de1c7fcc0cd4bdec39772be525c85d0287 Mon Sep 17 00:00:00 2001 From: "zhangzhicheng.zzc" Date: Thu, 27 Oct 2022 22:53:16 +0800 Subject: [PATCH] [to #42322933]fix token classification bugs Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10550136 --- modelscope/metainfo.py | 1 + modelscope/models/nlp/task_models/token_classification.py | 1 - modelscope/pipelines/nlp/token_classification_pipeline.py | 4 +++- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 2aeb86da..a671ded5 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -254,6 +254,7 @@ class Pipelines(object): translation_en_to_de = 'translation_en_to_de' # keep it underscore translation_en_to_ro = 'translation_en_to_ro' # keep it underscore translation_en_to_fr = 'translation_en_to_fr' # keep it underscore + token_classification = 'token-classification' # audio tasks sambert_hifigan_tts = 'sambert-hifigan-tts' diff --git a/modelscope/models/nlp/task_models/token_classification.py b/modelscope/models/nlp/task_models/token_classification.py index 2739bf11..8b523baf 100644 --- a/modelscope/models/nlp/task_models/token_classification.py +++ b/modelscope/models/nlp/task_models/token_classification.py @@ -66,7 +66,6 @@ class TokenClassificationModel(SingleBackboneTaskModelBase): attentions=outputs.attentions, offset_mapping=input['offset_mapping'], ) - return outputs def extract_logits(self, outputs): return outputs[OutputKeys.LOGITS].cpu().detach() diff --git a/modelscope/pipelines/nlp/token_classification_pipeline.py b/modelscope/pipelines/nlp/token_classification_pipeline.py index c36f0dfc..75bc538d 100644 --- a/modelscope/pipelines/nlp/token_classification_pipeline.py +++ b/modelscope/pipelines/nlp/token_classification_pipeline.py @@ -17,6 +17,8 @@ from modelscope.utils.tensor_utils import (torch_nested_detach, __all__ = ['TokenClassificationPipeline'] +@PIPELINES.register_module( + Tasks.token_classification, module_name=Pipelines.token_classification) @PIPELINES.register_module( Tasks.token_classification, module_name=Pipelines.part_of_speech) @PIPELINES.register_module( @@ -41,7 +43,7 @@ class TokenClassificationPipeline(Pipeline): str) else model if preprocessor is None: - preprocessor = Model.from_pretrained( + preprocessor = Preprocessor.from_pretrained( model.model_dir, sequence_length=kwargs.pop('sequence_length', 128)) model.eval()