mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-15 16:07:51 +01:00
fix memory management
This commit is contained in:
@@ -20,7 +20,7 @@ from torchvision.transforms import Resize
|
||||
|
||||
|
||||
class BaseTracker:
|
||||
def __init__(self, xmem_checkpoint, device, sam_checkpoint, model_type) -> None:
|
||||
def __init__(self, xmem_checkpoint, device, sam_model, model_type=None) -> None:
|
||||
"""
|
||||
device: model device
|
||||
xmem_checkpoint: checkpoint of XMem model
|
||||
@@ -43,9 +43,9 @@ class BaseTracker:
|
||||
self.mapper = MaskMapper()
|
||||
self.initialised = False
|
||||
|
||||
# SAM-based refinement
|
||||
self.sam_model = BaseSegmenter(sam_checkpoint, model_type, device=device)
|
||||
self.resizer = Resize([256, 256])
|
||||
# # SAM-based refinement
|
||||
# self.sam_model = sam_model
|
||||
# self.resizer = Resize([256, 256])
|
||||
|
||||
@torch.no_grad()
|
||||
def resize_mask(self, mask):
|
||||
@@ -78,8 +78,7 @@ class BaseTracker:
|
||||
# prepare inputs
|
||||
frame_tensor = self.im_transform(frame).to(self.device)
|
||||
# track one frame
|
||||
probs, logits = self.tracker.step(frame_tensor, mask, labels) # logits 2 (bg fg) H W
|
||||
|
||||
probs, _ = self.tracker.step(frame_tensor, mask, labels) # logits 2 (bg fg) H W
|
||||
# # refine
|
||||
# if first_frame_annotation is None:
|
||||
# out_mask = self.sam_refinement(frame, logits[1], ti)
|
||||
@@ -92,7 +91,8 @@ class BaseTracker:
|
||||
painted_image = frame
|
||||
for obj in range(1, num_objs+1):
|
||||
painted_image = mask_painter(painted_image, (out_mask==obj).astype('uint8'), mask_color=obj+1)
|
||||
return out_mask, probs, painted_image
|
||||
|
||||
return out_mask, out_mask, painted_image
|
||||
|
||||
@torch.no_grad()
|
||||
def sam_refinement(self, frame, logits, ti):
|
||||
@@ -134,22 +134,52 @@ if __name__ == '__main__':
|
||||
# ----------------------------------------------------------
|
||||
# initalise tracker
|
||||
# ----------------------------------------------------------
|
||||
device = 'cuda:0'
|
||||
device = 'cuda:4'
|
||||
XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth'
|
||||
SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
|
||||
model_type = 'vit_h'
|
||||
tracker = BaseTracker(XMEM_checkpoint, device, SAM_checkpoint, model_type)
|
||||
|
||||
# sam_model = BaseSegmenter(SAM_checkpoint, model_type, device=device)
|
||||
tracker = BaseTracker(XMEM_checkpoint, device, None, device)
|
||||
|
||||
# test for storage efficiency
|
||||
frames = np.load('/ssd1/gaomingqi/efficiency/efficiency.npy')
|
||||
first_frame_annotation = np.array(Image.open('/ssd1/gaomingqi/efficiency/template_mask.png'))
|
||||
|
||||
# track anything given in the first frame annotation
|
||||
for ti, frame in enumerate(frames):
|
||||
print(ti)
|
||||
if ti > 200:
|
||||
break
|
||||
if ti == 0:
|
||||
mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
|
||||
else:
|
||||
mask, prob, painted_image = tracker.track(frame)
|
||||
# save
|
||||
painted_image = Image.fromarray(painted_image)
|
||||
painted_image.save(f'/ssd1/gaomingqi/results/TrackA/horsejump-high/{ti:05d}.png')
|
||||
painted_image.save(f'/ssd1/gaomingqi/results/TrackA/gsw/{ti:05d}.png')
|
||||
|
||||
tracker.clear_memory()
|
||||
for ti, frame in enumerate(frames):
|
||||
print(ti)
|
||||
# if ti > 200:
|
||||
# break
|
||||
if ti == 0:
|
||||
mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
|
||||
else:
|
||||
mask, prob, painted_image = tracker.track(frame)
|
||||
# save
|
||||
painted_image = Image.fromarray(painted_image)
|
||||
painted_image.save(f'/ssd1/gaomingqi/results/TrackA/gsw/{ti:05d}.png')
|
||||
|
||||
# # track anything given in the first frame annotation
|
||||
# for ti, frame in enumerate(frames):
|
||||
# if ti == 0:
|
||||
# mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
|
||||
# else:
|
||||
# mask, prob, painted_image = tracker.track(frame)
|
||||
# # save
|
||||
# painted_image = Image.fromarray(painted_image)
|
||||
# painted_image.save(f'/ssd1/gaomingqi/results/TrackA/horsejump-high/{ti:05d}.png')
|
||||
|
||||
# # ----------------------------------------------------------
|
||||
# # another video
|
||||
@@ -198,3 +228,6 @@ if __name__ == '__main__':
|
||||
# prob = Image.fromarray((probs[1].cpu().numpy()*255).astype('uint8'))
|
||||
|
||||
# # prob.save(f'/ssd1/gaomingqi/failure/probs/{ti:05d}.png')
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
# config info for XMem
|
||||
benchmark: False
|
||||
disable_long_term: False
|
||||
max_mid_term_frames: 15
|
||||
max_mid_term_frames: 10
|
||||
min_mid_term_frames: 5
|
||||
max_long_term_elements: 10000
|
||||
max_long_term_elements: 1000
|
||||
num_prototypes: 128
|
||||
top_k: 30
|
||||
mem_every: 5
|
||||
|
||||
@@ -29,9 +29,6 @@ class KeyValueMemoryStore:
|
||||
# shrinkage and selection are also single tensors
|
||||
self.s = self.e = None
|
||||
|
||||
# cumulated probs, softmax(HW) -> THW
|
||||
self.cumulated_probs = []
|
||||
|
||||
# usage
|
||||
if self.count_usage:
|
||||
self.use_count = self.life_count = None
|
||||
@@ -101,7 +98,7 @@ class KeyValueMemoryStore:
|
||||
self.use_count += usage.view_as(self.use_count)
|
||||
self.life_count += 1
|
||||
|
||||
def sieve_by_range(self, start: int, end: int, min_size: int, keeps: List):
|
||||
def sieve_by_range(self, start: int, end: int, min_size: int):
|
||||
# keep only the elements *outside* of this range (with some boundary conditions)
|
||||
# i.e., concat (a[:start], a[end:])
|
||||
# min_size is only used for values, we do not sieve values under this size
|
||||
@@ -122,31 +119,18 @@ class KeyValueMemoryStore:
|
||||
if self.v[gi].shape[-1] >= min_size:
|
||||
self.v[gi] = self.v[gi][:,:,:start]
|
||||
else:
|
||||
# self.k = torch.cat([self.k[:,:,:start], self.k[:,:,end:]], -1)
|
||||
# if self.count_usage:
|
||||
# self.use_count = torch.cat([self.use_count[:,:,:start], self.use_count[:,:,end:]], -1)
|
||||
# self.life_count = torch.cat([self.life_count[:,:,:start], self.life_count[:,:,end:]], -1)
|
||||
# if self.s is not None:
|
||||
# self.s = torch.cat([self.s[:,:,:start], self.s[:,:,end:]], -1)
|
||||
# if self.e is not None:
|
||||
# self.e = torch.cat([self.e[:,:,:start], self.e[:,:,end:]], -1)
|
||||
|
||||
# for gi in range(self.num_groups):
|
||||
# if self.v[gi].shape[-1] >= min_size:
|
||||
# self.v[gi] = torch.cat([self.v[gi][:,:,:start], self.v[gi][:,:,end:]], -1)
|
||||
|
||||
# key memory to be kept
|
||||
self.k = torch.cat([self.k[:,:,keep[0]:keep[1]] for keep in keeps], dim=-1)
|
||||
self.k = torch.cat([self.k[:,:,:start], self.k[:,:,end:]], -1)
|
||||
if self.count_usage:
|
||||
self.use_count = torch.cat([self.use_count[:,:,keep[0]:keep[1]] for keep in keeps], dim=-1)
|
||||
self.life_count = torch.cat([self.life_count[:,:,keep[0]:keep[1]] for keep in keeps], dim=-1)
|
||||
self.use_count = torch.cat([self.use_count[:,:,:start], self.use_count[:,:,end:]], -1)
|
||||
self.life_count = torch.cat([self.life_count[:,:,:start], self.life_count[:,:,end:]], -1)
|
||||
if self.s is not None:
|
||||
self.s = torch.cat([self.s[:,:,keep[0]:keep[1]] for keep in keeps], dim=-1)
|
||||
self.s = torch.cat([self.s[:,:,:start], self.s[:,:,end:]], -1)
|
||||
if self.e is not None:
|
||||
self.e = torch.cat([self.e[:,:,keep[0]:keep[1]] for keep in keeps], dim=-1)
|
||||
self.e = torch.cat([self.e[:,:,:start], self.e[:,:,end:]], -1)
|
||||
|
||||
for gi in range(self.num_groups):
|
||||
if self.v[gi].shape[-1] >= min_size:
|
||||
self.v[gi] = torch.cat([self.v[gi][:,:,keep[0]:keep[1]] for keep in keeps], dim=-1)
|
||||
self.v[gi] = torch.cat([self.v[gi][:,:,:start], self.v[gi][:,:,end:]], -1)
|
||||
|
||||
def remove_obsolete_features(self, max_size: int):
|
||||
# normalize with life duration
|
||||
@@ -179,7 +163,7 @@ class KeyValueMemoryStore:
|
||||
usage = self.use_count / self.life_count
|
||||
return usage
|
||||
|
||||
def get_all_sliced(self, start: int, end: int, spans: List):
|
||||
def get_all_sliced(self, start: int, end: int):
|
||||
# return k, sk, ek, usage in order, sliced by start and end
|
||||
|
||||
if end == 0:
|
||||
@@ -189,15 +173,10 @@ class KeyValueMemoryStore:
|
||||
ek = self.e[:,:,start:] if self.e is not None else None
|
||||
usage = self.get_usage()[:,:,start:]
|
||||
else:
|
||||
# k = self.k[:,:,start:end]
|
||||
# sk = self.s[:,:,start:end] if self.s is not None else None
|
||||
# ek = self.e[:,:,start:end] if self.e is not None else None
|
||||
# usage = self.get_usage()[:,:,start:end]
|
||||
|
||||
k = torch.cat([self.k[:,:,span[0]:span[1]] for span in spans], dim=-1)
|
||||
sk = torch.cat([self.s[:,:,span[0]:span[1]] for span in spans], dim=-1) if self.s is not None else None
|
||||
ek = torch.cat([self.e[:,:,span[0]:span[1]] for span in spans], dim=-1) if self.e is not None else None
|
||||
usage = torch.cat([self.get_usage()[:,:,span[0]:span[1]] for span in spans], dim=-1)
|
||||
k = self.k[:,:,start:end]
|
||||
sk = self.s[:,:,start:end] if self.s is not None else None
|
||||
ek = self.e[:,:,start:end] if self.e is not None else None
|
||||
usage = self.get_usage()[:,:,start:end]
|
||||
|
||||
return k, sk, ek, usage
|
||||
|
||||
@@ -233,4 +212,3 @@ class KeyValueMemoryStore:
|
||||
@property
|
||||
def selection(self):
|
||||
return self.e
|
||||
|
||||
|
||||
@@ -138,7 +138,7 @@ class MemoryManager:
|
||||
affinity_one_group = do_softmax(similarity[:, -self.work_mem.get_v_size(gi):],
|
||||
top_k=self.top_k, inplace=(gi==num_groups-1))
|
||||
affinity.append(affinity_one_group)
|
||||
|
||||
|
||||
all_memory_value = self.work_mem.value
|
||||
|
||||
# Shared affinity within each group
|
||||
@@ -182,13 +182,13 @@ class MemoryManager:
|
||||
if self.enable_long_term:
|
||||
# Do memory compressed if needed
|
||||
if self.work_mem.size >= self.max_work_elements:
|
||||
print('remove memory')
|
||||
# Remove obsolete features if needed
|
||||
if self.long_mem.size >= (self.max_long_elements-self.num_prototypes):
|
||||
self.long_mem.remove_obsolete_features(self.max_long_elements-self.num_prototypes)
|
||||
|
||||
self.compress_features()
|
||||
|
||||
|
||||
def create_hidden_state(self, n, sample_key):
|
||||
# n is the TOTAL number of objects
|
||||
h, w = sample_key.shape[-2:]
|
||||
@@ -212,13 +212,6 @@ class MemoryManager:
|
||||
HW = self.HW
|
||||
candidate_value = []
|
||||
total_work_mem_size = self.work_mem.size
|
||||
|
||||
# determine memory indices to be compressed and removed
|
||||
# uniform sampling from 1 to -2
|
||||
num_memory_elements = total_work_mem_size // HW
|
||||
spans = [[6*HW, 8*HW]]
|
||||
keeps = [[0, 6*HW], [8*HW, 15*HW]]
|
||||
|
||||
for gv in self.work_mem.value:
|
||||
# Some object groups might be added later in the video
|
||||
# So not all keys have values associated with all objects
|
||||
@@ -226,12 +219,7 @@ class MemoryManager:
|
||||
mem_size_in_this_group = gv.shape[-1]
|
||||
if mem_size_in_this_group == total_work_mem_size:
|
||||
# full LT
|
||||
# candidate_value.append(gv[:,:,HW:-self.min_work_elements+HW])
|
||||
candidate_value.append(torch.cat([gv[:,:,span[0]:span[1]] for span in spans], dim=-1))
|
||||
# values = []
|
||||
# for span in spans:
|
||||
# values.append(gv[:,:,span[0]:span[1]])
|
||||
# candidate_value.append(torch.concat(values, dim=-1))
|
||||
candidate_value.append(gv[:,:,HW:-self.min_work_elements+HW])
|
||||
else:
|
||||
# mem_size is smaller than total_work_mem_size, but at least HW
|
||||
assert HW <= mem_size_in_this_group < total_work_mem_size
|
||||
@@ -244,15 +232,15 @@ class MemoryManager:
|
||||
|
||||
# perform memory consolidation
|
||||
prototype_key, prototype_value, prototype_shrinkage = self.consolidation(
|
||||
*self.work_mem.get_all_sliced(HW, -self.min_work_elements+HW, spans), candidate_value)
|
||||
*self.work_mem.get_all_sliced(HW, -self.min_work_elements+HW), candidate_value)
|
||||
|
||||
# remove consolidated working memory
|
||||
self.work_mem.sieve_by_range(HW, -self.min_work_elements+HW, min_size=self.min_work_elements+HW, keeps=keeps)
|
||||
|
||||
# print('remove working memory')
|
||||
self.work_mem.sieve_by_range(HW, -self.min_work_elements+HW, min_size=self.min_work_elements+HW)
|
||||
|
||||
# add to long-term memory
|
||||
self.long_mem.add(prototype_key, prototype_value, prototype_shrinkage, selection=None, objects=None)
|
||||
print(f'long memory size: {self.long_mem.size}')
|
||||
print(f'work memory size: {self.work_mem.size}')
|
||||
|
||||
def consolidation(self, candidate_key, candidate_shrinkage, candidate_selection, usage, candidate_value):
|
||||
# keys: 1*C*N
|
||||
@@ -295,4 +283,4 @@ class MemoryManager:
|
||||
# readout the shrinkage term
|
||||
prototype_shrinkage = self._readout(affinity[0], candidate_shrinkage) if candidate_shrinkage is not None else None
|
||||
|
||||
return prototype_key, prototype_value, prototype_shrinkage
|
||||
return prototype_key, prototype_value, prototype_shrinkage
|
||||
Reference in New Issue
Block a user