mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-14 15:27:42 +01:00
add cv_pointnet2_sceneflow-estimation_general
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11201880
This commit is contained in:
1
.gitattributes
vendored
1
.gitattributes
vendored
@@ -7,3 +7,4 @@
|
||||
*.pickle filter=lfs diff=lfs merge=lfs -text
|
||||
*.avi filter=lfs diff=lfs merge=lfs -text
|
||||
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||
*.npy filter=lfs diff=lfs merge=lfs -text
|
||||
|
||||
3
data/test/pointclouds/flyingthings_pcd1.npy
Normal file
3
data/test/pointclouds/flyingthings_pcd1.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9fa9f5c8a49d457a7b6f4239e438699e60541e7602e8b3b66da9f7b6d55096ab
|
||||
size 1735856
|
||||
3
data/test/pointclouds/flyingthings_pcd2.npy
Normal file
3
data/test/pointclouds/flyingthings_pcd2.npy
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:86618feded6ae9fbcc772b9a7da17bad7d8b9c68ae0d505a239d110a3a0a7bf4
|
||||
size 1735856
|
||||
@@ -62,6 +62,7 @@ class Models(object):
|
||||
video_human_matting = 'video-human-matting'
|
||||
video_object_segmentation = 'video-object-segmentation'
|
||||
real_basicvsr = 'real-basicvsr'
|
||||
rcp_sceneflow_estimation = 'rcp-sceneflow-estimation'
|
||||
|
||||
# EasyCV models
|
||||
yolox = 'YOLOX'
|
||||
@@ -253,6 +254,7 @@ class Pipelines(object):
|
||||
video_human_matting = 'video-human-matting'
|
||||
video_object_segmentation = 'video-object-segmentation'
|
||||
video_super_resolution = 'realbasicvsr-video-super-resolution'
|
||||
pointcloud_sceneflow_estimation = 'pointcloud-sceneflow-estimation'
|
||||
|
||||
# nlp tasks
|
||||
automatic_post_editing = 'automatic-post-editing'
|
||||
|
||||
@@ -11,10 +11,11 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints,
|
||||
image_reid_person, image_semantic_segmentation,
|
||||
image_to_image_generation, image_to_image_translation,
|
||||
language_guided_video_summarization, movie_scene_segmentation,
|
||||
object_detection, product_retrieval_embedding,
|
||||
realtime_object_detection, referring_video_object_segmentation,
|
||||
salient_detection, shop_segmentation, super_resolution,
|
||||
video_object_segmentation, video_single_object_tracking,
|
||||
video_summarization, video_super_resolution, virual_tryon)
|
||||
object_detection, pointcloud_sceneflow_estimation,
|
||||
product_retrieval_embedding, realtime_object_detection,
|
||||
referring_video_object_segmentation, salient_detection,
|
||||
shop_segmentation, super_resolution, video_object_segmentation,
|
||||
video_single_object_tracking, video_summarization,
|
||||
video_super_resolution, virual_tryon)
|
||||
|
||||
# yapf: enable
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .rcp_model import SceneFlowEstimation
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'rcp_model': ['SceneFlowEstimation'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
446
modelscope/models/cv/pointcloud_sceneflow_estimation/common.py
Normal file
446
modelscope/models/cv/pointcloud_sceneflow_estimation/common.py
Normal file
@@ -0,0 +1,446 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from . import pointnet2_utils as pointutils
|
||||
|
||||
RADIUS = 2.5
|
||||
|
||||
|
||||
def index_points_group(points, knn_idx):
|
||||
"""
|
||||
Input:
|
||||
points: input points data, [B, N, C]
|
||||
knn_idx: sample index data, [B, N, K]
|
||||
Return:
|
||||
new_points:, indexed points data, [B, N, K, C]
|
||||
"""
|
||||
points_flipped = points.permute(0, 2, 1).contiguous()
|
||||
new_points = pointutils.grouping_operation(points_flipped,
|
||||
knn_idx.int()).permute(
|
||||
0, 2, 3, 1)
|
||||
|
||||
return new_points
|
||||
|
||||
|
||||
def curvature(pc, nsample=10, radius=RADIUS):
|
||||
# pc: B 3 N
|
||||
assert pc.shape[1] == 3
|
||||
pc = pc.permute(0, 2, 1)
|
||||
|
||||
dist, kidx = pointutils.knn(nsample, pc.contiguous(),
|
||||
pc.contiguous()) # (B, N, 10)
|
||||
|
||||
if radius is not None:
|
||||
tmp_idx = kidx[:, :, 0].unsqueeze(2).repeat(1, 1,
|
||||
nsample).to(kidx.device)
|
||||
kidx[dist > radius] = tmp_idx[dist > radius]
|
||||
|
||||
grouped_pc = index_points_group(pc, kidx) # B N 10 3
|
||||
pc_curvature = torch.sum(grouped_pc - pc.unsqueeze(2), dim=2) / 9.0
|
||||
return pc_curvature # B N 3
|
||||
|
||||
|
||||
class PointNetSetAbstractionRatio(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
ratio,
|
||||
radius,
|
||||
nsample,
|
||||
in_channel,
|
||||
mlp,
|
||||
group_all,
|
||||
return_fps=False,
|
||||
use_xyz=True,
|
||||
use_act=True,
|
||||
act=F.relu,
|
||||
mean_aggr=False,
|
||||
use_instance_norm=False):
|
||||
super(PointNetSetAbstractionRatio, self).__init__()
|
||||
self.ratio = ratio
|
||||
self.radius = radius
|
||||
self.nsample = nsample
|
||||
self.group_all = group_all
|
||||
self.use_xyz = use_xyz
|
||||
self.use_act = use_act
|
||||
self.mean_aggr = mean_aggr
|
||||
self.act = act
|
||||
self.mlp_convs = nn.ModuleList()
|
||||
self.mlp_bns = nn.ModuleList()
|
||||
last_channel = (in_channel + 3) if use_xyz else in_channel
|
||||
for out_channel in mlp:
|
||||
self.mlp_convs.append(
|
||||
nn.Conv2d(last_channel, out_channel, 1, bias=False))
|
||||
if use_instance_norm:
|
||||
self.mlp_bns.append(
|
||||
nn.InstanceNorm2d(out_channel, affine=True))
|
||||
else:
|
||||
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
|
||||
|
||||
last_channel = out_channel
|
||||
|
||||
if group_all:
|
||||
self.queryandgroup = pointutils.GroupAll(self.use_xyz)
|
||||
else:
|
||||
self.queryandgroup = pointutils.QueryAndGroup(
|
||||
radius, nsample, self.use_xyz)
|
||||
self.return_fps = return_fps
|
||||
|
||||
def forward(self, xyz, points, fps_idx=None):
|
||||
"""
|
||||
Input:
|
||||
xyz: input points position data, [B, C, N]
|
||||
points: input points data, [B, D, N]
|
||||
Return:
|
||||
new_xyz: sampled points position data, [B, C, S]
|
||||
new_points: sample points feature data, [B, D', S]
|
||||
"""
|
||||
B, C, N = xyz.shape
|
||||
npoint = int(N * self.ratio)
|
||||
|
||||
xyz = xyz.contiguous()
|
||||
xyz_t = xyz.permute(0, 2, 1).contiguous()
|
||||
|
||||
if (self.group_all is False) and (npoint != -1):
|
||||
if fps_idx is None:
|
||||
fps_idx = pointutils.furthest_point_sample(xyz_t,
|
||||
npoint) # [B, N]
|
||||
new_xyz = pointutils.gather_operation(xyz, fps_idx) # [B, C, N]
|
||||
else:
|
||||
new_xyz = xyz
|
||||
new_points, _ = self.queryandgroup(xyz_t,
|
||||
new_xyz.transpose(2,
|
||||
1).contiguous(),
|
||||
points) # [B, 3+C, N, S]
|
||||
|
||||
# new_xyz: sampled points position data, [B, C, npoint]
|
||||
# new_points: sampled points data, [B, C+D, npoint, nsample]
|
||||
for i, conv in enumerate(self.mlp_convs):
|
||||
if self.use_act:
|
||||
bn = self.mlp_bns[i]
|
||||
new_points = self.act(bn(conv(new_points)))
|
||||
else:
|
||||
new_points = conv(new_points)
|
||||
|
||||
if self.mean_aggr:
|
||||
new_points = torch.mean(new_points, -1)
|
||||
else:
|
||||
new_points = torch.max(new_points, -1)[0]
|
||||
|
||||
if self.return_fps:
|
||||
return new_xyz, new_points, fps_idx
|
||||
else:
|
||||
return new_xyz, new_points
|
||||
|
||||
|
||||
class PointNetSetAbstraction(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
npoint,
|
||||
radius,
|
||||
nsample,
|
||||
in_channel,
|
||||
mlp,
|
||||
group_all,
|
||||
return_fps=False,
|
||||
use_xyz=True,
|
||||
use_act=True,
|
||||
act=F.relu,
|
||||
mean_aggr=False,
|
||||
use_instance_norm=False):
|
||||
super(PointNetSetAbstraction, self).__init__()
|
||||
self.npoint = npoint
|
||||
self.radius = radius
|
||||
self.nsample = nsample
|
||||
self.group_all = group_all
|
||||
self.use_xyz = use_xyz
|
||||
self.use_act = use_act
|
||||
self.mean_aggr = mean_aggr
|
||||
self.act = act
|
||||
self.mlp_convs = nn.ModuleList()
|
||||
self.mlp_bns = nn.ModuleList()
|
||||
last_channel = (in_channel + 3) if use_xyz else in_channel
|
||||
for out_channel in mlp:
|
||||
self.mlp_convs.append(
|
||||
nn.Conv2d(last_channel, out_channel, 1, bias=False))
|
||||
if use_instance_norm:
|
||||
self.mlp_bns.append(
|
||||
nn.InstanceNorm2d(out_channel, affine=True))
|
||||
else:
|
||||
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
|
||||
|
||||
last_channel = out_channel
|
||||
|
||||
if group_all:
|
||||
self.queryandgroup = pointutils.GroupAll(self.use_xyz)
|
||||
else:
|
||||
self.queryandgroup = pointutils.QueryAndGroup(
|
||||
radius, nsample, self.use_xyz)
|
||||
self.return_fps = return_fps
|
||||
|
||||
def forward(self, xyz, points, fps_idx=None):
|
||||
"""
|
||||
Input:
|
||||
xyz: input points position data, [B, C, N]
|
||||
points: input points data, [B, D, N]
|
||||
Return:
|
||||
new_xyz: sampled points position data, [B, S, C]
|
||||
new_points: sample points feature data, [B, S, D']
|
||||
"""
|
||||
# device = xyz.device
|
||||
B, C, N = xyz.shape
|
||||
xyz = xyz.contiguous()
|
||||
xyz_t = xyz.permute(0, 2, 1).contiguous()
|
||||
|
||||
if (self.group_all is False) and (self.npoint != -1):
|
||||
if fps_idx is None:
|
||||
fps_idx = pointutils.furthest_point_sample(
|
||||
xyz_t, self.npoint) # [B, N]
|
||||
new_xyz = pointutils.gather_operation(xyz, fps_idx) # [B, C, N]
|
||||
else:
|
||||
new_xyz = xyz
|
||||
new_points, _ = self.queryandgroup(xyz_t,
|
||||
new_xyz.transpose(2,
|
||||
1).contiguous(),
|
||||
points) # [B, 3+C, N, S]
|
||||
|
||||
# new_xyz: sampled points position data, [B, C, npoint]
|
||||
# new_points: sampled points data, [B, C+D, npoint, nsample]
|
||||
for i, conv in enumerate(self.mlp_convs):
|
||||
if self.use_act:
|
||||
bn = self.mlp_bns[i]
|
||||
new_points = self.act(bn(conv(new_points)))
|
||||
else:
|
||||
new_points = conv(new_points)
|
||||
|
||||
if self.mean_aggr:
|
||||
new_points = torch.mean(new_points, -1)
|
||||
else:
|
||||
new_points = torch.max(new_points, -1)[0]
|
||||
|
||||
if self.return_fps:
|
||||
return new_xyz, new_points, fps_idx
|
||||
else:
|
||||
return new_xyz, new_points
|
||||
|
||||
|
||||
class PointNetFeaturePropogation(nn.Module):
|
||||
|
||||
def __init__(self, in_channel, mlp, learn_mask=False, nsample=3):
|
||||
super(PointNetFeaturePropogation, self).__init__()
|
||||
self.mlp_convs = nn.ModuleList()
|
||||
self.mlp_bns = nn.ModuleList()
|
||||
self.apply_mlp = mlp is not None
|
||||
last_channel = in_channel
|
||||
self.nsample = nsample
|
||||
if self.apply_mlp:
|
||||
for out_channel in mlp:
|
||||
self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
|
||||
self.mlp_bns.append(nn.BatchNorm1d(out_channel))
|
||||
last_channel = out_channel
|
||||
|
||||
if learn_mask:
|
||||
self.queryandgroup = pointutils.QueryAndGroup(
|
||||
None, 9, use_xyz=True)
|
||||
last_channel = (128 + 3)
|
||||
for out_channel in [32, 1]:
|
||||
self.mlp_convs.append(
|
||||
nn.Conv2d(last_channel, out_channel, 1, bias=False))
|
||||
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
|
||||
last_channel = out_channel
|
||||
|
||||
def forward(self, pos1, pos2, feature1, feature2, hidden=None):
|
||||
"""
|
||||
Input:
|
||||
pos1: input points position data, [B, C, N]
|
||||
pos2: sampled input points position data, [B, C, S]
|
||||
feature1: input points data, [B, D, N]
|
||||
feature2: input points data, [B, D, S]
|
||||
Return:
|
||||
feat_new: upsampled points data, [B, D', N]
|
||||
"""
|
||||
pos1_t = pos1.permute(0, 2, 1).contiguous()
|
||||
pos2_t = pos2.permute(0, 2, 1).contiguous()
|
||||
B, C, N = pos1.shape
|
||||
|
||||
if hidden is None:
|
||||
if self.nsample == 3:
|
||||
dists, idx = pointutils.three_nn(pos1_t, pos2_t)
|
||||
else:
|
||||
dists, idx = pointutils.knn(self.nsample, pos1_t, pos2_t)
|
||||
dists[dists < 1e-10] = 1e-10
|
||||
weight = 1.0 / dists
|
||||
weight = weight / torch.sum(weight, -1, keepdim=True) # [B,N,3]
|
||||
interpolated_feat = torch.sum(
|
||||
pointutils.grouping_operation(feature2, idx)
|
||||
* weight.view(B, 1, N, self.nsample),
|
||||
dim=-1) # [B,C,N,3]
|
||||
else:
|
||||
dist, idx = pointutils.knn(9, pos1_t, pos2_t)
|
||||
|
||||
new_feat, _ = self.queryandgroup(pos2_t, pos1_t,
|
||||
hidden) # [B, 3+C, N, 9]
|
||||
|
||||
for i, conv in enumerate(self.mlp_convs):
|
||||
new_feat = conv(new_feat)
|
||||
weight = torch.softmax(new_feat, dim=-1) # [B, 1, N, 9]
|
||||
interpolated_feat = torch.sum(
|
||||
pointutils.grouping_operation(feature2, idx) * weight,
|
||||
dim=-1) # [B, C, N]
|
||||
|
||||
if feature1 is not None:
|
||||
feat_new = torch.cat([interpolated_feat, feature1], 1)
|
||||
else:
|
||||
feat_new = interpolated_feat
|
||||
|
||||
if self.apply_mlp:
|
||||
for i, conv in enumerate(self.mlp_convs):
|
||||
bn = self.mlp_bns[i]
|
||||
feat_new = F.relu(bn(conv(feat_new)))
|
||||
return feat_new
|
||||
|
||||
|
||||
class Sinkhorn(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, corr, epsilon, gamma, max_iter):
|
||||
# Early return if no iteration
|
||||
if max_iter == 0:
|
||||
return corr
|
||||
|
||||
# Init. of Sinkhorn algorithm
|
||||
power = gamma / (gamma + epsilon)
|
||||
a = (
|
||||
torch.ones((corr.shape[0], corr.shape[1], 1),
|
||||
device=corr.device,
|
||||
dtype=corr.dtype) / corr.shape[1])
|
||||
prob1 = (
|
||||
torch.ones((corr.shape[0], corr.shape[1], 1),
|
||||
device=corr.device,
|
||||
dtype=corr.dtype) / corr.shape[1])
|
||||
prob2 = (
|
||||
torch.ones((corr.shape[0], corr.shape[2], 1),
|
||||
device=corr.device,
|
||||
dtype=corr.dtype) / corr.shape[2])
|
||||
|
||||
# Sinkhorn algorithm
|
||||
for _ in range(max_iter):
|
||||
# Update b
|
||||
KTa = torch.bmm(corr.transpose(1, 2), a)
|
||||
b = torch.pow(prob2 / (KTa + 1e-8), power)
|
||||
# Update a
|
||||
Kb = torch.bmm(corr, b)
|
||||
a = torch.pow(prob1 / (Kb + 1e-8), power)
|
||||
|
||||
# Transportation map
|
||||
T = torch.mul(torch.mul(a, corr), b.transpose(1, 2))
|
||||
|
||||
return T
|
||||
|
||||
|
||||
class PointWiseOptimLayer(nn.Module):
|
||||
|
||||
def __init__(self, nsample, radius, in_channel, mlp, use_curvature=True):
|
||||
super().__init__()
|
||||
self.nsample = nsample
|
||||
self.radius = radius
|
||||
self.use_curvature = use_curvature
|
||||
|
||||
self.pos_embed = nn.Sequential(
|
||||
nn.Conv1d(3, 32, 1), nn.ReLU(inplace=True), nn.Conv1d(32, 64, 1))
|
||||
|
||||
self.qk_net = nn.Sequential(
|
||||
nn.Conv1d(in_channel + 64, in_channel + 64, 1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv1d(in_channel + 64, in_channel + 64, 1))
|
||||
if self.use_curvature:
|
||||
self.curvate_net = nn.Sequential(
|
||||
nn.Conv1d(3, 32, 1), nn.ReLU(inplace=True),
|
||||
nn.Conv1d(32, 32, 1))
|
||||
self.mlp_conv = nn.Conv1d(
|
||||
in_channel + 64 + 32, mlp[-1], 1, bias=True)
|
||||
else:
|
||||
self.mlp_conv = nn.Conv1d(in_channel + 64, mlp[-1], 1, bias=True)
|
||||
|
||||
def forward(self,
|
||||
pos1,
|
||||
pos2,
|
||||
feature1,
|
||||
feature2,
|
||||
nsample,
|
||||
radius=None,
|
||||
pos1_raw=None,
|
||||
return_score=False):
|
||||
"""
|
||||
Input:
|
||||
pos1: (batch_size, 3, npoint)
|
||||
pos2: (batch_size, 3, npoint)
|
||||
feature1: (batch_size, channel, npoint)
|
||||
feature2: (batch_size, channel, npoint)
|
||||
Output:
|
||||
pos1: (batch_size, 3, npoint)
|
||||
cost: (batch_size, channel, npoint)
|
||||
"""
|
||||
pos1_t = pos1.permute(0, 2, 1).contiguous()
|
||||
pos2_t = pos2.permute(0, 2, 1).contiguous()
|
||||
self.nsample = nsample
|
||||
self.radius = radius
|
||||
|
||||
dist, idx = pointutils.knn(self.nsample, pos1_t, pos2_t) # [B, N, K]
|
||||
if self.radius is not None:
|
||||
tmp_idx = idx[:, :,
|
||||
0].unsqueeze(2).repeat(1, 1,
|
||||
self.nsample).to(idx.device)
|
||||
idx[dist > self.radius] = tmp_idx[dist > self.radius]
|
||||
|
||||
pos1_embed_norm = self.pos_embed(pos1)
|
||||
pos2_embed_norm = self.pos_embed(pos2) # [B, C1, N]
|
||||
|
||||
feat1_w_pos = torch.cat([feature1, pos1_embed_norm], dim=1)
|
||||
feat2_w_pos = torch.cat([feature2, pos2_embed_norm],
|
||||
dim=1) # [B, C1+C2, N]
|
||||
|
||||
feat1_w_pos = self.qk_net(feat1_w_pos)
|
||||
feat2_w_pos = self.qk_net(feat2_w_pos) # [B, C1+C2, N]
|
||||
|
||||
feat2_grouped = pointutils.grouping_operation(feat2_w_pos,
|
||||
idx) # [B, C1+C2, N, S]
|
||||
|
||||
score = torch.softmax(
|
||||
feat1_w_pos.unsqueeze(-1) * feat2_grouped * 1.
|
||||
/ math.sqrt(feat1_w_pos.shape[1]),
|
||||
dim=-1) # [B, C1+C2, N, S]
|
||||
cost = (score * (feat1_w_pos.unsqueeze(-1) - feat2_grouped)**2).sum(
|
||||
dim=-1) # [B, C1+C2, N]
|
||||
|
||||
if self.use_curvature:
|
||||
curvate1_raw = curvature(pos1_raw).permute(0, 2, 1) # [B, 3, N]
|
||||
curvate1 = curvature(pos1).permute(0, 2, 1) # [B, 3, N]
|
||||
curvate_cost = self.curvate_net(curvate1_raw) - self.curvate_net(
|
||||
curvate1)
|
||||
curvate_cost = curvate_cost**2
|
||||
cost = self.mlp_conv(torch.cat([cost, curvate_cost],
|
||||
dim=1)) # [B, C, N]
|
||||
else:
|
||||
cost = self.mlp_conv(cost) # [B, C, N]
|
||||
|
||||
if return_score:
|
||||
pos2_grouped = pointutils.grouping_operation(pos2,
|
||||
idx) # [B, 3, N, S]
|
||||
# [B, N, K]
|
||||
index = (dist > self.radius).sum(
|
||||
dim=2, keepdim=True).float() > (dist.shape[2] - 0.1
|
||||
) # [B, N, 1]
|
||||
index = index.unsqueeze(1).repeat(1, score.shape[1], 1,
|
||||
dist.shape[2]) # [B, N, K]
|
||||
score_tmp = score.clone()
|
||||
score_tmp[index] = 0.0
|
||||
score = score_tmp
|
||||
return pos1, cost, score, pos2_grouped
|
||||
else:
|
||||
return pos1, cost
|
||||
@@ -0,0 +1,360 @@
|
||||
# The implementation is adopt from PointNet2, open-sourced under MIT license,
|
||||
# made publicy available at https://github.com/sshaoshuai/Pointnet2.PyTorch
|
||||
from typing import Tuple
|
||||
|
||||
import pointnet2_cuda as pointnet2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Function, Variable
|
||||
|
||||
|
||||
class FurthestPointSampling(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor:
|
||||
"""
|
||||
Uses iterative furthest point sampling to select a set of npoint features that have the largest
|
||||
minimum distance
|
||||
:param ctx:
|
||||
:param xyz: (B, N, 3) where N > npoint
|
||||
:param npoint: int, number of features in the sampled set
|
||||
:return:
|
||||
output: (B, npoint) tensor containing the set
|
||||
"""
|
||||
assert xyz.is_contiguous()
|
||||
|
||||
B, N, _ = xyz.size()
|
||||
output = torch.cuda.IntTensor(B, npoint)
|
||||
temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
|
||||
|
||||
pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp,
|
||||
output)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(xyz, a=None):
|
||||
return None, None
|
||||
|
||||
|
||||
furthest_point_sample = FurthestPointSampling.apply
|
||||
|
||||
|
||||
class GatherOperation(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, features: torch.Tensor,
|
||||
idx: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
:param ctx:
|
||||
:param features: (B, C, N)
|
||||
:param idx: (B, npoint) index tensor of the features to gather
|
||||
:return:
|
||||
output: (B, C, npoint)
|
||||
"""
|
||||
assert features.is_contiguous()
|
||||
assert idx.is_contiguous()
|
||||
|
||||
B, npoint = idx.size()
|
||||
_, C, N = features.size()
|
||||
output = torch.cuda.FloatTensor(B, C, npoint)
|
||||
|
||||
pointnet2.gather_points_wrapper(B, C, N, npoint, features, idx, output)
|
||||
|
||||
ctx.for_backwards = (idx, C, N)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
idx, C, N = ctx.for_backwards
|
||||
B, npoint = idx.size()
|
||||
|
||||
grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
|
||||
grad_out_data = grad_out.data.contiguous()
|
||||
pointnet2.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data,
|
||||
idx, grad_features.data)
|
||||
return grad_features, None
|
||||
|
||||
|
||||
gather_operation = GatherOperation.apply
|
||||
|
||||
|
||||
class KNN(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, k: int, unknown: torch.Tensor,
|
||||
known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Find the three nearest neighbors of unknown in known
|
||||
:param ctx:
|
||||
:param unknown: (B, N, 3)
|
||||
:param known: (B, M, 3)
|
||||
:return:
|
||||
dist: (B, N, k) l2 distance to the three nearest neighbors
|
||||
idx: (B, N, k) index of 3 nearest neighbors
|
||||
"""
|
||||
assert unknown.is_contiguous()
|
||||
assert known.is_contiguous()
|
||||
|
||||
B, N, _ = unknown.size()
|
||||
m = known.size(1)
|
||||
dist2 = torch.cuda.FloatTensor(B, N, k)
|
||||
idx = torch.cuda.IntTensor(B, N, k)
|
||||
|
||||
pointnet2.knn_wrapper(B, N, m, k, unknown, known, dist2, idx)
|
||||
return torch.sqrt(dist2), idx
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, a=None, b=None):
|
||||
return None, None, None
|
||||
|
||||
|
||||
knn = KNN.apply
|
||||
|
||||
|
||||
class ThreeNN(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, unknown: torch.Tensor,
|
||||
known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Find the three nearest neighbors of unknown in known
|
||||
:param ctx:
|
||||
:param unknown: (B, N, 3)
|
||||
:param known: (B, M, 3)
|
||||
:return:
|
||||
dist: (B, N, 3) l2 distance to the three nearest neighbors
|
||||
idx: (B, N, 3) index of 3 nearest neighbors
|
||||
"""
|
||||
assert unknown.is_contiguous()
|
||||
assert known.is_contiguous()
|
||||
|
||||
B, N, _ = unknown.size()
|
||||
m = known.size(1)
|
||||
dist2 = torch.cuda.FloatTensor(B, N, 3)
|
||||
idx = torch.cuda.IntTensor(B, N, 3)
|
||||
|
||||
pointnet2.three_nn_wrapper(B, N, m, unknown, known, dist2, idx)
|
||||
return torch.sqrt(dist2), idx
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, a=None, b=None):
|
||||
return None, None
|
||||
|
||||
|
||||
three_nn = ThreeNN.apply
|
||||
|
||||
|
||||
class ThreeInterpolate(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, features: torch.Tensor, idx: torch.Tensor,
|
||||
weight: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Performs weight linear interpolation on 3 features
|
||||
:param ctx:
|
||||
:param features: (B, C, M) Features descriptors to be interpolated from
|
||||
:param idx: (B, n, 3) three nearest neighbors of the target features in features
|
||||
:param weight: (B, n, 3) weights
|
||||
:return:
|
||||
output: (B, C, N) tensor of the interpolated features
|
||||
"""
|
||||
assert features.is_contiguous()
|
||||
assert idx.is_contiguous()
|
||||
assert weight.is_contiguous()
|
||||
|
||||
B, c, m = features.size()
|
||||
n = idx.size(1)
|
||||
ctx.three_interpolate_for_backward = (idx, weight, m)
|
||||
output = torch.cuda.FloatTensor(B, c, n)
|
||||
|
||||
pointnet2.three_interpolate_wrapper(B, c, m, n, features, idx, weight,
|
||||
output)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(
|
||||
ctx, grad_out: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
:param ctx:
|
||||
:param grad_out: (B, C, N) tensor with gradients of outputs
|
||||
:return:
|
||||
grad_features: (B, C, M) tensor with gradients of features
|
||||
None:
|
||||
None:
|
||||
"""
|
||||
idx, weight, m = ctx.three_interpolate_for_backward
|
||||
B, c, n = grad_out.size()
|
||||
|
||||
grad_features = Variable(torch.cuda.FloatTensor(B, c, m).zero_())
|
||||
grad_out_data = grad_out.data.contiguous()
|
||||
|
||||
pointnet2.three_interpolate_grad_wrapper(B, c, n, m, grad_out_data,
|
||||
idx, weight,
|
||||
grad_features.data)
|
||||
return grad_features, None, None
|
||||
|
||||
|
||||
three_interpolate = ThreeInterpolate.apply
|
||||
|
||||
|
||||
class GroupingOperation(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, features: torch.Tensor,
|
||||
idx: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
:param ctx:
|
||||
:param features: (B, C, N) tensor of features to group
|
||||
:param idx: (B, npoint, nsample) tensor containing the indicies of features to group with
|
||||
:return:
|
||||
output: (B, C, npoint, nsample) tensor
|
||||
"""
|
||||
assert features.is_contiguous()
|
||||
assert idx.is_contiguous()
|
||||
idx = idx.int()
|
||||
B, nfeatures, nsample = idx.size()
|
||||
_, C, N = features.size()
|
||||
output = torch.cuda.FloatTensor(B, C, nfeatures, nsample)
|
||||
|
||||
pointnet2.group_points_wrapper(B, C, N, nfeatures, nsample, features,
|
||||
idx, output)
|
||||
|
||||
ctx.for_backwards = (idx, N)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx,
|
||||
grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
:param ctx:
|
||||
:param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward
|
||||
:return:
|
||||
grad_features: (B, C, N) gradient of the features
|
||||
"""
|
||||
idx, N = ctx.for_backwards
|
||||
|
||||
B, C, npoint, nsample = grad_out.size()
|
||||
grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
|
||||
|
||||
grad_out_data = grad_out.data.contiguous()
|
||||
pointnet2.group_points_grad_wrapper(B, C, N, npoint, nsample,
|
||||
grad_out_data, idx,
|
||||
grad_features.data)
|
||||
return grad_features, None
|
||||
|
||||
|
||||
grouping_operation = GroupingOperation.apply
|
||||
|
||||
|
||||
class BallQuery(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor,
|
||||
new_xyz: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
:param ctx:
|
||||
:param radius: float, radius of the balls
|
||||
:param nsample: int, maximum number of features in the balls
|
||||
:param xyz: (B, N, 3) xyz coordinates of the features
|
||||
:param new_xyz: (B, npoint, 3) centers of the ball query
|
||||
:return:
|
||||
idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls
|
||||
"""
|
||||
assert new_xyz.is_contiguous()
|
||||
assert xyz.is_contiguous()
|
||||
|
||||
B, N, _ = xyz.size()
|
||||
npoint = new_xyz.size(1)
|
||||
idx = torch.cuda.IntTensor(B, npoint, nsample).zero_()
|
||||
|
||||
pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz,
|
||||
xyz, idx)
|
||||
return idx
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, a=None):
|
||||
return None, None, None, None
|
||||
|
||||
|
||||
ball_query = BallQuery.apply
|
||||
|
||||
|
||||
class QueryAndGroup(nn.Module):
|
||||
|
||||
def __init__(self, radius: float, nsample: int, use_xyz: bool = True):
|
||||
"""
|
||||
:param radius: float, radius of ball
|
||||
:param nsample: int, maximum number of features to gather in the ball
|
||||
:param use_xyz:
|
||||
"""
|
||||
super().__init__()
|
||||
self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
|
||||
|
||||
def forward(self,
|
||||
xyz: torch.Tensor,
|
||||
new_xyz: torch.Tensor,
|
||||
features: torch.Tensor = None) -> Tuple[torch.Tensor]:
|
||||
"""
|
||||
:param xyz: (B, N, 3) xyz coordinates of the features
|
||||
:param new_xyz: (B, npoint, 3) centroids
|
||||
:param features: (B, C, N) descriptors of the features
|
||||
:return:
|
||||
new_features: (B, 3 + C, npoint, nsample)
|
||||
"""
|
||||
# idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
|
||||
B, N, C = new_xyz.shape
|
||||
dist, idx = knn(self.nsample, new_xyz, xyz)
|
||||
if self.radius is not None:
|
||||
tmp_idx = idx[:, :,
|
||||
0].unsqueeze(2).repeat(1, 1,
|
||||
self.nsample).to(idx.device)
|
||||
idx[dist > self.radius] = tmp_idx[dist > self.radius]
|
||||
xyz_trans = xyz.transpose(1, 2).contiguous()
|
||||
grouped_xyz = grouping_operation(xyz_trans,
|
||||
idx) # (B, 3, npoint, nsample)
|
||||
grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)
|
||||
|
||||
if features is not None:
|
||||
grouped_features = grouping_operation(features, idx)
|
||||
if self.use_xyz:
|
||||
new_features = torch.cat([grouped_xyz, grouped_features],
|
||||
dim=1) # (B, C + 3, npoint, nsample)
|
||||
else:
|
||||
new_features = grouped_features
|
||||
else:
|
||||
assert self.use_xyz, 'Cannot have not features and not use xyz as a feature!'
|
||||
new_features = grouped_xyz
|
||||
|
||||
return new_features, grouped_xyz
|
||||
|
||||
|
||||
class GroupAll(nn.Module):
|
||||
|
||||
def __init__(self, use_xyz: bool = True):
|
||||
super().__init__()
|
||||
self.use_xyz = use_xyz
|
||||
|
||||
def forward(self,
|
||||
xyz: torch.Tensor,
|
||||
new_xyz: torch.Tensor,
|
||||
features: torch.Tensor = None):
|
||||
"""
|
||||
:param xyz: (B, N, 3) xyz coordinates of the features
|
||||
:param new_xyz: ignored
|
||||
:param features: (B, C, N) descriptors of the features
|
||||
:return:
|
||||
new_features: (B, C + 3, 1, N)
|
||||
"""
|
||||
grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
|
||||
if features is not None:
|
||||
grouped_features = features.unsqueeze(2)
|
||||
if self.use_xyz:
|
||||
new_features = torch.cat([grouped_xyz, grouped_features],
|
||||
dim=1) # (B, 3 + C, 1, N)
|
||||
else:
|
||||
new_features = grouped_features
|
||||
else:
|
||||
new_features = grouped_xyz
|
||||
|
||||
return new_features, grouped_xyz
|
||||
@@ -0,0 +1,64 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os.path as osp
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base.base_torch_model import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .sf_rcp import SF_RCP
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.pointcloud_sceneflow_estimation,
|
||||
module_name=Models.rcp_sceneflow_estimation)
|
||||
class SceneFlowEstimation(TorchModel):
|
||||
|
||||
def __init__(self, model_dir: str, **kwargs):
|
||||
"""str -- model file root."""
|
||||
super().__init__(model_dir, **kwargs)
|
||||
|
||||
assert torch.cuda.is_available(
|
||||
), 'current model only support run in gpu'
|
||||
|
||||
# build model
|
||||
self.model = SF_RCP(
|
||||
npoint=8192,
|
||||
use_instance_norm=False,
|
||||
model_name='SF_RCP',
|
||||
use_insrance_norm=False,
|
||||
use_curvature=True)
|
||||
|
||||
# load model
|
||||
model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE)
|
||||
|
||||
logger.info(f'load ckpt from:{model_path}')
|
||||
|
||||
checkpoint = torch.load(model_path, map_location='cpu')
|
||||
|
||||
self.model.load_state_dict({k: v for k, v in checkpoint.items()})
|
||||
self.model.cuda()
|
||||
self.model.eval()
|
||||
|
||||
def forward(self, Inputs):
|
||||
|
||||
return self.model(Inputs['pcd1'], Inputs['pcd2'], Inputs['pcd1'],
|
||||
Inputs['pcd2'])[-1]
|
||||
|
||||
def postprocess(self, Inputs):
|
||||
output = Inputs['output']
|
||||
|
||||
results = {OutputKeys.OUTPUT: output.detach().cpu().numpy()[0]}
|
||||
|
||||
return results
|
||||
|
||||
def inference(self, data):
|
||||
results = self.forward(data)
|
||||
|
||||
return results
|
||||
523
modelscope/models/cv/pointcloud_sceneflow_estimation/sf_rcp.py
Normal file
523
modelscope/models/cv/pointcloud_sceneflow_estimation/sf_rcp.py
Normal file
@@ -0,0 +1,523 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .common import (PointNetFeaturePropogation, PointNetSetAbstraction,
|
||||
PointWiseOptimLayer, Sinkhorn)
|
||||
|
||||
|
||||
class FeatureMatching(nn.Module):
|
||||
|
||||
def __init__(self, npoint, use_instance_norm, supporth_th, feature_norm,
|
||||
max_iter):
|
||||
super(FeatureMatching, self).__init__()
|
||||
self.support_th = supporth_th**2 # 10m
|
||||
self.feature_norm = feature_norm
|
||||
self.max_iter = max_iter
|
||||
# Mass regularisation
|
||||
self.gamma = torch.nn.Parameter(torch.zeros(1))
|
||||
# Entropic regularisation
|
||||
self.epsilon = torch.nn.Parameter(torch.zeros(1))
|
||||
self.sinkhorn = Sinkhorn()
|
||||
|
||||
self.extract_glob = FeatureExtractionGlobal(npoint, use_instance_norm)
|
||||
|
||||
# upsample flow
|
||||
self.fp0 = PointNetFeaturePropogation(in_channel=3, mlp=[])
|
||||
self.sa1 = PointNetSetAbstraction(
|
||||
npoint=int(npoint / 16),
|
||||
radius=None,
|
||||
nsample=16,
|
||||
in_channel=3,
|
||||
mlp=[32, 32, 64],
|
||||
group_all=False,
|
||||
use_instance_norm=use_instance_norm)
|
||||
self.fp1 = PointNetFeaturePropogation(in_channel=64, mlp=[])
|
||||
self.sa2 = PointNetSetAbstraction(
|
||||
npoint=int(npoint / 8),
|
||||
radius=None,
|
||||
nsample=16,
|
||||
in_channel=64,
|
||||
mlp=[64, 64, 128],
|
||||
group_all=False,
|
||||
use_instance_norm=use_instance_norm)
|
||||
self.fp2 = PointNetFeaturePropogation(in_channel=128, mlp=[])
|
||||
|
||||
self.flow_regressor = FlowRegressor(npoint, use_instance_norm)
|
||||
self.flow_up_sample = PointNetFeaturePropogation(in_channel=3, mlp=[])
|
||||
|
||||
def upsample_flow(self, pc1_l, pc1_l_glob, flow_inp):
|
||||
"""
|
||||
flow_inp: [B, N, 3]
|
||||
return: [B, 3, N]
|
||||
"""
|
||||
|
||||
flow_inp = flow_inp.permute(0, 2, 1).contiguous() # [B, 3, N]
|
||||
|
||||
flow_feat = self.fp0(pc1_l_glob['s16'], pc1_l_glob['s32'], None,
|
||||
flow_inp)
|
||||
_, corr_feats_l2 = self.sa1(pc1_l_glob['s16'], flow_feat)
|
||||
|
||||
flow_feat = self.fp1(pc1_l_glob['s8'], pc1_l_glob['s16'], None,
|
||||
corr_feats_l2)
|
||||
_, flow_feat = self.sa2(pc1_l_glob['s8'], flow_feat)
|
||||
|
||||
flow_feat = self.fp2(pc1_l['s4'], pc1_l_glob['s8'], None, flow_feat)
|
||||
|
||||
flow, flow_lr = self.flow_regressor(pc1_l, flow_feat)
|
||||
|
||||
flow_up = self.flow_up_sample(pc1_l['s1'], pc1_l_glob['s32'], None,
|
||||
flow_inp)
|
||||
flow_lr_up = self.flow_up_sample(pc1_l['s4'], pc1_l_glob['s32'], None,
|
||||
flow_inp)
|
||||
|
||||
flow, flow_lr = flow + flow_up, flow_lr + flow_lr_up
|
||||
|
||||
return flow, flow_lr
|
||||
|
||||
def calc_feats_corr(self, pcloud1, pcloud2, feature1, feature2, norm):
|
||||
"""
|
||||
pcloud1, pcloud2: [B, N, 3]
|
||||
feature1, feature2: [B, N, C]
|
||||
"""
|
||||
if norm:
|
||||
feature1 = feature1 / torch.sqrt(
|
||||
torch.sum(feature1**2, -1, keepdim=True) + 1e-6)
|
||||
feature2 = feature2 / torch.sqrt(
|
||||
torch.sum(feature2**2, -1, keepdim=True) + 1e-6)
|
||||
corr_mat = torch.bmm(feature1,
|
||||
feature2.transpose(1, 2)) # [B, N1, N2]
|
||||
else:
|
||||
corr_mat = torch.bmm(feature1, feature2.transpose(
|
||||
1, 2)) / feature1.shape[2]**.5 # [B, N1, N2]
|
||||
|
||||
if self.support_th is not None:
|
||||
distance_matrix = torch.sum(
|
||||
pcloud1**2, -1, keepdim=True) # [B, N1, 1]
|
||||
distance_matrix = distance_matrix + torch.sum(
|
||||
pcloud2**2, -1, keepdim=True).transpose(1, 2) # [B, N1, N2]
|
||||
distance_matrix = distance_matrix - 2 * torch.bmm(
|
||||
pcloud1, pcloud2.transpose(1, 2)) # [B, N1, N2]
|
||||
support = (distance_matrix < self.support_th) # [B, N1, N2]
|
||||
support = support.float()
|
||||
else:
|
||||
support = torch.ones_like(corr_mat)
|
||||
return corr_mat, support
|
||||
|
||||
def calc_corr_mat(self, pcloud1, pcloud2, feature1, feature2):
|
||||
"""
|
||||
pcloud1, pcloud2: [B, N, 3]
|
||||
feature1, feature2: [B, N, C]
|
||||
corr_mat: [B, N1, N2]
|
||||
"""
|
||||
epsilon = torch.exp(self.epsilon) + 0.03
|
||||
corr_mat, support = self.calc_feats_corr(
|
||||
pcloud1, pcloud2, feature1, feature2, norm=self.feature_norm)
|
||||
C = 1.0 - corr_mat
|
||||
corr_mat = torch.exp(-C / epsilon) * support
|
||||
return corr_mat
|
||||
|
||||
def get_flow_init(self, pcloud1, pcloud2, feats1, feats2):
|
||||
"""
|
||||
pcloud1, pcloud2: [B, 3, N]
|
||||
feats1, feats2: [B, C, N]
|
||||
"""
|
||||
|
||||
corr_mat = self.calc_corr_mat(
|
||||
pcloud1.permute(0, 2, 1), pcloud2.permute(0, 2, 1),
|
||||
feats1.permute(0, 2, 1), feats2.permute(0, 2, 1))
|
||||
|
||||
corr_mat = self.sinkhorn(corr_mat,
|
||||
torch.exp(self.epsilon) + 0.03, self.gamma,
|
||||
self.max_iter)
|
||||
|
||||
row_sum = corr_mat.sum(-1, keepdim=True) # [B, N1, 1]
|
||||
flow_init = (corr_mat @ pcloud2.permute(0, 2, 1).contiguous()) / (
|
||||
row_sum + 1e-6) - pcloud1.permute(0, 2,
|
||||
1).contiguous() # [B, N1, 3]
|
||||
|
||||
return flow_init
|
||||
|
||||
def forward(self, pc1_l, pc2_l, feats1, feats2):
|
||||
"""
|
||||
pc1_l, pc2_l: dict([B, 3, N])
|
||||
feats1, feats2: [B, C, N]
|
||||
"""
|
||||
pc1_l_glob, feats1_glob = self.extract_glob(pc1_l['s4'], feats1)
|
||||
pc2_l_glob, feats2_glob = self.extract_glob(pc2_l['s4'], feats2)
|
||||
|
||||
flow_init_s32 = self.get_flow_init(pc1_l_glob['s32'],
|
||||
pc2_l_glob['s32'], feats1_glob,
|
||||
feats2_glob)
|
||||
|
||||
flow_init, flow_init_s4 = self.upsample_flow(pc1_l, pc1_l_glob,
|
||||
flow_init_s32)
|
||||
|
||||
return flow_init, flow_init_s4
|
||||
|
||||
|
||||
class FlowRegressor(nn.Module):
|
||||
|
||||
def __init__(self, npoint, use_instance_norm, input_dim=128, nsample=32):
|
||||
super(FlowRegressor, self).__init__()
|
||||
self.sa1 = PointNetSetAbstraction(
|
||||
npoint=int(npoint / 4),
|
||||
radius=None,
|
||||
nsample=nsample,
|
||||
in_channel=input_dim,
|
||||
mlp=[input_dim, input_dim],
|
||||
group_all=False,
|
||||
use_instance_norm=use_instance_norm)
|
||||
self.sa2 = PointNetSetAbstraction(
|
||||
npoint=int(npoint / 4),
|
||||
radius=None,
|
||||
nsample=nsample,
|
||||
in_channel=input_dim,
|
||||
mlp=[input_dim, input_dim],
|
||||
group_all=False,
|
||||
use_instance_norm=use_instance_norm)
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(input_dim, input_dim), nn.ReLU(inplace=True),
|
||||
nn.Linear(input_dim, 3))
|
||||
|
||||
self.up_sample = PointNetFeaturePropogation(in_channel=3, mlp=[])
|
||||
|
||||
def forward(self, pc1_l, feats):
|
||||
"""
|
||||
pc1_l: dict([B, 3, N])
|
||||
feats: [B, C, N]
|
||||
return: [B, 3, N]
|
||||
"""
|
||||
_, x = self.sa1(pc1_l['s4'], feats)
|
||||
_, x = self.sa2(pc1_l['s4'], x)
|
||||
x = x.permute(0, 2, 1).contiguous() # [B, N, C]
|
||||
x = self.fc(x)
|
||||
flow_lr = x.permute(0, 2, 1).contiguous() # [B, 3, N]
|
||||
|
||||
flow = self.up_sample(pc1_l['s1'], pc1_l['s4'], None,
|
||||
flow_lr) # [B, 3, N]
|
||||
|
||||
return flow, flow_lr
|
||||
|
||||
|
||||
class FeatureExtractionGlobal(nn.Module):
|
||||
|
||||
def __init__(self, npoint, use_instance_norm):
|
||||
super(FeatureExtractionGlobal, self).__init__()
|
||||
self.sa1 = PointNetSetAbstraction(
|
||||
npoint=int(npoint / 8),
|
||||
radius=None,
|
||||
nsample=32,
|
||||
in_channel=64,
|
||||
mlp=[128, 128, 128],
|
||||
group_all=False,
|
||||
use_instance_norm=use_instance_norm)
|
||||
self.sa2 = PointNetSetAbstraction(
|
||||
npoint=int(npoint / 16),
|
||||
radius=None,
|
||||
nsample=24,
|
||||
in_channel=128,
|
||||
mlp=[128, 128, 128],
|
||||
group_all=False,
|
||||
use_instance_norm=use_instance_norm)
|
||||
self.sa3 = PointNetSetAbstraction(
|
||||
npoint=int(npoint / 32),
|
||||
radius=None,
|
||||
nsample=16,
|
||||
in_channel=128,
|
||||
mlp=[256, 256, 256],
|
||||
group_all=False,
|
||||
use_instance_norm=use_instance_norm)
|
||||
|
||||
def forward(self, pc, feature):
|
||||
pc_l1, feat_l1 = self.sa1(pc, feature)
|
||||
pc_l2, feat_l2 = self.sa2(pc_l1, feat_l1)
|
||||
pc_l3, feat_l3 = self.sa3(pc_l2, feat_l2)
|
||||
|
||||
pc_l = dict(s8=pc_l1, s16=pc_l2, s32=pc_l3)
|
||||
return pc_l, feat_l3
|
||||
|
||||
|
||||
class FeatureExtraction(nn.Module):
|
||||
|
||||
def __init__(self, npoint, use_instance_norm):
|
||||
super(FeatureExtraction, self).__init__()
|
||||
self.sa1 = PointNetSetAbstraction(
|
||||
npoint=int(npoint / 2),
|
||||
radius=None,
|
||||
nsample=32,
|
||||
in_channel=3,
|
||||
mlp=[32, 32, 32],
|
||||
group_all=False,
|
||||
return_fps=True,
|
||||
use_instance_norm=use_instance_norm)
|
||||
self.sa2 = PointNetSetAbstraction(
|
||||
npoint=int(npoint / 4),
|
||||
radius=None,
|
||||
nsample=32,
|
||||
in_channel=32,
|
||||
mlp=[64, 64, 64],
|
||||
group_all=False,
|
||||
return_fps=True,
|
||||
use_instance_norm=use_instance_norm)
|
||||
|
||||
def forward(self, pc, feature, fps_idx=None):
|
||||
"""
|
||||
pc: [B, 3, N]
|
||||
feature: [B, 3, N]
|
||||
"""
|
||||
fps_idx1 = fps_idx['s2'] if fps_idx is not None else None
|
||||
pc_l1, feat_l1, fps_idx1 = self.sa1(pc, feature, fps_idx=fps_idx1)
|
||||
fps_idx2 = fps_idx['s4'] if fps_idx is not None else None
|
||||
pc_l2, feat_l2, fps_idx2 = self.sa2(pc_l1, feat_l1, fps_idx=fps_idx2)
|
||||
pc_l = dict(s1=pc, s2=pc_l1, s4=pc_l2)
|
||||
fps_idx = dict(s2=fps_idx1, s4=fps_idx2)
|
||||
return pc_l, feat_l2, fps_idx
|
||||
|
||||
|
||||
class HiddenInitNet(nn.Module):
|
||||
|
||||
def __init__(self, npoint, use_instance_norm):
|
||||
super(HiddenInitNet, self).__init__()
|
||||
self.sa1 = PointNetSetAbstraction(
|
||||
npoint=int(npoint / 4),
|
||||
radius=None,
|
||||
nsample=8,
|
||||
in_channel=64,
|
||||
mlp=[128, 128, 128],
|
||||
group_all=False,
|
||||
use_instance_norm=use_instance_norm)
|
||||
self.sa2 = PointNetSetAbstraction(
|
||||
npoint=int(npoint / 4),
|
||||
radius=None,
|
||||
nsample=8,
|
||||
in_channel=128,
|
||||
mlp=[128],
|
||||
group_all=False,
|
||||
use_act=False,
|
||||
use_instance_norm=use_instance_norm)
|
||||
|
||||
def forward(self, pc, feature):
|
||||
_, feat_l1 = self.sa1(pc, feature)
|
||||
_, feat_l2 = self.sa2(pc, feat_l1)
|
||||
|
||||
h_init = torch.tanh(feat_l2)
|
||||
return h_init
|
||||
|
||||
|
||||
class GRUReg(nn.Module):
|
||||
|
||||
def __init__(self, npoint, hidden_dim, input_dim, use_instance_norm):
|
||||
super().__init__()
|
||||
in_ch = hidden_dim + input_dim
|
||||
|
||||
self.flow_proj = nn.ModuleList([
|
||||
PointNetSetAbstraction(
|
||||
npoint=int(npoint / 4),
|
||||
radius=None,
|
||||
nsample=16,
|
||||
in_channel=3,
|
||||
mlp=[32, 32, 32],
|
||||
group_all=False,
|
||||
use_instance_norm=use_instance_norm),
|
||||
PointNetSetAbstraction(
|
||||
npoint=int(npoint / 4),
|
||||
radius=None,
|
||||
nsample=8,
|
||||
in_channel=32,
|
||||
mlp=[16, 16, 16],
|
||||
group_all=False,
|
||||
use_instance_norm=use_instance_norm)
|
||||
])
|
||||
|
||||
self.hidden_init_net = HiddenInitNet(npoint, use_instance_norm)
|
||||
|
||||
self.gru_layers = nn.ModuleList([
|
||||
PointNetSetAbstraction(
|
||||
npoint=int(npoint / 4),
|
||||
radius=None,
|
||||
nsample=4,
|
||||
in_channel=in_ch,
|
||||
mlp=[hidden_dim],
|
||||
group_all=False,
|
||||
use_act=False,
|
||||
use_instance_norm=use_instance_norm),
|
||||
PointNetSetAbstraction(
|
||||
npoint=int(npoint / 4),
|
||||
radius=None,
|
||||
nsample=4,
|
||||
in_channel=in_ch,
|
||||
mlp=[hidden_dim],
|
||||
group_all=False,
|
||||
use_act=False,
|
||||
use_instance_norm=use_instance_norm),
|
||||
PointNetSetAbstraction(
|
||||
npoint=int(npoint / 4),
|
||||
radius=None,
|
||||
nsample=4,
|
||||
in_channel=in_ch,
|
||||
mlp=[hidden_dim],
|
||||
group_all=False,
|
||||
use_act=False,
|
||||
use_instance_norm=use_instance_norm)
|
||||
])
|
||||
|
||||
def gru(self, h, gru_inp, pc):
|
||||
hx = torch.cat([h, gru_inp], dim=1)
|
||||
z = torch.sigmoid(self.gru_layers[0](pc, hx)[1])
|
||||
r = torch.sigmoid(self.gru_layers[1](pc, hx)[1])
|
||||
q = torch.tanh(self.gru_layers[2](pc, torch.cat([r * h, gru_inp],
|
||||
dim=1))[1])
|
||||
h = (1 - z) * h + z * q
|
||||
return h
|
||||
|
||||
def get_gru_input(self, feats1_new, cost, flow, pc):
|
||||
flow_feats = flow
|
||||
for flow_conv in self.flow_proj:
|
||||
_, flow_feats = flow_conv(pc, flow_feats)
|
||||
|
||||
gru_inp = torch.cat([feats1_new, cost, flow_feats, flow],
|
||||
dim=1) # [64, 128, 16, 3]
|
||||
|
||||
return gru_inp
|
||||
|
||||
def forward(self, h, feats1_new, cost, flow_lr, pc1_l):
|
||||
gru_inp = self.get_gru_input(feats1_new, cost, flow_lr, pc=pc1_l['s4'])
|
||||
|
||||
h = self.gru(h, gru_inp, pc1_l['s4'])
|
||||
return h
|
||||
|
||||
|
||||
class SF_RCP(nn.Module):
|
||||
|
||||
def __init__(self, npoint=8192, use_instance_norm=False, **kwargs):
|
||||
super().__init__()
|
||||
self.radius = kwargs.get('radius', 3.5)
|
||||
self.nsample = kwargs.get('nsample', 6)
|
||||
self.radius_min = kwargs.get('radius_min', 3.5)
|
||||
self.nsample_min = kwargs.get('nsample_min', 6)
|
||||
self.use_curvature = kwargs.get('use_curvature', True)
|
||||
self.flow_ratio = kwargs.get('flow_ratio', 0.1)
|
||||
self.init_max_iter = kwargs.get('init_max_iter', 0)
|
||||
self.init_feature_norm = kwargs.get('init_feature_norm', True)
|
||||
self.support_th = kwargs.get('support_th', 10)
|
||||
|
||||
self.feature_extraction = FeatureExtraction(npoint, use_instance_norm)
|
||||
self.feature_matching = FeatureMatching(
|
||||
npoint,
|
||||
use_instance_norm,
|
||||
supporth_th=self.support_th,
|
||||
feature_norm=self.init_feature_norm,
|
||||
max_iter=self.init_max_iter)
|
||||
|
||||
self.pointwise_optim_layer = PointWiseOptimLayer(
|
||||
nsample=self.nsample,
|
||||
radius=self.radius,
|
||||
in_channel=64,
|
||||
mlp=[128, 128, 128],
|
||||
use_curvature=self.use_curvature)
|
||||
|
||||
self.gru = GRUReg(
|
||||
npoint,
|
||||
hidden_dim=128,
|
||||
input_dim=128 + 64 + 16 + 3,
|
||||
use_instance_norm=use_instance_norm)
|
||||
|
||||
self.flow_regressor = FlowRegressor(npoint, use_instance_norm)
|
||||
|
||||
def initialization(self, pc1_l, pc2_l, feats1, feats2):
|
||||
"""
|
||||
pc1: [B, 3, N]
|
||||
pc2: [B, 3, N]
|
||||
feature1: [B, 3, N]
|
||||
feature2: [B, 3, N]
|
||||
"""
|
||||
flow, flow_lr = self.feature_matching(pc1_l, pc2_l, feats1, feats2)
|
||||
|
||||
return flow, flow_lr
|
||||
|
||||
def pointwise_optimization(self, pc1_l_new, pc2_l, feats1_new, feats2,
|
||||
pc1_l, flow_lr, iter):
|
||||
_, cost, score, pos2_grouped = self.pointwise_optim_layer(
|
||||
pc1_l_new['s4'],
|
||||
pc2_l['s4'],
|
||||
feats1_new,
|
||||
feats2,
|
||||
nsample=max(self.nsample_min, self.nsample // (2**iter)),
|
||||
radius=max(self.radius_min, self.radius / (2**iter)),
|
||||
pos1_raw=pc1_l['s4'],
|
||||
return_score=True)
|
||||
|
||||
# pc1_new_l_loc: [B, 3, N, S]
|
||||
# pos2_grouped: [B, C, N, S]
|
||||
delta_flow_tmp = ((pos2_grouped - pc1_l_new['s4'].unsqueeze(-1))
|
||||
* score.mean(dim=1, keepdim=True)).sum(
|
||||
dim=-1) # [B, 3, N]
|
||||
flow_lr = flow_lr + self.flow_ratio * delta_flow_tmp
|
||||
|
||||
return flow_lr, cost
|
||||
|
||||
def update_pos(self, pc, pc_lr, flow, flow_lr):
|
||||
pc = pc + flow
|
||||
pc_lr = pc_lr + flow_lr
|
||||
return pc, pc_lr
|
||||
|
||||
def forward(self, pc1, pc2, feature1, feature2, iters=1):
|
||||
"""
|
||||
pc1: [B, N, 3]
|
||||
pc2: [B, N, 3]
|
||||
feature1: [B, N, 3]
|
||||
feature2: [B, N, 3]
|
||||
"""
|
||||
# prepare
|
||||
flow_predictions = []
|
||||
pc1 = pc1.permute(0, 2, 1).contiguous() # B 3 N
|
||||
pc2 = pc2.permute(0, 2, 1).contiguous() # B 3 N
|
||||
feature1 = feature1.permute(0, 2, 1).contiguous() # B 3 N
|
||||
feature2 = feature2.permute(0, 2, 1).contiguous() # B 3 N
|
||||
|
||||
# feature extraction
|
||||
pc1_l, feats1, fps_idx1 = self.feature_extraction(pc1, feature1)
|
||||
pc2_l, feats2, _ = self.feature_extraction(pc2, feature2)
|
||||
|
||||
# initialization, flow_lr_init(flow_low_resolution)
|
||||
flow_init, flow_lr_init = self.initialization(pc1_l, pc2_l, feats1,
|
||||
feats2)
|
||||
flow_predictions.append(flow_init.permute(0, 2, 1))
|
||||
|
||||
# gru init hidden state
|
||||
h = self.gru.hidden_init_net(pc1_l['s4'], feats1)
|
||||
|
||||
# update position
|
||||
pc1_lr_raw = pc1_l['s4']
|
||||
pc1_new, pc1_lr_new = self.update_pos(pc1, pc1_lr_raw, flow_init,
|
||||
flow_lr_init)
|
||||
|
||||
# iterative optim
|
||||
for iter in range(iters - 1):
|
||||
pc1_new = pc1_new.detach()
|
||||
pc1_lr_new = pc1_lr_new.detach()
|
||||
flow_lr = pc1_lr_new - pc1_lr_raw
|
||||
|
||||
pc1_l_new, feats1_new, _ = self.feature_extraction(
|
||||
pc1_new, pc1_new, fps_idx1)
|
||||
|
||||
# pointwise optimization to get udpated flow_lr and cost
|
||||
flow_lr_update, cost = self.pointwise_optimization(
|
||||
pc1_l_new, pc2_l, feats1_new, feats2, pc1_l, flow_lr, iter)
|
||||
flow_lr = flow_lr_update
|
||||
|
||||
# gru regularization
|
||||
h = self.gru(h, feats1_new, cost, flow_lr, pc1_l)
|
||||
# pred flow_lr
|
||||
delta_flow, delta_flow_lr = self.flow_regressor(pc1_l, h)
|
||||
|
||||
pc1_new, pc1_lr_new = self.update_pos(pc1_new, pc1_lr_new,
|
||||
delta_flow, delta_flow_lr)
|
||||
|
||||
flow = pc1_new - pc1
|
||||
flow_predictions.append(flow.permute(0, 2, 1))
|
||||
|
||||
return flow_predictions
|
||||
@@ -50,6 +50,8 @@ class OutputKeys(object):
|
||||
SCENE_NUM = 'scene_num'
|
||||
SCENE_META_LIST = 'scene_meta_list'
|
||||
SHOT_META_LIST = 'shot_meta_list'
|
||||
PCD12 = 'pcd12'
|
||||
PCD12_ALIGN = 'pcd12_align'
|
||||
|
||||
|
||||
TASK_OUTPUTS = {
|
||||
|
||||
@@ -70,6 +70,7 @@ if TYPE_CHECKING:
|
||||
from .image_skychange_pipeline import ImageSkychangePipeline
|
||||
from .video_object_segmentation_pipeline import VideoObjectSegmentationPipeline
|
||||
from .video_super_resolution_pipeline import VideoSuperResolutionPipeline
|
||||
from .pointcloud_sceneflow_estimation_pipeline import PointCloudSceneFlowEstimationPipeline
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
@@ -162,6 +163,9 @@ else:
|
||||
'VideoObjectSegmentationPipeline'
|
||||
],
|
||||
'video_super_resolution_pipeline': ['VideoSuperResolutionPipeline'],
|
||||
'pointcloud_sceneflow_estimation_pipeline': [
|
||||
'PointCloudSceneFlowEstimationPipeline'
|
||||
]
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from plyfile import PlyData, PlyElement
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Model, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import LoadImage
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.cv.image_utils import depth_to_color
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.pointcloud_sceneflow_estimation,
|
||||
module_name=Pipelines.pointcloud_sceneflow_estimation)
|
||||
class PointCloudSceneFlowEstimationPipeline(Pipeline):
|
||||
|
||||
def __init__(self, model: str, **kwargs):
|
||||
"""
|
||||
use `model` to create a image depth estimation pipeline for prediction
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
|
||||
logger.info('pointcloud scenflow estimation model, pipeline init')
|
||||
|
||||
def check_input_pcd(self, pcd):
|
||||
assert pcd.ndim == 2, 'pcd ndim must equal to 2'
|
||||
assert pcd.shape[1] == 3, 'pcd.shape[1] must equal to 3'
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
assert isinstance(input, tuple), 'only support tuple input'
|
||||
assert isinstance(input[0], str) and isinstance(
|
||||
input[1], str), 'only support tuple input with str type'
|
||||
|
||||
pcd1_file, pcd2_file = input
|
||||
logger.info(f'input pcd file:{pcd1_file}, \n {pcd2_file}')
|
||||
pcd1 = np.load(pcd1_file)
|
||||
pcd2 = np.load(pcd2_file)
|
||||
self.check_input_pcd(pcd1)
|
||||
self.check_input_pcd(pcd2)
|
||||
pcd1_torch = torch.from_numpy(pcd1).float().unsqueeze(0).cuda()
|
||||
pcd2_torch = torch.from_numpy(pcd2).float().unsqueeze(0).cuda()
|
||||
|
||||
data = {
|
||||
'pcd1': pcd1_torch,
|
||||
'pcd2': pcd2_torch,
|
||||
'pcd1_ori': pcd1,
|
||||
'pcd2_ori': pcd2
|
||||
}
|
||||
|
||||
return data
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
results = {}
|
||||
output = self.model.inference(input)
|
||||
results['output'] = output
|
||||
results['pcd1_ori'] = input['pcd1_ori']
|
||||
results['pcd2_ori'] = input['pcd2_ori']
|
||||
return results
|
||||
|
||||
def save_ply_data(self, pcd1, pcd2):
|
||||
vertexs = np.concatenate([pcd1, pcd2], axis=0)
|
||||
color1 = np.array([[255, 0, 0]], dtype=np.uint8)
|
||||
color2 = np.array([[0, 255, 0]], dtype=np.uint8)
|
||||
color1 = np.tile(color1, (pcd1.shape[0], 1))
|
||||
color2 = np.tile(color2, (pcd2.shape[0], 1))
|
||||
vertex_colors = np.concatenate([color1, color2], axis=0)
|
||||
|
||||
vertexs = np.array([tuple(v) for v in vertexs],
|
||||
dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])
|
||||
vertex_colors = np.array([tuple(v) for v in vertex_colors],
|
||||
dtype=[('red', 'u1'), ('green', 'u1'),
|
||||
('blue', 'u1')])
|
||||
|
||||
vertex_all = np.empty(
|
||||
len(vertexs), vertexs.dtype.descr + vertex_colors.dtype.descr)
|
||||
for prop in vertexs.dtype.names:
|
||||
vertex_all[prop] = vertexs[prop]
|
||||
for prop in vertex_colors.dtype.names:
|
||||
vertex_all[prop] = vertex_colors[prop]
|
||||
|
||||
el = PlyElement.describe(vertex_all, 'vertex')
|
||||
ply_data = PlyData([el])
|
||||
# .write(save_name)
|
||||
return ply_data
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
results = self.model.postprocess(inputs)
|
||||
flow = results[OutputKeys.OUTPUT]
|
||||
|
||||
pcd1 = inputs['pcd1_ori']
|
||||
pcd2 = inputs['pcd2_ori']
|
||||
if isinstance(pcd1, torch.Tensor):
|
||||
pcd1 = pcd1.cpu().numpy()
|
||||
if isinstance(pcd2, torch.Tensor):
|
||||
pcd2 = pcd2.cpu().numpy()
|
||||
if isinstance(flow, torch.Tensor):
|
||||
flow = flow.cpu().numpy()
|
||||
|
||||
outputs = {
|
||||
OutputKeys.OUTPUT: flow,
|
||||
OutputKeys.PCD12: self.save_ply_data(pcd1, pcd2),
|
||||
OutputKeys.PCD12_ALIGN: self.save_ply_data(pcd1 + flow, pcd2),
|
||||
}
|
||||
return outputs
|
||||
@@ -104,6 +104,9 @@ class CVTasks(object):
|
||||
video_summarization = 'video-summarization'
|
||||
image_reid_person = 'image-reid-person'
|
||||
|
||||
# pointcloud task
|
||||
pointcloud_sceneflow_estimation = 'pointcloud-sceneflow-estimation'
|
||||
|
||||
|
||||
class NLPTasks(object):
|
||||
# nlp tasks
|
||||
|
||||
@@ -24,6 +24,7 @@ onnxruntime>=1.10
|
||||
opencv-python
|
||||
pai-easycv>=0.6.3.9
|
||||
pandas
|
||||
plyfile>=0.7.4
|
||||
psutil
|
||||
regex
|
||||
scikit-image>=0.19.3
|
||||
|
||||
42
tests/pipelines/test_pointcloud_sceneflow_estimation.py
Normal file
42
tests/pipelines/test_pointcloud_sceneflow_estimation.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.demo_utils import DemoCompatibilityCheck
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class PointCloudSceneFlowEstimationTest(unittest.TestCase,
|
||||
DemoCompatibilityCheck):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.task = 'pointcloud-sceneflow-estimation'
|
||||
self.model_id = 'damo/cv_pointnet2_sceneflow-estimation_general'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_pointcloud_scenelfow_estimation(self):
|
||||
input_location = ('data/test/pointclouds/flyingthings_pcd1.npy',
|
||||
'data/test/pointclouds/flyingthings_pcd2.npy')
|
||||
estimator = pipeline(
|
||||
Tasks.pointcloud_sceneflow_estimation, model=self.model_id)
|
||||
result = estimator(input_location)
|
||||
flow = result[OutputKeys.OUTPUT]
|
||||
pcd12 = result[OutputKeys.PCD12]
|
||||
pcd12_align = result[OutputKeys.PCD12_ALIGN]
|
||||
|
||||
print(f'pred flow shape:{flow.shape}')
|
||||
np.save('flow.npy', flow)
|
||||
# visualization
|
||||
pcd12.write('pcd12.ply')
|
||||
pcd12_align.write('pcd12_align.ply')
|
||||
|
||||
print('test_pointcloud_scenelfow_estimation DONE')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user