This commit is contained in:
yuze.zyz
2025-01-22 23:06:20 +08:00
parent d90c27524e
commit 2eda829499
7 changed files with 334 additions and 256 deletions

View File

@@ -134,6 +134,14 @@ else:
'Qwen2VLForConditionalGeneration', 'T5EncoderModel'
]
from modelscope.utils import hf_util
extra_objects = {}
attributes = dir(hf_util)
imports = [attr for attr in attributes if not attr.startswith('__')]
for _import in imports:
extra_objects[_import] = getattr(hf_util, _import)
import sys
sys.modules[__name__] = LazyImportModule(
@@ -141,5 +149,5 @@ else:
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
extra_objects=extra_objects,
)

View File

@@ -0,0 +1,2 @@
from .auto_class import *
from .patcher import patch_context, patch_hub, unpatch_hub

View File

@@ -1,187 +1,143 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import inspect
import os
import sys
from functools import partial
from pathlib import Path
import importlib
from types import MethodType
from typing import BinaryIO, Dict, List, Optional, Union
from typing import TYPE_CHECKING
from huggingface_hub.hf_api import CommitInfo, future_compatible
from modelscope import snapshot_download
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke
from modelscope.utils.logger import get_logger
if TYPE_CHECKING:
from transformers import __version__ as transformers_version
try:
from transformers import Qwen2VLForConditionalGeneration
except ImportError:
pass
try:
from transformers import AutoModelForImageToImage as AutoModelForImageToImageHF
try:
from transformers import GPTQConfig
from transformers import AwqConfig
except ImportError:
pass
AutoModelForImageToImage = get_wrapped_class(AutoModelForImageToImageHF)
except ImportError:
AutoModelForImageToImage = UnsupportedAutoClass('AutoModelForImageToImage')
try:
from transformers import AutoModelForImageToImage
except ImportError:
pass
try:
from transformers import AutoModelForImageTextToText as AutoModelForImageTextToTextHF
try:
from transformers import AutoModelForImageTextToText
except ImportError:
pass
AutoModelForImageTextToText = get_wrapped_class(
AutoModelForImageTextToTextHF)
except ImportError:
AutoModelForImageTextToText = UnsupportedAutoClass(
'AutoModelForImageTextToText')
try:
from transformers import AutoModelForKeypointDetection
except ImportError:
pass
try:
from transformers import AutoModelForKeypointDetection as AutoModelForKeypointDetectionHF
else:
AutoModelForKeypointDetection = get_wrapped_class(
AutoModelForKeypointDetectionHF)
except ImportError:
AutoModelForKeypointDetection = UnsupportedAutoClass(
'AutoModelForKeypointDetection')
class UnsupportedAutoClass:
try:
from transformers import \
Qwen2VLForConditionalGeneration as Qwen2VLForConditionalGenerationHF
def __init__(self, name: str):
self.error_msg =\
f'{name} is not supported with your installed Transformers version {transformers_version}. ' + \
'Please update your Transformers by "pip install transformers -U".'
Qwen2VLForConditionalGeneration = get_wrapped_class(
Qwen2VLForConditionalGenerationHF)
except ImportError:
Qwen2VLForConditionalGeneration = UnsupportedAutoClass(
'Qwen2VLForConditionalGeneration')
logger = get_logger()
def get_wrapped_class(module_class,
ignore_file_pattern=[],
file_filter=None,
**kwargs):
"""Get a custom wrapper class for auto classes to download the models from the ModelScope hub
Args:
module_class: The actual module class
ignore_file_pattern (`str` or `List`, *optional*, default to `None`):
Any file pattern to be ignored in downloading, like exact file names or file extensions.
Returns:
The wrapper
"""
default_ignore_file_pattern = ignore_file_pattern
default_file_filter = file_filter
class ClassWrapper(module_class):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
def from_pretrained(self, pretrained_model_name_or_path, *model_args,
**kwargs):
ignore_file_pattern = kwargs.pop('ignore_file_pattern',
default_ignore_file_pattern)
subfolder = kwargs.pop('subfolder', default_file_filter)
file_filter = None
if subfolder:
file_filter = f'{subfolder}/*'
if not os.path.exists(pretrained_model_name_or_path):
revision = kwargs.pop('revision', DEFAULT_MODEL_REVISION)
if file_filter is None:
model_dir = snapshot_download(
pretrained_model_name_or_path,
revision=revision,
ignore_file_pattern=ignore_file_pattern,
user_agent=user_agent())
else:
model_dir = os.path.join(
snapshot_download(
raise ImportError(self.error_msg)
def from_config(self, cls, config):
raise ImportError(self.error_msg)
def user_agent(invoked_by=None):
from modelscope.utils.constant import Invoke
if invoked_by is None:
invoked_by = Invoke.PRETRAINED
uagent = '%s/%s' % (Invoke.KEY, invoked_by)
return uagent
def get_wrapped_class(module_class,
ignore_file_pattern=[],
file_filter=None,
**kwargs):
"""Get a custom wrapper class for auto classes to download the models from the ModelScope hub
Args:
module_class: The actual module class
ignore_file_pattern (`str` or `List`, *optional*, default to `None`):
Any file pattern to be ignored in downloading, like exact file names or file extensions.
Returns:
The wrapper
"""
default_ignore_file_pattern = ignore_file_pattern
default_file_filter = file_filter
class ClassWrapper(module_class):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path,
*model_args, **kwargs):
from modelscope import snapshot_download
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke
ignore_file_pattern = kwargs.pop('ignore_file_pattern',
default_ignore_file_pattern)
subfolder = kwargs.pop('subfolder', default_file_filter)
file_filter = None
if subfolder:
file_filter = f'{subfolder}/*'
if not os.path.exists(pretrained_model_name_or_path):
revision = kwargs.pop('revision', DEFAULT_MODEL_REVISION)
if file_filter is None:
model_dir = snapshot_download(
pretrained_model_name_or_path,
revision=revision,
ignore_file_pattern=ignore_file_pattern,
allow_file_pattern=file_filter,
user_agent=user_agent()), subfolder)
else:
model_dir = pretrained_model_name_or_path
user_agent=user_agent())
else:
model_dir = os.path.join(
snapshot_download(
pretrained_model_name_or_path,
revision=revision,
ignore_file_pattern=ignore_file_pattern,
allow_file_pattern=file_filter,
user_agent=user_agent()), subfolder)
else:
model_dir = pretrained_model_name_or_path
module_obj = module_class.from_pretrained(model_dir, *model_args,
**kwargs)
module_obj = module_class.from_pretrained(
model_dir, *model_args, **kwargs)
if module_class.__name__.startswith('AutoModel'):
module_obj.model_dir = model_dir
return module_obj
if module_class.__name__.startswith('AutoModel'):
module_obj.model_dir = model_dir
return module_obj
ClassWrapper.__name__ = module_class.__name__
ClassWrapper.__qualname__ = module_class.__qualname__
return ClassWrapper
ClassWrapper.__name__ = module_class.__name__
ClassWrapper.__qualname__ = module_class.__qualname__
return ClassWrapper
from .patcher import get_all_imported_modules
all_imported_modules = get_all_imported_modules()
all_available_modules = []
large_file_free = ['config', 'tokenizer']
for module in all_imported_modules:
try:
if (hasattr(module, 'from_pretrained')
and 'pretrained_model_name_or_path' in inspect.signature(
module.from_pretrained).parameters):
if any(lf in module.__name__.lower()
for lf in large_file_free):
ignore_file_patterns = {
'ignore_file_pattern': [
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth',
r'\w+\.pt', r'\w+\.h5'
]
}
else:
ignore_file_patterns = {}
all_available_modules.append(
get_wrapped_class(module, **ignore_file_patterns))
except (ImportError, AttributeError):
pass
AutoModel = get_wrapped_class(AutoModelHF)
AutoModelForCausalLM = get_wrapped_class(AutoModelForCausalLMHF)
AutoModelForSeq2SeqLM = get_wrapped_class(AutoModelForSeq2SeqLMHF)
AutoModelForVision2Seq = get_wrapped_class(AutoModelForVision2SeqHF)
AutoModelForSequenceClassification = get_wrapped_class(
AutoModelForSequenceClassificationHF)
AutoModelForTokenClassification = get_wrapped_class(
AutoModelForTokenClassificationHF)
AutoModelForImageSegmentation = get_wrapped_class(
AutoModelForImageSegmentationHF)
AutoModelForImageClassification = get_wrapped_class(
AutoModelForImageClassificationHF)
AutoModelForZeroShotImageClassification = get_wrapped_class(
AutoModelForZeroShotImageClassificationHF)
AutoModelForQuestionAnswering = get_wrapped_class(
AutoModelForQuestionAnsweringHF)
AutoModelForTableQuestionAnswering = get_wrapped_class(
AutoModelForTableQuestionAnsweringHF)
AutoModelForVisualQuestionAnswering = get_wrapped_class(
AutoModelForVisualQuestionAnsweringHF)
AutoModelForDocumentQuestionAnswering = get_wrapped_class(
AutoModelForDocumentQuestionAnsweringHF)
AutoModelForSemanticSegmentation = get_wrapped_class(
AutoModelForSemanticSegmentationHF)
AutoModelForUniversalSegmentation = get_wrapped_class(
AutoModelForUniversalSegmentationHF)
AutoModelForInstanceSegmentation = get_wrapped_class(
AutoModelForInstanceSegmentationHF)
AutoModelForObjectDetection = get_wrapped_class(AutoModelForObjectDetectionHF)
AutoModelForZeroShotObjectDetection = get_wrapped_class(
AutoModelForZeroShotObjectDetectionHF)
AutoModelForAudioClassification = get_wrapped_class(
AutoModelForAudioClassificationHF)
AutoModelForSpeechSeq2Seq = get_wrapped_class(AutoModelForSpeechSeq2SeqHF)
AutoModelForMaskedImageModeling = get_wrapped_class(
AutoModelForMaskedImageModelingHF)
AutoModelForMaskedLM = get_wrapped_class(AutoModelForMaskedLMHF)
AutoModelForMaskGeneration = get_wrapped_class(AutoModelForMaskGenerationHF)
AutoModelForPreTraining = get_wrapped_class(AutoModelForPreTrainingHF)
AutoModelForTextEncoding = get_wrapped_class(AutoModelForTextEncodingHF)
T5EncoderModel = get_wrapped_class(T5EncoderModelHF)
AutoTokenizer = get_wrapped_class(
AutoTokenizerHF,
ignore_file_pattern=[
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt', r'\w+\.h5'
])
AutoProcessor = get_wrapped_class(
AutoProcessorHF,
ignore_file_pattern=[
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt', r'\w+\.h5'
])
AutoConfig = get_wrapped_class(
AutoConfigHF,
ignore_file_pattern=[
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt', r'\w+\.h5'
])
GenerationConfig = get_wrapped_class(
GenerationConfigHF,
ignore_file_pattern=[
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt', r'\w+\.h5'
])
BitsAndBytesConfig = get_wrapped_class(
BitsAndBytesConfigHF,
ignore_file_pattern=[
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt', r'\w+\.h5'
])
AutoImageProcessor = get_wrapped_class(
AutoImageProcessorHF,
ignore_file_pattern=[
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt', r'\w+\.h5'
])
GPTQConfig = GPTQConfigHF
AwqConfig = AwqConfigHF
BatchFeature = get_wrapped_class(BatchFeatureHF)
for module in all_available_modules:
globals()[module.__name__] = module

