[Refactor] Refactor the modelscope download module (#1683)

This commit is contained in:
Xingjun.Wang
2026-04-21 16:00:56 +08:00
committed by GitHub
parent 55af6ce8e4
commit 862afeea7b
7 changed files with 708 additions and 153 deletions

View File

@@ -36,6 +36,8 @@ run-name: Docker-${{ inputs.modelscope_branch }}-${{ inputs.image_type }}-${{ in
jobs: jobs:
build: build:
runs-on: [modelscope-self-hosted-us] runs-on: [modelscope-self-hosted-us]
env:
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: 'true'
steps: steps:
- name: ResetFileMode - name: ResetFileMode
@@ -47,7 +49,7 @@ jobs:
source ~/.bashrc source ~/.bashrc
sudo chown -R $USER:$USER $ACTION_RUNNER_DIR sudo chown -R $USER:$USER $ACTION_RUNNER_DIR
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v5
with: with:
ref: ${{ github.event.inputs.modelscope_branch }} ref: ${{ github.event.inputs.modelscope_branch }}

View File

@@ -112,9 +112,10 @@ RUN set -eux; \
done done
# if this is called "PIP_VERSION", pip explodes with "ValueError: invalid truth value '<VERSION>'" # if this is called "PIP_VERSION", pip explodes with "ValueError: invalid truth value '<VERSION>'"
ENV PYTHON_PIP_VERSION 23.0.1 # pip>=23.3 required for Python 3.12 (pkgutil.ImpImporter removed; older pip crashes on any install)
ENV PYTHON_PIP_VERSION 24.3.1
# https://github.com/docker-library/python/issues/365 # https://github.com/docker-library/python/issues/365
ENV PYTHON_SETUPTOOLS_VERSION 65.5.1 ENV PYTHON_SETUPTOOLS_VERSION 75.8.2
# https://github.com/pypa/get-pip # https://github.com/pypa/get-pip
ENV PYTHON_GET_PIP_URL https://github.com/pypa/get-pip/raw/dbf0c85f76fb6e1ab42aa672ffca6f0a675d9ee4/public/get-pip.py ENV PYTHON_GET_PIP_URL https://github.com/pypa/get-pip/raw/dbf0c85f76fb6e1ab42aa672ffca6f0a675d9ee4/public/get-pip.py
ENV PYTHON_GET_PIP_SHA256 dfe9fd5c28dc98b5ac17979a953ea550cec37ae1b47a5116007395bfacff2ab9 ENV PYTHON_GET_PIP_SHA256 dfe9fd5c28dc98b5ac17979a953ea550cec37ae1b47a5116007395bfacff2ab9

View File

@@ -112,9 +112,10 @@ RUN set -eux; \
done done
# if this is called "PIP_VERSION", pip explodes with "ValueError: invalid truth value '<VERSION>'" # if this is called "PIP_VERSION", pip explodes with "ValueError: invalid truth value '<VERSION>'"
ENV PYTHON_PIP_VERSION 23.0.1 # pip>=23.3 required for Python 3.12 (pkgutil.ImpImporter removed; older pip crashes on any install)
ENV PYTHON_PIP_VERSION 24.3.1
# https://github.com/docker-library/python/issues/365 # https://github.com/docker-library/python/issues/365
ENV PYTHON_SETUPTOOLS_VERSION 65.5.1 ENV PYTHON_SETUPTOOLS_VERSION 75.8.2
# https://github.com/pypa/get-pip # https://github.com/pypa/get-pip
ENV PYTHON_GET_PIP_URL https://github.com/pypa/get-pip/raw/dbf0c85f76fb6e1ab42aa672ffca6f0a675d9ee4/public/get-pip.py ENV PYTHON_GET_PIP_URL https://github.com/pypa/get-pip/raw/dbf0c85f76fb6e1ab42aa672ffca6f0a675d9ee4/public/get-pip.py
ENV PYTHON_GET_PIP_SHA256 dfe9fd5c28dc98b5ac17979a953ea550cec37ae1b47a5116007395bfacff2ab9 ENV PYTHON_GET_PIP_SHA256 dfe9fd5c28dc98b5ac17979a953ea550cec37ae1b47a5116007395bfacff2ab9

View File

@@ -54,6 +54,13 @@ class Builder:
def generate_dockerfile(self) -> str: def generate_dockerfile(self) -> str:
raise NotImplementedError raise NotImplementedError
@staticmethod
def _remove_pynini_related_dependency(content: str) -> str:
return content.replace(
'pip install --no-cache-dir funtextprocessing typeguard==2.13.3 scikit-learn -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html &&', # noqa: E501
'pip install --no-cache-dir typeguard==2.13.3 scikit-learn -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html &&' # noqa: E501
)
def _save_dockerfile(self, content: str) -> None: def _save_dockerfile(self, content: str) -> None:
if os.path.exists('./Dockerfile'): if os.path.exists('./Dockerfile'):
os.remove('./Dockerfile') os.remove('./Dockerfile')
@@ -302,7 +309,7 @@ class StableCPUImageBuilder(Builder):
content = content.replace('{modelscope_branch}', content = content.replace('{modelscope_branch}',
self.args.modelscope_branch) self.args.modelscope_branch)
content = content.replace('{swift_branch}', self.args.swift_branch) content = content.replace('{swift_branch}', self.args.swift_branch)
return content return self._remove_pynini_related_dependency(content)
def image(self) -> str: def image(self) -> str:
return ( return (
@@ -363,7 +370,7 @@ RUN pip install --no-cache-dir -U icecream soundfile pybind11 py-spy
content = content.replace('{modelscope_branch}', content = content.replace('{modelscope_branch}',
self.args.modelscope_branch) self.args.modelscope_branch)
content = content.replace('{swift_branch}', self.args.swift_branch) content = content.replace('{swift_branch}', self.args.swift_branch)
return content return self._remove_pynini_related_dependency(content)
def image(self) -> str: def image(self) -> str:
return ( return (

View File

@@ -39,6 +39,9 @@ from .utils.utils import (file_integrity_validation, get_endpoint,
logger = get_logger() logger = get_logger()
# Maximum number of retries for hash validation failures
HASH_RETRY_TIMES = 3
def model_file_download( def model_file_download(
model_id: str, model_id: str,
@@ -711,49 +714,66 @@ def download_file(
disable_tqdm=False, disable_tqdm=False,
progress_callbacks: List[Type[ProgressCallback]] = None, progress_callbacks: List[Type[ProgressCallback]] = None,
): ):
if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_meta[
'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: # parallel download large file.
file_digest = parallel_download(
url,
temporary_cache_dir,
file_meta['Path'],
headers=headers,
cookies=None if cookies is None else cookies.get_dict(),
file_size=file_meta['Size'],
disable_tqdm=disable_tqdm,
progress_callbacks=progress_callbacks,
)
else:
file_digest = http_get_model_file(
url,
temporary_cache_dir,
file_meta['Path'],
file_size=file_meta['Size'],
headers=headers,
cookies=cookies,
disable_tqdm=disable_tqdm,
progress_callbacks=progress_callbacks,
)
# check file integrity
temp_file = os.path.join(temporary_cache_dir, file_meta['Path']) temp_file = os.path.join(temporary_cache_dir, file_meta['Path'])
if FILE_HASH in file_meta:
expected_hash = file_meta[FILE_HASH] for hash_attempt in range(HASH_RETRY_TIMES):
if file_digest is not None: if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_meta[
if file_digest != expected_hash: 'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: # parallel download large file.
logger.warning( file_digest = parallel_download(
'Mismatched real-time digest for %s, falling back to full hash check', url,
file_meta['Path']) temporary_cache_dir,
if not file_integrity_validation(temp_file, expected_hash): file_meta['Path'],
raise FileDownloadError( headers=headers,
'File %s hash validation failed after download, ' cookies=None if cookies is None else cookies.get_dict(),
'the file may be corrupted. Please retry.' file_size=file_meta['Size'],
% file_meta['Path']) disable_tqdm=disable_tqdm,
progress_callbacks=progress_callbacks,
)
else: else:
if not file_integrity_validation(temp_file, expected_hash): file_digest = http_get_model_file(
raise FileDownloadError( url,
'File %s hash validation failed after download, ' temporary_cache_dir,
'the file may be corrupted. Please retry.' file_meta['Path'],
% file_meta['Path']) file_size=file_meta['Size'],
# put file into to cache headers=headers,
cookies=cookies,
disable_tqdm=disable_tqdm,
progress_callbacks=progress_callbacks,
)
# Check file integrity
if FILE_HASH in file_meta:
expected_hash = file_meta[FILE_HASH]
hash_valid = True
if file_digest is not None:
if file_digest != expected_hash:
logger.warning(
'Mismatched real-time digest for %s, falling back to full hash check',
file_meta['Path'])
if not file_integrity_validation(temp_file, expected_hash):
hash_valid = False
else:
if not file_integrity_validation(temp_file, expected_hash):
hash_valid = False
if not hash_valid:
if hash_attempt < HASH_RETRY_TIMES - 1:
logger.warning(
'Hash validation failed for %s, '
'retrying download (attempt %d/%d)', file_meta['Path'],
hash_attempt + 1, HASH_RETRY_TIMES)
# Clean up corrupted file before retry
if os.path.exists(temp_file):
os.remove(temp_file)
continue
else:
raise FileDownloadError(
'File %s hash validation failed after %d attempts, '
'the file may be corrupted.' %
(file_meta['Path'], HASH_RETRY_TIMES))
# Hash validation passed or no hash to validate, exit retry loop
break
# Put file into cache
return cache.put_file(file_meta, temp_file) return cache.put_file(file_meta, temp_file)

View File

@@ -3,12 +3,16 @@
import fnmatch import fnmatch
import os import os
import re import re
import threading
import uuid import uuid
from concurrent.futures import ThreadPoolExecutor
from contextlib import nullcontext from contextlib import nullcontext
from http.cookiejar import CookieJar from http.cookiejar import CookieJar
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Type, Union from typing import Dict, List, Optional, Type, Union
from tqdm.auto import tqdm
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
DEFAULT_MODEL_REVISION, DEFAULT_MODEL_REVISION,
INTRA_CLOUD_ACCELERATION, INTRA_CLOUD_ACCELERATION,
@@ -20,16 +24,19 @@ from modelscope.utils.thread_utils import thread_executor
from .api import HubApi, ModelScopeConfig from .api import HubApi, ModelScopeConfig
from .callback import ProgressCallback from .callback import ProgressCallback
from .constants import DEFAULT_MAX_WORKERS from .constants import DEFAULT_MAX_WORKERS
from .errors import InvalidParameter from .errors import FileDownloadError, InvalidParameter
from .file_download import (create_temporary_directory_and_cache, from .file_download import (create_temporary_directory_and_cache,
download_file, get_file_download_url) download_file, get_file_download_url)
from .utils.caching import ModelFileSystemCache from .utils.caching import ModelFileSystemCache
from .utils.utils import (get_model_masked_directory, from .utils.utils import (extract_root_from_patterns,
get_model_masked_directory,
model_id_to_group_owner_name, strtobool, model_id_to_group_owner_name, strtobool,
weak_file_lock) weak_file_lock)
logger = get_logger() logger = get_logger()
DEFAULT_DATASET_PAGE_SIZE = 200
def snapshot_download( def snapshot_download(
model_id: str = None, model_id: str = None,
@@ -335,13 +342,42 @@ def _snapshot_download(
snapshot_header[ snapshot_header[
'cached_model_revision'] = cache.cached_model_revision 'cached_model_revision'] = cache.cached_model_revision
# Extract server-side root filter from include patterns
extracted_root = extract_root_from_patterns(
allow_file_pattern=_normalize_patterns(allow_file_pattern),
allow_patterns=_normalize_patterns(allow_patterns))
repo_files = _api.get_model_files( repo_files = _api.get_model_files(
model_id=repo_id, model_id=repo_id,
revision=revision, revision=revision,
root=extracted_root,
recursive=True, recursive=True,
use_cookies=False if cookies is None else cookies, use_cookies=False if cookies is None else cookies,
headers=snapshot_header, headers=snapshot_header,
endpoint=endpoint) endpoint=endpoint)
# Fallback: if root filter yielded no results, retry without it
if not repo_files and extracted_root is not None:
logger.warning(
f"root='{extracted_root}' returned no model files, "
f'falling back to root=None for full listing.')
repo_files = _api.get_model_files(
model_id=repo_id,
revision=revision,
root=None,
recursive=True,
use_cookies=False if cookies is None else cookies,
headers=snapshot_header,
endpoint=endpoint)
# Apply client-side pattern filtering
repo_files = filter_files_by_patterns(
repo_files,
allow_file_pattern=allow_file_pattern,
ignore_file_pattern=ignore_file_pattern,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns)
_download_file_lists( _download_file_lists(
repo_files, repo_files,
cache, cache,
@@ -354,10 +390,7 @@ def _snapshot_download(
repo_type=repo_type, repo_type=repo_type,
revision=revision, revision=revision,
cookies=cookies, cookies=cookies,
ignore_file_pattern=ignore_file_pattern, pre_filtered=True,
allow_file_pattern=allow_file_pattern,
ignore_patterns=ignore_patterns,
allow_patterns=allow_patterns,
max_workers=max_workers, max_workers=max_workers,
endpoint=endpoint, endpoint=endpoint,
progress_callbacks=progress_callbacks, progress_callbacks=progress_callbacks,
@@ -392,75 +425,148 @@ def _snapshot_download(
group_or_owner, name = model_id_to_group_owner_name(repo_id) group_or_owner, name = model_id_to_group_owner_name(repo_id)
revision_detail = revision or DEFAULT_DATASET_REVISION revision_detail = revision or DEFAULT_DATASET_REVISION
logger.info('Fetching dataset repo file list...') # Extract server-side root filter from include patterns
repo_files = fetch_repo_files( extracted_root = extract_root_from_patterns(
_api, repo_id, revision_detail, endpoint, token=token) allow_file_pattern=_normalize_patterns(allow_file_pattern),
allow_patterns=_normalize_patterns(allow_patterns))
root_path = '/' + extracted_root if extracted_root else '/'
if repo_files is None: print(f'Fetching file list (root: {root_path})...')
logger.error( file_page_iter = _iter_dataset_file_pages(
f'Failed to retrieve file list for dataset: {repo_id}')
return None
_download_file_lists(
repo_files,
cache,
temporary_cache_dir,
repo_id,
_api, _api,
name, repo_id,
group_or_owner, revision_detail,
headers, endpoint,
repo_type=repo_type, token=token,
root_path=root_path,
allow_file_pattern=allow_file_pattern,
ignore_file_pattern=ignore_file_pattern,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns)
_pipeline_download_dataset(
file_page_iter,
cache=cache,
temporary_cache_dir=temporary_cache_dir,
repo_id=repo_id,
api=_api,
dataset_name=name,
namespace=group_or_owner,
headers=headers,
revision=revision, revision=revision,
cookies=cookies, cookies=cookies,
ignore_file_pattern=ignore_file_pattern,
allow_file_pattern=allow_file_pattern,
ignore_patterns=ignore_patterns,
allow_patterns=allow_patterns,
max_workers=max_workers, max_workers=max_workers,
endpoint=endpoint, endpoint=endpoint,
progress_callbacks=progress_callbacks, progress_callbacks=progress_callbacks)
)
cache.save_model_version(revision_info=revision_detail) cache.save_model_version(revision_info=revision_detail)
cache_root_path = cache.get_root_location() cache_root_path = cache.get_root_location()
return cache_root_path return cache_root_path
def fetch_repo_files(_api, repo_id, revision, endpoint, token=None): def fetch_repo_files(
_owner, _dataset_name = repo_id.split('/') _api,
repo_id,
revision,
endpoint,
token=None,
root_path='/',
allow_file_pattern=None,
ignore_file_pattern=None,
allow_patterns=None,
ignore_patterns=None,
page_size=DEFAULT_DATASET_PAGE_SIZE,
):
"""Fetch and filter dataset repo files with pagination and server-side prefix filtering.
Applies per-page pattern filtering to minimize memory usage.
Falls back to root_path='/' if the extracted prefix yields no results.
Args:
_api: HubApi instance.
repo_id: Dataset repo identifier (owner/name).
revision: Git revision.
endpoint: API endpoint URL.
token: Authentication token.
root_path: Server-side directory prefix filter.
allow_file_pattern: Include patterns for client-side filtering.
ignore_file_pattern: Exclude patterns for client-side filtering.
allow_patterns: Additional include patterns (HF-compatible).
ignore_patterns: Additional exclude patterns (HF-compatible).
page_size: Number of files per API page request.
Returns:
List of filtered file entry dicts.
"""
if '/' not in repo_id:
raise InvalidParameter(
f"Invalid repo_id: '{repo_id}', expected format 'owner/name'")
_owner, _dataset_name = repo_id.split('/', 1)
_hub_id, _ = _api.get_dataset_id_and_type( _hub_id, _ = _api.get_dataset_id_and_type(
dataset_name=_dataset_name, dataset_name=_dataset_name,
namespace=_owner, namespace=_owner,
endpoint=endpoint, endpoint=endpoint,
token=token) token=token)
page_number = 1 has_patterns = any([
page_size = 150 allow_file_pattern, ignore_file_pattern, allow_patterns,
repo_files = [] ignore_patterns
])
while True: def _paginate_and_filter(effective_root_path):
try: """Fetch all pages with the given root_path, applying per-page filtering."""
dataset_files = _api.get_dataset_files( page_number = 1
repo_id=repo_id, repo_files = []
revision=revision,
root_path='/',
recursive=True,
page_number=page_number,
page_size=page_size,
endpoint=endpoint,
token=token,
dataset_hub_id=_hub_id)
except Exception as e:
logger.error(f'Error fetching dataset files: {e}')
break
repo_files.extend(dataset_files) while True:
try:
dataset_files = _api.get_dataset_files(
repo_id=repo_id,
revision=revision,
root_path=effective_root_path,
recursive=True,
page_number=page_number,
page_size=page_size,
endpoint=endpoint,
token=token,
dataset_hub_id=_hub_id)
except Exception as e:
logger.error(
f'Error fetching dataset files (page {page_number}): {e}')
break
if len(dataset_files) < page_size: if not dataset_files:
break break
page_number += 1 # Per-page filtering: apply patterns immediately to reduce memory
if has_patterns:
page_filtered = filter_files_by_patterns(
dataset_files,
allow_file_pattern=allow_file_pattern,
ignore_file_pattern=ignore_file_pattern,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns)
repo_files.extend(page_filtered)
else:
# No patterns: keep all non-tree entries
repo_files.extend(
f for f in dataset_files if f.get('Type') != 'tree')
if len(dataset_files) < page_size:
break
page_number += 1
return repo_files
# Primary fetch with optimized root_path
repo_files = _paginate_and_filter(root_path)
# Fallback: if optimized root_path yielded nothing and it's not the default
if not repo_files and root_path != '/':
logger.warning(f"root_path='{root_path}' returned no results, "
f"falling back to root_path='/' for full listing.")
repo_files = _paginate_and_filter('/')
return repo_files return repo_files
@@ -494,6 +600,310 @@ def _get_valid_regex_pattern(patterns: List[str]):
return None return None
def filter_files_by_patterns(
repo_files: List[dict],
*,
allow_file_pattern: Optional[List[str]] = None,
ignore_file_pattern: Optional[List[str]] = None,
allow_patterns: Optional[List[str]] = None,
ignore_patterns: Optional[List[str]] = None,
) -> List[dict]:
"""Filter repo file entries by include/exclude patterns.
Skips 'tree' type entries. Applies fnmatch and regex pattern matching.
Returns only file entries that pass all filter criteria.
Args:
repo_files: List of file entry dicts with 'Type', 'Path', 'Name' keys.
allow_file_pattern: Include patterns (fnmatch). Files must match at least one.
ignore_file_pattern: Exclude patterns (fnmatch). Matching files are skipped.
allow_patterns: Additional include patterns (HF-compatible).
ignore_patterns: Additional exclude patterns (HF-compatible).
Returns:
List of file entries that pass all filters.
"""
ignore_patterns = _normalize_patterns(ignore_patterns)
allow_patterns = _normalize_patterns(allow_patterns)
ignore_file_pattern = _normalize_patterns(ignore_file_pattern)
allow_file_pattern = _normalize_patterns(allow_file_pattern)
ignore_regex_pattern = _get_valid_regex_pattern(ignore_file_pattern)
filtered = []
for repo_file in repo_files:
if repo_file['Type'] == 'tree':
continue
try:
if ignore_patterns and any(
fnmatch.fnmatch(repo_file['Path'], p)
for p in ignore_patterns):
continue
if ignore_file_pattern and any(
fnmatch.fnmatch(repo_file['Path'], p)
for p in ignore_file_pattern):
continue
if ignore_regex_pattern and any(
re.search(p, repo_file['Name']) is not None
for p in ignore_regex_pattern):
continue
if allow_patterns and not any(
fnmatch.fnmatch(repo_file['Path'], p)
for p in allow_patterns):
continue
if allow_file_pattern and not any(
fnmatch.fnmatch(repo_file['Path'], p)
for p in allow_file_pattern):
continue
except Exception as e:
logger.warning('Invalid file pattern: %s' % e)
continue
filtered.append(repo_file)
return filtered
def _iter_dataset_file_pages(
_api,
repo_id,
revision,
endpoint,
token=None,
root_path='/',
allow_file_pattern=None,
ignore_file_pattern=None,
allow_patterns=None,
ignore_patterns=None,
page_size=DEFAULT_DATASET_PAGE_SIZE,
):
"""Generator that yields filtered file pages from a dataset repo.
Each yield is a non-empty list of file-entry dicts for one API page.
Applies per-page pattern filtering to minimize memory usage.
Falls back to root_path='/' if the extracted prefix yields no results.
Args:
_api: HubApi instance.
repo_id: Dataset repo identifier (owner/name).
revision: Git revision.
endpoint: API endpoint URL.
token: Authentication token.
root_path: Server-side directory prefix filter.
allow_file_pattern: Include patterns (fnmatch).
ignore_file_pattern: Exclude patterns (fnmatch).
allow_patterns: Additional include patterns (HF-compatible).
ignore_patterns: Additional exclude patterns (HF-compatible).
page_size: Number of files per API page request.
Yields:
List[dict]: Non-empty list of filtered file entries per page.
"""
if '/' not in repo_id:
raise InvalidParameter(
f"Invalid repo_id: '{repo_id}', expected format 'owner/name'")
_owner, _dataset_name = repo_id.split('/', 1)
_hub_id, _ = _api.get_dataset_id_and_type(
dataset_name=_dataset_name,
namespace=_owner,
endpoint=endpoint,
token=token)
has_patterns = any([
allow_file_pattern, ignore_file_pattern, allow_patterns,
ignore_patterns
])
def _paginate_pages(effective_root_path):
"""Yield filtered file pages for the given root_path."""
page_number = 1
total_found = 0
while True:
try:
dataset_files = _api.get_dataset_files(
repo_id=repo_id,
revision=revision,
root_path=effective_root_path,
recursive=True,
page_number=page_number,
page_size=page_size,
endpoint=endpoint,
token=token,
dataset_hub_id=_hub_id)
except Exception as e:
logger.error(
f'Error fetching dataset files (page {page_number}): {e}')
break
if not dataset_files:
break
# Per-page filtering to reduce memory footprint
if has_patterns:
page_filtered = filter_files_by_patterns(
dataset_files,
allow_file_pattern=allow_file_pattern,
ignore_file_pattern=ignore_file_pattern,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns)
else:
# No patterns: keep all non-tree entries
page_filtered = [
f for f in dataset_files if f.get('Type') != 'tree'
]
total_found += len(page_filtered)
if page_filtered:
yield page_filtered
print(
f'\r Fetched {total_found} matching files '
f'({page_number} pages)...',
end='',
flush=True)
if len(dataset_files) < page_size:
break
page_number += 1
# Primary fetch with optimized root_path
try:
yielded_any = False
for page in _paginate_pages(root_path):
yielded_any = True
yield page
# Fallback: if optimized root_path yielded nothing and it's not the default
if not yielded_any and root_path != '/':
print(f"\n root_path='{root_path}' returned no results, "
f"falling back to root_path='/' for full listing.")
for page in _paginate_pages('/'):
yield page
finally:
# Terminate the \r progress line regardless of how iteration ends
print()
def _pipeline_download_dataset(
file_page_iter,
cache,
temporary_cache_dir,
repo_id,
api,
dataset_name,
namespace,
headers,
revision,
cookies,
max_workers=DEFAULT_MAX_WORKERS,
endpoint=None,
progress_callbacks=None,
):
"""Pipeline consumer: download dataset files as pages are yielded.
Consumes the page iterator from _iter_dataset_file_pages, submitting
each file to a thread pool for concurrent download. Uses tqdm for
real-time progress and thread-safe error collection.
Args:
file_page_iter: Iterator yielding List[dict] file pages.
cache: ModelFileSystemCache instance for dedup.
temporary_cache_dir: Temp staging directory.
repo_id: Dataset repo identifier.
api: HubApi instance.
dataset_name: Dataset name component.
namespace: Owner/namespace component.
headers: HTTP request headers.
revision: Git revision.
cookies: HTTP cookies.
max_workers: Thread pool concurrency.
endpoint: API endpoint URL.
progress_callbacks: Optional progress callback list.
"""
total_found = 0
total_cached = 0
failed_items = []
lock = threading.Lock()
def _on_done(future, repo_file):
"""Done callback: update progress bar and collect failures."""
try:
future.result()
except Exception as exc:
with lock:
failed_items.append((repo_file, exc))
logger.debug(
f"Download failed for {repo_file.get('Path', '?')}: {exc}")
finally:
pbar.update(1)
# tqdm wraps the executor so all callbacks fire before pbar closes
with tqdm(total=0, unit=' files', disable=False) as pbar:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
for page_files in file_page_iter:
for repo_file in page_files:
total_found += 1
pbar.total = total_found
pbar.refresh()
# Skip files already in cache
if cache.exists(repo_file):
total_cached += 1
pbar.update(1)
continue
# Build download URL
url = api.get_dataset_file_url(
file_name=repo_file['Path'],
dataset_name=dataset_name,
namespace=namespace,
revision=revision,
endpoint=endpoint)
# Submit download task
future = executor.submit(
download_file,
url,
repo_file,
temporary_cache_dir,
cache,
headers,
cookies,
disable_tqdm=False,
progress_callbacks=progress_callbacks,
)
future.add_done_callback(
lambda f, rf=repo_file: _on_done(f, rf))
# Executor __exit__ waits for all futures to complete
# Report failures after progress bar closes
if failed_items:
failed_paths = [
item.get('Path', '?') if isinstance(item, dict) else str(item)
for item, _ in failed_items
]
logger.error(f'{len(failed_items)} file(s) failed to download:\n'
+ '\n'.join(f' - {p}' for p in failed_paths))
# Completion summary (always print, even if raising after)
downloaded = total_found - total_cached - len(failed_items)
print(f'Download complete: {total_found} files found, '
f'{total_cached} cached, {downloaded} downloaded'
+ (f', {len(failed_items)} failed' if failed_items else '') + '.')
if failed_items:
raise FileDownloadError(
f'{len(failed_items)} file(s) failed to download out of '
f'{total_found}.')
def _download_file_lists( def _download_file_lists(
repo_files: List[str], repo_files: List[str],
cache: ModelFileSystemCache, cache: ModelFileSystemCache,
@@ -513,62 +923,77 @@ def _download_file_lists(
max_workers: int = 8, max_workers: int = 8,
endpoint: Optional[str] = None, endpoint: Optional[str] = None,
progress_callbacks: List[Type[ProgressCallback]] = None, progress_callbacks: List[Type[ProgressCallback]] = None,
pre_filtered: bool = False,
): ):
ignore_patterns = _normalize_patterns(ignore_patterns) if pre_filtered:
allow_patterns = _normalize_patterns(allow_patterns) # Files are already filtered by patterns; only check cache
ignore_file_pattern = _normalize_patterns(ignore_file_pattern) filtered_repo_files = []
allow_file_pattern = _normalize_patterns(allow_file_pattern) for repo_file in repo_files:
# to compatible regex usage.
ignore_regex_pattern = _get_valid_regex_pattern(ignore_file_pattern)
filtered_repo_files = []
for repo_file in repo_files:
if repo_file['Type'] == 'tree':
continue
try:
# processing patterns
if ignore_patterns and any([
fnmatch.fnmatch(repo_file['Path'], pattern)
for pattern in ignore_patterns
]):
continue
if ignore_file_pattern and any([
fnmatch.fnmatch(repo_file['Path'], pattern)
for pattern in ignore_file_pattern
]):
continue
if ignore_regex_pattern and any([
re.search(pattern, repo_file['Name']) is not None
for pattern in ignore_regex_pattern
]): # noqa E501
continue
if allow_patterns is not None and allow_patterns:
if not any(
fnmatch.fnmatch(repo_file['Path'], pattern)
for pattern in allow_patterns):
continue
if allow_file_pattern is not None and allow_file_pattern:
if not any(
fnmatch.fnmatch(repo_file['Path'], pattern)
for pattern in allow_file_pattern):
continue
# check model_file is exist in cache, if existed, skip download
if cache.exists(repo_file): if cache.exists(repo_file):
file_name = os.path.basename(repo_file['Name']) file_name = os.path.basename(repo_file['Name'])
logger.debug( logger.debug(
f'File {file_name} already in cache with identical hash, skip downloading!' f'File {file_name} already in cache with identical hash, skip downloading!'
) )
continue continue
except Exception as e:
logger.warning('The file pattern is invalid : %s' % e)
else:
filtered_repo_files.append(repo_file) filtered_repo_files.append(repo_file)
else:
# Legacy path: apply pattern filtering + cache check
ignore_patterns = _normalize_patterns(ignore_patterns)
allow_patterns = _normalize_patterns(allow_patterns)
ignore_file_pattern = _normalize_patterns(ignore_file_pattern)
allow_file_pattern = _normalize_patterns(allow_file_pattern)
# to compatible regex usage.
ignore_regex_pattern = _get_valid_regex_pattern(ignore_file_pattern)
@thread_executor(max_workers=max_workers, disable_tqdm=False) filtered_repo_files = []
for repo_file in repo_files:
if repo_file['Type'] == 'tree':
continue
try:
# processing patterns
if ignore_patterns and any([
fnmatch.fnmatch(repo_file['Path'], pattern)
for pattern in ignore_patterns
]):
continue
if ignore_file_pattern and any([
fnmatch.fnmatch(repo_file['Path'], pattern)
for pattern in ignore_file_pattern
]):
continue
if ignore_regex_pattern and any([
re.search(pattern, repo_file['Name']) is not None
for pattern in ignore_regex_pattern
]): # noqa E501
continue
if allow_patterns is not None and allow_patterns:
if not any(
fnmatch.fnmatch(repo_file['Path'], pattern)
for pattern in allow_patterns):
continue
if allow_file_pattern is not None and allow_file_pattern:
if not any(
fnmatch.fnmatch(repo_file['Path'], pattern)
for pattern in allow_file_pattern):
continue
# check model_file is exist in cache, if existed, skip download
if cache.exists(repo_file):
file_name = os.path.basename(repo_file['Name'])
logger.debug(
f'File {file_name} already in cache with identical hash, skip downloading!'
)
continue
except Exception as e:
logger.warning('The file pattern is invalid : %s' % e)
else:
filtered_repo_files.append(repo_file)
@thread_executor(
max_workers=max_workers, disable_tqdm=False, fault_tolerant=True)
def _download_single_file(repo_file): def _download_single_file(repo_file):
if repo_type == REPO_TYPE_MODEL: if repo_type == REPO_TYPE_MODEL:
url = get_file_download_url( url = get_file_download_url(
@@ -602,5 +1027,26 @@ def _download_file_lists(
if len(filtered_repo_files) > 0: if len(filtered_repo_files) > 0:
logger.info( logger.info(
f'Got {len(filtered_repo_files)} files, start to download ...') f'Got {len(filtered_repo_files)} files, start to download ...')
_download_single_file(filtered_repo_files) download_result = _download_single_file(filtered_repo_files)
logger.info(f"Download {repo_type} '{repo_id}' successfully.")
# Handle fault-tolerant results: report failed downloads
failed_items = []
if isinstance(download_result, tuple) and len(download_result) == 2:
_, failed_items = download_result
if failed_items:
failed_paths = [
item['Path'] if isinstance(item, dict) else str(item)
for item, _ in failed_items
]
logger.error(
f'{len(failed_items)} file(s) failed to download:\n'
+ '\n'.join(f' - {p}' for p in failed_paths))
logger.info(
f"Finish downloading {len(filtered_repo_files)} files for repo '{repo_id}'"
)
if failed_items:
raise FileDownloadError(
f'{len(failed_items)} file(s) failed to download out of '
f'{len(filtered_repo_files)}.')

View File

@@ -66,6 +66,83 @@ def convert_patterns(raw_input: Union[str, List[str]]):
return output return output
def extract_root_from_patterns(
allow_file_pattern: Optional[List[str]] = None,
allow_patterns: Optional[List[str]] = None,
) -> Optional[str]:
"""Extract common directory prefix from include patterns for server-side filtering.
Only processes allow/include patterns (ignore/exclude is irrelevant for prefix).
Returns None if no meaningful prefix can be extracted.
Algorithm:
1. Merge allow_file_pattern and allow_patterns into one list
2. For each pattern, find position of first wildcard char (* ? [)
3. Extract text before that position, then take directory part (up to last '/')
4. Find longest common directory prefix (path-segment-aware)
5. Return None if no valid common prefix
Examples:
['TacExo/*'] -> 'TacExo'
['data/train/*.parquet'] -> 'data/train'
['data/train/*', 'data/valid/*'] -> 'data'
['*.safetensors'] -> None
['TacExo/*', 'OtherDir/*'] -> None (no common prefix)
['data/*/train.csv'] -> 'data'
"""
# Merge both pattern lists
patterns = []
if allow_file_pattern:
patterns.extend(allow_file_pattern)
if allow_patterns:
patterns.extend(allow_patterns)
if not patterns:
return None
extracted_dirs = []
for pattern in patterns:
# Find position of first wildcard character
first_wildcard = len(pattern)
for wc in ('*', '?', '['):
pos = pattern.find(wc)
if pos != -1:
first_wildcard = min(first_wildcard, pos)
# Get text before wildcard
prefix = pattern[:first_wildcard]
# Extract directory part (up to last '/')
last_sep = prefix.rfind('/')
if last_sep > 0:
dir_part = prefix[:last_sep]
# Validate: no wildcards should remain in dir_part
if not any(c in dir_part for c in ('*', '?', '[')):
extracted_dirs.append(dir_part)
# If no '/' found or only at position 0, this pattern has no directory prefix
if not extracted_dirs:
return None
# Find longest common directory prefix (path-segment-aware)
if len(extracted_dirs) == 1:
return extracted_dirs[0]
# Split all paths into segments and find common prefix segments
split_paths = [d.split('/') for d in extracted_dirs]
common_segments = []
for segments in zip(*split_paths):
if len(set(segments)) == 1:
common_segments.append(segments[0])
else:
break
if not common_segments:
return None
return '/'.join(common_segments)
# during model download, the '.' would be converted to '___' to produce # during model download, the '.' would be converted to '___' to produce
# actual physical (masked) directory for storage # actual physical (masked) directory for storage
def get_model_masked_directory(directory, model_id): def get_model_masked_directory(directory, model_id):
@@ -134,7 +211,8 @@ def get_endpoint(cn_site=True):
def compute_hash(file_path): def compute_hash(file_path):
BUFFER_SIZE = 1024 * 64 # 64k buffer size # 16MB buffer for large file hash computation
BUFFER_SIZE = 1024 * 1024 * 16
sha256_hash = hashlib.sha256() sha256_hash = hashlib.sha256()
with open(file_path, 'rb') as f: with open(file_path, 'rb') as f:
while True: while True: