mirror of
https://github.com/modelscope/modelscope.git
synced 2026-05-18 05:05:00 +02:00
[Refactor] Refactor the modelscope download module (#1683)
This commit is contained in:
4
.github/workflows/docker-image.yml
vendored
4
.github/workflows/docker-image.yml
vendored
@@ -36,6 +36,8 @@ run-name: Docker-${{ inputs.modelscope_branch }}-${{ inputs.image_type }}-${{ in
|
||||
jobs:
|
||||
build:
|
||||
runs-on: [modelscope-self-hosted-us]
|
||||
env:
|
||||
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: 'true'
|
||||
|
||||
steps:
|
||||
- name: ResetFileMode
|
||||
@@ -47,7 +49,7 @@ jobs:
|
||||
source ~/.bashrc
|
||||
sudo chown -R $USER:$USER $ACTION_RUNNER_DIR
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
ref: ${{ github.event.inputs.modelscope_branch }}
|
||||
|
||||
|
||||
@@ -112,9 +112,10 @@ RUN set -eux; \
|
||||
done
|
||||
|
||||
# if this is called "PIP_VERSION", pip explodes with "ValueError: invalid truth value '<VERSION>'"
|
||||
ENV PYTHON_PIP_VERSION 23.0.1
|
||||
# pip>=23.3 required for Python 3.12 (pkgutil.ImpImporter removed; older pip crashes on any install)
|
||||
ENV PYTHON_PIP_VERSION 24.3.1
|
||||
# https://github.com/docker-library/python/issues/365
|
||||
ENV PYTHON_SETUPTOOLS_VERSION 65.5.1
|
||||
ENV PYTHON_SETUPTOOLS_VERSION 75.8.2
|
||||
# https://github.com/pypa/get-pip
|
||||
ENV PYTHON_GET_PIP_URL https://github.com/pypa/get-pip/raw/dbf0c85f76fb6e1ab42aa672ffca6f0a675d9ee4/public/get-pip.py
|
||||
ENV PYTHON_GET_PIP_SHA256 dfe9fd5c28dc98b5ac17979a953ea550cec37ae1b47a5116007395bfacff2ab9
|
||||
|
||||
@@ -112,9 +112,10 @@ RUN set -eux; \
|
||||
done
|
||||
|
||||
# if this is called "PIP_VERSION", pip explodes with "ValueError: invalid truth value '<VERSION>'"
|
||||
ENV PYTHON_PIP_VERSION 23.0.1
|
||||
# pip>=23.3 required for Python 3.12 (pkgutil.ImpImporter removed; older pip crashes on any install)
|
||||
ENV PYTHON_PIP_VERSION 24.3.1
|
||||
# https://github.com/docker-library/python/issues/365
|
||||
ENV PYTHON_SETUPTOOLS_VERSION 65.5.1
|
||||
ENV PYTHON_SETUPTOOLS_VERSION 75.8.2
|
||||
# https://github.com/pypa/get-pip
|
||||
ENV PYTHON_GET_PIP_URL https://github.com/pypa/get-pip/raw/dbf0c85f76fb6e1ab42aa672ffca6f0a675d9ee4/public/get-pip.py
|
||||
ENV PYTHON_GET_PIP_SHA256 dfe9fd5c28dc98b5ac17979a953ea550cec37ae1b47a5116007395bfacff2ab9
|
||||
|
||||
@@ -54,6 +54,13 @@ class Builder:
|
||||
def generate_dockerfile(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def _remove_pynini_related_dependency(content: str) -> str:
|
||||
return content.replace(
|
||||
'pip install --no-cache-dir funtextprocessing typeguard==2.13.3 scikit-learn -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html &&', # noqa: E501
|
||||
'pip install --no-cache-dir typeguard==2.13.3 scikit-learn -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html &&' # noqa: E501
|
||||
)
|
||||
|
||||
def _save_dockerfile(self, content: str) -> None:
|
||||
if os.path.exists('./Dockerfile'):
|
||||
os.remove('./Dockerfile')
|
||||
@@ -302,7 +309,7 @@ class StableCPUImageBuilder(Builder):
|
||||
content = content.replace('{modelscope_branch}',
|
||||
self.args.modelscope_branch)
|
||||
content = content.replace('{swift_branch}', self.args.swift_branch)
|
||||
return content
|
||||
return self._remove_pynini_related_dependency(content)
|
||||
|
||||
def image(self) -> str:
|
||||
return (
|
||||
@@ -363,7 +370,7 @@ RUN pip install --no-cache-dir -U icecream soundfile pybind11 py-spy
|
||||
content = content.replace('{modelscope_branch}',
|
||||
self.args.modelscope_branch)
|
||||
content = content.replace('{swift_branch}', self.args.swift_branch)
|
||||
return content
|
||||
return self._remove_pynini_related_dependency(content)
|
||||
|
||||
def image(self) -> str:
|
||||
return (
|
||||
|
||||
@@ -39,6 +39,9 @@ from .utils.utils import (file_integrity_validation, get_endpoint,
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
# Maximum number of retries for hash validation failures
|
||||
HASH_RETRY_TIMES = 3
|
||||
|
||||
|
||||
def model_file_download(
|
||||
model_id: str,
|
||||
@@ -711,49 +714,66 @@ def download_file(
|
||||
disable_tqdm=False,
|
||||
progress_callbacks: List[Type[ProgressCallback]] = None,
|
||||
):
|
||||
if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_meta[
|
||||
'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: # parallel download large file.
|
||||
file_digest = parallel_download(
|
||||
url,
|
||||
temporary_cache_dir,
|
||||
file_meta['Path'],
|
||||
headers=headers,
|
||||
cookies=None if cookies is None else cookies.get_dict(),
|
||||
file_size=file_meta['Size'],
|
||||
disable_tqdm=disable_tqdm,
|
||||
progress_callbacks=progress_callbacks,
|
||||
)
|
||||
else:
|
||||
file_digest = http_get_model_file(
|
||||
url,
|
||||
temporary_cache_dir,
|
||||
file_meta['Path'],
|
||||
file_size=file_meta['Size'],
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
disable_tqdm=disable_tqdm,
|
||||
progress_callbacks=progress_callbacks,
|
||||
)
|
||||
|
||||
# check file integrity
|
||||
temp_file = os.path.join(temporary_cache_dir, file_meta['Path'])
|
||||
if FILE_HASH in file_meta:
|
||||
expected_hash = file_meta[FILE_HASH]
|
||||
if file_digest is not None:
|
||||
if file_digest != expected_hash:
|
||||
logger.warning(
|
||||
'Mismatched real-time digest for %s, falling back to full hash check',
|
||||
file_meta['Path'])
|
||||
if not file_integrity_validation(temp_file, expected_hash):
|
||||
raise FileDownloadError(
|
||||
'File %s hash validation failed after download, '
|
||||
'the file may be corrupted. Please retry.'
|
||||
% file_meta['Path'])
|
||||
|
||||
for hash_attempt in range(HASH_RETRY_TIMES):
|
||||
if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_meta[
|
||||
'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: # parallel download large file.
|
||||
file_digest = parallel_download(
|
||||
url,
|
||||
temporary_cache_dir,
|
||||
file_meta['Path'],
|
||||
headers=headers,
|
||||
cookies=None if cookies is None else cookies.get_dict(),
|
||||
file_size=file_meta['Size'],
|
||||
disable_tqdm=disable_tqdm,
|
||||
progress_callbacks=progress_callbacks,
|
||||
)
|
||||
else:
|
||||
if not file_integrity_validation(temp_file, expected_hash):
|
||||
raise FileDownloadError(
|
||||
'File %s hash validation failed after download, '
|
||||
'the file may be corrupted. Please retry.'
|
||||
% file_meta['Path'])
|
||||
# put file into to cache
|
||||
file_digest = http_get_model_file(
|
||||
url,
|
||||
temporary_cache_dir,
|
||||
file_meta['Path'],
|
||||
file_size=file_meta['Size'],
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
disable_tqdm=disable_tqdm,
|
||||
progress_callbacks=progress_callbacks,
|
||||
)
|
||||
|
||||
# Check file integrity
|
||||
if FILE_HASH in file_meta:
|
||||
expected_hash = file_meta[FILE_HASH]
|
||||
hash_valid = True
|
||||
if file_digest is not None:
|
||||
if file_digest != expected_hash:
|
||||
logger.warning(
|
||||
'Mismatched real-time digest for %s, falling back to full hash check',
|
||||
file_meta['Path'])
|
||||
if not file_integrity_validation(temp_file, expected_hash):
|
||||
hash_valid = False
|
||||
else:
|
||||
if not file_integrity_validation(temp_file, expected_hash):
|
||||
hash_valid = False
|
||||
|
||||
if not hash_valid:
|
||||
if hash_attempt < HASH_RETRY_TIMES - 1:
|
||||
logger.warning(
|
||||
'Hash validation failed for %s, '
|
||||
'retrying download (attempt %d/%d)', file_meta['Path'],
|
||||
hash_attempt + 1, HASH_RETRY_TIMES)
|
||||
# Clean up corrupted file before retry
|
||||
if os.path.exists(temp_file):
|
||||
os.remove(temp_file)
|
||||
continue
|
||||
else:
|
||||
raise FileDownloadError(
|
||||
'File %s hash validation failed after %d attempts, '
|
||||
'the file may be corrupted.' %
|
||||
(file_meta['Path'], HASH_RETRY_TIMES))
|
||||
|
||||
# Hash validation passed or no hash to validate, exit retry loop
|
||||
break
|
||||
|
||||
# Put file into cache
|
||||
return cache.put_file(file_meta, temp_file)
|
||||
|
||||
@@ -3,12 +3,16 @@
|
||||
import fnmatch
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import nullcontext
|
||||
from http.cookiejar import CookieJar
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Type, Union
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
|
||||
DEFAULT_MODEL_REVISION,
|
||||
INTRA_CLOUD_ACCELERATION,
|
||||
@@ -20,16 +24,19 @@ from modelscope.utils.thread_utils import thread_executor
|
||||
from .api import HubApi, ModelScopeConfig
|
||||
from .callback import ProgressCallback
|
||||
from .constants import DEFAULT_MAX_WORKERS
|
||||
from .errors import InvalidParameter
|
||||
from .errors import FileDownloadError, InvalidParameter
|
||||
from .file_download import (create_temporary_directory_and_cache,
|
||||
download_file, get_file_download_url)
|
||||
from .utils.caching import ModelFileSystemCache
|
||||
from .utils.utils import (get_model_masked_directory,
|
||||
from .utils.utils import (extract_root_from_patterns,
|
||||
get_model_masked_directory,
|
||||
model_id_to_group_owner_name, strtobool,
|
||||
weak_file_lock)
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
DEFAULT_DATASET_PAGE_SIZE = 200
|
||||
|
||||
|
||||
def snapshot_download(
|
||||
model_id: str = None,
|
||||
@@ -335,13 +342,42 @@ def _snapshot_download(
|
||||
snapshot_header[
|
||||
'cached_model_revision'] = cache.cached_model_revision
|
||||
|
||||
# Extract server-side root filter from include patterns
|
||||
extracted_root = extract_root_from_patterns(
|
||||
allow_file_pattern=_normalize_patterns(allow_file_pattern),
|
||||
allow_patterns=_normalize_patterns(allow_patterns))
|
||||
|
||||
repo_files = _api.get_model_files(
|
||||
model_id=repo_id,
|
||||
revision=revision,
|
||||
root=extracted_root,
|
||||
recursive=True,
|
||||
use_cookies=False if cookies is None else cookies,
|
||||
headers=snapshot_header,
|
||||
endpoint=endpoint)
|
||||
|
||||
# Fallback: if root filter yielded no results, retry without it
|
||||
if not repo_files and extracted_root is not None:
|
||||
logger.warning(
|
||||
f"root='{extracted_root}' returned no model files, "
|
||||
f'falling back to root=None for full listing.')
|
||||
repo_files = _api.get_model_files(
|
||||
model_id=repo_id,
|
||||
revision=revision,
|
||||
root=None,
|
||||
recursive=True,
|
||||
use_cookies=False if cookies is None else cookies,
|
||||
headers=snapshot_header,
|
||||
endpoint=endpoint)
|
||||
|
||||
# Apply client-side pattern filtering
|
||||
repo_files = filter_files_by_patterns(
|
||||
repo_files,
|
||||
allow_file_pattern=allow_file_pattern,
|
||||
ignore_file_pattern=ignore_file_pattern,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns)
|
||||
|
||||
_download_file_lists(
|
||||
repo_files,
|
||||
cache,
|
||||
@@ -354,10 +390,7 @@ def _snapshot_download(
|
||||
repo_type=repo_type,
|
||||
revision=revision,
|
||||
cookies=cookies,
|
||||
ignore_file_pattern=ignore_file_pattern,
|
||||
allow_file_pattern=allow_file_pattern,
|
||||
ignore_patterns=ignore_patterns,
|
||||
allow_patterns=allow_patterns,
|
||||
pre_filtered=True,
|
||||
max_workers=max_workers,
|
||||
endpoint=endpoint,
|
||||
progress_callbacks=progress_callbacks,
|
||||
@@ -392,75 +425,148 @@ def _snapshot_download(
|
||||
group_or_owner, name = model_id_to_group_owner_name(repo_id)
|
||||
revision_detail = revision or DEFAULT_DATASET_REVISION
|
||||
|
||||
logger.info('Fetching dataset repo file list...')
|
||||
repo_files = fetch_repo_files(
|
||||
_api, repo_id, revision_detail, endpoint, token=token)
|
||||
# Extract server-side root filter from include patterns
|
||||
extracted_root = extract_root_from_patterns(
|
||||
allow_file_pattern=_normalize_patterns(allow_file_pattern),
|
||||
allow_patterns=_normalize_patterns(allow_patterns))
|
||||
root_path = '/' + extracted_root if extracted_root else '/'
|
||||
|
||||
if repo_files is None:
|
||||
logger.error(
|
||||
f'Failed to retrieve file list for dataset: {repo_id}')
|
||||
return None
|
||||
|
||||
_download_file_lists(
|
||||
repo_files,
|
||||
cache,
|
||||
temporary_cache_dir,
|
||||
repo_id,
|
||||
print(f'Fetching file list (root: {root_path})...')
|
||||
file_page_iter = _iter_dataset_file_pages(
|
||||
_api,
|
||||
name,
|
||||
group_or_owner,
|
||||
headers,
|
||||
repo_type=repo_type,
|
||||
repo_id,
|
||||
revision_detail,
|
||||
endpoint,
|
||||
token=token,
|
||||
root_path=root_path,
|
||||
allow_file_pattern=allow_file_pattern,
|
||||
ignore_file_pattern=ignore_file_pattern,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns)
|
||||
|
||||
_pipeline_download_dataset(
|
||||
file_page_iter,
|
||||
cache=cache,
|
||||
temporary_cache_dir=temporary_cache_dir,
|
||||
repo_id=repo_id,
|
||||
api=_api,
|
||||
dataset_name=name,
|
||||
namespace=group_or_owner,
|
||||
headers=headers,
|
||||
revision=revision,
|
||||
cookies=cookies,
|
||||
ignore_file_pattern=ignore_file_pattern,
|
||||
allow_file_pattern=allow_file_pattern,
|
||||
ignore_patterns=ignore_patterns,
|
||||
allow_patterns=allow_patterns,
|
||||
max_workers=max_workers,
|
||||
endpoint=endpoint,
|
||||
progress_callbacks=progress_callbacks,
|
||||
)
|
||||
progress_callbacks=progress_callbacks)
|
||||
|
||||
cache.save_model_version(revision_info=revision_detail)
|
||||
cache_root_path = cache.get_root_location()
|
||||
return cache_root_path
|
||||
|
||||
|
||||
def fetch_repo_files(_api, repo_id, revision, endpoint, token=None):
|
||||
_owner, _dataset_name = repo_id.split('/')
|
||||
def fetch_repo_files(
|
||||
_api,
|
||||
repo_id,
|
||||
revision,
|
||||
endpoint,
|
||||
token=None,
|
||||
root_path='/',
|
||||
allow_file_pattern=None,
|
||||
ignore_file_pattern=None,
|
||||
allow_patterns=None,
|
||||
ignore_patterns=None,
|
||||
page_size=DEFAULT_DATASET_PAGE_SIZE,
|
||||
):
|
||||
"""Fetch and filter dataset repo files with pagination and server-side prefix filtering.
|
||||
|
||||
Applies per-page pattern filtering to minimize memory usage.
|
||||
Falls back to root_path='/' if the extracted prefix yields no results.
|
||||
|
||||
Args:
|
||||
_api: HubApi instance.
|
||||
repo_id: Dataset repo identifier (owner/name).
|
||||
revision: Git revision.
|
||||
endpoint: API endpoint URL.
|
||||
token: Authentication token.
|
||||
root_path: Server-side directory prefix filter.
|
||||
allow_file_pattern: Include patterns for client-side filtering.
|
||||
ignore_file_pattern: Exclude patterns for client-side filtering.
|
||||
allow_patterns: Additional include patterns (HF-compatible).
|
||||
ignore_patterns: Additional exclude patterns (HF-compatible).
|
||||
page_size: Number of files per API page request.
|
||||
|
||||
Returns:
|
||||
List of filtered file entry dicts.
|
||||
"""
|
||||
if '/' not in repo_id:
|
||||
raise InvalidParameter(
|
||||
f"Invalid repo_id: '{repo_id}', expected format 'owner/name'")
|
||||
_owner, _dataset_name = repo_id.split('/', 1)
|
||||
_hub_id, _ = _api.get_dataset_id_and_type(
|
||||
dataset_name=_dataset_name,
|
||||
namespace=_owner,
|
||||
endpoint=endpoint,
|
||||
token=token)
|
||||
|
||||
page_number = 1
|
||||
page_size = 150
|
||||
repo_files = []
|
||||
has_patterns = any([
|
||||
allow_file_pattern, ignore_file_pattern, allow_patterns,
|
||||
ignore_patterns
|
||||
])
|
||||
|
||||
while True:
|
||||
try:
|
||||
dataset_files = _api.get_dataset_files(
|
||||
repo_id=repo_id,
|
||||
revision=revision,
|
||||
root_path='/',
|
||||
recursive=True,
|
||||
page_number=page_number,
|
||||
page_size=page_size,
|
||||
endpoint=endpoint,
|
||||
token=token,
|
||||
dataset_hub_id=_hub_id)
|
||||
except Exception as e:
|
||||
logger.error(f'Error fetching dataset files: {e}')
|
||||
break
|
||||
def _paginate_and_filter(effective_root_path):
|
||||
"""Fetch all pages with the given root_path, applying per-page filtering."""
|
||||
page_number = 1
|
||||
repo_files = []
|
||||
|
||||
repo_files.extend(dataset_files)
|
||||
while True:
|
||||
try:
|
||||
dataset_files = _api.get_dataset_files(
|
||||
repo_id=repo_id,
|
||||
revision=revision,
|
||||
root_path=effective_root_path,
|
||||
recursive=True,
|
||||
page_number=page_number,
|
||||
page_size=page_size,
|
||||
endpoint=endpoint,
|
||||
token=token,
|
||||
dataset_hub_id=_hub_id)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'Error fetching dataset files (page {page_number}): {e}')
|
||||
break
|
||||
|
||||
if len(dataset_files) < page_size:
|
||||
break
|
||||
if not dataset_files:
|
||||
break
|
||||
|
||||
page_number += 1
|
||||
# Per-page filtering: apply patterns immediately to reduce memory
|
||||
if has_patterns:
|
||||
page_filtered = filter_files_by_patterns(
|
||||
dataset_files,
|
||||
allow_file_pattern=allow_file_pattern,
|
||||
ignore_file_pattern=ignore_file_pattern,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns)
|
||||
repo_files.extend(page_filtered)
|
||||
else:
|
||||
# No patterns: keep all non-tree entries
|
||||
repo_files.extend(
|
||||
f for f in dataset_files if f.get('Type') != 'tree')
|
||||
|
||||
if len(dataset_files) < page_size:
|
||||
break
|
||||
|
||||
page_number += 1
|
||||
|
||||
return repo_files
|
||||
|
||||
# Primary fetch with optimized root_path
|
||||
repo_files = _paginate_and_filter(root_path)
|
||||
|
||||
# Fallback: if optimized root_path yielded nothing and it's not the default
|
||||
if not repo_files and root_path != '/':
|
||||
logger.warning(f"root_path='{root_path}' returned no results, "
|
||||
f"falling back to root_path='/' for full listing.")
|
||||
repo_files = _paginate_and_filter('/')
|
||||
|
||||
return repo_files
|
||||
|
||||
@@ -494,6 +600,310 @@ def _get_valid_regex_pattern(patterns: List[str]):
|
||||
return None
|
||||
|
||||
|
||||
def filter_files_by_patterns(
|
||||
repo_files: List[dict],
|
||||
*,
|
||||
allow_file_pattern: Optional[List[str]] = None,
|
||||
ignore_file_pattern: Optional[List[str]] = None,
|
||||
allow_patterns: Optional[List[str]] = None,
|
||||
ignore_patterns: Optional[List[str]] = None,
|
||||
) -> List[dict]:
|
||||
"""Filter repo file entries by include/exclude patterns.
|
||||
|
||||
Skips 'tree' type entries. Applies fnmatch and regex pattern matching.
|
||||
Returns only file entries that pass all filter criteria.
|
||||
|
||||
Args:
|
||||
repo_files: List of file entry dicts with 'Type', 'Path', 'Name' keys.
|
||||
allow_file_pattern: Include patterns (fnmatch). Files must match at least one.
|
||||
ignore_file_pattern: Exclude patterns (fnmatch). Matching files are skipped.
|
||||
allow_patterns: Additional include patterns (HF-compatible).
|
||||
ignore_patterns: Additional exclude patterns (HF-compatible).
|
||||
|
||||
Returns:
|
||||
List of file entries that pass all filters.
|
||||
"""
|
||||
ignore_patterns = _normalize_patterns(ignore_patterns)
|
||||
allow_patterns = _normalize_patterns(allow_patterns)
|
||||
ignore_file_pattern = _normalize_patterns(ignore_file_pattern)
|
||||
allow_file_pattern = _normalize_patterns(allow_file_pattern)
|
||||
ignore_regex_pattern = _get_valid_regex_pattern(ignore_file_pattern)
|
||||
|
||||
filtered = []
|
||||
for repo_file in repo_files:
|
||||
if repo_file['Type'] == 'tree':
|
||||
continue
|
||||
try:
|
||||
if ignore_patterns and any(
|
||||
fnmatch.fnmatch(repo_file['Path'], p)
|
||||
for p in ignore_patterns):
|
||||
continue
|
||||
|
||||
if ignore_file_pattern and any(
|
||||
fnmatch.fnmatch(repo_file['Path'], p)
|
||||
for p in ignore_file_pattern):
|
||||
continue
|
||||
|
||||
if ignore_regex_pattern and any(
|
||||
re.search(p, repo_file['Name']) is not None
|
||||
for p in ignore_regex_pattern):
|
||||
continue
|
||||
|
||||
if allow_patterns and not any(
|
||||
fnmatch.fnmatch(repo_file['Path'], p)
|
||||
for p in allow_patterns):
|
||||
continue
|
||||
|
||||
if allow_file_pattern and not any(
|
||||
fnmatch.fnmatch(repo_file['Path'], p)
|
||||
for p in allow_file_pattern):
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning('Invalid file pattern: %s' % e)
|
||||
continue
|
||||
|
||||
filtered.append(repo_file)
|
||||
|
||||
return filtered
|
||||
|
||||
|
||||
def _iter_dataset_file_pages(
|
||||
_api,
|
||||
repo_id,
|
||||
revision,
|
||||
endpoint,
|
||||
token=None,
|
||||
root_path='/',
|
||||
allow_file_pattern=None,
|
||||
ignore_file_pattern=None,
|
||||
allow_patterns=None,
|
||||
ignore_patterns=None,
|
||||
page_size=DEFAULT_DATASET_PAGE_SIZE,
|
||||
):
|
||||
"""Generator that yields filtered file pages from a dataset repo.
|
||||
|
||||
Each yield is a non-empty list of file-entry dicts for one API page.
|
||||
Applies per-page pattern filtering to minimize memory usage.
|
||||
Falls back to root_path='/' if the extracted prefix yields no results.
|
||||
|
||||
Args:
|
||||
_api: HubApi instance.
|
||||
repo_id: Dataset repo identifier (owner/name).
|
||||
revision: Git revision.
|
||||
endpoint: API endpoint URL.
|
||||
token: Authentication token.
|
||||
root_path: Server-side directory prefix filter.
|
||||
allow_file_pattern: Include patterns (fnmatch).
|
||||
ignore_file_pattern: Exclude patterns (fnmatch).
|
||||
allow_patterns: Additional include patterns (HF-compatible).
|
||||
ignore_patterns: Additional exclude patterns (HF-compatible).
|
||||
page_size: Number of files per API page request.
|
||||
|
||||
Yields:
|
||||
List[dict]: Non-empty list of filtered file entries per page.
|
||||
"""
|
||||
if '/' not in repo_id:
|
||||
raise InvalidParameter(
|
||||
f"Invalid repo_id: '{repo_id}', expected format 'owner/name'")
|
||||
|
||||
_owner, _dataset_name = repo_id.split('/', 1)
|
||||
_hub_id, _ = _api.get_dataset_id_and_type(
|
||||
dataset_name=_dataset_name,
|
||||
namespace=_owner,
|
||||
endpoint=endpoint,
|
||||
token=token)
|
||||
|
||||
has_patterns = any([
|
||||
allow_file_pattern, ignore_file_pattern, allow_patterns,
|
||||
ignore_patterns
|
||||
])
|
||||
|
||||
def _paginate_pages(effective_root_path):
|
||||
"""Yield filtered file pages for the given root_path."""
|
||||
page_number = 1
|
||||
total_found = 0
|
||||
|
||||
while True:
|
||||
try:
|
||||
dataset_files = _api.get_dataset_files(
|
||||
repo_id=repo_id,
|
||||
revision=revision,
|
||||
root_path=effective_root_path,
|
||||
recursive=True,
|
||||
page_number=page_number,
|
||||
page_size=page_size,
|
||||
endpoint=endpoint,
|
||||
token=token,
|
||||
dataset_hub_id=_hub_id)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'Error fetching dataset files (page {page_number}): {e}')
|
||||
break
|
||||
|
||||
if not dataset_files:
|
||||
break
|
||||
|
||||
# Per-page filtering to reduce memory footprint
|
||||
if has_patterns:
|
||||
page_filtered = filter_files_by_patterns(
|
||||
dataset_files,
|
||||
allow_file_pattern=allow_file_pattern,
|
||||
ignore_file_pattern=ignore_file_pattern,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns)
|
||||
else:
|
||||
# No patterns: keep all non-tree entries
|
||||
page_filtered = [
|
||||
f for f in dataset_files if f.get('Type') != 'tree'
|
||||
]
|
||||
|
||||
total_found += len(page_filtered)
|
||||
if page_filtered:
|
||||
yield page_filtered
|
||||
|
||||
print(
|
||||
f'\r Fetched {total_found} matching files '
|
||||
f'({page_number} pages)...',
|
||||
end='',
|
||||
flush=True)
|
||||
|
||||
if len(dataset_files) < page_size:
|
||||
break
|
||||
|
||||
page_number += 1
|
||||
|
||||
# Primary fetch with optimized root_path
|
||||
try:
|
||||
yielded_any = False
|
||||
for page in _paginate_pages(root_path):
|
||||
yielded_any = True
|
||||
yield page
|
||||
|
||||
# Fallback: if optimized root_path yielded nothing and it's not the default
|
||||
if not yielded_any and root_path != '/':
|
||||
print(f"\n root_path='{root_path}' returned no results, "
|
||||
f"falling back to root_path='/' for full listing.")
|
||||
for page in _paginate_pages('/'):
|
||||
yield page
|
||||
finally:
|
||||
# Terminate the \r progress line regardless of how iteration ends
|
||||
print()
|
||||
|
||||
|
||||
def _pipeline_download_dataset(
|
||||
file_page_iter,
|
||||
cache,
|
||||
temporary_cache_dir,
|
||||
repo_id,
|
||||
api,
|
||||
dataset_name,
|
||||
namespace,
|
||||
headers,
|
||||
revision,
|
||||
cookies,
|
||||
max_workers=DEFAULT_MAX_WORKERS,
|
||||
endpoint=None,
|
||||
progress_callbacks=None,
|
||||
):
|
||||
"""Pipeline consumer: download dataset files as pages are yielded.
|
||||
|
||||
Consumes the page iterator from _iter_dataset_file_pages, submitting
|
||||
each file to a thread pool for concurrent download. Uses tqdm for
|
||||
real-time progress and thread-safe error collection.
|
||||
|
||||
Args:
|
||||
file_page_iter: Iterator yielding List[dict] file pages.
|
||||
cache: ModelFileSystemCache instance for dedup.
|
||||
temporary_cache_dir: Temp staging directory.
|
||||
repo_id: Dataset repo identifier.
|
||||
api: HubApi instance.
|
||||
dataset_name: Dataset name component.
|
||||
namespace: Owner/namespace component.
|
||||
headers: HTTP request headers.
|
||||
revision: Git revision.
|
||||
cookies: HTTP cookies.
|
||||
max_workers: Thread pool concurrency.
|
||||
endpoint: API endpoint URL.
|
||||
progress_callbacks: Optional progress callback list.
|
||||
"""
|
||||
total_found = 0
|
||||
total_cached = 0
|
||||
failed_items = []
|
||||
lock = threading.Lock()
|
||||
|
||||
def _on_done(future, repo_file):
|
||||
"""Done callback: update progress bar and collect failures."""
|
||||
try:
|
||||
future.result()
|
||||
except Exception as exc:
|
||||
with lock:
|
||||
failed_items.append((repo_file, exc))
|
||||
logger.debug(
|
||||
f"Download failed for {repo_file.get('Path', '?')}: {exc}")
|
||||
finally:
|
||||
pbar.update(1)
|
||||
|
||||
# tqdm wraps the executor so all callbacks fire before pbar closes
|
||||
with tqdm(total=0, unit=' files', disable=False) as pbar:
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
for page_files in file_page_iter:
|
||||
for repo_file in page_files:
|
||||
total_found += 1
|
||||
pbar.total = total_found
|
||||
pbar.refresh()
|
||||
|
||||
# Skip files already in cache
|
||||
if cache.exists(repo_file):
|
||||
total_cached += 1
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
# Build download URL
|
||||
url = api.get_dataset_file_url(
|
||||
file_name=repo_file['Path'],
|
||||
dataset_name=dataset_name,
|
||||
namespace=namespace,
|
||||
revision=revision,
|
||||
endpoint=endpoint)
|
||||
|
||||
# Submit download task
|
||||
future = executor.submit(
|
||||
download_file,
|
||||
url,
|
||||
repo_file,
|
||||
temporary_cache_dir,
|
||||
cache,
|
||||
headers,
|
||||
cookies,
|
||||
disable_tqdm=False,
|
||||
progress_callbacks=progress_callbacks,
|
||||
)
|
||||
future.add_done_callback(
|
||||
lambda f, rf=repo_file: _on_done(f, rf))
|
||||
|
||||
# Executor __exit__ waits for all futures to complete
|
||||
|
||||
# Report failures after progress bar closes
|
||||
if failed_items:
|
||||
failed_paths = [
|
||||
item.get('Path', '?') if isinstance(item, dict) else str(item)
|
||||
for item, _ in failed_items
|
||||
]
|
||||
logger.error(f'{len(failed_items)} file(s) failed to download:\n'
|
||||
+ '\n'.join(f' - {p}' for p in failed_paths))
|
||||
|
||||
# Completion summary (always print, even if raising after)
|
||||
downloaded = total_found - total_cached - len(failed_items)
|
||||
print(f'Download complete: {total_found} files found, '
|
||||
f'{total_cached} cached, {downloaded} downloaded'
|
||||
+ (f', {len(failed_items)} failed' if failed_items else '') + '.')
|
||||
|
||||
if failed_items:
|
||||
raise FileDownloadError(
|
||||
f'{len(failed_items)} file(s) failed to download out of '
|
||||
f'{total_found}.')
|
||||
|
||||
|
||||
def _download_file_lists(
|
||||
repo_files: List[str],
|
||||
cache: ModelFileSystemCache,
|
||||
@@ -513,62 +923,77 @@ def _download_file_lists(
|
||||
max_workers: int = 8,
|
||||
endpoint: Optional[str] = None,
|
||||
progress_callbacks: List[Type[ProgressCallback]] = None,
|
||||
pre_filtered: bool = False,
|
||||
):
|
||||
ignore_patterns = _normalize_patterns(ignore_patterns)
|
||||
allow_patterns = _normalize_patterns(allow_patterns)
|
||||
ignore_file_pattern = _normalize_patterns(ignore_file_pattern)
|
||||
allow_file_pattern = _normalize_patterns(allow_file_pattern)
|
||||
# to compatible regex usage.
|
||||
ignore_regex_pattern = _get_valid_regex_pattern(ignore_file_pattern)
|
||||
|
||||
filtered_repo_files = []
|
||||
for repo_file in repo_files:
|
||||
if repo_file['Type'] == 'tree':
|
||||
continue
|
||||
try:
|
||||
# processing patterns
|
||||
if ignore_patterns and any([
|
||||
fnmatch.fnmatch(repo_file['Path'], pattern)
|
||||
for pattern in ignore_patterns
|
||||
]):
|
||||
continue
|
||||
|
||||
if ignore_file_pattern and any([
|
||||
fnmatch.fnmatch(repo_file['Path'], pattern)
|
||||
for pattern in ignore_file_pattern
|
||||
]):
|
||||
continue
|
||||
|
||||
if ignore_regex_pattern and any([
|
||||
re.search(pattern, repo_file['Name']) is not None
|
||||
for pattern in ignore_regex_pattern
|
||||
]): # noqa E501
|
||||
continue
|
||||
|
||||
if allow_patterns is not None and allow_patterns:
|
||||
if not any(
|
||||
fnmatch.fnmatch(repo_file['Path'], pattern)
|
||||
for pattern in allow_patterns):
|
||||
continue
|
||||
|
||||
if allow_file_pattern is not None and allow_file_pattern:
|
||||
if not any(
|
||||
fnmatch.fnmatch(repo_file['Path'], pattern)
|
||||
for pattern in allow_file_pattern):
|
||||
continue
|
||||
# check model_file is exist in cache, if existed, skip download
|
||||
if pre_filtered:
|
||||
# Files are already filtered by patterns; only check cache
|
||||
filtered_repo_files = []
|
||||
for repo_file in repo_files:
|
||||
if cache.exists(repo_file):
|
||||
file_name = os.path.basename(repo_file['Name'])
|
||||
logger.debug(
|
||||
f'File {file_name} already in cache with identical hash, skip downloading!'
|
||||
)
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning('The file pattern is invalid : %s' % e)
|
||||
else:
|
||||
filtered_repo_files.append(repo_file)
|
||||
else:
|
||||
# Legacy path: apply pattern filtering + cache check
|
||||
ignore_patterns = _normalize_patterns(ignore_patterns)
|
||||
allow_patterns = _normalize_patterns(allow_patterns)
|
||||
ignore_file_pattern = _normalize_patterns(ignore_file_pattern)
|
||||
allow_file_pattern = _normalize_patterns(allow_file_pattern)
|
||||
# to compatible regex usage.
|
||||
ignore_regex_pattern = _get_valid_regex_pattern(ignore_file_pattern)
|
||||
|
||||
@thread_executor(max_workers=max_workers, disable_tqdm=False)
|
||||
filtered_repo_files = []
|
||||
for repo_file in repo_files:
|
||||
if repo_file['Type'] == 'tree':
|
||||
continue
|
||||
try:
|
||||
# processing patterns
|
||||
if ignore_patterns and any([
|
||||
fnmatch.fnmatch(repo_file['Path'], pattern)
|
||||
for pattern in ignore_patterns
|
||||
]):
|
||||
continue
|
||||
|
||||
if ignore_file_pattern and any([
|
||||
fnmatch.fnmatch(repo_file['Path'], pattern)
|
||||
for pattern in ignore_file_pattern
|
||||
]):
|
||||
continue
|
||||
|
||||
if ignore_regex_pattern and any([
|
||||
re.search(pattern, repo_file['Name']) is not None
|
||||
for pattern in ignore_regex_pattern
|
||||
]): # noqa E501
|
||||
continue
|
||||
|
||||
if allow_patterns is not None and allow_patterns:
|
||||
if not any(
|
||||
fnmatch.fnmatch(repo_file['Path'], pattern)
|
||||
for pattern in allow_patterns):
|
||||
continue
|
||||
|
||||
if allow_file_pattern is not None and allow_file_pattern:
|
||||
if not any(
|
||||
fnmatch.fnmatch(repo_file['Path'], pattern)
|
||||
for pattern in allow_file_pattern):
|
||||
continue
|
||||
# check model_file is exist in cache, if existed, skip download
|
||||
if cache.exists(repo_file):
|
||||
file_name = os.path.basename(repo_file['Name'])
|
||||
logger.debug(
|
||||
f'File {file_name} already in cache with identical hash, skip downloading!'
|
||||
)
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning('The file pattern is invalid : %s' % e)
|
||||
else:
|
||||
filtered_repo_files.append(repo_file)
|
||||
|
||||
@thread_executor(
|
||||
max_workers=max_workers, disable_tqdm=False, fault_tolerant=True)
|
||||
def _download_single_file(repo_file):
|
||||
if repo_type == REPO_TYPE_MODEL:
|
||||
url = get_file_download_url(
|
||||
@@ -602,5 +1027,26 @@ def _download_file_lists(
|
||||
if len(filtered_repo_files) > 0:
|
||||
logger.info(
|
||||
f'Got {len(filtered_repo_files)} files, start to download ...')
|
||||
_download_single_file(filtered_repo_files)
|
||||
logger.info(f"Download {repo_type} '{repo_id}' successfully.")
|
||||
download_result = _download_single_file(filtered_repo_files)
|
||||
|
||||
# Handle fault-tolerant results: report failed downloads
|
||||
failed_items = []
|
||||
if isinstance(download_result, tuple) and len(download_result) == 2:
|
||||
_, failed_items = download_result
|
||||
if failed_items:
|
||||
failed_paths = [
|
||||
item['Path'] if isinstance(item, dict) else str(item)
|
||||
for item, _ in failed_items
|
||||
]
|
||||
logger.error(
|
||||
f'{len(failed_items)} file(s) failed to download:\n'
|
||||
+ '\n'.join(f' - {p}' for p in failed_paths))
|
||||
|
||||
logger.info(
|
||||
f"Finish downloading {len(filtered_repo_files)} files for repo '{repo_id}'"
|
||||
)
|
||||
|
||||
if failed_items:
|
||||
raise FileDownloadError(
|
||||
f'{len(failed_items)} file(s) failed to download out of '
|
||||
f'{len(filtered_repo_files)}.')
|
||||
|
||||
@@ -66,6 +66,83 @@ def convert_patterns(raw_input: Union[str, List[str]]):
|
||||
return output
|
||||
|
||||
|
||||
def extract_root_from_patterns(
|
||||
allow_file_pattern: Optional[List[str]] = None,
|
||||
allow_patterns: Optional[List[str]] = None,
|
||||
) -> Optional[str]:
|
||||
"""Extract common directory prefix from include patterns for server-side filtering.
|
||||
|
||||
Only processes allow/include patterns (ignore/exclude is irrelevant for prefix).
|
||||
Returns None if no meaningful prefix can be extracted.
|
||||
|
||||
Algorithm:
|
||||
1. Merge allow_file_pattern and allow_patterns into one list
|
||||
2. For each pattern, find position of first wildcard char (* ? [)
|
||||
3. Extract text before that position, then take directory part (up to last '/')
|
||||
4. Find longest common directory prefix (path-segment-aware)
|
||||
5. Return None if no valid common prefix
|
||||
|
||||
Examples:
|
||||
['TacExo/*'] -> 'TacExo'
|
||||
['data/train/*.parquet'] -> 'data/train'
|
||||
['data/train/*', 'data/valid/*'] -> 'data'
|
||||
['*.safetensors'] -> None
|
||||
['TacExo/*', 'OtherDir/*'] -> None (no common prefix)
|
||||
['data/*/train.csv'] -> 'data'
|
||||
"""
|
||||
# Merge both pattern lists
|
||||
patterns = []
|
||||
if allow_file_pattern:
|
||||
patterns.extend(allow_file_pattern)
|
||||
if allow_patterns:
|
||||
patterns.extend(allow_patterns)
|
||||
|
||||
if not patterns:
|
||||
return None
|
||||
|
||||
extracted_dirs = []
|
||||
for pattern in patterns:
|
||||
# Find position of first wildcard character
|
||||
first_wildcard = len(pattern)
|
||||
for wc in ('*', '?', '['):
|
||||
pos = pattern.find(wc)
|
||||
if pos != -1:
|
||||
first_wildcard = min(first_wildcard, pos)
|
||||
|
||||
# Get text before wildcard
|
||||
prefix = pattern[:first_wildcard]
|
||||
|
||||
# Extract directory part (up to last '/')
|
||||
last_sep = prefix.rfind('/')
|
||||
if last_sep > 0:
|
||||
dir_part = prefix[:last_sep]
|
||||
# Validate: no wildcards should remain in dir_part
|
||||
if not any(c in dir_part for c in ('*', '?', '[')):
|
||||
extracted_dirs.append(dir_part)
|
||||
# If no '/' found or only at position 0, this pattern has no directory prefix
|
||||
|
||||
if not extracted_dirs:
|
||||
return None
|
||||
|
||||
# Find longest common directory prefix (path-segment-aware)
|
||||
if len(extracted_dirs) == 1:
|
||||
return extracted_dirs[0]
|
||||
|
||||
# Split all paths into segments and find common prefix segments
|
||||
split_paths = [d.split('/') for d in extracted_dirs]
|
||||
common_segments = []
|
||||
for segments in zip(*split_paths):
|
||||
if len(set(segments)) == 1:
|
||||
common_segments.append(segments[0])
|
||||
else:
|
||||
break
|
||||
|
||||
if not common_segments:
|
||||
return None
|
||||
|
||||
return '/'.join(common_segments)
|
||||
|
||||
|
||||
# during model download, the '.' would be converted to '___' to produce
|
||||
# actual physical (masked) directory for storage
|
||||
def get_model_masked_directory(directory, model_id):
|
||||
@@ -134,7 +211,8 @@ def get_endpoint(cn_site=True):
|
||||
|
||||
|
||||
def compute_hash(file_path):
|
||||
BUFFER_SIZE = 1024 * 64 # 64k buffer size
|
||||
# 16MB buffer for large file hash computation
|
||||
BUFFER_SIZE = 1024 * 1024 * 16
|
||||
sha256_hash = hashlib.sha256()
|
||||
with open(file_path, 'rb') as f:
|
||||
while True:
|
||||
|
||||
Reference in New Issue
Block a user