Files
Track-Anything/inference/interact/fbrs/model/is_deeplab_model.py
gaomingqi 9f30e59c45 add xmem
2023-04-12 08:24:08 +08:00

87 lines
3.2 KiB
Python

import torch
import torch.nn as nn
from .ops import DistMaps
from .modeling.deeplab_v3 import DeepLabV3Plus
from .modeling.basic_blocks import SepConvHead
def get_deeplab_model(backbone='resnet50', deeplab_ch=256, aspp_dropout=0.5,
norm_layer=nn.BatchNorm2d, backbone_norm_layer=None,
use_rgb_conv=True, cpu_dist_maps=False,
norm_radius=260):
model = DistMapsModel(
feature_extractor=DeepLabV3Plus(backbone=backbone,
ch=deeplab_ch,
project_dropout=aspp_dropout,
norm_layer=norm_layer,
backbone_norm_layer=backbone_norm_layer),
head=SepConvHead(1, in_channels=deeplab_ch, mid_channels=deeplab_ch // 2,
num_layers=2, norm_layer=norm_layer),
use_rgb_conv=use_rgb_conv,
norm_layer=norm_layer,
norm_radius=norm_radius,
cpu_dist_maps=cpu_dist_maps
)
return model
class DistMapsModel(nn.Module):
def __init__(self, feature_extractor, head, norm_layer=nn.BatchNorm2d, use_rgb_conv=True,
cpu_dist_maps=False, norm_radius=260):
super(DistMapsModel, self).__init__()
if use_rgb_conv:
self.rgb_conv = nn.Sequential(
nn.Conv2d(in_channels=5, out_channels=8, kernel_size=1),
nn.LeakyReLU(negative_slope=0.2),
norm_layer(8),
nn.Conv2d(in_channels=8, out_channels=3, kernel_size=1),
)
else:
self.rgb_conv = None
self.dist_maps = DistMaps(norm_radius=norm_radius, spatial_scale=1.0,
cpu_mode=cpu_dist_maps)
self.feature_extractor = feature_extractor
self.head = head
def forward(self, image, points):
coord_features = self.dist_maps(image, points)
if self.rgb_conv is not None:
x = self.rgb_conv(torch.cat((image, coord_features), dim=1))
else:
c1, c2 = torch.chunk(coord_features, 2, dim=1)
c3 = torch.ones_like(c1)
coord_features = torch.cat((c1, c2, c3), dim=1)
x = 0.8 * image * coord_features + 0.2 * image
backbone_features = self.feature_extractor(x)
instance_out = self.head(backbone_features[0])
instance_out = nn.functional.interpolate(instance_out, size=image.size()[2:],
mode='bilinear', align_corners=True)
return {'instances': instance_out}
def load_weights(self, path_to_weights):
current_state_dict = self.state_dict()
new_state_dict = torch.load(path_to_weights, map_location='cpu')
current_state_dict.update(new_state_dict)
self.load_state_dict(current_state_dict)
def get_trainable_params(self):
backbone_params = nn.ParameterList()
other_params = nn.ParameterList()
for name, param in self.named_parameters():
if param.requires_grad:
if 'backbone' in name:
backbone_params.append(param)
else:
other_params.append(param)
return backbone_params, other_params