mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 08:17:45 +01:00
Support trust_remote_code for pipeline and model (#1333)
This commit is contained in:
@@ -30,10 +30,23 @@ class Model(ABC):
|
||||
device_name = kwargs.get('device', 'gpu')
|
||||
verify_device(device_name)
|
||||
self._device_name = device_name
|
||||
self.trust_remote_code = kwargs.get('trust_remote_code', False)
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Dict[str, Any]:
|
||||
return self.postprocess(self.forward(*args, **kwargs))
|
||||
|
||||
def check_trust_remote_code(self, info_str: Optional[str] = None):
|
||||
"""Check trust_remote_code if the model needs to import extra libs
|
||||
|
||||
Args:
|
||||
info_str(str): The info showed to user if trust_remote_code is `False`.
|
||||
"""
|
||||
info_str = info_str or (
|
||||
'This model requires `trust_remote_code` to be `True` because it needs to '
|
||||
'import extra libs or execute the code in the model repo, setting this to true '
|
||||
'means you trust the files in it.')
|
||||
assert self.trust_remote_code, info_str
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, *args, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
|
||||
@@ -342,6 +342,8 @@ class ControlLDM(LatentDiffusion, Model):
|
||||
self.control_key = control_key
|
||||
self.only_mid_control = only_mid_control
|
||||
self.control_scales = [1.0] * 13
|
||||
self.trust_remote_code = kwargs.get('trust_remote_code', False)
|
||||
self.check_trust_remote_code()
|
||||
|
||||
@torch.no_grad()
|
||||
def get_input(self, batch, k, bs=None, *args, **kwargs):
|
||||
|
||||
@@ -68,6 +68,7 @@ class ImageViewTransform(TorchModel):
|
||||
self.model = None
|
||||
self.model = load_model_from_config(
|
||||
self.model, config, ckpt, device=self.device)
|
||||
self.check_trust_remote_code()
|
||||
|
||||
def forward(self, model_path, x, y):
|
||||
pred_results = _infer(self.model, model_path, x, y, self.device)
|
||||
|
||||
@@ -29,7 +29,7 @@ class SingleStageDetector(TorchModel):
|
||||
init model by cfg
|
||||
"""
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
|
||||
self.check_trust_remote_code()
|
||||
config_path = osp.join(model_dir, self.config_name)
|
||||
config = parse_config(config_path)
|
||||
self.cfg = config
|
||||
|
||||
@@ -33,6 +33,7 @@ class DROEstimation(TorchModel):
|
||||
def __init__(self, model_dir: str, **kwargs):
|
||||
"""str -- model file root."""
|
||||
super().__init__(model_dir, **kwargs)
|
||||
self.check_trust_remote_code()
|
||||
|
||||
model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE)
|
||||
|
||||
|
||||
@@ -67,6 +67,10 @@ class LinearAECPipeline(Pipeline):
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
self.check_trust_remote_code(
|
||||
'This pipeline requires `trust_remote_code=True` to load the module defined'
|
||||
' in the `dey_mini.yaml`, setting this to True means you trust the code and files'
|
||||
' listed in this model repo.')
|
||||
|
||||
self.use_cuda = torch.cuda.is_available()
|
||||
with open(
|
||||
|
||||
@@ -7,7 +7,7 @@ from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from multiprocessing import Pool
|
||||
from threading import Lock
|
||||
from typing import Any, Dict, Generator, List, Mapping, Union
|
||||
from typing import Any, Dict, Generator, List, Mapping, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from packaging import version
|
||||
@@ -45,6 +45,8 @@ class Pipeline(ABC):
|
||||
"""
|
||||
|
||||
def initiate_single_model(self, model, **kwargs):
|
||||
if self.trust_remote_code:
|
||||
kwargs['trust_remote_code'] = True
|
||||
if isinstance(model, str):
|
||||
logger.info(f'initiate model from {model}')
|
||||
if isinstance(model, str) and is_official_hub_path(model):
|
||||
@@ -95,6 +97,7 @@ class Pipeline(ABC):
|
||||
self.device_map = device_map
|
||||
verify_device(device)
|
||||
self.device_name = device
|
||||
self.trust_remote_code = kwargs.get('trust_remote_code', False)
|
||||
|
||||
if not isinstance(model, List):
|
||||
self.model = self.initiate_single_model(model, **kwargs)
|
||||
@@ -133,6 +136,18 @@ class Pipeline(ABC):
|
||||
self._compile = kwargs.get('compile', False)
|
||||
self._compile_options = kwargs.get('compile_options', {})
|
||||
|
||||
def check_trust_remote_code(self, info_str: Optional[str] = None):
|
||||
"""Check trust_remote_code if the pipeline needs to import extra libs
|
||||
|
||||
Args:
|
||||
info_str(str): The info showed to user if trust_remote_code is `False`.
|
||||
"""
|
||||
info_str = info_str or (
|
||||
'This pipeline requires `trust_remote_code` to be `True` because it needs to '
|
||||
'import extra libs or execute the code in the model repo, setting this to true '
|
||||
'means you trust the files in it.')
|
||||
assert self.trust_remote_code, info_str
|
||||
|
||||
def prepare_model(self):
|
||||
""" Place model on certain device for pytorch models before first inference
|
||||
"""
|
||||
|
||||
@@ -63,7 +63,8 @@ class PedestrainAttributeRecognitionPipeline(Pipeline):
|
||||
self.human_detect_model_id = 'damo/cv_tinynas_human-detection_damoyolo'
|
||||
self.human_detector = pipeline(
|
||||
Tasks.domain_specific_object_detection,
|
||||
model=self.human_detect_model_id)
|
||||
model=self.human_detect_model_id,
|
||||
trust_remote_code=kwargs.get('trust_remote_code', False))
|
||||
|
||||
def get_labels(self, outputs, thres=0.5):
|
||||
gender = outputs[0][0:1]
|
||||
|
||||
@@ -189,6 +189,11 @@ class DiscoDiffusionPipeline(DiffusersPipeline):
|
||||
"""
|
||||
|
||||
super().__init__(model, device, **kwargs)
|
||||
self.trust_remote_code = kwargs.get('trust_remote_code', False)
|
||||
self.check_trust_remote_code(
|
||||
'This pipeline requires `trust_remote_code=True` to load the module defined'
|
||||
' in `model_index.json`, setting this to True means you trust the code and files'
|
||||
' listed in this model repo.')
|
||||
|
||||
model_path = model
|
||||
|
||||
@@ -204,6 +209,12 @@ class DiscoDiffusionPipeline(DiffusersPipeline):
|
||||
if model_config['use_fp16']:
|
||||
self.unet.convert_to_fp16()
|
||||
|
||||
self.trust_remote_code = kwargs.get('trust_remote_code', False)
|
||||
self.check_trust_remote_code(
|
||||
'This pipeline requires import modules listed in `model_index.json`, '
|
||||
'please add `trust_remote_code=True` if you trust this model repo.'
|
||||
)
|
||||
|
||||
with open(
|
||||
os.path.join(model_path, 'model_index.json'),
|
||||
'r',
|
||||
|
||||
@@ -23,7 +23,7 @@ class TestExportObjectDetectionDamoyolo(unittest.TestCase):
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_export_object_detection_damoyolo(self):
|
||||
|
||||
model = Model.from_pretrained(self.model_id)
|
||||
model = Model.from_pretrained(self.model_id, trust_remote_code=True)
|
||||
Exporter.from_model(model).export_onnx(
|
||||
input_shape=(1, 3, 640, 640), output_dir=self.tmp_dir)
|
||||
|
||||
@@ -31,7 +31,7 @@ class TestExportObjectDetectionDamoyolo(unittest.TestCase):
|
||||
def test_export_domain_specific_object_detection_damoyolo(self):
|
||||
|
||||
model_id = 'damo/cv_tinynas_human-detection_damoyolo'
|
||||
model = Model.from_pretrained(model_id)
|
||||
model = Model.from_pretrained(model_id, trust_remote_code=True)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
Exporter.from_model(model).export_onnx(
|
||||
input_shape=(1, 3, 640, 640), output_dir=tmp_dir)
|
||||
|
||||
@@ -22,7 +22,7 @@ class AnydoorTest(unittest.TestCase):
|
||||
save_path = 'data/test/images/image_anydoor_gen.png'
|
||||
|
||||
anydoor_pipline: AnydoorPipeline = pipeline(
|
||||
self.task, model=self.model_id)
|
||||
self.task, model=self.model_id, trust_remote_code=True)
|
||||
out = anydoor_pipline((ref_image, ref_mask, bg_image, bg_mask))
|
||||
image = out['output_img']
|
||||
image.save(save_path)
|
||||
|
||||
@@ -21,7 +21,10 @@ class DiscoGuidedDiffusionTest(unittest.TestCase):
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run(self):
|
||||
diffusers_pipeline = pipeline(
|
||||
task=self.task, model=self.model_id1, model_revision='v1.0')
|
||||
task=self.task,
|
||||
model=self.model_id1,
|
||||
model_revision='v1.0',
|
||||
trust_remote_code=True)
|
||||
output = diffusers_pipeline({
|
||||
'text': self.test_input1,
|
||||
'height': 256,
|
||||
@@ -31,7 +34,10 @@ class DiscoGuidedDiffusionTest(unittest.TestCase):
|
||||
print('Image saved to output1.png')
|
||||
|
||||
diffusers_pipeline = pipeline(
|
||||
task=self.task, model=self.model_id2, model_revision='v1.0')
|
||||
task=self.task,
|
||||
model=self.model_id2,
|
||||
model_revision='v1.0',
|
||||
trust_remote_code=True)
|
||||
output = diffusers_pipeline({
|
||||
'text': self.test_input2,
|
||||
'height': 256,
|
||||
|
||||
@@ -16,7 +16,7 @@ class ImageControl3dPortraitTest(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.model_id = 'damo/cv_vit_image-control-3d-portrait-synthesis'
|
||||
self.test_image = 'data/test/images/image_control_3d_portrait.jpg'
|
||||
self.test_image = '/mnt/nas3/yzhao/1.jpg'
|
||||
self.save_dir = 'exp'
|
||||
os.makedirs(self.save_dir, exist_ok=True)
|
||||
|
||||
|
||||
@@ -36,12 +36,16 @@ class ImageViewTransformTest(unittest.TestCase):
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_modelhub(self):
|
||||
image_view_transform = pipeline(
|
||||
Tasks.image_view_transform, model=self.model_id, revision='v1.0.3')
|
||||
Tasks.image_view_transform,
|
||||
model=self.model_id,
|
||||
revision='v1.0.3',
|
||||
trust_remote_code=True)
|
||||
self.pipeline_inference(image_view_transform, self.input)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_modelhub_default_model(self):
|
||||
image_view_transform = pipeline(Tasks.image_view_transform)
|
||||
image_view_transform = pipeline(
|
||||
Tasks.image_view_transform, trust_remote_code=True)
|
||||
self.pipeline_inference(image_view_transform, self.input)
|
||||
|
||||
|
||||
|
||||
@@ -26,14 +26,14 @@ class PedestrianAttributeRecognitionTest(unittest.TestCase):
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_modelhub_with_image_file(self):
|
||||
pedestrian_attribute_recognition = pipeline(
|
||||
self.task, model=self.model_id)
|
||||
self.task, model=self.model_id, trust_remote_code=True)
|
||||
self.pipeline_inference(pedestrian_attribute_recognition,
|
||||
self.test_image)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_modelhub_with_image_input(self):
|
||||
pedestrian_attribute_recognition = pipeline(
|
||||
self.task, model=self.model_id)
|
||||
self.task, model=self.model_id, trust_remote_code=True)
|
||||
self.pipeline_inference(pedestrian_attribute_recognition,
|
||||
Image.open(self.test_image))
|
||||
|
||||
|
||||
@@ -19,7 +19,9 @@ class TinynasObjectDetectionTest(unittest.TestCase):
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_airdet(self):
|
||||
tinynas_object_detection = pipeline(
|
||||
Tasks.image_object_detection, model='damo/cv_tinynas_detection')
|
||||
Tasks.image_object_detection,
|
||||
model='damo/cv_tinynas_detection',
|
||||
trust_remote_code=True)
|
||||
result = tinynas_object_detection(
|
||||
'data/test/images/image_detection.jpg')
|
||||
print('airdet', result)
|
||||
@@ -28,7 +30,8 @@ class TinynasObjectDetectionTest(unittest.TestCase):
|
||||
def test_run_damoyolo(self):
|
||||
tinynas_object_detection = pipeline(
|
||||
Tasks.image_object_detection,
|
||||
model='damo/cv_tinynas_object-detection_damoyolo')
|
||||
model='damo/cv_tinynas_object-detection_damoyolo',
|
||||
trust_remote_code=True)
|
||||
result = tinynas_object_detection(
|
||||
'data/test/images/image_detection.jpg')
|
||||
print('damoyolo-s', result)
|
||||
@@ -37,7 +40,8 @@ class TinynasObjectDetectionTest(unittest.TestCase):
|
||||
def test_run_damoyolo_m(self):
|
||||
tinynas_object_detection = pipeline(
|
||||
Tasks.image_object_detection,
|
||||
model='damo/cv_tinynas_object-detection_damoyolo-m')
|
||||
model='damo/cv_tinynas_object-detection_damoyolo-m',
|
||||
trust_remote_code=True)
|
||||
result = tinynas_object_detection(
|
||||
'data/test/images/image_detection.jpg')
|
||||
print('damoyolo-m', result)
|
||||
@@ -46,7 +50,8 @@ class TinynasObjectDetectionTest(unittest.TestCase):
|
||||
def test_run_damoyolo_t(self):
|
||||
tinynas_object_detection = pipeline(
|
||||
Tasks.image_object_detection,
|
||||
model='damo/cv_tinynas_object-detection_damoyolo-t')
|
||||
model='damo/cv_tinynas_object-detection_damoyolo-t',
|
||||
trust_remote_code=True)
|
||||
result = tinynas_object_detection(
|
||||
'data/test/images/image_detection.jpg')
|
||||
print('damoyolo-t', result)
|
||||
@@ -56,7 +61,8 @@ class TinynasObjectDetectionTest(unittest.TestCase):
|
||||
test_image = 'data/test/images/image_detection.jpg'
|
||||
tinynas_object_detection = pipeline(
|
||||
Tasks.image_object_detection,
|
||||
model='damo/cv_tinynas_object-detection_damoyolo-m')
|
||||
model='damo/cv_tinynas_object-detection_damoyolo-m',
|
||||
trust_remote_code=True)
|
||||
result = tinynas_object_detection(test_image)
|
||||
tinynas_object_detection.show_result(test_image, result,
|
||||
'demo_ret.jpg')
|
||||
@@ -65,7 +71,8 @@ class TinynasObjectDetectionTest(unittest.TestCase):
|
||||
def test_human_detection_damoyolo(self):
|
||||
tinynas_object_detection = pipeline(
|
||||
Tasks.domain_specific_object_detection,
|
||||
model='damo/cv_tinynas_human-detection_damoyolo')
|
||||
model='damo/cv_tinynas_human-detection_damoyolo',
|
||||
trust_remote_code=True)
|
||||
result = tinynas_object_detection(
|
||||
'data/test/images/image_detection.jpg')
|
||||
assert result and (OutputKeys.SCORES in result) and (
|
||||
@@ -76,7 +83,8 @@ class TinynasObjectDetectionTest(unittest.TestCase):
|
||||
def test_human_detection_damoyolo_with_image(self):
|
||||
tinynas_object_detection = pipeline(
|
||||
Tasks.domain_specific_object_detection,
|
||||
model='damo/cv_tinynas_human-detection_damoyolo')
|
||||
model='damo/cv_tinynas_human-detection_damoyolo',
|
||||
trust_remote_code=True)
|
||||
img = Image.open('data/test/images/image_detection.jpg')
|
||||
result = tinynas_object_detection(img)
|
||||
assert result and (OutputKeys.SCORES in result) and (
|
||||
@@ -87,7 +95,8 @@ class TinynasObjectDetectionTest(unittest.TestCase):
|
||||
def test_facemask_detection_damoyolo(self):
|
||||
tinynas_object_detection = pipeline(
|
||||
Tasks.domain_specific_object_detection,
|
||||
model='damo/cv_tinynas_object-detection_damoyolo_facemask')
|
||||
model='damo/cv_tinynas_object-detection_damoyolo_facemask',
|
||||
trust_remote_code=True)
|
||||
result = tinynas_object_detection(
|
||||
'data/test/images/image_detection.jpg')
|
||||
assert result and (OutputKeys.SCORES in result) and (
|
||||
@@ -98,7 +107,8 @@ class TinynasObjectDetectionTest(unittest.TestCase):
|
||||
def test_facemask_detection_damoyolo_with_image(self):
|
||||
tinynas_object_detection = pipeline(
|
||||
Tasks.domain_specific_object_detection,
|
||||
model='damo/cv_tinynas_object-detection_damoyolo_facemask')
|
||||
model='damo/cv_tinynas_object-detection_damoyolo_facemask',
|
||||
trust_remote_code=True)
|
||||
img = Image.open('data/test/images/image_detection.jpg')
|
||||
result = tinynas_object_detection(img)
|
||||
assert result and (OutputKeys.SCORES in result) and (
|
||||
@@ -109,7 +119,8 @@ class TinynasObjectDetectionTest(unittest.TestCase):
|
||||
def test_safetyhat_detection_damoyolo(self):
|
||||
tinynas_object_detection = pipeline(
|
||||
Tasks.domain_specific_object_detection,
|
||||
model='damo/cv_tinynas_object-detection_damoyolo_safety-helmet')
|
||||
model='damo/cv_tinynas_object-detection_damoyolo_safety-helmet',
|
||||
trust_remote_code=True)
|
||||
result = tinynas_object_detection(
|
||||
'data/test/images/image_safetyhat.jpg')
|
||||
assert result and (OutputKeys.SCORES in result) and (
|
||||
@@ -120,7 +131,8 @@ class TinynasObjectDetectionTest(unittest.TestCase):
|
||||
def test_safetyhat_detection_damoyolo_with_image(self):
|
||||
tinynas_object_detection = pipeline(
|
||||
Tasks.domain_specific_object_detection,
|
||||
model='damo/cv_tinynas_object-detection_damoyolo_safety-helmet')
|
||||
model='damo/cv_tinynas_object-detection_damoyolo_safety-helmet',
|
||||
trust_remote_code=True)
|
||||
img = Image.open('data/test/images/image_safetyhat.jpg')
|
||||
result = tinynas_object_detection(img)
|
||||
assert result and (OutputKeys.SCORES in result) and (
|
||||
@@ -131,7 +143,8 @@ class TinynasObjectDetectionTest(unittest.TestCase):
|
||||
def test_cigarette_detection_damoyolo(self):
|
||||
tinynas_object_detection = pipeline(
|
||||
Tasks.domain_specific_object_detection,
|
||||
model='damo/cv_tinynas_object-detection_damoyolo_cigarette')
|
||||
model='damo/cv_tinynas_object-detection_damoyolo_cigarette',
|
||||
trust_remote_code=True)
|
||||
result = tinynas_object_detection('data/test/images/image_smoke.jpg')
|
||||
assert result and (OutputKeys.SCORES in result) and (
|
||||
OutputKeys.LABELS in result) and (OutputKeys.BOXES in result)
|
||||
@@ -141,7 +154,8 @@ class TinynasObjectDetectionTest(unittest.TestCase):
|
||||
def test_cigarette_detection_damoyolo_with_image(self):
|
||||
tinynas_object_detection = pipeline(
|
||||
Tasks.domain_specific_object_detection,
|
||||
model='damo/cv_tinynas_object-detection_damoyolo_cigarette')
|
||||
model='damo/cv_tinynas_object-detection_damoyolo_cigarette',
|
||||
trust_remote_code=True)
|
||||
img = Image.open('data/test/images/image_smoke.jpg')
|
||||
result = tinynas_object_detection(img)
|
||||
assert result and (OutputKeys.SCORES in result) and (
|
||||
@@ -152,7 +166,8 @@ class TinynasObjectDetectionTest(unittest.TestCase):
|
||||
def test_phone_detection_damoyolo(self):
|
||||
tinynas_object_detection = pipeline(
|
||||
Tasks.domain_specific_object_detection,
|
||||
model='damo/cv_tinynas_object-detection_damoyolo_phone')
|
||||
model='damo/cv_tinynas_object-detection_damoyolo_phone',
|
||||
trust_remote_code=True)
|
||||
result = tinynas_object_detection('data/test/images/image_phone.jpg')
|
||||
assert result and (OutputKeys.SCORES in result) and (
|
||||
OutputKeys.LABELS in result) and (OutputKeys.BOXES in result)
|
||||
@@ -162,7 +177,8 @@ class TinynasObjectDetectionTest(unittest.TestCase):
|
||||
def test_phone_detection_damoyolo_with_image(self):
|
||||
tinynas_object_detection = pipeline(
|
||||
Tasks.domain_specific_object_detection,
|
||||
model='damo/cv_tinynas_object-detection_damoyolo_phone')
|
||||
model='damo/cv_tinynas_object-detection_damoyolo_phone',
|
||||
trust_remote_code=True)
|
||||
img = Image.open('data/test/images/image_phone.jpg')
|
||||
result = tinynas_object_detection(img)
|
||||
assert result and (OutputKeys.SCORES in result) and (
|
||||
@@ -173,7 +189,8 @@ class TinynasObjectDetectionTest(unittest.TestCase):
|
||||
def test_head_detection_damoyolo(self):
|
||||
tinynas_object_detection = pipeline(
|
||||
Tasks.domain_specific_object_detection,
|
||||
model='damo/cv_tinynas_head-detection_damoyolo')
|
||||
model='damo/cv_tinynas_head-detection_damoyolo',
|
||||
trust_remote_code=True)
|
||||
result = tinynas_object_detection(
|
||||
'data/test/images/image_detection.jpg')
|
||||
assert result and (OutputKeys.SCORES in result) and (
|
||||
@@ -184,7 +201,8 @@ class TinynasObjectDetectionTest(unittest.TestCase):
|
||||
def test_head_detection_damoyolo_with_image(self):
|
||||
tinynas_object_detection = pipeline(
|
||||
Tasks.domain_specific_object_detection,
|
||||
model='damo/cv_tinynas_head-detection_damoyolo')
|
||||
model='damo/cv_tinynas_head-detection_damoyolo',
|
||||
trust_remote_code=True)
|
||||
img = Image.open('data/test/images/image_detection.jpg')
|
||||
result = tinynas_object_detection(img)
|
||||
assert result and (OutputKeys.SCORES in result) and (
|
||||
@@ -195,7 +213,8 @@ class TinynasObjectDetectionTest(unittest.TestCase):
|
||||
def test_smokefire_detection_damoyolo(self):
|
||||
tinynas_object_detection = pipeline(
|
||||
Tasks.domain_specific_object_detection,
|
||||
model='damo/cv_tinynas_object-detection_damoyolo_smokefire')
|
||||
model='damo/cv_tinynas_object-detection_damoyolo_smokefire',
|
||||
trust_remote_code=True)
|
||||
result = tinynas_object_detection(
|
||||
'data/test/images/image_smokefire_detection.jpg')
|
||||
assert result and (OutputKeys.SCORES in result) and (
|
||||
@@ -206,7 +225,8 @@ class TinynasObjectDetectionTest(unittest.TestCase):
|
||||
def test_smokefire_detection_damoyolo_with_image(self):
|
||||
tinynas_object_detection = pipeline(
|
||||
Tasks.domain_specific_object_detection,
|
||||
model='damo/cv_tinynas_object-detection_damoyolo_smokefire')
|
||||
model='damo/cv_tinynas_object-detection_damoyolo_smokefire',
|
||||
trust_remote_code=True)
|
||||
img = Image.open('data/test/images/image_smokefire_detection.jpg')
|
||||
result = tinynas_object_detection(img)
|
||||
assert result and (OutputKeys.SCORES in result) and (
|
||||
|
||||
@@ -20,7 +20,8 @@ class TrafficSignDetectionTest(unittest.TestCase):
|
||||
def test_traffic_sign_detection_damoyolo(self):
|
||||
tinynas_object_detection = pipeline(
|
||||
Tasks.domain_specific_object_detection,
|
||||
model='damo/cv_tinynas_object-detection_damoyolo_traffic_sign')
|
||||
model='damo/cv_tinynas_object-detection_damoyolo_traffic_sign',
|
||||
trust_remote_code=True)
|
||||
result = tinynas_object_detection(
|
||||
'data/test/images/image_traffic_sign.jpg')
|
||||
assert result and (OutputKeys.SCORES in result) and (
|
||||
|
||||
@@ -17,7 +17,10 @@ class VideoDepthEstimationTest(unittest.TestCase):
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_video_depth_estimation(self):
|
||||
input_location = 'data/test/videos/video_depth_estimation.mp4'
|
||||
estimator = pipeline(Tasks.video_depth_estimation, model=self.model_id)
|
||||
estimator = pipeline(
|
||||
Tasks.video_depth_estimation,
|
||||
model=self.model_id,
|
||||
trust_remote_code=True)
|
||||
result = estimator(input_location)
|
||||
show_video_depth_estimation_result(result[OutputKeys.DEPTHS_COLOR],
|
||||
'out.mp4')
|
||||
|
||||
Reference in New Issue
Block a user