mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 08:17:45 +01:00
247 lines
11 KiB
Python
247 lines
11 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import os
|
|
import os.path as osp
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
|
from modelscope.hub.snapshot_download import snapshot_download
|
|
from modelscope.metainfo import Tasks
|
|
from modelscope.models.builder import build_backbone, build_model
|
|
from modelscope.utils.automodel_utils import (can_load_by_ms,
|
|
try_to_load_hf_model)
|
|
from modelscope.utils.config import Config, ConfigDict
|
|
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke, ModelFile
|
|
from modelscope.utils.device import verify_device
|
|
from modelscope.utils.logger import get_logger
|
|
from modelscope.utils.plugins import (register_modelhub_repo,
|
|
register_plugins_repo)
|
|
|
|
logger = get_logger()
|
|
|
|
Tensor = Union['torch.Tensor', 'tf.Tensor']
|
|
|
|
|
|
class Model(ABC):
|
|
"""Base model interface.
|
|
"""
|
|
|
|
def __init__(self, model_dir, *args, **kwargs):
|
|
self.model_dir = model_dir
|
|
device_name = kwargs.get('device', 'gpu')
|
|
verify_device(device_name)
|
|
self._device_name = device_name
|
|
self.trust_remote_code = kwargs.get('trust_remote_code', False)
|
|
|
|
def __call__(self, *args, **kwargs) -> Dict[str, Any]:
|
|
return self.postprocess(self.forward(*args, **kwargs))
|
|
|
|
def check_trust_remote_code(self, info_str: Optional[str] = None):
|
|
"""Check trust_remote_code if the model needs to import extra libs
|
|
|
|
Args:
|
|
info_str(str): The info showed to user if trust_remote_code is `False`.
|
|
"""
|
|
info_str = info_str or (
|
|
'This model requires `trust_remote_code` to be `True` because it needs to '
|
|
'import extra libs or execute the code in the model repo, setting this to true '
|
|
'means you trust the files in it.')
|
|
assert self.trust_remote_code, info_str
|
|
|
|
@abstractmethod
|
|
def forward(self, *args, **kwargs) -> Dict[str, Any]:
|
|
"""
|
|
Run the forward pass for a model.
|
|
|
|
Returns:
|
|
Dict[str, Any]: output from the model forward pass
|
|
"""
|
|
pass
|
|
|
|
def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
|
|
""" Model specific postprocess and convert model output to
|
|
standard model outputs.
|
|
|
|
Args:
|
|
inputs: input data
|
|
|
|
Return:
|
|
dict of results: a dict containing outputs of model, each
|
|
output should have the standard output name.
|
|
"""
|
|
return inputs
|
|
|
|
@classmethod
|
|
def _instantiate(cls, **kwargs):
|
|
""" Define the instantiation method of a model,default method is by
|
|
calling the constructor. Note that in the case of no loading model
|
|
process in constructor of a task model, a load_model method is
|
|
added, and thus this method is overloaded
|
|
"""
|
|
return cls(**kwargs)
|
|
|
|
@classmethod
|
|
def from_pretrained(cls,
|
|
model_name_or_path: str,
|
|
revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
|
cfg_dict: Config = None,
|
|
device: str = None,
|
|
trust_remote_code: Optional[bool] = False,
|
|
**kwargs):
|
|
"""Instantiate a model from local directory or remote model repo. Note
|
|
that when loading from remote, the model revision can be specified.
|
|
|
|
Args:
|
|
model_name_or_path(str): A model dir or a model id to be loaded
|
|
revision(str, `optional`): The revision used when the model_name_or_path is
|
|
a model id of the remote hub. default `master`.
|
|
cfg_dict(Config, `optional`): An optional model config. If provided, it will replace
|
|
the config read out of the `model_name_or_path`
|
|
device(str, `optional`): The device to load the model.
|
|
trust_remote_code(bool, `optional`): Whether to trust and allow execution of remote code. Default is False.
|
|
**kwargs:
|
|
task(str, `optional`): The `Tasks` enumeration value to replace the task value
|
|
read out of config in the `model_name_or_path`. This is useful when the model to be loaded is not
|
|
equal to the model saved.
|
|
For example, load a `backbone` into a `text-classification` model.
|
|
Other kwargs will be directly fed into the `model` key, to replace the default configs.
|
|
use_hf(bool, `optional`):
|
|
If set to True, it will initialize the model using AutoModel or AutoModelFor* from hf.
|
|
If set to False, the model is loaded using the modelscope mode.
|
|
If set to None, the loading mode will be automatically selected.
|
|
ignore_file_pattern(List[str], `optional`):
|
|
This parameter is passed to snapshot_download
|
|
device_map(str | Dict[str, str], `optional`):
|
|
This parameter is passed to AutoModel or AutoModelFor*
|
|
torch_dtype(torch.dtype, `optional`):
|
|
This parameter is passed to AutoModel or AutoModelFor*
|
|
config(PretrainedConfig, `optional`):
|
|
This parameter is passed to AutoModel or AutoModelFor*
|
|
Returns:
|
|
A model instance.
|
|
|
|
Examples:
|
|
>>> from modelscope.models import Model
|
|
>>> Model.from_pretrained('damo/nlp_structbert_backbone_base_std', task='text-classification')
|
|
"""
|
|
prefetched = kwargs.get('model_prefetched')
|
|
if prefetched is not None:
|
|
kwargs.pop('model_prefetched')
|
|
invoked_by = kwargs.get(Invoke.KEY)
|
|
if invoked_by is not None:
|
|
kwargs.pop(Invoke.KEY)
|
|
else:
|
|
invoked_by = Invoke.PRETRAINED
|
|
|
|
ignore_file_pattern = kwargs.pop('ignore_file_pattern', None)
|
|
if osp.exists(model_name_or_path):
|
|
local_model_dir = model_name_or_path
|
|
else:
|
|
if prefetched is True:
|
|
raise RuntimeError(
|
|
'Expecting model is pre-fetched locally, but is not found.'
|
|
)
|
|
|
|
invoked_by = '%s/%s' % (Invoke.KEY, invoked_by)
|
|
local_model_dir = snapshot_download(
|
|
model_name_or_path,
|
|
revision,
|
|
user_agent=invoked_by,
|
|
ignore_file_pattern=ignore_file_pattern)
|
|
logger.info(f'initialize model from {local_model_dir}')
|
|
|
|
configuration_path = osp.join(local_model_dir, ModelFile.CONFIGURATION)
|
|
cfg = None
|
|
if cfg_dict is not None:
|
|
cfg = cfg_dict
|
|
elif os.path.exists(configuration_path):
|
|
cfg = Config.from_file(configuration_path)
|
|
task_name = getattr(cfg, 'task', None)
|
|
if 'task' in kwargs:
|
|
task_name = kwargs.pop('task')
|
|
model_cfg = getattr(cfg, 'model', ConfigDict())
|
|
if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'):
|
|
model_cfg.type = model_cfg.model_type
|
|
model_type = getattr(model_cfg, 'type', None)
|
|
if isinstance(device, str) and device.startswith('gpu'):
|
|
device = 'cuda' + device[3:]
|
|
use_hf = kwargs.pop('use_hf', None)
|
|
if use_hf is None and can_load_by_ms(local_model_dir, task_name,
|
|
model_type):
|
|
use_hf = False
|
|
model = None
|
|
if use_hf in {True, None}:
|
|
model = try_to_load_hf_model(local_model_dir, task_name, use_hf,
|
|
**kwargs)
|
|
if model is not None:
|
|
device_map = kwargs.pop('device_map', None)
|
|
if device_map is None and device is not None:
|
|
model = model.to(device)
|
|
return model
|
|
# use ms
|
|
if cfg is None:
|
|
raise FileNotFoundError(
|
|
f'`{ModelFile.CONFIGURATION}` file not found.')
|
|
model_cfg.model_dir = local_model_dir
|
|
|
|
# Security check: Only allow execution of remote code or plugins if trust_remote_code is True
|
|
plugins = cfg.safe_get('plugins')
|
|
if plugins and not trust_remote_code:
|
|
raise RuntimeError(
|
|
'Detected plugins field in the model configuration file, but '
|
|
'trust_remote_code=True was not explicitly set.\n'
|
|
'To prevent potential execution of malicious code, loading has been refused.\n'
|
|
'If you trust this model repository, please pass trust_remote_code=True to from_pretrained.'
|
|
)
|
|
if plugins and trust_remote_code:
|
|
logger.warning(
|
|
'Use trust_remote_code=True. Will invoke codes or install plugins from remote model repo. '
|
|
'Please make sure that you can trust the external codes.')
|
|
register_modelhub_repo(local_model_dir, allow_remote=trust_remote_code)
|
|
default_args = {}
|
|
if trust_remote_code:
|
|
default_args = {'trust_remote_code': trust_remote_code}
|
|
register_plugins_repo(plugins)
|
|
for k, v in kwargs.items():
|
|
model_cfg[k] = v
|
|
if device is not None:
|
|
model_cfg.device = device
|
|
if task_name is Tasks.backbone:
|
|
model_cfg.init_backbone = True
|
|
model = build_backbone(model_cfg)
|
|
else:
|
|
model = build_model(
|
|
model_cfg, task_name=task_name, default_args=default_args)
|
|
|
|
# dynamically add pipeline info to model for pipeline inference
|
|
if hasattr(cfg, 'pipeline'):
|
|
model.pipeline = cfg.pipeline
|
|
|
|
if not hasattr(model, 'cfg'):
|
|
model.cfg = cfg
|
|
|
|
model_cfg.pop('model_dir', None)
|
|
model.name = model_name_or_path
|
|
model.model_dir = local_model_dir
|
|
return model
|
|
|
|
def save_pretrained(self,
|
|
target_folder: Union[str, os.PathLike],
|
|
save_checkpoint_names: Union[str, List[str]] = None,
|
|
config: Optional[dict] = None,
|
|
**kwargs):
|
|
"""save the pretrained model, its configuration and other related files to a directory,
|
|
so that it can be re-loaded
|
|
|
|
Args:
|
|
target_folder (Union[str, os.PathLike]):
|
|
Directory to which to save. Will be created if it doesn't exist.
|
|
|
|
save_checkpoint_names (Union[str, List[str]]):
|
|
The checkpoint names to be saved in the target_folder
|
|
|
|
config (Optional[dict], optional):
|
|
The config for the configuration.json, might not be identical with model.config
|
|
"""
|
|
raise NotImplementedError(
|
|
'save_pretrained method need to be implemented by the subclass.')
|