mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 04:01:10 +01:00
add support for cv_rdevos_video-object-segmentation
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11066863
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3044047d7452f27ad5391ecfc83f1a366a5a82e3eb8c8c151a6bf02cbb37c046
|
||||
size 5366
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:401ebc9ffc8ba5f47ef31f96ebd6a2e7f82b976c49e9a36f1915f5b9202f3d38
|
||||
size 45113
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:24a6702646a071988aa3a77a40d92cfc2c3ada00ecef5afeaf9c475272f30af1
|
||||
size 50254
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:37071e9ad7c91af44b2ed9b6d8e3ea42c69f284352d4431ef680b4de01521261
|
||||
size 51068
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a019761196033e0571c40ccadfdc2eff454e40207e16c3de66a1e322ec727227
|
||||
size 50164
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8cd500e4f54ac5b85c9eb2a8b7d8706a061ce7ff80502d29850e641cd59a9e9f
|
||||
size 51056
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:fd943f6fb068f7ece7a009f729ed356dd09295ab4d19b0da858256b0259bf467
|
||||
size 51580
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:554e1725877410cd7ba7d24e151a8b7572eb88ed0d8760a70b15fc20843cda76
|
||||
size 50718
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a912a376204cdc39e128a0b1d777bc73b88f1f6b6e9d16a28755cf4a89fcfe87
|
||||
size 51000
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1e83de2980b89b971c3894e102a27a0684307a9e0c841775ce9419f2f5ab43d7
|
||||
size 50267
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3cd21580132bcbc401cd443e0b0c5bdf6f3fcbfec91e0e767cabc4f6480185a9
|
||||
size 50328
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2a1effa23c00b45e8b516dfa64e376756b9600fe94c43be7b2ae6331e5041474
|
||||
size 50885
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d0f16febf7ef8739393f16cab88e2da00aa3b1a60674ef33cbec7c3dd75584ed
|
||||
size 51075
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:cc1ffd892d12d82c10331039bc998822c9a601a1b80fb60d367f5532e3b5958f
|
||||
size 49797
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:cae709a0b5216f469591173beecd05d1e7786409dec6b41323751312c21ffbe8
|
||||
size 51894
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f2f650e5cc18dcb058da04822181d4a25eb85329d6daf1da395c88e9f537760a
|
||||
size 52597
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1f88d10b8ebc1f8b6da16faf5032a1b49ba1dab94854319637fb390a201af614
|
||||
size 51468
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:cc914fbab93a722b2d263e36f713427ae07d82b7aa52e2d189f7517fa6c4bf1b
|
||||
size 51824
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:79f27be589e2ac046d52bcc0d927e7aef8c2e1562dbb07d63aa0b7a50e0d6b90
|
||||
size 52512
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2db022e12a83b900afc12493474b2b078f373d864d64f8babbf3c8cd50b62978
|
||||
size 50119
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:402b7a8fcef26bbba5eebfc5fcc141b3a5a1a5fd06ad61ee6d3150c3beeb3824
|
||||
size 50961
|
||||
@@ -58,6 +58,7 @@ class Models(object):
|
||||
product_segmentation = 'product-segmentation'
|
||||
image_body_reshaping = 'image-body-reshaping'
|
||||
video_human_matting = 'video-human-matting'
|
||||
video_object_segmentation = 'video-object-segmentation'
|
||||
|
||||
# EasyCV models
|
||||
yolox = 'YOLOX'
|
||||
@@ -243,6 +244,7 @@ class Pipelines(object):
|
||||
image_body_reshaping = 'flow-based-body-reshaping'
|
||||
referring_video_object_segmentation = 'referring-video-object-segmentation'
|
||||
video_human_matting = 'video-human-matting'
|
||||
video_object_segmentation = 'video-object-segmentation'
|
||||
|
||||
# nlp tasks
|
||||
automatic_post_editing = 'automatic-post-editing'
|
||||
|
||||
@@ -14,6 +14,7 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints,
|
||||
object_detection, product_retrieval_embedding,
|
||||
realtime_object_detection, referring_video_object_segmentation,
|
||||
salient_detection, shop_segmentation, super_resolution,
|
||||
video_single_object_tracking, video_summarization, virual_tryon)
|
||||
video_object_segmentation, video_single_object_tracking,
|
||||
video_summarization, virual_tryon)
|
||||
|
||||
# yapf: enable
|
||||
|
||||
20
modelscope/models/cv/video_object_segmentation/__init__.py
Normal file
20
modelscope/models/cv/video_object_segmentation/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model import VideoObjectSegmentation
|
||||
|
||||
else:
|
||||
_import_structure = {'model': ['VideoObjectSegmentation']}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
29
modelscope/models/cv/video_object_segmentation/aggregate.py
Normal file
29
modelscope/models/cv/video_object_segmentation/aggregate.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# Adopted from https://github.com/Limingxing00/RDE-VOS-CVPR2022
|
||||
# under MIT License
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# Soft aggregation from STM
|
||||
def aggregate(prob, keep_bg=False):
|
||||
# Caclulate the probability of the background.
|
||||
background_prob = torch.prod(1 - prob, dim=0, keepdim=True)
|
||||
# Concatenate the probabilities of background and foreground objects.
|
||||
new_prob = torch.cat([background_prob, prob], 0).clamp(1e-7, 1 - 1e-7)
|
||||
|
||||
# logit function
|
||||
logits = torch.log((new_prob / (1 - new_prob)))
|
||||
|
||||
if keep_bg:
|
||||
return F.softmax(logits, dim=0)
|
||||
else:
|
||||
return F.softmax(logits, dim=0)[1:]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
prob = torch.randn(size=(1, 2, 1, 1))
|
||||
prob = torch.sigmoid(prob)
|
||||
new = aggregate(prob, keep_bg=True)
|
||||
print(prob)
|
||||
print(new)
|
||||
123
modelscope/models/cv/video_object_segmentation/cbam.py
Normal file
123
modelscope/models/cv/video_object_segmentation/cbam.py
Normal file
@@ -0,0 +1,123 @@
|
||||
# The implementation is modified from CBAM
|
||||
# under the MIT License
|
||||
# https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class BasicConv(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True):
|
||||
super(BasicConv, self).__init__()
|
||||
self.out_channels = out_planes
|
||||
self.conv = nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return x.view(x.size(0), -1)
|
||||
|
||||
|
||||
class ChannelGate(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
gate_channels,
|
||||
reduction_ratio=16,
|
||||
pool_types=['avg', 'max']):
|
||||
super(ChannelGate, self).__init__()
|
||||
self.gate_channels = gate_channels
|
||||
self.mlp = nn.Sequential(
|
||||
Flatten(),
|
||||
nn.Linear(gate_channels, gate_channels // reduction_ratio),
|
||||
nn.ReLU(),
|
||||
nn.Linear(gate_channels // reduction_ratio, gate_channels))
|
||||
self.pool_types = pool_types
|
||||
|
||||
def forward(self, x):
|
||||
channel_att_sum = None
|
||||
for pool_type in self.pool_types:
|
||||
if pool_type == 'avg':
|
||||
avg_pool = F.avg_pool2d(
|
||||
x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
|
||||
channel_att_raw = self.mlp(avg_pool)
|
||||
elif pool_type == 'max':
|
||||
max_pool = F.max_pool2d(
|
||||
x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
|
||||
channel_att_raw = self.mlp(max_pool)
|
||||
|
||||
if channel_att_sum is None:
|
||||
channel_att_sum = channel_att_raw
|
||||
else:
|
||||
channel_att_sum = channel_att_sum + channel_att_raw
|
||||
|
||||
scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(
|
||||
3).expand_as(x)
|
||||
return x * scale
|
||||
|
||||
|
||||
class ChannelPool(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return torch.cat(
|
||||
(torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)),
|
||||
dim=1)
|
||||
|
||||
|
||||
class SpatialGate(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(SpatialGate, self).__init__()
|
||||
kernel_size = 7
|
||||
self.compress = ChannelPool()
|
||||
self.spatial = BasicConv(
|
||||
2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2)
|
||||
|
||||
def forward(self, x):
|
||||
x_compress = self.compress(x)
|
||||
x_out = self.spatial(x_compress)
|
||||
scale = torch.sigmoid(x_out) # broadcasting
|
||||
return x * scale
|
||||
|
||||
|
||||
class CBAM(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
gate_channels,
|
||||
reduction_ratio=16,
|
||||
pool_types=['avg', 'max'],
|
||||
no_spatial=False):
|
||||
super(CBAM, self).__init__()
|
||||
self.ChannelGate = ChannelGate(gate_channels, reduction_ratio,
|
||||
pool_types)
|
||||
self.no_spatial = no_spatial
|
||||
if not no_spatial:
|
||||
self.SpatialGate = SpatialGate()
|
||||
|
||||
def forward(self, x):
|
||||
x_out = self.ChannelGate(x)
|
||||
if not self.no_spatial:
|
||||
x_out = self.SpatialGate(x_out)
|
||||
return x_out
|
||||
@@ -0,0 +1,62 @@
|
||||
# Adopted from https://github.com/Limingxing00/RDE-VOS-CVPR2022
|
||||
# under MIT License
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from modelscope.models.cv.video_object_segmentation.modules import (
|
||||
KeyEncoder, KeyProjection, MemCrompress, ValueEncoder)
|
||||
from modelscope.models.cv.video_object_segmentation.network import Decoder
|
||||
|
||||
|
||||
class RDE_VOS(nn.Module):
|
||||
|
||||
def __init__(self, repeat=0):
|
||||
super().__init__()
|
||||
self.key_encoder = KeyEncoder()
|
||||
self.value_encoder = ValueEncoder()
|
||||
|
||||
# Projection from f16 feature space to key space
|
||||
self.key_proj = KeyProjection(1024, keydim=64)
|
||||
|
||||
# Compress f16 a bit to use in decoding later on
|
||||
self.key_comp = nn.Conv2d(1024, 512, kernel_size=3, padding=1)
|
||||
|
||||
self.decoder = Decoder()
|
||||
self.mem_compress = MemCrompress(repeat=repeat)
|
||||
|
||||
def encode_value(self, frame, kf16, masks):
|
||||
k, _, h, w = masks.shape
|
||||
|
||||
# Extract memory key/value for a frame with multiple masks
|
||||
frame = frame.view(1, 3, h, w).repeat(k, 1, 1, 1)
|
||||
# Compute the "others" mask
|
||||
if k != 1:
|
||||
others = torch.cat([
|
||||
torch.sum(
|
||||
masks[[j for j in range(k) if i != j]],
|
||||
dim=0,
|
||||
keepdim=True) for i in range(k)
|
||||
], 0)
|
||||
else:
|
||||
others = torch.zeros_like(masks)
|
||||
|
||||
f16 = self.value_encoder(frame, kf16.repeat(k, 1, 1, 1), masks, others)
|
||||
|
||||
return f16.unsqueeze(2)
|
||||
|
||||
def encode_key(self, frame):
|
||||
f16, f8, f4 = self.key_encoder(frame)
|
||||
k16 = self.key_proj(f16)
|
||||
f16_thin = self.key_comp(f16)
|
||||
|
||||
return k16, f16_thin, f16, f8, f4
|
||||
|
||||
def segment_with_query(self, mem_bank, qf8, qf4, qk16, qv16):
|
||||
k = mem_bank.num_objects
|
||||
|
||||
readout_mem = mem_bank.match_memory(qk16)
|
||||
qv16 = qv16.expand(k, -1, -1, -1)
|
||||
qv16 = torch.cat([readout_mem, qv16], 1)
|
||||
|
||||
return torch.sigmoid(self.decoder(qv16, qf8, qf4))
|
||||
128
modelscope/models/cv/video_object_segmentation/inference_core.py
Normal file
128
modelscope/models/cv/video_object_segmentation/inference_core.py
Normal file
@@ -0,0 +1,128 @@
|
||||
# Adopted from https://github.com/Limingxing00/RDE-VOS-CVPR2022
|
||||
# under MIT License
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modelscope.models.cv.video_object_segmentation.aggregate import aggregate
|
||||
from modelscope.models.cv.video_object_segmentation.inference_memory_bank import \
|
||||
MemoryBank
|
||||
|
||||
|
||||
def pad_divide_by(in_img, d, in_size=None):
|
||||
if in_size is None:
|
||||
h, w = in_img.shape[-2:]
|
||||
else:
|
||||
h, w = in_size
|
||||
|
||||
if h % d > 0:
|
||||
new_h = h + d - h % d
|
||||
else:
|
||||
new_h = h
|
||||
if w % d > 0:
|
||||
new_w = w + d - w % d
|
||||
else:
|
||||
new_w = w
|
||||
lh, uh = int((new_h - h) / 2), int(new_h - h) - int((new_h - h) / 2)
|
||||
lw, uw = int((new_w - w) / 2), int(new_w - w) - int((new_w - w) / 2)
|
||||
pad_array = (int(lw), int(uw), int(lh), int(uh))
|
||||
out = F.pad(in_img, pad_array)
|
||||
return out, pad_array
|
||||
|
||||
|
||||
class InferenceCore:
|
||||
|
||||
def __init__(self,
|
||||
prop_net,
|
||||
is_cuda,
|
||||
images,
|
||||
num_objects,
|
||||
top_k=20,
|
||||
mem_every=5,
|
||||
include_last=False):
|
||||
self.prop_net = prop_net
|
||||
self.is_cuda = is_cuda
|
||||
self.mem_every = mem_every
|
||||
self.include_last = include_last
|
||||
|
||||
# True dimensions
|
||||
t = images.shape[1]
|
||||
h, w = images.shape[-2:]
|
||||
|
||||
# Pad each side to multiple of 16
|
||||
images, self.pad = pad_divide_by(images, 16)
|
||||
# Padded dimensions
|
||||
nh, nw = images.shape[-2:]
|
||||
|
||||
self.images = images
|
||||
if self.is_cuda:
|
||||
self.device = 'cuda'
|
||||
else:
|
||||
self.device = 'cpu'
|
||||
|
||||
self.k = num_objects
|
||||
|
||||
# Background included, not always consistent (i.e. sum up to 1)
|
||||
self.prob = torch.zeros((self.k + 1, t, 1, nh, nw),
|
||||
dtype=torch.float32,
|
||||
device=self.device)
|
||||
self.prob[0] = 1e-7
|
||||
|
||||
self.t, self.h, self.w = t, h, w
|
||||
self.nh, self.nw = nh, nw
|
||||
self.kh = self.nh // 16
|
||||
self.kw = self.nw // 16
|
||||
|
||||
self.mem_bank = MemoryBank(
|
||||
prop_net.mem_compress,
|
||||
k=self.k,
|
||||
top_k=top_k,
|
||||
mode='two-frames-compress')
|
||||
# Compress memory bank
|
||||
|
||||
def encode_key(self, idx):
|
||||
|
||||
result = self.prop_net.encode_key(self.images[:, idx])
|
||||
return result
|
||||
|
||||
def do_pass(self, first_k, first_v, idx, end_idx):
|
||||
global tt1, tt2, tt3, tt4
|
||||
self.mem_bank.add_memory(first_k, first_v)
|
||||
closest_ti = end_idx
|
||||
|
||||
# Note that we never reach closest_ti, just the frame before it
|
||||
this_range = range(idx + 1, closest_ti)
|
||||
end = closest_ti - 1
|
||||
for ti in this_range:
|
||||
k16, qv16, qf16, qf8, qf4 = self.encode_key(ti)
|
||||
|
||||
out_mask = self.prop_net.segment_with_query(
|
||||
self.mem_bank, qf8, qf4, k16, qv16)
|
||||
|
||||
out_mask = aggregate(out_mask, keep_bg=True)
|
||||
self.prob[:, ti] = out_mask
|
||||
|
||||
if ti != end:
|
||||
is_mem_frame = ((ti % self.mem_every) == 0)
|
||||
if self.include_last or is_mem_frame:
|
||||
prev_value = self.prop_net.encode_value(
|
||||
self.images[:, ti], qf16, out_mask[1:])
|
||||
prev_key = k16.unsqueeze(2)
|
||||
self.mem_bank.add_memory(
|
||||
prev_key, prev_value, is_temp=not is_mem_frame)
|
||||
return closest_ti
|
||||
|
||||
def interact(self, mask, frame_idx, end_idx):
|
||||
|
||||
mask, _ = pad_divide_by(mask, 16)
|
||||
|
||||
self.prob[:, frame_idx] = aggregate(mask, keep_bg=True)
|
||||
|
||||
# KV pair for the interacting frame
|
||||
first_k, _, qf16, _, _ = self.encode_key(frame_idx)
|
||||
first_v = self.prop_net.encode_value(self.images[:, frame_idx], qf16,
|
||||
self.prob[1:, frame_idx])
|
||||
first_k = first_k.unsqueeze(2)
|
||||
|
||||
# Propagate
|
||||
self.do_pass(first_k, first_v, frame_idx, end_idx)
|
||||
@@ -0,0 +1,255 @@
|
||||
# Adopted from https://github.com/Limingxing00/RDE-VOS-CVPR2022
|
||||
# under MIT License
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def softmax_w_top(x, top):
|
||||
values, indices = torch.topk(x, k=top, dim=1)
|
||||
x_exp = values.exp_()
|
||||
|
||||
x_exp /= torch.sum(x_exp, dim=1, keepdim=True)
|
||||
x.zero_().scatter_(1, indices, x_exp) # B * THW * HW
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def make_gaussian(y_idx, x_idx, height, width, sigma=7):
|
||||
yv, xv = torch.meshgrid([torch.arange(0, height), torch.arange(0, width)])
|
||||
|
||||
yv = yv.reshape(height * width).unsqueeze(0).float().cuda()
|
||||
xv = xv.reshape(height * width).unsqueeze(0).float().cuda()
|
||||
|
||||
y_idx = y_idx.transpose(0, 1)
|
||||
x_idx = x_idx.transpose(0, 1)
|
||||
|
||||
g = torch.exp(-((yv - y_idx)**2 + (xv - x_idx)**2) / (2 * sigma**2))
|
||||
|
||||
return g
|
||||
|
||||
|
||||
def kmn(x, top=None, gauss=None):
|
||||
if top is not None:
|
||||
if gauss is not None:
|
||||
maxes = torch.max(x, dim=1, keepdim=True)[0]
|
||||
x_exp = torch.exp(x - maxes) * gauss
|
||||
x_exp, indices = torch.topk(x_exp, k=top, dim=1)
|
||||
else:
|
||||
values, indices = torch.topk(x, k=top, dim=1)
|
||||
x_exp = torch.exp(values - values[:, 0])
|
||||
|
||||
x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True)
|
||||
x_exp /= x_exp_sum
|
||||
x.zero_().scatter_(1, indices, x_exp) # B * THW * HW
|
||||
|
||||
output = x
|
||||
else:
|
||||
maxes = torch.max(x, dim=1, keepdim=True)[0]
|
||||
if gauss is not None:
|
||||
x_exp = torch.exp(x - maxes) * gauss
|
||||
|
||||
x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True)
|
||||
x_exp /= x_exp_sum
|
||||
output = x_exp
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class MemoryBank:
|
||||
|
||||
def __init__(self, compress, k, top_k=20, mode='stm'):
|
||||
self.top_k = top_k
|
||||
|
||||
self.CK = None
|
||||
self.CV = None
|
||||
|
||||
self.mem_k = None
|
||||
self.mem_v = None
|
||||
|
||||
self.num_objects = k
|
||||
self.km = 5.6
|
||||
|
||||
self.compress = compress
|
||||
|
||||
self.init_mode(mode)
|
||||
|
||||
def init_mode(self, mode):
|
||||
"""
|
||||
stm, two-frames, gt, last, compress, gt-compress,
|
||||
last-compress, two-frames-compress
|
||||
"""
|
||||
self.is_compress = None
|
||||
self.use_gt = None
|
||||
self.use_last = None
|
||||
self.stm = None
|
||||
print('mode is {}'.format(mode))
|
||||
if mode == 'stm':
|
||||
self.stm = True
|
||||
elif mode == 'two-frames':
|
||||
self.use_gt = True
|
||||
self.use_last = True
|
||||
elif mode == 'gt':
|
||||
self.use_gt = True
|
||||
elif mode == 'last':
|
||||
self.use_last = True
|
||||
elif mode == 'compress':
|
||||
self.is_compress = True
|
||||
elif mode == 'gt-compress':
|
||||
self.use_gt = True
|
||||
self.is_compress = True
|
||||
elif mode == 'last-compress':
|
||||
self.use_last = True
|
||||
self.is_compress = True
|
||||
elif mode == 'two-frames-compress':
|
||||
self.use_gt = True
|
||||
self.use_last = True
|
||||
self.is_compress = True
|
||||
else:
|
||||
raise RuntimeError('check mode!')
|
||||
|
||||
def _global_matching(self, mk, qk, H, W):
|
||||
# NE means number of elements -- typically T*H*W
|
||||
mk = mk.flatten(start_dim=2)
|
||||
qk = qk.flatten(start_dim=2)
|
||||
B, CK, NE = mk.shape
|
||||
|
||||
a = mk.pow(2).sum(1).unsqueeze(2)
|
||||
b = 2 * (mk.transpose(1, 2) @ qk)
|
||||
# We don't actually need this, will update paper later
|
||||
# c = qk.pow(2).expand(B, -1, -1).sum(1).unsqueeze(1)
|
||||
|
||||
affinity = (-a + b) / math.sqrt(CK) # B, NE, HW
|
||||
# if self.km is not None:
|
||||
# # Make a bunch of Gaussian distributions
|
||||
# argmax_idx = affinity.max(2)[1]
|
||||
# y_idx, x_idx = argmax_idx//W, argmax_idx%W
|
||||
# g = make_gaussian(y_idx, x_idx, H, W, sigma=self.km)
|
||||
# g = g.view(B, NE, H*W)
|
||||
|
||||
# affinity = kmn(affinity, top=20, gauss=g) # B, THW, HW
|
||||
affinity = softmax_w_top(affinity, top=self.top_k) # B, THW, HW
|
||||
|
||||
return affinity
|
||||
|
||||
def _readout(self, affinity, mv):
|
||||
return torch.bmm(mv, affinity)
|
||||
|
||||
def match_memory(self, qk):
|
||||
k = self.num_objects
|
||||
_, _, h, w = qk.shape
|
||||
|
||||
qk = qk.flatten(start_dim=2)
|
||||
|
||||
# use gt+last+mem
|
||||
if self.temp_k is not None and self.is_compress and self.use_last and \
|
||||
self.use_gt:
|
||||
# print("mode: gt+last+mem")
|
||||
mk = torch.cat([self.mem_k, self.temp_k, self.gt_k, self.gt_k], 2)
|
||||
# mv = torch.cat([self.mem_v, self.temp_v], 2)
|
||||
try:
|
||||
mv = torch.cat([self.mem_v, self.temp_v, self.gt_v, self.gt_v],
|
||||
2)
|
||||
except Exception:
|
||||
mv = torch.cat([
|
||||
self.mem_v,
|
||||
self.temp_v.unsqueeze(0),
|
||||
self.gt_v.unsqueeze(0),
|
||||
self.gt_v.unsqueeze(0)
|
||||
], 3)
|
||||
# use gt+last
|
||||
elif self.temp_k is not None and self.use_last and self.use_gt:
|
||||
# print("mode: gt+last")
|
||||
mk = torch.cat([self.temp_k, self.gt_k, self.gt_k], 2)
|
||||
# mv = torch.cat([self.mem_v, self.temp_v], 2)
|
||||
try:
|
||||
mv = torch.cat([self.temp_v, self.gt_v, self.gt_v], 2)
|
||||
except Exception:
|
||||
mv = torch.cat([
|
||||
self.temp_v.unsqueeze(0),
|
||||
self.gt_v.unsqueeze(0),
|
||||
self.gt_v.unsqueeze(0)
|
||||
], 3)
|
||||
# use last+mem
|
||||
elif self.temp_k is not None and self.is_compress and self.use_last:
|
||||
# print("mode: last+mem")
|
||||
mk = torch.cat([self.mem_k, self.temp_k], 2)
|
||||
try:
|
||||
mv = torch.cat([self.mem_v, self.temp_v], 2)
|
||||
except Exception:
|
||||
mv = torch.cat([self.mem_v, self.temp_v.unsqueeze(0)], 3)
|
||||
# use gt+mem
|
||||
elif self.is_compress and self.use_gt:
|
||||
# print("mode: gt+mem")
|
||||
# mk = self.mem_k
|
||||
# mv = self.mem_v
|
||||
mk = torch.cat([self.mem_k, self.gt_k], 2)
|
||||
try:
|
||||
mv = torch.cat([self.mem_v, self.gt_v], 2)
|
||||
except Exception:
|
||||
mv = torch.cat([self.mem_v, self.gt_v.unsqueeze(0)], 3)
|
||||
# use last
|
||||
elif self.temp_k is not None and self.use_last:
|
||||
# print("mode: last")
|
||||
mk = self.temp_k
|
||||
mv = self.temp_v
|
||||
# use gt or only use our embedding
|
||||
else:
|
||||
# print("mode: gt")
|
||||
# use nothing
|
||||
mk = self.mem_k
|
||||
mv = self.mem_v
|
||||
|
||||
affinity = self._global_matching(mk, qk, h, w)
|
||||
if len(mv.shape) == 6:
|
||||
mv = mv.squeeze(0)
|
||||
mv = mv.flatten(start_dim=2)
|
||||
|
||||
# One affinity for all
|
||||
readout_mem = self._readout(affinity.expand(k, -1, -1), mv)
|
||||
|
||||
return readout_mem.view(k, self.CV, h, w)
|
||||
|
||||
def add_memory(self, key, value, is_temp=False):
|
||||
# Temp is for "last frame"
|
||||
# Not always used
|
||||
# But can always be flushed
|
||||
self.temp_k = None
|
||||
self.temp_v = None
|
||||
|
||||
if self.mem_k is None:
|
||||
# First frame, just shove it in
|
||||
self.mem_k = key # gt
|
||||
self.mem_v = value
|
||||
self.CK = key.shape[1]
|
||||
self.CV = value.shape[1]
|
||||
self.gt_k = key
|
||||
self.gt_v = value
|
||||
|
||||
elif self.is_compress:
|
||||
# compress the two frames
|
||||
if len(self.mem_v.shape) == 5:
|
||||
self.mem_v = self.mem_v.unsqueeze(0)
|
||||
k = torch.cat([self.mem_k, key], 2) # [1, 64, 2, 30, 57]
|
||||
v = torch.cat([self.mem_v, value.unsqueeze(0)],
|
||||
3) # [1, 2, 512, 2, 30, 57]
|
||||
self.mem_k, self.mem_v = self.compress(k, v)
|
||||
|
||||
# check if use last frame
|
||||
if self.use_last:
|
||||
self.temp_k = key # [1, 64, 1, 30, 57]
|
||||
self.temp_v = value # [2, 512, 1, 30, 57]
|
||||
|
||||
elif self.stm:
|
||||
# stm style
|
||||
# print("stm", self.mem_k.shape)
|
||||
self.mem_k = torch.cat([self.mem_k, key], 2)
|
||||
self.mem_v = torch.cat([self.mem_v, value], 2)
|
||||
|
||||
else:
|
||||
# no compress
|
||||
# check if use last frame
|
||||
if self.use_last:
|
||||
self.temp_k = key
|
||||
self.temp_v = value
|
||||
229
modelscope/models/cv/video_object_segmentation/mod_resnet.py
Normal file
229
modelscope/models/cv/video_object_segmentation/mod_resnet.py
Normal file
@@ -0,0 +1,229 @@
|
||||
# The implementation is modified from torchvision
|
||||
# under BSD-3-Clause License
|
||||
# https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
|
||||
|
||||
import math
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils import model_zoo
|
||||
|
||||
|
||||
def load_weights_sequential(target, source_state, extra_chan=1):
|
||||
|
||||
new_dict = OrderedDict()
|
||||
|
||||
for k1, v1 in target.state_dict().items():
|
||||
if 'num_batches_tracked' not in k1:
|
||||
if k1 in source_state:
|
||||
tar_v = source_state[k1]
|
||||
|
||||
if v1.shape != tar_v.shape:
|
||||
# Init the new segmentation channel with zeros
|
||||
# print(v1.shape, tar_v.shape)
|
||||
c, _, w, h = v1.shape
|
||||
pads = torch.zeros((c, extra_chan, w, h),
|
||||
device=tar_v.device)
|
||||
nn.init.orthogonal_(pads)
|
||||
tar_v = torch.cat([tar_v, pads], 1)
|
||||
|
||||
new_dict[k1] = tar_v
|
||||
|
||||
target.load_state_dict(new_dict, strict=False)
|
||||
|
||||
|
||||
model_urls = {
|
||||
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||
}
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1, dilation=1, bias=True):
|
||||
return nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=dilation,
|
||||
bias=bias,
|
||||
dilation=dilation)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self,
|
||||
inplanes,
|
||||
planes,
|
||||
stride=1,
|
||||
downsample=None,
|
||||
dilation=1,
|
||||
bias=True):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = conv3x3(
|
||||
inplanes, planes, stride=stride, dilation=dilation, bias=bias)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(
|
||||
planes, planes, stride=1, dilation=dilation, bias=bias)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self,
|
||||
inplanes,
|
||||
planes,
|
||||
stride=1,
|
||||
downsample=None,
|
||||
dilation=1,
|
||||
bias=True):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=bias)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(
|
||||
planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
padding=dilation,
|
||||
bias=bias)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=bias)
|
||||
self.bn3 = nn.BatchNorm2d(planes * 4)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(self, block, layers=(3, 4, 23, 3), extra_chan=1, bias=True):
|
||||
self.inplanes = 64
|
||||
super(ResNet, self).__init__()
|
||||
self.conv1 = nn.Conv2d(
|
||||
3 + extra_chan, 64, kernel_size=7, stride=2, padding=3, bias=bias)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0], bias=bias)
|
||||
self.layer2 = self._make_layer(
|
||||
block, 128, layers[1], stride=2, bias=bias)
|
||||
self.layer3 = self._make_layer(
|
||||
block, 256, layers[2], stride=2, bias=bias)
|
||||
self.layer4 = self._make_layer(
|
||||
block, 512, layers[3], stride=2, bias=bias)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
try:
|
||||
m.bias.data.zero_()
|
||||
except Exception:
|
||||
pass
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
def _make_layer(self,
|
||||
block,
|
||||
planes,
|
||||
blocks,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
bias=True):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
self.inplanes,
|
||||
planes * block.expansion,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=bias),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = [
|
||||
block(self.inplanes, planes, stride, downsample, dilation, bias)
|
||||
]
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(
|
||||
block(self.inplanes, planes, dilation=dilation, bias=bias))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def resnet18(pretrained=True, extra_chan=0):
|
||||
model = ResNet(BasicBlock, [2, 2, 2, 2], extra_chan=extra_chan)
|
||||
if pretrained and torch.distributed.is_initialized():
|
||||
local_rank = torch.distributed.get_rank()
|
||||
load_weights_sequential(
|
||||
model,
|
||||
model_zoo.load_url(
|
||||
model_urls['resnet18'],
|
||||
model_dir='pretrain/resnet18-{}'.format(local_rank)),
|
||||
extra_chan)
|
||||
return model
|
||||
|
||||
|
||||
def resnet50(pretrained=True, extra_chan=0):
|
||||
model = ResNet(Bottleneck, [3, 4, 6, 3], extra_chan=extra_chan, bias=False)
|
||||
|
||||
if pretrained and torch.distributed.is_initialized():
|
||||
local_rank = torch.distributed.get_rank()
|
||||
load_weights_sequential(
|
||||
model,
|
||||
model_zoo.load_url(
|
||||
model_urls['resnet50'],
|
||||
model_dir='pretrain/resnet50-{}'.format(local_rank)),
|
||||
extra_chan)
|
||||
print(torch.distributed.get_rank(), 'resnet 50 is loading...')
|
||||
return model
|
||||
29
modelscope/models/cv/video_object_segmentation/model.py
Normal file
29
modelscope/models/cv/video_object_segmentation/model.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os.path as osp
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.models.cv.video_object_segmentation.eval_network import RDE_VOS
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.video_object_segmentation,
|
||||
module_name=Models.video_object_segmentation)
|
||||
class VideoObjectSegmentation(TorchModel):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE)
|
||||
params = torch.load(model_path, map_location='cpu')
|
||||
self.model = RDE_VOS()
|
||||
self.model.load_state_dict(params, strict=True)
|
||||
self.model.eval()
|
||||
|
||||
def forward(self, inputs: Dict[str, Any]):
|
||||
return self.model(inputs)
|
||||
523
modelscope/models/cv/video_object_segmentation/modules.py
Normal file
523
modelscope/models/cv/video_object_segmentation/modules.py
Normal file
@@ -0,0 +1,523 @@
|
||||
# Adopted from https://github.com/Limingxing00/RDE-VOS-CVPR2022
|
||||
# under MIT License
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torchvision import models
|
||||
|
||||
from modelscope.models.cv.video_object_segmentation import cbam, mod_resnet
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
|
||||
def __init__(self, indim, outdim=None):
|
||||
super(ResBlock, self).__init__()
|
||||
if outdim is None:
|
||||
outdim = indim
|
||||
if indim == outdim:
|
||||
self.downsample = None
|
||||
else:
|
||||
self.downsample = nn.Conv2d(
|
||||
indim, outdim, kernel_size=3, padding=1)
|
||||
|
||||
self.conv1 = nn.Conv2d(indim, outdim, kernel_size=3, padding=1)
|
||||
self.conv2 = nn.Conv2d(outdim, outdim, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
r = self.conv1(F.relu(x))
|
||||
r = self.conv2(F.relu(r))
|
||||
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
return x + r
|
||||
|
||||
|
||||
class FeatureFusionBlock(nn.Module):
|
||||
|
||||
def __init__(self, indim, outdim):
|
||||
super().__init__()
|
||||
|
||||
self.block1 = ResBlock(indim, outdim)
|
||||
self.attention = cbam.CBAM(outdim)
|
||||
self.block2 = ResBlock(outdim, outdim)
|
||||
|
||||
def forward(self, x, f16):
|
||||
x = torch.cat([x, f16], 1)
|
||||
x = self.block1(x)
|
||||
r = self.attention(x)
|
||||
x = self.block2(x + r)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
# Single object version, used only in static image pretraining
|
||||
# See model.py (load_network) for the modification procedure
|
||||
class ValueEncoderSO(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
resnet = mod_resnet.resnet18(pretrained=False, extra_chan=1)
|
||||
self.conv1 = resnet.conv1
|
||||
self.bn1 = resnet.bn1
|
||||
self.relu = resnet.relu # 1/2, 64
|
||||
self.maxpool = resnet.maxpool
|
||||
|
||||
self.layer1 = resnet.layer1 # 1/4, 64
|
||||
self.layer2 = resnet.layer2 # 1/8, 128
|
||||
self.layer3 = resnet.layer3 # 1/16, 256
|
||||
|
||||
self.fuser = FeatureFusionBlock(1024 + 256, 512)
|
||||
|
||||
def forward(self, image, key_f16, mask):
|
||||
# key_f16 is the feature from the key encoder
|
||||
|
||||
f = torch.cat([image, mask], 1)
|
||||
|
||||
x = self.conv1(f)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x) # 1/2, 64
|
||||
x = self.maxpool(x) # 1/4, 64
|
||||
x = self.layer1(x) # 1/4, 64
|
||||
x = self.layer2(x) # 1/8, 128
|
||||
x = self.layer3(x) # 1/16, 256
|
||||
|
||||
x = self.fuser(x, key_f16)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
# Multiple objects version, used in other times
|
||||
class ValueEncoder(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
resnet = mod_resnet.resnet18(pretrained=False, extra_chan=2)
|
||||
self.conv1 = resnet.conv1
|
||||
self.bn1 = resnet.bn1
|
||||
self.relu = resnet.relu # 1/2, 64
|
||||
self.maxpool = resnet.maxpool
|
||||
|
||||
self.layer1 = resnet.layer1 # 1/4, 64
|
||||
self.layer2 = resnet.layer2 # 1/8, 128
|
||||
self.layer3 = resnet.layer3 # 1/16, 256
|
||||
|
||||
self.fuser = FeatureFusionBlock(1024 + 256, 512)
|
||||
|
||||
def forward(self, image, key_f16, mask, other_masks):
|
||||
# key_f16 is the feature from the key encoder
|
||||
|
||||
f = torch.cat([image, mask, other_masks], 1)
|
||||
|
||||
x = self.conv1(f)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x) # 1/2, 64
|
||||
x = self.maxpool(x) # 1/4, 64
|
||||
x = self.layer1(x) # 1/4, 64
|
||||
x = self.layer2(x) # 1/8, 128
|
||||
x = self.layer3(x) # 1/16, 256
|
||||
# x = torch.cat([x, x], dim=1)
|
||||
x = self.fuser(x, key_f16)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
# from retrying import retry
|
||||
# @retry(stop_max_attempt_number=5)
|
||||
class KeyEncoder(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
resnet = models.resnet50(pretrained=False)
|
||||
# if torch.distributed.get_rank() == 0:
|
||||
# torch.distributed.barrier()
|
||||
self.conv1 = resnet.conv1
|
||||
self.bn1 = resnet.bn1
|
||||
self.relu = resnet.relu # 1/2, 64
|
||||
self.maxpool = resnet.maxpool
|
||||
|
||||
self.res2 = resnet.layer1 # 1/4, 256
|
||||
self.layer2 = resnet.layer2 # 1/8, 512
|
||||
self.layer3 = resnet.layer3 # 1/16, 1024
|
||||
|
||||
def forward(self, f):
|
||||
x = self.conv1(f)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x) # 1/2, 64
|
||||
x = self.maxpool(x) # 1/4, 64
|
||||
f4 = self.res2(x) # 1/4, 256
|
||||
f8 = self.layer2(f4) # 1/8, 512
|
||||
f16 = self.layer3(f8) # 1/16, 1024
|
||||
|
||||
return f16, f8, f4
|
||||
|
||||
|
||||
class UpsampleBlock(nn.Module):
|
||||
|
||||
def __init__(self, skip_c, up_c, out_c, scale_factor=2):
|
||||
super().__init__()
|
||||
self.skip_conv = nn.Conv2d(skip_c, up_c, kernel_size=3, padding=1)
|
||||
self.out_conv = ResBlock(up_c, out_c)
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
def forward(self, skip_f, up_f):
|
||||
x = self.skip_conv(skip_f)
|
||||
x = x + F.interpolate(
|
||||
up_f,
|
||||
scale_factor=self.scale_factor,
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
x = self.out_conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class KeyProjection(nn.Module):
|
||||
|
||||
def __init__(self, indim, keydim):
|
||||
super().__init__()
|
||||
self.key_proj = nn.Conv2d(indim, keydim, kernel_size=3, padding=1)
|
||||
|
||||
nn.init.orthogonal_(self.key_proj.weight.data)
|
||||
nn.init.zeros_(self.key_proj.bias.data)
|
||||
|
||||
def forward(self, x):
|
||||
return self.key_proj(x)
|
||||
|
||||
|
||||
class _NonLocalBlockND(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
inter_channels=None,
|
||||
dimension=3,
|
||||
sub_sample=True,
|
||||
bn_layer=False):
|
||||
super(_NonLocalBlockND, self).__init__()
|
||||
|
||||
assert dimension in [1, 2, 3]
|
||||
|
||||
self.dimension = dimension
|
||||
self.sub_sample = sub_sample
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.inter_channels = inter_channels
|
||||
|
||||
if self.inter_channels is None:
|
||||
self.inter_channels = in_channels // 2
|
||||
if self.inter_channels == 0:
|
||||
self.inter_channels = 1
|
||||
|
||||
if dimension == 3:
|
||||
conv_nd = nn.Conv3d
|
||||
max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
|
||||
bn = nn.InstanceNorm3d
|
||||
elif dimension == 2:
|
||||
conv_nd = nn.Conv2d
|
||||
max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
|
||||
bn = nn.BatchNorm2d
|
||||
else:
|
||||
conv_nd = nn.Conv1d
|
||||
max_pool_layer = nn.MaxPool1d(kernel_size=(2))
|
||||
bn = nn.BatchNorm1d
|
||||
|
||||
self.g = conv_nd(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.inter_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
if bn_layer:
|
||||
self.W = nn.Sequential(
|
||||
conv_nd(
|
||||
in_channels=self.inter_channels,
|
||||
out_channels=self.in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0), bn(self.in_channels))
|
||||
# nn.init.constant_(self.W[1].weight, 0)
|
||||
# nn.init.constant_(self.W[1].bias, 0)
|
||||
else:
|
||||
self.W = conv_nd(
|
||||
in_channels=self.inter_channels,
|
||||
out_channels=self.in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
# nn.init.constant_(self.W.weight, 0)
|
||||
# nn.init.constant_(self.W.bias, 0)
|
||||
|
||||
self.theta = conv_nd(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.inter_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
self.phi = conv_nd(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.inter_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
if sub_sample:
|
||||
self.g = nn.Sequential(self.g, max_pool_layer)
|
||||
self.phi = nn.Sequential(self.phi, max_pool_layer)
|
||||
|
||||
def forward(self, x):
|
||||
'''
|
||||
:param x: (b, c, t, h, w)
|
||||
:return:
|
||||
'''
|
||||
|
||||
batch_size = x.size(0)
|
||||
|
||||
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
|
||||
g_x = g_x.permute(0, 2, 1)
|
||||
|
||||
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
|
||||
theta_x = theta_x.permute(0, 2, 1)
|
||||
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
|
||||
f = torch.matmul(theta_x, phi_x)
|
||||
N = f.size(-1)
|
||||
f_div_C = f / N
|
||||
|
||||
y = torch.matmul(f_div_C, g_x)
|
||||
y = y.permute(0, 2, 1).contiguous()
|
||||
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
|
||||
W_y = self.W(y)
|
||||
z = W_y + x
|
||||
|
||||
return z
|
||||
|
||||
|
||||
class NONLocalBlock1D(_NonLocalBlockND):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
inter_channels=None,
|
||||
sub_sample=True,
|
||||
bn_layer=True):
|
||||
super(NONLocalBlock1D, self).__init__(
|
||||
in_channels,
|
||||
inter_channels=inter_channels,
|
||||
dimension=1,
|
||||
sub_sample=sub_sample,
|
||||
bn_layer=bn_layer)
|
||||
|
||||
|
||||
class NONLocalBlock2D(_NonLocalBlockND):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
inter_channels=None,
|
||||
sub_sample=True,
|
||||
bn_layer=False):
|
||||
super(NONLocalBlock2D, self).__init__(
|
||||
in_channels,
|
||||
inter_channels=inter_channels,
|
||||
dimension=2,
|
||||
sub_sample=sub_sample,
|
||||
bn_layer=bn_layer)
|
||||
|
||||
|
||||
class NONLocalBlock3D(_NonLocalBlockND):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
inter_channels=None,
|
||||
sub_sample=True,
|
||||
bn_layer=True):
|
||||
super(NONLocalBlock3D, self).__init__(
|
||||
in_channels,
|
||||
inter_channels=inter_channels,
|
||||
dimension=3,
|
||||
sub_sample=sub_sample,
|
||||
bn_layer=bn_layer)
|
||||
|
||||
|
||||
class _ASPPModule3D(nn.Module):
|
||||
|
||||
def __init__(self, inplanes, planes, kernel_size, padding, dilation):
|
||||
super(_ASPPModule3D, self).__init__()
|
||||
self.atrous_conv = nn.Conv3d(
|
||||
inplanes,
|
||||
planes,
|
||||
kernel_size=kernel_size,
|
||||
stride=1,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.atrous_conv(x)
|
||||
return F.relu(x, inplace=True)
|
||||
|
||||
|
||||
class ASPP3D(nn.Module):
|
||||
|
||||
def __init__(self, in_plane, out_plane, reduction=4):
|
||||
super().__init__()
|
||||
dilations = [1, 2, 4, 6]
|
||||
mid_plane = out_plane // reduction
|
||||
self.aspp1 = _ASPPModule3D(
|
||||
in_plane, mid_plane, 1, padding=0, dilation=dilations[0])
|
||||
self.aspp2 = _ASPPModule3D(
|
||||
in_plane,
|
||||
mid_plane, (1, 3, 3),
|
||||
padding=(0, dilations[1], dilations[1]),
|
||||
dilation=(1, dilations[1], dilations[1]))
|
||||
self.aspp3 = _ASPPModule3D(
|
||||
in_plane,
|
||||
mid_plane, (1, 3, 3),
|
||||
padding=(0, dilations[2], dilations[2]),
|
||||
dilation=(1, dilations[2], dilations[2]))
|
||||
self.aspp4 = _ASPPModule3D(
|
||||
in_plane,
|
||||
mid_plane, (1, 3, 3),
|
||||
padding=(0, dilations[3], dilations[3]),
|
||||
dilation=(1, dilations[3], dilations[3]))
|
||||
self.conv1 = nn.Conv3d(
|
||||
mid_plane * 4,
|
||||
out_plane,
|
||||
kernel_size=(1, 3, 3),
|
||||
padding=(0, 1, 1),
|
||||
bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.aspp1(x)
|
||||
x2 = self.aspp2(x)
|
||||
x3 = self.aspp3(x)
|
||||
x4 = self.aspp4(x)
|
||||
x = torch.cat((x1, x2, x3, x4), dim=1)
|
||||
x = F.relu(self.conv1(x), inplace=True)
|
||||
return x
|
||||
|
||||
|
||||
class SELayerS(nn.Module):
|
||||
|
||||
def __init__(self, channel, reduction=16):
|
||||
super(SELayerS, self).__init__()
|
||||
channel = channel * 2 # 2 is time axis
|
||||
self.avg_pool = nn.AdaptiveAvgPool3d((2, 1, 1))
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(channel, channel // reduction, bias=False),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(channel // reduction, channel, bias=False), nn.Sigmoid())
|
||||
|
||||
def forward(self, x):
|
||||
b, c, _, _, _ = x.size()
|
||||
y = self.avg_pool(x).view(b, 2 * c)
|
||||
y = self.fc(y).view(b, c, 2, 1, 1)
|
||||
return x * y.expand_as(x)
|
||||
|
||||
|
||||
class SEBasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
# https://github.com/moskomule/senet.pytorch
|
||||
|
||||
def __init__(self,
|
||||
inplanes,
|
||||
planes,
|
||||
stride=1,
|
||||
downsample=None,
|
||||
reduction=8):
|
||||
super(SEBasicBlock, self).__init__()
|
||||
self.conv1 = nn.Conv3d(
|
||||
inplanes, planes, kernel_size=(1, 3, 3), padding=(0, 1, 1))
|
||||
self.in1 = nn.InstanceNorm3d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = nn.Conv3d(
|
||||
planes, planes, kernel_size=(1, 3, 3), padding=(0, 1, 1))
|
||||
self.in2 = nn.InstanceNorm3d(planes)
|
||||
self.ses = SELayerS(planes, reduction)
|
||||
self.in3 = nn.InstanceNorm3d(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
out = self.conv1(x)
|
||||
out = self.in1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.in2(out)
|
||||
|
||||
out = self.ses(out)
|
||||
out = self.in3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class SAM(nn.Module):
|
||||
"""
|
||||
Spatio-temporal aggregation module (SAM)
|
||||
"""
|
||||
|
||||
def __init__(self, indim, outdim=None, repeat=0, norm=False):
|
||||
super(SAM, self).__init__()
|
||||
self.indim = indim
|
||||
self.repeat = repeat
|
||||
if outdim is None:
|
||||
outdim = indim
|
||||
if repeat > 0:
|
||||
self.se_block = self.seRepeat(repeat)
|
||||
self.conv1 = ASPP3D(indim, outdim, reduction=4) # norm is 4
|
||||
|
||||
# self.conv2 = nn.Conv2d(outdim, outdim, kernel_size=3, padding=1)
|
||||
self.non_local = NONLocalBlock3D(indim, bn_layer=False)
|
||||
|
||||
def seRepeat(self, repeat=2):
|
||||
return nn.Sequential(*nn.ModuleList(
|
||||
[SEBasicBlock(self.indim, self.indim) for _ in range(repeat)]))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.non_local(x)
|
||||
|
||||
if self.repeat > 0:
|
||||
x = self.se_block(x)
|
||||
r = x
|
||||
x = self.conv1(x) + r
|
||||
return x
|
||||
|
||||
|
||||
class MemCrompress(nn.Module):
|
||||
|
||||
def __init__(self, repeat=0, norm=True):
|
||||
super().__init__()
|
||||
|
||||
self.key_encoder = SAM(64, 64, repeat=repeat, norm=norm)
|
||||
self.value_encoder = SAM(512, 512, repeat=repeat, norm=norm)
|
||||
|
||||
self.compress_key = nn.Conv3d(
|
||||
64, 64, kernel_size=(2, 3, 3), padding=(0, 1, 1))
|
||||
self.compress_value = nn.Conv3d(
|
||||
512, 512, kernel_size=(2, 3, 3), padding=(0, 1, 1))
|
||||
# self.temporal_shuffle = temporal_shuffle
|
||||
|
||||
def forward(self, key, value):
|
||||
# key N, C, T, H, W [4, 64, 2, 24, 24]
|
||||
# value N, O, C, T, H, W
|
||||
# return
|
||||
# key N, C, 1, H, W
|
||||
# value N, O, C, 1, H, W
|
||||
N, O, C, T, H, W = value.shape
|
||||
value = value.flatten(
|
||||
start_dim=0, end_dim=1) # N*O, C, T, H, W [8, 512, 2, 24, 24]
|
||||
|
||||
k = self.compress_key(self.key_encoder(key))
|
||||
v = self.compress_value(self.value_encoder(value))
|
||||
v = v.view(N, O, C, 1, H, W)
|
||||
return k, v
|
||||
174
modelscope/models/cv/video_object_segmentation/network.py
Normal file
174
modelscope/models/cv/video_object_segmentation/network.py
Normal file
@@ -0,0 +1,174 @@
|
||||
# Adopted from https://github.com/Limingxing00/RDE-VOS-CVPR2022
|
||||
# under MIT License
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modelscope.models.cv.video_object_segmentation.modules import (
|
||||
KeyEncoder, KeyProjection, MemCrompress, ResBlock, UpsampleBlock,
|
||||
ValueEncoder, ValueEncoderSO)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.compress = ResBlock(1024, 512)
|
||||
self.up_16_8 = UpsampleBlock(512, 512, 256) # 1/16 -> 1/8
|
||||
self.up_8_4 = UpsampleBlock(256, 256, 256) # 1/8 -> 1/4
|
||||
|
||||
self.pred = nn.Conv2d(
|
||||
256, 1, kernel_size=(3, 3), padding=(1, 1), stride=1)
|
||||
|
||||
def forward(self, f16, f8, f4):
|
||||
x = self.compress(f16)
|
||||
x = self.up_16_8(f8, x)
|
||||
x = self.up_8_4(f4, x)
|
||||
|
||||
x = self.pred(F.relu(x))
|
||||
|
||||
x = F.interpolate(
|
||||
x, scale_factor=4, mode='bilinear', align_corners=False)
|
||||
return x
|
||||
|
||||
|
||||
class MemoryReader(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def get_affinity(self, mk, qk):
|
||||
B, CK, T, H, W = mk.shape
|
||||
mk = mk.flatten(start_dim=2)
|
||||
qk = qk.flatten(start_dim=2)
|
||||
|
||||
a = mk.pow(2).sum(1).unsqueeze(2)
|
||||
b = 2 * (mk.transpose(1, 2) @ qk)
|
||||
# this term will be cancelled out in the softmax
|
||||
# c = qk.pow(2).sum(1).unsqueeze(1)
|
||||
|
||||
affinity = (-a + b) / math.sqrt(CK) # B, THW, HW
|
||||
|
||||
# softmax operation; aligned the evaluation style
|
||||
maxes = torch.max(affinity, dim=1, keepdim=True)[0]
|
||||
x_exp = torch.exp(affinity - maxes)
|
||||
x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True)
|
||||
affinity = x_exp / x_exp_sum
|
||||
|
||||
return affinity
|
||||
|
||||
def readout(self, affinity, mv, qv):
|
||||
B, CV, T, H, W = mv.shape
|
||||
|
||||
mo = mv.view(B, CV, T * H * W)
|
||||
mem = torch.bmm(mo, affinity) # Weighted-sum B, CV, HW
|
||||
mem = mem.view(B, CV, H, W)
|
||||
|
||||
mem_out = torch.cat([mem, qv], dim=1)
|
||||
|
||||
return mem_out
|
||||
|
||||
|
||||
class RDE_VOS(nn.Module):
|
||||
|
||||
def __init__(self, single_object, repeat=0, norm=False):
|
||||
super().__init__()
|
||||
self.single_object = single_object
|
||||
|
||||
self.key_encoder = KeyEncoder()
|
||||
if single_object:
|
||||
self.value_encoder = ValueEncoderSO()
|
||||
else:
|
||||
self.value_encoder = ValueEncoder()
|
||||
# Compress memory bank
|
||||
self.mem_compress = MemCrompress(repeat=repeat, norm=norm)
|
||||
# self.mem_compress.train(True)
|
||||
|
||||
# Projection from f16 feature space to key space
|
||||
self.key_proj = KeyProjection(1024, keydim=64)
|
||||
|
||||
# Compress f16 a bit to use in decoding later on
|
||||
self.key_comp = nn.Conv2d(1024, 512, kernel_size=3, padding=1)
|
||||
|
||||
self.memory = MemoryReader()
|
||||
self.decoder = Decoder()
|
||||
|
||||
def aggregate(self, prob):
|
||||
# During training, torch.prod work on channel dimension.
|
||||
new_prob = torch.cat([torch.prod(1 - prob, dim=1, keepdim=True), prob],
|
||||
1).clamp(1e-7, 1 - 1e-7)
|
||||
logits = torch.log((new_prob / (1 - new_prob)))
|
||||
return logits
|
||||
|
||||
def encode_key(self, frame):
|
||||
# input: b*t*c*h*w
|
||||
b, t = frame.shape[:2]
|
||||
|
||||
f16, f8, f4 = self.key_encoder(frame.flatten(start_dim=0, end_dim=1))
|
||||
k16 = self.key_proj(f16)
|
||||
f16_thin = self.key_comp(f16)
|
||||
|
||||
# B*C*T*H*W
|
||||
k16 = k16.view(b, t, *k16.shape[-3:]).transpose(1, 2).contiguous()
|
||||
|
||||
# B*T*C*H*W
|
||||
f16_thin = f16_thin.view(b, t, *f16_thin.shape[-3:])
|
||||
f16 = f16.view(b, t, *f16.shape[-3:])
|
||||
f8 = f8.view(b, t, *f8.shape[-3:])
|
||||
f4 = f4.view(b, t, *f4.shape[-3:])
|
||||
|
||||
return k16, f16_thin, f16, f8, f4
|
||||
|
||||
def encode_value(self, frame, kf16, mask, other_mask=None):
|
||||
# Extract memory key/value for a frame
|
||||
if self.single_object:
|
||||
f16 = self.value_encoder(frame, kf16, mask)
|
||||
else:
|
||||
f16 = self.value_encoder(frame, kf16, mask, other_mask)
|
||||
return f16.unsqueeze(2) # B*512*T*H*W
|
||||
|
||||
def segment(self, qk16, qv16, qf8, qf4, mk16, mv16, selector=None):
|
||||
# q - query, m - memory
|
||||
# qv16 is f16_thin above
|
||||
affinity = self.memory.get_affinity(mk16, qk16)
|
||||
|
||||
if self.single_object:
|
||||
mv2qv = self.memory.readout(affinity, mv16, qv16)
|
||||
logits = self.decoder(mv2qv, qf8, qf4)
|
||||
prob = torch.sigmoid(logits)
|
||||
else:
|
||||
mv2qv_o1 = self.memory.readout(affinity, mv16[:, 0], qv16)
|
||||
mv2qv_o2 = self.memory.readout(affinity, mv16[:, 1], qv16)
|
||||
logits = torch.cat([
|
||||
self.decoder(mv2qv_o1, qf8, qf4),
|
||||
self.decoder(mv2qv_o2, qf8, qf4),
|
||||
], 1)
|
||||
|
||||
prob = torch.sigmoid(logits)
|
||||
prob = prob * selector.unsqueeze(2).unsqueeze(2)
|
||||
|
||||
logits = self.aggregate(prob)
|
||||
prob = F.softmax(logits, dim=1)[:, 1:]
|
||||
|
||||
if self.single_object:
|
||||
return logits, prob, mv2qv
|
||||
else:
|
||||
return logits, prob, mv2qv_o1, mv2qv_o2
|
||||
|
||||
def memCrompress(self, key, value):
|
||||
return self.mem_compress(key, value)
|
||||
|
||||
def forward(self, mode, *args, **kwargs):
|
||||
if mode == 'encode_key':
|
||||
return self.encode_key(*args, **kwargs)
|
||||
elif mode == 'encode_value':
|
||||
return self.encode_value(*args, **kwargs)
|
||||
elif mode == 'segment':
|
||||
return self.segment(*args, **kwargs)
|
||||
elif mode == 'compress':
|
||||
return self.memCrompress(*args, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
@@ -839,7 +839,13 @@ TASK_OUTPUTS = {
|
||||
# {
|
||||
# 'scores': [0.1, 0.2, 0.3, ...]
|
||||
# }
|
||||
Tasks.translation_evaluation: [OutputKeys.SCORES]
|
||||
Tasks.translation_evaluation: [OutputKeys.SCORES],
|
||||
|
||||
# video object segmentation result for a single video
|
||||
# {
|
||||
# "masks": [np.array # 3D array with shape [frame_num, height, width]]
|
||||
# }
|
||||
Tasks.video_object_segmentation: [OutputKeys.MASKS],
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -226,6 +226,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
Tasks.translation_evaluation:
|
||||
(Pipelines.translation_evaluation,
|
||||
'damo/nlp_unite_mup_translation_evaluation_multilingual_large'),
|
||||
Tasks.video_object_segmentation:
|
||||
(Pipelines.video_object_segmentation,
|
||||
'damo/cv_rdevos_video-object-segmentation'),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -67,6 +67,7 @@ if TYPE_CHECKING:
|
||||
from .hand_static_pipeline import HandStaticPipeline
|
||||
from .referring_video_object_segmentation_pipeline import ReferringVideoObjectSegmentationPipeline
|
||||
from .language_guided_video_summarization_pipeline import LanguageGuidedVideoSummarizationPipeline
|
||||
from .video_object_segmentation_pipeline import VideoObjectSegmentationPipeline
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
@@ -153,7 +154,10 @@ else:
|
||||
],
|
||||
'language_guided_video_summarization_pipeline': [
|
||||
'LanguageGuidedVideoSummarizationPipeline'
|
||||
]
|
||||
],
|
||||
'video_object_segmentation_pipeline': [
|
||||
'VideoObjectSegmentationPipeline'
|
||||
],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
151
modelscope/pipelines/cv/video_object_segmentation_pipeline.py
Normal file
151
modelscope/pipelines/cv/video_object_segmentation_pipeline.py
Normal file
@@ -0,0 +1,151 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.cv.video_object_segmentation.inference_core import \
|
||||
InferenceCore
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
im_normalization = transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
|
||||
|
||||
def unpad(img, pad):
|
||||
if pad[2] + pad[3] > 0:
|
||||
img = img[:, :, pad[2]:-pad[3], :]
|
||||
if pad[0] + pad[1] > 0:
|
||||
img = img[:, :, :, pad[0]:-pad[1]]
|
||||
return img
|
||||
|
||||
|
||||
def all_to_onehot(masks, labels):
|
||||
if len(masks.shape) == 3:
|
||||
Ms = np.zeros(
|
||||
(len(labels), masks.shape[0], masks.shape[1], masks.shape[2]),
|
||||
dtype=np.uint8)
|
||||
else:
|
||||
Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1]),
|
||||
dtype=np.uint8)
|
||||
|
||||
for k, l in enumerate(labels):
|
||||
Ms[k] = (masks == l).astype(np.uint8)
|
||||
|
||||
return Ms
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.video_object_segmentation,
|
||||
module_name=Pipelines.video_object_segmentation)
|
||||
class VideoObjectSegmentationPipeline(Pipeline):
|
||||
|
||||
def __init__(self, model: str, **kwargs):
|
||||
"""
|
||||
use `model` to create video_object_segmentation pipeline for prediction
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
|
||||
super().__init__(model=model, **kwargs)
|
||||
logger.info('load model done')
|
||||
self.im_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
im_normalization,
|
||||
transforms.Resize(480, interpolation=Image.BICUBIC),
|
||||
])
|
||||
self.mask_transform = transforms.Compose([
|
||||
transforms.Resize(480, interpolation=Image.NEAREST),
|
||||
])
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
|
||||
self.images = input['images']
|
||||
self.mask = input['mask']
|
||||
|
||||
frames = len(self.images)
|
||||
shape = np.shape(self.mask)
|
||||
|
||||
info = {}
|
||||
info['name'] = 'maas_test_video'
|
||||
info['frames'] = frames
|
||||
info['size'] = shape # Real sizes
|
||||
info['gt_obj'] = {} # Frames with labelled objects
|
||||
|
||||
images = []
|
||||
masks = []
|
||||
for i in range(frames):
|
||||
img = self.images[i]
|
||||
images.append(self.im_transform(img))
|
||||
|
||||
palette = self.mask.getpalette()
|
||||
masks.append(np.array(self.mask, dtype=np.uint8))
|
||||
this_labels = np.unique(masks[-1])
|
||||
this_labels = this_labels[this_labels != 0]
|
||||
info['gt_obj'][i] = this_labels
|
||||
|
||||
images = torch.stack(images, 0)
|
||||
masks = np.stack(masks, 0)
|
||||
|
||||
labels = np.unique(masks).astype(np.uint8)
|
||||
labels = labels[labels != 0]
|
||||
|
||||
masks = torch.from_numpy(all_to_onehot(masks, labels)).float()
|
||||
# Resize to 480p
|
||||
masks = self.mask_transform(masks)
|
||||
masks = masks.unsqueeze(2)
|
||||
|
||||
info['labels'] = labels
|
||||
|
||||
result = {
|
||||
'rgb': images,
|
||||
'gt': masks,
|
||||
'info': info,
|
||||
'palette': np.array(palette),
|
||||
}
|
||||
return result
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
rgb = input['rgb'].unsqueeze(0)
|
||||
msk = input['gt']
|
||||
info = input['info']
|
||||
k = len(info['labels'])
|
||||
size = info['size']
|
||||
|
||||
is_cuda = rgb.is_cuda
|
||||
|
||||
processor = InferenceCore(
|
||||
self.model.model, is_cuda, rgb, k, top_k=20, mem_every=4)
|
||||
processor.interact(msk[:, 0], 0, rgb.shape[1])
|
||||
|
||||
# Do unpad -> upsample to original size
|
||||
out_masks = torch.zeros((processor.t, 1, *size),
|
||||
dtype=torch.uint8,
|
||||
device='cuda' if is_cuda else 'cpu')
|
||||
for ti in range(processor.t):
|
||||
prob = unpad(processor.prob[:, ti], processor.pad)
|
||||
prob = F.interpolate(
|
||||
prob, tuple(size), mode='bilinear', align_corners=False)
|
||||
out_masks[ti] = torch.argmax(prob, dim=0)
|
||||
|
||||
if is_cuda:
|
||||
out_masks = (out_masks.detach().cpu().numpy()[:,
|
||||
0]).astype(np.uint8)
|
||||
else:
|
||||
out_masks = (out_masks.detach().numpy()[:, 0]).astype(np.uint8)
|
||||
|
||||
return {OutputKeys.MASKS: out_masks}
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return inputs
|
||||
@@ -89,6 +89,7 @@ class CVTasks(object):
|
||||
language_guided_video_summarization = 'language-guided-video-summarization'
|
||||
|
||||
# video segmentation
|
||||
video_object_segmentation = 'video-object-segmentation'
|
||||
referring_video_object_segmentation = 'referring-video-object-segmentation'
|
||||
video_human_matting = 'video-human-matting'
|
||||
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.preprocessors.image import load_image
|
||||
@@ -478,3 +481,12 @@ def depth_to_color(depth):
|
||||
(depth.max() - depth) / depth.max()) * 2**8).astype(np.uint8)[:, :, :3]
|
||||
depth_color = cv2.cvtColor(depth_color, cv2.COLOR_RGB2BGR)
|
||||
return depth_color
|
||||
|
||||
|
||||
def masks_visualization(masks, palette):
|
||||
vis_masks = []
|
||||
for f in range(masks.shape[0]):
|
||||
img_E = Image.fromarray(masks[f])
|
||||
img_E.putpalette(palette)
|
||||
vis_masks.append(img_E)
|
||||
return vis_masks
|
||||
|
||||
51
tests/pipelines/test_video_object_segmentation.py
Normal file
51
tests/pipelines/test_video_object_segmentation.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.cv.image_utils import masks_visualization
|
||||
from modelscope.utils.demo_utils import DemoCompatibilityCheck
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class VideoObjectSegmentationTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.task = 'video-object-segmentation'
|
||||
self.model_id = 'damo/cv_rdevos_video-object-segmentation'
|
||||
self.input_location = 'data/test/videos/video_object_segmentation_test'
|
||||
self.images_dir = os.path.join(self.input_location, 'JPEGImages')
|
||||
self.mask_file = os.path.join(self.input_location, 'Annotations',
|
||||
'00000.png')
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_video_object_segmentation(self):
|
||||
input_images = []
|
||||
for image_file in sorted(os.listdir(self.images_dir)):
|
||||
img = Image.open(os.path.join(self.images_dir, image_file))\
|
||||
.convert('RGB')
|
||||
input_images.append(img)
|
||||
mask = Image.open(self.mask_file).convert('P')
|
||||
input = {'images': input_images, 'mask': mask}
|
||||
|
||||
segmentor = pipeline(
|
||||
Tasks.video_object_segmentation, model=self.model_id)
|
||||
result = segmentor(input)
|
||||
out_masks = result[OutputKeys.MASKS]
|
||||
|
||||
vis_masks = masks_visualization(out_masks, mask.getpalette())
|
||||
|
||||
os.makedirs('test_result', exist_ok=True)
|
||||
for f, vis_mask in enumerate(vis_masks):
|
||||
vis_mask.save(os.path.join('test_result', '{:05d}.png'.format(f)))
|
||||
|
||||
print('test_video_object_segmentation DONE')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user