mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 20:19:51 +01:00
fix #845
This commit is contained in:
@@ -336,7 +336,7 @@ class HubApi:
|
||||
git_wrapper = GitCommandWrapper()
|
||||
try:
|
||||
repo = Repository(model_dir=tmp_dir, clone_from=model_id)
|
||||
branches = git_wrapper.get_remote_branches(tmp_dir)
|
||||
branches, _ = self.get_model_branches_and_tags(model_id=model_id, use_cookies=cookies)
|
||||
if revision not in branches:
|
||||
logger.info('Create new branch %s' % revision)
|
||||
git_wrapper.new_branch(tmp_dir, revision)
|
||||
|
||||
@@ -24,8 +24,8 @@ from modelscope.utils.constant import DEFAULT_MODEL_REVISION
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .errors import FileDownloadError, NotExistError
|
||||
from .utils.caching import ModelFileSystemCache
|
||||
from .utils.utils import (file_integrity_validation, get_cache_dir,
|
||||
get_endpoint, model_id_to_group_owner_name)
|
||||
from .utils.utils import (file_integrity_validation, get_endpoint,
|
||||
get_model_cache_dir, model_id_to_group_owner_name)
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
@@ -75,7 +75,7 @@ def model_file_download(
|
||||
if some parameter value is invalid
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = get_cache_dir()
|
||||
cache_dir = get_model_cache_dir()
|
||||
if isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
temporary_cache_dir = os.path.join(cache_dir, 'temp')
|
||||
|
||||
@@ -15,7 +15,7 @@ from .constants import (FILE_HASH, MODELSCOPE_DOWNLOAD_PARALLELS,
|
||||
from .file_download import (get_file_download_url, http_get_file,
|
||||
parallel_download)
|
||||
from .utils.caching import ModelFileSystemCache
|
||||
from .utils.utils import (file_integrity_validation, get_cache_dir,
|
||||
from .utils.utils import (file_integrity_validation, get_model_cache_dir,
|
||||
model_id_to_group_owner_name)
|
||||
|
||||
logger = get_logger()
|
||||
@@ -65,7 +65,7 @@ def snapshot_download(model_id: str,
|
||||
"""
|
||||
|
||||
if cache_dir is None:
|
||||
cache_dir = get_cache_dir()
|
||||
cache_dir = get_model_cache_dir()
|
||||
if isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
temporary_cache_dir = os.path.join(cache_dir, 'temp')
|
||||
|
||||
@@ -12,7 +12,7 @@ from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN,
|
||||
MODEL_ID_SEPARATOR, MODELSCOPE_SDK_DEBUG,
|
||||
MODELSCOPE_URL_SCHEME)
|
||||
from modelscope.hub.errors import FileIntegrityError
|
||||
from modelscope.utils.file_utils import get_default_cache_dir
|
||||
from modelscope.utils.file_utils import get_default_modelscope_cache_dir
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
@@ -28,7 +28,7 @@ def model_id_to_group_owner_name(model_id):
|
||||
return group_or_owner, name
|
||||
|
||||
|
||||
def get_cache_dir(model_id: Optional[str] = None):
|
||||
def get_model_cache_dir(model_id: Optional[str] = None):
|
||||
"""cache dir precedence:
|
||||
function parameter > environment > ~/.cache/modelscope/hub
|
||||
|
||||
@@ -38,9 +38,9 @@ def get_cache_dir(model_id: Optional[str] = None):
|
||||
Returns:
|
||||
str: the model_id dir if model_id not None, otherwise cache root dir.
|
||||
"""
|
||||
default_cache_dir = get_default_cache_dir()
|
||||
base_path = os.getenv('MODELSCOPE_CACHE',
|
||||
os.path.join(default_cache_dir, 'hub'))
|
||||
get_default_modelscope_cache_dir())
|
||||
base_path = os.path.join(base_path, 'hub')
|
||||
return base_path if model_id is None else os.path.join(
|
||||
base_path, model_id + '/')
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from modelscope.metainfo import (CustomDatasets, Heads, Hooks, LR_Schedulers,
|
||||
Metrics, Models, Optimizers, Pipelines,
|
||||
Preprocessors, TaskModels, Trainers)
|
||||
from modelscope.utils.constant import Fields, Tasks
|
||||
from modelscope.utils.file_utils import get_default_cache_dir
|
||||
from modelscope.utils.file_utils import get_default_modelscope_cache_dir
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.registry import default_group
|
||||
|
||||
@@ -29,7 +29,7 @@ p = Path(__file__)
|
||||
# get the path of package 'modelscope'
|
||||
SKIP_FUNCTION_SCANNING = True
|
||||
MODELSCOPE_PATH = p.resolve().parents[1]
|
||||
INDEXER_FILE_DIR = get_default_cache_dir()
|
||||
INDEXER_FILE_DIR = get_default_modelscope_cache_dir()
|
||||
REGISTER_MODULE = 'register_module'
|
||||
IGNORED_PACKAGES = ['modelscope', '.']
|
||||
SCAN_SUB_FOLDERS = [
|
||||
|
||||
@@ -11,7 +11,7 @@ from urllib.parse import urlparse
|
||||
import numpy as np
|
||||
|
||||
from modelscope.fileio.file import HTTPStorage
|
||||
from modelscope.hub.utils.utils import get_cache_dir
|
||||
from modelscope.hub.utils.utils import get_model_cache_dir
|
||||
from modelscope.utils.hub import snapshot_download
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
@@ -334,7 +334,7 @@ def update_local_model(model_config, model_path, extra_args):
|
||||
model_revision = extra_args['update_model']
|
||||
if model_config.__contains__('model'):
|
||||
model_name = model_config['model']
|
||||
dst_dir_root = get_cache_dir()
|
||||
dst_dir_root = get_model_cache_dir()
|
||||
if isinstance(model_path, str) and os.path.exists(
|
||||
model_path) and not model_path.startswith(dst_dir_root):
|
||||
try:
|
||||
|
||||
@@ -5,14 +5,13 @@ from pathlib import Path
|
||||
|
||||
# Cache location
|
||||
from modelscope.hub.constants import DEFAULT_MODELSCOPE_DATA_ENDPOINT
|
||||
from modelscope.utils.file_utils import get_default_modelscope_cache_dir
|
||||
|
||||
DEFAULT_CACHE_HOME = Path.home().joinpath('.cache')
|
||||
CACHE_HOME = os.getenv('CACHE_HOME', DEFAULT_CACHE_HOME)
|
||||
DEFAULT_MS_CACHE_HOME = os.path.join(CACHE_HOME, 'modelscope', 'hub')
|
||||
DEFAULT_MS_CACHE_HOME = get_default_modelscope_cache_dir()
|
||||
MS_CACHE_HOME = os.path.expanduser(
|
||||
os.getenv('MS_CACHE_HOME', DEFAULT_MS_CACHE_HOME))
|
||||
os.getenv('MODELSCOPE_CACHE', DEFAULT_MS_CACHE_HOME))
|
||||
|
||||
DEFAULT_MS_DATASETS_CACHE = os.path.join(MS_CACHE_HOME, 'datasets')
|
||||
DEFAULT_MS_DATASETS_CACHE = os.path.join(MS_CACHE_HOME, 'hub', 'datasets')
|
||||
MS_DATASETS_CACHE = Path(
|
||||
os.getenv('MS_DATASETS_CACHE', DEFAULT_MS_DATASETS_CACHE))
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import json
|
||||
|
||||
from modelscope.hub.api import HubApi
|
||||
from modelscope.hub.file_download import model_file_download
|
||||
from modelscope.hub.utils.utils import get_cache_dir
|
||||
from modelscope.hub.utils.utils import get_model_cache_dir
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile
|
||||
|
||||
@@ -31,7 +31,7 @@ def func_receive_dict_inputs(func):
|
||||
return False
|
||||
|
||||
|
||||
def get_default_cache_dir():
|
||||
def get_default_modelscope_cache_dir():
|
||||
"""
|
||||
default base dir: '~/.cache/modelscope'
|
||||
"""
|
||||
|
||||
@@ -20,14 +20,14 @@ import pkg_resources
|
||||
from modelscope.fileio.file import LocalStorage
|
||||
from modelscope.utils.ast_utils import FilesAstScanning
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
|
||||
from modelscope.utils.file_utils import get_default_cache_dir
|
||||
from modelscope.utils.file_utils import get_default_modelscope_cache_dir
|
||||
from modelscope.utils.hub import read_config, snapshot_download
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
storage = LocalStorage()
|
||||
|
||||
MODELSCOPE_FILE_DIR = get_default_cache_dir()
|
||||
MODELSCOPE_FILE_DIR = get_default_modelscope_cache_dir()
|
||||
MODELSCOPE_DYNAMIC_MODULE = 'modelscope_modules'
|
||||
BASE_MODULE_DIR = os.path.join(MODELSCOPE_FILE_DIR, MODELSCOPE_DYNAMIC_MODULE)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import json
|
||||
|
||||
from modelscope.hub.api import HubApi
|
||||
from modelscope.hub.file_download import model_file_download
|
||||
from modelscope.hub.utils.utils import get_cache_dir
|
||||
from modelscope.hub.utils.utils import get_model_cache_dir
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile
|
||||
@@ -20,7 +20,7 @@ class ModelJsonTest:
|
||||
|
||||
def test_single(self, model_id: str, model_revision=None):
|
||||
# get model_revision & task info
|
||||
cache_root = get_cache_dir()
|
||||
cache_root = get_model_cache_dir()
|
||||
configuration_file = os.path.join(cache_root, model_id,
|
||||
ModelFile.CONFIGURATION)
|
||||
if not model_revision:
|
||||
|
||||
@@ -12,7 +12,7 @@ from utils.source_file_analyzer import (get_all_register_modules,
|
||||
|
||||
from modelscope.hub.api import HubApi
|
||||
from modelscope.hub.file_download import model_file_download
|
||||
from modelscope.hub.utils.utils import (get_cache_dir,
|
||||
from modelscope.hub.utils.utils import (get_model_cache_dir,
|
||||
model_id_to_group_owner_name)
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile
|
||||
@@ -36,7 +36,7 @@ def get_models_info(groups: list) -> dict:
|
||||
if len(models) >= total_count:
|
||||
break
|
||||
page += 1
|
||||
cache_root = get_cache_dir()
|
||||
cache_root = get_model_cache_dir()
|
||||
models_info = {} # key model id, value model info
|
||||
for model_info in models:
|
||||
model_id = '%s/%s' % (group, model_info['Name'])
|
||||
|
||||
@@ -6,7 +6,7 @@ import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from modelscope.hub.utils.utils import get_cache_dir
|
||||
from modelscope.hub.utils.utils import get_model_cache_dir
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.trainers import build_trainer
|
||||
@@ -57,7 +57,7 @@ class TestImageDefrcnFewShotTrainer(unittest.TestCase):
|
||||
cfg.model.roi_heads.freeze_feat = False
|
||||
cfg.model.roi_heads.cls_dropout = False
|
||||
cfg.model.weights = os.path.join(
|
||||
get_cache_dir(), self.model_id,
|
||||
get_model_cache_dir(), self.model_id,
|
||||
'ImageNetPretrained/MSRA/R-101.pkl')
|
||||
|
||||
cfg.datasets.root = self.data_dir
|
||||
|
||||
Reference in New Issue
Block a user