diff --git a/modelscope/models/cv/object_detection_3d/depe/mmdet3d_plugin/models/detectors/petr3d.py b/modelscope/models/cv/object_detection_3d/depe/mmdet3d_plugin/models/detectors/petr3d.py index f33a8467..c8e0a4e1 100644 --- a/modelscope/models/cv/object_detection_3d/depe/mmdet3d_plugin/models/detectors/petr3d.py +++ b/modelscope/models/cv/object_detection_3d/depe/mmdet3d_plugin/models/detectors/petr3d.py @@ -4,6 +4,7 @@ https://github.com/megvii-research/PETR/blob/main/projects/mmdet3d_plugin/models """ import numpy as np import torch +from mmcv.parallel.data_container import DataContainer as DC from mmcv.runner import auto_fp16, force_fp32 from mmdet3d.core import bbox3d2result from mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector @@ -185,7 +186,11 @@ class Petr3D(MVXTwoStageDetector): def simple_test(self, img_metas, img=None, rescale=False): """Test function without augmentaiton.""" if not torch.cuda.is_available() and img is not None: - img = img[0] + if isinstance(img, torch.Tensor): + img = img[0] + elif isinstance(img, DC) and isinstance(img_metas, DC): + img = img.data[0] + img_metas = img_metas.data[0] img_feats = self.extract_feat(img=img, img_metas=img_metas) bbox_list = [dict() for i in range(len(img_metas))] bbox_pts = self.simple_test_pts(img_feats, img_metas, rescale=rescale)