fix memory management

This commit is contained in:
gaomingqi
2023-04-17 19:02:32 +08:00
parent 8b7333664f
commit 0db89e9248
4 changed files with 67 additions and 68 deletions

View File

@@ -20,7 +20,7 @@ from torchvision.transforms import Resize
class BaseTracker:
def __init__(self, xmem_checkpoint, device, sam_checkpoint, model_type) -> None:
def __init__(self, xmem_checkpoint, device, sam_model, model_type=None) -> None:
"""
device: model device
xmem_checkpoint: checkpoint of XMem model
@@ -43,9 +43,9 @@ class BaseTracker:
self.mapper = MaskMapper()
self.initialised = False
# SAM-based refinement
self.sam_model = BaseSegmenter(sam_checkpoint, model_type, device=device)
self.resizer = Resize([256, 256])
# # SAM-based refinement
# self.sam_model = sam_model
# self.resizer = Resize([256, 256])
@torch.no_grad()
def resize_mask(self, mask):
@@ -78,8 +78,7 @@ class BaseTracker:
# prepare inputs
frame_tensor = self.im_transform(frame).to(self.device)
# track one frame
probs, logits = self.tracker.step(frame_tensor, mask, labels) # logits 2 (bg fg) H W
probs, _ = self.tracker.step(frame_tensor, mask, labels) # logits 2 (bg fg) H W
# # refine
# if first_frame_annotation is None:
# out_mask = self.sam_refinement(frame, logits[1], ti)
@@ -92,7 +91,8 @@ class BaseTracker:
painted_image = frame
for obj in range(1, num_objs+1):
painted_image = mask_painter(painted_image, (out_mask==obj).astype('uint8'), mask_color=obj+1)
return out_mask, probs, painted_image
return out_mask, out_mask, painted_image
@torch.no_grad()
def sam_refinement(self, frame, logits, ti):
@@ -134,22 +134,52 @@ if __name__ == '__main__':
# ----------------------------------------------------------
# initalise tracker
# ----------------------------------------------------------
device = 'cuda:0'
device = 'cuda:4'
XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth'
SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
model_type = 'vit_h'
tracker = BaseTracker(XMEM_checkpoint, device, SAM_checkpoint, model_type)
# sam_model = BaseSegmenter(SAM_checkpoint, model_type, device=device)
tracker = BaseTracker(XMEM_checkpoint, device, None, device)
# test for storage efficiency
frames = np.load('/ssd1/gaomingqi/efficiency/efficiency.npy')
first_frame_annotation = np.array(Image.open('/ssd1/gaomingqi/efficiency/template_mask.png'))
# track anything given in the first frame annotation
for ti, frame in enumerate(frames):
print(ti)
if ti > 200:
break
if ti == 0:
mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
else:
mask, prob, painted_image = tracker.track(frame)
# save
painted_image = Image.fromarray(painted_image)
painted_image.save(f'/ssd1/gaomingqi/results/TrackA/horsejump-high/{ti:05d}.png')
painted_image.save(f'/ssd1/gaomingqi/results/TrackA/gsw/{ti:05d}.png')
tracker.clear_memory()
for ti, frame in enumerate(frames):
print(ti)
# if ti > 200:
# break
if ti == 0:
mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
else:
mask, prob, painted_image = tracker.track(frame)
# save
painted_image = Image.fromarray(painted_image)
painted_image.save(f'/ssd1/gaomingqi/results/TrackA/gsw/{ti:05d}.png')
# # track anything given in the first frame annotation
# for ti, frame in enumerate(frames):
# if ti == 0:
# mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
# else:
# mask, prob, painted_image = tracker.track(frame)
# # save
# painted_image = Image.fromarray(painted_image)
# painted_image.save(f'/ssd1/gaomingqi/results/TrackA/horsejump-high/{ti:05d}.png')
# # ----------------------------------------------------------
# # another video
@@ -198,3 +228,6 @@ if __name__ == '__main__':
# prob = Image.fromarray((probs[1].cpu().numpy()*255).astype('uint8'))
# # prob.save(f'/ssd1/gaomingqi/failure/probs/{ti:05d}.png')

View File

@@ -1,9 +1,9 @@
# config info for XMem
benchmark: False
disable_long_term: False
max_mid_term_frames: 15
max_mid_term_frames: 10
min_mid_term_frames: 5
max_long_term_elements: 10000
max_long_term_elements: 1000
num_prototypes: 128
top_k: 30
mem_every: 5

View File

@@ -29,9 +29,6 @@ class KeyValueMemoryStore:
# shrinkage and selection are also single tensors
self.s = self.e = None
# cumulated probs, softmax(HW) -> THW
self.cumulated_probs = []
# usage
if self.count_usage:
self.use_count = self.life_count = None
@@ -101,7 +98,7 @@ class KeyValueMemoryStore:
self.use_count += usage.view_as(self.use_count)
self.life_count += 1
def sieve_by_range(self, start: int, end: int, min_size: int, keeps: List):
def sieve_by_range(self, start: int, end: int, min_size: int):
# keep only the elements *outside* of this range (with some boundary conditions)
# i.e., concat (a[:start], a[end:])
# min_size is only used for values, we do not sieve values under this size
@@ -122,31 +119,18 @@ class KeyValueMemoryStore:
if self.v[gi].shape[-1] >= min_size:
self.v[gi] = self.v[gi][:,:,:start]
else:
# self.k = torch.cat([self.k[:,:,:start], self.k[:,:,end:]], -1)
# if self.count_usage:
# self.use_count = torch.cat([self.use_count[:,:,:start], self.use_count[:,:,end:]], -1)
# self.life_count = torch.cat([self.life_count[:,:,:start], self.life_count[:,:,end:]], -1)
# if self.s is not None:
# self.s = torch.cat([self.s[:,:,:start], self.s[:,:,end:]], -1)
# if self.e is not None:
# self.e = torch.cat([self.e[:,:,:start], self.e[:,:,end:]], -1)
# for gi in range(self.num_groups):
# if self.v[gi].shape[-1] >= min_size:
# self.v[gi] = torch.cat([self.v[gi][:,:,:start], self.v[gi][:,:,end:]], -1)
# key memory to be kept
self.k = torch.cat([self.k[:,:,keep[0]:keep[1]] for keep in keeps], dim=-1)
self.k = torch.cat([self.k[:,:,:start], self.k[:,:,end:]], -1)
if self.count_usage:
self.use_count = torch.cat([self.use_count[:,:,keep[0]:keep[1]] for keep in keeps], dim=-1)
self.life_count = torch.cat([self.life_count[:,:,keep[0]:keep[1]] for keep in keeps], dim=-1)
self.use_count = torch.cat([self.use_count[:,:,:start], self.use_count[:,:,end:]], -1)
self.life_count = torch.cat([self.life_count[:,:,:start], self.life_count[:,:,end:]], -1)
if self.s is not None:
self.s = torch.cat([self.s[:,:,keep[0]:keep[1]] for keep in keeps], dim=-1)
self.s = torch.cat([self.s[:,:,:start], self.s[:,:,end:]], -1)
if self.e is not None:
self.e = torch.cat([self.e[:,:,keep[0]:keep[1]] for keep in keeps], dim=-1)
self.e = torch.cat([self.e[:,:,:start], self.e[:,:,end:]], -1)
for gi in range(self.num_groups):
if self.v[gi].shape[-1] >= min_size:
self.v[gi] = torch.cat([self.v[gi][:,:,keep[0]:keep[1]] for keep in keeps], dim=-1)
self.v[gi] = torch.cat([self.v[gi][:,:,:start], self.v[gi][:,:,end:]], -1)
def remove_obsolete_features(self, max_size: int):
# normalize with life duration
@@ -179,7 +163,7 @@ class KeyValueMemoryStore:
usage = self.use_count / self.life_count
return usage
def get_all_sliced(self, start: int, end: int, spans: List):
def get_all_sliced(self, start: int, end: int):
# return k, sk, ek, usage in order, sliced by start and end
if end == 0:
@@ -189,15 +173,10 @@ class KeyValueMemoryStore:
ek = self.e[:,:,start:] if self.e is not None else None
usage = self.get_usage()[:,:,start:]
else:
# k = self.k[:,:,start:end]
# sk = self.s[:,:,start:end] if self.s is not None else None
# ek = self.e[:,:,start:end] if self.e is not None else None
# usage = self.get_usage()[:,:,start:end]
k = torch.cat([self.k[:,:,span[0]:span[1]] for span in spans], dim=-1)
sk = torch.cat([self.s[:,:,span[0]:span[1]] for span in spans], dim=-1) if self.s is not None else None
ek = torch.cat([self.e[:,:,span[0]:span[1]] for span in spans], dim=-1) if self.e is not None else None
usage = torch.cat([self.get_usage()[:,:,span[0]:span[1]] for span in spans], dim=-1)
k = self.k[:,:,start:end]
sk = self.s[:,:,start:end] if self.s is not None else None
ek = self.e[:,:,start:end] if self.e is not None else None
usage = self.get_usage()[:,:,start:end]
return k, sk, ek, usage
@@ -233,4 +212,3 @@ class KeyValueMemoryStore:
@property
def selection(self):
return self.e

View File

@@ -138,7 +138,7 @@ class MemoryManager:
affinity_one_group = do_softmax(similarity[:, -self.work_mem.get_v_size(gi):],
top_k=self.top_k, inplace=(gi==num_groups-1))
affinity.append(affinity_one_group)
all_memory_value = self.work_mem.value
# Shared affinity within each group
@@ -182,13 +182,13 @@ class MemoryManager:
if self.enable_long_term:
# Do memory compressed if needed
if self.work_mem.size >= self.max_work_elements:
print('remove memory')
# Remove obsolete features if needed
if self.long_mem.size >= (self.max_long_elements-self.num_prototypes):
self.long_mem.remove_obsolete_features(self.max_long_elements-self.num_prototypes)
self.compress_features()
def create_hidden_state(self, n, sample_key):
# n is the TOTAL number of objects
h, w = sample_key.shape[-2:]
@@ -212,13 +212,6 @@ class MemoryManager:
HW = self.HW
candidate_value = []
total_work_mem_size = self.work_mem.size
# determine memory indices to be compressed and removed
# uniform sampling from 1 to -2
num_memory_elements = total_work_mem_size // HW
spans = [[6*HW, 8*HW]]
keeps = [[0, 6*HW], [8*HW, 15*HW]]
for gv in self.work_mem.value:
# Some object groups might be added later in the video
# So not all keys have values associated with all objects
@@ -226,12 +219,7 @@ class MemoryManager:
mem_size_in_this_group = gv.shape[-1]
if mem_size_in_this_group == total_work_mem_size:
# full LT
# candidate_value.append(gv[:,:,HW:-self.min_work_elements+HW])
candidate_value.append(torch.cat([gv[:,:,span[0]:span[1]] for span in spans], dim=-1))
# values = []
# for span in spans:
# values.append(gv[:,:,span[0]:span[1]])
# candidate_value.append(torch.concat(values, dim=-1))
candidate_value.append(gv[:,:,HW:-self.min_work_elements+HW])
else:
# mem_size is smaller than total_work_mem_size, but at least HW
assert HW <= mem_size_in_this_group < total_work_mem_size
@@ -244,15 +232,15 @@ class MemoryManager:
# perform memory consolidation
prototype_key, prototype_value, prototype_shrinkage = self.consolidation(
*self.work_mem.get_all_sliced(HW, -self.min_work_elements+HW, spans), candidate_value)
*self.work_mem.get_all_sliced(HW, -self.min_work_elements+HW), candidate_value)
# remove consolidated working memory
self.work_mem.sieve_by_range(HW, -self.min_work_elements+HW, min_size=self.min_work_elements+HW, keeps=keeps)
# print('remove working memory')
self.work_mem.sieve_by_range(HW, -self.min_work_elements+HW, min_size=self.min_work_elements+HW)
# add to long-term memory
self.long_mem.add(prototype_key, prototype_value, prototype_shrinkage, selection=None, objects=None)
print(f'long memory size: {self.long_mem.size}')
print(f'work memory size: {self.work_mem.size}')
def consolidation(self, candidate_key, candidate_shrinkage, candidate_selection, usage, candidate_value):
# keys: 1*C*N
@@ -295,4 +283,4 @@ class MemoryManager:
# readout the shrinkage term
prototype_shrinkage = self._readout(affinity[0], candidate_shrinkage) if candidate_shrinkage is not None else None
return prototype_key, prototype_value, prototype_shrinkage
return prototype_key, prototype_value, prototype_shrinkage