diff --git a/modelscope/cli/download.py b/modelscope/cli/download.py index aa23a301..b3281906 100644 --- a/modelscope/cli/download.py +++ b/modelscope/cli/download.py @@ -1,12 +1,14 @@ # Copyright (c) Alibaba, Inc. and its affiliates. - +import os from argparse import ArgumentParser from modelscope.cli.base import CLICommand +from modelscope.hub.constants import DEFAULT_MAX_WORKERS from modelscope.hub.file_download import (dataset_file_download, model_file_download) from modelscope.hub.snapshot_download import (dataset_snapshot_download, snapshot_download) +from modelscope.utils.constant import DEFAULT_DATASET_REVISION def subparser_func(args): @@ -89,6 +91,11 @@ class DownloadCMD(CLICommand): help='Glob patterns to exclude from files to download.' 'Ignored if file is specified') parser.set_defaults(func=subparser_func) + parser.add_argument( + '--max-workers', + type=int, + default=DEFAULT_MAX_WORKERS, + help='The maximum number of workers to download files.') def execute(self): if self.args.model or self.args.dataset: @@ -125,6 +132,7 @@ class DownloadCMD(CLICommand): cache_dir=self.args.cache_dir, local_dir=self.args.local_dir, allow_file_pattern=self.args.files, + max_workers=self.args.max_workers, ) else: # download repo snapshot_download( @@ -134,32 +142,36 @@ class DownloadCMD(CLICommand): local_dir=self.args.local_dir, allow_file_pattern=self.args.include, ignore_file_pattern=self.args.exclude, + max_workers=self.args.max_workers, ) elif self.args.dataset: + dataset_revision: str = self.args.revision if self.args.revision else DEFAULT_DATASET_REVISION if len(self.args.files) == 1: # download single file dataset_file_download( self.args.dataset, self.args.files[0], cache_dir=self.args.cache_dir, local_dir=self.args.local_dir, - revision=self.args.revision) + revision=dataset_revision) elif len( self.args.files) > 1: # download specified multiple files. dataset_snapshot_download( self.args.dataset, - revision=self.args.revision, + revision=dataset_revision, cache_dir=self.args.cache_dir, local_dir=self.args.local_dir, allow_file_pattern=self.args.files, + max_workers=self.args.max_workers, ) else: # download repo dataset_snapshot_download( self.args.dataset, - revision=self.args.revision, + revision=dataset_revision, cache_dir=self.args.cache_dir, local_dir=self.args.local_dir, allow_file_pattern=self.args.include, ignore_file_pattern=self.args.exclude, + max_workers=self.args.max_workers, ) else: pass # noop diff --git a/modelscope/hub/constants.py b/modelscope/hub/constants.py index b3d03e1a..53739a58 100644 --- a/modelscope/hub/constants.py +++ b/modelscope/hub/constants.py @@ -33,6 +33,7 @@ MODELSCOPE_ENABLE_DEFAULT_HASH_VALIDATION = 'MODELSCOPE_ENABLE_DEFAULT_HASH_VALI ONE_YEAR_SECONDS = 24 * 365 * 60 * 60 MODELSCOPE_REQUEST_ID = 'X-Request-ID' TEMPORARY_FOLDER_NAME = '._____temp' +DEFAULT_MAX_WORKERS = min(8, os.cpu_count() + 4) class Licenses(object): diff --git a/modelscope/hub/file_download.py b/modelscope/hub/file_download.py index 40ac8a03..64ada050 100644 --- a/modelscope/hub/file_download.py +++ b/modelscope/hub/file_download.py @@ -163,6 +163,7 @@ def _repo_file_download( local_files_only: Optional[bool] = False, cookies: Optional[CookieJar] = None, local_dir: Optional[str] = None, + disable_tqdm: bool = False, ) -> Optional[str]: # pragma: no cover if not repo_type: @@ -275,6 +276,9 @@ def _repo_file_download( dataset_name=name, namespace=group_or_owner, revision=revision) + else: + raise ValueError(f'Invalid repo type {repo_type}') + return download_file(url_to_download, file_to_download_meta, temporary_cache_dir, cache, headers, cookies) @@ -379,6 +383,7 @@ def parallel_download( cookies: CookieJar, headers: Optional[Dict[str, str]] = None, file_size: int = None, + disable_tqdm: bool = False, ): # create temp file with tqdm( @@ -389,6 +394,7 @@ def parallel_download( initial=0, desc='Downloading [' + file_name + ']', leave=True, + disable=disable_tqdm, ) as progress: PART_SIZE = 160 * 1024 * 1024 # every part is 160M tasks = [] @@ -425,6 +431,7 @@ def http_get_model_file( file_size: int, cookies: CookieJar, headers: Optional[Dict[str, str]] = None, + disable_tqdm: bool = False, ): """Download remote file, will retry 5 times before giving up on errors. @@ -441,6 +448,7 @@ def http_get_model_file( cookies used to authentication the user, which is used for downloading private repos headers(Dict[str, str], optional): http headers to carry necessary info when requesting the remote file + disable_tqdm(bool, optional): Disable the progress bar with tqdm. Raises: FileDownloadError: File download failed. @@ -466,6 +474,7 @@ def http_get_model_file( initial=0, desc='Downloading [' + file_name + ']', leave=True, + disable=disable_tqdm, ) as progress: if file_size == 0: # Avoid empty file server request @@ -589,8 +598,15 @@ def http_get_file( os.replace(temp_file.name, os.path.join(local_dir, file_name)) -def download_file(url, file_meta, temporary_cache_dir, cache, headers, - cookies): +def download_file( + url, + file_meta, + temporary_cache_dir, + cache, + headers, + cookies, + disable_tqdm=False, +): if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_meta[ 'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: # parallel download large file. parallel_download( @@ -599,7 +615,9 @@ def download_file(url, file_meta, temporary_cache_dir, cache, headers, file_meta['Path'], headers=headers, cookies=None if cookies is None else cookies.get_dict(), - file_size=file_meta['Size']) + file_size=file_meta['Size'], + disable_tqdm=disable_tqdm, + ) else: http_get_model_file( url, @@ -607,7 +625,9 @@ def download_file(url, file_meta, temporary_cache_dir, cache, headers, file_meta['Path'], file_size=file_meta['Size'], headers=headers, - cookies=cookies) + cookies=cookies, + disable_tqdm=disable_tqdm, + ) # check file integrity temp_file = os.path.join(temporary_cache_dir, file_meta['Path']) diff --git a/modelscope/hub/snapshot_download.py b/modelscope/hub/snapshot_download.py index 015cadbd..7b0da872 100644 --- a/modelscope/hub/snapshot_download.py +++ b/modelscope/hub/snapshot_download.py @@ -4,15 +4,14 @@ import fnmatch import os import re import uuid -from concurrent.futures import ThreadPoolExecutor from http.cookiejar import CookieJar from pathlib import Path from typing import Dict, List, Optional, Union -from tqdm.auto import tqdm - from modelscope.hub.api import HubApi, ModelScopeConfig from modelscope.hub.errors import InvalidParameter +from modelscope.hub.file_download import (create_temporary_directory_and_cache, + download_file, get_file_download_url) from modelscope.hub.utils.caching import ModelFileSystemCache from modelscope.hub.utils.utils import (get_model_masked_directory, model_id_to_group_owner_name) @@ -21,8 +20,7 @@ from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, REPO_TYPE_DATASET, REPO_TYPE_MODEL, REPO_TYPE_SUPPORT) from modelscope.utils.logger import get_logger -from .file_download import (create_temporary_directory_and_cache, - download_file, get_file_download_url) +from modelscope.utils.thread_utils import thread_executor logger = get_logger() @@ -390,21 +388,6 @@ def _get_valid_regex_pattern(patterns: List[str]): return None -def thread_download(func, iterable, max_workers, **kwargs): - # Create a tqdm progress bar with the total number of files to fetch - with tqdm( - total=len(iterable), - desc=f'Fetching {len(iterable)} files') as pbar: - # Define a wrapper function to update the progress bar - def progress_wrapper(*args, **kwargs): - result = func(*args, **kwargs) - pbar.update(1) - return result - - with ThreadPoolExecutor(max_workers=max_workers) as executor: - executor.map(progress_wrapper, iterable) - - def _download_file_lists( repo_files: List[str], cache: ModelFileSystemCache, @@ -476,6 +459,7 @@ def _download_file_lists( else: filtered_repo_files.append(repo_file) + @thread_executor(max_workers=max_workers, disable_tqdm=False) def _download_single_file(repo_file): if repo_type == REPO_TYPE_MODEL: url = get_file_download_url( @@ -492,10 +476,18 @@ def _download_file_lists( raise InvalidParameter( f'Invalid repo type: {repo_type}, supported types: {REPO_TYPE_SUPPORT}' ) - download_file(url, repo_file, temporary_cache_dir, cache, headers, - cookies) + download_file( + url, + repo_file, + temporary_cache_dir, + cache, + headers, + cookies, + disable_tqdm=True, + ) if len(filtered_repo_files) > 0: - thread_download(_download_single_file, filtered_repo_files, - max_workers) + logger.info( + f'Got {len(filtered_repo_files)} files, start to download ...') + _download_single_file(filtered_repo_files) logger.info(f"Download {repo_type} '{repo_id}' successfully.") diff --git a/modelscope/utils/thread_utils.py b/modelscope/utils/thread_utils.py new file mode 100644 index 00000000..bb323c51 --- /dev/null +++ b/modelscope/utils/thread_utils.py @@ -0,0 +1,62 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from concurrent.futures import ThreadPoolExecutor, as_completed +from functools import wraps + +from tqdm import tqdm + +from modelscope.hub.constants import DEFAULT_MAX_WORKERS +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +def thread_executor(max_workers: int = DEFAULT_MAX_WORKERS, + disable_tqdm=False): + """ + A decorator to execute a function in a threaded manner using ThreadPoolExecutor. + Args: + max_workers (int): The maximum number of threads to use. + disable_tqdm (bool): disable progress bar. + Returns: + function: A wrapped function that executes with threading and a progress bar. + Examples: + >>> from modelscope.utils.thread_utils import thread_executor + >>> import time + >>> @thread_executor(max_workers=8) + ... def process_item(item, x, y): + ... # do something to single item + ... time.sleep(1) + ... return str(item) + str(x) + str(y) + >>> items = [1, 2, 3] + >>> process_item(items, x='abc', y='xyz') + """ + + def decorator(func): + + @wraps(func) + def wrapper(iterable, *args, **kwargs): + results = [] + # Create a tqdm progress bar with the total number of items to process + with tqdm( + total=len(iterable), + desc=f'Processing {len(iterable)} items', + disable=disable_tqdm, + ) as pbar: + # Define a wrapper function to update the progress bar + with ThreadPoolExecutor(max_workers=max_workers) as executor: + # Submit all tasks + futures = { + executor.submit(func, item, *args, **kwargs): item + for item in iterable + } + + # Update the progress bar as tasks complete + for future in as_completed(futures): + pbar.update(1) + results.append(future.result()) + return results + + return wrapper + + return decorator diff --git a/modelscope/version.py b/modelscope/version.py index f3062f40..4d0af473 100644 --- a/modelscope/version.py +++ b/modelscope/version.py @@ -1,5 +1,5 @@ # Make sure to modify __release_datetime__ to release time when making official release. -__version__ = '1.21.0' +__version__ = '1.21.1' # default release datetime for branches under active development is set # to be a time far-far-away-into-the-future -__release_datetime__ = '2024-12-03 08:00:00' +__release_datetime__ = '2024-12-29 23:00:00'