mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
rm import
This commit is contained in:
@@ -1,10 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from transformers import Pipeline as PipelineHF
|
|
||||||
from transformers import PreTrainedModel, TFPreTrainedModel, pipeline
|
|
||||||
from transformers.pipelines import check_task, get_task
|
|
||||||
|
|
||||||
from modelscope.hub import snapshot_download
|
from modelscope.hub import snapshot_download
|
||||||
from modelscope.utils.hf_util.patcher import _patch_pretrained_class
|
from modelscope.utils.hf_util.patcher import _patch_pretrained_class
|
||||||
|
|
||||||
@@ -20,6 +16,7 @@ def _get_hf_device(device):
|
|||||||
|
|
||||||
|
|
||||||
def _get_hf_pipeline_class(task, model):
|
def _get_hf_pipeline_class(task, model):
|
||||||
|
from transformers.pipelines import check_task, get_task
|
||||||
if not task:
|
if not task:
|
||||||
task = get_task(model)
|
task = get_task(model)
|
||||||
normalized_task, targeted_task, task_options = check_task(task)
|
normalized_task, targeted_task, task_options = check_task(task)
|
||||||
@@ -34,7 +31,9 @@ def hf_pipeline(
|
|||||||
framework: Optional[str] = None,
|
framework: Optional[str] = None,
|
||||||
device: Optional[Union[int, str, 'torch.device']] = None,
|
device: Optional[Union[int, str, 'torch.device']] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> PipelineHF:
|
) -> 'transformers.Pipeline':
|
||||||
|
from transformers import pipeline
|
||||||
|
|
||||||
if isinstance(model, str):
|
if isinstance(model, str):
|
||||||
if not os.path.exists(model):
|
if not os.path.exists(model):
|
||||||
model = snapshot_download(model)
|
model = snapshot_download(model)
|
||||||
|
|||||||
Reference in New Issue
Block a user