mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-15 15:57:42 +01:00
add self supervised depth completion. (#711)
* add self supervised depth completion. * update. * fix the problem of key inconsistency. * delete args parser. * rename metrics to test_metrics.
This commit is contained in:
@@ -132,6 +132,7 @@ class Models(object):
|
||||
image_control_3d_portrait = 'image-control-3d-portrait'
|
||||
rife = 'rife'
|
||||
anydoor = 'anydoor'
|
||||
self_supervised_depth_completion = 'self-supervised-depth-completion'
|
||||
|
||||
# nlp models
|
||||
bert = 'bert'
|
||||
@@ -469,6 +470,7 @@ class Pipelines(object):
|
||||
rife_video_frame_interpolation = 'rife-video-frame-interpolation'
|
||||
anydoor = 'anydoor'
|
||||
image_to_3d = 'image-to-3d'
|
||||
self_supervised_depth_completion = 'self-supervised-depth-completion'
|
||||
|
||||
# nlp tasks
|
||||
automatic_post_editing = 'automatic-post-editing'
|
||||
@@ -959,7 +961,10 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
'damo/cv_image-view-transform'),
|
||||
Tasks.image_control_3d_portrait: (
|
||||
Pipelines.image_control_3d_portrait,
|
||||
'damo/cv_vit_image-control-3d-portrait-synthesis')
|
||||
'damo/cv_vit_image-control-3d-portrait-synthesis'),
|
||||
Tasks.self_supervised_depth_completion: (
|
||||
Pipelines.self_supervised_depth_completion,
|
||||
'damo/self-supervised-depth-completion')
|
||||
}
|
||||
|
||||
|
||||
@@ -982,6 +987,7 @@ class CVTrainers(object):
|
||||
nerf_recon_4k = 'nerf-recon-4k'
|
||||
action_detection = 'action-detection'
|
||||
vision_efficient_tuning = 'vision-efficient-tuning'
|
||||
self_supervised_depth_completion = 'self-supervised-depth-completion'
|
||||
|
||||
|
||||
class NLPTrainers(object):
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .self_supervised_depth_completion import SelfSupervisedDepthCompletion
|
||||
else:
|
||||
_import_structure = {
|
||||
'selfsuperviseddepthcompletion': ['SelfSupervisedDepthCompletion'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
@@ -0,0 +1,98 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
loss_names = ['l1', 'l2']
|
||||
|
||||
|
||||
class MaskedMSELoss(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(MaskedMSELoss, self).__init__()
|
||||
|
||||
def forward(self, pred, target):
|
||||
assert pred.dim() == target.dim(), 'inconsistent dimensions'
|
||||
valid_mask = (target > 0).detach()
|
||||
diff = target - pred
|
||||
diff = diff[valid_mask]
|
||||
self.loss = (diff**2).mean()
|
||||
return self.loss
|
||||
|
||||
|
||||
class MaskedL1Loss(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(MaskedL1Loss, self).__init__()
|
||||
|
||||
def forward(self, pred, target, weight=None):
|
||||
assert pred.dim() == target.dim(), 'inconsistent dimensions'
|
||||
valid_mask = (target > 0).detach()
|
||||
diff = target - pred
|
||||
diff = diff[valid_mask]
|
||||
self.loss = diff.abs().mean()
|
||||
return self.loss
|
||||
|
||||
|
||||
class PhotometricLoss(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(PhotometricLoss, self).__init__()
|
||||
|
||||
def forward(self, target, recon, mask=None):
|
||||
|
||||
assert recon.dim(
|
||||
) == 4, 'expected recon dimension to be 4, but instead got {}.'.format(
|
||||
recon.dim())
|
||||
assert target.dim(
|
||||
) == 4, 'expected target dimension to be 4, but instead got {}.'.format(
|
||||
target.dim())
|
||||
assert recon.size() == target.size(), 'expected recon and target to have the same size, but got {} and {} '\
|
||||
.format(recon.size(), target.size())
|
||||
diff = (target - recon).abs()
|
||||
diff = torch.sum(diff, 1) # sum along the color channel
|
||||
|
||||
# compare only pixels that are not black
|
||||
valid_mask = (torch.sum(recon, 1) > 0).float() * (torch.sum(target, 1)
|
||||
> 0).float()
|
||||
if mask is not None:
|
||||
valid_mask = valid_mask * torch.squeeze(mask).float()
|
||||
valid_mask = valid_mask.byte().detach()
|
||||
if valid_mask.numel() > 0:
|
||||
diff = diff[valid_mask]
|
||||
if diff.nelement() > 0:
|
||||
self.loss = diff.mean()
|
||||
else:
|
||||
logger.info(
|
||||
'warning: diff.nelement()==0 in PhotometricLoss (this is expected during early stage of training, \
|
||||
try larger batch size).')
|
||||
self.loss = 0
|
||||
else:
|
||||
logger.info('warning: 0 valid pixel in PhotometricLoss')
|
||||
self.loss = 0
|
||||
return self.loss
|
||||
|
||||
|
||||
class SmoothnessLoss(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(SmoothnessLoss, self).__init__()
|
||||
|
||||
def forward(self, depth):
|
||||
|
||||
def second_derivative(x):
|
||||
assert x.dim(
|
||||
) == 4, 'expected 4-dimensional data, but instead got {}'.format(
|
||||
x.dim())
|
||||
horizontal = 2 * x[:, :, 1:-1, 1:-1] - x[:, :,
|
||||
1:-1, :-2] - x[:, :, 1:-1,
|
||||
2:]
|
||||
vertical = 2 * x[:, :, 1:-1, 1:-1] - x[:, :, :-2,
|
||||
1:-1] - x[:, :, 2:, 1:-1]
|
||||
der_2nd = horizontal.abs() + vertical.abs()
|
||||
return der_2nd.mean()
|
||||
|
||||
self.loss = second_derivative(depth)
|
||||
return self.loss
|
||||
@@ -0,0 +1,344 @@
|
||||
import glob
|
||||
import os
|
||||
import os.path
|
||||
from random import choice
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch.utils.data as data
|
||||
from numpy import linalg as LA
|
||||
from PIL import Image
|
||||
|
||||
from modelscope.models.cv.self_supervised_depth_completion.dataloaders import \
|
||||
transforms
|
||||
from modelscope.models.cv.self_supervised_depth_completion.dataloaders.pose_estimator import \
|
||||
get_pose_pnp
|
||||
|
||||
input_options = ['d', 'rgb', 'rgbd', 'g', 'gd']
|
||||
|
||||
|
||||
def load_calib(args):
|
||||
"""
|
||||
Temporarily hardcoding the calibration matrix using calib file from 2011_09_26
|
||||
"""
|
||||
calib = open(os.path.join(args.data_folder, 'calib_cam_to_cam.txt'), 'r')
|
||||
lines = calib.readlines()
|
||||
P_rect_line = lines[25]
|
||||
|
||||
Proj_str = P_rect_line.split(':')[1].split(' ')[1:]
|
||||
Proj = np.reshape(np.array([float(p) for p in Proj_str]),
|
||||
(3, 4)).astype(np.float32)
|
||||
K = Proj[:3, :3] # camera matrix
|
||||
|
||||
# note: we will take the center crop of the images during augmentation
|
||||
# that changes the optical centers, but not focal lengths
|
||||
K[0, 2] = K[
|
||||
0,
|
||||
2] - 13 # from width = 1242 to 1216, with a 13-pixel cut on both sides
|
||||
K[1, 2] = K[
|
||||
1,
|
||||
2] - 11.5 # from width = 375 to 352, with a 11.5-pixel cut on both sides
|
||||
return K
|
||||
|
||||
|
||||
def get_paths_and_transform(split, args):
|
||||
assert (args.use_d or args.use_rgb
|
||||
or args.use_g), 'no proper input selected'
|
||||
|
||||
if split == 'train':
|
||||
transform = train_transform
|
||||
glob_d = os.path.join(
|
||||
args.data_folder,
|
||||
'data_depth_velodyne/train/*_sync/proj_depth/velodyne_raw/image_0[2,3]/*.png'
|
||||
)
|
||||
glob_gt = os.path.join(
|
||||
args.data_folder,
|
||||
'data_depth_annotated/train/*_sync/proj_depth/groundtruth/image_0[2,3]/*.png'
|
||||
)
|
||||
|
||||
def get_rgb_paths(p):
|
||||
ps = p.split('/')
|
||||
pnew = '/'.join([args.data_folder] + ['data_rgb'] + ps[-6:-4]
|
||||
+ ps[-2:-1] + ['data'] + ps[-1:])
|
||||
return pnew
|
||||
elif split == 'val':
|
||||
if args.val == 'full':
|
||||
transform = val_transform
|
||||
glob_d = os.path.join(
|
||||
args.data_folder,
|
||||
'data_depth_velodyne/val/*_sync/proj_depth/velodyne_raw/image_0[2,3]/*.png'
|
||||
)
|
||||
glob_gt = os.path.join(
|
||||
args.data_folder,
|
||||
'data_depth_annotated/val/*_sync/proj_depth/groundtruth/image_0[2,3]/*.png'
|
||||
)
|
||||
|
||||
def get_rgb_paths(p):
|
||||
ps = p.split('/')
|
||||
pnew = '/'.join(ps[:-7] + ['data_rgb '] + ps[-6:-4] + ps[-2:-1]
|
||||
+ ['data'] + ps[-1:])
|
||||
return pnew
|
||||
elif args.val == 'select':
|
||||
transform = no_transform
|
||||
glob_d = os.path.join(
|
||||
args.data_folder,
|
||||
'depth_selection/val_selection_cropped/velodyne_raw/*.png')
|
||||
glob_gt = os.path.join(
|
||||
args.data_folder,
|
||||
'depth_selection/val_selection_cropped/groundtruth_depth/*.png'
|
||||
)
|
||||
|
||||
def get_rgb_paths(p):
|
||||
return p.replace('groundtruth_depth', 'image')
|
||||
elif split == 'test_completion':
|
||||
transform = no_transform
|
||||
glob_d = os.path.join(
|
||||
args.data_folder,
|
||||
'depth_selection/test_depth_completion_anonymous/velodyne_raw/*.png'
|
||||
)
|
||||
glob_gt = None # "test_depth_completion_anonymous/"
|
||||
glob_rgb = os.path.join(
|
||||
args.data_folder,
|
||||
'depth_selection/test_depth_completion_anonymous/image/*.png')
|
||||
elif split == 'test_prediction':
|
||||
transform = no_transform
|
||||
glob_d = None
|
||||
glob_gt = None # "test_depth_completion_anonymous/"
|
||||
glob_rgb = os.path.join(
|
||||
args.data_folder,
|
||||
'depth_selection/test_depth_prediction_anonymous/image/*.png')
|
||||
else:
|
||||
raise ValueError('Unrecognized split ' + str(split))
|
||||
|
||||
if glob_gt is not None:
|
||||
# train or val-full or val-select
|
||||
paths_d = sorted(glob.glob(glob_d))
|
||||
paths_gt = sorted(glob.glob(glob_gt))
|
||||
paths_rgb = [get_rgb_paths(p) for p in paths_gt]
|
||||
else:
|
||||
# test only has d or rgb
|
||||
paths_rgb = sorted(glob.glob(glob_rgb))
|
||||
paths_gt = [None] * len(paths_rgb)
|
||||
if split == 'test_prediction':
|
||||
paths_d = [None] * len(
|
||||
paths_rgb) # test_prediction has no sparse depth
|
||||
else:
|
||||
paths_d = sorted(glob.glob(glob_d))
|
||||
|
||||
if len(paths_d) == 0 and len(paths_rgb) == 0 and len(paths_gt) == 0:
|
||||
raise (RuntimeError('Found 0 images under {}'.format(glob_gt)))
|
||||
if len(paths_d) == 0 and args.use_d:
|
||||
raise (RuntimeError('Requested sparse depth but none was found'))
|
||||
if len(paths_rgb) == 0 and args.use_rgb:
|
||||
raise (RuntimeError('Requested rgb images but none was found'))
|
||||
if len(paths_rgb) == 0 and args.use_g:
|
||||
raise (RuntimeError('Requested gray images but no rgb was found'))
|
||||
if len(paths_rgb) != len(paths_d) or len(paths_rgb) != len(paths_gt):
|
||||
raise (RuntimeError('Produced different sizes for datasets'))
|
||||
|
||||
paths = {'rgb': paths_rgb, 'd': paths_d, 'gt': paths_gt}
|
||||
return paths, transform
|
||||
|
||||
|
||||
def rgb_read(filename):
|
||||
assert os.path.exists(filename), 'file not found: {}'.format(filename)
|
||||
img_file = Image.open(filename)
|
||||
# rgb_png = np.array(img_file, dtype=float) / 255.0 # scale pixels to the range [0,1]
|
||||
rgb_png = np.array(img_file, dtype='uint8') # in the range [0,255]
|
||||
img_file.close()
|
||||
return rgb_png
|
||||
|
||||
|
||||
def depth_read(filename):
|
||||
# loads depth map D from png file
|
||||
# and returns it as a numpy array,
|
||||
# for details see readme.txt
|
||||
assert os.path.exists(filename), 'file not found: {}'.format(filename)
|
||||
img_file = Image.open(filename)
|
||||
depth_png = np.array(img_file, dtype=int)
|
||||
img_file.close()
|
||||
# make sure we have a proper 16bit depth map here.. not 8bit!
|
||||
assert np.max(depth_png) > 255, \
|
||||
'np.max(depth_png)={}, path={}'.format(np.max(depth_png), filename)
|
||||
|
||||
depth = depth_png.astype(float) / 256.
|
||||
# depth[depth_png == 0] = -1.
|
||||
depth = np.expand_dims(depth, -1)
|
||||
return depth
|
||||
|
||||
|
||||
oheight, owidth = 352, 1216
|
||||
|
||||
|
||||
def drop_depth_measurements(depth, prob_keep):
|
||||
mask = np.random.binomial(1, prob_keep, depth.shape)
|
||||
depth *= mask
|
||||
return depth
|
||||
|
||||
|
||||
def train_transform(rgb, sparse, target, rgb_near, args):
|
||||
# s = np.random.uniform(1.0, 1.5) # random scaling
|
||||
# angle = np.random.uniform(-5.0, 5.0) # random rotation degrees
|
||||
do_flip = np.random.uniform(0.0, 1.0) < 0.5 # random horizontal flip
|
||||
|
||||
transform_geometric = transforms.Compose([
|
||||
# transforms.Rotate(angle),
|
||||
# transforms.Resize(s),
|
||||
transforms.BottomCrop((oheight, owidth)),
|
||||
transforms.HorizontalFlip(do_flip)
|
||||
])
|
||||
if sparse is not None:
|
||||
sparse = transform_geometric(sparse)
|
||||
target = transform_geometric(target)
|
||||
if rgb is not None:
|
||||
brightness = np.random.uniform(
|
||||
max(0, 1 - args.jitter), 1 + args.jitter)
|
||||
contrast = np.random.uniform(max(0, 1 - args.jitter), 1 + args.jitter)
|
||||
saturation = np.random.uniform(
|
||||
max(0, 1 - args.jitter), 1 + args.jitter)
|
||||
transform_rgb = transforms.Compose([
|
||||
transforms.ColorJitter(brightness, contrast, saturation, 0),
|
||||
transform_geometric
|
||||
])
|
||||
rgb = transform_rgb(rgb)
|
||||
if rgb_near is not None:
|
||||
rgb_near = transform_rgb(rgb_near)
|
||||
# sparse = drop_depth_measurements(sparse, 0.9)
|
||||
|
||||
return rgb, sparse, target, rgb_near
|
||||
|
||||
|
||||
def val_transform(rgb, sparse, target, rgb_near, args):
|
||||
transform = transforms.Compose([
|
||||
transforms.BottomCrop((oheight, owidth)),
|
||||
])
|
||||
if rgb is not None:
|
||||
rgb = transform(rgb)
|
||||
if sparse is not None:
|
||||
sparse = transform(sparse)
|
||||
if target is not None:
|
||||
target = transform(target)
|
||||
if rgb_near is not None:
|
||||
rgb_near = transform(rgb_near)
|
||||
return rgb, sparse, target, rgb_near
|
||||
|
||||
|
||||
def no_transform(rgb, sparse, target, rgb_near, args):
|
||||
return rgb, sparse, target, rgb_near
|
||||
|
||||
|
||||
to_tensor = transforms.ToTensor()
|
||||
|
||||
|
||||
def to_float_tensor(x):
|
||||
return to_tensor(x).float()
|
||||
|
||||
|
||||
def handle_gray(rgb, args):
|
||||
if rgb is None:
|
||||
return None, None
|
||||
if not args.use_g:
|
||||
return rgb, None
|
||||
else:
|
||||
img = np.array(Image.fromarray(rgb).convert('L'))
|
||||
img = np.expand_dims(img, -1)
|
||||
if not args.use_rgb:
|
||||
rgb_ret = None
|
||||
else:
|
||||
rgb_ret = rgb
|
||||
return rgb_ret, img
|
||||
|
||||
|
||||
def get_rgb_near(path, args):
|
||||
assert path is not None, 'path is None'
|
||||
|
||||
def extract_frame_id(filename):
|
||||
head, tail = os.path.split(filename)
|
||||
number_string = tail[0:tail.find('.')]
|
||||
number = int(number_string)
|
||||
return head, number
|
||||
|
||||
def get_nearby_filename(filename, new_id):
|
||||
head, _ = os.path.split(filename)
|
||||
new_filename = os.path.join(head, '%010d.png' % new_id)
|
||||
return new_filename
|
||||
|
||||
head, number = extract_frame_id(path)
|
||||
count = 0
|
||||
max_frame_diff = 3
|
||||
candidates = [
|
||||
i - max_frame_diff for i in range(max_frame_diff * 2 + 1)
|
||||
if i - max_frame_diff != 0
|
||||
]
|
||||
while True:
|
||||
random_offset = choice(candidates)
|
||||
path_near = get_nearby_filename(path, number + random_offset)
|
||||
if os.path.exists(path_near):
|
||||
break
|
||||
assert count < 20, 'cannot find a nearby frame in 20 trials for {}'.format(
|
||||
path)
|
||||
count += 1
|
||||
|
||||
return rgb_read(path_near)
|
||||
|
||||
|
||||
class KittiDepth(data.Dataset):
|
||||
"""A data loader for the Kitti dataset
|
||||
"""
|
||||
|
||||
def __init__(self, split, args):
|
||||
self.args = args
|
||||
self.split = split
|
||||
paths, transform = get_paths_and_transform(split, args)
|
||||
self.paths = paths
|
||||
self.transform = transform
|
||||
self.K = load_calib(args)
|
||||
self.threshold_translation = 0.1
|
||||
|
||||
def __getraw__(self, index):
|
||||
rgb = rgb_read(self.paths['rgb'][index]) if \
|
||||
(self.paths['rgb'][index] is not None and (self.args.use_rgb or self.args.use_g)) else None
|
||||
sparse = depth_read(self.paths['d'][index]) if \
|
||||
(self.paths['d'][index] is not None and self.args.use_d) else None
|
||||
target = depth_read(self.paths['gt'][index]) if \
|
||||
self.paths['gt'][index] is not None else None
|
||||
rgb_near = get_rgb_near(self.paths['rgb'][index], self.args) if \
|
||||
self.split == 'train' and self.args.use_pose else None
|
||||
return rgb, sparse, target, rgb_near
|
||||
|
||||
def __getitem__(self, index):
|
||||
rgb, sparse, target, rgb_near = self.__getraw__(index)
|
||||
rgb, sparse, target, rgb_near = self.transform(rgb, sparse, target,
|
||||
rgb_near, self.args)
|
||||
r_mat, t_vec = None, None
|
||||
if self.split == 'train' and self.args.use_pose:
|
||||
success, r_vec, t_vec = get_pose_pnp(rgb, rgb_near, sparse, self.K)
|
||||
# discard if translation is too small
|
||||
success = success and LA.norm(t_vec) > self.threshold_translation
|
||||
if success:
|
||||
r_mat, _ = cv2.Rodrigues(r_vec)
|
||||
else:
|
||||
# return the same image and no motion when PnP fails
|
||||
rgb_near = rgb
|
||||
t_vec = np.zeros((3, 1))
|
||||
r_mat = np.eye(3)
|
||||
|
||||
rgb, gray = handle_gray(rgb, self.args)
|
||||
candidates = {
|
||||
'rgb': rgb,
|
||||
'd': sparse,
|
||||
'gt': target,
|
||||
'g': gray,
|
||||
'r_mat': r_mat,
|
||||
't_vec': t_vec,
|
||||
'rgb_near': rgb_near
|
||||
}
|
||||
items = {
|
||||
key: to_float_tensor(val)
|
||||
for key, val in candidates.items() if val is not None
|
||||
}
|
||||
|
||||
return items
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths['gt'])
|
||||
@@ -0,0 +1,102 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def rgb2gray(rgb):
|
||||
return np.dot(rgb[..., :3], [0.299, 0.587, 0.114])
|
||||
|
||||
|
||||
def convert_2d_to_3d(u, v, z, K):
|
||||
v0 = K[1][2]
|
||||
u0 = K[0][2]
|
||||
fy = K[1][1]
|
||||
fx = K[0][0]
|
||||
x = (u - u0) * z / fx
|
||||
y = (v - v0) * z / fy
|
||||
return (x, y, z)
|
||||
|
||||
|
||||
def feature_match(img1, img2):
|
||||
r''' Find features on both images and match them pairwise
|
||||
'''
|
||||
max_n_features = 1000
|
||||
# max_n_features = 500
|
||||
use_flann = False # better not use flann
|
||||
|
||||
detector = cv2.xfeatures2d.SIFT_create(max_n_features)
|
||||
|
||||
# find the keypoints and descriptors with SIFT
|
||||
kp1, des1 = detector.detectAndCompute(img1, None)
|
||||
kp2, des2 = detector.detectAndCompute(img2, None)
|
||||
if (des1 is None) or (des2 is None):
|
||||
return [], []
|
||||
des1 = des1.astype(np.float32)
|
||||
des2 = des2.astype(np.float32)
|
||||
|
||||
if use_flann:
|
||||
# FLANN parameters
|
||||
FLANN_INDEX_KDTREE = 0
|
||||
index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
|
||||
search_params = dict(checks=50)
|
||||
flann = cv2.FlannBasedMatcher(index_params, search_params)
|
||||
matches = flann.knnMatch(des1, des2, k=2)
|
||||
else:
|
||||
matcher = cv2.DescriptorMatcher().create('BruteForce')
|
||||
matches = matcher.knnMatch(des1, des2, k=2)
|
||||
|
||||
good = []
|
||||
pts1 = []
|
||||
pts2 = []
|
||||
# ratio test as per Lowe's paper
|
||||
for i, (m, n) in enumerate(matches):
|
||||
if m.distance < 0.8 * n.distance:
|
||||
good.append(m)
|
||||
pts2.append(kp2[m.trainIdx].pt)
|
||||
pts1.append(kp1[m.queryIdx].pt)
|
||||
|
||||
pts1 = np.int32(pts1)
|
||||
pts2 = np.int32(pts2)
|
||||
return pts1, pts2
|
||||
|
||||
|
||||
def get_pose_pnp(rgb_curr, rgb_near, depth_curr, K):
|
||||
gray_curr = rgb2gray(rgb_curr).astype(np.uint8)
|
||||
gray_near = rgb2gray(rgb_near).astype(np.uint8)
|
||||
height, width = gray_curr.shape
|
||||
|
||||
pts2d_curr, pts2d_near = feature_match(gray_curr,
|
||||
gray_near) # feature matching
|
||||
|
||||
# dilation of depth
|
||||
kernel = np.ones((4, 4), np.uint8)
|
||||
depth_curr_dilated = cv2.dilate(depth_curr, kernel)
|
||||
|
||||
# extract 3d pts
|
||||
pts3d_curr = []
|
||||
pts2d_near_filtered = [
|
||||
] # keep only feature points with depth in the current frame
|
||||
for i, pt2d in enumerate(pts2d_curr):
|
||||
# print(pt2d)
|
||||
u, v = pt2d[0], pt2d[1]
|
||||
z = depth_curr_dilated[v, u]
|
||||
if z > 0:
|
||||
xyz_curr = convert_2d_to_3d(u, v, z, K)
|
||||
pts3d_curr.append(xyz_curr)
|
||||
pts2d_near_filtered.append(pts2d_near[i])
|
||||
|
||||
# the minimal number of points accepted by solvePnP is 4:
|
||||
if len(pts3d_curr) >= 4 and len(pts2d_near_filtered) >= 4:
|
||||
pts3d_curr = np.expand_dims(
|
||||
np.array(pts3d_curr).astype(np.float32), axis=1)
|
||||
pts2d_near_filtered = np.expand_dims(
|
||||
np.array(pts2d_near_filtered).astype(np.float32), axis=1)
|
||||
|
||||
# ransac
|
||||
ret = cv2.solvePnPRansac(
|
||||
pts3d_curr, pts2d_near_filtered, K, distCoeffs=None)
|
||||
success = ret[0]
|
||||
rotation_vector = ret[1]
|
||||
translation_vector = ret[2]
|
||||
return (success, rotation_vector, translation_vector)
|
||||
else:
|
||||
return (0, None, None)
|
||||
@@ -0,0 +1,617 @@
|
||||
from __future__ import division
|
||||
import numbers
|
||||
import types
|
||||
|
||||
import numpy as np
|
||||
import scipy.ndimage.interpolation as itpl
|
||||
import skimage.transform
|
||||
import torch
|
||||
from PIL import Image, ImageEnhance
|
||||
|
||||
try:
|
||||
import accimage
|
||||
except ImportError:
|
||||
accimage = None
|
||||
|
||||
|
||||
def _is_numpy_image(img):
|
||||
return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
|
||||
|
||||
|
||||
def _is_pil_image(img):
|
||||
if accimage is not None:
|
||||
return isinstance(img, (Image.Image, accimage.Image))
|
||||
else:
|
||||
return isinstance(img, Image.Image)
|
||||
|
||||
|
||||
def _is_tensor_image(img):
|
||||
return torch.is_tensor(img) and img.ndimension() == 3
|
||||
|
||||
|
||||
def adjust_brightness(img, brightness_factor):
|
||||
"""Adjust brightness of an Image.
|
||||
|
||||
Args:
|
||||
img (PIL Image): PIL Image to be adjusted.
|
||||
brightness_factor (float): How much to adjust the brightness. Can be
|
||||
any non negative number. 0 gives a black image, 1 gives the
|
||||
original image while 2 increases the brightness by a factor of 2.
|
||||
|
||||
Returns:
|
||||
PIL Image: Brightness adjusted image.
|
||||
"""
|
||||
if not _is_pil_image(img):
|
||||
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||
|
||||
enhancer = ImageEnhance.Brightness(img)
|
||||
img = enhancer.enhance(brightness_factor)
|
||||
return img
|
||||
|
||||
|
||||
def adjust_contrast(img, contrast_factor):
|
||||
"""Adjust contrast of an Image.
|
||||
|
||||
Args:
|
||||
img (PIL Image): PIL Image to be adjusted.
|
||||
contrast_factor (float): How much to adjust the contrast. Can be any
|
||||
non negative number. 0 gives a solid gray image, 1 gives the
|
||||
original image while 2 increases the contrast by a factor of 2.
|
||||
|
||||
Returns:
|
||||
PIL Image: Contrast adjusted image.
|
||||
"""
|
||||
if not _is_pil_image(img):
|
||||
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||
|
||||
enhancer = ImageEnhance.Contrast(img)
|
||||
img = enhancer.enhance(contrast_factor)
|
||||
return img
|
||||
|
||||
|
||||
def adjust_saturation(img, saturation_factor):
|
||||
"""Adjust color saturation of an image.
|
||||
|
||||
Args:
|
||||
img (PIL Image): PIL Image to be adjusted.
|
||||
saturation_factor (float): How much to adjust the saturation. 0 will
|
||||
give a black and white image, 1 will give the original image while
|
||||
2 will enhance the saturation by a factor of 2.
|
||||
|
||||
Returns:
|
||||
PIL Image: Saturation adjusted image.
|
||||
"""
|
||||
if not _is_pil_image(img):
|
||||
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||
|
||||
enhancer = ImageEnhance.Color(img)
|
||||
img = enhancer.enhance(saturation_factor)
|
||||
return img
|
||||
|
||||
|
||||
def adjust_hue(img, hue_factor):
|
||||
"""Adjust hue of an image.
|
||||
|
||||
The image hue is adjusted by converting the image to HSV and
|
||||
cyclically shifting the intensities in the hue channel (H).
|
||||
The image is then converted back to original image mode.
|
||||
|
||||
`hue_factor` is the amount of shift in H channel and must be in the
|
||||
interval `[-0.5, 0.5]`.
|
||||
|
||||
See https://en.wikipedia.org/wiki/Hue for more details on Hue.
|
||||
|
||||
Args:
|
||||
img (PIL Image): PIL Image to be adjusted.
|
||||
hue_factor (float): How much to shift the hue channel. Should be in
|
||||
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
|
||||
HSV space in positive and negative direction respectively.
|
||||
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
|
||||
with complementary colors while 0 gives the original image.
|
||||
|
||||
Returns:
|
||||
PIL Image: Hue adjusted image.
|
||||
"""
|
||||
if not (-0.5 <= hue_factor <= 0.5):
|
||||
raise ValueError(
|
||||
'hue_factor is not in [-0.5, 0.5]. Got {}'.format(hue_factor))
|
||||
|
||||
if not _is_pil_image(img):
|
||||
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||
|
||||
input_mode = img.mode
|
||||
if input_mode in {'L', '1', 'I', 'F'}:
|
||||
return img
|
||||
|
||||
h, s, v = img.convert('HSV').split()
|
||||
|
||||
np_h = np.array(h, dtype=np.uint8)
|
||||
# uint8 addition take cares of rotation across boundaries
|
||||
with np.errstate(over='ignore'):
|
||||
np_h += np.uint8(hue_factor * 255)
|
||||
h = Image.fromarray(np_h, 'L')
|
||||
|
||||
img = Image.merge('HSV', (h, s, v)).convert(input_mode)
|
||||
return img
|
||||
|
||||
|
||||
def adjust_gamma(img, gamma, gain=1):
|
||||
"""Perform gamma correction on an image.
|
||||
|
||||
Also known as Power Law Transform. Intensities in RGB mode are adjusted
|
||||
based on the following equation:
|
||||
|
||||
I_out = 255 * gain * ((I_in / 255) ** gamma)
|
||||
|
||||
See https://en.wikipedia.org/wiki/Gamma_correction for more details.
|
||||
|
||||
Args:
|
||||
img (PIL Image): PIL Image to be adjusted.
|
||||
gamma (float): Non negative real number. gamma larger than 1 make the
|
||||
shadows darker, while gamma smaller than 1 make dark regions
|
||||
lighter.
|
||||
gain (float): The constant multiplier.
|
||||
"""
|
||||
if not _is_pil_image(img):
|
||||
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||
|
||||
if gamma < 0:
|
||||
raise ValueError('Gamma should be a non-negative real number')
|
||||
|
||||
input_mode = img.mode
|
||||
img = img.convert('RGB')
|
||||
|
||||
np_img = np.array(img, dtype=np.float32)
|
||||
np_img = 255 * gain * ((np_img / 255)**gamma)
|
||||
np_img = np.uint8(np.clip(np_img, 0, 255))
|
||||
|
||||
img = Image.fromarray(np_img, 'RGB').convert(input_mode)
|
||||
return img
|
||||
|
||||
|
||||
class Compose(object):
|
||||
"""Composes several transforms together.
|
||||
|
||||
Args:
|
||||
transforms (list of ``Transform`` objects): list of transforms to compose.
|
||||
|
||||
Example:
|
||||
>>> transforms.Compose([
|
||||
>>> transforms.CenterCrop(10),
|
||||
>>> transforms.ToTensor(),
|
||||
>>> ])
|
||||
"""
|
||||
|
||||
def __init__(self, transforms):
|
||||
self.transforms = transforms
|
||||
|
||||
def __call__(self, img):
|
||||
for t in self.transforms:
|
||||
img = t(img)
|
||||
return img
|
||||
|
||||
|
||||
class ToTensor(object):
|
||||
"""Convert a ``numpy.ndarray`` to tensor.
|
||||
|
||||
Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W).
|
||||
"""
|
||||
|
||||
def __call__(self, img):
|
||||
"""Convert a ``numpy.ndarray`` to tensor.
|
||||
|
||||
Args:
|
||||
img (numpy.ndarray): Image to be converted to tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: Converted image.
|
||||
"""
|
||||
if not (_is_numpy_image(img)):
|
||||
raise TypeError('img should be ndarray. Got {}'.format(type(img)))
|
||||
|
||||
if isinstance(img, np.ndarray):
|
||||
# handle numpy array
|
||||
if img.ndim == 3:
|
||||
img = torch.from_numpy(img.transpose((2, 0, 1)).copy())
|
||||
elif img.ndim == 2:
|
||||
img = torch.from_numpy(img.copy())
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'img should be ndarray with 2 or 3 dimensions. Got {}'.
|
||||
format(img.ndim))
|
||||
|
||||
return img
|
||||
|
||||
|
||||
class NormalizeNumpyArray(object):
|
||||
"""Normalize a ``numpy.ndarray`` with mean and standard deviation.
|
||||
Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform
|
||||
will normalize each channel of the input ``numpy.ndarray`` i.e.
|
||||
``input[channel] = (input[channel] - mean[channel]) / std[channel]``
|
||||
|
||||
Args:
|
||||
mean (sequence): Sequence of means for each channel.
|
||||
std (sequence): Sequence of standard deviations for each channel.
|
||||
"""
|
||||
|
||||
def __init__(self, mean, std):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Args:
|
||||
img (numpy.ndarray): Image of size (H, W, C) to be normalized.
|
||||
|
||||
Returns:
|
||||
Tensor: Normalized image.
|
||||
"""
|
||||
if not (_is_numpy_image(img)):
|
||||
raise TypeError('img should be ndarray. Got {}'.format(type(img)))
|
||||
# TODO: make efficient
|
||||
# print(img.shape)
|
||||
for i in range(3):
|
||||
img[:, :, i] = (img[:, :, i] - self.mean[i]) / self.std[i]
|
||||
return img
|
||||
|
||||
|
||||
class NormalizeTensor(object):
|
||||
"""Normalize an tensor image with mean and standard deviation.
|
||||
Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform
|
||||
will normalize each channel of the input ``torch.*Tensor`` i.e.
|
||||
``input[channel] = (input[channel] - mean[channel]) / std[channel]``
|
||||
|
||||
Args:
|
||||
mean (sequence): Sequence of means for each channel.
|
||||
std (sequence): Sequence of standard deviations for each channel.
|
||||
"""
|
||||
|
||||
def __init__(self, mean, std):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def __call__(self, tensor):
|
||||
"""
|
||||
Args:
|
||||
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
|
||||
|
||||
Returns:
|
||||
Tensor: Normalized Tensor image.
|
||||
"""
|
||||
if not _is_tensor_image(tensor):
|
||||
raise TypeError('tensor is not a torch image.')
|
||||
# TODO: make efficient
|
||||
for t, m, s in zip(tensor, self.mean, self.std):
|
||||
t.sub_(m).div_(s)
|
||||
return tensor
|
||||
|
||||
|
||||
class Rotate(object):
|
||||
"""Rotates the given ``numpy.ndarray``.
|
||||
|
||||
Args:
|
||||
angle (float): The rotation angle in degrees.
|
||||
"""
|
||||
|
||||
def __init__(self, angle):
|
||||
self.angle = angle
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Args:
|
||||
img (numpy.ndarray (C x H x W)): Image to be rotated.
|
||||
|
||||
Returns:
|
||||
img (numpy.ndarray (C x H x W)): Rotated image.
|
||||
"""
|
||||
|
||||
# order=0 means nearest-neighbor type interpolation
|
||||
return skimage.transform.rotate(img, self.angle, resize=False, order=0)
|
||||
|
||||
|
||||
class Resize(object):
|
||||
"""Resize the the given ``numpy.ndarray`` to the given size.
|
||||
Args:
|
||||
size (sequence or int): Desired output size. If size is a sequence like
|
||||
(h, w), output size will be matched to this. If size is an int,
|
||||
smaller edge of the image will be matched to this number.
|
||||
i.e, if height > width, then image will be rescaled to
|
||||
(size * height / width, size)
|
||||
interpolation (int, optional): Desired interpolation. Default is
|
||||
``PIL.Image.BILINEAR``
|
||||
"""
|
||||
|
||||
def __init__(self, size, interpolation='nearest'):
|
||||
assert isinstance(size, float)
|
||||
self.size = size
|
||||
self.interpolation = interpolation
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Args:
|
||||
img (numpy.ndarray (C x H x W)): Image to be scaled.
|
||||
Returns:
|
||||
img (numpy.ndarray (C x H x W)): Rescaled image.
|
||||
"""
|
||||
if img.ndim == 3:
|
||||
return skimage.transform.rescale(img, self.size, order=0)
|
||||
elif img.ndim == 2:
|
||||
return skimage.transform.rescale(img, self.size, order=0)
|
||||
else:
|
||||
RuntimeError(
|
||||
'img should be ndarray with 2 or 3 dimensions. Got {}'.format(
|
||||
img.ndim))
|
||||
|
||||
|
||||
class CenterCrop(object):
|
||||
"""Crops the given ``numpy.ndarray`` at the center.
|
||||
|
||||
Args:
|
||||
size (sequence or int): Desired output size of the crop. If size is an
|
||||
int instead of sequence like (h, w), a square crop (size, size) is
|
||||
made.
|
||||
"""
|
||||
|
||||
def __init__(self, size):
|
||||
if isinstance(size, numbers.Number):
|
||||
self.size = (int(size), int(size))
|
||||
else:
|
||||
self.size = size
|
||||
|
||||
@staticmethod
|
||||
def get_params(img, output_size):
|
||||
"""Get parameters for ``crop`` for center crop.
|
||||
|
||||
Args:
|
||||
img (numpy.ndarray (C x H x W)): Image to be cropped.
|
||||
output_size (tuple): Expected output size of the crop.
|
||||
|
||||
Returns:
|
||||
tuple: params (i, j, h, w) to be passed to ``crop`` for center crop.
|
||||
"""
|
||||
h = img.shape[0]
|
||||
w = img.shape[1]
|
||||
th, tw = output_size
|
||||
i = int(round((h - th) / 2.))
|
||||
j = int(round((w - tw) / 2.))
|
||||
|
||||
# # randomized cropping
|
||||
# i = np.random.randint(i-3, i+4)
|
||||
# j = np.random.randint(j-3, j+4)
|
||||
|
||||
return i, j, th, tw
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Args:
|
||||
img (numpy.ndarray (C x H x W)): Image to be cropped.
|
||||
|
||||
Returns:
|
||||
img (numpy.ndarray (C x H x W)): Cropped image.
|
||||
"""
|
||||
i, j, h, w = self.get_params(img, self.size)
|
||||
"""
|
||||
i: Upper pixel coordinate.
|
||||
j: Left pixel coordinate.
|
||||
h: Height of the cropped image.
|
||||
w: Width of the cropped image.
|
||||
"""
|
||||
if not (_is_numpy_image(img)):
|
||||
raise TypeError('img should be ndarray. Got {}'.format(type(img)))
|
||||
if img.ndim == 3:
|
||||
return img[i:i + h, j:j + w, :]
|
||||
elif img.ndim == 2:
|
||||
return img[i:i + h, j:j + w]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'img should be ndarray with 2 or 3 dimensions. Got {}'.format(
|
||||
img.ndim))
|
||||
|
||||
|
||||
class BottomCrop(object):
|
||||
"""Crops the given ``numpy.ndarray`` at the bottom.
|
||||
|
||||
Args:
|
||||
size (sequence or int): Desired output size of the crop. If size is an
|
||||
int instead of sequence like (h, w), a square crop (size, size) is
|
||||
made.
|
||||
"""
|
||||
|
||||
def __init__(self, size):
|
||||
if isinstance(size, numbers.Number):
|
||||
self.size = (int(size), int(size))
|
||||
else:
|
||||
self.size = size
|
||||
|
||||
@staticmethod
|
||||
def get_params(img, output_size):
|
||||
"""Get parameters for ``crop`` for bottom crop.
|
||||
|
||||
Args:
|
||||
img (numpy.ndarray (C x H x W)): Image to be cropped.
|
||||
output_size (tuple): Expected output size of the crop.
|
||||
|
||||
Returns:
|
||||
tuple: params (i, j, h, w) to be passed to ``crop`` for bottom crop.
|
||||
"""
|
||||
h = img.shape[0]
|
||||
w = img.shape[1]
|
||||
th, tw = output_size
|
||||
i = h - th
|
||||
j = int(round((w - tw) / 2.))
|
||||
|
||||
# randomized left and right cropping
|
||||
# i = np.random.randint(i-3, i+4)
|
||||
# j = np.random.randint(j-1, j+1)
|
||||
|
||||
return i, j, th, tw
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Args:
|
||||
img (numpy.ndarray (C x H x W)): Image to be cropped.
|
||||
|
||||
Returns:
|
||||
img (numpy.ndarray (C x H x W)): Cropped image.
|
||||
"""
|
||||
i, j, h, w = self.get_params(img, self.size)
|
||||
"""
|
||||
i: Upper pixel coordinate.
|
||||
j: Left pixel coordinate.
|
||||
h: Height of the cropped image.
|
||||
w: Width of the cropped image.
|
||||
"""
|
||||
if not (_is_numpy_image(img)):
|
||||
raise TypeError('img should be ndarray. Got {}'.format(type(img)))
|
||||
if img.ndim == 3:
|
||||
return img[i:i + h, j:j + w, :]
|
||||
elif img.ndim == 2:
|
||||
return img[i:i + h, j:j + w]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'img should be ndarray with 2 or 3 dimensions. Got {}'.format(
|
||||
img.ndim))
|
||||
|
||||
|
||||
class Crop(object):
|
||||
"""Crops the given ``numpy.ndarray`` at the center.
|
||||
|
||||
Args:
|
||||
size (sequence or int): Desired output size of the crop. If size is an
|
||||
int instead of sequence like (h, w), a square crop (size, size) is
|
||||
made.
|
||||
"""
|
||||
|
||||
def __init__(self, crop):
|
||||
self.crop = crop
|
||||
|
||||
@staticmethod
|
||||
def get_params(img, crop):
|
||||
"""Get parameters for ``crop`` for center crop.
|
||||
|
||||
Args:
|
||||
img (numpy.ndarray (C x H x W)): Image to be cropped.
|
||||
output_size (tuple): Expected output size of the crop.
|
||||
|
||||
Returns:
|
||||
tuple: params (i, j, h, w) to be passed to ``crop`` for center crop.
|
||||
"""
|
||||
x_l, x_r, y_b, y_t = crop
|
||||
h = img.shape[0]
|
||||
w = img.shape[1]
|
||||
assert x_l >= 0 and x_l < w
|
||||
assert x_r >= 0 and x_r < w
|
||||
assert y_b >= 0 and y_b < h
|
||||
assert y_t >= 0 and y_t < h
|
||||
assert x_l < x_r and y_b < y_t
|
||||
|
||||
return x_l, x_r, y_b, y_t
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Args:
|
||||
img (numpy.ndarray (C x H x W)): Image to be cropped.
|
||||
|
||||
Returns:
|
||||
img (numpy.ndarray (C x H x W)): Cropped image.
|
||||
"""
|
||||
x_l, x_r, y_b, y_t = self.get_params(img, self.crop)
|
||||
"""
|
||||
i: Upper pixel coordinate.
|
||||
j: Left pixel coordinate.
|
||||
h: Height of the cropped image.
|
||||
w: Width of the cropped image.
|
||||
"""
|
||||
if not (_is_numpy_image(img)):
|
||||
raise TypeError('img should be ndarray. Got {}'.format(type(img)))
|
||||
if img.ndim == 3:
|
||||
return img[y_b:y_t, x_l:x_r, :]
|
||||
elif img.ndim == 2:
|
||||
return img[y_b:y_t, x_l:x_r]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'img should be ndarray with 2 or 3 dimensions. Got {}'.format(
|
||||
img.ndim))
|
||||
|
||||
|
||||
class Lambda(object):
|
||||
"""Apply a user-defined lambda as a transform.
|
||||
|
||||
Args:
|
||||
lambd (function): Lambda/function to be used for transform.
|
||||
"""
|
||||
|
||||
def __init__(self, lambd):
|
||||
assert isinstance(lambd, types.LambdaType)
|
||||
self.lambd = lambd
|
||||
|
||||
def __call__(self, img):
|
||||
return self.lambd(img)
|
||||
|
||||
|
||||
class HorizontalFlip(object):
|
||||
"""Horizontally flip the given ``numpy.ndarray``.
|
||||
|
||||
Args:
|
||||
do_flip (boolean): whether or not do horizontal flip.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, do_flip):
|
||||
self.do_flip = do_flip
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Args:
|
||||
img (numpy.ndarray (C x H x W)): Image to be flipped.
|
||||
|
||||
Returns:
|
||||
img (numpy.ndarray (C x H x W)): flipped image.
|
||||
"""
|
||||
if not (_is_numpy_image(img)):
|
||||
raise TypeError('img should be ndarray. Got {}'.format(type(img)))
|
||||
|
||||
if self.do_flip:
|
||||
return np.fliplr(img)
|
||||
else:
|
||||
return img
|
||||
|
||||
|
||||
class ColorJitter(object):
|
||||
"""Randomly change the brightness, contrast and saturation of an image.
|
||||
|
||||
Args:
|
||||
brightness (float): How much to jitter brightness. brightness_factor
|
||||
is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
|
||||
contrast (float): How much to jitter contrast. contrast_factor
|
||||
is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
|
||||
saturation (float): How much to jitter saturation. saturation_factor
|
||||
is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
|
||||
hue(float): How much to jitter hue. hue_factor is chosen uniformly from
|
||||
[-hue, hue]. Should be >=0 and <= 0.5.
|
||||
"""
|
||||
|
||||
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
|
||||
transforms = []
|
||||
transforms.append(
|
||||
Lambda(lambda img: adjust_brightness(img, brightness)))
|
||||
transforms.append(Lambda(lambda img: adjust_contrast(img, contrast)))
|
||||
transforms.append(
|
||||
Lambda(lambda img: adjust_saturation(img, saturation)))
|
||||
transforms.append(Lambda(lambda img: adjust_hue(img, hue)))
|
||||
np.random.shuffle(transforms)
|
||||
self.transform = Compose(transforms)
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Args:
|
||||
img (numpy.ndarray (C x H x W)): Input image.
|
||||
|
||||
Returns:
|
||||
img (numpy.ndarray (C x H x W)): Color jittered image.
|
||||
"""
|
||||
if not (_is_numpy_image(img)):
|
||||
raise TypeError('img should be ndarray. Got {}'.format(type(img)))
|
||||
|
||||
pil = Image.fromarray(img)
|
||||
return np.array(self.transform(pil))
|
||||
269
modelscope/models/cv/self_supervised_depth_completion/helper.py
Normal file
269
modelscope/models/cv/self_supervised_depth_completion/helper.py
Normal file
@@ -0,0 +1,269 @@
|
||||
import csv
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from modelscope.models.cv.self_supervised_depth_completion import vis_utils
|
||||
from modelscope.models.cv.self_supervised_depth_completion.metrics import \
|
||||
Result
|
||||
|
||||
fieldnames = [
|
||||
'epoch', 'rmse', 'photo', 'mae', 'irmse', 'imae', 'mse', 'absrel', 'lg10',
|
||||
'silog', 'squared_rel', 'delta1', 'delta2', 'delta3', 'data_time',
|
||||
'gpu_time'
|
||||
]
|
||||
|
||||
|
||||
class logger:
|
||||
|
||||
def __init__(self, args, prepare=True):
|
||||
self.args = args
|
||||
output_directory = get_folder_name(args)
|
||||
self.output_directory = output_directory
|
||||
self.best_result = Result()
|
||||
self.best_result.set_to_worst()
|
||||
|
||||
if not prepare:
|
||||
return
|
||||
if not os.path.exists(output_directory):
|
||||
os.makedirs(output_directory)
|
||||
self.train_csv = os.path.join(output_directory, 'train.csv')
|
||||
self.val_csv = os.path.join(output_directory, 'val.csv')
|
||||
self.best_txt = os.path.join(output_directory, 'best.txt')
|
||||
|
||||
# backup the source code
|
||||
if args.resume == '':
|
||||
print('=> creating source code backup ...')
|
||||
backup_directory = os.path.join(output_directory, 'code_backup')
|
||||
self.backup_directory = backup_directory
|
||||
# backup_source_code(backup_directory)
|
||||
# create new csv files with only header
|
||||
with open(self.train_csv, 'w') as csvfile:
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
with open(self.val_csv, 'w') as csvfile:
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
print('=> finished creating source code backup.')
|
||||
|
||||
def conditional_print(self, split, i, epoch, lr, n_set, blk_avg_meter,
|
||||
avg_meter):
|
||||
if (i + 1) % self.args.print_freq == 0:
|
||||
avg = avg_meter.average()
|
||||
blk_avg = blk_avg_meter.average()
|
||||
print('=> output: {}'.format(self.output_directory))
|
||||
print(
|
||||
'{split} Epoch: {0} [{1}/{2}]\tlr={lr} '
|
||||
't_Data={blk_avg.data_time:.3f}({average.data_time:.3f}) '
|
||||
't_GPU={blk_avg.gpu_time:.3f}({average.gpu_time:.3f})\n\t'
|
||||
'RMSE={blk_avg.rmse:.2f}({average.rmse:.2f}) '
|
||||
'MAE={blk_avg.mae:.2f}({average.mae:.2f}) '
|
||||
'iRMSE={blk_avg.irmse:.2f}({average.irmse:.2f}) '
|
||||
'iMAE={blk_avg.imae:.2f}({average.imae:.2f})\n\t'
|
||||
'silog={blk_avg.silog:.2f}({average.silog:.2f}) '
|
||||
'squared_rel={blk_avg.squared_rel:.2f}({average.squared_rel:.2f}) '
|
||||
'Delta1={blk_avg.delta1:.3f}({average.delta1:.3f}) '
|
||||
'REL={blk_avg.absrel:.3f}({average.absrel:.3f})\n\t'
|
||||
'Lg10={blk_avg.lg10:.3f}({average.lg10:.3f}) '
|
||||
'Photometric={blk_avg.photometric:.3f}({average.photometric:.3f}) '
|
||||
.format(
|
||||
epoch,
|
||||
i + 1,
|
||||
n_set,
|
||||
lr=lr,
|
||||
blk_avg=blk_avg,
|
||||
average=avg,
|
||||
split=split.capitalize()))
|
||||
blk_avg_meter.reset()
|
||||
|
||||
def conditional_save_info(self, split, average_meter, epoch):
|
||||
avg = average_meter.average()
|
||||
if split == 'train':
|
||||
csvfile_name = self.train_csv
|
||||
elif split == 'val':
|
||||
csvfile_name = self.val_csv
|
||||
elif split == 'eval':
|
||||
eval_filename = os.path.join(self.output_directory, 'eval.txt')
|
||||
self.save_single_txt(eval_filename, avg, epoch)
|
||||
return avg
|
||||
elif 'test' in split:
|
||||
return avg
|
||||
else:
|
||||
raise ValueError('wrong split provided to logger')
|
||||
with open(csvfile_name, 'a') as csvfile:
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
writer.writerow({
|
||||
'epoch': epoch,
|
||||
'rmse': avg.rmse,
|
||||
'photo': avg.photometric,
|
||||
'mae': avg.mae,
|
||||
'irmse': avg.irmse,
|
||||
'imae': avg.imae,
|
||||
'mse': avg.mse,
|
||||
'silog': avg.silog,
|
||||
'squared_rel': avg.squared_rel,
|
||||
'absrel': avg.absrel,
|
||||
'lg10': avg.lg10,
|
||||
'delta1': avg.delta1,
|
||||
'delta2': avg.delta2,
|
||||
'delta3': avg.delta3,
|
||||
'gpu_time': avg.gpu_time,
|
||||
'data_time': avg.data_time
|
||||
})
|
||||
return avg
|
||||
|
||||
def save_single_txt(self, filename, result, epoch):
|
||||
with open(filename, 'w') as txtfile:
|
||||
txtfile.write(
|
||||
('rank_metric={}\n' + 'epoch={}\n' + 'rmse={:.3f}\n'
|
||||
+ 'mae={:.3f}\n' + 'silog={:.3f}\n' + 'squared_rel={:.3f}\n'
|
||||
+ 'irmse={:.3f}\n' + 'imae={:.3f}\n' + 'mse={:.3f}\n'
|
||||
+ 'absrel={:.3f}\n' + 'lg10={:.3f}\n'
|
||||
+ 'delta1={:.3f}\n' + 't_gpu={:.4f}').format(
|
||||
self.args.rank_metric, epoch, result.rmse, result.mae,
|
||||
result.silog, result.squared_rel, result.irmse,
|
||||
result.imae, result.mse, result.absrel, result.lg10,
|
||||
result.delta1, result.gpu_time))
|
||||
|
||||
def save_best_txt(self, result, epoch):
|
||||
self.save_single_txt(self.best_txt, result, epoch)
|
||||
|
||||
def _get_img_comparison_name(self, mode, epoch, is_best=False):
|
||||
if mode == 'eval':
|
||||
return self.output_directory + '/comparison_eval.png'
|
||||
if mode == 'val':
|
||||
if is_best:
|
||||
return self.output_directory + '/comparison_best.png'
|
||||
else:
|
||||
return self.output_directory + '/comparison_' + str(
|
||||
epoch) + '.png'
|
||||
|
||||
def conditional_save_img_comparison(self, mode, i, ele, pred, epoch):
|
||||
# save 8 images for visualization
|
||||
if mode == 'val' or mode == 'eval':
|
||||
skip = 100
|
||||
if i == 0:
|
||||
self.img_merge = vis_utils.merge_into_row(ele, pred)
|
||||
elif i % skip == 0 and i < 8 * skip:
|
||||
row = vis_utils.merge_into_row(ele, pred)
|
||||
self.img_merge = vis_utils.add_row(self.img_merge, row)
|
||||
elif i == 8 * skip:
|
||||
filename = self._get_img_comparison_name(mode, epoch)
|
||||
vis_utils.save_image(self.img_merge, filename)
|
||||
return self.img_merge
|
||||
|
||||
def save_img_comparison_as_best(self, mode, epoch):
|
||||
if mode == 'val':
|
||||
filename = self._get_img_comparison_name(mode, epoch, is_best=True)
|
||||
vis_utils.save_image(self.img_merge, filename)
|
||||
|
||||
def get_ranking_error(self, result):
|
||||
return getattr(result, self.args.rank_metric)
|
||||
|
||||
def rank_conditional_save_best(self, mode, result, epoch):
|
||||
error = self.get_ranking_error(result)
|
||||
best_error = self.get_ranking_error(self.best_result)
|
||||
is_best = error < best_error
|
||||
if is_best and mode == 'val':
|
||||
self.old_best_result = self.best_result
|
||||
self.best_result = result
|
||||
self.save_best_txt(result, epoch)
|
||||
return is_best
|
||||
|
||||
def conditional_save_pred(self, mode, i, pred, epoch):
|
||||
if ('test' in mode or mode == 'eval') and self.args.save_pred:
|
||||
|
||||
# save images for visualization/ testing
|
||||
image_folder = os.path.join(self.output_directory,
|
||||
mode + '_output')
|
||||
if not os.path.exists(image_folder):
|
||||
os.makedirs(image_folder)
|
||||
img = torch.squeeze(pred.data.cpu()).numpy()
|
||||
filename = os.path.join(image_folder, '{0:010d}.png'.format(i))
|
||||
vis_utils.save_depth_as_uint16png(img, filename)
|
||||
|
||||
def conditional_summarize(self, mode, avg, is_best):
|
||||
print('\n*\nSummary of ', mode, 'round')
|
||||
print(''
|
||||
'RMSE={average.rmse:.3f}\n'
|
||||
'MAE={average.mae:.3f}\n'
|
||||
'Photo={average.photometric:.3f}\n'
|
||||
'iRMSE={average.irmse:.3f}\n'
|
||||
'iMAE={average.imae:.3f}\n'
|
||||
'squared_rel={average.squared_rel}\n'
|
||||
'silog={average.silog}\n'
|
||||
'Delta1={average.delta1:.3f}\n'
|
||||
'REL={average.absrel:.3f}\n'
|
||||
'Lg10={average.lg10:.3f}\n'
|
||||
't_GPU={time:.3f}'.format(average=avg, time=avg.gpu_time))
|
||||
if is_best and mode == 'val':
|
||||
print('New best model by %s (was %.3f)' %
|
||||
(self.args.rank_metric,
|
||||
self.get_ranking_error(self.old_best_result)))
|
||||
elif mode == 'val':
|
||||
print('(best %s is %.3f)' %
|
||||
(self.args.rank_metric,
|
||||
self.get_ranking_error(self.best_result)))
|
||||
print('*\n')
|
||||
|
||||
|
||||
ignore_hidden = shutil.ignore_patterns('.', '..', '.git*', '*pycache*',
|
||||
'*build', '*.fuse*', '*_drive_*')
|
||||
|
||||
|
||||
def backup_source_code(backup_directory):
|
||||
if os.path.exists(backup_directory):
|
||||
shutil.rmtree(backup_directory)
|
||||
shutil.copytree('.', backup_directory, ignore=ignore_hidden)
|
||||
|
||||
|
||||
def adjust_learning_rate(lr_init, optimizer, epoch):
|
||||
"""Sets the learning rate to the initial LR decayed by 10 every 5 epochs"""
|
||||
lr = lr_init * (0.1**(epoch // 5))
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
return lr
|
||||
|
||||
|
||||
def save_checkpoint(state, is_best, epoch, output_directory):
|
||||
checkpoint_filename = os.path.join(output_directory,
|
||||
'checkpoint-' + str(epoch) + '.pth.tar')
|
||||
torch.save(state, checkpoint_filename)
|
||||
if is_best:
|
||||
best_filename = os.path.join(output_directory, 'model_best.pth.tar')
|
||||
shutil.copyfile(checkpoint_filename, best_filename)
|
||||
if epoch > 0:
|
||||
prev_checkpoint_filename = os.path.join(
|
||||
output_directory, 'checkpoint-' + str(epoch - 1) + '.pth.tar')
|
||||
if os.path.exists(prev_checkpoint_filename):
|
||||
os.remove(prev_checkpoint_filename)
|
||||
|
||||
|
||||
def get_folder_name(args):
|
||||
# current_time = time.strftime('%Y-%m-%d@%H-%M')
|
||||
# if args.use_pose:
|
||||
# prefix = 'mode={}.w1={}.w2={}.'.format(args.train_mode, args.w1,
|
||||
# args.w2)
|
||||
# else:
|
||||
# prefix = 'mode={}.'.format(args.train_mode)
|
||||
# return os.path.join(args.result,
|
||||
# prefix + 'input={}.resnet{}.criterion={}.lr={}.bs={}.wd={}.pretrained={}.jitter={}.time={}'.
|
||||
# format(args.input, args.layers, args.criterion, \
|
||||
# args.lr, args.batch_size, args.weight_decay, \
|
||||
# args.pretrained, args.jitter, current_time
|
||||
# ))
|
||||
return os.path.join(args.result, 'test')
|
||||
|
||||
|
||||
avgpool = torch.nn.AvgPool2d(kernel_size=2, stride=2).cuda()
|
||||
|
||||
|
||||
def multiscale(img):
|
||||
img1 = avgpool(img)
|
||||
img2 = avgpool(img1)
|
||||
img3 = avgpool(img2)
|
||||
img4 = avgpool(img3)
|
||||
img5 = avgpool(img4)
|
||||
return img5, img4, img3, img2, img1
|
||||
@@ -0,0 +1,141 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class Intrinsics:
|
||||
"""Intrinsics"""
|
||||
|
||||
def __init__(self, width, height, fu, fv, cu=0, cv=0):
|
||||
self.height, self.width = height, width
|
||||
self.fu, self.fv = fu, fv # fu, fv: focal length along the horizontal and vertical axes
|
||||
|
||||
# cu, cv: optical center along the horizontal and vertical axes
|
||||
self.cu = cu if cu > 0 else (width - 1) / 2.0
|
||||
self.cv = cv if cv > 0 else (height - 1) / 2.0
|
||||
|
||||
# U, V represent the homogeneous horizontal and vertical coordinates in the pixel space
|
||||
self.U = torch.arange(start=0, end=width).expand(height, width).float()
|
||||
self.V = torch.arange(
|
||||
start=0, end=height).expand(width, height).t().float()
|
||||
|
||||
# X_cam, Y_cam represent the homogeneous x, y coordinates (assuming depth z=1) in the camera coordinate system
|
||||
self.X_cam = (self.U - self.cu) / self.fu
|
||||
self.Y_cam = (self.V - self.cv) / self.fv
|
||||
|
||||
self.is_cuda = False
|
||||
|
||||
def cuda(self):
|
||||
self.X_cam.data = self.X_cam.data.cuda()
|
||||
self.Y_cam.data = self.Y_cam.data.cuda()
|
||||
self.is_cuda = True
|
||||
return self
|
||||
|
||||
def scale(self, height, width):
|
||||
# return a new set of corresponding intrinsic parameters for the scaled image
|
||||
ratio_u = float(width) / self.width
|
||||
ratio_v = float(height) / self.height
|
||||
fu = ratio_u * self.fu
|
||||
fv = ratio_v * self.fv
|
||||
cu = ratio_u * self.cu
|
||||
cv = ratio_v * self.cv
|
||||
new_intrinsics = Intrinsics(width, height, fu, fv, cu, cv)
|
||||
if self.is_cuda:
|
||||
new_intrinsics.cuda()
|
||||
return new_intrinsics
|
||||
|
||||
def __print__(self):
|
||||
logger.info(
|
||||
'size=({},{})\nfocal length=({},{})\noptical center=({},{})'.
|
||||
format(self.height, self.width, self.fv, self.fu, self.cv,
|
||||
self.cu))
|
||||
|
||||
|
||||
def image_to_pointcloud(depth, intrinsics):
|
||||
assert depth.dim() == 4
|
||||
assert depth.size(1) == 1
|
||||
|
||||
X = depth * intrinsics.X_cam
|
||||
Y = depth * intrinsics.Y_cam
|
||||
return torch.cat((X, Y, depth), dim=1)
|
||||
|
||||
|
||||
def pointcloud_to_image(pointcloud, intrinsics):
|
||||
assert pointcloud.dim() == 4
|
||||
|
||||
batch_size = pointcloud.size(0)
|
||||
X = pointcloud[:, 0, :, :] # .view(batch_size, -1)
|
||||
Y = pointcloud[:, 1, :, :] # .view(batch_size, -1)
|
||||
Z = pointcloud[:, 2, :, :].clamp(min=1e-3) # .view(batch_size, -1)
|
||||
|
||||
# compute pixel coordinates
|
||||
U_proj = intrinsics.fu * X / Z + intrinsics.cu # horizontal pixel coordinate
|
||||
V_proj = intrinsics.fv * Y / Z + intrinsics.cv # vertical pixel coordinate
|
||||
|
||||
# normalization to [-1, 1], required by torch.nn.functional.grid_sample
|
||||
w = intrinsics.width
|
||||
h = intrinsics.height
|
||||
U_proj_normalized = (2 * U_proj / (w - 1) - 1).view(batch_size, -1)
|
||||
V_proj_normalized = (2 * V_proj / (h - 1) - 1).view(batch_size, -1)
|
||||
|
||||
# This was important since PyTorch didn't do as it claimed for points out of boundary
|
||||
# See https://github.com/ClementPinard/SfmLearner-Pytorch/blob/master/inverse_warp.py
|
||||
# Might not be necessary any more
|
||||
U_proj_mask = ((U_proj_normalized > 1) + (U_proj_normalized < -1)).detach()
|
||||
U_proj_normalized[U_proj_mask] = 2
|
||||
V_proj_mask = ((V_proj_normalized > 1) + (V_proj_normalized < -1)).detach()
|
||||
V_proj_normalized[V_proj_mask] = 2
|
||||
|
||||
pixel_coords = torch.stack([U_proj_normalized, V_proj_normalized],
|
||||
dim=2) # [B, H*W, 2]
|
||||
return pixel_coords.view(batch_size, intrinsics.height, intrinsics.width,
|
||||
2)
|
||||
|
||||
|
||||
def batch_multiply(batch_scalar, batch_matrix):
|
||||
# input: batch_scalar of size b, batch_matrix of size b * 3 * 3
|
||||
# output: batch_matrix of size b * 3 * 3
|
||||
batch_size = batch_scalar.size(0)
|
||||
output = batch_matrix.clone()
|
||||
for i in range(batch_size):
|
||||
output[i] = batch_scalar[i] * batch_matrix[i]
|
||||
return output
|
||||
|
||||
|
||||
def transform_curr_to_near(pointcloud_curr, r_mat, t_vec, intrinsics):
|
||||
# translation and rotmat represent the transformation from tgt pose to src pose
|
||||
batch_size = pointcloud_curr.size(0)
|
||||
XYZ_ = torch.bmm(r_mat, pointcloud_curr.view(batch_size, 3, -1))
|
||||
|
||||
X = (XYZ_[:, 0, :] + t_vec[:, 0].unsqueeze(1)).view(
|
||||
-1, 1, intrinsics.height, intrinsics.width)
|
||||
Y = (XYZ_[:, 1, :] + t_vec[:, 1].unsqueeze(1)).view(
|
||||
-1, 1, intrinsics.height, intrinsics.width)
|
||||
Z = (XYZ_[:, 2, :] + t_vec[:, 2].unsqueeze(1)).view(
|
||||
-1, 1, intrinsics.height, intrinsics.width)
|
||||
|
||||
pointcloud_near = torch.cat((X, Y, Z), dim=1)
|
||||
|
||||
return pointcloud_near
|
||||
|
||||
|
||||
def homography_from(rgb_near, depth_curr, r_mat, t_vec, intrinsics):
|
||||
# inverse warp the RGB image from the nearby frame to the current frame
|
||||
|
||||
# to ensure dimension consistency
|
||||
r_mat = r_mat.view(-1, 3, 3)
|
||||
t_vec = t_vec.view(-1, 3)
|
||||
|
||||
# compute source pixel coordinate
|
||||
pointcloud_curr = image_to_pointcloud(depth_curr, intrinsics)
|
||||
pointcloud_near = transform_curr_to_near(pointcloud_curr, r_mat, t_vec,
|
||||
intrinsics)
|
||||
pixel_coords_near = pointcloud_to_image(pointcloud_near, intrinsics)
|
||||
|
||||
# the warping
|
||||
warped = F.grid_sample(rgb_near, pixel_coords_near)
|
||||
|
||||
return warped
|
||||
181
modelscope/models/cv/self_supervised_depth_completion/metrics.py
Normal file
181
modelscope/models/cv/self_supervised_depth_completion/metrics.py
Normal file
@@ -0,0 +1,181 @@
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
lg_e_10 = math.log(10)
|
||||
|
||||
|
||||
def log10(x):
|
||||
"""Convert a new tensor with the base-10 logarithm of the elements of x. """
|
||||
return torch.log(x) / lg_e_10
|
||||
|
||||
|
||||
class Result(object):
|
||||
"""Result"""
|
||||
|
||||
def __init__(self):
|
||||
self.irmse = 0
|
||||
self.imae = 0
|
||||
self.mse = 0
|
||||
self.rmse = 0
|
||||
self.mae = 0
|
||||
self.absrel = 0
|
||||
self.squared_rel = 0
|
||||
self.lg10 = 0
|
||||
self.delta1 = 0
|
||||
self.delta2 = 0
|
||||
self.delta3 = 0
|
||||
self.data_time = 0
|
||||
self.gpu_time = 0
|
||||
self.silog = 0 # Scale invariant logarithmic error [log(m)*100]
|
||||
self.photometric = 0
|
||||
|
||||
def set_to_worst(self):
|
||||
self.irmse = np.inf
|
||||
self.imae = np.inf
|
||||
self.mse = np.inf
|
||||
self.rmse = np.inf
|
||||
self.mae = np.inf
|
||||
self.absrel = np.inf
|
||||
self.squared_rel = np.inf
|
||||
self.lg10 = np.inf
|
||||
self.silog = np.inf
|
||||
self.delta1 = 0
|
||||
self.delta2 = 0
|
||||
self.delta3 = 0
|
||||
self.data_time = 0
|
||||
self.gpu_time = 0
|
||||
|
||||
def update(self,
|
||||
irmse,
|
||||
imae,
|
||||
mse,
|
||||
rmse,
|
||||
mae,
|
||||
absrel,
|
||||
squared_rel,
|
||||
lg10,
|
||||
delta1,
|
||||
delta2,
|
||||
delta3,
|
||||
gpu_time,
|
||||
data_time,
|
||||
silog,
|
||||
photometric=0):
|
||||
"""update"""
|
||||
self.irmse = irmse
|
||||
self.imae = imae
|
||||
self.mse = mse
|
||||
self.rmse = rmse
|
||||
self.mae = mae
|
||||
self.absrel = absrel
|
||||
self.squared_rel = squared_rel
|
||||
self.lg10 = lg10
|
||||
self.delta1 = delta1
|
||||
self.delta2 = delta2
|
||||
self.delta3 = delta3
|
||||
self.data_time = data_time
|
||||
self.gpu_time = gpu_time
|
||||
self.silog = silog
|
||||
self.photometric = photometric
|
||||
|
||||
def evaluate(self, output, target, photometric=0):
|
||||
"""evaluate"""
|
||||
valid_mask = target > 0.1
|
||||
|
||||
# convert from meters to mm
|
||||
output_mm = 1e3 * output[valid_mask]
|
||||
target_mm = 1e3 * target[valid_mask]
|
||||
|
||||
abs_diff = (output_mm - target_mm).abs()
|
||||
|
||||
self.mse = float((torch.pow(abs_diff, 2)).mean())
|
||||
self.rmse = math.sqrt(self.mse)
|
||||
self.mae = float(abs_diff.mean())
|
||||
self.lg10 = float((log10(output_mm) - log10(target_mm)).abs().mean())
|
||||
self.absrel = float((abs_diff / target_mm).mean())
|
||||
self.squared_rel = float(((abs_diff / target_mm)**2).mean())
|
||||
|
||||
maxRatio = torch.max(output_mm / target_mm, target_mm / output_mm)
|
||||
self.delta1 = float((maxRatio < 1.25).float().mean())
|
||||
self.delta2 = float((maxRatio < 1.25**2).float().mean())
|
||||
self.delta3 = float((maxRatio < 1.25**3).float().mean())
|
||||
self.data_time = 0
|
||||
self.gpu_time = 0
|
||||
|
||||
# silog uses meters
|
||||
err_log = torch.log(target[valid_mask]) - torch.log(output[valid_mask])
|
||||
normalized_squared_log = (err_log**2).mean()
|
||||
log_mean = err_log.mean()
|
||||
self.silog = math.sqrt(normalized_squared_log
|
||||
- log_mean * log_mean) * 100
|
||||
|
||||
# convert from meters to km
|
||||
inv_output_km = (1e-3 * output[valid_mask])**(-1)
|
||||
inv_target_km = (1e-3 * target[valid_mask])**(-1)
|
||||
abs_inv_diff = (inv_output_km - inv_target_km).abs()
|
||||
self.irmse = math.sqrt((torch.pow(abs_inv_diff, 2)).mean())
|
||||
self.imae = float(abs_inv_diff.mean())
|
||||
|
||||
self.photometric = float(photometric)
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""AverageMeter"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""reset"""
|
||||
self.count = 0.0
|
||||
self.sum_irmse = 0
|
||||
self.sum_imae = 0
|
||||
self.sum_mse = 0
|
||||
self.sum_rmse = 0
|
||||
self.sum_mae = 0
|
||||
self.sum_absrel = 0
|
||||
self.sum_squared_rel = 0
|
||||
self.sum_lg10 = 0
|
||||
self.sum_delta1 = 0
|
||||
self.sum_delta2 = 0
|
||||
self.sum_delta3 = 0
|
||||
self.sum_data_time = 0
|
||||
self.sum_gpu_time = 0
|
||||
self.sum_photometric = 0
|
||||
self.sum_silog = 0
|
||||
|
||||
def update(self, result, gpu_time, data_time, n=1):
|
||||
"""update"""
|
||||
self.count += n
|
||||
self.sum_irmse += n * result.irmse
|
||||
self.sum_imae += n * result.imae
|
||||
self.sum_mse += n * result.mse
|
||||
self.sum_rmse += n * result.rmse
|
||||
self.sum_mae += n * result.mae
|
||||
self.sum_absrel += n * result.absrel
|
||||
self.sum_squared_rel += n * result.squared_rel
|
||||
self.sum_lg10 += n * result.lg10
|
||||
self.sum_delta1 += n * result.delta1
|
||||
self.sum_delta2 += n * result.delta2
|
||||
self.sum_delta3 += n * result.delta3
|
||||
self.sum_data_time += n * data_time
|
||||
self.sum_gpu_time += n * gpu_time
|
||||
self.sum_silog += n * result.silog
|
||||
self.sum_photometric += n * result.photometric
|
||||
|
||||
def average(self):
|
||||
"""average"""
|
||||
avg = Result()
|
||||
if self.count > 0:
|
||||
avg.update(
|
||||
self.sum_irmse / self.count, self.sum_imae / self.count,
|
||||
self.sum_mse / self.count, self.sum_rmse / self.count,
|
||||
self.sum_mae / self.count, self.sum_absrel / self.count,
|
||||
self.sum_squared_rel / self.count, self.sum_lg10 / self.count,
|
||||
self.sum_delta1 / self.count, self.sum_delta2 / self.count,
|
||||
self.sum_delta3 / self.count, self.sum_gpu_time / self.count,
|
||||
self.sum_data_time / self.count, self.sum_silog / self.count,
|
||||
self.sum_photometric / self.count)
|
||||
return avg
|
||||
215
modelscope/models/cv/self_supervised_depth_completion/model.py
Normal file
215
modelscope/models/cv/self_supervised_depth_completion/model.py
Normal file
@@ -0,0 +1,215 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torchvision.models import resnet
|
||||
|
||||
|
||||
def init_weights(m):
|
||||
"""init_weights"""
|
||||
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
||||
m.weight.data.normal_(0, 1e-3)
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.ConvTranspose2d):
|
||||
m.weight.data.normal_(0, 1e-3)
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
|
||||
def conv_bn_relu(in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bn=True,
|
||||
relu=True):
|
||||
"""conv_bn_relu"""
|
||||
bias = not bn
|
||||
layers = []
|
||||
layers.append(
|
||||
nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size, stride, padding,
|
||||
bias=bias))
|
||||
if bn:
|
||||
layers.append(nn.BatchNorm2d(out_channels))
|
||||
if relu:
|
||||
layers.append(nn.LeakyReLU(0.2, inplace=True))
|
||||
layers = nn.Sequential(*layers)
|
||||
|
||||
# initialize the weights
|
||||
for m in layers.modules():
|
||||
init_weights(m)
|
||||
|
||||
return layers
|
||||
|
||||
|
||||
def convt_bn_relu(in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
output_padding=0,
|
||||
bn=True,
|
||||
relu=True):
|
||||
"""convt_bn_relu"""
|
||||
bias = not bn
|
||||
layers = []
|
||||
layers.append(
|
||||
nn.ConvTranspose2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
output_padding,
|
||||
bias=bias))
|
||||
if bn:
|
||||
layers.append(nn.BatchNorm2d(out_channels))
|
||||
if relu:
|
||||
layers.append(nn.LeakyReLU(0.2, inplace=True))
|
||||
layers = nn.Sequential(*layers)
|
||||
|
||||
# initialize the weights
|
||||
for m in layers.modules():
|
||||
init_weights(m)
|
||||
|
||||
return layers
|
||||
|
||||
|
||||
class DepthCompletionNet(nn.Module):
|
||||
"""DepthCompletionNet"""
|
||||
|
||||
def __init__(self, args):
|
||||
assert (
|
||||
args.layers in [18, 34, 50, 101, 152]
|
||||
), f'Only layers 18, 34, 50, 101, and 152 are defined, but got {layers}'.format(
|
||||
layers)
|
||||
super(DepthCompletionNet, self).__init__()
|
||||
self.modality = args.input
|
||||
|
||||
if 'd' in self.modality:
|
||||
channels = 64 // len(self.modality)
|
||||
self.conv1_d = conv_bn_relu(
|
||||
1, channels, kernel_size=3, stride=1, padding=1)
|
||||
if 'rgb' in self.modality:
|
||||
channels = 64 * 3 // len(self.modality)
|
||||
self.conv1_img = conv_bn_relu(
|
||||
3, channels, kernel_size=3, stride=1, padding=1)
|
||||
elif 'g' in self.modality:
|
||||
channels = 64 // len(self.modality)
|
||||
self.conv1_img = conv_bn_relu(
|
||||
1, channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
pretrained_model = resnet.__dict__['resnet{}'.format(args.layers)](
|
||||
pretrained=args.pretrained)
|
||||
if not args.pretrained:
|
||||
pretrained_model.apply(init_weights)
|
||||
# self.maxpool = pretrained_model._modules['maxpool']
|
||||
self.conv2 = pretrained_model._modules['layer1']
|
||||
self.conv3 = pretrained_model._modules['layer2']
|
||||
self.conv4 = pretrained_model._modules['layer3']
|
||||
self.conv5 = pretrained_model._modules['layer4']
|
||||
del pretrained_model # clear memory
|
||||
|
||||
# define number of intermediate channels
|
||||
if args.layers <= 34:
|
||||
num_channels = 512
|
||||
elif args.layers >= 50:
|
||||
num_channels = 2048
|
||||
self.conv6 = conv_bn_relu(
|
||||
num_channels, 512, kernel_size=3, stride=2, padding=1)
|
||||
|
||||
# decoding layers
|
||||
kernel_size = 3
|
||||
stride = 2
|
||||
self.convt5 = convt_bn_relu(
|
||||
in_channels=512,
|
||||
out_channels=256,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
output_padding=1)
|
||||
self.convt4 = convt_bn_relu(
|
||||
in_channels=768,
|
||||
out_channels=128,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
output_padding=1)
|
||||
self.convt3 = convt_bn_relu(
|
||||
in_channels=(256 + 128),
|
||||
out_channels=64,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
output_padding=1)
|
||||
self.convt2 = convt_bn_relu(
|
||||
in_channels=(128 + 64),
|
||||
out_channels=64,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
output_padding=1)
|
||||
self.convt1 = convt_bn_relu(
|
||||
in_channels=128,
|
||||
out_channels=64,
|
||||
kernel_size=kernel_size,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.convtf = conv_bn_relu(
|
||||
in_channels=128,
|
||||
out_channels=1,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
bn=False,
|
||||
relu=False)
|
||||
|
||||
def forward(self, x):
|
||||
"""forward"""
|
||||
# first layer
|
||||
if 'd' in self.modality:
|
||||
conv1_d = self.conv1_d(x['d'])
|
||||
if 'rgb' in self.modality:
|
||||
conv1_img = self.conv1_img(x['rgb'])
|
||||
elif 'g' in self.modality:
|
||||
conv1_img = self.conv1_img(x['g'])
|
||||
|
||||
if self.modality == 'rgbd' or self.modality == 'gd':
|
||||
conv1 = torch.cat((conv1_d, conv1_img), 1)
|
||||
else:
|
||||
conv1 = conv1_d if (self.modality == 'd') else conv1_img
|
||||
|
||||
conv2 = self.conv2(conv1)
|
||||
conv3 = self.conv3(conv2) # batchsize * ? * 176 * 608
|
||||
conv4 = self.conv4(conv3) # batchsize * ? * 88 * 304
|
||||
conv5 = self.conv5(conv4) # batchsize * ? * 44 * 152
|
||||
conv6 = self.conv6(conv5) # batchsize * ? * 22 * 76
|
||||
|
||||
# decoder
|
||||
convt5 = self.convt5(conv6)
|
||||
y = torch.cat((convt5, conv5), 1)
|
||||
|
||||
convt4 = self.convt4(y)
|
||||
y = torch.cat((convt4, conv4), 1)
|
||||
|
||||
convt3 = self.convt3(y)
|
||||
y = torch.cat((convt3, conv3), 1)
|
||||
|
||||
convt2 = self.convt2(y)
|
||||
y = torch.cat((convt2, conv2), 1)
|
||||
|
||||
convt1 = self.convt1(y)
|
||||
y = torch.cat((convt1, conv1), 1)
|
||||
|
||||
y = self.convtf(y)
|
||||
|
||||
if self.training:
|
||||
return 100 * y
|
||||
else:
|
||||
min_distance = 0.9
|
||||
return F.relu(
|
||||
100 * y - min_distance
|
||||
) + min_distance # the minimum range of Velodyne is around 3 feet ~= 0.9m
|
||||
@@ -0,0 +1,225 @@
|
||||
# import argparse
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
# import mmcv
|
||||
from argparse import ArgumentParser
|
||||
# import torchvision
|
||||
from os import makedirs
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.parallel
|
||||
import torch.optim
|
||||
import torch.utils.data
|
||||
from tqdm import tqdm
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base.base_torch_model import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.models.cv.self_supervised_depth_completion import (criteria,
|
||||
helper)
|
||||
from modelscope.models.cv.self_supervised_depth_completion.dataloaders.kitti_loader import (
|
||||
KittiDepth, input_options, load_calib, oheight, owidth)
|
||||
from modelscope.models.cv.self_supervised_depth_completion.inverse_warp import (
|
||||
Intrinsics, homography_from)
|
||||
from modelscope.models.cv.self_supervised_depth_completion.metrics import (
|
||||
AverageMeter, Result)
|
||||
from modelscope.models.cv.self_supervised_depth_completion.model import \
|
||||
DepthCompletionNet
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
# from modelscope.utils.config import Config
|
||||
|
||||
m_logger = get_logger()
|
||||
|
||||
|
||||
class ArgsList():
|
||||
"""ArgsList Class"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.workers = 4
|
||||
self.epochs = 11
|
||||
self.start_epoch = 0
|
||||
self.criterion = 'l2'
|
||||
self.batch_size = 1
|
||||
self.learning_rate = 1e-5
|
||||
self.weight_decay = 0
|
||||
self.print_freq = 10
|
||||
self.resume = ''
|
||||
self.data_folder = '../data'
|
||||
self.input = 'gd'
|
||||
self.layers = 34
|
||||
self.pretrained = True
|
||||
self.val = 'select'
|
||||
self.jitter = 0.1
|
||||
self.rank_metric = 'rmse'
|
||||
self.evaluate = ''
|
||||
self.cpu = False
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.self_supervised_depth_completion,
|
||||
module_name=Models.self_supervised_depth_completion)
|
||||
class SelfSupervisedDepthCompletion(TorchModel):
|
||||
"""SelfSupervisedDepthCompletion Class"""
|
||||
|
||||
def __init__(self, model_dir: str, **kwargs):
|
||||
"""str -- model file root."""
|
||||
super().__init__(model_dir, **kwargs)
|
||||
|
||||
args = ArgsList()
|
||||
# define loss functions
|
||||
self.depth_criterion = criteria.MaskedMSELoss()
|
||||
self.photometric_criterion = criteria.PhotometricLoss()
|
||||
self.smoothness_criterion = criteria.SmoothnessLoss()
|
||||
|
||||
# args.use_pose = ('photo' in args.train_mode)
|
||||
args.use_pose = True
|
||||
# args.pretrained = not args.no_pretrained
|
||||
args.use_rgb = ('rgb' in args.input) or args.use_pose
|
||||
args.use_d = 'd' in args.input
|
||||
args.use_g = 'g' in args.input
|
||||
|
||||
args.evaluate = os.path.join(self.model_dir, 'model_best.pth')
|
||||
|
||||
if args.use_pose:
|
||||
args.w1, args.w2 = 0.1, 0.1
|
||||
else:
|
||||
args.w1, args.w2 = 0, 0
|
||||
|
||||
self.cuda = torch.cuda.is_available() and not args.cpu
|
||||
if self.cuda:
|
||||
import torch.backends.cudnn as cudnn
|
||||
cudnn.benchmark = True
|
||||
self.device = torch.device('cuda')
|
||||
else:
|
||||
self.device = torch.device('cpu')
|
||||
print("=> using '{}' for computation.".format(self.device))
|
||||
|
||||
args_new = args
|
||||
if os.path.isfile(args.evaluate):
|
||||
print(
|
||||
"=> loading checkpoint '{}' ... ".format(args.evaluate),
|
||||
end='')
|
||||
self.checkpoint = torch.load(
|
||||
args.evaluate, map_location=self.device)
|
||||
args = self.checkpoint['args']
|
||||
args.val = args_new.val
|
||||
print('Completed.')
|
||||
else:
|
||||
print("No model found at '{}'".format(args.evaluate))
|
||||
return
|
||||
|
||||
print('=> creating model and optimizer ... ', end='')
|
||||
model = DepthCompletionNet(args).to(self.device)
|
||||
model_named_params = [
|
||||
p for _, p in model.named_parameters() if p.requires_grad
|
||||
]
|
||||
optimizer = torch.optim.Adam(
|
||||
model_named_params, lr=args.lr, weight_decay=args.weight_decay)
|
||||
print('completed.')
|
||||
if self.checkpoint is not None:
|
||||
model.load_state_dict(self.checkpoint['model'])
|
||||
optimizer.load_state_dict(self.checkpoint['optimizer'])
|
||||
print('=> checkpoint state loaded.')
|
||||
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
self.model = model
|
||||
self.args = args
|
||||
|
||||
def iterate(self, mode, args, loader, model, optimizer, logger, epoch):
|
||||
"""iterate data"""
|
||||
block_average_meter = AverageMeter()
|
||||
average_meter = AverageMeter()
|
||||
meters = [block_average_meter, average_meter]
|
||||
merged_img = None
|
||||
# switch to appropriate mode
|
||||
assert mode in ['train', 'val', 'eval', 'test_prediction', 'test_completion'], \
|
||||
'unsupported mode: {}'.format(mode)
|
||||
model.eval()
|
||||
lr = 0
|
||||
|
||||
for i, batch_data in enumerate(loader):
|
||||
start = time.time()
|
||||
batch_data = {
|
||||
key: val.to(self.device)
|
||||
for key, val in batch_data.items() if val is not None
|
||||
}
|
||||
gt = batch_data[
|
||||
'gt'] if mode != 'test_prediction' and mode != 'test_completion' else None
|
||||
data_time = time.time() - start
|
||||
|
||||
start = time.time()
|
||||
pred = model(batch_data)
|
||||
photometric_loss = 0
|
||||
gpu_time = time.time() - start
|
||||
|
||||
# measure accuracy and record loss
|
||||
with torch.no_grad():
|
||||
mini_batch_size = next(iter(batch_data.values())).size(0)
|
||||
result = Result()
|
||||
if mode != 'test_prediction' and mode != 'test_completion':
|
||||
result.evaluate(pred.data, gt.data, photometric_loss)
|
||||
[
|
||||
m.update(result, gpu_time, data_time, mini_batch_size)
|
||||
for m in meters
|
||||
]
|
||||
logger.conditional_print(mode, i, epoch, lr, len(loader),
|
||||
block_average_meter, average_meter)
|
||||
merged_img = logger.conditional_save_img_comparison(
|
||||
mode, i, batch_data, pred, epoch)
|
||||
merged_img = cv2.cvtColor(merged_img, cv2.COLOR_RGB2BGR)
|
||||
logger.conditional_save_pred(mode, i, pred, epoch)
|
||||
|
||||
avg = logger.conditional_save_info(mode, average_meter, epoch)
|
||||
is_best = logger.rank_conditional_save_best(mode, avg, epoch)
|
||||
logger.save_img_comparison_as_best(mode, epoch)
|
||||
logger.conditional_summarize(mode, avg, is_best)
|
||||
|
||||
return avg, is_best, merged_img
|
||||
|
||||
def forward(self, source_dir):
|
||||
"""main function"""
|
||||
|
||||
args = self.args
|
||||
args.data_folder = source_dir
|
||||
args.result = os.path.join(args.data_folder, 'results')
|
||||
if args.use_pose:
|
||||
# hard-coded KITTI camera intrinsics
|
||||
K = load_calib(args)
|
||||
fu, fv = float(K[0, 0]), float(K[1, 1])
|
||||
cu, cv = float(K[0, 2]), float(K[1, 2])
|
||||
kitti_intrinsics = Intrinsics(owidth, oheight, fu, fv, cu, cv)
|
||||
if self.cuda:
|
||||
kitti_intrinsics = kitti_intrinsics.cuda()
|
||||
|
||||
# Data loading code
|
||||
print('=> creating data loaders ... ')
|
||||
val_dataset = KittiDepth('val', self.args)
|
||||
val_loader = torch.utils.data.DataLoader(
|
||||
val_dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=2,
|
||||
pin_memory=True) # set batch size to be 1 for validation
|
||||
print('\t==> val_loader size:{}'.format(len(val_loader)))
|
||||
|
||||
# create backups and results folder
|
||||
logger = helper.logger(self.args)
|
||||
if self.checkpoint is not None:
|
||||
logger.best_result = self.checkpoint['best_result']
|
||||
|
||||
print('=> starting model evaluation ...')
|
||||
result, is_best, merged_img = self.iterate('val', self.args,
|
||||
val_loader, self.model,
|
||||
None, logger,
|
||||
self.checkpoint['epoch'])
|
||||
return merged_img
|
||||
@@ -0,0 +1,119 @@
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
if not ('DISPLAY' in os.environ):
|
||||
import matplotlib as mpl
|
||||
mpl.use('Agg')
|
||||
|
||||
cmap = plt.cm.jet
|
||||
|
||||
|
||||
def depth_colorize(depth):
|
||||
depth = (depth - np.min(depth)) / (np.max(depth) - np.min(depth))
|
||||
depth = 255 * cmap(depth)[:, :, :3] # H, W, C
|
||||
return depth.astype('uint8')
|
||||
|
||||
|
||||
def merge_into_row(ele, pred):
|
||||
|
||||
def preprocess_depth(x):
|
||||
y = np.squeeze(x.data.cpu().numpy())
|
||||
return depth_colorize(y)
|
||||
|
||||
# if is gray, transforms to rgb
|
||||
img_list = []
|
||||
if 'rgb' in ele:
|
||||
rgb = np.squeeze(ele['rgb'][0, ...].data.cpu().numpy())
|
||||
rgb = np.transpose(rgb, (1, 2, 0))
|
||||
img_list.append(rgb)
|
||||
elif 'g' in ele:
|
||||
g = np.squeeze(ele['g'][0, ...].data.cpu().numpy())
|
||||
g = np.array(Image.fromarray(g).convert('RGB'))
|
||||
img_list.append(g)
|
||||
if 'd' in ele:
|
||||
img_list.append(preprocess_depth(ele['d'][0, ...]))
|
||||
img_list.append(preprocess_depth(pred[0, ...]))
|
||||
if 'gt' in ele:
|
||||
img_list.append(preprocess_depth(ele['gt'][0, ...]))
|
||||
|
||||
img_merge = np.hstack(img_list)
|
||||
return img_merge.astype('uint8')
|
||||
|
||||
|
||||
def add_row(img_merge, row):
|
||||
return np.vstack([img_merge, row])
|
||||
|
||||
|
||||
def save_image(img_merge, filename):
|
||||
image_to_write = cv2.cvtColor(img_merge, cv2.COLOR_RGB2BGR)
|
||||
cv2.imwrite(filename, image_to_write)
|
||||
|
||||
|
||||
def save_depth_as_uint16png(img, filename):
|
||||
img = (img * 256).astype('uint16')
|
||||
cv2.imwrite(filename, img)
|
||||
|
||||
|
||||
if ('DISPLAY' in os.environ):
|
||||
f, axarr = plt.subplots(4, 1)
|
||||
plt.tight_layout()
|
||||
plt.ion()
|
||||
|
||||
|
||||
def display_warping(rgb_tgt, pred_tgt, warped):
|
||||
|
||||
def preprocess(rgb_tgt, pred_tgt, warped):
|
||||
rgb_tgt = 255 * np.transpose(
|
||||
np.squeeze(rgb_tgt.data.cpu().numpy()), (1, 2, 0)) # H, W, C
|
||||
# depth = np.squeeze(depth.cpu().numpy())
|
||||
# depth = depth_colorize(depth)
|
||||
|
||||
# convert to log-scale
|
||||
pred_tgt = np.squeeze(pred_tgt.data.cpu().numpy())
|
||||
# pred_tgt[pred_tgt<=0] = 0.9 # remove negative predictions
|
||||
# pred_tgt = np.log10(pred_tgt)
|
||||
|
||||
pred_tgt = depth_colorize(pred_tgt)
|
||||
|
||||
warped = 255 * np.transpose(
|
||||
np.squeeze(warped.data.cpu().numpy()), (1, 2, 0)) # H, W, C
|
||||
recon_err = np.absolute(
|
||||
warped.astype('float') - rgb_tgt.astype('float')) * (
|
||||
warped > 0)
|
||||
recon_err = recon_err[:, :, 0] + recon_err[:, :, 1] + recon_err[:, :,
|
||||
2]
|
||||
recon_err = depth_colorize(recon_err)
|
||||
return rgb_tgt.astype('uint8'), warped.astype(
|
||||
'uint8'), recon_err, pred_tgt
|
||||
|
||||
rgb_tgt, warped, recon_err, pred_tgt = preprocess(rgb_tgt, pred_tgt,
|
||||
warped)
|
||||
|
||||
# 1st column
|
||||
# column = 0
|
||||
axarr[0].imshow(rgb_tgt)
|
||||
axarr[0].axis('off')
|
||||
axarr[0].axis('equal')
|
||||
# axarr[0, column].set_title('rgb_tgt')
|
||||
|
||||
axarr[1].imshow(warped)
|
||||
axarr[1].axis('off')
|
||||
axarr[1].axis('equal')
|
||||
# axarr[1, column].set_title('warped')
|
||||
|
||||
axarr[2].imshow(recon_err, 'hot')
|
||||
axarr[2].axis('off')
|
||||
axarr[2].axis('equal')
|
||||
# axarr[2, column].set_title('recon_err error')
|
||||
|
||||
axarr[3].imshow(pred_tgt, 'hot')
|
||||
axarr[3].axis('off')
|
||||
axarr[3].axis('equal')
|
||||
# axarr[3, column].set_title('pred_tgt')
|
||||
|
||||
# plt.show()
|
||||
plt.pause(0.001)
|
||||
@@ -774,6 +774,7 @@ TASK_OUTPUTS = {
|
||||
Tasks.surface_recon_common: [OutputKeys.OUTPUT],
|
||||
Tasks.video_colorization: [OutputKeys.OUTPUT_VIDEO],
|
||||
Tasks.image_control_3d_portrait: [OutputKeys.OUTPUT],
|
||||
Tasks.self_supervised_depth_completion: [OutputKeys.OUTPUT_IMG],
|
||||
|
||||
# image quality assessment degradation result for single image
|
||||
# {
|
||||
|
||||
@@ -121,6 +121,8 @@ if TYPE_CHECKING:
|
||||
from .image_local_feature_matching_pipeline import ImageLocalFeatureMatchingPipeline
|
||||
from .rife_video_frame_interpolation_pipeline import RIFEVideoFrameInterpolationPipeline
|
||||
from .anydoor_pipeline import AnydoorPipeline
|
||||
from .self_supervised_depth_completion_pipeline import SelfSupervisedDepthCompletionPipeline
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'action_recognition_pipeline': ['ActionRecognitionPipeline'],
|
||||
@@ -303,6 +305,9 @@ else:
|
||||
'RIFEVideoFrameInterpolationPipeline'
|
||||
],
|
||||
'anydoor_pipeline': ['AnydoorPipeline'],
|
||||
'self_supervised_depth_completion_pipeline': [
|
||||
'SelfSupervisedDepthCompletionPipeline'
|
||||
],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import Any, Dict
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.self_supervised_depth_completion,
|
||||
module_name=Pipelines.self_supervised_depth_completion)
|
||||
class SelfSupervisedDepthCompletionPipeline(Pipeline):
|
||||
"""Self Supervise dDepth Completion Pipeline
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> model_id = 'Damo_XR_Lab/Self_Supervised_Depth_Completion'
|
||||
>>> data_dir = MsDataset.load(
|
||||
'KITTI_Depth_Dataset',
|
||||
namespace='Damo_XR_Lab',
|
||||
split='test',
|
||||
download_mode=DownloadMode.FORCE_REDOWNLOAD
|
||||
).config_kwargs['split_config']['test']
|
||||
>>> source_dir = os.path.join(data_dir, 'selected_data')
|
||||
>>> self_supervised_depth_completion = pipeline(Tasks.self_supervised_depth_completion,
|
||||
'Damo_XR_Lab/Self_Supervised_Depth_Completion')
|
||||
>>> result = self_supervised_depth_completion({
|
||||
'model_dir': model_id
|
||||
'source_dir': source_dir
|
||||
})
|
||||
cv2.imwrite('result.jpg', result[OutputKeys.OUTPUT])
|
||||
>>> #
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, model: str, **kwargs):
|
||||
|
||||
super().__init__(model=model, **kwargs)
|
||||
logger.info('load model done')
|
||||
|
||||
def preprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""preprocess, not used at present"""
|
||||
return inputs
|
||||
|
||||
def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""forward"""
|
||||
source_dir = inputs['source_dir']
|
||||
result = self.model.forward(source_dir)
|
||||
return {OutputKeys.OUTPUT: result}
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""postprocess, not used at present"""
|
||||
return inputs
|
||||
@@ -170,6 +170,7 @@ class CVTasks(object):
|
||||
human3d_render = 'human3d-render'
|
||||
human3d_animation = 'human3d-animation'
|
||||
image_control_3d_portrait = 'image-control-3d-portrait'
|
||||
self_supervised_depth_completion = 'self-supervised-depth-completion'
|
||||
|
||||
# 3d generation
|
||||
image_to_3d = 'image-to-3d'
|
||||
|
||||
@@ -3812,5 +3812,18 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"self-supervised-depth-completion": {
|
||||
"input": {},
|
||||
"parameters": {},
|
||||
"output": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"output_img": {
|
||||
"type": "string",
|
||||
"description":"The base64 encoded image."
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
54
tests/pipelines/test_self_supervised_depth_completion.py
Normal file
54
tests/pipelines/test_self_supervised_depth_completion.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
from modelscope import get_logger
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.outputs.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import DownloadMode, Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class SelfSupervisedDepthCompletionTest(unittest.TestCase):
|
||||
"""class SelfSupervisedDepthCompletionTest"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.model_id = 'Damo_XR_Lab/Self_Supervised_Depth_Completion'
|
||||
data_dir = MsDataset.load(
|
||||
'KITTI_Depth_Dataset',
|
||||
namespace='Damo_XR_Lab',
|
||||
split='test',
|
||||
download_mode=DownloadMode.FORCE_REDOWNLOAD
|
||||
).config_kwargs['split_config']['test']
|
||||
self.source_dir = os.path.join(data_dir, 'selected_data')
|
||||
logger.info(data_dir)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
@unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest only')
|
||||
def test_run(self):
|
||||
"""test running evaluation"""
|
||||
snapshot_path = snapshot_download(self.model_id)
|
||||
logger.info('snapshot_path: %s', snapshot_path)
|
||||
self_supervised_depth_completion = pipeline(
|
||||
task=Tasks.self_supervised_depth_completion,
|
||||
model=self.model_id
|
||||
# ,config_file = os.path.join(modelPath, "configuration.json")
|
||||
)
|
||||
|
||||
result = self_supervised_depth_completion(
|
||||
dict(model_dir=snapshot_path, source_dir=self.source_dir))
|
||||
cv2.imwrite('result.jpg', result[OutputKeys.OUTPUT])
|
||||
logger.info(
|
||||
'self-supervised-depth-completion_damo.test_run_modelhub done')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
0
tests/test_metrics/__init__.py
Normal file
0
tests/test_metrics/__init__.py
Normal file
Reference in New Issue
Block a user