add support for cv_rdevos_video-object-segmentation

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11066863
This commit is contained in:
hooks.hl
2022-12-21 07:37:27 +08:00
committed by yingda.chen
parent 9a67e0bb48
commit 0d205c8322
40 changed files with 1869 additions and 3 deletions

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3044047d7452f27ad5391ecfc83f1a366a5a82e3eb8c8c151a6bf02cbb37c046
size 5366

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:401ebc9ffc8ba5f47ef31f96ebd6a2e7f82b976c49e9a36f1915f5b9202f3d38
size 45113

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:24a6702646a071988aa3a77a40d92cfc2c3ada00ecef5afeaf9c475272f30af1
size 50254

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:37071e9ad7c91af44b2ed9b6d8e3ea42c69f284352d4431ef680b4de01521261
size 51068

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a019761196033e0571c40ccadfdc2eff454e40207e16c3de66a1e322ec727227
size 50164

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8cd500e4f54ac5b85c9eb2a8b7d8706a061ce7ff80502d29850e641cd59a9e9f
size 51056

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fd943f6fb068f7ece7a009f729ed356dd09295ab4d19b0da858256b0259bf467
size 51580

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:554e1725877410cd7ba7d24e151a8b7572eb88ed0d8760a70b15fc20843cda76
size 50718

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a912a376204cdc39e128a0b1d777bc73b88f1f6b6e9d16a28755cf4a89fcfe87
size 51000

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1e83de2980b89b971c3894e102a27a0684307a9e0c841775ce9419f2f5ab43d7
size 50267

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3cd21580132bcbc401cd443e0b0c5bdf6f3fcbfec91e0e767cabc4f6480185a9
size 50328

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2a1effa23c00b45e8b516dfa64e376756b9600fe94c43be7b2ae6331e5041474
size 50885

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d0f16febf7ef8739393f16cab88e2da00aa3b1a60674ef33cbec7c3dd75584ed
size 51075

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cc1ffd892d12d82c10331039bc998822c9a601a1b80fb60d367f5532e3b5958f
size 49797

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cae709a0b5216f469591173beecd05d1e7786409dec6b41323751312c21ffbe8
size 51894

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f2f650e5cc18dcb058da04822181d4a25eb85329d6daf1da395c88e9f537760a
size 52597

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1f88d10b8ebc1f8b6da16faf5032a1b49ba1dab94854319637fb390a201af614
size 51468

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cc914fbab93a722b2d263e36f713427ae07d82b7aa52e2d189f7517fa6c4bf1b
size 51824

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:79f27be589e2ac046d52bcc0d927e7aef8c2e1562dbb07d63aa0b7a50e0d6b90
size 52512

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2db022e12a83b900afc12493474b2b078f373d864d64f8babbf3c8cd50b62978
size 50119

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:402b7a8fcef26bbba5eebfc5fcc141b3a5a1a5fd06ad61ee6d3150c3beeb3824
size 50961

View File

@@ -58,6 +58,7 @@ class Models(object):
product_segmentation = 'product-segmentation'
image_body_reshaping = 'image-body-reshaping'
video_human_matting = 'video-human-matting'
video_object_segmentation = 'video-object-segmentation'
# EasyCV models
yolox = 'YOLOX'
@@ -243,6 +244,7 @@ class Pipelines(object):
image_body_reshaping = 'flow-based-body-reshaping'
referring_video_object_segmentation = 'referring-video-object-segmentation'
video_human_matting = 'video-human-matting'
video_object_segmentation = 'video-object-segmentation'
# nlp tasks
automatic_post_editing = 'automatic-post-editing'

View File

@@ -14,6 +14,7 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints,
object_detection, product_retrieval_embedding,
realtime_object_detection, referring_video_object_segmentation,
salient_detection, shop_segmentation, super_resolution,
video_single_object_tracking, video_summarization, virual_tryon)
video_object_segmentation, video_single_object_tracking,
video_summarization, virual_tryon)
# yapf: enable

View File

@@ -0,0 +1,20 @@
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .model import VideoObjectSegmentation
else:
_import_structure = {'model': ['VideoObjectSegmentation']}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1,29 @@
# Adopted from https://github.com/Limingxing00/RDE-VOS-CVPR2022
# under MIT License
import torch
import torch.nn.functional as F
# Soft aggregation from STM
def aggregate(prob, keep_bg=False):
# Caclulate the probability of the background.
background_prob = torch.prod(1 - prob, dim=0, keepdim=True)
# Concatenate the probabilities of background and foreground objects.
new_prob = torch.cat([background_prob, prob], 0).clamp(1e-7, 1 - 1e-7)
# logit function
logits = torch.log((new_prob / (1 - new_prob)))
if keep_bg:
return F.softmax(logits, dim=0)
else:
return F.softmax(logits, dim=0)[1:]
if __name__ == '__main__':
prob = torch.randn(size=(1, 2, 1, 1))
prob = torch.sigmoid(prob)
new = aggregate(prob, keep_bg=True)
print(prob)
print(new)

View File

