add cv_casmvs_multi-view-depth-esimation_general

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11204285
This commit is contained in:
dadong.gxd
2023-01-03 08:24:41 +08:00
committed by yingda.chen
parent 4698051fa5
commit 01c498cd14
15 changed files with 2344 additions and 11 deletions

View File

@@ -66,6 +66,7 @@ class Models(object):
video_object_segmentation = 'video-object-segmentation'
real_basicvsr = 'real-basicvsr'
rcp_sceneflow_estimation = 'rcp-sceneflow-estimation'
image_casmvs_depth_estimation = 'image-casmvs-depth-estimation'
# EasyCV models
yolox = 'YOLOX'
@@ -262,6 +263,7 @@ class Pipelines(object):
video_object_segmentation = 'video-object-segmentation'
video_super_resolution = 'realbasicvsr-video-super-resolution'
pointcloud_sceneflow_estimation = 'pointcloud-sceneflow-estimation'
image_multi_view_depth_estimation = 'image-multi-view-depth-estimation'
# nlp tasks
automatic_post_editing = 'automatic-post-editing'

View File

@@ -7,15 +7,15 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints,
face_generation, human_wholebody_keypoint, image_classification,
image_color_enhance, image_colorization, image_denoise,
image_inpainting, image_instance_segmentation,
image_panoptic_segmentation, image_portrait_enhancement,
image_reid_person, image_semantic_segmentation,
image_to_image_generation, image_to_image_translation,
language_guided_video_summarization, movie_scene_segmentation,
object_detection, pointcloud_sceneflow_estimation,
product_retrieval_embedding, realtime_object_detection,
referring_video_object_segmentation, salient_detection,
shop_segmentation, super_resolution, video_object_segmentation,
video_single_object_tracking, video_summarization,
video_super_resolution, virual_tryon)
image_mvs_depth_estimation, image_panoptic_segmentation,
image_portrait_enhancement, image_reid_person,
image_semantic_segmentation, image_to_image_generation,
image_to_image_translation, language_guided_video_summarization,
movie_scene_segmentation, object_detection,
pointcloud_sceneflow_estimation, product_retrieval_embedding,
realtime_object_detection, referring_video_object_segmentation,
salient_detection, shop_segmentation, super_resolution,
video_object_segmentation, video_single_object_tracking,
video_summarization, video_super_resolution, virual_tryon)
# yapf: enable

View File

@@ -0,0 +1,22 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .casmvs_model import ImageMultiViewDepthEstimation
else:
_import_structure = {
'casmvs_model': ['ImageMultiViewDepthEstimation'],
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1,221 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torch.nn as nn
import torch.nn.functional as F
from .module import (CostRegNet, FeatureNet, RefineNet, depth_regression,
get_depth_range_samples, homo_warping)
Align_Corners_Range = False
class DepthNet(nn.Module):
def __init__(self):
super(DepthNet, self).__init__()
def forward(self,
features,
proj_matrices,
depth_values,
num_depth,
cost_regularization,
prob_volume_init=None):
proj_matrices = torch.unbind(proj_matrices, 1)
assert len(features) == len(
proj_matrices
), 'Different number of images and projection matrices'
assert depth_values.shape[
1] == num_depth, 'depth_values.shape[1]:{} num_depth:{}'.format(
depth_values.shapep[1], num_depth)
num_views = len(features)
# step 1. feature extraction
# in: images; out: 32-channel feature maps
ref_feature, src_features = features[0], features[1:]
ref_proj, src_projs = proj_matrices[0], proj_matrices[1:]
# step 2. differentiable homograph, build cost volume
ref_volume = ref_feature.unsqueeze(2).repeat(1, 1, num_depth, 1, 1)
volume_sum = ref_volume
volume_sq_sum = ref_volume**2
del ref_volume
for src_fea, src_proj in zip(src_features, src_projs):
# warpped features
src_proj_new = src_proj[:, 0].clone()
src_proj_new[:, :3, :4] = torch.matmul(src_proj[:, 1, :3, :3],
src_proj[:, 0, :3, :4])
ref_proj_new = ref_proj[:, 0].clone()
ref_proj_new[:, :3, :4] = torch.matmul(ref_proj[:, 1, :3, :3],
ref_proj[:, 0, :3, :4])
warped_volume = homo_warping(src_fea, src_proj_new, ref_proj_new,
depth_values)
if self.training:
volume_sum = volume_sum + warped_volume
volume_sq_sum = volume_sq_sum + warped_volume**2
else:
# TODO: this is only a temporary solution to save memory, better way?
volume_sum += warped_volume
volume_sq_sum += warped_volume.pow_(
2) # the memory of warped_volume has been modified
del warped_volume
# aggregate multiple feature volumes by variance
volume_variance = volume_sq_sum.div_(num_views).sub_(
volume_sum.div_(num_views).pow_(2))
# step 3. cost volume regularization
cost_reg = cost_regularization(volume_variance)
prob_volume_pre = cost_reg.squeeze(1)
if prob_volume_init is not None:
prob_volume_pre += prob_volume_init
prob_volume = F.softmax(prob_volume_pre, dim=1)
depth = depth_regression(prob_volume, depth_values=depth_values)
with torch.no_grad():
# photometric confidence
prob_volume_sum4 = 4 * F.avg_pool3d(
F.pad(prob_volume.unsqueeze(1), pad=(0, 0, 0, 0, 1, 2)),
(4, 1, 1),
stride=1,
padding=0).squeeze(1)
depth_index = depth_regression(
prob_volume,
depth_values=torch.arange(
num_depth, device=prob_volume.device,
dtype=torch.float)).long()
depth_index = depth_index.clamp(min=0, max=num_depth - 1)
photometric_confidence = torch.gather(
prob_volume_sum4, 1, depth_index.unsqueeze(1)).squeeze(1)
return {
'depth': depth,
'photometric_confidence': photometric_confidence
}
class CascadeMVSNet(nn.Module):
def __init__(self,
refine=False,
ndepths=[48, 32, 8],
depth_interals_ratio=[4, 2, 1],
share_cr=False,
grad_method='detach',
arch_mode='fpn',
cr_base_chs=[8, 8, 8]):
super(CascadeMVSNet, self).__init__()
self.refine = refine
self.share_cr = share_cr
self.ndepths = ndepths
self.depth_interals_ratio = depth_interals_ratio
self.grad_method = grad_method
self.arch_mode = arch_mode
self.cr_base_chs = cr_base_chs
self.num_stage = len(ndepths)
assert len(ndepths) == len(depth_interals_ratio)
self.stage_infos = {
'stage1': {
'scale': 4.0,
},
'stage2': {
'scale': 2.0,
},
'stage3': {
'scale': 1.0,
}
}
self.feature = FeatureNet(
base_channels=8,
stride=4,
num_stage=self.num_stage,
arch_mode=self.arch_mode)
if self.share_cr:
self.cost_regularization = CostRegNet(
in_channels=self.feature.out_channels, base_channels=8)
else:
self.cost_regularization = nn.ModuleList([
CostRegNet(
in_channels=self.feature.out_channels[i],
base_channels=self.cr_base_chs[i])
for i in range(self.num_stage)
])
if self.refine:
self.refine_network = RefineNet()
self.DepthNet = DepthNet()
def forward(self, imgs, proj_matrices, depth_values):
depth_min = float(depth_values[0, 0].cpu().numpy())
depth_max = float(depth_values[0, -1].cpu().numpy())
depth_interval = (depth_max - depth_min) / depth_values.size(1)
# step 1. feature extraction
features = []
for nview_idx in range(imgs.size(1)): # imgs shape (B, N, C, H, W)
img = imgs[:, nview_idx]
features.append(self.feature(img))
outputs = {}
depth, cur_depth = None, None
for stage_idx in range(self.num_stage):
# stage feature, proj_mats, scales
features_stage = [
feat['stage{}'.format(stage_idx + 1)] for feat in features
]
proj_matrices_stage = proj_matrices['stage{}'.format(stage_idx
+ 1)]
stage_scale = self.stage_infos['stage{}'.format(stage_idx
+ 1)]['scale']
if depth is not None:
if self.grad_method == 'detach':
cur_depth = depth.detach()
else:
cur_depth = depth
cur_depth = F.interpolate(
cur_depth.unsqueeze(1), [img.shape[2], img.shape[3]],
mode='bilinear',
align_corners=Align_Corners_Range).squeeze(1)
else:
cur_depth = depth_values
depth_range_samples = get_depth_range_samples(
cur_depth=cur_depth,
ndepth=self.ndepths[stage_idx],
depth_inteval_pixel=self.depth_interals_ratio[stage_idx]
* depth_interval,
dtype=img[0].dtype,
device=img[0].device,
shape=[img.shape[0], img.shape[2], img.shape[3]],
max_depth=depth_max,
min_depth=depth_min)
outputs_stage = self.DepthNet(
features_stage,
proj_matrices_stage,
depth_values=F.interpolate(
depth_range_samples.unsqueeze(1), [
self.ndepths[stage_idx], img.shape[2]
// int(stage_scale), img.shape[3] // int(stage_scale)
],
mode='trilinear',
align_corners=Align_Corners_Range).squeeze(1),
num_depth=self.ndepths[stage_idx],
cost_regularization=self.cost_regularization
if self.share_cr else self.cost_regularization[stage_idx])
depth = outputs_stage['depth']
outputs['stage{}'.format(stage_idx + 1)] = outputs_stage
outputs.update(outputs_stage)
# depth map refinement
if self.refine:
refined_depth = self.refine_network(
torch.cat((imgs[:, 0], depth), 1))
outputs['refined_depth'] = refined_depth
return outputs

View File

@@ -0,0 +1,164 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import os.path as osp
import cv2
import numpy as np
import torch
from easydict import EasyDict as edict
from modelscope.metainfo import Models
from modelscope.models.base.base_torch_model import TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from .cas_mvsnet import CascadeMVSNet
from .colmap2mvsnet import processing_single_scene
from .depth_filter import pcd_depth_filter
from .general_eval_dataset import MVSDataset, save_pfm
from .utils import (generate_pointcloud, numpy2torch, tensor2numpy, tocuda,
write_cam)
logger = get_logger()
@MODELS.register_module(
Tasks.image_multi_view_depth_estimation,
module_name=Models.image_casmvs_depth_estimation)
class ImageMultiViewDepthEstimation(TorchModel):
def __init__(self, model_dir: str, **kwargs):
"""str -- model file root."""
super().__init__(model_dir, **kwargs)
# build model
self.model = CascadeMVSNet(
refine=False,
ndepths=[48, 32, 8],
depth_interals_ratio=[float(d_i) for d_i in [4, 2, 1]],
share_cr=False,
cr_base_chs=[8, 8, 8],
grad_method='detach')
# load checkpoint file
ckpt_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE)
logger.info(f'loading model {ckpt_path}')
state_dict = torch.load(ckpt_path, map_location=torch.device('cpu'))
self.model.load_state_dict(state_dict['model'], strict=True)
if torch.cuda.is_available():
self.device = 'cuda'
else:
self.device = 'cpu'
self.model.to(self.device)
self.model.eval()
logger.info(f'model init done! Device:{self.device}')
def preprocess_make_pair(self, inputs):
data = inputs['input_dir']
casmvs_inp_dir = inputs['casmvs_inp_dir']
args = edict()
args.dense_folder = data
args.save_folder = casmvs_inp_dir
args.max_d = 192
args.interval_scale = 1.06
args.theta0 = 5
args.sigma1 = 1
args.sigma2 = 10
args.model_ext = '.bin'
logger.info('preprocess of making pair data start')
processing_single_scene(args)
logger.info('preprocess of making pair data done')
def forward(self, inputs):
test_dir = os.path.dirname(inputs['casmvs_inp_dir'])
scene = os.path.basename(inputs['casmvs_inp_dir'])
test_list = [scene]
save_dir = inputs['casmvs_res_dir']
logger.info('depth estimation start')
test_dataset = MVSDataset(
test_dir,
test_list,
'test',
5,
192,
1.06,
max_h=1200,
max_w=1200,
fix_res=False)
with torch.no_grad():
for batch_idx, sample in enumerate(test_dataset):
sample = numpy2torch(sample)
if self.device == 'cuda':
sample_cuda = tocuda(sample)
proj_matrices_dict = sample_cuda['proj_matrices']
proj_matrices_dict_new = {}
for k, v in proj_matrices_dict.items():
proj_matrices_dict_new[k] = v.unsqueeze(0)
outputs = self.model(sample_cuda['imgs'].unsqueeze(0),
proj_matrices_dict_new,
sample_cuda['depth_values'].unsqueeze(0))
outputs = tensor2numpy(outputs)
del sample_cuda
filenames = [sample['filename']]
cams = sample['proj_matrices']['stage{}'.format(3)].unsqueeze(
0).numpy()
imgs = sample['imgs'].unsqueeze(0).numpy()
# save depth maps and confidence maps
for filename, cam, img, depth_est, photometric_confidence in zip(
filenames, cams, imgs, outputs['depth'],
outputs['photometric_confidence']):
img = img[0] # ref view
cam = cam[0] # ref cam
depth_filename = os.path.join(
save_dir, filename.format('depth_est', '.pfm'))
confidence_filename = os.path.join(
save_dir, filename.format('confidence', '.pfm'))
cam_filename = os.path.join(
save_dir, filename.format('cams', '_cam.txt'))
img_filename = os.path.join(
save_dir, filename.format('images', '.jpg'))
ply_filename = os.path.join(
save_dir, filename.format('ply_local', '.ply'))
os.makedirs(
depth_filename.rsplit('/', 1)[0], exist_ok=True)
os.makedirs(
confidence_filename.rsplit('/', 1)[0], exist_ok=True)
os.makedirs(cam_filename.rsplit('/', 1)[0], exist_ok=True)
os.makedirs(img_filename.rsplit('/', 1)[0], exist_ok=True)
os.makedirs(ply_filename.rsplit('/', 1)[0], exist_ok=True)
# save depth maps
save_pfm(depth_filename, depth_est)
# save confidence maps
save_pfm(confidence_filename, photometric_confidence)
# save cams, img
write_cam(cam_filename, cam)
img = np.clip(np.transpose(img, (1, 2, 0)) * 255, 0,
255).astype(np.uint8)
img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
cv2.imwrite(img_filename, img_bgr)
logger.info('depth estimation end')
return inputs
def postprocess(self, inputs):
test_dir = os.path.dirname(inputs['casmvs_inp_dir'])
scene = os.path.basename(inputs['casmvs_inp_dir'])
logger.info('depth fusion start')
pcd = pcd_depth_filter(
scene, test_dir, inputs['casmvs_res_dir'], thres_view=4)
logger.info('depth fusion end')
return pcd

