mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 08:27:49 +01:00
78 lines
3.0 KiB
Python
78 lines
3.0 KiB
Python
# Modified from https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
class BasicConv(nn.Module):
|
|
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
|
super(BasicConv, self).__init__()
|
|
self.out_channels = out_planes
|
|
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
return x
|
|
|
|
class Flatten(nn.Module):
|
|
def forward(self, x):
|
|
return x.view(x.size(0), -1)
|
|
|
|
class ChannelGate(nn.Module):
|
|
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
|
|
super(ChannelGate, self).__init__()
|
|
self.gate_channels = gate_channels
|
|
self.mlp = nn.Sequential(
|
|
Flatten(),
|
|
nn.Linear(gate_channels, gate_channels // reduction_ratio),
|
|
nn.ReLU(),
|
|
nn.Linear(gate_channels // reduction_ratio, gate_channels)
|
|
)
|
|
self.pool_types = pool_types
|
|
def forward(self, x):
|
|
channel_att_sum = None
|
|
for pool_type in self.pool_types:
|
|
if pool_type=='avg':
|
|
avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
|
|
channel_att_raw = self.mlp( avg_pool )
|
|
elif pool_type=='max':
|
|
max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
|
|
channel_att_raw = self.mlp( max_pool )
|
|
|
|
if channel_att_sum is None:
|
|
channel_att_sum = channel_att_raw
|
|
else:
|
|
channel_att_sum = channel_att_sum + channel_att_raw
|
|
|
|
scale = torch.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
|
|
return x * scale
|
|
|
|
class ChannelPool(nn.Module):
|
|
def forward(self, x):
|
|
return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )
|
|
|
|
class SpatialGate(nn.Module):
|
|
def __init__(self):
|
|
super(SpatialGate, self).__init__()
|
|
kernel_size = 7
|
|
self.compress = ChannelPool()
|
|
self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2)
|
|
def forward(self, x):
|
|
x_compress = self.compress(x)
|
|
x_out = self.spatial(x_compress)
|
|
scale = torch.sigmoid(x_out) # broadcasting
|
|
return x * scale
|
|
|
|
class CBAM(nn.Module):
|
|
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
|
|
super(CBAM, self).__init__()
|
|
self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
|
|
self.no_spatial=no_spatial
|
|
if not no_spatial:
|
|
self.SpatialGate = SpatialGate()
|
|
def forward(self, x):
|
|
x_out = self.ChannelGate(x)
|
|
if not self.no_spatial:
|
|
x_out = self.SpatialGate(x_out)
|
|
return x_out
|