add cv_pointnet2_sceneflow-estimation_general

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11201880
This commit is contained in:
dadong.gxd
2022-12-29 08:09:57 +08:00
committed by yingda.chen
parent f7a7504782
commit 42557b0867
16 changed files with 1596 additions and 5 deletions

1
.gitattributes vendored
View File

@@ -7,3 +7,4 @@
*.pickle filter=lfs diff=lfs merge=lfs -text
*.avi filter=lfs diff=lfs merge=lfs -text
*.bin filter=lfs diff=lfs merge=lfs -text
*.npy filter=lfs diff=lfs merge=lfs -text

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9fa9f5c8a49d457a7b6f4239e438699e60541e7602e8b3b66da9f7b6d55096ab
size 1735856

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:86618feded6ae9fbcc772b9a7da17bad7d8b9c68ae0d505a239d110a3a0a7bf4
size 1735856

View File

@@ -62,6 +62,7 @@ class Models(object):
video_human_matting = 'video-human-matting'
video_object_segmentation = 'video-object-segmentation'
real_basicvsr = 'real-basicvsr'
rcp_sceneflow_estimation = 'rcp-sceneflow-estimation'
# EasyCV models
yolox = 'YOLOX'
@@ -253,6 +254,7 @@ class Pipelines(object):
video_human_matting = 'video-human-matting'
video_object_segmentation = 'video-object-segmentation'
video_super_resolution = 'realbasicvsr-video-super-resolution'
pointcloud_sceneflow_estimation = 'pointcloud-sceneflow-estimation'
# nlp tasks
automatic_post_editing = 'automatic-post-editing'

View File

@@ -11,10 +11,11 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints,
image_reid_person, image_semantic_segmentation,
image_to_image_generation, image_to_image_translation,
language_guided_video_summarization, movie_scene_segmentation,
object_detection, product_retrieval_embedding,
realtime_object_detection, referring_video_object_segmentation,
salient_detection, shop_segmentation, super_resolution,
video_object_segmentation, video_single_object_tracking,
video_summarization, video_super_resolution, virual_tryon)
object_detection, pointcloud_sceneflow_estimation,
product_retrieval_embedding, realtime_object_detection,
referring_video_object_segmentation, salient_detection,
shop_segmentation, super_resolution, video_object_segmentation,
video_single_object_tracking, video_summarization,
video_super_resolution, virual_tryon)
# yapf: enable

View File

