refactor cache path

This commit is contained in:
mulin.lyh
2024-05-23 15:08:17 +08:00
parent 7fbc411605
commit 49fe172c5f
12 changed files with 55 additions and 48 deletions

View File

@@ -267,6 +267,8 @@ class HubApi:
This function must be called before calling HubApi's login with a valid token
which can be obtained from ModelScope's website.
If any error, please upload via git commands.
Args:
model_id (str):
The model id to be uploaded, caller must have write permission for it.
@@ -336,7 +338,7 @@ class HubApi:
git_wrapper = GitCommandWrapper()
try:
repo = Repository(model_dir=tmp_dir, clone_from=model_id)
branches, _ = self.get_model_branches_and_tags(model_id=model_id, use_cookies=cookies)
branches = git_wrapper.get_remote_branches(tmp_dir)
if revision not in branches:
logger.info('Create new branch %s' % revision)
git_wrapper.new_branch(tmp_dir, revision)

View File

@@ -21,11 +21,12 @@ from modelscope.hub.constants import (
API_FILE_DOWNLOAD_TIMEOUT, FILE_HASH, MODELSCOPE_DOWNLOAD_PARALLELS,
MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB)
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
from modelscope.utils.file_utils import get_model_cache_root
from modelscope.utils.logger import get_logger
from .errors import FileDownloadError, NotExistError
from .utils.caching import ModelFileSystemCache
from .utils.utils import (file_integrity_validation, get_endpoint,
get_model_cache_dir, model_id_to_group_owner_name)
model_id_to_group_owner_name)
logger = get_logger()
@@ -75,7 +76,7 @@ def model_file_download(
if some parameter value is invalid
"""
if cache_dir is None:
cache_dir = get_model_cache_dir()
cache_dir = get_model_cache_root()
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
temporary_cache_dir = os.path.join(cache_dir, 'temp')

View File

@@ -45,8 +45,9 @@ class GitCommandWrapper(metaclass=Singleton):
logger.debug(' '.join(args))
git_env = os.environ.copy()
git_env['GIT_TERMINAL_PROMPT'] = '0'
command = [self.git_path, *args]
response = subprocess.run(
[self.git_path, *args],
command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=git_env,
@@ -55,10 +56,11 @@ class GitCommandWrapper(metaclass=Singleton):
response.check_returncode()
return response
except subprocess.CalledProcessError as error:
logger.error('There are error run git command.')
raise GitError(
'stdout: %s, stderr: %s' %
(response.stdout.decode('utf8'), error.stderr.decode('utf8')))
output = 'stdout: %s, stderr: %s' % (
response.stdout.decode('utf8'), error.stderr.decode('utf8'))
logger.error('Running git command: %s failed, output: %s.' %
(command, output))
raise GitError(output)
def config_auth_token(self, repo_dir, auth_token):
url = self.get_repo_remote_url(repo_dir)

View File

@@ -9,13 +9,14 @@ from typing import Dict, List, Optional, Union
from modelscope.hub.api import HubApi, ModelScopeConfig
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
from modelscope.utils.file_utils import get_model_cache_root
from modelscope.utils.logger import get_logger
from .constants import (FILE_HASH, MODELSCOPE_DOWNLOAD_PARALLELS,
MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB)
from .file_download import (get_file_download_url, http_get_file,
parallel_download)
from .utils.caching import ModelFileSystemCache
from .utils.utils import (file_integrity_validation, get_model_cache_dir,
from .utils.utils import (file_integrity_validation,
model_id_to_group_owner_name)
logger = get_logger()
@@ -65,7 +66,7 @@ def snapshot_download(model_id: str,
"""
if cache_dir is None:
cache_dir = get_model_cache_dir()
cache_dir = get_model_cache_root()
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
temporary_cache_dir = os.path.join(cache_dir, 'temp')

View File

@@ -28,23 +28,6 @@ def model_id_to_group_owner_name(model_id):
return group_or_owner, name
def get_model_cache_dir(model_id: Optional[str] = None):
"""cache dir precedence:
function parameter > environment > ~/.cache/modelscope/hub
Args:
model_id (str, optional): The model id.
Returns:
str: the model_id dir if model_id not None, otherwise cache root dir.
"""
base_path = os.getenv('MODELSCOPE_CACHE',
get_default_modelscope_cache_dir())
base_path = os.path.join(base_path, 'hub')
return base_path if model_id is None else os.path.join(
base_path, model_id + '/')
def get_release_datetime():
if MODELSCOPE_SDK_DEBUG in os.environ:
rt = int(round(datetime.now().timestamp()))

View File

@@ -14,11 +14,12 @@ import gast
import json
from modelscope.fileio.file import LocalStorage
# do not delete
from modelscope.metainfo import (CustomDatasets, Heads, Hooks, LR_Schedulers,
Metrics, Models, Optimizers, Pipelines,
Preprocessors, TaskModels, Trainers)
from modelscope.utils.constant import Fields, Tasks
from modelscope.utils.file_utils import get_default_modelscope_cache_dir
from modelscope.utils.file_utils import get_modelscope_cache_dir
from modelscope.utils.logger import get_logger
from modelscope.utils.registry import default_group
@@ -29,7 +30,7 @@ p = Path(__file__)
# get the path of package 'modelscope'
SKIP_FUNCTION_SCANNING = True
MODELSCOPE_PATH = p.resolve().parents[1]
INDEXER_FILE_DIR = get_default_modelscope_cache_dir()
INDEXER_FILE_DIR = get_modelscope_cache_dir()
REGISTER_MODULE = 'register_module'
IGNORED_PACKAGES = ['modelscope', '.']
SCAN_SUB_FOLDERS = [

View File

@@ -1,7 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import re
import shutil
import struct
import sys
import tempfile
@@ -11,7 +10,7 @@ from urllib.parse import urlparse
import numpy as np
from modelscope.fileio.file import HTTPStorage
from modelscope.hub.utils.utils import get_model_cache_dir
from modelscope.utils.file_utils import get_model_cache_root
from modelscope.utils.hub import snapshot_download
from modelscope.utils.logger import get_logger
@@ -334,7 +333,7 @@ def update_local_model(model_config, model_path, extra_args):
model_revision = extra_args['update_model']
if model_config.__contains__('model'):
model_name = model_config['model']
dst_dir_root = get_model_cache_dir()
dst_dir_root = get_model_cache_root()
if isinstance(model_path, str) and os.path.exists(
model_path) and not model_path.startswith(dst_dir_root):
try:

View File

@@ -1,13 +1,9 @@
import argparse
import os
import traceback
from typing import List, Union
import json
from modelscope.hub.api import HubApi
from modelscope.hub.file_download import model_file_download
from modelscope.hub.utils.utils import get_model_cache_dir
from modelscope.pipelines import pipeline
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile

View File

@@ -39,7 +39,7 @@ def get_default_modelscope_cache_dir():
return default_cache_dir
def get_modelscope_cache_dir():
def get_modelscope_cache_dir() -> str:
"""Get modelscope cache dir, default location or
setting with MODELSCOPE_CACHE
@@ -49,6 +49,30 @@ def get_modelscope_cache_dir():
return os.getenv('MODELSCOPE_CACHE', get_default_modelscope_cache_dir())
def get_model_cache_root() -> str:
"""Get model cache root path.
Returns:
str: the modelscope cache root.
"""
return os.path.join(get_modelscope_cache_dir(), 'hub')
def get_model_cache_dir(model_id: str) -> str:
"""cache dir precedence:
function parameter > environment > ~/.cache/modelscope/hub/model_id
Args:
model_id (str, optional): The model id.
Returns:
str: the model_id dir if model_id not None, otherwise cache root dir.
"""
root_path = get_model_cache_root()
return root_path if model_id is None else os.path.join(
root_path, model_id + '/')
def read_file(path):
with open(path, 'r') as f:

View File

@@ -4,10 +4,10 @@ import json
from modelscope.hub.api import HubApi
from modelscope.hub.file_download import model_file_download
from modelscope.hub.utils.utils import get_model_cache_dir
from modelscope.pipelines import pipeline
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile
from modelscope.utils.file_utils import get_model_cache_dir
from modelscope.utils.input_output import (
call_pipeline_with_json, get_pipeline_information_by_pipeline,
get_task_input_examples, pipeline_output_to_service_base64_output)
@@ -20,9 +20,8 @@ class ModelJsonTest:
def test_single(self, model_id: str, model_revision=None):
# get model_revision & task info
cache_root = get_model_cache_dir()
configuration_file = os.path.join(cache_root, model_id,
ModelFile.CONFIGURATION)
configuration_file = os.path.join(
get_model_cache_dir(model_id), ModelFile.CONFIGURATION)
if not model_revision:
model_revision = self.api.list_model_revisions(
model_id=model_id)[0]

View File

@@ -12,10 +12,10 @@ from utils.source_file_analyzer import (get_all_register_modules,
from modelscope.hub.api import HubApi
from modelscope.hub.file_download import model_file_download
from modelscope.hub.utils.utils import (get_model_cache_dir,
model_id_to_group_owner_name)
from modelscope.hub.utils.utils import model_id_to_group_owner_name
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile
from modelscope.utils.file_utils import get_model_cache_dir
from modelscope.utils.logger import get_logger
logger = get_logger()
@@ -36,12 +36,11 @@ def get_models_info(groups: list) -> dict:
if len(models) >= total_count:
break
page += 1
cache_root = get_model_cache_dir()
models_info = {} # key model id, value model info
for model_info in models:
model_id = '%s/%s' % (group, model_info['Name'])
configuration_file = os.path.join(cache_root, model_id,
ModelFile.CONFIGURATION)
configuration_file = os.path.join(
get_model_cache_dir(model_id), ModelFile.CONFIGURATION)
if not os.path.exists(configuration_file):
try:
model_revisions = api.list_model_revisions(model_id=model_id)

View File

@@ -6,11 +6,11 @@ import sys
import tempfile
import unittest
from modelscope.hub.utils.utils import get_model_cache_dir
from modelscope.metainfo import Trainers
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
from modelscope.utils.constant import DownloadMode
from modelscope.utils.file_utils import get_model_cache_dir
from modelscope.utils.test_utils import test_level
@@ -57,7 +57,7 @@ class TestImageDefrcnFewShotTrainer(unittest.TestCase):
cfg.model.roi_heads.freeze_feat = False
cfg.model.roi_heads.cls_dropout = False
cfg.model.weights = os.path.join(
get_model_cache_dir(), self.model_id,
get_model_cache_dir(self.model_id),
'ImageNetPretrained/MSRA/R-101.pkl')
cfg.datasets.root = self.data_dir