add donwload command line and local_dir parameter (#866)

* add donwload command line and local_dir parameter

Co-authored-by: mulin.lyh <mulin.lyh@taobao.com>
This commit is contained in:
liuyhwangyh
2024-05-25 14:21:55 +08:00
committed by GitHub
parent 5c470f8941
commit f93a184d88
8 changed files with 147 additions and 38 deletions

View File

@@ -3,6 +3,7 @@
import argparse
from modelscope.cli.download import DownloadCMD
from modelscope.cli.login import LoginCMD
from modelscope.cli.modelcard import ModelCardCMD
from modelscope.cli.pipeline import PipelineCMD
from modelscope.cli.plugins import PluginsCMD
@@ -19,6 +20,7 @@ def run_cmd():
PipelineCMD.define_args(subparsers)
ModelCardCMD.define_args(subparsers)
ServerCMD.define_args(subparsers)
LoginCMD.define_args(subparsers)
args = parser.parse_args()

View File

@@ -3,6 +3,7 @@
from argparse import ArgumentParser
from modelscope.cli.base import CLICommand
from modelscope.hub.file_download import model_file_download
from modelscope.hub.snapshot_download import snapshot_download
@@ -22,9 +23,12 @@ class DownloadCMD(CLICommand):
def define_args(parsers: ArgumentParser):
""" define args for download command.
"""
parser = parsers.add_parser(DownloadCMD.name)
parser: ArgumentParser = parsers.add_parser(DownloadCMD.name)
parser.add_argument(
'model', type=str, help='Name of the model to be downloaded.')
'--model',
type=str,
required=True,
help='The model id to be downloaded.')
parser.add_argument(
'--revision',
type=str,
@@ -35,10 +39,57 @@ class DownloadCMD(CLICommand):
type=str,
default=None,
help='Cache directory to save model.')
parser.add_argument(
'--local_dir',
type=str,
default=None,
help='File will be downloaded to local location specified by'
'local_dir, in this case, cache_dir parameter will be ignored.')
parser.add_argument(
'files',
type=str,
default=None,
nargs='*',
help='Specify relative path to the repository file(s) to download.'
"(e.g 'tokenizer.json', 'onnx/decoder_model.onnx').")
parser.add_argument(
'--include',
nargs='*',
default=None,
type=str,
help='Glob patterns to match files to download.'
'Ignored if file is specified')
parser.add_argument(
'--exclude',
nargs='*',
type=str,
default=None,
help='Glob patterns to exclude from files to download.'
'Ignored if file is specified')
parser.set_defaults(func=subparser_func)
def execute(self):
snapshot_download(
self.args.model,
cache_dir=self.args.cache_dir,
revision=self.args.revision)
if len(self.args.files) == 1: # download single file
model_file_download(
self.args.model,
self.args.files[0],
cache_dir=self.args.cache_dir,
local_dir=self.args.local_dir,
revision=self.args.revision)
elif len(self.args.files) > 1: # download specified multiple files.
snapshot_download(
self.args.model,
revision=self.args.revision,
cache_dir=self.args.cache_dir,
local_dir=self.args.local_dir,
allow_file_pattern=self.args.files,
)
else: # download repo
snapshot_download(
self.args.model,
revision=self.args.revision,
cache_dir=self.args.cache_dir,
local_dir=self.args.local_dir,
allow_file_pattern=self.args.include,
ignore_file_pattern=self.args.exclude,
)

35
modelscope/cli/login.py Normal file
View File

@@ -0,0 +1,35 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from argparse import ArgumentParser
from modelscope.cli.base import CLICommand
from modelscope.hub.api import HubApi
def subparser_func(args):
""" Function which will be called for a specific sub parser.
"""
return LoginCMD(args)
class LoginCMD(CLICommand):
name = 'login'
def __init__(self, args):
self.args = args
@staticmethod
def define_args(parsers: ArgumentParser):
""" define args for login command.
"""
parser = parsers.add_parser(LoginCMD.name)
parser.add_argument(
'--token',
type=str,
required=True,
help='The Access Token for modelscope.')
parser.set_defaults(func=subparser_func)
def execute(self):
api = HubApi()
api.login(self.args.token)

View File