View File

@@ -0,0 +1,472 @@
# The implementation is borrowed from https://github.com/YoYo000/MVSNet. Model reading is provided by COLMAP.
from __future__ import print_function
import collections
import multiprocessing as mp
import os
import shutil
import struct
from functools import partial
import cv2
import numpy as np
# ============================ read_model.py ============================#
CameraModel = collections.namedtuple('CameraModel',
['model_id', 'model_name', 'num_params'])
Camera = collections.namedtuple('Camera',
['id', 'model', 'width', 'height', 'params'])
BaseImage = collections.namedtuple(
'Image', ['id', 'qvec', 'tvec', 'camera_id', 'name', 'xys', 'point3D_ids'])
Point3D = collections.namedtuple(
'Point3D', ['id', 'xyz', 'rgb', 'error', 'image_ids', 'point2D_idxs'])
class Image(BaseImage):
def qvec2rotmat(self):
return qvec2rotmat(self.qvec)
CAMERA_MODELS = {
CameraModel(model_id=0, model_name='SIMPLE_PINHOLE', num_params=3),
CameraModel(model_id=1, model_name='PINHOLE', num_params=4),
CameraModel(model_id=2, model_name='SIMPLE_RADIAL', num_params=4),
CameraModel(model_id=3, model_name='RADIAL', num_params=5),
CameraModel(model_id=4, model_name='OPENCV', num_params=8),
CameraModel(model_id=5, model_name='OPENCV_FISHEYE', num_params=8),
CameraModel(model_id=6, model_name='FULL_OPENCV', num_params=12),
CameraModel(model_id=7, model_name='FOV', num_params=5),
CameraModel(model_id=8, model_name='SIMPLE_RADIAL_FISHEYE', num_params=4),
CameraModel(model_id=9, model_name='RADIAL_FISHEYE', num_params=5),
CameraModel(model_id=10, model_name='THIN_PRISM_FISHEYE', num_params=12)
}
CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)
for camera_model in CAMERA_MODELS])
def read_next_bytes(fid,
num_bytes,
format_char_sequence,
endian_character='<'):
"""Read and unpack the next bytes from a binary file.
:param fid:
:param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
:param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
:param endian_character: Any of {@, =, <, >, !}
:return: Tuple of read and unpacked values.
"""
data = fid.read(num_bytes)
return struct.unpack(endian_character + format_char_sequence, data)
def read_cameras_text(path):
cameras = {}
with open(path, 'r', encoding='utf-8') as fid:
while True:
line = fid.readline()
if not line:
break
line = line.strip()
if len(line) > 0 and line[0] != '#':
elems = line.split()
camera_id = int(elems[0])
model = elems[1]
width = int(elems[2])
height = int(elems[3])
params = np.array(tuple(map(float, elems[4:])))
cameras[camera_id] = Camera(
id=camera_id,
model=model,
width=width,
height=height,
params=params)
return cameras
def read_cameras_binary(path_to_model_file):
cameras = {}
with open(path_to_model_file, 'rb') as fid:
num_cameras = read_next_bytes(fid, 8, 'Q')[0]
for camera_line_index in range(num_cameras):
camera_properties = read_next_bytes(
fid, num_bytes=24, format_char_sequence='iiQQ')
camera_id = camera_properties[0]
model_id = camera_properties[1]
model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
width = camera_properties[2]
height = camera_properties[3]
num_params = CAMERA_MODEL_IDS[model_id].num_params
params = read_next_bytes(
fid,
num_bytes=8 * num_params,
format_char_sequence='d' * num_params)
cameras[camera_id] = Camera(
id=camera_id,
model=model_name,
width=width,
height=height,
params=np.array(params))
assert len(cameras) == num_cameras
return cameras
def read_images_text(path):
images = {}
with open(path, 'r', encoding='utf-8') as fid:
while True:
line = fid.readline()
if not line:
break
line = line.strip()
if len(line) > 0 and line[0] != '#':
elems = line.split()
image_id = int(elems[0])
qvec = np.array(tuple(map(float, elems[1:5])))
tvec = np.array(tuple(map(float, elems[5:8])))
camera_id = int(elems[8])
image_name = elems[9]
elems = fid.readline().split()
xys = np.column_stack([
tuple(map(float, elems[0::3])),
tuple(map(float, elems[1::3]))
])
point3D_ids = np.array(tuple(map(int, elems[2::3])))
images[image_id] = Image(
id=image_id,
qvec=qvec,
tvec=tvec,
camera_id=camera_id,
name=image_name,
xys=xys,
point3D_ids=point3D_ids)
return images
def read_images_binary(path_to_model_file):
images = {}
with open(path_to_model_file, 'rb') as fid:
num_reg_images = read_next_bytes(fid, 8, 'Q')[0]
for image_index in range(num_reg_images):
binary_image_properties = read_next_bytes(
fid, num_bytes=64, format_char_sequence='idddddddi')
image_id = binary_image_properties[0]
qvec = np.array(binary_image_properties[1:5])
tvec = np.array(binary_image_properties[5:8])
camera_id = binary_image_properties[8]
image_name = ''
current_char = read_next_bytes(fid, 1, 'c')[0]
while current_char != b'\x00': # look for the ASCII 0 entry
image_name += current_char.decode('utf-8')
current_char = read_next_bytes(fid, 1, 'c')[0]
num_points2D = read_next_bytes(
fid, num_bytes=8, format_char_sequence='Q')[0]
x_y_id_s = read_next_bytes(
fid,
num_bytes=24 * num_points2D,
format_char_sequence='ddq' * num_points2D)
xys = np.column_stack([
tuple(map(float, x_y_id_s[0::3])),
tuple(map(float, x_y_id_s[1::3]))
])
point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
images[image_id] = Image(
id=image_id,
qvec=qvec,
tvec=tvec,
camera_id=camera_id,
name=image_name,
xys=xys,
point3D_ids=point3D_ids)
return images
def read_points3D_text(path):
points3D = {}
with open(path, 'r', encoding='utf-8') as fid:
while True:
line = fid.readline()
if not line:
break
line = line.strip()
if len(line) > 0 and line[0] != '#':
elems = line.split()
point3D_id = int(elems[0])
xyz = np.array(tuple(map(float, elems[1:4])))
rgb = np.array(tuple(map(int, elems[4:7])))
error = float(elems[7])
image_ids = np.array(tuple(map(int, elems[8::2])))
point2D_idxs = np.array(tuple(map(int, elems[9::2])))
points3D[point3D_id] = Point3D(
id=point3D_id,
xyz=xyz,
rgb=rgb,
error=error,
image_ids=image_ids,
point2D_idxs=point2D_idxs)
return points3D
def read_points3d_binary(path_to_model_file):
points3D = {}
with open(path_to_model_file, 'rb') as fid:
num_points = read_next_bytes(fid, 8, 'Q')[0]
for point_line_index in range(num_points):
binary_point_line_properties = read_next_bytes(
fid, num_bytes=43, format_char_sequence='QdddBBBd')
point3D_id = binary_point_line_properties[0]
xyz = np.array(binary_point_line_properties[1:4])
rgb = np.array(binary_point_line_properties[4:7])
error = np.array(binary_point_line_properties[7])
track_length = read_next_bytes(
fid, num_bytes=8, format_char_sequence='Q')[0]
track_elems = read_next_bytes(
fid,
num_bytes=8 * track_length,
format_char_sequence='ii' * track_length)
image_ids = np.array(tuple(map(int, track_elems[0::2])))
point2D_idxs = np.array(tuple(map(int, track_elems[1::2])))
points3D[point3D_id] = Point3D(
id=point3D_id,
xyz=xyz,
rgb=rgb,
error=error,
image_ids=image_ids,
point2D_idxs=point2D_idxs)
return points3D
def read_model(path, ext):
if ext == '.txt':
cameras = read_cameras_text(os.path.join(path, 'cameras' + ext))
images = read_images_text(os.path.join(path, 'images' + ext))
points3D = read_points3D_text(os.path.join(path, 'points3D') + ext)
else:
cameras = read_cameras_binary(os.path.join(path, 'cameras' + ext))
images = read_images_binary(os.path.join(path, 'images' + ext))
points3D = read_points3d_binary(os.path.join(path, 'points3D') + ext)
return cameras, images, points3D
def qvec2rotmat(qvec):
return np.array([
[
1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]
], # noqa
[
2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]
], # noqa
[
2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
1 - 2 * qvec[1]**2 - 2 * qvec[2]**2
]
]) # noqa
def rotmat2qvec(R):
Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
K = np.array(
[[Rxx - Ryy - Rzz, 0, 0, 0], [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
[Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
[Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 # noqa
eigvals, eigvecs = np.linalg.eigh(K)
qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
if qvec[0] < 0:
qvec *= -1
return qvec
def calc_score(inputs, images, points3d, extrinsic, args):
i, j = inputs
id_i = images[i + 1].point3D_ids
id_j = images[j + 1].point3D_ids
id_intersect = [it for it in id_i if it in id_j]
cam_center_i = -np.matmul(extrinsic[i + 1][:3, :3].transpose(),
extrinsic[i + 1][:3, 3:4])[:, 0]
cam_center_j = -np.matmul(extrinsic[j + 1][:3, :3].transpose(),
extrinsic[j + 1][:3, 3:4])[:, 0]
score = 0
for pid in id_intersect:
if pid == -1:
continue
p = points3d[pid].xyz
theta = (180 / np.pi) * np.arccos(
np.dot(cam_center_i - p, cam_center_j - p)
/ np.linalg.norm(cam_center_i - p)
/ np.linalg.norm(cam_center_j - p))
tmp_value = (
2 * # noqa
(args.sigma1 if theta <= args.theta0 else args.sigma2)**2)
score += np.exp(-(theta - args.theta0) * # noqa
(theta - args.theta0) / tmp_value)
return i, j, score
def processing_single_scene(args):
image_dir = os.path.join(args.dense_folder, 'images')
model_dir = os.path.join(args.dense_folder, 'sparse')
cam_dir = os.path.join(args.save_folder, 'cams')
image_converted_dir = os.path.join(args.save_folder, 'images_post')
if os.path.exists(image_converted_dir):
shutil.rmtree(image_converted_dir)
os.makedirs(image_converted_dir)
if os.path.exists(cam_dir):
shutil.rmtree(cam_dir)
cameras, images, points3d = read_model(model_dir, args.model_ext)
num_images = len(list(images.items()))
param_type = {
'SIMPLE_PINHOLE': ['f', 'cx', 'cy'],
'PINHOLE': ['fx', 'fy', 'cx', 'cy'],
'SIMPLE_RADIAL': ['f', 'cx', 'cy', 'k'],
'SIMPLE_RADIAL_FISHEYE': ['f', 'cx', 'cy', 'k'],
'RADIAL': ['f', 'cx', 'cy', 'k1', 'k2'],
'RADIAL_FISHEYE': ['f', 'cx', 'cy', 'k1', 'k2'],
'OPENCV': ['fx', 'fy', 'cx', 'cy', 'k1', 'k2', 'p1', 'p2'],
'OPENCV_FISHEYE': ['fx', 'fy', 'cx', 'cy', 'k1', 'k2', 'k3', 'k4'],
'FULL_OPENCV': [
'fx', 'fy', 'cx', 'cy', 'k1', 'k2', 'p1', 'p2', 'k3', 'k4', 'k5',
'k6'
],
'FOV': ['fx', 'fy', 'cx', 'cy', 'omega'],
'THIN_PRISM_FISHEYE': [
'fx', 'fy', 'cx', 'cy', 'k1', 'k2', 'p1', 'p2', 'k3', 'k4', 'sx1',
'sy1'
]
}
# intrinsic
intrinsic = {}
for camera_id, cam in cameras.items():
params_dict = {
key: value
for key, value in zip(param_type[cam.model], cam.params)
}
if 'f' in param_type[cam.model]:
params_dict['fx'] = params_dict['f']
params_dict['fy'] = params_dict['f']
i = np.array([[params_dict['fx'], 0, params_dict['cx']],
[0, params_dict['fy'], params_dict['cy']], [0, 0, 1]])
intrinsic[camera_id] = i
new_images = {}
for i, image_id in enumerate(sorted(images.keys())):
new_images[i + 1] = images[image_id]
images = new_images
# extrinsic
extrinsic = {}
for image_id, image in images.items():
e = np.zeros((4, 4))
e[:3, :3] = qvec2rotmat(image.qvec)
e[:3, 3] = image.tvec
e[3, 3] = 1
extrinsic[image_id] = e
# depth range and interval
depth_ranges = {}
for i in range(num_images):
zs = []
for p3d_id in images[i + 1].point3D_ids:
if p3d_id == -1:
continue
transformed = np.matmul(extrinsic[i + 1], [
points3d[p3d_id].xyz[0], points3d[p3d_id].xyz[1],
points3d[p3d_id].xyz[2], 1
])
zs.append(np.asscalar(transformed[2]))
zs_sorted = sorted(zs)
# relaxed depth range
max_ratio = 0.1
min_ratio = 0.03
num_max = max(5, int(len(zs) * max_ratio))
num_min = max(1, int(len(zs) * min_ratio))
depth_min = 1.0 * sum(zs_sorted[:num_min]) / len(zs_sorted[:num_min])
depth_max = 1.0 * sum(zs_sorted[-num_max:]) / len(zs_sorted[-num_max:])
if args.max_d == 0:
image_int = intrinsic[images[i + 1].camera_id]
image_ext = extrinsic[i + 1]
image_r = image_ext[0:3, 0:3]
image_t = image_ext[0:3, 3]
p1 = [image_int[0, 2], image_int[1, 2], 1]
p2 = [image_int[0, 2] + 1, image_int[1, 2], 1]
P1 = np.matmul(np.linalg.inv(image_int), p1) * depth_min
P1 = np.matmul(np.linalg.inv(image_r), (P1 - image_t))
P2 = np.matmul(np.linalg.inv(image_int), p2) * depth_min
P2 = np.matmul(np.linalg.inv(image_r), (P2 - image_t))
depth_num = (1 / depth_min - 1 / depth_max) / (
1 / depth_min - 1 / (depth_min + np.linalg.norm(P2 - P1)))
else:
depth_num = args.max_d
depth_interval = (depth_max - depth_min) / (depth_num
- 1) / args.interval_scale
depth_ranges[i + 1] = (depth_min, depth_interval, depth_num, depth_max)
# view selection
score = np.zeros((len(images), len(images)))
queue = []
for i in range(len(images)):
for j in range(i + 1, len(images)):
queue.append((i, j))
p = mp.Pool(processes=mp.cpu_count())
func = partial(
calc_score,
images=images,
points3d=points3d,
args=args,
extrinsic=extrinsic)
result = p.map(func, queue)
for i, j, s in result:
score[i, j] = s
score[j, i] = s
view_sel = []
for i in range(len(images)):
sorted_score = np.argsort(score[i])[::-1]
view_sel.append([(k, score[i, k]) for k in sorted_score[:10]])
# write
os.makedirs(cam_dir, exist_ok=True)
for i in range(num_images):
with open(os.path.join(cam_dir, '%08d_cam.txt' % i), 'w') as f:
f.write('extrinsic\n')
for j in range(4):
for k in range(4):
f.write(str(extrinsic[i + 1][j, k]) + ' ')
f.write('\n')
f.write('\nintrinsic\n')
for j in range(3):
for k in range(3):
f.write(
str(intrinsic[images[i + 1].camera_id][j, k]) + ' ')
f.write('\n')
f.write('\n%f %f %f %f\n' %
(depth_ranges[i + 1][0], depth_ranges[i + 1][1],
depth_ranges[i + 1][2], depth_ranges[i + 1][3]))
with open(os.path.join(args.save_folder, 'pair.txt'), 'w') as f:
f.write('%d\n' % len(images))
for i, sorted_score in enumerate(view_sel):
f.write('%d\n%d ' % (i, len(sorted_score)))
for image_id, s in sorted_score:
f.write('%d %f ' % (image_id, s))
f.write('\n')
# convert to jpg
for i in range(num_images):
img_path = os.path.join(image_dir, images[i + 1].name)
if not img_path.endswith('.jpg'):
img = cv2.imread(img_path)
cv2.imwrite(os.path.join(image_converted_dir, '%08d.jpg' % i), img)
else:
shutil.copyfile(
os.path.join(image_dir, images[i + 1].name),
os.path.join(image_converted_dir, '%08d.jpg' % i))

View File

@@ -0,0 +1,249 @@
# The implementation here is modified based on https://github.com/xy-guo/MVSNet_pytorch
import os
import cv2
import numpy as np
from PIL import Image
from plyfile import PlyData, PlyElement
from .general_eval_dataset import read_pfm
# read intrinsics and extrinsics
def read_camera_parameters(filename):
with open(filename) as f:
lines = f.readlines()
lines = [line.rstrip() for line in lines]
# extrinsics: line [1,5), 4x4 matrix
extrinsics = np.fromstring(
' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4))
# intrinsics: line [7-10), 3x3 matrix
intrinsics = np.fromstring(
' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3))
# assume the feature is 1/4 of the original image size
# intrinsics[:2, :] /= 4
return intrinsics, extrinsics
# read an image
def read_img(filename):
img = Image.open(filename)
# scale 0~255 to 0~1
np_img = np.array(img, dtype=np.float32) / 255.
return np_img
# read a binary mask
def read_mask(filename):
return read_img(filename) > 0.5
# save a binary mask
def save_mask(filename, mask):
assert mask.dtype == np.bool
mask = mask.astype(np.uint8) * 255
Image.fromarray(mask).save(filename)
# read a pair file, [(ref_view1, [src_view1-1, ...]), (ref_view2, [src_view2-1, ...]), ...]
def read_pair_file(filename):
data = []
with open(filename) as f:
num_viewpoint = int(f.readline())
# 49 viewpoints
for view_idx in range(num_viewpoint):
ref_view = int(f.readline().rstrip())
src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
if len(src_views) > 0:
data.append((ref_view, src_views))
return data
# project the reference point cloud into the source view, then project back
def reproject_with_depth(depth_ref, intrinsics_ref, extrinsics_ref, depth_src,
intrinsics_src, extrinsics_src):
width, height = depth_ref.shape[1], depth_ref.shape[0]
# step1. project reference pixels to the source view
# reference view x, y
x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height))
x_ref, y_ref = x_ref.reshape([-1]), y_ref.reshape([-1])
# reference 3D space
xyz_ref = np.matmul(
np.linalg.inv(intrinsics_ref),
np.vstack(
(x_ref, y_ref, np.ones_like(x_ref))) * depth_ref.reshape([-1]))
# source 3D space
xyz_src = np.matmul(
np.matmul(extrinsics_src, np.linalg.inv(extrinsics_ref)),
np.vstack((xyz_ref, np.ones_like(x_ref))))[:3]
# source view x, y
K_xyz_src = np.matmul(intrinsics_src, xyz_src)
xy_src = K_xyz_src[:2] / K_xyz_src[2:3]
# step2. reproject the source view points with source view depth estimation
# find the depth estimation of the source view
x_src = xy_src[0].reshape([height, width]).astype(np.float32)
y_src = xy_src[1].reshape([height, width]).astype(np.float32)
sampled_depth_src = cv2.remap(
depth_src, x_src, y_src, interpolation=cv2.INTER_LINEAR)
# source 3D space
# NOTE that we should use sampled source-view depth_here to project back
xyz_src = np.matmul(
np.linalg.inv(intrinsics_src),
np.vstack(
(xy_src, np.ones_like(x_ref))) * sampled_depth_src.reshape([-1]))
# reference 3D space
xyz_reprojected = np.matmul(
np.matmul(extrinsics_ref, np.linalg.inv(extrinsics_src)),
np.vstack((xyz_src, np.ones_like(x_ref))))[:3]
# source view x, y, depth
depth_reprojected = xyz_reprojected[2].reshape([height,
width]).astype(np.float32)
K_xyz_reprojected = np.matmul(intrinsics_ref, xyz_reprojected)
xy_reprojected = K_xyz_reprojected[:2] / K_xyz_reprojected[2:3]
x_reprojected = xy_reprojected[0].reshape([height,
width]).astype(np.float32)
y_reprojected = xy_reprojected[1].reshape([height,
width]).astype(np.float32)
return depth_reprojected, x_reprojected, y_reprojected, x_src, y_src
def check_geometric_consistency(depth_ref, intrinsics_ref, extrinsics_ref,
depth_src, intrinsics_src, extrinsics_src):
width, height = depth_ref.shape[1], depth_ref.shape[0]
x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height))
depth_reprojected, x2d_reprojected, y2d_reprojected, x2d_src, y2d_src = reproject_with_depth(
depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src,
extrinsics_src)
# check |p_reproj-p_1| < 1
dist = np.sqrt((x2d_reprojected - x_ref)**2 + (y2d_reprojected - y_ref)**2)
# check |d_reproj-d_1| / d_1 < 0.01
depth_diff = np.abs(depth_reprojected - depth_ref)
relative_depth_diff = depth_diff / depth_ref
mask = np.logical_and(dist < 1, relative_depth_diff < 0.01)
depth_reprojected[~mask] = 0
return mask, depth_reprojected, x2d_src, y2d_src
def filter_depth(pair_folder, scan_folder, out_folder, thres_view):
# the pair file
pair_file = os.path.join(pair_folder, 'pair.txt')
# for the final point cloud
vertexs = []
vertex_colors = []
pair_data = read_pair_file(pair_file)
# for each reference view and the corresponding source views
for ref_view, src_views in pair_data:
# src_views = src_views[:args.num_view]
# load the camera parameters
ref_intrinsics, ref_extrinsics = read_camera_parameters(
os.path.join(scan_folder, 'cams/{:0>8}_cam.txt'.format(ref_view)))
# load the reference image
ref_img = read_img(
os.path.join(scan_folder, 'images/{:0>8}.jpg'.format(ref_view)))
# load the estimated depth of the reference view
ref_depth_est = read_pfm(
os.path.join(out_folder,
'depth_est/{:0>8}.pfm'.format(ref_view)))[0]
# load the photometric mask of the reference view
confidence = read_pfm(
os.path.join(out_folder,
'confidence/{:0>8}.pfm'.format(ref_view)))[0]
photo_mask = confidence > 0.9
all_srcview_depth_ests = []
all_srcview_x = []
all_srcview_y = []
all_srcview_geomask = []
# compute the geometric mask
geo_mask_sum = 0
for src_view in src_views:
# camera parameters of the source view
src_intrinsics, src_extrinsics = read_camera_parameters(
os.path.join(scan_folder,
'cams/{:0>8}_cam.txt'.format(src_view)))
# the estimated depth of the source view
src_depth_est = read_pfm(
os.path.join(out_folder,
'depth_est/{:0>8}.pfm'.format(src_view)))[0]
geo_mask, depth_reprojected, x2d_src, y2d_src = check_geometric_consistency(
ref_depth_est, ref_intrinsics, ref_extrinsics, src_depth_est,
src_intrinsics, src_extrinsics)
geo_mask_sum += geo_mask.astype(np.int32)
all_srcview_depth_ests.append(depth_reprojected)
all_srcview_x.append(x2d_src)
all_srcview_y.append(y2d_src)
all_srcview_geomask.append(geo_mask)
depth_est_averaged = (sum(all_srcview_depth_ests) + ref_depth_est) / (
geo_mask_sum + 1)
# at least 3 source views matched
geo_mask = geo_mask_sum >= thres_view
final_mask = np.logical_and(photo_mask, geo_mask)
os.makedirs(os.path.join(out_folder, 'mask'), exist_ok=True)
save_mask(
os.path.join(out_folder, 'mask/{:0>8}_photo.png'.format(ref_view)),
photo_mask)
save_mask(
os.path.join(out_folder, 'mask/{:0>8}_geo.png'.format(ref_view)),
geo_mask)
save_mask(
os.path.join(out_folder, 'mask/{:0>8}_final.png'.format(ref_view)),
final_mask)
height, width = depth_est_averaged.shape[:2]
x, y = np.meshgrid(np.arange(0, width), np.arange(0, height))
valid_points = final_mask
x, y, depth = x[valid_points], y[valid_points], depth_est_averaged[
valid_points]
color = ref_img[valid_points]
xyz_ref = np.matmul(
np.linalg.inv(ref_intrinsics),
np.vstack((x, y, np.ones_like(x))) * depth)
xyz_world = np.matmul(
np.linalg.inv(ref_extrinsics), np.vstack(
(xyz_ref, np.ones_like(x))))[:3]
vertexs.append(xyz_world.transpose((1, 0)))
vertex_colors.append((color * 255).astype(np.uint8))
vertexs = np.concatenate(vertexs, axis=0)
vertex_colors = np.concatenate(vertex_colors, axis=0)
vertexs = np.array([tuple(v) for v in vertexs],
dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])
vertex_colors = np.array([tuple(v) for v in vertex_colors],
dtype=[('red', 'u1'), ('green', 'u1'),
('blue', 'u1')])
vertex_all = np.empty(
len(vertexs), vertexs.dtype.descr + vertex_colors.dtype.descr)
for prop in vertexs.dtype.names:
vertex_all[prop] = vertexs[prop]
for prop in vertex_colors.dtype.names:
vertex_all[prop] = vertex_colors[prop]
el = PlyElement.describe(vertex_all, 'vertex')
# PlyData([el]).write(plyfilename)
pcd = PlyData([el])
return pcd
def pcd_depth_filter(scene, test_dir, save_dir, thres_view):
old_scene_folder = os.path.join(test_dir, scene)
new_scene_folder = os.path.join(save_dir, scene)
out_folder = os.path.join(save_dir, scene)
pcd = filter_depth(old_scene_folder, new_scene_folder, out_folder,
thres_view)
return pcd

View File

@@ -0,0 +1,284 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import re
import sys
import cv2
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
def read_pfm(filename):
file = open(filename, 'rb')
color = None
width = None
height = None
scale = None
endian = None
header = file.readline().decode('utf-8').rstrip()
if header == 'PF':
color = True
elif header == 'Pf':
color = False
else:
raise Exception('Not a PFM file.')
dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8'))
if dim_match:
width, height = map(int, dim_match.groups())
else:
raise Exception('Malformed PFM header.')
scale = float(file.readline().rstrip())
if scale < 0: # little-endian
endian = '<'
scale = -scale
else:
endian = '>' # big-endian
data = np.fromfile(file, endian + 'f')
shape = (height, width, 3) if color else (height, width)
data = np.reshape(data, shape)
data = np.flipud(data)
file.close()
return data, scale
def save_pfm(filename, image, scale=1):
file = open(filename, 'wb')
color = None
image = np.flipud(image)
if image.dtype.name != 'float32':
raise Exception('Image dtype must be float32.')
if len(image.shape) == 3 and image.shape[2] == 3: # color image
color = True
elif len(image.shape) == 2 or len(
image.shape) == 3 and image.shape[2] == 1: # greyscale
color = False
else:
raise Exception(
'Image must have H x W x 3, H x W x 1 or H x W dimensions.')
file.write('PF\n'.encode('utf-8') if color else 'Pf\n'.encode('utf-8'))
file.write('{} {}\n'.format(image.shape[1],
image.shape[0]).encode('utf-8'))
endian = image.dtype.byteorder
if endian == '<' or endian == '=' and sys.byteorder == 'little':
scale = -scale
file.write(('%f\n' % scale).encode('utf-8'))
image.tofile(file)
file.close()
S_H, S_W = 0, 0
class MVSDataset(Dataset):
def __init__(self,
datapath,
listfile,
mode,
nviews,
ndepths=192,
interval_scale=1.06,
**kwargs):
super(MVSDataset, self).__init__()
self.datapath = datapath
self.listfile = listfile
self.mode = mode
self.nviews = nviews
self.ndepths = ndepths
self.interval_scale = interval_scale
self.max_h, self.max_w = kwargs['max_h'], kwargs['max_w']
self.fix_res = kwargs.get(
'fix_res', False) # whether to fix the resolution of input image.
self.fix_wh = False
assert self.mode == 'test'
self.metas = self.build_list()
def build_list(self):
metas = []
scans = self.listfile
interval_scale_dict = {}
# scans
for scan in scans:
# determine the interval scale of each scene. default is 1.06
if isinstance(self.interval_scale, float):
interval_scale_dict[scan] = self.interval_scale
else:
interval_scale_dict[scan] = self.interval_scale[scan]
pair_file = '{}/pair.txt'.format(scan)
# read the pair file
with open(os.path.join(self.datapath, pair_file)) as f:
num_viewpoint = int(f.readline())
# viewpoints
for view_idx in range(num_viewpoint):
ref_view = int(f.readline().rstrip())
src_views = [
int(x) for x in f.readline().rstrip().split()[1::2]
]
# filter by no src view and fill to nviews
if len(src_views) > 0:
if len(src_views) < self.nviews:
src_views += [src_views[0]] * (
self.nviews - len(src_views))
metas.append((scan, ref_view, src_views, scan))
self.interval_scale = interval_scale_dict
return metas
def __len__(self):
return len(self.metas)
def read_cam_file(self, filename, interval_scale):
with open(filename) as f:
lines = f.readlines()
lines = [line.rstrip() for line in lines]
# extrinsics: line [1,5), 4x4 matrix
extrinsics = np.fromstring(
' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4))
# intrinsics: line [7-10), 3x3 matrix
intrinsics = np.fromstring(
' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3))
intrinsics[:2, :] /= 4.0
# depth_min & depth_interval: line 11
depth_min = float(lines[11].split()[0])
depth_interval = float(lines[11].split()[1])
if len(lines[11].split()) >= 3:
num_depth = lines[11].split()[2]
depth_max = depth_min + int(float(num_depth)) * depth_interval
depth_interval = (depth_max - depth_min) / self.ndepths
depth_interval *= interval_scale
return intrinsics, extrinsics, depth_min, depth_interval
def read_img(self, filename):
img = Image.open(filename)
# scale 0~255 to 0~1
np_img = np.array(img, dtype=np.float32) / 255.
return np_img
def read_depth(self, filename):
# read pfm depth file
return np.array(read_pfm(filename)[0], dtype=np.float32)
def scale_mvs_input(self, img, intrinsics, max_w, max_h, base=32):
h, w = img.shape[:2]
if h > max_h or w > max_w:
scale = 1.0 * max_h / h
if scale * w > max_w:
scale = 1.0 * max_w / w
new_w, new_h = scale * w // base * base, scale * h // base * base
else:
new_w, new_h = 1.0 * w // base * base, 1.0 * h // base * base
scale_w = 1.0 * new_w / w
scale_h = 1.0 * new_h / h
intrinsics[0, :] *= scale_w
intrinsics[1, :] *= scale_h
img = cv2.resize(img, (int(new_w), int(new_h)))
return img, intrinsics
def __getitem__(self, idx):
global S_H, S_W
meta = self.metas[idx]
scan, ref_view, src_views, scene_name = meta
# use only the reference view and first nviews-1 source views
view_ids = [ref_view] + src_views[:self.nviews - 1]
imgs = []
depth_values = None
proj_matrices = []
for i, vid in enumerate(view_ids):
img_filename = os.path.join(
self.datapath, '{}/images_post/{:0>8}.jpg'.format(scan, vid))
if not os.path.exists(img_filename):
img_filename = os.path.join(
self.datapath, '{}/images/{:0>8}.jpg'.format(scan, vid))
proj_mat_filename = os.path.join(
self.datapath, '{}/cams/{:0>8}_cam.txt'.format(scan, vid))
img = self.read_img(img_filename)
intrinsics, extrinsics, depth_min, depth_interval = self.read_cam_file(
proj_mat_filename,
interval_scale=self.interval_scale[scene_name])
# scale input
img, intrinsics = self.scale_mvs_input(img, intrinsics, self.max_w,
self.max_h)
if self.fix_res:
# using the same standard height or width in entire scene.
S_H, S_W = img.shape[:2]
self.fix_res = False
self.fix_wh = True
if i == 0:
if not self.fix_wh:
# using the same standard height or width in each nviews.
S_H, S_W = img.shape[:2]
# resize to standard height or width
c_h, c_w = img.shape[:2]
if (c_h != S_H) or (c_w != S_W):
scale_h = 1.0 * S_H / c_h
scale_w = 1.0 * S_W / c_w
img = cv2.resize(img, (S_W, S_H))
intrinsics[0, :] *= scale_w
intrinsics[1, :] *= scale_h
imgs.append(img)
# extrinsics, intrinsics
proj_mat = np.zeros(shape=(2, 4, 4), dtype=np.float32) #
proj_mat[0, :4, :4] = extrinsics
proj_mat[1, :3, :3] = intrinsics
proj_matrices.append(proj_mat)
if i == 0: # reference view
depth_values = np.arange(
depth_min,
depth_interval * (self.ndepths - 0.5) + depth_min,
depth_interval,
dtype=np.float32)
# all
imgs = np.stack(imgs).transpose([0, 3, 1, 2])
proj_matrices = np.stack(proj_matrices)
stage2_pjmats = proj_matrices.copy()
stage2_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 2
stage3_pjmats = proj_matrices.copy()
stage3_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 4
proj_matrices_ms = {
'stage1': proj_matrices,
'stage2': stage2_pjmats,
'stage3': stage3_pjmats
}
return {
'imgs': imgs,
'proj_matrices': proj_matrices_ms,
'depth_values': depth_values,
'filename': scan + '/{}/' + '{:0>8}'.format(view_ids[0]) + '{}'
}

View File

@@ -0,0 +1,678 @@
# The implementation here is modified based on https://github.com/xy-guo/MVSNet_pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
def init_bn(module):
if module.weight is not None:
nn.init.ones_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
return
def init_uniform(module, init_method):
if module.weight is not None:
if init_method == 'kaiming':
nn.init.kaiming_uniform_(module.weight)
elif init_method == 'xavier':
nn.init.xavier_uniform_(module.weight)
return
class Conv2d(nn.Module):
"""Applies a 2D convolution (optionally with batch normalization and relu activation)
over an input signal composed of several input planes.
Attributes:
conv (nn.Module): convolution module
bn (nn.Module): batch normalization module
relu (bool): whether to activate by relu
Notes:
Default momentum for batch normalization is set to be 0.01,
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
relu=True,
bn=True,
bn_momentum=0.1,
init_method='xavier',
**kwargs):
super(Conv2d, self).__init__()
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
bias=(not bn),
**kwargs)
self.kernel_size = kernel_size
self.stride = stride
self.bn = nn.BatchNorm2d(
out_channels, momentum=bn_momentum) if bn else None
self.relu = relu
def forward(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu:
x = F.relu(x, inplace=True)
return x
def init_weights(self, init_method):
"""default initialization"""
init_uniform(self.conv, init_method)
if self.bn is not None:
init_bn(self.bn)
class Deconv2d(nn.Module):
"""Applies a 2D deconvolution (optionally with batch normalization and relu activation)
over an input signal composed of several input planes.
Attributes:
conv (nn.Module): convolution module
bn (nn.Module): batch normalization module
relu (bool): whether to activate by relu
Notes:
Default momentum for batch normalization is set to be 0.01,
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
relu=True,
bn=True,
bn_momentum=0.1,
init_method='xavier',
**kwargs):
super(Deconv2d, self).__init__()
self.out_channels = out_channels
assert stride in [1, 2]
self.stride = stride
self.conv = nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
bias=(not bn),
**kwargs)
self.bn = nn.BatchNorm2d(
out_channels, momentum=bn_momentum) if bn else None
self.relu = relu
def forward(self, x):
y = self.conv(x)
if self.stride == 2:
h, w = list(x.size())[2:]
y = y[:, :, :2 * h, :2 * w].contiguous()
if self.bn is not None:
x = self.bn(y)
if self.relu:
x = F.relu(x, inplace=True)
return x
def init_weights(self, init_method):
"""default initialization"""
init_uniform(self.conv, init_method)
if self.bn is not None:
init_bn(self.bn)
class Conv3d(nn.Module):
"""Applies a 3D convolution (optionally with batch normalization and relu activation)
over an input signal composed of several input planes.
Attributes:
conv (nn.Module): convolution module
bn (nn.Module): batch normalization module
relu (bool): whether to activate by relu
Notes:
Default momentum for batch normalization is set to be 0.01,
"""
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
relu=True,
bn=True,
bn_momentum=0.1,
init_method='xavier',
**kwargs):
super(Conv3d, self).__init__()
self.out_channels = out_channels
self.kernel_size = kernel_size
assert stride in [1, 2]
self.stride = stride
self.conv = nn.Conv3d(
in_channels,
out_channels,
kernel_size,
stride=stride,
bias=(not bn),
**kwargs)
self.bn = nn.BatchNorm3d(
out_channels, momentum=bn_momentum) if bn else None
self.relu = relu
def forward(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu:
x = F.relu(x, inplace=True)
return x
def init_weights(self, init_method):
"""default initialization"""
init_uniform(self.conv, init_method)
if self.bn is not None:
init_bn(self.bn)
class Deconv3d(nn.Module):
"""Applies a 3D deconvolution (optionally with batch normalization and relu activation)
over an input signal composed of several input planes.
Attributes:
conv (nn.Module): convolution module
bn (nn.Module): batch normalization module
relu (bool): whether to activate by relu
Notes:
Default momentum for batch normalization is set to be 0.01,
"""
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
relu=True,
bn=True,
bn_momentum=0.1,
init_method='xavier',
**kwargs):
super(Deconv3d, self).__init__()
self.out_channels = out_channels
assert stride in [1, 2]
self.stride = stride
self.conv = nn.ConvTranspose3d(
in_channels,
out_channels,
kernel_size,
stride=stride,
bias=(not bn),
**kwargs)
self.bn = nn.BatchNorm3d(
out_channels, momentum=bn_momentum) if bn else None
self.relu = relu
def forward(self, x):
y = self.conv(x)
if self.bn is not None:
x = self.bn(y)
if self.relu:
x = F.relu(x, inplace=True)
return x
def init_weights(self, init_method):
"""default initialization"""
init_uniform(self.conv, init_method)
if self.bn is not None:
init_bn(self.bn)
class ConvBnReLU(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
pad=1):
super(ConvBnReLU, self).__init__()
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=pad,
bias=False)
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
return F.relu(self.bn(self.conv(x)), inplace=True)
class ConvBn(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
pad=1):
super(ConvBn, self).__init__()
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=pad,
bias=False)
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
return self.bn(self.conv(x))
def homo_warping(src_fea, src_proj, ref_proj, depth_values):
"""
src_fea: [B, C, H, W]
src_proj: [B, 4, 4]
ref_proj: [B, 4, 4]
depth_values: [B, Ndepth] o [B, Ndepth, H, W]
out: [B, C, Ndepth, H, W]
"""
batch, channels = src_fea.shape[0], src_fea.shape[1]
num_depth = depth_values.shape[1]
height, width = src_fea.shape[2], src_fea.shape[3]
with torch.no_grad():
proj = torch.matmul(src_proj, torch.inverse(ref_proj))
rot = proj[:, :3, :3] # [B,3,3]
trans = proj[:, :3, 3:4] # [B,3,1]
y, x = torch.meshgrid([
torch.arange(
0, height, dtype=torch.float32, device=src_fea.device),
torch.arange(0, width, dtype=torch.float32, device=src_fea.device)
])
y, x = y.contiguous(), x.contiguous()
y, x = y.view(height * width), x.view(height * width)
xyz = torch.stack((x, y, torch.ones_like(x))) # [3, H*W]
xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1) # [B, 3, H*W]
rot_xyz = torch.matmul(rot, xyz) # [B, 3, H*W]
rot_depth_xyz = rot_xyz.unsqueeze(2).repeat(
1, 1, num_depth, 1) * depth_values.view(batch, 1, num_depth,
-1) # [B, 3, Ndepth, H*W]
proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1,
1) # [B, 3, Ndepth, H*W]
proj_xy = proj_xyz[:, :
2, :, :] / proj_xyz[:, 2:
3, :, :] # [B, 2, Ndepth, H*W]
proj_x_normalized = proj_xy[:, 0, :, :] / ((width - 1) / 2) - 1
proj_y_normalized = proj_xy[:, 1, :, :] / ((height - 1) / 2) - 1
proj_xy = torch.stack((proj_x_normalized, proj_y_normalized),
dim=3) # [B, Ndepth, H*W, 2]
grid = proj_xy
warped_src_fea = F.grid_sample(
src_fea,
grid.view(batch, num_depth * height, width, 2),
mode='bilinear',
padding_mode='zeros',
align_corners=True)
warped_src_fea = warped_src_fea.view(batch, channels, num_depth, height,
width)
return warped_src_fea
class DeConv2dFuse(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
relu=True,
bn=True,
bn_momentum=0.1):
super(DeConv2dFuse, self).__init__()
self.deconv = Deconv2d(
in_channels,
out_channels,
kernel_size,
stride=2,
padding=1,
output_padding=1,
bn=True,
relu=relu,
bn_momentum=bn_momentum)
self.conv = Conv2d(
2 * out_channels,
out_channels,
kernel_size,
stride=1,
padding=1,
bn=bn,
relu=relu,
bn_momentum=bn_momentum)
def forward(self, x_pre, x):
x = self.deconv(x)
x = torch.cat((x, x_pre), dim=1)
x = self.conv(x)
return x
class FeatureNet(nn.Module):
def __init__(self, base_channels, num_stage=3, stride=4, arch_mode='unet'):
super(FeatureNet, self).__init__()
assert arch_mode in [
'unet', 'fpn'
], f"mode must be in 'unet' or 'fpn', but get:{arch_mode}"
self.arch_mode = arch_mode
self.stride = stride
self.base_channels = base_channels
self.num_stage = num_stage
self.conv0 = nn.Sequential(
Conv2d(3, base_channels, 3, 1, padding=1),
Conv2d(base_channels, base_channels, 3, 1, padding=1),
)
self.conv1 = nn.Sequential(
Conv2d(base_channels, base_channels * 2, 5, stride=2, padding=2),
Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1),
Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1),
)
self.conv2 = nn.Sequential(
Conv2d(
base_channels * 2, base_channels * 4, 5, stride=2, padding=2),
Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1),
Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1),
)
self.out1 = nn.Conv2d(
base_channels * 4, base_channels * 4, 1, bias=False)
self.out_channels = [4 * base_channels]
if self.arch_mode == 'unet':
if num_stage == 3:
self.deconv1 = DeConv2dFuse(base_channels * 4,
base_channels * 2, 3)
self.deconv2 = DeConv2dFuse(base_channels * 2, base_channels,
3)
self.out2 = nn.Conv2d(
base_channels * 2, base_channels * 2, 1, bias=False)
self.out3 = nn.Conv2d(
base_channels, base_channels, 1, bias=False)
self.out_channels.append(2 * base_channels)
self.out_channels.append(base_channels)
elif num_stage == 2:
self.deconv1 = DeConv2dFuse(base_channels * 4,
base_channels * 2, 3)
self.out2 = nn.Conv2d(
base_channels * 2, base_channels * 2, 1, bias=False)
self.out_channels.append(2 * base_channels)
elif self.arch_mode == 'fpn':
final_chs = base_channels * 4
if num_stage == 3:
self.inner1 = nn.Conv2d(
base_channels * 2, final_chs, 1, bias=True)
self.inner2 = nn.Conv2d(
base_channels * 1, final_chs, 1, bias=True)
self.out2 = nn.Conv2d(
final_chs, base_channels * 2, 3, padding=1, bias=False)
self.out3 = nn.Conv2d(
final_chs, base_channels, 3, padding=1, bias=False)
self.out_channels.append(base_channels * 2)
self.out_channels.append(base_channels)
elif num_stage == 2:
self.inner1 = nn.Conv2d(
base_channels * 2, final_chs, 1, bias=True)
self.out2 = nn.Conv2d(
final_chs, base_channels, 3, padding=1, bias=False)
self.out_channels.append(base_channels)
def forward(self, x):
conv0 = self.conv0(x)
conv1 = self.conv1(conv0)
conv2 = self.conv2(conv1)
intra_feat = conv2
outputs = {}
out = self.out1(intra_feat)
outputs['stage1'] = out
if self.arch_mode == 'unet':
if self.num_stage == 3:
intra_feat = self.deconv1(conv1, intra_feat)
out = self.out2(intra_feat)
outputs['stage2'] = out
intra_feat = self.deconv2(conv0, intra_feat)
out = self.out3(intra_feat)
outputs['stage3'] = out
elif self.num_stage == 2:
intra_feat = self.deconv1(conv1, intra_feat)
out = self.out2(intra_feat)
outputs['stage2'] = out
elif self.arch_mode == 'fpn':
if self.num_stage == 3:
intra_feat = F.interpolate(
intra_feat, scale_factor=2,
mode='nearest') + self.inner1(conv1)
out = self.out2(intra_feat)
outputs['stage2'] = out
intra_feat = F.interpolate(
intra_feat, scale_factor=2,
mode='nearest') + self.inner2(conv0)
out = self.out3(intra_feat)
outputs['stage3'] = out
elif self.num_stage == 2:
intra_feat = F.interpolate(
intra_feat, scale_factor=2,
mode='nearest') + self.inner1(conv1)
out = self.out2(intra_feat)
outputs['stage2'] = out
return outputs
class CostRegNet(nn.Module):
def __init__(self, in_channels, base_channels):
super(CostRegNet, self).__init__()
self.conv0 = Conv3d(in_channels, base_channels, padding=1)
self.conv1 = Conv3d(
base_channels, base_channels * 2, stride=2, padding=1)
self.conv2 = Conv3d(base_channels * 2, base_channels * 2, padding=1)
self.conv3 = Conv3d(
base_channels * 2, base_channels * 4, stride=2, padding=1)
self.conv4 = Conv3d(base_channels * 4, base_channels * 4, padding=1)
self.conv5 = Conv3d(
base_channels * 4, base_channels * 8, stride=2, padding=1)
self.conv6 = Conv3d(base_channels * 8, base_channels * 8, padding=1)
self.conv7 = Deconv3d(
base_channels * 8,
base_channels * 4,
stride=2,
padding=1,
output_padding=1)
self.conv9 = Deconv3d(
base_channels * 4,
base_channels * 2,
stride=2,
padding=1,
output_padding=1)
self.conv11 = Deconv3d(
base_channels * 2,
base_channels * 1,
stride=2,
padding=1,
output_padding=1)
self.prob = nn.Conv3d(
base_channels, 1, 3, stride=1, padding=1, bias=False)
def forward(self, x):
conv0 = self.conv0(x)
conv2 = self.conv2(self.conv1(conv0))
conv4 = self.conv4(self.conv3(conv2))
x = self.conv6(self.conv5(conv4))
x = conv4 + self.conv7(x)
x = conv2 + self.conv9(x)
x = conv0 + self.conv11(x)
x = self.prob(x)
return x
class RefineNet(nn.Module):
def __init__(self):
super(RefineNet, self).__init__()
self.conv1 = ConvBnReLU(4, 32)
self.conv2 = ConvBnReLU(32, 32)
self.conv3 = ConvBnReLU(32, 32)
self.res = ConvBnReLU(32, 1)
def forward(self, img, depth_init):
concat = F.cat((img, depth_init), dim=1)
depth_residual = self.res(self.conv3(self.conv2(self.conv1(concat))))
depth_refined = depth_init + depth_residual
return depth_refined
def depth_regression(p, depth_values):
if depth_values.dim() <= 2:
depth_values = depth_values.view(*depth_values.shape, 1, 1)
depth = torch.sum(p * depth_values, 1)
return depth
def cas_mvsnet_loss(inputs, depth_gt_ms, mask_ms, **kwargs):
depth_loss_weights = kwargs.get('dlossw', None)
total_loss = torch.tensor(
0.0,
dtype=torch.float32,
device=mask_ms['stage1'].device,
requires_grad=False)
for (stage_inputs, stage_key) in [(inputs[k], k) for k in inputs.keys()
if 'stage' in k]:
depth_est = stage_inputs['depth']
depth_gt = depth_gt_ms[stage_key]
mask = mask_ms[stage_key]
mask = mask > 0.5
depth_loss = F.smooth_l1_loss(
depth_est[mask], depth_gt[mask], reduction='mean')
if depth_loss_weights is not None:
stage_idx = int(stage_key.replace('stage', '')) - 1
total_loss += depth_loss_weights[stage_idx] * depth_loss
else:
total_loss += 1.0 * depth_loss
return total_loss, depth_loss
def get_cur_depth_range_samples(cur_depth,
ndepth,
depth_inteval_pixel,
shape,
max_depth=192.0,
min_depth=0.0):
"""
shape, (B, H, W)
cur_depth: (B, H, W)
return depth_range_values: (B, D, H, W)
"""
cur_depth_min = (cur_depth - ndepth / 2 * depth_inteval_pixel) # (B, H, W)
cur_depth_max = (cur_depth + ndepth / 2 * depth_inteval_pixel)
assert cur_depth.shape == torch.Size(
shape), 'cur_depth:{}, input shape:{}'.format(cur_depth.shape, shape)
new_interval = (cur_depth_max - cur_depth_min) / (ndepth - 1) # (B, H, W)
depth_range_samples = cur_depth_min.unsqueeze(1) + (
torch.arange(
0,
ndepth,
device=cur_depth.device,
dtype=cur_depth.dtype,
requires_grad=False).reshape(1, -1, 1, 1)
* new_interval.unsqueeze(1))
return depth_range_samples
def get_depth_range_samples(cur_depth,
ndepth,
depth_inteval_pixel,
device,
dtype,
shape,
max_depth=192.0,
min_depth=0.0):
"""
shape: (B, H, W)
cur_depth: (B, H, W) or (B, D)
return depth_range_samples: (B, D, H, W)
"""
if cur_depth.dim() == 2:
cur_depth_min = cur_depth[:, 0] # (B,)
cur_depth_max = cur_depth[:, -1]
new_interval = (cur_depth_max - cur_depth_min) / (ndepth - 1) # (B, )
depth_range_samples = cur_depth_min.unsqueeze(1) + (torch.arange(
0, ndepth, device=device, dtype=dtype,
requires_grad=False).reshape(1, -1) * new_interval.unsqueeze(1)
) # noqa # (B, D)
depth_range_samples = depth_range_samples.unsqueeze(-1).unsqueeze(
-1).repeat(1, 1, shape[1], shape[2]) # (B, D, H, W)
else:
depth_range_samples = get_cur_depth_range_samples(
cur_depth, ndepth, depth_inteval_pixel, shape, max_depth,
min_depth)
return depth_range_samples

