mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 04:01:10 +01:00
Merge remote-tracking branch 'upstream/master' into release/1.34
This commit is contained in:
@@ -514,7 +514,8 @@ class HubApi:
|
||||
def get_endpoint_for_read(self,
|
||||
repo_id: str,
|
||||
*,
|
||||
repo_type: Optional[str] = None) -> str:
|
||||
repo_type: Optional[str] = None,
|
||||
token: Optional[str] = None) -> str:
|
||||
"""Get proper endpoint for read operation (such as download, list etc.)
|
||||
1. If user has set MODELSCOPE_DOMAIN, construct endpoint with user-specified domain.
|
||||
If the repo does not exist on that endpoint, throw 404 error, otherwise return the endpoint.
|
||||
@@ -529,7 +530,7 @@ class HubApi:
|
||||
if s is not None and s.strip() != '':
|
||||
endpoint = MODELSCOPE_URL_SCHEME + s
|
||||
try:
|
||||
self.repo_exists(repo_id=repo_id, repo_type=repo_type, endpoint=endpoint, re_raise=True)
|
||||
self.repo_exists(repo_id=repo_id, repo_type=repo_type, endpoint=endpoint, re_raise=True, token=token)
|
||||
except Exception:
|
||||
logger.error(f'Repo {repo_id} does not exist on {endpoint}.')
|
||||
raise
|
||||
@@ -538,13 +539,13 @@ class HubApi:
|
||||
check_cn_first = not is_env_true(MODELSCOPE_PREFER_AI_SITE)
|
||||
prefer_endpoint = get_endpoint(cn_site=check_cn_first)
|
||||
if not self.repo_exists(
|
||||
repo_id, repo_type=repo_type, endpoint=prefer_endpoint):
|
||||
repo_id, repo_type=repo_type, endpoint=prefer_endpoint, token=token):
|
||||
alternative_endpoint = get_endpoint(cn_site=(not check_cn_first))
|
||||
logger.warning(f'Repo {repo_id} not exists on {prefer_endpoint}, '
|
||||
f'will try on alternative endpoint {alternative_endpoint}.')
|
||||
try:
|
||||
self.repo_exists(
|
||||
repo_id, repo_type=repo_type, endpoint=alternative_endpoint, re_raise=True)
|
||||
repo_id, repo_type=repo_type, endpoint=alternative_endpoint, re_raise=True, token=token)
|
||||
except Exception:
|
||||
logger.error(f'Repo {repo_id} not exists on either {prefer_endpoint} or {alternative_endpoint}')
|
||||
raise
|
||||
@@ -856,20 +857,21 @@ class HubApi:
|
||||
ignore_file_pattern = [ignore_file_pattern]
|
||||
if visibility is None or license is None:
|
||||
raise InvalidParameter('Visibility and License cannot be empty for new model.')
|
||||
if not self.repo_exists(model_id):
|
||||
if not self.repo_exists(model_id, token=token):
|
||||
logger.info('Creating new model [%s]' % model_id)
|
||||
self.create_model(
|
||||
model_id=model_id,
|
||||
visibility=visibility,
|
||||
license=license,
|
||||
chinese_name=chinese_name,
|
||||
original_model_id=original_model_id)
|
||||
original_model_id=original_model_id,
|
||||
token=token)
|
||||
tmp_dir = os.path.join(model_dir, TEMPORARY_FOLDER_NAME) # make temporary folder
|
||||
git_wrapper = GitCommandWrapper()
|
||||
logger.info(f'Pushing folder {model_dir} as model {model_id}.')
|
||||
logger.info(f'Total folder size {folder_size}, this may take a while depending on actual pushing size...')
|
||||
try:
|
||||
repo = Repository(model_dir=tmp_dir, clone_from=model_id)
|
||||
repo = Repository(model_dir=tmp_dir, clone_from=model_id, auth_token=token)
|
||||
branches = git_wrapper.get_remote_branches(tmp_dir)
|
||||
if revision not in branches:
|
||||
logger.info('Creating new branch %s' % revision)
|
||||
@@ -2050,7 +2052,7 @@ class HubApi:
|
||||
if create_default_config:
|
||||
with tempfile.TemporaryDirectory() as temp_cache_dir:
|
||||
from modelscope.hub.repository import Repository
|
||||
repo = Repository(temp_cache_dir, repo_id)
|
||||
repo = Repository(temp_cache_dir, repo_id, auth_token=token)
|
||||
default_config = {
|
||||
'framework': 'pytorch',
|
||||
'task': 'text-generation',
|
||||
@@ -2916,6 +2918,7 @@ class HubApi:
|
||||
page_number=page_number,
|
||||
page_size=page_size,
|
||||
endpoint=endpoint,
|
||||
token=token,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'Get dataset: {repo_id} file list failed, message: {str(e)}')
|
||||
@@ -3005,7 +3008,7 @@ class HubApi:
|
||||
cookies = self.get_cookies(access_token=token, cookies_required=True)
|
||||
|
||||
if repo_type == REPO_TYPE_MODEL:
|
||||
model_info = self.get_model(model_id=repo_id)
|
||||
model_info = self.get_model(model_id=repo_id, token=token)
|
||||
path = f'{self.endpoint}/api/v1/models/{repo_id}'
|
||||
tasks = model_info.get('Tasks')
|
||||
model_tasks = ''
|
||||
@@ -3038,6 +3041,7 @@ class HubApi:
|
||||
dataset_idx, _ = self.get_dataset_id_and_type(
|
||||
dataset_name=repo_id_parts[1],
|
||||
namespace=repo_id_parts[0],
|
||||
token=token
|
||||
)
|
||||
|
||||
path = f'{self.endpoint}/api/v1/datasets/{dataset_idx}'
|
||||
@@ -3088,6 +3092,8 @@ class ModelScopeConfig:
|
||||
if os.path.exists(cookies_path):
|
||||
with open(cookies_path, 'rb') as f:
|
||||
cookies = pickle.load(f)
|
||||
if not cookies:
|
||||
return None
|
||||
for cookie in cookies:
|
||||
if cookie.name == 'm_session_id' and cookie.is_expired() and \
|
||||
not ModelScopeConfig.cookie_expired_warning:
|
||||
|
||||
@@ -221,7 +221,8 @@ def _repo_file_download(
|
||||
if cookies is None:
|
||||
cookies = _api.get_cookies()
|
||||
repo_files = []
|
||||
endpoint = _api.get_endpoint_for_read(repo_id=repo_id, repo_type=repo_type)
|
||||
endpoint = _api.get_endpoint_for_read(
|
||||
repo_id=repo_id, repo_type=repo_type, token=token)
|
||||
file_to_download_meta = None
|
||||
if repo_type == REPO_TYPE_MODEL:
|
||||
revision = _api.get_valid_revision(
|
||||
@@ -263,7 +264,8 @@ def _repo_file_download(
|
||||
recursive=True,
|
||||
page_number=page_number,
|
||||
page_size=page_size,
|
||||
endpoint=endpoint)
|
||||
endpoint=endpoint,
|
||||
token=token)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'Get dataset: {repo_id} file list failed, error: {e}')
|
||||
|
||||
@@ -56,7 +56,7 @@ class CredentialProviderWrapper(CredentialsProvider):
|
||||
oss_config = self.api.get_dataset_access_config_session(
|
||||
dataset_name=self.dataset_name,
|
||||
namespace=self.namespace,
|
||||
check_cookie=True,
|
||||
check_cookie=False,
|
||||
revision=self.revision)
|
||||
|
||||
return Credentials(
|
||||
@@ -83,7 +83,7 @@ class OssUtilities:
|
||||
oss_config = self.api.get_dataset_access_config_session(
|
||||
dataset_name=self.dataset_name,
|
||||
namespace=self.namespace,
|
||||
check_cookie=True,
|
||||
check_cookie=False,
|
||||
revision=self.revision)
|
||||
|
||||
if os.getenv('ENABLE_DATASET_ACCELERATION') == 'True':
|
||||
|
||||
Reference in New Issue
Block a user