@@ -0,0 +1,22 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .rcp_model import SceneFlowEstimation
else:
_import_structure = {
'rcp_model': ['SceneFlowEstimation'],
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1,446 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import pointnet2_utils as pointutils
RADIUS = 2.5
def index_points_group(points, knn_idx):
"""
Input:
points: input points data, [B, N, C]
knn_idx: sample index data, [B, N, K]
Return:
new_points:, indexed points data, [B, N, K, C]
"""
points_flipped = points.permute(0, 2, 1).contiguous()
new_points = pointutils.grouping_operation(points_flipped,
knn_idx.int()).permute(
0, 2, 3, 1)
return new_points
def curvature(pc, nsample=10, radius=RADIUS):
# pc: B 3 N
assert pc.shape[1] == 3
pc = pc.permute(0, 2, 1)
dist, kidx = pointutils.knn(nsample, pc.contiguous(),
pc.contiguous()) # (B, N, 10)
if radius is not None:
tmp_idx = kidx[:, :, 0].unsqueeze(2).repeat(1, 1,
nsample).to(kidx.device)
kidx[dist > radius] = tmp_idx[dist > radius]
grouped_pc = index_points_group(pc, kidx) # B N 10 3
pc_curvature = torch.sum(grouped_pc - pc.unsqueeze(2), dim=2) / 9.0
return pc_curvature # B N 3
class PointNetSetAbstractionRatio(nn.Module):
def __init__(self,
ratio,
radius,
nsample,
in_channel,
mlp,
group_all,
return_fps=False,
use_xyz=True,
use_act=True,
act=F.relu,
mean_aggr=False,
use_instance_norm=False):
super(PointNetSetAbstractionRatio, self).__init__()
self.ratio = ratio
self.radius = radius
self.nsample = nsample
self.group_all = group_all
self.use_xyz = use_xyz
self.use_act = use_act
self.mean_aggr = mean_aggr
self.act = act
self.mlp_convs = nn.ModuleList()
self.mlp_bns = nn.ModuleList()
last_channel = (in_channel + 3) if use_xyz else in_channel
for out_channel in mlp:
self.mlp_convs.append(
nn.Conv2d(last_channel, out_channel, 1, bias=False))
if use_instance_norm:
self.mlp_bns.append(
nn.InstanceNorm2d(out_channel, affine=True))
else:
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
last_channel = out_channel
if group_all:
self.queryandgroup = pointutils.GroupAll(self.use_xyz)
else:
self.queryandgroup = pointutils.QueryAndGroup(
radius, nsample, self.use_xyz)
self.return_fps = return_fps
def forward(self, xyz, points, fps_idx=None):
"""
Input:
xyz: input points position data, [B, C, N]
points: input points data, [B, D, N]
Return:
new_xyz: sampled points position data, [B, C, S]
new_points: sample points feature data, [B, D', S]
"""
B, C, N = xyz.shape
npoint = int(N * self.ratio)
xyz = xyz.contiguous()
xyz_t = xyz.permute(0, 2, 1).contiguous()
if (self.group_all is False) and (npoint != -1):
if fps_idx is None:
fps_idx = pointutils.furthest_point_sample(xyz_t,
npoint) # [B, N]
new_xyz = pointutils.gather_operation(xyz, fps_idx) # [B, C, N]
else:
new_xyz = xyz
new_points, _ = self.queryandgroup(xyz_t,
new_xyz.transpose(2,
1).contiguous(),
points) # [B, 3+C, N, S]
# new_xyz: sampled points position data, [B, C, npoint]
# new_points: sampled points data, [B, C+D, npoint, nsample]
for i, conv in enumerate(self.mlp_convs):
if self.use_act:
bn = self.mlp_bns[i]
new_points = self.act(bn(conv(new_points)))
else:
new_points = conv(new_points)
if self.mean_aggr:
new_points = torch.mean(new_points, -1)
else:
new_points = torch.max(new_points, -1)[0]
if self.return_fps:
return new_xyz, new_points, fps_idx
else:
return new_xyz, new_points
class PointNetSetAbstraction(nn.Module):
def __init__(self,
npoint,
radius,
nsample,
in_channel,
mlp,
group_all,
return_fps=False,
use_xyz=True,
use_act=True,
act=F.relu,
mean_aggr=False,
use_instance_norm=False):
super(PointNetSetAbstraction, self).__init__()
self.npoint = npoint
self.radius = radius
self.nsample = nsample
self.group_all = group_all
self.use_xyz = use_xyz
self.use_act = use_act
self.mean_aggr = mean_aggr
self.act = act
self.mlp_convs = nn.ModuleList()
self.mlp_bns = nn.ModuleList()
last_channel = (in_channel + 3) if use_xyz else in_channel
for out_channel in mlp:
self.mlp_convs.append(
nn.Conv2d(last_channel, out_channel, 1, bias=False))
if use_instance_norm:
self.mlp_bns.append(
nn.InstanceNorm2d(out_channel, affine=True))
else:
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
last_channel = out_channel
if group_all:
self.queryandgroup = pointutils.GroupAll(self.use_xyz)
else:
self.queryandgroup = pointutils.QueryAndGroup(
radius, nsample, self.use_xyz)
self.return_fps = return_fps
def forward(self, xyz, points, fps_idx=None):
"""
Input:
xyz: input points position data, [B, C, N]
points: input points data, [B, D, N]
Return:
new_xyz: sampled points position data, [B, S, C]
new_points: sample points feature data, [B, S, D']
"""
# device = xyz.device
B, C, N = xyz.shape
xyz = xyz.contiguous()
xyz_t = xyz.permute(0, 2, 1).contiguous()
if (self.group_all is False) and (self.npoint != -1):
if fps_idx is None:
fps_idx = pointutils.furthest_point_sample(
xyz_t, self.npoint) # [B, N]
new_xyz = pointutils.gather_operation(xyz, fps_idx) # [B, C, N]
else:
new_xyz = xyz
new_points, _ = self.queryandgroup(xyz_t,
new_xyz.transpose(2,
1).contiguous(),
points) # [B, 3+C, N, S]
# new_xyz: sampled points position data, [B, C, npoint]
# new_points: sampled points data, [B, C+D, npoint, nsample]
for i, conv in enumerate(self.mlp_convs):
if self.use_act:
bn = self.mlp_bns[i]
new_points = self.act(bn(conv(new_points)))
else:
new_points = conv(new_points)
if self.mean_aggr:
new_points = torch.mean(new_points, -1)
else:
new_points = torch.max(new_points, -1)[0]
if self.return_fps:
return new_xyz, new_points, fps_idx
else:
return new_xyz, new_points
class PointNetFeaturePropogation(nn.Module):
def __init__(self, in_channel, mlp, learn_mask=False, nsample=3):
super(PointNetFeaturePropogation, self).__init__()
self.mlp_convs = nn.ModuleList()
self.mlp_bns = nn.ModuleList()
self.apply_mlp = mlp is not None
last_channel = in_channel
self.nsample = nsample
if self.apply_mlp:
for out_channel in mlp:
self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
self.mlp_bns.append(nn.BatchNorm1d(out_channel))
last_channel = out_channel
if learn_mask:
self.queryandgroup = pointutils.QueryAndGroup(
None, 9, use_xyz=True)
last_channel = (128 + 3)
for out_channel in [32, 1]:
self.mlp_convs.append(
nn.Conv2d(last_channel, out_channel, 1, bias=False))
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
last_channel = out_channel
def forward(self, pos1, pos2, feature1, feature2, hidden=None):
"""
Input:
pos1: input points position data, [B, C, N]
pos2: sampled input points position data, [B, C, S]
feature1: input points data, [B, D, N]
feature2: input points data, [B, D, S]
Return:
feat_new: upsampled points data, [B, D', N]
"""
pos1_t = pos1.permute(0, 2, 1).contiguous()
pos2_t = pos2.permute(0, 2, 1).contiguous()
B, C, N = pos1.shape
if hidden is None:
if self.nsample == 3:
dists, idx = pointutils.three_nn(pos1_t, pos2_t)
else:
dists, idx = pointutils.knn(self.nsample, pos1_t, pos2_t)
dists[dists < 1e-10] = 1e-10
weight = 1.0 / dists
weight = weight / torch.sum(weight, -1, keepdim=True) # [B,N,3]
interpolated_feat = torch.sum(
pointutils.grouping_operation(feature2, idx)
* weight.view(B, 1, N, self.nsample),
dim=-1) # [B,C,N,3]
else:
dist, idx = pointutils.knn(9, pos1_t, pos2_t)
new_feat, _ = self.queryandgroup(pos2_t, pos1_t,
hidden) # [B, 3+C, N, 9]
for i, conv in enumerate(self.mlp_convs):
new_feat = conv(new_feat)
weight = torch.softmax(new_feat, dim=-1) # [B, 1, N, 9]
interpolated_feat = torch.sum(
pointutils.grouping_operation(feature2, idx) * weight,
dim=-1) # [B, C, N]
if feature1 is not None:
feat_new = torch.cat([interpolated_feat, feature1], 1)
else:
feat_new = interpolated_feat
if self.apply_mlp:
for i, conv in enumerate(self.mlp_convs):
bn = self.mlp_bns[i]
feat_new = F.relu(bn(conv(feat_new)))
return feat_new
class Sinkhorn(nn.Module):
def __init__(self):
super().__init__()
def forward(self, corr, epsilon, gamma, max_iter):
# Early return if no iteration
if max_iter == 0:
return corr
# Init. of Sinkhorn algorithm
power = gamma / (gamma + epsilon)
a = (
torch.ones((corr.shape[0], corr.shape[1], 1),
device=corr.device,
dtype=corr.dtype) / corr.shape[1])
prob1 = (
torch.ones((corr.shape[0], corr.shape[1], 1),
device=corr.device,
dtype=corr.dtype) / corr.shape[1])
prob2 = (
torch.ones((corr.shape[0], corr.shape[2], 1),
device=corr.device,
dtype=corr.dtype) / corr.shape[2])
# Sinkhorn algorithm
for _ in range(max_iter):
# Update b
KTa = torch.bmm(corr.transpose(1, 2), a)
b = torch.pow(prob2 / (KTa + 1e-8), power)
# Update a
Kb = torch.bmm(corr, b)
a = torch.pow(prob1 / (Kb + 1e-8), power)
# Transportation map
T = torch.mul(torch.mul(a, corr), b.transpose(1, 2))
return T
class PointWiseOptimLayer(nn.Module):
def __init__(self, nsample, radius, in_channel, mlp, use_curvature=True):
super().__init__()
self.nsample = nsample
self.radius = radius
self.use_curvature = use_curvature
self.pos_embed = nn.Sequential(
nn.Conv1d(3, 32, 1), nn.ReLU(inplace=True), nn.Conv1d(32, 64, 1))
self.qk_net = nn.Sequential(
nn.Conv1d(in_channel + 64, in_channel + 64, 1),
nn.ReLU(inplace=True),
nn.Conv1d(in_channel + 64, in_channel + 64, 1))
if self.use_curvature:
self.curvate_net = nn.Sequential(
nn.Conv1d(3, 32, 1), nn.ReLU(inplace=True),
nn.Conv1d(32, 32, 1))
self.mlp_conv = nn.Conv1d(
in_channel + 64 + 32, mlp[-1], 1, bias=True)
else:
self.mlp_conv = nn.Conv1d(in_channel + 64, mlp[-1], 1, bias=True)
def forward(self,
pos1,
pos2,
feature1,
feature2,
nsample,
radius=None,
pos1_raw=None,
return_score=False):
"""
Input:
pos1: (batch_size, 3, npoint)
pos2: (batch_size, 3, npoint)
feature1: (batch_size, channel, npoint)
feature2: (batch_size, channel, npoint)
Output:
pos1: (batch_size, 3, npoint)
cost: (batch_size, channel, npoint)
"""
pos1_t = pos1.permute(0, 2, 1).contiguous()
pos2_t = pos2.permute(0, 2, 1).contiguous()
self.nsample = nsample
self.radius = radius
dist, idx = pointutils.knn(self.nsample, pos1_t, pos2_t) # [B, N, K]
if self.radius is not None:
tmp_idx = idx[:, :,
0].unsqueeze(2).repeat(1, 1,
self.nsample).to(idx.device)
idx[dist > self.radius] = tmp_idx[dist > self.radius]
pos1_embed_norm = self.pos_embed(pos1)
pos2_embed_norm = self.pos_embed(pos2) # [B, C1, N]
feat1_w_pos = torch.cat([feature1, pos1_embed_norm], dim=1)
feat2_w_pos = torch.cat([feature2, pos2_embed_norm],
dim=1) # [B, C1+C2, N]
feat1_w_pos = self.qk_net(feat1_w_pos)
feat2_w_pos = self.qk_net(feat2_w_pos) # [B, C1+C2, N]
feat2_grouped = pointutils.grouping_operation(feat2_w_pos,
idx) # [B, C1+C2, N, S]
score = torch.softmax(
feat1_w_pos.unsqueeze(-1) * feat2_grouped * 1.
/ math.sqrt(feat1_w_pos.shape[1]),
dim=-1) # [B, C1+C2, N, S]
cost = (score * (feat1_w_pos.unsqueeze(-1) - feat2_grouped)**2).sum(
dim=-1) # [B, C1+C2, N]
if self.use_curvature:
curvate1_raw = curvature(pos1_raw).permute(0, 2, 1) # [B, 3, N]
curvate1 = curvature(pos1).permute(0, 2, 1) # [B, 3, N]
curvate_cost = self.curvate_net(curvate1_raw) - self.curvate_net(
curvate1)
curvate_cost = curvate_cost**2
cost = self.mlp_conv(torch.cat([cost, curvate_cost],
dim=1)) # [B, C, N]
else:
cost = self.mlp_conv(cost) # [B, C, N]
if return_score:
pos2_grouped = pointutils.grouping_operation(pos2,
idx) # [B, 3, N, S]
# [B, N, K]
index = (dist > self.radius).sum(
dim=2, keepdim=True).float() > (dist.shape[2] - 0.1
) # [B, N, 1]
index = index.unsqueeze(1).repeat(1, score.shape[1], 1,
dist.shape[2]) # [B, N, K]
score_tmp = score.clone()
score_tmp[index] = 0.0
score = score_tmp
return pos1, cost, score, pos2_grouped
else:
return pos1, cost

View File

@@ -0,0 +1,360 @@
# The implementation is adopt from PointNet2, open-sourced under MIT license,
# made publicy available at https://github.com/sshaoshuai/Pointnet2.PyTorch
from typing import Tuple
import pointnet2_cuda as pointnet2
import torch
import torch.nn as nn
from torch.autograd import Function, Variable
class FurthestPointSampling(Function):
@staticmethod
def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor:
"""
Uses iterative furthest point sampling to select a set of npoint features that have the largest
minimum distance
:param ctx:
:param xyz: (B, N, 3) where N > npoint
:param npoint: int, number of features in the sampled set
:return:
output: (B, npoint) tensor containing the set
"""
assert xyz.is_contiguous()
B, N, _ = xyz.size()
output = torch.cuda.IntTensor(B, npoint)
temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp,
output)
return output
@staticmethod
def backward(xyz, a=None):
return None, None
furthest_point_sample = FurthestPointSampling.apply
class GatherOperation(Function):
@staticmethod
def forward(ctx, features: torch.Tensor,
idx: torch.Tensor) -> torch.Tensor:
"""
:param ctx:
:param features: (B, C, N)
:param idx: (B, npoint) index tensor of the features to gather
:return:
output: (B, C, npoint)
"""
assert features.is_contiguous()
assert idx.is_contiguous()
B, npoint = idx.size()
_, C, N = features.size()
output = torch.cuda.FloatTensor(B, C, npoint)
pointnet2.gather_points_wrapper(B, C, N, npoint, features, idx, output)
ctx.for_backwards = (idx, C, N)
return output
@staticmethod
def backward(ctx, grad_out):
idx, C, N = ctx.for_backwards
B, npoint = idx.size()
grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
grad_out_data = grad_out.data.contiguous()
pointnet2.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data,
idx, grad_features.data)
return grad_features, None
gather_operation = GatherOperation.apply
class KNN(Function):
@staticmethod
def forward(ctx, k: int, unknown: torch.Tensor,
known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Find the three nearest neighbors of unknown in known
:param ctx:
:param unknown: (B, N, 3)
:param known: (B, M, 3)
:return:
dist: (B, N, k) l2 distance to the three nearest neighbors
idx: (B, N, k) index of 3 nearest neighbors
"""
assert unknown.is_contiguous()
assert known.is_contiguous()
B, N, _ = unknown.size()
m = known.size(1)
dist2 = torch.cuda.FloatTensor(B, N, k)
idx = torch.cuda.IntTensor(B, N, k)
pointnet2.knn_wrapper(B, N, m, k, unknown, known, dist2, idx)
return torch.sqrt(dist2), idx
@staticmethod
def backward(ctx, a=None, b=None):
return None, None, None
knn = KNN.apply
class ThreeNN(Function):
@staticmethod
def forward(ctx, unknown: torch.Tensor,
known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Find the three nearest neighbors of unknown in known
:param ctx:
:param unknown: (B, N, 3)
:param known: (B, M, 3)
:return:
dist: (B, N, 3) l2 distance to the three nearest neighbors
idx: (B, N, 3) index of 3 nearest neighbors
"""
assert unknown.is_contiguous()
assert known.is_contiguous()
B, N, _ = unknown.size()
m = known.size(1)
dist2 = torch.cuda.FloatTensor(B, N, 3)
idx = torch.cuda.IntTensor(B, N, 3)
pointnet2.three_nn_wrapper(B, N, m, unknown, known, dist2, idx)
return torch.sqrt(dist2), idx
@staticmethod
def backward(ctx, a=None, b=None):
return None, None
three_nn = ThreeNN.apply
class ThreeInterpolate(Function):
@staticmethod
def forward(ctx, features: torch.Tensor, idx: torch.Tensor,
weight: torch.Tensor) -> torch.Tensor:
"""
Performs weight linear interpolation on 3 features
:param ctx:
:param features: (B, C, M) Features descriptors to be interpolated from
:param idx: (B, n, 3) three nearest neighbors of the target features in features
:param weight: (B, n, 3) weights
:return:
output: (B, C, N) tensor of the interpolated features
"""
assert features.is_contiguous()
assert idx.is_contiguous()
assert weight.is_contiguous()
B, c, m = features.size()
n = idx.size(1)
ctx.three_interpolate_for_backward = (idx, weight, m)
output = torch.cuda.FloatTensor(B, c, n)
pointnet2.three_interpolate_wrapper(B, c, m, n, features, idx, weight,
output)
return output
@staticmethod
def backward(
ctx, grad_out: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
:param ctx:
:param grad_out: (B, C, N) tensor with gradients of outputs
:return:
grad_features: (B, C, M) tensor with gradients of features
None:
None:
"""
idx, weight, m = ctx.three_interpolate_for_backward
B, c, n = grad_out.size()
grad_features = Variable(torch.cuda.FloatTensor(B, c, m).zero_())
grad_out_data = grad_out.data.contiguous()
pointnet2.three_interpolate_grad_wrapper(B, c, n, m, grad_out_data,
idx, weight,
grad_features.data)
return grad_features, None, None
three_interpolate = ThreeInterpolate.apply
class GroupingOperation(Function):
@staticmethod
def forward(ctx, features: torch.Tensor,
idx: torch.Tensor) -> torch.Tensor:
"""
:param ctx:
:param features: (B, C, N) tensor of features to group
:param idx: (B, npoint, nsample) tensor containing the indicies of features to group with
:return:
output: (B, C, npoint, nsample) tensor
"""
assert features.is_contiguous()
assert idx.is_contiguous()
idx = idx.int()
B, nfeatures, nsample = idx.size()
_, C, N = features.size()
output = torch.cuda.FloatTensor(B, C, nfeatures, nsample)
pointnet2.group_points_wrapper(B, C, N, nfeatures, nsample, features,
idx, output)
ctx.for_backwards = (idx, N)
return output
@staticmethod
def backward(ctx,
grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
:param ctx:
:param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward
:return:
grad_features: (B, C, N) gradient of the features
"""
idx, N = ctx.for_backwards
B, C, npoint, nsample = grad_out.size()
grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
grad_out_data = grad_out.data.contiguous()
pointnet2.group_points_grad_wrapper(B, C, N, npoint, nsample,
grad_out_data, idx,
grad_features.data)
return grad_features, None
grouping_operation = GroupingOperation.apply
class BallQuery(Function):
@staticmethod
def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor,
new_xyz: torch.Tensor) -> torch.Tensor:
"""
:param ctx:
:param radius: float, radius of the balls
:param nsample: int, maximum number of features in the balls
:param xyz: (B, N, 3) xyz coordinates of the features
:param new_xyz: (B, npoint, 3) centers of the ball query
:return:
idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls
"""
assert new_xyz.is_contiguous()
assert xyz.is_contiguous()
B, N, _ = xyz.size()
npoint = new_xyz.size(1)
idx = torch.cuda.IntTensor(B, npoint, nsample).zero_()
pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz,
xyz, idx)
return idx
@staticmethod
def backward(ctx, a=None):
return None, None, None, None
ball_query = BallQuery.apply
class QueryAndGroup(nn.Module):
def __init__(self, radius: float, nsample: int, use_xyz: bool = True):
"""
:param radius: float, radius of ball
:param nsample: int, maximum number of features to gather in the ball
:param use_xyz:
"""
super().__init__()
self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
def forward(self,
xyz: torch.Tensor,
new_xyz: torch.Tensor,
features: torch.Tensor = None) -> Tuple[torch.Tensor]:
"""
:param xyz: (B, N, 3) xyz coordinates of the features
:param new_xyz: (B, npoint, 3) centroids
:param features: (B, C, N) descriptors of the features
:return:
new_features: (B, 3 + C, npoint, nsample)
"""
# idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
B, N, C = new_xyz.shape
dist, idx = knn(self.nsample, new_xyz, xyz)
if self.radius is not None:
tmp_idx = idx[:, :,
0].unsqueeze(2).repeat(1, 1,
self.nsample).to(idx.device)
idx[dist > self.radius] = tmp_idx[dist > self.radius]
xyz_trans = xyz.transpose(1, 2).contiguous()
grouped_xyz = grouping_operation(xyz_trans,
idx) # (B, 3, npoint, nsample)
grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)
if features is not None:
grouped_features = grouping_operation(features, idx)
if self.use_xyz:
new_features = torch.cat([grouped_xyz, grouped_features],
dim=1) # (B, C + 3, npoint, nsample)
else:
new_features = grouped_features
else:
assert self.use_xyz, 'Cannot have not features and not use xyz as a feature!'
new_features = grouped_xyz
return new_features, grouped_xyz
class GroupAll(nn.Module):
def __init__(self, use_xyz: bool = True):
super().__init__()
self.use_xyz = use_xyz
def forward(self,
xyz: torch.Tensor,
new_xyz: torch.Tensor,
features: torch.Tensor = None):
"""
:param xyz: (B, N, 3) xyz coordinates of the features
:param new_xyz: ignored
:param features: (B, C, N) descriptors of the features
:return:
new_features: (B, C + 3, 1, N)
"""
grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
if features is not None:
grouped_features = features.unsqueeze(2)
if self.use_xyz:
new_features = torch.cat([grouped_xyz, grouped_features],
dim=1) # (B, 3 + C, 1, N)
else:
new_features = grouped_features
else:
new_features = grouped_xyz
return new_features, grouped_xyz

