mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 16:37:58 +01:00
142 lines
5.6 KiB
Python
142 lines
5.6 KiB
Python
|
|
import torch
|
||
|
|
import torch.nn as nn
|
||
|
|
import torch._utils
|
||
|
|
import torch.nn.functional as F
|
||
|
|
|
||
|
|
|
||
|
|
class SpatialGather_Module(nn.Module):
|
||
|
|
"""
|
||
|
|
Aggregate the context features according to the initial
|
||
|
|
predicted probability distribution.
|
||
|
|
Employ the soft-weighted method to aggregate the context.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, cls_num=0, scale=1):
|
||
|
|
super(SpatialGather_Module, self).__init__()
|
||
|
|
self.cls_num = cls_num
|
||
|
|
self.scale = scale
|
||
|
|
|
||
|
|
def forward(self, feats, probs):
|
||
|
|
batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3)
|
||
|
|
probs = probs.view(batch_size, c, -1)
|
||
|
|
feats = feats.view(batch_size, feats.size(1), -1)
|
||
|
|
feats = feats.permute(0, 2, 1) # batch x hw x c
|
||
|
|
probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw
|
||
|
|
ocr_context = torch.matmul(probs, feats) \
|
||
|
|
.permute(0, 2, 1).unsqueeze(3) # batch x k x c
|
||
|
|
return ocr_context
|
||
|
|
|
||
|
|
|
||
|
|
class SpatialOCR_Module(nn.Module):
|
||
|
|
"""
|
||
|
|
Implementation of the OCR module:
|
||
|
|
We aggregate the global object representation to update the representation for each pixel.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self,
|
||
|
|
in_channels,
|
||
|
|
key_channels,
|
||
|
|
out_channels,
|
||
|
|
scale=1,
|
||
|
|
dropout=0.1,
|
||
|
|
norm_layer=nn.BatchNorm2d,
|
||
|
|
align_corners=True):
|
||
|
|
super(SpatialOCR_Module, self).__init__()
|
||
|
|
self.object_context_block = ObjectAttentionBlock2D(in_channels, key_channels, scale,
|
||
|
|
norm_layer, align_corners)
|
||
|
|
_in_channels = 2 * in_channels
|
||
|
|
|
||
|
|
self.conv_bn_dropout = nn.Sequential(
|
||
|
|
nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, bias=False),
|
||
|
|
nn.Sequential(norm_layer(out_channels), nn.ReLU(inplace=True)),
|
||
|
|
nn.Dropout2d(dropout)
|
||
|
|
)
|
||
|
|
|
||
|
|
def forward(self, feats, proxy_feats):
|
||
|
|
context = self.object_context_block(feats, proxy_feats)
|
||
|
|
|
||
|
|
output = self.conv_bn_dropout(torch.cat([context, feats], 1))
|
||
|
|
|
||
|
|
return output
|
||
|
|
|
||
|
|
|
||
|
|
class ObjectAttentionBlock2D(nn.Module):
|
||
|
|
'''
|
||
|
|
The basic implementation for object context block
|
||
|
|
Input:
|
||
|
|
N X C X H X W
|
||
|
|
Parameters:
|
||
|
|
in_channels : the dimension of the input feature map
|
||
|
|
key_channels : the dimension after the key/query transform
|
||
|
|
scale : choose the scale to downsample the input feature maps (save memory cost)
|
||
|
|
bn_type : specify the bn type
|
||
|
|
Return:
|
||
|
|
N X C X H X W
|
||
|
|
'''
|
||
|
|
|
||
|
|
def __init__(self,
|
||
|
|
in_channels,
|
||
|
|
key_channels,
|
||
|
|
scale=1,
|
||
|
|
norm_layer=nn.BatchNorm2d,
|
||
|
|
align_corners=True):
|
||
|
|
super(ObjectAttentionBlock2D, self).__init__()
|
||
|
|
self.scale = scale
|
||
|
|
self.in_channels = in_channels
|
||
|
|
self.key_channels = key_channels
|
||
|
|
self.align_corners = align_corners
|
||
|
|
|
||
|
|
self.pool = nn.MaxPool2d(kernel_size=(scale, scale))
|
||
|
|
self.f_pixel = nn.Sequential(
|
||
|
|
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
|
||
|
|
kernel_size=1, stride=1, padding=0, bias=False),
|
||
|
|
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
|
||
|
|
nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
|
||
|
|
kernel_size=1, stride=1, padding=0, bias=False),
|
||
|
|
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
|
||
|
|
)
|
||
|
|
self.f_object = nn.Sequential(
|
||
|
|
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
|
||
|
|
kernel_size=1, stride=1, padding=0, bias=False),
|
||
|
|
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
|
||
|
|
nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
|
||
|
|
kernel_size=1, stride=1, padding=0, bias=False),
|
||
|
|
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
|
||
|
|
)
|
||
|
|
self.f_down = nn.Sequential(
|
||
|
|
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
|
||
|
|
kernel_size=1, stride=1, padding=0, bias=False),
|
||
|
|
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
|
||
|
|
)
|
||
|
|
self.f_up = nn.Sequential(
|
||
|
|
nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels,
|
||
|
|
kernel_size=1, stride=1, padding=0, bias=False),
|
||
|
|
nn.Sequential(norm_layer(self.in_channels), nn.ReLU(inplace=True))
|
||
|
|
)
|
||
|
|
|
||
|
|
def forward(self, x, proxy):
|
||
|
|
batch_size, h, w = x.size(0), x.size(2), x.size(3)
|
||
|
|
if self.scale > 1:
|
||
|
|
x = self.pool(x)
|
||
|
|
|
||
|
|
query = self.f_pixel(x).view(batch_size, self.key_channels, -1)
|
||
|
|
query = query.permute(0, 2, 1)
|
||
|
|
key = self.f_object(proxy).view(batch_size, self.key_channels, -1)
|
||
|
|
value = self.f_down(proxy).view(batch_size, self.key_channels, -1)
|
||
|
|
value = value.permute(0, 2, 1)
|
||
|
|
|
||
|
|
sim_map = torch.matmul(query, key)
|
||
|
|
sim_map = (self.key_channels ** -.5) * sim_map
|
||
|
|
sim_map = F.softmax(sim_map, dim=-1)
|
||
|
|
|
||
|
|
# add bg context ...
|
||
|
|
context = torch.matmul(sim_map, value)
|
||
|
|
context = context.permute(0, 2, 1).contiguous()
|
||
|
|
context = context.view(batch_size, self.key_channels, *x.size()[2:])
|
||
|
|
context = self.f_up(context)
|
||
|
|
if self.scale > 1:
|
||
|
|
context = F.interpolate(input=context, size=(h, w),
|
||
|
|
mode='bilinear', align_corners=self.align_corners)
|
||
|
|
|
||
|
|
return context
|