View File

@@ -0,0 +1,118 @@
# The implementation here is modified based on https://github.com/xy-guo/MVSNet_pytorch
import random
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.utils as vutils
# convert a function into recursive style to handle nested dict/list/tuple variables
def make_recursive_func(func):
def wrapper(vars):
if isinstance(vars, list):
return [wrapper(x) for x in vars]
elif isinstance(vars, tuple):
return tuple([wrapper(x) for x in vars])
elif isinstance(vars, dict):
return {k: wrapper(v) for k, v in vars.items()}
else:
return func(vars)
return wrapper
@make_recursive_func
def tensor2numpy(vars):
if isinstance(vars, np.ndarray):
return vars
elif isinstance(vars, torch.Tensor):
return vars.detach().cpu().numpy().copy()
else:
raise NotImplementedError(
'invalid input type {} for tensor2numpy'.format(type(vars)))
@make_recursive_func
def numpy2torch(vars):
if isinstance(vars, np.ndarray):
return torch.from_numpy(vars)
elif isinstance(vars, torch.Tensor):
return vars
elif isinstance(vars, str):
return vars
else:
raise NotImplementedError(
'invalid input type {} for numpy2torch'.format(type(vars)))
@make_recursive_func
def tocuda(vars):
if isinstance(vars, torch.Tensor):
return vars.to(torch.device('cuda'))
elif isinstance(vars, str):
return vars
else:
raise NotImplementedError(
'invalid input type {} for tensor2numpy'.format(type(vars)))
def generate_pointcloud(rgb, depth, ply_file, intr, scale=1.0):
"""
Generate a colored point cloud in PLY format from a color and a depth image.
Input:
rgb_file -- filename of color image
depth_file -- filename of depth image
ply_file -- filename of ply file
"""
fx, fy, cx, cy = intr[0, 0], intr[1, 1], intr[0, 2], intr[1, 2]
points = []
for v in range(rgb.shape[0]):
for u in range(rgb.shape[1]):
color = rgb[v, u] # rgb.getpixel((u, v))
Z = depth[v, u] / scale
if Z == 0:
continue
X = (u - cx) * Z / fx
Y = (v - cy) * Z / fy
points.append('%f %f %f %d %d %d 0\n' %
(X, Y, Z, color[0], color[1], color[2]))
file = open(ply_file, 'w')
file.write('''ply
format ascii 1.0
element vertex %d
property float x
property float y
property float z
property uchar red
property uchar green
property uchar blue
property uchar alpha
end_header
%s
''' % (len(points), ''.join(points)))
file.close()
def write_cam(file, cam):
f = open(file, 'w')
f.write('extrinsic\n')
for i in range(0, 4):
for j in range(0, 4):
f.write(str(cam[0][i][j]) + ' ')
f.write('\n')
f.write('\n')
f.write('intrinsic\n')
for i in range(0, 3):
for j in range(0, 3):
f.write(str(cam[1][i][j]) + ' ')
f.write('\n')
f.write('\n' + str(cam[1][3][0]) + ' ' + str(cam[1][3][1]) + ' '
+ str(cam[1][3][2]) + ' ' + str(cam[1][3][3]) + '\n')
f.close()

