mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 20:19:22 +01:00
fix tqdm bar (#1108)
This commit is contained in:
@@ -4,11 +4,12 @@ import fnmatch
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from http.cookiejar import CookieJar
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from tqdm.contrib.concurrent import thread_map
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from modelscope.hub.api import HubApi, ModelScopeConfig
|
||||
from modelscope.hub.errors import InvalidParameter
|
||||
@@ -325,8 +326,6 @@ def _snapshot_download(
|
||||
|
||||
cache.save_model_version(revision_info=revision_detail)
|
||||
cache_root_path = cache.get_root_location()
|
||||
|
||||
logger.info(f"Download {repo_type} '{repo_id}' successfully.")
|
||||
return cache_root_path
|
||||
|
||||
|
||||
@@ -391,6 +390,21 @@ def _get_valid_regex_pattern(patterns: List[str]):
|
||||
return None
|
||||
|
||||
|
||||
def thread_download(func, iterable, max_workers, **kwargs):
|
||||
# Create a tqdm progress bar with the total number of files to fetch
|
||||
with tqdm(
|
||||
total=len(iterable),
|
||||
desc=f'Fetching {len(iterable)} files') as pbar:
|
||||
# Define a wrapper function to update the progress bar
|
||||
def progress_wrapper(*args, **kwargs):
|
||||
result = func(*args, **kwargs)
|
||||
pbar.update(1)
|
||||
return result
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
executor.map(progress_wrapper, iterable)
|
||||
|
||||
|
||||
def _download_file_lists(
|
||||
repo_files: List[str],
|
||||
cache: ModelFileSystemCache,
|
||||
@@ -450,19 +464,18 @@ def _download_file_lists(
|
||||
fnmatch.fnmatch(repo_file['Path'], pattern)
|
||||
for pattern in allow_file_pattern):
|
||||
continue
|
||||
# check model_file is exist in cache, if existed, skip download
|
||||
if cache.exists(repo_file):
|
||||
file_name = os.path.basename(repo_file['Name'])
|
||||
logger.debug(
|
||||
f'File {file_name} already in cache, skip downloading!')
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning('The file pattern is invalid : %s' % e)
|
||||
else:
|
||||
filtered_repo_files.append(repo_file)
|
||||
|
||||
def _download_single_file(repo_file):
|
||||
# check model_file is exist in cache, if existed, skip download, otherwise download
|
||||
if cache.exists(repo_file):
|
||||
file_name = os.path.basename(repo_file['Name'])
|
||||
logger.debug(
|
||||
f'File {file_name} already in cache, skip downloading!')
|
||||
return
|
||||
|
||||
if repo_type == REPO_TYPE_MODEL:
|
||||
url = get_file_download_url(
|
||||
model_id=repo_id,
|
||||
@@ -481,11 +494,7 @@ def _download_file_lists(
|
||||
download_file(url, repo_file, temporary_cache_dir, cache, headers,
|
||||
cookies)
|
||||
|
||||
# Use thread_map for parallel downloading
|
||||
thread_map(
|
||||
_download_single_file,
|
||||
filtered_repo_files,
|
||||
max_workers=max_workers,
|
||||
desc=f'Fetching {len(filtered_repo_files)} files',
|
||||
leave=True,
|
||||
position=max_workers)
|
||||
if len(filtered_repo_files) > 0:
|
||||
thread_download(_download_single_file, filtered_repo_files,
|
||||
max_workers)
|
||||
logger.info(f"Download {repo_type} '{repo_id}' successfully.")
|
||||
|
||||
Reference in New Issue
Block a user