@@ -0,0 +1,123 @@
# The implementation is modified from CBAM
# under the MIT License
# https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class BasicConv(nn.Module):
def __init__(self,
in_planes,
out_planes,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True):
super(BasicConv, self).__init__()
self.out_channels = out_planes
self.conv = nn.Conv2d(
in_planes,
out_planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias)
def forward(self, x):
x = self.conv(x)
return x
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
class ChannelGate(nn.Module):
def __init__(self,
gate_channels,
reduction_ratio=16,
pool_types=['avg', 'max']):
super(ChannelGate, self).__init__()
self.gate_channels = gate_channels
self.mlp = nn.Sequential(
Flatten(),
nn.Linear(gate_channels, gate_channels // reduction_ratio),
nn.ReLU(),
nn.Linear(gate_channels // reduction_ratio, gate_channels))
self.pool_types = pool_types
def forward(self, x):
channel_att_sum = None
for pool_type in self.pool_types:
if pool_type == 'avg':
avg_pool = F.avg_pool2d(
x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp(avg_pool)
elif pool_type == 'max':
max_pool = F.max_pool2d(
x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp(max_pool)
if channel_att_sum is None:
channel_att_sum = channel_att_raw
else:
channel_att_sum = channel_att_sum + channel_att_raw
scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(
3).expand_as(x)
return x * scale
class ChannelPool(nn.Module):
def forward(self, x):
return torch.cat(
(torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)),
dim=1)
class SpatialGate(nn.Module):
def __init__(self):
super(SpatialGate, self).__init__()
kernel_size = 7
self.compress = ChannelPool()
self.spatial = BasicConv(
2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2)
def forward(self, x):
x_compress = self.compress(x)
x_out = self.spatial(x_compress)
scale = torch.sigmoid(x_out) # broadcasting
return x * scale
class CBAM(nn.Module):
def __init__(self,
gate_channels,
reduction_ratio=16,
pool_types=['avg', 'max'],
no_spatial=False):
super(CBAM, self).__init__()
self.ChannelGate = ChannelGate(gate_channels, reduction_ratio,
pool_types)
self.no_spatial = no_spatial
if not no_spatial:
self.SpatialGate = SpatialGate()
def forward(self, x):
x_out = self.ChannelGate(x)
if not self.no_spatial:
x_out = self.SpatialGate(x_out)
return x_out

View File

@@ -0,0 +1,62 @@
# Adopted from https://github.com/Limingxing00/RDE-VOS-CVPR2022
# under MIT License
import torch
import torch.nn as nn
from modelscope.models.cv.video_object_segmentation.modules import (
KeyEncoder, KeyProjection, MemCrompress, ValueEncoder)
from modelscope.models.cv.video_object_segmentation.network import Decoder
class RDE_VOS(nn.Module):
def __init__(self, repeat=0):
super().__init__()
self.key_encoder = KeyEncoder()
self.value_encoder = ValueEncoder()
# Projection from f16 feature space to key space
self.key_proj = KeyProjection(1024, keydim=64)
# Compress f16 a bit to use in decoding later on
self.key_comp = nn.Conv2d(1024, 512, kernel_size=3, padding=1)
self.decoder = Decoder()
self.mem_compress = MemCrompress(repeat=repeat)
def encode_value(self, frame, kf16, masks):
k, _, h, w = masks.shape
# Extract memory key/value for a frame with multiple masks
frame = frame.view(1, 3, h, w).repeat(k, 1, 1, 1)
# Compute the "others" mask
if k != 1:
others = torch.cat([
torch.sum(
masks[[j for j in range(k) if i != j]],
dim=0,
keepdim=True) for i in range(k)
], 0)
else:
others = torch.zeros_like(masks)
f16 = self.value_encoder(frame, kf16.repeat(k, 1, 1, 1), masks, others)
return f16.unsqueeze(2)
def encode_key(self, frame):
f16, f8, f4 = self.key_encoder(frame)
k16 = self.key_proj(f16)
f16_thin = self.key_comp(f16)
return k16, f16_thin, f16, f8, f4
def segment_with_query(self, mem_bank, qf8, qf4, qk16, qv16):
k = mem_bank.num_objects
readout_mem = mem_bank.match_memory(qk16)
qv16 = qv16.expand(k, -1, -1, -1)
qv16 = torch.cat([readout_mem, qv16], 1)
return torch.sigmoid(self.decoder(qv16, qf8, qf4))

View File

@@ -0,0 +1,128 @@
# Adopted from https://github.com/Limingxing00/RDE-VOS-CVPR2022
# under MIT License
import torch
import torch.nn.functional as F
from modelscope.models.cv.video_object_segmentation.aggregate import aggregate
from modelscope.models.cv.video_object_segmentation.inference_memory_bank import \
MemoryBank
def pad_divide_by(in_img, d, in_size=None):
if in_size is None:
h, w = in_img.shape[-2:]
else:
h, w = in_size
if h % d > 0:
new_h = h + d - h % d
else:
new_h = h
if w % d > 0:
new_w = w + d - w % d
else:
new_w = w
lh, uh = int((new_h - h) / 2), int(new_h - h) - int((new_h - h) / 2)
lw, uw = int((new_w - w) / 2), int(new_w - w) - int((new_w - w) / 2)
pad_array = (int(lw), int(uw), int(lh), int(uh))
out = F.pad(in_img, pad_array)
return out, pad_array
class InferenceCore:
def __init__(self,
prop_net,
is_cuda,
images,
num_objects,
top_k=20,
mem_every=5,
include_last=False):
self.prop_net = prop_net
self.is_cuda = is_cuda
self.mem_every = mem_every
self.include_last = include_last
# True dimensions
t = images.shape[1]
h, w = images.shape[-2:]
# Pad each side to multiple of 16
images, self.pad = pad_divide_by(images, 16)
# Padded dimensions
nh, nw = images.shape[-2:]
self.images = images
if self.is_cuda:
self.device = 'cuda'
else:
self.device = 'cpu'
self.k = num_objects
# Background included, not always consistent (i.e. sum up to 1)
self.prob = torch.zeros((self.k + 1, t, 1, nh, nw),
dtype=torch.float32,
device=self.device)
self.prob[0] = 1e-7
self.t, self.h, self.w = t, h, w
self.nh, self.nw = nh, nw
self.kh = self.nh // 16
self.kw = self.nw // 16
self.mem_bank = MemoryBank(
prop_net.mem_compress,
k=self.k,
top_k=top_k,
mode='two-frames-compress')
# Compress memory bank
def encode_key(self, idx):
result = self.prop_net.encode_key(self.images[:, idx])
return result
def do_pass(self, first_k, first_v, idx, end_idx):
global tt1, tt2, tt3, tt4
self.mem_bank.add_memory(first_k, first_v)
closest_ti = end_idx
# Note that we never reach closest_ti, just the frame before it
this_range = range(idx + 1, closest_ti)
end = closest_ti - 1
for ti in this_range:
k16, qv16, qf16, qf8, qf4 = self.encode_key(ti)
out_mask = self.prop_net.segment_with_query(
self.mem_bank, qf8, qf4, k16, qv16)
out_mask = aggregate(out_mask, keep_bg=True)
self.prob[:, ti] = out_mask
if ti != end:
is_mem_frame = ((ti % self.mem_every) == 0)
if self.include_last or is_mem_frame:
prev_value = self.prop_net.encode_value(
self.images[:, ti], qf16, out_mask[1:])
prev_key = k16.unsqueeze(2)
self.mem_bank.add_memory(
prev_key, prev_value, is_temp=not is_mem_frame)
return closest_ti
def interact(self, mask, frame_idx, end_idx):
mask, _ = pad_divide_by(mask, 16)
self.prob[:, frame_idx] = aggregate(mask, keep_bg=True)
# KV pair for the interacting frame
first_k, _, qf16, _, _ = self.encode_key(frame_idx)
first_v = self.prop_net.encode_value(self.images[:, frame_idx], qf16,
self.prob[1:, frame_idx])
first_k = first_k.unsqueeze(2)
# Propagate
self.do_pass(first_k, first_v, frame_idx, end_idx)

View File

@@ -0,0 +1,255 @@
# Adopted from https://github.com/Limingxing00/RDE-VOS-CVPR2022
# under MIT License
import math
import torch
def softmax_w_top(x, top):
values, indices = torch.topk(x, k=top, dim=1)
x_exp = values.exp_()
x_exp /= torch.sum(x_exp, dim=1, keepdim=True)
x.zero_().scatter_(1, indices, x_exp) # B * THW * HW
return x
def make_gaussian(y_idx, x_idx, height, width, sigma=7):
yv, xv = torch.meshgrid([torch.arange(0, height), torch.arange(0, width)])
yv = yv.reshape(height * width).unsqueeze(0).float().cuda()
xv = xv.reshape(height * width).unsqueeze(0).float().cuda()
y_idx = y_idx.transpose(0, 1)
x_idx = x_idx.transpose(0, 1)
g = torch.exp(-((yv - y_idx)**2 + (xv - x_idx)**2) / (2 * sigma**2))
return g
def kmn(x, top=None, gauss=None):
if top is not None:
if gauss is not None:
maxes = torch.max(x, dim=1, keepdim=True)[0]
x_exp = torch.exp(x - maxes) * gauss
x_exp, indices = torch.topk(x_exp, k=top, dim=1)
else:
values, indices = torch.topk(x, k=top, dim=1)
x_exp = torch.exp(values - values[:, 0])
x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True)
x_exp /= x_exp_sum
x.zero_().scatter_(1, indices, x_exp) # B * THW * HW
output = x
else:
maxes = torch.max(x, dim=1, keepdim=True)[0]
if gauss is not None:
x_exp = torch.exp(x - maxes) * gauss
x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True)
x_exp /= x_exp_sum
output = x_exp
return output
class MemoryBank:
def __init__(self, compress, k, top_k=20, mode='stm'):
self.top_k = top_k
self.CK = None
self.CV = None
self.mem_k = None
self.mem_v = None
self.num_objects = k
self.km = 5.6
self.compress = compress
self.init_mode(mode)
def init_mode(self, mode):
"""
stm, two-frames, gt, last, compress, gt-compress,
last-compress, two-frames-compress
"""
self.is_compress = None
self.use_gt = None
self.use_last = None
self.stm = None
print('mode is {}'.format(mode))
if mode == 'stm':
self.stm = True
elif mode == 'two-frames':
self.use_gt = True
self.use_last = True
elif mode == 'gt':
self.use_gt = True
elif mode == 'last':
self.use_last = True
elif mode == 'compress':
self.is_compress = True
elif mode == 'gt-compress':
self.use_gt = True
self.is_compress = True
elif mode == 'last-compress':
self.use_last = True
self.is_compress = True
elif mode == 'two-frames-compress':
self.use_gt = True
self.use_last = True
self.is_compress = True
else:
raise RuntimeError('check mode!')
def _global_matching(self, mk, qk, H, W):
# NE means number of elements -- typically T*H*W
mk = mk.flatten(start_dim=2)
qk = qk.flatten(start_dim=2)
B, CK, NE = mk.shape
a = mk.pow(2).sum(1).unsqueeze(2)
b = 2 * (mk.transpose(1, 2) @ qk)
# We don't actually need this, will update paper later
# c = qk.pow(2).expand(B, -1, -1).sum(1).unsqueeze(1)
affinity = (-a + b) / math.sqrt(CK) # B, NE, HW
# if self.km is not None:
# # Make a bunch of Gaussian distributions
# argmax_idx = affinity.max(2)[1]
# y_idx, x_idx = argmax_idx//W, argmax_idx%W
# g = make_gaussian(y_idx, x_idx, H, W, sigma=self.km)
# g = g.view(B, NE, H*W)
# affinity = kmn(affinity, top=20, gauss=g) # B, THW, HW
affinity = softmax_w_top(affinity, top=self.top_k) # B, THW, HW
return affinity
def _readout(self, affinity, mv):
return torch.bmm(mv, affinity)
def match_memory(self, qk):
k = self.num_objects
_, _, h, w = qk.shape
qk = qk.flatten(start_dim=2)
# use gt+last+mem
if self.temp_k is not None and self.is_compress and self.use_last and \
self.use_gt:
# print("mode: gt+last+mem")
mk = torch.cat([self.mem_k, self.temp_k, self.gt_k, self.gt_k], 2)
# mv = torch.cat([self.mem_v, self.temp_v], 2)
try:
mv = torch.cat([self.mem_v, self.temp_v, self.gt_v, self.gt_v],
2)
except Exception:
mv = torch.cat([
self.mem_v,
self.temp_v.unsqueeze(0),
self.gt_v.unsqueeze(0),
self.gt_v.unsqueeze(0)
], 3)
# use gt+last
elif self.temp_k is not None and self.use_last and self.use_gt:
# print("mode: gt+last")
mk = torch.cat([self.temp_k, self.gt_k, self.gt_k], 2)
# mv = torch.cat([self.mem_v, self.temp_v], 2)
try:
mv = torch.cat([self.temp_v, self.gt_v, self.gt_v], 2)
except Exception:
mv = torch.cat([
self.temp_v.unsqueeze(0),
self.gt_v.unsqueeze(0),
self.gt_v.unsqueeze(0)
], 3)
# use last+mem
elif self.temp_k is not None and self.is_compress and self.use_last:
# print("mode: last+mem")
mk = torch.cat([self.mem_k, self.temp_k], 2)
try:
mv = torch.cat([self.mem_v, self.temp_v], 2)
except Exception:
mv = torch.cat([self.mem_v, self.temp_v.unsqueeze(0)], 3)
# use gt+mem
elif self.is_compress and self.use_gt:
# print("mode: gt+mem")
# mk = self.mem_k
# mv = self.mem_v
mk = torch.cat([self.mem_k, self.gt_k], 2)
try:
mv = torch.cat([self.mem_v, self.gt_v], 2)
except Exception:
mv = torch.cat([self.mem_v, self.gt_v.unsqueeze(0)], 3)
# use last
elif self.temp_k is not None and self.use_last:
# print("mode: last")
mk = self.temp_k
mv = self.temp_v
# use gt or only use our embedding
else:
# print("mode: gt")
# use nothing
mk = self.mem_k
mv = self.mem_v
affinity = self._global_matching(mk, qk, h, w)
if len(mv.shape) == 6:
mv = mv.squeeze(0)
mv = mv.flatten(start_dim=2)
# One affinity for all
readout_mem = self._readout(affinity.expand(k, -1, -1), mv)
return readout_mem.view(k, self.CV, h, w)
def add_memory(self, key, value, is_temp=False):
# Temp is for "last frame"
# Not always used
# But can always be flushed
self.temp_k = None
self.temp_v = None
if self.mem_k is None:
# First frame, just shove it in
self.mem_k = key # gt
self.mem_v = value
self.CK = key.shape[1]
self.CV = value.shape[1]
self.gt_k = key
self.gt_v = value
elif self.is_compress:
# compress the two frames
if len(self.mem_v.shape) == 5:
self.mem_v = self.mem_v.unsqueeze(0)
k = torch.cat([self.mem_k, key], 2) # [1, 64, 2, 30, 57]
v = torch.cat([self.mem_v, value.unsqueeze(0)],
3) # [1, 2, 512, 2, 30, 57]
self.mem_k, self.mem_v = self.compress(k, v)
# check if use last frame
if self.use_last:
self.temp_k = key # [1, 64, 1, 30, 57]
self.temp_v = value # [2, 512, 1, 30, 57]
elif self.stm:
# stm style
# print("stm", self.mem_k.shape)
self.mem_k = torch.cat([self.mem_k, key], 2)
self.mem_v = torch.cat([self.mem_v, value], 2)
else:
# no compress
# check if use last frame
if self.use_last:
self.temp_k = key
self.temp_v = value

View File

@@ -0,0 +1,229 @@
# The implementation is modified from torchvision
# under BSD-3-Clause License
# https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
import math
from collections import OrderedDict
import torch
import torch.nn as nn
from torch.utils import model_zoo
def load_weights_sequential(target, source_state, extra_chan=1):
new_dict = OrderedDict()
for k1, v1 in target.state_dict().items():
if 'num_batches_tracked' not in k1:
if k1 in source_state:
tar_v = source_state[k1]
if v1.shape != tar_v.shape:
# Init the new segmentation channel with zeros
# print(v1.shape, tar_v.shape)
c, _, w, h = v1.shape
pads = torch.zeros((c, extra_chan, w, h),
device=tar_v.device)
nn.init.orthogonal_(pads)
tar_v = torch.cat([tar_v, pads], 1)
new_dict[k1] = tar_v
target.load_state_dict(new_dict, strict=False)
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
}
def conv3x3(in_planes, out_planes, stride=1, dilation=1, bias=True):
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
bias=bias,
dilation=dilation)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
dilation=1,
bias=True):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(
inplanes, planes, stride=stride, dilation=dilation, bias=bias)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(
planes, planes, stride=1, dilation=dilation, bias=bias)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
dilation=1,
bias=True):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=bias)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(
planes,
planes,
kernel_size=3,
stride=stride,
dilation=dilation,
padding=dilation,
bias=bias)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=bias)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers=(3, 4, 23, 3), extra_chan=1, bias=True):
self.inplanes = 64
super(ResNet, self).__init__()
self.conv1 = nn.Conv2d(
3 + extra_chan, 64, kernel_size=7, stride=2, padding=3, bias=bias)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], bias=bias)
self.layer2 = self._make_layer(
block, 128, layers[1], stride=2, bias=bias)
self.layer3 = self._make_layer(
block, 256, layers[2], stride=2, bias=bias)
self.layer4 = self._make_layer(
block, 512, layers[3], stride=2, bias=bias)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
try:
m.bias.data.zero_()
except Exception:
pass
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self,
block,
planes,
blocks,
stride=1,
dilation=1,
bias=True):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=bias),
nn.BatchNorm2d(planes * block.expansion),
)
layers = [
block(self.inplanes, planes, stride, downsample, dilation, bias)
]
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
block(self.inplanes, planes, dilation=dilation, bias=bias))
return nn.Sequential(*layers)
def resnet18(pretrained=True, extra_chan=0):
model = ResNet(BasicBlock, [2, 2, 2, 2], extra_chan=extra_chan)
if pretrained and torch.distributed.is_initialized():
local_rank = torch.distributed.get_rank()
load_weights_sequential(
model,
model_zoo.load_url(
model_urls['resnet18'],
model_dir='pretrain/resnet18-{}'.format(local_rank)),
extra_chan)
return model
def resnet50(pretrained=True, extra_chan=0):
model = ResNet(Bottleneck, [3, 4, 6, 3], extra_chan=extra_chan, bias=False)
if pretrained and torch.distributed.is_initialized():
local_rank = torch.distributed.get_rank()
load_weights_sequential(
model,
model_zoo.load_url(
model_urls['resnet50'],
model_dir='pretrain/resnet50-{}'.format(local_rank)),
extra_chan)
print(torch.distributed.get_rank(), 'resnet 50 is loading...')
return model

