Files
modelscope/modelscope/utils/hub.py
suluyana 845a0bd5fe Fix/text gen (#1177)
* fix text-gen: read pipeline type from configuration.json first

---------

Co-authored-by: suluyan <suluyan.sly@alibaba-inc.com>
2025-01-13 10:05:42 +08:00

158 lines
5.7 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import os.path as osp
from typing import List, Optional, Union
from requests import HTTPError
from modelscope.hub.constants import Licenses, ModelVisibility
from modelscope.hub.file_download import model_file_download
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.utils.config import Config
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigFields,
ModelFile)
from .logger import get_logger
logger = get_logger()
def create_model_if_not_exist(
api,
model_id: str,
chinese_name: str,
visibility: Optional[int] = ModelVisibility.PUBLIC,
license: Optional[str] = Licenses.APACHE_V2):
if api.repo_exists(model_id):
logger.info(f'model {model_id} already exists, skip creation.')
return False
else:
api.create_model(
model_id=model_id,
visibility=visibility,
license=license,
chinese_name=chinese_name,
)
logger.info(f'model {model_id} successfully created.')
return True
def read_config(model_id_or_path: str,
revision: Optional[str] = DEFAULT_MODEL_REVISION):
""" Read config from hub or local path
Args:
model_id_or_path (str): Model repo name or local directory path.
revision: revision of the model when getting from the hub
Return:
config (:obj:`Config`): config object
"""
if not os.path.exists(model_id_or_path):
local_path = model_file_download(
model_id_or_path, ModelFile.CONFIGURATION, revision=revision)
elif os.path.isdir(model_id_or_path):
local_path = os.path.join(model_id_or_path, ModelFile.CONFIGURATION)
elif os.path.isfile(model_id_or_path):
local_path = model_id_or_path
else:
return None
return Config.from_file(local_path)
def auto_load(model: Union[str, List[str]]):
if isinstance(model, str):
if not osp.exists(model):
model = snapshot_download(model)
else:
model = [
snapshot_download(m) if not osp.exists(m) else m for m in model
]
return model
def get_model_type(model_dir):
"""Get the model type from the configuration.
This method will try to get the model type from 'model.backbone.type',
'model.type' or 'model.model_type' field in the configuration.json file. If
this file does not exist, the method will try to get the 'model_type' field
from the config.json.
Args:
model_dir: The local model dir to use. @return: The model type
string, returns None if nothing is found.
"""
try:
configuration_file = osp.join(model_dir, ModelFile.CONFIGURATION)
config_file = osp.join(model_dir, 'config.json')
if osp.isfile(configuration_file):
cfg = Config.from_file(configuration_file)
if hasattr(cfg.model, 'backbone'):
return cfg.model.backbone.type
elif hasattr(cfg.model,
'model_type') and not hasattr(cfg.model, 'type'):
return cfg.model.model_type
else:
return cfg.model.type
elif osp.isfile(config_file):
cfg = Config.from_file(config_file)
return cfg.model_type if hasattr(cfg, 'model_type') else None
except Exception as e:
logger.error(f'parse config file failed with error: {e}')
def parse_label_mapping(model_dir):
"""Get the label mapping from the model dir.
This method will do:
1. Try to read label-id mapping from the label_mapping.json
2. Try to read label-id mapping from the configuration.json
3. Try to read label-id mapping from the config.json
Args:
model_dir: The local model dir to use.
Returns:
The label2id mapping if found.
"""
import json
import os
label2id = None
label_path = os.path.join(model_dir, ModelFile.LABEL_MAPPING)
if os.path.exists(label_path):
with open(label_path, encoding='utf-8') as f:
label_mapping = json.load(f)
label2id = {name: idx for name, idx in label_mapping.items()}
if label2id is None:
config_path = os.path.join(model_dir, ModelFile.CONFIGURATION)
config = Config.from_file(config_path)
if hasattr(config, ConfigFields.model) and hasattr(
config[ConfigFields.model], 'label2id'):
label2id = config[ConfigFields.model].label2id
elif hasattr(config, ConfigFields.model) and hasattr(
config[ConfigFields.model], 'id2label'):
id2label = config[ConfigFields.model].id2label
label2id = {label: id for id, label in id2label.items()}
elif hasattr(config, ConfigFields.preprocessor) and hasattr(
config[ConfigFields.preprocessor], 'label2id'):
label2id = config[ConfigFields.preprocessor].label2id
elif hasattr(config, ConfigFields.preprocessor) and hasattr(
config[ConfigFields.preprocessor], 'id2label'):
id2label = config[ConfigFields.preprocessor].id2label
label2id = {label: id for id, label in id2label.items()}
config_path = os.path.join(model_dir, 'config.json')
if label2id is None and os.path.exists(config_path):
config = Config.from_file(config_path)
if hasattr(config, 'label2id'):
label2id = config.label2id
elif hasattr(config, 'id2label'):
id2label = config.id2label
label2id = {label: id for id, label in id2label.items()}
if label2id is not None:
label2id = {label: int(id) for label, id in label2id.items()}
return label2id