mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
weak file lock (#1417)
This commit is contained in:
@@ -4,17 +4,11 @@ import fnmatch
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
|
from contextlib import nullcontext
|
||||||
from http.cookiejar import CookieJar
|
from http.cookiejar import CookieJar
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Type, Union
|
from typing import Dict, List, Optional, Type, Union
|
||||||
|
|
||||||
from modelscope.hub.api import HubApi, ModelScopeConfig
|
|
||||||
from modelscope.hub.errors import InvalidParameter
|
|
||||||
from modelscope.hub.file_download import (create_temporary_directory_and_cache,
|
|
||||||
download_file, get_file_download_url)
|
|
||||||
from modelscope.hub.utils.caching import ModelFileSystemCache
|
|
||||||
from modelscope.hub.utils.utils import (get_model_masked_directory,
|
|
||||||
model_id_to_group_owner_name)
|
|
||||||
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
|
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
|
||||||
DEFAULT_MODEL_REVISION,
|
DEFAULT_MODEL_REVISION,
|
||||||
INTRA_CLOUD_ACCELERATION,
|
INTRA_CLOUD_ACCELERATION,
|
||||||
@@ -23,7 +17,15 @@ from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
|
|||||||
from modelscope.utils.file_utils import get_modelscope_cache_dir
|
from modelscope.utils.file_utils import get_modelscope_cache_dir
|
||||||
from modelscope.utils.logger import get_logger
|
from modelscope.utils.logger import get_logger
|
||||||
from modelscope.utils.thread_utils import thread_executor
|
from modelscope.utils.thread_utils import thread_executor
|
||||||
|
from .api import HubApi, ModelScopeConfig
|
||||||
from .callback import ProgressCallback
|
from .callback import ProgressCallback
|
||||||
|
from .errors import InvalidParameter
|
||||||
|
from .file_download import (create_temporary_directory_and_cache,
|
||||||
|
download_file, get_file_download_url)
|
||||||
|
from .utils.caching import ModelFileSystemCache
|
||||||
|
from .utils.utils import (get_model_masked_directory,
|
||||||
|
model_id_to_group_owner_name, strtobool,
|
||||||
|
weak_file_lock)
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
@@ -43,6 +45,7 @@ def snapshot_download(
|
|||||||
max_workers: int = 8,
|
max_workers: int = 8,
|
||||||
repo_id: str = None,
|
repo_id: str = None,
|
||||||
repo_type: Optional[str] = REPO_TYPE_MODEL,
|
repo_type: Optional[str] = REPO_TYPE_MODEL,
|
||||||
|
enable_file_lock: Optional[bool] = None,
|
||||||
progress_callbacks: List[Type[ProgressCallback]] = None,
|
progress_callbacks: List[Type[ProgressCallback]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Download all files of a repo.
|
"""Download all files of a repo.
|
||||||
@@ -79,6 +82,9 @@ def snapshot_download(
|
|||||||
If provided, files matching any of the patterns are not downloaded, priority over ignore_file_pattern.
|
If provided, files matching any of the patterns are not downloaded, priority over ignore_file_pattern.
|
||||||
For hugging-face compatibility.
|
For hugging-face compatibility.
|
||||||
max_workers (`int`): The maximum number of workers to download files, default 8.
|
max_workers (`int`): The maximum number of workers to download files, default 8.
|
||||||
|
enable_file_lock (`bool`): Enable file lock, this is useful in multiprocessing downloading, default `True`.
|
||||||
|
If you find something wrong with file lock and have a problem modifying your code,
|
||||||
|
change `MODELSCOPE_HUB_FILE_LOCK` env to `false`.
|
||||||
progress_callbacks (`List[Type[ProgressCallback]]`, **optional**, default to `None`):
|
progress_callbacks (`List[Type[ProgressCallback]]`, **optional**, default to `None`):
|
||||||
progress callbacks to track the download progress.
|
progress callbacks to track the download progress.
|
||||||
Raises:
|
Raises:
|
||||||
@@ -109,21 +115,35 @@ def snapshot_download(
|
|||||||
if revision is None:
|
if revision is None:
|
||||||
revision = DEFAULT_DATASET_REVISION if repo_type == REPO_TYPE_DATASET else DEFAULT_MODEL_REVISION
|
revision = DEFAULT_DATASET_REVISION if repo_type == REPO_TYPE_DATASET else DEFAULT_MODEL_REVISION
|
||||||
|
|
||||||
return _snapshot_download(
|
if enable_file_lock is None:
|
||||||
repo_id,
|
enable_file_lock = strtobool(
|
||||||
repo_type=repo_type,
|
os.environ.get('MODELSCOPE_HUB_FILE_LOCK', 'true'))
|
||||||
revision=revision,
|
|
||||||
cache_dir=cache_dir,
|
if enable_file_lock:
|
||||||
user_agent=user_agent,
|
system_cache = cache_dir if cache_dir is not None else get_modelscope_cache_dir(
|
||||||
local_files_only=local_files_only,
|
)
|
||||||
cookies=cookies,
|
os.makedirs(os.path.join(system_cache, '.lock'), exist_ok=True)
|
||||||
ignore_file_pattern=ignore_file_pattern,
|
lock_file = os.path.join(system_cache, '.lock',
|
||||||
allow_file_pattern=allow_file_pattern,
|
repo_id.replace('/', '___'))
|
||||||
local_dir=local_dir,
|
context = weak_file_lock(lock_file)
|
||||||
ignore_patterns=ignore_patterns,
|
else:
|
||||||
allow_patterns=allow_patterns,
|
context = nullcontext()
|
||||||
max_workers=max_workers,
|
with context:
|
||||||
progress_callbacks=progress_callbacks)
|
return _snapshot_download(
|
||||||
|
repo_id,
|
||||||
|
repo_type=repo_type,
|
||||||
|
revision=revision,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
user_agent=user_agent,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
cookies=cookies,
|
||||||
|
ignore_file_pattern=ignore_file_pattern,
|
||||||
|
allow_file_pattern=allow_file_pattern,
|
||||||
|
local_dir=local_dir,
|
||||||
|
ignore_patterns=ignore_patterns,
|
||||||
|
allow_patterns=allow_patterns,
|
||||||
|
max_workers=max_workers,
|
||||||
|
progress_callbacks=progress_callbacks)
|
||||||
|
|
||||||
|
|
||||||
def dataset_snapshot_download(
|
def dataset_snapshot_download(
|
||||||
@@ -138,6 +158,7 @@ def dataset_snapshot_download(
|
|||||||
allow_file_pattern: Optional[Union[str, List[str]]] = None,
|
allow_file_pattern: Optional[Union[str, List[str]]] = None,
|
||||||
allow_patterns: Optional[Union[List[str], str]] = None,
|
allow_patterns: Optional[Union[List[str], str]] = None,
|
||||||
ignore_patterns: Optional[Union[List[str], str]] = None,
|
ignore_patterns: Optional[Union[List[str], str]] = None,
|
||||||
|
enable_file_lock: Optional[bool] = None,
|
||||||
max_workers: int = 8,
|
max_workers: int = 8,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Download raw files of a dataset.
|
"""Download raw files of a dataset.
|
||||||
@@ -171,6 +192,9 @@ def dataset_snapshot_download(
|
|||||||
ignore_patterns (`str` or `List`, *optional*, default to `None`):
|
ignore_patterns (`str` or `List`, *optional*, default to `None`):
|
||||||
If provided, files matching any of the patterns are not downloaded, priority over ignore_file_pattern.
|
If provided, files matching any of the patterns are not downloaded, priority over ignore_file_pattern.
|
||||||
For hugging-face compatibility.
|
For hugging-face compatibility.
|
||||||
|
enable_file_lock (`bool`): Enable file lock, this is useful in multiprocessing downloading, default `True`.
|
||||||
|
If you find something wrong with file lock and have a problem modifying your code,
|
||||||
|
change `MODELSCOPE_HUB_FILE_LOCK` env to `false`.
|
||||||
max_workers (`int`): The maximum number of workers to download files, default 8.
|
max_workers (`int`): The maximum number of workers to download files, default 8.
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: the value details.
|
ValueError: the value details.
|
||||||
@@ -187,20 +211,34 @@ def dataset_snapshot_download(
|
|||||||
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||||
if some parameter value is invalid
|
if some parameter value is invalid
|
||||||
"""
|
"""
|
||||||
return _snapshot_download(
|
if enable_file_lock is None:
|
||||||
dataset_id,
|
enable_file_lock = strtobool(
|
||||||
repo_type=REPO_TYPE_DATASET,
|
os.environ.get('MODELSCOPE_HUB_FILE_LOCK', 'true'))
|
||||||
revision=revision,
|
|
||||||
cache_dir=cache_dir,
|
if enable_file_lock:
|
||||||
user_agent=user_agent,
|
system_cache = cache_dir if cache_dir is not None else get_modelscope_cache_dir(
|
||||||
local_files_only=local_files_only,
|
)
|
||||||
cookies=cookies,
|
os.makedirs(os.path.join(system_cache, '.lock'), exist_ok=True)
|
||||||
ignore_file_pattern=ignore_file_pattern,
|
lock_file = os.path.join(system_cache, '.lock',
|
||||||
allow_file_pattern=allow_file_pattern,
|
dataset_id.replace('/', '___'))
|
||||||
local_dir=local_dir,
|
context = weak_file_lock(lock_file)
|
||||||
ignore_patterns=ignore_patterns,
|
else:
|
||||||
allow_patterns=allow_patterns,
|
context = nullcontext()
|
||||||
max_workers=max_workers)
|
with context:
|
||||||
|
return _snapshot_download(
|
||||||
|
dataset_id,
|
||||||
|
repo_type=REPO_TYPE_DATASET,
|
||||||
|
revision=revision,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
user_agent=user_agent,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
cookies=cookies,
|
||||||
|
ignore_file_pattern=ignore_file_pattern,
|
||||||
|
allow_file_pattern=allow_file_pattern,
|
||||||
|
local_dir=local_dir,
|
||||||
|
ignore_patterns=ignore_patterns,
|
||||||
|
allow_patterns=allow_patterns,
|
||||||
|
max_workers=max_workers)
|
||||||
|
|
||||||
|
|
||||||
def _snapshot_download(
|
def _snapshot_download(
|
||||||
|
|||||||
@@ -1,12 +1,15 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
|
||||||
|
import contextlib
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Union
|
from typing import Generator, List, Optional, Union
|
||||||
|
|
||||||
|
from filelock import BaseFileLock, FileLock, SoftFileLock, Timeout
|
||||||
|
|
||||||
from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN,
|
from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN,
|
||||||
DEFAULT_MODELSCOPE_GROUP,
|
DEFAULT_MODELSCOPE_GROUP,
|
||||||
@@ -242,3 +245,57 @@ def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str:
|
|||||||
for row in rows:
|
for row in rows:
|
||||||
lines.append(row_format.format(*row))
|
lines.append(row_format.format(*row))
|
||||||
return '\n'.join(lines)
|
return '\n'.join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
# Part of the code borrowed from the awesome work of huggingface_hub/transformers
|
||||||
|
def strtobool(val):
|
||||||
|
val = val.lower()
|
||||||
|
if val in {'y', 'yes', 't', 'true', 'on', '1'}:
|
||||||
|
return 1
|
||||||
|
if val in {'n', 'no', 'f', 'false', 'off', '0'}:
|
||||||
|
return 0
|
||||||
|
raise ValueError(f'invalid truth value {val!r}')
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def weak_file_lock(lock_file: Union[str, Path],
|
||||||
|
*,
|
||||||
|
timeout: Optional[float] = None
|
||||||
|
) -> Generator[BaseFileLock, None, None]:
|
||||||
|
default_interval = 60
|
||||||
|
lock = FileLock(lock_file, timeout=default_interval)
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
if timeout is not None and elapsed_time >= timeout:
|
||||||
|
raise Timeout(str(lock_file))
|
||||||
|
|
||||||
|
try:
|
||||||
|
lock.acquire(
|
||||||
|
timeout=min(default_interval, timeout - elapsed_time)
|
||||||
|
if timeout else default_interval) # noqa
|
||||||
|
except Timeout:
|
||||||
|
logger.info(
|
||||||
|
f'Still waiting to acquire lock on {lock_file} (elapsed: {time.time() - start_time:.1f} seconds)'
|
||||||
|
)
|
||||||
|
except NotImplementedError as e:
|
||||||
|
if 'use SoftFileLock instead' in str(e):
|
||||||
|
logger.warning(
|
||||||
|
'FileSystem does not appear to support flock. Falling back to SoftFileLock for %s',
|
||||||
|
lock_file)
|
||||||
|
lock = SoftFileLock(lock_file, timeout=default_interval)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield lock
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
lock.release()
|
||||||
|
except OSError:
|
||||||
|
try:
|
||||||
|
Path(lock_file).unlink()
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|||||||
79
tests/hub/test_download_file_lock.py
Normal file
79
tests/hub/test_download_file_lock.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
import hashlib
|
||||||
|
import multiprocessing
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from modelscope import snapshot_download
|
||||||
|
|
||||||
|
|
||||||
|
def download_model(model_name, cache_dir, enable_lock):
|
||||||
|
if not enable_lock:
|
||||||
|
os.environ['MODELSCOPE_HUB_FILE_LOCK'] = 'false'
|
||||||
|
snapshot_download(model_name, cache_dir=cache_dir)
|
||||||
|
|
||||||
|
|
||||||
|
class FileLockDownloadingTest(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.temp_dir = tempfile.TemporaryDirectory()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.temp_dir.cleanup()
|
||||||
|
|
||||||
|
def test_multi_processing_file_lock(self):
|
||||||
|
|
||||||
|
models = [
|
||||||
|
'iic/nlp_bert_relation-extraction_chinese-base',
|
||||||
|
'iic/nlp_bert_relation-extraction_chinese-base',
|
||||||
|
'iic/nlp_bert_relation-extraction_chinese-base',
|
||||||
|
]
|
||||||
|
args_list = [(model, self.temp_dir.name, True) for model in models]
|
||||||
|
|
||||||
|
with multiprocessing.Pool(processes=3) as pool:
|
||||||
|
pool.starmap(download_model, args_list)
|
||||||
|
|
||||||
|
def get_file_sha256(file_path):
|
||||||
|
sha256_hash = hashlib.sha256()
|
||||||
|
with open(file_path, 'rb') as f:
|
||||||
|
for chunk in iter(lambda: f.read(4096), b''):
|
||||||
|
sha256_hash.update(chunk)
|
||||||
|
return sha256_hash.hexdigest()
|
||||||
|
|
||||||
|
tensor_file = os.path.join(
|
||||||
|
self.temp_dir.name, 'iic',
|
||||||
|
'nlp_bert_relation-extraction_chinese-base', 'pytorch_model.bin')
|
||||||
|
sha256 = '2b623d2c06c8101c1283657d35bc22d69bcc10f62ded0ba6d0606e4130f9c8af'
|
||||||
|
self.assertTrue(get_file_sha256(tensor_file) == sha256)
|
||||||
|
|
||||||
|
def test_multi_processing_disabled(self):
|
||||||
|
try:
|
||||||
|
models = [
|
||||||
|
'iic/nlp_bert_backbone_base_std',
|
||||||
|
'iic/nlp_bert_backbone_base_std',
|
||||||
|
'iic/nlp_bert_backbone_base_std',
|
||||||
|
]
|
||||||
|
args_list = [(model, self.temp_dir.name, False)
|
||||||
|
for model in models]
|
||||||
|
|
||||||
|
with multiprocessing.Pool(processes=3) as pool:
|
||||||
|
pool.starmap(download_model, args_list)
|
||||||
|
|
||||||
|
def get_file_sha256(file_path):
|
||||||
|
sha256_hash = hashlib.sha256()
|
||||||
|
with open(file_path, 'rb') as f:
|
||||||
|
for chunk in iter(lambda: f.read(4096), b''):
|
||||||
|
sha256_hash.update(chunk)
|
||||||
|
return sha256_hash.hexdigest()
|
||||||
|
|
||||||
|
tensor_file = os.path.join(self.temp_dir.name, 'iic',
|
||||||
|
'nlp_bert_backbone_base_std',
|
||||||
|
'pytorch_model.bin')
|
||||||
|
sha256 = 'c6a293a8091f7eaa1ac7ecf88fd6f4cc00f6957188b2730d34faa787f15d3caa'
|
||||||
|
self.assertTrue(get_file_sha256(tensor_file) != sha256)
|
||||||
|
except Exception: # noqa
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user