This commit is contained in:
yuze.zyz
2025-01-26 16:26:37 +08:00
parent 2eda829499
commit 4723e5c0ff
6 changed files with 306 additions and 205 deletions

View File

@@ -31,6 +31,7 @@ if TYPE_CHECKING:
from .trainers import (EpochBasedTrainer, Hook, Priority, TrainingArgs, from .trainers import (EpochBasedTrainer, Hook, Priority, TrainingArgs,
build_dataset_from_file) build_dataset_from_file)
from .utils.constant import Tasks from .utils.constant import Tasks
from .utils.hf_util import patch_hub, patch_context, unpatch_hub
if is_transformers_available(): if is_transformers_available():
from .utils.hf_util import ( from .utils.hf_util import (
AutoModel, AutoProcessor, AutoFeatureExtractor, GenerationConfig, AutoModel, AutoProcessor, AutoFeatureExtractor, GenerationConfig,
@@ -106,34 +107,6 @@ else:
'msdatasets': ['MsDataset'] 'msdatasets': ['MsDataset']
} }
if is_transformers_available():
_import_structure['utils.hf_util'] = [
'AutoModel', 'AutoProcessor', 'AutoFeatureExtractor',
'GenerationConfig', 'AutoConfig', 'GPTQConfig', 'AwqConfig',
'BitsAndBytesConfig', 'AutoModelForCausalLM',
'AutoModelForSeq2SeqLM', 'AutoModelForVision2Seq',
'AutoModelForSequenceClassification',
'AutoModelForTokenClassification',
'AutoModelForImageClassification', 'AutoModelForImageToImage',
'AutoModelForImageTextToText',
'AutoModelForZeroShotImageClassification',
'AutoModelForKeypointDetection',
'AutoModelForDocumentQuestionAnswering',
'AutoModelForSemanticSegmentation',
'AutoModelForUniversalSegmentation',
'AutoModelForInstanceSegmentation', 'AutoModelForObjectDetection',
'AutoModelForZeroShotObjectDetection',
'AutoModelForAudioClassification', 'AutoModelForSpeechSeq2Seq',
'AutoModelForMaskedImageModeling',
'AutoModelForVisualQuestionAnswering',
'AutoModelForTableQuestionAnswering',
'AutoModelForImageSegmentation', 'AutoModelForQuestionAnswering',
'AutoModelForMaskedLM', 'AutoTokenizer',
'AutoModelForMaskGeneration', 'AutoModelForPreTraining',
'AutoModelForTextEncoding', 'AutoImageProcessor', 'BatchFeature',
'Qwen2VLForConditionalGeneration', 'T5EncoderModel'
]
from modelscope.utils import hf_util from modelscope.utils import hf_util
extra_objects = {} extra_objects = {}

View File

