From 2eda82949937d9bb8a263dbeab59de4bbf435853 Mon Sep 17 00:00:00 2001 From: "yuze.zyz" Date: Wed, 22 Jan 2025 23:06:20 +0800 Subject: [PATCH] fix --- modelscope/__init__.py | 10 +- modelscope/utils/hf_util/__init__.py | 2 + modelscope/utils/hf_util/auto_class.py | 288 +++++++++++-------------- modelscope/utils/hf_util/patcher.py | 214 +++++++++++------- modelscope/utils/import_utils.py | 4 + tests/hub/test_patch_hf.py | 3 - tests/utils/test_hf_util.py | 69 +++++- 7 files changed, 334 insertions(+), 256 deletions(-) diff --git a/modelscope/__init__.py b/modelscope/__init__.py index c969be68..0f0469b0 100644 --- a/modelscope/__init__.py +++ b/modelscope/__init__.py @@ -134,6 +134,14 @@ else: 'Qwen2VLForConditionalGeneration', 'T5EncoderModel' ] + from modelscope.utils import hf_util + + extra_objects = {} + attributes = dir(hf_util) + imports = [attr for attr in attributes if not attr.startswith('__')] + for _import in imports: + extra_objects[_import] = getattr(hf_util, _import) + import sys sys.modules[__name__] = LazyImportModule( @@ -141,5 +149,5 @@ else: globals()['__file__'], _import_structure, module_spec=__spec__, - extra_objects={}, + extra_objects=extra_objects, ) diff --git a/modelscope/utils/hf_util/__init__.py b/modelscope/utils/hf_util/__init__.py index e69de29b..a138ff7a 100644 --- a/modelscope/utils/hf_util/__init__.py +++ b/modelscope/utils/hf_util/__init__.py @@ -0,0 +1,2 @@ +from .auto_class import * +from .patcher import patch_context, patch_hub, unpatch_hub diff --git a/modelscope/utils/hf_util/auto_class.py b/modelscope/utils/hf_util/auto_class.py index 2a43f0e9..157158fd 100644 --- a/modelscope/utils/hf_util/auto_class.py +++ b/modelscope/utils/hf_util/auto_class.py @@ -1,187 +1,143 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import inspect import os -import sys -from functools import partial -from pathlib import Path -import importlib -from types import MethodType -from typing import BinaryIO, Dict, List, Optional, Union +from typing import TYPE_CHECKING -from huggingface_hub.hf_api import CommitInfo, future_compatible -from modelscope import snapshot_download -from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke -from modelscope.utils.logger import get_logger +if TYPE_CHECKING: + from transformers import __version__ as transformers_version + try: + from transformers import Qwen2VLForConditionalGeneration + except ImportError: + pass -try: - from transformers import AutoModelForImageToImage as AutoModelForImageToImageHF + try: + from transformers import GPTQConfig + from transformers import AwqConfig + except ImportError: + pass - AutoModelForImageToImage = get_wrapped_class(AutoModelForImageToImageHF) -except ImportError: - AutoModelForImageToImage = UnsupportedAutoClass('AutoModelForImageToImage') + try: + from transformers import AutoModelForImageToImage + except ImportError: + pass -try: - from transformers import AutoModelForImageTextToText as AutoModelForImageTextToTextHF + try: + from transformers import AutoModelForImageTextToText + except ImportError: + pass - AutoModelForImageTextToText = get_wrapped_class( - AutoModelForImageTextToTextHF) -except ImportError: - AutoModelForImageTextToText = UnsupportedAutoClass( - 'AutoModelForImageTextToText') + try: + from transformers import AutoModelForKeypointDetection + except ImportError: + pass -try: - from transformers import AutoModelForKeypointDetection as AutoModelForKeypointDetectionHF +else: - AutoModelForKeypointDetection = get_wrapped_class( - AutoModelForKeypointDetectionHF) -except ImportError: - AutoModelForKeypointDetection = UnsupportedAutoClass( - 'AutoModelForKeypointDetection') + class UnsupportedAutoClass: -try: - from transformers import \ - Qwen2VLForConditionalGeneration as Qwen2VLForConditionalGenerationHF + def __init__(self, name: str): + self.error_msg =\ + f'{name} is not supported with your installed Transformers version {transformers_version}. ' + \ + 'Please update your Transformers by "pip install transformers -U".' - Qwen2VLForConditionalGeneration = get_wrapped_class( - Qwen2VLForConditionalGenerationHF) -except ImportError: - Qwen2VLForConditionalGeneration = UnsupportedAutoClass( - 'Qwen2VLForConditionalGeneration') - - -logger = get_logger() - - -def get_wrapped_class(module_class, - ignore_file_pattern=[], - file_filter=None, - **kwargs): - """Get a custom wrapper class for auto classes to download the models from the ModelScope hub - Args: - module_class: The actual module class - ignore_file_pattern (`str` or `List`, *optional*, default to `None`): - Any file pattern to be ignored in downloading, like exact file names or file extensions. - Returns: - The wrapper - """ - default_ignore_file_pattern = ignore_file_pattern - default_file_filter = file_filter - - class ClassWrapper(module_class): - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, + def from_pretrained(self, pretrained_model_name_or_path, *model_args, **kwargs): - ignore_file_pattern = kwargs.pop('ignore_file_pattern', - default_ignore_file_pattern) - subfolder = kwargs.pop('subfolder', default_file_filter) - file_filter = None - if subfolder: - file_filter = f'{subfolder}/*' - if not os.path.exists(pretrained_model_name_or_path): - revision = kwargs.pop('revision', DEFAULT_MODEL_REVISION) - if file_filter is None: - model_dir = snapshot_download( - pretrained_model_name_or_path, - revision=revision, - ignore_file_pattern=ignore_file_pattern, - user_agent=user_agent()) - else: - model_dir = os.path.join( - snapshot_download( + raise ImportError(self.error_msg) + + def from_config(self, cls, config): + raise ImportError(self.error_msg) + + def user_agent(invoked_by=None): + from modelscope.utils.constant import Invoke + + if invoked_by is None: + invoked_by = Invoke.PRETRAINED + uagent = '%s/%s' % (Invoke.KEY, invoked_by) + return uagent + + def get_wrapped_class(module_class, + ignore_file_pattern=[], + file_filter=None, + **kwargs): + """Get a custom wrapper class for auto classes to download the models from the ModelScope hub + Args: + module_class: The actual module class + ignore_file_pattern (`str` or `List`, *optional*, default to `None`): + Any file pattern to be ignored in downloading, like exact file names or file extensions. + Returns: + The wrapper + """ + default_ignore_file_pattern = ignore_file_pattern + default_file_filter = file_filter + + class ClassWrapper(module_class): + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, + *model_args, **kwargs): + + from modelscope import snapshot_download + from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke + + ignore_file_pattern = kwargs.pop('ignore_file_pattern', + default_ignore_file_pattern) + subfolder = kwargs.pop('subfolder', default_file_filter) + file_filter = None + if subfolder: + file_filter = f'{subfolder}/*' + if not os.path.exists(pretrained_model_name_or_path): + revision = kwargs.pop('revision', DEFAULT_MODEL_REVISION) + if file_filter is None: + model_dir = snapshot_download( pretrained_model_name_or_path, revision=revision, ignore_file_pattern=ignore_file_pattern, - allow_file_pattern=file_filter, - user_agent=user_agent()), subfolder) - else: - model_dir = pretrained_model_name_or_path + user_agent=user_agent()) + else: + model_dir = os.path.join( + snapshot_download( + pretrained_model_name_or_path, + revision=revision, + ignore_file_pattern=ignore_file_pattern, + allow_file_pattern=file_filter, + user_agent=user_agent()), subfolder) + else: + model_dir = pretrained_model_name_or_path - module_obj = module_class.from_pretrained(model_dir, *model_args, - **kwargs) + module_obj = module_class.from_pretrained( + model_dir, *model_args, **kwargs) - if module_class.__name__.startswith('AutoModel'): - module_obj.model_dir = model_dir - return module_obj + if module_class.__name__.startswith('AutoModel'): + module_obj.model_dir = model_dir + return module_obj - ClassWrapper.__name__ = module_class.__name__ - ClassWrapper.__qualname__ = module_class.__qualname__ - return ClassWrapper + ClassWrapper.__name__ = module_class.__name__ + ClassWrapper.__qualname__ = module_class.__qualname__ + return ClassWrapper + from .patcher import get_all_imported_modules + all_imported_modules = get_all_imported_modules() + all_available_modules = [] + large_file_free = ['config', 'tokenizer'] + for module in all_imported_modules: + try: + if (hasattr(module, 'from_pretrained') + and 'pretrained_model_name_or_path' in inspect.signature( + module.from_pretrained).parameters): + if any(lf in module.__name__.lower() + for lf in large_file_free): + ignore_file_patterns = { + 'ignore_file_pattern': [ + r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', + r'\w+\.pt', r'\w+\.h5' + ] + } + else: + ignore_file_patterns = {} + all_available_modules.append( + get_wrapped_class(module, **ignore_file_patterns)) + except (ImportError, AttributeError): + pass -AutoModel = get_wrapped_class(AutoModelHF) -AutoModelForCausalLM = get_wrapped_class(AutoModelForCausalLMHF) -AutoModelForSeq2SeqLM = get_wrapped_class(AutoModelForSeq2SeqLMHF) -AutoModelForVision2Seq = get_wrapped_class(AutoModelForVision2SeqHF) -AutoModelForSequenceClassification = get_wrapped_class( - AutoModelForSequenceClassificationHF) -AutoModelForTokenClassification = get_wrapped_class( - AutoModelForTokenClassificationHF) -AutoModelForImageSegmentation = get_wrapped_class( - AutoModelForImageSegmentationHF) -AutoModelForImageClassification = get_wrapped_class( - AutoModelForImageClassificationHF) -AutoModelForZeroShotImageClassification = get_wrapped_class( - AutoModelForZeroShotImageClassificationHF) -AutoModelForQuestionAnswering = get_wrapped_class( - AutoModelForQuestionAnsweringHF) -AutoModelForTableQuestionAnswering = get_wrapped_class( - AutoModelForTableQuestionAnsweringHF) -AutoModelForVisualQuestionAnswering = get_wrapped_class( - AutoModelForVisualQuestionAnsweringHF) -AutoModelForDocumentQuestionAnswering = get_wrapped_class( - AutoModelForDocumentQuestionAnsweringHF) -AutoModelForSemanticSegmentation = get_wrapped_class( - AutoModelForSemanticSegmentationHF) -AutoModelForUniversalSegmentation = get_wrapped_class( - AutoModelForUniversalSegmentationHF) -AutoModelForInstanceSegmentation = get_wrapped_class( - AutoModelForInstanceSegmentationHF) -AutoModelForObjectDetection = get_wrapped_class(AutoModelForObjectDetectionHF) -AutoModelForZeroShotObjectDetection = get_wrapped_class( - AutoModelForZeroShotObjectDetectionHF) -AutoModelForAudioClassification = get_wrapped_class( - AutoModelForAudioClassificationHF) -AutoModelForSpeechSeq2Seq = get_wrapped_class(AutoModelForSpeechSeq2SeqHF) -AutoModelForMaskedImageModeling = get_wrapped_class( - AutoModelForMaskedImageModelingHF) -AutoModelForMaskedLM = get_wrapped_class(AutoModelForMaskedLMHF) -AutoModelForMaskGeneration = get_wrapped_class(AutoModelForMaskGenerationHF) -AutoModelForPreTraining = get_wrapped_class(AutoModelForPreTrainingHF) -AutoModelForTextEncoding = get_wrapped_class(AutoModelForTextEncodingHF) -T5EncoderModel = get_wrapped_class(T5EncoderModelHF) - -AutoTokenizer = get_wrapped_class( - AutoTokenizerHF, - ignore_file_pattern=[ - r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt', r'\w+\.h5' - ]) -AutoProcessor = get_wrapped_class( - AutoProcessorHF, - ignore_file_pattern=[ - r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt', r'\w+\.h5' - ]) -AutoConfig = get_wrapped_class( - AutoConfigHF, - ignore_file_pattern=[ - r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt', r'\w+\.h5' - ]) -GenerationConfig = get_wrapped_class( - GenerationConfigHF, - ignore_file_pattern=[ - r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt', r'\w+\.h5' - ]) -BitsAndBytesConfig = get_wrapped_class( - BitsAndBytesConfigHF, - ignore_file_pattern=[ - r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt', r'\w+\.h5' - ]) -AutoImageProcessor = get_wrapped_class( - AutoImageProcessorHF, - ignore_file_pattern=[ - r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt', r'\w+\.h5' - ]) - -GPTQConfig = GPTQConfigHF -AwqConfig = AwqConfigHF -BatchFeature = get_wrapped_class(BatchFeatureHF) + for module in all_available_modules: + globals()[module.__name__] = module diff --git a/modelscope/utils/hf_util/patcher.py b/modelscope/utils/hf_util/patcher.py index fd5103f1..c81c70d9 100644 --- a/modelscope/utils/hf_util/patcher.py +++ b/modelscope/utils/hf_util/patcher.py @@ -1,4 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import contextlib import importlib import inspect import os @@ -8,27 +9,59 @@ from pathlib import Path from types import MethodType from typing import BinaryIO, Dict, List, Optional, Union -from huggingface_hub.hf_api import CommitInfo, future_compatible -from modelscope import snapshot_download -from modelscope.utils.logger import get_logger +def get_all_imported_modules(): + all_imported_modules = [] + if importlib.util.find_spec('transformers') is not None: + import transformers + extra_modules = ['T5'] + lazy_module = sys.modules['transformers'] + _import_structure = lazy_module._import_structure + for key in _import_structure: + values = _import_structure[key] + for value in values: + # pretrained + if 'auto' in value.lower() or any(m in value + for m in extra_modules): + try: + module = importlib.import_module( + f'.{key}', transformers.__name__) + value = getattr(module, value) + all_imported_modules.append(value) + except (ImportError, AttributeError): + pass -logger = get_logger() + if importlib.util.find_spec('peft') is not None: + import peft + attributes = dir(peft) + imports = [attr for attr in attributes if not attr.startswith('__')] + all_imported_modules.extend( + [getattr(peft, _import) for _import in imports]) + + if importlib.util.find_spec('diffusers') is not None: + import diffusers + if importlib.util.find_spec('diffusers') is not None: + lazy_module = sys.modules['diffusers'] + _import_structure = lazy_module._import_structure + for key in _import_structure: + values = _import_structure[key] + for value in values: + if 'pipeline' in value.lower(): + try: + module = importlib.import_module( + f'.{key}', diffusers.__name__) + value = getattr(module, value) + all_imported_modules.append(value) + except (ImportError, AttributeError): + pass + return all_imported_modules -extra_modules = ['T5'] -lazy_module = sys.modules['transformers'] -all_modules = lazy_module._modules -all_imported_modules = [] -for module in all_modules: - if 'auto' in module.lower() or any(m in module for m in extra_modules): - all_imported_modules.append(importlib.import_module(f'transformers.{module}')) - - -def _patch_pretrained_class(): +def _patch_pretrained_class(all_imported_modules): def get_model_dir(pretrained_model_name_or_path, ignore_file_pattern, **kwargs): + from modelscope import snapshot_download if not os.path.exists(pretrained_model_name_or_path): revision = kwargs.pop('revision', None) model_dir = snapshot_download( @@ -43,94 +76,109 @@ def _patch_pretrained_class(): r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt' ] - def patch_pretrained_model_name_or_path(cls, pretrained_model_name_or_path, + def patch_pretrained_model_name_or_path(pretrained_model_name_or_path, *model_args, **kwargs): model_dir = get_model_dir(pretrained_model_name_or_path, kwargs.pop('ignore_file_pattern', None), **kwargs) - return kwargs.pop('ori_func')(cls, model_dir, *model_args, **kwargs) + return kwargs.pop('ori_func')(model_dir, *model_args, **kwargs) - def patch_peft_model_id(cls, model, model_id, *model_args, **kwargs): + def patch_peft_model_id(model, model_id, *model_args, **kwargs): model_dir = get_model_dir(model_id, kwargs.pop('ignore_file_pattern', None), **kwargs) - return kwargs.pop('ori_func')(cls, model, model_dir, *model_args, - **kwargs) + return kwargs.pop('ori_func')(model, model_dir, *model_args, **kwargs) - def _get_peft_type(cls, model_id, **kwargs): + def _get_peft_type(model_id, **kwargs): model_dir = get_model_dir(model_id, ignore_file_pattern, **kwargs) - return kwargs.pop('ori_func')(cls, model_dir, **kwargs) + return kwargs.pop('ori_func')(model_dir, **kwargs) for var in all_imported_modules: - if var is None: + if var is None or not hasattr(var, '__name__'): continue name = var.__name__ - need_model = 'model' in name.lower() or 'processor' in name.lower() or 'extractor' in name.lower() + need_model = 'model' in name.lower() or 'processor' in name.lower( + ) or 'extractor' in name.lower() if need_model: ignore_file_pattern_kwargs = {} else: - ignore_file_pattern_kwargs = {'ignore_file_pattern': ignore_file_pattern} + ignore_file_pattern_kwargs = { + 'ignore_file_pattern': ignore_file_pattern + } - has_from_pretrained = hasattr(var, 'from_pretrained') - has_get_peft_type = hasattr(var, '_get_peft_type') - has_get_config_dict = hasattr(var, 'get_config_dict') - parameters = inspect.signature(var.from_pretrained).parameters - is_peft = 'model' in parameters and 'model_id' in parameters + try: + has_from_pretrained = hasattr(var, 'from_pretrained') + has_get_peft_type = hasattr(var, '_get_peft_type') + has_get_config_dict = hasattr(var, 'get_config_dict') + except ImportError: + continue if has_from_pretrained and not hasattr(var, '_from_pretrained_origin'): + parameters = inspect.signature(var.from_pretrained).parameters + is_peft = 'model' in parameters and 'model_id' in parameters var._from_pretrained_origin = var.from_pretrained if not is_peft: - var.from_pretrained = partial(patch_pretrained_model_name_or_path, - ori_func=var._from_pretrained_origin, - **ignore_file_pattern_kwargs) + var.from_pretrained = partial( + patch_pretrained_model_name_or_path, + ori_func=var._from_pretrained_origin, + **ignore_file_pattern_kwargs) else: - var.from_pretrained = partial(patch_peft_model_id, - ori_func=var._from_pretrained_origin, - **ignore_file_pattern_kwargs) - delattr(var, '_from_pretrained_origin') + var.from_pretrained = partial( + patch_peft_model_id, + ori_func=var._from_pretrained_origin, + **ignore_file_pattern_kwargs) if has_get_peft_type and not hasattr(var, '_get_peft_type_origin'): var._get_peft_type_origin = var._get_peft_type - var._get_peft_type = partial(_get_peft_type, - ori_func=var._get_peft_type_origin, - **ignore_file_pattern_kwargs) - delattr(var, '_get_peft_type_origin') + var._get_peft_type = partial( + _get_peft_type, + ori_func=var._get_peft_type_origin, + **ignore_file_pattern_kwargs) if has_get_config_dict and not hasattr(var, '_get_config_dict_origin'): var._get_config_dict_origin = var.get_config_dict - var.get_config_dict = partial(patch_pretrained_model_name_or_path, - ori_func=var._get_config_dict_origin, - **ignore_file_pattern_kwargs) - delattr(var, '_get_config_dict_origin') + var.get_config_dict = partial( + patch_pretrained_model_name_or_path, + ori_func=var._get_config_dict_origin, + **ignore_file_pattern_kwargs) -def _unpatch_pretrained_class(): +def _unpatch_pretrained_class(all_imported_modules): for var in all_imported_modules: if var is None: continue - has_from_pretrained = hasattr(var, 'from_pretrained') - has_get_peft_type = hasattr(var, '_get_peft_type') - has_get_config_dict = hasattr(var, 'get_config_dict') + try: + has_from_pretrained = hasattr(var, 'from_pretrained') + has_get_peft_type = hasattr(var, '_get_peft_type') + has_get_config_dict = hasattr(var, 'get_config_dict') + except ImportError: + continue if has_from_pretrained and hasattr(var, '_from_pretrained_origin'): var.from_pretrained = var._from_pretrained_origin + delattr(var, '_from_pretrained_origin') if has_get_peft_type and hasattr(var, '_get_peft_type_origin'): var._get_peft_type = var._get_peft_type_origin + delattr(var, '_get_peft_type_origin') if has_get_config_dict and hasattr(var, '_get_config_dict_origin'): var.get_config_dict = var._get_config_dict_origin + delattr(var, '_get_config_dict_origin') def _patch_hub(): import huggingface_hub from huggingface_hub import hf_api from huggingface_hub.hf_api import api + from huggingface_hub.hf_api import CommitInfo, future_compatible + from modelscope import get_logger + logger = get_logger() def _file_exists( - self, - repo_id: str, - filename: str, - *, - repo_type: Optional[str] = None, - revision: Optional[str] = None, - token: Union[str, bool, None] = None, + self, + repo_id: str, + filename: str, + *, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + token: Union[str, bool, None] = None, ): """Patch huggingface_hub.file_exists""" if repo_type is not None: @@ -171,7 +219,8 @@ def _patch_hub(): api.try_login(token) return file_download( repo_id, - file_path=os.path.join(subfolder, filename) if subfolder else filename, + file_path=os.path.join(subfolder, filename) + if subfolder else filename, cache_dir=cache_dir, local_dir=local_dir, local_files_only=local_files_only, @@ -209,16 +258,16 @@ def _patch_hub(): @future_compatible def upload_folder( - *, - repo_id: str, - folder_path: Union[str, Path], - path_in_repo: Optional[str] = None, - commit_message: Optional[str] = None, - commit_description: Optional[str] = None, - token: Union[str, bool, None] = None, - revision: Optional[str] = 'master', - ignore_patterns: Optional[Union[List[str], str]] = None, - **kwargs, + *, + repo_id: str, + folder_path: Union[str, Path], + path_in_repo: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + token: Union[str, bool, None] = None, + revision: Optional[str] = 'master', + ignore_patterns: Optional[Union[List[str], str]] = None, + **kwargs, ): from modelscope.hub.push_to_hub import push_model_to_hub push_model_to_hub(repo_id, folder_path, path_in_repo, commit_message, @@ -233,16 +282,16 @@ def _patch_hub(): @future_compatible def upload_file( - self, - *, - path_or_fileobj: Union[str, Path, bytes, BinaryIO], - path_in_repo: str, - repo_id: str, - token: Union[str, bool, None] = None, - revision: Optional[str] = None, - commit_message: Optional[str] = None, - commit_description: Optional[str] = None, - **kwargs, + self, + *, + path_or_fileobj: Union[str, Path, bytes, BinaryIO], + path_in_repo: str, + repo_id: str, + token: Union[str, bool, None] = None, + revision: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + **kwargs, ): from modelscope.hub.push_to_hub import push_files_to_hub push_files_to_hub(path_or_fileobj, path_in_repo, repo_id, token, @@ -347,10 +396,19 @@ def _unpatch_hub(): repocard.upload_file = hf_api.upload_file delattr(hf_api, '_upload_file_origin') + def patch_hub(): _patch_hub() - _patch_pretrained_class() + _patch_pretrained_class(get_all_imported_modules()) def unpatch_hub(): - _unpatch_pretrained_class() + _unpatch_pretrained_class(get_all_imported_modules()) + _unpatch_hub() + + +@contextlib.contextmanager +def patch_context(): + patch_hub() + yield + unpatch_hub() diff --git a/modelscope/utils/import_utils.py b/modelscope/utils/import_utils.py index 984df7af..51ff7a96 100644 --- a/modelscope/utils/import_utils.py +++ b/modelscope/utils/import_utils.py @@ -282,6 +282,10 @@ def is_transformers_available(): return importlib.util.find_spec('transformers') is not None +def is_diffusers_available(): + return importlib.util.find_spec('diffusers') is not None + + def is_tensorrt_llm_available(): return importlib.util.find_spec('tensorrt_llm') is not None diff --git a/tests/hub/test_patch_hf.py b/tests/hub/test_patch_hf.py index dbaf2c11..13754923 100644 --- a/tests/hub/test_patch_hf.py +++ b/tests/hub/test_patch_hf.py @@ -14,6 +14,3 @@ class DownloadDatasetTest(unittest.TestCase): from transformers import AutoModel model = AutoModel.from_pretrained('AI-ModelScope/bert-base-uncased') self.assertTrue(model is not None) - - - diff --git a/tests/utils/test_hf_util.py b/tests/utils/test_hf_util.py index 9d6b61bd..03de5aea 100644 --- a/tests/utils/test_hf_util.py +++ b/tests/utils/test_hf_util.py @@ -2,8 +2,7 @@ import unittest -from modelscope import (AutoConfig, AutoModel, AutoModelForCausalLM, - AutoTokenizer, GenerationConfig) +from modelscope.utils.hf_util.patcher import patch_context class HFUtilTest(unittest.TestCase): @@ -15,6 +14,7 @@ class HFUtilTest(unittest.TestCase): pass def test_auto_tokenizer(self): + from modelscope import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained( 'baichuan-inc/Baichuan2-7B-Chat', trust_remote_code=True, @@ -28,11 +28,13 @@ class HFUtilTest(unittest.TestCase): self.assertTrue(BitsAndBytesConfig is not None) def test_auto_model(self): + from modelscope import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained( 'baichuan-inc/baichuan-7B', trust_remote_code=True) self.assertTrue(model is not None) def test_auto_config(self): + from modelscope import AutoConfig, GenerationConfig config = AutoConfig.from_pretrained( 'baichuan-inc/Baichuan-13B-Chat', trust_remote_code=True, @@ -45,12 +47,63 @@ class HFUtilTest(unittest.TestCase): self.assertEqual(gen_config.assistant_token_id, 196) def test_transformer_patch(self): - tokenizer = AutoTokenizer.from_pretrained( - 'iic/nlp_structbert_sentiment-classification_chinese-base') - self.assertIsNotNone(tokenizer) - model = AutoModelForCausalLM.from_pretrained( - 'iic/nlp_structbert_sentiment-classification_chinese-base') - self.assertIsNotNone(model) + with patch_context(): + from transformers import AutoTokenizer, AutoModelForCausalLM + tokenizer = AutoTokenizer.from_pretrained( + 'iic/nlp_structbert_sentiment-classification_chinese-base') + self.assertIsNotNone(tokenizer) + model = AutoModelForCausalLM.from_pretrained( + 'iic/nlp_structbert_sentiment-classification_chinese-base') + self.assertIsNotNone(model) + + def test_patch_model(self): + from modelscope.utils.hf_util.patcher import patch_context + with patch_context(): + from transformers import AutoModel + model = AutoModel.from_pretrained( + 'iic/nlp_structbert_sentiment-classification_chinese-tiny') + self.assertTrue(model is not None) + try: + model = AutoModel.from_pretrained( + 'iic/nlp_structbert_sentiment-classification_chinese-tiny') + except Exception: + pass + else: + self.assertTrue(False) + + def test_patch_config(self): + with patch_context(): + from transformers import AutoConfig + config = AutoConfig.from_pretrained( + 'iic/nlp_structbert_sentiment-classification_chinese-tiny') + self.assertTrue(config is not None) + try: + config = AutoConfig.from_pretrained( + 'iic/nlp_structbert_sentiment-classification_chinese-tiny') + except Exception: + pass + else: + self.assertTrue(False) + + def test_patch_diffusers(self): + with patch_context(): + from diffusers import StableDiffusionPipeline + pipe = StableDiffusionPipeline.from_pretrained( + 'AI-ModelScope/stable-diffusion-v1-5') + self.assertTrue(pipe is not None) + try: + pipe = StableDiffusionPipeline.from_pretrained( + 'AI-ModelScope/stable-diffusion-v1-5') + except Exception: + pass + else: + self.assertTrue(False) + + def test_patch_peft(self): + with patch_context(): + from peft import PeftModel + self.assertTrue(hasattr(PeftModel, '_from_pretrained_origin')) + self.assertFalse(hasattr(PeftModel, '_from_pretrained_origin')) if __name__ == '__main__':