mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 16:37:58 +01:00
116 lines
5.1 KiB
Python
116 lines
5.1 KiB
Python
from inference.memory_manager import MemoryManager
|
|
from model.network import XMem
|
|
from model.aggregate import aggregate
|
|
|
|
from tracker.util.tensor_util import pad_divide_by, unpad
|
|
|
|
|
|
class InferenceCore:
|
|
def __init__(self, network:XMem, config):
|
|
self.config = config
|
|
self.network = network
|
|
self.mem_every = config['mem_every']
|
|
self.deep_update_every = config['deep_update_every']
|
|
self.enable_long_term = config['enable_long_term']
|
|
|
|
# if deep_update_every < 0, synchronize deep update with memory frame
|
|
self.deep_update_sync = (self.deep_update_every < 0)
|
|
|
|
self.clear_memory()
|
|
self.all_labels = None
|
|
|
|
def clear_memory(self):
|
|
self.curr_ti = -1
|
|
self.last_mem_ti = 0
|
|
if not self.deep_update_sync:
|
|
self.last_deep_update_ti = -self.deep_update_every
|
|
self.memory = MemoryManager(config=self.config)
|
|
|
|
def update_config(self, config):
|
|
self.mem_every = config['mem_every']
|
|
self.deep_update_every = config['deep_update_every']
|
|
self.enable_long_term = config['enable_long_term']
|
|
|
|
# if deep_update_every < 0, synchronize deep update with memory frame
|
|
self.deep_update_sync = (self.deep_update_every < 0)
|
|
self.memory.update_config(config)
|
|
|
|
def set_all_labels(self, all_labels):
|
|
# self.all_labels = [l.item() for l in all_labels]
|
|
self.all_labels = all_labels
|
|
|
|
def step(self, image, mask=None, valid_labels=None, end=False):
|
|
# image: 3*H*W
|
|
# mask: num_objects*H*W or None
|
|
self.curr_ti += 1
|
|
image, self.pad = pad_divide_by(image, 16)
|
|
image = image.unsqueeze(0) # add the batch dimension
|
|
|
|
is_mem_frame = ((self.curr_ti-self.last_mem_ti >= self.mem_every) or (mask is not None)) and (not end)
|
|
need_segment = (self.curr_ti > 0) and ((valid_labels is None) or (len(self.all_labels) != len(valid_labels)))
|
|
is_deep_update = (
|
|
(self.deep_update_sync and is_mem_frame) or # synchronized
|
|
(not self.deep_update_sync and self.curr_ti-self.last_deep_update_ti >= self.deep_update_every) # no-sync
|
|
) and (not end)
|
|
is_normal_update = (not self.deep_update_sync or not is_deep_update) and (not end)
|
|
|
|
key, shrinkage, selection, f16, f8, f4 = self.network.encode_key(image,
|
|
need_ek=(self.enable_long_term or need_segment),
|
|
need_sk=is_mem_frame)
|
|
multi_scale_features = (f16, f8, f4)
|
|
|
|
# segment the current frame is needed
|
|
if need_segment:
|
|
memory_readout = self.memory.match_memory(key, selection).unsqueeze(0)
|
|
|
|
hidden, pred_logits_with_bg, pred_prob_with_bg = self.network.segment(multi_scale_features, memory_readout,
|
|
self.memory.get_hidden(), h_out=is_normal_update, strip_bg=False)
|
|
# remove batch dim
|
|
pred_prob_with_bg = pred_prob_with_bg[0]
|
|
pred_prob_no_bg = pred_prob_with_bg[1:]
|
|
|
|
pred_logits_with_bg = pred_logits_with_bg[0]
|
|
pred_logits_no_bg = pred_logits_with_bg[1:]
|
|
|
|
if is_normal_update:
|
|
self.memory.set_hidden(hidden)
|
|
else:
|
|
pred_prob_no_bg = pred_prob_with_bg = pred_logits_with_bg = pred_logits_no_bg = None
|
|
|
|
# use the input mask if any
|
|
if mask is not None:
|
|
mask, _ = pad_divide_by(mask, 16)
|
|
|
|
if pred_prob_no_bg is not None:
|
|
# if we have a predicted mask, we work on it
|
|
# make pred_prob_no_bg consistent with the input mask
|
|
mask_regions = (mask.sum(0) > 0.5)
|
|
pred_prob_no_bg[:, mask_regions] = 0
|
|
# shift by 1 because mask/pred_prob_no_bg do not contain background
|
|
mask = mask.type_as(pred_prob_no_bg)
|
|
if valid_labels is not None:
|
|
shift_by_one_non_labels = [i for i in range(pred_prob_no_bg.shape[0]) if (i+1) not in valid_labels]
|
|
# non-labelled objects are copied from the predicted mask
|
|
mask[shift_by_one_non_labels] = pred_prob_no_bg[shift_by_one_non_labels]
|
|
pred_prob_with_bg = aggregate(mask, dim=0)
|
|
|
|
# also create new hidden states
|
|
self.memory.create_hidden_state(len(self.all_labels), key)
|
|
|
|
# save as memory if needed
|
|
if is_mem_frame:
|
|
value, hidden = self.network.encode_value(image, f16, self.memory.get_hidden(),
|
|
pred_prob_with_bg[1:].unsqueeze(0), is_deep_update=is_deep_update)
|
|
self.memory.add_memory(key, shrinkage, value, self.all_labels,
|
|
selection=selection if self.enable_long_term else None)
|
|
self.last_mem_ti = self.curr_ti
|
|
|
|
if is_deep_update:
|
|
self.memory.set_hidden(hidden)
|
|
self.last_deep_update_ti = self.curr_ti
|
|
|
|
if pred_logits_with_bg is None:
|
|
return unpad(pred_prob_with_bg, self.pad), None
|
|
else:
|
|
return unpad(pred_prob_with_bg, self.pad), unpad(pred_logits_with_bg, self.pad)
|