rm import

This commit is contained in:
suluyan
2025-03-10 16:29:02 +08:00
parent 30d8995cd9
commit c45bc6d19f

View File

@@ -1,10 +1,6 @@
import os import os
from typing import Optional, Union from typing import Optional, Union
from transformers import Pipeline as PipelineHF
from transformers import PreTrainedModel, TFPreTrainedModel, pipeline
from transformers.pipelines import check_task, get_task
from modelscope.hub import snapshot_download from modelscope.hub import snapshot_download
from modelscope.utils.hf_util.patcher import _patch_pretrained_class from modelscope.utils.hf_util.patcher import _patch_pretrained_class
@@ -20,6 +16,7 @@ def _get_hf_device(device):
def _get_hf_pipeline_class(task, model): def _get_hf_pipeline_class(task, model):
from transformers.pipelines import check_task, get_task
if not task: if not task:
task = get_task(model) task = get_task(model)
normalized_task, targeted_task, task_options = check_task(task) normalized_task, targeted_task, task_options = check_task(task)
@@ -34,7 +31,9 @@ def hf_pipeline(
framework: Optional[str] = None, framework: Optional[str] = None,
device: Optional[Union[int, str, 'torch.device']] = None, device: Optional[Union[int, str, 'torch.device']] = None,
**kwargs, **kwargs,
) -> PipelineHF: ) -> 'transformers.Pipeline':
from transformers import pipeline
if isinstance(model, str): if isinstance(model, str):
if not os.path.exists(model): if not os.path.exists(model):
model = snapshot_download(model) model = snapshot_download(model)