Files
modelscope/modelscope/pipelines/util.py
2023-09-28 00:54:41 +08:00

114 lines
3.8 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
from typing import List, Optional, Union
from modelscope.hub.api import HubApi
from modelscope.hub.file_download import model_file_download
from modelscope.utils.config import Config
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile
from modelscope.utils.logger import get_logger
logger = get_logger()
def is_config_has_model(cfg_file):
try:
cfg = Config.from_file(cfg_file)
return hasattr(cfg, 'model') or hasattr(cfg, 'model_type')
except Exception as e:
logger.error(f'parse config file {cfg_file} failed: {e}')
return False
def is_official_hub_path(path: Union[str, List],
revision: Optional[str] = DEFAULT_MODEL_REVISION):
""" Whether path is an official hub name or a valid local
path to official hub directory.
"""
def is_official_hub_impl(path):
if osp.exists(path):
cfg_file = osp.join(path, ModelFile.CONFIGURATION)
return osp.exists(cfg_file)
else:
try:
_ = HubApi().get_model(path, revision=revision)
return True
except Exception as e:
raise ValueError(f'invalid model repo path {e}')
if isinstance(path, str):
return is_official_hub_impl(path)
else:
results = [is_official_hub_impl(m) for m in path]
all_true = all(results)
any_true = any(results)
if any_true and not all_true:
raise ValueError(
f'some model are hub address, some are not, model list: {path}'
)
return all_true
def is_model(path: Union[str, List]):
""" whether path is a valid modelhub path and containing model config
"""
def is_modelhub_path_impl(path):
if osp.exists(path):
cfg_file = osp.join(path, ModelFile.CONFIGURATION)
hf_cfg_file = osp.join(path, ModelFile.CONFIG)
if osp.exists(cfg_file):
return is_config_has_model(cfg_file)
elif osp.exists(hf_cfg_file):
return is_config_has_model(hf_cfg_file)
else:
return False
else:
try:
cfg_file = model_file_download(path, ModelFile.CONFIGURATION)
if is_config_has_model(cfg_file):
return True
else:
hf_cfg_file = model_file_download(path, ModelFile.CONFIG)
return is_config_has_model(hf_cfg_file)
except Exception:
return False
if isinstance(path, str):
return is_modelhub_path_impl(path)
else:
results = [is_modelhub_path_impl(m) for m in path]
all_true = all(results)
any_true = any(results)
if any_true and not all_true:
raise ValueError(
f'some models are hub address, some are not, model list: {path}'
)
return all_true
def batch_process(model, data):
import torch
if model.__class__.__name__ == 'OfaForAllTasks':
# collate batch data due to the nested data structure
assert isinstance(data, list)
batch_data = {
'nsentences': len(data),
'samples': [d['samples'][0] for d in data],
'net_input': {}
}
for k in data[0]['net_input'].keys():
batch_data['net_input'][k] = torch.cat(
[d['net_input'][k] for d in data])
if 'w_resize_ratios' in data[0]:
batch_data['w_resize_ratios'] = torch.cat(
[d['w_resize_ratios'] for d in data])
if 'h_resize_ratios' in data[0]:
batch_data['h_resize_ratios'] = torch.cat(
[d['h_resize_ratios'] for d in data])
return batch_data