mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
support subfolder of from_pretrained
This commit is contained in:
@@ -108,15 +108,23 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
|
|||||||
allow_file_pattern=None,
|
allow_file_pattern=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
from modelscope import snapshot_download
|
from modelscope import snapshot_download
|
||||||
|
subfolder = kwargs.pop('subfolder', None)
|
||||||
|
file_filter = None
|
||||||
|
if subfolder:
|
||||||
|
file_filter = f'{subfolder}/*'
|
||||||
if not os.path.exists(pretrained_model_name_or_path):
|
if not os.path.exists(pretrained_model_name_or_path):
|
||||||
revision = kwargs.pop('revision', None)
|
revision = kwargs.pop('revision', None)
|
||||||
if revision is None or revision == 'main':
|
if revision is None or revision == 'main':
|
||||||
revision = 'master'
|
revision = 'master'
|
||||||
|
if file_filter is not None:
|
||||||
|
allow_file_pattern = file_filter
|
||||||
model_dir = snapshot_download(
|
model_dir = snapshot_download(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
ignore_file_pattern=ignore_file_pattern,
|
ignore_file_pattern=ignore_file_pattern,
|
||||||
allow_file_pattern=allow_file_pattern)
|
allow_file_pattern=allow_file_pattern)
|
||||||
|
if subfolder:
|
||||||
|
model_dir = os.path.join(model_dir, subfolder)
|
||||||
else:
|
else:
|
||||||
model_dir = pretrained_model_name_or_path
|
model_dir = pretrained_model_name_or_path
|
||||||
return model_dir
|
return model_dir
|
||||||
|
|||||||
Reference in New Issue
Block a user