Support trust_remote_code for pipeline and model (#1333)

This commit is contained in:
tastelikefeet
2025-05-13 22:52:57 +08:00
committed by GitHub
parent c30bfeb777
commit 75d54927e1
18 changed files with 116 additions and 34 deletions

View File

@@ -30,10 +30,23 @@ class Model(ABC):
device_name = kwargs.get('device', 'gpu')
verify_device(device_name)
self._device_name = device_name
self.trust_remote_code = kwargs.get('trust_remote_code', False)
def __call__(self, *args, **kwargs) -> Dict[str, Any]:
return self.postprocess(self.forward(*args, **kwargs))
def check_trust_remote_code(self, info_str: Optional[str] = None):
"""Check trust_remote_code if the model needs to import extra libs
Args:
info_str(str): The info showed to user if trust_remote_code is `False`.
"""
info_str = info_str or (
'This model requires `trust_remote_code` to be `True` because it needs to '
'import extra libs or execute the code in the model repo, setting this to true '
'means you trust the files in it.')
assert self.trust_remote_code, info_str
@abstractmethod
def forward(self, *args, **kwargs) -> Dict[str, Any]:
"""

View File

@@ -342,6 +342,8 @@ class ControlLDM(LatentDiffusion, Model):
self.control_key = control_key
self.only_mid_control = only_mid_control
self.control_scales = [1.0] * 13
self.trust_remote_code = kwargs.get('trust_remote_code', False)
self.check_trust_remote_code()
@torch.no_grad()
def get_input(self, batch, k, bs=None, *args, **kwargs):

View File

@@ -68,6 +68,7 @@ class ImageViewTransform(TorchModel):
self.model = None
self.model = load_model_from_config(
self.model, config, ckpt, device=self.device)
self.check_trust_remote_code()
def forward(self, model_path, x, y):
pred_results = _infer(self.model, model_path, x, y, self.device)

View File

@@ -29,7 +29,7 @@ class SingleStageDetector(TorchModel):
init model by cfg
"""
super().__init__(model_dir, *args, **kwargs)
self.check_trust_remote_code()
config_path = osp.join(model_dir, self.config_name)
config = parse_config(config_path)
self.cfg = config

View File

@@ -33,6 +33,7 @@ class DROEstimation(TorchModel):
def __init__(self, model_dir: str, **kwargs):
"""str -- model file root."""
super().__init__(model_dir, **kwargs)
self.check_trust_remote_code()
model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE)

View File

@@ -67,6 +67,10 @@ class LinearAECPipeline(Pipeline):
model: model id on modelscope hub.
"""
super().__init__(model=model, **kwargs)
self.check_trust_remote_code(
'This pipeline requires `trust_remote_code=True` to load the module defined'
' in the `dey_mini.yaml`, setting this to True means you trust the code and files'
' listed in this model repo.')
self.use_cuda = torch.cuda.is_available()
with open(

View File

@@ -7,7 +7,7 @@ from abc import ABC, abstractmethod
from functools import partial
from multiprocessing import Pool
from threading import Lock
from typing import Any, Dict, Generator, List, Mapping, Union
from typing import Any, Dict, Generator, List, Mapping, Optional, Union
import numpy as np
from packaging import version
@@ -45,6 +45,8 @@ class Pipeline(ABC):
"""
def initiate_single_model(self, model, **kwargs):
if self.trust_remote_code:
kwargs['trust_remote_code'] = True
if isinstance(model, str):
logger.info(f'initiate model from {model}')
if isinstance(model, str) and is_official_hub_path(model):
@@ -95,6 +97,7 @@ class Pipeline(ABC):
self.device_map = device_map
verify_device(device)
self.device_name = device
self.trust_remote_code = kwargs.get('trust_remote_code', False)
if not isinstance(model, List):
self.model = self.initiate_single_model(model, **kwargs)
@@ -133,6 +136,18 @@ class Pipeline(ABC):
self._compile = kwargs.get('compile', False)
self._compile_options = kwargs.get('compile_options', {})
def check_trust_remote_code(self, info_str: Optional[str] = None):
"""Check trust_remote_code if the pipeline needs to import extra libs
Args:
info_str(str): The info showed to user if trust_remote_code is `False`.
"""
info_str = info_str or (
'This pipeline requires `trust_remote_code` to be `True` because it needs to '
'import extra libs or execute the code in the model repo, setting this to true '
'means you trust the files in it.')
assert self.trust_remote_code, info_str
def prepare_model(self):
""" Place model on certain device for pytorch models before first inference
"""

