Support downloading exact file for hf wrapper (#1323)

This commit is contained in:
tastelikefeet
2025-04-30 14:57:59 +08:00
committed by GitHub
parent 806ac2b05e
commit a91f19ea54
4 changed files with 86 additions and 24 deletions

View File

@@ -2058,7 +2058,8 @@ class HubApi:
if query_addr: if query_addr:
domain_response = send_request(query_addr, timeout=internal_timeout) domain_response = send_request(query_addr, timeout=internal_timeout)
region_id = domain_response.text.strip() if domain_response is not None:
region_id = domain_response.text.strip()
return region_id return region_id

View File

@@ -326,9 +326,10 @@ class Preprocessor(ABC):
) )
return None return None
if (model_type, task) not in PREPROCESSOR_MAP: if (model_type, task) not in PREPROCESSOR_MAP:
logger.warning( logger.info(
f'No preprocessor key {(model_type, task)} found in PREPROCESSOR_MAP, ' f'No preprocessor key {(model_type, task)} found in PREPROCESSOR_MAP, '
f'skip building preprocessor.') f'skip building preprocessor. If the pipeline runs normally, please ignore this log.'
)
return None return None
sub_cfg = ConfigDict({ sub_cfg = ConfigDict({

View File

@@ -220,13 +220,60 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
The wrapped class The wrapped class
""" """
@contextlib.contextmanager
def file_pattern_context(kwargs, module_class, cls):
if 'allow_file_pattern' not in kwargs:
kwargs['allow_file_pattern'] = allow_file_pattern
if 'ignore_file_pattern' not in kwargs:
kwargs['ignore_file_pattern'] = ignore_file_pattern
if kwargs.get(
'allow_file_pattern') is None and module_class is not None:
extra_allow_file_pattern = None
if 'GenerationConfig' == module_class.__name__:
from transformers.utils import GENERATION_CONFIG_NAME
extra_allow_file_pattern = [
GENERATION_CONFIG_NAME, r'*.py'
]
elif 'Config' in module_class.__name__:
from transformers import CONFIG_NAME
extra_allow_file_pattern = [CONFIG_NAME, r'*.py']
elif 'Tokenizer' in module_class.__name__:
from transformers.tokenization_utils import ADDED_TOKENS_FILE
from transformers.tokenization_utils import SPECIAL_TOKENS_MAP_FILE
from transformers.tokenization_utils import TOKENIZER_CONFIG_FILE
from transformers.tokenization_utils_base import FULL_TOKENIZER_FILE
from transformers.tokenization_utils_base import CHAT_TEMPLATE_FILE
extra_allow_file_pattern = list(
(cls.vocab_files_names.values()) if cls is not None
and hasattr(cls, 'vocab_files_names') else []) + [
ADDED_TOKENS_FILE, SPECIAL_TOKENS_MAP_FILE,
TOKENIZER_CONFIG_FILE, FULL_TOKENIZER_FILE,
CHAT_TEMPLATE_FILE, r'*.py'
] # noqa
elif 'Processor' in module_class.__name__:
from transformers.utils import FEATURE_EXTRACTOR_NAME
from transformers.utils import PROCESSOR_NAME
from transformers.tokenization_utils import TOKENIZER_CONFIG_FILE
extra_allow_file_pattern = [
FEATURE_EXTRACTOR_NAME, TOKENIZER_CONFIG_FILE,
PROCESSOR_NAME, r'*.py'
]
kwargs['allow_file_pattern'] = extra_allow_file_pattern
yield
kwargs.pop('ignore_file_pattern', None)
kwargs.pop('allow_file_pattern', None)
def from_pretrained(model, model_id, *model_args, **kwargs): def from_pretrained(model, model_id, *model_args, **kwargs):
# model is an instance
model_dir = get_model_dir( with file_pattern_context(kwargs):
model_id, # model is an instance
ignore_file_pattern=ignore_file_pattern, model_dir = get_model_dir(
allow_file_pattern=allow_file_pattern, model_id,
**kwargs) module_class=module_class,
cls=module_class,
**kwargs)
module_obj = module_class.from_pretrained(model, model_dir, module_obj = module_class.from_pretrained(model, model_dir,
*model_args, **kwargs) *model_args, **kwargs)
@@ -238,11 +285,9 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, def from_pretrained(cls, pretrained_model_name_or_path,
*model_args, **kwargs): *model_args, **kwargs):
model_dir = get_model_dir( with file_pattern_context(kwargs, module_class, cls):
pretrained_model_name_or_path, model_dir = get_model_dir(pretrained_model_name_or_path,
ignore_file_pattern=ignore_file_pattern, **kwargs)
allow_file_pattern=allow_file_pattern,
**kwargs)
module_obj = module_class.from_pretrained( module_obj = module_class.from_pretrained(
model_dir, *model_args, **kwargs) model_dir, *model_args, **kwargs)
@@ -253,22 +298,25 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
@classmethod @classmethod
def _get_peft_type(cls, model_id, **kwargs): def _get_peft_type(cls, model_id, **kwargs):
model_dir = get_model_dir( with file_pattern_context(kwargs, module_class, cls):
model_id, model_dir = get_model_dir(
ignore_file_pattern=ignore_file_pattern, model_id,
allow_file_pattern=allow_file_pattern, ignore_file_pattern=ignore_file_pattern,
**kwargs) allow_file_pattern=allow_file_pattern,
**kwargs)
module_obj = module_class._get_peft_type(model_dir, **kwargs) module_obj = module_class._get_peft_type(model_dir, **kwargs)
return module_obj return module_obj
@classmethod @classmethod
def get_config_dict(cls, pretrained_model_name_or_path, def get_config_dict(cls, pretrained_model_name_or_path,
*model_args, **kwargs): *model_args, **kwargs):
model_dir = get_model_dir( with file_pattern_context(kwargs, module_class, cls):
pretrained_model_name_or_path, model_dir = get_model_dir(
ignore_file_pattern=ignore_file_pattern, pretrained_model_name_or_path,
allow_file_pattern=allow_file_pattern, ignore_file_pattern=ignore_file_pattern,
**kwargs) allow_file_pattern=allow_file_pattern,
**kwargs)
module_obj = module_class.get_config_dict( module_obj = module_class.get_config_dict(
model_dir, *model_args, **kwargs) model_dir, *model_args, **kwargs)

View File

@@ -91,6 +91,18 @@ class HFUtilTest(unittest.TestCase):
revision='v1.0.3') revision='v1.0.3')
self.assertEqual(gen_config.assistant_token_id, 196) self.assertEqual(gen_config.assistant_token_id, 196)
def test_qwen_tokenizer(self):
from modelscope import Qwen2Tokenizer
tokenizer = Qwen2Tokenizer.from_pretrained(
'Qwen/Qwen2-Math-7B-Instruct')
self.assertTrue(tokenizer is not None)
def test_extra_ignore_args(self):
from modelscope import Qwen2Tokenizer
tokenizer = Qwen2Tokenizer.from_pretrained(
'Qwen/Qwen2-Math-7B-Instruct', ignore_file_pattern=[r'\w+\.h5'])
self.assertTrue(tokenizer is not None)
def test_transformer_patch(self): def test_transformer_patch(self):
with patch_context(): with patch_context():
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM