diff --git a/modelscope/__init__.py b/modelscope/__init__.py index 11f28767..97abdbd3 100644 --- a/modelscope/__init__.py +++ b/modelscope/__init__.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING from modelscope.utils.import_utils import LazyImportModule +from .utils.automodel_utils import fix_transformers_upgrade if TYPE_CHECKING: from .exporters import Exporter, TfModelExporter, TorchModelExporter @@ -33,7 +34,8 @@ if TYPE_CHECKING: AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoTokenizer, - GenerationConfig) + GenerationConfig, AutoImageProcessor, + BatchFeature) from .utils.hub import create_model_if_not_exist, read_config from .utils.logger import get_logger from .version import __release_datetime__, __version__ @@ -81,7 +83,8 @@ else: 'BitsAndBytesConfig', 'AutoModelForCausalLM', 'AutoModelForSeq2SeqLM', 'AutoTokenizer', 'AutoModelForSequenceClassification', - 'AutoModelForTokenClassification' + 'AutoModelForTokenClassification', 'AutoImageProcessor', + 'BatchFeature' ], 'msdatasets': ['MsDataset'] } @@ -95,3 +98,5 @@ else: module_spec=__spec__, extra_objects={}, ) + +fix_transformers_upgrade() diff --git a/modelscope/models/base/base_model.py b/modelscope/models/base/base_model.py index 8a317218..9beb156b 100644 --- a/modelscope/models/base/base_model.py +++ b/modelscope/models/base/base_model.py @@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional, Union from modelscope.hub.snapshot_download import snapshot_download from modelscope.metainfo import Tasks from modelscope.models.builder import build_backbone, build_model -from modelscope.utils.automodel_utils import (can_load_by_ms, fix_upgrade, +from modelscope.utils.automodel_utils import (can_load_by_ms, try_to_load_hf_model) from modelscope.utils.config import Config, ConfigDict from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke, ModelFile @@ -192,7 +192,6 @@ class Model(ABC): model_cfg.pop('model_dir', None) model.name = model_name_or_path model.model_dir = local_model_dir - fix_upgrade(model) return model def save_pretrained(self, diff --git a/modelscope/models/nlp/dgds/document_grounded_dialog_retrieval.py b/modelscope/models/nlp/dgds/document_grounded_dialog_retrieval.py index 07685673..43051b57 100644 --- a/modelscope/models/nlp/dgds/document_grounded_dialog_retrieval.py +++ b/modelscope/models/nlp/dgds/document_grounded_dialog_retrieval.py @@ -28,6 +28,8 @@ class DocumentGroundedDialogRetrievalModel(TorchModel): map_location='cpu') compatible_position_ids(state_dict, 'ctx_encoder.encoder.embeddings.position_ids') + compatible_position_ids(state_dict, + 'qry_encoder.encoder.embeddings.position_ids') self.model.load_state_dict(state_dict) def forward(self, input: Dict[str, Tensor], gck_segment=32): diff --git a/modelscope/utils/automodel_utils.py b/modelscope/utils/automodel_utils.py index 706a1544..f96046ff 100644 --- a/modelscope/utils/automodel_utils.py +++ b/modelscope/utils/automodel_utils.py @@ -25,15 +25,27 @@ def can_load_by_ms(model_dir: str, task_name: Optional[str], def fix_upgrade(module_obj: Any): + from transformers import PreTrainedModel + if hasattr(module_obj, '_set_gradient_checkpointing') \ + and 'value' in inspect.signature(module_obj._set_gradient_checkpointing).parameters.keys(): + module_obj._set_gradient_checkpointing = MethodType( + PreTrainedModel._set_gradient_checkpointing, module_obj) + + +def post_init(self, *args, **kwargs): + fix_upgrade(self) + self.post_init_origin(*args, **kwargs) + + +def fix_transformers_upgrade(): if is_transformers_available(): # from 4.35.0, transformers changes its arguments of _set_gradient_checkpointing import transformers from transformers import PreTrainedModel - if version.parse(transformers.__version__) >= version.parse('4.35.0'): - if isinstance(module_obj, PreTrainedModel) and hasattr(module_obj, '_set_gradient_checkpointing') \ - and 'value' in inspect.signature(module_obj._set_gradient_checkpointing).parameters.keys(): - module_obj._set_gradient_checkpointing = MethodType( - PreTrainedModel._set_gradient_checkpointing, module_obj) + if version.parse(transformers.__version__) >= version.parse('4.35.0') \ + and not hasattr(PreTrainedModel, 'post_init_origin'): + PreTrainedModel.post_init_origin = PreTrainedModel.post_init + PreTrainedModel.post_init = post_init def _can_load_by_hf_automodel(automodel_class: type, config) -> bool: diff --git a/modelscope/utils/hf_util.py b/modelscope/utils/hf_util.py index f98dbea2..5a81d52d 100644 --- a/modelscope/utils/hf_util.py +++ b/modelscope/utils/hf_util.py @@ -2,6 +2,7 @@ import os from transformers import AutoConfig as AutoConfigHF +from transformers import AutoImageProcessor as AutoImageProcessorHF from transformers import AutoModel as AutoModelHF from transformers import AutoModelForCausalLM as AutoModelForCausalLMHF from transformers import AutoModelForSeq2SeqLM as AutoModelForSeq2SeqLMHF @@ -10,12 +11,12 @@ from transformers import \ from transformers import \ AutoModelForTokenClassification as AutoModelForTokenClassificationHF from transformers import AutoTokenizer as AutoTokenizerHF +from transformers import BatchFeature as BatchFeatureHF from transformers import BitsAndBytesConfig as BitsAndBytesConfigHF from transformers import GenerationConfig as GenerationConfigHF from transformers import PreTrainedModel, PreTrainedTokenizerBase from modelscope import snapshot_download -from modelscope.utils.automodel_utils import fix_upgrade from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke try: @@ -112,7 +113,6 @@ def get_wrapped_class(module_class, ignore_file_pattern=[], **kwargs): if module_class.__name__.startswith('AutoModel'): module_obj.model_dir = model_dir - fix_upgrade(module_obj) return module_obj ClassWrapper.__name__ = module_class.__name__ @@ -136,3 +136,5 @@ GenerationConfig = get_wrapped_class( GenerationConfigHF, ignore_file_pattern=[r'\w+\.bin', r'\w+\.safetensors']) GPTQConfig = GPTQConfigHF BitsAndBytesConfig = BitsAndBytesConfigHF +AutoImageProcessor = get_wrapped_class(AutoImageProcessorHF) +BatchFeature = get_wrapped_class(BatchFeatureHF) diff --git a/modelscope/utils/test_utils.py b/modelscope/utils/test_utils.py index 03d293ec..bc7b4311 100644 --- a/modelscope/utils/test_utils.py +++ b/modelscope/utils/test_utils.py @@ -115,6 +115,9 @@ def get_case_model_info(): elements = line.split(':') test_file = elements[0] model_pos = line.find('damo') + if model_pos == -1 or (model_pos - 1) > len(line): + print('Processing line: %s failed' % line) + continue left_quote = line[model_pos - 1] rquote_idx = line.rfind(left_quote) model_name = line[model_pos:rquote_idx]