mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 08:27:49 +01:00
177 lines
6.2 KiB
Python
177 lines
6.2 KiB
Python
from contextlib import ExitStack
|
|
|
|
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
|
|
from .basic_blocks import SeparableConv2d
|
|
from .resnet import ResNetBackbone
|
|
from ...model import ops
|
|
|
|
|
|
class DeepLabV3Plus(nn.Module):
|
|
def __init__(self, backbone='resnet50', norm_layer=nn.BatchNorm2d,
|
|
backbone_norm_layer=None,
|
|
ch=256,
|
|
project_dropout=0.5,
|
|
inference_mode=False,
|
|
**kwargs):
|
|
super(DeepLabV3Plus, self).__init__()
|
|
if backbone_norm_layer is None:
|
|
backbone_norm_layer = norm_layer
|
|
|
|
self.backbone_name = backbone
|
|
self.norm_layer = norm_layer
|
|
self.backbone_norm_layer = backbone_norm_layer
|
|
self.inference_mode = False
|
|
self.ch = ch
|
|
self.aspp_in_channels = 2048
|
|
self.skip_project_in_channels = 256 # layer 1 out_channels
|
|
|
|
self._kwargs = kwargs
|
|
if backbone == 'resnet34':
|
|
self.aspp_in_channels = 512
|
|
self.skip_project_in_channels = 64
|
|
|
|
self.backbone = ResNetBackbone(backbone=self.backbone_name, pretrained_base=False,
|
|
norm_layer=self.backbone_norm_layer, **kwargs)
|
|
|
|
self.head = _DeepLabHead(in_channels=ch + 32, mid_channels=ch, out_channels=ch,
|
|
norm_layer=self.norm_layer)
|
|
self.skip_project = _SkipProject(self.skip_project_in_channels, 32, norm_layer=self.norm_layer)
|
|
self.aspp = _ASPP(in_channels=self.aspp_in_channels,
|
|
atrous_rates=[12, 24, 36],
|
|
out_channels=ch,
|
|
project_dropout=project_dropout,
|
|
norm_layer=self.norm_layer)
|
|
|
|
if inference_mode:
|
|
self.set_prediction_mode()
|
|
|
|
def load_pretrained_weights(self):
|
|
pretrained = ResNetBackbone(backbone=self.backbone_name, pretrained_base=True,
|
|
norm_layer=self.backbone_norm_layer, **self._kwargs)
|
|
backbone_state_dict = self.backbone.state_dict()
|
|
pretrained_state_dict = pretrained.state_dict()
|
|
|
|
backbone_state_dict.update(pretrained_state_dict)
|
|
self.backbone.load_state_dict(backbone_state_dict)
|
|
|
|
if self.inference_mode:
|
|
for param in self.backbone.parameters():
|
|
param.requires_grad = False
|
|
|
|
def set_prediction_mode(self):
|
|
self.inference_mode = True
|
|
self.eval()
|
|
|
|
def forward(self, x):
|
|
with ExitStack() as stack:
|
|
if self.inference_mode:
|
|
stack.enter_context(torch.no_grad())
|
|
|
|
c1, _, c3, c4 = self.backbone(x)
|
|
c1 = self.skip_project(c1)
|
|
|
|
x = self.aspp(c4)
|
|
x = F.interpolate(x, c1.size()[2:], mode='bilinear', align_corners=True)
|
|
x = torch.cat((x, c1), dim=1)
|
|
x = self.head(x)
|
|
|
|
return x,
|
|
|
|
|
|
class _SkipProject(nn.Module):
|
|
def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d):
|
|
super(_SkipProject, self).__init__()
|
|
_activation = ops.select_activation_function("relu")
|
|
|
|
self.skip_project = nn.Sequential(
|
|
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
|
|
norm_layer(out_channels),
|
|
_activation()
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.skip_project(x)
|
|
|
|
|
|
class _DeepLabHead(nn.Module):
|
|
def __init__(self, out_channels, in_channels, mid_channels=256, norm_layer=nn.BatchNorm2d):
|
|
super(_DeepLabHead, self).__init__()
|
|
|
|
self.block = nn.Sequential(
|
|
SeparableConv2d(in_channels=in_channels, out_channels=mid_channels, dw_kernel=3,
|
|
dw_padding=1, activation='relu', norm_layer=norm_layer),
|
|
SeparableConv2d(in_channels=mid_channels, out_channels=mid_channels, dw_kernel=3,
|
|
dw_padding=1, activation='relu', norm_layer=norm_layer),
|
|
nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1)
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.block(x)
|
|
|
|
|
|
class _ASPP(nn.Module):
|
|
def __init__(self, in_channels, atrous_rates, out_channels=256,
|
|
project_dropout=0.5, norm_layer=nn.BatchNorm2d):
|
|
super(_ASPP, self).__init__()
|
|
|
|
b0 = nn.Sequential(
|
|
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=False),
|
|
norm_layer(out_channels),
|
|
nn.ReLU()
|
|
)
|
|
|
|
rate1, rate2, rate3 = tuple(atrous_rates)
|
|
b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer)
|
|
b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer)
|
|
b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer)
|
|
b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer)
|
|
|
|
self.concurent = nn.ModuleList([b0, b1, b2, b3, b4])
|
|
|
|
project = [
|
|
nn.Conv2d(in_channels=5*out_channels, out_channels=out_channels,
|
|
kernel_size=1, bias=False),
|
|
norm_layer(out_channels),
|
|
nn.ReLU()
|
|
]
|
|
if project_dropout > 0:
|
|
project.append(nn.Dropout(project_dropout))
|
|
self.project = nn.Sequential(*project)
|
|
|
|
def forward(self, x):
|
|
x = torch.cat([block(x) for block in self.concurent], dim=1)
|
|
|
|
return self.project(x)
|
|
|
|
|
|
class _AsppPooling(nn.Module):
|
|
def __init__(self, in_channels, out_channels, norm_layer):
|
|
super(_AsppPooling, self).__init__()
|
|
|
|
self.gap = nn.Sequential(
|
|
nn.AdaptiveAvgPool2d((1, 1)),
|
|
nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
|
|
kernel_size=1, bias=False),
|
|
norm_layer(out_channels),
|
|
nn.ReLU()
|
|
)
|
|
|
|
def forward(self, x):
|
|
pool = self.gap(x)
|
|
return F.interpolate(pool, x.size()[2:], mode='bilinear', align_corners=True)
|
|
|
|
|
|
def _ASPPConv(in_channels, out_channels, atrous_rate, norm_layer):
|
|
block = nn.Sequential(
|
|
nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
|
|
kernel_size=3, padding=atrous_rate,
|
|
dilation=atrous_rate, bias=False),
|
|
norm_layer(out_channels),
|
|
nn.ReLU()
|
|
)
|
|
|
|
return block
|