weak file lock (#1417)

This commit is contained in:
tastelikefeet
2025-07-17 18:38:24 +08:00
committed by hjh0119
parent 3d11b891ca
commit d031f3e20b
3 changed files with 211 additions and 37 deletions

View File

@@ -4,17 +4,11 @@ import fnmatch
import os import os
import re import re
import uuid import uuid
from contextlib import nullcontext
from http.cookiejar import CookieJar from http.cookiejar import CookieJar
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Type, Union from typing import Dict, List, Optional, Type, Union
from modelscope.hub.api import HubApi, ModelScopeConfig
from modelscope.hub.errors import InvalidParameter
from modelscope.hub.file_download import (create_temporary_directory_and_cache,
download_file, get_file_download_url)
from modelscope.hub.utils.caching import ModelFileSystemCache
from modelscope.hub.utils.utils import (get_model_masked_directory,
model_id_to_group_owner_name)
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
DEFAULT_MODEL_REVISION, DEFAULT_MODEL_REVISION,
INTRA_CLOUD_ACCELERATION, INTRA_CLOUD_ACCELERATION,
@@ -23,7 +17,15 @@ from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
from modelscope.utils.file_utils import get_modelscope_cache_dir from modelscope.utils.file_utils import get_modelscope_cache_dir
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger
from modelscope.utils.thread_utils import thread_executor from modelscope.utils.thread_utils import thread_executor
from .api import HubApi, ModelScopeConfig
from .callback import ProgressCallback from .callback import ProgressCallback
from .errors import InvalidParameter
from .file_download import (create_temporary_directory_and_cache,
download_file, get_file_download_url)
from .utils.caching import ModelFileSystemCache
from .utils.utils import (get_model_masked_directory,
model_id_to_group_owner_name, strtobool,
weak_file_lock)
logger = get_logger() logger = get_logger()
@@ -43,6 +45,7 @@ def snapshot_download(
max_workers: int = 8, max_workers: int = 8,
repo_id: str = None, repo_id: str = None,
repo_type: Optional[str] = REPO_TYPE_MODEL, repo_type: Optional[str] = REPO_TYPE_MODEL,
enable_file_lock: Optional[bool] = None,
progress_callbacks: List[Type[ProgressCallback]] = None, progress_callbacks: List[Type[ProgressCallback]] = None,
) -> str: ) -> str:
"""Download all files of a repo. """Download all files of a repo.
@@ -79,6 +82,9 @@ def snapshot_download(
If provided, files matching any of the patterns are not downloaded, priority over ignore_file_pattern. If provided, files matching any of the patterns are not downloaded, priority over ignore_file_pattern.
For hugging-face compatibility. For hugging-face compatibility.
max_workers (`int`): The maximum number of workers to download files, default 8. max_workers (`int`): The maximum number of workers to download files, default 8.
enable_file_lock (`bool`): Enable file lock, this is useful in multiprocessing downloading, default `True`.
If you find something wrong with file lock and have a problem modifying your code,
change `MODELSCOPE_HUB_FILE_LOCK` env to `false`.
progress_callbacks (`List[Type[ProgressCallback]]`, **optional**, default to `None`): progress_callbacks (`List[Type[ProgressCallback]]`, **optional**, default to `None`):
progress callbacks to track the download progress. progress callbacks to track the download progress.
Raises: Raises:
@@ -109,21 +115,35 @@ def snapshot_download(
if revision is None: if revision is None:
revision = DEFAULT_DATASET_REVISION if repo_type == REPO_TYPE_DATASET else DEFAULT_MODEL_REVISION revision = DEFAULT_DATASET_REVISION if repo_type == REPO_TYPE_DATASET else DEFAULT_MODEL_REVISION
return _snapshot_download( if enable_file_lock is None:
repo_id, enable_file_lock = strtobool(
repo_type=repo_type, os.environ.get('MODELSCOPE_HUB_FILE_LOCK', 'true'))
revision=revision,
cache_dir=cache_dir, if enable_file_lock:
user_agent=user_agent, system_cache = cache_dir if cache_dir is not None else get_modelscope_cache_dir(
local_files_only=local_files_only, )
cookies=cookies, os.makedirs(os.path.join(system_cache, '.lock'), exist_ok=True)
ignore_file_pattern=ignore_file_pattern, lock_file = os.path.join(system_cache, '.lock',
allow_file_pattern=allow_file_pattern, repo_id.replace('/', '___'))
local_dir=local_dir, context = weak_file_lock(lock_file)
ignore_patterns=ignore_patterns, else:
allow_patterns=allow_patterns, context = nullcontext()
max_workers=max_workers, with context:
progress_callbacks=progress_callbacks) return _snapshot_download(
repo_id,
repo_type=repo_type,
revision=revision,
cache_dir=cache_dir,
user_agent=user_agent,
local_files_only=local_files_only,
cookies=cookies,
ignore_file_pattern=ignore_file_pattern,
allow_file_pattern=allow_file_pattern,
local_dir=local_dir,
ignore_patterns=ignore_patterns,
allow_patterns=allow_patterns,
max_workers=max_workers,
progress_callbacks=progress_callbacks)
def dataset_snapshot_download( def dataset_snapshot_download(
@@ -138,6 +158,7 @@ def dataset_snapshot_download(
allow_file_pattern: Optional[Union[str, List[str]]] = None, allow_file_pattern: Optional[Union[str, List[str]]] = None,
allow_patterns: Optional[Union[List[str], str]] = None, allow_patterns: Optional[Union[List[str], str]] = None,
ignore_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None,
enable_file_lock: Optional[bool] = None,
max_workers: int = 8, max_workers: int = 8,
) -> str: ) -> str:
"""Download raw files of a dataset. """Download raw files of a dataset.
@@ -171,6 +192,9 @@ def dataset_snapshot_download(
ignore_patterns (`str` or `List`, *optional*, default to `None`): ignore_patterns (`str` or `List`, *optional*, default to `None`):
If provided, files matching any of the patterns are not downloaded, priority over ignore_file_pattern. If provided, files matching any of the patterns are not downloaded, priority over ignore_file_pattern.
For hugging-face compatibility. For hugging-face compatibility.
enable_file_lock (`bool`): Enable file lock, this is useful in multiprocessing downloading, default `True`.
If you find something wrong with file lock and have a problem modifying your code,
change `MODELSCOPE_HUB_FILE_LOCK` env to `false`.
max_workers (`int`): The maximum number of workers to download files, default 8. max_workers (`int`): The maximum number of workers to download files, default 8.
Raises: Raises:
ValueError: the value details. ValueError: the value details.
@@ -187,20 +211,34 @@ def dataset_snapshot_download(
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
if some parameter value is invalid if some parameter value is invalid
""" """
return _snapshot_download( if enable_file_lock is None:
dataset_id, enable_file_lock = strtobool(
repo_type=REPO_TYPE_DATASET, os.environ.get('MODELSCOPE_HUB_FILE_LOCK', 'true'))
revision=revision,
cache_dir=cache_dir, if enable_file_lock:
user_agent=user_agent, system_cache = cache_dir if cache_dir is not None else get_modelscope_cache_dir(
local_files_only=local_files_only, )
cookies=cookies, os.makedirs(os.path.join(system_cache, '.lock'), exist_ok=True)
ignore_file_pattern=ignore_file_pattern, lock_file = os.path.join(system_cache, '.lock',
allow_file_pattern=allow_file_pattern, dataset_id.replace('/', '___'))
local_dir=local_dir, context = weak_file_lock(lock_file)
ignore_patterns=ignore_patterns, else:
allow_patterns=allow_patterns, context = nullcontext()
max_workers=max_workers) with context:
return _snapshot_download(
dataset_id,
repo_type=REPO_TYPE_DATASET,
revision=revision,
cache_dir=cache_dir,
user_agent=user_agent,
local_files_only=local_files_only,
cookies=cookies,
ignore_file_pattern=ignore_file_pattern,
allow_file_pattern=allow_file_pattern,
local_dir=local_dir,
ignore_patterns=ignore_patterns,
allow_patterns=allow_patterns,
max_workers=max_workers)
def _snapshot_download( def _snapshot_download(

View File

@@ -1,12 +1,15 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import contextlib
import hashlib import hashlib
import os import os
import sys import sys
import time import time
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import List, Optional, Union from typing import Generator, List, Optional, Union
from filelock import BaseFileLock, FileLock, SoftFileLock, Timeout
from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN, from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN,
DEFAULT_MODELSCOPE_GROUP, DEFAULT_MODELSCOPE_GROUP,
@@ -242,3 +245,57 @@ def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str:
for row in rows: for row in rows:
lines.append(row_format.format(*row)) lines.append(row_format.format(*row))
return '\n'.join(lines) return '\n'.join(lines)
# Part of the code borrowed from the awesome work of huggingface_hub/transformers
def strtobool(val):
val = val.lower()
if val in {'y', 'yes', 't', 'true', 'on', '1'}:
return 1
if val in {'n', 'no', 'f', 'false', 'off', '0'}:
return 0
raise ValueError(f'invalid truth value {val!r}')
@contextlib.contextmanager
def weak_file_lock(lock_file: Union[str, Path],
*,
timeout: Optional[float] = None
) -> Generator[BaseFileLock, None, None]:
default_interval = 60
lock = FileLock(lock_file, timeout=default_interval)
start_time = time.time()
while True:
elapsed_time = time.time() - start_time
if timeout is not None and elapsed_time >= timeout:
raise Timeout(str(lock_file))
try:
lock.acquire(
timeout=min(default_interval, timeout - elapsed_time)
if timeout else default_interval) # noqa
except Timeout:
logger.info(
f'Still waiting to acquire lock on {lock_file} (elapsed: {time.time() - start_time:.1f} seconds)'
)
except NotImplementedError as e:
if 'use SoftFileLock instead' in str(e):
logger.warning(
'FileSystem does not appear to support flock. Falling back to SoftFileLock for %s',
lock_file)
lock = SoftFileLock(lock_file, timeout=default_interval)
continue
else:
break
try:
yield lock
finally:
try:
lock.release()
except OSError:
try:
Path(lock_file).unlink()
except OSError:
pass

View File

@@ -0,0 +1,79 @@
import hashlib
import multiprocessing
import os
import tempfile
import unittest
from modelscope import snapshot_download
def download_model(model_name, cache_dir, enable_lock):
if not enable_lock:
os.environ['MODELSCOPE_HUB_FILE_LOCK'] = 'false'
snapshot_download(model_name, cache_dir=cache_dir)
class FileLockDownloadingTest(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
def tearDown(self):
self.temp_dir.cleanup()
def test_multi_processing_file_lock(self):
models = [
'iic/nlp_bert_relation-extraction_chinese-base',
'iic/nlp_bert_relation-extraction_chinese-base',
'iic/nlp_bert_relation-extraction_chinese-base',
]
args_list = [(model, self.temp_dir.name, True) for model in models]
with multiprocessing.Pool(processes=3) as pool:
pool.starmap(download_model, args_list)
def get_file_sha256(file_path):
sha256_hash = hashlib.sha256()
with open(file_path, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b''):
sha256_hash.update(chunk)
return sha256_hash.hexdigest()
tensor_file = os.path.join(
self.temp_dir.name, 'iic',
'nlp_bert_relation-extraction_chinese-base', 'pytorch_model.bin')
sha256 = '2b623d2c06c8101c1283657d35bc22d69bcc10f62ded0ba6d0606e4130f9c8af'
self.assertTrue(get_file_sha256(tensor_file) == sha256)
def test_multi_processing_disabled(self):
try:
models = [
'iic/nlp_bert_backbone_base_std',
'iic/nlp_bert_backbone_base_std',
'iic/nlp_bert_backbone_base_std',
]
args_list = [(model, self.temp_dir.name, False)
for model in models]
with multiprocessing.Pool(processes=3) as pool:
pool.starmap(download_model, args_list)
def get_file_sha256(file_path):
sha256_hash = hashlib.sha256()
with open(file_path, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b''):
sha256_hash.update(chunk)
return sha256_hash.hexdigest()
tensor_file = os.path.join(self.temp_dir.name, 'iic',
'nlp_bert_backbone_base_std',
'pytorch_model.bin')
sha256 = 'c6a293a8091f7eaa1ac7ecf88fd6f4cc00f6957188b2730d34faa787f15d3caa'
self.assertTrue(get_file_sha256(tensor_file) != sha256)
except Exception: # noqa
pass
if __name__ == '__main__':
unittest.main()