mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-17 00:37:43 +01:00
* fix #845 Co-authored-by: mulin.lyh <mulin.lyh@taobao.com>
This commit is contained in:
@@ -267,6 +267,8 @@ class HubApi:
|
|||||||
This function must be called before calling HubApi's login with a valid token
|
This function must be called before calling HubApi's login with a valid token
|
||||||
which can be obtained from ModelScope's website.
|
which can be obtained from ModelScope's website.
|
||||||
|
|
||||||
|
If any error, please upload via git commands.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_id (str):
|
model_id (str):
|
||||||
The model id to be uploaded, caller must have write permission for it.
|
The model id to be uploaded, caller must have write permission for it.
|
||||||
|
|||||||
@@ -21,11 +21,12 @@ from modelscope.hub.constants import (
|
|||||||
API_FILE_DOWNLOAD_TIMEOUT, FILE_HASH, MODELSCOPE_DOWNLOAD_PARALLELS,
|
API_FILE_DOWNLOAD_TIMEOUT, FILE_HASH, MODELSCOPE_DOWNLOAD_PARALLELS,
|
||||||
MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB)
|
MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB)
|
||||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
|
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
|
||||||
|
from modelscope.utils.file_utils import get_model_cache_root
|
||||||
from modelscope.utils.logger import get_logger
|
from modelscope.utils.logger import get_logger
|
||||||
from .errors import FileDownloadError, NotExistError
|
from .errors import FileDownloadError, NotExistError
|
||||||
from .utils.caching import ModelFileSystemCache
|
from .utils.caching import ModelFileSystemCache
|
||||||
from .utils.utils import (file_integrity_validation, get_cache_dir,
|
from .utils.utils import (file_integrity_validation, get_endpoint,
|
||||||
get_endpoint, model_id_to_group_owner_name)
|
model_id_to_group_owner_name)
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
@@ -75,7 +76,7 @@ def model_file_download(
|
|||||||
if some parameter value is invalid
|
if some parameter value is invalid
|
||||||
"""
|
"""
|
||||||
if cache_dir is None:
|
if cache_dir is None:
|
||||||
cache_dir = get_cache_dir()
|
cache_dir = get_model_cache_root()
|
||||||
if isinstance(cache_dir, Path):
|
if isinstance(cache_dir, Path):
|
||||||
cache_dir = str(cache_dir)
|
cache_dir = str(cache_dir)
|
||||||
temporary_cache_dir = os.path.join(cache_dir, 'temp')
|
temporary_cache_dir = os.path.join(cache_dir, 'temp')
|
||||||
|
|||||||
@@ -45,8 +45,9 @@ class GitCommandWrapper(metaclass=Singleton):
|
|||||||
logger.debug(' '.join(args))
|
logger.debug(' '.join(args))
|
||||||
git_env = os.environ.copy()
|
git_env = os.environ.copy()
|
||||||
git_env['GIT_TERMINAL_PROMPT'] = '0'
|
git_env['GIT_TERMINAL_PROMPT'] = '0'
|
||||||
|
command = [self.git_path, *args]
|
||||||
response = subprocess.run(
|
response = subprocess.run(
|
||||||
[self.git_path, *args],
|
command,
|
||||||
stdout=subprocess.PIPE,
|
stdout=subprocess.PIPE,
|
||||||
stderr=subprocess.PIPE,
|
stderr=subprocess.PIPE,
|
||||||
env=git_env,
|
env=git_env,
|
||||||
@@ -55,10 +56,11 @@ class GitCommandWrapper(metaclass=Singleton):
|
|||||||
response.check_returncode()
|
response.check_returncode()
|
||||||
return response
|
return response
|
||||||
except subprocess.CalledProcessError as error:
|
except subprocess.CalledProcessError as error:
|
||||||
logger.error('There are error run git command.')
|
output = 'stdout: %s, stderr: %s' % (
|
||||||
raise GitError(
|
response.stdout.decode('utf8'), error.stderr.decode('utf8'))
|
||||||
'stdout: %s, stderr: %s' %
|
logger.error('Running git command: %s failed, output: %s.' %
|
||||||
(response.stdout.decode('utf8'), error.stderr.decode('utf8')))
|
(command, output))
|
||||||
|
raise GitError(output)
|
||||||
|
|
||||||
def config_auth_token(self, repo_dir, auth_token):
|
def config_auth_token(self, repo_dir, auth_token):
|
||||||
url = self.get_repo_remote_url(repo_dir)
|
url = self.get_repo_remote_url(repo_dir)
|
||||||
|
|||||||
@@ -9,13 +9,14 @@ from typing import Dict, List, Optional, Union
|
|||||||
|
|
||||||
from modelscope.hub.api import HubApi, ModelScopeConfig
|
from modelscope.hub.api import HubApi, ModelScopeConfig
|
||||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
|
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
|
||||||
|
from modelscope.utils.file_utils import get_model_cache_root
|
||||||
from modelscope.utils.logger import get_logger
|
from modelscope.utils.logger import get_logger
|
||||||
from .constants import (FILE_HASH, MODELSCOPE_DOWNLOAD_PARALLELS,
|
from .constants import (FILE_HASH, MODELSCOPE_DOWNLOAD_PARALLELS,
|
||||||
MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB)
|
MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB)
|
||||||
from .file_download import (get_file_download_url, http_get_file,
|
from .file_download import (get_file_download_url, http_get_file,
|
||||||
parallel_download)
|
parallel_download)
|
||||||
from .utils.caching import ModelFileSystemCache
|
from .utils.caching import ModelFileSystemCache
|
||||||
from .utils.utils import (file_integrity_validation, get_cache_dir,
|
from .utils.utils import (file_integrity_validation,
|
||||||
model_id_to_group_owner_name)
|
model_id_to_group_owner_name)
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
@@ -65,7 +66,7 @@ def snapshot_download(model_id: str,
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if cache_dir is None:
|
if cache_dir is None:
|
||||||
cache_dir = get_cache_dir()
|
cache_dir = get_model_cache_root()
|
||||||
if isinstance(cache_dir, Path):
|
if isinstance(cache_dir, Path):
|
||||||
cache_dir = str(cache_dir)
|
cache_dir = str(cache_dir)
|
||||||
temporary_cache_dir = os.path.join(cache_dir, 'temp')
|
temporary_cache_dir = os.path.join(cache_dir, 'temp')
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN,
|
|||||||
MODEL_ID_SEPARATOR, MODELSCOPE_SDK_DEBUG,
|
MODEL_ID_SEPARATOR, MODELSCOPE_SDK_DEBUG,
|
||||||
MODELSCOPE_URL_SCHEME)
|
MODELSCOPE_URL_SCHEME)
|
||||||
from modelscope.hub.errors import FileIntegrityError
|
from modelscope.hub.errors import FileIntegrityError
|
||||||
from modelscope.utils.file_utils import get_default_cache_dir
|
from modelscope.utils.file_utils import get_default_modelscope_cache_dir
|
||||||
from modelscope.utils.logger import get_logger
|
from modelscope.utils.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
@@ -28,23 +28,6 @@ def model_id_to_group_owner_name(model_id):
|
|||||||
return group_or_owner, name
|
return group_or_owner, name
|
||||||
|
|
||||||
|
|
||||||
def get_cache_dir(model_id: Optional[str] = None):
|
|
||||||
"""cache dir precedence:
|
|
||||||
function parameter > environment > ~/.cache/modelscope/hub
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_id (str, optional): The model id.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: the model_id dir if model_id not None, otherwise cache root dir.
|
|
||||||
"""
|
|
||||||
default_cache_dir = get_default_cache_dir()
|
|
||||||
base_path = os.getenv('MODELSCOPE_CACHE',
|
|
||||||
os.path.join(default_cache_dir, 'hub'))
|
|
||||||
return base_path if model_id is None else os.path.join(
|
|
||||||
base_path, model_id + '/')
|
|
||||||
|
|
||||||
|
|
||||||
def get_release_datetime():
|
def get_release_datetime():
|
||||||
if MODELSCOPE_SDK_DEBUG in os.environ:
|
if MODELSCOPE_SDK_DEBUG in os.environ:
|
||||||
rt = int(round(datetime.now().timestamp()))
|
rt = int(round(datetime.now().timestamp()))
|
||||||
|
|||||||
@@ -14,11 +14,12 @@ import gast
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
from modelscope.fileio.file import LocalStorage
|
from modelscope.fileio.file import LocalStorage
|
||||||
|
# do not delete
|
||||||
from modelscope.metainfo import (CustomDatasets, Heads, Hooks, LR_Schedulers,
|
from modelscope.metainfo import (CustomDatasets, Heads, Hooks, LR_Schedulers,
|
||||||
Metrics, Models, Optimizers, Pipelines,
|
Metrics, Models, Optimizers, Pipelines,
|
||||||
Preprocessors, TaskModels, Trainers)
|
Preprocessors, TaskModels, Trainers)
|
||||||
from modelscope.utils.constant import Fields, Tasks
|
from modelscope.utils.constant import Fields, Tasks
|
||||||
from modelscope.utils.file_utils import get_default_cache_dir
|
from modelscope.utils.file_utils import get_modelscope_cache_dir
|
||||||
from modelscope.utils.logger import get_logger
|
from modelscope.utils.logger import get_logger
|
||||||
from modelscope.utils.registry import default_group
|
from modelscope.utils.registry import default_group
|
||||||
|
|
||||||
@@ -29,7 +30,7 @@ p = Path(__file__)
|
|||||||
# get the path of package 'modelscope'
|
# get the path of package 'modelscope'
|
||||||
SKIP_FUNCTION_SCANNING = True
|
SKIP_FUNCTION_SCANNING = True
|
||||||
MODELSCOPE_PATH = p.resolve().parents[1]
|
MODELSCOPE_PATH = p.resolve().parents[1]
|
||||||
INDEXER_FILE_DIR = get_default_cache_dir()
|
INDEXER_FILE_DIR = get_modelscope_cache_dir()
|
||||||
REGISTER_MODULE = 'register_module'
|
REGISTER_MODULE = 'register_module'
|
||||||
IGNORED_PACKAGES = ['modelscope', '.']
|
IGNORED_PACKAGES = ['modelscope', '.']
|
||||||
SCAN_SUB_FOLDERS = [
|
SCAN_SUB_FOLDERS = [
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
|
||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
@@ -11,7 +10,7 @@ from urllib.parse import urlparse
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from modelscope.fileio.file import HTTPStorage
|
from modelscope.fileio.file import HTTPStorage
|
||||||
from modelscope.hub.utils.utils import get_cache_dir
|
from modelscope.utils.file_utils import get_model_cache_root
|
||||||
from modelscope.utils.hub import snapshot_download
|
from modelscope.utils.hub import snapshot_download
|
||||||
from modelscope.utils.logger import get_logger
|
from modelscope.utils.logger import get_logger
|
||||||
|
|
||||||
@@ -334,7 +333,7 @@ def update_local_model(model_config, model_path, extra_args):
|
|||||||
model_revision = extra_args['update_model']
|
model_revision = extra_args['update_model']
|
||||||
if model_config.__contains__('model'):
|
if model_config.__contains__('model'):
|
||||||
model_name = model_config['model']
|
model_name = model_config['model']
|
||||||
dst_dir_root = get_cache_dir()
|
dst_dir_root = get_model_cache_root()
|
||||||
if isinstance(model_path, str) and os.path.exists(
|
if isinstance(model_path, str) and os.path.exists(
|
||||||
model_path) and not model_path.startswith(dst_dir_root):
|
model_path) and not model_path.startswith(dst_dir_root):
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -5,14 +5,11 @@ from pathlib import Path
|
|||||||
|
|
||||||
# Cache location
|
# Cache location
|
||||||
from modelscope.hub.constants import DEFAULT_MODELSCOPE_DATA_ENDPOINT
|
from modelscope.hub.constants import DEFAULT_MODELSCOPE_DATA_ENDPOINT
|
||||||
|
from modelscope.utils.file_utils import get_modelscope_cache_dir
|
||||||
|
|
||||||
DEFAULT_CACHE_HOME = Path.home().joinpath('.cache')
|
MS_CACHE_HOME = get_modelscope_cache_dir()
|
||||||
CACHE_HOME = os.getenv('CACHE_HOME', DEFAULT_CACHE_HOME)
|
|
||||||
DEFAULT_MS_CACHE_HOME = os.path.join(CACHE_HOME, 'modelscope', 'hub')
|
|
||||||
MS_CACHE_HOME = os.path.expanduser(
|
|
||||||
os.getenv('MS_CACHE_HOME', DEFAULT_MS_CACHE_HOME))
|
|
||||||
|
|
||||||
DEFAULT_MS_DATASETS_CACHE = os.path.join(MS_CACHE_HOME, 'datasets')
|
DEFAULT_MS_DATASETS_CACHE = os.path.join(MS_CACHE_HOME, 'hub', 'datasets')
|
||||||
MS_DATASETS_CACHE = Path(
|
MS_DATASETS_CACHE = Path(
|
||||||
os.getenv('MS_DATASETS_CACHE', DEFAULT_MS_DATASETS_CACHE))
|
os.getenv('MS_DATASETS_CACHE', DEFAULT_MS_DATASETS_CACHE))
|
||||||
|
|
||||||
|
|||||||
@@ -1,13 +1,9 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
|
||||||
import traceback
|
import traceback
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
import json
|
|
||||||
|
|
||||||
from modelscope.hub.api import HubApi
|
from modelscope.hub.api import HubApi
|
||||||
from modelscope.hub.file_download import model_file_download
|
from modelscope.hub.file_download import model_file_download
|
||||||
from modelscope.hub.utils.utils import get_cache_dir
|
|
||||||
from modelscope.pipelines import pipeline
|
from modelscope.pipelines import pipeline
|
||||||
from modelscope.utils.config import Config
|
from modelscope.utils.config import Config
|
||||||
from modelscope.utils.constant import ModelFile
|
from modelscope.utils.constant import ModelFile
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ def func_receive_dict_inputs(func):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_default_cache_dir():
|
def get_default_modelscope_cache_dir():
|
||||||
"""
|
"""
|
||||||
default base dir: '~/.cache/modelscope'
|
default base dir: '~/.cache/modelscope'
|
||||||
"""
|
"""
|
||||||
@@ -39,6 +39,40 @@ def get_default_cache_dir():
|
|||||||
return default_cache_dir
|
return default_cache_dir
|
||||||
|
|
||||||
|
|
||||||
|
def get_modelscope_cache_dir() -> str:
|
||||||
|
"""Get modelscope cache dir, default location or
|
||||||
|
setting with MODELSCOPE_CACHE
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: the modelscope cache root.
|
||||||
|
"""
|
||||||
|
return os.getenv('MODELSCOPE_CACHE', get_default_modelscope_cache_dir())
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_cache_root() -> str:
|
||||||
|
"""Get model cache root path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: the modelscope cache root.
|
||||||
|
"""
|
||||||
|
return os.path.join(get_modelscope_cache_dir(), 'hub')
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_cache_dir(model_id: str) -> str:
|
||||||
|
"""cache dir precedence:
|
||||||
|
function parameter > environment > ~/.cache/modelscope/hub/model_id
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id (str, optional): The model id.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: the model_id dir if model_id not None, otherwise cache root dir.
|
||||||
|
"""
|
||||||
|
root_path = get_model_cache_root()
|
||||||
|
return root_path if model_id is None else os.path.join(
|
||||||
|
root_path, model_id + '/')
|
||||||
|
|
||||||
|
|
||||||
def read_file(path):
|
def read_file(path):
|
||||||
|
|
||||||
with open(path, 'r') as f:
|
with open(path, 'r') as f:
|
||||||
|
|||||||
@@ -20,14 +20,14 @@ import pkg_resources
|
|||||||
from modelscope.fileio.file import LocalStorage
|
from modelscope.fileio.file import LocalStorage
|
||||||
from modelscope.utils.ast_utils import FilesAstScanning
|
from modelscope.utils.ast_utils import FilesAstScanning
|
||||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
|
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
|
||||||
from modelscope.utils.file_utils import get_default_cache_dir
|
from modelscope.utils.file_utils import get_modelscope_cache_dir
|
||||||
from modelscope.utils.hub import read_config, snapshot_download
|
from modelscope.utils.hub import read_config, snapshot_download
|
||||||
from modelscope.utils.logger import get_logger
|
from modelscope.utils.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
storage = LocalStorage()
|
storage = LocalStorage()
|
||||||
|
|
||||||
MODELSCOPE_FILE_DIR = get_default_cache_dir()
|
MODELSCOPE_FILE_DIR = get_modelscope_cache_dir()
|
||||||
MODELSCOPE_DYNAMIC_MODULE = 'modelscope_modules'
|
MODELSCOPE_DYNAMIC_MODULE = 'modelscope_modules'
|
||||||
BASE_MODULE_DIR = os.path.join(MODELSCOPE_FILE_DIR, MODELSCOPE_DYNAMIC_MODULE)
|
BASE_MODULE_DIR = os.path.join(MODELSCOPE_FILE_DIR, MODELSCOPE_DYNAMIC_MODULE)
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
# Make sure to modify __release_datetime__ to release time when making official release.
|
# Make sure to modify __release_datetime__ to release time when making official release.
|
||||||
__version__ = '1.9.4'
|
__version__ = '2.0.0'
|
||||||
# default release datetime for branches under active development is set
|
# default release datetime for branches under active development is set
|
||||||
# to be a time far-far-away-into-the-future
|
# to be a time far-far-away-into-the-future
|
||||||
__release_datetime__ = '2099-09-06 00:00:00'
|
__release_datetime__ = '2099-09-06 00:00:00'
|
||||||
|
|||||||
@@ -4,10 +4,10 @@ import json
|
|||||||
|
|
||||||
from modelscope.hub.api import HubApi
|
from modelscope.hub.api import HubApi
|
||||||
from modelscope.hub.file_download import model_file_download
|
from modelscope.hub.file_download import model_file_download
|
||||||
from modelscope.hub.utils.utils import get_cache_dir
|
|
||||||
from modelscope.pipelines import pipeline
|
from modelscope.pipelines import pipeline
|
||||||
from modelscope.utils.config import Config
|
from modelscope.utils.config import Config
|
||||||
from modelscope.utils.constant import ModelFile
|
from modelscope.utils.constant import ModelFile
|
||||||
|
from modelscope.utils.file_utils import get_model_cache_dir
|
||||||
from modelscope.utils.input_output import (
|
from modelscope.utils.input_output import (
|
||||||
call_pipeline_with_json, get_pipeline_information_by_pipeline,
|
call_pipeline_with_json, get_pipeline_information_by_pipeline,
|
||||||
get_task_input_examples, pipeline_output_to_service_base64_output)
|
get_task_input_examples, pipeline_output_to_service_base64_output)
|
||||||
@@ -20,9 +20,8 @@ class ModelJsonTest:
|
|||||||
|
|
||||||
def test_single(self, model_id: str, model_revision=None):
|
def test_single(self, model_id: str, model_revision=None):
|
||||||
# get model_revision & task info
|
# get model_revision & task info
|
||||||
cache_root = get_cache_dir()
|
configuration_file = os.path.join(
|
||||||
configuration_file = os.path.join(cache_root, model_id,
|
get_model_cache_dir(model_id), ModelFile.CONFIGURATION)
|
||||||
ModelFile.CONFIGURATION)
|
|
||||||
if not model_revision:
|
if not model_revision:
|
||||||
model_revision = self.api.list_model_revisions(
|
model_revision = self.api.list_model_revisions(
|
||||||
model_id=model_id)[0]
|
model_id=model_id)[0]
|
||||||
|
|||||||
@@ -316,7 +316,9 @@ class OfaTasksTest(unittest.TestCase):
|
|||||||
result[OutputKeys.OUTPUT_IMG].save('result.png')
|
result[OutputKeys.OUTPUT_IMG].save('result.png')
|
||||||
print(f'Output written to {osp.abspath("result.png")}')
|
print(f'Output written to {osp.abspath("result.png")}')
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
@unittest.skipUnless(
|
||||||
|
test_level() >= 1,
|
||||||
|
'skip test in current test level, model has no text2phone_dict.txt')
|
||||||
def test_run_with_asr_with_name(self):
|
def test_run_with_asr_with_name(self):
|
||||||
model = 'damo/ofa_mmspeech_pretrain_base_zh'
|
model = 'damo/ofa_mmspeech_pretrain_base_zh'
|
||||||
ofa_pipe = pipeline(Tasks.auto_speech_recognition, model=model)
|
ofa_pipe = pipeline(Tasks.auto_speech_recognition, model=model)
|
||||||
|
|||||||
@@ -12,10 +12,10 @@ from utils.source_file_analyzer import (get_all_register_modules,
|
|||||||
|
|
||||||
from modelscope.hub.api import HubApi
|
from modelscope.hub.api import HubApi
|
||||||
from modelscope.hub.file_download import model_file_download
|
from modelscope.hub.file_download import model_file_download
|
||||||
from modelscope.hub.utils.utils import (get_cache_dir,
|
from modelscope.hub.utils.utils import model_id_to_group_owner_name
|
||||||
model_id_to_group_owner_name)
|
|
||||||
from modelscope.utils.config import Config
|
from modelscope.utils.config import Config
|
||||||
from modelscope.utils.constant import ModelFile
|
from modelscope.utils.constant import ModelFile
|
||||||
|
from modelscope.utils.file_utils import get_model_cache_dir
|
||||||
from modelscope.utils.logger import get_logger
|
from modelscope.utils.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
@@ -36,12 +36,11 @@ def get_models_info(groups: list) -> dict:
|
|||||||
if len(models) >= total_count:
|
if len(models) >= total_count:
|
||||||
break
|
break
|
||||||
page += 1
|
page += 1
|
||||||
cache_root = get_cache_dir()
|
|
||||||
models_info = {} # key model id, value model info
|
models_info = {} # key model id, value model info
|
||||||
for model_info in models:
|
for model_info in models:
|
||||||
model_id = '%s/%s' % (group, model_info['Name'])
|
model_id = '%s/%s' % (group, model_info['Name'])
|
||||||
configuration_file = os.path.join(cache_root, model_id,
|
configuration_file = os.path.join(
|
||||||
ModelFile.CONFIGURATION)
|
get_model_cache_dir(model_id), ModelFile.CONFIGURATION)
|
||||||
if not os.path.exists(configuration_file):
|
if not os.path.exists(configuration_file):
|
||||||
try:
|
try:
|
||||||
model_revisions = api.list_model_revisions(model_id=model_id)
|
model_revisions = api.list_model_revisions(model_id=model_id)
|
||||||
|
|||||||
@@ -6,11 +6,11 @@ import sys
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from modelscope.hub.utils.utils import get_cache_dir
|
|
||||||
from modelscope.metainfo import Trainers
|
from modelscope.metainfo import Trainers
|
||||||
from modelscope.msdatasets import MsDataset
|
from modelscope.msdatasets import MsDataset
|
||||||
from modelscope.trainers import build_trainer
|
from modelscope.trainers import build_trainer
|
||||||
from modelscope.utils.constant import DownloadMode
|
from modelscope.utils.constant import DownloadMode
|
||||||
|
from modelscope.utils.file_utils import get_model_cache_dir
|
||||||
from modelscope.utils.test_utils import test_level
|
from modelscope.utils.test_utils import test_level
|
||||||
|
|
||||||
|
|
||||||
@@ -57,7 +57,7 @@ class TestImageDefrcnFewShotTrainer(unittest.TestCase):
|
|||||||
cfg.model.roi_heads.freeze_feat = False
|
cfg.model.roi_heads.freeze_feat = False
|
||||||
cfg.model.roi_heads.cls_dropout = False
|
cfg.model.roi_heads.cls_dropout = False
|
||||||
cfg.model.weights = os.path.join(
|
cfg.model.weights = os.path.join(
|
||||||
get_cache_dir(), self.model_id,
|
get_model_cache_dir(self.model_id),
|
||||||
'ImageNetPretrained/MSRA/R-101.pkl')
|
'ImageNetPretrained/MSRA/R-101.pkl')
|
||||||
|
|
||||||
cfg.datasets.root = self.data_dir
|
cfg.datasets.root = self.data_dir
|
||||||
|
|||||||
Reference in New Issue
Block a user