support third_party key in pipeline

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/12940228

* support third_party key in pipeline
This commit is contained in:
lanjinpeng.ljp
2023-06-20 19:25:24 +08:00
committed by wenmeng.zwm
parent fa7562fd96
commit bc5c16aa10

View File

@@ -7,7 +7,8 @@ from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import DEFAULT_MODEL_FOR_PIPELINE, Pipelines
from modelscope.models.base import Model
from modelscope.utils.config import ConfigDict, check_config
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, Invoke,
ThirdParty)
from modelscope.utils.hub import read_config
from modelscope.utils.plugins import (register_modelhub_repo,
register_plugins_repo)
@@ -18,7 +19,7 @@ from .util import is_official_hub_path
PIPELINES = Registry('pipelines')
def normalize_model_input(model, model_revision):
def normalize_model_input(model, model_revision, third_party=None):
""" normalize the input model, to ensure that a model str is a valid local path: in other words,
for model represented by a model id, the model shall be downloaded locally
"""
@@ -26,19 +27,21 @@ def normalize_model_input(model, model_revision):
# skip revision download if model is a local directory
if not os.path.exists(model):
# note that if there is already a local copy, snapshot_download will check and skip downloading
user_agent = {Invoke.KEY: Invoke.PIPELINE}
if third_party is not None:
user_agent[ThirdParty.KEY] = third_party
model = snapshot_download(
model,
revision=model_revision,
user_agent={Invoke.KEY: Invoke.PIPELINE})
model, revision=model_revision, user_agent=user_agent)
elif isinstance(model, list) and isinstance(model[0], str):
for idx in range(len(model)):
if is_official_hub_path(
model[idx],
model_revision) and not os.path.exists(model[idx]):
user_agent = {Invoke.KEY: Invoke.PIPELINE}
if third_party is not None:
user_agent[ThirdParty.KEY] = third_party
model[idx] = snapshot_download(
model[idx],
revision=model_revision,
user_agent={Invoke.KEY: Invoke.PIPELINE})
model[idx], revision=model_revision, user_agent=user_agent)
return model
@@ -97,7 +100,11 @@ def pipeline(task: str = None,
if task is None and pipeline_name is None:
raise ValueError('task or pipeline_name is required')
model = normalize_model_input(model, model_revision)
third_party = kwargs.get(ThirdParty.KEY)
if third_party is not None:
kwargs.pop(ThirdParty.KEY)
model = normalize_model_input(
model, model_revision, third_party=third_party)
pipeline_props = {'type': pipeline_name}
if pipeline_name is None:
# get default pipeline for this task