mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 16:37:58 +01:00
fix memory management
This commit is contained in:
@@ -20,7 +20,7 @@ from torchvision.transforms import Resize
|
|||||||
|
|
||||||
|
|
||||||
class BaseTracker:
|
class BaseTracker:
|
||||||
def __init__(self, xmem_checkpoint, device, sam_checkpoint, model_type) -> None:
|
def __init__(self, xmem_checkpoint, device, sam_model, model_type=None) -> None:
|
||||||
"""
|
"""
|
||||||
device: model device
|
device: model device
|
||||||
xmem_checkpoint: checkpoint of XMem model
|
xmem_checkpoint: checkpoint of XMem model
|
||||||
@@ -43,9 +43,9 @@ class BaseTracker:
|
|||||||
self.mapper = MaskMapper()
|
self.mapper = MaskMapper()
|
||||||
self.initialised = False
|
self.initialised = False
|
||||||
|
|
||||||
# SAM-based refinement
|
# # SAM-based refinement
|
||||||
self.sam_model = BaseSegmenter(sam_checkpoint, model_type, device=device)
|
# self.sam_model = sam_model
|
||||||
self.resizer = Resize([256, 256])
|
# self.resizer = Resize([256, 256])
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def resize_mask(self, mask):
|
def resize_mask(self, mask):
|
||||||
@@ -78,8 +78,7 @@ class BaseTracker:
|
|||||||
# prepare inputs
|
# prepare inputs
|
||||||
frame_tensor = self.im_transform(frame).to(self.device)
|
frame_tensor = self.im_transform(frame).to(self.device)
|
||||||
# track one frame
|
# track one frame
|
||||||
probs, logits = self.tracker.step(frame_tensor, mask, labels) # logits 2 (bg fg) H W
|
probs, _ = self.tracker.step(frame_tensor, mask, labels) # logits 2 (bg fg) H W
|
||||||
|
|
||||||
# # refine
|
# # refine
|
||||||
# if first_frame_annotation is None:
|
# if first_frame_annotation is None:
|
||||||
# out_mask = self.sam_refinement(frame, logits[1], ti)
|
# out_mask = self.sam_refinement(frame, logits[1], ti)
|
||||||
@@ -92,7 +91,8 @@ class BaseTracker:
|
|||||||
painted_image = frame
|
painted_image = frame
|
||||||
for obj in range(1, num_objs+1):
|
for obj in range(1, num_objs+1):
|
||||||
painted_image = mask_painter(painted_image, (out_mask==obj).astype('uint8'), mask_color=obj+1)
|
painted_image = mask_painter(painted_image, (out_mask==obj).astype('uint8'), mask_color=obj+1)
|
||||||
return out_mask, probs, painted_image
|
|
||||||
|
return out_mask, out_mask, painted_image
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sam_refinement(self, frame, logits, ti):
|
def sam_refinement(self, frame, logits, ti):
|
||||||
@@ -134,22 +134,52 @@ if __name__ == '__main__':
|
|||||||
# ----------------------------------------------------------
|
# ----------------------------------------------------------
|
||||||
# initalise tracker
|
# initalise tracker
|
||||||
# ----------------------------------------------------------
|
# ----------------------------------------------------------
|
||||||
device = 'cuda:0'
|
device = 'cuda:4'
|
||||||
XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth'
|
XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth'
|
||||||
SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
|
SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
|
||||||
model_type = 'vit_h'
|
model_type = 'vit_h'
|
||||||
tracker = BaseTracker(XMEM_checkpoint, device, SAM_checkpoint, model_type)
|
|
||||||
|
|
||||||
|
# sam_model = BaseSegmenter(SAM_checkpoint, model_type, device=device)
|
||||||
|
tracker = BaseTracker(XMEM_checkpoint, device, None, device)
|
||||||
|
|
||||||
|
# test for storage efficiency
|
||||||
|
frames = np.load('/ssd1/gaomingqi/efficiency/efficiency.npy')
|
||||||
|
first_frame_annotation = np.array(Image.open('/ssd1/gaomingqi/efficiency/template_mask.png'))
|
||||||
|
|
||||||
# track anything given in the first frame annotation
|
|
||||||
for ti, frame in enumerate(frames):
|
for ti, frame in enumerate(frames):
|
||||||
|
print(ti)
|
||||||
|
if ti > 200:
|
||||||
|
break
|
||||||
if ti == 0:
|
if ti == 0:
|
||||||
mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
|
mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
|
||||||
else:
|
else:
|
||||||
mask, prob, painted_image = tracker.track(frame)
|
mask, prob, painted_image = tracker.track(frame)
|
||||||
# save
|
# save
|
||||||
painted_image = Image.fromarray(painted_image)
|
painted_image = Image.fromarray(painted_image)
|
||||||
painted_image.save(f'/ssd1/gaomingqi/results/TrackA/horsejump-high/{ti:05d}.png')
|
painted_image.save(f'/ssd1/gaomingqi/results/TrackA/gsw/{ti:05d}.png')
|
||||||
|
|
||||||
|
tracker.clear_memory()
|
||||||
|
for ti, frame in enumerate(frames):
|
||||||
|
print(ti)
|
||||||
|
# if ti > 200:
|
||||||
|
# break
|
||||||
|
if ti == 0:
|
||||||
|
mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
|
||||||
|
else:
|
||||||
|
mask, prob, painted_image = tracker.track(frame)
|
||||||
|
# save
|
||||||
|
painted_image = Image.fromarray(painted_image)
|
||||||
|
painted_image.save(f'/ssd1/gaomingqi/results/TrackA/gsw/{ti:05d}.png')
|
||||||
|
|
||||||
|
# # track anything given in the first frame annotation
|
||||||
|
# for ti, frame in enumerate(frames):
|
||||||
|
# if ti == 0:
|
||||||
|
# mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
|
||||||
|
# else:
|
||||||
|
# mask, prob, painted_image = tracker.track(frame)
|
||||||
|
# # save
|
||||||
|
# painted_image = Image.fromarray(painted_image)
|
||||||
|
# painted_image.save(f'/ssd1/gaomingqi/results/TrackA/horsejump-high/{ti:05d}.png')
|
||||||
|
|
||||||
# # ----------------------------------------------------------
|
# # ----------------------------------------------------------
|
||||||
# # another video
|
# # another video
|
||||||
@@ -198,3 +228,6 @@ if __name__ == '__main__':
|
|||||||
# prob = Image.fromarray((probs[1].cpu().numpy()*255).astype('uint8'))
|
# prob = Image.fromarray((probs[1].cpu().numpy()*255).astype('uint8'))
|
||||||
|
|
||||||
# # prob.save(f'/ssd1/gaomingqi/failure/probs/{ti:05d}.png')
|
# # prob.save(f'/ssd1/gaomingqi/failure/probs/{ti:05d}.png')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
# config info for XMem
|
# config info for XMem
|
||||||
benchmark: False
|
benchmark: False
|
||||||
disable_long_term: False
|
disable_long_term: False
|
||||||
max_mid_term_frames: 15
|
max_mid_term_frames: 10
|
||||||
min_mid_term_frames: 5
|
min_mid_term_frames: 5
|
||||||
max_long_term_elements: 10000
|
max_long_term_elements: 1000
|
||||||
num_prototypes: 128
|
num_prototypes: 128
|
||||||
top_k: 30
|
top_k: 30
|
||||||
mem_every: 5
|
mem_every: 5
|
||||||
|
|||||||
@@ -29,9 +29,6 @@ class KeyValueMemoryStore:
|
|||||||
# shrinkage and selection are also single tensors
|
# shrinkage and selection are also single tensors
|
||||||
self.s = self.e = None
|
self.s = self.e = None
|
||||||
|
|
||||||
# cumulated probs, softmax(HW) -> THW
|
|
||||||
self.cumulated_probs = []
|
|
||||||
|
|
||||||
# usage
|
# usage
|
||||||
if self.count_usage:
|
if self.count_usage:
|
||||||
self.use_count = self.life_count = None
|
self.use_count = self.life_count = None
|
||||||
@@ -101,7 +98,7 @@ class KeyValueMemoryStore:
|
|||||||
self.use_count += usage.view_as(self.use_count)
|
self.use_count += usage.view_as(self.use_count)
|
||||||
self.life_count += 1
|
self.life_count += 1
|
||||||
|
|
||||||
def sieve_by_range(self, start: int, end: int, min_size: int, keeps: List):
|
def sieve_by_range(self, start: int, end: int, min_size: int):
|
||||||
# keep only the elements *outside* of this range (with some boundary conditions)
|
# keep only the elements *outside* of this range (with some boundary conditions)
|
||||||
# i.e., concat (a[:start], a[end:])
|
# i.e., concat (a[:start], a[end:])
|
||||||
# min_size is only used for values, we do not sieve values under this size
|
# min_size is only used for values, we do not sieve values under this size
|
||||||
@@ -122,31 +119,18 @@ class KeyValueMemoryStore:
|
|||||||
if self.v[gi].shape[-1] >= min_size:
|
if self.v[gi].shape[-1] >= min_size:
|
||||||
self.v[gi] = self.v[gi][:,:,:start]
|
self.v[gi] = self.v[gi][:,:,:start]
|
||||||
else:
|
else:
|
||||||
# self.k = torch.cat([self.k[:,:,:start], self.k[:,:,end:]], -1)
|
self.k = torch.cat([self.k[:,:,:start], self.k[:,:,end:]], -1)
|
||||||
# if self.count_usage:
|
|
||||||
# self.use_count = torch.cat([self.use_count[:,:,:start], self.use_count[:,:,end:]], -1)
|
|
||||||
# self.life_count = torch.cat([self.life_count[:,:,:start], self.life_count[:,:,end:]], -1)
|
|
||||||
# if self.s is not None:
|
|
||||||
# self.s = torch.cat([self.s[:,:,:start], self.s[:,:,end:]], -1)
|
|
||||||
# if self.e is not None:
|
|
||||||
# self.e = torch.cat([self.e[:,:,:start], self.e[:,:,end:]], -1)
|
|
||||||
|
|
||||||
# for gi in range(self.num_groups):
|
|
||||||
# if self.v[gi].shape[-1] >= min_size:
|
|
||||||
# self.v[gi] = torch.cat([self.v[gi][:,:,:start], self.v[gi][:,:,end:]], -1)
|
|
||||||
|
|
||||||
# key memory to be kept
|
|
||||||
self.k = torch.cat([self.k[:,:,keep[0]:keep[1]] for keep in keeps], dim=-1)
|
|
||||||
if self.count_usage:
|
if self.count_usage:
|
||||||
self.use_count = torch.cat([self.use_count[:,:,keep[0]:keep[1]] for keep in keeps], dim=-1)
|
self.use_count = torch.cat([self.use_count[:,:,:start], self.use_count[:,:,end:]], -1)
|
||||||
self.life_count = torch.cat([self.life_count[:,:,keep[0]:keep[1]] for keep in keeps], dim=-1)
|
self.life_count = torch.cat([self.life_count[:,:,:start], self.life_count[:,:,end:]], -1)
|
||||||
if self.s is not None:
|
if self.s is not None:
|
||||||
self.s = torch.cat([self.s[:,:,keep[0]:keep[1]] for keep in keeps], dim=-1)
|
self.s = torch.cat([self.s[:,:,:start], self.s[:,:,end:]], -1)
|
||||||
if self.e is not None:
|
if self.e is not None:
|
||||||
self.e = torch.cat([self.e[:,:,keep[0]:keep[1]] for keep in keeps], dim=-1)
|
self.e = torch.cat([self.e[:,:,:start], self.e[:,:,end:]], -1)
|
||||||
|
|
||||||
for gi in range(self.num_groups):
|
for gi in range(self.num_groups):
|
||||||
if self.v[gi].shape[-1] >= min_size:
|
if self.v[gi].shape[-1] >= min_size:
|
||||||
self.v[gi] = torch.cat([self.v[gi][:,:,keep[0]:keep[1]] for keep in keeps], dim=-1)
|
self.v[gi] = torch.cat([self.v[gi][:,:,:start], self.v[gi][:,:,end:]], -1)
|
||||||
|
|
||||||
def remove_obsolete_features(self, max_size: int):
|
def remove_obsolete_features(self, max_size: int):
|
||||||
# normalize with life duration
|
# normalize with life duration
|
||||||
@@ -179,7 +163,7 @@ class KeyValueMemoryStore:
|
|||||||
usage = self.use_count / self.life_count
|
usage = self.use_count / self.life_count
|
||||||
return usage
|
return usage
|
||||||
|
|
||||||
def get_all_sliced(self, start: int, end: int, spans: List):
|
def get_all_sliced(self, start: int, end: int):
|
||||||
# return k, sk, ek, usage in order, sliced by start and end
|
# return k, sk, ek, usage in order, sliced by start and end
|
||||||
|
|
||||||
if end == 0:
|
if end == 0:
|
||||||
@@ -189,15 +173,10 @@ class KeyValueMemoryStore:
|
|||||||
ek = self.e[:,:,start:] if self.e is not None else None
|
ek = self.e[:,:,start:] if self.e is not None else None
|
||||||
usage = self.get_usage()[:,:,start:]
|
usage = self.get_usage()[:,:,start:]
|
||||||
else:
|
else:
|
||||||
# k = self.k[:,:,start:end]
|
k = self.k[:,:,start:end]
|
||||||
# sk = self.s[:,:,start:end] if self.s is not None else None
|
sk = self.s[:,:,start:end] if self.s is not None else None
|
||||||
# ek = self.e[:,:,start:end] if self.e is not None else None
|
ek = self.e[:,:,start:end] if self.e is not None else None
|
||||||
# usage = self.get_usage()[:,:,start:end]
|
usage = self.get_usage()[:,:,start:end]
|
||||||
|
|
||||||
k = torch.cat([self.k[:,:,span[0]:span[1]] for span in spans], dim=-1)
|
|
||||||
sk = torch.cat([self.s[:,:,span[0]:span[1]] for span in spans], dim=-1) if self.s is not None else None
|
|
||||||
ek = torch.cat([self.e[:,:,span[0]:span[1]] for span in spans], dim=-1) if self.e is not None else None
|
|
||||||
usage = torch.cat([self.get_usage()[:,:,span[0]:span[1]] for span in spans], dim=-1)
|
|
||||||
|
|
||||||
return k, sk, ek, usage
|
return k, sk, ek, usage
|
||||||
|
|
||||||
@@ -233,4 +212,3 @@ class KeyValueMemoryStore:
|
|||||||
@property
|
@property
|
||||||
def selection(self):
|
def selection(self):
|
||||||
return self.e
|
return self.e
|
||||||
|
|
||||||
|
|||||||
@@ -182,13 +182,13 @@ class MemoryManager:
|
|||||||
if self.enable_long_term:
|
if self.enable_long_term:
|
||||||
# Do memory compressed if needed
|
# Do memory compressed if needed
|
||||||
if self.work_mem.size >= self.max_work_elements:
|
if self.work_mem.size >= self.max_work_elements:
|
||||||
|
print('remove memory')
|
||||||
# Remove obsolete features if needed
|
# Remove obsolete features if needed
|
||||||
if self.long_mem.size >= (self.max_long_elements-self.num_prototypes):
|
if self.long_mem.size >= (self.max_long_elements-self.num_prototypes):
|
||||||
self.long_mem.remove_obsolete_features(self.max_long_elements-self.num_prototypes)
|
self.long_mem.remove_obsolete_features(self.max_long_elements-self.num_prototypes)
|
||||||
|
|
||||||
self.compress_features()
|
self.compress_features()
|
||||||
|
|
||||||
|
|
||||||
def create_hidden_state(self, n, sample_key):
|
def create_hidden_state(self, n, sample_key):
|
||||||
# n is the TOTAL number of objects
|
# n is the TOTAL number of objects
|
||||||
h, w = sample_key.shape[-2:]
|
h, w = sample_key.shape[-2:]
|
||||||
@@ -212,13 +212,6 @@ class MemoryManager:
|
|||||||
HW = self.HW
|
HW = self.HW
|
||||||
candidate_value = []
|
candidate_value = []
|
||||||
total_work_mem_size = self.work_mem.size
|
total_work_mem_size = self.work_mem.size
|
||||||
|
|
||||||
# determine memory indices to be compressed and removed
|
|
||||||
# uniform sampling from 1 to -2
|
|
||||||
num_memory_elements = total_work_mem_size // HW
|
|
||||||
spans = [[6*HW, 8*HW]]
|
|
||||||
keeps = [[0, 6*HW], [8*HW, 15*HW]]
|
|
||||||
|
|
||||||
for gv in self.work_mem.value:
|
for gv in self.work_mem.value:
|
||||||
# Some object groups might be added later in the video
|
# Some object groups might be added later in the video
|
||||||
# So not all keys have values associated with all objects
|
# So not all keys have values associated with all objects
|
||||||
@@ -226,12 +219,7 @@ class MemoryManager:
|
|||||||
mem_size_in_this_group = gv.shape[-1]
|
mem_size_in_this_group = gv.shape[-1]
|
||||||
if mem_size_in_this_group == total_work_mem_size:
|
if mem_size_in_this_group == total_work_mem_size:
|
||||||
# full LT
|
# full LT
|
||||||
# candidate_value.append(gv[:,:,HW:-self.min_work_elements+HW])
|
candidate_value.append(gv[:,:,HW:-self.min_work_elements+HW])
|
||||||
candidate_value.append(torch.cat([gv[:,:,span[0]:span[1]] for span in spans], dim=-1))
|
|
||||||
# values = []
|
|
||||||
# for span in spans:
|
|
||||||
# values.append(gv[:,:,span[0]:span[1]])
|
|
||||||
# candidate_value.append(torch.concat(values, dim=-1))
|
|
||||||
else:
|
else:
|
||||||
# mem_size is smaller than total_work_mem_size, but at least HW
|
# mem_size is smaller than total_work_mem_size, but at least HW
|
||||||
assert HW <= mem_size_in_this_group < total_work_mem_size
|
assert HW <= mem_size_in_this_group < total_work_mem_size
|
||||||
@@ -244,15 +232,15 @@ class MemoryManager:
|
|||||||
|
|
||||||
# perform memory consolidation
|
# perform memory consolidation
|
||||||
prototype_key, prototype_value, prototype_shrinkage = self.consolidation(
|
prototype_key, prototype_value, prototype_shrinkage = self.consolidation(
|
||||||
*self.work_mem.get_all_sliced(HW, -self.min_work_elements+HW, spans), candidate_value)
|
*self.work_mem.get_all_sliced(HW, -self.min_work_elements+HW), candidate_value)
|
||||||
|
|
||||||
# remove consolidated working memory
|
# remove consolidated working memory
|
||||||
self.work_mem.sieve_by_range(HW, -self.min_work_elements+HW, min_size=self.min_work_elements+HW, keeps=keeps)
|
self.work_mem.sieve_by_range(HW, -self.min_work_elements+HW, min_size=self.min_work_elements+HW)
|
||||||
|
|
||||||
# print('remove working memory')
|
|
||||||
|
|
||||||
# add to long-term memory
|
# add to long-term memory
|
||||||
self.long_mem.add(prototype_key, prototype_value, prototype_shrinkage, selection=None, objects=None)
|
self.long_mem.add(prototype_key, prototype_value, prototype_shrinkage, selection=None, objects=None)
|
||||||
|
print(f'long memory size: {self.long_mem.size}')
|
||||||
|
print(f'work memory size: {self.work_mem.size}')
|
||||||
|
|
||||||
def consolidation(self, candidate_key, candidate_shrinkage, candidate_selection, usage, candidate_value):
|
def consolidation(self, candidate_key, candidate_shrinkage, candidate_selection, usage, candidate_value):
|
||||||
# keys: 1*C*N
|
# keys: 1*C*N
|
||||||
|
|||||||
Reference in New Issue
Block a user