mirror of
https://github.com/modelscope/modelscope.git
synced 2026-05-18 05:05:00 +02:00
Fix downloading repos in automap (#1630)
This commit is contained in:
@@ -12,8 +12,7 @@ from types import MethodType
|
||||
from typing import BinaryIO, Dict, Iterable, List, Optional, Union
|
||||
|
||||
from modelscope.hub.constants import DEFAULT_MODELSCOPE_DATA_ENDPOINT
|
||||
from modelscope.utils.repo_utils import (CommitInfo, CommitOperation,
|
||||
CommitOperationAdd)
|
||||
from modelscope.utils.repo_utils import CommitInfo, CommitOperation
|
||||
|
||||
ignore_file_pattern = [
|
||||
r'*.bin',
|
||||
@@ -133,6 +132,29 @@ def get_all_imported_modules():
|
||||
return all_imported_modules
|
||||
|
||||
|
||||
def _decide_allow_file_pattern(module_name, cls=None):
|
||||
extra_allow_file_pattern = None
|
||||
if 'GenerationConfig' in module_name:
|
||||
from transformers.utils import GENERATION_CONFIG_NAME
|
||||
extra_allow_file_pattern = [GENERATION_CONFIG_NAME, r'*.py']
|
||||
elif 'Config' in module_name:
|
||||
from transformers import CONFIG_NAME
|
||||
extra_allow_file_pattern = [CONFIG_NAME, r'*.py']
|
||||
elif 'Tokenizer' in module_name:
|
||||
extra_allow_file_pattern = list((
|
||||
cls.vocab_files_names.values()
|
||||
) if cls is not None and hasattr(cls, 'vocab_files_names') else []) + [
|
||||
'chat_template.jinja', r'*.json', r'*.py', r'*.txt', r'*.model',
|
||||
r'*.tiktoken'
|
||||
] # noqa
|
||||
elif 'Processor' in module_name:
|
||||
extra_allow_file_pattern = [
|
||||
'chat_template.jinja', r'*.json', r'*.py', r'*.txt', r'*.model',
|
||||
r'*.tiktoken'
|
||||
]
|
||||
return extra_allow_file_pattern
|
||||
|
||||
|
||||
def _patch_pretrained_class(all_imported_modules, wrap=False):
|
||||
"""Patch all class to download from modelscope
|
||||
|
||||
@@ -231,28 +253,8 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
|
||||
|
||||
if kwargs.get(
|
||||
'allow_file_pattern') is None and module_class is not None:
|
||||
extra_allow_file_pattern = None
|
||||
if 'GenerationConfig' == module_class.__name__:
|
||||
from transformers.utils import GENERATION_CONFIG_NAME
|
||||
extra_allow_file_pattern = [
|
||||
GENERATION_CONFIG_NAME, r'*.py'
|
||||
]
|
||||
elif 'Config' in module_class.__name__:
|
||||
from transformers import CONFIG_NAME
|
||||
extra_allow_file_pattern = [CONFIG_NAME, r'*.py']
|
||||
elif 'Tokenizer' in module_class.__name__:
|
||||
extra_allow_file_pattern = list(
|
||||
(cls.vocab_files_names.values()) if cls is not None
|
||||
and hasattr(cls, 'vocab_files_names') else []) + [
|
||||
'chat_template.jinja', r'*.json', r'*.py',
|
||||
r'*.txt', r'*.model', r'*.tiktoken'
|
||||
] # noqa
|
||||
elif 'Processor' in module_class.__name__:
|
||||
extra_allow_file_pattern = [
|
||||
'chat_template.jinja', r'*.json', r'*.py', r'*.txt',
|
||||
r'*.model', r'*.tiktoken'
|
||||
]
|
||||
|
||||
extra_allow_file_pattern = _decide_allow_file_pattern(
|
||||
module_class.__name__, cls)
|
||||
kwargs['allow_file_pattern'] = extra_allow_file_pattern
|
||||
yield
|
||||
kwargs.pop('ignore_file_pattern', None)
|
||||
@@ -439,11 +441,26 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
|
||||
|
||||
def get_class_from_dynamic_module(class_reference, *args, **kwargs):
|
||||
from transformers.dynamic_module_utils import origin_get_class_from_dynamic_module
|
||||
if 'pretrained_model_name_or_path' in inspect.signature(
|
||||
origin_get_class_from_dynamic_module).parameters:
|
||||
pretrained_model_name_or_path = args[0]
|
||||
if not os.path.exists(pretrained_model_name_or_path):
|
||||
from modelscope import snapshot_download
|
||||
args[0] = snapshot_download(pretrained_model_name_or_path)
|
||||
if '--' in class_reference:
|
||||
repo_id, class_reference = class_reference.split('--')
|
||||
if not os.path.exists(repo_id):
|
||||
download_kwargs = {}
|
||||
extra_allow_file_pattern = _decide_allow_file_pattern(
|
||||
class_reference)
|
||||
if extra_allow_file_pattern is not None:
|
||||
download_kwargs[
|
||||
'allow_file_pattern'] = extra_allow_file_pattern
|
||||
if 'Config' in class_reference or 'Processor' in class_reference or 'Tokenizer' in class_reference:
|
||||
download_kwargs[
|
||||
'ignore_file_pattern'] = ignore_file_pattern
|
||||
from modelscope import snapshot_download
|
||||
repo_id = snapshot_download(repo_id)
|
||||
repo_id = snapshot_download(repo_id, **download_kwargs)
|
||||
class_reference = repo_id + '--' + class_reference
|
||||
return origin_get_class_from_dynamic_module(class_reference, *args,
|
||||
**kwargs)
|
||||
|
||||
@@ -253,6 +253,33 @@ class HFUtilTest(unittest.TestCase):
|
||||
from huggingface_hub import whoami
|
||||
self.assertTrue(whoami()['name'] == self.user)
|
||||
|
||||
def test_automapping_download(self):
|
||||
from modelscope import AutoConfig
|
||||
model = 'nomic-ai/nomic-embed-text-v1.5'
|
||||
config = AutoConfig.from_pretrained(model, trust_remote_code=True)
|
||||
model_dir = config.name_or_path
|
||||
files = os.listdir(model_dir)
|
||||
has_weight_files = any(
|
||||
f.endswith('.safetensors') or f.endswith('.bin') for f in files)
|
||||
self.assertFalse(
|
||||
has_weight_files,
|
||||
f'Expected no weight files in {model_dir}, but found: '
|
||||
f"{[f for f in files if f.endswith('.safetensors') or f.endswith('.bin')]}"
|
||||
)
|
||||
cache_dir = os.path.dirname(model_dir)
|
||||
cache_dir = os.path.dirname(cache_dir)
|
||||
model_dir_2 = os.path.join(cache_dir, 'nomic-ai', 'nomic-bert-2048')
|
||||
if os.path.exists(model_dir_2):
|
||||
files = os.listdir(model_dir_2)
|
||||
has_weight_files = any(
|
||||
f.endswith('.safetensors') or f.endswith('.bin')
|
||||
for f in files)
|
||||
self.assertFalse(
|
||||
has_weight_files,
|
||||
f'Expected no weight files in {model_dir}, but found: '
|
||||
f"{[f for f in files if f.endswith('.safetensors') or f.endswith('.bin')]}"
|
||||
)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_push_to_hub(self):
|
||||
with patch_context():
|
||||
|
||||
Reference in New Issue
Block a user