mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 08:27:49 +01:00
84 lines
3.0 KiB
Python
84 lines
3.0 KiB
Python
import torch
|
|
from torch import nn as nn
|
|
import numpy as np
|
|
|
|
from . import initializer as initializer
|
|
from ..utils.cython import get_dist_maps
|
|
|
|
|
|
def select_activation_function(activation):
|
|
if isinstance(activation, str):
|
|
if activation.lower() == 'relu':
|
|
return nn.ReLU
|
|
elif activation.lower() == 'softplus':
|
|
return nn.Softplus
|
|
else:
|
|
raise ValueError(f"Unknown activation type {activation}")
|
|
elif isinstance(activation, nn.Module):
|
|
return activation
|
|
else:
|
|
raise ValueError(f"Unknown activation type {activation}")
|
|
|
|
|
|
class BilinearConvTranspose2d(nn.ConvTranspose2d):
|
|
def __init__(self, in_channels, out_channels, scale, groups=1):
|
|
kernel_size = 2 * scale - scale % 2
|
|
self.scale = scale
|
|
|
|
super().__init__(
|
|
in_channels, out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=scale,
|
|
padding=1,
|
|
groups=groups,
|
|
bias=False)
|
|
|
|
self.apply(initializer.Bilinear(scale=scale, in_channels=in_channels, groups=groups))
|
|
|
|
|
|
class DistMaps(nn.Module):
|
|
def __init__(self, norm_radius, spatial_scale=1.0, cpu_mode=False):
|
|
super(DistMaps, self).__init__()
|
|
self.spatial_scale = spatial_scale
|
|
self.norm_radius = norm_radius
|
|
self.cpu_mode = cpu_mode
|
|
|
|
def get_coord_features(self, points, batchsize, rows, cols):
|
|
if self.cpu_mode:
|
|
coords = []
|
|
for i in range(batchsize):
|
|
norm_delimeter = self.spatial_scale * self.norm_radius
|
|
coords.append(get_dist_maps(points[i].cpu().float().numpy(), rows, cols,
|
|
norm_delimeter))
|
|
coords = torch.from_numpy(np.stack(coords, axis=0)).to(points.device).float()
|
|
else:
|
|
num_points = points.shape[1] // 2
|
|
points = points.view(-1, 2)
|
|
invalid_points = torch.max(points, dim=1, keepdim=False)[0] < 0
|
|
row_array = torch.arange(start=0, end=rows, step=1, dtype=torch.float32, device=points.device)
|
|
col_array = torch.arange(start=0, end=cols, step=1, dtype=torch.float32, device=points.device)
|
|
|
|
coord_rows, coord_cols = torch.meshgrid(row_array, col_array)
|
|
coords = torch.stack((coord_rows, coord_cols), dim=0).unsqueeze(0).repeat(points.size(0), 1, 1, 1)
|
|
|
|
add_xy = (points * self.spatial_scale).view(points.size(0), points.size(1), 1, 1)
|
|
coords.add_(-add_xy)
|
|
coords.div_(self.norm_radius * self.spatial_scale)
|
|
coords.mul_(coords)
|
|
|
|
coords[:, 0] += coords[:, 1]
|
|
coords = coords[:, :1]
|
|
|
|
coords[invalid_points, :, :, :] = 1e6
|
|
|
|
coords = coords.view(-1, num_points, 1, rows, cols)
|
|
coords = coords.min(dim=1)[0] # -> (bs * num_masks * 2) x 1 x h x w
|
|
coords = coords.view(-1, 2, rows, cols)
|
|
|
|
coords.sqrt_().mul_(2).tanh_()
|
|
|
|
return coords
|
|
|
|
def forward(self, x, coords):
|
|
return self.get_coord_features(coords, x.shape[0], x.shape[2], x.shape[3])
|