mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
118 lines
4.1 KiB
Python
118 lines
4.1 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
|
|
import os
|
|
from typing import Dict, Optional, Union
|
|
from urllib.parse import urlparse
|
|
|
|
from modelscope.hub.api import HubApi, ModelScopeConfig
|
|
from modelscope.hub.constants import FILE_HASH
|
|
from modelscope.hub.git import GitCommandWrapper
|
|
from modelscope.hub.utils.caching import ModelFileSystemCache
|
|
from modelscope.hub.utils.utils import compute_hash
|
|
from modelscope.utils.logger import get_logger
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
def get_model_id_from_cache(model_root_path: str, ) -> str:
|
|
model_cache = None
|
|
# download with git
|
|
if os.path.exists(os.path.join(model_root_path, '.git')):
|
|
git_cmd_wrapper = GitCommandWrapper()
|
|
git_url = git_cmd_wrapper.get_repo_remote_url(model_root_path)
|
|
if git_url.endswith('.git'):
|
|
git_url = git_url[:-4]
|
|
u_parse = urlparse(git_url)
|
|
model_id = u_parse.path[1:]
|
|
else: # snapshot_download
|
|
model_cache = ModelFileSystemCache(model_root_path)
|
|
model_id = model_cache.get_model_id()
|
|
return model_id
|
|
|
|
|
|
def check_local_model_is_latest(
|
|
model_root_path: str,
|
|
user_agent: Optional[Union[Dict, str]] = None,
|
|
):
|
|
"""Check local model repo is latest.
|
|
Check local model repo is same as hub latest version.
|
|
"""
|
|
try:
|
|
model_id = get_model_id_from_cache(model_root_path)
|
|
model_id = model_id.replace('___', '.')
|
|
# make headers
|
|
headers = {
|
|
'user-agent':
|
|
ModelScopeConfig.get_user_agent(user_agent=user_agent, )
|
|
}
|
|
cookies = ModelScopeConfig.get_cookies()
|
|
|
|
snapshot_header = headers if 'CI_TEST' in os.environ else {
|
|
**headers,
|
|
**{
|
|
'Snapshot': 'True'
|
|
}
|
|
}
|
|
_api = HubApi(timeout=0.5)
|
|
try:
|
|
_, revisions = _api.get_model_branches_and_tags(
|
|
model_id=model_id, use_cookies=cookies)
|
|
if len(revisions) > 0:
|
|
latest_revision = revisions[0]
|
|
else:
|
|
latest_revision = 'master'
|
|
except: # noqa: E722
|
|
latest_revision = 'master'
|
|
|
|
model_files = _api.get_model_files(
|
|
model_id=model_id,
|
|
revision=latest_revision,
|
|
recursive=True,
|
|
headers=snapshot_header,
|
|
use_cookies=cookies,
|
|
)
|
|
model_cache = None
|
|
# download via non-git method
|
|
if not os.path.exists(os.path.join(model_root_path, '.git')):
|
|
model_cache = ModelFileSystemCache(model_root_path)
|
|
for model_file in model_files:
|
|
if model_file['Type'] == 'tree':
|
|
continue
|
|
# check model_file updated
|
|
if model_cache is not None:
|
|
if model_cache.exists(model_file):
|
|
continue
|
|
else:
|
|
logger.info(
|
|
f'Model file {model_file["Name"]} is different from the latest version `{latest_revision}`,'
|
|
f'This is because you are using an older version or the file is updated manually.'
|
|
)
|
|
break
|
|
else:
|
|
if FILE_HASH in model_file:
|
|
local_file_hash = compute_hash(
|
|
os.path.join(model_root_path, model_file['Path']))
|
|
if local_file_hash == model_file[FILE_HASH]:
|
|
continue
|
|
else:
|
|
logger.info(
|
|
f'Model file {model_file["Name"]} is different from the latest version `{latest_revision}`,'
|
|
f'This is because you are using an older version or the file is updated manually.'
|
|
)
|
|
break
|
|
except: # noqa: E722
|
|
pass # ignore
|
|
|
|
|
|
def check_model_is_id(model_id: str, token: Optional[str] = None):
|
|
if model_id is None or os.path.exists(model_id):
|
|
return False
|
|
else:
|
|
_api = HubApi()
|
|
_api.login(token)
|
|
try:
|
|
_api.get_model(model_id=model_id, )
|
|
return True
|
|
except Exception:
|
|
return False
|