Files
modelscope/modelscope/hub/utils/utils.py
2025-08-29 17:23:46 +08:00

383 lines
13 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import contextlib
import hashlib
import os
import sys
import time
import zoneinfo
from datetime import datetime
from pathlib import Path
from typing import Generator, List, Optional, Union
from filelock import BaseFileLock, FileLock, SoftFileLock, Timeout
from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN,
DEFAULT_MODELSCOPE_GROUP,
DEFAULT_MODELSCOPE_INTL_DOMAIN,
MODEL_ID_SEPARATOR, MODELSCOPE_DOMAIN,
MODELSCOPE_SDK_DEBUG,
MODELSCOPE_URL_SCHEME)
from modelscope.hub.errors import FileIntegrityError
from modelscope.utils.logger import get_logger
logger = get_logger()
def model_id_to_group_owner_name(model_id):
if MODEL_ID_SEPARATOR in model_id:
group_or_owner = model_id.split(MODEL_ID_SEPARATOR)[0]
name = model_id.split(MODEL_ID_SEPARATOR)[1]
else:
group_or_owner = DEFAULT_MODELSCOPE_GROUP
name = model_id
return group_or_owner, name
def is_env_true(var_name):
value = os.environ.get(var_name, '').strip().lower()
return value == 'true'
def get_domain(cn_site=True):
if MODELSCOPE_DOMAIN in os.environ and os.getenv(MODELSCOPE_DOMAIN):
return os.getenv(MODELSCOPE_DOMAIN)
if cn_site:
return DEFAULT_MODELSCOPE_DOMAIN
else:
return DEFAULT_MODELSCOPE_INTL_DOMAIN
def convert_patterns(raw_input: Union[str, List[str]]):
output = None
if isinstance(raw_input, str):
output = list()
if ',' in raw_input:
output = [s.strip() for s in raw_input.split(',')]
else:
output.append(raw_input.strip())
elif isinstance(raw_input, list):
output = list()
for s in raw_input:
if isinstance(s, str):
if ',' in s:
output.extend([ss.strip() for ss in s.split(',')])
else:
output.append(s.strip())
return output
# during model download, the '.' would be converted to '___' to produce
# actual physical (masked) directory for storage
def get_model_masked_directory(directory, model_id):
if sys.platform.startswith('win'):
parts = directory.rsplit('\\', 2)
else:
parts = directory.rsplit('/', 2)
# this is the actual directory the model files are located.
masked_directory = os.path.join(parts[0], model_id.replace('.', '___'))
return masked_directory
def convert_readable_size(size_bytes: int) -> str:
import math
if size_bytes == 0:
return '0B'
size_name = ('B', 'KB', 'MB', 'GB', 'TB', 'PB', 'EB', 'ZB', 'YB')
i = int(math.floor(math.log(size_bytes, 1024)))
p = math.pow(1024, i)
s = round(size_bytes / p, 2)
return f'{s} {size_name[i]}'
def get_folder_size(folder_path: str) -> int:
total_size = 0
for path in Path(folder_path).rglob('*'):
if path.is_file():
total_size += path.stat().st_size
return total_size
# return a readable string that describe size of for a given folder (MB, GB etc.)
def get_readable_folder_size(folder_path: str) -> str:
return convert_readable_size(get_folder_size(folder_path=folder_path))
def get_cache_dir(model_id: Optional[str] = None):
"""cache dir precedence:
function parameter > environment > ~/.cache/modelscope/hub
Args:
model_id (str, optional): The model id.
Returns:
str: the model_id dir if model_id not None, otherwise cache root dir.
"""
default_cache_dir = Path.home().joinpath('.cache', 'modelscope')
base_path = os.getenv('MODELSCOPE_CACHE',
os.path.join(default_cache_dir, 'hub'))
return base_path if model_id is None else os.path.join(
base_path, model_id + '/')
def get_release_datetime():
if MODELSCOPE_SDK_DEBUG in os.environ:
rt = int(round(datetime.now().timestamp()))
else:
from modelscope import version
rt = int(
round(
datetime.strptime(version.__release_datetime__,
'%Y-%m-%d %H:%M:%S').timestamp()))
return rt
def get_endpoint(cn_site=True):
return MODELSCOPE_URL_SCHEME + get_domain(cn_site)
def compute_hash(file_path):
BUFFER_SIZE = 1024 * 64 # 64k buffer size
sha256_hash = hashlib.sha256()
with open(file_path, 'rb') as f:
while True:
data = f.read(BUFFER_SIZE)
if not data:
break
sha256_hash.update(data)
return sha256_hash.hexdigest()
def file_integrity_validation(file_path, expected_sha256):
"""Validate the file hash is expected, if not, delete the file
Args:
file_path (str): The file to validate
expected_sha256 (str): The expected sha256 hash
Raises:
FileIntegrityError: If file_path hash is not expected.
"""
file_sha256 = compute_hash(file_path)
if not file_sha256 == expected_sha256:
os.remove(file_path)
msg = 'File %s integrity check failed, expected sha256 signature is %s, actual is %s, the download may be incomplete, please try again.' % ( # noqa E501
file_path, expected_sha256, file_sha256)
logger.error(msg)
raise FileIntegrityError(msg)
def add_content_to_file(repo,
file_name: str,
patterns: List[str],
commit_message: Optional[str] = None,
ignore_push_error=False) -> None:
if isinstance(patterns, str):
patterns = [patterns]
if commit_message is None:
commit_message = f'Add `{patterns[0]}` patterns to {file_name}'
# Get current file content
repo_dir = repo.model_dir
file_path = os.path.join(repo_dir, file_name)
if os.path.exists(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
current_content = f.read()
else:
current_content = ''
# Add the patterns to file
content = current_content
for pattern in patterns:
if pattern not in content:
if len(content) > 0 and not content.endswith('\n'):
content += '\n'
content += f'{pattern}\n'
# Write the file if it has changed
if content != current_content:
with open(file_path, 'w', encoding='utf-8') as f:
logger.debug(f'Writing {file_name} file. Content: {content}')
f.write(content)
try:
repo.push(commit_message)
except Exception as e:
if ignore_push_error:
pass
else:
raise e
_TIMESINCE_CHUNKS = (
# Label, divider, max value
('second', 1, 60),
('minute', 60, 60),
('hour', 60 * 60, 24),
('day', 60 * 60 * 24, 6),
('week', 60 * 60 * 24 * 7, 6),
('month', 60 * 60 * 24 * 30, 11),
('year', 60 * 60 * 24 * 365, None),
)
def format_timesince(ts: float) -> str:
"""Format timestamp in seconds into a human-readable string, relative to now.
"""
delta = time.time() - ts
if delta < 20:
return 'a few seconds ago'
for label, divider, max_value in _TIMESINCE_CHUNKS: # noqa: B007
value = round(delta / divider)
if max_value is not None and value <= max_value:
break
return f"{value} {label}{'s' if value > 1 else ''} ago"
def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str:
"""
Inspired by:
- stackoverflow.com/a/8356620/593036
- stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
"""
col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
row_format = ('{{:{}}} ' * len(headers)).format(*col_widths)
lines = []
lines.append(row_format.format(*headers))
lines.append(row_format.format(*['-' * w for w in col_widths]))
for row in rows:
lines.append(row_format.format(*row))
return '\n'.join(lines)
# Part of the code borrowed from the awesome work of huggingface_hub/transformers
def strtobool(val):
val = val.lower()
if val in {'y', 'yes', 't', 'true', 'on', '1'}:
return 1
if val in {'n', 'no', 'f', 'false', 'off', '0'}:
return 0
raise ValueError(f'invalid truth value {val!r}')
@contextlib.contextmanager
def weak_file_lock(lock_file: Union[str, Path],
*,
timeout: Optional[float] = None
) -> Generator[BaseFileLock, None, None]:
default_interval = 60
lock = FileLock(lock_file, timeout=default_interval)
start_time = time.time()
while True:
elapsed_time = time.time() - start_time
if timeout is not None and elapsed_time >= timeout:
raise Timeout(str(lock_file))
try:
lock.acquire(
timeout=min(default_interval, timeout - elapsed_time)
if timeout else default_interval) # noqa
except Timeout:
logger.info(
f'Still waiting to acquire lock on {lock_file} (elapsed: {time.time() - start_time:.1f} seconds)'
)
except NotImplementedError as e:
if 'use SoftFileLock instead' in str(e):
logger.warning(
'FileSystem does not appear to support flock. Falling back to SoftFileLock for %s',
lock_file)
lock = SoftFileLock(lock_file, timeout=default_interval)
continue
else:
break
try:
yield lock
finally:
try:
lock.release()
except OSError:
try:
Path(lock_file).unlink()
except OSError:
pass
def convert_timestamp(time_stamp: Union[int, str, datetime],
time_zone: str = 'Asia/Shanghai') -> Optional[datetime]:
"""Convert a UNIX/string timestamp to a timezone-aware datetime object.
Args:
time_stamp: UNIX timestamp (int), ISO string, or datetime object
time_zone: Target timezone for non-UTC timestamps (default: 'Asia/Shanghai')
Returns:
Timezone-aware datetime object or None if input is None
"""
if not time_stamp:
return None
# Handle datetime objects first
if isinstance(time_stamp, datetime):
return time_stamp
if isinstance(time_stamp, str):
try:
if time_stamp.endswith('Z'):
# Normalize fractional seconds to 6 digits
if '.' not in time_stamp:
# No fractional seconds (e.g., "2024-11-16T00:27:02Z")
time_stamp = time_stamp[:-1] + '.000000Z'
else:
# Has fractional seconds (e.g., "2022-08-19T07:19:38.123456789Z")
base, fraction = time_stamp[:-1].split('.')
# Truncate or pad to 6 digits
fraction = fraction[:6].ljust(6, '0')
time_stamp = f'{base}.{fraction}Z'
dt = datetime.strptime(time_stamp,
'%Y-%m-%dT%H:%M:%S.%fZ').replace(
tzinfo=zoneinfo.ZoneInfo('UTC'))
if time_zone != 'UTC':
dt = dt.astimezone(zoneinfo.ZoneInfo(time_zone))
return dt
else:
# Try parsing common ISO formats
formats = [
'%Y-%m-%dT%H:%M:%S.%f', # With microseconds
'%Y-%m-%dT%H:%M:%S', # Without microseconds
'%Y-%m-%d %H:%M:%S.%f', # Space separator with microseconds
'%Y-%m-%d %H:%M:%S', # Space separator without microseconds
]
for fmt in formats:
try:
return datetime.strptime(
time_stamp,
fmt).replace(tzinfo=zoneinfo.ZoneInfo(time_zone))
except ValueError:
continue
raise ValueError(
f"Unsupported timestamp format: '{time_stamp}'")
except ValueError as e:
raise ValueError(
f"Cannot parse '{time_stamp}' as a datetime. Expected formats: "
f"'YYYY-MM-DDTHH:MM:SS[.ffffff]Z' (UTC) or 'YYYY-MM-DDTHH:MM:SS[.ffffff]' (local)"
) from e
elif isinstance(time_stamp, int):
try:
# UNIX timestamps are always in UTC, then convert to target timezone
return datetime.fromtimestamp(
time_stamp, tz=zoneinfo.ZoneInfo('UTC')).astimezone(
zoneinfo.ZoneInfo(time_zone))
except (ValueError, OSError) as e:
raise ValueError(
f"Cannot convert '{time_stamp}' to datetime. Ensure it's a valid UNIX timestamp."
) from e
else:
raise TypeError(
f"Unsupported type '{type(time_stamp)}'. Expected int, str, or datetime."
)