@@ -29,7 +29,7 @@ def push_files_to_hub(
path_in_repo: str, path_in_repo: str,
repo_id: str, repo_id: str,
token: Union[str, bool, None] = None, token: Union[str, bool, None] = None,
revision: Optional[str] = None, revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
commit_message: Optional[str] = None, commit_message: Optional[str] = None,
commit_description: Optional[str] = None, commit_description: Optional[str] = None,
): ):
@@ -49,56 +49,12 @@ def push_files_to_hub(
sub_folder = os.path.join(temp_cache_dir, path_in_repo) sub_folder = os.path.join(temp_cache_dir, path_in_repo)
os.makedirs(sub_folder, exist_ok=True) os.makedirs(sub_folder, exist_ok=True)
if os.path.isfile(path_or_fileobj): if os.path.isfile(path_or_fileobj):
shutil.copyfile(path_or_fileobj, sub_folder) dest_file = os.path.join(sub_folder,
os.path.basename(path_or_fileobj))
shutil.copyfile(path_or_fileobj, dest_file)
else: else:
shutil.copytree(path_or_fileobj, sub_folder, dirs_exist_ok=True) shutil.copytree(path_or_fileobj, sub_folder, dirs_exist_ok=True)
repo.push(commit_message) repo.push(commit_message)
def push_model_to_hub(repo_id: str,
folder_path: Union[str, Path],
path_in_repo: Optional[str] = None,
commit_message: Optional[str] = None,
commit_description: Optional[str] = None,
token: Union[str, bool, None] = None,
private: bool = False,
revision: Optional[str] = 'master',
ignore_patterns: Optional[Union[List[str], str]] = None,
**kwargs):
from modelscope.hub.create_model import create_model_repo
create_model_repo(repo_id, token, private)
from modelscope import push_to_hub
commit_message = commit_message or 'Upload folder using api'
if commit_description:
commit_message = commit_message + '\n' + commit_description
if not os.path.exists(os.path.join(folder_path, 'configuration.json')):
default_config = {
'framework': 'pytorch',
'task': 'text-generation',
'allow_remote': True
}
config_json = kwargs.get('config_json') or {}
config = {**default_config, **config_json}
with open(os.path.join(folder_path, 'configuration.json'), 'w') as f:
f.write(json.dumps(config))
if ignore_patterns:
ignore_patterns = [p for p in ignore_patterns if p != '_*']
if path_in_repo:
# We don't support part submit for now
path_in_repo = os.path.basename(folder_path)
folder_path = os.path.dirname(folder_path)
ignore_patterns = []
if revision is None or revision == 'main':
revision = 'master'
push_to_hub(
repo_id,
folder_path,
token,
private,
commit_message=commit_message,
ignore_patterns=ignore_patterns,
revision=revision,
tag=path_in_repo)
def _api_push_to_hub(repo_name, def _api_push_to_hub(repo_name,

View File

@@ -5,6 +5,44 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import __version__ as transformers_version from transformers import __version__ as transformers_version
from transformers import AutoConfig
from transformers import AutoFeatureExtractor
from transformers import AutoImageProcessor
from transformers import AutoModel
from transformers import AutoModelForAudioClassification
from transformers import AutoModelForCausalLM
from transformers import AutoModelForDocumentQuestionAnswering
from transformers import AutoModelForImageClassification
from transformers import AutoModelForImageSegmentation
from transformers import AutoModelForInstanceSegmentation
from transformers import AutoModelForMaskedImageModeling
from transformers import AutoModelForMaskedLM
from transformers import AutoModelForMaskGeneration
from transformers import AutoModelForObjectDetection
from transformers import AutoModelForPreTraining
from transformers import AutoModelForQuestionAnswering
from transformers import AutoModelForSemanticSegmentation
from transformers import AutoModelForSeq2SeqLM
from transformers import AutoModelForSequenceClassification
from transformers import AutoModelForSpeechSeq2Seq
from transformers import AutoModelForTableQuestionAnswering
from transformers import AutoModelForTextEncoding
from transformers import AutoModelForTokenClassification
from transformers import AutoModelForUniversalSegmentation
from transformers import AutoModelForVision2Seq
from transformers import AutoModelForVisualQuestionAnswering
from transformers import AutoModelForZeroShotImageClassification
from transformers import AutoModelForZeroShotObjectDetection
from transformers import AutoProcessor
from transformers import AutoTokenizer
from transformers import BatchFeature
from transformers import BitsAndBytesConfig
from transformers import GenerationConfig
from transformers import (PretrainedConfig, PreTrainedModel,
PreTrainedTokenizerBase)
from transformers import T5EncoderModel
try: try:
from transformers import Qwen2VLForConditionalGeneration from transformers import Qwen2VLForConditionalGeneration
except ImportError: except ImportError:
@@ -55,89 +93,10 @@ else:
uagent = '%s/%s' % (Invoke.KEY, invoked_by) uagent = '%s/%s' % (Invoke.KEY, invoked_by)
return uagent return uagent
def get_wrapped_class(module_class, from .patcher import get_all_imported_modules, _patch_pretrained_class
ignore_file_pattern=[],
file_filter=None,
**kwargs):
"""Get a custom wrapper class for auto classes to download the models from the ModelScope hub
Args:
module_class: The actual module class
ignore_file_pattern (`str` or `List`, *optional*, default to `None`):
Any file pattern to be ignored in downloading, like exact file names or file extensions.
Returns:
The wrapper
"""
default_ignore_file_pattern = ignore_file_pattern
default_file_filter = file_filter
class ClassWrapper(module_class):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path,
*model_args, **kwargs):
from modelscope import snapshot_download
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke
ignore_file_pattern = kwargs.pop('ignore_file_pattern',
default_ignore_file_pattern)
subfolder = kwargs.pop('subfolder', default_file_filter)
file_filter = None
if subfolder:
file_filter = f'{subfolder}/*'
if not os.path.exists(pretrained_model_name_or_path):
revision = kwargs.pop('revision', DEFAULT_MODEL_REVISION)
if file_filter is None:
model_dir = snapshot_download(
pretrained_model_name_or_path,
revision=revision,
ignore_file_pattern=ignore_file_pattern,
user_agent=user_agent())
else:
model_dir = os.path.join(
snapshot_download(
pretrained_model_name_or_path,
revision=revision,
ignore_file_pattern=ignore_file_pattern,
allow_file_pattern=file_filter,
user_agent=user_agent()), subfolder)
else:
model_dir = pretrained_model_name_or_path
module_obj = module_class.from_pretrained(
model_dir, *model_args, **kwargs)
if module_class.__name__.startswith('AutoModel'):
module_obj.model_dir = model_dir
return module_obj
ClassWrapper.__name__ = module_class.__name__
ClassWrapper.__qualname__ = module_class.__qualname__
return ClassWrapper
from .patcher import get_all_imported_modules
all_imported_modules = get_all_imported_modules() all_imported_modules = get_all_imported_modules()
all_available_modules = [] all_available_modules = _patch_pretrained_class(all_imported_modules, wrap=True)
large_file_free = ['config', 'tokenizer']
for module in all_imported_modules:
try:
if (hasattr(module, 'from_pretrained')
and 'pretrained_model_name_or_path' in inspect.signature(
module.from_pretrained).parameters):
if any(lf in module.__name__.lower()
for lf in large_file_free):
ignore_file_patterns = {
'ignore_file_pattern': [
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth',
r'\w+\.pt', r'\w+\.h5'
]
}
else:
ignore_file_patterns = {}
all_available_modules.append(
get_wrapped_class(module, **ignore_file_patterns))
except (ImportError, AttributeError):
pass
for module in all_available_modules: for module in all_available_modules:
globals()[module.__name__] = module globals()[module.__name__] = module

View File

@@ -11,18 +11,20 @@ from typing import BinaryIO, Dict, List, Optional, Union
def get_all_imported_modules(): def get_all_imported_modules():
"""Find all modules in transformers/peft/diffusers"""
all_imported_modules = [] all_imported_modules = []
transformers_include_names = ['Auto', 'T5', 'BitsAndBytes', 'GenerationConfig',
'Quant', 'Awq', 'GPTQ', 'BatchFeature', 'Qwen2']
diffusers_include_names = ['Pipeline']
if importlib.util.find_spec('transformers') is not None: if importlib.util.find_spec('transformers') is not None:
import transformers import transformers
extra_modules = ['T5']
lazy_module = sys.modules['transformers'] lazy_module = sys.modules['transformers']
_import_structure = lazy_module._import_structure _import_structure = lazy_module._import_structure
for key in _import_structure: for key in _import_structure:
values = _import_structure[key] values = _import_structure[key]
for value in values: for value in values:
# pretrained # pretrained
if 'auto' in value.lower() or any(m in value if any([name in value for name in transformers_include_names]):
for m in extra_modules):
try: try:
module = importlib.import_module( module = importlib.import_module(
f'.{key}', transformers.__name__) f'.{key}', transformers.__name__)
@@ -46,7 +48,7 @@ def get_all_imported_modules():
for key in _import_structure: for key in _import_structure:
values = _import_structure[key] values = _import_structure[key]
for value in values: for value in values:
if 'pipeline' in value.lower(): if any([name in value for name in diffusers_include_names]):
try: try:
module = importlib.import_module( module = importlib.import_module(
f'.{key}', diffusers.__name__) f'.{key}', diffusers.__name__)
@@ -57,9 +59,11 @@ def get_all_imported_modules():
return all_imported_modules return all_imported_modules
def _patch_pretrained_class(all_imported_modules): def _patch_pretrained_class(all_imported_modules, wrap=False):
def get_model_dir(pretrained_model_name_or_path, ignore_file_pattern, def get_model_dir(pretrained_model_name_or_path,
ignore_file_pattern=None,
allow_file_pattern=None,
**kwargs): **kwargs):
from modelscope import snapshot_download from modelscope import snapshot_download
if not os.path.exists(pretrained_model_name_or_path): if not os.path.exists(pretrained_model_name_or_path):
@@ -67,14 +71,13 @@ def _patch_pretrained_class(all_imported_modules):
model_dir = snapshot_download( model_dir = snapshot_download(
pretrained_model_name_or_path, pretrained_model_name_or_path,
revision=revision, revision=revision,
ignore_file_pattern=ignore_file_pattern) ignore_file_pattern=ignore_file_pattern,
allow_file_pattern=allow_file_pattern)
else: else:
model_dir = pretrained_model_name_or_path model_dir = pretrained_model_name_or_path
return model_dir return model_dir
ignore_file_pattern = [ ignore_file_pattern = [r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt', r'\w+\.h5']
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt'
]
def patch_pretrained_model_name_or_path(pretrained_model_name_or_path, def patch_pretrained_model_name_or_path(pretrained_model_name_or_path,
*model_args, **kwargs): *model_args, **kwargs):
@@ -93,6 +96,88 @@ def _patch_pretrained_class(all_imported_modules):
model_dir = get_model_dir(model_id, ignore_file_pattern, **kwargs) model_dir = get_model_dir(model_id, ignore_file_pattern, **kwargs)
return kwargs.pop('ori_func')(model_dir, **kwargs) return kwargs.pop('ori_func')(model_dir, **kwargs)
def get_wrapped_class(module_class: 'PreTrainedModel',
ignore_file_pattern: Optional[Union[str, List[str]]] = None,
allow_file_pattern: Optional[Union[str, List[str]]] = None,
**kwargs):
"""Get a custom wrapper class for auto classes to download the models from the ModelScope hub
Args:
module_class (`PreTrainedModel`): The actual module class
ignore_file_pattern (`str` or `List`, *optional*, default to `None`):
Any file pattern to be ignored, like exact file names or file extensions.
allow_file_pattern (`str` or `List`, *optional*, default to `None`):
Any file pattern to be included, like exact file names or file extensions.
Returns:
The wrapper
"""
def from_pretrained(model, model_id, *model_args, **kwargs):
model_dir = get_model_dir(model_id,
ignore_file_pattern=ignore_file_pattern,
allow_file_pattern=allow_file_pattern,
**kwargs)
module_obj = module_class.from_pretrained(
model, model_dir, *model_args, **kwargs)
return module_obj
class ClassWrapper(module_class):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path,
*model_args, **kwargs):
model_dir = get_model_dir(pretrained_model_name_or_path,
ignore_file_pattern=ignore_file_pattern,
allow_file_pattern=allow_file_pattern,
**kwargs)
module_obj = module_class.from_pretrained(
model_dir, *model_args, **kwargs)
if module_class.__name__.startswith('AutoModel'):
module_obj.model_dir = model_dir
return module_obj
@classmethod
def _get_peft_type(cls, model_id, **kwargs):
model_dir = get_model_dir(model_id,
kwargs.pop('ignore_file_pattern', None),
**kwargs)
module_obj = module_class._get_peft_type(
model_dir, **kwargs)
return module_obj
@classmethod
def get_config_dict(cls, pretrained_model_name_or_path, *model_args, **kwargs):
model_dir = get_model_dir(pretrained_model_name_or_path,
kwargs.pop('ignore_file_pattern', None),
**kwargs)
module_obj = module_class.get_config_dict(
model_dir, *model_args, **kwargs)
return module_obj
if not hasattr(module_class, 'from_pretrained'):
del ClassWrapper.from_pretrained
else:
parameters = inspect.signature(var.from_pretrained).parameters
if 'model' in parameters and 'model_id' in parameters:
# peft
ClassWrapper.from_pretrained = from_pretrained
if not hasattr(module_class, '_get_peft_type'):
del ClassWrapper._get_peft_type
if not hasattr(module_class, 'get_config_dict'):
del ClassWrapper.get_config_dict
ClassWrapper.__name__ = module_class.__name__
ClassWrapper.__qualname__ = module_class.__qualname__
return ClassWrapper
all_available_modules = []
for var in all_imported_modules: for var in all_imported_modules:
if var is None or not hasattr(var, '__name__'): if var is None or not hasattr(var, '__name__'):
continue continue
@@ -107,38 +192,53 @@ def _patch_pretrained_class(all_imported_modules):
} }
try: try:
# some TFxxx classes has import errors
has_from_pretrained = hasattr(var, 'from_pretrained') has_from_pretrained = hasattr(var, 'from_pretrained')
has_get_peft_type = hasattr(var, '_get_peft_type') has_get_peft_type = hasattr(var, '_get_peft_type')
has_get_config_dict = hasattr(var, 'get_config_dict') has_get_config_dict = hasattr(var, 'get_config_dict')
except ImportError: except ImportError:
continue continue
if has_from_pretrained and not hasattr(var, '_from_pretrained_origin'):
parameters = inspect.signature(var.from_pretrained).parameters
is_peft = 'model' in parameters and 'model_id' in parameters
var._from_pretrained_origin = var.from_pretrained
if not is_peft:
var.from_pretrained = partial(
patch_pretrained_model_name_or_path,
ori_func=var._from_pretrained_origin,
**ignore_file_pattern_kwargs)
else:
var.from_pretrained = partial(
patch_peft_model_id,
ori_func=var._from_pretrained_origin,
**ignore_file_pattern_kwargs)
if has_get_peft_type and not hasattr(var, '_get_peft_type_origin'):
var._get_peft_type_origin = var._get_peft_type
var._get_peft_type = partial(
_get_peft_type,
ori_func=var._get_peft_type_origin,
**ignore_file_pattern_kwargs)
if has_get_config_dict and not hasattr(var, '_get_config_dict_origin'): if wrap:
var._get_config_dict_origin = var.get_config_dict try:
var.get_config_dict = partial( if not has_from_pretrained and not has_get_config_dict and not has_get_peft_type:
patch_pretrained_model_name_or_path, all_available_modules.append(var)
ori_func=var._get_config_dict_origin, else:
**ignore_file_pattern_kwargs) all_available_modules.append(get_wrapped_class(var, ignore_file_pattern))
except Exception:
all_available_modules.append(var)
else:
if has_from_pretrained and not hasattr(var, '_from_pretrained_origin'):
parameters = inspect.signature(var.from_pretrained).parameters
# different argument names
is_peft = 'model' in parameters and 'model_id' in parameters
var._from_pretrained_origin = var.from_pretrained
if not is_peft:
var.from_pretrained = partial(
patch_pretrained_model_name_or_path,
ori_func=var._from_pretrained_origin,
**ignore_file_pattern_kwargs)
else:
var.from_pretrained = partial(
patch_peft_model_id,
ori_func=var._from_pretrained_origin,
**ignore_file_pattern_kwargs)
if has_get_peft_type and not hasattr(var, '_get_peft_type_origin'):
var._get_peft_type_origin = var._get_peft_type
var._get_peft_type = partial(
_get_peft_type,
ori_func=var._get_peft_type_origin,
**ignore_file_pattern_kwargs)
if has_get_config_dict and not hasattr(var, '_get_config_dict_origin'):
var._get_config_dict_origin = var.get_config_dict
var.get_config_dict = partial(
patch_pretrained_model_name_or_path,
ori_func=var._get_config_dict_origin,
**ignore_file_pattern_kwargs)
all_available_modules.append(var)
return all_available_modules
def _unpatch_pretrained_class(all_imported_modules): def _unpatch_pretrained_class(all_imported_modules):
@@ -167,7 +267,7 @@ def _patch_hub():
import huggingface_hub import huggingface_hub
from huggingface_hub import hf_api from huggingface_hub import hf_api
from huggingface_hub.hf_api import api from huggingface_hub.hf_api import api
from huggingface_hub.hf_api import CommitInfo, future_compatible from huggingface_hub.hf_api import future_compatible
from modelscope import get_logger from modelscope import get_logger
logger = get_logger() logger = get_logger()
@@ -258,6 +358,7 @@ def _patch_hub():
@future_compatible @future_compatible
def upload_folder( def upload_folder(
self,
*, *,
repo_id: str, repo_id: str,
folder_path: Union[str, Path], folder_path: Union[str, Path],
@@ -269,10 +370,16 @@ def _patch_hub():
ignore_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None,
**kwargs, **kwargs,
): ):
from modelscope.hub.push_to_hub import push_model_to_hub from modelscope.hub.push_to_hub import push_files_to_hub
push_model_to_hub(repo_id, folder_path, path_in_repo, commit_message, push_files_to_hub(
commit_description, token, True, revision, path_or_fileobj=folder_path,
ignore_patterns) path_in_repo=path_in_repo,
repo_id=repo_id,
commit_message=commit_message,
commit_description=commit_description,
revision=revision,
token=token)
from modelscope.utils.repo_utils import CommitInfo
return CommitInfo( return CommitInfo(
commit_url=f'https://www.modelscope.cn/models/{repo_id}/files', commit_url=f'https://www.modelscope.cn/models/{repo_id}/files',
commit_message=commit_message, commit_message=commit_message,
@@ -280,6 +387,8 @@ def _patch_hub():
oid=None, oid=None,
) )
from modelscope.utils.constant import DEFAULT_REPOSITORY_REVISION
@future_compatible @future_compatible
def upload_file( def upload_file(
self, self,
@@ -288,7 +397,7 @@ def _patch_hub():
path_in_repo: str, path_in_repo: str,
repo_id: str, repo_id: str,
token: Union[str, bool, None] = None, token: Union[str, bool, None] = None,
revision: Optional[str] = None, revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
commit_message: Optional[str] = None, commit_message: Optional[str] = None,
commit_description: Optional[str] = None, commit_description: Optional[str] = None,
**kwargs, **kwargs,

View File

@@ -29,7 +29,7 @@ TEST_ACCESS_TOKEN1 = os.environ.get('TEST_ACCESS_TOKEN_CITEST', None)
TEST_ACCESS_TOKEN2 = os.environ.get('TEST_ACCESS_TOKEN_SDKDEV', None) TEST_ACCESS_TOKEN2 = os.environ.get('TEST_ACCESS_TOKEN_SDKDEV', None)
TEST_MODEL_CHINESE_NAME = '内部测试模型' TEST_MODEL_CHINESE_NAME = '内部测试模型'
TEST_MODEL_ORG = 'citest' TEST_MODEL_ORG = 'tastelikefeet'
def delete_credential(): def delete_credential():

View File

@@ -1,17 +1,51 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tempfile
import unittest import unittest
import uuid
from huggingface_hub import CommitInfo, RepoUrl
from modelscope import HubApi
from modelscope.utils.hf_util.patcher import patch_context from modelscope.utils.hf_util.patcher import patch_context
from modelscope.utils.logger import get_logger
from modelscope.utils.test_utils import TEST_MODEL_ORG
logger = get_logger()
class HFUtilTest(unittest.TestCase): class HFUtilTest(unittest.TestCase):
def setUp(self): def setUp(self):
pass logger.info('SetUp')
self.api = HubApi()
self.user = TEST_MODEL_ORG
print(self.user)
self.create_model_name = '%s/%s_%s' % (self.user, 'test_model_upload',
uuid.uuid4().hex)
logger.info('create %s' % self.create_model_name)
temporary_dir = tempfile.mkdtemp()
self.work_dir = temporary_dir
self.model_dir = os.path.join(temporary_dir, self.create_model_name)
self.repo_path = os.path.join(self.work_dir, 'repo_path')
self.test_folder = os.path.join(temporary_dir, 'test_folder')
self.test_file1 = os.path.join(
os.path.join(temporary_dir, 'test_folder', '1.json'))
self.test_file2 = os.path.join(os.path.join(temporary_dir, '2.json'))
os.makedirs(self.test_folder, exist_ok=True)
with open(self.test_file1, 'w') as f:
f.write('{}')
with open(self.test_file2, 'w') as f:
f.write('{}')
def tearDown(self): def tearDown(self):
pass logger.info('TearDown')
shutil.rmtree(self.model_dir, ignore_errors=True)
try:
self.api.delete_model(model_id=self.create_model_name)
except Exception:
pass
def test_auto_tokenizer(self): def test_auto_tokenizer(self):
from modelscope import AutoTokenizer from modelscope import AutoTokenizer
@@ -24,7 +58,7 @@ class HFUtilTest(unittest.TestCase):
self.assertFalse(tokenizer.is_fast) self.assertFalse(tokenizer.is_fast)
def test_quantization_import(self): def test_quantization_import(self):
from modelscope import GPTQConfig, BitsAndBytesConfig from modelscope import BitsAndBytesConfig
self.assertTrue(BitsAndBytesConfig is not None) self.assertTrue(BitsAndBytesConfig is not None)
def test_auto_model(self): def test_auto_model(self):
@@ -71,6 +105,16 @@ class HFUtilTest(unittest.TestCase):
else: else:
self.assertTrue(False) self.assertTrue(False)
def test_patch_config_bert(self):
from transformers import BertConfig
try:
BertConfig.from_pretrained(
'iic/nlp_structbert_sentiment-classification_chinese-tiny')
except Exception:
pass
else:
self.assertTrue(False)
def test_patch_config(self): def test_patch_config(self):
with patch_context(): with patch_context():
from transformers import AutoConfig from transformers import AutoConfig
@@ -85,6 +129,13 @@ class HFUtilTest(unittest.TestCase):
else: else:
self.assertTrue(False) self.assertTrue(False)
# Test patch again
with patch_context():
from transformers import AutoConfig
config = AutoConfig.from_pretrained(
'iic/nlp_structbert_sentiment-classification_chinese-tiny')
self.assertTrue(config is not None)
def test_patch_diffusers(self): def test_patch_diffusers(self):
with patch_context(): with patch_context():
from diffusers import StableDiffusionPipeline from diffusers import StableDiffusionPipeline
@@ -105,6 +156,59 @@ class HFUtilTest(unittest.TestCase):
self.assertTrue(hasattr(PeftModel, '_from_pretrained_origin')) self.assertTrue(hasattr(PeftModel, '_from_pretrained_origin'))
self.assertFalse(hasattr(PeftModel, '_from_pretrained_origin')) self.assertFalse(hasattr(PeftModel, '_from_pretrained_origin'))
def test_patch_file_exists(self):
with patch_context():
from huggingface_hub import file_exists
self.assertTrue(
file_exists('AI-ModelScope/stable-diffusion-v1-5',
'feature_extractor/preprocessor_config.json'))
try:
# Import again
from huggingface_hub import file_exists # noqa
file_exists('AI-ModelScope/stable-diffusion-v1-5',
'feature_extractor/preprocessor_config.json')
except Exception:
pass
else:
self.assertTrue(False)
def test_patch_file_download(self):
with patch_context():
from huggingface_hub import hf_hub_download
local_dir = hf_hub_download(
'AI-ModelScope/stable-diffusion-v1-5',
'feature_extractor/preprocessor_config.json')
logger.info('patch file_download dir: ' + local_dir)
self.assertTrue(local_dir is not None)
def test_patch_create_repo(self):
with patch_context():
from huggingface_hub import create_repo
repo_url: RepoUrl = create_repo(self.create_model_name)
logger.info('patch create repo result: ' + repo_url.repo_id)
self.assertTrue(repo_url is not None)
from huggingface_hub import upload_folder
commit_info: CommitInfo = upload_folder(
repo_id=self.create_model_name,
folder_path=self.test_folder,
path_in_repo='')
logger.info('patch create repo result: ' + commit_info.commit_url)
self.assertTrue(commit_info is not None)
from huggingface_hub import file_exists
self.assertTrue(file_exists(self.create_model_name, '1.json'))
from huggingface_hub import upload_file
commit_info: CommitInfo = upload_file(
path_or_fileobj=self.test_file2,
path_in_repo='test_folder2',
repo_id=self.create_model_name)
self.assertTrue(
file_exists(self.create_model_name, 'test_folder2/2.json'))
def test_who_am_i(self):
with patch_context():
from huggingface_hub import whoami
self.assertTrue(whoami()['name'] == self.user)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()