do not download pt and pth files for autoconfig, autotoknizer and generation config (#1008)

Co-authored-by: Yingda Chen <yingda.chen@alibaba-inc.com>
This commit is contained in:
Yingda Chen
2024-10-01 15:17:05 +08:00
committed by GitHub
parent 8d72902e20
commit 9f631056f3

View File

@@ -127,7 +127,9 @@ def _patch_pretrained_class():
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
**kwargs):
ignore_file_pattern = [r'\w+\.bin', r'\w+\.safetensors']
ignore_file_pattern = [
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt'
]
model_dir = get_model_dir(pretrained_model_name_or_path,
ignore_file_pattern, **kwargs)
return ori_from_pretrained(cls, model_dir, *model_args, **kwargs)
@@ -143,14 +145,18 @@ def _patch_pretrained_class():
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
**kwargs):
ignore_file_pattern = [r'\w+\.bin', r'\w+\.safetensors']
ignore_file_pattern = [
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt'
]
model_dir = get_model_dir(pretrained_model_name_or_path,
ignore_file_pattern, **kwargs)
return ori_from_pretrained(cls, model_dir, *model_args, **kwargs)
@classmethod
def get_config_dict(cls, pretrained_model_name_or_path, **kwargs):
ignore_file_pattern = [r'\w+\.bin', r'\w+\.safetensors']
ignore_file_pattern = [
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt'
]
model_dir = get_model_dir(pretrained_model_name_or_path,
ignore_file_pattern, **kwargs)
return ori_get_config_dict(cls, model_dir, **kwargs)
@@ -242,11 +248,20 @@ AutoModelForTokenClassification = get_wrapped_class(
AutoModelForTokenClassificationHF)
AutoTokenizer = get_wrapped_class(
AutoTokenizerHF, ignore_file_pattern=[r'\w+\.bin', r'\w+\.safetensors'])
AutoTokenizerHF,
ignore_file_pattern=[
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt'
])
AutoConfig = get_wrapped_class(
AutoConfigHF, ignore_file_pattern=[r'\w+\.bin', r'\w+\.safetensors'])
AutoConfigHF,
ignore_file_pattern=[
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt'
])
GenerationConfig = get_wrapped_class(
GenerationConfigHF, ignore_file_pattern=[r'\w+\.bin', r'\w+\.safetensors'])
GenerationConfigHF,
ignore_file_pattern=[
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt'
])
GPTQConfig = GPTQConfigHF
AwqConfig = AwqConfigHF
BitsAndBytesConfig = BitsAndBytesConfigHF