ok Merge branch 'master' of github.com:modelscope/modelscope into merge_release_1.23

This commit is contained in:
xingjun.wxj
2025-03-11 11:13:26 +08:00

View File

@@ -1,13 +1,8 @@
import os import os
from typing import Optional, Union from typing import Optional, Union
import torch
from transformers import Pipeline as PipelineHF
from transformers import PreTrainedModel, TFPreTrainedModel, pipeline
from transformers.pipelines import check_task, get_task
from modelscope.hub import snapshot_download from modelscope.hub import snapshot_download
from modelscope.utils.hf_util.patcher import _patch_pretrained_class, patch_hub from modelscope.utils.hf_util.patcher import _patch_pretrained_class
def _get_hf_device(device): def _get_hf_device(device):
@@ -21,6 +16,7 @@ def _get_hf_device(device):
def _get_hf_pipeline_class(task, model): def _get_hf_pipeline_class(task, model):
from transformers.pipelines import check_task, get_task
if not task: if not task:
task = get_task(model) task = get_task(model)
normalized_task, targeted_task, task_options = check_task(task) normalized_task, targeted_task, task_options = check_task(task)
@@ -35,7 +31,9 @@ def hf_pipeline(
framework: Optional[str] = None, framework: Optional[str] = None,
device: Optional[Union[int, str, 'torch.device']] = None, device: Optional[Union[int, str, 'torch.device']] = None,
**kwargs, **kwargs,
) -> PipelineHF: ) -> 'transformers.Pipeline':
from transformers import pipeline
if isinstance(model, str): if isinstance(model, str):
if not os.path.exists(model): if not os.path.exists(model):
model = snapshot_download(model) model = snapshot_download(model)