Files
modelscope/modelscope/pipelines/builder.py
luyan defb668def tmp
2025-01-21 10:32:59 +08:00

291 lines
12 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from typing import Any, Dict, List, Optional, Union
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import DEFAULT_MODEL_FOR_PIPELINE
from modelscope.models.base import Model
from modelscope.utils.config import ConfigDict, check_config
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, Invoke, Tasks,
ThirdParty)
from modelscope.utils.hub import read_config
from modelscope.utils.plugins import (register_modelhub_repo,
register_plugins_repo)
from modelscope.utils.registry import Registry, build_from_cfg
from modelscope.utils.logger import get_logger
from modelscope.utils.import_utils import is_transformers_available
from .base import Pipeline
from .util import is_official_hub_path
PIPELINES = Registry('pipelines')
logger = get_logger()
def normalize_model_input(model,
model_revision,
third_party=None,
ignore_file_pattern=None):
""" normalize the input model, to ensure that a model str is a valid local path: in other words,
for model represented by a model id, the model shall be downloaded locally
"""
if isinstance(model, str) and is_official_hub_path(model, model_revision):
# skip revision download if model is a local directory
if not os.path.exists(model):
# note that if there is already a local copy, snapshot_download will check and skip downloading
user_agent = {Invoke.KEY: Invoke.PIPELINE}
if third_party is not None:
user_agent[ThirdParty.KEY] = third_party
model = snapshot_download(
model,
revision=model_revision,
user_agent=user_agent,
ignore_file_pattern=ignore_file_pattern)
elif isinstance(model, list) and isinstance(model[0], str):
for idx in range(len(model)):
if is_official_hub_path(
model[idx],
model_revision) and not os.path.exists(model[idx]):
user_agent = {Invoke.KEY: Invoke.PIPELINE}
if third_party is not None:
user_agent[ThirdParty.KEY] = third_party
model[idx] = snapshot_download(
model[idx], revision=model_revision, user_agent=user_agent)
return model
def build_pipeline(cfg: ConfigDict,
task_name: str = None,
default_args: dict = None):
""" build pipeline given model config dict.
Args:
cfg (:obj:`ConfigDict`): config dict for model object.
task_name (str, optional): task name, refer to
:obj:`Tasks` for more details.
default_args (dict, optional): Default initialization arguments.
"""
return build_from_cfg(
cfg, PIPELINES, group_key=task_name, default_args=default_args)
def pipeline(task: str = None,
model: Union[str, List[str], Model, List[Model]] = None,
preprocessor=None,
config_file: str = None,
pipeline_name: str = None,
framework: str = None,
device: str = 'gpu',
model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
ignore_file_pattern: List[str] = None,
**kwargs) -> Pipeline:
""" Factory method to build an obj:`Pipeline`.
Args:
task (str): Task name defining which pipeline will be returned.
model (str or List[str] or obj:`Model` or obj:list[`Model`]): (list of) model name or model object.
preprocessor: preprocessor object.
config_file (str, optional): path to config file.
pipeline_name (str, optional): pipeline class name or alias name.
framework (str, optional): framework type.
model_revision: revision of model(s) if getting from model hub, for multiple models, expecting
all models to have the same revision
device (str, optional): whether to use gpu or cpu is used to do inference.
ignore_file_pattern(`str` or `List`, *optional*, default to `None`):
Any file pattern to be ignored in downloading, like exact file names or file extensions.
Return:
pipeline (obj:`Pipeline`): pipeline object for certain task.
Examples:
>>> # Using default model for a task
>>> p = pipeline('image-classification')
>>> # Using pipeline with a model name
>>> p = pipeline('text-classification', model='damo/distilbert-base-uncased')
>>> # Using pipeline with a model object
>>> resnet = Model.from_pretrained('Resnet')
>>> p = pipeline('image-classification', model=resnet)
>>> # Using pipeline with a list of model names
>>> p = pipeline('audio-kws', model=['damo/audio-tts', 'damo/auto-tts2'])
"""
if task is None and pipeline_name is None:
raise ValueError('task or pipeline_name is required')
pipeline_props = None
if pipeline_name is None:
# get default pipeline for this task
if isinstance(model, str) \
or (isinstance(model, list) and isinstance(model[0], str)):
if is_official_hub_path(model, revision=model_revision):
# read config file from hub and parse
cfg = read_config(
model, revision=model_revision) if isinstance(
model, str) else read_config(
model[0], revision=model_revision)
if cfg:
pipeline_name = cfg.safe_get('pipeline',
{}).get('type', None)
if pipeline_name is None:
prefer_llm_pipeline = kwargs.get('external_engine_for_llm')
# if not specified in both args and configuration.json, prefer llm pipeline for aforementioned tasks
if task is not None and task.lower() in [
Tasks.text_generation, Tasks.chat
]:
if prefer_llm_pipeline is None:
prefer_llm_pipeline = True
# for llm pipeline, if llm_framework is not specified, default to swift instead
# TODO: port the swift infer based on transformer into ModelScope
if prefer_llm_pipeline:
if kwargs.get('llm_framework') is None:
kwargs['llm_framework'] = 'swift'
pipeline_name = external_engine_for_llm_checker(
model, model_revision, kwargs)
if pipeline_name is None or pipeline_name != 'llm':
third_party = kwargs.get(ThirdParty.KEY)
if third_party is not None:
kwargs.pop(ThirdParty.KEY)
model = normalize_model_input(
model,
model_revision,
third_party=third_party,
ignore_file_pattern=ignore_file_pattern)
register_plugins_repo(cfg.safe_get('plugins'))
register_modelhub_repo(model,
cfg.get('allow_remote', False))
if pipeline_name:
pipeline_props = {'type': pipeline_name}
else:
try:
check_config(cfg)
pipeline_props = cfg.pipeline
except AssertionError as e:
logger.info(str(e))
elif model is not None:
# get pipeline info from Model object
first_model = model[0] if isinstance(model, list) else model
if not hasattr(first_model, 'pipeline'):
# model is instantiated by user, we should parse config again
cfg = read_config(first_model.model_dir)
check_config(cfg)
first_model.pipeline = cfg.pipeline
pipeline_props = first_model.pipeline
else:
pipeline_name, default_model_repo = get_default_pipeline_info(task)
model = normalize_model_input(default_model_repo, model_revision)
pipeline_props = {'type': pipeline_name}
else:
pipeline_props = {'type': pipeline_name}
if not pipeline_props and is_transformers_available():
from modelscope.utils.hf_util import hf_pipeline
return hf_pipeline(task=task,
model=model,
framework=framework,
device=device,
**kwargs)
pipeline_props['model'] = model
pipeline_props['device'] = device
cfg = ConfigDict(pipeline_props)
clear_llm_info(kwargs, pipeline_name)
if kwargs:
cfg.update(kwargs)
if preprocessor is not None:
cfg.preprocessor = preprocessor
return build_pipeline(cfg, task_name=task)
def add_default_pipeline_info(task: str,
model_name: str,
modelhub_name: str = None,
overwrite: bool = False):
""" Add default model for a task.
Args:
task (str): task name.
model_name (str): model_name.
modelhub_name (str): name for default modelhub.
overwrite (bool): overwrite default info.
"""
if not overwrite:
assert task not in DEFAULT_MODEL_FOR_PIPELINE, \
f'task {task} already has default model.'
DEFAULT_MODEL_FOR_PIPELINE[task] = (model_name, modelhub_name)
def get_default_pipeline_info(task):
""" Get default info for certain task.
Args:
task (str): task name.
Return:
A tuple: first element is pipeline name(model_name), second element
is modelhub name.
"""
if task not in DEFAULT_MODEL_FOR_PIPELINE:
# support pipeline which does not register default model
pipeline_name = list(PIPELINES.modules[task].keys())[0]
default_model = None
else:
pipeline_name, default_model = DEFAULT_MODEL_FOR_PIPELINE[task]
return pipeline_name, default_model
def external_engine_for_llm_checker(model: Union[str, List[str], Model,
List[Model]],
revision: Optional[str],
kwargs: Dict[str, Any]) -> Optional[str]:
from .nlp.llm_pipeline import ModelTypeHelper, LLMAdapterRegistry
from ..hub.check_model import get_model_id_from_cache
from swift.llm import get_model_info_meta
if isinstance(model, list):
model = model[0]
if not isinstance(model, str):
model = model.model_dir
if kwargs.get('llm_framework') == 'swift':
# check if swift supports
if os.path.exists(model):
model_id = get_model_id_from_cache(model)
else:
model_id = model
try:
info = get_model_info_meta(model_id)
model_type = info[0].model_type
except Exception as e:
logger.warning(
f'Cannot using llm_framework with {model_id}, '
f'ignoring llm_framework={self.llm_framework} : {e}')
model_type = None
if model_type:
return 'llm'
model_type = ModelTypeHelper.get(
model, revision, with_adapter=True, split='-', use_cache=True)
if LLMAdapterRegistry.contains(model_type):
return 'llm'
def clear_llm_info(kwargs: Dict, pipeline_name: str):
from modelscope.utils.model_type_helper import ModelTypeHelper
kwargs.pop('external_engine_for_llm', None)
if pipeline_name != 'llm':
kwargs.pop('llm_framework', None)
ModelTypeHelper.clear_cache()