View File

@@ -0,0 +1,29 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
from typing import Any, Dict
import torch
from modelscope.metainfo import Models
from modelscope.models.base import TorchModel
from modelscope.models.builder import MODELS
from modelscope.models.cv.video_object_segmentation.eval_network import RDE_VOS
from modelscope.utils.constant import ModelFile, Tasks
@MODELS.register_module(
Tasks.video_object_segmentation,
module_name=Models.video_object_segmentation)
class VideoObjectSegmentation(TorchModel):
def __init__(self, model_dir: str, *args, **kwargs):
super().__init__(model_dir, *args, **kwargs)
model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE)
params = torch.load(model_path, map_location='cpu')
self.model = RDE_VOS()
self.model.load_state_dict(params, strict=True)
self.model.eval()
def forward(self, inputs: Dict[str, Any]):
return self.model(inputs)

View File

@@ -0,0 +1,523 @@
# Adopted from https://github.com/Limingxing00/RDE-VOS-CVPR2022
# under MIT License
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from modelscope.models.cv.video_object_segmentation import cbam, mod_resnet
class ResBlock(nn.Module):
def __init__(self, indim, outdim=None):
super(ResBlock, self).__init__()
if outdim is None:
outdim = indim
if indim == outdim:
self.downsample = None
else:
self.downsample = nn.Conv2d(
indim, outdim, kernel_size=3, padding=1)
self.conv1 = nn.Conv2d(indim, outdim, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(outdim, outdim, kernel_size=3, padding=1)
def forward(self, x):
r = self.conv1(F.relu(x))
r = self.conv2(F.relu(r))
if self.downsample is not None:
x = self.downsample(x)
return x + r
class FeatureFusionBlock(nn.Module):
def __init__(self, indim, outdim):
super().__init__()
self.block1 = ResBlock(indim, outdim)
self.attention = cbam.CBAM(outdim)
self.block2 = ResBlock(outdim, outdim)
def forward(self, x, f16):
x = torch.cat([x, f16], 1)
x = self.block1(x)
r = self.attention(x)
x = self.block2(x + r)
return x
# Single object version, used only in static image pretraining
# See model.py (load_network) for the modification procedure
class ValueEncoderSO(nn.Module):
def __init__(self):
super().__init__()
resnet = mod_resnet.resnet18(pretrained=False, extra_chan=1)
self.conv1 = resnet.conv1
self.bn1 = resnet.bn1
self.relu = resnet.relu # 1/2, 64
self.maxpool = resnet.maxpool
self.layer1 = resnet.layer1 # 1/4, 64
self.layer2 = resnet.layer2 # 1/8, 128
self.layer3 = resnet.layer3 # 1/16, 256
self.fuser = FeatureFusionBlock(1024 + 256, 512)
def forward(self, image, key_f16, mask):
# key_f16 is the feature from the key encoder
f = torch.cat([image, mask], 1)
x = self.conv1(f)
x = self.bn1(x)
x = self.relu(x) # 1/2, 64
x = self.maxpool(x) # 1/4, 64
x = self.layer1(x) # 1/4, 64
x = self.layer2(x) # 1/8, 128
x = self.layer3(x) # 1/16, 256
x = self.fuser(x, key_f16)
return x
# Multiple objects version, used in other times
class ValueEncoder(nn.Module):
def __init__(self):
super().__init__()
resnet = mod_resnet.resnet18(pretrained=False, extra_chan=2)
self.conv1 = resnet.conv1
self.bn1 = resnet.bn1
self.relu = resnet.relu # 1/2, 64
self.maxpool = resnet.maxpool
self.layer1 = resnet.layer1 # 1/4, 64
self.layer2 = resnet.layer2 # 1/8, 128
self.layer3 = resnet.layer3 # 1/16, 256
self.fuser = FeatureFusionBlock(1024 + 256, 512)
def forward(self, image, key_f16, mask, other_masks):
# key_f16 is the feature from the key encoder
f = torch.cat([image, mask, other_masks], 1)
x = self.conv1(f)
x = self.bn1(x)
x = self.relu(x) # 1/2, 64
x = self.maxpool(x) # 1/4, 64
x = self.layer1(x) # 1/4, 64
x = self.layer2(x) # 1/8, 128
x = self.layer3(x) # 1/16, 256
# x = torch.cat([x, x], dim=1)
x = self.fuser(x, key_f16)
return x
# from retrying import retry
# @retry(stop_max_attempt_number=5)
class KeyEncoder(nn.Module):
def __init__(self):
super().__init__()
resnet = models.resnet50(pretrained=False)
# if torch.distributed.get_rank() == 0:
# torch.distributed.barrier()
self.conv1 = resnet.conv1
self.bn1 = resnet.bn1
self.relu = resnet.relu # 1/2, 64
self.maxpool = resnet.maxpool
self.res2 = resnet.layer1 # 1/4, 256
self.layer2 = resnet.layer2 # 1/8, 512
self.layer3 = resnet.layer3 # 1/16, 1024
def forward(self, f):
x = self.conv1(f)
x = self.bn1(x)
x = self.relu(x) # 1/2, 64
x = self.maxpool(x) # 1/4, 64
f4 = self.res2(x) # 1/4, 256
f8 = self.layer2(f4) # 1/8, 512
f16 = self.layer3(f8) # 1/16, 1024
return f16, f8, f4
class UpsampleBlock(nn.Module):
def __init__(self, skip_c, up_c, out_c, scale_factor=2):
super().__init__()
self.skip_conv = nn.Conv2d(skip_c, up_c, kernel_size=3, padding=1)
self.out_conv = ResBlock(up_c, out_c)
self.scale_factor = scale_factor
def forward(self, skip_f, up_f):
x = self.skip_conv(skip_f)
x = x + F.interpolate(
up_f,
scale_factor=self.scale_factor,
mode='bilinear',
align_corners=False)
x = self.out_conv(x)
return x
class KeyProjection(nn.Module):
def __init__(self, indim, keydim):
super().__init__()
self.key_proj = nn.Conv2d(indim, keydim, kernel_size=3, padding=1)
nn.init.orthogonal_(self.key_proj.weight.data)
nn.init.zeros_(self.key_proj.bias.data)
def forward(self, x):
return self.key_proj(x)
class _NonLocalBlockND(nn.Module):
def __init__(self,
in_channels,
inter_channels=None,
dimension=3,
sub_sample=True,
bn_layer=False):
super(_NonLocalBlockND, self).__init__()
assert dimension in [1, 2, 3]
self.dimension = dimension
self.sub_sample = sub_sample
self.in_channels = in_channels
self.inter_channels = inter_channels
if self.inter_channels is None:
self.inter_channels = in_channels // 2
if self.inter_channels == 0:
self.inter_channels = 1
if dimension == 3:
conv_nd = nn.Conv3d
max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
bn = nn.InstanceNorm3d
elif dimension == 2:
conv_nd = nn.Conv2d
max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
bn = nn.BatchNorm2d
else:
conv_nd = nn.Conv1d
max_pool_layer = nn.MaxPool1d(kernel_size=(2))
bn = nn.BatchNorm1d
self.g = conv_nd(
in_channels=self.in_channels,
out_channels=self.inter_channels,
kernel_size=1,
stride=1,
padding=0)
if bn_layer:
self.W = nn.Sequential(
conv_nd(
in_channels=self.inter_channels,
out_channels=self.in_channels,
kernel_size=1,
stride=1,
padding=0), bn(self.in_channels))
# nn.init.constant_(self.W[1].weight, 0)
# nn.init.constant_(self.W[1].bias, 0)
else:
self.W = conv_nd(
in_channels=self.inter_channels,
out_channels=self.in_channels,
kernel_size=1,
stride=1,
padding=0)
# nn.init.constant_(self.W.weight, 0)
# nn.init.constant_(self.W.bias, 0)
self.theta = conv_nd(
in_channels=self.in_channels,
out_channels=self.inter_channels,
kernel_size=1,
stride=1,
padding=0)
self.phi = conv_nd(
in_channels=self.in_channels,
out_channels=self.inter_channels,
kernel_size=1,
stride=1,
padding=0)
if sub_sample:
self.g = nn.Sequential(self.g, max_pool_layer)
self.phi = nn.Sequential(self.phi, max_pool_layer)
def forward(self, x):
'''
:param x: (b, c, t, h, w)
:return:
'''
batch_size = x.size(0)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
f = torch.matmul(theta_x, phi_x)
N = f.size(-1)
f_div_C = f / N
y = torch.matmul(f_div_C, g_x)
y = y.permute(0, 2, 1).contiguous()
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
z = W_y + x
return z
class NONLocalBlock1D(_NonLocalBlockND):
def __init__(self,
in_channels,
inter_channels=None,
sub_sample=True,
bn_layer=True):
super(NONLocalBlock1D, self).__init__(
in_channels,
inter_channels=inter_channels,
dimension=1,
sub_sample=sub_sample,
bn_layer=bn_layer)
class NONLocalBlock2D(_NonLocalBlockND):
def __init__(self,
in_channels,
inter_channels=None,
sub_sample=True,
bn_layer=False):
super(NONLocalBlock2D, self).__init__(
in_channels,
inter_channels=inter_channels,
dimension=2,
sub_sample=sub_sample,
bn_layer=bn_layer)
class NONLocalBlock3D(_NonLocalBlockND):
def __init__(self,
in_channels,
inter_channels=None,
sub_sample=True,
bn_layer=True):
super(NONLocalBlock3D, self).__init__(
in_channels,
inter_channels=inter_channels,
dimension=3,
sub_sample=sub_sample,
bn_layer=bn_layer)
class _ASPPModule3D(nn.Module):
def __init__(self, inplanes, planes, kernel_size, padding, dilation):
super(_ASPPModule3D, self).__init__()
self.atrous_conv = nn.Conv3d(
inplanes,
planes,
kernel_size=kernel_size,
stride=1,
padding=padding,
dilation=dilation,
bias=False)
def forward(self, x):
x = self.atrous_conv(x)
return F.relu(x, inplace=True)
class ASPP3D(nn.Module):
def __init__(self, in_plane, out_plane, reduction=4):
super().__init__()
dilations = [1, 2, 4, 6]
mid_plane = out_plane // reduction
self.aspp1 = _ASPPModule3D(
in_plane, mid_plane, 1, padding=0, dilation=dilations[0])
self.aspp2 = _ASPPModule3D(
in_plane,
mid_plane, (1, 3, 3),
padding=(0, dilations[1], dilations[1]),
dilation=(1, dilations[1], dilations[1]))
self.aspp3 = _ASPPModule3D(
in_plane,
mid_plane, (1, 3, 3),
padding=(0, dilations[2], dilations[2]),
dilation=(1, dilations[2], dilations[2]))
self.aspp4 = _ASPPModule3D(
in_plane,
mid_plane, (1, 3, 3),
padding=(0, dilations[3], dilations[3]),
dilation=(1, dilations[3], dilations[3]))
self.conv1 = nn.Conv3d(
mid_plane * 4,
out_plane,
kernel_size=(1, 3, 3),
padding=(0, 1, 1),
bias=False)
def forward(self, x):
x1 = self.aspp1(x)
x2 = self.aspp2(x)
x3 = self.aspp3(x)
x4 = self.aspp4(x)
x = torch.cat((x1, x2, x3, x4), dim=1)
x = F.relu(self.conv1(x), inplace=True)
return x
class SELayerS(nn.Module):
def __init__(self, channel, reduction=16):
super(SELayerS, self).__init__()
channel = channel * 2 # 2 is time axis
self.avg_pool = nn.AdaptiveAvgPool3d((2, 1, 1))
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False), nn.Sigmoid())
def forward(self, x):
b, c, _, _, _ = x.size()
y = self.avg_pool(x).view(b, 2 * c)
y = self.fc(y).view(b, c, 2, 1, 1)
return x * y.expand_as(x)
class SEBasicBlock(nn.Module):
expansion = 1
# https://github.com/moskomule/senet.pytorch
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
reduction=8):
super(SEBasicBlock, self).__init__()
self.conv1 = nn.Conv3d(
inplanes, planes, kernel_size=(1, 3, 3), padding=(0, 1, 1))
self.in1 = nn.InstanceNorm3d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv3d(
planes, planes, kernel_size=(1, 3, 3), padding=(0, 1, 1))
self.in2 = nn.InstanceNorm3d(planes)
self.ses = SELayerS(planes, reduction)
self.in3 = nn.InstanceNorm3d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.in1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.in2(out)
out = self.ses(out)
out = self.in3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class SAM(nn.Module):
"""
Spatio-temporal aggregation module (SAM)
"""
def __init__(self, indim, outdim=None, repeat=0, norm=False):
super(SAM, self).__init__()
self.indim = indim
self.repeat = repeat
if outdim is None:
outdim = indim
if repeat > 0:
self.se_block = self.seRepeat(repeat)
self.conv1 = ASPP3D(indim, outdim, reduction=4) # norm is 4
# self.conv2 = nn.Conv2d(outdim, outdim, kernel_size=3, padding=1)
self.non_local = NONLocalBlock3D(indim, bn_layer=False)
def seRepeat(self, repeat=2):
return nn.Sequential(*nn.ModuleList(
[SEBasicBlock(self.indim, self.indim) for _ in range(repeat)]))
def forward(self, x):
x = self.non_local(x)
if self.repeat > 0:
x = self.se_block(x)
r = x
x = self.conv1(x) + r
return x
class MemCrompress(nn.Module):
def __init__(self, repeat=0, norm=True):
super().__init__()
self.key_encoder = SAM(64, 64, repeat=repeat, norm=norm)
self.value_encoder = SAM(512, 512, repeat=repeat, norm=norm)
self.compress_key = nn.Conv3d(
64, 64, kernel_size=(2, 3, 3), padding=(0, 1, 1))
self.compress_value = nn.Conv3d(
512, 512, kernel_size=(2, 3, 3), padding=(0, 1, 1))
# self.temporal_shuffle = temporal_shuffle
def forward(self, key, value):
# key N, C, T, H, W [4, 64, 2, 24, 24]
# value N, O, C, T, H, W
# return
# key N, C, 1, H, W
# value N, O, C, 1, H, W
N, O, C, T, H, W = value.shape
value = value.flatten(
start_dim=0, end_dim=1) # N*O, C, T, H, W [8, 512, 2, 24, 24]
k = self.compress_key(self.key_encoder(key))
v = self.compress_value(self.value_encoder(value))
v = v.view(N, O, C, 1, H, W)
return k, v

