Files
modelscope/modelscope/hub/utils/utils.py
2022-11-01 18:04:48 +08:00

103 lines
3.2 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import hashlib
import os
from datetime import datetime
from typing import Optional
import requests
from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN,
DEFAULT_MODELSCOPE_GROUP,
MODEL_ID_SEPARATOR, MODELSCOPE_SDK_DEBUG,
MODELSCOPE_URL_SCHEME)
from modelscope.hub.errors import FileIntegrityError
from modelscope.utils.file_utils import get_default_cache_dir
from modelscope.utils.logger import get_logger
logger = get_logger()
def model_id_to_group_owner_name(model_id):
if MODEL_ID_SEPARATOR in model_id:
group_or_owner = model_id.split(MODEL_ID_SEPARATOR)[0]
name = model_id.split(MODEL_ID_SEPARATOR)[1]
else:
group_or_owner = DEFAULT_MODELSCOPE_GROUP
name = model_id
return group_or_owner, name
def get_cache_dir(model_id: Optional[str] = None):
"""
cache dir precedence:
function parameter > enviroment > ~/.cache/modelscope/hub
"""
default_cache_dir = get_default_cache_dir()
base_path = os.getenv('MODELSCOPE_CACHE',
os.path.join(default_cache_dir, 'hub'))
return base_path if model_id is None else os.path.join(
base_path, model_id + '/')
def get_release_datetime():
if MODELSCOPE_SDK_DEBUG in os.environ:
rt = int(round(datetime.now().timestamp()))
else:
from modelscope import version
rt = int(
round(
datetime.strptime(version.__release_datetime__,
'%Y-%m-%d %H:%M:%S').timestamp()))
return rt
def get_endpoint():
modelscope_domain = os.getenv('MODELSCOPE_DOMAIN',
DEFAULT_MODELSCOPE_DOMAIN)
return MODELSCOPE_URL_SCHEME + modelscope_domain
def compute_hash(file_path):
BUFFER_SIZE = 1024 * 64 # 64k buffer size
sha256_hash = hashlib.sha256()
with open(file_path, 'rb') as f:
while True:
data = f.read(BUFFER_SIZE)
if not data:
break
sha256_hash.update(data)
return sha256_hash.hexdigest()
def file_integrity_validation(file_path, expected_sha256):
"""Validate the file hash is expected, if not, delete the file
Args:
file_path (str): The file to validate
expected_sha256 (str): The expected sha256 hash
Raises:
FileIntegrityError: If file_path hash is not expected.
"""
file_sha256 = compute_hash(file_path)
if not file_sha256 == expected_sha256:
os.remove(file_path)
msg = 'File %s integrity check failed, the download may be incomplete, please try again.' % file_path
logger.error(msg)
raise FileIntegrityError(msg)
def create_library_statistics(method: str, name: str, cn_name: Optional[str]):
try:
from modelscope.hub.api import ModelScopeConfig
path = f'{get_endpoint()}/api/v1/statistics/library'
headers = {'user-agent': ModelScopeConfig.get_user_agent()}
params = {'Method': method, 'Name': name, 'CnName': cn_name}
r = requests.post(path, params=params, headers=headers)
r.raise_for_status()
except Exception:
pass
return