View File

@@ -0,0 +1,64 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
import numpy as np
import torch
from modelscope.metainfo import Models
from modelscope.models.base.base_torch_model import TorchModel
from modelscope.models.builder import MODELS
from modelscope.outputs import OutputKeys
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from .sf_rcp import SF_RCP
logger = get_logger()
@MODELS.register_module(
Tasks.pointcloud_sceneflow_estimation,
module_name=Models.rcp_sceneflow_estimation)
class SceneFlowEstimation(TorchModel):
def __init__(self, model_dir: str, **kwargs):
"""str -- model file root."""
super().__init__(model_dir, **kwargs)
assert torch.cuda.is_available(
), 'current model only support run in gpu'
# build model
self.model = SF_RCP(
npoint=8192,
use_instance_norm=False,
model_name='SF_RCP',
use_insrance_norm=False,
use_curvature=True)
# load model
model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE)
logger.info(f'load ckpt from:{model_path}')
checkpoint = torch.load(model_path, map_location='cpu')
self.model.load_state_dict({k: v for k, v in checkpoint.items()})
self.model.cuda()
self.model.eval()
def forward(self, Inputs):
return self.model(Inputs['pcd1'], Inputs['pcd2'], Inputs['pcd1'],
Inputs['pcd2'])[-1]
def postprocess(self, Inputs):
output = Inputs['output']
results = {OutputKeys.OUTPUT: output.detach().cpu().numpy()[0]}
return results
def inference(self, data):
results = self.forward(data)
return results

