diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml index 58de50af..0da55ffc 100644 --- a/.github/workflows/docker-image.yml +++ b/.github/workflows/docker-image.yml @@ -36,6 +36,8 @@ run-name: Docker-${{ inputs.modelscope_branch }}-${{ inputs.image_type }}-${{ in jobs: build: runs-on: [modelscope-self-hosted-us] + env: + FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: 'true' steps: - name: ResetFileMode @@ -47,7 +49,7 @@ jobs: source ~/.bashrc sudo chown -R $USER:$USER $ACTION_RUNNER_DIR - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: ref: ${{ github.event.inputs.modelscope_branch }} diff --git a/docker/Dockerfile.extra_install b/docker/Dockerfile.extra_install index f8070bd3..8068c0b8 100644 --- a/docker/Dockerfile.extra_install +++ b/docker/Dockerfile.extra_install @@ -112,9 +112,10 @@ RUN set -eux; \ done # if this is called "PIP_VERSION", pip explodes with "ValueError: invalid truth value ''" -ENV PYTHON_PIP_VERSION 23.0.1 +# pip>=23.3 required for Python 3.12 (pkgutil.ImpImporter removed; older pip crashes on any install) +ENV PYTHON_PIP_VERSION 24.3.1 # https://github.com/docker-library/python/issues/365 -ENV PYTHON_SETUPTOOLS_VERSION 65.5.1 +ENV PYTHON_SETUPTOOLS_VERSION 75.8.2 # https://github.com/pypa/get-pip ENV PYTHON_GET_PIP_URL https://github.com/pypa/get-pip/raw/dbf0c85f76fb6e1ab42aa672ffca6f0a675d9ee4/public/get-pip.py ENV PYTHON_GET_PIP_SHA256 dfe9fd5c28dc98b5ac17979a953ea550cec37ae1b47a5116007395bfacff2ab9 diff --git a/docker/Dockerfile.ubuntu_base b/docker/Dockerfile.ubuntu_base index f56be6f7..b3a3ee33 100644 --- a/docker/Dockerfile.ubuntu_base +++ b/docker/Dockerfile.ubuntu_base @@ -112,9 +112,10 @@ RUN set -eux; \ done # if this is called "PIP_VERSION", pip explodes with "ValueError: invalid truth value ''" -ENV PYTHON_PIP_VERSION 23.0.1 +# pip>=23.3 required for Python 3.12 (pkgutil.ImpImporter removed; older pip crashes on any install) +ENV PYTHON_PIP_VERSION 24.3.1 # https://github.com/docker-library/python/issues/365 -ENV PYTHON_SETUPTOOLS_VERSION 65.5.1 +ENV PYTHON_SETUPTOOLS_VERSION 75.8.2 # https://github.com/pypa/get-pip ENV PYTHON_GET_PIP_URL https://github.com/pypa/get-pip/raw/dbf0c85f76fb6e1ab42aa672ffca6f0a675d9ee4/public/get-pip.py ENV PYTHON_GET_PIP_SHA256 dfe9fd5c28dc98b5ac17979a953ea550cec37ae1b47a5116007395bfacff2ab9 diff --git a/docker/build_image.py b/docker/build_image.py index 4a8b0c48..37d35787 100644 --- a/docker/build_image.py +++ b/docker/build_image.py @@ -54,6 +54,13 @@ class Builder: def generate_dockerfile(self) -> str: raise NotImplementedError + @staticmethod + def _remove_pynini_related_dependency(content: str) -> str: + return content.replace( + 'pip install --no-cache-dir funtextprocessing typeguard==2.13.3 scikit-learn -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html &&', # noqa: E501 + 'pip install --no-cache-dir typeguard==2.13.3 scikit-learn -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html &&' # noqa: E501 + ) + def _save_dockerfile(self, content: str) -> None: if os.path.exists('./Dockerfile'): os.remove('./Dockerfile') @@ -302,7 +309,7 @@ class StableCPUImageBuilder(Builder): content = content.replace('{modelscope_branch}', self.args.modelscope_branch) content = content.replace('{swift_branch}', self.args.swift_branch) - return content + return self._remove_pynini_related_dependency(content) def image(self) -> str: return ( @@ -363,7 +370,7 @@ RUN pip install --no-cache-dir -U icecream soundfile pybind11 py-spy content = content.replace('{modelscope_branch}', self.args.modelscope_branch) content = content.replace('{swift_branch}', self.args.swift_branch) - return content + return self._remove_pynini_related_dependency(content) def image(self) -> str: return ( diff --git a/modelscope/hub/file_download.py b/modelscope/hub/file_download.py index d8dc545f..41f4ebeb 100644 --- a/modelscope/hub/file_download.py +++ b/modelscope/hub/file_download.py @@ -39,6 +39,9 @@ from .utils.utils import (file_integrity_validation, get_endpoint, logger = get_logger() +# Maximum number of retries for hash validation failures +HASH_RETRY_TIMES = 3 + def model_file_download( model_id: str, @@ -711,49 +714,66 @@ def download_file( disable_tqdm=False, progress_callbacks: List[Type[ProgressCallback]] = None, ): - if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_meta[ - 'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: # parallel download large file. - file_digest = parallel_download( - url, - temporary_cache_dir, - file_meta['Path'], - headers=headers, - cookies=None if cookies is None else cookies.get_dict(), - file_size=file_meta['Size'], - disable_tqdm=disable_tqdm, - progress_callbacks=progress_callbacks, - ) - else: - file_digest = http_get_model_file( - url, - temporary_cache_dir, - file_meta['Path'], - file_size=file_meta['Size'], - headers=headers, - cookies=cookies, - disable_tqdm=disable_tqdm, - progress_callbacks=progress_callbacks, - ) - - # check file integrity temp_file = os.path.join(temporary_cache_dir, file_meta['Path']) - if FILE_HASH in file_meta: - expected_hash = file_meta[FILE_HASH] - if file_digest is not None: - if file_digest != expected_hash: - logger.warning( - 'Mismatched real-time digest for %s, falling back to full hash check', - file_meta['Path']) - if not file_integrity_validation(temp_file, expected_hash): - raise FileDownloadError( - 'File %s hash validation failed after download, ' - 'the file may be corrupted. Please retry.' - % file_meta['Path']) + + for hash_attempt in range(HASH_RETRY_TIMES): + if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_meta[ + 'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: # parallel download large file. + file_digest = parallel_download( + url, + temporary_cache_dir, + file_meta['Path'], + headers=headers, + cookies=None if cookies is None else cookies.get_dict(), + file_size=file_meta['Size'], + disable_tqdm=disable_tqdm, + progress_callbacks=progress_callbacks, + ) else: - if not file_integrity_validation(temp_file, expected_hash): - raise FileDownloadError( - 'File %s hash validation failed after download, ' - 'the file may be corrupted. Please retry.' - % file_meta['Path']) - # put file into to cache + file_digest = http_get_model_file( + url, + temporary_cache_dir, + file_meta['Path'], + file_size=file_meta['Size'], + headers=headers, + cookies=cookies, + disable_tqdm=disable_tqdm, + progress_callbacks=progress_callbacks, + ) + + # Check file integrity + if FILE_HASH in file_meta: + expected_hash = file_meta[FILE_HASH] + hash_valid = True + if file_digest is not None: + if file_digest != expected_hash: + logger.warning( + 'Mismatched real-time digest for %s, falling back to full hash check', + file_meta['Path']) + if not file_integrity_validation(temp_file, expected_hash): + hash_valid = False + else: + if not file_integrity_validation(temp_file, expected_hash): + hash_valid = False + + if not hash_valid: + if hash_attempt < HASH_RETRY_TIMES - 1: + logger.warning( + 'Hash validation failed for %s, ' + 'retrying download (attempt %d/%d)', file_meta['Path'], + hash_attempt + 1, HASH_RETRY_TIMES) + # Clean up corrupted file before retry + if os.path.exists(temp_file): + os.remove(temp_file) + continue + else: + raise FileDownloadError( + 'File %s hash validation failed after %d attempts, ' + 'the file may be corrupted.' % + (file_meta['Path'], HASH_RETRY_TIMES)) + + # Hash validation passed or no hash to validate, exit retry loop + break + + # Put file into cache return cache.put_file(file_meta, temp_file) diff --git a/modelscope/hub/snapshot_download.py b/modelscope/hub/snapshot_download.py index 635acd9c..28782785 100644 --- a/modelscope/hub/snapshot_download.py +++ b/modelscope/hub/snapshot_download.py @@ -3,12 +3,16 @@ import fnmatch import os import re +import threading import uuid +from concurrent.futures import ThreadPoolExecutor from contextlib import nullcontext from http.cookiejar import CookieJar from pathlib import Path from typing import Dict, List, Optional, Type, Union +from tqdm.auto import tqdm + from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, DEFAULT_MODEL_REVISION, INTRA_CLOUD_ACCELERATION, @@ -20,16 +24,19 @@ from modelscope.utils.thread_utils import thread_executor from .api import HubApi, ModelScopeConfig from .callback import ProgressCallback from .constants import DEFAULT_MAX_WORKERS -from .errors import InvalidParameter +from .errors import FileDownloadError, InvalidParameter from .file_download import (create_temporary_directory_and_cache, download_file, get_file_download_url) from .utils.caching import ModelFileSystemCache -from .utils.utils import (get_model_masked_directory, +from .utils.utils import (extract_root_from_patterns, + get_model_masked_directory, model_id_to_group_owner_name, strtobool, weak_file_lock) logger = get_logger() +DEFAULT_DATASET_PAGE_SIZE = 200 + def snapshot_download( model_id: str = None, @@ -335,13 +342,42 @@ def _snapshot_download( snapshot_header[ 'cached_model_revision'] = cache.cached_model_revision + # Extract server-side root filter from include patterns + extracted_root = extract_root_from_patterns( + allow_file_pattern=_normalize_patterns(allow_file_pattern), + allow_patterns=_normalize_patterns(allow_patterns)) + repo_files = _api.get_model_files( model_id=repo_id, revision=revision, + root=extracted_root, recursive=True, use_cookies=False if cookies is None else cookies, headers=snapshot_header, endpoint=endpoint) + + # Fallback: if root filter yielded no results, retry without it + if not repo_files and extracted_root is not None: + logger.warning( + f"root='{extracted_root}' returned no model files, " + f'falling back to root=None for full listing.') + repo_files = _api.get_model_files( + model_id=repo_id, + revision=revision, + root=None, + recursive=True, + use_cookies=False if cookies is None else cookies, + headers=snapshot_header, + endpoint=endpoint) + + # Apply client-side pattern filtering + repo_files = filter_files_by_patterns( + repo_files, + allow_file_pattern=allow_file_pattern, + ignore_file_pattern=ignore_file_pattern, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns) + _download_file_lists( repo_files, cache, @@ -354,10 +390,7 @@ def _snapshot_download( repo_type=repo_type, revision=revision, cookies=cookies, - ignore_file_pattern=ignore_file_pattern, - allow_file_pattern=allow_file_pattern, - ignore_patterns=ignore_patterns, - allow_patterns=allow_patterns, + pre_filtered=True, max_workers=max_workers, endpoint=endpoint, progress_callbacks=progress_callbacks, @@ -392,75 +425,148 @@ def _snapshot_download( group_or_owner, name = model_id_to_group_owner_name(repo_id) revision_detail = revision or DEFAULT_DATASET_REVISION - logger.info('Fetching dataset repo file list...') - repo_files = fetch_repo_files( - _api, repo_id, revision_detail, endpoint, token=token) + # Extract server-side root filter from include patterns + extracted_root = extract_root_from_patterns( + allow_file_pattern=_normalize_patterns(allow_file_pattern), + allow_patterns=_normalize_patterns(allow_patterns)) + root_path = '/' + extracted_root if extracted_root else '/' - if repo_files is None: - logger.error( - f'Failed to retrieve file list for dataset: {repo_id}') - return None - - _download_file_lists( - repo_files, - cache, - temporary_cache_dir, - repo_id, + print(f'Fetching file list (root: {root_path})...') + file_page_iter = _iter_dataset_file_pages( _api, - name, - group_or_owner, - headers, - repo_type=repo_type, + repo_id, + revision_detail, + endpoint, + token=token, + root_path=root_path, + allow_file_pattern=allow_file_pattern, + ignore_file_pattern=ignore_file_pattern, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns) + + _pipeline_download_dataset( + file_page_iter, + cache=cache, + temporary_cache_dir=temporary_cache_dir, + repo_id=repo_id, + api=_api, + dataset_name=name, + namespace=group_or_owner, + headers=headers, revision=revision, cookies=cookies, - ignore_file_pattern=ignore_file_pattern, - allow_file_pattern=allow_file_pattern, - ignore_patterns=ignore_patterns, - allow_patterns=allow_patterns, max_workers=max_workers, endpoint=endpoint, - progress_callbacks=progress_callbacks, - ) + progress_callbacks=progress_callbacks) cache.save_model_version(revision_info=revision_detail) cache_root_path = cache.get_root_location() return cache_root_path -def fetch_repo_files(_api, repo_id, revision, endpoint, token=None): - _owner, _dataset_name = repo_id.split('/') +def fetch_repo_files( + _api, + repo_id, + revision, + endpoint, + token=None, + root_path='/', + allow_file_pattern=None, + ignore_file_pattern=None, + allow_patterns=None, + ignore_patterns=None, + page_size=DEFAULT_DATASET_PAGE_SIZE, +): + """Fetch and filter dataset repo files with pagination and server-side prefix filtering. + + Applies per-page pattern filtering to minimize memory usage. + Falls back to root_path='/' if the extracted prefix yields no results. + + Args: + _api: HubApi instance. + repo_id: Dataset repo identifier (owner/name). + revision: Git revision. + endpoint: API endpoint URL. + token: Authentication token. + root_path: Server-side directory prefix filter. + allow_file_pattern: Include patterns for client-side filtering. + ignore_file_pattern: Exclude patterns for client-side filtering. + allow_patterns: Additional include patterns (HF-compatible). + ignore_patterns: Additional exclude patterns (HF-compatible). + page_size: Number of files per API page request. + + Returns: + List of filtered file entry dicts. + """ + if '/' not in repo_id: + raise InvalidParameter( + f"Invalid repo_id: '{repo_id}', expected format 'owner/name'") + _owner, _dataset_name = repo_id.split('/', 1) _hub_id, _ = _api.get_dataset_id_and_type( dataset_name=_dataset_name, namespace=_owner, endpoint=endpoint, token=token) - page_number = 1 - page_size = 150 - repo_files = [] + has_patterns = any([ + allow_file_pattern, ignore_file_pattern, allow_patterns, + ignore_patterns + ]) - while True: - try: - dataset_files = _api.get_dataset_files( - repo_id=repo_id, - revision=revision, - root_path='/', - recursive=True, - page_number=page_number, - page_size=page_size, - endpoint=endpoint, - token=token, - dataset_hub_id=_hub_id) - except Exception as e: - logger.error(f'Error fetching dataset files: {e}') - break + def _paginate_and_filter(effective_root_path): + """Fetch all pages with the given root_path, applying per-page filtering.""" + page_number = 1 + repo_files = [] - repo_files.extend(dataset_files) + while True: + try: + dataset_files = _api.get_dataset_files( + repo_id=repo_id, + revision=revision, + root_path=effective_root_path, + recursive=True, + page_number=page_number, + page_size=page_size, + endpoint=endpoint, + token=token, + dataset_hub_id=_hub_id) + except Exception as e: + logger.error( + f'Error fetching dataset files (page {page_number}): {e}') + break - if len(dataset_files) < page_size: - break + if not dataset_files: + break - page_number += 1 + # Per-page filtering: apply patterns immediately to reduce memory + if has_patterns: + page_filtered = filter_files_by_patterns( + dataset_files, + allow_file_pattern=allow_file_pattern, + ignore_file_pattern=ignore_file_pattern, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns) + repo_files.extend(page_filtered) + else: + # No patterns: keep all non-tree entries + repo_files.extend( + f for f in dataset_files if f.get('Type') != 'tree') + + if len(dataset_files) < page_size: + break + + page_number += 1 + + return repo_files + + # Primary fetch with optimized root_path + repo_files = _paginate_and_filter(root_path) + + # Fallback: if optimized root_path yielded nothing and it's not the default + if not repo_files and root_path != '/': + logger.warning(f"root_path='{root_path}' returned no results, " + f"falling back to root_path='/' for full listing.") + repo_files = _paginate_and_filter('/') return repo_files @@ -494,6 +600,310 @@ def _get_valid_regex_pattern(patterns: List[str]): return None +def filter_files_by_patterns( + repo_files: List[dict], + *, + allow_file_pattern: Optional[List[str]] = None, + ignore_file_pattern: Optional[List[str]] = None, + allow_patterns: Optional[List[str]] = None, + ignore_patterns: Optional[List[str]] = None, +) -> List[dict]: + """Filter repo file entries by include/exclude patterns. + + Skips 'tree' type entries. Applies fnmatch and regex pattern matching. + Returns only file entries that pass all filter criteria. + + Args: + repo_files: List of file entry dicts with 'Type', 'Path', 'Name' keys. + allow_file_pattern: Include patterns (fnmatch). Files must match at least one. + ignore_file_pattern: Exclude patterns (fnmatch). Matching files are skipped. + allow_patterns: Additional include patterns (HF-compatible). + ignore_patterns: Additional exclude patterns (HF-compatible). + + Returns: + List of file entries that pass all filters. + """ + ignore_patterns = _normalize_patterns(ignore_patterns) + allow_patterns = _normalize_patterns(allow_patterns) + ignore_file_pattern = _normalize_patterns(ignore_file_pattern) + allow_file_pattern = _normalize_patterns(allow_file_pattern) + ignore_regex_pattern = _get_valid_regex_pattern(ignore_file_pattern) + + filtered = [] + for repo_file in repo_files: + if repo_file['Type'] == 'tree': + continue + try: + if ignore_patterns and any( + fnmatch.fnmatch(repo_file['Path'], p) + for p in ignore_patterns): + continue + + if ignore_file_pattern and any( + fnmatch.fnmatch(repo_file['Path'], p) + for p in ignore_file_pattern): + continue + + if ignore_regex_pattern and any( + re.search(p, repo_file['Name']) is not None + for p in ignore_regex_pattern): + continue + + if allow_patterns and not any( + fnmatch.fnmatch(repo_file['Path'], p) + for p in allow_patterns): + continue + + if allow_file_pattern and not any( + fnmatch.fnmatch(repo_file['Path'], p) + for p in allow_file_pattern): + continue + except Exception as e: + logger.warning('Invalid file pattern: %s' % e) + continue + + filtered.append(repo_file) + + return filtered + + +def _iter_dataset_file_pages( + _api, + repo_id, + revision, + endpoint, + token=None, + root_path='/', + allow_file_pattern=None, + ignore_file_pattern=None, + allow_patterns=None, + ignore_patterns=None, + page_size=DEFAULT_DATASET_PAGE_SIZE, +): + """Generator that yields filtered file pages from a dataset repo. + + Each yield is a non-empty list of file-entry dicts for one API page. + Applies per-page pattern filtering to minimize memory usage. + Falls back to root_path='/' if the extracted prefix yields no results. + + Args: + _api: HubApi instance. + repo_id: Dataset repo identifier (owner/name). + revision: Git revision. + endpoint: API endpoint URL. + token: Authentication token. + root_path: Server-side directory prefix filter. + allow_file_pattern: Include patterns (fnmatch). + ignore_file_pattern: Exclude patterns (fnmatch). + allow_patterns: Additional include patterns (HF-compatible). + ignore_patterns: Additional exclude patterns (HF-compatible). + page_size: Number of files per API page request. + + Yields: + List[dict]: Non-empty list of filtered file entries per page. + """ + if '/' not in repo_id: + raise InvalidParameter( + f"Invalid repo_id: '{repo_id}', expected format 'owner/name'") + + _owner, _dataset_name = repo_id.split('/', 1) + _hub_id, _ = _api.get_dataset_id_and_type( + dataset_name=_dataset_name, + namespace=_owner, + endpoint=endpoint, + token=token) + + has_patterns = any([ + allow_file_pattern, ignore_file_pattern, allow_patterns, + ignore_patterns + ]) + + def _paginate_pages(effective_root_path): + """Yield filtered file pages for the given root_path.""" + page_number = 1 + total_found = 0 + + while True: + try: + dataset_files = _api.get_dataset_files( + repo_id=repo_id, + revision=revision, + root_path=effective_root_path, + recursive=True, + page_number=page_number, + page_size=page_size, + endpoint=endpoint, + token=token, + dataset_hub_id=_hub_id) + except Exception as e: + logger.error( + f'Error fetching dataset files (page {page_number}): {e}') + break + + if not dataset_files: + break + + # Per-page filtering to reduce memory footprint + if has_patterns: + page_filtered = filter_files_by_patterns( + dataset_files, + allow_file_pattern=allow_file_pattern, + ignore_file_pattern=ignore_file_pattern, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns) + else: + # No patterns: keep all non-tree entries + page_filtered = [ + f for f in dataset_files if f.get('Type') != 'tree' + ] + + total_found += len(page_filtered) + if page_filtered: + yield page_filtered + + print( + f'\r Fetched {total_found} matching files ' + f'({page_number} pages)...', + end='', + flush=True) + + if len(dataset_files) < page_size: + break + + page_number += 1 + + # Primary fetch with optimized root_path + try: + yielded_any = False + for page in _paginate_pages(root_path): + yielded_any = True + yield page + + # Fallback: if optimized root_path yielded nothing and it's not the default + if not yielded_any and root_path != '/': + print(f"\n root_path='{root_path}' returned no results, " + f"falling back to root_path='/' for full listing.") + for page in _paginate_pages('/'): + yield page + finally: + # Terminate the \r progress line regardless of how iteration ends + print() + + +def _pipeline_download_dataset( + file_page_iter, + cache, + temporary_cache_dir, + repo_id, + api, + dataset_name, + namespace, + headers, + revision, + cookies, + max_workers=DEFAULT_MAX_WORKERS, + endpoint=None, + progress_callbacks=None, +): + """Pipeline consumer: download dataset files as pages are yielded. + + Consumes the page iterator from _iter_dataset_file_pages, submitting + each file to a thread pool for concurrent download. Uses tqdm for + real-time progress and thread-safe error collection. + + Args: + file_page_iter: Iterator yielding List[dict] file pages. + cache: ModelFileSystemCache instance for dedup. + temporary_cache_dir: Temp staging directory. + repo_id: Dataset repo identifier. + api: HubApi instance. + dataset_name: Dataset name component. + namespace: Owner/namespace component. + headers: HTTP request headers. + revision: Git revision. + cookies: HTTP cookies. + max_workers: Thread pool concurrency. + endpoint: API endpoint URL. + progress_callbacks: Optional progress callback list. + """ + total_found = 0 + total_cached = 0 + failed_items = [] + lock = threading.Lock() + + def _on_done(future, repo_file): + """Done callback: update progress bar and collect failures.""" + try: + future.result() + except Exception as exc: + with lock: + failed_items.append((repo_file, exc)) + logger.debug( + f"Download failed for {repo_file.get('Path', '?')}: {exc}") + finally: + pbar.update(1) + + # tqdm wraps the executor so all callbacks fire before pbar closes + with tqdm(total=0, unit=' files', disable=False) as pbar: + with ThreadPoolExecutor(max_workers=max_workers) as executor: + for page_files in file_page_iter: + for repo_file in page_files: + total_found += 1 + pbar.total = total_found + pbar.refresh() + + # Skip files already in cache + if cache.exists(repo_file): + total_cached += 1 + pbar.update(1) + continue + + # Build download URL + url = api.get_dataset_file_url( + file_name=repo_file['Path'], + dataset_name=dataset_name, + namespace=namespace, + revision=revision, + endpoint=endpoint) + + # Submit download task + future = executor.submit( + download_file, + url, + repo_file, + temporary_cache_dir, + cache, + headers, + cookies, + disable_tqdm=False, + progress_callbacks=progress_callbacks, + ) + future.add_done_callback( + lambda f, rf=repo_file: _on_done(f, rf)) + + # Executor __exit__ waits for all futures to complete + + # Report failures after progress bar closes + if failed_items: + failed_paths = [ + item.get('Path', '?') if isinstance(item, dict) else str(item) + for item, _ in failed_items + ] + logger.error(f'{len(failed_items)} file(s) failed to download:\n' + + '\n'.join(f' - {p}' for p in failed_paths)) + + # Completion summary (always print, even if raising after) + downloaded = total_found - total_cached - len(failed_items) + print(f'Download complete: {total_found} files found, ' + f'{total_cached} cached, {downloaded} downloaded' + + (f', {len(failed_items)} failed' if failed_items else '') + '.') + + if failed_items: + raise FileDownloadError( + f'{len(failed_items)} file(s) failed to download out of ' + f'{total_found}.') + + def _download_file_lists( repo_files: List[str], cache: ModelFileSystemCache, @@ -513,62 +923,77 @@ def _download_file_lists( max_workers: int = 8, endpoint: Optional[str] = None, progress_callbacks: List[Type[ProgressCallback]] = None, + pre_filtered: bool = False, ): - ignore_patterns = _normalize_patterns(ignore_patterns) - allow_patterns = _normalize_patterns(allow_patterns) - ignore_file_pattern = _normalize_patterns(ignore_file_pattern) - allow_file_pattern = _normalize_patterns(allow_file_pattern) - # to compatible regex usage. - ignore_regex_pattern = _get_valid_regex_pattern(ignore_file_pattern) - - filtered_repo_files = [] - for repo_file in repo_files: - if repo_file['Type'] == 'tree': - continue - try: - # processing patterns - if ignore_patterns and any([ - fnmatch.fnmatch(repo_file['Path'], pattern) - for pattern in ignore_patterns - ]): - continue - - if ignore_file_pattern and any([ - fnmatch.fnmatch(repo_file['Path'], pattern) - for pattern in ignore_file_pattern - ]): - continue - - if ignore_regex_pattern and any([ - re.search(pattern, repo_file['Name']) is not None - for pattern in ignore_regex_pattern - ]): # noqa E501 - continue - - if allow_patterns is not None and allow_patterns: - if not any( - fnmatch.fnmatch(repo_file['Path'], pattern) - for pattern in allow_patterns): - continue - - if allow_file_pattern is not None and allow_file_pattern: - if not any( - fnmatch.fnmatch(repo_file['Path'], pattern) - for pattern in allow_file_pattern): - continue - # check model_file is exist in cache, if existed, skip download + if pre_filtered: + # Files are already filtered by patterns; only check cache + filtered_repo_files = [] + for repo_file in repo_files: if cache.exists(repo_file): file_name = os.path.basename(repo_file['Name']) logger.debug( f'File {file_name} already in cache with identical hash, skip downloading!' ) continue - except Exception as e: - logger.warning('The file pattern is invalid : %s' % e) - else: filtered_repo_files.append(repo_file) + else: + # Legacy path: apply pattern filtering + cache check + ignore_patterns = _normalize_patterns(ignore_patterns) + allow_patterns = _normalize_patterns(allow_patterns) + ignore_file_pattern = _normalize_patterns(ignore_file_pattern) + allow_file_pattern = _normalize_patterns(allow_file_pattern) + # to compatible regex usage. + ignore_regex_pattern = _get_valid_regex_pattern(ignore_file_pattern) - @thread_executor(max_workers=max_workers, disable_tqdm=False) + filtered_repo_files = [] + for repo_file in repo_files: + if repo_file['Type'] == 'tree': + continue + try: + # processing patterns + if ignore_patterns and any([ + fnmatch.fnmatch(repo_file['Path'], pattern) + for pattern in ignore_patterns + ]): + continue + + if ignore_file_pattern and any([ + fnmatch.fnmatch(repo_file['Path'], pattern) + for pattern in ignore_file_pattern + ]): + continue + + if ignore_regex_pattern and any([ + re.search(pattern, repo_file['Name']) is not None + for pattern in ignore_regex_pattern + ]): # noqa E501 + continue + + if allow_patterns is not None and allow_patterns: + if not any( + fnmatch.fnmatch(repo_file['Path'], pattern) + for pattern in allow_patterns): + continue + + if allow_file_pattern is not None and allow_file_pattern: + if not any( + fnmatch.fnmatch(repo_file['Path'], pattern) + for pattern in allow_file_pattern): + continue + # check model_file is exist in cache, if existed, skip download + if cache.exists(repo_file): + file_name = os.path.basename(repo_file['Name']) + logger.debug( + f'File {file_name} already in cache with identical hash, skip downloading!' + ) + continue + except Exception as e: + logger.warning('The file pattern is invalid : %s' % e) + else: + filtered_repo_files.append(repo_file) + + @thread_executor( + max_workers=max_workers, disable_tqdm=False, fault_tolerant=True) def _download_single_file(repo_file): if repo_type == REPO_TYPE_MODEL: url = get_file_download_url( @@ -602,5 +1027,26 @@ def _download_file_lists( if len(filtered_repo_files) > 0: logger.info( f'Got {len(filtered_repo_files)} files, start to download ...') - _download_single_file(filtered_repo_files) - logger.info(f"Download {repo_type} '{repo_id}' successfully.") + download_result = _download_single_file(filtered_repo_files) + + # Handle fault-tolerant results: report failed downloads + failed_items = [] + if isinstance(download_result, tuple) and len(download_result) == 2: + _, failed_items = download_result + if failed_items: + failed_paths = [ + item['Path'] if isinstance(item, dict) else str(item) + for item, _ in failed_items + ] + logger.error( + f'{len(failed_items)} file(s) failed to download:\n' + + '\n'.join(f' - {p}' for p in failed_paths)) + + logger.info( + f"Finish downloading {len(filtered_repo_files)} files for repo '{repo_id}'" + ) + + if failed_items: + raise FileDownloadError( + f'{len(failed_items)} file(s) failed to download out of ' + f'{len(filtered_repo_files)}.') diff --git a/modelscope/hub/utils/utils.py b/modelscope/hub/utils/utils.py index e7b3c502..d93e8a90 100644 --- a/modelscope/hub/utils/utils.py +++ b/modelscope/hub/utils/utils.py @@ -66,6 +66,83 @@ def convert_patterns(raw_input: Union[str, List[str]]): return output +def extract_root_from_patterns( + allow_file_pattern: Optional[List[str]] = None, + allow_patterns: Optional[List[str]] = None, +) -> Optional[str]: + """Extract common directory prefix from include patterns for server-side filtering. + + Only processes allow/include patterns (ignore/exclude is irrelevant for prefix). + Returns None if no meaningful prefix can be extracted. + + Algorithm: + 1. Merge allow_file_pattern and allow_patterns into one list + 2. For each pattern, find position of first wildcard char (* ? [) + 3. Extract text before that position, then take directory part (up to last '/') + 4. Find longest common directory prefix (path-segment-aware) + 5. Return None if no valid common prefix + + Examples: + ['TacExo/*'] -> 'TacExo' + ['data/train/*.parquet'] -> 'data/train' + ['data/train/*', 'data/valid/*'] -> 'data' + ['*.safetensors'] -> None + ['TacExo/*', 'OtherDir/*'] -> None (no common prefix) + ['data/*/train.csv'] -> 'data' + """ + # Merge both pattern lists + patterns = [] + if allow_file_pattern: + patterns.extend(allow_file_pattern) + if allow_patterns: + patterns.extend(allow_patterns) + + if not patterns: + return None + + extracted_dirs = [] + for pattern in patterns: + # Find position of first wildcard character + first_wildcard = len(pattern) + for wc in ('*', '?', '['): + pos = pattern.find(wc) + if pos != -1: + first_wildcard = min(first_wildcard, pos) + + # Get text before wildcard + prefix = pattern[:first_wildcard] + + # Extract directory part (up to last '/') + last_sep = prefix.rfind('/') + if last_sep > 0: + dir_part = prefix[:last_sep] + # Validate: no wildcards should remain in dir_part + if not any(c in dir_part for c in ('*', '?', '[')): + extracted_dirs.append(dir_part) + # If no '/' found or only at position 0, this pattern has no directory prefix + + if not extracted_dirs: + return None + + # Find longest common directory prefix (path-segment-aware) + if len(extracted_dirs) == 1: + return extracted_dirs[0] + + # Split all paths into segments and find common prefix segments + split_paths = [d.split('/') for d in extracted_dirs] + common_segments = [] + for segments in zip(*split_paths): + if len(set(segments)) == 1: + common_segments.append(segments[0]) + else: + break + + if not common_segments: + return None + + return '/'.join(common_segments) + + # during model download, the '.' would be converted to '___' to produce # actual physical (masked) directory for storage def get_model_masked_directory(directory, model_id): @@ -134,7 +211,8 @@ def get_endpoint(cn_site=True): def compute_hash(file_path): - BUFFER_SIZE = 1024 * 64 # 64k buffer size + # 16MB buffer for large file hash computation + BUFFER_SIZE = 1024 * 1024 * 16 sha256_hash = hashlib.sha256() with open(file_path, 'rb') as f: while True: