mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 12:09:22 +01:00
support third_party key in pipeline
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/12940228 * support third_party key in pipeline
This commit is contained in:
committed by
wenmeng.zwm
parent
fa7562fd96
commit
bc5c16aa10
@@ -7,7 +7,8 @@ from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.metainfo import DEFAULT_MODEL_FOR_PIPELINE, Pipelines
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.utils.config import ConfigDict, check_config
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke
|
||||
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, Invoke,
|
||||
ThirdParty)
|
||||
from modelscope.utils.hub import read_config
|
||||
from modelscope.utils.plugins import (register_modelhub_repo,
|
||||
register_plugins_repo)
|
||||
@@ -18,7 +19,7 @@ from .util import is_official_hub_path
|
||||
PIPELINES = Registry('pipelines')
|
||||
|
||||
|
||||
def normalize_model_input(model, model_revision):
|
||||
def normalize_model_input(model, model_revision, third_party=None):
|
||||
""" normalize the input model, to ensure that a model str is a valid local path: in other words,
|
||||
for model represented by a model id, the model shall be downloaded locally
|
||||
"""
|
||||
@@ -26,19 +27,21 @@ def normalize_model_input(model, model_revision):
|
||||
# skip revision download if model is a local directory
|
||||
if not os.path.exists(model):
|
||||
# note that if there is already a local copy, snapshot_download will check and skip downloading
|
||||
user_agent = {Invoke.KEY: Invoke.PIPELINE}
|
||||
if third_party is not None:
|
||||
user_agent[ThirdParty.KEY] = third_party
|
||||
model = snapshot_download(
|
||||
model,
|
||||
revision=model_revision,
|
||||
user_agent={Invoke.KEY: Invoke.PIPELINE})
|
||||
model, revision=model_revision, user_agent=user_agent)
|
||||
elif isinstance(model, list) and isinstance(model[0], str):
|
||||
for idx in range(len(model)):
|
||||
if is_official_hub_path(
|
||||
model[idx],
|
||||
model_revision) and not os.path.exists(model[idx]):
|
||||
user_agent = {Invoke.KEY: Invoke.PIPELINE}
|
||||
if third_party is not None:
|
||||
user_agent[ThirdParty.KEY] = third_party
|
||||
model[idx] = snapshot_download(
|
||||
model[idx],
|
||||
revision=model_revision,
|
||||
user_agent={Invoke.KEY: Invoke.PIPELINE})
|
||||
model[idx], revision=model_revision, user_agent=user_agent)
|
||||
return model
|
||||
|
||||
|
||||
@@ -97,7 +100,11 @@ def pipeline(task: str = None,
|
||||
if task is None and pipeline_name is None:
|
||||
raise ValueError('task or pipeline_name is required')
|
||||
|
||||
model = normalize_model_input(model, model_revision)
|
||||
third_party = kwargs.get(ThirdParty.KEY)
|
||||
if third_party is not None:
|
||||
kwargs.pop(ThirdParty.KEY)
|
||||
model = normalize_model_input(
|
||||
model, model_revision, third_party=third_party)
|
||||
pipeline_props = {'type': pipeline_name}
|
||||
if pipeline_name is None:
|
||||
# get default pipeline for this task
|
||||
|
||||
Reference in New Issue
Block a user