fix dataset download cli

This commit is contained in:
xingjun.wang
2024-12-29 21:41:48 +08:00
parent c09aabc630
commit 6ef02276e0
6 changed files with 121 additions and 34 deletions

View File

@@ -1,12 +1,14 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from argparse import ArgumentParser
from modelscope.cli.base import CLICommand
from modelscope.hub.constants import DEFAULT_MAX_WORKERS
from modelscope.hub.file_download import (dataset_file_download,
model_file_download)
from modelscope.hub.snapshot_download import (dataset_snapshot_download,
snapshot_download)
from modelscope.utils.constant import DEFAULT_DATASET_REVISION
def subparser_func(args):
@@ -89,6 +91,11 @@ class DownloadCMD(CLICommand):
help='Glob patterns to exclude from files to download.'
'Ignored if file is specified')
parser.set_defaults(func=subparser_func)
parser.add_argument(
'--max-workers',
type=int,
default=DEFAULT_MAX_WORKERS,
help='The maximum number of workers to download files.')
def execute(self):
if self.args.model or self.args.dataset:
@@ -125,6 +132,7 @@ class DownloadCMD(CLICommand):
cache_dir=self.args.cache_dir,
local_dir=self.args.local_dir,
allow_file_pattern=self.args.files,
max_workers=self.args.max_workers,
)
else: # download repo
snapshot_download(
@@ -134,32 +142,36 @@ class DownloadCMD(CLICommand):
local_dir=self.args.local_dir,
allow_file_pattern=self.args.include,
ignore_file_pattern=self.args.exclude,
max_workers=self.args.max_workers,
)
elif self.args.dataset:
dataset_revision: str = self.args.revision if self.args.revision else DEFAULT_DATASET_REVISION
if len(self.args.files) == 1: # download single file
dataset_file_download(
self.args.dataset,
self.args.files[0],
cache_dir=self.args.cache_dir,
local_dir=self.args.local_dir,
revision=self.args.revision)
revision=dataset_revision)
elif len(
self.args.files) > 1: # download specified multiple files.
dataset_snapshot_download(
self.args.dataset,
revision=self.args.revision,
revision=dataset_revision,
cache_dir=self.args.cache_dir,
local_dir=self.args.local_dir,
allow_file_pattern=self.args.files,
max_workers=self.args.max_workers,
)
else: # download repo
dataset_snapshot_download(
self.args.dataset,
revision=self.args.revision,
revision=dataset_revision,
cache_dir=self.args.cache_dir,
local_dir=self.args.local_dir,
allow_file_pattern=self.args.include,
ignore_file_pattern=self.args.exclude,
max_workers=self.args.max_workers,
)
else:
pass # noop

View File

@@ -33,6 +33,7 @@ MODELSCOPE_ENABLE_DEFAULT_HASH_VALIDATION = 'MODELSCOPE_ENABLE_DEFAULT_HASH_VALI
ONE_YEAR_SECONDS = 24 * 365 * 60 * 60
MODELSCOPE_REQUEST_ID = 'X-Request-ID'
TEMPORARY_FOLDER_NAME = '._____temp'
DEFAULT_MAX_WORKERS = min(8, os.cpu_count() + 4)
class Licenses(object):

View File

