From 5ca44baea36b7c66043342afc9ffb966e6d24417 Mon Sep 17 00:00:00 2001 From: Yioutpi Date: Fri, 14 Apr 2023 04:46:20 +0800 Subject: [PATCH 1/4] update return note --- app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app.py b/app.py index 9db7f3b..3b695e7 100644 --- a/app.py +++ b/app.py @@ -31,7 +31,7 @@ def get_frames_from_video(video_input, play_state): video_path:str timestamp:float64 Return - [[0:nearest_frame-1], [nearest_frame+1], nearest_frame] + [[0:nearest_frame], [nearest_frame:], nearest_frame] """ video_path = video_input timestamp = play_state[1] - play_state[0] From c11c310361b3c57ed608d27dad5f302c46f35050 Mon Sep 17 00:00:00 2001 From: gaomingqi Date: Fri, 14 Apr 2023 05:37:08 +0800 Subject: [PATCH 2/4] update base_tracker --- tracker/base_tracker.py | 58 ++++++++++++++++++++++++++++++++--------- 1 file changed, 46 insertions(+), 12 deletions(-) diff --git a/tracker/base_tracker.py b/tracker/base_tracker.py index 1165bc3..e4f59cf 100644 --- a/tracker/base_tracker.py +++ b/tracker/base_tracker.py @@ -9,9 +9,16 @@ import torch import yaml from model.network import XMem from inference.inference_core import InferenceCore +from inference.data.mask_mapper import MaskMapper + # for data transormation from torchvision import transforms from dataset.range_transform import im_normalization +import torch.nn.functional as F + +import sys +sys.path.insert(0, sys.path[0]+"/../") +from tools.painter import mask_painter class BaseTracker: @@ -32,6 +39,14 @@ class BaseTracker: transforms.ToTensor(), im_normalization, ]) + self.device = device + + def resize_mask(self, mask): + # mask transform is applied AFTER mapper, so we need to post-process it in eval.py + h, w = mask.shape[-2:] + min_hw = min(h, w) + return F.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)), + mode='nearest') def track(self, frames, first_frame_annotation): """ @@ -40,14 +55,36 @@ class BaseTracker: first_frame_annotation: numpy array: H, W Output: - masks: numpy arrays: T, H, W + masks: numpy arrays: H, W """ - # data transformation - for frame in frames: - frame = self.im_transform(frame) - - # tracking + shape = np.array(frames).shape[1:3] # H, W + frame_list = [self.im_transform(frame).to(self.device) for frame in frames] + frame_tensors = torch.stack(frame_list, dim=0) + # data transformation + mapper = MaskMapper() + + vid_length = len(frame_tensors) + + for ti, frame_tensor in enumerate(frame_tensors): + if ti == 0: + mask, labels = mapper.convert_mask(first_frame_annotation) + mask = torch.Tensor(mask).to(self.device) + self.tracker.set_all_labels(list(mapper.remappings.values())) + else: + mask = None + labels = None + + # track one frame + prob = self.tracker.step(frame_tensor, mask, labels, end=(ti==vid_length-1)) + + out_mask = torch.argmax(prob, dim=0) + out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8) + + painted_image = mask_painter(frames[ti], out_mask) + # save + painted_image = Image.fromarray(painted_image) + painted_image.save(f'/ssd1/gaomingqi/results/TrackA/{ti}.png') if __name__ == '__main__': @@ -56,20 +93,17 @@ if __name__ == '__main__': video_path_list.sort() # first frame first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/dance-twirl/00000.png' - # load frames -<<<<<<< HEAD -======= - frames = ["test_confict"] ->>>>>>> a5606340a199569856ffa1585aeeff5a40cc34ba + frames = [] for video_path in video_path_list: frames.append(np.array(Image.open(video_path).convert('RGB'))) frames = np.stack(frames, 0) # N, H, W, C - # load first frame annotation first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C + # ---------------------------------------------------------- # initalise tracker + # ---------------------------------------------------------- device = 'cuda:0' XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth' tracker = BaseTracker('cuda:0', XMEM_checkpoint) From c8ca9078ecce7f089a25029ba05a4a35f3368576 Mon Sep 17 00:00:00 2001 From: gaomingqi Date: Fri, 14 Apr 2023 10:17:41 +0800 Subject: [PATCH 3/4] segment videos via base_tracker --- tracker/base_tracker.py | 46 ++++++++++++++++++++--------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/tracker/base_tracker.py b/tracker/base_tracker.py index 99d328f..a544408 100644 --- a/tracker/base_tracker.py +++ b/tracker/base_tracker.py @@ -39,8 +39,10 @@ class BaseTracker: transforms.ToTensor(), im_normalization, ]) + self.mapper = MaskMapper() self.device = device + @torch.no_grad() def resize_mask(self, mask): # mask transform is applied AFTER mapper, so we need to post-process it in eval.py h, w = mask.shape[-2:] @@ -48,6 +50,7 @@ class BaseTracker: return F.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)), mode='nearest') + @torch.no_grad() def track(self, frames, first_frame_annotation): """ Input: @@ -57,34 +60,28 @@ class BaseTracker: Output: masks: numpy arrays: H, W """ - shape = np.array(frames).shape[1:3] # H, W - frame_list = [self.im_transform(frame).to(self.device) for frame in frames] - frame_tensors = torch.stack(frame_list, dim=0) - - # data transformation - mapper = MaskMapper() + vid_length = len(frames) + masks = [] - vid_length = len(frame_tensors) - - for ti, frame_tensor in enumerate(frame_tensors): + for ti, frame in enumerate(frames): + # convert to tensor + frame_tensor = self.im_transform(frame).to(self.device) if ti == 0: - mask, labels = mapper.convert_mask(first_frame_annotation) + mask, labels = self.mapper.convert_mask(first_frame_annotation) mask = torch.Tensor(mask).to(self.device) - self.tracker.set_all_labels(list(mapper.remappings.values())) + self.tracker.set_all_labels(list(self.mapper.remappings.values())) else: mask = None labels = None # track one frame prob = self.tracker.step(frame_tensor, mask, labels, end=(ti==vid_length-1)) - + # convert to mask out_mask = torch.argmax(prob, dim=0) out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8) + masks.append(out_mask) - painted_image = mask_painter(frames[ti], out_mask) - # save - painted_image = Image.fromarray(painted_image) - painted_image.save(f'/ssd1/gaomingqi/results/TrackA/{ti}.png') + return np.stack(masks, 0) if __name__ == '__main__': @@ -94,11 +91,7 @@ if __name__ == '__main__': # first frame first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/dance-twirl/00000.png' # load frames -<<<<<<< HEAD frames = [] -======= - frames = ["test_confict"] ->>>>>>> 5ca44baea36b7c66043342afc9ffb966e6d24417 for video_path in video_path_list: frames.append(np.array(Image.open(video_path).convert('RGB'))) frames = np.stack(frames, 0) # N, H, W, C @@ -108,9 +101,16 @@ if __name__ == '__main__': # ---------------------------------------------------------- # initalise tracker # ---------------------------------------------------------- - device = 'cuda:0' + device = 'cuda:1' XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth' - tracker = BaseTracker('cuda:0', XMEM_checkpoint) + tracker = BaseTracker(device, XMEM_checkpoint) # track anything given in the first frame annotation - tracker.track(frames, first_frame_annotation) + masks = tracker.track(frames, first_frame_annotation) + + # save + for ti, (frame, mask) in enumerate(zip(frames, masks)): + painted_image = mask_painter(frame, mask) + # save + painted_image = Image.fromarray(painted_image) + painted_image.save(f'/ssd1/gaomingqi/results/TrackA/{ti:05d}.png') From bb602906208bf62ec9cc476035040189c39704f4 Mon Sep 17 00:00:00 2001 From: ShangGaoG <12132332@mail.sustech.edu.cn> Date: Fri, 14 Apr 2023 10:19:51 +0800 Subject: [PATCH 4/4] mingqi --- tracker/base_tracker.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tracker/base_tracker.py b/tracker/base_tracker.py index 99d328f..d008893 100644 --- a/tracker/base_tracker.py +++ b/tracker/base_tracker.py @@ -94,11 +94,7 @@ if __name__ == '__main__': # first frame first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/dance-twirl/00000.png' # load frames -<<<<<<< HEAD - frames = [] -======= frames = ["test_confict"] ->>>>>>> 5ca44baea36b7c66043342afc9ffb966e6d24417 for video_path in video_path_list: frames.append(np.array(Image.open(video_path).convert('RGB'))) frames = np.stack(frames, 0) # N, H, W, C