diff --git a/data/test/videos/video_object_segmentation_test/Annotations/00000.png b/data/test/videos/video_object_segmentation_test/Annotations/00000.png new file mode 100644 index 00000000..fa61c872 --- /dev/null +++ b/data/test/videos/video_object_segmentation_test/Annotations/00000.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3044047d7452f27ad5391ecfc83f1a366a5a82e3eb8c8c151a6bf02cbb37c046 +size 5366 diff --git a/data/test/videos/video_object_segmentation_test/JPEGImages/00000.jpg b/data/test/videos/video_object_segmentation_test/JPEGImages/00000.jpg new file mode 100644 index 00000000..4ce9c234 --- /dev/null +++ b/data/test/videos/video_object_segmentation_test/JPEGImages/00000.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:401ebc9ffc8ba5f47ef31f96ebd6a2e7f82b976c49e9a36f1915f5b9202f3d38 +size 45113 diff --git a/data/test/videos/video_object_segmentation_test/JPEGImages/00005.jpg b/data/test/videos/video_object_segmentation_test/JPEGImages/00005.jpg new file mode 100644 index 00000000..4a2f33e6 --- /dev/null +++ b/data/test/videos/video_object_segmentation_test/JPEGImages/00005.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:24a6702646a071988aa3a77a40d92cfc2c3ada00ecef5afeaf9c475272f30af1 +size 50254 diff --git a/data/test/videos/video_object_segmentation_test/JPEGImages/00010.jpg b/data/test/videos/video_object_segmentation_test/JPEGImages/00010.jpg new file mode 100644 index 00000000..de124df9 --- /dev/null +++ b/data/test/videos/video_object_segmentation_test/JPEGImages/00010.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:37071e9ad7c91af44b2ed9b6d8e3ea42c69f284352d4431ef680b4de01521261 +size 51068 diff --git a/data/test/videos/video_object_segmentation_test/JPEGImages/00015.jpg b/data/test/videos/video_object_segmentation_test/JPEGImages/00015.jpg new file mode 100644 index 00000000..4066426d --- /dev/null +++ b/data/test/videos/video_object_segmentation_test/JPEGImages/00015.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a019761196033e0571c40ccadfdc2eff454e40207e16c3de66a1e322ec727227 +size 50164 diff --git a/data/test/videos/video_object_segmentation_test/JPEGImages/00020.jpg b/data/test/videos/video_object_segmentation_test/JPEGImages/00020.jpg new file mode 100644 index 00000000..f819bc1e --- /dev/null +++ b/data/test/videos/video_object_segmentation_test/JPEGImages/00020.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8cd500e4f54ac5b85c9eb2a8b7d8706a061ce7ff80502d29850e641cd59a9e9f +size 51056 diff --git a/data/test/videos/video_object_segmentation_test/JPEGImages/00025.jpg b/data/test/videos/video_object_segmentation_test/JPEGImages/00025.jpg new file mode 100644 index 00000000..4182d95b --- /dev/null +++ b/data/test/videos/video_object_segmentation_test/JPEGImages/00025.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fd943f6fb068f7ece7a009f729ed356dd09295ab4d19b0da858256b0259bf467 +size 51580 diff --git a/data/test/videos/video_object_segmentation_test/JPEGImages/00030.jpg b/data/test/videos/video_object_segmentation_test/JPEGImages/00030.jpg new file mode 100644 index 00000000..9afba8bf --- /dev/null +++ b/data/test/videos/video_object_segmentation_test/JPEGImages/00030.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:554e1725877410cd7ba7d24e151a8b7572eb88ed0d8760a70b15fc20843cda76 +size 50718 diff --git a/data/test/videos/video_object_segmentation_test/JPEGImages/00035.jpg b/data/test/videos/video_object_segmentation_test/JPEGImages/00035.jpg new file mode 100644 index 00000000..9aee928b --- /dev/null +++ b/data/test/videos/video_object_segmentation_test/JPEGImages/00035.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a912a376204cdc39e128a0b1d777bc73b88f1f6b6e9d16a28755cf4a89fcfe87 +size 51000 diff --git a/data/test/videos/video_object_segmentation_test/JPEGImages/00040.jpg b/data/test/videos/video_object_segmentation_test/JPEGImages/00040.jpg new file mode 100644 index 00000000..c6d7b2f8 --- /dev/null +++ b/data/test/videos/video_object_segmentation_test/JPEGImages/00040.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1e83de2980b89b971c3894e102a27a0684307a9e0c841775ce9419f2f5ab43d7 +size 50267 diff --git a/data/test/videos/video_object_segmentation_test/JPEGImages/00045.jpg b/data/test/videos/video_object_segmentation_test/JPEGImages/00045.jpg new file mode 100644 index 00000000..b9e3de1d --- /dev/null +++ b/data/test/videos/video_object_segmentation_test/JPEGImages/00045.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3cd21580132bcbc401cd443e0b0c5bdf6f3fcbfec91e0e767cabc4f6480185a9 +size 50328 diff --git a/data/test/videos/video_object_segmentation_test/JPEGImages/00050.jpg b/data/test/videos/video_object_segmentation_test/JPEGImages/00050.jpg new file mode 100644 index 00000000..f25911d8 --- /dev/null +++ b/data/test/videos/video_object_segmentation_test/JPEGImages/00050.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2a1effa23c00b45e8b516dfa64e376756b9600fe94c43be7b2ae6331e5041474 +size 50885 diff --git a/data/test/videos/video_object_segmentation_test/JPEGImages/00055.jpg b/data/test/videos/video_object_segmentation_test/JPEGImages/00055.jpg new file mode 100644 index 00000000..0dfd910c --- /dev/null +++ b/data/test/videos/video_object_segmentation_test/JPEGImages/00055.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d0f16febf7ef8739393f16cab88e2da00aa3b1a60674ef33cbec7c3dd75584ed +size 51075 diff --git a/data/test/videos/video_object_segmentation_test/JPEGImages/00060.jpg b/data/test/videos/video_object_segmentation_test/JPEGImages/00060.jpg new file mode 100644 index 00000000..ad19b253 --- /dev/null +++ b/data/test/videos/video_object_segmentation_test/JPEGImages/00060.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cc1ffd892d12d82c10331039bc998822c9a601a1b80fb60d367f5532e3b5958f +size 49797 diff --git a/data/test/videos/video_object_segmentation_test/JPEGImages/00065.jpg b/data/test/videos/video_object_segmentation_test/JPEGImages/00065.jpg new file mode 100644 index 00000000..dde3c38f --- /dev/null +++ b/data/test/videos/video_object_segmentation_test/JPEGImages/00065.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cae709a0b5216f469591173beecd05d1e7786409dec6b41323751312c21ffbe8 +size 51894 diff --git a/data/test/videos/video_object_segmentation_test/JPEGImages/00070.jpg b/data/test/videos/video_object_segmentation_test/JPEGImages/00070.jpg new file mode 100644 index 00000000..dec68d78 --- /dev/null +++ b/data/test/videos/video_object_segmentation_test/JPEGImages/00070.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f2f650e5cc18dcb058da04822181d4a25eb85329d6daf1da395c88e9f537760a +size 52597 diff --git a/data/test/videos/video_object_segmentation_test/JPEGImages/00075.jpg b/data/test/videos/video_object_segmentation_test/JPEGImages/00075.jpg new file mode 100644 index 00000000..05fc9716 --- /dev/null +++ b/data/test/videos/video_object_segmentation_test/JPEGImages/00075.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1f88d10b8ebc1f8b6da16faf5032a1b49ba1dab94854319637fb390a201af614 +size 51468 diff --git a/data/test/videos/video_object_segmentation_test/JPEGImages/00080.jpg b/data/test/videos/video_object_segmentation_test/JPEGImages/00080.jpg new file mode 100644 index 00000000..00534f40 --- /dev/null +++ b/data/test/videos/video_object_segmentation_test/JPEGImages/00080.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cc914fbab93a722b2d263e36f713427ae07d82b7aa52e2d189f7517fa6c4bf1b +size 51824 diff --git a/data/test/videos/video_object_segmentation_test/JPEGImages/00085.jpg b/data/test/videos/video_object_segmentation_test/JPEGImages/00085.jpg new file mode 100644 index 00000000..8cd5b25a --- /dev/null +++ b/data/test/videos/video_object_segmentation_test/JPEGImages/00085.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:79f27be589e2ac046d52bcc0d927e7aef8c2e1562dbb07d63aa0b7a50e0d6b90 +size 52512 diff --git a/data/test/videos/video_object_segmentation_test/JPEGImages/00090.jpg b/data/test/videos/video_object_segmentation_test/JPEGImages/00090.jpg new file mode 100644 index 00000000..d20ed6d3 --- /dev/null +++ b/data/test/videos/video_object_segmentation_test/JPEGImages/00090.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2db022e12a83b900afc12493474b2b078f373d864d64f8babbf3c8cd50b62978 +size 50119 diff --git a/data/test/videos/video_object_segmentation_test/JPEGImages/00095.jpg b/data/test/videos/video_object_segmentation_test/JPEGImages/00095.jpg new file mode 100644 index 00000000..68fbe96c --- /dev/null +++ b/data/test/videos/video_object_segmentation_test/JPEGImages/00095.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:402b7a8fcef26bbba5eebfc5fcc141b3a5a1a5fd06ad61ee6d3150c3beeb3824 +size 50961 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 0af35cc4..70fe52cc 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -58,6 +58,7 @@ class Models(object): product_segmentation = 'product-segmentation' image_body_reshaping = 'image-body-reshaping' video_human_matting = 'video-human-matting' + video_object_segmentation = 'video-object-segmentation' # EasyCV models yolox = 'YOLOX' @@ -243,6 +244,7 @@ class Pipelines(object): image_body_reshaping = 'flow-based-body-reshaping' referring_video_object_segmentation = 'referring-video-object-segmentation' video_human_matting = 'video-human-matting' + video_object_segmentation = 'video-object-segmentation' # nlp tasks automatic_post_editing = 'automatic-post-editing' diff --git a/modelscope/models/cv/__init__.py b/modelscope/models/cv/__init__.py index de972032..b781a89d 100644 --- a/modelscope/models/cv/__init__.py +++ b/modelscope/models/cv/__init__.py @@ -14,6 +14,7 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints, object_detection, product_retrieval_embedding, realtime_object_detection, referring_video_object_segmentation, salient_detection, shop_segmentation, super_resolution, - video_single_object_tracking, video_summarization, virual_tryon) + video_object_segmentation, video_single_object_tracking, + video_summarization, virual_tryon) # yapf: enable diff --git a/modelscope/models/cv/video_object_segmentation/__init__.py b/modelscope/models/cv/video_object_segmentation/__init__.py new file mode 100644 index 00000000..2318a824 --- /dev/null +++ b/modelscope/models/cv/video_object_segmentation/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .model import VideoObjectSegmentation + +else: + _import_structure = {'model': ['VideoObjectSegmentation']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/video_object_segmentation/aggregate.py b/modelscope/models/cv/video_object_segmentation/aggregate.py new file mode 100644 index 00000000..c7c8d035 --- /dev/null +++ b/modelscope/models/cv/video_object_segmentation/aggregate.py @@ -0,0 +1,29 @@ +# Adopted from https://github.com/Limingxing00/RDE-VOS-CVPR2022 +# under MIT License + +import torch +import torch.nn.functional as F + + +# Soft aggregation from STM +def aggregate(prob, keep_bg=False): + # Caclulate the probability of the background. + background_prob = torch.prod(1 - prob, dim=0, keepdim=True) + # Concatenate the probabilities of background and foreground objects. + new_prob = torch.cat([background_prob, prob], 0).clamp(1e-7, 1 - 1e-7) + + # logit function + logits = torch.log((new_prob / (1 - new_prob))) + + if keep_bg: + return F.softmax(logits, dim=0) + else: + return F.softmax(logits, dim=0)[1:] + + +if __name__ == '__main__': + prob = torch.randn(size=(1, 2, 1, 1)) + prob = torch.sigmoid(prob) + new = aggregate(prob, keep_bg=True) + print(prob) + print(new) diff --git a/modelscope/models/cv/video_object_segmentation/cbam.py b/modelscope/models/cv/video_object_segmentation/cbam.py new file mode 100644 index 00000000..fc0cdc6a --- /dev/null +++ b/modelscope/models/cv/video_object_segmentation/cbam.py @@ -0,0 +1,123 @@ +# The implementation is modified from CBAM +# under the MIT License +# https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class BasicConv(nn.Module): + + def __init__(self, + in_planes, + out_planes, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True): + super(BasicConv, self).__init__() + self.out_channels = out_planes + self.conv = nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias) + + def forward(self, x): + x = self.conv(x) + return x + + +class Flatten(nn.Module): + + def forward(self, x): + return x.view(x.size(0), -1) + + +class ChannelGate(nn.Module): + + def __init__(self, + gate_channels, + reduction_ratio=16, + pool_types=['avg', 'max']): + super(ChannelGate, self).__init__() + self.gate_channels = gate_channels + self.mlp = nn.Sequential( + Flatten(), + nn.Linear(gate_channels, gate_channels // reduction_ratio), + nn.ReLU(), + nn.Linear(gate_channels // reduction_ratio, gate_channels)) + self.pool_types = pool_types + + def forward(self, x): + channel_att_sum = None + for pool_type in self.pool_types: + if pool_type == 'avg': + avg_pool = F.avg_pool2d( + x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) + channel_att_raw = self.mlp(avg_pool) + elif pool_type == 'max': + max_pool = F.max_pool2d( + x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) + channel_att_raw = self.mlp(max_pool) + + if channel_att_sum is None: + channel_att_sum = channel_att_raw + else: + channel_att_sum = channel_att_sum + channel_att_raw + + scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze( + 3).expand_as(x) + return x * scale + + +class ChannelPool(nn.Module): + + def forward(self, x): + return torch.cat( + (torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), + dim=1) + + +class SpatialGate(nn.Module): + + def __init__(self): + super(SpatialGate, self).__init__() + kernel_size = 7 + self.compress = ChannelPool() + self.spatial = BasicConv( + 2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2) + + def forward(self, x): + x_compress = self.compress(x) + x_out = self.spatial(x_compress) + scale = torch.sigmoid(x_out) # broadcasting + return x * scale + + +class CBAM(nn.Module): + + def __init__(self, + gate_channels, + reduction_ratio=16, + pool_types=['avg', 'max'], + no_spatial=False): + super(CBAM, self).__init__() + self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, + pool_types) + self.no_spatial = no_spatial + if not no_spatial: + self.SpatialGate = SpatialGate() + + def forward(self, x): + x_out = self.ChannelGate(x) + if not self.no_spatial: + x_out = self.SpatialGate(x_out) + return x_out diff --git a/modelscope/models/cv/video_object_segmentation/eval_network.py b/modelscope/models/cv/video_object_segmentation/eval_network.py new file mode 100644 index 00000000..e43aabd6 --- /dev/null +++ b/modelscope/models/cv/video_object_segmentation/eval_network.py @@ -0,0 +1,62 @@ +# Adopted from https://github.com/Limingxing00/RDE-VOS-CVPR2022 +# under MIT License + +import torch +import torch.nn as nn + +from modelscope.models.cv.video_object_segmentation.modules import ( + KeyEncoder, KeyProjection, MemCrompress, ValueEncoder) +from modelscope.models.cv.video_object_segmentation.network import Decoder + + +class RDE_VOS(nn.Module): + + def __init__(self, repeat=0): + super().__init__() + self.key_encoder = KeyEncoder() + self.value_encoder = ValueEncoder() + + # Projection from f16 feature space to key space + self.key_proj = KeyProjection(1024, keydim=64) + + # Compress f16 a bit to use in decoding later on + self.key_comp = nn.Conv2d(1024, 512, kernel_size=3, padding=1) + + self.decoder = Decoder() + self.mem_compress = MemCrompress(repeat=repeat) + + def encode_value(self, frame, kf16, masks): + k, _, h, w = masks.shape + + # Extract memory key/value for a frame with multiple masks + frame = frame.view(1, 3, h, w).repeat(k, 1, 1, 1) + # Compute the "others" mask + if k != 1: + others = torch.cat([ + torch.sum( + masks[[j for j in range(k) if i != j]], + dim=0, + keepdim=True) for i in range(k) + ], 0) + else: + others = torch.zeros_like(masks) + + f16 = self.value_encoder(frame, kf16.repeat(k, 1, 1, 1), masks, others) + + return f16.unsqueeze(2) + + def encode_key(self, frame): + f16, f8, f4 = self.key_encoder(frame) + k16 = self.key_proj(f16) + f16_thin = self.key_comp(f16) + + return k16, f16_thin, f16, f8, f4 + + def segment_with_query(self, mem_bank, qf8, qf4, qk16, qv16): + k = mem_bank.num_objects + + readout_mem = mem_bank.match_memory(qk16) + qv16 = qv16.expand(k, -1, -1, -1) + qv16 = torch.cat([readout_mem, qv16], 1) + + return torch.sigmoid(self.decoder(qv16, qf8, qf4)) diff --git a/modelscope/models/cv/video_object_segmentation/inference_core.py b/modelscope/models/cv/video_object_segmentation/inference_core.py new file mode 100644 index 00000000..03a97a00 --- /dev/null +++ b/modelscope/models/cv/video_object_segmentation/inference_core.py @@ -0,0 +1,128 @@ +# Adopted from https://github.com/Limingxing00/RDE-VOS-CVPR2022 +# under MIT License + +import torch +import torch.nn.functional as F + +from modelscope.models.cv.video_object_segmentation.aggregate import aggregate +from modelscope.models.cv.video_object_segmentation.inference_memory_bank import \ + MemoryBank + + +def pad_divide_by(in_img, d, in_size=None): + if in_size is None: + h, w = in_img.shape[-2:] + else: + h, w = in_size + + if h % d > 0: + new_h = h + d - h % d + else: + new_h = h + if w % d > 0: + new_w = w + d - w % d + else: + new_w = w + lh, uh = int((new_h - h) / 2), int(new_h - h) - int((new_h - h) / 2) + lw, uw = int((new_w - w) / 2), int(new_w - w) - int((new_w - w) / 2) + pad_array = (int(lw), int(uw), int(lh), int(uh)) + out = F.pad(in_img, pad_array) + return out, pad_array + + +class InferenceCore: + + def __init__(self, + prop_net, + is_cuda, + images, + num_objects, + top_k=20, + mem_every=5, + include_last=False): + self.prop_net = prop_net + self.is_cuda = is_cuda + self.mem_every = mem_every + self.include_last = include_last + + # True dimensions + t = images.shape[1] + h, w = images.shape[-2:] + + # Pad each side to multiple of 16 + images, self.pad = pad_divide_by(images, 16) + # Padded dimensions + nh, nw = images.shape[-2:] + + self.images = images + if self.is_cuda: + self.device = 'cuda' + else: + self.device = 'cpu' + + self.k = num_objects + + # Background included, not always consistent (i.e. sum up to 1) + self.prob = torch.zeros((self.k + 1, t, 1, nh, nw), + dtype=torch.float32, + device=self.device) + self.prob[0] = 1e-7 + + self.t, self.h, self.w = t, h, w + self.nh, self.nw = nh, nw + self.kh = self.nh // 16 + self.kw = self.nw // 16 + + self.mem_bank = MemoryBank( + prop_net.mem_compress, + k=self.k, + top_k=top_k, + mode='two-frames-compress') + # Compress memory bank + + def encode_key(self, idx): + + result = self.prop_net.encode_key(self.images[:, idx]) + return result + + def do_pass(self, first_k, first_v, idx, end_idx): + global tt1, tt2, tt3, tt4 + self.mem_bank.add_memory(first_k, first_v) + closest_ti = end_idx + + # Note that we never reach closest_ti, just the frame before it + this_range = range(idx + 1, closest_ti) + end = closest_ti - 1 + for ti in this_range: + k16, qv16, qf16, qf8, qf4 = self.encode_key(ti) + + out_mask = self.prop_net.segment_with_query( + self.mem_bank, qf8, qf4, k16, qv16) + + out_mask = aggregate(out_mask, keep_bg=True) + self.prob[:, ti] = out_mask + + if ti != end: + is_mem_frame = ((ti % self.mem_every) == 0) + if self.include_last or is_mem_frame: + prev_value = self.prop_net.encode_value( + self.images[:, ti], qf16, out_mask[1:]) + prev_key = k16.unsqueeze(2) + self.mem_bank.add_memory( + prev_key, prev_value, is_temp=not is_mem_frame) + return closest_ti + + def interact(self, mask, frame_idx, end_idx): + + mask, _ = pad_divide_by(mask, 16) + + self.prob[:, frame_idx] = aggregate(mask, keep_bg=True) + + # KV pair for the interacting frame + first_k, _, qf16, _, _ = self.encode_key(frame_idx) + first_v = self.prop_net.encode_value(self.images[:, frame_idx], qf16, + self.prob[1:, frame_idx]) + first_k = first_k.unsqueeze(2) + + # Propagate + self.do_pass(first_k, first_v, frame_idx, end_idx) diff --git a/modelscope/models/cv/video_object_segmentation/inference_memory_bank.py b/modelscope/models/cv/video_object_segmentation/inference_memory_bank.py new file mode 100644 index 00000000..4d4869b1 --- /dev/null +++ b/modelscope/models/cv/video_object_segmentation/inference_memory_bank.py @@ -0,0 +1,255 @@ +# Adopted from https://github.com/Limingxing00/RDE-VOS-CVPR2022 +# under MIT License + +import math + +import torch + + +def softmax_w_top(x, top): + values, indices = torch.topk(x, k=top, dim=1) + x_exp = values.exp_() + + x_exp /= torch.sum(x_exp, dim=1, keepdim=True) + x.zero_().scatter_(1, indices, x_exp) # B * THW * HW + + return x + + +def make_gaussian(y_idx, x_idx, height, width, sigma=7): + yv, xv = torch.meshgrid([torch.arange(0, height), torch.arange(0, width)]) + + yv = yv.reshape(height * width).unsqueeze(0).float().cuda() + xv = xv.reshape(height * width).unsqueeze(0).float().cuda() + + y_idx = y_idx.transpose(0, 1) + x_idx = x_idx.transpose(0, 1) + + g = torch.exp(-((yv - y_idx)**2 + (xv - x_idx)**2) / (2 * sigma**2)) + + return g + + +def kmn(x, top=None, gauss=None): + if top is not None: + if gauss is not None: + maxes = torch.max(x, dim=1, keepdim=True)[0] + x_exp = torch.exp(x - maxes) * gauss + x_exp, indices = torch.topk(x_exp, k=top, dim=1) + else: + values, indices = torch.topk(x, k=top, dim=1) + x_exp = torch.exp(values - values[:, 0]) + + x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True) + x_exp /= x_exp_sum + x.zero_().scatter_(1, indices, x_exp) # B * THW * HW + + output = x + else: + maxes = torch.max(x, dim=1, keepdim=True)[0] + if gauss is not None: + x_exp = torch.exp(x - maxes) * gauss + + x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True) + x_exp /= x_exp_sum + output = x_exp + + return output + + +class MemoryBank: + + def __init__(self, compress, k, top_k=20, mode='stm'): + self.top_k = top_k + + self.CK = None + self.CV = None + + self.mem_k = None + self.mem_v = None + + self.num_objects = k + self.km = 5.6 + + self.compress = compress + + self.init_mode(mode) + + def init_mode(self, mode): + """ + stm, two-frames, gt, last, compress, gt-compress, + last-compress, two-frames-compress + """ + self.is_compress = None + self.use_gt = None + self.use_last = None + self.stm = None + print('mode is {}'.format(mode)) + if mode == 'stm': + self.stm = True + elif mode == 'two-frames': + self.use_gt = True + self.use_last = True + elif mode == 'gt': + self.use_gt = True + elif mode == 'last': + self.use_last = True + elif mode == 'compress': + self.is_compress = True + elif mode == 'gt-compress': + self.use_gt = True + self.is_compress = True + elif mode == 'last-compress': + self.use_last = True + self.is_compress = True + elif mode == 'two-frames-compress': + self.use_gt = True + self.use_last = True + self.is_compress = True + else: + raise RuntimeError('check mode!') + + def _global_matching(self, mk, qk, H, W): + # NE means number of elements -- typically T*H*W + mk = mk.flatten(start_dim=2) + qk = qk.flatten(start_dim=2) + B, CK, NE = mk.shape + + a = mk.pow(2).sum(1).unsqueeze(2) + b = 2 * (mk.transpose(1, 2) @ qk) + # We don't actually need this, will update paper later + # c = qk.pow(2).expand(B, -1, -1).sum(1).unsqueeze(1) + + affinity = (-a + b) / math.sqrt(CK) # B, NE, HW + # if self.km is not None: + # # Make a bunch of Gaussian distributions + # argmax_idx = affinity.max(2)[1] + # y_idx, x_idx = argmax_idx//W, argmax_idx%W + # g = make_gaussian(y_idx, x_idx, H, W, sigma=self.km) + # g = g.view(B, NE, H*W) + + # affinity = kmn(affinity, top=20, gauss=g) # B, THW, HW + affinity = softmax_w_top(affinity, top=self.top_k) # B, THW, HW + + return affinity + + def _readout(self, affinity, mv): + return torch.bmm(mv, affinity) + + def match_memory(self, qk): + k = self.num_objects + _, _, h, w = qk.shape + + qk = qk.flatten(start_dim=2) + + # use gt+last+mem + if self.temp_k is not None and self.is_compress and self.use_last and \ + self.use_gt: + # print("mode: gt+last+mem") + mk = torch.cat([self.mem_k, self.temp_k, self.gt_k, self.gt_k], 2) + # mv = torch.cat([self.mem_v, self.temp_v], 2) + try: + mv = torch.cat([self.mem_v, self.temp_v, self.gt_v, self.gt_v], + 2) + except Exception: + mv = torch.cat([ + self.mem_v, + self.temp_v.unsqueeze(0), + self.gt_v.unsqueeze(0), + self.gt_v.unsqueeze(0) + ], 3) + # use gt+last + elif self.temp_k is not None and self.use_last and self.use_gt: + # print("mode: gt+last") + mk = torch.cat([self.temp_k, self.gt_k, self.gt_k], 2) + # mv = torch.cat([self.mem_v, self.temp_v], 2) + try: + mv = torch.cat([self.temp_v, self.gt_v, self.gt_v], 2) + except Exception: + mv = torch.cat([ + self.temp_v.unsqueeze(0), + self.gt_v.unsqueeze(0), + self.gt_v.unsqueeze(0) + ], 3) + # use last+mem + elif self.temp_k is not None and self.is_compress and self.use_last: + # print("mode: last+mem") + mk = torch.cat([self.mem_k, self.temp_k], 2) + try: + mv = torch.cat([self.mem_v, self.temp_v], 2) + except Exception: + mv = torch.cat([self.mem_v, self.temp_v.unsqueeze(0)], 3) + # use gt+mem + elif self.is_compress and self.use_gt: + # print("mode: gt+mem") + # mk = self.mem_k + # mv = self.mem_v + mk = torch.cat([self.mem_k, self.gt_k], 2) + try: + mv = torch.cat([self.mem_v, self.gt_v], 2) + except Exception: + mv = torch.cat([self.mem_v, self.gt_v.unsqueeze(0)], 3) + # use last + elif self.temp_k is not None and self.use_last: + # print("mode: last") + mk = self.temp_k + mv = self.temp_v + # use gt or only use our embedding + else: + # print("mode: gt") + # use nothing + mk = self.mem_k + mv = self.mem_v + + affinity = self._global_matching(mk, qk, h, w) + if len(mv.shape) == 6: + mv = mv.squeeze(0) + mv = mv.flatten(start_dim=2) + + # One affinity for all + readout_mem = self._readout(affinity.expand(k, -1, -1), mv) + + return readout_mem.view(k, self.CV, h, w) + + def add_memory(self, key, value, is_temp=False): + # Temp is for "last frame" + # Not always used + # But can always be flushed + self.temp_k = None + self.temp_v = None + + if self.mem_k is None: + # First frame, just shove it in + self.mem_k = key # gt + self.mem_v = value + self.CK = key.shape[1] + self.CV = value.shape[1] + self.gt_k = key + self.gt_v = value + + elif self.is_compress: + # compress the two frames + if len(self.mem_v.shape) == 5: + self.mem_v = self.mem_v.unsqueeze(0) + k = torch.cat([self.mem_k, key], 2) # [1, 64, 2, 30, 57] + v = torch.cat([self.mem_v, value.unsqueeze(0)], + 3) # [1, 2, 512, 2, 30, 57] + self.mem_k, self.mem_v = self.compress(k, v) + + # check if use last frame + if self.use_last: + self.temp_k = key # [1, 64, 1, 30, 57] + self.temp_v = value # [2, 512, 1, 30, 57] + + elif self.stm: + # stm style + # print("stm", self.mem_k.shape) + self.mem_k = torch.cat([self.mem_k, key], 2) + self.mem_v = torch.cat([self.mem_v, value], 2) + + else: + # no compress + # check if use last frame + if self.use_last: + self.temp_k = key + self.temp_v = value diff --git a/modelscope/models/cv/video_object_segmentation/mod_resnet.py b/modelscope/models/cv/video_object_segmentation/mod_resnet.py new file mode 100644 index 00000000..5bc15500 --- /dev/null +++ b/modelscope/models/cv/video_object_segmentation/mod_resnet.py @@ -0,0 +1,229 @@ +# The implementation is modified from torchvision +# under BSD-3-Clause License +# https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py + +import math +from collections import OrderedDict + +import torch +import torch.nn as nn +from torch.utils import model_zoo + + +def load_weights_sequential(target, source_state, extra_chan=1): + + new_dict = OrderedDict() + + for k1, v1 in target.state_dict().items(): + if 'num_batches_tracked' not in k1: + if k1 in source_state: + tar_v = source_state[k1] + + if v1.shape != tar_v.shape: + # Init the new segmentation channel with zeros + # print(v1.shape, tar_v.shape) + c, _, w, h = v1.shape + pads = torch.zeros((c, extra_chan, w, h), + device=tar_v.device) + nn.init.orthogonal_(pads) + tar_v = torch.cat([tar_v, pads], 1) + + new_dict[k1] = tar_v + + target.load_state_dict(new_dict, strict=False) + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1, dilation=1, bias=True): + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + bias=bias, + dilation=dilation) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + dilation=1, + bias=True): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3( + inplanes, planes, stride=stride, dilation=dilation, bias=bias) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3( + planes, planes, stride=1, dilation=dilation, bias=bias) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + dilation=1, + bias=True): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=bias) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=stride, + dilation=dilation, + padding=dilation, + bias=bias) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=bias) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers=(3, 4, 23, 3), extra_chan=1, bias=True): + self.inplanes = 64 + super(ResNet, self).__init__() + self.conv1 = nn.Conv2d( + 3 + extra_chan, 64, kernel_size=7, stride=2, padding=3, bias=bias) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0], bias=bias) + self.layer2 = self._make_layer( + block, 128, layers[1], stride=2, bias=bias) + self.layer3 = self._make_layer( + block, 256, layers[2], stride=2, bias=bias) + self.layer4 = self._make_layer( + block, 512, layers[3], stride=2, bias=bias) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + try: + m.bias.data.zero_() + except Exception: + pass + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, + block, + planes, + blocks, + stride=1, + dilation=1, + bias=True): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=bias), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [ + block(self.inplanes, planes, stride, downsample, dilation, bias) + ] + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append( + block(self.inplanes, planes, dilation=dilation, bias=bias)) + + return nn.Sequential(*layers) + + +def resnet18(pretrained=True, extra_chan=0): + model = ResNet(BasicBlock, [2, 2, 2, 2], extra_chan=extra_chan) + if pretrained and torch.distributed.is_initialized(): + local_rank = torch.distributed.get_rank() + load_weights_sequential( + model, + model_zoo.load_url( + model_urls['resnet18'], + model_dir='pretrain/resnet18-{}'.format(local_rank)), + extra_chan) + return model + + +def resnet50(pretrained=True, extra_chan=0): + model = ResNet(Bottleneck, [3, 4, 6, 3], extra_chan=extra_chan, bias=False) + + if pretrained and torch.distributed.is_initialized(): + local_rank = torch.distributed.get_rank() + load_weights_sequential( + model, + model_zoo.load_url( + model_urls['resnet50'], + model_dir='pretrain/resnet50-{}'.format(local_rank)), + extra_chan) + print(torch.distributed.get_rank(), 'resnet 50 is loading...') + return model diff --git a/modelscope/models/cv/video_object_segmentation/model.py b/modelscope/models/cv/video_object_segmentation/model.py new file mode 100644 index 00000000..a148cd73 --- /dev/null +++ b/modelscope/models/cv/video_object_segmentation/model.py @@ -0,0 +1,29 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os.path as osp +from typing import Any, Dict + +import torch + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.cv.video_object_segmentation.eval_network import RDE_VOS +from modelscope.utils.constant import ModelFile, Tasks + + +@MODELS.register_module( + Tasks.video_object_segmentation, + module_name=Models.video_object_segmentation) +class VideoObjectSegmentation(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + super().__init__(model_dir, *args, **kwargs) + model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) + params = torch.load(model_path, map_location='cpu') + self.model = RDE_VOS() + self.model.load_state_dict(params, strict=True) + self.model.eval() + + def forward(self, inputs: Dict[str, Any]): + return self.model(inputs) diff --git a/modelscope/models/cv/video_object_segmentation/modules.py b/modelscope/models/cv/video_object_segmentation/modules.py new file mode 100644 index 00000000..11a4da01 --- /dev/null +++ b/modelscope/models/cv/video_object_segmentation/modules.py @@ -0,0 +1,523 @@ +# Adopted from https://github.com/Limingxing00/RDE-VOS-CVPR2022 +# under MIT License + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import models + +from modelscope.models.cv.video_object_segmentation import cbam, mod_resnet + + +class ResBlock(nn.Module): + + def __init__(self, indim, outdim=None): + super(ResBlock, self).__init__() + if outdim is None: + outdim = indim + if indim == outdim: + self.downsample = None + else: + self.downsample = nn.Conv2d( + indim, outdim, kernel_size=3, padding=1) + + self.conv1 = nn.Conv2d(indim, outdim, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(outdim, outdim, kernel_size=3, padding=1) + + def forward(self, x): + r = self.conv1(F.relu(x)) + r = self.conv2(F.relu(r)) + + if self.downsample is not None: + x = self.downsample(x) + + return x + r + + +class FeatureFusionBlock(nn.Module): + + def __init__(self, indim, outdim): + super().__init__() + + self.block1 = ResBlock(indim, outdim) + self.attention = cbam.CBAM(outdim) + self.block2 = ResBlock(outdim, outdim) + + def forward(self, x, f16): + x = torch.cat([x, f16], 1) + x = self.block1(x) + r = self.attention(x) + x = self.block2(x + r) + + return x + + +# Single object version, used only in static image pretraining +# See model.py (load_network) for the modification procedure +class ValueEncoderSO(nn.Module): + + def __init__(self): + super().__init__() + + resnet = mod_resnet.resnet18(pretrained=False, extra_chan=1) + self.conv1 = resnet.conv1 + self.bn1 = resnet.bn1 + self.relu = resnet.relu # 1/2, 64 + self.maxpool = resnet.maxpool + + self.layer1 = resnet.layer1 # 1/4, 64 + self.layer2 = resnet.layer2 # 1/8, 128 + self.layer3 = resnet.layer3 # 1/16, 256 + + self.fuser = FeatureFusionBlock(1024 + 256, 512) + + def forward(self, image, key_f16, mask): + # key_f16 is the feature from the key encoder + + f = torch.cat([image, mask], 1) + + x = self.conv1(f) + x = self.bn1(x) + x = self.relu(x) # 1/2, 64 + x = self.maxpool(x) # 1/4, 64 + x = self.layer1(x) # 1/4, 64 + x = self.layer2(x) # 1/8, 128 + x = self.layer3(x) # 1/16, 256 + + x = self.fuser(x, key_f16) + + return x + + +# Multiple objects version, used in other times +class ValueEncoder(nn.Module): + + def __init__(self): + super().__init__() + + resnet = mod_resnet.resnet18(pretrained=False, extra_chan=2) + self.conv1 = resnet.conv1 + self.bn1 = resnet.bn1 + self.relu = resnet.relu # 1/2, 64 + self.maxpool = resnet.maxpool + + self.layer1 = resnet.layer1 # 1/4, 64 + self.layer2 = resnet.layer2 # 1/8, 128 + self.layer3 = resnet.layer3 # 1/16, 256 + + self.fuser = FeatureFusionBlock(1024 + 256, 512) + + def forward(self, image, key_f16, mask, other_masks): + # key_f16 is the feature from the key encoder + + f = torch.cat([image, mask, other_masks], 1) + + x = self.conv1(f) + x = self.bn1(x) + x = self.relu(x) # 1/2, 64 + x = self.maxpool(x) # 1/4, 64 + x = self.layer1(x) # 1/4, 64 + x = self.layer2(x) # 1/8, 128 + x = self.layer3(x) # 1/16, 256 + # x = torch.cat([x, x], dim=1) + x = self.fuser(x, key_f16) + + return x + + +# from retrying import retry +# @retry(stop_max_attempt_number=5) +class KeyEncoder(nn.Module): + + def __init__(self): + super().__init__() + resnet = models.resnet50(pretrained=False) + # if torch.distributed.get_rank() == 0: + # torch.distributed.barrier() + self.conv1 = resnet.conv1 + self.bn1 = resnet.bn1 + self.relu = resnet.relu # 1/2, 64 + self.maxpool = resnet.maxpool + + self.res2 = resnet.layer1 # 1/4, 256 + self.layer2 = resnet.layer2 # 1/8, 512 + self.layer3 = resnet.layer3 # 1/16, 1024 + + def forward(self, f): + x = self.conv1(f) + x = self.bn1(x) + x = self.relu(x) # 1/2, 64 + x = self.maxpool(x) # 1/4, 64 + f4 = self.res2(x) # 1/4, 256 + f8 = self.layer2(f4) # 1/8, 512 + f16 = self.layer3(f8) # 1/16, 1024 + + return f16, f8, f4 + + +class UpsampleBlock(nn.Module): + + def __init__(self, skip_c, up_c, out_c, scale_factor=2): + super().__init__() + self.skip_conv = nn.Conv2d(skip_c, up_c, kernel_size=3, padding=1) + self.out_conv = ResBlock(up_c, out_c) + self.scale_factor = scale_factor + + def forward(self, skip_f, up_f): + x = self.skip_conv(skip_f) + x = x + F.interpolate( + up_f, + scale_factor=self.scale_factor, + mode='bilinear', + align_corners=False) + x = self.out_conv(x) + return x + + +class KeyProjection(nn.Module): + + def __init__(self, indim, keydim): + super().__init__() + self.key_proj = nn.Conv2d(indim, keydim, kernel_size=3, padding=1) + + nn.init.orthogonal_(self.key_proj.weight.data) + nn.init.zeros_(self.key_proj.bias.data) + + def forward(self, x): + return self.key_proj(x) + + +class _NonLocalBlockND(nn.Module): + + def __init__(self, + in_channels, + inter_channels=None, + dimension=3, + sub_sample=True, + bn_layer=False): + super(_NonLocalBlockND, self).__init__() + + assert dimension in [1, 2, 3] + + self.dimension = dimension + self.sub_sample = sub_sample + + self.in_channels = in_channels + self.inter_channels = inter_channels + + if self.inter_channels is None: + self.inter_channels = in_channels // 2 + if self.inter_channels == 0: + self.inter_channels = 1 + + if dimension == 3: + conv_nd = nn.Conv3d + max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) + bn = nn.InstanceNorm3d + elif dimension == 2: + conv_nd = nn.Conv2d + max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) + bn = nn.BatchNorm2d + else: + conv_nd = nn.Conv1d + max_pool_layer = nn.MaxPool1d(kernel_size=(2)) + bn = nn.BatchNorm1d + + self.g = conv_nd( + in_channels=self.in_channels, + out_channels=self.inter_channels, + kernel_size=1, + stride=1, + padding=0) + + if bn_layer: + self.W = nn.Sequential( + conv_nd( + in_channels=self.inter_channels, + out_channels=self.in_channels, + kernel_size=1, + stride=1, + padding=0), bn(self.in_channels)) + # nn.init.constant_(self.W[1].weight, 0) + # nn.init.constant_(self.W[1].bias, 0) + else: + self.W = conv_nd( + in_channels=self.inter_channels, + out_channels=self.in_channels, + kernel_size=1, + stride=1, + padding=0) + # nn.init.constant_(self.W.weight, 0) + # nn.init.constant_(self.W.bias, 0) + + self.theta = conv_nd( + in_channels=self.in_channels, + out_channels=self.inter_channels, + kernel_size=1, + stride=1, + padding=0) + + self.phi = conv_nd( + in_channels=self.in_channels, + out_channels=self.inter_channels, + kernel_size=1, + stride=1, + padding=0) + + if sub_sample: + self.g = nn.Sequential(self.g, max_pool_layer) + self.phi = nn.Sequential(self.phi, max_pool_layer) + + def forward(self, x): + ''' + :param x: (b, c, t, h, w) + :return: + ''' + + batch_size = x.size(0) + + g_x = self.g(x).view(batch_size, self.inter_channels, -1) + g_x = g_x.permute(0, 2, 1) + + theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) + theta_x = theta_x.permute(0, 2, 1) + phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) + f = torch.matmul(theta_x, phi_x) + N = f.size(-1) + f_div_C = f / N + + y = torch.matmul(f_div_C, g_x) + y = y.permute(0, 2, 1).contiguous() + y = y.view(batch_size, self.inter_channels, *x.size()[2:]) + W_y = self.W(y) + z = W_y + x + + return z + + +class NONLocalBlock1D(_NonLocalBlockND): + + def __init__(self, + in_channels, + inter_channels=None, + sub_sample=True, + bn_layer=True): + super(NONLocalBlock1D, self).__init__( + in_channels, + inter_channels=inter_channels, + dimension=1, + sub_sample=sub_sample, + bn_layer=bn_layer) + + +class NONLocalBlock2D(_NonLocalBlockND): + + def __init__(self, + in_channels, + inter_channels=None, + sub_sample=True, + bn_layer=False): + super(NONLocalBlock2D, self).__init__( + in_channels, + inter_channels=inter_channels, + dimension=2, + sub_sample=sub_sample, + bn_layer=bn_layer) + + +class NONLocalBlock3D(_NonLocalBlockND): + + def __init__(self, + in_channels, + inter_channels=None, + sub_sample=True, + bn_layer=True): + super(NONLocalBlock3D, self).__init__( + in_channels, + inter_channels=inter_channels, + dimension=3, + sub_sample=sub_sample, + bn_layer=bn_layer) + + +class _ASPPModule3D(nn.Module): + + def __init__(self, inplanes, planes, kernel_size, padding, dilation): + super(_ASPPModule3D, self).__init__() + self.atrous_conv = nn.Conv3d( + inplanes, + planes, + kernel_size=kernel_size, + stride=1, + padding=padding, + dilation=dilation, + bias=False) + + def forward(self, x): + x = self.atrous_conv(x) + return F.relu(x, inplace=True) + + +class ASPP3D(nn.Module): + + def __init__(self, in_plane, out_plane, reduction=4): + super().__init__() + dilations = [1, 2, 4, 6] + mid_plane = out_plane // reduction + self.aspp1 = _ASPPModule3D( + in_plane, mid_plane, 1, padding=0, dilation=dilations[0]) + self.aspp2 = _ASPPModule3D( + in_plane, + mid_plane, (1, 3, 3), + padding=(0, dilations[1], dilations[1]), + dilation=(1, dilations[1], dilations[1])) + self.aspp3 = _ASPPModule3D( + in_plane, + mid_plane, (1, 3, 3), + padding=(0, dilations[2], dilations[2]), + dilation=(1, dilations[2], dilations[2])) + self.aspp4 = _ASPPModule3D( + in_plane, + mid_plane, (1, 3, 3), + padding=(0, dilations[3], dilations[3]), + dilation=(1, dilations[3], dilations[3])) + self.conv1 = nn.Conv3d( + mid_plane * 4, + out_plane, + kernel_size=(1, 3, 3), + padding=(0, 1, 1), + bias=False) + + def forward(self, x): + x1 = self.aspp1(x) + x2 = self.aspp2(x) + x3 = self.aspp3(x) + x4 = self.aspp4(x) + x = torch.cat((x1, x2, x3, x4), dim=1) + x = F.relu(self.conv1(x), inplace=True) + return x + + +class SELayerS(nn.Module): + + def __init__(self, channel, reduction=16): + super(SELayerS, self).__init__() + channel = channel * 2 # 2 is time axis + self.avg_pool = nn.AdaptiveAvgPool3d((2, 1, 1)) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction, bias=False), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel, bias=False), nn.Sigmoid()) + + def forward(self, x): + b, c, _, _, _ = x.size() + y = self.avg_pool(x).view(b, 2 * c) + y = self.fc(y).view(b, c, 2, 1, 1) + return x * y.expand_as(x) + + +class SEBasicBlock(nn.Module): + expansion = 1 + + # https://github.com/moskomule/senet.pytorch + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + reduction=8): + super(SEBasicBlock, self).__init__() + self.conv1 = nn.Conv3d( + inplanes, planes, kernel_size=(1, 3, 3), padding=(0, 1, 1)) + self.in1 = nn.InstanceNorm3d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv3d( + planes, planes, kernel_size=(1, 3, 3), padding=(0, 1, 1)) + self.in2 = nn.InstanceNorm3d(planes) + self.ses = SELayerS(planes, reduction) + self.in3 = nn.InstanceNorm3d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + out = self.conv1(x) + out = self.in1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.in2(out) + + out = self.ses(out) + out = self.in3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class SAM(nn.Module): + """ + Spatio-temporal aggregation module (SAM) + """ + + def __init__(self, indim, outdim=None, repeat=0, norm=False): + super(SAM, self).__init__() + self.indim = indim + self.repeat = repeat + if outdim is None: + outdim = indim + if repeat > 0: + self.se_block = self.seRepeat(repeat) + self.conv1 = ASPP3D(indim, outdim, reduction=4) # norm is 4 + + # self.conv2 = nn.Conv2d(outdim, outdim, kernel_size=3, padding=1) + self.non_local = NONLocalBlock3D(indim, bn_layer=False) + + def seRepeat(self, repeat=2): + return nn.Sequential(*nn.ModuleList( + [SEBasicBlock(self.indim, self.indim) for _ in range(repeat)])) + + def forward(self, x): + x = self.non_local(x) + + if self.repeat > 0: + x = self.se_block(x) + r = x + x = self.conv1(x) + r + return x + + +class MemCrompress(nn.Module): + + def __init__(self, repeat=0, norm=True): + super().__init__() + + self.key_encoder = SAM(64, 64, repeat=repeat, norm=norm) + self.value_encoder = SAM(512, 512, repeat=repeat, norm=norm) + + self.compress_key = nn.Conv3d( + 64, 64, kernel_size=(2, 3, 3), padding=(0, 1, 1)) + self.compress_value = nn.Conv3d( + 512, 512, kernel_size=(2, 3, 3), padding=(0, 1, 1)) + # self.temporal_shuffle = temporal_shuffle + + def forward(self, key, value): + # key N, C, T, H, W [4, 64, 2, 24, 24] + # value N, O, C, T, H, W + # return + # key N, C, 1, H, W + # value N, O, C, 1, H, W + N, O, C, T, H, W = value.shape + value = value.flatten( + start_dim=0, end_dim=1) # N*O, C, T, H, W [8, 512, 2, 24, 24] + + k = self.compress_key(self.key_encoder(key)) + v = self.compress_value(self.value_encoder(value)) + v = v.view(N, O, C, 1, H, W) + return k, v diff --git a/modelscope/models/cv/video_object_segmentation/network.py b/modelscope/models/cv/video_object_segmentation/network.py new file mode 100644 index 00000000..aae389b2 --- /dev/null +++ b/modelscope/models/cv/video_object_segmentation/network.py @@ -0,0 +1,174 @@ +# Adopted from https://github.com/Limingxing00/RDE-VOS-CVPR2022 +# under MIT License + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.models.cv.video_object_segmentation.modules import ( + KeyEncoder, KeyProjection, MemCrompress, ResBlock, UpsampleBlock, + ValueEncoder, ValueEncoderSO) + + +class Decoder(nn.Module): + + def __init__(self): + super().__init__() + self.compress = ResBlock(1024, 512) + self.up_16_8 = UpsampleBlock(512, 512, 256) # 1/16 -> 1/8 + self.up_8_4 = UpsampleBlock(256, 256, 256) # 1/8 -> 1/4 + + self.pred = nn.Conv2d( + 256, 1, kernel_size=(3, 3), padding=(1, 1), stride=1) + + def forward(self, f16, f8, f4): + x = self.compress(f16) + x = self.up_16_8(f8, x) + x = self.up_8_4(f4, x) + + x = self.pred(F.relu(x)) + + x = F.interpolate( + x, scale_factor=4, mode='bilinear', align_corners=False) + return x + + +class MemoryReader(nn.Module): + + def __init__(self): + super().__init__() + + def get_affinity(self, mk, qk): + B, CK, T, H, W = mk.shape + mk = mk.flatten(start_dim=2) + qk = qk.flatten(start_dim=2) + + a = mk.pow(2).sum(1).unsqueeze(2) + b = 2 * (mk.transpose(1, 2) @ qk) + # this term will be cancelled out in the softmax + # c = qk.pow(2).sum(1).unsqueeze(1) + + affinity = (-a + b) / math.sqrt(CK) # B, THW, HW + + # softmax operation; aligned the evaluation style + maxes = torch.max(affinity, dim=1, keepdim=True)[0] + x_exp = torch.exp(affinity - maxes) + x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True) + affinity = x_exp / x_exp_sum + + return affinity + + def readout(self, affinity, mv, qv): + B, CV, T, H, W = mv.shape + + mo = mv.view(B, CV, T * H * W) + mem = torch.bmm(mo, affinity) # Weighted-sum B, CV, HW + mem = mem.view(B, CV, H, W) + + mem_out = torch.cat([mem, qv], dim=1) + + return mem_out + + +class RDE_VOS(nn.Module): + + def __init__(self, single_object, repeat=0, norm=False): + super().__init__() + self.single_object = single_object + + self.key_encoder = KeyEncoder() + if single_object: + self.value_encoder = ValueEncoderSO() + else: + self.value_encoder = ValueEncoder() + # Compress memory bank + self.mem_compress = MemCrompress(repeat=repeat, norm=norm) + # self.mem_compress.train(True) + + # Projection from f16 feature space to key space + self.key_proj = KeyProjection(1024, keydim=64) + + # Compress f16 a bit to use in decoding later on + self.key_comp = nn.Conv2d(1024, 512, kernel_size=3, padding=1) + + self.memory = MemoryReader() + self.decoder = Decoder() + + def aggregate(self, prob): + # During training, torch.prod work on channel dimension. + new_prob = torch.cat([torch.prod(1 - prob, dim=1, keepdim=True), prob], + 1).clamp(1e-7, 1 - 1e-7) + logits = torch.log((new_prob / (1 - new_prob))) + return logits + + def encode_key(self, frame): + # input: b*t*c*h*w + b, t = frame.shape[:2] + + f16, f8, f4 = self.key_encoder(frame.flatten(start_dim=0, end_dim=1)) + k16 = self.key_proj(f16) + f16_thin = self.key_comp(f16) + + # B*C*T*H*W + k16 = k16.view(b, t, *k16.shape[-3:]).transpose(1, 2).contiguous() + + # B*T*C*H*W + f16_thin = f16_thin.view(b, t, *f16_thin.shape[-3:]) + f16 = f16.view(b, t, *f16.shape[-3:]) + f8 = f8.view(b, t, *f8.shape[-3:]) + f4 = f4.view(b, t, *f4.shape[-3:]) + + return k16, f16_thin, f16, f8, f4 + + def encode_value(self, frame, kf16, mask, other_mask=None): + # Extract memory key/value for a frame + if self.single_object: + f16 = self.value_encoder(frame, kf16, mask) + else: + f16 = self.value_encoder(frame, kf16, mask, other_mask) + return f16.unsqueeze(2) # B*512*T*H*W + + def segment(self, qk16, qv16, qf8, qf4, mk16, mv16, selector=None): + # q - query, m - memory + # qv16 is f16_thin above + affinity = self.memory.get_affinity(mk16, qk16) + + if self.single_object: + mv2qv = self.memory.readout(affinity, mv16, qv16) + logits = self.decoder(mv2qv, qf8, qf4) + prob = torch.sigmoid(logits) + else: + mv2qv_o1 = self.memory.readout(affinity, mv16[:, 0], qv16) + mv2qv_o2 = self.memory.readout(affinity, mv16[:, 1], qv16) + logits = torch.cat([ + self.decoder(mv2qv_o1, qf8, qf4), + self.decoder(mv2qv_o2, qf8, qf4), + ], 1) + + prob = torch.sigmoid(logits) + prob = prob * selector.unsqueeze(2).unsqueeze(2) + + logits = self.aggregate(prob) + prob = F.softmax(logits, dim=1)[:, 1:] + + if self.single_object: + return logits, prob, mv2qv + else: + return logits, prob, mv2qv_o1, mv2qv_o2 + + def memCrompress(self, key, value): + return self.mem_compress(key, value) + + def forward(self, mode, *args, **kwargs): + if mode == 'encode_key': + return self.encode_key(*args, **kwargs) + elif mode == 'encode_value': + return self.encode_value(*args, **kwargs) + elif mode == 'segment': + return self.segment(*args, **kwargs) + elif mode == 'compress': + return self.memCrompress(*args, **kwargs) + else: + raise NotImplementedError diff --git a/modelscope/outputs/outputs.py b/modelscope/outputs/outputs.py index 2f4426b2..f2aaf48f 100644 --- a/modelscope/outputs/outputs.py +++ b/modelscope/outputs/outputs.py @@ -839,7 +839,13 @@ TASK_OUTPUTS = { # { # 'scores': [0.1, 0.2, 0.3, ...] # } - Tasks.translation_evaluation: [OutputKeys.SCORES] + Tasks.translation_evaluation: [OutputKeys.SCORES], + + # video object segmentation result for a single video + # { + # "masks": [np.array # 3D array with shape [frame_num, height, width]] + # } + Tasks.video_object_segmentation: [OutputKeys.MASKS], } diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 42ff1c7a..c1a4d86b 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -226,6 +226,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.translation_evaluation: (Pipelines.translation_evaluation, 'damo/nlp_unite_mup_translation_evaluation_multilingual_large'), + Tasks.video_object_segmentation: + (Pipelines.video_object_segmentation, + 'damo/cv_rdevos_video-object-segmentation'), } diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index 7f689d5e..6e80a6b9 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -67,6 +67,7 @@ if TYPE_CHECKING: from .hand_static_pipeline import HandStaticPipeline from .referring_video_object_segmentation_pipeline import ReferringVideoObjectSegmentationPipeline from .language_guided_video_summarization_pipeline import LanguageGuidedVideoSummarizationPipeline + from .video_object_segmentation_pipeline import VideoObjectSegmentationPipeline else: _import_structure = { @@ -153,7 +154,10 @@ else: ], 'language_guided_video_summarization_pipeline': [ 'LanguageGuidedVideoSummarizationPipeline' - ] + ], + 'video_object_segmentation_pipeline': [ + 'VideoObjectSegmentationPipeline' + ], } import sys diff --git a/modelscope/pipelines/cv/video_object_segmentation_pipeline.py b/modelscope/pipelines/cv/video_object_segmentation_pipeline.py new file mode 100644 index 00000000..6cd16f99 --- /dev/null +++ b/modelscope/pipelines/cv/video_object_segmentation_pipeline.py @@ -0,0 +1,151 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from typing import Any, Dict + +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image +from torchvision import transforms + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.video_object_segmentation.inference_core import \ + InferenceCore +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + +im_normalization = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + +def unpad(img, pad): + if pad[2] + pad[3] > 0: + img = img[:, :, pad[2]:-pad[3], :] + if pad[0] + pad[1] > 0: + img = img[:, :, :, pad[0]:-pad[1]] + return img + + +def all_to_onehot(masks, labels): + if len(masks.shape) == 3: + Ms = np.zeros( + (len(labels), masks.shape[0], masks.shape[1], masks.shape[2]), + dtype=np.uint8) + else: + Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1]), + dtype=np.uint8) + + for k, l in enumerate(labels): + Ms[k] = (masks == l).astype(np.uint8) + + return Ms + + +@PIPELINES.register_module( + Tasks.video_object_segmentation, + module_name=Pipelines.video_object_segmentation) +class VideoObjectSegmentationPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create video_object_segmentation pipeline for prediction + Args: + model: model id on modelscope hub. + """ + + super().__init__(model=model, **kwargs) + logger.info('load model done') + self.im_transform = transforms.Compose([ + transforms.ToTensor(), + im_normalization, + transforms.Resize(480, interpolation=Image.BICUBIC), + ]) + self.mask_transform = transforms.Compose([ + transforms.Resize(480, interpolation=Image.NEAREST), + ]) + + def preprocess(self, input: Input) -> Dict[str, Any]: + + self.images = input['images'] + self.mask = input['mask'] + + frames = len(self.images) + shape = np.shape(self.mask) + + info = {} + info['name'] = 'maas_test_video' + info['frames'] = frames + info['size'] = shape # Real sizes + info['gt_obj'] = {} # Frames with labelled objects + + images = [] + masks = [] + for i in range(frames): + img = self.images[i] + images.append(self.im_transform(img)) + + palette = self.mask.getpalette() + masks.append(np.array(self.mask, dtype=np.uint8)) + this_labels = np.unique(masks[-1]) + this_labels = this_labels[this_labels != 0] + info['gt_obj'][i] = this_labels + + images = torch.stack(images, 0) + masks = np.stack(masks, 0) + + labels = np.unique(masks).astype(np.uint8) + labels = labels[labels != 0] + + masks = torch.from_numpy(all_to_onehot(masks, labels)).float() + # Resize to 480p + masks = self.mask_transform(masks) + masks = masks.unsqueeze(2) + + info['labels'] = labels + + result = { + 'rgb': images, + 'gt': masks, + 'info': info, + 'palette': np.array(palette), + } + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + rgb = input['rgb'].unsqueeze(0) + msk = input['gt'] + info = input['info'] + k = len(info['labels']) + size = info['size'] + + is_cuda = rgb.is_cuda + + processor = InferenceCore( + self.model.model, is_cuda, rgb, k, top_k=20, mem_every=4) + processor.interact(msk[:, 0], 0, rgb.shape[1]) + + # Do unpad -> upsample to original size + out_masks = torch.zeros((processor.t, 1, *size), + dtype=torch.uint8, + device='cuda' if is_cuda else 'cpu') + for ti in range(processor.t): + prob = unpad(processor.prob[:, ti], processor.pad) + prob = F.interpolate( + prob, tuple(size), mode='bilinear', align_corners=False) + out_masks[ti] = torch.argmax(prob, dim=0) + + if is_cuda: + out_masks = (out_masks.detach().cpu().numpy()[:, + 0]).astype(np.uint8) + else: + out_masks = (out_masks.detach().numpy()[:, 0]).astype(np.uint8) + + return {OutputKeys.MASKS: out_masks} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index dc41794a..d527e4c9 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -89,6 +89,7 @@ class CVTasks(object): language_guided_video_summarization = 'language-guided-video-summarization' # video segmentation + video_object_segmentation = 'video-object-segmentation' referring_video_object_segmentation = 'referring-video-object-segmentation' video_human_matting = 'video-human-matting' diff --git a/modelscope/utils/cv/image_utils.py b/modelscope/utils/cv/image_utils.py index 0c72fe62..531889d2 100644 --- a/modelscope/utils/cv/image_utils.py +++ b/modelscope/utils/cv/image_utils.py @@ -1,8 +1,11 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import os + import cv2 import matplotlib.pyplot as plt import numpy as np +from PIL import Image from modelscope.outputs import OutputKeys from modelscope.preprocessors.image import load_image @@ -478,3 +481,12 @@ def depth_to_color(depth): (depth.max() - depth) / depth.max()) * 2**8).astype(np.uint8)[:, :, :3] depth_color = cv2.cvtColor(depth_color, cv2.COLOR_RGB2BGR) return depth_color + + +def masks_visualization(masks, palette): + vis_masks = [] + for f in range(masks.shape[0]): + img_E = Image.fromarray(masks[f]) + img_E.putpalette(palette) + vis_masks.append(img_E) + return vis_masks diff --git a/tests/pipelines/test_video_object_segmentation.py b/tests/pipelines/test_video_object_segmentation.py new file mode 100644 index 00000000..e4adeb26 --- /dev/null +++ b/tests/pipelines/test_video_object_segmentation.py @@ -0,0 +1,51 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import unittest + +from PIL import Image + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.cv.image_utils import masks_visualization +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class VideoObjectSegmentationTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = 'video-object-segmentation' + self.model_id = 'damo/cv_rdevos_video-object-segmentation' + self.input_location = 'data/test/videos/video_object_segmentation_test' + self.images_dir = os.path.join(self.input_location, 'JPEGImages') + self.mask_file = os.path.join(self.input_location, 'Annotations', + '00000.png') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_video_object_segmentation(self): + input_images = [] + for image_file in sorted(os.listdir(self.images_dir)): + img = Image.open(os.path.join(self.images_dir, image_file))\ + .convert('RGB') + input_images.append(img) + mask = Image.open(self.mask_file).convert('P') + input = {'images': input_images, 'mask': mask} + + segmentor = pipeline( + Tasks.video_object_segmentation, model=self.model_id) + result = segmentor(input) + out_masks = result[OutputKeys.MASKS] + + vis_masks = masks_visualization(out_masks, mask.getpalette()) + + os.makedirs('test_result', exist_ok=True) + for f, vis_mask in enumerate(vis_masks): + vis_mask.save(os.path.join('test_result', '{:05d}.png'.format(f))) + + print('test_video_object_segmentation DONE') + + +if __name__ == '__main__': + unittest.main()