View File

@@ -0,0 +1,174 @@
# Adopted from https://github.com/Limingxing00/RDE-VOS-CVPR2022
# under MIT License
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from modelscope.models.cv.video_object_segmentation.modules import (
KeyEncoder, KeyProjection, MemCrompress, ResBlock, UpsampleBlock,
ValueEncoder, ValueEncoderSO)
class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.compress = ResBlock(1024, 512)
self.up_16_8 = UpsampleBlock(512, 512, 256) # 1/16 -> 1/8
self.up_8_4 = UpsampleBlock(256, 256, 256) # 1/8 -> 1/4
self.pred = nn.Conv2d(
256, 1, kernel_size=(3, 3), padding=(1, 1), stride=1)
def forward(self, f16, f8, f4):
x = self.compress(f16)
x = self.up_16_8(f8, x)
x = self.up_8_4(f4, x)
x = self.pred(F.relu(x))
x = F.interpolate(
x, scale_factor=4, mode='bilinear', align_corners=False)
return x
class MemoryReader(nn.Module):
def __init__(self):
super().__init__()
def get_affinity(self, mk, qk):
B, CK, T, H, W = mk.shape
mk = mk.flatten(start_dim=2)
qk = qk.flatten(start_dim=2)
a = mk.pow(2).sum(1).unsqueeze(2)
b = 2 * (mk.transpose(1, 2) @ qk)
# this term will be cancelled out in the softmax
# c = qk.pow(2).sum(1).unsqueeze(1)
affinity = (-a + b) / math.sqrt(CK) # B, THW, HW
# softmax operation; aligned the evaluation style
maxes = torch.max(affinity, dim=1, keepdim=True)[0]
x_exp = torch.exp(affinity - maxes)
x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True)
affinity = x_exp / x_exp_sum
return affinity
def readout(self, affinity, mv, qv):
B, CV, T, H, W = mv.shape
mo = mv.view(B, CV, T * H * W)
mem = torch.bmm(mo, affinity) # Weighted-sum B, CV, HW
mem = mem.view(B, CV, H, W)
mem_out = torch.cat([mem, qv], dim=1)
return mem_out
class RDE_VOS(nn.Module):
def __init__(self, single_object, repeat=0, norm=False):
super().__init__()
self.single_object = single_object
self.key_encoder = KeyEncoder()
if single_object:
self.value_encoder = ValueEncoderSO()
else:
self.value_encoder = ValueEncoder()
# Compress memory bank
self.mem_compress = MemCrompress(repeat=repeat, norm=norm)
# self.mem_compress.train(True)
# Projection from f16 feature space to key space
self.key_proj = KeyProjection(1024, keydim=64)
# Compress f16 a bit to use in decoding later on
self.key_comp = nn.Conv2d(1024, 512, kernel_size=3, padding=1)
self.memory = MemoryReader()
self.decoder = Decoder()
def aggregate(self, prob):
# During training, torch.prod work on channel dimension.
new_prob = torch.cat([torch.prod(1 - prob, dim=1, keepdim=True), prob],
1).clamp(1e-7, 1 - 1e-7)
logits = torch.log((new_prob / (1 - new_prob)))
return logits
def encode_key(self, frame):
# input: b*t*c*h*w
b, t = frame.shape[:2]
f16, f8, f4 = self.key_encoder(frame.flatten(start_dim=0, end_dim=1))
k16 = self.key_proj(f16)
f16_thin = self.key_comp(f16)
# B*C*T*H*W
k16 = k16.view(b, t, *k16.shape[-3:]).transpose(1, 2).contiguous()
# B*T*C*H*W
f16_thin = f16_thin.view(b, t, *f16_thin.shape[-3:])
f16 = f16.view(b, t, *f16.shape[-3:])
f8 = f8.view(b, t, *f8.shape[-3:])
f4 = f4.view(b, t, *f4.shape[-3:])
return k16, f16_thin, f16, f8, f4
def encode_value(self, frame, kf16, mask, other_mask=None):
# Extract memory key/value for a frame
if self.single_object:
f16 = self.value_encoder(frame, kf16, mask)
else:
f16 = self.value_encoder(frame, kf16, mask, other_mask)
return f16.unsqueeze(2) # B*512*T*H*W
def segment(self, qk16, qv16, qf8, qf4, mk16, mv16, selector=None):
# q - query, m - memory
# qv16 is f16_thin above
affinity = self.memory.get_affinity(mk16, qk16)
if self.single_object:
mv2qv = self.memory.readout(affinity, mv16, qv16)
logits = self.decoder(mv2qv, qf8, qf4)
prob = torch.sigmoid(logits)
else:
mv2qv_o1 = self.memory.readout(affinity, mv16[:, 0], qv16)
mv2qv_o2 = self.memory.readout(affinity, mv16[:, 1], qv16)
logits = torch.cat([
self.decoder(mv2qv_o1, qf8, qf4),
self.decoder(mv2qv_o2, qf8, qf4),
], 1)
prob = torch.sigmoid(logits)
prob = prob * selector.unsqueeze(2).unsqueeze(2)
logits = self.aggregate(prob)
prob = F.softmax(logits, dim=1)[:, 1:]
if self.single_object:
return logits, prob, mv2qv
else:
return logits, prob, mv2qv_o1, mv2qv_o2
def memCrompress(self, key, value):
return self.mem_compress(key, value)
def forward(self, mode, *args, **kwargs):
if mode == 'encode_key':
return self.encode_key(*args, **kwargs)
elif mode == 'encode_value':
return self.encode_value(*args, **kwargs)
elif mode == 'segment':
return self.segment(*args, **kwargs)
elif mode == 'compress':
return self.memCrompress(*args, **kwargs)
else:
raise NotImplementedError

