mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 12:10:09 +01:00
fix dataset download cli
This commit is contained in:
@@ -1,12 +1,14 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from modelscope.cli.base import CLICommand
|
||||
from modelscope.hub.constants import DEFAULT_MAX_WORKERS
|
||||
from modelscope.hub.file_download import (dataset_file_download,
|
||||
model_file_download)
|
||||
from modelscope.hub.snapshot_download import (dataset_snapshot_download,
|
||||
snapshot_download)
|
||||
from modelscope.utils.constant import DEFAULT_DATASET_REVISION
|
||||
|
||||
|
||||
def subparser_func(args):
|
||||
@@ -89,6 +91,11 @@ class DownloadCMD(CLICommand):
|
||||
help='Glob patterns to exclude from files to download.'
|
||||
'Ignored if file is specified')
|
||||
parser.set_defaults(func=subparser_func)
|
||||
parser.add_argument(
|
||||
'--max-workers',
|
||||
type=int,
|
||||
default=DEFAULT_MAX_WORKERS,
|
||||
help='The maximum number of workers to download files.')
|
||||
|
||||
def execute(self):
|
||||
if self.args.model or self.args.dataset:
|
||||
@@ -125,6 +132,7 @@ class DownloadCMD(CLICommand):
|
||||
cache_dir=self.args.cache_dir,
|
||||
local_dir=self.args.local_dir,
|
||||
allow_file_pattern=self.args.files,
|
||||
max_workers=self.args.max_workers,
|
||||
)
|
||||
else: # download repo
|
||||
snapshot_download(
|
||||
@@ -134,32 +142,36 @@ class DownloadCMD(CLICommand):
|
||||
local_dir=self.args.local_dir,
|
||||
allow_file_pattern=self.args.include,
|
||||
ignore_file_pattern=self.args.exclude,
|
||||
max_workers=self.args.max_workers,
|
||||
)
|
||||
elif self.args.dataset:
|
||||
dataset_revision: str = self.args.revision if self.args.revision else DEFAULT_DATASET_REVISION
|
||||
if len(self.args.files) == 1: # download single file
|
||||
dataset_file_download(
|
||||
self.args.dataset,
|
||||
self.args.files[0],
|
||||
cache_dir=self.args.cache_dir,
|
||||
local_dir=self.args.local_dir,
|
||||
revision=self.args.revision)
|
||||
revision=dataset_revision)
|
||||
elif len(
|
||||
self.args.files) > 1: # download specified multiple files.
|
||||
dataset_snapshot_download(
|
||||
self.args.dataset,
|
||||
revision=self.args.revision,
|
||||
revision=dataset_revision,
|
||||
cache_dir=self.args.cache_dir,
|
||||
local_dir=self.args.local_dir,
|
||||
allow_file_pattern=self.args.files,
|
||||
max_workers=self.args.max_workers,
|
||||
)
|
||||
else: # download repo
|
||||
dataset_snapshot_download(
|
||||
self.args.dataset,
|
||||
revision=self.args.revision,
|
||||
revision=dataset_revision,
|
||||
cache_dir=self.args.cache_dir,
|
||||
local_dir=self.args.local_dir,
|
||||
allow_file_pattern=self.args.include,
|
||||
ignore_file_pattern=self.args.exclude,
|
||||
max_workers=self.args.max_workers,
|
||||
)
|
||||
else:
|
||||
pass # noop
|
||||
|
||||
@@ -33,6 +33,7 @@ MODELSCOPE_ENABLE_DEFAULT_HASH_VALIDATION = 'MODELSCOPE_ENABLE_DEFAULT_HASH_VALI
|
||||
ONE_YEAR_SECONDS = 24 * 365 * 60 * 60
|
||||
MODELSCOPE_REQUEST_ID = 'X-Request-ID'
|
||||
TEMPORARY_FOLDER_NAME = '._____temp'
|
||||
DEFAULT_MAX_WORKERS = min(8, os.cpu_count() + 4)
|
||||
|
||||
|
||||
class Licenses(object):
|
||||
|
||||
@@ -163,6 +163,7 @@ def _repo_file_download(
|
||||
local_files_only: Optional[bool] = False,
|
||||
cookies: Optional[CookieJar] = None,
|
||||
local_dir: Optional[str] = None,
|
||||
disable_tqdm: bool = False,
|
||||
) -> Optional[str]: # pragma: no cover
|
||||
|
||||
if not repo_type:
|
||||
@@ -275,6 +276,9 @@ def _repo_file_download(
|
||||
dataset_name=name,
|
||||
namespace=group_or_owner,
|
||||
revision=revision)
|
||||
else:
|
||||
raise ValueError(f'Invalid repo type {repo_type}')
|
||||
|
||||
return download_file(url_to_download, file_to_download_meta,
|
||||
temporary_cache_dir, cache, headers, cookies)
|
||||
|
||||
@@ -379,6 +383,7 @@ def parallel_download(
|
||||
cookies: CookieJar,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
file_size: int = None,
|
||||
disable_tqdm: bool = False,
|
||||
):
|
||||
# create temp file
|
||||
with tqdm(
|
||||
@@ -389,6 +394,7 @@ def parallel_download(
|
||||
initial=0,
|
||||
desc='Downloading [' + file_name + ']',
|
||||
leave=True,
|
||||
disable=disable_tqdm,
|
||||
) as progress:
|
||||
PART_SIZE = 160 * 1024 * 1024 # every part is 160M
|
||||
tasks = []
|
||||
@@ -425,6 +431,7 @@ def http_get_model_file(
|
||||
file_size: int,
|
||||
cookies: CookieJar,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
disable_tqdm: bool = False,
|
||||
):
|
||||
"""Download remote file, will retry 5 times before giving up on errors.
|
||||
|
||||
@@ -441,6 +448,7 @@ def http_get_model_file(
|
||||
cookies used to authentication the user, which is used for downloading private repos
|
||||
headers(Dict[str, str], optional):
|
||||
http headers to carry necessary info when requesting the remote file
|
||||
disable_tqdm(bool, optional): Disable the progress bar with tqdm.
|
||||
|
||||
Raises:
|
||||
FileDownloadError: File download failed.
|
||||
@@ -466,6 +474,7 @@ def http_get_model_file(
|
||||
initial=0,
|
||||
desc='Downloading [' + file_name + ']',
|
||||
leave=True,
|
||||
disable=disable_tqdm,
|
||||
) as progress:
|
||||
if file_size == 0:
|
||||
# Avoid empty file server request
|
||||
@@ -589,8 +598,15 @@ def http_get_file(
|
||||
os.replace(temp_file.name, os.path.join(local_dir, file_name))
|
||||
|
||||
|
||||
def download_file(url, file_meta, temporary_cache_dir, cache, headers,
|
||||
cookies):
|
||||
def download_file(
|
||||
url,
|
||||
file_meta,
|
||||
temporary_cache_dir,
|
||||
cache,
|
||||
headers,
|
||||
cookies,
|
||||
disable_tqdm=False,
|
||||
):
|
||||
if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_meta[
|
||||
'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: # parallel download large file.
|
||||
parallel_download(
|
||||
@@ -599,7 +615,9 @@ def download_file(url, file_meta, temporary_cache_dir, cache, headers,
|
||||
file_meta['Path'],
|
||||
headers=headers,
|
||||
cookies=None if cookies is None else cookies.get_dict(),
|
||||
file_size=file_meta['Size'])
|
||||
file_size=file_meta['Size'],
|
||||
disable_tqdm=disable_tqdm,
|
||||
)
|
||||
else:
|
||||
http_get_model_file(
|
||||
url,
|
||||
@@ -607,7 +625,9 @@ def download_file(url, file_meta, temporary_cache_dir, cache, headers,
|
||||
file_meta['Path'],
|
||||
file_size=file_meta['Size'],
|
||||
headers=headers,
|
||||
cookies=cookies)
|
||||
cookies=cookies,
|
||||
disable_tqdm=disable_tqdm,
|
||||
)
|
||||
|
||||
# check file integrity
|
||||
temp_file = os.path.join(temporary_cache_dir, file_meta['Path'])
|
||||
|
||||
@@ -4,15 +4,14 @@ import fnmatch
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from http.cookiejar import CookieJar
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from modelscope.hub.api import HubApi, ModelScopeConfig
|
||||
from modelscope.hub.errors import InvalidParameter
|
||||
from modelscope.hub.file_download import (create_temporary_directory_and_cache,
|
||||
download_file, get_file_download_url)
|
||||
from modelscope.hub.utils.caching import ModelFileSystemCache
|
||||
from modelscope.hub.utils.utils import (get_model_masked_directory,
|
||||
model_id_to_group_owner_name)
|
||||
@@ -21,8 +20,7 @@ from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
|
||||
REPO_TYPE_DATASET, REPO_TYPE_MODEL,
|
||||
REPO_TYPE_SUPPORT)
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .file_download import (create_temporary_directory_and_cache,
|
||||
download_file, get_file_download_url)
|
||||
from modelscope.utils.thread_utils import thread_executor
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
@@ -390,21 +388,6 @@ def _get_valid_regex_pattern(patterns: List[str]):
|
||||
return None
|
||||
|
||||
|
||||
def thread_download(func, iterable, max_workers, **kwargs):
|
||||
# Create a tqdm progress bar with the total number of files to fetch
|
||||
with tqdm(
|
||||
total=len(iterable),
|
||||
desc=f'Fetching {len(iterable)} files') as pbar:
|
||||
# Define a wrapper function to update the progress bar
|
||||
def progress_wrapper(*args, **kwargs):
|
||||
result = func(*args, **kwargs)
|
||||
pbar.update(1)
|
||||
return result
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
executor.map(progress_wrapper, iterable)
|
||||
|
||||
|
||||
def _download_file_lists(
|
||||
repo_files: List[str],
|
||||
cache: ModelFileSystemCache,
|
||||
@@ -476,6 +459,7 @@ def _download_file_lists(
|
||||
else:
|
||||
filtered_repo_files.append(repo_file)
|
||||
|
||||
@thread_executor(max_workers=max_workers, disable_tqdm=False)
|
||||
def _download_single_file(repo_file):
|
||||
if repo_type == REPO_TYPE_MODEL:
|
||||
url = get_file_download_url(
|
||||
@@ -492,10 +476,18 @@ def _download_file_lists(
|
||||
raise InvalidParameter(
|
||||
f'Invalid repo type: {repo_type}, supported types: {REPO_TYPE_SUPPORT}'
|
||||
)
|
||||
download_file(url, repo_file, temporary_cache_dir, cache, headers,
|
||||
cookies)
|
||||
download_file(
|
||||
url,
|
||||
repo_file,
|
||||
temporary_cache_dir,
|
||||
cache,
|
||||
headers,
|
||||
cookies,
|
||||
disable_tqdm=True,
|
||||
)
|
||||
|
||||
if len(filtered_repo_files) > 0:
|
||||
thread_download(_download_single_file, filtered_repo_files,
|
||||
max_workers)
|
||||
logger.info(
|
||||
f'Got {len(filtered_repo_files)} files, start to download ...')
|
||||
_download_single_file(filtered_repo_files)
|
||||
logger.info(f"Download {repo_type} '{repo_id}' successfully.")
|
||||
|
||||
62
modelscope/utils/thread_utils.py
Normal file
62
modelscope/utils/thread_utils.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from functools import wraps
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from modelscope.hub.constants import DEFAULT_MAX_WORKERS
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def thread_executor(max_workers: int = DEFAULT_MAX_WORKERS,
|
||||
disable_tqdm=False):
|
||||
"""
|
||||
A decorator to execute a function in a threaded manner using ThreadPoolExecutor.
|
||||
Args:
|
||||
max_workers (int): The maximum number of threads to use.
|
||||
disable_tqdm (bool): disable progress bar.
|
||||
Returns:
|
||||
function: A wrapped function that executes with threading and a progress bar.
|
||||
Examples:
|
||||
>>> from modelscope.utils.thread_utils import thread_executor
|
||||
>>> import time
|
||||
>>> @thread_executor(max_workers=8)
|
||||
... def process_item(item, x, y):
|
||||
... # do something to single item
|
||||
... time.sleep(1)
|
||||
... return str(item) + str(x) + str(y)
|
||||
>>> items = [1, 2, 3]
|
||||
>>> process_item(items, x='abc', y='xyz')
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(iterable, *args, **kwargs):
|
||||
results = []
|
||||
# Create a tqdm progress bar with the total number of items to process
|
||||
with tqdm(
|
||||
total=len(iterable),
|
||||
desc=f'Processing {len(iterable)} items',
|
||||
disable=disable_tqdm,
|
||||
) as pbar:
|
||||
# Define a wrapper function to update the progress bar
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit all tasks
|
||||
futures = {
|
||||
executor.submit(func, item, *args, **kwargs): item
|
||||
for item in iterable
|
||||
}
|
||||
|
||||
# Update the progress bar as tasks complete
|
||||
for future in as_completed(futures):
|
||||
pbar.update(1)
|
||||
results.append(future.result())
|
||||
return results
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
@@ -1,5 +1,5 @@
|
||||
# Make sure to modify __release_datetime__ to release time when making official release.
|
||||
__version__ = '1.21.0'
|
||||
__version__ = '1.21.1'
|
||||
# default release datetime for branches under active development is set
|
||||
# to be a time far-far-away-into-the-future
|
||||
__release_datetime__ = '2024-12-03 08:00:00'
|
||||
__release_datetime__ = '2024-12-29 23:00:00'
|
||||
|
||||
Reference in New Issue
Block a user