View File

@@ -0,0 +1,523 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torch.nn as nn
import torch.nn.functional as F
from .common import (PointNetFeaturePropogation, PointNetSetAbstraction,
PointWiseOptimLayer, Sinkhorn)
class FeatureMatching(nn.Module):
def __init__(self, npoint, use_instance_norm, supporth_th, feature_norm,
max_iter):
super(FeatureMatching, self).__init__()
self.support_th = supporth_th**2 # 10m
self.feature_norm = feature_norm
self.max_iter = max_iter
# Mass regularisation
self.gamma = torch.nn.Parameter(torch.zeros(1))
# Entropic regularisation
self.epsilon = torch.nn.Parameter(torch.zeros(1))
self.sinkhorn = Sinkhorn()
self.extract_glob = FeatureExtractionGlobal(npoint, use_instance_norm)
# upsample flow
self.fp0 = PointNetFeaturePropogation(in_channel=3, mlp=[])
self.sa1 = PointNetSetAbstraction(
npoint=int(npoint / 16),
radius=None,
nsample=16,
in_channel=3,
mlp=[32, 32, 64],
group_all=False,
use_instance_norm=use_instance_norm)
self.fp1 = PointNetFeaturePropogation(in_channel=64, mlp=[])
self.sa2 = PointNetSetAbstraction(
npoint=int(npoint / 8),
radius=None,
nsample=16,
in_channel=64,
mlp=[64, 64, 128],
group_all=False,
use_instance_norm=use_instance_norm)
self.fp2 = PointNetFeaturePropogation(in_channel=128, mlp=[])
self.flow_regressor = FlowRegressor(npoint, use_instance_norm)
self.flow_up_sample = PointNetFeaturePropogation(in_channel=3, mlp=[])
def upsample_flow(self, pc1_l, pc1_l_glob, flow_inp):
"""
flow_inp: [B, N, 3]
return: [B, 3, N]
"""
flow_inp = flow_inp.permute(0, 2, 1).contiguous() # [B, 3, N]
flow_feat = self.fp0(pc1_l_glob['s16'], pc1_l_glob['s32'], None,
flow_inp)
_, corr_feats_l2 = self.sa1(pc1_l_glob['s16'], flow_feat)
flow_feat = self.fp1(pc1_l_glob['s8'], pc1_l_glob['s16'], None,
corr_feats_l2)
_, flow_feat = self.sa2(pc1_l_glob['s8'], flow_feat)
flow_feat = self.fp2(pc1_l['s4'], pc1_l_glob['s8'], None, flow_feat)
flow, flow_lr = self.flow_regressor(pc1_l, flow_feat)
flow_up = self.flow_up_sample(pc1_l['s1'], pc1_l_glob['s32'], None,
flow_inp)
flow_lr_up = self.flow_up_sample(pc1_l['s4'], pc1_l_glob['s32'], None,
flow_inp)
flow, flow_lr = flow + flow_up, flow_lr + flow_lr_up
return flow, flow_lr
def calc_feats_corr(self, pcloud1, pcloud2, feature1, feature2, norm):
"""
pcloud1, pcloud2: [B, N, 3]
feature1, feature2: [B, N, C]
"""
if norm:
feature1 = feature1 / torch.sqrt(
torch.sum(feature1**2, -1, keepdim=True) + 1e-6)
feature2 = feature2 / torch.sqrt(
torch.sum(feature2**2, -1, keepdim=True) + 1e-6)
corr_mat = torch.bmm(feature1,
feature2.transpose(1, 2)) # [B, N1, N2]
else:
corr_mat = torch.bmm(feature1, feature2.transpose(
1, 2)) / feature1.shape[2]**.5 # [B, N1, N2]
if self.support_th is not None:
distance_matrix = torch.sum(
pcloud1**2, -1, keepdim=True) # [B, N1, 1]
distance_matrix = distance_matrix + torch.sum(
pcloud2**2, -1, keepdim=True).transpose(1, 2) # [B, N1, N2]
distance_matrix = distance_matrix - 2 * torch.bmm(
pcloud1, pcloud2.transpose(1, 2)) # [B, N1, N2]
support = (distance_matrix < self.support_th) # [B, N1, N2]
support = support.float()
else:
support = torch.ones_like(corr_mat)
return corr_mat, support
def calc_corr_mat(self, pcloud1, pcloud2, feature1, feature2):
"""
pcloud1, pcloud2: [B, N, 3]
feature1, feature2: [B, N, C]
corr_mat: [B, N1, N2]
"""
epsilon = torch.exp(self.epsilon) + 0.03
corr_mat, support = self.calc_feats_corr(
pcloud1, pcloud2, feature1, feature2, norm=self.feature_norm)
C = 1.0 - corr_mat
corr_mat = torch.exp(-C / epsilon) * support
return corr_mat
def get_flow_init(self, pcloud1, pcloud2, feats1, feats2):
"""
pcloud1, pcloud2: [B, 3, N]
feats1, feats2: [B, C, N]
"""
corr_mat = self.calc_corr_mat(
pcloud1.permute(0, 2, 1), pcloud2.permute(0, 2, 1),
feats1.permute(0, 2, 1), feats2.permute(0, 2, 1))
corr_mat = self.sinkhorn(corr_mat,
torch.exp(self.epsilon) + 0.03, self.gamma,
self.max_iter)
row_sum = corr_mat.sum(-1, keepdim=True) # [B, N1, 1]
flow_init = (corr_mat @ pcloud2.permute(0, 2, 1).contiguous()) / (
row_sum + 1e-6) - pcloud1.permute(0, 2,
1).contiguous() # [B, N1, 3]
return flow_init
def forward(self, pc1_l, pc2_l, feats1, feats2):
"""
pc1_l, pc2_l: dict([B, 3, N])
feats1, feats2: [B, C, N]
"""
pc1_l_glob, feats1_glob = self.extract_glob(pc1_l['s4'], feats1)
pc2_l_glob, feats2_glob = self.extract_glob(pc2_l['s4'], feats2)
flow_init_s32 = self.get_flow_init(pc1_l_glob['s32'],
pc2_l_glob['s32'], feats1_glob,
feats2_glob)
flow_init, flow_init_s4 = self.upsample_flow(pc1_l, pc1_l_glob,
flow_init_s32)
return flow_init, flow_init_s4
class FlowRegressor(nn.Module):
def __init__(self, npoint, use_instance_norm, input_dim=128, nsample=32):
super(FlowRegressor, self).__init__()
self.sa1 = PointNetSetAbstraction(
npoint=int(npoint / 4),
radius=None,
nsample=nsample,
in_channel=input_dim,
mlp=[input_dim, input_dim],
group_all=False,
use_instance_norm=use_instance_norm)
self.sa2 = PointNetSetAbstraction(
npoint=int(npoint / 4),
radius=None,
nsample=nsample,
in_channel=input_dim,
mlp=[input_dim, input_dim],
group_all=False,
use_instance_norm=use_instance_norm)
self.fc = nn.Sequential(
nn.Linear(input_dim, input_dim), nn.ReLU(inplace=True),
nn.Linear(input_dim, 3))
self.up_sample = PointNetFeaturePropogation(in_channel=3, mlp=[])
def forward(self, pc1_l, feats):
"""
pc1_l: dict([B, 3, N])
feats: [B, C, N]
return: [B, 3, N]
"""
_, x = self.sa1(pc1_l['s4'], feats)
_, x = self.sa2(pc1_l['s4'], x)
x = x.permute(0, 2, 1).contiguous() # [B, N, C]
x = self.fc(x)
flow_lr = x.permute(0, 2, 1).contiguous() # [B, 3, N]
flow = self.up_sample(pc1_l['s1'], pc1_l['s4'], None,
flow_lr) # [B, 3, N]
return flow, flow_lr
class FeatureExtractionGlobal(nn.Module):
def __init__(self, npoint, use_instance_norm):
super(FeatureExtractionGlobal, self).__init__()
self.sa1 = PointNetSetAbstraction(
npoint=int(npoint / 8),
radius=None,
nsample=32,
in_channel=64,
mlp=[128, 128, 128],
group_all=False,
use_instance_norm=use_instance_norm)
self.sa2 = PointNetSetAbstraction(
npoint=int(npoint / 16),
radius=None,
nsample=24,
in_channel=128,
mlp=[128, 128, 128],
group_all=False,
use_instance_norm=use_instance_norm)
self.sa3 = PointNetSetAbstraction(
npoint=int(npoint / 32),
radius=None,
nsample=16,
in_channel=128,
mlp=[256, 256, 256],
group_all=False,
use_instance_norm=use_instance_norm)
def forward(self, pc, feature):
pc_l1, feat_l1 = self.sa1(pc, feature)
pc_l2, feat_l2 = self.sa2(pc_l1, feat_l1)
pc_l3, feat_l3 = self.sa3(pc_l2, feat_l2)
pc_l = dict(s8=pc_l1, s16=pc_l2, s32=pc_l3)
return pc_l, feat_l3
class FeatureExtraction(nn.Module):
def __init__(self, npoint, use_instance_norm):
super(FeatureExtraction, self).__init__()
self.sa1 = PointNetSetAbstraction(
npoint=int(npoint / 2),
radius=None,
nsample=32,
in_channel=3,
mlp=[32, 32, 32],
group_all=False,
return_fps=True,
use_instance_norm=use_instance_norm)
self.sa2 = PointNetSetAbstraction(
npoint=int(npoint / 4),
radius=None,
nsample=32,
in_channel=32,
mlp=[64, 64, 64],
group_all=False,
return_fps=True,
use_instance_norm=use_instance_norm)
def forward(self, pc, feature, fps_idx=None):
"""
pc: [B, 3, N]
feature: [B, 3, N]
"""
fps_idx1 = fps_idx['s2'] if fps_idx is not None else None
pc_l1, feat_l1, fps_idx1 = self.sa1(pc, feature, fps_idx=fps_idx1)
fps_idx2 = fps_idx['s4'] if fps_idx is not None else None
pc_l2, feat_l2, fps_idx2 = self.sa2(pc_l1, feat_l1, fps_idx=fps_idx2)
pc_l = dict(s1=pc, s2=pc_l1, s4=pc_l2)
fps_idx = dict(s2=fps_idx1, s4=fps_idx2)
return pc_l, feat_l2, fps_idx
class HiddenInitNet(nn.Module):
def __init__(self, npoint, use_instance_norm):
super(HiddenInitNet, self).__init__()
self.sa1 = PointNetSetAbstraction(
npoint=int(npoint / 4),
radius=None,
nsample=8,
in_channel=64,
mlp=[128, 128, 128],
group_all=False,
use_instance_norm=use_instance_norm)
self.sa2 = PointNetSetAbstraction(
npoint=int(npoint / 4),
radius=None,
nsample=8,
in_channel=128,
mlp=[128],
group_all=False,
use_act=False,
use_instance_norm=use_instance_norm)
def forward(self, pc, feature):
_, feat_l1 = self.sa1(pc, feature)
_, feat_l2 = self.sa2(pc, feat_l1)
h_init = torch.tanh(feat_l2)
return h_init
class GRUReg(nn.Module):
def __init__(self, npoint, hidden_dim, input_dim, use_instance_norm):
super().__init__()
in_ch = hidden_dim + input_dim
self.flow_proj = nn.ModuleList([
PointNetSetAbstraction(
npoint=int(npoint / 4),
radius=None,
nsample=16,
in_channel=3,
mlp=[32, 32, 32],
group_all=False,
use_instance_norm=use_instance_norm),
PointNetSetAbstraction(
npoint=int(npoint / 4),
radius=None,
nsample=8,
in_channel=32,
mlp=[16, 16, 16],
group_all=False,
use_instance_norm=use_instance_norm)
])
self.hidden_init_net = HiddenInitNet(npoint, use_instance_norm)
self.gru_layers = nn.ModuleList([
PointNetSetAbstraction(
npoint=int(npoint / 4),
radius=None,
nsample=4,
in_channel=in_ch,
mlp=[hidden_dim],
group_all=False,
use_act=False,
use_instance_norm=use_instance_norm),
PointNetSetAbstraction(
npoint=int(npoint / 4),
radius=None,
nsample=4,
in_channel=in_ch,
mlp=[hidden_dim],
group_all=False,
use_act=False,
use_instance_norm=use_instance_norm),
PointNetSetAbstraction(
npoint=int(npoint / 4),
radius=None,
nsample=4,
in_channel=in_ch,
mlp=[hidden_dim],
group_all=False,
use_act=False,
use_instance_norm=use_instance_norm)
])
def gru(self, h, gru_inp, pc):
hx = torch.cat([h, gru_inp], dim=1)
z = torch.sigmoid(self.gru_layers[0](pc, hx)[1])
r = torch.sigmoid(self.gru_layers[1](pc, hx)[1])
q = torch.tanh(self.gru_layers[2](pc, torch.cat([r * h, gru_inp],
dim=1))[1])
h = (1 - z) * h + z * q
return h
def get_gru_input(self, feats1_new, cost, flow, pc):
flow_feats = flow
for flow_conv in self.flow_proj:
_, flow_feats = flow_conv(pc, flow_feats)
gru_inp = torch.cat([feats1_new, cost, flow_feats, flow],
dim=1) # [64, 128, 16, 3]
return gru_inp
def forward(self, h, feats1_new, cost, flow_lr, pc1_l):
gru_inp = self.get_gru_input(feats1_new, cost, flow_lr, pc=pc1_l['s4'])
h = self.gru(h, gru_inp, pc1_l['s4'])
return h
class SF_RCP(nn.Module):
def __init__(self, npoint=8192, use_instance_norm=False, **kwargs):
super().__init__()
self.radius = kwargs.get('radius', 3.5)
self.nsample = kwargs.get('nsample', 6)
self.radius_min = kwargs.get('radius_min', 3.5)
self.nsample_min = kwargs.get('nsample_min', 6)
self.use_curvature = kwargs.get('use_curvature', True)
self.flow_ratio = kwargs.get('flow_ratio', 0.1)
self.init_max_iter = kwargs.get('init_max_iter', 0)
self.init_feature_norm = kwargs.get('init_feature_norm', True)
self.support_th = kwargs.get('support_th', 10)
self.feature_extraction = FeatureExtraction(npoint, use_instance_norm)
self.feature_matching = FeatureMatching(
npoint,
use_instance_norm,
supporth_th=self.support_th,
feature_norm=self.init_feature_norm,
max_iter=self.init_max_iter)
self.pointwise_optim_layer = PointWiseOptimLayer(
nsample=self.nsample,
radius=self.radius,
in_channel=64,
mlp=[128, 128, 128],
use_curvature=self.use_curvature)
self.gru = GRUReg(
npoint,
hidden_dim=128,
input_dim=128 + 64 + 16 + 3,
use_instance_norm=use_instance_norm)
self.flow_regressor = FlowRegressor(npoint, use_instance_norm)
def initialization(self, pc1_l, pc2_l, feats1, feats2):
"""
pc1: [B, 3, N]
pc2: [B, 3, N]
feature1: [B, 3, N]
feature2: [B, 3, N]
"""
flow, flow_lr = self.feature_matching(pc1_l, pc2_l, feats1, feats2)
return flow, flow_lr
def pointwise_optimization(self, pc1_l_new, pc2_l, feats1_new, feats2,
pc1_l, flow_lr, iter):
_, cost, score, pos2_grouped = self.pointwise_optim_layer(
pc1_l_new['s4'],
pc2_l['s4'],
feats1_new,
feats2,
nsample=max(self.nsample_min, self.nsample // (2**iter)),
radius=max(self.radius_min, self.radius / (2**iter)),
pos1_raw=pc1_l['s4'],
return_score=True)
# pc1_new_l_loc: [B, 3, N, S]
# pos2_grouped: [B, C, N, S]
delta_flow_tmp = ((pos2_grouped - pc1_l_new['s4'].unsqueeze(-1))
* score.mean(dim=1, keepdim=True)).sum(
dim=-1) # [B, 3, N]
flow_lr = flow_lr + self.flow_ratio * delta_flow_tmp
return flow_lr, cost
def update_pos(self, pc, pc_lr, flow, flow_lr):
pc = pc + flow
pc_lr = pc_lr + flow_lr
return pc, pc_lr
def forward(self, pc1, pc2, feature1, feature2, iters=1):
"""
pc1: [B, N, 3]
pc2: [B, N, 3]
feature1: [B, N, 3]
feature2: [B, N, 3]
"""
# prepare
flow_predictions = []
pc1 = pc1.permute(0, 2, 1).contiguous() # B 3 N
pc2 = pc2.permute(0, 2, 1).contiguous() # B 3 N
feature1 = feature1.permute(0, 2, 1).contiguous() # B 3 N
feature2 = feature2.permute(0, 2, 1).contiguous() # B 3 N
# feature extraction
pc1_l, feats1, fps_idx1 = self.feature_extraction(pc1, feature1)
pc2_l, feats2, _ = self.feature_extraction(pc2, feature2)
# initialization, flow_lr_init(flow_low_resolution)
flow_init, flow_lr_init = self.initialization(pc1_l, pc2_l, feats1,
feats2)
flow_predictions.append(flow_init.permute(0, 2, 1))
# gru init hidden state
h = self.gru.hidden_init_net(pc1_l['s4'], feats1)
# update position
pc1_lr_raw = pc1_l['s4']
pc1_new, pc1_lr_new = self.update_pos(pc1, pc1_lr_raw, flow_init,
flow_lr_init)
# iterative optim
for iter in range(iters - 1):
pc1_new = pc1_new.detach()
pc1_lr_new = pc1_lr_new.detach()
flow_lr = pc1_lr_new - pc1_lr_raw
pc1_l_new, feats1_new, _ = self.feature_extraction(
pc1_new, pc1_new, fps_idx1)
# pointwise optimization to get udpated flow_lr and cost
flow_lr_update, cost = self.pointwise_optimization(
pc1_l_new, pc2_l, feats1_new, feats2, pc1_l, flow_lr, iter)
flow_lr = flow_lr_update
# gru regularization
h = self.gru(h, feats1_new, cost, flow_lr, pc1_l)
# pred flow_lr
delta_flow, delta_flow_lr = self.flow_regressor(pc1_l, h)
pc1_new, pc1_lr_new = self.update_pos(pc1_new, pc1_lr_new,
delta_flow, delta_flow_lr)
flow = pc1_new - pc1
flow_predictions.append(flow.permute(0, 2, 1))
return flow_predictions

View File

@@ -50,6 +50,8 @@ class OutputKeys(object):
SCENE_NUM = 'scene_num'
SCENE_META_LIST = 'scene_meta_list'
SHOT_META_LIST = 'shot_meta_list'
PCD12 = 'pcd12'
PCD12_ALIGN = 'pcd12_align'
TASK_OUTPUTS = {

View File

@@ -70,6 +70,7 @@ if TYPE_CHECKING:
from .image_skychange_pipeline import ImageSkychangePipeline
from .video_object_segmentation_pipeline import VideoObjectSegmentationPipeline
from .video_super_resolution_pipeline import VideoSuperResolutionPipeline
from .pointcloud_sceneflow_estimation_pipeline import PointCloudSceneFlowEstimationPipeline
else:
_import_structure = {
@@ -162,6 +163,9 @@ else:
'VideoObjectSegmentationPipeline'
],
'video_super_resolution_pipeline': ['VideoSuperResolutionPipeline'],
'pointcloud_sceneflow_estimation_pipeline': [
'PointCloudSceneFlowEstimationPipeline'
]
}
import sys

View File

@@ -0,0 +1,114 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict, Union
import numpy as np
import torch
from plyfile import PlyData, PlyElement
from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Model, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import Tasks
from modelscope.utils.cv.image_utils import depth_to_color
from modelscope.utils.logger import get_logger
logger = get_logger()
@PIPELINES.register_module(
Tasks.pointcloud_sceneflow_estimation,
module_name=Pipelines.pointcloud_sceneflow_estimation)
class PointCloudSceneFlowEstimationPipeline(Pipeline):
def __init__(self, model: str, **kwargs):
"""
use `model` to create a image depth estimation pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model, **kwargs)
logger.info('pointcloud scenflow estimation model, pipeline init')
def check_input_pcd(self, pcd):
assert pcd.ndim == 2, 'pcd ndim must equal to 2'
assert pcd.shape[1] == 3, 'pcd.shape[1] must equal to 3'
def preprocess(self, input: Input) -> Dict[str, Any]:
assert isinstance(input, tuple), 'only support tuple input'
assert isinstance(input[0], str) and isinstance(
input[1], str), 'only support tuple input with str type'
pcd1_file, pcd2_file = input
logger.info(f'input pcd file:{pcd1_file}, \n {pcd2_file}')
pcd1 = np.load(pcd1_file)
pcd2 = np.load(pcd2_file)
self.check_input_pcd(pcd1)
self.check_input_pcd(pcd2)
pcd1_torch = torch.from_numpy(pcd1).float().unsqueeze(0).cuda()
pcd2_torch = torch.from_numpy(pcd2).float().unsqueeze(0).cuda()
data = {
'pcd1': pcd1_torch,
'pcd2': pcd2_torch,
'pcd1_ori': pcd1,
'pcd2_ori': pcd2
}
return data
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
results = {}
output = self.model.inference(input)
results['output'] = output
results['pcd1_ori'] = input['pcd1_ori']
results['pcd2_ori'] = input['pcd2_ori']
return results
def save_ply_data(self, pcd1, pcd2):
vertexs = np.concatenate([pcd1, pcd2], axis=0)
color1 = np.array([[255, 0, 0]], dtype=np.uint8)
color2 = np.array([[0, 255, 0]], dtype=np.uint8)
color1 = np.tile(color1, (pcd1.shape[0], 1))
color2 = np.tile(color2, (pcd2.shape[0], 1))
vertex_colors = np.concatenate([color1, color2], axis=0)
vertexs = np.array([tuple(v) for v in vertexs],
dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])
vertex_colors = np.array([tuple(v) for v in vertex_colors],
dtype=[('red', 'u1'), ('green', 'u1'),
('blue', 'u1')])
vertex_all = np.empty(
len(vertexs), vertexs.dtype.descr + vertex_colors.dtype.descr)
for prop in vertexs.dtype.names:
vertex_all[prop] = vertexs[prop]
for prop in vertex_colors.dtype.names:
vertex_all[prop] = vertex_colors[prop]
el = PlyElement.describe(vertex_all, 'vertex')
ply_data = PlyData([el])
# .write(save_name)
return ply_data
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
results = self.model.postprocess(inputs)
flow = results[OutputKeys.OUTPUT]
pcd1 = inputs['pcd1_ori']
pcd2 = inputs['pcd2_ori']
if isinstance(pcd1, torch.Tensor):
pcd1 = pcd1.cpu().numpy()
if isinstance(pcd2, torch.Tensor):
pcd2 = pcd2.cpu().numpy()
if isinstance(flow, torch.Tensor):
flow = flow.cpu().numpy()
outputs = {
OutputKeys.OUTPUT: flow,
OutputKeys.PCD12: self.save_ply_data(pcd1, pcd2),
OutputKeys.PCD12_ALIGN: self.save_ply_data(pcd1 + flow, pcd2),
}
return outputs

View File

@@ -104,6 +104,9 @@ class CVTasks(object):
video_summarization = 'video-summarization'
image_reid_person = 'image-reid-person'
# pointcloud task
pointcloud_sceneflow_estimation = 'pointcloud-sceneflow-estimation'
class NLPTasks(object):
# nlp tasks

View File

@@ -24,6 +24,7 @@ onnxruntime>=1.10
opencv-python
pai-easycv>=0.6.3.9
pandas
plyfile>=0.7.4
psutil
regex
scikit-image>=0.19.3

View File

@@ -0,0 +1,42 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
import numpy as np
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.test_utils import test_level
class PointCloudSceneFlowEstimationTest(unittest.TestCase,
DemoCompatibilityCheck):
def setUp(self) -> None:
self.task = 'pointcloud-sceneflow-estimation'
self.model_id = 'damo/cv_pointnet2_sceneflow-estimation_general'
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_pointcloud_scenelfow_estimation(self):
input_location = ('data/test/pointclouds/flyingthings_pcd1.npy',
'data/test/pointclouds/flyingthings_pcd2.npy')
estimator = pipeline(
Tasks.pointcloud_sceneflow_estimation, model=self.model_id)
result = estimator(input_location)
flow = result[OutputKeys.OUTPUT]
pcd12 = result[OutputKeys.PCD12]
pcd12_align = result[OutputKeys.PCD12_ALIGN]
print(f'pred flow shape:{flow.shape}')
np.save('flow.npy', flow)
# visualization
pcd12.write('pcd12.ply')
pcd12_align.write('pcd12_align.ply')
print('test_pointcloud_scenelfow_estimation DONE')
if __name__ == '__main__':
unittest.main()