View File

@@ -839,7 +839,13 @@ TASK_OUTPUTS = {
# {
# 'scores': [0.1, 0.2, 0.3, ...]
# }
Tasks.translation_evaluation: [OutputKeys.SCORES]
Tasks.translation_evaluation: [OutputKeys.SCORES],
# video object segmentation result for a single video
# {
# "masks": [np.array # 3D array with shape [frame_num, height, width]]
# }
Tasks.video_object_segmentation: [OutputKeys.MASKS],
}

View File

@@ -226,6 +226,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.translation_evaluation:
(Pipelines.translation_evaluation,
'damo/nlp_unite_mup_translation_evaluation_multilingual_large'),
Tasks.video_object_segmentation:
(Pipelines.video_object_segmentation,
'damo/cv_rdevos_video-object-segmentation'),
}

View File

@@ -67,6 +67,7 @@ if TYPE_CHECKING:
from .hand_static_pipeline import HandStaticPipeline
from .referring_video_object_segmentation_pipeline import ReferringVideoObjectSegmentationPipeline
from .language_guided_video_summarization_pipeline import LanguageGuidedVideoSummarizationPipeline
from .video_object_segmentation_pipeline import VideoObjectSegmentationPipeline
else:
_import_structure = {
@@ -153,7 +154,10 @@ else:
],
'language_guided_video_summarization_pipeline': [
'LanguageGuidedVideoSummarizationPipeline'
]
],
'video_object_segmentation_pipeline': [
'VideoObjectSegmentationPipeline'
],
}
import sys

