fix tqdm bar (#1108)

This commit is contained in:
Yunlin Mao
2024-12-01 15:38:03 +08:00
committed by GitHub
parent a721220fa1
commit a4582012ff

View File

@@ -4,11 +4,12 @@ import fnmatch
import os
import re
import uuid
from concurrent.futures import ThreadPoolExecutor
from http.cookiejar import CookieJar
from pathlib import Path
from typing import Dict, List, Optional, Union
from tqdm.contrib.concurrent import thread_map
from tqdm.auto import tqdm
from modelscope.hub.api import HubApi, ModelScopeConfig
from modelscope.hub.errors import InvalidParameter
@@ -325,8 +326,6 @@ def _snapshot_download(
cache.save_model_version(revision_info=revision_detail)
cache_root_path = cache.get_root_location()
logger.info(f"Download {repo_type} '{repo_id}' successfully.")
return cache_root_path
@@ -391,6 +390,21 @@ def _get_valid_regex_pattern(patterns: List[str]):
return None
def thread_download(func, iterable, max_workers, **kwargs):
# Create a tqdm progress bar with the total number of files to fetch
with tqdm(
total=len(iterable),
desc=f'Fetching {len(iterable)} files') as pbar:
# Define a wrapper function to update the progress bar
def progress_wrapper(*args, **kwargs):
result = func(*args, **kwargs)
pbar.update(1)
return result
with ThreadPoolExecutor(max_workers=max_workers) as executor:
executor.map(progress_wrapper, iterable)
def _download_file_lists(
repo_files: List[str],
cache: ModelFileSystemCache,
@@ -450,19 +464,18 @@ def _download_file_lists(
fnmatch.fnmatch(repo_file['Path'], pattern)
for pattern in allow_file_pattern):
continue
# check model_file is exist in cache, if existed, skip download
if cache.exists(repo_file):
file_name = os.path.basename(repo_file['Name'])
logger.debug(
f'File {file_name} already in cache, skip downloading!')
continue
except Exception as e:
logger.warning('The file pattern is invalid : %s' % e)
else:
filtered_repo_files.append(repo_file)
def _download_single_file(repo_file):
# check model_file is exist in cache, if existed, skip download, otherwise download
if cache.exists(repo_file):
file_name = os.path.basename(repo_file['Name'])
logger.debug(
f'File {file_name} already in cache, skip downloading!')
return
if repo_type == REPO_TYPE_MODEL:
url = get_file_download_url(
model_id=repo_id,
@@ -481,11 +494,7 @@ def _download_file_lists(
download_file(url, repo_file, temporary_cache_dir, cache, headers,
cookies)
# Use thread_map for parallel downloading
thread_map(
_download_single_file,
filtered_repo_files,
max_workers=max_workers,
desc=f'Fetching {len(filtered_repo_files)} files',
leave=True,
position=max_workers)
if len(filtered_repo_files) > 0:
thread_download(_download_single_file, filtered_repo_files,
max_workers)
logger.info(f"Download {repo_type} '{repo_id}' successfully.")