mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 08:17:45 +01:00
add cv_casmvs_multi-view-depth-esimation_general
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11204285
This commit is contained in:
@@ -66,6 +66,7 @@ class Models(object):
|
||||
video_object_segmentation = 'video-object-segmentation'
|
||||
real_basicvsr = 'real-basicvsr'
|
||||
rcp_sceneflow_estimation = 'rcp-sceneflow-estimation'
|
||||
image_casmvs_depth_estimation = 'image-casmvs-depth-estimation'
|
||||
|
||||
# EasyCV models
|
||||
yolox = 'YOLOX'
|
||||
@@ -262,6 +263,7 @@ class Pipelines(object):
|
||||
video_object_segmentation = 'video-object-segmentation'
|
||||
video_super_resolution = 'realbasicvsr-video-super-resolution'
|
||||
pointcloud_sceneflow_estimation = 'pointcloud-sceneflow-estimation'
|
||||
image_multi_view_depth_estimation = 'image-multi-view-depth-estimation'
|
||||
|
||||
# nlp tasks
|
||||
automatic_post_editing = 'automatic-post-editing'
|
||||
|
||||
@@ -7,15 +7,15 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints,
|
||||
face_generation, human_wholebody_keypoint, image_classification,
|
||||
image_color_enhance, image_colorization, image_denoise,
|
||||
image_inpainting, image_instance_segmentation,
|
||||
image_panoptic_segmentation, image_portrait_enhancement,
|
||||
image_reid_person, image_semantic_segmentation,
|
||||
image_to_image_generation, image_to_image_translation,
|
||||
language_guided_video_summarization, movie_scene_segmentation,
|
||||
object_detection, pointcloud_sceneflow_estimation,
|
||||
product_retrieval_embedding, realtime_object_detection,
|
||||
referring_video_object_segmentation, salient_detection,
|
||||
shop_segmentation, super_resolution, video_object_segmentation,
|
||||
video_single_object_tracking, video_summarization,
|
||||
video_super_resolution, virual_tryon)
|
||||
image_mvs_depth_estimation, image_panoptic_segmentation,
|
||||
image_portrait_enhancement, image_reid_person,
|
||||
image_semantic_segmentation, image_to_image_generation,
|
||||
image_to_image_translation, language_guided_video_summarization,
|
||||
movie_scene_segmentation, object_detection,
|
||||
pointcloud_sceneflow_estimation, product_retrieval_embedding,
|
||||
realtime_object_detection, referring_video_object_segmentation,
|
||||
salient_detection, shop_segmentation, super_resolution,
|
||||
video_object_segmentation, video_single_object_tracking,
|
||||
video_summarization, video_super_resolution, virual_tryon)
|
||||
|
||||
# yapf: enable
|
||||
|
||||
22
modelscope/models/cv/image_mvs_depth_estimation/__init__.py
Normal file
22
modelscope/models/cv/image_mvs_depth_estimation/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .casmvs_model import ImageMultiViewDepthEstimation
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'casmvs_model': ['ImageMultiViewDepthEstimation'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
221
modelscope/models/cv/image_mvs_depth_estimation/cas_mvsnet.py
Normal file
221
modelscope/models/cv/image_mvs_depth_estimation/cas_mvsnet.py
Normal file
@@ -0,0 +1,221 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .module import (CostRegNet, FeatureNet, RefineNet, depth_regression,
|
||||
get_depth_range_samples, homo_warping)
|
||||
|
||||
Align_Corners_Range = False
|
||||
|
||||
|
||||
class DepthNet(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(DepthNet, self).__init__()
|
||||
|
||||
def forward(self,
|
||||
features,
|
||||
proj_matrices,
|
||||
depth_values,
|
||||
num_depth,
|
||||
cost_regularization,
|
||||
prob_volume_init=None):
|
||||
proj_matrices = torch.unbind(proj_matrices, 1)
|
||||
assert len(features) == len(
|
||||
proj_matrices
|
||||
), 'Different number of images and projection matrices'
|
||||
assert depth_values.shape[
|
||||
1] == num_depth, 'depth_values.shape[1]:{} num_depth:{}'.format(
|
||||
depth_values.shapep[1], num_depth)
|
||||
num_views = len(features)
|
||||
|
||||
# step 1. feature extraction
|
||||
# in: images; out: 32-channel feature maps
|
||||
ref_feature, src_features = features[0], features[1:]
|
||||
ref_proj, src_projs = proj_matrices[0], proj_matrices[1:]
|
||||
|
||||
# step 2. differentiable homograph, build cost volume
|
||||
ref_volume = ref_feature.unsqueeze(2).repeat(1, 1, num_depth, 1, 1)
|
||||
volume_sum = ref_volume
|
||||
volume_sq_sum = ref_volume**2
|
||||
del ref_volume
|
||||
for src_fea, src_proj in zip(src_features, src_projs):
|
||||
# warpped features
|
||||
src_proj_new = src_proj[:, 0].clone()
|
||||
src_proj_new[:, :3, :4] = torch.matmul(src_proj[:, 1, :3, :3],
|
||||
src_proj[:, 0, :3, :4])
|
||||
ref_proj_new = ref_proj[:, 0].clone()
|
||||
ref_proj_new[:, :3, :4] = torch.matmul(ref_proj[:, 1, :3, :3],
|
||||
ref_proj[:, 0, :3, :4])
|
||||
warped_volume = homo_warping(src_fea, src_proj_new, ref_proj_new,
|
||||
depth_values)
|
||||
if self.training:
|
||||
volume_sum = volume_sum + warped_volume
|
||||
volume_sq_sum = volume_sq_sum + warped_volume**2
|
||||
else:
|
||||
# TODO: this is only a temporary solution to save memory, better way?
|
||||
volume_sum += warped_volume
|
||||
volume_sq_sum += warped_volume.pow_(
|
||||
2) # the memory of warped_volume has been modified
|
||||
del warped_volume
|
||||
# aggregate multiple feature volumes by variance
|
||||
volume_variance = volume_sq_sum.div_(num_views).sub_(
|
||||
volume_sum.div_(num_views).pow_(2))
|
||||
|
||||
# step 3. cost volume regularization
|
||||
cost_reg = cost_regularization(volume_variance)
|
||||
prob_volume_pre = cost_reg.squeeze(1)
|
||||
|
||||
if prob_volume_init is not None:
|
||||
prob_volume_pre += prob_volume_init
|
||||
|
||||
prob_volume = F.softmax(prob_volume_pre, dim=1)
|
||||
depth = depth_regression(prob_volume, depth_values=depth_values)
|
||||
|
||||
with torch.no_grad():
|
||||
# photometric confidence
|
||||
prob_volume_sum4 = 4 * F.avg_pool3d(
|
||||
F.pad(prob_volume.unsqueeze(1), pad=(0, 0, 0, 0, 1, 2)),
|
||||
(4, 1, 1),
|
||||
stride=1,
|
||||
padding=0).squeeze(1)
|
||||
depth_index = depth_regression(
|
||||
prob_volume,
|
||||
depth_values=torch.arange(
|
||||
num_depth, device=prob_volume.device,
|
||||
dtype=torch.float)).long()
|
||||
depth_index = depth_index.clamp(min=0, max=num_depth - 1)
|
||||
photometric_confidence = torch.gather(
|
||||
prob_volume_sum4, 1, depth_index.unsqueeze(1)).squeeze(1)
|
||||
|
||||
return {
|
||||
'depth': depth,
|
||||
'photometric_confidence': photometric_confidence
|
||||
}
|
||||
|
||||
|
||||
class CascadeMVSNet(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
refine=False,
|
||||
ndepths=[48, 32, 8],
|
||||
depth_interals_ratio=[4, 2, 1],
|
||||
share_cr=False,
|
||||
grad_method='detach',
|
||||
arch_mode='fpn',
|
||||
cr_base_chs=[8, 8, 8]):
|
||||
super(CascadeMVSNet, self).__init__()
|
||||
self.refine = refine
|
||||
self.share_cr = share_cr
|
||||
self.ndepths = ndepths
|
||||
self.depth_interals_ratio = depth_interals_ratio
|
||||
self.grad_method = grad_method
|
||||
self.arch_mode = arch_mode
|
||||
self.cr_base_chs = cr_base_chs
|
||||
self.num_stage = len(ndepths)
|
||||
|
||||
assert len(ndepths) == len(depth_interals_ratio)
|
||||
|
||||
self.stage_infos = {
|
||||
'stage1': {
|
||||
'scale': 4.0,
|
||||
},
|
||||
'stage2': {
|
||||
'scale': 2.0,
|
||||
},
|
||||
'stage3': {
|
||||
'scale': 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
self.feature = FeatureNet(
|
||||
base_channels=8,
|
||||
stride=4,
|
||||
num_stage=self.num_stage,
|
||||
arch_mode=self.arch_mode)
|
||||
if self.share_cr:
|
||||
self.cost_regularization = CostRegNet(
|
||||
in_channels=self.feature.out_channels, base_channels=8)
|
||||
else:
|
||||
self.cost_regularization = nn.ModuleList([
|
||||
CostRegNet(
|
||||
in_channels=self.feature.out_channels[i],
|
||||
base_channels=self.cr_base_chs[i])
|
||||
for i in range(self.num_stage)
|
||||
])
|
||||
if self.refine:
|
||||
self.refine_network = RefineNet()
|
||||
self.DepthNet = DepthNet()
|
||||
|
||||
def forward(self, imgs, proj_matrices, depth_values):
|
||||
depth_min = float(depth_values[0, 0].cpu().numpy())
|
||||
depth_max = float(depth_values[0, -1].cpu().numpy())
|
||||
depth_interval = (depth_max - depth_min) / depth_values.size(1)
|
||||
|
||||
# step 1. feature extraction
|
||||
features = []
|
||||
for nview_idx in range(imgs.size(1)): # imgs shape (B, N, C, H, W)
|
||||
img = imgs[:, nview_idx]
|
||||
features.append(self.feature(img))
|
||||
|
||||
outputs = {}
|
||||
depth, cur_depth = None, None
|
||||
for stage_idx in range(self.num_stage):
|
||||
# stage feature, proj_mats, scales
|
||||
features_stage = [
|
||||
feat['stage{}'.format(stage_idx + 1)] for feat in features
|
||||
]
|
||||
proj_matrices_stage = proj_matrices['stage{}'.format(stage_idx
|
||||
+ 1)]
|
||||
stage_scale = self.stage_infos['stage{}'.format(stage_idx
|
||||
+ 1)]['scale']
|
||||
|
||||
if depth is not None:
|
||||
if self.grad_method == 'detach':
|
||||
cur_depth = depth.detach()
|
||||
else:
|
||||
cur_depth = depth
|
||||
cur_depth = F.interpolate(
|
||||
cur_depth.unsqueeze(1), [img.shape[2], img.shape[3]],
|
||||
mode='bilinear',
|
||||
align_corners=Align_Corners_Range).squeeze(1)
|
||||
else:
|
||||
cur_depth = depth_values
|
||||
depth_range_samples = get_depth_range_samples(
|
||||
cur_depth=cur_depth,
|
||||
ndepth=self.ndepths[stage_idx],
|
||||
depth_inteval_pixel=self.depth_interals_ratio[stage_idx]
|
||||
* depth_interval,
|
||||
dtype=img[0].dtype,
|
||||
device=img[0].device,
|
||||
shape=[img.shape[0], img.shape[2], img.shape[3]],
|
||||
max_depth=depth_max,
|
||||
min_depth=depth_min)
|
||||
|
||||
outputs_stage = self.DepthNet(
|
||||
features_stage,
|
||||
proj_matrices_stage,
|
||||
depth_values=F.interpolate(
|
||||
depth_range_samples.unsqueeze(1), [
|
||||
self.ndepths[stage_idx], img.shape[2]
|
||||
// int(stage_scale), img.shape[3] // int(stage_scale)
|
||||
],
|
||||
mode='trilinear',
|
||||
align_corners=Align_Corners_Range).squeeze(1),
|
||||
num_depth=self.ndepths[stage_idx],
|
||||
cost_regularization=self.cost_regularization
|
||||
if self.share_cr else self.cost_regularization[stage_idx])
|
||||
|
||||
depth = outputs_stage['depth']
|
||||
|
||||
outputs['stage{}'.format(stage_idx + 1)] = outputs_stage
|
||||
outputs.update(outputs_stage)
|
||||
|
||||
# depth map refinement
|
||||
if self.refine:
|
||||
refined_depth = self.refine_network(
|
||||
torch.cat((imgs[:, 0], depth), 1))
|
||||
outputs['refined_depth'] = refined_depth
|
||||
|
||||
return outputs
|
||||
164
modelscope/models/cv/image_mvs_depth_estimation/casmvs_model.py
Normal file
164
modelscope/models/cv/image_mvs_depth_estimation/casmvs_model.py
Normal file
@@ -0,0 +1,164 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import os.path as osp
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base.base_torch_model import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .cas_mvsnet import CascadeMVSNet
|
||||
from .colmap2mvsnet import processing_single_scene
|
||||
from .depth_filter import pcd_depth_filter
|
||||
from .general_eval_dataset import MVSDataset, save_pfm
|
||||
from .utils import (generate_pointcloud, numpy2torch, tensor2numpy, tocuda,
|
||||
write_cam)
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.image_multi_view_depth_estimation,
|
||||
module_name=Models.image_casmvs_depth_estimation)
|
||||
class ImageMultiViewDepthEstimation(TorchModel):
|
||||
|
||||
def __init__(self, model_dir: str, **kwargs):
|
||||
"""str -- model file root."""
|
||||
super().__init__(model_dir, **kwargs)
|
||||
|
||||
# build model
|
||||
self.model = CascadeMVSNet(
|
||||
refine=False,
|
||||
ndepths=[48, 32, 8],
|
||||
depth_interals_ratio=[float(d_i) for d_i in [4, 2, 1]],
|
||||
share_cr=False,
|
||||
cr_base_chs=[8, 8, 8],
|
||||
grad_method='detach')
|
||||
|
||||
# load checkpoint file
|
||||
ckpt_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE)
|
||||
logger.info(f'loading model {ckpt_path}')
|
||||
state_dict = torch.load(ckpt_path, map_location=torch.device('cpu'))
|
||||
self.model.load_state_dict(state_dict['model'], strict=True)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self.device = 'cuda'
|
||||
else:
|
||||
self.device = 'cpu'
|
||||
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
logger.info(f'model init done! Device:{self.device}')
|
||||
|
||||
def preprocess_make_pair(self, inputs):
|
||||
|
||||
data = inputs['input_dir']
|
||||
casmvs_inp_dir = inputs['casmvs_inp_dir']
|
||||
|
||||
args = edict()
|
||||
args.dense_folder = data
|
||||
args.save_folder = casmvs_inp_dir
|
||||
args.max_d = 192
|
||||
args.interval_scale = 1.06
|
||||
args.theta0 = 5
|
||||
args.sigma1 = 1
|
||||
args.sigma2 = 10
|
||||
args.model_ext = '.bin'
|
||||
|
||||
logger.info('preprocess of making pair data start')
|
||||
processing_single_scene(args)
|
||||
logger.info('preprocess of making pair data done')
|
||||
|
||||
def forward(self, inputs):
|
||||
test_dir = os.path.dirname(inputs['casmvs_inp_dir'])
|
||||
scene = os.path.basename(inputs['casmvs_inp_dir'])
|
||||
test_list = [scene]
|
||||
save_dir = inputs['casmvs_res_dir']
|
||||
|
||||
logger.info('depth estimation start')
|
||||
|
||||
test_dataset = MVSDataset(
|
||||
test_dir,
|
||||
test_list,
|
||||
'test',
|
||||
5,
|
||||
192,
|
||||
1.06,
|
||||
max_h=1200,
|
||||
max_w=1200,
|
||||
fix_res=False)
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_idx, sample in enumerate(test_dataset):
|
||||
sample = numpy2torch(sample)
|
||||
|
||||
if self.device == 'cuda':
|
||||
sample_cuda = tocuda(sample)
|
||||
|
||||
proj_matrices_dict = sample_cuda['proj_matrices']
|
||||
proj_matrices_dict_new = {}
|
||||
for k, v in proj_matrices_dict.items():
|
||||
proj_matrices_dict_new[k] = v.unsqueeze(0)
|
||||
|
||||
outputs = self.model(sample_cuda['imgs'].unsqueeze(0),
|
||||
proj_matrices_dict_new,
|
||||
sample_cuda['depth_values'].unsqueeze(0))
|
||||
|
||||
outputs = tensor2numpy(outputs)
|
||||
del sample_cuda
|
||||
filenames = [sample['filename']]
|
||||
cams = sample['proj_matrices']['stage{}'.format(3)].unsqueeze(
|
||||
0).numpy()
|
||||
imgs = sample['imgs'].unsqueeze(0).numpy()
|
||||
|
||||
# save depth maps and confidence maps
|
||||
for filename, cam, img, depth_est, photometric_confidence in zip(
|
||||
filenames, cams, imgs, outputs['depth'],
|
||||
outputs['photometric_confidence']):
|
||||
|
||||
img = img[0] # ref view
|
||||
cam = cam[0] # ref cam
|
||||
depth_filename = os.path.join(
|
||||
save_dir, filename.format('depth_est', '.pfm'))
|
||||
confidence_filename = os.path.join(
|
||||
save_dir, filename.format('confidence', '.pfm'))
|
||||
cam_filename = os.path.join(
|
||||
save_dir, filename.format('cams', '_cam.txt'))
|
||||
img_filename = os.path.join(
|
||||
save_dir, filename.format('images', '.jpg'))
|
||||
ply_filename = os.path.join(
|
||||
save_dir, filename.format('ply_local', '.ply'))
|
||||
os.makedirs(
|
||||
depth_filename.rsplit('/', 1)[0], exist_ok=True)
|
||||
os.makedirs(
|
||||
confidence_filename.rsplit('/', 1)[0], exist_ok=True)
|
||||
os.makedirs(cam_filename.rsplit('/', 1)[0], exist_ok=True)
|
||||
os.makedirs(img_filename.rsplit('/', 1)[0], exist_ok=True)
|
||||
os.makedirs(ply_filename.rsplit('/', 1)[0], exist_ok=True)
|
||||
# save depth maps
|
||||
save_pfm(depth_filename, depth_est)
|
||||
# save confidence maps
|
||||
save_pfm(confidence_filename, photometric_confidence)
|
||||
# save cams, img
|
||||
write_cam(cam_filename, cam)
|
||||
img = np.clip(np.transpose(img, (1, 2, 0)) * 255, 0,
|
||||
255).astype(np.uint8)
|
||||
img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
||||
cv2.imwrite(img_filename, img_bgr)
|
||||
|
||||
logger.info('depth estimation end')
|
||||
return inputs
|
||||
|
||||
def postprocess(self, inputs):
|
||||
test_dir = os.path.dirname(inputs['casmvs_inp_dir'])
|
||||
scene = os.path.basename(inputs['casmvs_inp_dir'])
|
||||
logger.info('depth fusion start')
|
||||
pcd = pcd_depth_filter(
|
||||
scene, test_dir, inputs['casmvs_res_dir'], thres_view=4)
|
||||
logger.info('depth fusion end')
|
||||
return pcd
|
||||
472
modelscope/models/cv/image_mvs_depth_estimation/colmap2mvsnet.py
Normal file
472
modelscope/models/cv/image_mvs_depth_estimation/colmap2mvsnet.py
Normal file
@@ -0,0 +1,472 @@
|
||||
# The implementation is borrowed from https://github.com/YoYo000/MVSNet. Model reading is provided by COLMAP.
|
||||
|
||||
from __future__ import print_function
|
||||
import collections
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import shutil
|
||||
import struct
|
||||
from functools import partial
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
# ============================ read_model.py ============================#
|
||||
CameraModel = collections.namedtuple('CameraModel',
|
||||
['model_id', 'model_name', 'num_params'])
|
||||
Camera = collections.namedtuple('Camera',
|
||||
['id', 'model', 'width', 'height', 'params'])
|
||||
BaseImage = collections.namedtuple(
|
||||
'Image', ['id', 'qvec', 'tvec', 'camera_id', 'name', 'xys', 'point3D_ids'])
|
||||
Point3D = collections.namedtuple(
|
||||
'Point3D', ['id', 'xyz', 'rgb', 'error', 'image_ids', 'point2D_idxs'])
|
||||
|
||||
|
||||
class Image(BaseImage):
|
||||
|
||||
def qvec2rotmat(self):
|
||||
return qvec2rotmat(self.qvec)
|
||||
|
||||
|
||||
CAMERA_MODELS = {
|
||||
CameraModel(model_id=0, model_name='SIMPLE_PINHOLE', num_params=3),
|
||||
CameraModel(model_id=1, model_name='PINHOLE', num_params=4),
|
||||
CameraModel(model_id=2, model_name='SIMPLE_RADIAL', num_params=4),
|
||||
CameraModel(model_id=3, model_name='RADIAL', num_params=5),
|
||||
CameraModel(model_id=4, model_name='OPENCV', num_params=8),
|
||||
CameraModel(model_id=5, model_name='OPENCV_FISHEYE', num_params=8),
|
||||
CameraModel(model_id=6, model_name='FULL_OPENCV', num_params=12),
|
||||
CameraModel(model_id=7, model_name='FOV', num_params=5),
|
||||
CameraModel(model_id=8, model_name='SIMPLE_RADIAL_FISHEYE', num_params=4),
|
||||
CameraModel(model_id=9, model_name='RADIAL_FISHEYE', num_params=5),
|
||||
CameraModel(model_id=10, model_name='THIN_PRISM_FISHEYE', num_params=12)
|
||||
}
|
||||
CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)
|
||||
for camera_model in CAMERA_MODELS])
|
||||
|
||||
|
||||
def read_next_bytes(fid,
|
||||
num_bytes,
|
||||
format_char_sequence,
|
||||
endian_character='<'):
|
||||
"""Read and unpack the next bytes from a binary file.
|
||||
:param fid:
|
||||
:param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
|
||||
:param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
|
||||
:param endian_character: Any of {@, =, <, >, !}
|
||||
:return: Tuple of read and unpacked values.
|
||||
"""
|
||||
data = fid.read(num_bytes)
|
||||
return struct.unpack(endian_character + format_char_sequence, data)
|
||||
|
||||
|
||||
def read_cameras_text(path):
|
||||
cameras = {}
|
||||
with open(path, 'r', encoding='utf-8') as fid:
|
||||
while True:
|
||||
line = fid.readline()
|
||||
if not line:
|
||||
break
|
||||
line = line.strip()
|
||||
if len(line) > 0 and line[0] != '#':
|
||||
elems = line.split()
|
||||
camera_id = int(elems[0])
|
||||
model = elems[1]
|
||||
width = int(elems[2])
|
||||
height = int(elems[3])
|
||||
params = np.array(tuple(map(float, elems[4:])))
|
||||
cameras[camera_id] = Camera(
|
||||
id=camera_id,
|
||||
model=model,
|
||||
width=width,
|
||||
height=height,
|
||||
params=params)
|
||||
return cameras
|
||||
|
||||
|
||||
def read_cameras_binary(path_to_model_file):
|
||||
cameras = {}
|
||||
with open(path_to_model_file, 'rb') as fid:
|
||||
num_cameras = read_next_bytes(fid, 8, 'Q')[0]
|
||||
for camera_line_index in range(num_cameras):
|
||||
camera_properties = read_next_bytes(
|
||||
fid, num_bytes=24, format_char_sequence='iiQQ')
|
||||
camera_id = camera_properties[0]
|
||||
model_id = camera_properties[1]
|
||||
model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
|
||||
width = camera_properties[2]
|
||||
height = camera_properties[3]
|
||||
num_params = CAMERA_MODEL_IDS[model_id].num_params
|
||||
params = read_next_bytes(
|
||||
fid,
|
||||
num_bytes=8 * num_params,
|
||||
format_char_sequence='d' * num_params)
|
||||
cameras[camera_id] = Camera(
|
||||
id=camera_id,
|
||||
model=model_name,
|
||||
width=width,
|
||||
height=height,
|
||||
params=np.array(params))
|
||||
assert len(cameras) == num_cameras
|
||||
return cameras
|
||||
|
||||
|
||||
def read_images_text(path):
|
||||
images = {}
|
||||
with open(path, 'r', encoding='utf-8') as fid:
|
||||
while True:
|
||||
line = fid.readline()
|
||||
if not line:
|
||||
break
|
||||
line = line.strip()
|
||||
if len(line) > 0 and line[0] != '#':
|
||||
elems = line.split()
|
||||
image_id = int(elems[0])
|
||||
qvec = np.array(tuple(map(float, elems[1:5])))
|
||||
tvec = np.array(tuple(map(float, elems[5:8])))
|
||||
camera_id = int(elems[8])
|
||||
image_name = elems[9]
|
||||
elems = fid.readline().split()
|
||||
xys = np.column_stack([
|
||||
tuple(map(float, elems[0::3])),
|
||||
tuple(map(float, elems[1::3]))
|
||||
])
|
||||
point3D_ids = np.array(tuple(map(int, elems[2::3])))
|
||||
images[image_id] = Image(
|
||||
id=image_id,
|
||||
qvec=qvec,
|
||||
tvec=tvec,
|
||||
camera_id=camera_id,
|
||||
name=image_name,
|
||||
xys=xys,
|
||||
point3D_ids=point3D_ids)
|
||||
return images
|
||||
|
||||
|
||||
def read_images_binary(path_to_model_file):
|
||||
images = {}
|
||||
with open(path_to_model_file, 'rb') as fid:
|
||||
num_reg_images = read_next_bytes(fid, 8, 'Q')[0]
|
||||
for image_index in range(num_reg_images):
|
||||
binary_image_properties = read_next_bytes(
|
||||
fid, num_bytes=64, format_char_sequence='idddddddi')
|
||||
image_id = binary_image_properties[0]
|
||||
qvec = np.array(binary_image_properties[1:5])
|
||||
tvec = np.array(binary_image_properties[5:8])
|
||||
camera_id = binary_image_properties[8]
|
||||
image_name = ''
|
||||
current_char = read_next_bytes(fid, 1, 'c')[0]
|
||||
while current_char != b'\x00': # look for the ASCII 0 entry
|
||||
image_name += current_char.decode('utf-8')
|
||||
current_char = read_next_bytes(fid, 1, 'c')[0]
|
||||
num_points2D = read_next_bytes(
|
||||
fid, num_bytes=8, format_char_sequence='Q')[0]
|
||||
x_y_id_s = read_next_bytes(
|
||||
fid,
|
||||
num_bytes=24 * num_points2D,
|
||||
format_char_sequence='ddq' * num_points2D)
|
||||
xys = np.column_stack([
|
||||
tuple(map(float, x_y_id_s[0::3])),
|
||||
tuple(map(float, x_y_id_s[1::3]))
|
||||
])
|
||||
point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
|
||||
images[image_id] = Image(
|
||||
id=image_id,
|
||||
qvec=qvec,
|
||||
tvec=tvec,
|
||||
camera_id=camera_id,
|
||||
name=image_name,
|
||||
xys=xys,
|
||||
point3D_ids=point3D_ids)
|
||||
return images
|
||||
|
||||
|
||||
def read_points3D_text(path):
|
||||
points3D = {}
|
||||
with open(path, 'r', encoding='utf-8') as fid:
|
||||
while True:
|
||||
line = fid.readline()
|
||||
if not line:
|
||||
break
|
||||
line = line.strip()
|
||||
if len(line) > 0 and line[0] != '#':
|
||||
elems = line.split()
|
||||
point3D_id = int(elems[0])
|
||||
xyz = np.array(tuple(map(float, elems[1:4])))
|
||||
rgb = np.array(tuple(map(int, elems[4:7])))
|
||||
error = float(elems[7])
|
||||
image_ids = np.array(tuple(map(int, elems[8::2])))
|
||||
point2D_idxs = np.array(tuple(map(int, elems[9::2])))
|
||||
points3D[point3D_id] = Point3D(
|
||||
id=point3D_id,
|
||||
xyz=xyz,
|
||||
rgb=rgb,
|
||||
error=error,
|
||||
image_ids=image_ids,
|
||||
point2D_idxs=point2D_idxs)
|
||||
return points3D
|
||||
|
||||
|
||||
def read_points3d_binary(path_to_model_file):
|
||||
points3D = {}
|
||||
with open(path_to_model_file, 'rb') as fid:
|
||||
num_points = read_next_bytes(fid, 8, 'Q')[0]
|
||||
for point_line_index in range(num_points):
|
||||
binary_point_line_properties = read_next_bytes(
|
||||
fid, num_bytes=43, format_char_sequence='QdddBBBd')
|
||||
point3D_id = binary_point_line_properties[0]
|
||||
xyz = np.array(binary_point_line_properties[1:4])
|
||||
rgb = np.array(binary_point_line_properties[4:7])
|
||||
error = np.array(binary_point_line_properties[7])
|
||||
track_length = read_next_bytes(
|
||||
fid, num_bytes=8, format_char_sequence='Q')[0]
|
||||
track_elems = read_next_bytes(
|
||||
fid,
|
||||
num_bytes=8 * track_length,
|
||||
format_char_sequence='ii' * track_length)
|
||||
image_ids = np.array(tuple(map(int, track_elems[0::2])))
|
||||
point2D_idxs = np.array(tuple(map(int, track_elems[1::2])))
|
||||
points3D[point3D_id] = Point3D(
|
||||
id=point3D_id,
|
||||
xyz=xyz,
|
||||
rgb=rgb,
|
||||
error=error,
|
||||
image_ids=image_ids,
|
||||
point2D_idxs=point2D_idxs)
|
||||
return points3D
|
||||
|
||||
|
||||
def read_model(path, ext):
|
||||
if ext == '.txt':
|
||||
cameras = read_cameras_text(os.path.join(path, 'cameras' + ext))
|
||||
images = read_images_text(os.path.join(path, 'images' + ext))
|
||||
points3D = read_points3D_text(os.path.join(path, 'points3D') + ext)
|
||||
else:
|
||||
cameras = read_cameras_binary(os.path.join(path, 'cameras' + ext))
|
||||
images = read_images_binary(os.path.join(path, 'images' + ext))
|
||||
points3D = read_points3d_binary(os.path.join(path, 'points3D') + ext)
|
||||
return cameras, images, points3D
|
||||
|
||||
|
||||
def qvec2rotmat(qvec):
|
||||
return np.array([
|
||||
[
|
||||
1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
|
||||
2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
|
||||
2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]
|
||||
], # noqa
|
||||
[
|
||||
2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
|
||||
1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
|
||||
2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]
|
||||
], # noqa
|
||||
[
|
||||
2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
|
||||
2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
|
||||
1 - 2 * qvec[1]**2 - 2 * qvec[2]**2
|
||||
]
|
||||
]) # noqa
|
||||
|
||||
|
||||
def rotmat2qvec(R):
|
||||
Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
|
||||
K = np.array(
|
||||
[[Rxx - Ryy - Rzz, 0, 0, 0], [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
|
||||
[Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
|
||||
[Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 # noqa
|
||||
eigvals, eigvecs = np.linalg.eigh(K)
|
||||
qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
|
||||
if qvec[0] < 0:
|
||||
qvec *= -1
|
||||
return qvec
|
||||
|
||||
|
||||
def calc_score(inputs, images, points3d, extrinsic, args):
|
||||
i, j = inputs
|
||||
id_i = images[i + 1].point3D_ids
|
||||
id_j = images[j + 1].point3D_ids
|
||||
id_intersect = [it for it in id_i if it in id_j]
|
||||
cam_center_i = -np.matmul(extrinsic[i + 1][:3, :3].transpose(),
|
||||
extrinsic[i + 1][:3, 3:4])[:, 0]
|
||||
cam_center_j = -np.matmul(extrinsic[j + 1][:3, :3].transpose(),
|
||||
extrinsic[j + 1][:3, 3:4])[:, 0]
|
||||
score = 0
|
||||
for pid in id_intersect:
|
||||
if pid == -1:
|
||||
continue
|
||||
p = points3d[pid].xyz
|
||||
theta = (180 / np.pi) * np.arccos(
|
||||
np.dot(cam_center_i - p, cam_center_j - p)
|
||||
/ np.linalg.norm(cam_center_i - p)
|
||||
/ np.linalg.norm(cam_center_j - p))
|
||||
tmp_value = (
|
||||
2 * # noqa
|
||||
(args.sigma1 if theta <= args.theta0 else args.sigma2)**2)
|
||||
score += np.exp(-(theta - args.theta0) * # noqa
|
||||
(theta - args.theta0) / tmp_value)
|
||||
return i, j, score
|
||||
|
||||
|
||||
def processing_single_scene(args):
|
||||
|
||||
image_dir = os.path.join(args.dense_folder, 'images')
|
||||
model_dir = os.path.join(args.dense_folder, 'sparse')
|
||||
cam_dir = os.path.join(args.save_folder, 'cams')
|
||||
image_converted_dir = os.path.join(args.save_folder, 'images_post')
|
||||
|
||||
if os.path.exists(image_converted_dir):
|
||||
shutil.rmtree(image_converted_dir)
|
||||
os.makedirs(image_converted_dir)
|
||||
if os.path.exists(cam_dir):
|
||||
shutil.rmtree(cam_dir)
|
||||
|
||||
cameras, images, points3d = read_model(model_dir, args.model_ext)
|
||||
num_images = len(list(images.items()))
|
||||
|
||||
param_type = {
|
||||
'SIMPLE_PINHOLE': ['f', 'cx', 'cy'],
|
||||
'PINHOLE': ['fx', 'fy', 'cx', 'cy'],
|
||||
'SIMPLE_RADIAL': ['f', 'cx', 'cy', 'k'],
|
||||
'SIMPLE_RADIAL_FISHEYE': ['f', 'cx', 'cy', 'k'],
|
||||
'RADIAL': ['f', 'cx', 'cy', 'k1', 'k2'],
|
||||
'RADIAL_FISHEYE': ['f', 'cx', 'cy', 'k1', 'k2'],
|
||||
'OPENCV': ['fx', 'fy', 'cx', 'cy', 'k1', 'k2', 'p1', 'p2'],
|
||||
'OPENCV_FISHEYE': ['fx', 'fy', 'cx', 'cy', 'k1', 'k2', 'k3', 'k4'],
|
||||
'FULL_OPENCV': [
|
||||
'fx', 'fy', 'cx', 'cy', 'k1', 'k2', 'p1', 'p2', 'k3', 'k4', 'k5',
|
||||
'k6'
|
||||
],
|
||||
'FOV': ['fx', 'fy', 'cx', 'cy', 'omega'],
|
||||
'THIN_PRISM_FISHEYE': [
|
||||
'fx', 'fy', 'cx', 'cy', 'k1', 'k2', 'p1', 'p2', 'k3', 'k4', 'sx1',
|
||||
'sy1'
|
||||
]
|
||||
}
|
||||
|
||||
# intrinsic
|
||||
intrinsic = {}
|
||||
for camera_id, cam in cameras.items():
|
||||
params_dict = {
|
||||
key: value
|
||||
for key, value in zip(param_type[cam.model], cam.params)
|
||||
}
|
||||
if 'f' in param_type[cam.model]:
|
||||
params_dict['fx'] = params_dict['f']
|
||||
params_dict['fy'] = params_dict['f']
|
||||
i = np.array([[params_dict['fx'], 0, params_dict['cx']],
|
||||
[0, params_dict['fy'], params_dict['cy']], [0, 0, 1]])
|
||||
intrinsic[camera_id] = i
|
||||
|
||||
new_images = {}
|
||||
for i, image_id in enumerate(sorted(images.keys())):
|
||||
new_images[i + 1] = images[image_id]
|
||||
images = new_images
|
||||
|
||||
# extrinsic
|
||||
extrinsic = {}
|
||||
for image_id, image in images.items():
|
||||
e = np.zeros((4, 4))
|
||||
e[:3, :3] = qvec2rotmat(image.qvec)
|
||||
e[:3, 3] = image.tvec
|
||||
e[3, 3] = 1
|
||||
extrinsic[image_id] = e
|
||||
|
||||
# depth range and interval
|
||||
depth_ranges = {}
|
||||
for i in range(num_images):
|
||||
zs = []
|
||||
for p3d_id in images[i + 1].point3D_ids:
|
||||
if p3d_id == -1:
|
||||
continue
|
||||
transformed = np.matmul(extrinsic[i + 1], [
|
||||
points3d[p3d_id].xyz[0], points3d[p3d_id].xyz[1],
|
||||
points3d[p3d_id].xyz[2], 1
|
||||
])
|
||||
zs.append(np.asscalar(transformed[2]))
|
||||
zs_sorted = sorted(zs)
|
||||
# relaxed depth range
|
||||
max_ratio = 0.1
|
||||
min_ratio = 0.03
|
||||
num_max = max(5, int(len(zs) * max_ratio))
|
||||
num_min = max(1, int(len(zs) * min_ratio))
|
||||
depth_min = 1.0 * sum(zs_sorted[:num_min]) / len(zs_sorted[:num_min])
|
||||
depth_max = 1.0 * sum(zs_sorted[-num_max:]) / len(zs_sorted[-num_max:])
|
||||
if args.max_d == 0:
|
||||
image_int = intrinsic[images[i + 1].camera_id]
|
||||
image_ext = extrinsic[i + 1]
|
||||
image_r = image_ext[0:3, 0:3]
|
||||
image_t = image_ext[0:3, 3]
|
||||
p1 = [image_int[0, 2], image_int[1, 2], 1]
|
||||
p2 = [image_int[0, 2] + 1, image_int[1, 2], 1]
|
||||
P1 = np.matmul(np.linalg.inv(image_int), p1) * depth_min
|
||||
P1 = np.matmul(np.linalg.inv(image_r), (P1 - image_t))
|
||||
P2 = np.matmul(np.linalg.inv(image_int), p2) * depth_min
|
||||
P2 = np.matmul(np.linalg.inv(image_r), (P2 - image_t))
|
||||
depth_num = (1 / depth_min - 1 / depth_max) / (
|
||||
1 / depth_min - 1 / (depth_min + np.linalg.norm(P2 - P1)))
|
||||
else:
|
||||
depth_num = args.max_d
|
||||
depth_interval = (depth_max - depth_min) / (depth_num
|
||||
- 1) / args.interval_scale
|
||||
depth_ranges[i + 1] = (depth_min, depth_interval, depth_num, depth_max)
|
||||
|
||||
# view selection
|
||||
score = np.zeros((len(images), len(images)))
|
||||
queue = []
|
||||
for i in range(len(images)):
|
||||
for j in range(i + 1, len(images)):
|
||||
queue.append((i, j))
|
||||
|
||||
p = mp.Pool(processes=mp.cpu_count())
|
||||
func = partial(
|
||||
calc_score,
|
||||
images=images,
|
||||
points3d=points3d,
|
||||
args=args,
|
||||
extrinsic=extrinsic)
|
||||
result = p.map(func, queue)
|
||||
for i, j, s in result:
|
||||
score[i, j] = s
|
||||
score[j, i] = s
|
||||
view_sel = []
|
||||
for i in range(len(images)):
|
||||
sorted_score = np.argsort(score[i])[::-1]
|
||||
view_sel.append([(k, score[i, k]) for k in sorted_score[:10]])
|
||||
|
||||
# write
|
||||
os.makedirs(cam_dir, exist_ok=True)
|
||||
|
||||
for i in range(num_images):
|
||||
with open(os.path.join(cam_dir, '%08d_cam.txt' % i), 'w') as f:
|
||||
f.write('extrinsic\n')
|
||||
for j in range(4):
|
||||
for k in range(4):
|
||||
f.write(str(extrinsic[i + 1][j, k]) + ' ')
|
||||
f.write('\n')
|
||||
f.write('\nintrinsic\n')
|
||||
for j in range(3):
|
||||
for k in range(3):
|
||||
f.write(
|
||||
str(intrinsic[images[i + 1].camera_id][j, k]) + ' ')
|
||||
f.write('\n')
|
||||
f.write('\n%f %f %f %f\n' %
|
||||
(depth_ranges[i + 1][0], depth_ranges[i + 1][1],
|
||||
depth_ranges[i + 1][2], depth_ranges[i + 1][3]))
|
||||
with open(os.path.join(args.save_folder, 'pair.txt'), 'w') as f:
|
||||
f.write('%d\n' % len(images))
|
||||
for i, sorted_score in enumerate(view_sel):
|
||||
f.write('%d\n%d ' % (i, len(sorted_score)))
|
||||
for image_id, s in sorted_score:
|
||||
f.write('%d %f ' % (image_id, s))
|
||||
f.write('\n')
|
||||
|
||||
# convert to jpg
|
||||
for i in range(num_images):
|
||||
img_path = os.path.join(image_dir, images[i + 1].name)
|
||||
if not img_path.endswith('.jpg'):
|
||||
img = cv2.imread(img_path)
|
||||
cv2.imwrite(os.path.join(image_converted_dir, '%08d.jpg' % i), img)
|
||||
else:
|
||||
shutil.copyfile(
|
||||
os.path.join(image_dir, images[i + 1].name),
|
||||
os.path.join(image_converted_dir, '%08d.jpg' % i))
|
||||
249
modelscope/models/cv/image_mvs_depth_estimation/depth_filter.py
Normal file
249
modelscope/models/cv/image_mvs_depth_estimation/depth_filter.py
Normal file
@@ -0,0 +1,249 @@
|
||||
# The implementation here is modified based on https://github.com/xy-guo/MVSNet_pytorch
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from plyfile import PlyData, PlyElement
|
||||
|
||||
from .general_eval_dataset import read_pfm
|
||||
|
||||
|
||||
# read intrinsics and extrinsics
|
||||
def read_camera_parameters(filename):
|
||||
with open(filename) as f:
|
||||
lines = f.readlines()
|
||||
lines = [line.rstrip() for line in lines]
|
||||
# extrinsics: line [1,5), 4x4 matrix
|
||||
extrinsics = np.fromstring(
|
||||
' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4))
|
||||
# intrinsics: line [7-10), 3x3 matrix
|
||||
intrinsics = np.fromstring(
|
||||
' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3))
|
||||
# assume the feature is 1/4 of the original image size
|
||||
# intrinsics[:2, :] /= 4
|
||||
return intrinsics, extrinsics
|
||||
|
||||
|
||||
# read an image
|
||||
def read_img(filename):
|
||||
img = Image.open(filename)
|
||||
# scale 0~255 to 0~1
|
||||
np_img = np.array(img, dtype=np.float32) / 255.
|
||||
return np_img
|
||||
|
||||
|
||||
# read a binary mask
|
||||
def read_mask(filename):
|
||||
return read_img(filename) > 0.5
|
||||
|
||||
|
||||
# save a binary mask
|
||||
def save_mask(filename, mask):
|
||||
assert mask.dtype == np.bool
|
||||
mask = mask.astype(np.uint8) * 255
|
||||
Image.fromarray(mask).save(filename)
|
||||
|
||||
|
||||
# read a pair file, [(ref_view1, [src_view1-1, ...]), (ref_view2, [src_view2-1, ...]), ...]
|
||||
def read_pair_file(filename):
|
||||
data = []
|
||||
with open(filename) as f:
|
||||
num_viewpoint = int(f.readline())
|
||||
# 49 viewpoints
|
||||
for view_idx in range(num_viewpoint):
|
||||
ref_view = int(f.readline().rstrip())
|
||||
src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
|
||||
if len(src_views) > 0:
|
||||
data.append((ref_view, src_views))
|
||||
return data
|
||||
|
||||
|
||||
# project the reference point cloud into the source view, then project back
|
||||
def reproject_with_depth(depth_ref, intrinsics_ref, extrinsics_ref, depth_src,
|
||||
intrinsics_src, extrinsics_src):
|
||||
width, height = depth_ref.shape[1], depth_ref.shape[0]
|
||||
# step1. project reference pixels to the source view
|
||||
# reference view x, y
|
||||
x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height))
|
||||
x_ref, y_ref = x_ref.reshape([-1]), y_ref.reshape([-1])
|
||||
# reference 3D space
|
||||
xyz_ref = np.matmul(
|
||||
np.linalg.inv(intrinsics_ref),
|
||||
np.vstack(
|
||||
(x_ref, y_ref, np.ones_like(x_ref))) * depth_ref.reshape([-1]))
|
||||
# source 3D space
|
||||
xyz_src = np.matmul(
|
||||
np.matmul(extrinsics_src, np.linalg.inv(extrinsics_ref)),
|
||||
np.vstack((xyz_ref, np.ones_like(x_ref))))[:3]
|
||||
# source view x, y
|
||||
K_xyz_src = np.matmul(intrinsics_src, xyz_src)
|
||||
xy_src = K_xyz_src[:2] / K_xyz_src[2:3]
|
||||
|
||||
# step2. reproject the source view points with source view depth estimation
|
||||
# find the depth estimation of the source view
|
||||
x_src = xy_src[0].reshape([height, width]).astype(np.float32)
|
||||
y_src = xy_src[1].reshape([height, width]).astype(np.float32)
|
||||
sampled_depth_src = cv2.remap(
|
||||
depth_src, x_src, y_src, interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
# source 3D space
|
||||
# NOTE that we should use sampled source-view depth_here to project back
|
||||
xyz_src = np.matmul(
|
||||
np.linalg.inv(intrinsics_src),
|
||||
np.vstack(
|
||||
(xy_src, np.ones_like(x_ref))) * sampled_depth_src.reshape([-1]))
|
||||
# reference 3D space
|
||||
xyz_reprojected = np.matmul(
|
||||
np.matmul(extrinsics_ref, np.linalg.inv(extrinsics_src)),
|
||||
np.vstack((xyz_src, np.ones_like(x_ref))))[:3]
|
||||
# source view x, y, depth
|
||||
depth_reprojected = xyz_reprojected[2].reshape([height,
|
||||
width]).astype(np.float32)
|
||||
K_xyz_reprojected = np.matmul(intrinsics_ref, xyz_reprojected)
|
||||
xy_reprojected = K_xyz_reprojected[:2] / K_xyz_reprojected[2:3]
|
||||
x_reprojected = xy_reprojected[0].reshape([height,
|
||||
width]).astype(np.float32)
|
||||
y_reprojected = xy_reprojected[1].reshape([height,
|
||||
width]).astype(np.float32)
|
||||
|
||||
return depth_reprojected, x_reprojected, y_reprojected, x_src, y_src
|
||||
|
||||
|
||||
def check_geometric_consistency(depth_ref, intrinsics_ref, extrinsics_ref,
|
||||
depth_src, intrinsics_src, extrinsics_src):
|
||||
width, height = depth_ref.shape[1], depth_ref.shape[0]
|
||||
x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height))
|
||||
depth_reprojected, x2d_reprojected, y2d_reprojected, x2d_src, y2d_src = reproject_with_depth(
|
||||
depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src,
|
||||
extrinsics_src)
|
||||
# check |p_reproj-p_1| < 1
|
||||
dist = np.sqrt((x2d_reprojected - x_ref)**2 + (y2d_reprojected - y_ref)**2)
|
||||
|
||||
# check |d_reproj-d_1| / d_1 < 0.01
|
||||
depth_diff = np.abs(depth_reprojected - depth_ref)
|
||||
relative_depth_diff = depth_diff / depth_ref
|
||||
|
||||
mask = np.logical_and(dist < 1, relative_depth_diff < 0.01)
|
||||
depth_reprojected[~mask] = 0
|
||||
|
||||
return mask, depth_reprojected, x2d_src, y2d_src
|
||||
|
||||
|
||||
def filter_depth(pair_folder, scan_folder, out_folder, thres_view):
|
||||
# the pair file
|
||||
pair_file = os.path.join(pair_folder, 'pair.txt')
|
||||
# for the final point cloud
|
||||
vertexs = []
|
||||
vertex_colors = []
|
||||
|
||||
pair_data = read_pair_file(pair_file)
|
||||
|
||||
# for each reference view and the corresponding source views
|
||||
for ref_view, src_views in pair_data:
|
||||
# src_views = src_views[:args.num_view]
|
||||
# load the camera parameters
|
||||
ref_intrinsics, ref_extrinsics = read_camera_parameters(
|
||||
os.path.join(scan_folder, 'cams/{:0>8}_cam.txt'.format(ref_view)))
|
||||
# load the reference image
|
||||
ref_img = read_img(
|
||||
os.path.join(scan_folder, 'images/{:0>8}.jpg'.format(ref_view)))
|
||||
# load the estimated depth of the reference view
|
||||
ref_depth_est = read_pfm(
|
||||
os.path.join(out_folder,
|
||||
'depth_est/{:0>8}.pfm'.format(ref_view)))[0]
|
||||
# load the photometric mask of the reference view
|
||||
confidence = read_pfm(
|
||||
os.path.join(out_folder,
|
||||
'confidence/{:0>8}.pfm'.format(ref_view)))[0]
|
||||
photo_mask = confidence > 0.9
|
||||
|
||||
all_srcview_depth_ests = []
|
||||
all_srcview_x = []
|
||||
all_srcview_y = []
|
||||
all_srcview_geomask = []
|
||||
|
||||
# compute the geometric mask
|
||||
geo_mask_sum = 0
|
||||
for src_view in src_views:
|
||||
# camera parameters of the source view
|
||||
src_intrinsics, src_extrinsics = read_camera_parameters(
|
||||
os.path.join(scan_folder,
|
||||
'cams/{:0>8}_cam.txt'.format(src_view)))
|
||||
# the estimated depth of the source view
|
||||
src_depth_est = read_pfm(
|
||||
os.path.join(out_folder,
|
||||
'depth_est/{:0>8}.pfm'.format(src_view)))[0]
|
||||
|
||||
geo_mask, depth_reprojected, x2d_src, y2d_src = check_geometric_consistency(
|
||||
ref_depth_est, ref_intrinsics, ref_extrinsics, src_depth_est,
|
||||
src_intrinsics, src_extrinsics)
|
||||
geo_mask_sum += geo_mask.astype(np.int32)
|
||||
all_srcview_depth_ests.append(depth_reprojected)
|
||||
all_srcview_x.append(x2d_src)
|
||||
all_srcview_y.append(y2d_src)
|
||||
all_srcview_geomask.append(geo_mask)
|
||||
|
||||
depth_est_averaged = (sum(all_srcview_depth_ests) + ref_depth_est) / (
|
||||
geo_mask_sum + 1)
|
||||
# at least 3 source views matched
|
||||
geo_mask = geo_mask_sum >= thres_view
|
||||
final_mask = np.logical_and(photo_mask, geo_mask)
|
||||
|
||||
os.makedirs(os.path.join(out_folder, 'mask'), exist_ok=True)
|
||||
save_mask(
|
||||
os.path.join(out_folder, 'mask/{:0>8}_photo.png'.format(ref_view)),
|
||||
photo_mask)
|
||||
save_mask(
|
||||
os.path.join(out_folder, 'mask/{:0>8}_geo.png'.format(ref_view)),
|
||||
geo_mask)
|
||||
save_mask(
|
||||
os.path.join(out_folder, 'mask/{:0>8}_final.png'.format(ref_view)),
|
||||
final_mask)
|
||||
|
||||
height, width = depth_est_averaged.shape[:2]
|
||||
x, y = np.meshgrid(np.arange(0, width), np.arange(0, height))
|
||||
valid_points = final_mask
|
||||
x, y, depth = x[valid_points], y[valid_points], depth_est_averaged[
|
||||
valid_points]
|
||||
|
||||
color = ref_img[valid_points]
|
||||
|
||||
xyz_ref = np.matmul(
|
||||
np.linalg.inv(ref_intrinsics),
|
||||
np.vstack((x, y, np.ones_like(x))) * depth)
|
||||
xyz_world = np.matmul(
|
||||
np.linalg.inv(ref_extrinsics), np.vstack(
|
||||
(xyz_ref, np.ones_like(x))))[:3]
|
||||
vertexs.append(xyz_world.transpose((1, 0)))
|
||||
vertex_colors.append((color * 255).astype(np.uint8))
|
||||
|
||||
vertexs = np.concatenate(vertexs, axis=0)
|
||||
vertex_colors = np.concatenate(vertex_colors, axis=0)
|
||||
vertexs = np.array([tuple(v) for v in vertexs],
|
||||
dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])
|
||||
vertex_colors = np.array([tuple(v) for v in vertex_colors],
|
||||
dtype=[('red', 'u1'), ('green', 'u1'),
|
||||
('blue', 'u1')])
|
||||
|
||||
vertex_all = np.empty(
|
||||
len(vertexs), vertexs.dtype.descr + vertex_colors.dtype.descr)
|
||||
for prop in vertexs.dtype.names:
|
||||
vertex_all[prop] = vertexs[prop]
|
||||
for prop in vertex_colors.dtype.names:
|
||||
vertex_all[prop] = vertex_colors[prop]
|
||||
|
||||
el = PlyElement.describe(vertex_all, 'vertex')
|
||||
# PlyData([el]).write(plyfilename)
|
||||
pcd = PlyData([el])
|
||||
|
||||
return pcd
|
||||
|
||||
|
||||
def pcd_depth_filter(scene, test_dir, save_dir, thres_view):
|
||||
old_scene_folder = os.path.join(test_dir, scene)
|
||||
new_scene_folder = os.path.join(save_dir, scene)
|
||||
out_folder = os.path.join(save_dir, scene)
|
||||
pcd = filter_depth(old_scene_folder, new_scene_folder, out_folder,
|
||||
thres_view)
|
||||
return pcd
|
||||
@@ -0,0 +1,284 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
def read_pfm(filename):
|
||||
file = open(filename, 'rb')
|
||||
color = None
|
||||
width = None
|
||||
height = None
|
||||
scale = None
|
||||
endian = None
|
||||
|
||||
header = file.readline().decode('utf-8').rstrip()
|
||||
if header == 'PF':
|
||||
color = True
|
||||
elif header == 'Pf':
|
||||
color = False
|
||||
else:
|
||||
raise Exception('Not a PFM file.')
|
||||
|
||||
dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8'))
|
||||
if dim_match:
|
||||
width, height = map(int, dim_match.groups())
|
||||
else:
|
||||
raise Exception('Malformed PFM header.')
|
||||
|
||||
scale = float(file.readline().rstrip())
|
||||
if scale < 0: # little-endian
|
||||
endian = '<'
|
||||
scale = -scale
|
||||
else:
|
||||
endian = '>' # big-endian
|
||||
|
||||
data = np.fromfile(file, endian + 'f')
|
||||
shape = (height, width, 3) if color else (height, width)
|
||||
|
||||
data = np.reshape(data, shape)
|
||||
data = np.flipud(data)
|
||||
file.close()
|
||||
return data, scale
|
||||
|
||||
|
||||
def save_pfm(filename, image, scale=1):
|
||||
file = open(filename, 'wb')
|
||||
color = None
|
||||
|
||||
image = np.flipud(image)
|
||||
|
||||
if image.dtype.name != 'float32':
|
||||
raise Exception('Image dtype must be float32.')
|
||||
|
||||
if len(image.shape) == 3 and image.shape[2] == 3: # color image
|
||||
color = True
|
||||
elif len(image.shape) == 2 or len(
|
||||
image.shape) == 3 and image.shape[2] == 1: # greyscale
|
||||
color = False
|
||||
else:
|
||||
raise Exception(
|
||||
'Image must have H x W x 3, H x W x 1 or H x W dimensions.')
|
||||
|
||||
file.write('PF\n'.encode('utf-8') if color else 'Pf\n'.encode('utf-8'))
|
||||
file.write('{} {}\n'.format(image.shape[1],
|
||||
image.shape[0]).encode('utf-8'))
|
||||
|
||||
endian = image.dtype.byteorder
|
||||
|
||||
if endian == '<' or endian == '=' and sys.byteorder == 'little':
|
||||
scale = -scale
|
||||
|
||||
file.write(('%f\n' % scale).encode('utf-8'))
|
||||
|
||||
image.tofile(file)
|
||||
file.close()
|
||||
|
||||
|
||||
S_H, S_W = 0, 0
|
||||
|
||||
|
||||
class MVSDataset(Dataset):
|
||||
|
||||
def __init__(self,
|
||||
datapath,
|
||||
listfile,
|
||||
mode,
|
||||
nviews,
|
||||
ndepths=192,
|
||||
interval_scale=1.06,
|
||||
**kwargs):
|
||||
super(MVSDataset, self).__init__()
|
||||
self.datapath = datapath
|
||||
self.listfile = listfile
|
||||
self.mode = mode
|
||||
self.nviews = nviews
|
||||
self.ndepths = ndepths
|
||||
self.interval_scale = interval_scale
|
||||
self.max_h, self.max_w = kwargs['max_h'], kwargs['max_w']
|
||||
self.fix_res = kwargs.get(
|
||||
'fix_res', False) # whether to fix the resolution of input image.
|
||||
self.fix_wh = False
|
||||
|
||||
assert self.mode == 'test'
|
||||
self.metas = self.build_list()
|
||||
|
||||
def build_list(self):
|
||||
metas = []
|
||||
scans = self.listfile
|
||||
|
||||
interval_scale_dict = {}
|
||||
# scans
|
||||
for scan in scans:
|
||||
# determine the interval scale of each scene. default is 1.06
|
||||
if isinstance(self.interval_scale, float):
|
||||
interval_scale_dict[scan] = self.interval_scale
|
||||
else:
|
||||
interval_scale_dict[scan] = self.interval_scale[scan]
|
||||
|
||||
pair_file = '{}/pair.txt'.format(scan)
|
||||
# read the pair file
|
||||
with open(os.path.join(self.datapath, pair_file)) as f:
|
||||
num_viewpoint = int(f.readline())
|
||||
# viewpoints
|
||||
for view_idx in range(num_viewpoint):
|
||||
ref_view = int(f.readline().rstrip())
|
||||
src_views = [
|
||||
int(x) for x in f.readline().rstrip().split()[1::2]
|
||||
]
|
||||
# filter by no src view and fill to nviews
|
||||
if len(src_views) > 0:
|
||||
if len(src_views) < self.nviews:
|
||||
src_views += [src_views[0]] * (
|
||||
self.nviews - len(src_views))
|
||||
metas.append((scan, ref_view, src_views, scan))
|
||||
|
||||
self.interval_scale = interval_scale_dict
|
||||
return metas
|
||||
|
||||
def __len__(self):
|
||||
return len(self.metas)
|
||||
|
||||
def read_cam_file(self, filename, interval_scale):
|
||||
with open(filename) as f:
|
||||
lines = f.readlines()
|
||||
lines = [line.rstrip() for line in lines]
|
||||
# extrinsics: line [1,5), 4x4 matrix
|
||||
extrinsics = np.fromstring(
|
||||
' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4))
|
||||
# intrinsics: line [7-10), 3x3 matrix
|
||||
intrinsics = np.fromstring(
|
||||
' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3))
|
||||
intrinsics[:2, :] /= 4.0
|
||||
# depth_min & depth_interval: line 11
|
||||
depth_min = float(lines[11].split()[0])
|
||||
depth_interval = float(lines[11].split()[1])
|
||||
|
||||
if len(lines[11].split()) >= 3:
|
||||
num_depth = lines[11].split()[2]
|
||||
depth_max = depth_min + int(float(num_depth)) * depth_interval
|
||||
depth_interval = (depth_max - depth_min) / self.ndepths
|
||||
|
||||
depth_interval *= interval_scale
|
||||
|
||||
return intrinsics, extrinsics, depth_min, depth_interval
|
||||
|
||||
def read_img(self, filename):
|
||||
img = Image.open(filename)
|
||||
# scale 0~255 to 0~1
|
||||
np_img = np.array(img, dtype=np.float32) / 255.
|
||||
|
||||
return np_img
|
||||
|
||||
def read_depth(self, filename):
|
||||
# read pfm depth file
|
||||
return np.array(read_pfm(filename)[0], dtype=np.float32)
|
||||
|
||||
def scale_mvs_input(self, img, intrinsics, max_w, max_h, base=32):
|
||||
h, w = img.shape[:2]
|
||||
if h > max_h or w > max_w:
|
||||
scale = 1.0 * max_h / h
|
||||
if scale * w > max_w:
|
||||
scale = 1.0 * max_w / w
|
||||
new_w, new_h = scale * w // base * base, scale * h // base * base
|
||||
else:
|
||||
new_w, new_h = 1.0 * w // base * base, 1.0 * h // base * base
|
||||
|
||||
scale_w = 1.0 * new_w / w
|
||||
scale_h = 1.0 * new_h / h
|
||||
intrinsics[0, :] *= scale_w
|
||||
intrinsics[1, :] *= scale_h
|
||||
|
||||
img = cv2.resize(img, (int(new_w), int(new_h)))
|
||||
|
||||
return img, intrinsics
|
||||
|
||||
def __getitem__(self, idx):
|
||||
global S_H, S_W
|
||||
meta = self.metas[idx]
|
||||
scan, ref_view, src_views, scene_name = meta
|
||||
# use only the reference view and first nviews-1 source views
|
||||
view_ids = [ref_view] + src_views[:self.nviews - 1]
|
||||
|
||||
imgs = []
|
||||
depth_values = None
|
||||
proj_matrices = []
|
||||
|
||||
for i, vid in enumerate(view_ids):
|
||||
img_filename = os.path.join(
|
||||
self.datapath, '{}/images_post/{:0>8}.jpg'.format(scan, vid))
|
||||
if not os.path.exists(img_filename):
|
||||
img_filename = os.path.join(
|
||||
self.datapath, '{}/images/{:0>8}.jpg'.format(scan, vid))
|
||||
|
||||
proj_mat_filename = os.path.join(
|
||||
self.datapath, '{}/cams/{:0>8}_cam.txt'.format(scan, vid))
|
||||
|
||||
img = self.read_img(img_filename)
|
||||
intrinsics, extrinsics, depth_min, depth_interval = self.read_cam_file(
|
||||
proj_mat_filename,
|
||||
interval_scale=self.interval_scale[scene_name])
|
||||
# scale input
|
||||
img, intrinsics = self.scale_mvs_input(img, intrinsics, self.max_w,
|
||||
self.max_h)
|
||||
|
||||
if self.fix_res:
|
||||
# using the same standard height or width in entire scene.
|
||||
S_H, S_W = img.shape[:2]
|
||||
self.fix_res = False
|
||||
self.fix_wh = True
|
||||
|
||||
if i == 0:
|
||||
if not self.fix_wh:
|
||||
# using the same standard height or width in each nviews.
|
||||
S_H, S_W = img.shape[:2]
|
||||
|
||||
# resize to standard height or width
|
||||
c_h, c_w = img.shape[:2]
|
||||
if (c_h != S_H) or (c_w != S_W):
|
||||
scale_h = 1.0 * S_H / c_h
|
||||
scale_w = 1.0 * S_W / c_w
|
||||
img = cv2.resize(img, (S_W, S_H))
|
||||
intrinsics[0, :] *= scale_w
|
||||
intrinsics[1, :] *= scale_h
|
||||
|
||||
imgs.append(img)
|
||||
# extrinsics, intrinsics
|
||||
proj_mat = np.zeros(shape=(2, 4, 4), dtype=np.float32) #
|
||||
proj_mat[0, :4, :4] = extrinsics
|
||||
proj_mat[1, :3, :3] = intrinsics
|
||||
proj_matrices.append(proj_mat)
|
||||
|
||||
if i == 0: # reference view
|
||||
depth_values = np.arange(
|
||||
depth_min,
|
||||
depth_interval * (self.ndepths - 0.5) + depth_min,
|
||||
depth_interval,
|
||||
dtype=np.float32)
|
||||
|
||||
# all
|
||||
imgs = np.stack(imgs).transpose([0, 3, 1, 2])
|
||||
proj_matrices = np.stack(proj_matrices)
|
||||
|
||||
stage2_pjmats = proj_matrices.copy()
|
||||
stage2_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 2
|
||||
stage3_pjmats = proj_matrices.copy()
|
||||
stage3_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 4
|
||||
|
||||
proj_matrices_ms = {
|
||||
'stage1': proj_matrices,
|
||||
'stage2': stage2_pjmats,
|
||||
'stage3': stage3_pjmats
|
||||
}
|
||||
|
||||
return {
|
||||
'imgs': imgs,
|
||||
'proj_matrices': proj_matrices_ms,
|
||||
'depth_values': depth_values,
|
||||
'filename': scan + '/{}/' + '{:0>8}'.format(view_ids[0]) + '{}'
|
||||
}
|
||||
678
modelscope/models/cv/image_mvs_depth_estimation/module.py
Normal file
678
modelscope/models/cv/image_mvs_depth_estimation/module.py
Normal file
@@ -0,0 +1,678 @@
|
||||
# The implementation here is modified based on https://github.com/xy-guo/MVSNet_pytorch
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def init_bn(module):
|
||||
if module.weight is not None:
|
||||
nn.init.ones_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
return
|
||||
|
||||
|
||||
def init_uniform(module, init_method):
|
||||
if module.weight is not None:
|
||||
if init_method == 'kaiming':
|
||||
nn.init.kaiming_uniform_(module.weight)
|
||||
elif init_method == 'xavier':
|
||||
nn.init.xavier_uniform_(module.weight)
|
||||
return
|
||||
|
||||
|
||||
class Conv2d(nn.Module):
|
||||
"""Applies a 2D convolution (optionally with batch normalization and relu activation)
|
||||
over an input signal composed of several input planes.
|
||||
|
||||
Attributes:
|
||||
conv (nn.Module): convolution module
|
||||
bn (nn.Module): batch normalization module
|
||||
relu (bool): whether to activate by relu
|
||||
|
||||
Notes:
|
||||
Default momentum for batch normalization is set to be 0.01,
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
relu=True,
|
||||
bn=True,
|
||||
bn_momentum=0.1,
|
||||
init_method='xavier',
|
||||
**kwargs):
|
||||
super(Conv2d, self).__init__()
|
||||
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
bias=(not bn),
|
||||
**kwargs)
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.bn = nn.BatchNorm2d(
|
||||
out_channels, momentum=bn_momentum) if bn else None
|
||||
self.relu = relu
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
if self.bn is not None:
|
||||
x = self.bn(x)
|
||||
if self.relu:
|
||||
x = F.relu(x, inplace=True)
|
||||
return x
|
||||
|
||||
def init_weights(self, init_method):
|
||||
"""default initialization"""
|
||||
init_uniform(self.conv, init_method)
|
||||
if self.bn is not None:
|
||||
init_bn(self.bn)
|
||||
|
||||
|
||||
class Deconv2d(nn.Module):
|
||||
"""Applies a 2D deconvolution (optionally with batch normalization and relu activation)
|
||||
over an input signal composed of several input planes.
|
||||
|
||||
Attributes:
|
||||
conv (nn.Module): convolution module
|
||||
bn (nn.Module): batch normalization module
|
||||
relu (bool): whether to activate by relu
|
||||
|
||||
Notes:
|
||||
Default momentum for batch normalization is set to be 0.01,
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
relu=True,
|
||||
bn=True,
|
||||
bn_momentum=0.1,
|
||||
init_method='xavier',
|
||||
**kwargs):
|
||||
super(Deconv2d, self).__init__()
|
||||
self.out_channels = out_channels
|
||||
assert stride in [1, 2]
|
||||
self.stride = stride
|
||||
|
||||
self.conv = nn.ConvTranspose2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
bias=(not bn),
|
||||
**kwargs)
|
||||
self.bn = nn.BatchNorm2d(
|
||||
out_channels, momentum=bn_momentum) if bn else None
|
||||
self.relu = relu
|
||||
|
||||
def forward(self, x):
|
||||
y = self.conv(x)
|
||||
if self.stride == 2:
|
||||
h, w = list(x.size())[2:]
|
||||
y = y[:, :, :2 * h, :2 * w].contiguous()
|
||||
if self.bn is not None:
|
||||
x = self.bn(y)
|
||||
if self.relu:
|
||||
x = F.relu(x, inplace=True)
|
||||
return x
|
||||
|
||||
def init_weights(self, init_method):
|
||||
"""default initialization"""
|
||||
init_uniform(self.conv, init_method)
|
||||
if self.bn is not None:
|
||||
init_bn(self.bn)
|
||||
|
||||
|
||||
class Conv3d(nn.Module):
|
||||
"""Applies a 3D convolution (optionally with batch normalization and relu activation)
|
||||
over an input signal composed of several input planes.
|
||||
|
||||
Attributes:
|
||||
conv (nn.Module): convolution module
|
||||
bn (nn.Module): batch normalization module
|
||||
relu (bool): whether to activate by relu
|
||||
|
||||
Notes:
|
||||
Default momentum for batch normalization is set to be 0.01,
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
relu=True,
|
||||
bn=True,
|
||||
bn_momentum=0.1,
|
||||
init_method='xavier',
|
||||
**kwargs):
|
||||
super(Conv3d, self).__init__()
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
assert stride in [1, 2]
|
||||
self.stride = stride
|
||||
|
||||
self.conv = nn.Conv3d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
bias=(not bn),
|
||||
**kwargs)
|
||||
self.bn = nn.BatchNorm3d(
|
||||
out_channels, momentum=bn_momentum) if bn else None
|
||||
self.relu = relu
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
if self.bn is not None:
|
||||
x = self.bn(x)
|
||||
if self.relu:
|
||||
x = F.relu(x, inplace=True)
|
||||
return x
|
||||
|
||||
def init_weights(self, init_method):
|
||||
"""default initialization"""
|
||||
init_uniform(self.conv, init_method)
|
||||
if self.bn is not None:
|
||||
init_bn(self.bn)
|
||||
|
||||
|
||||
class Deconv3d(nn.Module):
|
||||
"""Applies a 3D deconvolution (optionally with batch normalization and relu activation)
|
||||
over an input signal composed of several input planes.
|
||||
|
||||
Attributes:
|
||||
conv (nn.Module): convolution module
|
||||
bn (nn.Module): batch normalization module
|
||||
relu (bool): whether to activate by relu
|
||||
|
||||
Notes:
|
||||
Default momentum for batch normalization is set to be 0.01,
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
relu=True,
|
||||
bn=True,
|
||||
bn_momentum=0.1,
|
||||
init_method='xavier',
|
||||
**kwargs):
|
||||
super(Deconv3d, self).__init__()
|
||||
self.out_channels = out_channels
|
||||
assert stride in [1, 2]
|
||||
self.stride = stride
|
||||
|
||||
self.conv = nn.ConvTranspose3d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
bias=(not bn),
|
||||
**kwargs)
|
||||
self.bn = nn.BatchNorm3d(
|
||||
out_channels, momentum=bn_momentum) if bn else None
|
||||
self.relu = relu
|
||||
|
||||
def forward(self, x):
|
||||
y = self.conv(x)
|
||||
if self.bn is not None:
|
||||
x = self.bn(y)
|
||||
if self.relu:
|
||||
x = F.relu(x, inplace=True)
|
||||
return x
|
||||
|
||||
def init_weights(self, init_method):
|
||||
"""default initialization"""
|
||||
init_uniform(self.conv, init_method)
|
||||
if self.bn is not None:
|
||||
init_bn(self.bn)
|
||||
|
||||
|
||||
class ConvBnReLU(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
pad=1):
|
||||
super(ConvBnReLU, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=pad,
|
||||
bias=False)
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
return F.relu(self.bn(self.conv(x)), inplace=True)
|
||||
|
||||
|
||||
class ConvBn(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
pad=1):
|
||||
super(ConvBn, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=pad,
|
||||
bias=False)
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
return self.bn(self.conv(x))
|
||||
|
||||
|
||||
def homo_warping(src_fea, src_proj, ref_proj, depth_values):
|
||||
"""
|
||||
src_fea: [B, C, H, W]
|
||||
src_proj: [B, 4, 4]
|
||||
ref_proj: [B, 4, 4]
|
||||
depth_values: [B, Ndepth] o [B, Ndepth, H, W]
|
||||
out: [B, C, Ndepth, H, W]
|
||||
"""
|
||||
batch, channels = src_fea.shape[0], src_fea.shape[1]
|
||||
num_depth = depth_values.shape[1]
|
||||
height, width = src_fea.shape[2], src_fea.shape[3]
|
||||
|
||||
with torch.no_grad():
|
||||
proj = torch.matmul(src_proj, torch.inverse(ref_proj))
|
||||
rot = proj[:, :3, :3] # [B,3,3]
|
||||
trans = proj[:, :3, 3:4] # [B,3,1]
|
||||
|
||||
y, x = torch.meshgrid([
|
||||
torch.arange(
|
||||
0, height, dtype=torch.float32, device=src_fea.device),
|
||||
torch.arange(0, width, dtype=torch.float32, device=src_fea.device)
|
||||
])
|
||||
y, x = y.contiguous(), x.contiguous()
|
||||
y, x = y.view(height * width), x.view(height * width)
|
||||
xyz = torch.stack((x, y, torch.ones_like(x))) # [3, H*W]
|
||||
xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1) # [B, 3, H*W]
|
||||
rot_xyz = torch.matmul(rot, xyz) # [B, 3, H*W]
|
||||
rot_depth_xyz = rot_xyz.unsqueeze(2).repeat(
|
||||
1, 1, num_depth, 1) * depth_values.view(batch, 1, num_depth,
|
||||
-1) # [B, 3, Ndepth, H*W]
|
||||
proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1,
|
||||
1) # [B, 3, Ndepth, H*W]
|
||||
proj_xy = proj_xyz[:, :
|
||||
2, :, :] / proj_xyz[:, 2:
|
||||
3, :, :] # [B, 2, Ndepth, H*W]
|
||||
proj_x_normalized = proj_xy[:, 0, :, :] / ((width - 1) / 2) - 1
|
||||
proj_y_normalized = proj_xy[:, 1, :, :] / ((height - 1) / 2) - 1
|
||||
proj_xy = torch.stack((proj_x_normalized, proj_y_normalized),
|
||||
dim=3) # [B, Ndepth, H*W, 2]
|
||||
grid = proj_xy
|
||||
|
||||
warped_src_fea = F.grid_sample(
|
||||
src_fea,
|
||||
grid.view(batch, num_depth * height, width, 2),
|
||||
mode='bilinear',
|
||||
padding_mode='zeros',
|
||||
align_corners=True)
|
||||
warped_src_fea = warped_src_fea.view(batch, channels, num_depth, height,
|
||||
width)
|
||||
|
||||
return warped_src_fea
|
||||
|
||||
|
||||
class DeConv2dFuse(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
relu=True,
|
||||
bn=True,
|
||||
bn_momentum=0.1):
|
||||
super(DeConv2dFuse, self).__init__()
|
||||
|
||||
self.deconv = Deconv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=2,
|
||||
padding=1,
|
||||
output_padding=1,
|
||||
bn=True,
|
||||
relu=relu,
|
||||
bn_momentum=bn_momentum)
|
||||
|
||||
self.conv = Conv2d(
|
||||
2 * out_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bn=bn,
|
||||
relu=relu,
|
||||
bn_momentum=bn_momentum)
|
||||
|
||||
def forward(self, x_pre, x):
|
||||
x = self.deconv(x)
|
||||
x = torch.cat((x, x_pre), dim=1)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class FeatureNet(nn.Module):
|
||||
|
||||
def __init__(self, base_channels, num_stage=3, stride=4, arch_mode='unet'):
|
||||
super(FeatureNet, self).__init__()
|
||||
assert arch_mode in [
|
||||
'unet', 'fpn'
|
||||
], f"mode must be in 'unet' or 'fpn', but get:{arch_mode}"
|
||||
self.arch_mode = arch_mode
|
||||
self.stride = stride
|
||||
self.base_channels = base_channels
|
||||
self.num_stage = num_stage
|
||||
|
||||
self.conv0 = nn.Sequential(
|
||||
Conv2d(3, base_channels, 3, 1, padding=1),
|
||||
Conv2d(base_channels, base_channels, 3, 1, padding=1),
|
||||
)
|
||||
|
||||
self.conv1 = nn.Sequential(
|
||||
Conv2d(base_channels, base_channels * 2, 5, stride=2, padding=2),
|
||||
Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1),
|
||||
Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1),
|
||||
)
|
||||
|
||||
self.conv2 = nn.Sequential(
|
||||
Conv2d(
|
||||
base_channels * 2, base_channels * 4, 5, stride=2, padding=2),
|
||||
Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1),
|
||||
Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1),
|
||||
)
|
||||
|
||||
self.out1 = nn.Conv2d(
|
||||
base_channels * 4, base_channels * 4, 1, bias=False)
|
||||
self.out_channels = [4 * base_channels]
|
||||
|
||||
if self.arch_mode == 'unet':
|
||||
if num_stage == 3:
|
||||
self.deconv1 = DeConv2dFuse(base_channels * 4,
|
||||
base_channels * 2, 3)
|
||||
self.deconv2 = DeConv2dFuse(base_channels * 2, base_channels,
|
||||
3)
|
||||
|
||||
self.out2 = nn.Conv2d(
|
||||
base_channels * 2, base_channels * 2, 1, bias=False)
|
||||
self.out3 = nn.Conv2d(
|
||||
base_channels, base_channels, 1, bias=False)
|
||||
self.out_channels.append(2 * base_channels)
|
||||
self.out_channels.append(base_channels)
|
||||
|
||||
elif num_stage == 2:
|
||||
self.deconv1 = DeConv2dFuse(base_channels * 4,
|
||||
base_channels * 2, 3)
|
||||
|
||||
self.out2 = nn.Conv2d(
|
||||
base_channels * 2, base_channels * 2, 1, bias=False)
|
||||
self.out_channels.append(2 * base_channels)
|
||||
elif self.arch_mode == 'fpn':
|
||||
final_chs = base_channels * 4
|
||||
if num_stage == 3:
|
||||
self.inner1 = nn.Conv2d(
|
||||
base_channels * 2, final_chs, 1, bias=True)
|
||||
self.inner2 = nn.Conv2d(
|
||||
base_channels * 1, final_chs, 1, bias=True)
|
||||
|
||||
self.out2 = nn.Conv2d(
|
||||
final_chs, base_channels * 2, 3, padding=1, bias=False)
|
||||
self.out3 = nn.Conv2d(
|
||||
final_chs, base_channels, 3, padding=1, bias=False)
|
||||
self.out_channels.append(base_channels * 2)
|
||||
self.out_channels.append(base_channels)
|
||||
|
||||
elif num_stage == 2:
|
||||
self.inner1 = nn.Conv2d(
|
||||
base_channels * 2, final_chs, 1, bias=True)
|
||||
|
||||
self.out2 = nn.Conv2d(
|
||||
final_chs, base_channels, 3, padding=1, bias=False)
|
||||
self.out_channels.append(base_channels)
|
||||
|
||||
def forward(self, x):
|
||||
conv0 = self.conv0(x)
|
||||
conv1 = self.conv1(conv0)
|
||||
conv2 = self.conv2(conv1)
|
||||
|
||||
intra_feat = conv2
|
||||
outputs = {}
|
||||
out = self.out1(intra_feat)
|
||||
outputs['stage1'] = out
|
||||
if self.arch_mode == 'unet':
|
||||
if self.num_stage == 3:
|
||||
intra_feat = self.deconv1(conv1, intra_feat)
|
||||
out = self.out2(intra_feat)
|
||||
outputs['stage2'] = out
|
||||
|
||||
intra_feat = self.deconv2(conv0, intra_feat)
|
||||
out = self.out3(intra_feat)
|
||||
outputs['stage3'] = out
|
||||
|
||||
elif self.num_stage == 2:
|
||||
intra_feat = self.deconv1(conv1, intra_feat)
|
||||
out = self.out2(intra_feat)
|
||||
outputs['stage2'] = out
|
||||
|
||||
elif self.arch_mode == 'fpn':
|
||||
if self.num_stage == 3:
|
||||
intra_feat = F.interpolate(
|
||||
intra_feat, scale_factor=2,
|
||||
mode='nearest') + self.inner1(conv1)
|
||||
out = self.out2(intra_feat)
|
||||
outputs['stage2'] = out
|
||||
|
||||
intra_feat = F.interpolate(
|
||||
intra_feat, scale_factor=2,
|
||||
mode='nearest') + self.inner2(conv0)
|
||||
out = self.out3(intra_feat)
|
||||
outputs['stage3'] = out
|
||||
|
||||
elif self.num_stage == 2:
|
||||
intra_feat = F.interpolate(
|
||||
intra_feat, scale_factor=2,
|
||||
mode='nearest') + self.inner1(conv1)
|
||||
out = self.out2(intra_feat)
|
||||
outputs['stage2'] = out
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class CostRegNet(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, base_channels):
|
||||
super(CostRegNet, self).__init__()
|
||||
self.conv0 = Conv3d(in_channels, base_channels, padding=1)
|
||||
|
||||
self.conv1 = Conv3d(
|
||||
base_channels, base_channels * 2, stride=2, padding=1)
|
||||
self.conv2 = Conv3d(base_channels * 2, base_channels * 2, padding=1)
|
||||
|
||||
self.conv3 = Conv3d(
|
||||
base_channels * 2, base_channels * 4, stride=2, padding=1)
|
||||
self.conv4 = Conv3d(base_channels * 4, base_channels * 4, padding=1)
|
||||
|
||||
self.conv5 = Conv3d(
|
||||
base_channels * 4, base_channels * 8, stride=2, padding=1)
|
||||
self.conv6 = Conv3d(base_channels * 8, base_channels * 8, padding=1)
|
||||
|
||||
self.conv7 = Deconv3d(
|
||||
base_channels * 8,
|
||||
base_channels * 4,
|
||||
stride=2,
|
||||
padding=1,
|
||||
output_padding=1)
|
||||
|
||||
self.conv9 = Deconv3d(
|
||||
base_channels * 4,
|
||||
base_channels * 2,
|
||||
stride=2,
|
||||
padding=1,
|
||||
output_padding=1)
|
||||
|
||||
self.conv11 = Deconv3d(
|
||||
base_channels * 2,
|
||||
base_channels * 1,
|
||||
stride=2,
|
||||
padding=1,
|
||||
output_padding=1)
|
||||
|
||||
self.prob = nn.Conv3d(
|
||||
base_channels, 1, 3, stride=1, padding=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
conv0 = self.conv0(x)
|
||||
conv2 = self.conv2(self.conv1(conv0))
|
||||
conv4 = self.conv4(self.conv3(conv2))
|
||||
x = self.conv6(self.conv5(conv4))
|
||||
x = conv4 + self.conv7(x)
|
||||
x = conv2 + self.conv9(x)
|
||||
x = conv0 + self.conv11(x)
|
||||
x = self.prob(x)
|
||||
return x
|
||||
|
||||
|
||||
class RefineNet(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(RefineNet, self).__init__()
|
||||
self.conv1 = ConvBnReLU(4, 32)
|
||||
self.conv2 = ConvBnReLU(32, 32)
|
||||
self.conv3 = ConvBnReLU(32, 32)
|
||||
self.res = ConvBnReLU(32, 1)
|
||||
|
||||
def forward(self, img, depth_init):
|
||||
concat = F.cat((img, depth_init), dim=1)
|
||||
depth_residual = self.res(self.conv3(self.conv2(self.conv1(concat))))
|
||||
depth_refined = depth_init + depth_residual
|
||||
return depth_refined
|
||||
|
||||
|
||||
def depth_regression(p, depth_values):
|
||||
if depth_values.dim() <= 2:
|
||||
depth_values = depth_values.view(*depth_values.shape, 1, 1)
|
||||
depth = torch.sum(p * depth_values, 1)
|
||||
|
||||
return depth
|
||||
|
||||
|
||||
def cas_mvsnet_loss(inputs, depth_gt_ms, mask_ms, **kwargs):
|
||||
depth_loss_weights = kwargs.get('dlossw', None)
|
||||
|
||||
total_loss = torch.tensor(
|
||||
0.0,
|
||||
dtype=torch.float32,
|
||||
device=mask_ms['stage1'].device,
|
||||
requires_grad=False)
|
||||
|
||||
for (stage_inputs, stage_key) in [(inputs[k], k) for k in inputs.keys()
|
||||
if 'stage' in k]:
|
||||
depth_est = stage_inputs['depth']
|
||||
depth_gt = depth_gt_ms[stage_key]
|
||||
mask = mask_ms[stage_key]
|
||||
mask = mask > 0.5
|
||||
|
||||
depth_loss = F.smooth_l1_loss(
|
||||
depth_est[mask], depth_gt[mask], reduction='mean')
|
||||
|
||||
if depth_loss_weights is not None:
|
||||
stage_idx = int(stage_key.replace('stage', '')) - 1
|
||||
total_loss += depth_loss_weights[stage_idx] * depth_loss
|
||||
else:
|
||||
total_loss += 1.0 * depth_loss
|
||||
|
||||
return total_loss, depth_loss
|
||||
|
||||
|
||||
def get_cur_depth_range_samples(cur_depth,
|
||||
ndepth,
|
||||
depth_inteval_pixel,
|
||||
shape,
|
||||
max_depth=192.0,
|
||||
min_depth=0.0):
|
||||
"""
|
||||
shape, (B, H, W)
|
||||
cur_depth: (B, H, W)
|
||||
return depth_range_values: (B, D, H, W)
|
||||
"""
|
||||
cur_depth_min = (cur_depth - ndepth / 2 * depth_inteval_pixel) # (B, H, W)
|
||||
cur_depth_max = (cur_depth + ndepth / 2 * depth_inteval_pixel)
|
||||
|
||||
assert cur_depth.shape == torch.Size(
|
||||
shape), 'cur_depth:{}, input shape:{}'.format(cur_depth.shape, shape)
|
||||
new_interval = (cur_depth_max - cur_depth_min) / (ndepth - 1) # (B, H, W)
|
||||
|
||||
depth_range_samples = cur_depth_min.unsqueeze(1) + (
|
||||
torch.arange(
|
||||
0,
|
||||
ndepth,
|
||||
device=cur_depth.device,
|
||||
dtype=cur_depth.dtype,
|
||||
requires_grad=False).reshape(1, -1, 1, 1)
|
||||
* new_interval.unsqueeze(1))
|
||||
|
||||
return depth_range_samples
|
||||
|
||||
|
||||
def get_depth_range_samples(cur_depth,
|
||||
ndepth,
|
||||
depth_inteval_pixel,
|
||||
device,
|
||||
dtype,
|
||||
shape,
|
||||
max_depth=192.0,
|
||||
min_depth=0.0):
|
||||
"""
|
||||
shape: (B, H, W)
|
||||
cur_depth: (B, H, W) or (B, D)
|
||||
return depth_range_samples: (B, D, H, W)
|
||||
"""
|
||||
if cur_depth.dim() == 2:
|
||||
cur_depth_min = cur_depth[:, 0] # (B,)
|
||||
cur_depth_max = cur_depth[:, -1]
|
||||
new_interval = (cur_depth_max - cur_depth_min) / (ndepth - 1) # (B, )
|
||||
|
||||
depth_range_samples = cur_depth_min.unsqueeze(1) + (torch.arange(
|
||||
0, ndepth, device=device, dtype=dtype,
|
||||
requires_grad=False).reshape(1, -1) * new_interval.unsqueeze(1)
|
||||
) # noqa # (B, D)
|
||||
|
||||
depth_range_samples = depth_range_samples.unsqueeze(-1).unsqueeze(
|
||||
-1).repeat(1, 1, shape[1], shape[2]) # (B, D, H, W)
|
||||
|
||||
else:
|
||||
|
||||
depth_range_samples = get_cur_depth_range_samples(
|
||||
cur_depth, ndepth, depth_inteval_pixel, shape, max_depth,
|
||||
min_depth)
|
||||
|
||||
return depth_range_samples
|
||||
118
modelscope/models/cv/image_mvs_depth_estimation/utils.py
Normal file
118
modelscope/models/cv/image_mvs_depth_estimation/utils.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# The implementation here is modified based on https://github.com/xy-guo/MVSNet_pytorch
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.utils as vutils
|
||||
|
||||
|
||||
# convert a function into recursive style to handle nested dict/list/tuple variables
|
||||
def make_recursive_func(func):
|
||||
|
||||
def wrapper(vars):
|
||||
if isinstance(vars, list):
|
||||
return [wrapper(x) for x in vars]
|
||||
elif isinstance(vars, tuple):
|
||||
return tuple([wrapper(x) for x in vars])
|
||||
elif isinstance(vars, dict):
|
||||
return {k: wrapper(v) for k, v in vars.items()}
|
||||
else:
|
||||
return func(vars)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@make_recursive_func
|
||||
def tensor2numpy(vars):
|
||||
if isinstance(vars, np.ndarray):
|
||||
return vars
|
||||
elif isinstance(vars, torch.Tensor):
|
||||
return vars.detach().cpu().numpy().copy()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'invalid input type {} for tensor2numpy'.format(type(vars)))
|
||||
|
||||
|
||||
@make_recursive_func
|
||||
def numpy2torch(vars):
|
||||
if isinstance(vars, np.ndarray):
|
||||
return torch.from_numpy(vars)
|
||||
elif isinstance(vars, torch.Tensor):
|
||||
return vars
|
||||
elif isinstance(vars, str):
|
||||
return vars
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'invalid input type {} for numpy2torch'.format(type(vars)))
|
||||
|
||||
|
||||
@make_recursive_func
|
||||
def tocuda(vars):
|
||||
if isinstance(vars, torch.Tensor):
|
||||
return vars.to(torch.device('cuda'))
|
||||
elif isinstance(vars, str):
|
||||
return vars
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'invalid input type {} for tensor2numpy'.format(type(vars)))
|
||||
|
||||
|
||||
def generate_pointcloud(rgb, depth, ply_file, intr, scale=1.0):
|
||||
"""
|
||||
Generate a colored point cloud in PLY format from a color and a depth image.
|
||||
|
||||
Input:
|
||||
rgb_file -- filename of color image
|
||||
depth_file -- filename of depth image
|
||||
ply_file -- filename of ply file
|
||||
|
||||
"""
|
||||
fx, fy, cx, cy = intr[0, 0], intr[1, 1], intr[0, 2], intr[1, 2]
|
||||
points = []
|
||||
for v in range(rgb.shape[0]):
|
||||
for u in range(rgb.shape[1]):
|
||||
color = rgb[v, u] # rgb.getpixel((u, v))
|
||||
Z = depth[v, u] / scale
|
||||
if Z == 0:
|
||||
continue
|
||||
X = (u - cx) * Z / fx
|
||||
Y = (v - cy) * Z / fy
|
||||
points.append('%f %f %f %d %d %d 0\n' %
|
||||
(X, Y, Z, color[0], color[1], color[2]))
|
||||
file = open(ply_file, 'w')
|
||||
file.write('''ply
|
||||
format ascii 1.0
|
||||
element vertex %d
|
||||
property float x
|
||||
property float y
|
||||
property float z
|
||||
property uchar red
|
||||
property uchar green
|
||||
property uchar blue
|
||||
property uchar alpha
|
||||
end_header
|
||||
%s
|
||||
''' % (len(points), ''.join(points)))
|
||||
file.close()
|
||||
|
||||
|
||||
def write_cam(file, cam):
|
||||
f = open(file, 'w')
|
||||
f.write('extrinsic\n')
|
||||
for i in range(0, 4):
|
||||
for j in range(0, 4):
|
||||
f.write(str(cam[0][i][j]) + ' ')
|
||||
f.write('\n')
|
||||
f.write('\n')
|
||||
|
||||
f.write('intrinsic\n')
|
||||
for i in range(0, 3):
|
||||
for j in range(0, 3):
|
||||
f.write(str(cam[1][i][j]) + ' ')
|
||||
f.write('\n')
|
||||
|
||||
f.write('\n' + str(cam[1][3][0]) + ' ' + str(cam[1][3][1]) + ' '
|
||||
+ str(cam[1][3][2]) + ' ' + str(cam[1][3][3]) + '\n')
|
||||
|
||||
f.close()
|
||||
@@ -245,6 +245,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
Tasks.video_object_segmentation:
|
||||
(Pipelines.video_object_segmentation,
|
||||
'damo/cv_rdevos_video-object-segmentation'),
|
||||
Tasks.image_multi_view_depth_estimation:
|
||||
(Pipelines.image_multi_view_depth_estimation,
|
||||
'damo/cv_casmvs_multi-view-depth-estimation_general'),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -73,6 +73,7 @@ if TYPE_CHECKING:
|
||||
from .video_super_resolution_pipeline import VideoSuperResolutionPipeline
|
||||
from .pointcloud_sceneflow_estimation_pipeline import PointCloudSceneFlowEstimationPipeline
|
||||
from .maskdino_instance_segmentation_pipeline import MaskDINOInstanceSegmentationPipeline
|
||||
from .image_mvs_depth_estimation_pipeline import ImageMultiViewDepthEstimationPipeline
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
@@ -171,7 +172,10 @@ else:
|
||||
],
|
||||
'maskdino_instance_segmentation_pipeline': [
|
||||
'MaskDINOInstanceSegmentationPipeline'
|
||||
]
|
||||
],
|
||||
'image_mvs_depth_estimation_pipeline': [
|
||||
'ImageMultiViewDepthEstimationPipeline'
|
||||
],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import shutil
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Model, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.cv.image_utils import depth_to_color
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.image_multi_view_depth_estimation,
|
||||
module_name=Pipelines.image_multi_view_depth_estimation)
|
||||
class ImageMultiViewDepthEstimationPipeline(Pipeline):
|
||||
|
||||
def __init__(self, model: str, **kwargs):
|
||||
"""
|
||||
use `model` to create a image multi-view depth estimation pipeline for prediction
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
self.tmp_dir = None
|
||||
logger.info('pipeline init done')
|
||||
|
||||
def check_input(self, input_dir):
|
||||
assert os.path.exists(
|
||||
input_dir), f'input dir:{input_dir} does not exsit'
|
||||
sub_dirs = os.listdir(input_dir)
|
||||
assert 'images' in sub_dirs, "must contain 'images' folder"
|
||||
assert 'sparse' in sub_dirs, "must contain 'sparse' folder"
|
||||
files = os.listdir(os.path.join(input_dir, 'sparse'))
|
||||
assert 'cameras.bin' in files, "'sparse' folder must contain 'cameras.bin'"
|
||||
assert 'images.bin' in files, "'sparse' folder must contain 'images.bin'"
|
||||
assert 'points3D.bin' in files, "'sparse' folder must contain 'points3D.bin'"
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
assert isinstance(input, str), 'input must be str'
|
||||
self.check_input(input)
|
||||
self.tmp_dir = TemporaryDirectory()
|
||||
|
||||
casmvs_inp_dir = os.path.join(self.tmp_dir.name, 'casmvs_inp_dir')
|
||||
casmvs_res_dir = os.path.join(self.tmp_dir.name, 'casmvs_res_dir')
|
||||
os.makedirs(casmvs_inp_dir, exist_ok=True)
|
||||
os.makedirs(casmvs_res_dir, exist_ok=True)
|
||||
|
||||
input_dict = {
|
||||
'input_dir': input,
|
||||
'casmvs_inp_dir': casmvs_inp_dir,
|
||||
'casmvs_res_dir': casmvs_res_dir
|
||||
}
|
||||
|
||||
self.model.preprocess_make_pair(input_dict)
|
||||
|
||||
return input_dict
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
results = self.model.forward(input)
|
||||
return results
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
pcd = self.model.postprocess(inputs)
|
||||
|
||||
# clear tmp dir
|
||||
if self.tmp_dir is not None:
|
||||
self.tmp_dir.cleanup()
|
||||
|
||||
outputs = {
|
||||
OutputKeys.OUTPUT: pcd,
|
||||
}
|
||||
|
||||
return outputs
|
||||
@@ -107,6 +107,8 @@ class CVTasks(object):
|
||||
|
||||
# pointcloud task
|
||||
pointcloud_sceneflow_estimation = 'pointcloud-sceneflow-estimation'
|
||||
# image multi-view depth estimation
|
||||
image_multi_view_depth_estimation = 'image-multi-view-depth-estimation'
|
||||
|
||||
# domain specific object detection
|
||||
domain_specific_object_detection = 'domain-specific-object-detection'
|
||||
|
||||
34
tests/pipelines/test_image_mvs_depth_estimation.py
Normal file
34
tests/pipelines/test_image_mvs_depth_estimation.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.demo_utils import DemoCompatibilityCheck
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class ImageMVSDepthEstimationTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.task = 'image-multi-view-depth-estimation'
|
||||
self.model_id = 'damo/cv_casmvs_multi-view-depth-estimation_general'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_image_mvs_depth_estimation(self):
|
||||
estimator = pipeline(
|
||||
Tasks.image_multi_view_depth_estimation,
|
||||
model='damo/cv_casmvs_multi-view-depth-estimation_general')
|
||||
model_dir = snapshot_download(self.model_id)
|
||||
input_location = os.path.join(model_dir, 'test_data')
|
||||
|
||||
result = estimator(input_location)
|
||||
pcd = result[OutputKeys.OUTPUT]
|
||||
pcd.write('./pcd_fusion.ply')
|
||||
print('test_image_mvs_depth_estimation DONE')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user