diff --git a/data/test/images/vidt_test1.jpg b/data/test/images/vidt_test1.jpg new file mode 100644 index 00000000..6f4bc051 --- /dev/null +++ b/data/test/images/vidt_test1.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b7e87ea289bc59863ed81129d5991ede97bf5335c173ab9f36e4e4cfdc858e41 +size 120137 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 48a6330f..ca7c6162 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -83,6 +83,7 @@ class Models(object): video_deinterlace = 'video-deinterlace' quadtree_attention_image_matching = 'quadtree-attention-image-matching' vision_middleware = 'vision-middleware' + vidt = 'vidt' video_stabilization = 'video-stabilization' real_basicvsr = 'real-basicvsr' rcp_sceneflow_estimation = 'rcp-sceneflow-estimation' @@ -361,6 +362,7 @@ class Pipelines(object): image_skychange = 'image-skychange' video_human_matting = 'video-human-matting' vision_middleware_multi_task = 'vision-middleware-multi-task' + vidt = 'vidt' video_frame_interpolation = 'video-frame-interpolation' video_object_segmentation = 'video-object-segmentation' video_deinterlace = 'video-deinterlace' diff --git a/modelscope/models/cv/__init__.py b/modelscope/models/cv/__init__.py index 2f5f689f..782d25f1 100644 --- a/modelscope/models/cv/__init__.py +++ b/modelscope/models/cv/__init__.py @@ -25,7 +25,7 @@ from . import (action_recognition, animal_recognition, bad_image_detecting, table_recognition, video_deinterlace, video_frame_interpolation, video_object_segmentation, video_panoptic_segmentation, video_single_object_tracking, video_stabilization, - video_summarization, video_super_resolution, virual_tryon, + video_summarization, video_super_resolution, vidt, virual_tryon, vision_middleware, vop_retrieval) # yapf: enable diff --git a/modelscope/models/cv/vidt/__init__.py b/modelscope/models/cv/vidt/__init__.py new file mode 100644 index 00000000..785d0274 --- /dev/null +++ b/modelscope/models/cv/vidt/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .model import VidtModel +else: + _import_structure = { + 'model': ['VidtModel'], + } + import sys + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/vidt/backbone.py b/modelscope/models/cv/vidt/backbone.py new file mode 100644 index 00000000..198ab498 --- /dev/null +++ b/modelscope/models/cv/vidt/backbone.py @@ -0,0 +1,1061 @@ +# The implementation here is modified based on timm, +# originally Apache 2.0 License and publicly available at +# https://github.com/naver-ai/vidt/blob/vidt-plus/methods/swin_w_ram.py + +import math +import os + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + """ Multilayer perceptron.""" + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def masked_sin_pos_encoding(x, + mask, + num_pos_feats, + temperature=10000, + scale=2 * math.pi): + """ Masked Sinusoidal Positional Encoding + + Args: + x: [PATCH] tokens + mask: the padding mask for [PATCH] tokens + num_pos_feats: the size of channel dimension + temperature: the temperature value + scale: the normalization scale + + Returns: + pos: Sinusoidal positional encodings + """ + + num_pos_feats = num_pos_feats // 2 + not_mask = ~mask + + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * scale + + dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), + dim=4).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), + dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3) + + return pos + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, + C) + windows = x.permute(0, 1, 3, 2, 4, + 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class ReconfiguredAttentionModule(nn.Module): + """ Window based multi-head self attention (W-MSA) module with relative position bias -> extended with RAM. + It supports both of shifted and non-shifted window. + + !!!!!!!!!!! IMPORTANT !!!!!!!!!!! + The original attention module in Swin is replaced with the reconfigured attention module in Section 3. + All the Args are shared, so only the forward function is modified. + See https://arxiv.org/pdf/2110.03921.pdf + !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, + dim, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0., + proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), + num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, + None] - coords_flatten[:, + None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, + 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer('relative_position_index', + relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, + x, + det, + mask=None, + cross_attn=False, + cross_attn_mask=None): + """ Forward function. + RAM module receives [Patch] and [DET] tokens and returns their calibrated ones + + Args: + x: [PATCH] tokens + det: [DET] tokens + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None -> mask for shifted window attention + + "additional inputs for RAM" + cross_attn: whether to use cross-attention [det x patch] (for selective cross-attention) + cross_attn_mask: mask for cross-attention + + Returns: + patch_x: the calibrated [PATCH] tokens + det_x: the calibrated [DET] tokens + """ + + assert self.window_size[0] == self.window_size[1] + window_size = self.window_size[0] + local_map_size = window_size * window_size + + # projection before window partitioning + if not cross_attn: + B, H, W, C = x.shape + N = H * W + x = x.view(B, N, C) + x = torch.cat([x, det], dim=1) + full_qkv = self.qkv(x) + patch_qkv, det_qkv = full_qkv[:, :N, :], full_qkv[:, N:, :] + else: + B, H, W, C = x[0].shape + N = H * W + _, ori_H, ori_W, _ = x[1].shape + ori_N = ori_H * ori_W + + shifted_x = x[0].view(B, N, C) + cross_x = x[1].view(B, ori_N, C) + x = torch.cat([shifted_x, cross_x, det], dim=1) + full_qkv = self.qkv(x) + patch_qkv, cross_patch_qkv, det_qkv = \ + full_qkv[:, :N, :], full_qkv[:, N:N + ori_N, :], full_qkv[:, N + ori_N:, :] + patch_qkv = patch_qkv.view(B, H, W, -1) + + # window partitioning for [PATCH] tokens + patch_qkv = window_partition( + patch_qkv, window_size) # nW*B, window_size, window_size, C + B_ = patch_qkv.shape[0] + patch_qkv = patch_qkv.reshape(B_, window_size * window_size, 3, + self.num_heads, C // self.num_heads) + _patch_qkv = patch_qkv.permute(2, 0, 3, 1, 4) + patch_q, patch_k, patch_v = _patch_qkv[0], _patch_qkv[1], _patch_qkv[2] + + # [PATCH x PATCH] self-attention using window partitions + patch_q = patch_q * self.scale + patch_attn = (patch_q @ patch_k.transpose(-2, -1)) + # add relative pos bias for [patch x patch] self-attention + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + patch_attn = patch_attn + relative_position_bias.unsqueeze(0) + + # if shifted window is used, it needs to apply the mask + if mask is not None: + nW = mask.shape[0] + tmp0 = patch_attn.view(B_ // nW, nW, self.num_heads, + local_map_size, local_map_size) + tmp1 = mask.unsqueeze(1).unsqueeze(0) + patch_attn = tmp0 + tmp1 + patch_attn = patch_attn.view(-1, self.num_heads, local_map_size, + local_map_size) + + patch_attn = self.softmax(patch_attn) + patch_attn = self.attn_drop(patch_attn) + patch_x = (patch_attn @ patch_v).transpose(1, 2).reshape( + B_, window_size, window_size, C) + + # extract qkv for [DET] tokens + det_qkv = det_qkv.view(B, -1, 3, self.num_heads, C // self.num_heads) + det_qkv = det_qkv.permute(2, 0, 3, 1, 4) + det_q, det_k, det_v = det_qkv[0], det_qkv[1], det_qkv[2] + + # if cross-attention is activated + if cross_attn: + + # reconstruct the spatial form of [PATCH] tokens for global [DET x PATCH] attention + cross_patch_qkv = cross_patch_qkv.view(B, ori_H, ori_W, 3, + self.num_heads, + C // self.num_heads) + patch_kv = cross_patch_qkv[:, :, :, + 1:, :, :].permute(3, 0, 4, 1, 2, + 5).contiguous() + patch_kv = patch_kv.view(2, B, self.num_heads, ori_H * ori_W, -1) + + # extract "key and value" of [PATCH] tokens for cross-attention + cross_patch_k, cross_patch_v = patch_kv[0], patch_kv[1] + + # bind key and value of [PATCH] and [DET] tokens for [DET X [PATCH, DET]] attention + det_k, det_v = torch.cat([cross_patch_k, det_k], + dim=2), torch.cat([cross_patch_v, det_v], + dim=2) + + # [DET x DET] self-attention or binded [DET x [PATCH, DET]] attention + det_q = det_q * self.scale + det_attn = (det_q @ det_k.transpose(-2, -1)) + # apply cross-attention mask if available + if cross_attn_mask is not None: + det_attn = det_attn + cross_attn_mask + det_attn = self.softmax(det_attn) + det_attn = self.attn_drop(det_attn) + det_x = (det_attn @ det_v).transpose(1, 2).reshape(B, -1, C) + + # reverse window for [PATCH] tokens <- the output of [PATCH x PATCH] self attention + patch_x = window_reverse(patch_x, window_size, H, W) + + # projection for outputs from multi-head + x = torch.cat([patch_x.view(B, H * W, C), det_x], dim=1) + x = self.proj(x) + x = self.proj_drop(x) + + # decompose after FFN into [PATCH] and [DET] tokens + patch_x = x[:, :H * W, :].view(B, H, W, C) + det_x = x[:, H * W:, :] + + return patch_x, det_x + + +class SwinTransformerBlock(nn.Module): + """ Swin Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size' + + self.norm1 = norm_layer(dim) + self.attn = ReconfiguredAttentionModule( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + self.H = None + self.W = None + + def forward(self, x, mask_matrix, pos, cross_attn, cross_attn_mask): + """ Forward function. + + Args: + x: Input feature, tensor size (B, H*W + DET, C). i.e., binded [PATCH, DET] tokens + H, W: Spatial resolution of the input feature. + mask_matrix: Attention mask for cyclic shift. + + "additional inputs' + pos: (patch_pos, det_pos) + cross_attn: whether to use cross attn [det x [det + patch]] + cross_attn_mask: attention mask for cross-attention + + Returns: + x: calibrated & binded [PATCH, DET] tokens + """ + + B, L, C = x.shape + H, W = self.H, self.W + + assert L == H * W + self.det_token_num, 'input feature has wrong size' + + shortcut = x + x = self.norm1(x) + x, det = x[:, :H * W, :], x[:, H * W:, :] + x = x.view(B, H, W, C) + orig_x = x + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # projection for det positional encodings: make the channel size suitable for the current layer + patch_pos, det_pos = pos + det_pos = self.det_pos_linear(det_pos) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll( + x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # prepare cross-attn and add positional encodings + if cross_attn: + # patch token (for cross-attention) + Sinusoidal pos encoding + cross_patch = orig_x + patch_pos + # det token + learnable pos encoding + det = det + det_pos + shifted_x = (shifted_x, cross_patch) + else: + # it cross_attn is deativated, only [PATCH] and [DET] self-attention are performed + det = det + det_pos + shifted_x = shifted_x + + # W-MSA/SW-MSA + shifted_x, det = self.attn( + shifted_x, + mask=attn_mask, + # additional args + det=det, + cross_attn=cross_attn, + cross_attn_mask=cross_attn_mask) + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll( + shifted_x, + shifts=(self.shift_size, self.shift_size), + dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + x = torch.cat([x, det], dim=1) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class PatchMerging(nn.Module): + """ Patch Merging Layer + + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm, expand=True): + super().__init__() + self.dim = dim + + # if expand is True, the channel size will be expanded, otherwise, return 256 size of channel + expand_dim = 2 * dim if expand else 256 + self.reduction = nn.Linear(4 * dim, expand_dim, bias=False) + self.norm = norm_layer(4 * dim) + + # added for detection token [please ignore, not used for training] + # not implemented yet. + self.expansion = nn.Linear(dim, expand_dim, bias=False) + self.norm2 = norm_layer(dim) + + def forward(self, x, H, W): + """ Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C), i.e., binded [PATCH, DET] tokens + H, W: Spatial resolution of the input feature. + + Returns: + x: merged [PATCH, DET] tokens; + only [PATCH] tokens are reduced in spatial dim, while [DET] tokens is fix-scale + """ + + B, L, C = x.shape + assert L == H * W + self.det_token_num, 'input feature has wrong size' + + x, det = x[:, :H * W, :], x[:, H * W:, :] + x = x.view(B, H, W, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + # simply repeating for DET tokens + det = det.repeat(1, 1, 4) + + x = torch.cat([x, det], dim=1) + x = self.norm(x) + x = self.reduction(x) + + return x + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (int): Local window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + dim, + depth, + num_heads, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + last=False, + use_checkpoint=False): + + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + self.depth = depth + self.dim = dim + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] + if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) for i in range(depth) + ]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + dim=dim, norm_layer=norm_layer, expand=(not last)) + else: + self.downsample = None + + def forward(self, x, H, W, det_pos, input_mask, cross_attn=False): + """ Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + det_pos: pos encoding for det token + input_mask: padding mask for inputs + cross_attn: whether to use cross attn [det x [det + patch]] + """ + + B = x.shape[0] + + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + # mask for cyclic shift + mask_windows = window_partition(img_mask, self.window_size) + mask_windows = mask_windows.view(-1, + self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, + float(-100.0)).masked_fill( + attn_mask == 0, float(0.0)) + + # compute sinusoidal pos encoding and cross-attn mask here to avoid redundant computation + if cross_attn: + + _H, _W = input_mask.shape[1:] + if not (_H == H and _W == W): + input_mask = F.interpolate( + input_mask[None].float(), size=(H, W)).to(torch.bool)[0] + + # sinusoidal pos encoding for [PATCH] tokens used in cross-attention + patch_pos = masked_sin_pos_encoding(x, input_mask, self.dim) + + # attention padding mask due to the zero padding in inputs + # the zero (padded) area is masked by 1.0 in 'input_mask' + cross_attn_mask = input_mask.float() + cross_attn_mask = cross_attn_mask.masked_fill(cross_attn_mask != 0.0, float(-100.0)). \ + masked_fill(cross_attn_mask == 0.0, float(0.0)) + + # pad for detection token (this padding is required to process the binded [PATCH, DET] attention + cross_attn_mask = cross_attn_mask.view( + B, H * W).unsqueeze(1).unsqueeze(2) + cross_attn_mask = F.pad( + cross_attn_mask, (0, self.det_token_num), value=0) + + else: + patch_pos = None + cross_attn_mask = None + + # zip pos encodings + pos = (patch_pos, det_pos) + + for n_blk, blk in enumerate(self.blocks): + blk.H, blk.W = H, W + + # for selective cross-attention + if cross_attn: + _cross_attn = True + _cross_attn_mask = cross_attn_mask + _pos = pos # i.e., (patch_pos, det_pos) + else: + _cross_attn = False + _cross_attn_mask = None + _pos = (None, det_pos) + + if self.use_checkpoint: + x = checkpoint.checkpoint( + blk, + x, + attn_mask, + # additional inputs + pos=_pos, + cross_attn=_cross_attn, + cross_attn_mask=_cross_attn_mask) + else: + x = blk( + x, + attn_mask, + # additional inputs + pos=_pos, + cross_attn=_cross_attn, + cross_attn_mask=_cross_attn_mask) + + # reduce the number of patch tokens, but maintaining a fixed-scale det tokens + # meanwhile, the channel dim increases by a factor of 2 + if self.downsample is not None: + x_down = self.downsample(x, H, W) + Wh, Ww = (H + 1) // 2, (W + 1) // 2 + return x, H, W, x_down, Wh, Ww + else: + return x, H, W, x, H, W + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + + Args: + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, + patch_size=4, + in_chans=3, + embed_dim=96, + norm_layer=None): + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + + # padding + _, _, H, W = x.size() + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, + (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) # B C Wh Ww + if self.norm is not None: + Wh, Ww = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) + + return x + + +class SwinTransformer(nn.Module): + """ Swin Transformer backbone. + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + + Args: + pretrain_img_size (int): Input image size for training the pretrained model, + used in absolute postion embedding. Default 224. + patch_size (int | tuple(int)): Patch size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. + patch_norm (bool): If True, add normalization after patch embedding. Default: True. + out_indices (Sequence[int]): Output from which stages. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any args. + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + pretrain_img_size=224, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + out_indices=[1, 2, + 3], # not used in the current version, please ignore. + frozen_stages=-1, + use_checkpoint=False): + super().__init__() + + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_2tuple(pretrain_img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [ + pretrain_img_size[0] // patch_size[0], + pretrain_img_size[1] // patch_size[1] + ] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], + patches_resolution[1])) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2**i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + # modified by ViDT + downsample=PatchMerging if + (i_layer < self.num_layers) else None, + last=None if (i_layer < self.num_layers - 1) else True, + # + use_checkpoint=use_checkpoint) + self.layers.append(layer) + + num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + # Not used in the current version -> please ignore. this error will be fixed later + # we leave this lines to load the pre-trained model ... + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + @torch.jit.ignore + def no_weight_decay(self): + return {'det_pos_embed', 'det_token'} + + def finetune_det(self, + method, + det_token_num=100, + pos_dim=256, + cross_indices=[3]): + """ A funtion to add neccessary (leanable) variables to Swin Transformer for object detection + + Args: + method: vidt or vidt_wo_neck + det_token_num: the number of object to detect, i.e., number of object queries + pos_dim: the channel dimension of positional encodings for [DET] and [PATCH] tokens + cross_indices: the indices where to use the [DET X PATCH] cross-attention + there are four possible stages in [0, 1, 2, 3]. 3 indicates Stage 4 in the ViDT paper. + """ + + # which method? + self.method = method + + # how many object we detect? + self.det_token_num = det_token_num + self.det_token = nn.Parameter( + torch.zeros(1, det_token_num, self.num_features[0])) + self.det_token = trunc_normal_(self.det_token, std=.02) + + # dim size of pos encoding + self.pos_dim = pos_dim + + # learnable positional encoding for detection tokens + det_pos_embed = torch.zeros(1, det_token_num, pos_dim) + det_pos_embed = trunc_normal_(det_pos_embed, std=.02) + self.det_pos_embed = torch.nn.Parameter(det_pos_embed) + + # info for detection + self.num_channels = [ + self.num_features[i + 1] + for i in range(len(self.num_features) - 1) + ] + if method == 'vidt': + self.num_channels.append( + self.pos_dim) # default: 256 (same to the default pos_dim) + self.cross_indices = cross_indices + # divisor to reduce the spatial size of the mask + self.mask_divisor = 2**(len(self.layers) - len(self.cross_indices)) + + # projection matrix for det pos encoding in each Swin layer (there are 4 blocks) + for layer in self.layers: + layer.det_token_num = det_token_num + if layer.downsample is not None: + layer.downsample.det_token_num = det_token_num + for block in layer.blocks: + block.det_token_num = det_token_num + block.det_pos_linear = nn.Linear(pos_dim, block.dim) + + # neck-free model do not require downsamling at the last stage. + if method == 'vidt_wo_neck': + self.layers[-1].downsample = None + + def forward(self, x, mask): + """ Forward function. + + Args: + x: input rgb images + mask: input padding masks [0: rgb values, 1: padded values] + + Returns: + patch_outs: multi-scale [PATCH] tokens (four scales are used) + these tokens are the first input of the neck decoder + det_tgt: final [DET] tokens obtained at the last stage + this tokens are the second input of the neck decoder + det_pos: the learnable pos encoding for [DET] tokens. + these encodings are used to generate reference points in deformable attention + """ + + # original input shape + B, _, _ = x.shape[0], x.shape[2], x.shape[3] + + # patch embedding + x = self.patch_embed(x) + + Wh, Ww = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + # expand det_token for all examples in the batch + det_token = self.det_token.expand(B, -1, -1) + + # det pos encoding -> will be projected in each block + det_pos = self.det_pos_embed + + # prepare a mask for cross attention + mask = F.interpolate( + mask[None].float(), + size=(Wh // self.mask_divisor, + Ww // self.mask_divisor)).to(torch.bool)[0] + + patch_outs = [] + for stage in range(self.num_layers): + layer = self.layers[stage] + + # whether to use cross-attention + cross_attn = True if stage in self.cross_indices else False + + # concat input + x = torch.cat([x, det_token], dim=1) + + # inference + x_out, H, W, x, Wh, Ww = layer( + x, + Wh, + Ww, + # additional input for VIDT + input_mask=mask, + det_pos=det_pos, + cross_attn=cross_attn) + + x, det_token = x[:, :-self.det_token_num, :], x[:, -self. + det_token_num:, :] + + # Aggregate intermediate outputs + if stage > 0: + patch_out = x_out[:, :-self.det_token_num, :].view( + B, H, W, -1).permute(0, 3, 1, 2) + patch_outs.append(patch_out) + + # patch token reduced from last stage output + patch_outs.append(x.view(B, Wh, Ww, -1).permute(0, 3, 1, 2)) + + # det token + det_tgt = x_out[:, -self.det_token_num:, :].permute(0, 2, 1) + + # det token pos encoding + det_pos = det_pos.permute(0, 2, 1) + + features_0, features_1, features_2, features_3 = patch_outs + return features_0, features_1, features_2, features_3, det_tgt, det_pos + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer, self).train(mode) + self._freeze_stages() + + # not working in the current version + def flops(self): + flops = 0 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += self.num_features * self.patches_resolution[ + 0] * self.patches_resolution[1] // (2**self.num_layers) + flops += self.num_features * self.num_classes + return flops diff --git a/modelscope/models/cv/vidt/deformable_transformer.py b/modelscope/models/cv/vidt/deformable_transformer.py new file mode 100644 index 00000000..7344ce5d --- /dev/null +++ b/modelscope/models/cv/vidt/deformable_transformer.py @@ -0,0 +1,616 @@ +# The implementation here is modified based on timm, +# originally Apache 2.0 License and publicly available at +# https://github.com/naver-ai/vidt/blob/vidt-plus/methods/vidt/deformable_transformer.py + +import copy +import math +import warnings + +import torch +import torch.nn.functional as F +from timm.models.layers import DropPath +from torch import nn +from torch.nn.init import constant_, normal_, xavier_uniform_ + + +class DeformableTransformer(nn.Module): + """ A Deformable Transformer for the neck in a detector + + The transformer encoder is completely removed for ViDT + Args: + d_model: the channel dimension for attention [default=256] + nhead: the number of heads [default=8] + num_decoder_layers: the number of decoding layers [default=6] + dim_feedforward: the channel dim of point-wise FFNs [default=1024] + dropout: the degree of dropout used in FFNs [default=0.1] + activation: An activation function to use [default='relu'] + return_intermediate_dec: whether to return all the indermediate outputs [default=True] + num_feature_levels: the number of scales for extracted features [default=4] + dec_n_points: the number of reference points for deformable attention [default=4] + drop_path: the ratio of stochastic depth for decoding layers [default=0.0] + token_label: whether to use the token label loss for training [default=False]. This is an additional trick + proposed in https://openreview.net/forum?id=LhbD74dsZFL (ICLR'22) for further improvement + """ + + def __init__(self, + d_model=256, + nhead=8, + num_decoder_layers=6, + dim_feedforward=1024, + dropout=0.1, + activation='relu', + return_intermediate_dec=True, + num_feature_levels=4, + dec_n_points=4, + drop_path=0., + token_label=False): + super().__init__() + + self.d_model = d_model + self.nhead = nhead + decoder_layer = DeformableTransformerDecoderLayer( + d_model, + dim_feedforward, + dropout, + activation, + num_feature_levels, + nhead, + dec_n_points, + drop_path=drop_path) + self.decoder = DeformableTransformerDecoder(decoder_layer, + num_decoder_layers, + return_intermediate_dec) + + self.level_embed = nn.Parameter( + torch.Tensor(num_feature_levels, d_model)) + self.token_label = token_label + + self.reference_points = nn.Linear(d_model, 2) + + if self.token_label: + self.enc_output = nn.Linear(d_model, d_model) + self.enc_output_norm = nn.LayerNorm(d_model) + + self.token_embed = nn.Linear(d_model, 91) + prior_prob = 0.01 + bias_value = -math.log((1 - prior_prob) / prior_prob) + self.token_embed.bias.data = torch.ones(91) * bias_value + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + for m in self.modules(): + if isinstance(m, MSDeformAttn): + m._reset_parameters() + + normal_(self.level_embed) + + def get_proposal_pos_embed(self, proposals): + num_pos_feats = 128 + temperature = 10000 + scale = 2 * math.pi + + dim_t = torch.arange( + num_pos_feats, dtype=torch.float32, device=proposals.device) + dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats) + # N, L, 4 + proposals = proposals.sigmoid() * scale + # N, L, 4, 128 + pos = proposals[:, :, :, None] / dim_t + # N, L, 4, 64, 2 + pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), + dim=4).flatten(2) + return pos + + def gen_encoder_output_proposals(self, memory, memory_padding_mask, + spatial_shapes): + N_, S_, C_ = memory.shape + proposals = [] + _cur = 0 + for lvl, (H_, W_) in enumerate(spatial_shapes): + mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view( + N_, H_, W_, 1) + valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) + valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) + + grid_y, grid_x = torch.meshgrid( + torch.linspace( + 0, H_ - 1, H_, dtype=torch.float32, device=memory.device), + torch.linspace( + 0, W_ - 1, W_, dtype=torch.float32, device=memory.device)) + grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) + + scale = torch.cat([valid_W.unsqueeze(-1), + valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2) + grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale + wh = torch.ones_like(grid) * 0.05 * (2.0**lvl) + proposal = torch.cat((grid, wh), -1).view(N_, -1, 4) + proposals.append(proposal) + _cur += (H_ * W_) + output_proposals = torch.cat(proposals, 1) + tmp = (output_proposals > 0.01) & (output_proposals < 0.99) + output_proposals_valid = tmp.all(-1, keepdim=True) + output_proposals = torch.log(output_proposals / (1 - output_proposals)) + output_proposals = output_proposals.masked_fill( + memory_padding_mask.unsqueeze(-1), float('inf')) + output_proposals = output_proposals.masked_fill( + ~output_proposals_valid, float('inf')) + + output_memory = memory + output_memory = output_memory.masked_fill( + memory_padding_mask.unsqueeze(-1), float(0)) + output_memory = output_memory.masked_fill(~output_proposals_valid, + float(0)) + output_memory = self.enc_output_norm(self.enc_output(output_memory)) + return output_memory, output_proposals + + def get_valid_ratio(self, mask): + _, H, W = mask.shape + valid_H = torch.sum(~mask[:, :, 0], 1) + valid_W = torch.sum(~mask[:, 0, :], 1) + valid_ratio_h = valid_H.float() / H + valid_ratio_w = valid_W.float() / W + valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) + return valid_ratio + + def forward(self, srcs, masks, tgt, query_pos): + """ The forward step of the decoder + + Args: + srcs: [Patch] tokens + masks: input padding mask + tgt: [DET] tokens + query_pos: [DET] token pos encodings + + Returns: + hs: calibrated [DET] tokens + init_reference_out: init reference points + inter_references_out: intermediate reference points for box refinement + enc_token_class_unflat: info. for token labeling + """ + + # prepare input for the Transformer decoder + src_flatten = [] + mask_flatten = [] + spatial_shapes = [] + for lvl, (src, mask) in enumerate(zip(srcs, masks)): + bs, c, h, w = src.shape + spatial_shape = (h, w) + spatial_shapes.append(spatial_shape) + src = src.flatten(2).transpose(1, 2) + mask = mask.flatten(1) + src_flatten.append(src) + mask_flatten.append(mask) + src_flatten = torch.cat(src_flatten, 1) + mask_flatten = torch.cat(mask_flatten, 1) + spatial_shapes = torch.as_tensor( + spatial_shapes, dtype=torch.long, device=src_flatten.device) + level_start_index = torch.cat((spatial_shapes.new_zeros( + (1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) + valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) + + memory = src_flatten + bs, _, c = memory.shape + tgt = tgt # [DET] tokens + query_pos = query_pos.expand(bs, -1, -1) # [DET] token pos encodings + + # prepare input for token label + if self.token_label: + output_memory, output_proposals = self.gen_encoder_output_proposals( + memory, mask_flatten, spatial_shapes) + enc_token_class_unflat = None + if self.token_label: + enc_token_class = self.token_embed(output_memory) + enc_token_class_unflat = [] + for st, (h, w) in zip(level_start_index, spatial_shapes): + enc_token_class_unflat.append( + enc_token_class[:, st:st + h * w, :].view(bs, h, w, 91)) + + # reference points for deformable attention + reference_points = self.reference_points(query_pos).sigmoid() + init_reference_out = reference_points # query_pos -> reference point + + # decoder + hs, inter_references = self.decoder(tgt, reference_points, memory, + spatial_shapes, level_start_index, + valid_ratios, query_pos, + mask_flatten) + + inter_references_out = inter_references + + return hs, init_reference_out, inter_references_out, enc_token_class_unflat + + +class DeformableTransformerDecoderLayer(nn.Module): + """ A decoder layer. + + Args: + d_model: the channel dimension for attention [default=256] + d_ffn: the channel dim of point-wise FFNs [default=1024] + dropout: the degree of dropout used in FFNs [default=0.1] + activation: An activation function to use [default='relu'] + n_levels: the number of scales for extracted features [default=4] + n_heads: the number of heads [default=8] + n_points: the number of reference points for deformable attention [default=4] + drop_path: the ratio of stochastic depth for decoding layers [default=0.0] + """ + + def __init__(self, + d_model=256, + d_ffn=1024, + dropout=0.1, + activation='relu', + n_levels=4, + n_heads=8, + n_points=4, + drop_path=0.): + super().__init__() + + # [DET x PATCH] deformable cross-attention + self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + + # [DET x DET] self-attention + self.self_attn = nn.MultiheadAttention( + d_model, n_heads, dropout=dropout) + self.dropout2 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm(d_model) + + # ffn for multi-heaed + self.linear1 = nn.Linear(d_model, d_ffn) + self.activation = _get_activation_fn(activation) + self.dropout3 = nn.Dropout(dropout) + self.linear2 = nn.Linear(d_ffn, d_model) + self.dropout4 = nn.Dropout(dropout) + self.norm3 = nn.LayerNorm(d_model) + + # stochastic depth + self.drop_path = DropPath(drop_path) if drop_path > 0. else None + + @staticmethod + def with_pos_embed(tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, tgt): + tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout4(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward(self, + tgt, + query_pos, + reference_points, + src, + src_spatial_shapes, + level_start_index, + src_padding_mask=None): + + # [DET] self-attention + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn( + q.transpose(0, 1), k.transpose(0, 1), + tgt.transpose(0, 1))[0].transpose(0, 1) + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + + # Multi-scale deformable cross-attention in Eq. (1) in the ViDT paper + tgt2 = self.cross_attn( + self.with_pos_embed(tgt, query_pos), reference_points, src, + src_spatial_shapes, level_start_index, src_padding_mask) + + if self.drop_path is None: + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + # ffn + tgt = self.forward_ffn(tgt) + else: + tgt = tgt + self.drop_path(self.dropout1(tgt2)) + tgt2 = self.linear2( + self.dropout3(self.activation(self.linear1(tgt)))) + tgt = tgt + self.drop_path(self.dropout4(tgt2)) + tgt = self.norm3(tgt) + + return tgt + + +class DeformableTransformerDecoder(nn.Module): + """ A Decoder consisting of multiple layers + + Args: + decoder_layer: a deformable decoding layer + num_layers: the number of layers + return_intermediate: whether to return intermediate resutls + """ + + def __init__(self, decoder_layer, num_layers, return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.return_intermediate = return_intermediate + # hack implementation for iterative bounding box refinement + self.bbox_embed = None + self.class_embed = None + + def forward(self, + tgt, + reference_points, + src, + src_spatial_shapes, + src_level_start_index, + src_valid_ratios, + query_pos=None, + src_padding_mask=None): + """ The forwared step of the Deformable Decoder + + Args: + tgt: [DET] tokens + reference_points: reference points for deformable attention + src: the [PATCH] tokens fattened into a 1-d sequence + src_spatial_shapes: the spatial shape of each multi-scale feature map + src_level_start_index: the start index to refer different scale inputs + src_valid_ratios: the ratio of multi-scale feature maps + query_pos: the pos encoding for [DET] tokens + src_padding_mask: the input padding mask + + Returns: + output: [DET] tokens calibrated (i.e., object embeddings) + reference_points: A reference points + + If return_intermediate = True, output & reference_points are returned from all decoding layers + """ + + output = tgt + intermediate = [] + intermediate_reference_points = [] + + # iterative bounding box refinement (handling the [DET] tokens produced from Swin with RAM) + if self.bbox_embed is not None: + tmp = self.bbox_embed[0](output) + if reference_points.shape[-1] == 4: + new_reference_points = tmp + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + else: + assert reference_points.shape[-1] == 2 + new_reference_points = tmp + new_reference_points[ + ..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + reference_points = new_reference_points.detach() + # + + if self.return_intermediate: + intermediate.append(output) + intermediate_reference_points.append(reference_points) + + for lid, layer in enumerate(self.layers): + if reference_points.shape[-1] == 4: + tmp0 = reference_points[:, :, None] + tmp1 = torch.cat([src_valid_ratios, src_valid_ratios], + -1)[:, None] + reference_points_input = tmp0 * tmp1 + else: + assert reference_points.shape[-1] == 2 + reference_points_input = reference_points[:, :, + None] * src_valid_ratios[:, + None] + + # deformable operation + output = layer(output, query_pos, reference_points_input, src, + src_spatial_shapes, src_level_start_index, + src_padding_mask) + + # hack implementation for iterative bounding box refinement + if self.bbox_embed is not None: + tmp = self.bbox_embed[lid + 1](output) + if reference_points.shape[-1] == 4: + new_reference_points = tmp + inverse_sigmoid( + reference_points) + new_reference_points = new_reference_points.sigmoid() + else: + assert reference_points.shape[-1] == 2 + new_reference_points = tmp + new_reference_points[..., :2] = tmp[ + ..., :2] + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + reference_points = new_reference_points.detach() + # + + if self.return_intermediate: + intermediate.append(output) + intermediate_reference_points.append(reference_points) + + if self.return_intermediate: + return torch.stack(intermediate), torch.stack( + intermediate_reference_points) + + return output, reference_points + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + + if activation == 'relu': + return F.relu + if activation == 'gelu': + return F.gelu + if activation == 'glu': + return F.glu + raise RuntimeError(F'activation should be relu/gelu, not {activation}.') + + +def ms_deform_attn_core_pytorch(value, value_spatial_shapes, + sampling_locations, attention_weights): + # for debug and test only, + # need to use cuda version instead + N_, S_, M_, D_ = value.shape + _, Lq_, M_, L_, P_, _ = sampling_locations.shape + value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], + dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for lid_, (H_, W_) in enumerate(value_spatial_shapes): + # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ + value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape( + N_ * M_, D_, H_, W_) + # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 + sampling_grid_l_ = sampling_grids[:, :, :, + lid_].transpose(1, 2).flatten(0, 1) + # N_*M_, D_, Lq_, P_ + sampling_value_l_ = F.grid_sample( + value_l_, + sampling_grid_l_, + mode='bilinear', + padding_mode='zeros', + align_corners=False) + sampling_value_list.append(sampling_value_l_) + # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) + attention_weights = attention_weights.transpose(1, 2).reshape( + N_ * M_, 1, Lq_, L_ * P_) + output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) + * attention_weights).sum(-1).view(N_, M_ * D_, Lq_) + return output.transpose(1, 2).contiguous() + + +def _is_power_of_2(n): + if (not isinstance(n, int)) or (n < 0): + raise ValueError( + 'invalid input for _is_power_of_2: {} (type: {})'.format( + n, type(n))) + return (n & (n - 1) == 0) and n != 0 + + +class MSDeformAttn(nn.Module): + + def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): + """ + Multi-Scale Deformable Attention Module + :param d_model hidden dimension + :param n_levels number of feature levels + :param n_heads number of attention heads + :param n_points number of sampling points per attention head per feature level + """ + super().__init__() + if d_model % n_heads != 0: + raise ValueError( + 'd_model must be divisible by n_heads, but got {} and {}'. + format(d_model, n_heads)) + _d_per_head = d_model // n_heads + # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation + if not _is_power_of_2(_d_per_head): + warnings.warn( + "You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " + 'which is more efficient in our CUDA implementation.') + + self.im2col_step = 64 + + self.d_model = d_model + self.n_levels = n_levels + self.n_heads = n_heads + self.n_points = n_points + + self.sampling_offsets = nn.Linear(d_model, + n_heads * n_levels * n_points * 2) + self.attention_weights = nn.Linear(d_model, + n_heads * n_levels * n_points) + self.value_proj = nn.Linear(d_model, d_model) + self.output_proj = nn.Linear(d_model, d_model) + + self._reset_parameters() + + def _reset_parameters(self): + constant_(self.sampling_offsets.weight.data, 0.) + thetas = torch.arange( + self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = (grid_init + / grid_init.abs().max(-1, keepdim=True)[0]).view( + self.n_heads, 1, 1, 2).repeat(1, self.n_levels, + self.n_points, 1) + for i in range(self.n_points): + grid_init[:, :, i, :] *= i + 1 + with torch.no_grad(): + self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + constant_(self.attention_weights.weight.data, 0.) + constant_(self.attention_weights.bias.data, 0.) + xavier_uniform_(self.value_proj.weight.data) + constant_(self.value_proj.bias.data, 0.) + xavier_uniform_(self.output_proj.weight.data) + constant_(self.output_proj.bias.data, 0.) + + def forward(self, + query, + reference_points, + input_flatten, + input_spatial_shapes, + input_level_start_index, + input_padding_mask=None): + """ + :param query (N, Length_{query}, C) + :param reference_points (N, Length_{query}, n_levels, 2) + :param input_flatten (N, \\sum_{l=0}^{L-1} H_l \\cdot W_l, C) + :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + :param input_level_start_index (n_levels, ) + :param input_padding_mask (N, \\sum_{l=0}^{L-1} H_l \\cdot W_l) + :return output (N, Length_{query}, C) + """ + N, Len_q, _ = query.shape + N, Len_in, _ = input_flatten.shape + assert (input_spatial_shapes[:, 0] + * input_spatial_shapes[:, 1]).sum() == Len_in + + value = self.value_proj(input_flatten) + if input_padding_mask is not None: + value = value.masked_fill(input_padding_mask[..., None], float(0)) + value = value.view(N, Len_in, self.n_heads, + self.d_model // self.n_heads) + sampling_offsets = self.sampling_offsets(query).view( + N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) + # attn weights for each sampled query. + attention_weights = self.attention_weights(query).view( + N, Len_q, self.n_heads, self.n_levels * self.n_points) + attention_weights = F.softmax(attention_weights, + -1).view(N, Len_q, self.n_heads, + self.n_levels, self.n_points) + # N, Len_q, n_heads, n_levels, n_points, 2 + + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack( + [input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], + -1) + tmp0 = reference_points[:, :, None, :, None, :] + tmp1 = sampling_offsets / offset_normalizer[None, None, None, :, + None, :] + sampling_locations = tmp0 + tmp1 + elif reference_points.shape[-1] == 4: + tmp0 = reference_points[:, :, None, :, None, :2] + tmp1 = sampling_offsets / self.n_points * reference_points[:, :, + None, :, + None, + 2:] * 0.5 + sampling_locations = tmp0 + tmp1 + else: + raise ValueError( + 'Last dim of reference_points must be 2 or 4, but get {} instead.' + .format(reference_points.shape[-1])) + output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, + sampling_locations, + attention_weights) + output = self.output_proj(output) + + return output + + +def inverse_sigmoid(x, eps=1e-5): + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) diff --git a/modelscope/models/cv/vidt/fpn_fusion.py b/modelscope/models/cv/vidt/fpn_fusion.py new file mode 100644 index 00000000..b48ba0fe --- /dev/null +++ b/modelscope/models/cv/vidt/fpn_fusion.py @@ -0,0 +1,248 @@ +# The implementation here is modified based on timm, +# originally Apache 2.0 License and publicly available at +# https://github.com/naver-ai/vidt/blob/vidt-plus/methods/vidt/fpn_fusion.py + +import torch.nn as nn + + +class FPNFusionModule(nn.Module): + """ This is a fpn-style cross-scale feature fusion module" """ + + def __init__(self, embed_dims, fuse_dim=256, n_block=4, use_bn=False): + super().__init__() + """ Initializes the model. + Args: + embed_dims: the list of channel dim for different scale feature maps (i.e., the input) + fuse_dim: the channel dim of the fused feature map (i.e., the output) + n_block: the number of multi-scale features (default=4) + use_bn: whether to use bn + """ + + self.embed_dims = embed_dims + self.fuse_dim = fuse_dim + self.n_block = n_block + + # cross-scale fusion layers + self.multi_scaler = _make_multi_scale_layers( + embed_dims, fuse_dim, use_bn=use_bn, n_block=n_block) + + def forward(self, x_blocks): + + x_blocks = x_blocks + + # preperation: channel reduction and normalization + for idx in range(self.n_block - 1, -1, -1): + x_blocks[idx] = getattr(self.multi_scaler, f'layer_{idx}_rn')( + x_blocks[idx]) + x_blocks[idx] = getattr(self.multi_scaler, f'p_norm_{idx}')( + x_blocks[idx]) + + # cross-scale fusion + refined_embeds = [] + for idx in range(self.n_block - 1, -1, -1): + if idx == self.n_block - 1: + path = getattr(self.multi_scaler, + f'refinenet_{idx}')([x_blocks[idx]], None) + else: + path = getattr(self.multi_scaler, + f'refinenet_{idx}')([path, x_blocks[idx]], + x_blocks[idx].size()[2:]) + refined_embeds.append(path) + + return refined_embeds + + +def _make_multi_scale_layers(in_shape, + out_shape, + n_block=4, + groups=1, + use_bn=False): + + out_shapes = [out_shape for _ in range(n_block)] + multi_scaler = nn.Module() + + for idx in range(n_block - 1, -1, -1): + """ + 1 x 1 conv for dim reduction -> group norm + """ + layer_name = f'layer_{(idx)}_rn' + multi_scaler.add_module( + layer_name, + nn.Conv2d(in_shape[idx], out_shapes[idx], kernel_size=1)) + + layer_name = f'p_norm_{(idx)}' + multi_scaler.add_module(layer_name, nn.GroupNorm(32, out_shapes[idx])) + + layer_name = f'refinenet_{idx}' + multi_scaler.add_module(layer_name, + _make_fusion_block(out_shape, use_bn)) + + # initialize for the 1x1 conv + nn.init.xavier_uniform_( + getattr(multi_scaler, f'layer_{idx}_rn').weight, gain=1) + nn.init.constant_(getattr(multi_scaler, f'layer_{idx}_rn').bias, 0) + + return multi_scaler + + +def _make_fusion_block(features, use_bn): + """ We use a resnet bottleneck structure for fpn """ + + return FeatureFusionBlock( + features, + nn.ReLU(False), + bn=use_bn, + expand=False, + align_corners=True, + ) + + +class FeatureFusionBlock(nn.Module): + """ Feature fusion block """ + + def __init__(self, + features, + activation, + bn=False, + expand=False, + align_corners=True): + """Init. + Args: + features (int): channel dim of the input feature + activation: activation function to use + bn: whether to use bn + expand: whether to exapnd feature or not + align_corners: wheter to use align_corners for interpolation + """ + + super(FeatureFusionBlock, self).__init__() + self.align_corners = align_corners + self.groups = 1 + self.expand = expand + out_features = features + + if self.expand is True: + out_features = features // 2 + + self.smoothing = nn.Conv2d( + features, + out_features, + kernel_size=1, + bias=True, + groups=1, + ) + + self.resConfUnit1 = ResidualConvUnit(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit(features, activation, bn) + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, xs, up_size): + """ Forward pass. + Args + xs: xs[0]: the feature refined from the previous step, xs[1]: the next scale features to fuse + up_size: the size for upsampling; xs[0] is upsampled before merging with xs[1] + Returns: + output: the fused feature, which is fed to the next fusion step as an input + """ + + output = xs[0] + if len(xs) == 2: + # upsampling + output = nn.functional.interpolate( + output, + size=up_size, + mode='bilinear', + align_corners=self.align_corners) + # feature smoothing since the upsampled feature is coarse-grain + output = self.smoothing(output) + + # refine the next scale feature before fusion + res = self.resConfUnit1(xs[1]) + + # fusion + output = self.skip_add.add(output, res) + + # post refine after fusion + output = self.resConfUnit2(output) + + return output + + +class ResidualConvUnit(nn.Module): + """ Residual convolution module. """ + + def __init__(self, features, activation, bn): + """Init. + Args: + features (int): channel dim of the input + activation: activation function + bn: whether to use bn + """ + + super().__init__() + + self.bn = bn + self.groups = 1 + + self.conv1 = nn.Conv2d( + features, + 64, + kernel_size=1, + stride=1, + bias=not self.bn, + groups=self.groups, + ) + self.conv2 = nn.Conv2d( + 64, + 64, + kernel_size=3, + stride=1, + padding=1, + bias=not self.bn, + groups=self.groups, + ) + self.conv3 = nn.Conv2d( + 64, + features, + kernel_size=1, + stride=1, + bias=not self.bn, + groups=self.groups, + ) + if self.bn is True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + self.bn3 = nn.BatchNorm2d(features) + + self.activation = activation + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """ Forward pass + + Args: + x (tensor): input feature + + Returns: + tensor: output feature + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn is True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn is True: + out = self.bn2(out) + + out = self.activation(out) + out = self.conv3(out) + if self.bn is True: + out = self.bn3(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) diff --git a/modelscope/models/cv/vidt/head.py b/modelscope/models/cv/vidt/head.py new file mode 100644 index 00000000..28737e96 --- /dev/null +++ b/modelscope/models/cv/vidt/head.py @@ -0,0 +1,413 @@ +# The implementation here is modified based on timm, +# originally Apache 2.0 License and publicly available at +# https://github.com/naver-ai/vidt/blob/vidt-plus/methods/vidt/detector.py + +import copy +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Detector(nn.Module): + """ This is a combination of "Swin with RAM" and a "Neck-free Deformable Decoder" """ + + def __init__( + self, + backbone, + transformer, + num_classes, + num_queries, + aux_loss=False, + with_box_refine=False, + # The three additional techniques for ViDT+ + epff=None, # (1) Efficient Pyramid Feature Fusion Module + with_vector=False, + processor_dct=None, + vector_hidden_dim=256, # (2) UQR Module + iou_aware=False, + token_label=False, # (3) Additional losses + distil=False): + """ Initializes the model. + Args: + backbone: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + num_classes: number of object classes + num_queries: number of object queries (i.e., det tokens). This is the maximal number of objects + DETR can detect in a single image. For COCO, we recommend 100 queries. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + with_box_refine: iterative bounding box refinement + epff: None or fusion module available + iou_aware: True if iou_aware is to be used. + see the original paper https://arxiv.org/abs/1912.05992 + token_label: True if token_label is to be used. + see the original paper https://arxiv.org/abs/2104.10858 + distil: whether to use knowledge distillation with token matching + """ + + super().__init__() + self.num_queries = num_queries + self.transformer = transformer + hidden_dim = transformer.d_model + self.class_embed = nn.Linear(hidden_dim, num_classes) + self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) + + # two essential techniques used [default use] + self.aux_loss = aux_loss + self.with_box_refine = with_box_refine + + # For UQR module for ViDT+ + self.with_vector = with_vector + self.processor_dct = processor_dct + if self.with_vector: + print( + f'Training with vector_hidden_dim {vector_hidden_dim}.', + flush=True) + self.vector_embed = MLP(hidden_dim, vector_hidden_dim, + self.processor_dct.n_keep, 3) + + # For two additional losses for ViDT+ + self.iou_aware = iou_aware + self.token_label = token_label + + # distillation + self.distil = distil + + # For EPFF module for ViDT+ + if epff is None: + num_backbone_outs = len(backbone.num_channels) + input_proj_list = [] + for _ in range(num_backbone_outs): + in_channels = backbone.num_channels[_] + input_proj_list.append( + nn.Sequential( + # This is 1x1 conv -> so linear layer + nn.Conv2d(in_channels, hidden_dim, kernel_size=1), + nn.GroupNorm(32, hidden_dim), + )) + self.input_proj = nn.ModuleList(input_proj_list) + + # initialize the projection layer for [PATCH] tokens + for proj in self.input_proj: + nn.init.xavier_uniform_(proj[0].weight, gain=1) + nn.init.constant_(proj[0].bias, 0) + self.fusion = None + else: + # the cross scale fusion module has its own reduction layers + self.fusion = epff + + # channel dim reduction for [DET] tokens + self.tgt_proj = nn.Sequential( + # This is 1x1 conv -> so linear layer + nn.Conv2d(backbone.num_channels[-2], hidden_dim, kernel_size=1), + nn.GroupNorm(32, hidden_dim), + ) + + # channel dim reductionfor [DET] learnable pos encodings + self.query_pos_proj = nn.Sequential( + # This is 1x1 conv -> so linear layer + nn.Conv2d(hidden_dim, hidden_dim, kernel_size=1), + nn.GroupNorm(32, hidden_dim), + ) + + # initialize detection head: box regression and classification + prior_prob = 0.01 + bias_value = -math.log((1 - prior_prob) / prior_prob) + self.class_embed.bias.data = torch.ones(num_classes) * bias_value + nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0) + nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0) + + # initialize projection layer for [DET] tokens and encodings + nn.init.xavier_uniform_(self.tgt_proj[0].weight, gain=1) + nn.init.constant_(self.tgt_proj[0].bias, 0) + nn.init.xavier_uniform_(self.query_pos_proj[0].weight, gain=1) + nn.init.constant_(self.query_pos_proj[0].bias, 0) + + if self.with_vector: + nn.init.constant_(self.vector_embed.layers[-1].weight.data, 0) + nn.init.constant_(self.vector_embed.layers[-1].bias.data, 0) + + # the prediction is made for each decoding layers + the standalone detector (Swin with RAM) + num_pred = transformer.decoder.num_layers + 1 + + # set up all required nn.Module for additional techniques + if with_box_refine: + self.class_embed = _get_clones(self.class_embed, num_pred) + self.bbox_embed = _get_clones(self.bbox_embed, num_pred) + nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], + -2.0) + # hack implementation for iterative bounding box refinement + self.transformer.decoder.bbox_embed = self.bbox_embed + else: + nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0) + self.class_embed = nn.ModuleList( + [self.class_embed for _ in range(num_pred)]) + self.bbox_embed = nn.ModuleList( + [self.bbox_embed for _ in range(num_pred)]) + self.transformer.decoder.bbox_embed = None + + if self.with_vector: + nn.init.constant_(self.vector_embed.layers[-1].bias.data[2:], -2.0) + self.vector_embed = nn.ModuleList( + [self.vector_embed for _ in range(num_pred)]) + + if self.iou_aware: + self.iou_embed = MLP(hidden_dim, hidden_dim, 1, 3) + if with_box_refine: + self.iou_embed = _get_clones(self.iou_embed, num_pred) + else: + self.iou_embed = nn.ModuleList( + [self.iou_embed for _ in range(num_pred)]) + + def forward(self, features_0, features_1, features_2, features_3, det_tgt, + det_pos, mask): + """ The forward step of ViDT + + Args: + The forward expects a NestedTensor, which consists of: + - features_0: images feature + - features_1: images feature + - features_2: images feature + - features_3: images feature + - det_tgt: images det logits feature + - det_pos: images det position feature + - mask: images mask + Returns: + A dictionary having the key and value pairs below: + - "out_pred_logits": the classification logits (including no-object) for all queries. + Shape= [batch_size x num_queries x (num_classes + 1)] + - "out_pred_boxes": The normalized boxes coordinates for all queries, represented as + (center_x, center_y, height, width). These values are normalized in [0, 1], + relative to the size of each individual image (disregarding possible padding). + See PostProcess for information on how to retrieve the unnormalized bounding box. + """ + features = [features_0, features_1, features_2, features_3] + + # [DET] token and encoding projection to compact representation for the input to the Neck-free transformer + det_tgt = self.tgt_proj(det_tgt.unsqueeze(-1)).squeeze(-1).permute( + 0, 2, 1) + det_pos = self.query_pos_proj( + det_pos.unsqueeze(-1)).squeeze(-1).permute(0, 2, 1) + + # [PATCH] token projection + shapes = [] + for le, src in enumerate(features): + shapes.append(src.shape[-2:]) + + srcs = [] + if self.fusion is None: + for le, src in enumerate(features): + srcs.append(self.input_proj[le](src)) + else: + # EPFF (multi-scale fusion) is used if fusion is activated + srcs = self.fusion(features) + + masks = [] + for le, src in enumerate(srcs): + # resize mask + shapes.append(src.shape[-2:]) + _mask = F.interpolate( + mask[None].float(), size=src.shape[-2:]).to(torch.bool)[0] + masks.append(_mask) + assert mask is not None + + outputs_classes = [] + outputs_coords = [] + + # return the output of the neck-free decoder + hs, init_reference, inter_references, enc_token_class_unflat = self.transformer( + srcs, masks, det_tgt, det_pos) + + # perform predictions via the detection head + for lvl in range(hs.shape[0]): + reference = init_reference if lvl == 0 else inter_references[lvl + - 1] + reference = inverse_sigmoid(reference) + + outputs_class = self.class_embed[lvl](hs[lvl]) + # bbox output + reference + tmp = self.bbox_embed[lvl](hs[lvl]) + if reference.shape[-1] == 4: + tmp += reference + else: + assert reference.shape[-1] == 2 + tmp[..., :2] += reference + + outputs_coord = tmp.sigmoid() + outputs_classes.append(outputs_class) + outputs_coords.append(outputs_coord) + + # stack all predictions made from each decoding layers + outputs_class = torch.stack(outputs_classes) + outputs_coord = torch.stack(outputs_coords) + + outputs_vector = None + if self.with_vector: + outputs_vectors = [] + for lvl in range(hs.shape[0]): + outputs_vector = self.vector_embed[lvl](hs[lvl]) + outputs_vectors.append(outputs_vector) + outputs_vector = torch.stack(outputs_vectors) + + # final prediction is made the last decoding layer + out = { + 'pred_logits': outputs_class[-1], + 'pred_boxes': outputs_coord[-1] + } + + if self.with_vector: + out.update({'pred_vectors': outputs_vector[-1]}) + + # aux loss is defined by using the rest predictions + if self.aux_loss and self.transformer.decoder.num_layers > 0: + out['aux_outputs'] = self._set_aux_loss(outputs_class, + outputs_coord, + outputs_vector) + + # iou awareness loss is defined for each decoding layer similar to auxiliary decoding loss + if self.iou_aware: + outputs_ious = [] + for lvl in range(hs.shape[0]): + outputs_ious.append(self.iou_embed[lvl](hs[lvl])) + outputs_iou = torch.stack(outputs_ious) + out['pred_ious'] = outputs_iou[-1] + + if self.aux_loss: + for i, aux in enumerate(out['aux_outputs']): + aux['pred_ious'] = outputs_iou[i] + + # token label loss + if self.token_label: + out['enc_tokens'] = {'pred_logits': enc_token_class_unflat} + + if self.distil: + # 'patch_token': multi-scale patch tokens from each stage + # 'body_det_token' and 'neck_det_tgt': the input det_token for multiple detection heads + out['distil_tokens'] = { + 'patch_token': srcs, + 'body_det_token': det_tgt, + 'neck_det_token': hs + } + + out_pred_logits = out['pred_logits'] + out_pred_boxes = out['pred_boxes'] + return out_pred_logits, out_pred_boxes + + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_coord, outputs_vector): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + + if outputs_vector is None: + return [{ + 'pred_logits': a, + 'pred_boxes': b + } for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] + else: + return [{ + 'pred_logits': a, + 'pred_boxes': b, + 'pred_vectors': c + } for a, b, c in zip(outputs_class[:-1], outputs_coord[:-1], + outputs_vector[:-1])] + + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +def inverse_sigmoid(x, eps=1e-5): + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + + +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=-1) + + +# process post_results +def get_predictions(post_results, bbox_thu=0.40): + batch_final_res = [] + for per_img_res in post_results: + per_img_final_res = [] + for i in range(len(per_img_res['scores'])): + score = float(per_img_res['scores'][i].cpu()) + label = int(per_img_res['labels'][i].cpu()) + bbox = [] + for it in per_img_res['boxes'][i].cpu(): + bbox.append(int(it)) + if score >= bbox_thu: + per_img_final_res.append([score, label, bbox]) + batch_final_res.append(per_img_final_res) + return batch_final_res + + +class PostProcess(nn.Module): + """ This module converts the model's output into the format expected by the coco api""" + + def __init__(self, processor_dct=None): + super().__init__() + # For instance segmentation using UQR module + self.processor_dct = processor_dct + + @torch.no_grad() + def forward(self, out_logits, out_bbox, target_sizes): + """ Perform the computation + + Args: + out_logits: raw logits outputs of the model + out_bbox: raw bbox outputs of the model + target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch + For evaluation, this must be the original image size (before any data augmentation) + For visualization, this should be the image size after data augment, but before padding + """ + assert len(out_logits) == len(target_sizes) + assert target_sizes.shape[1] == 2 + + prob = out_logits.sigmoid() + topk_values, topk_indexes = torch.topk( + prob.view(out_logits.shape[0], -1), 100, dim=1) + scores = topk_values + topk_boxes = topk_indexes // out_logits.shape[2] + labels = topk_indexes % out_logits.shape[2] + boxes = box_cxcywh_to_xyxy(out_bbox) + boxes = torch.gather(boxes, 1, + topk_boxes.unsqueeze(-1).repeat(1, 1, 4)) + + # and from relative [0, 1] to absolute [0, height] coordinates + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], + dim=1).to(torch.float32) + boxes = boxes * scale_fct[:, None, :] + + results = [{ + 'scores': s, + 'labels': l, + 'boxes': b + } for s, l, b in zip(scores, labels, boxes)] + + return results + + +def _get_clones(module, N): + """ Clone a moudle N times """ + + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) diff --git a/modelscope/models/cv/vidt/model.py b/modelscope/models/cv/vidt/model.py new file mode 100644 index 00000000..65940637 --- /dev/null +++ b/modelscope/models/cv/vidt/model.py @@ -0,0 +1,98 @@ +# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import os + +import torch + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks +from .backbone import SwinTransformer +from .deformable_transformer import DeformableTransformer +from .fpn_fusion import FPNFusionModule +from .head import Detector + + +@MODELS.register_module(Tasks.image_object_detection, module_name=Models.vidt) +class VidtModel(TorchModel): + """ + The implementation of 'ViDT for joint-learning of object detection and instance segmentation'. + This model is dynamically initialized with the following parts: + - 'backbone': pre-trained backbone model with parameters. + - 'head': detection and segentation head with fine-tuning. + """ + + def __init__(self, model_dir: str, **kwargs): + """ Initialize a Vidt Model. + Args: + model_dir: model id or path, where model_dir/pytorch_model.pt contains: + - 'backbone_weights': parameters of backbone. + - 'head_weights': parameters of head. + """ + super(VidtModel, self).__init__() + + model_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE) + model_dict = torch.load(model_path, map_location='cpu') + + # build backbone + backbone = SwinTransformer( + pretrain_img_size=[224, 224], + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + drop_path_rate=0.2) + backbone.finetune_det( + method='vidt', det_token_num=300, pos_dim=256, cross_indices=[3]) + self.backbone = backbone + self.backbone.load_state_dict( + model_dict['backbone_weights'], strict=True) + + # build head + epff = FPNFusionModule(backbone.num_channels, fuse_dim=256) + deform_transformers = DeformableTransformer( + d_model=256, + nhead=8, + num_decoder_layers=6, + dim_feedforward=1024, + dropout=0.1, + activation='relu', + return_intermediate_dec=True, + num_feature_levels=4, + dec_n_points=4, + token_label=False) + head = Detector( + backbone, + deform_transformers, + num_classes=2, + num_queries=300, + # two essential techniques used in ViDT + aux_loss=True, + with_box_refine=True, + # an epff module for ViDT+ + epff=epff, + # an UQR module for ViDT+ + with_vector=False, + processor_dct=None, + # two additional losses for VIDT+ + iou_aware=True, + token_label=False, + vector_hidden_dim=256, + # distil + distil=False) + self.head = head + self.head.load_state_dict(model_dict['head_weights'], strict=True) + + def forward(self, x, mask): + """ Dynamic forward function of VidtModel. + Args: + x: input images (B, 3, H, W) + mask: input padding masks (B, H, W) + """ + features_0, features_1, features_2, features_3, det_tgt, det_pos = self.backbone( + x, mask) + out_pred_logits, out_pred_boxes = self.head(features_0, features_1, + features_2, features_3, + det_tgt, det_pos, mask) + return out_pred_logits, out_pred_boxes diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index 443d4d43..f1c027a0 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -81,6 +81,7 @@ if TYPE_CHECKING: from .vision_efficient_tuning_prefix_pipeline import VisionEfficientTuningPrefixPipeline from .vision_efficient_tuning_lora_pipeline import VisionEfficientTuningLoRAPipeline from .vision_middleware_pipeline import VisionMiddlewarePipeline + from .vidt_pipeline import VidtPipeline from .video_frame_interpolation_pipeline import VideoFrameInterpolationPipeline from .image_skychange_pipeline import ImageSkychangePipeline from .image_driving_perception_pipeline import ImageDrivingPerceptionPipeline @@ -219,6 +220,7 @@ else: 'VisionEfficientTuningLoRAPipeline' ], 'vision_middleware_pipeline': ['VisionMiddlewarePipeline'], + 'vidt_pipeline': ['VidtPipeline'], 'video_frame_interpolation_pipeline': [ 'VideoFrameInterpolationPipeline' ], diff --git a/modelscope/pipelines/cv/vidt_pipeline.py b/modelscope/pipelines/cv/vidt_pipeline.py new file mode 100644 index 00000000..5c16c35e --- /dev/null +++ b/modelscope/pipelines/cv/vidt_pipeline.py @@ -0,0 +1,207 @@ +# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from typing import Any, Dict + +import torch +import torchvision.transforms as transforms +from torch import nn + +from modelscope.metainfo import Pipelines +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_object_detection, module_name=Pipelines.vidt) +class VidtPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a vidt pipeline for prediction + Args: + model: model id on modelscope hub. + Example: + >>> from modelscope.pipelines import pipeline + >>> vidt_pipeline = pipeline('image-object-detection', 'damo/ViDT-logo-detection') + >>> result = vidt_pipeline( + 'data/test/images/vidt_test1.png') + >>> print(f'Output: {result}.') + """ + super().__init__(model=model, **kwargs) + + self.model.eval() + self.transform = transforms.Compose([ + transforms.Resize([640, 640]), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + self.postprocessors = PostProcess() + self.label_dic = {0: 'negative', 1: 'positive'} + + def preprocess(self, inputs: Input, **preprocess_params): + img = LoadImage.convert_to_img(inputs) + ori_size = [img.size[1], img.size[0]] + image = self.transform(img) + tensor_list = [image] + orig_target_sizes = [ori_size] + orig_target_sizes = torch.tensor(orig_target_sizes).to(self.device) + samples = nested_tensor_from_tensor_list(tensor_list) + samples = samples.to(self.device) + res = {} + res['tensors'] = samples.tensors + res['mask'] = samples.mask + res['orig_target_sizes'] = orig_target_sizes + return res + + def forward(self, inputs: Dict[str, Any], **forward_params): + tensors = inputs['tensors'] + mask = inputs['mask'] + orig_target_sizes = inputs['orig_target_sizes'] + with torch.no_grad(): + out_pred_logits, out_pred_boxes = self.model(tensors, mask) + res = {} + res['out_pred_logits'] = out_pred_logits + res['out_pred_boxes'] = out_pred_boxes + res['orig_target_sizes'] = orig_target_sizes + return res + + def postprocess(self, inputs: Dict[str, Any], **post_params): + results = self.postprocessors(inputs['out_pred_logits'], + inputs['out_pred_boxes'], + inputs['orig_target_sizes']) + batch_predictions = get_predictions(results)[0] # 仅支持单张图推理 + scores = [] + labels = [] + boxes = [] + for sub_pre in batch_predictions: + scores.append(sub_pre[0]) + labels.append(self.label_dic[sub_pre[1]]) + boxes.append(sub_pre[2]) # [xmin, ymin, xmax, ymax] + outputs = {} + outputs['scores'] = scores + outputs['labels'] = labels + outputs['boxes'] = boxes + return outputs + + +def nested_tensor_from_tensor_list(tensor_list): + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[:img.shape[0], :img.shape[1], :img.shape[2]].copy_(img) + m[:img.shape[1], :img.shape[2]] = False + return NestedTensor(tensor, mask) + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +class NestedTensor(object): + + def __init__(self, tensors, mask): + self.tensors = tensors + self.mask = mask + + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=-1) + + +# process post_results +def get_predictions(post_results, bbox_thu=0.40): + batch_final_res = [] + for per_img_res in post_results: + per_img_final_res = [] + for i in range(len(per_img_res['scores'])): + score = float(per_img_res['scores'][i].cpu()) + label = int(per_img_res['labels'][i].cpu()) + bbox = [] + for it in per_img_res['boxes'][i].cpu(): + bbox.append(int(it)) + if score >= bbox_thu: + per_img_final_res.append([score, label, bbox]) + batch_final_res.append(per_img_final_res) + return batch_final_res + + +class PostProcess(nn.Module): + """ This module converts the model's output into the format expected by the coco api""" + + def __init__(self, processor_dct=None): + super().__init__() + # For instance segmentation using UQR module + self.processor_dct = processor_dct + + @torch.no_grad() + def forward(self, out_logits, out_bbox, target_sizes): + """ Perform the computation + + Parameters: + out_logits: raw logits outputs of the model + out_bbox: raw bbox outputs of the model + target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch + For evaluation, this must be the original image size (before any data augmentation) + For visualization, this should be the image size after data augment, but before padding + """ + assert len(out_logits) == len(target_sizes) + assert target_sizes.shape[1] == 2 + + prob = out_logits.sigmoid() + topk_values, topk_indexes = torch.topk( + prob.view(out_logits.shape[0], -1), 100, dim=1) + scores = topk_values + topk_boxes = topk_indexes // out_logits.shape[2] + labels = topk_indexes % out_logits.shape[2] + boxes = box_cxcywh_to_xyxy(out_bbox) + boxes = torch.gather(boxes, 1, + topk_boxes.unsqueeze(-1).repeat(1, 1, 4)) + + # and from relative [0, 1] to absolute [0, height] coordinates + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], + dim=1).to(torch.float32) + boxes = boxes * scale_fct[:, None, :] + + results = [{ + 'scores': s, + 'labels': l, + 'boxes': b + } for s, l, b in zip(scores, labels, boxes)] + + return results diff --git a/tests/pipelines/test_vidt_face.py b/tests/pipelines/test_vidt_face.py new file mode 100644 index 00000000..8640d128 --- /dev/null +++ b/tests/pipelines/test_vidt_face.py @@ -0,0 +1,31 @@ +# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import unittest + +from modelscope.models import Model +from modelscope.models.cv.vidt import VidtModel +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class VidtTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.image_object_detection + self.model_id = 'damo/ViDT-face-detection' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_pipeline(self): + vidt_pipeline = pipeline(self.task, self.model_id) + result = vidt_pipeline('data/test/images/vidt_test1.jpg') + print(f'Vidt output: {result}.') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_load_model_from_pretrained(self): + model = Model.from_pretrained('damo/ViDT-face-detection') + self.assertTrue(model.__class__ == VidtModel) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/pipelines/test_vidt_logo.py b/tests/pipelines/test_vidt_logo.py new file mode 100644 index 00000000..143eb205 --- /dev/null +++ b/tests/pipelines/test_vidt_logo.py @@ -0,0 +1,31 @@ +# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import unittest + +from modelscope.models import Model +from modelscope.models.cv.vidt import VidtModel +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class VidtTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.image_object_detection + self.model_id = 'damo/ViDT-logo-detection' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_pipeline(self): + vidt_pipeline = pipeline(self.task, self.model_id) + result = vidt_pipeline('data/test/images/vidt_test1.jpg') + print(f'Vidt output: {result}.') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_load_model_from_pretrained(self): + model = Model.from_pretrained('damo/ViDT-logo-detection') + self.assertTrue(model.__class__ == VidtModel) + + +if __name__ == '__main__': + unittest.main()