mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
weak file lock (#1417)
This commit is contained in:
@@ -4,17 +4,11 @@ import fnmatch
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from contextlib import nullcontext
|
||||
from http.cookiejar import CookieJar
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Type, Union
|
||||
|
||||
from modelscope.hub.api import HubApi, ModelScopeConfig
|
||||
from modelscope.hub.errors import InvalidParameter
|
||||
from modelscope.hub.file_download import (create_temporary_directory_and_cache,
|
||||
download_file, get_file_download_url)
|
||||
from modelscope.hub.utils.caching import ModelFileSystemCache
|
||||
from modelscope.hub.utils.utils import (get_model_masked_directory,
|
||||
model_id_to_group_owner_name)
|
||||
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
|
||||
DEFAULT_MODEL_REVISION,
|
||||
INTRA_CLOUD_ACCELERATION,
|
||||
@@ -23,7 +17,15 @@ from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
|
||||
from modelscope.utils.file_utils import get_modelscope_cache_dir
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.thread_utils import thread_executor
|
||||
from .api import HubApi, ModelScopeConfig
|
||||
from .callback import ProgressCallback
|
||||
from .errors import InvalidParameter
|
||||
from .file_download import (create_temporary_directory_and_cache,
|
||||
download_file, get_file_download_url)
|
||||
from .utils.caching import ModelFileSystemCache
|
||||
from .utils.utils import (get_model_masked_directory,
|
||||
model_id_to_group_owner_name, strtobool,
|
||||
weak_file_lock)
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
@@ -43,6 +45,7 @@ def snapshot_download(
|
||||
max_workers: int = 8,
|
||||
repo_id: str = None,
|
||||
repo_type: Optional[str] = REPO_TYPE_MODEL,
|
||||
enable_file_lock: Optional[bool] = None,
|
||||
progress_callbacks: List[Type[ProgressCallback]] = None,
|
||||
) -> str:
|
||||
"""Download all files of a repo.
|
||||
@@ -79,6 +82,9 @@ def snapshot_download(
|
||||
If provided, files matching any of the patterns are not downloaded, priority over ignore_file_pattern.
|
||||
For hugging-face compatibility.
|
||||
max_workers (`int`): The maximum number of workers to download files, default 8.
|
||||
enable_file_lock (`bool`): Enable file lock, this is useful in multiprocessing downloading, default `True`.
|
||||
If you find something wrong with file lock and have a problem modifying your code,
|
||||
change `MODELSCOPE_HUB_FILE_LOCK` env to `false`.
|
||||
progress_callbacks (`List[Type[ProgressCallback]]`, **optional**, default to `None`):
|
||||
progress callbacks to track the download progress.
|
||||
Raises:
|
||||
@@ -109,6 +115,20 @@ def snapshot_download(
|
||||
if revision is None:
|
||||
revision = DEFAULT_DATASET_REVISION if repo_type == REPO_TYPE_DATASET else DEFAULT_MODEL_REVISION
|
||||
|
||||
if enable_file_lock is None:
|
||||
enable_file_lock = strtobool(
|
||||
os.environ.get('MODELSCOPE_HUB_FILE_LOCK', 'true'))
|
||||
|
||||
if enable_file_lock:
|
||||
system_cache = cache_dir if cache_dir is not None else get_modelscope_cache_dir(
|
||||
)
|
||||
os.makedirs(os.path.join(system_cache, '.lock'), exist_ok=True)
|
||||
lock_file = os.path.join(system_cache, '.lock',
|
||||
repo_id.replace('/', '___'))
|
||||
context = weak_file_lock(lock_file)
|
||||
else:
|
||||
context = nullcontext()
|
||||
with context:
|
||||
return _snapshot_download(
|
||||
repo_id,
|
||||
repo_type=repo_type,
|
||||
@@ -138,6 +158,7 @@ def dataset_snapshot_download(
|
||||
allow_file_pattern: Optional[Union[str, List[str]]] = None,
|
||||
allow_patterns: Optional[Union[List[str], str]] = None,
|
||||
ignore_patterns: Optional[Union[List[str], str]] = None,
|
||||
enable_file_lock: Optional[bool] = None,
|
||||
max_workers: int = 8,
|
||||
) -> str:
|
||||
"""Download raw files of a dataset.
|
||||
@@ -171,6 +192,9 @@ def dataset_snapshot_download(
|
||||
ignore_patterns (`str` or `List`, *optional*, default to `None`):
|
||||
If provided, files matching any of the patterns are not downloaded, priority over ignore_file_pattern.
|
||||
For hugging-face compatibility.
|
||||
enable_file_lock (`bool`): Enable file lock, this is useful in multiprocessing downloading, default `True`.
|
||||
If you find something wrong with file lock and have a problem modifying your code,
|
||||
change `MODELSCOPE_HUB_FILE_LOCK` env to `false`.
|
||||
max_workers (`int`): The maximum number of workers to download files, default 8.
|
||||
Raises:
|
||||
ValueError: the value details.
|
||||
@@ -187,6 +211,20 @@ def dataset_snapshot_download(
|
||||
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||
if some parameter value is invalid
|
||||
"""
|
||||
if enable_file_lock is None:
|
||||
enable_file_lock = strtobool(
|
||||
os.environ.get('MODELSCOPE_HUB_FILE_LOCK', 'true'))
|
||||
|
||||
if enable_file_lock:
|
||||
system_cache = cache_dir if cache_dir is not None else get_modelscope_cache_dir(
|
||||
)
|
||||
os.makedirs(os.path.join(system_cache, '.lock'), exist_ok=True)
|
||||
lock_file = os.path.join(system_cache, '.lock',
|
||||
dataset_id.replace('/', '___'))
|
||||
context = weak_file_lock(lock_file)
|
||||
else:
|
||||
context = nullcontext()
|
||||
with context:
|
||||
return _snapshot_download(
|
||||
dataset_id,
|
||||
repo_type=REPO_TYPE_DATASET,
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import contextlib
|
||||
import hashlib
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
from typing import Generator, List, Optional, Union
|
||||
|
||||
from filelock import BaseFileLock, FileLock, SoftFileLock, Timeout
|
||||
|
||||
from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN,
|
||||
DEFAULT_MODELSCOPE_GROUP,
|
||||
@@ -242,3 +245,57 @@ def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str:
|
||||
for row in rows:
|
||||
lines.append(row_format.format(*row))
|
||||
return '\n'.join(lines)
|
||||
|
||||
|
||||
# Part of the code borrowed from the awesome work of huggingface_hub/transformers
|
||||
def strtobool(val):
|
||||
val = val.lower()
|
||||
if val in {'y', 'yes', 't', 'true', 'on', '1'}:
|
||||
return 1
|
||||
if val in {'n', 'no', 'f', 'false', 'off', '0'}:
|
||||
return 0
|
||||
raise ValueError(f'invalid truth value {val!r}')
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def weak_file_lock(lock_file: Union[str, Path],
|
||||
*,
|
||||
timeout: Optional[float] = None
|
||||
) -> Generator[BaseFileLock, None, None]:
|
||||
default_interval = 60
|
||||
lock = FileLock(lock_file, timeout=default_interval)
|
||||
start_time = time.time()
|
||||
|
||||
while True:
|
||||
elapsed_time = time.time() - start_time
|
||||
if timeout is not None and elapsed_time >= timeout:
|
||||
raise Timeout(str(lock_file))
|
||||
|
||||
try:
|
||||
lock.acquire(
|
||||
timeout=min(default_interval, timeout - elapsed_time)
|
||||
if timeout else default_interval) # noqa
|
||||
except Timeout:
|
||||
logger.info(
|
||||
f'Still waiting to acquire lock on {lock_file} (elapsed: {time.time() - start_time:.1f} seconds)'
|
||||
)
|
||||
except NotImplementedError as e:
|
||||
if 'use SoftFileLock instead' in str(e):
|
||||
logger.warning(
|
||||
'FileSystem does not appear to support flock. Falling back to SoftFileLock for %s',
|
||||
lock_file)
|
||||
lock = SoftFileLock(lock_file, timeout=default_interval)
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
try:
|
||||
yield lock
|
||||
finally:
|
||||
try:
|
||||
lock.release()
|
||||
except OSError:
|
||||
try:
|
||||
Path(lock_file).unlink()
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
79
tests/hub/test_download_file_lock.py
Normal file
79
tests/hub/test_download_file_lock.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import hashlib
|
||||
import multiprocessing
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from modelscope import snapshot_download
|
||||
|
||||
|
||||
def download_model(model_name, cache_dir, enable_lock):
|
||||
if not enable_lock:
|
||||
os.environ['MODELSCOPE_HUB_FILE_LOCK'] = 'false'
|
||||
snapshot_download(model_name, cache_dir=cache_dir)
|
||||
|
||||
|
||||
class FileLockDownloadingTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
|
||||
def tearDown(self):
|
||||
self.temp_dir.cleanup()
|
||||
|
||||
def test_multi_processing_file_lock(self):
|
||||
|
||||
models = [
|
||||
'iic/nlp_bert_relation-extraction_chinese-base',
|
||||
'iic/nlp_bert_relation-extraction_chinese-base',
|
||||
'iic/nlp_bert_relation-extraction_chinese-base',
|
||||
]
|
||||
args_list = [(model, self.temp_dir.name, True) for model in models]
|
||||
|
||||
with multiprocessing.Pool(processes=3) as pool:
|
||||
pool.starmap(download_model, args_list)
|
||||
|
||||
def get_file_sha256(file_path):
|
||||
sha256_hash = hashlib.sha256()
|
||||
with open(file_path, 'rb') as f:
|
||||
for chunk in iter(lambda: f.read(4096), b''):
|
||||
sha256_hash.update(chunk)
|
||||
return sha256_hash.hexdigest()
|
||||
|
||||
tensor_file = os.path.join(
|
||||
self.temp_dir.name, 'iic',
|
||||
'nlp_bert_relation-extraction_chinese-base', 'pytorch_model.bin')
|
||||
sha256 = '2b623d2c06c8101c1283657d35bc22d69bcc10f62ded0ba6d0606e4130f9c8af'
|
||||
self.assertTrue(get_file_sha256(tensor_file) == sha256)
|
||||
|
||||
def test_multi_processing_disabled(self):
|
||||
try:
|
||||
models = [
|
||||
'iic/nlp_bert_backbone_base_std',
|
||||
'iic/nlp_bert_backbone_base_std',
|
||||
'iic/nlp_bert_backbone_base_std',
|
||||
]
|
||||
args_list = [(model, self.temp_dir.name, False)
|
||||
for model in models]
|
||||
|
||||
with multiprocessing.Pool(processes=3) as pool:
|
||||
pool.starmap(download_model, args_list)
|
||||
|
||||
def get_file_sha256(file_path):
|
||||
sha256_hash = hashlib.sha256()
|
||||
with open(file_path, 'rb') as f:
|
||||
for chunk in iter(lambda: f.read(4096), b''):
|
||||
sha256_hash.update(chunk)
|
||||
return sha256_hash.hexdigest()
|
||||
|
||||
tensor_file = os.path.join(self.temp_dir.name, 'iic',
|
||||
'nlp_bert_backbone_base_std',
|
||||
'pytorch_model.bin')
|
||||
sha256 = 'c6a293a8091f7eaa1ac7ecf88fd6f4cc00f6957188b2730d34faa787f15d3caa'
|
||||
self.assertTrue(get_file_sha256(tensor_file) != sha256)
|
||||
except Exception: # noqa
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user