mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
Merge branch 'feat/whoami' of https://github.com/tastelikefeet/modelscope into yh2
This commit is contained in:
@@ -31,6 +31,7 @@ if TYPE_CHECKING:
|
||||
from .trainers import (EpochBasedTrainer, Hook, Priority, TrainingArgs,
|
||||
build_dataset_from_file)
|
||||
from .utils.constant import Tasks
|
||||
from .utils.hf_util import patch_hub, patch_context, unpatch_hub
|
||||
if is_transformers_available():
|
||||
from .utils.hf_util import (
|
||||
AutoModel, AutoProcessor, AutoFeatureExtractor, GenerationConfig,
|
||||
@@ -106,33 +107,13 @@ else:
|
||||
'msdatasets': ['MsDataset']
|
||||
}
|
||||
|
||||
if is_transformers_available():
|
||||
_import_structure['utils.hf_util'] = [
|
||||
'AutoModel', 'AutoProcessor', 'AutoFeatureExtractor',
|
||||
'GenerationConfig', 'AutoConfig', 'GPTQConfig', 'AwqConfig',
|
||||
'BitsAndBytesConfig', 'AutoModelForCausalLM',
|
||||
'AutoModelForSeq2SeqLM', 'AutoModelForVision2Seq',
|
||||
'AutoModelForSequenceClassification',
|
||||
'AutoModelForTokenClassification',
|
||||
'AutoModelForImageClassification', 'AutoModelForImageToImage',
|
||||
'AutoModelForImageTextToText',
|
||||
'AutoModelForZeroShotImageClassification',
|
||||
'AutoModelForKeypointDetection',
|
||||
'AutoModelForDocumentQuestionAnswering',
|
||||
'AutoModelForSemanticSegmentation',
|
||||
'AutoModelForUniversalSegmentation',
|
||||
'AutoModelForInstanceSegmentation', 'AutoModelForObjectDetection',
|
||||
'AutoModelForZeroShotObjectDetection',
|
||||
'AutoModelForAudioClassification', 'AutoModelForSpeechSeq2Seq',
|
||||
'AutoModelForMaskedImageModeling',
|
||||
'AutoModelForVisualQuestionAnswering',
|
||||
'AutoModelForTableQuestionAnswering',
|
||||
'AutoModelForImageSegmentation', 'AutoModelForQuestionAnswering',
|
||||
'AutoModelForMaskedLM', 'AutoTokenizer',
|
||||
'AutoModelForMaskGeneration', 'AutoModelForPreTraining',
|
||||
'AutoModelForTextEncoding', 'AutoImageProcessor', 'BatchFeature',
|
||||
'Qwen2VLForConditionalGeneration', 'T5EncoderModel'
|
||||
]
|
||||
from modelscope.utils import hf_util
|
||||
|
||||
extra_objects = {}
|
||||
attributes = dir(hf_util)
|
||||
imports = [attr for attr in attributes if not attr.startswith('__')]
|
||||
for _import in imports:
|
||||
extra_objects[_import] = getattr(hf_util, _import)
|
||||
|
||||
import sys
|
||||
|
||||
@@ -141,5 +122,5 @@ else:
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
extra_objects=extra_objects,
|
||||
)
|
||||
|
||||
@@ -9,6 +9,7 @@ import pickle
|
||||
import platform
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from http import HTTPStatus
|
||||
@@ -47,7 +48,9 @@ from modelscope.hub.errors import (InvalidParameter, NotExistError,
|
||||
raise_for_http_status, raise_on_error)
|
||||
from modelscope.hub.git import GitCommandWrapper
|
||||
from modelscope.hub.repository import Repository
|
||||
from modelscope.hub.utils.utils import (get_endpoint, get_readable_folder_size,
|
||||
from modelscope.hub.utils.utils import (add_patterns_to_file,
|
||||
add_patterns_to_gitattributes,
|
||||
get_endpoint, get_readable_folder_size,
|
||||
get_release_datetime,
|
||||
model_id_to_group_owner_name)
|
||||
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
|
||||
@@ -75,6 +78,7 @@ logger = get_logger()
|
||||
class HubApi:
|
||||
"""Model hub api interface.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
endpoint: Optional[str] = None,
|
||||
timeout=API_HTTP_CLIENT_TIMEOUT,
|
||||
@@ -109,14 +113,14 @@ class HubApi:
|
||||
self.upload_checker = UploadingCheck()
|
||||
|
||||
def login(
|
||||
self,
|
||||
access_token: str,
|
||||
self,
|
||||
access_token: Optional[str] = None,
|
||||
):
|
||||
"""Login with your SDK access token, which can be obtained from
|
||||
https://www.modelscope.cn user center.
|
||||
|
||||
Args:
|
||||
access_token (str): user access token on modelscope.
|
||||
access_token (str): user access token on modelscope, set this argument or set `MODELSCOPE_API_TOKEN`.
|
||||
|
||||
Returns:
|
||||
cookies: to authenticate yourself to ModelScope open-api
|
||||
@@ -125,6 +129,9 @@ class HubApi:
|
||||
Note:
|
||||
You only have to login once within 30 days.
|
||||
"""
|
||||
if access_token is None:
|
||||
access_token = os.environ.get('MODELSCOPE_API_TOKEN')
|
||||
assert access_token is not None, 'Please pass in access_token or set `MODELSCOPE_API_TOKEN`'
|
||||
path = f'{self.endpoint}/api/v1/login'
|
||||
r = self.session.post(
|
||||
path,
|
||||
@@ -147,6 +154,16 @@ class HubApi:
|
||||
return d[API_RESPONSE_FIELD_DATA][
|
||||
API_RESPONSE_FIELD_GIT_ACCESS_TOKEN], cookies
|
||||
|
||||
def try_login(self, access_token: Optional[str] = None) -> bool:
|
||||
"""Wraps the `login` method and returns bool.
|
||||
"""
|
||||
try:
|
||||
self.login(access_token)
|
||||
return True
|
||||
except AssertionError:
|
||||
logger.warning('Login failed.')
|
||||
return False
|
||||
|
||||
def create_model(self,
|
||||
model_id: str,
|
||||
visibility: Optional[int] = ModelVisibility.PUBLIC,
|
||||
@@ -226,9 +243,9 @@ class HubApi:
|
||||
return f'{self.endpoint}/api/v1/models/{model_id}.git'
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||
) -> str:
|
||||
"""Get model information at ModelScope
|
||||
|
||||
@@ -264,10 +281,10 @@ class HubApi:
|
||||
raise_for_http_status(r)
|
||||
|
||||
def repo_exists(
|
||||
self,
|
||||
repo_id: str,
|
||||
*,
|
||||
repo_type: Optional[str] = None,
|
||||
self,
|
||||
repo_id: str,
|
||||
*,
|
||||
repo_type: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Checks if a repository exists on ModelScope
|
||||
@@ -475,7 +492,7 @@ class HubApi:
|
||||
r = self.session.put(
|
||||
path,
|
||||
data='{"Path":"%s", "PageNumber":%s, "PageSize": %s}' %
|
||||
(owner_or_group, page_number, page_size),
|
||||
(owner_or_group, page_number, page_size),
|
||||
cookies=cookies,
|
||||
headers=self.builder_headers(self.headers))
|
||||
handle_http_response(r, logger, cookies, owner_or_group)
|
||||
@@ -489,9 +506,7 @@ class HubApi:
|
||||
raise_for_http_status(r)
|
||||
return None
|
||||
|
||||
def _check_cookie(self,
|
||||
use_cookies: Union[bool,
|
||||
CookieJar] = False) -> CookieJar:
|
||||
def _check_cookie(self, use_cookies: Union[bool, CookieJar] = False) -> CookieJar: # noqa
|
||||
cookies = None
|
||||
if isinstance(use_cookies, CookieJar):
|
||||
cookies = use_cookies
|
||||
@@ -602,7 +617,8 @@ class HubApi:
|
||||
else:
|
||||
if revision is None: # user not specified revision, use latest revision before release time
|
||||
revisions_detail = [x for x in
|
||||
all_tags_detail if x['CreatedAt'] <= release_timestamp] if all_tags_detail else [] # noqa E501
|
||||
all_tags_detail if
|
||||
x['CreatedAt'] <= release_timestamp] if all_tags_detail else [] # noqa E501
|
||||
if len(revisions_detail) > 0:
|
||||
revision = revisions_detail[0]['Revision'] # use latest revision before release time.
|
||||
revision_detail = revisions_detail[0]
|
||||
@@ -636,9 +652,9 @@ class HubApi:
|
||||
cookies=cookies)['Revision']
|
||||
|
||||
def get_model_branches_and_tags_details(
|
||||
self,
|
||||
model_id: str,
|
||||
use_cookies: Union[bool, CookieJar] = False,
|
||||
self,
|
||||
model_id: str,
|
||||
use_cookies: Union[bool, CookieJar] = False,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
"""Get model branch and tags.
|
||||
|
||||
@@ -662,9 +678,9 @@ class HubApi:
|
||||
return info['RevisionMap']['Branches'], info['RevisionMap']['Tags']
|
||||
|
||||
def get_model_branches_and_tags(
|
||||
self,
|
||||
model_id: str,
|
||||
use_cookies: Union[bool, CookieJar] = False,
|
||||
self,
|
||||
model_id: str,
|
||||
use_cookies: Union[bool, CookieJar] = False,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
"""Get model branch and tags.
|
||||
|
||||
@@ -1103,7 +1119,7 @@ class HubApi:
|
||||
def list_oss_dataset_objects(self, dataset_name, namespace, max_limit,
|
||||
is_recursive, is_filter_dir, revision):
|
||||
url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss/tree/?' \
|
||||
f'MaxLimit={max_limit}&Revision={revision}&Recursive={is_recursive}&FilterDir={is_filter_dir}'
|
||||
f'MaxLimit={max_limit}&Revision={revision}&Recursive={is_recursive}&FilterDir={is_filter_dir}'
|
||||
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
resp = self.session.get(url=url, cookies=cookies, timeout=1800)
|
||||
@@ -1132,7 +1148,7 @@ class HubApi:
|
||||
raise ValueError('Args cannot be empty!')
|
||||
|
||||
url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss/prefix?Prefix={object_name}/' \
|
||||
f'&Revision={revision}'
|
||||
f'&Revision={revision}'
|
||||
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
resp = self.session.delete(url=url, cookies=cookies)
|
||||
@@ -1198,21 +1214,22 @@ class HubApi:
|
||||
repo_type: Optional[str] = REPO_TYPE_MODEL,
|
||||
chinese_name: Optional[str] = '',
|
||||
license: Optional[str] = Licenses.APACHE_V2,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
|
||||
# TODO: exist_ok
|
||||
|
||||
if not repo_id:
|
||||
raise ValueError('Repo id cannot be empty!')
|
||||
|
||||
if token:
|
||||
self.login(access_token=token)
|
||||
else:
|
||||
logger.warning('No token provided, will use the cached token.')
|
||||
self.try_login(token)
|
||||
if '/' not in repo_id:
|
||||
user_name = ModelScopeConfig.get_user_info()[0]
|
||||
assert isinstance(user_name, str)
|
||||
repo_id = f'{user_name}/{repo_id}'
|
||||
logger.info(
|
||||
f"'/' not in hub_model_id, pushing to personal repo {repo_id}")
|
||||
|
||||
repo_id_list = repo_id.split('/')
|
||||
if len(repo_id_list) != 2:
|
||||
raise ValueError('Invalid repo id, should be in the format of `owner_name/repo_name`')
|
||||
namespace, repo_name = repo_id_list
|
||||
|
||||
if repo_type == REPO_TYPE_MODEL:
|
||||
@@ -1228,6 +1245,25 @@ class HubApi:
|
||||
chinese_name=chinese_name,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_cache_dir:
|
||||
from modelscope.hub.repository import Repository
|
||||
repo = Repository(temp_cache_dir, repo_id)
|
||||
add_patterns_to_gitattributes(
|
||||
repo, ['*.safetensors', '*.bin', '*.pt', '*.gguf'])
|
||||
default_config = {
|
||||
'framework': 'pytorch',
|
||||
'task': 'text-generation',
|
||||
'allow_remote': True
|
||||
}
|
||||
config_json = kwargs.get('config_json')
|
||||
if not config_json:
|
||||
config_json = {}
|
||||
config = {**default_config, **config_json}
|
||||
add_patterns_to_file(
|
||||
repo,
|
||||
'configuration.json', [json.dumps(config)],
|
||||
ignore_push_error=True)
|
||||
|
||||
elif repo_type == REPO_TYPE_DATASET:
|
||||
visibilities = {k: v for k, v in DatasetVisibility.__dict__.items() if not k.startswith('__')}
|
||||
visibility: int = visibilities.get(visibility.upper())
|
||||
|
||||
@@ -100,15 +100,12 @@ def check_local_model_is_latest(
|
||||
pass # ignore
|
||||
|
||||
|
||||
def check_model_is_id(model_id: str, token=None):
|
||||
if token is None:
|
||||
token = os.environ.get('MODELSCOPE_API_TOKEN')
|
||||
def check_model_is_id(model_id: str, token: Optional[str] = None):
|
||||
if model_id is None or os.path.exists(model_id):
|
||||
return False
|
||||
else:
|
||||
_api = HubApi()
|
||||
if token is not None:
|
||||
_api.login(token)
|
||||
_api.try_login(token)
|
||||
try:
|
||||
_api.get_model(model_id=model_id, )
|
||||
return True
|
||||
|
||||
@@ -3,7 +3,12 @@
|
||||
import concurrent.futures
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from multiprocessing import Manager, Process, Value
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import json
|
||||
|
||||
from modelscope.hub.api import HubApi
|
||||
from modelscope.hub.constants import ModelVisibility
|
||||
@@ -19,6 +24,44 @@ _tasks = dict()
|
||||
_manager = None
|
||||
|
||||
|
||||
def _push_files_to_hub(
|
||||
path_or_fileobj: Union[str, Path],
|
||||
path_in_repo: str,
|
||||
repo_id: str,
|
||||
token: Union[str, bool, None] = None,
|
||||
revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
|
||||
commit_message: Optional[str] = None,
|
||||
commit_description: Optional[str] = None,
|
||||
):
|
||||
"""Push files to model hub incrementally
|
||||
|
||||
This function if used for patch_hub, user is not recommended to call this.
|
||||
This function will be merged to push_to_hub in later sprints.
|
||||
"""
|
||||
if not os.path.exists(path_or_fileobj):
|
||||
return
|
||||
|
||||
from modelscope import HubApi
|
||||
api = HubApi()
|
||||
api.login(token)
|
||||
if not commit_message:
|
||||
commit_message = 'Updating files'
|
||||
if commit_description:
|
||||
commit_message = commit_message + '\n' + commit_description
|
||||
with tempfile.TemporaryDirectory() as temp_cache_dir:
|
||||
from modelscope.hub.repository import Repository
|
||||
repo = Repository(temp_cache_dir, repo_id, revision=revision)
|
||||
sub_folder = os.path.join(temp_cache_dir, path_in_repo)
|
||||
os.makedirs(sub_folder, exist_ok=True)
|
||||
if os.path.isfile(path_or_fileobj):
|
||||
dest_file = os.path.join(sub_folder,
|
||||
os.path.basename(path_or_fileobj))
|
||||
shutil.copyfile(path_or_fileobj, dest_file)
|
||||
else:
|
||||
shutil.copytree(path_or_fileobj, sub_folder, dirs_exist_ok=True)
|
||||
repo.push(commit_message)
|
||||
|
||||
|
||||
def _api_push_to_hub(repo_name,
|
||||
output_dir,
|
||||
token,
|
||||
|
||||
@@ -2,9 +2,11 @@
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import BinaryIO, List, Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
@@ -125,3 +127,67 @@ def file_integrity_validation(file_path, expected_sha256):
|
||||
file_path, expected_sha256, file_sha256)
|
||||
logger.error(msg)
|
||||
raise FileIntegrityError(msg)
|
||||
|
||||
|
||||
def add_patterns_to_file(repo,
|
||||
file_name: str,
|
||||
patterns: List[str],
|
||||
commit_message: Optional[str] = None,
|
||||
ignore_push_error=False) -> None:
|
||||
if isinstance(patterns, str):
|
||||
patterns = [patterns]
|
||||
if commit_message is None:
|
||||
commit_message = f'Add `{patterns[0]}` patterns to {file_name}'
|
||||
|
||||
# Get current file content
|
||||
repo_dir = repo.model_dir
|
||||
file_path = os.path.join(repo_dir, file_name)
|
||||
if os.path.exists(file_path):
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
current_content = f.read()
|
||||
else:
|
||||
current_content = ''
|
||||
# Add the patterns to file
|
||||
content = current_content
|
||||
for pattern in patterns:
|
||||
if pattern not in content:
|
||||
if len(content) > 0 and not content.endswith('\n'):
|
||||
content += '\n'
|
||||
content += f'{pattern}\n'
|
||||
|
||||
# Write the file if it has changed
|
||||
if content != current_content:
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
logger.debug(f'Writing {file_name} file. Content: {content}')
|
||||
f.write(content)
|
||||
try:
|
||||
repo.push(commit_message)
|
||||
except Exception as e:
|
||||
if ignore_push_error:
|
||||
pass
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
def add_patterns_to_gitignore(repo,
|
||||
patterns: List[str],
|
||||
commit_message: Optional[str] = None) -> None:
|
||||
add_patterns_to_file(
|
||||
repo, '.gitignore', patterns, commit_message, ignore_push_error=True)
|
||||
|
||||
|
||||
def add_patterns_to_gitattributes(
|
||||
repo,
|
||||
patterns: List[str],
|
||||
commit_message: Optional[str] = None) -> None:
|
||||
new_patterns = []
|
||||
suffix = 'filter=lfs diff=lfs merge=lfs -text'
|
||||
for pattern in patterns:
|
||||
if suffix not in pattern:
|
||||
pattern = f'{pattern} {suffix}'
|
||||
new_patterns.append(pattern)
|
||||
file_name = '.gitattributes'
|
||||
if commit_message is None:
|
||||
commit_message = f'Add `{patterns[0]}` patterns to {file_name}'
|
||||
add_patterns_to_file(
|
||||
repo, file_name, new_patterns, commit_message, ignore_push_error=True)
|
||||
|
||||
@@ -1,468 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
from pathlib import Path
|
||||
from types import MethodType
|
||||
from typing import Optional, Union
|
||||
|
||||
from transformers import AutoConfig as AutoConfigHF
|
||||
from transformers import AutoFeatureExtractor as AutoFeatureExtractorHF
|
||||
from transformers import AutoImageProcessor as AutoImageProcessorHF
|
||||
from transformers import AutoModel as AutoModelHF
|
||||
from transformers import \
|
||||
AutoModelForAudioClassification as AutoModelForAudioClassificationHF
|
||||
from transformers import AutoModelForCausalLM as AutoModelForCausalLMHF
|
||||
from transformers import \
|
||||
AutoModelForDocumentQuestionAnswering as \
|
||||
AutoModelForDocumentQuestionAnsweringHF
|
||||
from transformers import \
|
||||
AutoModelForImageClassification as AutoModelForImageClassificationHF
|
||||
from transformers import \
|
||||
AutoModelForImageSegmentation as AutoModelForImageSegmentationHF
|
||||
from transformers import \
|
||||
AutoModelForInstanceSegmentation as AutoModelForInstanceSegmentationHF
|
||||
from transformers import \
|
||||
AutoModelForMaskedImageModeling as AutoModelForMaskedImageModelingHF
|
||||
from transformers import AutoModelForMaskedLM as AutoModelForMaskedLMHF
|
||||
from transformers import \
|
||||
AutoModelForMaskGeneration as AutoModelForMaskGenerationHF
|
||||
from transformers import \
|
||||
AutoModelForObjectDetection as AutoModelForObjectDetectionHF
|
||||
from transformers import AutoModelForPreTraining as AutoModelForPreTrainingHF
|
||||
from transformers import \
|
||||
AutoModelForQuestionAnswering as AutoModelForQuestionAnsweringHF
|
||||
from transformers import \
|
||||
AutoModelForSemanticSegmentation as AutoModelForSemanticSegmentationHF
|
||||
from transformers import AutoModelForSeq2SeqLM as AutoModelForSeq2SeqLMHF
|
||||
from transformers import \
|
||||
AutoModelForSequenceClassification as AutoModelForSequenceClassificationHF
|
||||
from transformers import \
|
||||
AutoModelForSpeechSeq2Seq as AutoModelForSpeechSeq2SeqHF
|
||||
from transformers import \
|
||||
AutoModelForTableQuestionAnswering as AutoModelForTableQuestionAnsweringHF
|
||||
from transformers import AutoModelForTextEncoding as AutoModelForTextEncodingHF
|
||||
from transformers import \
|
||||
AutoModelForTokenClassification as AutoModelForTokenClassificationHF
|
||||
from transformers import \
|
||||
AutoModelForUniversalSegmentation as AutoModelForUniversalSegmentationHF
|
||||
from transformers import AutoModelForVision2Seq as AutoModelForVision2SeqHF
|
||||
from transformers import \
|
||||
AutoModelForVisualQuestionAnswering as \
|
||||
AutoModelForVisualQuestionAnsweringHF
|
||||
from transformers import \
|
||||
AutoModelForZeroShotImageClassification as \
|
||||
AutoModelForZeroShotImageClassificationHF
|
||||
from transformers import \
|
||||
AutoModelForZeroShotObjectDetection as \
|
||||
AutoModelForZeroShotObjectDetectionHF
|
||||
from transformers import AutoProcessor as AutoProcessorHF
|
||||
from transformers import AutoTokenizer as AutoTokenizerHF
|
||||
from transformers import BatchFeature as BatchFeatureHF
|
||||
from transformers import BitsAndBytesConfig as BitsAndBytesConfigHF
|
||||
from transformers import GenerationConfig as GenerationConfigHF
|
||||
from transformers import (PretrainedConfig, PreTrainedModel,
|
||||
PreTrainedTokenizerBase)
|
||||
from transformers import T5EncoderModel as T5EncoderModelHF
|
||||
from transformers import __version__ as transformers_version
|
||||
|
||||
from modelscope import snapshot_download
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke
|
||||
from .logger import get_logger
|
||||
|
||||
try:
|
||||
from transformers import GPTQConfig as GPTQConfigHF
|
||||
from transformers import AwqConfig as AwqConfigHF
|
||||
except ImportError:
|
||||
GPTQConfigHF = None
|
||||
AwqConfigHF = None
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class UnsupportedAutoClass:
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.error_msg =\
|
||||
f'{name} is not supported with your installed Transformers version {transformers_version}. ' + \
|
||||
'Please update your Transformers by "pip install transformers -U".'
|
||||
|
||||
def from_pretrained(self, pretrained_model_name_or_path, *model_args,
|
||||
**kwargs):
|
||||
raise ImportError(self.error_msg)
|
||||
|
||||
def from_config(self, cls, config):
|
||||
raise ImportError(self.error_msg)
|
||||
|
||||
|
||||
def user_agent(invoked_by=None):
|
||||
if invoked_by is None:
|
||||
invoked_by = Invoke.PRETRAINED
|
||||
uagent = '%s/%s' % (Invoke.KEY, invoked_by)
|
||||
return uagent
|
||||
|
||||
|
||||
def _try_login(token: Optional[str] = None):
|
||||
from modelscope.hub.api import HubApi
|
||||
api = HubApi()
|
||||
if token is None:
|
||||
token = os.environ.get('MODELSCOPE_API_TOKEN')
|
||||
if token:
|
||||
api.login(token)
|
||||
|
||||
|
||||
def _file_exists(
|
||||
self,
|
||||
repo_id: str,
|
||||
filename: str,
|
||||
*,
|
||||
repo_type: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
token: Union[str, bool, None] = None,
|
||||
):
|
||||
"""Patch huggingface_hub.file_exists"""
|
||||
if repo_type is not None:
|
||||
logger.warning(
|
||||
'The passed in repo_type will not be used in modelscope. Now only model repo can be queried.'
|
||||
)
|
||||
_try_login(token)
|
||||
from modelscope.hub.api import HubApi
|
||||
api = HubApi()
|
||||
return api.file_exists(repo_id, filename, revision=revision)
|
||||
|
||||
|
||||
def _file_download(repo_id: str,
|
||||
filename: str,
|
||||
*,
|
||||
subfolder: Optional[str] = None,
|
||||
repo_type: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
cache_dir: Union[str, Path, None] = None,
|
||||
local_dir: Union[str, Path, None] = None,
|
||||
token: Union[bool, str, None] = None,
|
||||
local_files_only: bool = False,
|
||||
**kwargs):
|
||||
"""Patch huggingface_hub.hf_hub_download"""
|
||||
if len(kwargs) > 0:
|
||||
logger.warning(
|
||||
'The passed in library_name,library_version,user_agent,force_download,proxies'
|
||||
'etag_timeout,headers,endpoint '
|
||||
'will not be used in modelscope.')
|
||||
assert repo_type in (
|
||||
None, 'model',
|
||||
'dataset'), f'repo_type={repo_type} is not supported in ModelScope'
|
||||
if repo_type in (None, 'model'):
|
||||
from modelscope.hub.file_download import model_file_download as file_download
|
||||
else:
|
||||
from modelscope.hub.file_download import dataset_file_download as file_download
|
||||
_try_login(token)
|
||||
return file_download(
|
||||
repo_id,
|
||||
file_path=os.path.join(subfolder, filename) if subfolder else filename,
|
||||
cache_dir=cache_dir,
|
||||
local_dir=local_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision)
|
||||
|
||||
|
||||
def _patch_pretrained_class():
|
||||
|
||||
def get_model_dir(pretrained_model_name_or_path, ignore_file_pattern,
|
||||
**kwargs):
|
||||
if not os.path.exists(pretrained_model_name_or_path):
|
||||
revision = kwargs.pop('revision', None)
|
||||
model_dir = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
ignore_file_pattern=ignore_file_pattern)
|
||||
else:
|
||||
model_dir = pretrained_model_name_or_path
|
||||
return model_dir
|
||||
|
||||
def patch_tokenizer_base():
|
||||
""" Monkey patch PreTrainedTokenizerBase.from_pretrained to adapt to modelscope hub.
|
||||
"""
|
||||
ori_from_pretrained = PreTrainedTokenizerBase.from_pretrained.__func__
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
|
||||
**kwargs):
|
||||
ignore_file_pattern = [
|
||||
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt'
|
||||
]
|
||||
model_dir = get_model_dir(pretrained_model_name_or_path,
|
||||
ignore_file_pattern, **kwargs)
|
||||
return ori_from_pretrained(cls, model_dir, *model_args, **kwargs)
|
||||
|
||||
PreTrainedTokenizerBase.from_pretrained = from_pretrained
|
||||
|
||||
def patch_config_base():
|
||||
""" Monkey patch PretrainedConfig.from_pretrained to adapt to modelscope hub.
|
||||
"""
|
||||
ori_from_pretrained = PretrainedConfig.from_pretrained.__func__
|
||||
ori_get_config_dict = PretrainedConfig.get_config_dict.__func__
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
|
||||
**kwargs):
|
||||
ignore_file_pattern = [
|
||||
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt'
|
||||
]
|
||||
model_dir = get_model_dir(pretrained_model_name_or_path,
|
||||
ignore_file_pattern, **kwargs)
|
||||
return ori_from_pretrained(cls, model_dir, *model_args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_config_dict(cls, pretrained_model_name_or_path, **kwargs):
|
||||
ignore_file_pattern = [
|
||||
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt'
|
||||
]
|
||||
model_dir = get_model_dir(pretrained_model_name_or_path,
|
||||
ignore_file_pattern, **kwargs)
|
||||
return ori_get_config_dict(cls, model_dir, **kwargs)
|
||||
|
||||
PretrainedConfig.from_pretrained = from_pretrained
|
||||
PretrainedConfig.get_config_dict = get_config_dict
|
||||
|
||||
def patch_model_base():
|
||||
""" Monkey patch PreTrainedModel.from_pretrained to adapt to modelscope hub.
|
||||
"""
|
||||
ori_from_pretrained = PreTrainedModel.from_pretrained.__func__
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
|
||||
**kwargs):
|
||||
model_dir = get_model_dir(pretrained_model_name_or_path, None,
|
||||
**kwargs)
|
||||
return ori_from_pretrained(cls, model_dir, *model_args, **kwargs)
|
||||
|
||||
PreTrainedModel.from_pretrained = from_pretrained
|
||||
|
||||
def patch_image_processor_base():
|
||||
""" Monkey patch AutoImageProcessorHF.from_pretrained to adapt to modelscope hub.
|
||||
"""
|
||||
ori_from_pretrained = AutoImageProcessorHF.from_pretrained.__func__
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
|
||||
**kwargs):
|
||||
model_dir = get_model_dir(pretrained_model_name_or_path, None,
|
||||
**kwargs)
|
||||
return ori_from_pretrained(cls, model_dir, *model_args, **kwargs)
|
||||
|
||||
AutoImageProcessorHF.from_pretrained = from_pretrained
|
||||
|
||||
def patch_auto_processor_base():
|
||||
""" Monkey patch AutoProcessorHF.from_pretrained to adapt to modelscope hub.
|
||||
"""
|
||||
ori_from_pretrained = AutoProcessorHF.from_pretrained.__func__
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
|
||||
**kwargs):
|
||||
model_dir = get_model_dir(pretrained_model_name_or_path, None,
|
||||
**kwargs)
|
||||
return ori_from_pretrained(cls, model_dir, *model_args, **kwargs)
|
||||
|
||||
AutoProcessorHF.from_pretrained = from_pretrained
|
||||
|
||||
def patch_feature_extractor_base():
|
||||
""" Monkey patch AutoFeatureExtractorHF.from_pretrained to adapt to modelscope hub.
|
||||
"""
|
||||
ori_from_pretrained = AutoFeatureExtractorHF.from_pretrained.__func__
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
|
||||
**kwargs):
|
||||
model_dir = get_model_dir(pretrained_model_name_or_path, None,
|
||||
**kwargs)
|
||||
return ori_from_pretrained(cls, model_dir, *model_args, **kwargs)
|
||||
|
||||
AutoFeatureExtractorHF.from_pretrained = from_pretrained
|
||||
|
||||
patch_tokenizer_base()
|
||||
patch_config_base()
|
||||
patch_model_base()
|
||||
patch_image_processor_base()
|
||||
patch_auto_processor_base()
|
||||
patch_feature_extractor_base()
|
||||
|
||||
|
||||
def patch_hub():
|
||||
"""Patch hf hub, which to make users can download models from modelscope to speed up.
|
||||
"""
|
||||
import huggingface_hub
|
||||
from huggingface_hub import hf_api
|
||||
from huggingface_hub.hf_api import api
|
||||
|
||||
huggingface_hub.hf_hub_download = _file_download
|
||||
huggingface_hub.file_download.hf_hub_download = _file_download
|
||||
|
||||
hf_api.file_exists = MethodType(_file_exists, api)
|
||||
huggingface_hub.file_exists = hf_api.file_exists
|
||||
huggingface_hub.hf_api.file_exists = hf_api.file_exists
|
||||
|
||||
_patch_pretrained_class()
|
||||
|
||||
|
||||
def get_wrapped_class(module_class,
|
||||
ignore_file_pattern=[],
|
||||
file_filter=None,
|
||||
**kwargs):
|
||||
"""Get a custom wrapper class for auto classes to download the models from the ModelScope hub
|
||||
Args:
|
||||
module_class: The actual module class
|
||||
ignore_file_pattern (`str` or `List`, *optional*, default to `None`):
|
||||
Any file pattern to be ignored in downloading, like exact file names or file extensions.
|
||||
Returns:
|
||||
The wrapper
|
||||
"""
|
||||
default_ignore_file_pattern = ignore_file_pattern
|
||||
default_file_filter = file_filter
|
||||
|
||||
class ClassWrapper(module_class):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
|
||||
**kwargs):
|
||||
ignore_file_pattern = kwargs.pop('ignore_file_pattern',
|
||||
default_ignore_file_pattern)
|
||||
subfolder = kwargs.pop('subfolder', default_file_filter)
|
||||
file_filter = None
|
||||
if subfolder:
|
||||
file_filter = f'{subfolder}/*'
|
||||
if not os.path.exists(pretrained_model_name_or_path):
|
||||
revision = kwargs.pop('revision', DEFAULT_MODEL_REVISION)
|
||||
if file_filter is None:
|
||||
model_dir = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
ignore_file_pattern=ignore_file_pattern,
|
||||
user_agent=user_agent())
|
||||
else:
|
||||
model_dir = os.path.join(
|
||||
snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
ignore_file_pattern=ignore_file_pattern,
|
||||
allow_file_pattern=file_filter,
|
||||
user_agent=user_agent()), subfolder)
|
||||
else:
|
||||
model_dir = pretrained_model_name_or_path
|
||||
|
||||
module_obj = module_class.from_pretrained(model_dir, *model_args,
|
||||
**kwargs)
|
||||
|
||||
if module_class.__name__.startswith('AutoModel'):
|
||||
module_obj.model_dir = model_dir
|
||||
return module_obj
|
||||
|
||||
ClassWrapper.__name__ = module_class.__name__
|
||||
ClassWrapper.__qualname__ = module_class.__qualname__
|
||||
return ClassWrapper
|
||||
|
||||
|
||||
AutoModel = get_wrapped_class(AutoModelHF)
|
||||
AutoModelForCausalLM = get_wrapped_class(AutoModelForCausalLMHF)
|
||||
AutoModelForSeq2SeqLM = get_wrapped_class(AutoModelForSeq2SeqLMHF)
|
||||
AutoModelForVision2Seq = get_wrapped_class(AutoModelForVision2SeqHF)
|
||||
AutoModelForSequenceClassification = get_wrapped_class(
|
||||
AutoModelForSequenceClassificationHF)
|
||||
AutoModelForTokenClassification = get_wrapped_class(
|
||||
AutoModelForTokenClassificationHF)
|
||||
AutoModelForImageSegmentation = get_wrapped_class(
|
||||
AutoModelForImageSegmentationHF)
|
||||
AutoModelForImageClassification = get_wrapped_class(
|
||||
AutoModelForImageClassificationHF)
|
||||
AutoModelForZeroShotImageClassification = get_wrapped_class(
|
||||
AutoModelForZeroShotImageClassificationHF)
|
||||
try:
|
||||
from transformers import AutoModelForImageToImage as AutoModelForImageToImageHF
|
||||
AutoModelForImageToImage = get_wrapped_class(AutoModelForImageToImageHF)
|
||||
except ImportError:
|
||||
AutoModelForImageToImage = UnsupportedAutoClass('AutoModelForImageToImage')
|
||||
|
||||
try:
|
||||
from transformers import AutoModelForImageTextToText as AutoModelForImageTextToTextHF
|
||||
AutoModelForImageTextToText = get_wrapped_class(
|
||||
AutoModelForImageTextToTextHF)
|
||||
except ImportError:
|
||||
AutoModelForImageTextToText = UnsupportedAutoClass(
|
||||
'AutoModelForImageTextToText')
|
||||
|
||||
try:
|
||||
from transformers import AutoModelForKeypointDetection as AutoModelForKeypointDetectionHF
|
||||
AutoModelForKeypointDetection = get_wrapped_class(
|
||||
AutoModelForKeypointDetectionHF)
|
||||
except ImportError:
|
||||
AutoModelForKeypointDetection = UnsupportedAutoClass(
|
||||
'AutoModelForKeypointDetection')
|
||||
|
||||
AutoModelForQuestionAnswering = get_wrapped_class(
|
||||
AutoModelForQuestionAnsweringHF)
|
||||
AutoModelForTableQuestionAnswering = get_wrapped_class(
|
||||
AutoModelForTableQuestionAnsweringHF)
|
||||
AutoModelForVisualQuestionAnswering = get_wrapped_class(
|
||||
AutoModelForVisualQuestionAnsweringHF)
|
||||
AutoModelForDocumentQuestionAnswering = get_wrapped_class(
|
||||
AutoModelForDocumentQuestionAnsweringHF)
|
||||
AutoModelForSemanticSegmentation = get_wrapped_class(
|
||||
AutoModelForSemanticSegmentationHF)
|
||||
AutoModelForUniversalSegmentation = get_wrapped_class(
|
||||
AutoModelForUniversalSegmentationHF)
|
||||
AutoModelForInstanceSegmentation = get_wrapped_class(
|
||||
AutoModelForInstanceSegmentationHF)
|
||||
AutoModelForObjectDetection = get_wrapped_class(AutoModelForObjectDetectionHF)
|
||||
AutoModelForZeroShotObjectDetection = get_wrapped_class(
|
||||
AutoModelForZeroShotObjectDetectionHF)
|
||||
AutoModelForAudioClassification = get_wrapped_class(
|
||||
AutoModelForAudioClassificationHF)
|
||||
AutoModelForSpeechSeq2Seq = get_wrapped_class(AutoModelForSpeechSeq2SeqHF)
|
||||
AutoModelForMaskedImageModeling = get_wrapped_class(
|
||||
AutoModelForMaskedImageModelingHF)
|
||||
AutoModelForMaskedLM = get_wrapped_class(AutoModelForMaskedLMHF)
|
||||
AutoModelForMaskGeneration = get_wrapped_class(AutoModelForMaskGenerationHF)
|
||||
AutoModelForPreTraining = get_wrapped_class(AutoModelForPreTrainingHF)
|
||||
AutoModelForTextEncoding = get_wrapped_class(AutoModelForTextEncodingHF)
|
||||
T5EncoderModel = get_wrapped_class(T5EncoderModelHF)
|
||||
try:
|
||||
from transformers import \
|
||||
Qwen2VLForConditionalGeneration as Qwen2VLForConditionalGenerationHF
|
||||
Qwen2VLForConditionalGeneration = get_wrapped_class(
|
||||
Qwen2VLForConditionalGenerationHF)
|
||||
except ImportError:
|
||||
Qwen2VLForConditionalGeneration = UnsupportedAutoClass(
|
||||
'Qwen2VLForConditionalGeneration')
|
||||
|
||||
AutoTokenizer = get_wrapped_class(
|
||||
AutoTokenizerHF,
|
||||
ignore_file_pattern=[
|
||||
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt', r'\w+\.h5'
|
||||
])
|
||||
AutoProcessor = get_wrapped_class(
|
||||
AutoProcessorHF,
|
||||
ignore_file_pattern=[
|
||||
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt', r'\w+\.h5'
|
||||
])
|
||||
AutoConfig = get_wrapped_class(
|
||||
AutoConfigHF,
|
||||
ignore_file_pattern=[
|
||||
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt', r'\w+\.h5'
|
||||
])
|
||||
GenerationConfig = get_wrapped_class(
|
||||
GenerationConfigHF,
|
||||
ignore_file_pattern=[
|
||||
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt', r'\w+\.h5'
|
||||
])
|
||||
BitsAndBytesConfig = get_wrapped_class(
|
||||
BitsAndBytesConfigHF,
|
||||
ignore_file_pattern=[
|
||||
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt', r'\w+\.h5'
|
||||
])
|
||||
AutoImageProcessor = get_wrapped_class(
|
||||
AutoImageProcessorHF,
|
||||
ignore_file_pattern=[
|
||||
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt', r'\w+\.h5'
|
||||
])
|
||||
|
||||
GPTQConfig = GPTQConfigHF
|
||||
AwqConfig = AwqConfigHF
|
||||
BatchFeature = get_wrapped_class(BatchFeatureHF)
|
||||
2
modelscope/utils/hf_util/__init__.py
Normal file
2
modelscope/utils/hf_util/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .auto_class import *
|
||||
from .patcher import patch_context, patch_hub, unpatch_hub
|
||||
75
modelscope/utils/hf_util/auto_class.py
Normal file
75
modelscope/utils/hf_util/auto_class.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import AutoConfig
|
||||
from transformers import AutoFeatureExtractor
|
||||
from transformers import AutoImageProcessor
|
||||
from transformers import AutoModel
|
||||
from transformers import AutoModelForAudioClassification
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers import AutoModelForDocumentQuestionAnswering
|
||||
from transformers import AutoModelForImageClassification
|
||||
from transformers import AutoModelForImageSegmentation
|
||||
from transformers import AutoModelForInstanceSegmentation
|
||||
from transformers import AutoModelForMaskedImageModeling
|
||||
from transformers import AutoModelForMaskedLM
|
||||
from transformers import AutoModelForMaskGeneration
|
||||
from transformers import AutoModelForObjectDetection
|
||||
from transformers import AutoModelForPreTraining
|
||||
from transformers import AutoModelForQuestionAnswering
|
||||
from transformers import AutoModelForSemanticSegmentation
|
||||
from transformers import AutoModelForSeq2SeqLM
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
from transformers import AutoModelForSpeechSeq2Seq
|
||||
from transformers import AutoModelForTableQuestionAnswering
|
||||
from transformers import AutoModelForTextEncoding
|
||||
from transformers import AutoModelForTokenClassification
|
||||
from transformers import AutoModelForUniversalSegmentation
|
||||
from transformers import AutoModelForVision2Seq
|
||||
from transformers import AutoModelForVisualQuestionAnswering
|
||||
from transformers import AutoModelForZeroShotImageClassification
|
||||
from transformers import AutoModelForZeroShotObjectDetection
|
||||
from transformers import AutoProcessor
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import BatchFeature
|
||||
from transformers import BitsAndBytesConfig
|
||||
from transformers import GenerationConfig
|
||||
from transformers import (PretrainedConfig, PreTrainedModel,
|
||||
PreTrainedTokenizerBase)
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
try:
|
||||
from transformers import Qwen2VLForConditionalGeneration
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from transformers import GPTQConfig
|
||||
from transformers import AwqConfig
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from transformers import AutoModelForImageToImage
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from transformers import AutoModelForImageTextToText
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from transformers import AutoModelForKeypointDetection
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
else:
|
||||
|
||||
from .patcher import get_all_imported_modules, _patch_pretrained_class
|
||||
all_available_modules = _patch_pretrained_class(
|
||||
get_all_imported_modules(), wrap=True)
|
||||
|
||||
for module in all_available_modules:
|
||||
globals()[module.__name__] = module
|
||||
559
modelscope/utils/hf_util/patcher.py
Normal file
559
modelscope/utils/hf_util/patcher.py
Normal file
@@ -0,0 +1,559 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import contextlib
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from types import MethodType
|
||||
from typing import BinaryIO, Dict, List, Optional, Union
|
||||
|
||||
|
||||
def get_all_imported_modules():
|
||||
"""Find all modules in transformers/peft/diffusers"""
|
||||
all_imported_modules = []
|
||||
transformers_include_names = [
|
||||
'Auto', 'T5', 'BitsAndBytes', 'GenerationConfig', 'Quant', 'Awq',
|
||||
'GPTQ', 'BatchFeature', 'Qwen2'
|
||||
]
|
||||
diffusers_include_names = ['Pipeline']
|
||||
if importlib.util.find_spec('transformers') is not None:
|
||||
import transformers
|
||||
lazy_module = sys.modules['transformers']
|
||||
_import_structure = lazy_module._import_structure
|
||||
for key in _import_structure:
|
||||
values = _import_structure[key]
|
||||
for value in values:
|
||||
# pretrained
|
||||
if any([name in value for name in transformers_include_names]):
|
||||
try:
|
||||
module = importlib.import_module(
|
||||
f'.{key}', transformers.__name__)
|
||||
value = getattr(module, value)
|
||||
all_imported_modules.append(value)
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
|
||||
if importlib.util.find_spec('peft') is not None:
|
||||
import peft
|
||||
attributes = dir(peft)
|
||||
imports = [attr for attr in attributes if not attr.startswith('__')]
|
||||
all_imported_modules.extend(
|
||||
[getattr(peft, _import) for _import in imports])
|
||||
|
||||
if importlib.util.find_spec('diffusers') is not None:
|
||||
import diffusers
|
||||
if importlib.util.find_spec('diffusers') is not None:
|
||||
lazy_module = sys.modules['diffusers']
|
||||
_import_structure = lazy_module._import_structure
|
||||
for key in _import_structure:
|
||||
values = _import_structure[key]
|
||||
for value in values:
|
||||
if any([name in value
|
||||
for name in diffusers_include_names]):
|
||||
try:
|
||||
module = importlib.import_module(
|
||||
f'.{key}', diffusers.__name__)
|
||||
value = getattr(module, value)
|
||||
all_imported_modules.append(value)
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
return all_imported_modules
|
||||
|
||||
|
||||
def _patch_pretrained_class(all_imported_modules, wrap=False):
|
||||
"""Patch all class to download from modelscope
|
||||
|
||||
Args:
|
||||
wrap: Wrap the class or monkey patch the original class
|
||||
|
||||
Returns:
|
||||
The classes after patched
|
||||
"""
|
||||
|
||||
def get_model_dir(pretrained_model_name_or_path,
|
||||
ignore_file_pattern=None,
|
||||
allow_file_pattern=None,
|
||||
**kwargs):
|
||||
from modelscope import snapshot_download
|
||||
if not os.path.exists(pretrained_model_name_or_path):
|
||||
revision = kwargs.pop('revision', None)
|
||||
model_dir = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
ignore_file_pattern=ignore_file_pattern,
|
||||
allow_file_pattern=allow_file_pattern)
|
||||
else:
|
||||
model_dir = pretrained_model_name_or_path
|
||||
return model_dir
|
||||
|
||||
ignore_file_pattern = [
|
||||
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt', r'\w+\.h5',
|
||||
r'\w+\.ckpt'
|
||||
]
|
||||
|
||||
def patch_pretrained_model_name_or_path(pretrained_model_name_or_path,
|
||||
*model_args, **kwargs):
|
||||
"""Patch all from_pretrained/get_config_dict"""
|
||||
model_dir = get_model_dir(pretrained_model_name_or_path,
|
||||
kwargs.pop('ignore_file_pattern', None),
|
||||
kwargs.pop('allow_file_pattern', None),
|
||||
**kwargs)
|
||||
return kwargs.pop('ori_func')(model_dir, *model_args, **kwargs)
|
||||
|
||||
def patch_peft_model_id(model, model_id, *model_args, **kwargs):
|
||||
"""Patch all peft.from_pretrained"""
|
||||
model_dir = get_model_dir(model_id,
|
||||
kwargs.pop('ignore_file_pattern', None),
|
||||
kwargs.pop('allow_file_pattern', None),
|
||||
**kwargs)
|
||||
return kwargs.pop('ori_func')(model, model_dir, *model_args, **kwargs)
|
||||
|
||||
def _get_peft_type(model_id, **kwargs):
|
||||
"""Patch all _get_peft_type"""
|
||||
model_dir = get_model_dir(model_id,
|
||||
kwargs.pop('ignore_file_pattern', None),
|
||||
kwargs.pop('allow_file_pattern', None),
|
||||
**kwargs)
|
||||
return kwargs.pop('ori_func')(model_dir, **kwargs)
|
||||
|
||||
def get_wrapped_class(
|
||||
module_class: 'PreTrainedModel',
|
||||
ignore_file_pattern: Optional[Union[str, List[str]]] = None,
|
||||
allow_file_pattern: Optional[Union[str, List[str]]] = None,
|
||||
**kwargs):
|
||||
"""Get a custom wrapper class for auto classes to download the models from the ModelScope hub
|
||||
Args:
|
||||
module_class (`PreTrainedModel`): The actual module class
|
||||
ignore_file_pattern (`str` or `List`, *optional*, default to `None`):
|
||||
Any file pattern to be ignored, like exact file names or file extensions.
|
||||
allow_file_pattern (`str` or `List`, *optional*, default to `None`):
|
||||
Any file pattern to be included, like exact file names or file extensions.
|
||||
Returns:
|
||||
The wrapped class
|
||||
"""
|
||||
|
||||
def from_pretrained(model, model_id, *model_args, **kwargs):
|
||||
# model is an instance
|
||||
model_dir = get_model_dir(
|
||||
model_id,
|
||||
ignore_file_pattern=ignore_file_pattern,
|
||||
allow_file_pattern=allow_file_pattern,
|
||||
**kwargs)
|
||||
|
||||
module_obj = module_class.from_pretrained(model, model_dir,
|
||||
*model_args, **kwargs)
|
||||
|
||||
return module_obj
|
||||
|
||||
class ClassWrapper(module_class):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path,
|
||||
*model_args, **kwargs):
|
||||
model_dir = get_model_dir(
|
||||
pretrained_model_name_or_path,
|
||||
ignore_file_pattern=ignore_file_pattern,
|
||||
allow_file_pattern=allow_file_pattern,
|
||||
**kwargs)
|
||||
|
||||
module_obj = module_class.from_pretrained(
|
||||
model_dir, *model_args, **kwargs)
|
||||
|
||||
if module_class.__name__.startswith('AutoModel'):
|
||||
module_obj.model_dir = model_dir
|
||||
return module_obj
|
||||
|
||||
@classmethod
|
||||
def _get_peft_type(cls, model_id, **kwargs):
|
||||
model_dir = get_model_dir(
|
||||
model_id,
|
||||
ignore_file_pattern=ignore_file_pattern,
|
||||
allow_file_pattern=allow_file_pattern,
|
||||
**kwargs)
|
||||
module_obj = module_class._get_peft_type(model_dir, **kwargs)
|
||||
return module_obj
|
||||
|
||||
@classmethod
|
||||
def get_config_dict(cls, pretrained_model_name_or_path,
|
||||
*model_args, **kwargs):
|
||||
model_dir = get_model_dir(
|
||||
pretrained_model_name_or_path,
|
||||
ignore_file_pattern=ignore_file_pattern,
|
||||
allow_file_pattern=allow_file_pattern,
|
||||
**kwargs)
|
||||
|
||||
module_obj = module_class.get_config_dict(
|
||||
model_dir, *model_args, **kwargs)
|
||||
return module_obj
|
||||
|
||||
if not hasattr(module_class, 'from_pretrained'):
|
||||
del ClassWrapper.from_pretrained
|
||||
else:
|
||||
parameters = inspect.signature(var.from_pretrained).parameters
|
||||
if 'model' in parameters and 'model_id' in parameters:
|
||||
# peft
|
||||
ClassWrapper.from_pretrained = from_pretrained
|
||||
|
||||
if not hasattr(module_class, '_get_peft_type'):
|
||||
del ClassWrapper._get_peft_type
|
||||
|
||||
if not hasattr(module_class, 'get_config_dict'):
|
||||
del ClassWrapper.get_config_dict
|
||||
|
||||
ClassWrapper.__name__ = module_class.__name__
|
||||
ClassWrapper.__qualname__ = module_class.__qualname__
|
||||
return ClassWrapper
|
||||
|
||||
all_available_modules = []
|
||||
for var in all_imported_modules:
|
||||
if var is None or not hasattr(var, '__name__'):
|
||||
continue
|
||||
name = var.__name__
|
||||
need_model = 'model' in name.lower() or 'processor' in name.lower(
|
||||
) or 'extractor' in name.lower() or 'pipeline' in name.lower()
|
||||
if need_model:
|
||||
ignore_file_pattern_kwargs = {}
|
||||
else:
|
||||
ignore_file_pattern_kwargs = {
|
||||
'ignore_file_pattern': ignore_file_pattern
|
||||
}
|
||||
|
||||
try:
|
||||
# some TFxxx classes has import errors
|
||||
has_from_pretrained = hasattr(var, 'from_pretrained')
|
||||
has_get_peft_type = hasattr(var, '_get_peft_type')
|
||||
has_get_config_dict = hasattr(var, 'get_config_dict')
|
||||
except ImportError:
|
||||
continue
|
||||
|
||||
if wrap:
|
||||
try:
|
||||
if not has_from_pretrained and not has_get_config_dict and not has_get_peft_type:
|
||||
all_available_modules.append(var)
|
||||
else:
|
||||
all_available_modules.append(
|
||||
get_wrapped_class(var, **ignore_file_pattern_kwargs))
|
||||
except Exception:
|
||||
all_available_modules.append(var)
|
||||
else:
|
||||
if has_from_pretrained and not hasattr(var,
|
||||
'_from_pretrained_origin'):
|
||||
parameters = inspect.signature(var.from_pretrained).parameters
|
||||
# different argument names
|
||||
is_peft = 'model' in parameters and 'model_id' in parameters
|
||||
var._from_pretrained_origin = var.from_pretrained
|
||||
if not is_peft:
|
||||
var.from_pretrained = partial(
|
||||
patch_pretrained_model_name_or_path,
|
||||
ori_func=var._from_pretrained_origin,
|
||||
**ignore_file_pattern_kwargs)
|
||||
else:
|
||||
var.from_pretrained = partial(
|
||||
patch_peft_model_id,
|
||||
ori_func=var._from_pretrained_origin,
|
||||
**ignore_file_pattern_kwargs)
|
||||
if has_get_peft_type and not hasattr(var, '_get_peft_type_origin'):
|
||||
var._get_peft_type_origin = var._get_peft_type
|
||||
var._get_peft_type = partial(
|
||||
_get_peft_type,
|
||||
ori_func=var._get_peft_type_origin,
|
||||
**ignore_file_pattern_kwargs)
|
||||
|
||||
if has_get_config_dict and not hasattr(var,
|
||||
'_get_config_dict_origin'):
|
||||
var._get_config_dict_origin = var.get_config_dict
|
||||
var.get_config_dict = partial(
|
||||
patch_pretrained_model_name_or_path,
|
||||
ori_func=var._get_config_dict_origin,
|
||||
**ignore_file_pattern_kwargs)
|
||||
|
||||
all_available_modules.append(var)
|
||||
return all_available_modules
|
||||
|
||||
|
||||
def _unpatch_pretrained_class(all_imported_modules):
|
||||
for var in all_imported_modules:
|
||||
if var is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
has_from_pretrained = hasattr(var, 'from_pretrained')
|
||||
has_get_peft_type = hasattr(var, '_get_peft_type')
|
||||
has_get_config_dict = hasattr(var, 'get_config_dict')
|
||||
except ImportError:
|
||||
continue
|
||||
if has_from_pretrained and hasattr(var, '_from_pretrained_origin'):
|
||||
var.from_pretrained = var._from_pretrained_origin
|
||||
delattr(var, '_from_pretrained_origin')
|
||||
if has_get_peft_type and hasattr(var, '_get_peft_type_origin'):
|
||||
var._get_peft_type = var._get_peft_type_origin
|
||||
delattr(var, '_get_peft_type_origin')
|
||||
if has_get_config_dict and hasattr(var, '_get_config_dict_origin'):
|
||||
var.get_config_dict = var._get_config_dict_origin
|
||||
delattr(var, '_get_config_dict_origin')
|
||||
|
||||
|
||||
def _patch_hub():
|
||||
import huggingface_hub
|
||||
from huggingface_hub import hf_api
|
||||
from huggingface_hub.hf_api import api
|
||||
from huggingface_hub.hf_api import future_compatible
|
||||
from modelscope import get_logger
|
||||
logger = get_logger()
|
||||
|
||||
def _file_exists(
|
||||
self,
|
||||
repo_id: str,
|
||||
filename: str,
|
||||
*,
|
||||
repo_type: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
token: Union[str, bool, None] = None,
|
||||
):
|
||||
"""Patch huggingface_hub.file_exists"""
|
||||
if repo_type is not None:
|
||||
logger.warning(
|
||||
'The passed in repo_type will not be used in modelscope. Now only model repo can be queried.'
|
||||
)
|
||||
from modelscope.hub.api import HubApi
|
||||
api = HubApi()
|
||||
api.try_login(token)
|
||||
return api.file_exists(repo_id, filename, revision=revision)
|
||||
|
||||
def _file_download(repo_id: str,
|
||||
filename: str,
|
||||
*,
|
||||
subfolder: Optional[str] = None,
|
||||
repo_type: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
cache_dir: Union[str, Path, None] = None,
|
||||
local_dir: Union[str, Path, None] = None,
|
||||
token: Union[bool, str, None] = None,
|
||||
local_files_only: bool = False,
|
||||
**kwargs):
|
||||
"""Patch huggingface_hub.hf_hub_download"""
|
||||
if len(kwargs) > 0:
|
||||
logger.warning(
|
||||
'The passed in library_name,library_version,user_agent,force_download,proxies'
|
||||
'etag_timeout,headers,endpoint '
|
||||
'will not be used in modelscope.')
|
||||
assert repo_type in (
|
||||
None, 'model',
|
||||
'dataset'), f'repo_type={repo_type} is not supported in ModelScope'
|
||||
if repo_type in (None, 'model'):
|
||||
from modelscope.hub.file_download import model_file_download as file_download
|
||||
else:
|
||||
from modelscope.hub.file_download import dataset_file_download as file_download
|
||||
from modelscope import HubApi
|
||||
api = HubApi()
|
||||
api.try_login(token)
|
||||
return file_download(
|
||||
repo_id,
|
||||
file_path=os.path.join(subfolder, filename)
|
||||
if subfolder else filename,
|
||||
cache_dir=cache_dir,
|
||||
local_dir=local_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision)
|
||||
|
||||
def _whoami(self, token: Union[bool, str, None] = None) -> Dict:
|
||||
from modelscope.hub.api import ModelScopeConfig
|
||||
from modelscope.hub.api import HubApi
|
||||
api = HubApi()
|
||||
api.try_login(token)
|
||||
return {'name': ModelScopeConfig.get_user_info()[0] or 'unknown'}
|
||||
|
||||
def create_repo(self,
|
||||
repo_id: str,
|
||||
*,
|
||||
token: Union[str, bool, None] = None,
|
||||
private: bool = False,
|
||||
**kwargs) -> 'RepoUrl':
|
||||
"""
|
||||
Create a new repository on the hub.
|
||||
|
||||
Args:
|
||||
repo_id: The ID of the repository to create.
|
||||
token: The authentication token to use.
|
||||
private: Whether the repository should be private.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
RepoUrl: The URL of the created repository.
|
||||
"""
|
||||
from modelscope.hub.api import HubApi
|
||||
api = HubApi()
|
||||
from modelscope.hub.constants import ModelVisibility
|
||||
visibility = ModelVisibility.PRIVATE if private else ModelVisibility.PUBLIC
|
||||
hub_model_id = api.create_repo(
|
||||
repo_id, token=token, visibility=visibility, **kwargs)
|
||||
from huggingface_hub import RepoUrl
|
||||
return RepoUrl(url=hub_model_id, )
|
||||
|
||||
@future_compatible
|
||||
def upload_folder(
|
||||
self,
|
||||
*,
|
||||
repo_id: str,
|
||||
folder_path: Union[str, Path],
|
||||
path_in_repo: Optional[str] = None,
|
||||
commit_message: Optional[str] = None,
|
||||
commit_description: Optional[str] = None,
|
||||
token: Union[str, bool, None] = None,
|
||||
revision: Optional[str] = 'master',
|
||||
ignore_patterns: Optional[Union[List[str], str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
from modelscope.hub.push_to_hub import _push_files_to_hub
|
||||
_push_files_to_hub(
|
||||
path_or_fileobj=folder_path,
|
||||
path_in_repo=path_in_repo,
|
||||
repo_id=repo_id,
|
||||
commit_message=commit_message,
|
||||
commit_description=commit_description,
|
||||
revision=revision,
|
||||
token=token)
|
||||
from modelscope.utils.repo_utils import CommitInfo
|
||||
return CommitInfo(
|
||||
commit_url=f'https://www.modelscope.cn/models/{repo_id}/files',
|
||||
commit_message=commit_message,
|
||||
commit_description=commit_description,
|
||||
oid=None,
|
||||
)
|
||||
|
||||
from modelscope.utils.constant import DEFAULT_REPOSITORY_REVISION
|
||||
|
||||
@future_compatible
|
||||
def upload_file(
|
||||
self,
|
||||
*,
|
||||
path_or_fileobj: Union[str, Path, bytes, BinaryIO],
|
||||
path_in_repo: str,
|
||||
repo_id: str,
|
||||
token: Union[str, bool, None] = None,
|
||||
revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
|
||||
commit_message: Optional[str] = None,
|
||||
commit_description: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
from modelscope.hub.push_to_hub import _push_files_to_hub
|
||||
_push_files_to_hub(path_or_fileobj, path_in_repo, repo_id, token,
|
||||
revision, commit_message, commit_description)
|
||||
|
||||
# Patch repocard.validate
|
||||
from huggingface_hub import repocard
|
||||
if not hasattr(repocard.RepoCard, '_validate_origin'):
|
||||
repocard.RepoCard._validate_origin = repocard.RepoCard.validate
|
||||
repocard.RepoCard.validate = lambda *args, **kwargs: None
|
||||
|
||||
if not hasattr(hf_api, '_hf_hub_download_origin'):
|
||||
# Patch hf_hub_download
|
||||
hf_api._hf_hub_download_origin = huggingface_hub.file_download.hf_hub_download
|
||||
huggingface_hub.hf_hub_download = _file_download
|
||||
huggingface_hub.file_download.hf_hub_download = _file_download
|
||||
|
||||
if not hasattr(hf_api, '_file_exists_origin'):
|
||||
# Patch file_exists
|
||||
hf_api._file_exists_origin = hf_api.file_exists
|
||||
hf_api.file_exists = MethodType(_file_exists, api)
|
||||
huggingface_hub.file_exists = hf_api.file_exists
|
||||
huggingface_hub.hf_api.file_exists = hf_api.file_exists
|
||||
|
||||
if not hasattr(hf_api, '_whoami_origin'):
|
||||
# Patch whoami
|
||||
hf_api._whoami_origin = hf_api.whoami
|
||||
hf_api.whoami = MethodType(_whoami, api)
|
||||
huggingface_hub.whoami = hf_api.whoami
|
||||
huggingface_hub.hf_api.whoami = hf_api.whoami
|
||||
|
||||
if not hasattr(hf_api, '_create_repo_origin'):
|
||||
# Patch create_repo
|
||||
from transformers.utils import hub
|
||||
hf_api._create_repo_origin = hf_api.create_repo
|
||||
hf_api.create_repo = MethodType(create_repo, api)
|
||||
huggingface_hub.create_repo = hf_api.create_repo
|
||||
huggingface_hub.hf_api.create_repo = hf_api.create_repo
|
||||
hub.create_repo = hf_api.create_repo
|
||||
|
||||
if not hasattr(hf_api, '_upload_folder_origin'):
|
||||
# Patch upload_folder
|
||||
hf_api._upload_folder_origin = hf_api.upload_folder
|
||||
hf_api.upload_folder = MethodType(upload_folder, api)
|
||||
huggingface_hub.upload_folder = hf_api.upload_folder
|
||||
huggingface_hub.hf_api.upload_folder = hf_api.upload_folder
|
||||
|
||||
if not hasattr(hf_api, '_upload_file_origin'):
|
||||
# Patch upload_file
|
||||
hf_api._upload_file_origin = hf_api.upload_file
|
||||
hf_api.upload_file = MethodType(upload_file, api)
|
||||
huggingface_hub.upload_file = hf_api.upload_file
|
||||
huggingface_hub.hf_api.upload_file = hf_api.upload_file
|
||||
repocard.upload_file = hf_api.upload_file
|
||||
|
||||
|
||||
def _unpatch_hub():
|
||||
import huggingface_hub
|
||||
from huggingface_hub import hf_api
|
||||
|
||||
from huggingface_hub import repocard
|
||||
if hasattr(repocard.RepoCard, '_validate_origin'):
|
||||
repocard.RepoCard.validate = repocard.RepoCard._validate_origin
|
||||
delattr(repocard.RepoCard, '_validate_origin')
|
||||
|
||||
if hasattr(hf_api, '_hf_hub_download_origin'):
|
||||
huggingface_hub.file_download.hf_hub_download = hf_api._hf_hub_download_origin
|
||||
huggingface_hub.hf_hub_download = hf_api._hf_hub_download_origin
|
||||
huggingface_hub.file_download.hf_hub_download = hf_api._hf_hub_download_origin
|
||||
delattr(hf_api, '_hf_hub_download_origin')
|
||||
|
||||
if hasattr(hf_api, '_file_exists_origin'):
|
||||
hf_api.file_exists = hf_api._file_exists_origin
|
||||
huggingface_hub.file_exists = hf_api.file_exists
|
||||
huggingface_hub.hf_api.file_exists = hf_api.file_exists
|
||||
delattr(hf_api, '_file_exists_origin')
|
||||
|
||||
if hasattr(hf_api, '_whoami_origin'):
|
||||
hf_api.whoami = hf_api._whoami_origin
|
||||
huggingface_hub.whoami = hf_api.whoami
|
||||
huggingface_hub.hf_api.whoami = hf_api.whoami
|
||||
delattr(hf_api, '_whoami_origin')
|
||||
|
||||
if hasattr(hf_api, '_create_repo_origin'):
|
||||
from transformers.utils import hub
|
||||
hf_api.create_repo = hf_api._create_repo_origin
|
||||
huggingface_hub.create_repo = hf_api.create_repo
|
||||
huggingface_hub.hf_api.create_repo = hf_api.create_repo
|
||||
hub.create_repo = hf_api.create_repo
|
||||
delattr(hf_api, '_create_repo_origin')
|
||||
|
||||
if hasattr(hf_api, '_upload_folder_origin'):
|
||||
hf_api.upload_folder = hf_api._upload_folder_origin
|
||||
huggingface_hub.upload_folder = hf_api.upload_folder
|
||||
huggingface_hub.hf_api.upload_folder = hf_api.upload_folder
|
||||
delattr(hf_api, '_upload_folder_origin')
|
||||
|
||||
if hasattr(hf_api, '_upload_file_origin'):
|
||||
hf_api.upload_file = hf_api._upload_file_origin
|
||||
huggingface_hub.upload_file = hf_api.upload_file
|
||||
huggingface_hub.hf_api.upload_file = hf_api.upload_file
|
||||
repocard.upload_file = hf_api.upload_file
|
||||
delattr(hf_api, '_upload_file_origin')
|
||||
|
||||
|
||||
def patch_hub():
|
||||
_patch_hub()
|
||||
_patch_pretrained_class(get_all_imported_modules())
|
||||
|
||||
|
||||
def unpatch_hub():
|
||||
_unpatch_pretrained_class(get_all_imported_modules())
|
||||
_unpatch_hub()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def patch_context():
|
||||
patch_hub()
|
||||
yield
|
||||
unpatch_hub()
|
||||
@@ -282,6 +282,10 @@ def is_transformers_available():
|
||||
return importlib.util.find_spec('transformers') is not None
|
||||
|
||||
|
||||
def is_diffusers_available():
|
||||
return importlib.util.find_spec('diffusers') is not None
|
||||
|
||||
|
||||
def is_tensorrt_llm_available():
|
||||
return importlib.util.find_spec('tensorrt_llm') is not None
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ TEST_ACCESS_TOKEN1 = os.environ.get('TEST_ACCESS_TOKEN_CITEST', None)
|
||||
TEST_ACCESS_TOKEN2 = os.environ.get('TEST_ACCESS_TOKEN_SDKDEV', None)
|
||||
|
||||
TEST_MODEL_CHINESE_NAME = '内部测试模型'
|
||||
TEST_MODEL_ORG = 'citest'
|
||||
TEST_MODEL_ORG = os.environ.get('TEST_MODEL_ORG', 'citest')
|
||||
|
||||
|
||||
def delete_credential():
|
||||
|
||||
@@ -1,20 +1,54 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
import uuid
|
||||
|
||||
from modelscope import (AutoConfig, AutoModel, AutoModelForCausalLM,
|
||||
AutoTokenizer, GenerationConfig)
|
||||
from huggingface_hub import CommitInfo, RepoUrl
|
||||
|
||||
from modelscope import HubApi
|
||||
from modelscope.utils.hf_util.patcher import patch_context
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.test_utils import TEST_MODEL_ORG
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class HFUtilTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
pass
|
||||
logger.info('SetUp')
|
||||
self.api = HubApi()
|
||||
self.user = TEST_MODEL_ORG
|
||||
print(self.user)
|
||||
self.create_model_name = '%s/%s_%s' % (self.user, 'test_model_upload',
|
||||
uuid.uuid4().hex)
|
||||
logger.info('create %s' % self.create_model_name)
|
||||
temporary_dir = tempfile.mkdtemp()
|
||||
self.work_dir = temporary_dir
|
||||
self.model_dir = os.path.join(temporary_dir, self.create_model_name)
|
||||
self.repo_path = os.path.join(self.work_dir, 'repo_path')
|
||||
self.test_folder = os.path.join(temporary_dir, 'test_folder')
|
||||
self.test_file1 = os.path.join(
|
||||
os.path.join(temporary_dir, 'test_folder', '1.json'))
|
||||
self.test_file2 = os.path.join(os.path.join(temporary_dir, '2.json'))
|
||||
os.makedirs(self.test_folder, exist_ok=True)
|
||||
with open(self.test_file1, 'w') as f:
|
||||
f.write('{}')
|
||||
with open(self.test_file2, 'w') as f:
|
||||
f.write('{}')
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
logger.info('TearDown')
|
||||
shutil.rmtree(self.model_dir, ignore_errors=True)
|
||||
try:
|
||||
self.api.delete_model(model_id=self.create_model_name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def test_auto_tokenizer(self):
|
||||
from modelscope import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
'baichuan-inc/Baichuan2-7B-Chat',
|
||||
trust_remote_code=True,
|
||||
@@ -24,15 +58,17 @@ class HFUtilTest(unittest.TestCase):
|
||||
self.assertFalse(tokenizer.is_fast)
|
||||
|
||||
def test_quantization_import(self):
|
||||
from modelscope import GPTQConfig, BitsAndBytesConfig
|
||||
from modelscope import BitsAndBytesConfig
|
||||
self.assertTrue(BitsAndBytesConfig is not None)
|
||||
|
||||
def test_auto_model(self):
|
||||
from modelscope import AutoModelForCausalLM
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
'baichuan-inc/baichuan-7B', trust_remote_code=True)
|
||||
self.assertTrue(model is not None)
|
||||
|
||||
def test_auto_config(self):
|
||||
from modelscope import AutoConfig, GenerationConfig
|
||||
config = AutoConfig.from_pretrained(
|
||||
'baichuan-inc/Baichuan-13B-Chat',
|
||||
trust_remote_code=True,
|
||||
@@ -45,12 +81,143 @@ class HFUtilTest(unittest.TestCase):
|
||||
self.assertEqual(gen_config.assistant_token_id, 196)
|
||||
|
||||
def test_transformer_patch(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
'iic/nlp_structbert_sentiment-classification_chinese-base')
|
||||
self.assertIsNotNone(tokenizer)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
'iic/nlp_structbert_sentiment-classification_chinese-base')
|
||||
self.assertIsNotNone(model)
|
||||
with patch_context():
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
'iic/nlp_structbert_sentiment-classification_chinese-base')
|
||||
self.assertIsNotNone(tokenizer)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
'iic/nlp_structbert_sentiment-classification_chinese-base')
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_patch_model(self):
|
||||
from modelscope.utils.hf_util.patcher import patch_context
|
||||
with patch_context():
|
||||
from transformers import AutoModel
|
||||
model = AutoModel.from_pretrained(
|
||||
'iic/nlp_structbert_sentiment-classification_chinese-tiny')
|
||||
self.assertTrue(model is not None)
|
||||
try:
|
||||
model = AutoModel.from_pretrained(
|
||||
'iic/nlp_structbert_sentiment-classification_chinese-tiny')
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
self.assertTrue(False)
|
||||
|
||||
def test_patch_config_bert(self):
|
||||
from transformers import BertConfig
|
||||
try:
|
||||
BertConfig.from_pretrained(
|
||||
'iic/nlp_structbert_sentiment-classification_chinese-tiny')
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
self.assertTrue(False)
|
||||
|
||||
def test_patch_config(self):
|
||||
with patch_context():
|
||||
from transformers import AutoConfig
|
||||
config = AutoConfig.from_pretrained(
|
||||
'iic/nlp_structbert_sentiment-classification_chinese-tiny')
|
||||
self.assertTrue(config is not None)
|
||||
try:
|
||||
config = AutoConfig.from_pretrained(
|
||||
'iic/nlp_structbert_sentiment-classification_chinese-tiny')
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
self.assertTrue(False)
|
||||
|
||||
# Test patch again
|
||||
with patch_context():
|
||||
from transformers import AutoConfig
|
||||
config = AutoConfig.from_pretrained(
|
||||
'iic/nlp_structbert_sentiment-classification_chinese-tiny')
|
||||
self.assertTrue(config is not None)
|
||||
|
||||
def test_patch_diffusers(self):
|
||||
with patch_context():
|
||||
from diffusers import StableDiffusionPipeline
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
'AI-ModelScope/stable-diffusion-v1-5')
|
||||
self.assertTrue(pipe is not None)
|
||||
try:
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
'AI-ModelScope/stable-diffusion-v1-5')
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
self.assertTrue(False)
|
||||
|
||||
from modelscope import StableDiffusionPipeline
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
'AI-ModelScope/stable-diffusion-v1-5')
|
||||
self.assertTrue(pipe is not None)
|
||||
|
||||
def test_patch_peft(self):
|
||||
with patch_context():
|
||||
from transformers import AutoModelForCausalLM
|
||||
from peft import PeftModel
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
'OpenBMB/MiniCPM3-4B', trust_remote_code=True)
|
||||
model = PeftModel.from_pretrained(
|
||||
model, 'OpenBMB/MiniCPM3-RAG-LoRA', trust_remote_code=True)
|
||||
self.assertTrue(model is not None)
|
||||
self.assertFalse(hasattr(PeftModel, '_from_pretrained_origin'))
|
||||
|
||||
def test_patch_file_exists(self):
|
||||
with patch_context():
|
||||
from huggingface_hub import file_exists
|
||||
self.assertTrue(
|
||||
file_exists('AI-ModelScope/stable-diffusion-v1-5',
|
||||
'feature_extractor/preprocessor_config.json'))
|
||||
try:
|
||||
# Import again
|
||||
from huggingface_hub import file_exists # noqa
|
||||
file_exists('AI-ModelScope/stable-diffusion-v1-5',
|
||||
'feature_extractor/preprocessor_config.json')
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
self.assertTrue(False)
|
||||
|
||||
def test_patch_file_download(self):
|
||||
with patch_context():
|
||||
from huggingface_hub import hf_hub_download
|
||||
local_dir = hf_hub_download(
|
||||
'AI-ModelScope/stable-diffusion-v1-5',
|
||||
'feature_extractor/preprocessor_config.json')
|
||||
logger.info('patch file_download dir: ' + local_dir)
|
||||
self.assertTrue(local_dir is not None)
|
||||
|
||||
def test_patch_create_repo(self):
|
||||
with patch_context():
|
||||
from huggingface_hub import create_repo
|
||||
repo_url: RepoUrl = create_repo(self.create_model_name)
|
||||
logger.info('patch create repo result: ' + repo_url.repo_id)
|
||||
self.assertTrue(repo_url is not None)
|
||||
from huggingface_hub import upload_folder
|
||||
commit_info: CommitInfo = upload_folder(
|
||||
repo_id=self.create_model_name,
|
||||
folder_path=self.test_folder,
|
||||
path_in_repo='')
|
||||
logger.info('patch create repo result: ' + commit_info.commit_url)
|
||||
self.assertTrue(commit_info is not None)
|
||||
from huggingface_hub import file_exists
|
||||
self.assertTrue(file_exists(self.create_model_name, '1.json'))
|
||||
from huggingface_hub import upload_file
|
||||
commit_info: CommitInfo = upload_file(
|
||||
path_or_fileobj=self.test_file2,
|
||||
path_in_repo='test_folder2',
|
||||
repo_id=self.create_model_name)
|
||||
self.assertTrue(
|
||||
file_exists(self.create_model_name, 'test_folder2/2.json'))
|
||||
|
||||
def test_who_am_i(self):
|
||||
with patch_context():
|
||||
from huggingface_hub import whoami
|
||||
self.assertTrue(whoami()['name'] == self.user)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user