mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 12:10:09 +01:00
* fix text-gen: read pipeline type from configuration.json first --------- Co-authored-by: suluyan <suluyan.sly@alibaba-inc.com>
158 lines
5.7 KiB
Python
158 lines
5.7 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
|
|
import os
|
|
import os.path as osp
|
|
from typing import List, Optional, Union
|
|
|
|
from requests import HTTPError
|
|
|
|
from modelscope.hub.constants import Licenses, ModelVisibility
|
|
from modelscope.hub.file_download import model_file_download
|
|
from modelscope.hub.snapshot_download import snapshot_download
|
|
from modelscope.utils.config import Config
|
|
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigFields,
|
|
ModelFile)
|
|
from .logger import get_logger
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
def create_model_if_not_exist(
|
|
api,
|
|
model_id: str,
|
|
chinese_name: str,
|
|
visibility: Optional[int] = ModelVisibility.PUBLIC,
|
|
license: Optional[str] = Licenses.APACHE_V2):
|
|
if api.repo_exists(model_id):
|
|
logger.info(f'model {model_id} already exists, skip creation.')
|
|
return False
|
|
else:
|
|
api.create_model(
|
|
model_id=model_id,
|
|
visibility=visibility,
|
|
license=license,
|
|
chinese_name=chinese_name,
|
|
)
|
|
logger.info(f'model {model_id} successfully created.')
|
|
return True
|
|
|
|
|
|
def read_config(model_id_or_path: str,
|
|
revision: Optional[str] = DEFAULT_MODEL_REVISION):
|
|
""" Read config from hub or local path
|
|
|
|
Args:
|
|
model_id_or_path (str): Model repo name or local directory path.
|
|
revision: revision of the model when getting from the hub
|
|
Return:
|
|
config (:obj:`Config`): config object
|
|
"""
|
|
if not os.path.exists(model_id_or_path):
|
|
local_path = model_file_download(
|
|
model_id_or_path, ModelFile.CONFIGURATION, revision=revision)
|
|
elif os.path.isdir(model_id_or_path):
|
|
local_path = os.path.join(model_id_or_path, ModelFile.CONFIGURATION)
|
|
elif os.path.isfile(model_id_or_path):
|
|
local_path = model_id_or_path
|
|
else:
|
|
return None
|
|
|
|
return Config.from_file(local_path)
|
|
|
|
|
|
def auto_load(model: Union[str, List[str]]):
|
|
if isinstance(model, str):
|
|
if not osp.exists(model):
|
|
model = snapshot_download(model)
|
|
else:
|
|
model = [
|
|
snapshot_download(m) if not osp.exists(m) else m for m in model
|
|
]
|
|
|
|
return model
|
|
|
|
|
|
def get_model_type(model_dir):
|
|
"""Get the model type from the configuration.
|
|
|
|
This method will try to get the model type from 'model.backbone.type',
|
|
'model.type' or 'model.model_type' field in the configuration.json file. If
|
|
this file does not exist, the method will try to get the 'model_type' field
|
|
from the config.json.
|
|
|
|
Args:
|
|
model_dir: The local model dir to use. @return: The model type
|
|
string, returns None if nothing is found.
|
|
"""
|
|
try:
|
|
configuration_file = osp.join(model_dir, ModelFile.CONFIGURATION)
|
|
config_file = osp.join(model_dir, 'config.json')
|
|
if osp.isfile(configuration_file):
|
|
cfg = Config.from_file(configuration_file)
|
|
if hasattr(cfg.model, 'backbone'):
|
|
return cfg.model.backbone.type
|
|
elif hasattr(cfg.model,
|
|
'model_type') and not hasattr(cfg.model, 'type'):
|
|
return cfg.model.model_type
|
|
else:
|
|
return cfg.model.type
|
|
elif osp.isfile(config_file):
|
|
cfg = Config.from_file(config_file)
|
|
return cfg.model_type if hasattr(cfg, 'model_type') else None
|
|
except Exception as e:
|
|
logger.error(f'parse config file failed with error: {e}')
|
|
|
|
|
|
def parse_label_mapping(model_dir):
|
|
"""Get the label mapping from the model dir.
|
|
|
|
This method will do:
|
|
1. Try to read label-id mapping from the label_mapping.json
|
|
2. Try to read label-id mapping from the configuration.json
|
|
3. Try to read label-id mapping from the config.json
|
|
|
|
Args:
|
|
model_dir: The local model dir to use.
|
|
|
|
Returns:
|
|
The label2id mapping if found.
|
|
"""
|
|
import json
|
|
import os
|
|
label2id = None
|
|
label_path = os.path.join(model_dir, ModelFile.LABEL_MAPPING)
|
|
if os.path.exists(label_path):
|
|
with open(label_path, encoding='utf-8') as f:
|
|
label_mapping = json.load(f)
|
|
label2id = {name: idx for name, idx in label_mapping.items()}
|
|
|
|
if label2id is None:
|
|
config_path = os.path.join(model_dir, ModelFile.CONFIGURATION)
|
|
config = Config.from_file(config_path)
|
|
if hasattr(config, ConfigFields.model) and hasattr(
|
|
config[ConfigFields.model], 'label2id'):
|
|
label2id = config[ConfigFields.model].label2id
|
|
elif hasattr(config, ConfigFields.model) and hasattr(
|
|
config[ConfigFields.model], 'id2label'):
|
|
id2label = config[ConfigFields.model].id2label
|
|
label2id = {label: id for id, label in id2label.items()}
|
|
elif hasattr(config, ConfigFields.preprocessor) and hasattr(
|
|
config[ConfigFields.preprocessor], 'label2id'):
|
|
label2id = config[ConfigFields.preprocessor].label2id
|
|
elif hasattr(config, ConfigFields.preprocessor) and hasattr(
|
|
config[ConfigFields.preprocessor], 'id2label'):
|
|
id2label = config[ConfigFields.preprocessor].id2label
|
|
label2id = {label: id for id, label in id2label.items()}
|
|
|
|
config_path = os.path.join(model_dir, 'config.json')
|
|
if label2id is None and os.path.exists(config_path):
|
|
config = Config.from_file(config_path)
|
|
if hasattr(config, 'label2id'):
|
|
label2id = config.label2id
|
|
elif hasattr(config, 'id2label'):
|
|
id2label = config.id2label
|
|
label2id = {label: id for id, label in id2label.items()}
|
|
if label2id is not None:
|
|
label2id = {label: int(id) for label, id in label2id.items()}
|
|
return label2id
|