View File

@@ -1,4 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import contextlib
import importlib
import inspect
import os
@@ -8,27 +9,59 @@ from pathlib import Path
from types import MethodType
from typing import BinaryIO, Dict, List, Optional, Union
from huggingface_hub.hf_api import CommitInfo, future_compatible
from modelscope import snapshot_download
from modelscope.utils.logger import get_logger
def get_all_imported_modules():
all_imported_modules = []
if importlib.util.find_spec('transformers') is not None:
import transformers
extra_modules = ['T5']
lazy_module = sys.modules['transformers']
_import_structure = lazy_module._import_structure
for key in _import_structure:
values = _import_structure[key]
for value in values:
# pretrained
if 'auto' in value.lower() or any(m in value
for m in extra_modules):
try:
module = importlib.import_module(
f'.{key}', transformers.__name__)
value = getattr(module, value)
all_imported_modules.append(value)
except (ImportError, AttributeError):
pass
logger = get_logger()
if importlib.util.find_spec('peft') is not None:
import peft
attributes = dir(peft)
imports = [attr for attr in attributes if not attr.startswith('__')]
all_imported_modules.extend(
[getattr(peft, _import) for _import in imports])
if importlib.util.find_spec('diffusers') is not None:
import diffusers
if importlib.util.find_spec('diffusers') is not None:
lazy_module = sys.modules['diffusers']
_import_structure = lazy_module._import_structure
for key in _import_structure:
values = _import_structure[key]
for value in values:
if 'pipeline' in value.lower():
try:
module = importlib.import_module(
f'.{key}', diffusers.__name__)
value = getattr(module, value)
all_imported_modules.append(value)
except (ImportError, AttributeError):
pass
return all_imported_modules
extra_modules = ['T5']
lazy_module = sys.modules['transformers']
all_modules = lazy_module._modules
all_imported_modules = []
for module in all_modules:
if 'auto' in module.lower() or any(m in module for m in extra_modules):
all_imported_modules.append(importlib.import_module(f'transformers.{module}'))
def _patch_pretrained_class():
def _patch_pretrained_class(all_imported_modules):
def get_model_dir(pretrained_model_name_or_path, ignore_file_pattern,
**kwargs):
from modelscope import snapshot_download
if not os.path.exists(pretrained_model_name_or_path):
revision = kwargs.pop('revision', None)
model_dir = snapshot_download(
@@ -43,94 +76,109 @@ def _patch_pretrained_class():
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt'
]
def patch_pretrained_model_name_or_path(cls, pretrained_model_name_or_path,
def patch_pretrained_model_name_or_path(pretrained_model_name_or_path,
*model_args, **kwargs):
model_dir = get_model_dir(pretrained_model_name_or_path,
kwargs.pop('ignore_file_pattern', None),
**kwargs)
return kwargs.pop('ori_func')(cls, model_dir, *model_args, **kwargs)
return kwargs.pop('ori_func')(model_dir, *model_args, **kwargs)
def patch_peft_model_id(cls, model, model_id, *model_args, **kwargs):
def patch_peft_model_id(model, model_id, *model_args, **kwargs):
model_dir = get_model_dir(model_id,
kwargs.pop('ignore_file_pattern', None),
**kwargs)
return kwargs.pop('ori_func')(cls, model, model_dir, *model_args,
**kwargs)
return kwargs.pop('ori_func')(model, model_dir, *model_args, **kwargs)
def _get_peft_type(cls, model_id, **kwargs):
def _get_peft_type(model_id, **kwargs):
model_dir = get_model_dir(model_id, ignore_file_pattern, **kwargs)
return kwargs.pop('ori_func')(cls, model_dir, **kwargs)
return kwargs.pop('ori_func')(model_dir, **kwargs)
for var in all_imported_modules:
if var is None:
if var is None or not hasattr(var, '__name__'):
continue
name = var.__name__
need_model = 'model' in name.lower() or 'processor' in name.lower() or 'extractor' in name.lower()
need_model = 'model' in name.lower() or 'processor' in name.lower(
) or 'extractor' in name.lower()
if need_model:
ignore_file_pattern_kwargs = {}
else:
ignore_file_pattern_kwargs = {'ignore_file_pattern': ignore_file_pattern}
ignore_file_pattern_kwargs = {
'ignore_file_pattern': ignore_file_pattern
}
has_from_pretrained = hasattr(var, 'from_pretrained')
has_get_peft_type = hasattr(var, '_get_peft_type')
has_get_config_dict = hasattr(var, 'get_config_dict')
parameters = inspect.signature(var.from_pretrained).parameters
is_peft = 'model' in parameters and 'model_id' in parameters
try:
has_from_pretrained = hasattr(var, 'from_pretrained')
has_get_peft_type = hasattr(var, '_get_peft_type')
has_get_config_dict = hasattr(var, 'get_config_dict')
except ImportError:
continue
if has_from_pretrained and not hasattr(var, '_from_pretrained_origin'):
parameters = inspect.signature(var.from_pretrained).parameters
is_peft = 'model' in parameters and 'model_id' in parameters
var._from_pretrained_origin = var.from_pretrained
if not is_peft:
var.from_pretrained = partial(patch_pretrained_model_name_or_path,
ori_func=var._from_pretrained_origin,
**ignore_file_pattern_kwargs)
var.from_pretrained = partial(
patch_pretrained_model_name_or_path,
ori_func=var._from_pretrained_origin,
**ignore_file_pattern_kwargs)
else:
var.from_pretrained = partial(patch_peft_model_id,
ori_func=var._from_pretrained_origin,
**ignore_file_pattern_kwargs)
delattr(var, '_from_pretrained_origin')
var.from_pretrained = partial(
patch_peft_model_id,
ori_func=var._from_pretrained_origin,
**ignore_file_pattern_kwargs)
if has_get_peft_type and not hasattr(var, '_get_peft_type_origin'):
var._get_peft_type_origin = var._get_peft_type
var._get_peft_type = partial(_get_peft_type,
ori_func=var._get_peft_type_origin,
**ignore_file_pattern_kwargs)
delattr(var, '_get_peft_type_origin')
var._get_peft_type = partial(
_get_peft_type,
ori_func=var._get_peft_type_origin,
**ignore_file_pattern_kwargs)
if has_get_config_dict and not hasattr(var, '_get_config_dict_origin'):
var._get_config_dict_origin = var.get_config_dict
var.get_config_dict = partial(patch_pretrained_model_name_or_path,
ori_func=var._get_config_dict_origin,
**ignore_file_pattern_kwargs)
delattr(var, '_get_config_dict_origin')
var.get_config_dict = partial(
patch_pretrained_model_name_or_path,
ori_func=var._get_config_dict_origin,
**ignore_file_pattern_kwargs)
def _unpatch_pretrained_class():
def _unpatch_pretrained_class(all_imported_modules):
for var in all_imported_modules:
if var is None:
continue
has_from_pretrained = hasattr(var, 'from_pretrained')
has_get_peft_type = hasattr(var, '_get_peft_type')
has_get_config_dict = hasattr(var, 'get_config_dict')
try:
has_from_pretrained = hasattr(var, 'from_pretrained')
has_get_peft_type = hasattr(var, '_get_peft_type')
has_get_config_dict = hasattr(var, 'get_config_dict')
except ImportError:
continue
if has_from_pretrained and hasattr(var, '_from_pretrained_origin'):
var.from_pretrained = var._from_pretrained_origin
delattr(var, '_from_pretrained_origin')
if has_get_peft_type and hasattr(var, '_get_peft_type_origin'):
var._get_peft_type = var._get_peft_type_origin
delattr(var, '_get_peft_type_origin')
if has_get_config_dict and hasattr(var, '_get_config_dict_origin'):
var.get_config_dict = var._get_config_dict_origin
delattr(var, '_get_config_dict_origin')
def _patch_hub():
import huggingface_hub
from huggingface_hub import hf_api
from huggingface_hub.hf_api import api
from huggingface_hub.hf_api import CommitInfo, future_compatible
from modelscope import get_logger
logger = get_logger()
def _file_exists(
self,
repo_id: str,
filename: str,
*,
repo_type: Optional[str] = None,
revision: Optional[str] = None,
token: Union[str, bool, None] = None,
self,
repo_id: str,
filename: str,
*,
repo_type: Optional[str] = None,
revision: Optional[str] = None,
token: Union[str, bool, None] = None,
):
"""Patch huggingface_hub.file_exists"""
if repo_type is not None:
@@ -171,7 +219,8 @@ def _patch_hub():
api.try_login(token)
return file_download(
repo_id,
file_path=os.path.join(subfolder, filename) if subfolder else filename,
file_path=os.path.join(subfolder, filename)
if subfolder else filename,
cache_dir=cache_dir,
local_dir=local_dir,
local_files_only=local_files_only,
@@ -209,16 +258,16 @@ def _patch_hub():
@future_compatible
def upload_folder(
*,
repo_id: str,
folder_path: Union[str, Path],
path_in_repo: Optional[str] = None,
commit_message: Optional[str] = None,
commit_description: Optional[str] = None,
token: Union[str, bool, None] = None,
revision: Optional[str] = 'master',
ignore_patterns: Optional[Union[List[str], str]] = None,
**kwargs,
*,
repo_id: str,
folder_path: Union[str, Path],
path_in_repo: Optional[str] = None,
commit_message: Optional[str] = None,
commit_description: Optional[str] = None,
token: Union[str, bool, None] = None,
revision: Optional[str] = 'master',
ignore_patterns: Optional[Union[List[str], str]] = None,
**kwargs,
):
from modelscope.hub.push_to_hub import push_model_to_hub
push_model_to_hub(repo_id, folder_path, path_in_repo, commit_message,
@@ -233,16 +282,16 @@ def _patch_hub():
@future_compatible
def upload_file(
self,
*,
path_or_fileobj: Union[str, Path, bytes, BinaryIO],
path_in_repo: str,
repo_id: str,
token: Union[str, bool, None] = None,
revision: Optional[str] = None,
commit_message: Optional[str] = None,
commit_description: Optional[str] = None,
**kwargs,
self,
*,
path_or_fileobj: Union[str, Path, bytes, BinaryIO],
path_in_repo: str,
repo_id: str,
token: Union[str, bool, None] = None,
revision: Optional[str] = None,
commit_message: Optional[str] = None,
commit_description: Optional[str] = None,
**kwargs,
):
from modelscope.hub.push_to_hub import push_files_to_hub
push_files_to_hub(path_or_fileobj, path_in_repo, repo_id, token,
@@ -347,10 +396,19 @@ def _unpatch_hub():
repocard.upload_file = hf_api.upload_file
delattr(hf_api, '_upload_file_origin')
def patch_hub():
_patch_hub()
_patch_pretrained_class()
_patch_pretrained_class(get_all_imported_modules())
def unpatch_hub():
_unpatch_pretrained_class()
_unpatch_pretrained_class(get_all_imported_modules())
_unpatch_hub()
@contextlib.contextmanager
def patch_context():
patch_hub()
yield
unpatch_hub()

View File

@@ -282,6 +282,10 @@ def is_transformers_available():
return importlib.util.find_spec('transformers') is not None
def is_diffusers_available():
return importlib.util.find_spec('diffusers') is not None
def is_tensorrt_llm_available():
return importlib.util.find_spec('tensorrt_llm') is not None

View File

@@ -14,6 +14,3 @@ class DownloadDatasetTest(unittest.TestCase):
from transformers import AutoModel
model = AutoModel.from_pretrained('AI-ModelScope/bert-base-uncased')
self.assertTrue(model is not None)

View File

@@ -2,8 +2,7 @@
import unittest
from modelscope import (AutoConfig, AutoModel, AutoModelForCausalLM,
AutoTokenizer, GenerationConfig)
from modelscope.utils.hf_util.patcher import patch_context
class HFUtilTest(unittest.TestCase):
@@ -15,6 +14,7 @@ class HFUtilTest(unittest.TestCase):
pass
def test_auto_tokenizer(self):
from modelscope import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
'baichuan-inc/Baichuan2-7B-Chat',
trust_remote_code=True,
@@ -28,11 +28,13 @@ class HFUtilTest(unittest.TestCase):
self.assertTrue(BitsAndBytesConfig is not None)
def test_auto_model(self):
from modelscope import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
'baichuan-inc/baichuan-7B', trust_remote_code=True)
self.assertTrue(model is not None)
def test_auto_config(self):
from modelscope import AutoConfig, GenerationConfig
config = AutoConfig.from_pretrained(
'baichuan-inc/Baichuan-13B-Chat',
trust_remote_code=True,
@@ -45,12 +47,63 @@ class HFUtilTest(unittest.TestCase):
self.assertEqual(gen_config.assistant_token_id, 196)
def test_transformer_patch(self):
tokenizer = AutoTokenizer.from_pretrained(
'iic/nlp_structbert_sentiment-classification_chinese-base')
self.assertIsNotNone(tokenizer)
model = AutoModelForCausalLM.from_pretrained(
'iic/nlp_structbert_sentiment-classification_chinese-base')
self.assertIsNotNone(model)
with patch_context():
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained(
'iic/nlp_structbert_sentiment-classification_chinese-base')
self.assertIsNotNone(tokenizer)
model = AutoModelForCausalLM.from_pretrained(
'iic/nlp_structbert_sentiment-classification_chinese-base')
self.assertIsNotNone(model)
def test_patch_model(self):
from modelscope.utils.hf_util.patcher import patch_context
with patch_context():
from transformers import AutoModel
model = AutoModel.from_pretrained(
'iic/nlp_structbert_sentiment-classification_chinese-tiny')
self.assertTrue(model is not None)
try:
model = AutoModel.from_pretrained(
'iic/nlp_structbert_sentiment-classification_chinese-tiny')
except Exception:
pass
else:
self.assertTrue(False)
def test_patch_config(self):
with patch_context():
from transformers import AutoConfig
config = AutoConfig.from_pretrained(
'iic/nlp_structbert_sentiment-classification_chinese-tiny')
self.assertTrue(config is not None)
try:
config = AutoConfig.from_pretrained(
'iic/nlp_structbert_sentiment-classification_chinese-tiny')
except Exception:
pass
else:
self.assertTrue(False)
def test_patch_diffusers(self):
with patch_context():
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained(
'AI-ModelScope/stable-diffusion-v1-5')
self.assertTrue(pipe is not None)
try:
pipe = StableDiffusionPipeline.from_pretrained(
'AI-ModelScope/stable-diffusion-v1-5')
except Exception:
pass
else:
self.assertTrue(False)
def test_patch_peft(self):
with patch_context():
from peft import PeftModel
self.assertTrue(hasattr(PeftModel, '_from_pretrained_origin'))
self.assertFalse(hasattr(PeftModel, '_from_pretrained_origin'))
if __name__ == '__main__':