mirror of
https://github.com/modelscope/modelscope.git
synced 2026-05-18 05:05:00 +02:00
[Refactor] Refactor the modelscope download module (#1683)
This commit is contained in:
4
.github/workflows/docker-image.yml
vendored
4
.github/workflows/docker-image.yml
vendored
@@ -36,6 +36,8 @@ run-name: Docker-${{ inputs.modelscope_branch }}-${{ inputs.image_type }}-${{ in
|
|||||||
jobs:
|
jobs:
|
||||||
build:
|
build:
|
||||||
runs-on: [modelscope-self-hosted-us]
|
runs-on: [modelscope-self-hosted-us]
|
||||||
|
env:
|
||||||
|
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: 'true'
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: ResetFileMode
|
- name: ResetFileMode
|
||||||
@@ -47,7 +49,7 @@ jobs:
|
|||||||
source ~/.bashrc
|
source ~/.bashrc
|
||||||
sudo chown -R $USER:$USER $ACTION_RUNNER_DIR
|
sudo chown -R $USER:$USER $ACTION_RUNNER_DIR
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v5
|
||||||
with:
|
with:
|
||||||
ref: ${{ github.event.inputs.modelscope_branch }}
|
ref: ${{ github.event.inputs.modelscope_branch }}
|
||||||
|
|
||||||
|
|||||||
@@ -112,9 +112,10 @@ RUN set -eux; \
|
|||||||
done
|
done
|
||||||
|
|
||||||
# if this is called "PIP_VERSION", pip explodes with "ValueError: invalid truth value '<VERSION>'"
|
# if this is called "PIP_VERSION", pip explodes with "ValueError: invalid truth value '<VERSION>'"
|
||||||
ENV PYTHON_PIP_VERSION 23.0.1
|
# pip>=23.3 required for Python 3.12 (pkgutil.ImpImporter removed; older pip crashes on any install)
|
||||||
|
ENV PYTHON_PIP_VERSION 24.3.1
|
||||||
# https://github.com/docker-library/python/issues/365
|
# https://github.com/docker-library/python/issues/365
|
||||||
ENV PYTHON_SETUPTOOLS_VERSION 65.5.1
|
ENV PYTHON_SETUPTOOLS_VERSION 75.8.2
|
||||||
# https://github.com/pypa/get-pip
|
# https://github.com/pypa/get-pip
|
||||||
ENV PYTHON_GET_PIP_URL https://github.com/pypa/get-pip/raw/dbf0c85f76fb6e1ab42aa672ffca6f0a675d9ee4/public/get-pip.py
|
ENV PYTHON_GET_PIP_URL https://github.com/pypa/get-pip/raw/dbf0c85f76fb6e1ab42aa672ffca6f0a675d9ee4/public/get-pip.py
|
||||||
ENV PYTHON_GET_PIP_SHA256 dfe9fd5c28dc98b5ac17979a953ea550cec37ae1b47a5116007395bfacff2ab9
|
ENV PYTHON_GET_PIP_SHA256 dfe9fd5c28dc98b5ac17979a953ea550cec37ae1b47a5116007395bfacff2ab9
|
||||||
|
|||||||
@@ -112,9 +112,10 @@ RUN set -eux; \
|
|||||||
done
|
done
|
||||||
|
|
||||||
# if this is called "PIP_VERSION", pip explodes with "ValueError: invalid truth value '<VERSION>'"
|
# if this is called "PIP_VERSION", pip explodes with "ValueError: invalid truth value '<VERSION>'"
|
||||||
ENV PYTHON_PIP_VERSION 23.0.1
|
# pip>=23.3 required for Python 3.12 (pkgutil.ImpImporter removed; older pip crashes on any install)
|
||||||
|
ENV PYTHON_PIP_VERSION 24.3.1
|
||||||
# https://github.com/docker-library/python/issues/365
|
# https://github.com/docker-library/python/issues/365
|
||||||
ENV PYTHON_SETUPTOOLS_VERSION 65.5.1
|
ENV PYTHON_SETUPTOOLS_VERSION 75.8.2
|
||||||
# https://github.com/pypa/get-pip
|
# https://github.com/pypa/get-pip
|
||||||
ENV PYTHON_GET_PIP_URL https://github.com/pypa/get-pip/raw/dbf0c85f76fb6e1ab42aa672ffca6f0a675d9ee4/public/get-pip.py
|
ENV PYTHON_GET_PIP_URL https://github.com/pypa/get-pip/raw/dbf0c85f76fb6e1ab42aa672ffca6f0a675d9ee4/public/get-pip.py
|
||||||
ENV PYTHON_GET_PIP_SHA256 dfe9fd5c28dc98b5ac17979a953ea550cec37ae1b47a5116007395bfacff2ab9
|
ENV PYTHON_GET_PIP_SHA256 dfe9fd5c28dc98b5ac17979a953ea550cec37ae1b47a5116007395bfacff2ab9
|
||||||
|
|||||||
@@ -54,6 +54,13 @@ class Builder:
|
|||||||
def generate_dockerfile(self) -> str:
|
def generate_dockerfile(self) -> str:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _remove_pynini_related_dependency(content: str) -> str:
|
||||||
|
return content.replace(
|
||||||
|
'pip install --no-cache-dir funtextprocessing typeguard==2.13.3 scikit-learn -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html &&', # noqa: E501
|
||||||
|
'pip install --no-cache-dir typeguard==2.13.3 scikit-learn -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html &&' # noqa: E501
|
||||||
|
)
|
||||||
|
|
||||||
def _save_dockerfile(self, content: str) -> None:
|
def _save_dockerfile(self, content: str) -> None:
|
||||||
if os.path.exists('./Dockerfile'):
|
if os.path.exists('./Dockerfile'):
|
||||||
os.remove('./Dockerfile')
|
os.remove('./Dockerfile')
|
||||||
@@ -302,7 +309,7 @@ class StableCPUImageBuilder(Builder):
|
|||||||
content = content.replace('{modelscope_branch}',
|
content = content.replace('{modelscope_branch}',
|
||||||
self.args.modelscope_branch)
|
self.args.modelscope_branch)
|
||||||
content = content.replace('{swift_branch}', self.args.swift_branch)
|
content = content.replace('{swift_branch}', self.args.swift_branch)
|
||||||
return content
|
return self._remove_pynini_related_dependency(content)
|
||||||
|
|
||||||
def image(self) -> str:
|
def image(self) -> str:
|
||||||
return (
|
return (
|
||||||
@@ -363,7 +370,7 @@ RUN pip install --no-cache-dir -U icecream soundfile pybind11 py-spy
|
|||||||
content = content.replace('{modelscope_branch}',
|
content = content.replace('{modelscope_branch}',
|
||||||
self.args.modelscope_branch)
|
self.args.modelscope_branch)
|
||||||
content = content.replace('{swift_branch}', self.args.swift_branch)
|
content = content.replace('{swift_branch}', self.args.swift_branch)
|
||||||
return content
|
return self._remove_pynini_related_dependency(content)
|
||||||
|
|
||||||
def image(self) -> str:
|
def image(self) -> str:
|
||||||
return (
|
return (
|
||||||
|
|||||||
@@ -39,6 +39,9 @@ from .utils.utils import (file_integrity_validation, get_endpoint,
|
|||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
# Maximum number of retries for hash validation failures
|
||||||
|
HASH_RETRY_TIMES = 3
|
||||||
|
|
||||||
|
|
||||||
def model_file_download(
|
def model_file_download(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
@@ -711,49 +714,66 @@ def download_file(
|
|||||||
disable_tqdm=False,
|
disable_tqdm=False,
|
||||||
progress_callbacks: List[Type[ProgressCallback]] = None,
|
progress_callbacks: List[Type[ProgressCallback]] = None,
|
||||||
):
|
):
|
||||||
if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_meta[
|
|
||||||
'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: # parallel download large file.
|
|
||||||
file_digest = parallel_download(
|
|
||||||
url,
|
|
||||||
temporary_cache_dir,
|
|
||||||
file_meta['Path'],
|
|
||||||
headers=headers,
|
|
||||||
cookies=None if cookies is None else cookies.get_dict(),
|
|
||||||
file_size=file_meta['Size'],
|
|
||||||
disable_tqdm=disable_tqdm,
|
|
||||||
progress_callbacks=progress_callbacks,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
file_digest = http_get_model_file(
|
|
||||||
url,
|
|
||||||
temporary_cache_dir,
|
|
||||||
file_meta['Path'],
|
|
||||||
file_size=file_meta['Size'],
|
|
||||||
headers=headers,
|
|
||||||
cookies=cookies,
|
|
||||||
disable_tqdm=disable_tqdm,
|
|
||||||
progress_callbacks=progress_callbacks,
|
|
||||||
)
|
|
||||||
|
|
||||||
# check file integrity
|
|
||||||
temp_file = os.path.join(temporary_cache_dir, file_meta['Path'])
|
temp_file = os.path.join(temporary_cache_dir, file_meta['Path'])
|
||||||
if FILE_HASH in file_meta:
|
|
||||||
expected_hash = file_meta[FILE_HASH]
|
for hash_attempt in range(HASH_RETRY_TIMES):
|
||||||
if file_digest is not None:
|
if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_meta[
|
||||||
if file_digest != expected_hash:
|
'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: # parallel download large file.
|
||||||
logger.warning(
|
file_digest = parallel_download(
|
||||||
'Mismatched real-time digest for %s, falling back to full hash check',
|
url,
|
||||||
file_meta['Path'])
|
temporary_cache_dir,
|
||||||
if not file_integrity_validation(temp_file, expected_hash):
|
file_meta['Path'],
|
||||||
raise FileDownloadError(
|
headers=headers,
|
||||||
'File %s hash validation failed after download, '
|
cookies=None if cookies is None else cookies.get_dict(),
|
||||||
'the file may be corrupted. Please retry.'
|
file_size=file_meta['Size'],
|
||||||
% file_meta['Path'])
|
disable_tqdm=disable_tqdm,
|
||||||
|
progress_callbacks=progress_callbacks,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if not file_integrity_validation(temp_file, expected_hash):
|
file_digest = http_get_model_file(
|
||||||
raise FileDownloadError(
|
url,
|
||||||
'File %s hash validation failed after download, '
|
temporary_cache_dir,
|
||||||
'the file may be corrupted. Please retry.'
|
file_meta['Path'],
|
||||||
% file_meta['Path'])
|
file_size=file_meta['Size'],
|
||||||
# put file into to cache
|
headers=headers,
|
||||||
|
cookies=cookies,
|
||||||
|
disable_tqdm=disable_tqdm,
|
||||||
|
progress_callbacks=progress_callbacks,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check file integrity
|
||||||
|
if FILE_HASH in file_meta:
|
||||||
|
expected_hash = file_meta[FILE_HASH]
|
||||||
|
hash_valid = True
|
||||||
|
if file_digest is not None:
|
||||||
|
if file_digest != expected_hash:
|
||||||
|
logger.warning(
|
||||||
|
'Mismatched real-time digest for %s, falling back to full hash check',
|
||||||
|
file_meta['Path'])
|
||||||
|
if not file_integrity_validation(temp_file, expected_hash):
|
||||||
|
hash_valid = False
|
||||||
|
else:
|
||||||
|
if not file_integrity_validation(temp_file, expected_hash):
|
||||||
|
hash_valid = False
|
||||||
|
|
||||||
|
if not hash_valid:
|
||||||
|
if hash_attempt < HASH_RETRY_TIMES - 1:
|
||||||
|
logger.warning(
|
||||||
|
'Hash validation failed for %s, '
|
||||||
|
'retrying download (attempt %d/%d)', file_meta['Path'],
|
||||||
|
hash_attempt + 1, HASH_RETRY_TIMES)
|
||||||
|
# Clean up corrupted file before retry
|
||||||
|
if os.path.exists(temp_file):
|
||||||
|
os.remove(temp_file)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
raise FileDownloadError(
|
||||||
|
'File %s hash validation failed after %d attempts, '
|
||||||
|
'the file may be corrupted.' %
|
||||||
|
(file_meta['Path'], HASH_RETRY_TIMES))
|
||||||
|
|
||||||
|
# Hash validation passed or no hash to validate, exit retry loop
|
||||||
|
break
|
||||||
|
|
||||||
|
# Put file into cache
|
||||||
return cache.put_file(file_meta, temp_file)
|
return cache.put_file(file_meta, temp_file)
|
||||||
|
|||||||
@@ -3,12 +3,16 @@
|
|||||||
import fnmatch
|
import fnmatch
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from http.cookiejar import CookieJar
|
from http.cookiejar import CookieJar
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Type, Union
|
from typing import Dict, List, Optional, Type, Union
|
||||||
|
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
|
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
|
||||||
DEFAULT_MODEL_REVISION,
|
DEFAULT_MODEL_REVISION,
|
||||||
INTRA_CLOUD_ACCELERATION,
|
INTRA_CLOUD_ACCELERATION,
|
||||||
@@ -20,16 +24,19 @@ from modelscope.utils.thread_utils import thread_executor
|
|||||||
from .api import HubApi, ModelScopeConfig
|
from .api import HubApi, ModelScopeConfig
|
||||||
from .callback import ProgressCallback
|
from .callback import ProgressCallback
|
||||||
from .constants import DEFAULT_MAX_WORKERS
|
from .constants import DEFAULT_MAX_WORKERS
|
||||||
from .errors import InvalidParameter
|
from .errors import FileDownloadError, InvalidParameter
|
||||||
from .file_download import (create_temporary_directory_and_cache,
|
from .file_download import (create_temporary_directory_and_cache,
|
||||||
download_file, get_file_download_url)
|
download_file, get_file_download_url)
|
||||||
from .utils.caching import ModelFileSystemCache
|
from .utils.caching import ModelFileSystemCache
|
||||||
from .utils.utils import (get_model_masked_directory,
|
from .utils.utils import (extract_root_from_patterns,
|
||||||
|
get_model_masked_directory,
|
||||||
model_id_to_group_owner_name, strtobool,
|
model_id_to_group_owner_name, strtobool,
|
||||||
weak_file_lock)
|
weak_file_lock)
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
DEFAULT_DATASET_PAGE_SIZE = 200
|
||||||
|
|
||||||
|
|
||||||
def snapshot_download(
|
def snapshot_download(
|
||||||
model_id: str = None,
|
model_id: str = None,
|
||||||
@@ -335,13 +342,42 @@ def _snapshot_download(
|
|||||||
snapshot_header[
|
snapshot_header[
|
||||||
'cached_model_revision'] = cache.cached_model_revision
|
'cached_model_revision'] = cache.cached_model_revision
|
||||||
|
|
||||||
|
# Extract server-side root filter from include patterns
|
||||||
|
extracted_root = extract_root_from_patterns(
|
||||||
|
allow_file_pattern=_normalize_patterns(allow_file_pattern),
|
||||||
|
allow_patterns=_normalize_patterns(allow_patterns))
|
||||||
|
|
||||||
repo_files = _api.get_model_files(
|
repo_files = _api.get_model_files(
|
||||||
model_id=repo_id,
|
model_id=repo_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
|
root=extracted_root,
|
||||||
recursive=True,
|
recursive=True,
|
||||||
use_cookies=False if cookies is None else cookies,
|
use_cookies=False if cookies is None else cookies,
|
||||||
headers=snapshot_header,
|
headers=snapshot_header,
|
||||||
endpoint=endpoint)
|
endpoint=endpoint)
|
||||||
|
|
||||||
|
# Fallback: if root filter yielded no results, retry without it
|
||||||
|
if not repo_files and extracted_root is not None:
|
||||||
|
logger.warning(
|
||||||
|
f"root='{extracted_root}' returned no model files, "
|
||||||
|
f'falling back to root=None for full listing.')
|
||||||
|
repo_files = _api.get_model_files(
|
||||||
|
model_id=repo_id,
|
||||||
|
revision=revision,
|
||||||
|
root=None,
|
||||||
|
recursive=True,
|
||||||
|
use_cookies=False if cookies is None else cookies,
|
||||||
|
headers=snapshot_header,
|
||||||
|
endpoint=endpoint)
|
||||||
|
|
||||||
|
# Apply client-side pattern filtering
|
||||||
|
repo_files = filter_files_by_patterns(
|
||||||
|
repo_files,
|
||||||
|
allow_file_pattern=allow_file_pattern,
|
||||||
|
ignore_file_pattern=ignore_file_pattern,
|
||||||
|
allow_patterns=allow_patterns,
|
||||||
|
ignore_patterns=ignore_patterns)
|
||||||
|
|
||||||
_download_file_lists(
|
_download_file_lists(
|
||||||
repo_files,
|
repo_files,
|
||||||
cache,
|
cache,
|
||||||
@@ -354,10 +390,7 @@ def _snapshot_download(
|
|||||||
repo_type=repo_type,
|
repo_type=repo_type,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
cookies=cookies,
|
cookies=cookies,
|
||||||
ignore_file_pattern=ignore_file_pattern,
|
pre_filtered=True,
|
||||||
allow_file_pattern=allow_file_pattern,
|
|
||||||
ignore_patterns=ignore_patterns,
|
|
||||||
allow_patterns=allow_patterns,
|
|
||||||
max_workers=max_workers,
|
max_workers=max_workers,
|
||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
progress_callbacks=progress_callbacks,
|
progress_callbacks=progress_callbacks,
|
||||||
@@ -392,75 +425,148 @@ def _snapshot_download(
|
|||||||
group_or_owner, name = model_id_to_group_owner_name(repo_id)
|
group_or_owner, name = model_id_to_group_owner_name(repo_id)
|
||||||
revision_detail = revision or DEFAULT_DATASET_REVISION
|
revision_detail = revision or DEFAULT_DATASET_REVISION
|
||||||
|
|
||||||
logger.info('Fetching dataset repo file list...')
|
# Extract server-side root filter from include patterns
|
||||||
repo_files = fetch_repo_files(
|
extracted_root = extract_root_from_patterns(
|
||||||
_api, repo_id, revision_detail, endpoint, token=token)
|
allow_file_pattern=_normalize_patterns(allow_file_pattern),
|
||||||
|
allow_patterns=_normalize_patterns(allow_patterns))
|
||||||
|
root_path = '/' + extracted_root if extracted_root else '/'
|
||||||
|
|
||||||
if repo_files is None:
|
print(f'Fetching file list (root: {root_path})...')
|
||||||
logger.error(
|
file_page_iter = _iter_dataset_file_pages(
|
||||||
f'Failed to retrieve file list for dataset: {repo_id}')
|
|
||||||
return None
|
|
||||||
|
|
||||||
_download_file_lists(
|
|
||||||
repo_files,
|
|
||||||
cache,
|
|
||||||
temporary_cache_dir,
|
|
||||||
repo_id,
|
|
||||||
_api,
|
_api,
|
||||||
name,
|
repo_id,
|
||||||
group_or_owner,
|
revision_detail,
|
||||||
headers,
|
endpoint,
|
||||||
repo_type=repo_type,
|
token=token,
|
||||||
|
root_path=root_path,
|
||||||
|
allow_file_pattern=allow_file_pattern,
|
||||||
|
ignore_file_pattern=ignore_file_pattern,
|
||||||
|
allow_patterns=allow_patterns,
|
||||||
|
ignore_patterns=ignore_patterns)
|
||||||
|
|
||||||
|
_pipeline_download_dataset(
|
||||||
|
file_page_iter,
|
||||||
|
cache=cache,
|
||||||
|
temporary_cache_dir=temporary_cache_dir,
|
||||||
|
repo_id=repo_id,
|
||||||
|
api=_api,
|
||||||
|
dataset_name=name,
|
||||||
|
namespace=group_or_owner,
|
||||||
|
headers=headers,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
cookies=cookies,
|
cookies=cookies,
|
||||||
ignore_file_pattern=ignore_file_pattern,
|
|
||||||
allow_file_pattern=allow_file_pattern,
|
|
||||||
ignore_patterns=ignore_patterns,
|
|
||||||
allow_patterns=allow_patterns,
|
|
||||||
max_workers=max_workers,
|
max_workers=max_workers,
|
||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
progress_callbacks=progress_callbacks,
|
progress_callbacks=progress_callbacks)
|
||||||
)
|
|
||||||
|
|
||||||
cache.save_model_version(revision_info=revision_detail)
|
cache.save_model_version(revision_info=revision_detail)
|
||||||
cache_root_path = cache.get_root_location()
|
cache_root_path = cache.get_root_location()
|
||||||
return cache_root_path
|
return cache_root_path
|
||||||
|
|
||||||
|
|
||||||
def fetch_repo_files(_api, repo_id, revision, endpoint, token=None):
|
def fetch_repo_files(
|
||||||
_owner, _dataset_name = repo_id.split('/')
|
_api,
|
||||||
|
repo_id,
|
||||||
|
revision,
|
||||||
|
endpoint,
|
||||||
|
token=None,
|
||||||
|
root_path='/',
|
||||||
|
allow_file_pattern=None,
|
||||||
|
ignore_file_pattern=None,
|
||||||
|
allow_patterns=None,
|
||||||
|
ignore_patterns=None,
|
||||||
|
page_size=DEFAULT_DATASET_PAGE_SIZE,
|
||||||
|
):
|
||||||
|
"""Fetch and filter dataset repo files with pagination and server-side prefix filtering.
|
||||||
|
|
||||||
|
Applies per-page pattern filtering to minimize memory usage.
|
||||||
|
Falls back to root_path='/' if the extracted prefix yields no results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
_api: HubApi instance.
|
||||||
|
repo_id: Dataset repo identifier (owner/name).
|
||||||
|
revision: Git revision.
|
||||||
|
endpoint: API endpoint URL.
|
||||||
|
token: Authentication token.
|
||||||
|
root_path: Server-side directory prefix filter.
|
||||||
|
allow_file_pattern: Include patterns for client-side filtering.
|
||||||
|
ignore_file_pattern: Exclude patterns for client-side filtering.
|
||||||
|
allow_patterns: Additional include patterns (HF-compatible).
|
||||||
|
ignore_patterns: Additional exclude patterns (HF-compatible).
|
||||||
|
page_size: Number of files per API page request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of filtered file entry dicts.
|
||||||
|
"""
|
||||||
|
if '/' not in repo_id:
|
||||||
|
raise InvalidParameter(
|
||||||
|
f"Invalid repo_id: '{repo_id}', expected format 'owner/name'")
|
||||||
|
_owner, _dataset_name = repo_id.split('/', 1)
|
||||||
_hub_id, _ = _api.get_dataset_id_and_type(
|
_hub_id, _ = _api.get_dataset_id_and_type(
|
||||||
dataset_name=_dataset_name,
|
dataset_name=_dataset_name,
|
||||||
namespace=_owner,
|
namespace=_owner,
|
||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
token=token)
|
token=token)
|
||||||
|
|
||||||
page_number = 1
|
has_patterns = any([
|
||||||
page_size = 150
|
allow_file_pattern, ignore_file_pattern, allow_patterns,
|
||||||
repo_files = []
|
ignore_patterns
|
||||||
|
])
|
||||||
|
|
||||||
while True:
|
def _paginate_and_filter(effective_root_path):
|
||||||
try:
|
"""Fetch all pages with the given root_path, applying per-page filtering."""
|
||||||
dataset_files = _api.get_dataset_files(
|
page_number = 1
|
||||||
repo_id=repo_id,
|
repo_files = []
|
||||||
revision=revision,
|
|
||||||
root_path='/',
|
|
||||||
recursive=True,
|
|
||||||
page_number=page_number,
|
|
||||||
page_size=page_size,
|
|
||||||
endpoint=endpoint,
|
|
||||||
token=token,
|
|
||||||
dataset_hub_id=_hub_id)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f'Error fetching dataset files: {e}')
|
|
||||||
break
|
|
||||||
|
|
||||||
repo_files.extend(dataset_files)
|
while True:
|
||||||
|
try:
|
||||||
|
dataset_files = _api.get_dataset_files(
|
||||||
|
repo_id=repo_id,
|
||||||
|
revision=revision,
|
||||||
|
root_path=effective_root_path,
|
||||||
|
recursive=True,
|
||||||
|
page_number=page_number,
|
||||||
|
page_size=page_size,
|
||||||
|
endpoint=endpoint,
|
||||||
|
token=token,
|
||||||
|
dataset_hub_id=_hub_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f'Error fetching dataset files (page {page_number}): {e}')
|
||||||
|
break
|
||||||
|
|
||||||
if len(dataset_files) < page_size:
|
if not dataset_files:
|
||||||
break
|
break
|
||||||
|
|
||||||
page_number += 1
|
# Per-page filtering: apply patterns immediately to reduce memory
|
||||||
|
if has_patterns:
|
||||||
|
page_filtered = filter_files_by_patterns(
|
||||||
|
dataset_files,
|
||||||
|
allow_file_pattern=allow_file_pattern,
|
||||||
|
ignore_file_pattern=ignore_file_pattern,
|
||||||
|
allow_patterns=allow_patterns,
|
||||||
|
ignore_patterns=ignore_patterns)
|
||||||
|
repo_files.extend(page_filtered)
|
||||||
|
else:
|
||||||
|
# No patterns: keep all non-tree entries
|
||||||
|
repo_files.extend(
|
||||||
|
f for f in dataset_files if f.get('Type') != 'tree')
|
||||||
|
|
||||||
|
if len(dataset_files) < page_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
page_number += 1
|
||||||
|
|
||||||
|
return repo_files
|
||||||
|
|
||||||
|
# Primary fetch with optimized root_path
|
||||||
|
repo_files = _paginate_and_filter(root_path)
|
||||||
|
|
||||||
|
# Fallback: if optimized root_path yielded nothing and it's not the default
|
||||||
|
if not repo_files and root_path != '/':
|
||||||
|
logger.warning(f"root_path='{root_path}' returned no results, "
|
||||||
|
f"falling back to root_path='/' for full listing.")
|
||||||
|
repo_files = _paginate_and_filter('/')
|
||||||
|
|
||||||
return repo_files
|
return repo_files
|
||||||
|
|
||||||
@@ -494,6 +600,310 @@ def _get_valid_regex_pattern(patterns: List[str]):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def filter_files_by_patterns(
|
||||||
|
repo_files: List[dict],
|
||||||
|
*,
|
||||||
|
allow_file_pattern: Optional[List[str]] = None,
|
||||||
|
ignore_file_pattern: Optional[List[str]] = None,
|
||||||
|
allow_patterns: Optional[List[str]] = None,
|
||||||
|
ignore_patterns: Optional[List[str]] = None,
|
||||||
|
) -> List[dict]:
|
||||||
|
"""Filter repo file entries by include/exclude patterns.
|
||||||
|
|
||||||
|
Skips 'tree' type entries. Applies fnmatch and regex pattern matching.
|
||||||
|
Returns only file entries that pass all filter criteria.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_files: List of file entry dicts with 'Type', 'Path', 'Name' keys.
|
||||||
|
allow_file_pattern: Include patterns (fnmatch). Files must match at least one.
|
||||||
|
ignore_file_pattern: Exclude patterns (fnmatch). Matching files are skipped.
|
||||||
|
allow_patterns: Additional include patterns (HF-compatible).
|
||||||
|
ignore_patterns: Additional exclude patterns (HF-compatible).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of file entries that pass all filters.
|
||||||
|
"""
|
||||||
|
ignore_patterns = _normalize_patterns(ignore_patterns)
|
||||||
|
allow_patterns = _normalize_patterns(allow_patterns)
|
||||||
|
ignore_file_pattern = _normalize_patterns(ignore_file_pattern)
|
||||||
|
allow_file_pattern = _normalize_patterns(allow_file_pattern)
|
||||||
|
ignore_regex_pattern = _get_valid_regex_pattern(ignore_file_pattern)
|
||||||
|
|
||||||
|
filtered = []
|
||||||
|
for repo_file in repo_files:
|
||||||
|
if repo_file['Type'] == 'tree':
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
if ignore_patterns and any(
|
||||||
|
fnmatch.fnmatch(repo_file['Path'], p)
|
||||||
|
for p in ignore_patterns):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if ignore_file_pattern and any(
|
||||||
|
fnmatch.fnmatch(repo_file['Path'], p)
|
||||||
|
for p in ignore_file_pattern):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if ignore_regex_pattern and any(
|
||||||
|
re.search(p, repo_file['Name']) is not None
|
||||||
|
for p in ignore_regex_pattern):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if allow_patterns and not any(
|
||||||
|
fnmatch.fnmatch(repo_file['Path'], p)
|
||||||
|
for p in allow_patterns):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if allow_file_pattern and not any(
|
||||||
|
fnmatch.fnmatch(repo_file['Path'], p)
|
||||||
|
for p in allow_file_pattern):
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning('Invalid file pattern: %s' % e)
|
||||||
|
continue
|
||||||
|
|
||||||
|
filtered.append(repo_file)
|
||||||
|
|
||||||
|
return filtered
|
||||||
|
|
||||||
|
|
||||||
|
def _iter_dataset_file_pages(
|
||||||
|
_api,
|
||||||
|
repo_id,
|
||||||
|
revision,
|
||||||
|
endpoint,
|
||||||
|
token=None,
|
||||||
|
root_path='/',
|
||||||
|
allow_file_pattern=None,
|
||||||
|
ignore_file_pattern=None,
|
||||||
|
allow_patterns=None,
|
||||||
|
ignore_patterns=None,
|
||||||
|
page_size=DEFAULT_DATASET_PAGE_SIZE,
|
||||||
|
):
|
||||||
|
"""Generator that yields filtered file pages from a dataset repo.
|
||||||
|
|
||||||
|
Each yield is a non-empty list of file-entry dicts for one API page.
|
||||||
|
Applies per-page pattern filtering to minimize memory usage.
|
||||||
|
Falls back to root_path='/' if the extracted prefix yields no results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
_api: HubApi instance.
|
||||||
|
repo_id: Dataset repo identifier (owner/name).
|
||||||
|
revision: Git revision.
|
||||||
|
endpoint: API endpoint URL.
|
||||||
|
token: Authentication token.
|
||||||
|
root_path: Server-side directory prefix filter.
|
||||||
|
allow_file_pattern: Include patterns (fnmatch).
|
||||||
|
ignore_file_pattern: Exclude patterns (fnmatch).
|
||||||
|
allow_patterns: Additional include patterns (HF-compatible).
|
||||||
|
ignore_patterns: Additional exclude patterns (HF-compatible).
|
||||||
|
page_size: Number of files per API page request.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
List[dict]: Non-empty list of filtered file entries per page.
|
||||||
|
"""
|
||||||
|
if '/' not in repo_id:
|
||||||
|
raise InvalidParameter(
|
||||||
|
f"Invalid repo_id: '{repo_id}', expected format 'owner/name'")
|
||||||
|
|
||||||
|
_owner, _dataset_name = repo_id.split('/', 1)
|
||||||
|
_hub_id, _ = _api.get_dataset_id_and_type(
|
||||||
|
dataset_name=_dataset_name,
|
||||||
|
namespace=_owner,
|
||||||
|
endpoint=endpoint,
|
||||||
|
token=token)
|
||||||
|
|
||||||
|
has_patterns = any([
|
||||||
|
allow_file_pattern, ignore_file_pattern, allow_patterns,
|
||||||
|
ignore_patterns
|
||||||
|
])
|
||||||
|
|
||||||
|
def _paginate_pages(effective_root_path):
|
||||||
|
"""Yield filtered file pages for the given root_path."""
|
||||||
|
page_number = 1
|
||||||
|
total_found = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
dataset_files = _api.get_dataset_files(
|
||||||
|
repo_id=repo_id,
|
||||||
|
revision=revision,
|
||||||
|
root_path=effective_root_path,
|
||||||
|
recursive=True,
|
||||||
|
page_number=page_number,
|
||||||
|
page_size=page_size,
|
||||||
|
endpoint=endpoint,
|
||||||
|
token=token,
|
||||||
|
dataset_hub_id=_hub_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f'Error fetching dataset files (page {page_number}): {e}')
|
||||||
|
break
|
||||||
|
|
||||||
|
if not dataset_files:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Per-page filtering to reduce memory footprint
|
||||||
|
if has_patterns:
|
||||||
|
page_filtered = filter_files_by_patterns(
|
||||||
|
dataset_files,
|
||||||
|
allow_file_pattern=allow_file_pattern,
|
||||||
|
ignore_file_pattern=ignore_file_pattern,
|
||||||
|
allow_patterns=allow_patterns,
|
||||||
|
ignore_patterns=ignore_patterns)
|
||||||
|
else:
|
||||||
|
# No patterns: keep all non-tree entries
|
||||||
|
page_filtered = [
|
||||||
|
f for f in dataset_files if f.get('Type') != 'tree'
|
||||||
|
]
|
||||||
|
|
||||||
|
total_found += len(page_filtered)
|
||||||
|
if page_filtered:
|
||||||
|
yield page_filtered
|
||||||
|
|
||||||
|
print(
|
||||||
|
f'\r Fetched {total_found} matching files '
|
||||||
|
f'({page_number} pages)...',
|
||||||
|
end='',
|
||||||
|
flush=True)
|
||||||
|
|
||||||
|
if len(dataset_files) < page_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
page_number += 1
|
||||||
|
|
||||||
|
# Primary fetch with optimized root_path
|
||||||
|
try:
|
||||||
|
yielded_any = False
|
||||||
|
for page in _paginate_pages(root_path):
|
||||||
|
yielded_any = True
|
||||||
|
yield page
|
||||||
|
|
||||||
|
# Fallback: if optimized root_path yielded nothing and it's not the default
|
||||||
|
if not yielded_any and root_path != '/':
|
||||||
|
print(f"\n root_path='{root_path}' returned no results, "
|
||||||
|
f"falling back to root_path='/' for full listing.")
|
||||||
|
for page in _paginate_pages('/'):
|
||||||
|
yield page
|
||||||
|
finally:
|
||||||
|
# Terminate the \r progress line regardless of how iteration ends
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def _pipeline_download_dataset(
|
||||||
|
file_page_iter,
|
||||||
|
cache,
|
||||||
|
temporary_cache_dir,
|
||||||
|
repo_id,
|
||||||
|
api,
|
||||||
|
dataset_name,
|
||||||
|
namespace,
|
||||||
|
headers,
|
||||||
|
revision,
|
||||||
|
cookies,
|
||||||
|
max_workers=DEFAULT_MAX_WORKERS,
|
||||||
|
endpoint=None,
|
||||||
|
progress_callbacks=None,
|
||||||
|
):
|
||||||
|
"""Pipeline consumer: download dataset files as pages are yielded.
|
||||||
|
|
||||||
|
Consumes the page iterator from _iter_dataset_file_pages, submitting
|
||||||
|
each file to a thread pool for concurrent download. Uses tqdm for
|
||||||
|
real-time progress and thread-safe error collection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_page_iter: Iterator yielding List[dict] file pages.
|
||||||
|
cache: ModelFileSystemCache instance for dedup.
|
||||||
|
temporary_cache_dir: Temp staging directory.
|
||||||
|
repo_id: Dataset repo identifier.
|
||||||
|
api: HubApi instance.
|
||||||
|
dataset_name: Dataset name component.
|
||||||
|
namespace: Owner/namespace component.
|
||||||
|
headers: HTTP request headers.
|
||||||
|
revision: Git revision.
|
||||||
|
cookies: HTTP cookies.
|
||||||
|
max_workers: Thread pool concurrency.
|
||||||
|
endpoint: API endpoint URL.
|
||||||
|
progress_callbacks: Optional progress callback list.
|
||||||
|
"""
|
||||||
|
total_found = 0
|
||||||
|
total_cached = 0
|
||||||
|
failed_items = []
|
||||||
|
lock = threading.Lock()
|
||||||
|
|
||||||
|
def _on_done(future, repo_file):
|
||||||
|
"""Done callback: update progress bar and collect failures."""
|
||||||
|
try:
|
||||||
|
future.result()
|
||||||
|
except Exception as exc:
|
||||||
|
with lock:
|
||||||
|
failed_items.append((repo_file, exc))
|
||||||
|
logger.debug(
|
||||||
|
f"Download failed for {repo_file.get('Path', '?')}: {exc}")
|
||||||
|
finally:
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
# tqdm wraps the executor so all callbacks fire before pbar closes
|
||||||
|
with tqdm(total=0, unit=' files', disable=False) as pbar:
|
||||||
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
for page_files in file_page_iter:
|
||||||
|
for repo_file in page_files:
|
||||||
|
total_found += 1
|
||||||
|
pbar.total = total_found
|
||||||
|
pbar.refresh()
|
||||||
|
|
||||||
|
# Skip files already in cache
|
||||||
|
if cache.exists(repo_file):
|
||||||
|
total_cached += 1
|
||||||
|
pbar.update(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Build download URL
|
||||||
|
url = api.get_dataset_file_url(
|
||||||
|
file_name=repo_file['Path'],
|
||||||
|
dataset_name=dataset_name,
|
||||||
|
namespace=namespace,
|
||||||
|
revision=revision,
|
||||||
|
endpoint=endpoint)
|
||||||
|
|
||||||
|
# Submit download task
|
||||||
|
future = executor.submit(
|
||||||
|
download_file,
|
||||||
|
url,
|
||||||
|
repo_file,
|
||||||
|
temporary_cache_dir,
|
||||||
|
cache,
|
||||||
|
headers,
|
||||||
|
cookies,
|
||||||
|
disable_tqdm=False,
|
||||||
|
progress_callbacks=progress_callbacks,
|
||||||
|
)
|
||||||
|
future.add_done_callback(
|
||||||
|
lambda f, rf=repo_file: _on_done(f, rf))
|
||||||
|
|
||||||
|
# Executor __exit__ waits for all futures to complete
|
||||||
|
|
||||||
|
# Report failures after progress bar closes
|
||||||
|
if failed_items:
|
||||||
|
failed_paths = [
|
||||||
|
item.get('Path', '?') if isinstance(item, dict) else str(item)
|
||||||
|
for item, _ in failed_items
|
||||||
|
]
|
||||||
|
logger.error(f'{len(failed_items)} file(s) failed to download:\n'
|
||||||
|
+ '\n'.join(f' - {p}' for p in failed_paths))
|
||||||
|
|
||||||
|
# Completion summary (always print, even if raising after)
|
||||||
|
downloaded = total_found - total_cached - len(failed_items)
|
||||||
|
print(f'Download complete: {total_found} files found, '
|
||||||
|
f'{total_cached} cached, {downloaded} downloaded'
|
||||||
|
+ (f', {len(failed_items)} failed' if failed_items else '') + '.')
|
||||||
|
|
||||||
|
if failed_items:
|
||||||
|
raise FileDownloadError(
|
||||||
|
f'{len(failed_items)} file(s) failed to download out of '
|
||||||
|
f'{total_found}.')
|
||||||
|
|
||||||
|
|
||||||
def _download_file_lists(
|
def _download_file_lists(
|
||||||
repo_files: List[str],
|
repo_files: List[str],
|
||||||
cache: ModelFileSystemCache,
|
cache: ModelFileSystemCache,
|
||||||
@@ -513,62 +923,77 @@ def _download_file_lists(
|
|||||||
max_workers: int = 8,
|
max_workers: int = 8,
|
||||||
endpoint: Optional[str] = None,
|
endpoint: Optional[str] = None,
|
||||||
progress_callbacks: List[Type[ProgressCallback]] = None,
|
progress_callbacks: List[Type[ProgressCallback]] = None,
|
||||||
|
pre_filtered: bool = False,
|
||||||
):
|
):
|
||||||
ignore_patterns = _normalize_patterns(ignore_patterns)
|
if pre_filtered:
|
||||||
allow_patterns = _normalize_patterns(allow_patterns)
|
# Files are already filtered by patterns; only check cache
|
||||||
ignore_file_pattern = _normalize_patterns(ignore_file_pattern)
|
filtered_repo_files = []
|
||||||
allow_file_pattern = _normalize_patterns(allow_file_pattern)
|
for repo_file in repo_files:
|
||||||
# to compatible regex usage.
|
|
||||||
ignore_regex_pattern = _get_valid_regex_pattern(ignore_file_pattern)
|
|
||||||
|
|
||||||
filtered_repo_files = []
|
|
||||||
for repo_file in repo_files:
|
|
||||||
if repo_file['Type'] == 'tree':
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
# processing patterns
|
|
||||||
if ignore_patterns and any([
|
|
||||||
fnmatch.fnmatch(repo_file['Path'], pattern)
|
|
||||||
for pattern in ignore_patterns
|
|
||||||
]):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if ignore_file_pattern and any([
|
|
||||||
fnmatch.fnmatch(repo_file['Path'], pattern)
|
|
||||||
for pattern in ignore_file_pattern
|
|
||||||
]):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if ignore_regex_pattern and any([
|
|
||||||
re.search(pattern, repo_file['Name']) is not None
|
|
||||||
for pattern in ignore_regex_pattern
|
|
||||||
]): # noqa E501
|
|
||||||
continue
|
|
||||||
|
|
||||||
if allow_patterns is not None and allow_patterns:
|
|
||||||
if not any(
|
|
||||||
fnmatch.fnmatch(repo_file['Path'], pattern)
|
|
||||||
for pattern in allow_patterns):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if allow_file_pattern is not None and allow_file_pattern:
|
|
||||||
if not any(
|
|
||||||
fnmatch.fnmatch(repo_file['Path'], pattern)
|
|
||||||
for pattern in allow_file_pattern):
|
|
||||||
continue
|
|
||||||
# check model_file is exist in cache, if existed, skip download
|
|
||||||
if cache.exists(repo_file):
|
if cache.exists(repo_file):
|
||||||
file_name = os.path.basename(repo_file['Name'])
|
file_name = os.path.basename(repo_file['Name'])
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f'File {file_name} already in cache with identical hash, skip downloading!'
|
f'File {file_name} already in cache with identical hash, skip downloading!'
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
except Exception as e:
|
|
||||||
logger.warning('The file pattern is invalid : %s' % e)
|
|
||||||
else:
|
|
||||||
filtered_repo_files.append(repo_file)
|
filtered_repo_files.append(repo_file)
|
||||||
|
else:
|
||||||
|
# Legacy path: apply pattern filtering + cache check
|
||||||
|
ignore_patterns = _normalize_patterns(ignore_patterns)
|
||||||
|
allow_patterns = _normalize_patterns(allow_patterns)
|
||||||
|
ignore_file_pattern = _normalize_patterns(ignore_file_pattern)
|
||||||
|
allow_file_pattern = _normalize_patterns(allow_file_pattern)
|
||||||
|
# to compatible regex usage.
|
||||||
|
ignore_regex_pattern = _get_valid_regex_pattern(ignore_file_pattern)
|
||||||
|
|
||||||
@thread_executor(max_workers=max_workers, disable_tqdm=False)
|
filtered_repo_files = []
|
||||||
|
for repo_file in repo_files:
|
||||||
|
if repo_file['Type'] == 'tree':
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
# processing patterns
|
||||||
|
if ignore_patterns and any([
|
||||||
|
fnmatch.fnmatch(repo_file['Path'], pattern)
|
||||||
|
for pattern in ignore_patterns
|
||||||
|
]):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if ignore_file_pattern and any([
|
||||||
|
fnmatch.fnmatch(repo_file['Path'], pattern)
|
||||||
|
for pattern in ignore_file_pattern
|
||||||
|
]):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if ignore_regex_pattern and any([
|
||||||
|
re.search(pattern, repo_file['Name']) is not None
|
||||||
|
for pattern in ignore_regex_pattern
|
||||||
|
]): # noqa E501
|
||||||
|
continue
|
||||||
|
|
||||||
|
if allow_patterns is not None and allow_patterns:
|
||||||
|
if not any(
|
||||||
|
fnmatch.fnmatch(repo_file['Path'], pattern)
|
||||||
|
for pattern in allow_patterns):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if allow_file_pattern is not None and allow_file_pattern:
|
||||||
|
if not any(
|
||||||
|
fnmatch.fnmatch(repo_file['Path'], pattern)
|
||||||
|
for pattern in allow_file_pattern):
|
||||||
|
continue
|
||||||
|
# check model_file is exist in cache, if existed, skip download
|
||||||
|
if cache.exists(repo_file):
|
||||||
|
file_name = os.path.basename(repo_file['Name'])
|
||||||
|
logger.debug(
|
||||||
|
f'File {file_name} already in cache with identical hash, skip downloading!'
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning('The file pattern is invalid : %s' % e)
|
||||||
|
else:
|
||||||
|
filtered_repo_files.append(repo_file)
|
||||||
|
|
||||||
|
@thread_executor(
|
||||||
|
max_workers=max_workers, disable_tqdm=False, fault_tolerant=True)
|
||||||
def _download_single_file(repo_file):
|
def _download_single_file(repo_file):
|
||||||
if repo_type == REPO_TYPE_MODEL:
|
if repo_type == REPO_TYPE_MODEL:
|
||||||
url = get_file_download_url(
|
url = get_file_download_url(
|
||||||
@@ -602,5 +1027,26 @@ def _download_file_lists(
|
|||||||
if len(filtered_repo_files) > 0:
|
if len(filtered_repo_files) > 0:
|
||||||
logger.info(
|
logger.info(
|
||||||
f'Got {len(filtered_repo_files)} files, start to download ...')
|
f'Got {len(filtered_repo_files)} files, start to download ...')
|
||||||
_download_single_file(filtered_repo_files)
|
download_result = _download_single_file(filtered_repo_files)
|
||||||
logger.info(f"Download {repo_type} '{repo_id}' successfully.")
|
|
||||||
|
# Handle fault-tolerant results: report failed downloads
|
||||||
|
failed_items = []
|
||||||
|
if isinstance(download_result, tuple) and len(download_result) == 2:
|
||||||
|
_, failed_items = download_result
|
||||||
|
if failed_items:
|
||||||
|
failed_paths = [
|
||||||
|
item['Path'] if isinstance(item, dict) else str(item)
|
||||||
|
for item, _ in failed_items
|
||||||
|
]
|
||||||
|
logger.error(
|
||||||
|
f'{len(failed_items)} file(s) failed to download:\n'
|
||||||
|
+ '\n'.join(f' - {p}' for p in failed_paths))
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Finish downloading {len(filtered_repo_files)} files for repo '{repo_id}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
if failed_items:
|
||||||
|
raise FileDownloadError(
|
||||||
|
f'{len(failed_items)} file(s) failed to download out of '
|
||||||
|
f'{len(filtered_repo_files)}.')
|
||||||
|
|||||||
@@ -66,6 +66,83 @@ def convert_patterns(raw_input: Union[str, List[str]]):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def extract_root_from_patterns(
|
||||||
|
allow_file_pattern: Optional[List[str]] = None,
|
||||||
|
allow_patterns: Optional[List[str]] = None,
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Extract common directory prefix from include patterns for server-side filtering.
|
||||||
|
|
||||||
|
Only processes allow/include patterns (ignore/exclude is irrelevant for prefix).
|
||||||
|
Returns None if no meaningful prefix can be extracted.
|
||||||
|
|
||||||
|
Algorithm:
|
||||||
|
1. Merge allow_file_pattern and allow_patterns into one list
|
||||||
|
2. For each pattern, find position of first wildcard char (* ? [)
|
||||||
|
3. Extract text before that position, then take directory part (up to last '/')
|
||||||
|
4. Find longest common directory prefix (path-segment-aware)
|
||||||
|
5. Return None if no valid common prefix
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
['TacExo/*'] -> 'TacExo'
|
||||||
|
['data/train/*.parquet'] -> 'data/train'
|
||||||
|
['data/train/*', 'data/valid/*'] -> 'data'
|
||||||
|
['*.safetensors'] -> None
|
||||||
|
['TacExo/*', 'OtherDir/*'] -> None (no common prefix)
|
||||||
|
['data/*/train.csv'] -> 'data'
|
||||||
|
"""
|
||||||
|
# Merge both pattern lists
|
||||||
|
patterns = []
|
||||||
|
if allow_file_pattern:
|
||||||
|
patterns.extend(allow_file_pattern)
|
||||||
|
if allow_patterns:
|
||||||
|
patterns.extend(allow_patterns)
|
||||||
|
|
||||||
|
if not patterns:
|
||||||
|
return None
|
||||||
|
|
||||||
|
extracted_dirs = []
|
||||||
|
for pattern in patterns:
|
||||||
|
# Find position of first wildcard character
|
||||||
|
first_wildcard = len(pattern)
|
||||||
|
for wc in ('*', '?', '['):
|
||||||
|
pos = pattern.find(wc)
|
||||||
|
if pos != -1:
|
||||||
|
first_wildcard = min(first_wildcard, pos)
|
||||||
|
|
||||||
|
# Get text before wildcard
|
||||||
|
prefix = pattern[:first_wildcard]
|
||||||
|
|
||||||
|
# Extract directory part (up to last '/')
|
||||||
|
last_sep = prefix.rfind('/')
|
||||||
|
if last_sep > 0:
|
||||||
|
dir_part = prefix[:last_sep]
|
||||||
|
# Validate: no wildcards should remain in dir_part
|
||||||
|
if not any(c in dir_part for c in ('*', '?', '[')):
|
||||||
|
extracted_dirs.append(dir_part)
|
||||||
|
# If no '/' found or only at position 0, this pattern has no directory prefix
|
||||||
|
|
||||||
|
if not extracted_dirs:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Find longest common directory prefix (path-segment-aware)
|
||||||
|
if len(extracted_dirs) == 1:
|
||||||
|
return extracted_dirs[0]
|
||||||
|
|
||||||
|
# Split all paths into segments and find common prefix segments
|
||||||
|
split_paths = [d.split('/') for d in extracted_dirs]
|
||||||
|
common_segments = []
|
||||||
|
for segments in zip(*split_paths):
|
||||||
|
if len(set(segments)) == 1:
|
||||||
|
common_segments.append(segments[0])
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not common_segments:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return '/'.join(common_segments)
|
||||||
|
|
||||||
|
|
||||||
# during model download, the '.' would be converted to '___' to produce
|
# during model download, the '.' would be converted to '___' to produce
|
||||||
# actual physical (masked) directory for storage
|
# actual physical (masked) directory for storage
|
||||||
def get_model_masked_directory(directory, model_id):
|
def get_model_masked_directory(directory, model_id):
|
||||||
@@ -134,7 +211,8 @@ def get_endpoint(cn_site=True):
|
|||||||
|
|
||||||
|
|
||||||
def compute_hash(file_path):
|
def compute_hash(file_path):
|
||||||
BUFFER_SIZE = 1024 * 64 # 64k buffer size
|
# 16MB buffer for large file hash computation
|
||||||
|
BUFFER_SIZE = 1024 * 1024 * 16
|
||||||
sha256_hash = hashlib.sha256()
|
sha256_hash = hashlib.sha256()
|
||||||
with open(file_path, 'rb') as f:
|
with open(file_path, 'rb') as f:
|
||||||
while True:
|
while True:
|
||||||
|
|||||||
Reference in New Issue
Block a user