From a4582012fffd9c6c620fa91021a8c98e713cccdd Mon Sep 17 00:00:00 2001 From: Yunlin Mao Date: Sun, 1 Dec 2024 15:38:03 +0800 Subject: [PATCH] fix tqdm bar (#1108) --- modelscope/hub/snapshot_download.py | 45 +++++++++++++++++------------ 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/modelscope/hub/snapshot_download.py b/modelscope/hub/snapshot_download.py index 7ba0f446..f28c18e0 100644 --- a/modelscope/hub/snapshot_download.py +++ b/modelscope/hub/snapshot_download.py @@ -4,11 +4,12 @@ import fnmatch import os import re import uuid +from concurrent.futures import ThreadPoolExecutor from http.cookiejar import CookieJar from pathlib import Path from typing import Dict, List, Optional, Union -from tqdm.contrib.concurrent import thread_map +from tqdm.auto import tqdm from modelscope.hub.api import HubApi, ModelScopeConfig from modelscope.hub.errors import InvalidParameter @@ -325,8 +326,6 @@ def _snapshot_download( cache.save_model_version(revision_info=revision_detail) cache_root_path = cache.get_root_location() - - logger.info(f"Download {repo_type} '{repo_id}' successfully.") return cache_root_path @@ -391,6 +390,21 @@ def _get_valid_regex_pattern(patterns: List[str]): return None +def thread_download(func, iterable, max_workers, **kwargs): + # Create a tqdm progress bar with the total number of files to fetch + with tqdm( + total=len(iterable), + desc=f'Fetching {len(iterable)} files') as pbar: + # Define a wrapper function to update the progress bar + def progress_wrapper(*args, **kwargs): + result = func(*args, **kwargs) + pbar.update(1) + return result + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + executor.map(progress_wrapper, iterable) + + def _download_file_lists( repo_files: List[str], cache: ModelFileSystemCache, @@ -450,19 +464,18 @@ def _download_file_lists( fnmatch.fnmatch(repo_file['Path'], pattern) for pattern in allow_file_pattern): continue + # check model_file is exist in cache, if existed, skip download + if cache.exists(repo_file): + file_name = os.path.basename(repo_file['Name']) + logger.debug( + f'File {file_name} already in cache, skip downloading!') + continue except Exception as e: logger.warning('The file pattern is invalid : %s' % e) else: filtered_repo_files.append(repo_file) def _download_single_file(repo_file): - # check model_file is exist in cache, if existed, skip download, otherwise download - if cache.exists(repo_file): - file_name = os.path.basename(repo_file['Name']) - logger.debug( - f'File {file_name} already in cache, skip downloading!') - return - if repo_type == REPO_TYPE_MODEL: url = get_file_download_url( model_id=repo_id, @@ -481,11 +494,7 @@ def _download_file_lists( download_file(url, repo_file, temporary_cache_dir, cache, headers, cookies) - # Use thread_map for parallel downloading - thread_map( - _download_single_file, - filtered_repo_files, - max_workers=max_workers, - desc=f'Fetching {len(filtered_repo_files)} files', - leave=True, - position=max_workers) + if len(filtered_repo_files) > 0: + thread_download(_download_single_file, filtered_repo_files, + max_workers) + logger.info(f"Download {repo_type} '{repo_id}' successfully.")