mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
add self supervised depth completion. (#711)
* add self supervised depth completion. * update. * fix the problem of key inconsistency. * delete args parser. * rename metrics to test_metrics.
This commit is contained in:
@@ -132,6 +132,7 @@ class Models(object):
|
|||||||
image_control_3d_portrait = 'image-control-3d-portrait'
|
image_control_3d_portrait = 'image-control-3d-portrait'
|
||||||
rife = 'rife'
|
rife = 'rife'
|
||||||
anydoor = 'anydoor'
|
anydoor = 'anydoor'
|
||||||
|
self_supervised_depth_completion = 'self-supervised-depth-completion'
|
||||||
|
|
||||||
# nlp models
|
# nlp models
|
||||||
bert = 'bert'
|
bert = 'bert'
|
||||||
@@ -469,6 +470,7 @@ class Pipelines(object):
|
|||||||
rife_video_frame_interpolation = 'rife-video-frame-interpolation'
|
rife_video_frame_interpolation = 'rife-video-frame-interpolation'
|
||||||
anydoor = 'anydoor'
|
anydoor = 'anydoor'
|
||||||
image_to_3d = 'image-to-3d'
|
image_to_3d = 'image-to-3d'
|
||||||
|
self_supervised_depth_completion = 'self-supervised-depth-completion'
|
||||||
|
|
||||||
# nlp tasks
|
# nlp tasks
|
||||||
automatic_post_editing = 'automatic-post-editing'
|
automatic_post_editing = 'automatic-post-editing'
|
||||||
@@ -959,7 +961,10 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
|||||||
'damo/cv_image-view-transform'),
|
'damo/cv_image-view-transform'),
|
||||||
Tasks.image_control_3d_portrait: (
|
Tasks.image_control_3d_portrait: (
|
||||||
Pipelines.image_control_3d_portrait,
|
Pipelines.image_control_3d_portrait,
|
||||||
'damo/cv_vit_image-control-3d-portrait-synthesis')
|
'damo/cv_vit_image-control-3d-portrait-synthesis'),
|
||||||
|
Tasks.self_supervised_depth_completion: (
|
||||||
|
Pipelines.self_supervised_depth_completion,
|
||||||
|
'damo/self-supervised-depth-completion')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -982,6 +987,7 @@ class CVTrainers(object):
|
|||||||
nerf_recon_4k = 'nerf-recon-4k'
|
nerf_recon_4k = 'nerf-recon-4k'
|
||||||
action_detection = 'action-detection'
|
action_detection = 'action-detection'
|
||||||
vision_efficient_tuning = 'vision-efficient-tuning'
|
vision_efficient_tuning = 'vision-efficient-tuning'
|
||||||
|
self_supervised_depth_completion = 'self-supervised-depth-completion'
|
||||||
|
|
||||||
|
|
||||||
class NLPTrainers(object):
|
class NLPTrainers(object):
|
||||||
|
|||||||
@@ -0,0 +1,21 @@
|
|||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from modelscope.utils.import_utils import LazyImportModule
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .self_supervised_depth_completion import SelfSupervisedDepthCompletion
|
||||||
|
else:
|
||||||
|
_import_structure = {
|
||||||
|
'selfsuperviseddepthcompletion': ['SelfSupervisedDepthCompletion'],
|
||||||
|
}
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.modules[__name__] = LazyImportModule(
|
||||||
|
__name__,
|
||||||
|
globals()['__file__'],
|
||||||
|
_import_structure,
|
||||||
|
module_spec=__spec__,
|
||||||
|
extra_objects={},
|
||||||
|
)
|
||||||
@@ -0,0 +1,98 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from modelscope.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
loss_names = ['l1', 'l2']
|
||||||
|
|
||||||
|
|
||||||
|
class MaskedMSELoss(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(MaskedMSELoss, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, pred, target):
|
||||||
|
assert pred.dim() == target.dim(), 'inconsistent dimensions'
|
||||||
|
valid_mask = (target > 0).detach()
|
||||||
|
diff = target - pred
|
||||||
|
diff = diff[valid_mask]
|
||||||
|
self.loss = (diff**2).mean()
|
||||||
|
return self.loss
|
||||||
|
|
||||||
|
|
||||||
|
class MaskedL1Loss(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(MaskedL1Loss, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, pred, target, weight=None):
|
||||||
|
assert pred.dim() == target.dim(), 'inconsistent dimensions'
|
||||||
|
valid_mask = (target > 0).detach()
|
||||||
|
diff = target - pred
|
||||||
|
diff = diff[valid_mask]
|
||||||
|
self.loss = diff.abs().mean()
|
||||||
|
return self.loss
|
||||||
|
|
||||||
|
|
||||||
|
class PhotometricLoss(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(PhotometricLoss, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, target, recon, mask=None):
|
||||||
|
|
||||||
|
assert recon.dim(
|
||||||
|
) == 4, 'expected recon dimension to be 4, but instead got {}.'.format(
|
||||||
|
recon.dim())
|
||||||
|
assert target.dim(
|
||||||
|
) == 4, 'expected target dimension to be 4, but instead got {}.'.format(
|
||||||
|
target.dim())
|
||||||
|
assert recon.size() == target.size(), 'expected recon and target to have the same size, but got {} and {} '\
|
||||||
|
.format(recon.size(), target.size())
|
||||||
|
diff = (target - recon).abs()
|
||||||
|
diff = torch.sum(diff, 1) # sum along the color channel
|
||||||
|
|
||||||
|
# compare only pixels that are not black
|
||||||
|
valid_mask = (torch.sum(recon, 1) > 0).float() * (torch.sum(target, 1)
|
||||||
|
> 0).float()
|
||||||
|
if mask is not None:
|
||||||
|
valid_mask = valid_mask * torch.squeeze(mask).float()
|
||||||
|
valid_mask = valid_mask.byte().detach()
|
||||||
|
if valid_mask.numel() > 0:
|
||||||
|
diff = diff[valid_mask]
|
||||||
|
if diff.nelement() > 0:
|
||||||
|
self.loss = diff.mean()
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
'warning: diff.nelement()==0 in PhotometricLoss (this is expected during early stage of training, \
|
||||||
|
try larger batch size).')
|
||||||
|
self.loss = 0
|
||||||
|
else:
|
||||||
|
logger.info('warning: 0 valid pixel in PhotometricLoss')
|
||||||
|
self.loss = 0
|
||||||
|
return self.loss
|
||||||
|
|
||||||
|
|
||||||
|
class SmoothnessLoss(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(SmoothnessLoss, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, depth):
|
||||||
|
|
||||||
|
def second_derivative(x):
|
||||||
|
assert x.dim(
|
||||||
|
) == 4, 'expected 4-dimensional data, but instead got {}'.format(
|
||||||
|
x.dim())
|
||||||
|
horizontal = 2 * x[:, :, 1:-1, 1:-1] - x[:, :,
|
||||||
|
1:-1, :-2] - x[:, :, 1:-1,
|
||||||
|
2:]
|
||||||
|
vertical = 2 * x[:, :, 1:-1, 1:-1] - x[:, :, :-2,
|
||||||
|
1:-1] - x[:, :, 2:, 1:-1]
|
||||||
|
der_2nd = horizontal.abs() + vertical.abs()
|
||||||
|
return der_2nd.mean()
|
||||||
|
|
||||||
|
self.loss = second_derivative(depth)
|
||||||
|
return self.loss
|
||||||
@@ -0,0 +1,344 @@
|
|||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import os.path
|
||||||
|
from random import choice
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch.utils.data as data
|
||||||
|
from numpy import linalg as LA
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from modelscope.models.cv.self_supervised_depth_completion.dataloaders import \
|
||||||
|
transforms
|
||||||
|
from modelscope.models.cv.self_supervised_depth_completion.dataloaders.pose_estimator import \
|
||||||
|
get_pose_pnp
|
||||||
|
|
||||||
|
input_options = ['d', 'rgb', 'rgbd', 'g', 'gd']
|
||||||
|
|
||||||
|
|
||||||
|
def load_calib(args):
|
||||||
|
"""
|
||||||
|
Temporarily hardcoding the calibration matrix using calib file from 2011_09_26
|
||||||
|
"""
|
||||||
|
calib = open(os.path.join(args.data_folder, 'calib_cam_to_cam.txt'), 'r')
|
||||||
|
lines = calib.readlines()
|
||||||
|
P_rect_line = lines[25]
|
||||||
|
|
||||||
|
Proj_str = P_rect_line.split(':')[1].split(' ')[1:]
|
||||||
|
Proj = np.reshape(np.array([float(p) for p in Proj_str]),
|
||||||
|
(3, 4)).astype(np.float32)
|
||||||
|
K = Proj[:3, :3] # camera matrix
|
||||||
|
|
||||||
|
# note: we will take the center crop of the images during augmentation
|
||||||
|
# that changes the optical centers, but not focal lengths
|
||||||
|
K[0, 2] = K[
|
||||||
|
0,
|
||||||
|
2] - 13 # from width = 1242 to 1216, with a 13-pixel cut on both sides
|
||||||
|
K[1, 2] = K[
|
||||||
|
1,
|
||||||
|
2] - 11.5 # from width = 375 to 352, with a 11.5-pixel cut on both sides
|
||||||
|
return K
|
||||||
|
|
||||||
|
|
||||||
|
def get_paths_and_transform(split, args):
|
||||||
|
assert (args.use_d or args.use_rgb
|
||||||
|
or args.use_g), 'no proper input selected'
|
||||||
|
|
||||||
|
if split == 'train':
|
||||||
|
transform = train_transform
|
||||||
|
glob_d = os.path.join(
|
||||||
|
args.data_folder,
|
||||||
|
'data_depth_velodyne/train/*_sync/proj_depth/velodyne_raw/image_0[2,3]/*.png'
|
||||||
|
)
|
||||||
|
glob_gt = os.path.join(
|
||||||
|
args.data_folder,
|
||||||
|
'data_depth_annotated/train/*_sync/proj_depth/groundtruth/image_0[2,3]/*.png'
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_rgb_paths(p):
|
||||||
|
ps = p.split('/')
|
||||||
|
pnew = '/'.join([args.data_folder] + ['data_rgb'] + ps[-6:-4]
|
||||||
|
+ ps[-2:-1] + ['data'] + ps[-1:])
|
||||||
|
return pnew
|
||||||
|
elif split == 'val':
|
||||||
|
if args.val == 'full':
|
||||||
|
transform = val_transform
|
||||||
|
glob_d = os.path.join(
|
||||||
|
args.data_folder,
|
||||||
|
'data_depth_velodyne/val/*_sync/proj_depth/velodyne_raw/image_0[2,3]/*.png'
|
||||||
|
)
|
||||||
|
glob_gt = os.path.join(
|
||||||
|
args.data_folder,
|
||||||
|
'data_depth_annotated/val/*_sync/proj_depth/groundtruth/image_0[2,3]/*.png'
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_rgb_paths(p):
|
||||||
|
ps = p.split('/')
|
||||||
|
pnew = '/'.join(ps[:-7] + ['data_rgb '] + ps[-6:-4] + ps[-2:-1]
|
||||||
|
+ ['data'] + ps[-1:])
|
||||||
|
return pnew
|
||||||
|
elif args.val == 'select':
|
||||||
|
transform = no_transform
|
||||||
|
glob_d = os.path.join(
|
||||||
|
args.data_folder,
|
||||||
|
'depth_selection/val_selection_cropped/velodyne_raw/*.png')
|
||||||
|
glob_gt = os.path.join(
|
||||||
|
args.data_folder,
|
||||||
|
'depth_selection/val_selection_cropped/groundtruth_depth/*.png'
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_rgb_paths(p):
|
||||||
|
return p.replace('groundtruth_depth', 'image')
|
||||||
|
elif split == 'test_completion':
|
||||||
|
transform = no_transform
|
||||||
|
glob_d = os.path.join(
|
||||||
|
args.data_folder,
|
||||||
|
'depth_selection/test_depth_completion_anonymous/velodyne_raw/*.png'
|
||||||
|
)
|
||||||
|
glob_gt = None # "test_depth_completion_anonymous/"
|
||||||
|
glob_rgb = os.path.join(
|
||||||
|
args.data_folder,
|
||||||
|
'depth_selection/test_depth_completion_anonymous/image/*.png')
|
||||||
|
elif split == 'test_prediction':
|
||||||
|
transform = no_transform
|
||||||
|
glob_d = None
|
||||||
|
glob_gt = None # "test_depth_completion_anonymous/"
|
||||||
|
glob_rgb = os.path.join(
|
||||||
|
args.data_folder,
|
||||||
|
'depth_selection/test_depth_prediction_anonymous/image/*.png')
|
||||||
|
else:
|
||||||
|
raise ValueError('Unrecognized split ' + str(split))
|
||||||
|
|
||||||
|
if glob_gt is not None:
|
||||||
|
# train or val-full or val-select
|
||||||
|
paths_d = sorted(glob.glob(glob_d))
|
||||||
|
paths_gt = sorted(glob.glob(glob_gt))
|
||||||
|
paths_rgb = [get_rgb_paths(p) for p in paths_gt]
|
||||||
|
else:
|
||||||
|
# test only has d or rgb
|
||||||
|
paths_rgb = sorted(glob.glob(glob_rgb))
|
||||||
|
paths_gt = [None] * len(paths_rgb)
|
||||||
|
if split == 'test_prediction':
|
||||||
|
paths_d = [None] * len(
|
||||||
|
paths_rgb) # test_prediction has no sparse depth
|
||||||
|
else:
|
||||||
|
paths_d = sorted(glob.glob(glob_d))
|
||||||
|
|
||||||
|
if len(paths_d) == 0 and len(paths_rgb) == 0 and len(paths_gt) == 0:
|
||||||
|
raise (RuntimeError('Found 0 images under {}'.format(glob_gt)))
|
||||||
|
if len(paths_d) == 0 and args.use_d:
|
||||||
|
raise (RuntimeError('Requested sparse depth but none was found'))
|
||||||
|
if len(paths_rgb) == 0 and args.use_rgb:
|
||||||
|
raise (RuntimeError('Requested rgb images but none was found'))
|
||||||
|
if len(paths_rgb) == 0 and args.use_g:
|
||||||
|
raise (RuntimeError('Requested gray images but no rgb was found'))
|
||||||
|
if len(paths_rgb) != len(paths_d) or len(paths_rgb) != len(paths_gt):
|
||||||
|
raise (RuntimeError('Produced different sizes for datasets'))
|
||||||
|
|
||||||
|
paths = {'rgb': paths_rgb, 'd': paths_d, 'gt': paths_gt}
|
||||||
|
return paths, transform
|
||||||
|
|
||||||
|
|
||||||
|
def rgb_read(filename):
|
||||||
|
assert os.path.exists(filename), 'file not found: {}'.format(filename)
|
||||||
|
img_file = Image.open(filename)
|
||||||
|
# rgb_png = np.array(img_file, dtype=float) / 255.0 # scale pixels to the range [0,1]
|
||||||
|
rgb_png = np.array(img_file, dtype='uint8') # in the range [0,255]
|
||||||
|
img_file.close()
|
||||||
|
return rgb_png
|
||||||
|
|
||||||
|
|
||||||
|
def depth_read(filename):
|
||||||
|
# loads depth map D from png file
|
||||||
|
# and returns it as a numpy array,
|
||||||
|
# for details see readme.txt
|
||||||
|
assert os.path.exists(filename), 'file not found: {}'.format(filename)
|
||||||
|
img_file = Image.open(filename)
|
||||||
|
depth_png = np.array(img_file, dtype=int)
|
||||||
|
img_file.close()
|
||||||
|
# make sure we have a proper 16bit depth map here.. not 8bit!
|
||||||
|
assert np.max(depth_png) > 255, \
|
||||||
|
'np.max(depth_png)={}, path={}'.format(np.max(depth_png), filename)
|
||||||
|
|
||||||
|
depth = depth_png.astype(float) / 256.
|
||||||
|
# depth[depth_png == 0] = -1.
|
||||||
|
depth = np.expand_dims(depth, -1)
|
||||||
|
return depth
|
||||||
|
|
||||||
|
|
||||||
|
oheight, owidth = 352, 1216
|
||||||
|
|
||||||
|
|
||||||
|
def drop_depth_measurements(depth, prob_keep):
|
||||||
|
mask = np.random.binomial(1, prob_keep, depth.shape)
|
||||||
|
depth *= mask
|
||||||
|
return depth
|
||||||
|
|
||||||
|
|
||||||
|
def train_transform(rgb, sparse, target, rgb_near, args):
|
||||||
|
# s = np.random.uniform(1.0, 1.5) # random scaling
|
||||||
|
# angle = np.random.uniform(-5.0, 5.0) # random rotation degrees
|
||||||
|
do_flip = np.random.uniform(0.0, 1.0) < 0.5 # random horizontal flip
|
||||||
|
|
||||||
|
transform_geometric = transforms.Compose([
|
||||||
|
# transforms.Rotate(angle),
|
||||||
|
# transforms.Resize(s),
|
||||||
|
transforms.BottomCrop((oheight, owidth)),
|
||||||
|
transforms.HorizontalFlip(do_flip)
|
||||||
|
])
|
||||||
|
if sparse is not None:
|
||||||
|
sparse = transform_geometric(sparse)
|
||||||
|
target = transform_geometric(target)
|
||||||
|
if rgb is not None:
|
||||||
|
brightness = np.random.uniform(
|
||||||
|
max(0, 1 - args.jitter), 1 + args.jitter)
|
||||||
|
contrast = np.random.uniform(max(0, 1 - args.jitter), 1 + args.jitter)
|
||||||
|
saturation = np.random.uniform(
|
||||||
|
max(0, 1 - args.jitter), 1 + args.jitter)
|
||||||
|
transform_rgb = transforms.Compose([
|
||||||
|
transforms.ColorJitter(brightness, contrast, saturation, 0),
|
||||||
|
transform_geometric
|
||||||
|
])
|
||||||
|
rgb = transform_rgb(rgb)
|
||||||
|
if rgb_near is not None:
|
||||||
|
rgb_near = transform_rgb(rgb_near)
|
||||||
|
# sparse = drop_depth_measurements(sparse, 0.9)
|
||||||
|
|
||||||
|
return rgb, sparse, target, rgb_near
|
||||||
|
|
||||||
|
|
||||||
|
def val_transform(rgb, sparse, target, rgb_near, args):
|
||||||
|
transform = transforms.Compose([
|
||||||
|
transforms.BottomCrop((oheight, owidth)),
|
||||||
|
])
|
||||||
|
if rgb is not None:
|
||||||
|
rgb = transform(rgb)
|
||||||
|
if sparse is not None:
|
||||||
|
sparse = transform(sparse)
|
||||||
|
if target is not None:
|
||||||
|
target = transform(target)
|
||||||
|
if rgb_near is not None:
|
||||||
|
rgb_near = transform(rgb_near)
|
||||||
|
return rgb, sparse, target, rgb_near
|
||||||
|
|
||||||
|
|
||||||
|
def no_transform(rgb, sparse, target, rgb_near, args):
|
||||||
|
return rgb, sparse, target, rgb_near
|
||||||
|
|
||||||
|
|
||||||
|
to_tensor = transforms.ToTensor()
|
||||||
|
|
||||||
|
|
||||||
|
def to_float_tensor(x):
|
||||||
|
return to_tensor(x).float()
|
||||||
|
|
||||||
|
|
||||||
|
def handle_gray(rgb, args):
|
||||||
|
if rgb is None:
|
||||||
|
return None, None
|
||||||
|
if not args.use_g:
|
||||||
|
return rgb, None
|
||||||
|
else:
|
||||||
|
img = np.array(Image.fromarray(rgb).convert('L'))
|
||||||
|
img = np.expand_dims(img, -1)
|
||||||
|
if not args.use_rgb:
|
||||||
|
rgb_ret = None
|
||||||
|
else:
|
||||||
|
rgb_ret = rgb
|
||||||
|
return rgb_ret, img
|
||||||
|
|
||||||
|
|
||||||
|
def get_rgb_near(path, args):
|
||||||
|
assert path is not None, 'path is None'
|
||||||
|
|
||||||
|
def extract_frame_id(filename):
|
||||||
|
head, tail = os.path.split(filename)
|
||||||
|
number_string = tail[0:tail.find('.')]
|
||||||
|
number = int(number_string)
|
||||||
|
return head, number
|
||||||
|
|
||||||
|
def get_nearby_filename(filename, new_id):
|
||||||
|
head, _ = os.path.split(filename)
|
||||||
|
new_filename = os.path.join(head, '%010d.png' % new_id)
|
||||||
|
return new_filename
|
||||||
|
|
||||||
|
head, number = extract_frame_id(path)
|
||||||
|
count = 0
|
||||||
|
max_frame_diff = 3
|
||||||
|
candidates = [
|
||||||
|
i - max_frame_diff for i in range(max_frame_diff * 2 + 1)
|
||||||
|
if i - max_frame_diff != 0
|
||||||
|
]
|
||||||
|
while True:
|
||||||
|
random_offset = choice(candidates)
|
||||||
|
path_near = get_nearby_filename(path, number + random_offset)
|
||||||
|
if os.path.exists(path_near):
|
||||||
|
break
|
||||||
|
assert count < 20, 'cannot find a nearby frame in 20 trials for {}'.format(
|
||||||
|
path)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
return rgb_read(path_near)
|
||||||
|
|
||||||
|
|
||||||
|
class KittiDepth(data.Dataset):
|
||||||
|
"""A data loader for the Kitti dataset
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, split, args):
|
||||||
|
self.args = args
|
||||||
|
self.split = split
|
||||||
|
paths, transform = get_paths_and_transform(split, args)
|
||||||
|
self.paths = paths
|
||||||
|
self.transform = transform
|
||||||
|
self.K = load_calib(args)
|
||||||
|
self.threshold_translation = 0.1
|
||||||
|
|
||||||
|
def __getraw__(self, index):
|
||||||
|
rgb = rgb_read(self.paths['rgb'][index]) if \
|
||||||
|
(self.paths['rgb'][index] is not None and (self.args.use_rgb or self.args.use_g)) else None
|
||||||
|
sparse = depth_read(self.paths['d'][index]) if \
|
||||||
|
(self.paths['d'][index] is not None and self.args.use_d) else None
|
||||||
|
target = depth_read(self.paths['gt'][index]) if \
|
||||||
|
self.paths['gt'][index] is not None else None
|
||||||
|
rgb_near = get_rgb_near(self.paths['rgb'][index], self.args) if \
|
||||||
|
self.split == 'train' and self.args.use_pose else None
|
||||||
|
return rgb, sparse, target, rgb_near
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
rgb, sparse, target, rgb_near = self.__getraw__(index)
|
||||||
|
rgb, sparse, target, rgb_near = self.transform(rgb, sparse, target,
|
||||||
|
rgb_near, self.args)
|
||||||
|
r_mat, t_vec = None, None
|
||||||
|
if self.split == 'train' and self.args.use_pose:
|
||||||
|
success, r_vec, t_vec = get_pose_pnp(rgb, rgb_near, sparse, self.K)
|
||||||
|
# discard if translation is too small
|
||||||
|
success = success and LA.norm(t_vec) > self.threshold_translation
|
||||||
|
if success:
|
||||||
|
r_mat, _ = cv2.Rodrigues(r_vec)
|
||||||
|
else:
|
||||||
|
# return the same image and no motion when PnP fails
|
||||||
|
rgb_near = rgb
|
||||||
|
t_vec = np.zeros((3, 1))
|
||||||
|
r_mat = np.eye(3)
|
||||||
|
|
||||||
|
rgb, gray = handle_gray(rgb, self.args)
|
||||||
|
candidates = {
|
||||||
|
'rgb': rgb,
|
||||||
|
'd': sparse,
|
||||||
|
'gt': target,
|
||||||
|
'g': gray,
|
||||||
|
'r_mat': r_mat,
|
||||||
|
't_vec': t_vec,
|
||||||
|
'rgb_near': rgb_near
|
||||||
|
}
|
||||||
|
items = {
|
||||||
|
key: to_float_tensor(val)
|
||||||
|
for key, val in candidates.items() if val is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
return items
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.paths['gt'])
|
||||||
@@ -0,0 +1,102 @@
|
|||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def rgb2gray(rgb):
|
||||||
|
return np.dot(rgb[..., :3], [0.299, 0.587, 0.114])
|
||||||
|
|
||||||
|
|
||||||
|
def convert_2d_to_3d(u, v, z, K):
|
||||||
|
v0 = K[1][2]
|
||||||
|
u0 = K[0][2]
|
||||||
|
fy = K[1][1]
|
||||||
|
fx = K[0][0]
|
||||||
|
x = (u - u0) * z / fx
|
||||||
|
y = (v - v0) * z / fy
|
||||||
|
return (x, y, z)
|
||||||
|
|
||||||
|
|
||||||
|
def feature_match(img1, img2):
|
||||||
|
r''' Find features on both images and match them pairwise
|
||||||
|
'''
|
||||||
|
max_n_features = 1000
|
||||||
|
# max_n_features = 500
|
||||||
|
use_flann = False # better not use flann
|
||||||
|
|
||||||
|
detector = cv2.xfeatures2d.SIFT_create(max_n_features)
|
||||||
|
|
||||||
|
# find the keypoints and descriptors with SIFT
|
||||||
|
kp1, des1 = detector.detectAndCompute(img1, None)
|
||||||
|
kp2, des2 = detector.detectAndCompute(img2, None)
|
||||||
|
if (des1 is None) or (des2 is None):
|
||||||
|
return [], []
|
||||||
|
des1 = des1.astype(np.float32)
|
||||||
|
des2 = des2.astype(np.float32)
|
||||||
|
|
||||||
|
if use_flann:
|
||||||
|
# FLANN parameters
|
||||||
|
FLANN_INDEX_KDTREE = 0
|
||||||
|
index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
|
||||||
|
search_params = dict(checks=50)
|
||||||
|
flann = cv2.FlannBasedMatcher(index_params, search_params)
|
||||||
|
matches = flann.knnMatch(des1, des2, k=2)
|
||||||
|
else:
|
||||||
|
matcher = cv2.DescriptorMatcher().create('BruteForce')
|
||||||
|
matches = matcher.knnMatch(des1, des2, k=2)
|
||||||
|
|
||||||
|
good = []
|
||||||
|
pts1 = []
|
||||||
|
pts2 = []
|
||||||
|
# ratio test as per Lowe's paper
|
||||||
|
for i, (m, n) in enumerate(matches):
|
||||||
|
if m.distance < 0.8 * n.distance:
|
||||||
|
good.append(m)
|
||||||
|
pts2.append(kp2[m.trainIdx].pt)
|
||||||
|
pts1.append(kp1[m.queryIdx].pt)
|
||||||
|
|
||||||
|
pts1 = np.int32(pts1)
|
||||||
|
pts2 = np.int32(pts2)
|
||||||
|
return pts1, pts2
|
||||||
|
|
||||||
|
|
||||||
|
def get_pose_pnp(rgb_curr, rgb_near, depth_curr, K):
|
||||||
|
gray_curr = rgb2gray(rgb_curr).astype(np.uint8)
|
||||||
|
gray_near = rgb2gray(rgb_near).astype(np.uint8)
|
||||||
|
height, width = gray_curr.shape
|
||||||
|
|
||||||
|
pts2d_curr, pts2d_near = feature_match(gray_curr,
|
||||||
|
gray_near) # feature matching
|
||||||
|
|
||||||
|
# dilation of depth
|
||||||
|
kernel = np.ones((4, 4), np.uint8)
|
||||||
|
depth_curr_dilated = cv2.dilate(depth_curr, kernel)
|
||||||
|
|
||||||
|
# extract 3d pts
|
||||||
|
pts3d_curr = []
|
||||||
|
pts2d_near_filtered = [
|
||||||
|
] # keep only feature points with depth in the current frame
|
||||||
|
for i, pt2d in enumerate(pts2d_curr):
|
||||||
|
# print(pt2d)
|
||||||
|
u, v = pt2d[0], pt2d[1]
|
||||||
|
z = depth_curr_dilated[v, u]
|
||||||
|
if z > 0:
|
||||||
|
xyz_curr = convert_2d_to_3d(u, v, z, K)
|
||||||
|
pts3d_curr.append(xyz_curr)
|
||||||
|
pts2d_near_filtered.append(pts2d_near[i])
|
||||||
|
|
||||||
|
# the minimal number of points accepted by solvePnP is 4:
|
||||||
|
if len(pts3d_curr) >= 4 and len(pts2d_near_filtered) >= 4:
|
||||||
|
pts3d_curr = np.expand_dims(
|
||||||
|
np.array(pts3d_curr).astype(np.float32), axis=1)
|
||||||
|
pts2d_near_filtered = np.expand_dims(
|
||||||
|
np.array(pts2d_near_filtered).astype(np.float32), axis=1)
|
||||||
|
|
||||||
|
# ransac
|
||||||
|
ret = cv2.solvePnPRansac(
|
||||||
|
pts3d_curr, pts2d_near_filtered, K, distCoeffs=None)
|
||||||
|
success = ret[0]
|
||||||
|
rotation_vector = ret[1]
|
||||||
|
translation_vector = ret[2]
|
||||||
|
return (success, rotation_vector, translation_vector)
|
||||||
|
else:
|
||||||
|
return (0, None, None)
|
||||||
@@ -0,0 +1,617 @@
|
|||||||
|
from __future__ import division
|
||||||
|
import numbers
|
||||||
|
import types
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import scipy.ndimage.interpolation as itpl
|
||||||
|
import skimage.transform
|
||||||
|
import torch
|
||||||
|
from PIL import Image, ImageEnhance
|
||||||
|
|
||||||
|
try:
|
||||||
|
import accimage
|
||||||
|
except ImportError:
|
||||||
|
accimage = None
|
||||||
|
|
||||||
|
|
||||||
|
def _is_numpy_image(img):
|
||||||
|
return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
|
||||||
|
|
||||||
|
|
||||||
|
def _is_pil_image(img):
|
||||||
|
if accimage is not None:
|
||||||
|
return isinstance(img, (Image.Image, accimage.Image))
|
||||||
|
else:
|
||||||
|
return isinstance(img, Image.Image)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_tensor_image(img):
|
||||||
|
return torch.is_tensor(img) and img.ndimension() == 3
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_brightness(img, brightness_factor):
|
||||||
|
"""Adjust brightness of an Image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (PIL Image): PIL Image to be adjusted.
|
||||||
|
brightness_factor (float): How much to adjust the brightness. Can be
|
||||||
|
any non negative number. 0 gives a black image, 1 gives the
|
||||||
|
original image while 2 increases the brightness by a factor of 2.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PIL Image: Brightness adjusted image.
|
||||||
|
"""
|
||||||
|
if not _is_pil_image(img):
|
||||||
|
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||||
|
|
||||||
|
enhancer = ImageEnhance.Brightness(img)
|
||||||
|
img = enhancer.enhance(brightness_factor)
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_contrast(img, contrast_factor):
|
||||||
|
"""Adjust contrast of an Image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (PIL Image): PIL Image to be adjusted.
|
||||||
|
contrast_factor (float): How much to adjust the contrast. Can be any
|
||||||
|
non negative number. 0 gives a solid gray image, 1 gives the
|
||||||
|
original image while 2 increases the contrast by a factor of 2.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PIL Image: Contrast adjusted image.
|
||||||
|
"""
|
||||||
|
if not _is_pil_image(img):
|
||||||
|
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||||
|
|
||||||
|
enhancer = ImageEnhance.Contrast(img)
|
||||||
|
img = enhancer.enhance(contrast_factor)
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_saturation(img, saturation_factor):
|
||||||
|
"""Adjust color saturation of an image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (PIL Image): PIL Image to be adjusted.
|
||||||
|
saturation_factor (float): How much to adjust the saturation. 0 will
|
||||||
|
give a black and white image, 1 will give the original image while
|
||||||
|
2 will enhance the saturation by a factor of 2.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PIL Image: Saturation adjusted image.
|
||||||
|
"""
|
||||||
|
if not _is_pil_image(img):
|
||||||
|
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||||
|
|
||||||
|
enhancer = ImageEnhance.Color(img)
|
||||||
|
img = enhancer.enhance(saturation_factor)
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_hue(img, hue_factor):
|
||||||
|
"""Adjust hue of an image.
|
||||||
|
|
||||||
|
The image hue is adjusted by converting the image to HSV and
|
||||||
|
cyclically shifting the intensities in the hue channel (H).
|
||||||
|
The image is then converted back to original image mode.
|
||||||
|
|
||||||
|
`hue_factor` is the amount of shift in H channel and must be in the
|
||||||
|
interval `[-0.5, 0.5]`.
|
||||||
|
|
||||||
|
See https://en.wikipedia.org/wiki/Hue for more details on Hue.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (PIL Image): PIL Image to be adjusted.
|
||||||
|
hue_factor (float): How much to shift the hue channel. Should be in
|
||||||
|
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
|
||||||
|
HSV space in positive and negative direction respectively.
|
||||||
|
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
|
||||||
|
with complementary colors while 0 gives the original image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PIL Image: Hue adjusted image.
|
||||||
|
"""
|
||||||
|
if not (-0.5 <= hue_factor <= 0.5):
|
||||||
|
raise ValueError(
|
||||||
|
'hue_factor is not in [-0.5, 0.5]. Got {}'.format(hue_factor))
|
||||||
|
|
||||||
|
if not _is_pil_image(img):
|
||||||
|
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||||
|
|
||||||
|
input_mode = img.mode
|
||||||
|
if input_mode in {'L', '1', 'I', 'F'}:
|
||||||
|
return img
|
||||||
|
|
||||||
|
h, s, v = img.convert('HSV').split()
|
||||||
|
|
||||||
|
np_h = np.array(h, dtype=np.uint8)
|
||||||
|
# uint8 addition take cares of rotation across boundaries
|
||||||
|
with np.errstate(over='ignore'):
|
||||||
|
np_h += np.uint8(hue_factor * 255)
|
||||||
|
h = Image.fromarray(np_h, 'L')
|
||||||
|
|
||||||
|
img = Image.merge('HSV', (h, s, v)).convert(input_mode)
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_gamma(img, gamma, gain=1):
|
||||||
|
"""Perform gamma correction on an image.
|
||||||
|
|
||||||
|
Also known as Power Law Transform. Intensities in RGB mode are adjusted
|
||||||
|
based on the following equation:
|
||||||
|
|
||||||
|
I_out = 255 * gain * ((I_in / 255) ** gamma)
|
||||||
|
|
||||||
|
See https://en.wikipedia.org/wiki/Gamma_correction for more details.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (PIL Image): PIL Image to be adjusted.
|
||||||
|
gamma (float): Non negative real number. gamma larger than 1 make the
|
||||||
|
shadows darker, while gamma smaller than 1 make dark regions
|
||||||
|
lighter.
|
||||||
|
gain (float): The constant multiplier.
|
||||||
|
"""
|
||||||
|
if not _is_pil_image(img):
|
||||||
|
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||||
|
|
||||||
|
if gamma < 0:
|
||||||
|
raise ValueError('Gamma should be a non-negative real number')
|
||||||
|
|
||||||
|
input_mode = img.mode
|
||||||
|
img = img.convert('RGB')
|
||||||
|
|
||||||
|
np_img = np.array(img, dtype=np.float32)
|
||||||
|
np_img = 255 * gain * ((np_img / 255)**gamma)
|
||||||
|
np_img = np.uint8(np.clip(np_img, 0, 255))
|
||||||
|
|
||||||
|
img = Image.fromarray(np_img, 'RGB').convert(input_mode)
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
class Compose(object):
|
||||||
|
"""Composes several transforms together.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
transforms (list of ``Transform`` objects): list of transforms to compose.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> transforms.Compose([
|
||||||
|
>>> transforms.CenterCrop(10),
|
||||||
|
>>> transforms.ToTensor(),
|
||||||
|
>>> ])
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, transforms):
|
||||||
|
self.transforms = transforms
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
for t in self.transforms:
|
||||||
|
img = t(img)
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
class ToTensor(object):
|
||||||
|
"""Convert a ``numpy.ndarray`` to tensor.
|
||||||
|
|
||||||
|
Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
"""Convert a ``numpy.ndarray`` to tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (numpy.ndarray): Image to be converted to tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Converted image.
|
||||||
|
"""
|
||||||
|
if not (_is_numpy_image(img)):
|
||||||
|
raise TypeError('img should be ndarray. Got {}'.format(type(img)))
|
||||||
|
|
||||||
|
if isinstance(img, np.ndarray):
|
||||||
|
# handle numpy array
|
||||||
|
if img.ndim == 3:
|
||||||
|
img = torch.from_numpy(img.transpose((2, 0, 1)).copy())
|
||||||
|
elif img.ndim == 2:
|
||||||
|
img = torch.from_numpy(img.copy())
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
'img should be ndarray with 2 or 3 dimensions. Got {}'.
|
||||||
|
format(img.ndim))
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizeNumpyArray(object):
|
||||||
|
"""Normalize a ``numpy.ndarray`` with mean and standard deviation.
|
||||||
|
Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform
|
||||||
|
will normalize each channel of the input ``numpy.ndarray`` i.e.
|
||||||
|
``input[channel] = (input[channel] - mean[channel]) / std[channel]``
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mean (sequence): Sequence of means for each channel.
|
||||||
|
std (sequence): Sequence of standard deviations for each channel.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, mean, std):
|
||||||
|
self.mean = mean
|
||||||
|
self.std = std
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
img (numpy.ndarray): Image of size (H, W, C) to be normalized.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Normalized image.
|
||||||
|
"""
|
||||||
|
if not (_is_numpy_image(img)):
|
||||||
|
raise TypeError('img should be ndarray. Got {}'.format(type(img)))
|
||||||
|
# TODO: make efficient
|
||||||
|
# print(img.shape)
|
||||||
|
for i in range(3):
|
||||||
|
img[:, :, i] = (img[:, :, i] - self.mean[i]) / self.std[i]
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizeTensor(object):
|
||||||
|
"""Normalize an tensor image with mean and standard deviation.
|
||||||
|
Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform
|
||||||
|
will normalize each channel of the input ``torch.*Tensor`` i.e.
|
||||||
|
``input[channel] = (input[channel] - mean[channel]) / std[channel]``
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mean (sequence): Sequence of means for each channel.
|
||||||
|
std (sequence): Sequence of standard deviations for each channel.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, mean, std):
|
||||||
|
self.mean = mean
|
||||||
|
self.std = std
|
||||||
|
|
||||||
|
def __call__(self, tensor):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Normalized Tensor image.
|
||||||
|
"""
|
||||||
|
if not _is_tensor_image(tensor):
|
||||||
|
raise TypeError('tensor is not a torch image.')
|
||||||
|
# TODO: make efficient
|
||||||
|
for t, m, s in zip(tensor, self.mean, self.std):
|
||||||
|
t.sub_(m).div_(s)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
class Rotate(object):
|
||||||
|
"""Rotates the given ``numpy.ndarray``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
angle (float): The rotation angle in degrees.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, angle):
|
||||||
|
self.angle = angle
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
img (numpy.ndarray (C x H x W)): Image to be rotated.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
img (numpy.ndarray (C x H x W)): Rotated image.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# order=0 means nearest-neighbor type interpolation
|
||||||
|
return skimage.transform.rotate(img, self.angle, resize=False, order=0)
|
||||||
|
|
||||||
|
|
||||||
|
class Resize(object):
|
||||||
|
"""Resize the the given ``numpy.ndarray`` to the given size.
|
||||||
|
Args:
|
||||||
|
size (sequence or int): Desired output size. If size is a sequence like
|
||||||
|
(h, w), output size will be matched to this. If size is an int,
|
||||||
|
smaller edge of the image will be matched to this number.
|
||||||
|
i.e, if height > width, then image will be rescaled to
|
||||||
|
(size * height / width, size)
|
||||||
|
interpolation (int, optional): Desired interpolation. Default is
|
||||||
|
``PIL.Image.BILINEAR``
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, size, interpolation='nearest'):
|
||||||
|
assert isinstance(size, float)
|
||||||
|
self.size = size
|
||||||
|
self.interpolation = interpolation
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
img (numpy.ndarray (C x H x W)): Image to be scaled.
|
||||||
|
Returns:
|
||||||
|
img (numpy.ndarray (C x H x W)): Rescaled image.
|
||||||
|
"""
|
||||||
|
if img.ndim == 3:
|
||||||
|
return skimage.transform.rescale(img, self.size, order=0)
|
||||||
|
elif img.ndim == 2:
|
||||||
|
return skimage.transform.rescale(img, self.size, order=0)
|
||||||
|
else:
|
||||||
|
RuntimeError(
|
||||||
|
'img should be ndarray with 2 or 3 dimensions. Got {}'.format(
|
||||||
|
img.ndim))
|
||||||
|
|
||||||
|
|
||||||
|
class CenterCrop(object):
|
||||||
|
"""Crops the given ``numpy.ndarray`` at the center.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size (sequence or int): Desired output size of the crop. If size is an
|
||||||
|
int instead of sequence like (h, w), a square crop (size, size) is
|
||||||
|
made.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, size):
|
||||||
|
if isinstance(size, numbers.Number):
|
||||||
|
self.size = (int(size), int(size))
|
||||||
|
else:
|
||||||
|
self.size = size
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_params(img, output_size):
|
||||||
|
"""Get parameters for ``crop`` for center crop.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (numpy.ndarray (C x H x W)): Image to be cropped.
|
||||||
|
output_size (tuple): Expected output size of the crop.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: params (i, j, h, w) to be passed to ``crop`` for center crop.
|
||||||
|
"""
|
||||||
|
h = img.shape[0]
|
||||||
|
w = img.shape[1]
|
||||||
|
th, tw = output_size
|
||||||
|
i = int(round((h - th) / 2.))
|
||||||
|
j = int(round((w - tw) / 2.))
|
||||||
|
|
||||||
|
# # randomized cropping
|
||||||
|
# i = np.random.randint(i-3, i+4)
|
||||||
|
# j = np.random.randint(j-3, j+4)
|
||||||
|
|
||||||
|
return i, j, th, tw
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
img (numpy.ndarray (C x H x W)): Image to be cropped.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
img (numpy.ndarray (C x H x W)): Cropped image.
|
||||||
|
"""
|
||||||
|
i, j, h, w = self.get_params(img, self.size)
|
||||||
|
"""
|
||||||
|
i: Upper pixel coordinate.
|
||||||
|
j: Left pixel coordinate.
|
||||||
|
h: Height of the cropped image.
|
||||||
|
w: Width of the cropped image.
|
||||||
|
"""
|
||||||
|
if not (_is_numpy_image(img)):
|
||||||
|
raise TypeError('img should be ndarray. Got {}'.format(type(img)))
|
||||||
|
if img.ndim == 3:
|
||||||
|
return img[i:i + h, j:j + w, :]
|
||||||
|
elif img.ndim == 2:
|
||||||
|
return img[i:i + h, j:j + w]
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
'img should be ndarray with 2 or 3 dimensions. Got {}'.format(
|
||||||
|
img.ndim))
|
||||||
|
|
||||||
|
|
||||||
|
class BottomCrop(object):
|
||||||
|
"""Crops the given ``numpy.ndarray`` at the bottom.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size (sequence or int): Desired output size of the crop. If size is an
|
||||||
|
int instead of sequence like (h, w), a square crop (size, size) is
|
||||||
|
made.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, size):
|
||||||
|
if isinstance(size, numbers.Number):
|
||||||
|
self.size = (int(size), int(size))
|
||||||
|
else:
|
||||||
|
self.size = size
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_params(img, output_size):
|
||||||
|
"""Get parameters for ``crop`` for bottom crop.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (numpy.ndarray (C x H x W)): Image to be cropped.
|
||||||
|
output_size (tuple): Expected output size of the crop.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: params (i, j, h, w) to be passed to ``crop`` for bottom crop.
|
||||||
|
"""
|
||||||
|
h = img.shape[0]
|
||||||
|
w = img.shape[1]
|
||||||
|
th, tw = output_size
|
||||||
|
i = h - th
|
||||||
|
j = int(round((w - tw) / 2.))
|
||||||
|
|
||||||
|
# randomized left and right cropping
|
||||||
|
# i = np.random.randint(i-3, i+4)
|
||||||
|
# j = np.random.randint(j-1, j+1)
|
||||||
|
|
||||||
|
return i, j, th, tw
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
img (numpy.ndarray (C x H x W)): Image to be cropped.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
img (numpy.ndarray (C x H x W)): Cropped image.
|
||||||
|
"""
|
||||||
|
i, j, h, w = self.get_params(img, self.size)
|
||||||
|
"""
|
||||||
|
i: Upper pixel coordinate.
|
||||||
|
j: Left pixel coordinate.
|
||||||
|
h: Height of the cropped image.
|
||||||
|
w: Width of the cropped image.
|
||||||
|
"""
|
||||||
|
if not (_is_numpy_image(img)):
|
||||||
|
raise TypeError('img should be ndarray. Got {}'.format(type(img)))
|
||||||
|
if img.ndim == 3:
|
||||||
|
return img[i:i + h, j:j + w, :]
|
||||||
|
elif img.ndim == 2:
|
||||||
|
return img[i:i + h, j:j + w]
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
'img should be ndarray with 2 or 3 dimensions. Got {}'.format(
|
||||||
|
img.ndim))
|
||||||
|
|
||||||
|
|
||||||
|
class Crop(object):
|
||||||
|
"""Crops the given ``numpy.ndarray`` at the center.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size (sequence or int): Desired output size of the crop. If size is an
|
||||||
|
int instead of sequence like (h, w), a square crop (size, size) is
|
||||||
|
made.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, crop):
|
||||||
|
self.crop = crop
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_params(img, crop):
|
||||||
|
"""Get parameters for ``crop`` for center crop.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (numpy.ndarray (C x H x W)): Image to be cropped.
|
||||||
|
output_size (tuple): Expected output size of the crop.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: params (i, j, h, w) to be passed to ``crop`` for center crop.
|
||||||
|
"""
|
||||||
|
x_l, x_r, y_b, y_t = crop
|
||||||
|
h = img.shape[0]
|
||||||
|
w = img.shape[1]
|
||||||
|
assert x_l >= 0 and x_l < w
|
||||||
|
assert x_r >= 0 and x_r < w
|
||||||
|
assert y_b >= 0 and y_b < h
|
||||||
|
assert y_t >= 0 and y_t < h
|
||||||
|
assert x_l < x_r and y_b < y_t
|
||||||
|
|
||||||
|
return x_l, x_r, y_b, y_t
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
img (numpy.ndarray (C x H x W)): Image to be cropped.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
img (numpy.ndarray (C x H x W)): Cropped image.
|
||||||
|
"""
|
||||||
|
x_l, x_r, y_b, y_t = self.get_params(img, self.crop)
|
||||||
|
"""
|
||||||
|
i: Upper pixel coordinate.
|
||||||
|
j: Left pixel coordinate.
|
||||||
|
h: Height of the cropped image.
|
||||||
|
w: Width of the cropped image.
|
||||||
|
"""
|
||||||
|
if not (_is_numpy_image(img)):
|
||||||
|
raise TypeError('img should be ndarray. Got {}'.format(type(img)))
|
||||||
|
if img.ndim == 3:
|
||||||
|
return img[y_b:y_t, x_l:x_r, :]
|
||||||
|
elif img.ndim == 2:
|
||||||
|
return img[y_b:y_t, x_l:x_r]
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
'img should be ndarray with 2 or 3 dimensions. Got {}'.format(
|
||||||
|
img.ndim))
|
||||||
|
|
||||||
|
|
||||||
|
class Lambda(object):
|
||||||
|
"""Apply a user-defined lambda as a transform.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lambd (function): Lambda/function to be used for transform.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, lambd):
|
||||||
|
assert isinstance(lambd, types.LambdaType)
|
||||||
|
self.lambd = lambd
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
return self.lambd(img)
|
||||||
|
|
||||||
|
|
||||||
|
class HorizontalFlip(object):
|
||||||
|
"""Horizontally flip the given ``numpy.ndarray``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
do_flip (boolean): whether or not do horizontal flip.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, do_flip):
|
||||||
|
self.do_flip = do_flip
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
img (numpy.ndarray (C x H x W)): Image to be flipped.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
img (numpy.ndarray (C x H x W)): flipped image.
|
||||||
|
"""
|
||||||
|
if not (_is_numpy_image(img)):
|
||||||
|
raise TypeError('img should be ndarray. Got {}'.format(type(img)))
|
||||||
|
|
||||||
|
if self.do_flip:
|
||||||
|
return np.fliplr(img)
|
||||||
|
else:
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
class ColorJitter(object):
|
||||||
|
"""Randomly change the brightness, contrast and saturation of an image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
brightness (float): How much to jitter brightness. brightness_factor
|
||||||
|
is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
|
||||||
|
contrast (float): How much to jitter contrast. contrast_factor
|
||||||
|
is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
|
||||||
|
saturation (float): How much to jitter saturation. saturation_factor
|
||||||
|
is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
|
||||||
|
hue(float): How much to jitter hue. hue_factor is chosen uniformly from
|
||||||
|
[-hue, hue]. Should be >=0 and <= 0.5.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
|
||||||
|
transforms = []
|
||||||
|
transforms.append(
|
||||||
|
Lambda(lambda img: adjust_brightness(img, brightness)))
|
||||||
|
transforms.append(Lambda(lambda img: adjust_contrast(img, contrast)))
|
||||||
|
transforms.append(
|
||||||
|
Lambda(lambda img: adjust_saturation(img, saturation)))
|
||||||
|
transforms.append(Lambda(lambda img: adjust_hue(img, hue)))
|
||||||
|
np.random.shuffle(transforms)
|
||||||
|
self.transform = Compose(transforms)
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
img (numpy.ndarray (C x H x W)): Input image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
img (numpy.ndarray (C x H x W)): Color jittered image.
|
||||||
|
"""
|
||||||
|
if not (_is_numpy_image(img)):
|
||||||
|
raise TypeError('img should be ndarray. Got {}'.format(type(img)))
|
||||||
|
|
||||||
|
pil = Image.fromarray(img)
|
||||||
|
return np.array(self.transform(pil))
|
||||||
269
modelscope/models/cv/self_supervised_depth_completion/helper.py
Normal file
269
modelscope/models/cv/self_supervised_depth_completion/helper.py
Normal file
@@ -0,0 +1,269 @@
|
|||||||
|
import csv
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from modelscope.models.cv.self_supervised_depth_completion import vis_utils
|
||||||
|
from modelscope.models.cv.self_supervised_depth_completion.metrics import \
|
||||||
|
Result
|
||||||
|
|
||||||
|
fieldnames = [
|
||||||
|
'epoch', 'rmse', 'photo', 'mae', 'irmse', 'imae', 'mse', 'absrel', 'lg10',
|
||||||
|
'silog', 'squared_rel', 'delta1', 'delta2', 'delta3', 'data_time',
|
||||||
|
'gpu_time'
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class logger:
|
||||||
|
|
||||||
|
def __init__(self, args, prepare=True):
|
||||||
|
self.args = args
|
||||||
|
output_directory = get_folder_name(args)
|
||||||
|
self.output_directory = output_directory
|
||||||
|
self.best_result = Result()
|
||||||
|
self.best_result.set_to_worst()
|
||||||
|
|
||||||
|
if not prepare:
|
||||||
|
return
|
||||||
|
if not os.path.exists(output_directory):
|
||||||
|
os.makedirs(output_directory)
|
||||||
|
self.train_csv = os.path.join(output_directory, 'train.csv')
|
||||||
|
self.val_csv = os.path.join(output_directory, 'val.csv')
|
||||||
|
self.best_txt = os.path.join(output_directory, 'best.txt')
|
||||||
|
|
||||||
|
# backup the source code
|
||||||
|
if args.resume == '':
|
||||||
|
print('=> creating source code backup ...')
|
||||||
|
backup_directory = os.path.join(output_directory, 'code_backup')
|
||||||
|
self.backup_directory = backup_directory
|
||||||
|
# backup_source_code(backup_directory)
|
||||||
|
# create new csv files with only header
|
||||||
|
with open(self.train_csv, 'w') as csvfile:
|
||||||
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||||
|
writer.writeheader()
|
||||||
|
with open(self.val_csv, 'w') as csvfile:
|
||||||
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||||
|
writer.writeheader()
|
||||||
|
print('=> finished creating source code backup.')
|
||||||
|
|
||||||
|
def conditional_print(self, split, i, epoch, lr, n_set, blk_avg_meter,
|
||||||
|
avg_meter):
|
||||||
|
if (i + 1) % self.args.print_freq == 0:
|
||||||
|
avg = avg_meter.average()
|
||||||
|
blk_avg = blk_avg_meter.average()
|
||||||
|
print('=> output: {}'.format(self.output_directory))
|
||||||
|
print(
|
||||||
|
'{split} Epoch: {0} [{1}/{2}]\tlr={lr} '
|
||||||
|
't_Data={blk_avg.data_time:.3f}({average.data_time:.3f}) '
|
||||||
|
't_GPU={blk_avg.gpu_time:.3f}({average.gpu_time:.3f})\n\t'
|
||||||
|
'RMSE={blk_avg.rmse:.2f}({average.rmse:.2f}) '
|
||||||
|
'MAE={blk_avg.mae:.2f}({average.mae:.2f}) '
|
||||||
|
'iRMSE={blk_avg.irmse:.2f}({average.irmse:.2f}) '
|
||||||
|
'iMAE={blk_avg.imae:.2f}({average.imae:.2f})\n\t'
|
||||||
|
'silog={blk_avg.silog:.2f}({average.silog:.2f}) '
|
||||||
|
'squared_rel={blk_avg.squared_rel:.2f}({average.squared_rel:.2f}) '
|
||||||
|
'Delta1={blk_avg.delta1:.3f}({average.delta1:.3f}) '
|
||||||
|
'REL={blk_avg.absrel:.3f}({average.absrel:.3f})\n\t'
|
||||||
|
'Lg10={blk_avg.lg10:.3f}({average.lg10:.3f}) '
|
||||||
|
'Photometric={blk_avg.photometric:.3f}({average.photometric:.3f}) '
|
||||||
|
.format(
|
||||||
|
epoch,
|
||||||
|
i + 1,
|
||||||
|
n_set,
|
||||||
|
lr=lr,
|
||||||
|
blk_avg=blk_avg,
|
||||||
|
average=avg,
|
||||||
|
split=split.capitalize()))
|
||||||
|
blk_avg_meter.reset()
|
||||||
|
|
||||||
|
def conditional_save_info(self, split, average_meter, epoch):
|
||||||
|
avg = average_meter.average()
|
||||||
|
if split == 'train':
|
||||||
|
csvfile_name = self.train_csv
|
||||||
|
elif split == 'val':
|
||||||
|
csvfile_name = self.val_csv
|
||||||
|
elif split == 'eval':
|
||||||
|
eval_filename = os.path.join(self.output_directory, 'eval.txt')
|
||||||
|
self.save_single_txt(eval_filename, avg, epoch)
|
||||||
|
return avg
|
||||||
|
elif 'test' in split:
|
||||||
|
return avg
|
||||||
|
else:
|
||||||
|
raise ValueError('wrong split provided to logger')
|
||||||
|
with open(csvfile_name, 'a') as csvfile:
|
||||||
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||||
|
writer.writerow({
|
||||||
|
'epoch': epoch,
|
||||||
|
'rmse': avg.rmse,
|
||||||
|
'photo': avg.photometric,
|
||||||
|
'mae': avg.mae,
|
||||||
|
'irmse': avg.irmse,
|
||||||
|
'imae': avg.imae,
|
||||||
|
'mse': avg.mse,
|
||||||
|
'silog': avg.silog,
|
||||||
|
'squared_rel': avg.squared_rel,
|
||||||
|
'absrel': avg.absrel,
|
||||||
|
'lg10': avg.lg10,
|
||||||
|
'delta1': avg.delta1,
|
||||||
|
'delta2': avg.delta2,
|
||||||
|
'delta3': avg.delta3,
|
||||||
|
'gpu_time': avg.gpu_time,
|
||||||
|
'data_time': avg.data_time
|
||||||
|
})
|
||||||
|
return avg
|
||||||
|
|
||||||
|
def save_single_txt(self, filename, result, epoch):
|
||||||
|
with open(filename, 'w') as txtfile:
|
||||||
|
txtfile.write(
|
||||||
|
('rank_metric={}\n' + 'epoch={}\n' + 'rmse={:.3f}\n'
|
||||||
|
+ 'mae={:.3f}\n' + 'silog={:.3f}\n' + 'squared_rel={:.3f}\n'
|
||||||
|
+ 'irmse={:.3f}\n' + 'imae={:.3f}\n' + 'mse={:.3f}\n'
|
||||||
|
+ 'absrel={:.3f}\n' + 'lg10={:.3f}\n'
|
||||||
|
+ 'delta1={:.3f}\n' + 't_gpu={:.4f}').format(
|
||||||
|
self.args.rank_metric, epoch, result.rmse, result.mae,
|
||||||
|
result.silog, result.squared_rel, result.irmse,
|
||||||
|
result.imae, result.mse, result.absrel, result.lg10,
|
||||||
|
result.delta1, result.gpu_time))
|
||||||
|
|
||||||
|
def save_best_txt(self, result, epoch):
|
||||||
|
self.save_single_txt(self.best_txt, result, epoch)
|
||||||
|
|
||||||
|
def _get_img_comparison_name(self, mode, epoch, is_best=False):
|
||||||
|
if mode == 'eval':
|
||||||
|
return self.output_directory + '/comparison_eval.png'
|
||||||
|
if mode == 'val':
|
||||||
|
if is_best:
|
||||||
|
return self.output_directory + '/comparison_best.png'
|
||||||
|
else:
|
||||||
|
return self.output_directory + '/comparison_' + str(
|
||||||
|
epoch) + '.png'
|
||||||
|
|
||||||
|
def conditional_save_img_comparison(self, mode, i, ele, pred, epoch):
|
||||||
|
# save 8 images for visualization
|
||||||
|
if mode == 'val' or mode == 'eval':
|
||||||
|
skip = 100
|
||||||
|
if i == 0:
|
||||||
|
self.img_merge = vis_utils.merge_into_row(ele, pred)
|
||||||
|
elif i % skip == 0 and i < 8 * skip:
|
||||||
|
row = vis_utils.merge_into_row(ele, pred)
|
||||||
|
self.img_merge = vis_utils.add_row(self.img_merge, row)
|
||||||
|
elif i == 8 * skip:
|
||||||
|
filename = self._get_img_comparison_name(mode, epoch)
|
||||||
|
vis_utils.save_image(self.img_merge, filename)
|
||||||
|
return self.img_merge
|
||||||
|
|
||||||
|
def save_img_comparison_as_best(self, mode, epoch):
|
||||||
|
if mode == 'val':
|
||||||
|
filename = self._get_img_comparison_name(mode, epoch, is_best=True)
|
||||||
|
vis_utils.save_image(self.img_merge, filename)
|
||||||
|
|
||||||
|
def get_ranking_error(self, result):
|
||||||
|
return getattr(result, self.args.rank_metric)
|
||||||
|
|
||||||
|
def rank_conditional_save_best(self, mode, result, epoch):
|
||||||
|
error = self.get_ranking_error(result)
|
||||||
|
best_error = self.get_ranking_error(self.best_result)
|
||||||
|
is_best = error < best_error
|
||||||
|
if is_best and mode == 'val':
|
||||||
|
self.old_best_result = self.best_result
|
||||||
|
self.best_result = result
|
||||||
|
self.save_best_txt(result, epoch)
|
||||||
|
return is_best
|
||||||
|
|
||||||
|
def conditional_save_pred(self, mode, i, pred, epoch):
|
||||||
|
if ('test' in mode or mode == 'eval') and self.args.save_pred:
|
||||||
|
|
||||||
|
# save images for visualization/ testing
|
||||||
|
image_folder = os.path.join(self.output_directory,
|
||||||
|
mode + '_output')
|
||||||
|
if not os.path.exists(image_folder):
|
||||||
|
os.makedirs(image_folder)
|
||||||
|
img = torch.squeeze(pred.data.cpu()).numpy()
|
||||||
|
filename = os.path.join(image_folder, '{0:010d}.png'.format(i))
|
||||||
|
vis_utils.save_depth_as_uint16png(img, filename)
|
||||||
|
|
||||||
|
def conditional_summarize(self, mode, avg, is_best):
|
||||||
|
print('\n*\nSummary of ', mode, 'round')
|
||||||
|
print(''
|
||||||
|
'RMSE={average.rmse:.3f}\n'
|
||||||
|
'MAE={average.mae:.3f}\n'
|
||||||
|
'Photo={average.photometric:.3f}\n'
|
||||||
|
'iRMSE={average.irmse:.3f}\n'
|
||||||
|
'iMAE={average.imae:.3f}\n'
|
||||||
|
'squared_rel={average.squared_rel}\n'
|
||||||
|
'silog={average.silog}\n'
|
||||||
|
'Delta1={average.delta1:.3f}\n'
|
||||||
|
'REL={average.absrel:.3f}\n'
|
||||||
|
'Lg10={average.lg10:.3f}\n'
|
||||||
|
't_GPU={time:.3f}'.format(average=avg, time=avg.gpu_time))
|
||||||
|
if is_best and mode == 'val':
|
||||||
|
print('New best model by %s (was %.3f)' %
|
||||||
|
(self.args.rank_metric,
|
||||||
|
self.get_ranking_error(self.old_best_result)))
|
||||||
|
elif mode == 'val':
|
||||||
|
print('(best %s is %.3f)' %
|
||||||
|
(self.args.rank_metric,
|
||||||
|
self.get_ranking_error(self.best_result)))
|
||||||
|
print('*\n')
|
||||||
|
|
||||||
|
|
||||||
|
ignore_hidden = shutil.ignore_patterns('.', '..', '.git*', '*pycache*',
|
||||||
|
'*build', '*.fuse*', '*_drive_*')
|
||||||
|
|
||||||
|
|
||||||
|
def backup_source_code(backup_directory):
|
||||||
|
if os.path.exists(backup_directory):
|
||||||
|
shutil.rmtree(backup_directory)
|
||||||
|
shutil.copytree('.', backup_directory, ignore=ignore_hidden)
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_learning_rate(lr_init, optimizer, epoch):
|
||||||
|
"""Sets the learning rate to the initial LR decayed by 10 every 5 epochs"""
|
||||||
|
lr = lr_init * (0.1**(epoch // 5))
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
param_group['lr'] = lr
|
||||||
|
return lr
|
||||||
|
|
||||||
|
|
||||||
|
def save_checkpoint(state, is_best, epoch, output_directory):
|
||||||
|
checkpoint_filename = os.path.join(output_directory,
|
||||||
|
'checkpoint-' + str(epoch) + '.pth.tar')
|
||||||
|
torch.save(state, checkpoint_filename)
|
||||||
|
if is_best:
|
||||||
|
best_filename = os.path.join(output_directory, 'model_best.pth.tar')
|
||||||
|
shutil.copyfile(checkpoint_filename, best_filename)
|
||||||
|
if epoch > 0:
|
||||||
|
prev_checkpoint_filename = os.path.join(
|
||||||
|
output_directory, 'checkpoint-' + str(epoch - 1) + '.pth.tar')
|
||||||
|
if os.path.exists(prev_checkpoint_filename):
|
||||||
|
os.remove(prev_checkpoint_filename)
|
||||||
|
|
||||||
|
|
||||||
|
def get_folder_name(args):
|
||||||
|
# current_time = time.strftime('%Y-%m-%d@%H-%M')
|
||||||
|
# if args.use_pose:
|
||||||
|
# prefix = 'mode={}.w1={}.w2={}.'.format(args.train_mode, args.w1,
|
||||||
|
# args.w2)
|
||||||
|
# else:
|
||||||
|
# prefix = 'mode={}.'.format(args.train_mode)
|
||||||
|
# return os.path.join(args.result,
|
||||||
|
# prefix + 'input={}.resnet{}.criterion={}.lr={}.bs={}.wd={}.pretrained={}.jitter={}.time={}'.
|
||||||
|
# format(args.input, args.layers, args.criterion, \
|
||||||
|
# args.lr, args.batch_size, args.weight_decay, \
|
||||||
|
# args.pretrained, args.jitter, current_time
|
||||||
|
# ))
|
||||||
|
return os.path.join(args.result, 'test')
|
||||||
|
|
||||||
|
|
||||||
|
avgpool = torch.nn.AvgPool2d(kernel_size=2, stride=2).cuda()
|
||||||
|
|
||||||
|
|
||||||
|
def multiscale(img):
|
||||||
|
img1 = avgpool(img)
|
||||||
|
img2 = avgpool(img1)
|
||||||
|
img3 = avgpool(img2)
|
||||||
|
img4 = avgpool(img3)
|
||||||
|
img5 = avgpool(img4)
|
||||||
|
return img5, img4, img3, img2, img1
|
||||||
@@ -0,0 +1,141 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from modelscope.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class Intrinsics:
|
||||||
|
"""Intrinsics"""
|
||||||
|
|
||||||
|
def __init__(self, width, height, fu, fv, cu=0, cv=0):
|
||||||
|
self.height, self.width = height, width
|
||||||
|
self.fu, self.fv = fu, fv # fu, fv: focal length along the horizontal and vertical axes
|
||||||
|
|
||||||
|
# cu, cv: optical center along the horizontal and vertical axes
|
||||||
|
self.cu = cu if cu > 0 else (width - 1) / 2.0
|
||||||
|
self.cv = cv if cv > 0 else (height - 1) / 2.0
|
||||||
|
|
||||||
|
# U, V represent the homogeneous horizontal and vertical coordinates in the pixel space
|
||||||
|
self.U = torch.arange(start=0, end=width).expand(height, width).float()
|
||||||
|
self.V = torch.arange(
|
||||||
|
start=0, end=height).expand(width, height).t().float()
|
||||||
|
|
||||||
|
# X_cam, Y_cam represent the homogeneous x, y coordinates (assuming depth z=1) in the camera coordinate system
|
||||||
|
self.X_cam = (self.U - self.cu) / self.fu
|
||||||
|
self.Y_cam = (self.V - self.cv) / self.fv
|
||||||
|
|
||||||
|
self.is_cuda = False
|
||||||
|
|
||||||
|
def cuda(self):
|
||||||
|
self.X_cam.data = self.X_cam.data.cuda()
|
||||||
|
self.Y_cam.data = self.Y_cam.data.cuda()
|
||||||
|
self.is_cuda = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
def scale(self, height, width):
|
||||||
|
# return a new set of corresponding intrinsic parameters for the scaled image
|
||||||
|
ratio_u = float(width) / self.width
|
||||||
|
ratio_v = float(height) / self.height
|
||||||
|
fu = ratio_u * self.fu
|
||||||
|
fv = ratio_v * self.fv
|
||||||
|
cu = ratio_u * self.cu
|
||||||
|
cv = ratio_v * self.cv
|
||||||
|
new_intrinsics = Intrinsics(width, height, fu, fv, cu, cv)
|
||||||
|
if self.is_cuda:
|
||||||
|
new_intrinsics.cuda()
|
||||||
|
return new_intrinsics
|
||||||
|
|
||||||
|
def __print__(self):
|
||||||
|
logger.info(
|
||||||
|
'size=({},{})\nfocal length=({},{})\noptical center=({},{})'.
|
||||||
|
format(self.height, self.width, self.fv, self.fu, self.cv,
|
||||||
|
self.cu))
|
||||||
|
|
||||||
|
|
||||||
|
def image_to_pointcloud(depth, intrinsics):
|
||||||
|
assert depth.dim() == 4
|
||||||
|
assert depth.size(1) == 1
|
||||||
|
|
||||||
|
X = depth * intrinsics.X_cam
|
||||||
|
Y = depth * intrinsics.Y_cam
|
||||||
|
return torch.cat((X, Y, depth), dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def pointcloud_to_image(pointcloud, intrinsics):
|
||||||
|
assert pointcloud.dim() == 4
|
||||||
|
|
||||||
|
batch_size = pointcloud.size(0)
|
||||||
|
X = pointcloud[:, 0, :, :] # .view(batch_size, -1)
|
||||||
|
Y = pointcloud[:, 1, :, :] # .view(batch_size, -1)
|
||||||
|
Z = pointcloud[:, 2, :, :].clamp(min=1e-3) # .view(batch_size, -1)
|
||||||
|
|
||||||
|
# compute pixel coordinates
|
||||||
|
U_proj = intrinsics.fu * X / Z + intrinsics.cu # horizontal pixel coordinate
|
||||||
|
V_proj = intrinsics.fv * Y / Z + intrinsics.cv # vertical pixel coordinate
|
||||||
|
|
||||||
|
# normalization to [-1, 1], required by torch.nn.functional.grid_sample
|
||||||
|
w = intrinsics.width
|
||||||
|
h = intrinsics.height
|
||||||
|
U_proj_normalized = (2 * U_proj / (w - 1) - 1).view(batch_size, -1)
|
||||||
|
V_proj_normalized = (2 * V_proj / (h - 1) - 1).view(batch_size, -1)
|
||||||
|
|
||||||
|
# This was important since PyTorch didn't do as it claimed for points out of boundary
|
||||||
|
# See https://github.com/ClementPinard/SfmLearner-Pytorch/blob/master/inverse_warp.py
|
||||||
|
# Might not be necessary any more
|
||||||
|
U_proj_mask = ((U_proj_normalized > 1) + (U_proj_normalized < -1)).detach()
|
||||||
|
U_proj_normalized[U_proj_mask] = 2
|
||||||
|
V_proj_mask = ((V_proj_normalized > 1) + (V_proj_normalized < -1)).detach()
|
||||||
|
V_proj_normalized[V_proj_mask] = 2
|
||||||
|
|
||||||
|
pixel_coords = torch.stack([U_proj_normalized, V_proj_normalized],
|
||||||
|
dim=2) # [B, H*W, 2]
|
||||||
|
return pixel_coords.view(batch_size, intrinsics.height, intrinsics.width,
|
||||||
|
2)
|
||||||
|
|
||||||
|
|
||||||
|
def batch_multiply(batch_scalar, batch_matrix):
|
||||||
|
# input: batch_scalar of size b, batch_matrix of size b * 3 * 3
|
||||||
|
# output: batch_matrix of size b * 3 * 3
|
||||||
|
batch_size = batch_scalar.size(0)
|
||||||
|
output = batch_matrix.clone()
|
||||||
|
for i in range(batch_size):
|
||||||
|
output[i] = batch_scalar[i] * batch_matrix[i]
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def transform_curr_to_near(pointcloud_curr, r_mat, t_vec, intrinsics):
|
||||||
|
# translation and rotmat represent the transformation from tgt pose to src pose
|
||||||
|
batch_size = pointcloud_curr.size(0)
|
||||||
|
XYZ_ = torch.bmm(r_mat, pointcloud_curr.view(batch_size, 3, -1))
|
||||||
|
|
||||||
|
X = (XYZ_[:, 0, :] + t_vec[:, 0].unsqueeze(1)).view(
|
||||||
|
-1, 1, intrinsics.height, intrinsics.width)
|
||||||
|
Y = (XYZ_[:, 1, :] + t_vec[:, 1].unsqueeze(1)).view(
|
||||||
|
-1, 1, intrinsics.height, intrinsics.width)
|
||||||
|
Z = (XYZ_[:, 2, :] + t_vec[:, 2].unsqueeze(1)).view(
|
||||||
|
-1, 1, intrinsics.height, intrinsics.width)
|
||||||
|
|
||||||
|
pointcloud_near = torch.cat((X, Y, Z), dim=1)
|
||||||
|
|
||||||
|
return pointcloud_near
|
||||||
|
|
||||||
|
|
||||||
|
def homography_from(rgb_near, depth_curr, r_mat, t_vec, intrinsics):
|
||||||
|
# inverse warp the RGB image from the nearby frame to the current frame
|
||||||
|
|
||||||
|
# to ensure dimension consistency
|
||||||
|
r_mat = r_mat.view(-1, 3, 3)
|
||||||
|
t_vec = t_vec.view(-1, 3)
|
||||||
|
|
||||||
|
# compute source pixel coordinate
|
||||||
|
pointcloud_curr = image_to_pointcloud(depth_curr, intrinsics)
|
||||||
|
pointcloud_near = transform_curr_to_near(pointcloud_curr, r_mat, t_vec,
|
||||||
|
intrinsics)
|
||||||
|
pixel_coords_near = pointcloud_to_image(pointcloud_near, intrinsics)
|
||||||
|
|
||||||
|
# the warping
|
||||||
|
warped = F.grid_sample(rgb_near, pixel_coords_near)
|
||||||
|
|
||||||
|
return warped
|
||||||
181
modelscope/models/cv/self_supervised_depth_completion/metrics.py
Normal file
181
modelscope/models/cv/self_supervised_depth_completion/metrics.py
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
lg_e_10 = math.log(10)
|
||||||
|
|
||||||
|
|
||||||
|
def log10(x):
|
||||||
|
"""Convert a new tensor with the base-10 logarithm of the elements of x. """
|
||||||
|
return torch.log(x) / lg_e_10
|
||||||
|
|
||||||
|
|
||||||
|
class Result(object):
|
||||||
|
"""Result"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.irmse = 0
|
||||||
|
self.imae = 0
|
||||||
|
self.mse = 0
|
||||||
|
self.rmse = 0
|
||||||
|
self.mae = 0
|
||||||
|
self.absrel = 0
|
||||||
|
self.squared_rel = 0
|
||||||
|
self.lg10 = 0
|
||||||
|
self.delta1 = 0
|
||||||
|
self.delta2 = 0
|
||||||
|
self.delta3 = 0
|
||||||
|
self.data_time = 0
|
||||||
|
self.gpu_time = 0
|
||||||
|
self.silog = 0 # Scale invariant logarithmic error [log(m)*100]
|
||||||
|
self.photometric = 0
|
||||||
|
|
||||||
|
def set_to_worst(self):
|
||||||
|
self.irmse = np.inf
|
||||||
|
self.imae = np.inf
|
||||||
|
self.mse = np.inf
|
||||||
|
self.rmse = np.inf
|
||||||
|
self.mae = np.inf
|
||||||
|
self.absrel = np.inf
|
||||||
|
self.squared_rel = np.inf
|
||||||
|
self.lg10 = np.inf
|
||||||
|
self.silog = np.inf
|
||||||
|
self.delta1 = 0
|
||||||
|
self.delta2 = 0
|
||||||
|
self.delta3 = 0
|
||||||
|
self.data_time = 0
|
||||||
|
self.gpu_time = 0
|
||||||
|
|
||||||
|
def update(self,
|
||||||
|
irmse,
|
||||||
|
imae,
|
||||||
|
mse,
|
||||||
|
rmse,
|
||||||
|
mae,
|
||||||
|
absrel,
|
||||||
|
squared_rel,
|
||||||
|
lg10,
|
||||||
|
delta1,
|
||||||
|
delta2,
|
||||||
|
delta3,
|
||||||
|
gpu_time,
|
||||||
|
data_time,
|
||||||
|
silog,
|
||||||
|
photometric=0):
|
||||||
|
"""update"""
|
||||||
|
self.irmse = irmse
|
||||||
|
self.imae = imae
|
||||||
|
self.mse = mse
|
||||||
|
self.rmse = rmse
|
||||||
|
self.mae = mae
|
||||||
|
self.absrel = absrel
|
||||||
|
self.squared_rel = squared_rel
|
||||||
|
self.lg10 = lg10
|
||||||
|
self.delta1 = delta1
|
||||||
|
self.delta2 = delta2
|
||||||
|
self.delta3 = delta3
|
||||||
|
self.data_time = data_time
|
||||||
|
self.gpu_time = gpu_time
|
||||||
|
self.silog = silog
|
||||||
|
self.photometric = photometric
|
||||||
|
|
||||||
|
def evaluate(self, output, target, photometric=0):
|
||||||
|
"""evaluate"""
|
||||||
|
valid_mask = target > 0.1
|
||||||
|
|
||||||
|
# convert from meters to mm
|
||||||
|
output_mm = 1e3 * output[valid_mask]
|
||||||
|
target_mm = 1e3 * target[valid_mask]
|
||||||
|
|
||||||
|
abs_diff = (output_mm - target_mm).abs()
|
||||||
|
|
||||||
|
self.mse = float((torch.pow(abs_diff, 2)).mean())
|
||||||
|
self.rmse = math.sqrt(self.mse)
|
||||||
|
self.mae = float(abs_diff.mean())
|
||||||
|
self.lg10 = float((log10(output_mm) - log10(target_mm)).abs().mean())
|
||||||
|
self.absrel = float((abs_diff / target_mm).mean())
|
||||||
|
self.squared_rel = float(((abs_diff / target_mm)**2).mean())
|
||||||
|
|
||||||
|
maxRatio = torch.max(output_mm / target_mm, target_mm / output_mm)
|
||||||
|
self.delta1 = float((maxRatio < 1.25).float().mean())
|
||||||
|
self.delta2 = float((maxRatio < 1.25**2).float().mean())
|
||||||
|
self.delta3 = float((maxRatio < 1.25**3).float().mean())
|
||||||
|
self.data_time = 0
|
||||||
|
self.gpu_time = 0
|
||||||
|
|
||||||
|
# silog uses meters
|
||||||
|
err_log = torch.log(target[valid_mask]) - torch.log(output[valid_mask])
|
||||||
|
normalized_squared_log = (err_log**2).mean()
|
||||||
|
log_mean = err_log.mean()
|
||||||
|
self.silog = math.sqrt(normalized_squared_log
|
||||||
|
- log_mean * log_mean) * 100
|
||||||
|
|
||||||
|
# convert from meters to km
|
||||||
|
inv_output_km = (1e-3 * output[valid_mask])**(-1)
|
||||||
|
inv_target_km = (1e-3 * target[valid_mask])**(-1)
|
||||||
|
abs_inv_diff = (inv_output_km - inv_target_km).abs()
|
||||||
|
self.irmse = math.sqrt((torch.pow(abs_inv_diff, 2)).mean())
|
||||||
|
self.imae = float(abs_inv_diff.mean())
|
||||||
|
|
||||||
|
self.photometric = float(photometric)
|
||||||
|
|
||||||
|
|
||||||
|
class AverageMeter(object):
|
||||||
|
"""AverageMeter"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""reset"""
|
||||||
|
self.count = 0.0
|
||||||
|
self.sum_irmse = 0
|
||||||
|
self.sum_imae = 0
|
||||||
|
self.sum_mse = 0
|
||||||
|
self.sum_rmse = 0
|
||||||
|
self.sum_mae = 0
|
||||||
|
self.sum_absrel = 0
|
||||||
|
self.sum_squared_rel = 0
|
||||||
|
self.sum_lg10 = 0
|
||||||
|
self.sum_delta1 = 0
|
||||||
|
self.sum_delta2 = 0
|
||||||
|
self.sum_delta3 = 0
|
||||||
|
self.sum_data_time = 0
|
||||||
|
self.sum_gpu_time = 0
|
||||||
|
self.sum_photometric = 0
|
||||||
|
self.sum_silog = 0
|
||||||
|
|
||||||
|
def update(self, result, gpu_time, data_time, n=1):
|
||||||
|
"""update"""
|
||||||
|
self.count += n
|
||||||
|
self.sum_irmse += n * result.irmse
|
||||||
|
self.sum_imae += n * result.imae
|
||||||
|
self.sum_mse += n * result.mse
|
||||||
|
self.sum_rmse += n * result.rmse
|
||||||
|
self.sum_mae += n * result.mae
|
||||||
|
self.sum_absrel += n * result.absrel
|
||||||
|
self.sum_squared_rel += n * result.squared_rel
|
||||||
|
self.sum_lg10 += n * result.lg10
|
||||||
|
self.sum_delta1 += n * result.delta1
|
||||||
|
self.sum_delta2 += n * result.delta2
|
||||||
|
self.sum_delta3 += n * result.delta3
|
||||||
|
self.sum_data_time += n * data_time
|
||||||
|
self.sum_gpu_time += n * gpu_time
|
||||||
|
self.sum_silog += n * result.silog
|
||||||
|
self.sum_photometric += n * result.photometric
|
||||||
|
|
||||||
|
def average(self):
|
||||||
|
"""average"""
|
||||||
|
avg = Result()
|
||||||
|
if self.count > 0:
|
||||||
|
avg.update(
|
||||||
|
self.sum_irmse / self.count, self.sum_imae / self.count,
|
||||||
|
self.sum_mse / self.count, self.sum_rmse / self.count,
|
||||||
|
self.sum_mae / self.count, self.sum_absrel / self.count,
|
||||||
|
self.sum_squared_rel / self.count, self.sum_lg10 / self.count,
|
||||||
|
self.sum_delta1 / self.count, self.sum_delta2 / self.count,
|
||||||
|
self.sum_delta3 / self.count, self.sum_gpu_time / self.count,
|
||||||
|
self.sum_data_time / self.count, self.sum_silog / self.count,
|
||||||
|
self.sum_photometric / self.count)
|
||||||
|
return avg
|
||||||
215
modelscope/models/cv/self_supervised_depth_completion/model.py
Normal file
215
modelscope/models/cv/self_supervised_depth_completion/model.py
Normal file
@@ -0,0 +1,215 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torchvision.models import resnet
|
||||||
|
|
||||||
|
|
||||||
|
def init_weights(m):
|
||||||
|
"""init_weights"""
|
||||||
|
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
||||||
|
m.weight.data.normal_(0, 1e-3)
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
|
elif isinstance(m, nn.ConvTranspose2d):
|
||||||
|
m.weight.data.normal_(0, 1e-3)
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
m.weight.data.fill_(1)
|
||||||
|
m.bias.data.zero_()
|
||||||
|
|
||||||
|
|
||||||
|
def conv_bn_relu(in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
bn=True,
|
||||||
|
relu=True):
|
||||||
|
"""conv_bn_relu"""
|
||||||
|
bias = not bn
|
||||||
|
layers = []
|
||||||
|
layers.append(
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels, out_channels, kernel_size, stride, padding,
|
||||||
|
bias=bias))
|
||||||
|
if bn:
|
||||||
|
layers.append(nn.BatchNorm2d(out_channels))
|
||||||
|
if relu:
|
||||||
|
layers.append(nn.LeakyReLU(0.2, inplace=True))
|
||||||
|
layers = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
# initialize the weights
|
||||||
|
for m in layers.modules():
|
||||||
|
init_weights(m)
|
||||||
|
|
||||||
|
return layers
|
||||||
|
|
||||||
|
|
||||||
|
def convt_bn_relu(in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
output_padding=0,
|
||||||
|
bn=True,
|
||||||
|
relu=True):
|
||||||
|
"""convt_bn_relu"""
|
||||||
|
bias = not bn
|
||||||
|
layers = []
|
||||||
|
layers.append(
|
||||||
|
nn.ConvTranspose2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
output_padding,
|
||||||
|
bias=bias))
|
||||||
|
if bn:
|
||||||
|
layers.append(nn.BatchNorm2d(out_channels))
|
||||||
|
if relu:
|
||||||
|
layers.append(nn.LeakyReLU(0.2, inplace=True))
|
||||||
|
layers = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
# initialize the weights
|
||||||
|
for m in layers.modules():
|
||||||
|
init_weights(m)
|
||||||
|
|
||||||
|
return layers
|
||||||
|
|
||||||
|
|
||||||
|
class DepthCompletionNet(nn.Module):
|
||||||
|
"""DepthCompletionNet"""
|
||||||
|
|
||||||
|
def __init__(self, args):
|
||||||
|
assert (
|
||||||
|
args.layers in [18, 34, 50, 101, 152]
|
||||||
|
), f'Only layers 18, 34, 50, 101, and 152 are defined, but got {layers}'.format(
|
||||||
|
layers)
|
||||||
|
super(DepthCompletionNet, self).__init__()
|
||||||
|
self.modality = args.input
|
||||||
|
|
||||||
|
if 'd' in self.modality:
|
||||||
|
channels = 64 // len(self.modality)
|
||||||
|
self.conv1_d = conv_bn_relu(
|
||||||
|
1, channels, kernel_size=3, stride=1, padding=1)
|
||||||
|
if 'rgb' in self.modality:
|
||||||
|
channels = 64 * 3 // len(self.modality)
|
||||||
|
self.conv1_img = conv_bn_relu(
|
||||||
|
3, channels, kernel_size=3, stride=1, padding=1)
|
||||||
|
elif 'g' in self.modality:
|
||||||
|
channels = 64 // len(self.modality)
|
||||||
|
self.conv1_img = conv_bn_relu(
|
||||||
|
1, channels, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
pretrained_model = resnet.__dict__['resnet{}'.format(args.layers)](
|
||||||
|
pretrained=args.pretrained)
|
||||||
|
if not args.pretrained:
|
||||||
|
pretrained_model.apply(init_weights)
|
||||||
|
# self.maxpool = pretrained_model._modules['maxpool']
|
||||||
|
self.conv2 = pretrained_model._modules['layer1']
|
||||||
|
self.conv3 = pretrained_model._modules['layer2']
|
||||||
|
self.conv4 = pretrained_model._modules['layer3']
|
||||||
|
self.conv5 = pretrained_model._modules['layer4']
|
||||||
|
del pretrained_model # clear memory
|
||||||
|
|
||||||
|
# define number of intermediate channels
|
||||||
|
if args.layers <= 34:
|
||||||
|
num_channels = 512
|
||||||
|
elif args.layers >= 50:
|
||||||
|
num_channels = 2048
|
||||||
|
self.conv6 = conv_bn_relu(
|
||||||
|
num_channels, 512, kernel_size=3, stride=2, padding=1)
|
||||||
|
|
||||||
|
# decoding layers
|
||||||
|
kernel_size = 3
|
||||||
|
stride = 2
|
||||||
|
self.convt5 = convt_bn_relu(
|
||||||
|
in_channels=512,
|
||||||
|
out_channels=256,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=1,
|
||||||
|
output_padding=1)
|
||||||
|
self.convt4 = convt_bn_relu(
|
||||||
|
in_channels=768,
|
||||||
|
out_channels=128,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=1,
|
||||||
|
output_padding=1)
|
||||||
|
self.convt3 = convt_bn_relu(
|
||||||
|
in_channels=(256 + 128),
|
||||||
|
out_channels=64,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=1,
|
||||||
|
output_padding=1)
|
||||||
|
self.convt2 = convt_bn_relu(
|
||||||
|
in_channels=(128 + 64),
|
||||||
|
out_channels=64,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=1,
|
||||||
|
output_padding=1)
|
||||||
|
self.convt1 = convt_bn_relu(
|
||||||
|
in_channels=128,
|
||||||
|
out_channels=64,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=1)
|
||||||
|
self.convtf = conv_bn_relu(
|
||||||
|
in_channels=128,
|
||||||
|
out_channels=1,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
bn=False,
|
||||||
|
relu=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""forward"""
|
||||||
|
# first layer
|
||||||
|
if 'd' in self.modality:
|
||||||
|
conv1_d = self.conv1_d(x['d'])
|
||||||
|
if 'rgb' in self.modality:
|
||||||
|
conv1_img = self.conv1_img(x['rgb'])
|
||||||
|
elif 'g' in self.modality:
|
||||||
|
conv1_img = self.conv1_img(x['g'])
|
||||||
|
|
||||||
|
if self.modality == 'rgbd' or self.modality == 'gd':
|
||||||
|
conv1 = torch.cat((conv1_d, conv1_img), 1)
|
||||||
|
else:
|
||||||
|
conv1 = conv1_d if (self.modality == 'd') else conv1_img
|
||||||
|
|
||||||
|
conv2 = self.conv2(conv1)
|
||||||
|
conv3 = self.conv3(conv2) # batchsize * ? * 176 * 608
|
||||||
|
conv4 = self.conv4(conv3) # batchsize * ? * 88 * 304
|
||||||
|
conv5 = self.conv5(conv4) # batchsize * ? * 44 * 152
|
||||||
|
conv6 = self.conv6(conv5) # batchsize * ? * 22 * 76
|
||||||
|
|
||||||
|
# decoder
|
||||||
|
convt5 = self.convt5(conv6)
|
||||||
|
y = torch.cat((convt5, conv5), 1)
|
||||||
|
|
||||||
|
convt4 = self.convt4(y)
|
||||||
|
y = torch.cat((convt4, conv4), 1)
|
||||||
|
|
||||||
|
convt3 = self.convt3(y)
|
||||||
|
y = torch.cat((convt3, conv3), 1)
|
||||||
|
|
||||||
|
convt2 = self.convt2(y)
|
||||||
|
y = torch.cat((convt2, conv2), 1)
|
||||||
|
|
||||||
|
convt1 = self.convt1(y)
|
||||||
|
y = torch.cat((convt1, conv1), 1)
|
||||||
|
|
||||||
|
y = self.convtf(y)
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
return 100 * y
|
||||||
|
else:
|
||||||
|
min_distance = 0.9
|
||||||
|
return F.relu(
|
||||||
|
100 * y - min_distance
|
||||||
|
) + min_distance # the minimum range of Velodyne is around 3 feet ~= 0.9m
|
||||||
@@ -0,0 +1,225 @@
|
|||||||
|
# import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
# import mmcv
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
# import torchvision
|
||||||
|
from os import makedirs
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.parallel
|
||||||
|
import torch.optim
|
||||||
|
import torch.utils.data
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from modelscope.metainfo import Models
|
||||||
|
from modelscope.models.base.base_torch_model import TorchModel
|
||||||
|
from modelscope.models.builder import MODELS
|
||||||
|
from modelscope.models.cv.self_supervised_depth_completion import (criteria,
|
||||||
|
helper)
|
||||||
|
from modelscope.models.cv.self_supervised_depth_completion.dataloaders.kitti_loader import (
|
||||||
|
KittiDepth, input_options, load_calib, oheight, owidth)
|
||||||
|
from modelscope.models.cv.self_supervised_depth_completion.inverse_warp import (
|
||||||
|
Intrinsics, homography_from)
|
||||||
|
from modelscope.models.cv.self_supervised_depth_completion.metrics import (
|
||||||
|
AverageMeter, Result)
|
||||||
|
from modelscope.models.cv.self_supervised_depth_completion.model import \
|
||||||
|
DepthCompletionNet
|
||||||
|
from modelscope.utils.constant import Tasks
|
||||||
|
from modelscope.utils.logger import get_logger
|
||||||
|
|
||||||
|
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
|
||||||
|
|
||||||
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
# from modelscope.utils.config import Config
|
||||||
|
|
||||||
|
m_logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class ArgsList():
|
||||||
|
"""ArgsList Class"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.workers = 4
|
||||||
|
self.epochs = 11
|
||||||
|
self.start_epoch = 0
|
||||||
|
self.criterion = 'l2'
|
||||||
|
self.batch_size = 1
|
||||||
|
self.learning_rate = 1e-5
|
||||||
|
self.weight_decay = 0
|
||||||
|
self.print_freq = 10
|
||||||
|
self.resume = ''
|
||||||
|
self.data_folder = '../data'
|
||||||
|
self.input = 'gd'
|
||||||
|
self.layers = 34
|
||||||
|
self.pretrained = True
|
||||||
|
self.val = 'select'
|
||||||
|
self.jitter = 0.1
|
||||||
|
self.rank_metric = 'rmse'
|
||||||
|
self.evaluate = ''
|
||||||
|
self.cpu = False
|
||||||
|
|
||||||
|
|
||||||
|
@MODELS.register_module(
|
||||||
|
Tasks.self_supervised_depth_completion,
|
||||||
|
module_name=Models.self_supervised_depth_completion)
|
||||||
|
class SelfSupervisedDepthCompletion(TorchModel):
|
||||||
|
"""SelfSupervisedDepthCompletion Class"""
|
||||||
|
|
||||||
|
def __init__(self, model_dir: str, **kwargs):
|
||||||
|
"""str -- model file root."""
|
||||||
|
super().__init__(model_dir, **kwargs)
|
||||||
|
|
||||||
|
args = ArgsList()
|
||||||
|
# define loss functions
|
||||||
|
self.depth_criterion = criteria.MaskedMSELoss()
|
||||||
|
self.photometric_criterion = criteria.PhotometricLoss()
|
||||||
|
self.smoothness_criterion = criteria.SmoothnessLoss()
|
||||||
|
|
||||||
|
# args.use_pose = ('photo' in args.train_mode)
|
||||||
|
args.use_pose = True
|
||||||
|
# args.pretrained = not args.no_pretrained
|
||||||
|
args.use_rgb = ('rgb' in args.input) or args.use_pose
|
||||||
|
args.use_d = 'd' in args.input
|
||||||
|
args.use_g = 'g' in args.input
|
||||||
|
|
||||||
|
args.evaluate = os.path.join(self.model_dir, 'model_best.pth')
|
||||||
|
|
||||||
|
if args.use_pose:
|
||||||
|
args.w1, args.w2 = 0.1, 0.1
|
||||||
|
else:
|
||||||
|
args.w1, args.w2 = 0, 0
|
||||||
|
|
||||||
|
self.cuda = torch.cuda.is_available() and not args.cpu
|
||||||
|
if self.cuda:
|
||||||
|
import torch.backends.cudnn as cudnn
|
||||||
|
cudnn.benchmark = True
|
||||||
|
self.device = torch.device('cuda')
|
||||||
|
else:
|
||||||
|
self.device = torch.device('cpu')
|
||||||
|
print("=> using '{}' for computation.".format(self.device))
|
||||||
|
|
||||||
|
args_new = args
|
||||||
|
if os.path.isfile(args.evaluate):
|
||||||
|
print(
|
||||||
|
"=> loading checkpoint '{}' ... ".format(args.evaluate),
|
||||||
|
end='')
|
||||||
|
self.checkpoint = torch.load(
|
||||||
|
args.evaluate, map_location=self.device)
|
||||||
|
args = self.checkpoint['args']
|
||||||
|
args.val = args_new.val
|
||||||
|
print('Completed.')
|
||||||
|
else:
|
||||||
|
print("No model found at '{}'".format(args.evaluate))
|
||||||
|
return
|
||||||
|
|
||||||
|
print('=> creating model and optimizer ... ', end='')
|
||||||
|
model = DepthCompletionNet(args).to(self.device)
|
||||||
|
model_named_params = [
|
||||||
|
p for _, p in model.named_parameters() if p.requires_grad
|
||||||
|
]
|
||||||
|
optimizer = torch.optim.Adam(
|
||||||
|
model_named_params, lr=args.lr, weight_decay=args.weight_decay)
|
||||||
|
print('completed.')
|
||||||
|
if self.checkpoint is not None:
|
||||||
|
model.load_state_dict(self.checkpoint['model'])
|
||||||
|
optimizer.load_state_dict(self.checkpoint['optimizer'])
|
||||||
|
print('=> checkpoint state loaded.')
|
||||||
|
|
||||||
|
model = torch.nn.DataParallel(model)
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
def iterate(self, mode, args, loader, model, optimizer, logger, epoch):
|
||||||
|
"""iterate data"""
|
||||||
|
block_average_meter = AverageMeter()
|
||||||
|
average_meter = AverageMeter()
|
||||||
|
meters = [block_average_meter, average_meter]
|
||||||
|
merged_img = None
|
||||||
|
# switch to appropriate mode
|
||||||
|
assert mode in ['train', 'val', 'eval', 'test_prediction', 'test_completion'], \
|
||||||
|
'unsupported mode: {}'.format(mode)
|
||||||
|
model.eval()
|
||||||
|
lr = 0
|
||||||
|
|
||||||
|
for i, batch_data in enumerate(loader):
|
||||||
|
start = time.time()
|
||||||
|
batch_data = {
|
||||||
|
key: val.to(self.device)
|
||||||
|
for key, val in batch_data.items() if val is not None
|
||||||
|
}
|
||||||
|
gt = batch_data[
|
||||||
|
'gt'] if mode != 'test_prediction' and mode != 'test_completion' else None
|
||||||
|
data_time = time.time() - start
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
pred = model(batch_data)
|
||||||
|
photometric_loss = 0
|
||||||
|
gpu_time = time.time() - start
|
||||||
|
|
||||||
|
# measure accuracy and record loss
|
||||||
|
with torch.no_grad():
|
||||||
|
mini_batch_size = next(iter(batch_data.values())).size(0)
|
||||||
|
result = Result()
|
||||||
|
if mode != 'test_prediction' and mode != 'test_completion':
|
||||||
|
result.evaluate(pred.data, gt.data, photometric_loss)
|
||||||
|
[
|
||||||
|
m.update(result, gpu_time, data_time, mini_batch_size)
|
||||||
|
for m in meters
|
||||||
|
]
|
||||||
|
logger.conditional_print(mode, i, epoch, lr, len(loader),
|
||||||
|
block_average_meter, average_meter)
|
||||||
|
merged_img = logger.conditional_save_img_comparison(
|
||||||
|
mode, i, batch_data, pred, epoch)
|
||||||
|
merged_img = cv2.cvtColor(merged_img, cv2.COLOR_RGB2BGR)
|
||||||
|
logger.conditional_save_pred(mode, i, pred, epoch)
|
||||||
|
|
||||||
|
avg = logger.conditional_save_info(mode, average_meter, epoch)
|
||||||
|
is_best = logger.rank_conditional_save_best(mode, avg, epoch)
|
||||||
|
logger.save_img_comparison_as_best(mode, epoch)
|
||||||
|
logger.conditional_summarize(mode, avg, is_best)
|
||||||
|
|
||||||
|
return avg, is_best, merged_img
|
||||||
|
|
||||||
|
def forward(self, source_dir):
|
||||||
|
"""main function"""
|
||||||
|
|
||||||
|
args = self.args
|
||||||
|
args.data_folder = source_dir
|
||||||
|
args.result = os.path.join(args.data_folder, 'results')
|
||||||
|
if args.use_pose:
|
||||||
|
# hard-coded KITTI camera intrinsics
|
||||||
|
K = load_calib(args)
|
||||||
|
fu, fv = float(K[0, 0]), float(K[1, 1])
|
||||||
|
cu, cv = float(K[0, 2]), float(K[1, 2])
|
||||||
|
kitti_intrinsics = Intrinsics(owidth, oheight, fu, fv, cu, cv)
|
||||||
|
if self.cuda:
|
||||||
|
kitti_intrinsics = kitti_intrinsics.cuda()
|
||||||
|
|
||||||
|
# Data loading code
|
||||||
|
print('=> creating data loaders ... ')
|
||||||
|
val_dataset = KittiDepth('val', self.args)
|
||||||
|
val_loader = torch.utils.data.DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=2,
|
||||||
|
pin_memory=True) # set batch size to be 1 for validation
|
||||||
|
print('\t==> val_loader size:{}'.format(len(val_loader)))
|
||||||
|
|
||||||
|
# create backups and results folder
|
||||||
|
logger = helper.logger(self.args)
|
||||||
|
if self.checkpoint is not None:
|
||||||
|
logger.best_result = self.checkpoint['best_result']
|
||||||
|
|
||||||
|
print('=> starting model evaluation ...')
|
||||||
|
result, is_best, merged_img = self.iterate('val', self.args,
|
||||||
|
val_loader, self.model,
|
||||||
|
None, logger,
|
||||||
|
self.checkpoint['epoch'])
|
||||||
|
return merged_img
|
||||||
@@ -0,0 +1,119 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
if not ('DISPLAY' in os.environ):
|
||||||
|
import matplotlib as mpl
|
||||||
|
mpl.use('Agg')
|
||||||
|
|
||||||
|
cmap = plt.cm.jet
|
||||||
|
|
||||||
|
|
||||||
|
def depth_colorize(depth):
|
||||||
|
depth = (depth - np.min(depth)) / (np.max(depth) - np.min(depth))
|
||||||
|
depth = 255 * cmap(depth)[:, :, :3] # H, W, C
|
||||||
|
return depth.astype('uint8')
|
||||||
|
|
||||||
|
|
||||||
|
def merge_into_row(ele, pred):
|
||||||
|
|
||||||
|
def preprocess_depth(x):
|
||||||
|
y = np.squeeze(x.data.cpu().numpy())
|
||||||
|
return depth_colorize(y)
|
||||||
|
|
||||||
|
# if is gray, transforms to rgb
|
||||||
|
img_list = []
|
||||||
|
if 'rgb' in ele:
|
||||||
|
rgb = np.squeeze(ele['rgb'][0, ...].data.cpu().numpy())
|
||||||
|
rgb = np.transpose(rgb, (1, 2, 0))
|
||||||
|
img_list.append(rgb)
|
||||||
|
elif 'g' in ele:
|
||||||
|
g = np.squeeze(ele['g'][0, ...].data.cpu().numpy())
|
||||||
|
g = np.array(Image.fromarray(g).convert('RGB'))
|
||||||
|
img_list.append(g)
|
||||||
|
if 'd' in ele:
|
||||||
|
img_list.append(preprocess_depth(ele['d'][0, ...]))
|
||||||
|
img_list.append(preprocess_depth(pred[0, ...]))
|
||||||
|
if 'gt' in ele:
|
||||||
|
img_list.append(preprocess_depth(ele['gt'][0, ...]))
|
||||||
|
|
||||||
|
img_merge = np.hstack(img_list)
|
||||||
|
return img_merge.astype('uint8')
|
||||||
|
|
||||||
|
|
||||||
|
def add_row(img_merge, row):
|
||||||
|
return np.vstack([img_merge, row])
|
||||||
|
|
||||||
|
|
||||||
|
def save_image(img_merge, filename):
|
||||||
|
image_to_write = cv2.cvtColor(img_merge, cv2.COLOR_RGB2BGR)
|
||||||
|
cv2.imwrite(filename, image_to_write)
|
||||||
|
|
||||||
|
|
||||||
|
def save_depth_as_uint16png(img, filename):
|
||||||
|
img = (img * 256).astype('uint16')
|
||||||
|
cv2.imwrite(filename, img)
|
||||||
|
|
||||||
|
|
||||||
|
if ('DISPLAY' in os.environ):
|
||||||
|
f, axarr = plt.subplots(4, 1)
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.ion()
|
||||||
|
|
||||||
|
|
||||||
|
def display_warping(rgb_tgt, pred_tgt, warped):
|
||||||
|
|
||||||
|
def preprocess(rgb_tgt, pred_tgt, warped):
|
||||||
|
rgb_tgt = 255 * np.transpose(
|
||||||
|
np.squeeze(rgb_tgt.data.cpu().numpy()), (1, 2, 0)) # H, W, C
|
||||||
|
# depth = np.squeeze(depth.cpu().numpy())
|
||||||
|
# depth = depth_colorize(depth)
|
||||||
|
|
||||||
|
# convert to log-scale
|
||||||
|
pred_tgt = np.squeeze(pred_tgt.data.cpu().numpy())
|
||||||
|
# pred_tgt[pred_tgt<=0] = 0.9 # remove negative predictions
|
||||||
|
# pred_tgt = np.log10(pred_tgt)
|
||||||
|
|
||||||
|
pred_tgt = depth_colorize(pred_tgt)
|
||||||
|
|
||||||
|
warped = 255 * np.transpose(
|
||||||
|
np.squeeze(warped.data.cpu().numpy()), (1, 2, 0)) # H, W, C
|
||||||
|
recon_err = np.absolute(
|
||||||
|
warped.astype('float') - rgb_tgt.astype('float')) * (
|
||||||
|
warped > 0)
|
||||||
|
recon_err = recon_err[:, :, 0] + recon_err[:, :, 1] + recon_err[:, :,
|
||||||
|
2]
|
||||||
|
recon_err = depth_colorize(recon_err)
|
||||||
|
return rgb_tgt.astype('uint8'), warped.astype(
|
||||||
|
'uint8'), recon_err, pred_tgt
|
||||||
|
|
||||||
|
rgb_tgt, warped, recon_err, pred_tgt = preprocess(rgb_tgt, pred_tgt,
|
||||||
|
warped)
|
||||||
|
|
||||||
|
# 1st column
|
||||||
|
# column = 0
|
||||||
|
axarr[0].imshow(rgb_tgt)
|
||||||
|
axarr[0].axis('off')
|
||||||
|
axarr[0].axis('equal')
|
||||||
|
# axarr[0, column].set_title('rgb_tgt')
|
||||||
|
|
||||||
|
axarr[1].imshow(warped)
|
||||||
|
axarr[1].axis('off')
|
||||||
|
axarr[1].axis('equal')
|
||||||
|
# axarr[1, column].set_title('warped')
|
||||||
|
|
||||||
|
axarr[2].imshow(recon_err, 'hot')
|
||||||
|
axarr[2].axis('off')
|
||||||
|
axarr[2].axis('equal')
|
||||||
|
# axarr[2, column].set_title('recon_err error')
|
||||||
|
|
||||||
|
axarr[3].imshow(pred_tgt, 'hot')
|
||||||
|
axarr[3].axis('off')
|
||||||
|
axarr[3].axis('equal')
|
||||||
|
# axarr[3, column].set_title('pred_tgt')
|
||||||
|
|
||||||
|
# plt.show()
|
||||||
|
plt.pause(0.001)
|
||||||
@@ -774,6 +774,7 @@ TASK_OUTPUTS = {
|
|||||||
Tasks.surface_recon_common: [OutputKeys.OUTPUT],
|
Tasks.surface_recon_common: [OutputKeys.OUTPUT],
|
||||||
Tasks.video_colorization: [OutputKeys.OUTPUT_VIDEO],
|
Tasks.video_colorization: [OutputKeys.OUTPUT_VIDEO],
|
||||||
Tasks.image_control_3d_portrait: [OutputKeys.OUTPUT],
|
Tasks.image_control_3d_portrait: [OutputKeys.OUTPUT],
|
||||||
|
Tasks.self_supervised_depth_completion: [OutputKeys.OUTPUT_IMG],
|
||||||
|
|
||||||
# image quality assessment degradation result for single image
|
# image quality assessment degradation result for single image
|
||||||
# {
|
# {
|
||||||
|
|||||||
@@ -121,6 +121,8 @@ if TYPE_CHECKING:
|
|||||||
from .image_local_feature_matching_pipeline import ImageLocalFeatureMatchingPipeline
|
from .image_local_feature_matching_pipeline import ImageLocalFeatureMatchingPipeline
|
||||||
from .rife_video_frame_interpolation_pipeline import RIFEVideoFrameInterpolationPipeline
|
from .rife_video_frame_interpolation_pipeline import RIFEVideoFrameInterpolationPipeline
|
||||||
from .anydoor_pipeline import AnydoorPipeline
|
from .anydoor_pipeline import AnydoorPipeline
|
||||||
|
from .self_supervised_depth_completion_pipeline import SelfSupervisedDepthCompletionPipeline
|
||||||
|
|
||||||
else:
|
else:
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
'action_recognition_pipeline': ['ActionRecognitionPipeline'],
|
'action_recognition_pipeline': ['ActionRecognitionPipeline'],
|
||||||
@@ -303,6 +305,9 @@ else:
|
|||||||
'RIFEVideoFrameInterpolationPipeline'
|
'RIFEVideoFrameInterpolationPipeline'
|
||||||
],
|
],
|
||||||
'anydoor_pipeline': ['AnydoorPipeline'],
|
'anydoor_pipeline': ['AnydoorPipeline'],
|
||||||
|
'self_supervised_depth_completion_pipeline': [
|
||||||
|
'SelfSupervisedDepthCompletionPipeline'
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|||||||
@@ -0,0 +1,59 @@
|
|||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from modelscope.metainfo import Pipelines
|
||||||
|
from modelscope.outputs import OutputKeys
|
||||||
|
from modelscope.pipelines.base import Pipeline
|
||||||
|
from modelscope.pipelines.builder import PIPELINES
|
||||||
|
from modelscope.utils.constant import Tasks
|
||||||
|
from modelscope.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@PIPELINES.register_module(
|
||||||
|
Tasks.self_supervised_depth_completion,
|
||||||
|
module_name=Pipelines.self_supervised_depth_completion)
|
||||||
|
class SelfSupervisedDepthCompletionPipeline(Pipeline):
|
||||||
|
"""Self Supervise dDepth Completion Pipeline
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from modelscope.pipelines import pipeline
|
||||||
|
>>> model_id = 'Damo_XR_Lab/Self_Supervised_Depth_Completion'
|
||||||
|
>>> data_dir = MsDataset.load(
|
||||||
|
'KITTI_Depth_Dataset',
|
||||||
|
namespace='Damo_XR_Lab',
|
||||||
|
split='test',
|
||||||
|
download_mode=DownloadMode.FORCE_REDOWNLOAD
|
||||||
|
).config_kwargs['split_config']['test']
|
||||||
|
>>> source_dir = os.path.join(data_dir, 'selected_data')
|
||||||
|
>>> self_supervised_depth_completion = pipeline(Tasks.self_supervised_depth_completion,
|
||||||
|
'Damo_XR_Lab/Self_Supervised_Depth_Completion')
|
||||||
|
>>> result = self_supervised_depth_completion({
|
||||||
|
'model_dir': model_id
|
||||||
|
'source_dir': source_dir
|
||||||
|
})
|
||||||
|
cv2.imwrite('result.jpg', result[OutputKeys.OUTPUT])
|
||||||
|
>>> #
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model: str, **kwargs):
|
||||||
|
|
||||||
|
super().__init__(model=model, **kwargs)
|
||||||
|
logger.info('load model done')
|
||||||
|
|
||||||
|
def preprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""preprocess, not used at present"""
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""forward"""
|
||||||
|
source_dir = inputs['source_dir']
|
||||||
|
result = self.model.forward(source_dir)
|
||||||
|
return {OutputKeys.OUTPUT: result}
|
||||||
|
|
||||||
|
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""postprocess, not used at present"""
|
||||||
|
return inputs
|
||||||
@@ -170,6 +170,7 @@ class CVTasks(object):
|
|||||||
human3d_render = 'human3d-render'
|
human3d_render = 'human3d-render'
|
||||||
human3d_animation = 'human3d-animation'
|
human3d_animation = 'human3d-animation'
|
||||||
image_control_3d_portrait = 'image-control-3d-portrait'
|
image_control_3d_portrait = 'image-control-3d-portrait'
|
||||||
|
self_supervised_depth_completion = 'self-supervised-depth-completion'
|
||||||
|
|
||||||
# 3d generation
|
# 3d generation
|
||||||
image_to_3d = 'image-to-3d'
|
image_to_3d = 'image-to-3d'
|
||||||
|
|||||||
@@ -3812,5 +3812,18 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"self-supervised-depth-completion": {
|
||||||
|
"input": {},
|
||||||
|
"parameters": {},
|
||||||
|
"output": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"output_img": {
|
||||||
|
"type": "string",
|
||||||
|
"description":"The base64 encoded image."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
54
tests/pipelines/test_self_supervised_depth_completion.py
Normal file
54
tests/pipelines/test_self_supervised_depth_completion.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from modelscope import get_logger
|
||||||
|
from modelscope.hub.snapshot_download import snapshot_download
|
||||||
|
from modelscope.msdatasets import MsDataset
|
||||||
|
from modelscope.outputs.outputs import OutputKeys
|
||||||
|
from modelscope.pipelines import pipeline
|
||||||
|
from modelscope.utils.constant import DownloadMode, Tasks
|
||||||
|
from modelscope.utils.test_utils import test_level
|
||||||
|
|
||||||
|
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class SelfSupervisedDepthCompletionTest(unittest.TestCase):
|
||||||
|
"""class SelfSupervisedDepthCompletionTest"""
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.model_id = 'Damo_XR_Lab/Self_Supervised_Depth_Completion'
|
||||||
|
data_dir = MsDataset.load(
|
||||||
|
'KITTI_Depth_Dataset',
|
||||||
|
namespace='Damo_XR_Lab',
|
||||||
|
split='test',
|
||||||
|
download_mode=DownloadMode.FORCE_REDOWNLOAD
|
||||||
|
).config_kwargs['split_config']['test']
|
||||||
|
self.source_dir = os.path.join(data_dir, 'selected_data')
|
||||||
|
logger.info(data_dir)
|
||||||
|
|
||||||
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||||
|
@unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest only')
|
||||||
|
def test_run(self):
|
||||||
|
"""test running evaluation"""
|
||||||
|
snapshot_path = snapshot_download(self.model_id)
|
||||||
|
logger.info('snapshot_path: %s', snapshot_path)
|
||||||
|
self_supervised_depth_completion = pipeline(
|
||||||
|
task=Tasks.self_supervised_depth_completion,
|
||||||
|
model=self.model_id
|
||||||
|
# ,config_file = os.path.join(modelPath, "configuration.json")
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self_supervised_depth_completion(
|
||||||
|
dict(model_dir=snapshot_path, source_dir=self.source_dir))
|
||||||
|
cv2.imwrite('result.jpg', result[OutputKeys.OUTPUT])
|
||||||
|
logger.info(
|
||||||
|
'self-supervised-depth-completion_damo.test_run_modelhub done')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
0
tests/test_metrics/__init__.py
Normal file
0
tests/test_metrics/__init__.py
Normal file
Reference in New Issue
Block a user