diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index dd39453c..ca1431ea 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -7,7 +7,8 @@ from modelscope.hub.snapshot_download import snapshot_download from modelscope.metainfo import DEFAULT_MODEL_FOR_PIPELINE, Pipelines from modelscope.models.base import Model from modelscope.utils.config import ConfigDict, check_config -from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke +from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, Invoke, + ThirdParty) from modelscope.utils.hub import read_config from modelscope.utils.plugins import (register_modelhub_repo, register_plugins_repo) @@ -18,7 +19,7 @@ from .util import is_official_hub_path PIPELINES = Registry('pipelines') -def normalize_model_input(model, model_revision): +def normalize_model_input(model, model_revision, third_party=None): """ normalize the input model, to ensure that a model str is a valid local path: in other words, for model represented by a model id, the model shall be downloaded locally """ @@ -26,19 +27,21 @@ def normalize_model_input(model, model_revision): # skip revision download if model is a local directory if not os.path.exists(model): # note that if there is already a local copy, snapshot_download will check and skip downloading + user_agent = {Invoke.KEY: Invoke.PIPELINE} + if third_party is not None: + user_agent[ThirdParty.KEY] = third_party model = snapshot_download( - model, - revision=model_revision, - user_agent={Invoke.KEY: Invoke.PIPELINE}) + model, revision=model_revision, user_agent=user_agent) elif isinstance(model, list) and isinstance(model[0], str): for idx in range(len(model)): if is_official_hub_path( model[idx], model_revision) and not os.path.exists(model[idx]): + user_agent = {Invoke.KEY: Invoke.PIPELINE} + if third_party is not None: + user_agent[ThirdParty.KEY] = third_party model[idx] = snapshot_download( - model[idx], - revision=model_revision, - user_agent={Invoke.KEY: Invoke.PIPELINE}) + model[idx], revision=model_revision, user_agent=user_agent) return model @@ -97,7 +100,11 @@ def pipeline(task: str = None, if task is None and pipeline_name is None: raise ValueError('task or pipeline_name is required') - model = normalize_model_input(model, model_revision) + third_party = kwargs.get(ThirdParty.KEY) + if third_party is not None: + kwargs.pop(ThirdParty.KEY) + model = normalize_model_input( + model, model_revision, third_party=third_party) pipeline_props = {'type': pipeline_name} if pipeline_name is None: # get default pipeline for this task