mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
tmp
This commit is contained in:
@@ -54,7 +54,7 @@ if TYPE_CHECKING:
|
||||
AutoModelForMaskedLM, AutoTokenizer, AutoModelForMaskGeneration,
|
||||
AutoModelForPreTraining, AutoModelForTextEncoding,
|
||||
AutoImageProcessor, BatchFeature, Qwen2VLForConditionalGeneration,
|
||||
T5EncoderModel)
|
||||
T5EncoderModel, hf_pipeline)
|
||||
else:
|
||||
print(
|
||||
'transformer is not installed, please install it if you want to use related modules'
|
||||
@@ -131,7 +131,8 @@ else:
|
||||
'AutoModelForMaskedLM', 'AutoTokenizer',
|
||||
'AutoModelForMaskGeneration', 'AutoModelForPreTraining',
|
||||
'AutoModelForTextEncoding', 'AutoImageProcessor', 'BatchFeature',
|
||||
'Qwen2VLForConditionalGeneration', 'T5EncoderModel'
|
||||
'Qwen2VLForConditionalGeneration', 'T5EncoderModel',
|
||||
'hf_pipeline'
|
||||
]
|
||||
|
||||
import sys
|
||||
|
||||
@@ -13,10 +13,14 @@ from modelscope.utils.hub import read_config
|
||||
from modelscope.utils.plugins import (register_modelhub_repo,
|
||||
register_plugins_repo)
|
||||
from modelscope.utils.registry import Registry, build_from_cfg
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.import_utils import is_transformers_available
|
||||
|
||||
from .base import Pipeline
|
||||
from .util import is_official_hub_path
|
||||
|
||||
PIPELINES = Registry('pipelines')
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def normalize_model_input(model,
|
||||
@@ -109,6 +113,7 @@ def pipeline(task: str = None,
|
||||
if task is None and pipeline_name is None:
|
||||
raise ValueError('task or pipeline_name is required')
|
||||
|
||||
pipeline_props = None
|
||||
if pipeline_name is None:
|
||||
# get default pipeline for this task
|
||||
if isinstance(model, str) \
|
||||
@@ -157,8 +162,11 @@ def pipeline(task: str = None,
|
||||
if pipeline_name:
|
||||
pipeline_props = {'type': pipeline_name}
|
||||
else:
|
||||
check_config(cfg)
|
||||
pipeline_props = cfg.pipeline
|
||||
try:
|
||||
check_config(cfg)
|
||||
pipeline_props = cfg.pipeline
|
||||
except AssertionError as e:
|
||||
logger.info(str(e))
|
||||
|
||||
elif model is not None:
|
||||
# get pipeline info from Model object
|
||||
@@ -176,6 +184,14 @@ def pipeline(task: str = None,
|
||||
else:
|
||||
pipeline_props = {'type': pipeline_name}
|
||||
|
||||
if not pipeline_props and is_transformers_available():
|
||||
from modelscope.utils.hf_util import hf_pipeline
|
||||
return hf_pipeline(task=task,
|
||||
model=model,
|
||||
framework=framework,
|
||||
device=device,
|
||||
**kwargs)
|
||||
|
||||
pipeline_props['model'] = model
|
||||
pipeline_props['device'] = device
|
||||
cfg = ConfigDict(pipeline_props)
|
||||
|
||||
@@ -63,6 +63,7 @@ from transformers import (PretrainedConfig, PreTrainedModel,
|
||||
PreTrainedTokenizerBase)
|
||||
from transformers import T5EncoderModel as T5EncoderModelHF
|
||||
from transformers import __version__ as transformers_version
|
||||
from transformers import pipeline as hf_pipeline
|
||||
|
||||
from modelscope import snapshot_download
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke
|
||||
|
||||
Reference in New Issue
Block a user