From 4d65fe0cfa40ab112768184d69138ca7ae8e36a8 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 10 Jan 2025 15:47:44 +0800 Subject: [PATCH] fix cache --- modelscope/hub/snapshot_download.py | 10 +++++----- modelscope/utils/config_ds.py | 7 +++---- modelscope/utils/file_utils.py | 2 +- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/modelscope/hub/snapshot_download.py b/modelscope/hub/snapshot_download.py index 4510280b..bc69470a 100644 --- a/modelscope/hub/snapshot_download.py +++ b/modelscope/hub/snapshot_download.py @@ -21,6 +21,7 @@ from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, DEFAULT_MODEL_REVISION, REPO_TYPE_DATASET, REPO_TYPE_MODEL, REPO_TYPE_SUPPORT) +from modelscope.utils.file_utils import get_modelscope_cache_dir from modelscope.utils.logger import get_logger from modelscope.utils.thread_utils import thread_executor @@ -204,9 +205,8 @@ def _snapshot_download( temporary_cache_dir, cache = create_temporary_directory_and_cache( repo_id, local_dir=local_dir, cache_dir=cache_dir, repo_type=repo_type) - system_cache = cache_dir if cache_dir is not None else os.getenv( - 'MODELSCOPE_CACHE', - Path.home().joinpath('.cache', 'modelscope', 'hub')) + system_cache = cache_dir if cache_dir is not None else get_modelscope_cache_dir( + ) if local_files_only: if len(cache.cached_files) == 0: raise ValueError( @@ -233,7 +233,7 @@ def _snapshot_download( if repo_type == REPO_TYPE_MODEL: directory = os.path.abspath( local_dir) if local_dir is not None else os.path.join( - system_cache, repo_id) + system_cache, 'hub', repo_id) print(f'Downloading Model to directory: {directory}') revision_detail = _api.get_valid_revision_detail( repo_id, revision=revision, cookies=cookies) @@ -294,7 +294,7 @@ def _snapshot_download( elif repo_type == REPO_TYPE_DATASET: directory = os.path.abspath( local_dir) if local_dir else os.path.join( - system_cache, 'datasets', repo_id) + system_cache, 'hub', 'datasets', repo_id) print(f'Downloading Dataset to directory: {directory}') group_or_owner, name = model_id_to_group_owner_name(repo_id) diff --git a/modelscope/utils/config_ds.py b/modelscope/utils/config_ds.py index 72a25887..87551e1b 100644 --- a/modelscope/utils/config_ds.py +++ b/modelscope/utils/config_ds.py @@ -5,13 +5,12 @@ from pathlib import Path # Cache location from modelscope.hub.constants import DEFAULT_MODELSCOPE_DATA_ENDPOINT -from modelscope.utils.file_utils import get_modelscope_cache_dir +from modelscope.utils.file_utils import get_dataset_cache_root, get_modelscope_cache_dir MS_CACHE_HOME = get_modelscope_cache_dir() -DEFAULT_MS_DATASETS_CACHE = os.path.join(MS_CACHE_HOME, 'hub', 'datasets') -MS_DATASETS_CACHE = Path( - os.getenv('MS_DATASETS_CACHE', DEFAULT_MS_DATASETS_CACHE)) +# NOTE: remove `MS_DATASETS_CACHE` env, default is `{MODELSCOPE_CACHE}/hub/datasets` +MS_DATASETS_CACHE = get_dataset_cache_root() DOWNLOADED_DATASETS_DIR = 'downloads' DEFAULT_DOWNLOADED_DATASETS_PATH = os.path.join(MS_DATASETS_CACHE, diff --git a/modelscope/utils/file_utils.py b/modelscope/utils/file_utils.py index c00e8d26..a334f89e 100644 --- a/modelscope/utils/file_utils.py +++ b/modelscope/utils/file_utils.py @@ -64,7 +64,7 @@ def get_dataset_cache_root() -> str: Returns: str: the modelscope dataset raw file cache root. """ - return os.path.join(get_modelscope_cache_dir(), 'datasets') + return os.path.join(get_modelscope_cache_dir(), 'hub', 'datasets') def get_dataset_cache_dir(dataset_id: str) -> str: