From 102943923e37ec167bf70716ff0768414159b689 Mon Sep 17 00:00:00 2001 From: suluyan Date: Thu, 23 Jun 2022 12:53:42 +0800 Subject: [PATCH] fix --- modelscope/metainfo.py | 1 + modelscope/models/nlp/masked_language_model.py | 7 ++++--- tests/pipelines/test_fill_mask.py | 1 - 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index cea10739..af39f3f4 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -15,6 +15,7 @@ class Models(object): bert = 'bert' palm = 'palm-v2' structbert = 'structbert' + veco = 'veco' # audio models sambert_hifi_16k = 'sambert-hifi-16k' diff --git a/modelscope/models/nlp/masked_language_model.py b/modelscope/models/nlp/masked_language_model.py index 514c72c7..fd5f97e6 100644 --- a/modelscope/models/nlp/masked_language_model.py +++ b/modelscope/models/nlp/masked_language_model.py @@ -2,7 +2,8 @@ from typing import Any, Dict, Optional, Union import numpy as np -from ...utils.constant import Tasks +from modelscope.metainfo import Models +from modelscope.utils.constant import Tasks from ..base import Model, Tensor from ..builder import MODELS @@ -36,14 +37,14 @@ class AliceMindBaseForMaskedLM(Model): return {'logits': rst['logits'], 'input_ids': inputs['input_ids']} -@MODELS.register_module(Tasks.fill_mask, module_name=r'sbert') +@MODELS.register_module(Tasks.fill_mask, module_name=Models.structbert) class StructBertForMaskedLM(AliceMindBaseForMaskedLM): # The StructBert for MaskedLM uses the same underlying model structure # as the base model class. pass -@MODELS.register_module(Tasks.fill_mask, module_name=r'veco') +@MODELS.register_module(Tasks.fill_mask, module_name=Models.veco) class VecoForMaskedLM(AliceMindBaseForMaskedLM): # The Veco for MaskedLM uses the same underlying model structure # as the base model class. diff --git a/tests/pipelines/test_fill_mask.py b/tests/pipelines/test_fill_mask.py index a4d53403..d56ffe90 100644 --- a/tests/pipelines/test_fill_mask.py +++ b/tests/pipelines/test_fill_mask.py @@ -10,7 +10,6 @@ from modelscope.models.nlp import StructBertForMaskedLM, VecoForMaskedLM from modelscope.pipelines import FillMaskPipeline, pipeline from modelscope.preprocessors import FillMaskPreprocessor from modelscope.utils.constant import Tasks -from modelscope.utils.hub import get_model_cache_dir from modelscope.utils.test_utils import test_level