@@ -163,6 +163,7 @@ def _repo_file_download(
local_files_only: Optional[bool] = False,
cookies: Optional[CookieJar] = None,
local_dir: Optional[str] = None,
disable_tqdm: bool = False,
) -> Optional[str]: # pragma: no cover
if not repo_type:
@@ -275,6 +276,9 @@ def _repo_file_download(
dataset_name=name,
namespace=group_or_owner,
revision=revision)
else:
raise ValueError(f'Invalid repo type {repo_type}')
return download_file(url_to_download, file_to_download_meta,
temporary_cache_dir, cache, headers, cookies)
@@ -379,6 +383,7 @@ def parallel_download(
cookies: CookieJar,
headers: Optional[Dict[str, str]] = None,
file_size: int = None,
disable_tqdm: bool = False,
):
# create temp file
with tqdm(
@@ -389,6 +394,7 @@ def parallel_download(
initial=0,
desc='Downloading [' + file_name + ']',
leave=True,
disable=disable_tqdm,
) as progress:
PART_SIZE = 160 * 1024 * 1024 # every part is 160M
tasks = []
@@ -425,6 +431,7 @@ def http_get_model_file(
file_size: int,
cookies: CookieJar,
headers: Optional[Dict[str, str]] = None,
disable_tqdm: bool = False,
):
"""Download remote file, will retry 5 times before giving up on errors.
@@ -441,6 +448,7 @@ def http_get_model_file(
cookies used to authentication the user, which is used for downloading private repos
headers(Dict[str, str], optional):
http headers to carry necessary info when requesting the remote file
disable_tqdm(bool, optional): Disable the progress bar with tqdm.
Raises:
FileDownloadError: File download failed.
@@ -466,6 +474,7 @@ def http_get_model_file(
initial=0,
desc='Downloading [' + file_name + ']',
leave=True,
disable=disable_tqdm,
) as progress:
if file_size == 0:
# Avoid empty file server request
@@ -589,8 +598,15 @@ def http_get_file(
os.replace(temp_file.name, os.path.join(local_dir, file_name))
def download_file(url, file_meta, temporary_cache_dir, cache, headers,
cookies):
def download_file(
url,
file_meta,
temporary_cache_dir,
cache,
headers,
cookies,
disable_tqdm=False,
):
if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_meta[
'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: # parallel download large file.
parallel_download(
@@ -599,7 +615,9 @@ def download_file(url, file_meta, temporary_cache_dir, cache, headers,
file_meta['Path'],
headers=headers,
cookies=None if cookies is None else cookies.get_dict(),
file_size=file_meta['Size'])
file_size=file_meta['Size'],
disable_tqdm=disable_tqdm,
)
else:
http_get_model_file(
url,
@@ -607,7 +625,9 @@ def download_file(url, file_meta, temporary_cache_dir, cache, headers,
file_meta['Path'],
file_size=file_meta['Size'],
headers=headers,
cookies=cookies)
cookies=cookies,
disable_tqdm=disable_tqdm,
)
# check file integrity
temp_file = os.path.join(temporary_cache_dir, file_meta['Path'])

View File

@@ -4,15 +4,14 @@ import fnmatch
import os
import re
import uuid
from concurrent.futures import ThreadPoolExecutor
from http.cookiejar import CookieJar
from pathlib import Path
from typing import Dict, List, Optional, Union
from tqdm.auto import tqdm
from modelscope.hub.api import HubApi, ModelScopeConfig
from modelscope.hub.errors import InvalidParameter
from modelscope.hub.file_download import (create_temporary_directory_and_cache,
download_file, get_file_download_url)
from modelscope.hub.utils.caching import ModelFileSystemCache
from modelscope.hub.utils.utils import (get_model_masked_directory,
model_id_to_group_owner_name)
@@ -21,8 +20,7 @@ from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
REPO_TYPE_DATASET, REPO_TYPE_MODEL,
REPO_TYPE_SUPPORT)
from modelscope.utils.logger import get_logger
from .file_download import (create_temporary_directory_and_cache,
download_file, get_file_download_url)
from modelscope.utils.thread_utils import thread_executor
logger = get_logger()
@@ -390,21 +388,6 @@ def _get_valid_regex_pattern(patterns: List[str]):
return None
def thread_download(func, iterable, max_workers, **kwargs):
# Create a tqdm progress bar with the total number of files to fetch
with tqdm(
total=len(iterable),
desc=f'Fetching {len(iterable)} files') as pbar:
# Define a wrapper function to update the progress bar
def progress_wrapper(*args, **kwargs):
result = func(*args, **kwargs)
pbar.update(1)
return result
with ThreadPoolExecutor(max_workers=max_workers) as executor:
executor.map(progress_wrapper, iterable)
def _download_file_lists(
repo_files: List[str],
cache: ModelFileSystemCache,
@@ -476,6 +459,7 @@ def _download_file_lists(
else:
filtered_repo_files.append(repo_file)
@thread_executor(max_workers=max_workers, disable_tqdm=False)
def _download_single_file(repo_file):
if repo_type == REPO_TYPE_MODEL:
url = get_file_download_url(
@@ -492,10 +476,18 @@ def _download_file_lists(
raise InvalidParameter(
f'Invalid repo type: {repo_type}, supported types: {REPO_TYPE_SUPPORT}'
)
download_file(url, repo_file, temporary_cache_dir, cache, headers,
cookies)
download_file(
url,
repo_file,
temporary_cache_dir,
cache,
headers,
cookies,
disable_tqdm=True,
)
if len(filtered_repo_files) > 0:
thread_download(_download_single_file, filtered_repo_files,
max_workers)
logger.info(
f'Got {len(filtered_repo_files)} files, start to download ...')
_download_single_file(filtered_repo_files)
logger.info(f"Download {repo_type} '{repo_id}' successfully.")

View File

@@ -0,0 +1,62 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import wraps
from tqdm import tqdm
from modelscope.hub.constants import DEFAULT_MAX_WORKERS
from modelscope.utils.logger import get_logger
logger = get_logger()
def thread_executor(max_workers: int = DEFAULT_MAX_WORKERS,
disable_tqdm=False):
"""
A decorator to execute a function in a threaded manner using ThreadPoolExecutor.
Args:
max_workers (int): The maximum number of threads to use.
disable_tqdm (bool): disable progress bar.
Returns:
function: A wrapped function that executes with threading and a progress bar.
Examples:
>>> from modelscope.utils.thread_utils import thread_executor
>>> import time
>>> @thread_executor(max_workers=8)
... def process_item(item, x, y):
... # do something to single item
... time.sleep(1)
... return str(item) + str(x) + str(y)
>>> items = [1, 2, 3]
>>> process_item(items, x='abc', y='xyz')
"""
def decorator(func):
@wraps(func)
def wrapper(iterable, *args, **kwargs):
results = []
# Create a tqdm progress bar with the total number of items to process
with tqdm(
total=len(iterable),
desc=f'Processing {len(iterable)} items',
disable=disable_tqdm,
) as pbar:
# Define a wrapper function to update the progress bar
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all tasks
futures = {
executor.submit(func, item, *args, **kwargs): item
for item in iterable
}
# Update the progress bar as tasks complete
for future in as_completed(futures):
pbar.update(1)
results.append(future.result())
return results
return wrapper
return decorator

View File

@@ -1,5 +1,5 @@
# Make sure to modify __release_datetime__ to release time when making official release.
__version__ = '1.21.0'
__version__ = '1.21.1'
# default release datetime for branches under active development is set
# to be a time far-far-away-into-the-future
__release_datetime__ = '2024-12-03 08:00:00'
__release_datetime__ = '2024-12-29 23:00:00'