mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 20:49:37 +01:00
103 lines
3.2 KiB
Python
103 lines
3.2 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
|
|
import hashlib
|
|
import os
|
|
from datetime import datetime
|
|
from typing import Optional
|
|
|
|
import requests
|
|
|
|
from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN,
|
|
DEFAULT_MODELSCOPE_GROUP,
|
|
MODEL_ID_SEPARATOR, MODELSCOPE_SDK_DEBUG,
|
|
MODELSCOPE_URL_SCHEME)
|
|
from modelscope.hub.errors import FileIntegrityError
|
|
from modelscope.utils.file_utils import get_default_cache_dir
|
|
from modelscope.utils.logger import get_logger
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
def model_id_to_group_owner_name(model_id):
|
|
if MODEL_ID_SEPARATOR in model_id:
|
|
group_or_owner = model_id.split(MODEL_ID_SEPARATOR)[0]
|
|
name = model_id.split(MODEL_ID_SEPARATOR)[1]
|
|
else:
|
|
group_or_owner = DEFAULT_MODELSCOPE_GROUP
|
|
name = model_id
|
|
return group_or_owner, name
|
|
|
|
|
|
def get_cache_dir(model_id: Optional[str] = None):
|
|
"""
|
|
cache dir precedence:
|
|
function parameter > enviroment > ~/.cache/modelscope/hub
|
|
"""
|
|
default_cache_dir = get_default_cache_dir()
|
|
base_path = os.getenv('MODELSCOPE_CACHE',
|
|
os.path.join(default_cache_dir, 'hub'))
|
|
return base_path if model_id is None else os.path.join(
|
|
base_path, model_id + '/')
|
|
|
|
|
|
def get_release_datetime():
|
|
if MODELSCOPE_SDK_DEBUG in os.environ:
|
|
rt = int(round(datetime.now().timestamp()))
|
|
else:
|
|
from modelscope import version
|
|
rt = int(
|
|
round(
|
|
datetime.strptime(version.__release_datetime__,
|
|
'%Y-%m-%d %H:%M:%S').timestamp()))
|
|
return rt
|
|
|
|
|
|
def get_endpoint():
|
|
modelscope_domain = os.getenv('MODELSCOPE_DOMAIN',
|
|
DEFAULT_MODELSCOPE_DOMAIN)
|
|
return MODELSCOPE_URL_SCHEME + modelscope_domain
|
|
|
|
|
|
def compute_hash(file_path):
|
|
BUFFER_SIZE = 1024 * 64 # 64k buffer size
|
|
sha256_hash = hashlib.sha256()
|
|
with open(file_path, 'rb') as f:
|
|
while True:
|
|
data = f.read(BUFFER_SIZE)
|
|
if not data:
|
|
break
|
|
sha256_hash.update(data)
|
|
return sha256_hash.hexdigest()
|
|
|
|
|
|
def file_integrity_validation(file_path, expected_sha256):
|
|
"""Validate the file hash is expected, if not, delete the file
|
|
|
|
Args:
|
|
file_path (str): The file to validate
|
|
expected_sha256 (str): The expected sha256 hash
|
|
|
|
Raises:
|
|
FileIntegrityError: If file_path hash is not expected.
|
|
|
|
"""
|
|
file_sha256 = compute_hash(file_path)
|
|
if not file_sha256 == expected_sha256:
|
|
os.remove(file_path)
|
|
msg = 'File %s integrity check failed, the download may be incomplete, please try again.' % file_path
|
|
logger.error(msg)
|
|
raise FileIntegrityError(msg)
|
|
|
|
|
|
def create_library_statistics(method: str, name: str, cn_name: Optional[str]):
|
|
try:
|
|
from modelscope.hub.api import ModelScopeConfig
|
|
path = f'{get_endpoint()}/api/v1/statistics/library'
|
|
headers = {'user-agent': ModelScopeConfig.get_user_agent()}
|
|
params = {'Method': method, 'Name': name, 'CnName': cn_name}
|
|
r = requests.post(path, params=params, headers=headers)
|
|
r.raise_for_status()
|
|
except Exception:
|
|
pass
|
|
return
|