mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 08:27:49 +01:00
286 lines
12 KiB
Python
286 lines
12 KiB
Python
import torch
|
|
import warnings
|
|
|
|
from inference.kv_memory_store import KeyValueMemoryStore
|
|
from model.memory_util import *
|
|
|
|
|
|
class MemoryManager:
|
|
"""
|
|
Manages all three memory stores and the transition between working/long-term memory
|
|
"""
|
|
def __init__(self, config):
|
|
self.hidden_dim = config['hidden_dim']
|
|
self.top_k = config['top_k']
|
|
|
|
self.enable_long_term = config['enable_long_term']
|
|
self.enable_long_term_usage = config['enable_long_term_count_usage']
|
|
if self.enable_long_term:
|
|
self.max_mt_frames = config['max_mid_term_frames']
|
|
self.min_mt_frames = config['min_mid_term_frames']
|
|
self.num_prototypes = config['num_prototypes']
|
|
self.max_long_elements = config['max_long_term_elements']
|
|
|
|
# dimensions will be inferred from input later
|
|
self.CK = self.CV = None
|
|
self.H = self.W = None
|
|
|
|
# The hidden state will be stored in a single tensor for all objects
|
|
# B x num_objects x CH x H x W
|
|
self.hidden = None
|
|
|
|
self.work_mem = KeyValueMemoryStore(count_usage=self.enable_long_term)
|
|
if self.enable_long_term:
|
|
self.long_mem = KeyValueMemoryStore(count_usage=self.enable_long_term_usage)
|
|
|
|
self.reset_config = True
|
|
|
|
def update_config(self, config):
|
|
self.reset_config = True
|
|
self.hidden_dim = config['hidden_dim']
|
|
self.top_k = config['top_k']
|
|
|
|
assert self.enable_long_term == config['enable_long_term'], 'cannot update this'
|
|
assert self.enable_long_term_usage == config['enable_long_term_count_usage'], 'cannot update this'
|
|
|
|
self.enable_long_term_usage = config['enable_long_term_count_usage']
|
|
if self.enable_long_term:
|
|
self.max_mt_frames = config['max_mid_term_frames']
|
|
self.min_mt_frames = config['min_mid_term_frames']
|
|
self.num_prototypes = config['num_prototypes']
|
|
self.max_long_elements = config['max_long_term_elements']
|
|
|
|
def _readout(self, affinity, v):
|
|
# this function is for a single object group
|
|
return v @ affinity
|
|
|
|
def match_memory(self, query_key, selection):
|
|
# query_key: B x C^k x H x W
|
|
# selection: B x C^k x H x W
|
|
num_groups = self.work_mem.num_groups
|
|
h, w = query_key.shape[-2:]
|
|
|
|
query_key = query_key.flatten(start_dim=2)
|
|
selection = selection.flatten(start_dim=2) if selection is not None else None
|
|
|
|
"""
|
|
Memory readout using keys
|
|
"""
|
|
|
|
if self.enable_long_term and self.long_mem.engaged():
|
|
# Use long-term memory
|
|
long_mem_size = self.long_mem.size
|
|
memory_key = torch.cat([self.long_mem.key, self.work_mem.key], -1)
|
|
shrinkage = torch.cat([self.long_mem.shrinkage, self.work_mem.shrinkage], -1)
|
|
|
|
similarity = get_similarity(memory_key, shrinkage, query_key, selection)
|
|
work_mem_similarity = similarity[:, long_mem_size:]
|
|
long_mem_similarity = similarity[:, :long_mem_size]
|
|
|
|
# get the usage with the first group
|
|
# the first group always have all the keys valid
|
|
affinity, usage = do_softmax(
|
|
torch.cat([long_mem_similarity[:, -self.long_mem.get_v_size(0):], work_mem_similarity], 1),
|
|
top_k=self.top_k, inplace=True, return_usage=True)
|
|
affinity = [affinity]
|
|
|
|
# compute affinity group by group as later groups only have a subset of keys
|
|
for gi in range(1, num_groups):
|
|
if gi < self.long_mem.num_groups:
|
|
# merge working and lt similarities before softmax
|
|
affinity_one_group = do_softmax(
|
|
torch.cat([long_mem_similarity[:, -self.long_mem.get_v_size(gi):],
|
|
work_mem_similarity[:, -self.work_mem.get_v_size(gi):]], 1),
|
|
top_k=self.top_k, inplace=True)
|
|
else:
|
|
# no long-term memory for this group
|
|
affinity_one_group = do_softmax(work_mem_similarity[:, -self.work_mem.get_v_size(gi):],
|
|
top_k=self.top_k, inplace=(gi==num_groups-1))
|
|
affinity.append(affinity_one_group)
|
|
|
|
all_memory_value = []
|
|
for gi, gv in enumerate(self.work_mem.value):
|
|
# merge the working and lt values before readout
|
|
if gi < self.long_mem.num_groups:
|
|
all_memory_value.append(torch.cat([self.long_mem.value[gi], self.work_mem.value[gi]], -1))
|
|
else:
|
|
all_memory_value.append(gv)
|
|
|
|
"""
|
|
Record memory usage for working and long-term memory
|
|
"""
|
|
# ignore the index return for long-term memory
|
|
work_usage = usage[:, long_mem_size:]
|
|
self.work_mem.update_usage(work_usage.flatten())
|
|
|
|
if self.enable_long_term_usage:
|
|
# ignore the index return for working memory
|
|
long_usage = usage[:, :long_mem_size]
|
|
self.long_mem.update_usage(long_usage.flatten())
|
|
else:
|
|
# No long-term memory
|
|
similarity = get_similarity(self.work_mem.key, self.work_mem.shrinkage, query_key, selection)
|
|
|
|
if self.enable_long_term:
|
|
affinity, usage = do_softmax(similarity, inplace=(num_groups==1),
|
|
top_k=self.top_k, return_usage=True)
|
|
|
|
# Record memory usage for working memory
|
|
self.work_mem.update_usage(usage.flatten())
|
|
else:
|
|
affinity = do_softmax(similarity, inplace=(num_groups==1),
|
|
top_k=self.top_k, return_usage=False)
|
|
|
|
affinity = [affinity]
|
|
|
|
# compute affinity group by group as later groups only have a subset of keys
|
|
for gi in range(1, num_groups):
|
|
affinity_one_group = do_softmax(similarity[:, -self.work_mem.get_v_size(gi):],
|
|
top_k=self.top_k, inplace=(gi==num_groups-1))
|
|
affinity.append(affinity_one_group)
|
|
|
|
all_memory_value = self.work_mem.value
|
|
|
|
# Shared affinity within each group
|
|
all_readout_mem = torch.cat([
|
|
self._readout(affinity[gi], gv)
|
|
for gi, gv in enumerate(all_memory_value)
|
|
], 0)
|
|
|
|
return all_readout_mem.view(all_readout_mem.shape[0], self.CV, h, w)
|
|
|
|
def add_memory(self, key, shrinkage, value, objects, selection=None):
|
|
# key: 1*C*H*W
|
|
# value: 1*num_objects*C*H*W
|
|
# objects contain a list of object indices
|
|
if self.H is None or self.reset_config:
|
|
self.reset_config = False
|
|
self.H, self.W = key.shape[-2:]
|
|
self.HW = self.H*self.W
|
|
if self.enable_long_term:
|
|
# convert from num. frames to num. nodes
|
|
self.min_work_elements = self.min_mt_frames*self.HW
|
|
self.max_work_elements = self.max_mt_frames*self.HW
|
|
|
|
# key: 1*C*N
|
|
# value: num_objects*C*N
|
|
key = key.flatten(start_dim=2)
|
|
shrinkage = shrinkage.flatten(start_dim=2)
|
|
value = value[0].flatten(start_dim=2)
|
|
|
|
self.CK = key.shape[1]
|
|
self.CV = value.shape[1]
|
|
|
|
if selection is not None:
|
|
if not self.enable_long_term:
|
|
warnings.warn('the selection factor is only needed in long-term mode', UserWarning)
|
|
selection = selection.flatten(start_dim=2)
|
|
|
|
self.work_mem.add(key, value, shrinkage, selection, objects)
|
|
|
|
# long-term memory cleanup
|
|
if self.enable_long_term:
|
|
# Do memory compressed if needed
|
|
if self.work_mem.size >= self.max_work_elements:
|
|
print('remove memory')
|
|
# Remove obsolete features if needed
|
|
if self.long_mem.size >= (self.max_long_elements-self.num_prototypes):
|
|
self.long_mem.remove_obsolete_features(self.max_long_elements-self.num_prototypes)
|
|
|
|
self.compress_features()
|
|
|
|
def create_hidden_state(self, n, sample_key):
|
|
# n is the TOTAL number of objects
|
|
h, w = sample_key.shape[-2:]
|
|
if self.hidden is None:
|
|
self.hidden = torch.zeros((1, n, self.hidden_dim, h, w), device=sample_key.device)
|
|
elif self.hidden.shape[1] != n:
|
|
self.hidden = torch.cat([
|
|
self.hidden,
|
|
torch.zeros((1, n-self.hidden.shape[1], self.hidden_dim, h, w), device=sample_key.device)
|
|
], 1)
|
|
|
|
assert(self.hidden.shape[1] == n)
|
|
|
|
def set_hidden(self, hidden):
|
|
self.hidden = hidden
|
|
|
|
def get_hidden(self):
|
|
return self.hidden
|
|
|
|
def compress_features(self):
|
|
HW = self.HW
|
|
candidate_value = []
|
|
total_work_mem_size = self.work_mem.size
|
|
for gv in self.work_mem.value:
|
|
# Some object groups might be added later in the video
|
|
# So not all keys have values associated with all objects
|
|
# We need to keep track of the key->value validity
|
|
mem_size_in_this_group = gv.shape[-1]
|
|
if mem_size_in_this_group == total_work_mem_size:
|
|
# full LT
|
|
candidate_value.append(gv[:,:,HW:-self.min_work_elements+HW])
|
|
else:
|
|
# mem_size is smaller than total_work_mem_size, but at least HW
|
|
assert HW <= mem_size_in_this_group < total_work_mem_size
|
|
if mem_size_in_this_group > self.min_work_elements+HW:
|
|
# part of this object group still goes into LT
|
|
candidate_value.append(gv[:,:,HW:-self.min_work_elements+HW])
|
|
else:
|
|
# this object group cannot go to the LT at all
|
|
candidate_value.append(None)
|
|
|
|
# perform memory consolidation
|
|
prototype_key, prototype_value, prototype_shrinkage = self.consolidation(
|
|
*self.work_mem.get_all_sliced(HW, -self.min_work_elements+HW), candidate_value)
|
|
|
|
# remove consolidated working memory
|
|
self.work_mem.sieve_by_range(HW, -self.min_work_elements+HW, min_size=self.min_work_elements+HW)
|
|
|
|
# add to long-term memory
|
|
self.long_mem.add(prototype_key, prototype_value, prototype_shrinkage, selection=None, objects=None)
|
|
print(f'long memory size: {self.long_mem.size}')
|
|
print(f'work memory size: {self.work_mem.size}')
|
|
|
|
def consolidation(self, candidate_key, candidate_shrinkage, candidate_selection, usage, candidate_value):
|
|
# keys: 1*C*N
|
|
# values: num_objects*C*N
|
|
N = candidate_key.shape[-1]
|
|
|
|
# find the indices with max usage
|
|
_, max_usage_indices = torch.topk(usage, k=self.num_prototypes, dim=-1, sorted=True)
|
|
prototype_indices = max_usage_indices.flatten()
|
|
|
|
# Prototypes are invalid for out-of-bound groups
|
|
validity = [prototype_indices >= (N-gv.shape[2]) if gv is not None else None for gv in candidate_value]
|
|
|
|
prototype_key = candidate_key[:, :, prototype_indices]
|
|
prototype_selection = candidate_selection[:, :, prototype_indices] if candidate_selection is not None else None
|
|
|
|
"""
|
|
Potentiation step
|
|
"""
|
|
similarity = get_similarity(candidate_key, candidate_shrinkage, prototype_key, prototype_selection)
|
|
|
|
# convert similarity to affinity
|
|
# need to do it group by group since the softmax normalization would be different
|
|
affinity = [
|
|
do_softmax(similarity[:, -gv.shape[2]:, validity[gi]]) if gv is not None else None
|
|
for gi, gv in enumerate(candidate_value)
|
|
]
|
|
|
|
# some values can be have all False validity. Weed them out.
|
|
affinity = [
|
|
aff if aff is None or aff.shape[-1] > 0 else None for aff in affinity
|
|
]
|
|
|
|
# readout the values
|
|
prototype_value = [
|
|
self._readout(affinity[gi], gv) if affinity[gi] is not None else None
|
|
for gi, gv in enumerate(candidate_value)
|
|
]
|
|
|
|
# readout the shrinkage term
|
|
prototype_shrinkage = self._readout(affinity[0], candidate_shrinkage) if candidate_shrinkage is not None else None
|
|
|
|
return prototype_key, prototype_value, prototype_shrinkage |