diff --git a/modelscope/hub/snapshot_download.py b/modelscope/hub/snapshot_download.py index e64b2f67..a2e91d7e 100644 --- a/modelscope/hub/snapshot_download.py +++ b/modelscope/hub/snapshot_download.py @@ -4,17 +4,11 @@ import fnmatch import os import re import uuid +from contextlib import nullcontext from http.cookiejar import CookieJar from pathlib import Path from typing import Dict, List, Optional, Type, Union -from modelscope.hub.api import HubApi, ModelScopeConfig -from modelscope.hub.errors import InvalidParameter -from modelscope.hub.file_download import (create_temporary_directory_and_cache, - download_file, get_file_download_url) -from modelscope.hub.utils.caching import ModelFileSystemCache -from modelscope.hub.utils.utils import (get_model_masked_directory, - model_id_to_group_owner_name) from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, DEFAULT_MODEL_REVISION, INTRA_CLOUD_ACCELERATION, @@ -23,7 +17,15 @@ from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, from modelscope.utils.file_utils import get_modelscope_cache_dir from modelscope.utils.logger import get_logger from modelscope.utils.thread_utils import thread_executor +from .api import HubApi, ModelScopeConfig from .callback import ProgressCallback +from .errors import InvalidParameter +from .file_download import (create_temporary_directory_and_cache, + download_file, get_file_download_url) +from .utils.caching import ModelFileSystemCache +from .utils.utils import (get_model_masked_directory, + model_id_to_group_owner_name, strtobool, + weak_file_lock) logger = get_logger() @@ -43,6 +45,7 @@ def snapshot_download( max_workers: int = 8, repo_id: str = None, repo_type: Optional[str] = REPO_TYPE_MODEL, + enable_file_lock: Optional[bool] = None, progress_callbacks: List[Type[ProgressCallback]] = None, ) -> str: """Download all files of a repo. @@ -79,6 +82,9 @@ def snapshot_download( If provided, files matching any of the patterns are not downloaded, priority over ignore_file_pattern. For hugging-face compatibility. max_workers (`int`): The maximum number of workers to download files, default 8. + enable_file_lock (`bool`): Enable file lock, this is useful in multiprocessing downloading, default `True`. + If you find something wrong with file lock and have a problem modifying your code, + change `MODELSCOPE_HUB_FILE_LOCK` env to `false`. progress_callbacks (`List[Type[ProgressCallback]]`, **optional**, default to `None`): progress callbacks to track the download progress. Raises: @@ -109,21 +115,35 @@ def snapshot_download( if revision is None: revision = DEFAULT_DATASET_REVISION if repo_type == REPO_TYPE_DATASET else DEFAULT_MODEL_REVISION - return _snapshot_download( - repo_id, - repo_type=repo_type, - revision=revision, - cache_dir=cache_dir, - user_agent=user_agent, - local_files_only=local_files_only, - cookies=cookies, - ignore_file_pattern=ignore_file_pattern, - allow_file_pattern=allow_file_pattern, - local_dir=local_dir, - ignore_patterns=ignore_patterns, - allow_patterns=allow_patterns, - max_workers=max_workers, - progress_callbacks=progress_callbacks) + if enable_file_lock is None: + enable_file_lock = strtobool( + os.environ.get('MODELSCOPE_HUB_FILE_LOCK', 'true')) + + if enable_file_lock: + system_cache = cache_dir if cache_dir is not None else get_modelscope_cache_dir( + ) + os.makedirs(os.path.join(system_cache, '.lock'), exist_ok=True) + lock_file = os.path.join(system_cache, '.lock', + repo_id.replace('/', '___')) + context = weak_file_lock(lock_file) + else: + context = nullcontext() + with context: + return _snapshot_download( + repo_id, + repo_type=repo_type, + revision=revision, + cache_dir=cache_dir, + user_agent=user_agent, + local_files_only=local_files_only, + cookies=cookies, + ignore_file_pattern=ignore_file_pattern, + allow_file_pattern=allow_file_pattern, + local_dir=local_dir, + ignore_patterns=ignore_patterns, + allow_patterns=allow_patterns, + max_workers=max_workers, + progress_callbacks=progress_callbacks) def dataset_snapshot_download( @@ -138,6 +158,7 @@ def dataset_snapshot_download( allow_file_pattern: Optional[Union[str, List[str]]] = None, allow_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None, + enable_file_lock: Optional[bool] = None, max_workers: int = 8, ) -> str: """Download raw files of a dataset. @@ -171,6 +192,9 @@ def dataset_snapshot_download( ignore_patterns (`str` or `List`, *optional*, default to `None`): If provided, files matching any of the patterns are not downloaded, priority over ignore_file_pattern. For hugging-face compatibility. + enable_file_lock (`bool`): Enable file lock, this is useful in multiprocessing downloading, default `True`. + If you find something wrong with file lock and have a problem modifying your code, + change `MODELSCOPE_HUB_FILE_LOCK` env to `false`. max_workers (`int`): The maximum number of workers to download files, default 8. Raises: ValueError: the value details. @@ -187,20 +211,34 @@ def dataset_snapshot_download( - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid """ - return _snapshot_download( - dataset_id, - repo_type=REPO_TYPE_DATASET, - revision=revision, - cache_dir=cache_dir, - user_agent=user_agent, - local_files_only=local_files_only, - cookies=cookies, - ignore_file_pattern=ignore_file_pattern, - allow_file_pattern=allow_file_pattern, - local_dir=local_dir, - ignore_patterns=ignore_patterns, - allow_patterns=allow_patterns, - max_workers=max_workers) + if enable_file_lock is None: + enable_file_lock = strtobool( + os.environ.get('MODELSCOPE_HUB_FILE_LOCK', 'true')) + + if enable_file_lock: + system_cache = cache_dir if cache_dir is not None else get_modelscope_cache_dir( + ) + os.makedirs(os.path.join(system_cache, '.lock'), exist_ok=True) + lock_file = os.path.join(system_cache, '.lock', + dataset_id.replace('/', '___')) + context = weak_file_lock(lock_file) + else: + context = nullcontext() + with context: + return _snapshot_download( + dataset_id, + repo_type=REPO_TYPE_DATASET, + revision=revision, + cache_dir=cache_dir, + user_agent=user_agent, + local_files_only=local_files_only, + cookies=cookies, + ignore_file_pattern=ignore_file_pattern, + allow_file_pattern=allow_file_pattern, + local_dir=local_dir, + ignore_patterns=ignore_patterns, + allow_patterns=allow_patterns, + max_workers=max_workers) def _snapshot_download( diff --git a/modelscope/hub/utils/utils.py b/modelscope/hub/utils/utils.py index 3c5ee67a..28bcdbf2 100644 --- a/modelscope/hub/utils/utils.py +++ b/modelscope/hub/utils/utils.py @@ -1,12 +1,15 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import contextlib import hashlib import os import sys import time from datetime import datetime from pathlib import Path -from typing import List, Optional, Union +from typing import Generator, List, Optional, Union + +from filelock import BaseFileLock, FileLock, SoftFileLock, Timeout from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN, DEFAULT_MODELSCOPE_GROUP, @@ -242,3 +245,57 @@ def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str: for row in rows: lines.append(row_format.format(*row)) return '\n'.join(lines) + + +# Part of the code borrowed from the awesome work of huggingface_hub/transformers +def strtobool(val): + val = val.lower() + if val in {'y', 'yes', 't', 'true', 'on', '1'}: + return 1 + if val in {'n', 'no', 'f', 'false', 'off', '0'}: + return 0 + raise ValueError(f'invalid truth value {val!r}') + + +@contextlib.contextmanager +def weak_file_lock(lock_file: Union[str, Path], + *, + timeout: Optional[float] = None + ) -> Generator[BaseFileLock, None, None]: + default_interval = 60 + lock = FileLock(lock_file, timeout=default_interval) + start_time = time.time() + + while True: + elapsed_time = time.time() - start_time + if timeout is not None and elapsed_time >= timeout: + raise Timeout(str(lock_file)) + + try: + lock.acquire( + timeout=min(default_interval, timeout - elapsed_time) + if timeout else default_interval) # noqa + except Timeout: + logger.info( + f'Still waiting to acquire lock on {lock_file} (elapsed: {time.time() - start_time:.1f} seconds)' + ) + except NotImplementedError as e: + if 'use SoftFileLock instead' in str(e): + logger.warning( + 'FileSystem does not appear to support flock. Falling back to SoftFileLock for %s', + lock_file) + lock = SoftFileLock(lock_file, timeout=default_interval) + continue + else: + break + + try: + yield lock + finally: + try: + lock.release() + except OSError: + try: + Path(lock_file).unlink() + except OSError: + pass diff --git a/tests/hub/test_download_file_lock.py b/tests/hub/test_download_file_lock.py new file mode 100644 index 00000000..4a6fe803 --- /dev/null +++ b/tests/hub/test_download_file_lock.py @@ -0,0 +1,79 @@ +import hashlib +import multiprocessing +import os +import tempfile +import unittest + +from modelscope import snapshot_download + + +def download_model(model_name, cache_dir, enable_lock): + if not enable_lock: + os.environ['MODELSCOPE_HUB_FILE_LOCK'] = 'false' + snapshot_download(model_name, cache_dir=cache_dir) + + +class FileLockDownloadingTest(unittest.TestCase): + + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + + def tearDown(self): + self.temp_dir.cleanup() + + def test_multi_processing_file_lock(self): + + models = [ + 'iic/nlp_bert_relation-extraction_chinese-base', + 'iic/nlp_bert_relation-extraction_chinese-base', + 'iic/nlp_bert_relation-extraction_chinese-base', + ] + args_list = [(model, self.temp_dir.name, True) for model in models] + + with multiprocessing.Pool(processes=3) as pool: + pool.starmap(download_model, args_list) + + def get_file_sha256(file_path): + sha256_hash = hashlib.sha256() + with open(file_path, 'rb') as f: + for chunk in iter(lambda: f.read(4096), b''): + sha256_hash.update(chunk) + return sha256_hash.hexdigest() + + tensor_file = os.path.join( + self.temp_dir.name, 'iic', + 'nlp_bert_relation-extraction_chinese-base', 'pytorch_model.bin') + sha256 = '2b623d2c06c8101c1283657d35bc22d69bcc10f62ded0ba6d0606e4130f9c8af' + self.assertTrue(get_file_sha256(tensor_file) == sha256) + + def test_multi_processing_disabled(self): + try: + models = [ + 'iic/nlp_bert_backbone_base_std', + 'iic/nlp_bert_backbone_base_std', + 'iic/nlp_bert_backbone_base_std', + ] + args_list = [(model, self.temp_dir.name, False) + for model in models] + + with multiprocessing.Pool(processes=3) as pool: + pool.starmap(download_model, args_list) + + def get_file_sha256(file_path): + sha256_hash = hashlib.sha256() + with open(file_path, 'rb') as f: + for chunk in iter(lambda: f.read(4096), b''): + sha256_hash.update(chunk) + return sha256_hash.hexdigest() + + tensor_file = os.path.join(self.temp_dir.name, 'iic', + 'nlp_bert_backbone_base_std', + 'pytorch_model.bin') + sha256 = 'c6a293a8091f7eaa1ac7ecf88fd6f4cc00f6957188b2730d34faa787f15d3caa' + self.assertTrue(get_file_sha256(tensor_file) != sha256) + except Exception: # noqa + pass + + +if __name__ == '__main__': + unittest.main()