add multi-object support to base_tracker

This commit is contained in:
gaomingqi
2023-04-17 12:33:37 +08:00
parent 7237edfaea
commit 32cc1c33c3
2 changed files with 38 additions and 33 deletions

View File

@@ -9,7 +9,7 @@ https://user-images.githubusercontent.com/28050374/232070852-af2e85e5-a834-4bbc-
## Get Started ## Get Started
This is Get Started.
## Acknowledgement ## Acknowledgement
The project is based on [Segment Anything](https://github.com/facebookresearch/segment-anything) and [XMem](https://github.com/hkchengrex/XMem). Thanks for the authors for their efforts. The project is based on [Segment Anything](https://github.com/facebookresearch/segment-anything) and [XMem](https://github.com/hkchengrex/XMem). Thanks for the authors for their efforts.

View File

@@ -87,7 +87,11 @@ class BaseTracker:
# convert to mask # convert to mask
out_mask = torch.argmax(probs, dim=0) out_mask = torch.argmax(probs, dim=0)
out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8) out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
painted_image = mask_painter(frame, out_mask)
num_objs = out_mask.max()
painted_image = frame
for obj in range(1, num_objs+1):
painted_image = mask_painter(painted_image, (out_mask==obj).astype('uint8'), mask_color=obj+1)
return out_mask, probs, painted_image return out_mask, probs, painted_image
@torch.no_grad() @torch.no_grad()
@@ -114,11 +118,11 @@ class BaseTracker:
if __name__ == '__main__': if __name__ == '__main__':
# video frames # video frames (multiple objects)
video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/dance-twirl', '*.jpg')) video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/horsejump-high', '*.jpg'))
video_path_list.sort() video_path_list.sort()
# first frame # first frame
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/dance-twirl/00000.png' first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/horsejump-high/00000.png'
# load frames # load frames
frames = [] frames = []
for video_path in video_path_list: for video_path in video_path_list:
@@ -130,21 +134,22 @@ if __name__ == '__main__':
# ---------------------------------------------------------- # ----------------------------------------------------------
# initalise tracker # initalise tracker
# ---------------------------------------------------------- # ----------------------------------------------------------
device = 'cuda:1' device = 'cuda:0'
XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth' XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth'
SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth' SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
model_type = 'vit_h' model_type = 'vit_h'
tracker = BaseTracker(XMEM_checkpoint, device, SAM_checkpoint, model_type) tracker = BaseTracker(XMEM_checkpoint, device, SAM_checkpoint, model_type)
# # track anything given in the first frame annotation
# for ti, frame in enumerate(frames): # track anything given in the first frame annotation
# if ti == 0: for ti, frame in enumerate(frames):
# mask, prob, painted_image = tracker.track(frame, first_frame_annotation) if ti == 0:
# else: mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
# mask, prob, painted_image = tracker.track(frame) else:
# # save mask, prob, painted_image = tracker.track(frame)
# painted_image = Image.fromarray(painted_image) # save
# painted_image.save(f'/ssd1/gaomingqi/results/TrackA/dance-twirl/{ti:05d}.png') painted_image = Image.fromarray(painted_image)
painted_image.save(f'/ssd1/gaomingqi/results/TrackA/horsejump-high/{ti:05d}.png')
# # ---------------------------------------------------------- # # ----------------------------------------------------------
# # another video # # another video
@@ -161,7 +166,7 @@ if __name__ == '__main__':
# frames = np.stack(frames, 0) # N, H, W, C # frames = np.stack(frames, 0) # N, H, W, C
# # load first frame annotation # # load first frame annotation
# first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C # first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C
# print('first video done. clear.') # print('first video done. clear.')
# tracker.clear_memory() # tracker.clear_memory()
@@ -175,21 +180,21 @@ if __name__ == '__main__':
# painted_image = Image.fromarray(painted_image) # painted_image = Image.fromarray(painted_image)
# painted_image.save(f'/ssd1/gaomingqi/results/TrackA/camel/{ti:05d}.png') # painted_image.save(f'/ssd1/gaomingqi/results/TrackA/camel/{ti:05d}.png')
# failure case test # # failure case test
failure_path = '/ssd1/gaomingqi/failure' # failure_path = '/ssd1/gaomingqi/failure'
frames = np.load(os.path.join(failure_path, 'video_frames.npy')) # frames = np.load(os.path.join(failure_path, 'video_frames.npy'))
# first_frame = np.array(Image.open(os.path.join(failure_path, 'template_frame.png')).convert('RGB')) # # first_frame = np.array(Image.open(os.path.join(failure_path, 'template_frame.png')).convert('RGB'))
first_mask = np.array(Image.open(os.path.join(failure_path, 'template_mask.png')).convert('P')) # first_mask = np.array(Image.open(os.path.join(failure_path, 'template_mask.png')).convert('P'))
first_mask = np.clip(first_mask, 0, 1) # first_mask = np.clip(first_mask, 0, 1)
for ti, frame in enumerate(frames): # for ti, frame in enumerate(frames):
if ti == 0: # if ti == 0:
mask, probs, painted_image = tracker.track(frame, first_mask) # mask, probs, painted_image = tracker.track(frame, first_mask)
else: # else:
mask, probs, painted_image = tracker.track(frame) # mask, probs, painted_image = tracker.track(frame)
# save # # save
painted_image = Image.fromarray(painted_image) # painted_image = Image.fromarray(painted_image)
painted_image.save(f'/ssd1/gaomingqi/failure/LJ/{ti:05d}.png') # painted_image.save(f'/ssd1/gaomingqi/failure/LJ/{ti:05d}.png')
prob = Image.fromarray((probs[1].cpu().numpy()*255).astype('uint8')) # prob = Image.fromarray((probs[1].cpu().numpy()*255).astype('uint8'))
# prob.save(f'/ssd1/gaomingqi/failure/probs/{ti:05d}.png') # # prob.save(f'/ssd1/gaomingqi/failure/probs/{ti:05d}.png')