Files
Track-Anything/tracker/model/memory_util.py
2023-04-12 13:21:43 +08:00

81 lines
2.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import math
import numpy as np
import torch
from typing import Optional
def get_similarity(mk, ms, qk, qe):
# used for training/inference and memory reading/memory potentiation
# mk: B x CK x [N] - Memory keys
# ms: B x 1 x [N] - Memory shrinkage
# qk: B x CK x [HW/P] - Query keys
# qe: B x CK x [HW/P] - Query selection
# Dimensions in [] are flattened
CK = mk.shape[1]
mk = mk.flatten(start_dim=2)
ms = ms.flatten(start_dim=1).unsqueeze(2) if ms is not None else None
qk = qk.flatten(start_dim=2)
qe = qe.flatten(start_dim=2) if qe is not None else None
if qe is not None:
# See appendix for derivation
# or you can just trust me ヽ(ー_ー )
mk = mk.transpose(1, 2)
a_sq = (mk.pow(2) @ qe)
two_ab = 2 * (mk @ (qk * qe))
b_sq = (qe * qk.pow(2)).sum(1, keepdim=True)
similarity = (-a_sq+two_ab-b_sq)
else:
# similar to STCN if we don't have the selection term
a_sq = mk.pow(2).sum(1).unsqueeze(2)
two_ab = 2 * (mk.transpose(1, 2) @ qk)
similarity = (-a_sq+two_ab)
if ms is not None:
similarity = similarity * ms / math.sqrt(CK) # B*N*HW
else:
similarity = similarity / math.sqrt(CK) # B*N*HW
return similarity
def do_softmax(similarity, top_k: Optional[int]=None, inplace=False, return_usage=False):
# normalize similarity with top-k softmax
# similarity: B x N x [HW/P]
# use inplace with care
if top_k is not None:
values, indices = torch.topk(similarity, k=top_k, dim=1)
x_exp = values.exp_()
x_exp /= torch.sum(x_exp, dim=1, keepdim=True)
if inplace:
similarity.zero_().scatter_(1, indices, x_exp) # B*N*HW
affinity = similarity
else:
affinity = torch.zeros_like(similarity).scatter_(1, indices, x_exp) # B*N*HW
else:
maxes = torch.max(similarity, dim=1, keepdim=True)[0]
x_exp = torch.exp(similarity - maxes)
x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True)
affinity = x_exp / x_exp_sum
indices = None
if return_usage:
return affinity, affinity.sum(dim=2)
return affinity
def get_affinity(mk, ms, qk, qe):
# shorthand used in training with no top-k
similarity = get_similarity(mk, ms, qk, qe)
affinity = do_softmax(similarity)
return affinity
def readout(affinity, mv):
B, CV, T, H, W = mv.shape
mo = mv.view(B, CV, T*H*W)
mem = torch.bmm(mo, affinity)
mem = mem.view(B, CV, H, W)
return mem