View File

@@ -0,0 +1,151 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from typing import Any, Dict
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from modelscope.metainfo import Pipelines
from modelscope.models.cv.video_object_segmentation.inference_core import \
InferenceCore
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
logger = get_logger()
im_normalization = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
def unpad(img, pad):
if pad[2] + pad[3] > 0:
img = img[:, :, pad[2]:-pad[3], :]
if pad[0] + pad[1] > 0:
img = img[:, :, :, pad[0]:-pad[1]]
return img
def all_to_onehot(masks, labels):
if len(masks.shape) == 3:
Ms = np.zeros(
(len(labels), masks.shape[0], masks.shape[1], masks.shape[2]),
dtype=np.uint8)
else:
Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1]),
dtype=np.uint8)
for k, l in enumerate(labels):
Ms[k] = (masks == l).astype(np.uint8)
return Ms
@PIPELINES.register_module(
Tasks.video_object_segmentation,
module_name=Pipelines.video_object_segmentation)
class VideoObjectSegmentationPipeline(Pipeline):
def __init__(self, model: str, **kwargs):
"""
use `model` to create video_object_segmentation pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model, **kwargs)
logger.info('load model done')
self.im_transform = transforms.Compose([
transforms.ToTensor(),
im_normalization,
transforms.Resize(480, interpolation=Image.BICUBIC),
])
self.mask_transform = transforms.Compose([
transforms.Resize(480, interpolation=Image.NEAREST),
])
def preprocess(self, input: Input) -> Dict[str, Any]:
self.images = input['images']
self.mask = input['mask']
frames = len(self.images)
shape = np.shape(self.mask)
info = {}
info['name'] = 'maas_test_video'
info['frames'] = frames
info['size'] = shape # Real sizes
info['gt_obj'] = {} # Frames with labelled objects
images = []
masks = []
for i in range(frames):
img = self.images[i]
images.append(self.im_transform(img))
palette = self.mask.getpalette()
masks.append(np.array(self.mask, dtype=np.uint8))
this_labels = np.unique(masks[-1])
this_labels = this_labels[this_labels != 0]
info['gt_obj'][i] = this_labels
images = torch.stack(images, 0)
masks = np.stack(masks, 0)
labels = np.unique(masks).astype(np.uint8)
labels = labels[labels != 0]
masks = torch.from_numpy(all_to_onehot(masks, labels)).float()
# Resize to 480p
masks = self.mask_transform(masks)
masks = masks.unsqueeze(2)
info['labels'] = labels
result = {
'rgb': images,
'gt': masks,
'info': info,
'palette': np.array(palette),
}
return result
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
rgb = input['rgb'].unsqueeze(0)
msk = input['gt']
info = input['info']
k = len(info['labels'])
size = info['size']
is_cuda = rgb.is_cuda
processor = InferenceCore(
self.model.model, is_cuda, rgb, k, top_k=20, mem_every=4)
processor.interact(msk[:, 0], 0, rgb.shape[1])
# Do unpad -> upsample to original size
out_masks = torch.zeros((processor.t, 1, *size),
dtype=torch.uint8,
device='cuda' if is_cuda else 'cpu')
for ti in range(processor.t):
prob = unpad(processor.prob[:, ti], processor.pad)
prob = F.interpolate(
prob, tuple(size), mode='bilinear', align_corners=False)
out_masks[ti] = torch.argmax(prob, dim=0)
if is_cuda:
out_masks = (out_masks.detach().cpu().numpy()[:,
0]).astype(np.uint8)
else:
out_masks = (out_masks.detach().numpy()[:, 0]).astype(np.uint8)
return {OutputKeys.MASKS: out_masks}
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

View File

@@ -89,6 +89,7 @@ class CVTasks(object):
language_guided_video_summarization = 'language-guided-video-summarization'
# video segmentation
video_object_segmentation = 'video-object-segmentation'
referring_video_object_segmentation = 'referring-video-object-segmentation'
video_human_matting = 'video-human-matting'

View File

@@ -1,8 +1,11 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import cv2
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from modelscope.outputs import OutputKeys
from modelscope.preprocessors.image import load_image
@@ -478,3 +481,12 @@ def depth_to_color(depth):
(depth.max() - depth) / depth.max()) * 2**8).astype(np.uint8)[:, :, :3]
depth_color = cv2.cvtColor(depth_color, cv2.COLOR_RGB2BGR)
return depth_color
def masks_visualization(masks, palette):
vis_masks = []
for f in range(masks.shape[0]):
img_E = Image.fromarray(masks[f])
img_E.putpalette(palette)
vis_masks.append(img_E)
return vis_masks

View File

@@ -0,0 +1,51 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import unittest
from PIL import Image
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.cv.image_utils import masks_visualization
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.test_utils import test_level
class VideoObjectSegmentationTest(unittest.TestCase, DemoCompatibilityCheck):
def setUp(self) -> None:
self.task = 'video-object-segmentation'
self.model_id = 'damo/cv_rdevos_video-object-segmentation'
self.input_location = 'data/test/videos/video_object_segmentation_test'
self.images_dir = os.path.join(self.input_location, 'JPEGImages')
self.mask_file = os.path.join(self.input_location, 'Annotations',
'00000.png')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_video_object_segmentation(self):
input_images = []
for image_file in sorted(os.listdir(self.images_dir)):
img = Image.open(os.path.join(self.images_dir, image_file))\
.convert('RGB')
input_images.append(img)
mask = Image.open(self.mask_file).convert('P')
input = {'images': input_images, 'mask': mask}
segmentor = pipeline(
Tasks.video_object_segmentation, model=self.model_id)
result = segmentor(input)
out_masks = result[OutputKeys.MASKS]
vis_masks = masks_visualization(out_masks, mask.getpalette())
os.makedirs('test_result', exist_ok=True)
for f, vis_mask in enumerate(vis_masks):
vis_mask.save(os.path.join('test_result', '{:05d}.png'.format(f)))
print('test_video_object_segmentation DONE')
if __name__ == '__main__':
unittest.main()