Files
Track-Anything/inference/interact/fbrs/model/ops.py

84 lines
3.0 KiB
Python
Raw Normal View History

2023-04-12 08:24:08 +08:00
import torch
from torch import nn as nn
import numpy as np
from . import initializer as initializer
from ..utils.cython import get_dist_maps
def select_activation_function(activation):
if isinstance(activation, str):
if activation.lower() == 'relu':
return nn.ReLU
elif activation.lower() == 'softplus':
return nn.Softplus
else:
raise ValueError(f"Unknown activation type {activation}")
elif isinstance(activation, nn.Module):
return activation
else:
raise ValueError(f"Unknown activation type {activation}")
class BilinearConvTranspose2d(nn.ConvTranspose2d):
def __init__(self, in_channels, out_channels, scale, groups=1):
kernel_size = 2 * scale - scale % 2
self.scale = scale
super().__init__(
in_channels, out_channels,
kernel_size=kernel_size,
stride=scale,
padding=1,
groups=groups,
bias=False)
self.apply(initializer.Bilinear(scale=scale, in_channels=in_channels, groups=groups))
class DistMaps(nn.Module):
def __init__(self, norm_radius, spatial_scale=1.0, cpu_mode=False):
super(DistMaps, self).__init__()
self.spatial_scale = spatial_scale
self.norm_radius = norm_radius
self.cpu_mode = cpu_mode
def get_coord_features(self, points, batchsize, rows, cols):
if self.cpu_mode:
coords = []
for i in range(batchsize):
norm_delimeter = self.spatial_scale * self.norm_radius
coords.append(get_dist_maps(points[i].cpu().float().numpy(), rows, cols,
norm_delimeter))
coords = torch.from_numpy(np.stack(coords, axis=0)).to(points.device).float()
else:
num_points = points.shape[1] // 2
points = points.view(-1, 2)
invalid_points = torch.max(points, dim=1, keepdim=False)[0] < 0
row_array = torch.arange(start=0, end=rows, step=1, dtype=torch.float32, device=points.device)
col_array = torch.arange(start=0, end=cols, step=1, dtype=torch.float32, device=points.device)
coord_rows, coord_cols = torch.meshgrid(row_array, col_array)
coords = torch.stack((coord_rows, coord_cols), dim=0).unsqueeze(0).repeat(points.size(0), 1, 1, 1)
add_xy = (points * self.spatial_scale).view(points.size(0), points.size(1), 1, 1)
coords.add_(-add_xy)
coords.div_(self.norm_radius * self.spatial_scale)
coords.mul_(coords)
coords[:, 0] += coords[:, 1]
coords = coords[:, :1]
coords[invalid_points, :, :, :] = 1e6
coords = coords.view(-1, num_points, 1, rows, cols)
coords = coords.min(dim=1)[0] # -> (bs * num_masks * 2) x 1 x h x w
coords = coords.view(-1, 2, rows, cols)
coords.sqrt_().mul_(2).tanh_()
return coords
def forward(self, x, coords):
return self.get_coord_features(coords, x.shape[0], x.shape[2], x.shape[3])