View File

@@ -63,7 +63,8 @@ class PedestrainAttributeRecognitionPipeline(Pipeline):
self.human_detect_model_id = 'damo/cv_tinynas_human-detection_damoyolo'
self.human_detector = pipeline(
Tasks.domain_specific_object_detection,
model=self.human_detect_model_id)
model=self.human_detect_model_id,
trust_remote_code=kwargs.get('trust_remote_code', False))
def get_labels(self, outputs, thres=0.5):
gender = outputs[0][0:1]

View File

@@ -189,6 +189,11 @@ class DiscoDiffusionPipeline(DiffusersPipeline):
"""
super().__init__(model, device, **kwargs)
self.trust_remote_code = kwargs.get('trust_remote_code', False)
self.check_trust_remote_code(
'This pipeline requires `trust_remote_code=True` to load the module defined'
' in `model_index.json`, setting this to True means you trust the code and files'
' listed in this model repo.')
model_path = model
@@ -204,6 +209,12 @@ class DiscoDiffusionPipeline(DiffusersPipeline):
if model_config['use_fp16']:
self.unet.convert_to_fp16()
self.trust_remote_code = kwargs.get('trust_remote_code', False)
self.check_trust_remote_code(
'This pipeline requires import modules listed in `model_index.json`, '
'please add `trust_remote_code=True` if you trust this model repo.'
)
with open(
os.path.join(model_path, 'model_index.json'),
'r',

View File

@@ -23,7 +23,7 @@ class TestExportObjectDetectionDamoyolo(unittest.TestCase):
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_export_object_detection_damoyolo(self):
model = Model.from_pretrained(self.model_id)
model = Model.from_pretrained(self.model_id, trust_remote_code=True)
Exporter.from_model(model).export_onnx(
input_shape=(1, 3, 640, 640), output_dir=self.tmp_dir)
@@ -31,7 +31,7 @@ class TestExportObjectDetectionDamoyolo(unittest.TestCase):
def test_export_domain_specific_object_detection_damoyolo(self):
model_id = 'damo/cv_tinynas_human-detection_damoyolo'
model = Model.from_pretrained(model_id)
model = Model.from_pretrained(model_id, trust_remote_code=True)
with tempfile.TemporaryDirectory() as tmp_dir:
Exporter.from_model(model).export_onnx(
input_shape=(1, 3, 640, 640), output_dir=tmp_dir)

View File

@@ -22,7 +22,7 @@ class AnydoorTest(unittest.TestCase):
save_path = 'data/test/images/image_anydoor_gen.png'
anydoor_pipline: AnydoorPipeline = pipeline(
self.task, model=self.model_id)
self.task, model=self.model_id, trust_remote_code=True)
out = anydoor_pipline((ref_image, ref_mask, bg_image, bg_mask))
image = out['output_img']
image.save(save_path)

View File

@@ -21,7 +21,10 @@ class DiscoGuidedDiffusionTest(unittest.TestCase):
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run(self):
diffusers_pipeline = pipeline(
task=self.task, model=self.model_id1, model_revision='v1.0')
task=self.task,
model=self.model_id1,
model_revision='v1.0',
trust_remote_code=True)
output = diffusers_pipeline({
'text': self.test_input1,
'height': 256,
@@ -31,7 +34,10 @@ class DiscoGuidedDiffusionTest(unittest.TestCase):
print('Image saved to output1.png')
diffusers_pipeline = pipeline(
task=self.task, model=self.model_id2, model_revision='v1.0')
task=self.task,
model=self.model_id2,
model_revision='v1.0',
trust_remote_code=True)
output = diffusers_pipeline({
'text': self.test_input2,
'height': 256,

View File

@@ -16,7 +16,7 @@ class ImageControl3dPortraitTest(unittest.TestCase):
def setUp(self) -> None:
self.model_id = 'damo/cv_vit_image-control-3d-portrait-synthesis'
self.test_image = 'data/test/images/image_control_3d_portrait.jpg'
self.test_image = '/mnt/nas3/yzhao/1.jpg'
self.save_dir = 'exp'
os.makedirs(self.save_dir, exist_ok=True)

View File

@@ -36,12 +36,16 @@ class ImageViewTransformTest(unittest.TestCase):
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub(self):
image_view_transform = pipeline(
Tasks.image_view_transform, model=self.model_id, revision='v1.0.3')
Tasks.image_view_transform,
model=self.model_id,
revision='v1.0.3',
trust_remote_code=True)
self.pipeline_inference(image_view_transform, self.input)
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_modelhub_default_model(self):
image_view_transform = pipeline(Tasks.image_view_transform)
image_view_transform = pipeline(
Tasks.image_view_transform, trust_remote_code=True)
self.pipeline_inference(image_view_transform, self.input)

View File

@@ -26,14 +26,14 @@ class PedestrianAttributeRecognitionTest(unittest.TestCase):
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub_with_image_file(self):
pedestrian_attribute_recognition = pipeline(
self.task, model=self.model_id)
self.task, model=self.model_id, trust_remote_code=True)
self.pipeline_inference(pedestrian_attribute_recognition,
self.test_image)
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub_with_image_input(self):
pedestrian_attribute_recognition = pipeline(
self.task, model=self.model_id)
self.task, model=self.model_id, trust_remote_code=True)
self.pipeline_inference(pedestrian_attribute_recognition,
Image.open(self.test_image))

View File

@@ -19,7 +19,9 @@ class TinynasObjectDetectionTest(unittest.TestCase):
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_airdet(self):
tinynas_object_detection = pipeline(
Tasks.image_object_detection, model='damo/cv_tinynas_detection')
Tasks.image_object_detection,
model='damo/cv_tinynas_detection',
trust_remote_code=True)
result = tinynas_object_detection(
'data/test/images/image_detection.jpg')
print('airdet', result)
@@ -28,7 +30,8 @@ class TinynasObjectDetectionTest(unittest.TestCase):
def test_run_damoyolo(self):
tinynas_object_detection = pipeline(
Tasks.image_object_detection,
model='damo/cv_tinynas_object-detection_damoyolo')
model='damo/cv_tinynas_object-detection_damoyolo',
trust_remote_code=True)
result = tinynas_object_detection(
'data/test/images/image_detection.jpg')
print('damoyolo-s', result)
@@ -37,7 +40,8 @@ class TinynasObjectDetectionTest(unittest.TestCase):
def test_run_damoyolo_m(self):
tinynas_object_detection = pipeline(
Tasks.image_object_detection,
model='damo/cv_tinynas_object-detection_damoyolo-m')
model='damo/cv_tinynas_object-detection_damoyolo-m',
trust_remote_code=True)
result = tinynas_object_detection(
'data/test/images/image_detection.jpg')
print('damoyolo-m', result)
@@ -46,7 +50,8 @@ class TinynasObjectDetectionTest(unittest.TestCase):
def test_run_damoyolo_t(self):
tinynas_object_detection = pipeline(
Tasks.image_object_detection,
model='damo/cv_tinynas_object-detection_damoyolo-t')
model='damo/cv_tinynas_object-detection_damoyolo-t',
trust_remote_code=True)
result = tinynas_object_detection(
'data/test/images/image_detection.jpg')
print('damoyolo-t', result)
@@ -56,7 +61,8 @@ class TinynasObjectDetectionTest(unittest.TestCase):
test_image = 'data/test/images/image_detection.jpg'
tinynas_object_detection = pipeline(
Tasks.image_object_detection,
model='damo/cv_tinynas_object-detection_damoyolo-m')
model='damo/cv_tinynas_object-detection_damoyolo-m',
trust_remote_code=True)
result = tinynas_object_detection(test_image)
tinynas_object_detection.show_result(test_image, result,
'demo_ret.jpg')
@@ -65,7 +71,8 @@ class TinynasObjectDetectionTest(unittest.TestCase):
def test_human_detection_damoyolo(self):
tinynas_object_detection = pipeline(
Tasks.domain_specific_object_detection,
model='damo/cv_tinynas_human-detection_damoyolo')
model='damo/cv_tinynas_human-detection_damoyolo',
trust_remote_code=True)
result = tinynas_object_detection(
'data/test/images/image_detection.jpg')
assert result and (OutputKeys.SCORES in result) and (
@@ -76,7 +83,8 @@ class TinynasObjectDetectionTest(unittest.TestCase):
def test_human_detection_damoyolo_with_image(self):
tinynas_object_detection = pipeline(
Tasks.domain_specific_object_detection,
model='damo/cv_tinynas_human-detection_damoyolo')
model='damo/cv_tinynas_human-detection_damoyolo',
trust_remote_code=True)
img = Image.open('data/test/images/image_detection.jpg')
result = tinynas_object_detection(img)
assert result and (OutputKeys.SCORES in result) and (
@@ -87,7 +95,8 @@ class TinynasObjectDetectionTest(unittest.TestCase):
def test_facemask_detection_damoyolo(self):
tinynas_object_detection = pipeline(
Tasks.domain_specific_object_detection,
model='damo/cv_tinynas_object-detection_damoyolo_facemask')
model='damo/cv_tinynas_object-detection_damoyolo_facemask',
trust_remote_code=True)
result = tinynas_object_detection(
'data/test/images/image_detection.jpg')
assert result and (OutputKeys.SCORES in result) and (
@@ -98,7 +107,8 @@ class TinynasObjectDetectionTest(unittest.TestCase):
def test_facemask_detection_damoyolo_with_image(self):
tinynas_object_detection = pipeline(
Tasks.domain_specific_object_detection,
model='damo/cv_tinynas_object-detection_damoyolo_facemask')
model='damo/cv_tinynas_object-detection_damoyolo_facemask',
trust_remote_code=True)
img = Image.open('data/test/images/image_detection.jpg')
result = tinynas_object_detection(img)
assert result and (OutputKeys.SCORES in result) and (
@@ -109,7 +119,8 @@ class TinynasObjectDetectionTest(unittest.TestCase):
def test_safetyhat_detection_damoyolo(self):
tinynas_object_detection = pipeline(
Tasks.domain_specific_object_detection,
model='damo/cv_tinynas_object-detection_damoyolo_safety-helmet')
model='damo/cv_tinynas_object-detection_damoyolo_safety-helmet',
trust_remote_code=True)
result = tinynas_object_detection(
'data/test/images/image_safetyhat.jpg')
assert result and (OutputKeys.SCORES in result) and (
@@ -120,7 +131,8 @@ class TinynasObjectDetectionTest(unittest.TestCase):
def test_safetyhat_detection_damoyolo_with_image(self):
tinynas_object_detection = pipeline(
Tasks.domain_specific_object_detection,
model='damo/cv_tinynas_object-detection_damoyolo_safety-helmet')
model='damo/cv_tinynas_object-detection_damoyolo_safety-helmet',
trust_remote_code=True)
img = Image.open('data/test/images/image_safetyhat.jpg')
result = tinynas_object_detection(img)
assert result and (OutputKeys.SCORES in result) and (
@@ -131,7 +143,8 @@ class TinynasObjectDetectionTest(unittest.TestCase):
def test_cigarette_detection_damoyolo(self):
tinynas_object_detection = pipeline(
Tasks.domain_specific_object_detection,
model='damo/cv_tinynas_object-detection_damoyolo_cigarette')
model='damo/cv_tinynas_object-detection_damoyolo_cigarette',
trust_remote_code=True)
result = tinynas_object_detection('data/test/images/image_smoke.jpg')
assert result and (OutputKeys.SCORES in result) and (
OutputKeys.LABELS in result) and (OutputKeys.BOXES in result)
@@ -141,7 +154,8 @@ class TinynasObjectDetectionTest(unittest.TestCase):
def test_cigarette_detection_damoyolo_with_image(self):
tinynas_object_detection = pipeline(
Tasks.domain_specific_object_detection,
model='damo/cv_tinynas_object-detection_damoyolo_cigarette')
model='damo/cv_tinynas_object-detection_damoyolo_cigarette',
trust_remote_code=True)
img = Image.open('data/test/images/image_smoke.jpg')
result = tinynas_object_detection(img)
assert result and (OutputKeys.SCORES in result) and (
@@ -152,7 +166,8 @@ class TinynasObjectDetectionTest(unittest.TestCase):
def test_phone_detection_damoyolo(self):
tinynas_object_detection = pipeline(
Tasks.domain_specific_object_detection,
model='damo/cv_tinynas_object-detection_damoyolo_phone')
model='damo/cv_tinynas_object-detection_damoyolo_phone',
trust_remote_code=True)
result = tinynas_object_detection('data/test/images/image_phone.jpg')
assert result and (OutputKeys.SCORES in result) and (
OutputKeys.LABELS in result) and (OutputKeys.BOXES in result)
@@ -162,7 +177,8 @@ class TinynasObjectDetectionTest(unittest.TestCase):
def test_phone_detection_damoyolo_with_image(self):
tinynas_object_detection = pipeline(
Tasks.domain_specific_object_detection,
model='damo/cv_tinynas_object-detection_damoyolo_phone')
model='damo/cv_tinynas_object-detection_damoyolo_phone',
trust_remote_code=True)
img = Image.open('data/test/images/image_phone.jpg')
result = tinynas_object_detection(img)
assert result and (OutputKeys.SCORES in result) and (
@@ -173,7 +189,8 @@ class TinynasObjectDetectionTest(unittest.TestCase):
def test_head_detection_damoyolo(self):
tinynas_object_detection = pipeline(
Tasks.domain_specific_object_detection,
model='damo/cv_tinynas_head-detection_damoyolo')
model='damo/cv_tinynas_head-detection_damoyolo',
trust_remote_code=True)
result = tinynas_object_detection(
'data/test/images/image_detection.jpg')
assert result and (OutputKeys.SCORES in result) and (
@@ -184,7 +201,8 @@ class TinynasObjectDetectionTest(unittest.TestCase):
def test_head_detection_damoyolo_with_image(self):
tinynas_object_detection = pipeline(
Tasks.domain_specific_object_detection,
model='damo/cv_tinynas_head-detection_damoyolo')
model='damo/cv_tinynas_head-detection_damoyolo',
trust_remote_code=True)
img = Image.open('data/test/images/image_detection.jpg')
result = tinynas_object_detection(img)
assert result and (OutputKeys.SCORES in result) and (
@@ -195,7 +213,8 @@ class TinynasObjectDetectionTest(unittest.TestCase):
def test_smokefire_detection_damoyolo(self):
tinynas_object_detection = pipeline(
Tasks.domain_specific_object_detection,
model='damo/cv_tinynas_object-detection_damoyolo_smokefire')
model='damo/cv_tinynas_object-detection_damoyolo_smokefire',
trust_remote_code=True)
result = tinynas_object_detection(
'data/test/images/image_smokefire_detection.jpg')
assert result and (OutputKeys.SCORES in result) and (
@@ -206,7 +225,8 @@ class TinynasObjectDetectionTest(unittest.TestCase):
def test_smokefire_detection_damoyolo_with_image(self):
tinynas_object_detection = pipeline(
Tasks.domain_specific_object_detection,
model='damo/cv_tinynas_object-detection_damoyolo_smokefire')
model='damo/cv_tinynas_object-detection_damoyolo_smokefire',
trust_remote_code=True)
img = Image.open('data/test/images/image_smokefire_detection.jpg')
result = tinynas_object_detection(img)
assert result and (OutputKeys.SCORES in result) and (

View File

@@ -20,7 +20,8 @@ class TrafficSignDetectionTest(unittest.TestCase):
def test_traffic_sign_detection_damoyolo(self):
tinynas_object_detection = pipeline(
Tasks.domain_specific_object_detection,
model='damo/cv_tinynas_object-detection_damoyolo_traffic_sign')
model='damo/cv_tinynas_object-detection_damoyolo_traffic_sign',
trust_remote_code=True)
result = tinynas_object_detection(
'data/test/images/image_traffic_sign.jpg')
assert result and (OutputKeys.SCORES in result) and (

View File

@@ -17,7 +17,10 @@ class VideoDepthEstimationTest(unittest.TestCase):
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_video_depth_estimation(self):
input_location = 'data/test/videos/video_depth_estimation.mp4'
estimator = pipeline(Tasks.video_depth_estimation, model=self.model_id)
estimator = pipeline(
Tasks.video_depth_estimation,
model=self.model_id,
trust_remote_code=True)
result = estimator(input_location)
show_video_depth_estimation_result(result[OutputKeys.DEPTHS_COLOR],
'out.mp4')