@@ -92,7 +92,7 @@ class HubApi:
def login(
self,
access_token: str,
) -> tuple():
):
"""Login with your SDK access token, which can be obtained from
https://www.modelscope.cn user center.

View File

@@ -30,6 +30,7 @@ MODELSCOPE_CLOUD_USERNAME = 'MODELSCOPE_USERNAME'
MODELSCOPE_SDK_DEBUG = 'MODELSCOPE_SDK_DEBUG'
ONE_YEAR_SECONDS = 24 * 365 * 60 * 60
MODELSCOPE_REQUEST_ID = 'X-Request-ID'
TEMPORARY_FOLDER_NAME = '._____temp'
class Licenses(object):

View File

@@ -18,7 +18,7 @@ from modelscope.hub.api import HubApi, ModelScopeConfig
from modelscope.hub.constants import (
API_FILE_DOWNLOAD_CHUNK_SIZE, API_FILE_DOWNLOAD_RETRY_TIMES,
API_FILE_DOWNLOAD_TIMEOUT, FILE_HASH, MODELSCOPE_DOWNLOAD_PARALLELS,
MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB)
MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB, TEMPORARY_FOLDER_NAME)
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
from modelscope.utils.file_utils import get_model_cache_root
from modelscope.utils.logger import get_logger
@@ -38,6 +38,7 @@ def model_file_download(
user_agent: Union[Dict, str, None] = None,
local_files_only: Optional[bool] = False,
cookies: Optional[CookieJar] = None,
local_dir: Optional[str] = None,
) -> Optional[str]: # pragma: no cover
"""Download from a given URL and cache it if it's not already present in the local cache.
@@ -55,6 +56,7 @@ def model_file_download(
local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the
local cached file if it exists. if `False`, download the file anyway even it exists.
cookies (CookieJar, optional): The cookie of download request.
local_dir (str, optional): Specific local directory path to which the file will be downloaded.
Returns:
string: string of local file or if networking is off, last version of
@@ -74,14 +76,8 @@ def model_file_download(
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
if some parameter value is invalid
"""
if cache_dir is None:
cache_dir = get_model_cache_root()
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
group_or_owner, name = model_id_to_group_owner_name(model_id)
temporary_cache_dir = os.path.join(cache_dir, 'temp', group_or_owner, name)
os.makedirs(temporary_cache_dir, exist_ok=True)
cache = ModelFileSystemCache(cache_dir, group_or_owner, name)
temporary_cache_dir, cache = create_temporary_directory_and_cache(
model_id, local_dir, cache_dir)
# if local_files_only is `True` and the file already exists in cached_path
# return the cached path
@@ -164,6 +160,26 @@ def model_file_download(
os.path.join(temporary_cache_dir, file_path))
def create_temporary_directory_and_cache(model_id: str, local_dir: str,
cache_dir: str):
group_or_owner, name = model_id_to_group_owner_name(model_id)
if local_dir is not None:
temporary_cache_dir = os.path.join(local_dir, TEMPORARY_FOLDER_NAME)
cache = ModelFileSystemCache(local_dir)
else:
if cache_dir is None:
cache_dir = get_model_cache_root()
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
temporary_cache_dir = os.path.join(cache_dir, TEMPORARY_FOLDER_NAME,
group_or_owner, name)
name = name.replace('.', '___')
cache = ModelFileSystemCache(cache_dir, group_or_owner, name)
os.makedirs(temporary_cache_dir, exist_ok=True)
return temporary_cache_dir, cache
def get_file_download_url(model_id: str, file_path: str, revision: str):
"""Format file download url according to `model_id`, `revision` and `file_path`.
e.g., Given `model_id=john/bert`, `revision=master`, `file_path=README.md`,

View File

@@ -1,22 +1,20 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import fnmatch
import os
import re
from http.cookiejar import CookieJar
from pathlib import Path
from typing import Dict, List, Optional, Union
from modelscope.hub.api import HubApi, ModelScopeConfig
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
from modelscope.utils.file_utils import get_model_cache_root
from modelscope.utils.logger import get_logger
from .constants import (FILE_HASH, MODELSCOPE_DOWNLOAD_PARALLELS,
MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB)
from .file_download import (get_file_download_url, http_get_file,
from .file_download import (create_temporary_directory_and_cache,
get_file_download_url, http_get_file,
parallel_download)
from .utils.caching import ModelFileSystemCache
from .utils.utils import (file_integrity_validation,
model_id_to_group_owner_name)
from .utils.utils import file_integrity_validation
logger = get_logger()
@@ -28,7 +26,9 @@ def snapshot_download(
user_agent: Optional[Union[Dict, str]] = None,
local_files_only: Optional[bool] = False,
cookies: Optional[CookieJar] = None,
ignore_file_pattern: List = None,
ignore_file_pattern: Optional[Union[str, List[str]]] = None,
allow_file_pattern: Optional[Union[str, List[str]]] = None,
local_dir: Optional[str] = None,
) -> str:
"""Download all files of a repo.
Downloads a whole snapshot of a repo's files at the specified revision. This
@@ -50,6 +50,9 @@ def snapshot_download(
cookies (CookieJar, optional): The cookie of the request, default None.
ignore_file_pattern (`str` or `List`, *optional*, default to `None`):
Any file pattern to be ignored in downloading, like exact file names or file extensions.
allow_file_pattern (`str` or `List`, *optional*, default to `None`):
Any file pattern to be downloading, like exact file names or file extensions.
local_dir (str, optional): Specific local directory path to which the file will be downloaded.
Raises:
ValueError: the value details.
@@ -65,17 +68,8 @@ def snapshot_download(
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
if some parameter value is invalid
"""
if cache_dir is None:
cache_dir = get_model_cache_root()
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
group_or_owner, name = model_id_to_group_owner_name(model_id)
temporary_cache_dir = os.path.join(cache_dir, 'temp', group_or_owner, name)
os.makedirs(temporary_cache_dir, exist_ok=True)
name = name.replace('.', '___')
cache = ModelFileSystemCache(cache_dir, group_or_owner, name)
temporary_cache_dir, cache = create_temporary_directory_and_cache(
model_id, local_dir, cache_dir)
if local_files_only:
if len(cache.cached_files) == 0:
@@ -123,11 +117,21 @@ def snapshot_download(
if isinstance(ignore_file_pattern, str):
ignore_file_pattern = [ignore_file_pattern]
if allow_file_pattern is not None:
if isinstance(allow_file_pattern, str):
allow_file_pattern = [allow_file_pattern]
for model_file in model_files:
if model_file['Type'] == 'tree' or \
any([re.search(pattern, model_file['Name']) is not None for pattern in ignore_file_pattern]):
any(fnmatch.fnmatch(model_file['Path'], pattern) for pattern in ignore_file_pattern):
continue
if allow_file_pattern is not None and allow_file_pattern:
if not any(
fnmatch.fnmatch(model_file['Path'], pattern)
for pattern in allow_file_pattern):
continue
# check model_file is exist in cache, if existed, skip download, otherwise download
if cache.exists(model_file):
file_name = os.path.basename(model_file['Name'])

View File

@@ -53,12 +53,12 @@ class DownloadCMDTest(unittest.TestCase):
super().tearDown()
def test_download(self):
cmd = f'python -m modelscope.cli.cli download {self.model_id}'
cmd = f'python -m modelscope.cli.cli download --model {self.model_id}'
stat, output = subprocess.getstatusoutput(cmd)
self.assertEqual(stat, 0)
def test_download_with_cache(self):
cmd = f'python -m modelscope.cli.cli download {self.model_id} --cache_dir {self.tmp_dir}'
cmd = f'python -m modelscope.cli.cli download --model {self.model_id} --cache_dir {self.tmp_dir}'
stat, output = subprocess.getstatusoutput(cmd)
if stat != 0:
print(output)
@@ -68,7 +68,7 @@ class DownloadCMDTest(unittest.TestCase):
f'{self.tmp_dir}/{self.model_id}/{download_model_file_name}'))
def test_download_with_revision(self):
cmd = f'python -m modelscope.cli.cli download {self.model_id} --revision {self.revision}'
cmd = f'python -m modelscope.cli.cli download --model {self.model_id} --revision {self.revision}'
stat, output = subprocess.getstatusoutput(cmd)
if stat != 0:
print(output)