fix text-gen: read pipeline type from configuration.json first

This commit is contained in:
suluyana
2025-01-10 15:11:01 +08:00
parent ba15de012f
commit 0cb31263ab
2 changed files with 38 additions and 30 deletions

View File

@@ -108,30 +108,7 @@ def pipeline(task: str = None,
"""
if task is None and pipeline_name is None:
raise ValueError('task or pipeline_name is required')
prefer_llm_pipeline = kwargs.get('external_engine_for_llm')
if task is not None and task.lower() in [
Tasks.text_generation, Tasks.chat
]:
# if not specified, prefer llm pipeline for aforementioned tasks
if prefer_llm_pipeline is None:
prefer_llm_pipeline = True
# for llm pipeline, if llm_framework is not specified, default to swift instead
# TODO: port the swift infer based on transformer into ModelScope
if prefer_llm_pipeline and kwargs.get('llm_framework') is None:
kwargs['llm_framework'] = 'swift'
third_party = kwargs.get(ThirdParty.KEY)
if third_party is not None:
kwargs.pop(ThirdParty.KEY)
if pipeline_name is None and prefer_llm_pipeline:
pipeline_name = external_engine_for_llm_checker(
model, model_revision, kwargs)
if pipeline_name is None:
model = normalize_model_input(
model,
model_revision,
third_party=third_party,
ignore_file_pattern=ignore_file_pattern)
pipeline_props = {'type': pipeline_name}
if pipeline_name is None:
# get default pipeline for this task
if isinstance(model, str) \
@@ -142,16 +119,45 @@ def pipeline(task: str = None,
model, revision=model_revision) if isinstance(
model, str) else read_config(
model[0], revision=model_revision)
register_plugins_repo(cfg.safe_get('plugins'))
register_modelhub_repo(model, cfg.get('allow_remote', False))
pipeline_name = external_engine_for_llm_checker(
model, model_revision,
kwargs) if prefer_llm_pipeline else None
if pipeline_name is not None:
if cfg:
pipeline_name = cfg.safe_get('pipeline', {}).get('type', None)
if pipeline_name is None:
prefer_llm_pipeline = kwargs.get('external_engine_for_llm')
# if not specified in both args and configuration.json, prefer llm pipeline for aforementioned tasks
if task is not None and task.lower() in [
Tasks.text_generation, Tasks.chat
]:
if prefer_llm_pipeline is None:
prefer_llm_pipeline = True
# for llm pipeline, if llm_framework is not specified, default to swift instead
# TODO: port the swift infer based on transformer into ModelScope
if prefer_llm_pipeline:
if kwargs.get('llm_framework') is None:
kwargs['llm_framework'] = 'swift'
pipeline_name = external_engine_for_llm_checker(
model, model_revision, kwargs)
if pipeline_name is None or pipeline_name != 'llm':
third_party = kwargs.get(ThirdParty.KEY)
if third_party is not None:
kwargs.pop(ThirdParty.KEY)
model = normalize_model_input(
model,
model_revision,
third_party=third_party,
ignore_file_pattern=ignore_file_pattern)
register_plugins_repo(cfg.safe_get('plugins'))
register_modelhub_repo(model, cfg.get('allow_remote', False))
if pipeline_name:
pipeline_props = {'type': pipeline_name}
else:
check_config(cfg)
pipeline_props = cfg.pipeline
elif model is not None:
# get pipeline info from Model object
first_model = model[0] if isinstance(model, list) else model

View File

@@ -54,6 +54,8 @@ def read_config(model_id_or_path: str,
local_path = os.path.join(model_id_or_path, ModelFile.CONFIGURATION)
elif os.path.isfile(model_id_or_path):
local_path = model_id_or_path
else:
return None
return Config.from_file(local_path)