From 0db89e924879d30d136bc0de3694bb3cc8cbb927 Mon Sep 17 00:00:00 2001 From: gaomingqi Date: Mon, 17 Apr 2023 19:02:32 +0800 Subject: [PATCH] fix memory management --- tracker/base_tracker.py | 55 ++++++++++++++++++++++------ tracker/config/config.yaml | 4 +- tracker/inference/kv_memory_store.py | 48 +++++++----------------- tracker/inference/memory_manager.py | 28 ++++---------- 4 files changed, 67 insertions(+), 68 deletions(-) diff --git a/tracker/base_tracker.py b/tracker/base_tracker.py index 4b139dd..46cab98 100644 --- a/tracker/base_tracker.py +++ b/tracker/base_tracker.py @@ -20,7 +20,7 @@ from torchvision.transforms import Resize class BaseTracker: - def __init__(self, xmem_checkpoint, device, sam_checkpoint, model_type) -> None: + def __init__(self, xmem_checkpoint, device, sam_model, model_type=None) -> None: """ device: model device xmem_checkpoint: checkpoint of XMem model @@ -43,9 +43,9 @@ class BaseTracker: self.mapper = MaskMapper() self.initialised = False - # SAM-based refinement - self.sam_model = BaseSegmenter(sam_checkpoint, model_type, device=device) - self.resizer = Resize([256, 256]) + # # SAM-based refinement + # self.sam_model = sam_model + # self.resizer = Resize([256, 256]) @torch.no_grad() def resize_mask(self, mask): @@ -78,8 +78,7 @@ class BaseTracker: # prepare inputs frame_tensor = self.im_transform(frame).to(self.device) # track one frame - probs, logits = self.tracker.step(frame_tensor, mask, labels) # logits 2 (bg fg) H W - + probs, _ = self.tracker.step(frame_tensor, mask, labels) # logits 2 (bg fg) H W # # refine # if first_frame_annotation is None: # out_mask = self.sam_refinement(frame, logits[1], ti) @@ -92,7 +91,8 @@ class BaseTracker: painted_image = frame for obj in range(1, num_objs+1): painted_image = mask_painter(painted_image, (out_mask==obj).astype('uint8'), mask_color=obj+1) - return out_mask, probs, painted_image + + return out_mask, out_mask, painted_image @torch.no_grad() def sam_refinement(self, frame, logits, ti): @@ -134,22 +134,52 @@ if __name__ == '__main__': # ---------------------------------------------------------- # initalise tracker # ---------------------------------------------------------- - device = 'cuda:0' + device = 'cuda:4' XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth' SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth' model_type = 'vit_h' - tracker = BaseTracker(XMEM_checkpoint, device, SAM_checkpoint, model_type) + # sam_model = BaseSegmenter(SAM_checkpoint, model_type, device=device) + tracker = BaseTracker(XMEM_checkpoint, device, None, device) + + # test for storage efficiency + frames = np.load('/ssd1/gaomingqi/efficiency/efficiency.npy') + first_frame_annotation = np.array(Image.open('/ssd1/gaomingqi/efficiency/template_mask.png')) - # track anything given in the first frame annotation for ti, frame in enumerate(frames): + print(ti) + if ti > 200: + break if ti == 0: mask, prob, painted_image = tracker.track(frame, first_frame_annotation) else: mask, prob, painted_image = tracker.track(frame) # save painted_image = Image.fromarray(painted_image) - painted_image.save(f'/ssd1/gaomingqi/results/TrackA/horsejump-high/{ti:05d}.png') + painted_image.save(f'/ssd1/gaomingqi/results/TrackA/gsw/{ti:05d}.png') + + tracker.clear_memory() + for ti, frame in enumerate(frames): + print(ti) + # if ti > 200: + # break + if ti == 0: + mask, prob, painted_image = tracker.track(frame, first_frame_annotation) + else: + mask, prob, painted_image = tracker.track(frame) + # save + painted_image = Image.fromarray(painted_image) + painted_image.save(f'/ssd1/gaomingqi/results/TrackA/gsw/{ti:05d}.png') + + # # track anything given in the first frame annotation + # for ti, frame in enumerate(frames): + # if ti == 0: + # mask, prob, painted_image = tracker.track(frame, first_frame_annotation) + # else: + # mask, prob, painted_image = tracker.track(frame) + # # save + # painted_image = Image.fromarray(painted_image) + # painted_image.save(f'/ssd1/gaomingqi/results/TrackA/horsejump-high/{ti:05d}.png') # # ---------------------------------------------------------- # # another video @@ -198,3 +228,6 @@ if __name__ == '__main__': # prob = Image.fromarray((probs[1].cpu().numpy()*255).astype('uint8')) # # prob.save(f'/ssd1/gaomingqi/failure/probs/{ti:05d}.png') + + + diff --git a/tracker/config/config.yaml b/tracker/config/config.yaml index 280f7f9..3c99064 100644 --- a/tracker/config/config.yaml +++ b/tracker/config/config.yaml @@ -1,9 +1,9 @@ # config info for XMem benchmark: False disable_long_term: False -max_mid_term_frames: 15 +max_mid_term_frames: 10 min_mid_term_frames: 5 -max_long_term_elements: 10000 +max_long_term_elements: 1000 num_prototypes: 128 top_k: 30 mem_every: 5 diff --git a/tracker/inference/kv_memory_store.py b/tracker/inference/kv_memory_store.py index 05c2efc..8e11130 100644 --- a/tracker/inference/kv_memory_store.py +++ b/tracker/inference/kv_memory_store.py @@ -29,9 +29,6 @@ class KeyValueMemoryStore: # shrinkage and selection are also single tensors self.s = self.e = None - # cumulated probs, softmax(HW) -> THW - self.cumulated_probs = [] - # usage if self.count_usage: self.use_count = self.life_count = None @@ -101,7 +98,7 @@ class KeyValueMemoryStore: self.use_count += usage.view_as(self.use_count) self.life_count += 1 - def sieve_by_range(self, start: int, end: int, min_size: int, keeps: List): + def sieve_by_range(self, start: int, end: int, min_size: int): # keep only the elements *outside* of this range (with some boundary conditions) # i.e., concat (a[:start], a[end:]) # min_size is only used for values, we do not sieve values under this size @@ -122,31 +119,18 @@ class KeyValueMemoryStore: if self.v[gi].shape[-1] >= min_size: self.v[gi] = self.v[gi][:,:,:start] else: - # self.k = torch.cat([self.k[:,:,:start], self.k[:,:,end:]], -1) - # if self.count_usage: - # self.use_count = torch.cat([self.use_count[:,:,:start], self.use_count[:,:,end:]], -1) - # self.life_count = torch.cat([self.life_count[:,:,:start], self.life_count[:,:,end:]], -1) - # if self.s is not None: - # self.s = torch.cat([self.s[:,:,:start], self.s[:,:,end:]], -1) - # if self.e is not None: - # self.e = torch.cat([self.e[:,:,:start], self.e[:,:,end:]], -1) - - # for gi in range(self.num_groups): - # if self.v[gi].shape[-1] >= min_size: - # self.v[gi] = torch.cat([self.v[gi][:,:,:start], self.v[gi][:,:,end:]], -1) - - # key memory to be kept - self.k = torch.cat([self.k[:,:,keep[0]:keep[1]] for keep in keeps], dim=-1) + self.k = torch.cat([self.k[:,:,:start], self.k[:,:,end:]], -1) if self.count_usage: - self.use_count = torch.cat([self.use_count[:,:,keep[0]:keep[1]] for keep in keeps], dim=-1) - self.life_count = torch.cat([self.life_count[:,:,keep[0]:keep[1]] for keep in keeps], dim=-1) + self.use_count = torch.cat([self.use_count[:,:,:start], self.use_count[:,:,end:]], -1) + self.life_count = torch.cat([self.life_count[:,:,:start], self.life_count[:,:,end:]], -1) if self.s is not None: - self.s = torch.cat([self.s[:,:,keep[0]:keep[1]] for keep in keeps], dim=-1) + self.s = torch.cat([self.s[:,:,:start], self.s[:,:,end:]], -1) if self.e is not None: - self.e = torch.cat([self.e[:,:,keep[0]:keep[1]] for keep in keeps], dim=-1) + self.e = torch.cat([self.e[:,:,:start], self.e[:,:,end:]], -1) + for gi in range(self.num_groups): if self.v[gi].shape[-1] >= min_size: - self.v[gi] = torch.cat([self.v[gi][:,:,keep[0]:keep[1]] for keep in keeps], dim=-1) + self.v[gi] = torch.cat([self.v[gi][:,:,:start], self.v[gi][:,:,end:]], -1) def remove_obsolete_features(self, max_size: int): # normalize with life duration @@ -179,7 +163,7 @@ class KeyValueMemoryStore: usage = self.use_count / self.life_count return usage - def get_all_sliced(self, start: int, end: int, spans: List): + def get_all_sliced(self, start: int, end: int): # return k, sk, ek, usage in order, sliced by start and end if end == 0: @@ -189,15 +173,10 @@ class KeyValueMemoryStore: ek = self.e[:,:,start:] if self.e is not None else None usage = self.get_usage()[:,:,start:] else: - # k = self.k[:,:,start:end] - # sk = self.s[:,:,start:end] if self.s is not None else None - # ek = self.e[:,:,start:end] if self.e is not None else None - # usage = self.get_usage()[:,:,start:end] - - k = torch.cat([self.k[:,:,span[0]:span[1]] for span in spans], dim=-1) - sk = torch.cat([self.s[:,:,span[0]:span[1]] for span in spans], dim=-1) if self.s is not None else None - ek = torch.cat([self.e[:,:,span[0]:span[1]] for span in spans], dim=-1) if self.e is not None else None - usage = torch.cat([self.get_usage()[:,:,span[0]:span[1]] for span in spans], dim=-1) + k = self.k[:,:,start:end] + sk = self.s[:,:,start:end] if self.s is not None else None + ek = self.e[:,:,start:end] if self.e is not None else None + usage = self.get_usage()[:,:,start:end] return k, sk, ek, usage @@ -233,4 +212,3 @@ class KeyValueMemoryStore: @property def selection(self): return self.e - diff --git a/tracker/inference/memory_manager.py b/tracker/inference/memory_manager.py index 9c4528e..adf6c85 100644 --- a/tracker/inference/memory_manager.py +++ b/tracker/inference/memory_manager.py @@ -138,7 +138,7 @@ class MemoryManager: affinity_one_group = do_softmax(similarity[:, -self.work_mem.get_v_size(gi):], top_k=self.top_k, inplace=(gi==num_groups-1)) affinity.append(affinity_one_group) - + all_memory_value = self.work_mem.value # Shared affinity within each group @@ -182,13 +182,13 @@ class MemoryManager: if self.enable_long_term: # Do memory compressed if needed if self.work_mem.size >= self.max_work_elements: + print('remove memory') # Remove obsolete features if needed if self.long_mem.size >= (self.max_long_elements-self.num_prototypes): self.long_mem.remove_obsolete_features(self.max_long_elements-self.num_prototypes) self.compress_features() - def create_hidden_state(self, n, sample_key): # n is the TOTAL number of objects h, w = sample_key.shape[-2:] @@ -212,13 +212,6 @@ class MemoryManager: HW = self.HW candidate_value = [] total_work_mem_size = self.work_mem.size - - # determine memory indices to be compressed and removed - # uniform sampling from 1 to -2 - num_memory_elements = total_work_mem_size // HW - spans = [[6*HW, 8*HW]] - keeps = [[0, 6*HW], [8*HW, 15*HW]] - for gv in self.work_mem.value: # Some object groups might be added later in the video # So not all keys have values associated with all objects @@ -226,12 +219,7 @@ class MemoryManager: mem_size_in_this_group = gv.shape[-1] if mem_size_in_this_group == total_work_mem_size: # full LT - # candidate_value.append(gv[:,:,HW:-self.min_work_elements+HW]) - candidate_value.append(torch.cat([gv[:,:,span[0]:span[1]] for span in spans], dim=-1)) - # values = [] - # for span in spans: - # values.append(gv[:,:,span[0]:span[1]]) - # candidate_value.append(torch.concat(values, dim=-1)) + candidate_value.append(gv[:,:,HW:-self.min_work_elements+HW]) else: # mem_size is smaller than total_work_mem_size, but at least HW assert HW <= mem_size_in_this_group < total_work_mem_size @@ -244,15 +232,15 @@ class MemoryManager: # perform memory consolidation prototype_key, prototype_value, prototype_shrinkage = self.consolidation( - *self.work_mem.get_all_sliced(HW, -self.min_work_elements+HW, spans), candidate_value) + *self.work_mem.get_all_sliced(HW, -self.min_work_elements+HW), candidate_value) # remove consolidated working memory - self.work_mem.sieve_by_range(HW, -self.min_work_elements+HW, min_size=self.min_work_elements+HW, keeps=keeps) - - # print('remove working memory') + self.work_mem.sieve_by_range(HW, -self.min_work_elements+HW, min_size=self.min_work_elements+HW) # add to long-term memory self.long_mem.add(prototype_key, prototype_value, prototype_shrinkage, selection=None, objects=None) + print(f'long memory size: {self.long_mem.size}') + print(f'work memory size: {self.work_mem.size}') def consolidation(self, candidate_key, candidate_shrinkage, candidate_selection, usage, candidate_value): # keys: 1*C*N @@ -295,4 +283,4 @@ class MemoryManager: # readout the shrinkage term prototype_shrinkage = self._readout(affinity[0], candidate_shrinkage) if candidate_shrinkage is not None else None - return prototype_key, prototype_value, prototype_shrinkage + return prototype_key, prototype_value, prototype_shrinkage \ No newline at end of file