View File

@@ -245,6 +245,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.video_object_segmentation:
(Pipelines.video_object_segmentation,
'damo/cv_rdevos_video-object-segmentation'),
Tasks.image_multi_view_depth_estimation:
(Pipelines.image_multi_view_depth_estimation,
'damo/cv_casmvs_multi-view-depth-estimation_general'),
}

View File

@@ -73,6 +73,7 @@ if TYPE_CHECKING:
from .video_super_resolution_pipeline import VideoSuperResolutionPipeline
from .pointcloud_sceneflow_estimation_pipeline import PointCloudSceneFlowEstimationPipeline
from .maskdino_instance_segmentation_pipeline import MaskDINOInstanceSegmentationPipeline
from .image_mvs_depth_estimation_pipeline import ImageMultiViewDepthEstimationPipeline
else:
_import_structure = {
@@ -171,7 +172,10 @@ else:
],
'maskdino_instance_segmentation_pipeline': [
'MaskDINOInstanceSegmentationPipeline'
]
],
'image_mvs_depth_estimation_pipeline': [
'ImageMultiViewDepthEstimationPipeline'
],
}
import sys

View File

@@ -0,0 +1,80 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
from tempfile import TemporaryDirectory
from typing import Any, Dict, Union
from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Model, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.constant import Tasks
from modelscope.utils.cv.image_utils import depth_to_color
from modelscope.utils.logger import get_logger
logger = get_logger()
@PIPELINES.register_module(
Tasks.image_multi_view_depth_estimation,
module_name=Pipelines.image_multi_view_depth_estimation)
class ImageMultiViewDepthEstimationPipeline(Pipeline):
def __init__(self, model: str, **kwargs):
"""
use `model` to create a image multi-view depth estimation pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model, **kwargs)
self.tmp_dir = None
logger.info('pipeline init done')
def check_input(self, input_dir):
assert os.path.exists(
input_dir), f'input dir:{input_dir} does not exsit'
sub_dirs = os.listdir(input_dir)
assert 'images' in sub_dirs, "must contain 'images' folder"
assert 'sparse' in sub_dirs, "must contain 'sparse' folder"
files = os.listdir(os.path.join(input_dir, 'sparse'))
assert 'cameras.bin' in files, "'sparse' folder must contain 'cameras.bin'"
assert 'images.bin' in files, "'sparse' folder must contain 'images.bin'"
assert 'points3D.bin' in files, "'sparse' folder must contain 'points3D.bin'"
def preprocess(self, input: Input) -> Dict[str, Any]:
assert isinstance(input, str), 'input must be str'
self.check_input(input)
self.tmp_dir = TemporaryDirectory()
casmvs_inp_dir = os.path.join(self.tmp_dir.name, 'casmvs_inp_dir')
casmvs_res_dir = os.path.join(self.tmp_dir.name, 'casmvs_res_dir')
os.makedirs(casmvs_inp_dir, exist_ok=True)
os.makedirs(casmvs_res_dir, exist_ok=True)
input_dict = {
'input_dir': input,
'casmvs_inp_dir': casmvs_inp_dir,
'casmvs_res_dir': casmvs_res_dir
}
self.model.preprocess_make_pair(input_dict)
return input_dict
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
results = self.model.forward(input)
return results
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
pcd = self.model.postprocess(inputs)
# clear tmp dir
if self.tmp_dir is not None:
self.tmp_dir.cleanup()
outputs = {
OutputKeys.OUTPUT: pcd,
}
return outputs

View File

@@ -107,6 +107,8 @@ class CVTasks(object):
# pointcloud task
pointcloud_sceneflow_estimation = 'pointcloud-sceneflow-estimation'
# image multi-view depth estimation
image_multi_view_depth_estimation = 'image-multi-view-depth-estimation'
# domain specific object detection
domain_specific_object_detection = 'domain-specific-object-detection'

View File

@@ -0,0 +1,34 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import unittest
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.test_utils import test_level
class ImageMVSDepthEstimationTest(unittest.TestCase, DemoCompatibilityCheck):
def setUp(self) -> None:
self.task = 'image-multi-view-depth-estimation'
self.model_id = 'damo/cv_casmvs_multi-view-depth-estimation_general'
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_image_mvs_depth_estimation(self):
estimator = pipeline(
Tasks.image_multi_view_depth_estimation,
model='damo/cv_casmvs_multi-view-depth-estimation_general')
model_dir = snapshot_download(self.model_id)
input_location = os.path.join(model_dir, 'test_data')
result = estimator(input_location)
pcd = result[OutputKeys.OUTPUT]
pcd.write('./pcd_fusion.ply')
print('test_image_mvs_depth_estimation DONE')
if __name__ == '__main__':
unittest.main()