add multi-object support to base_tracker

This commit is contained in:
gaomingqi
2023-04-17 12:33:37 +08:00
parent 7237edfaea
commit 32cc1c33c3
2 changed files with 38 additions and 33 deletions

View File

@@ -9,7 +9,7 @@ https://user-images.githubusercontent.com/28050374/232070852-af2e85e5-a834-4bbc-
## Get Started
This is Get Started.
## Acknowledgement
The project is based on [Segment Anything](https://github.com/facebookresearch/segment-anything) and [XMem](https://github.com/hkchengrex/XMem). Thanks for the authors for their efforts.

View File

@@ -87,7 +87,11 @@ class BaseTracker:
# convert to mask
out_mask = torch.argmax(probs, dim=0)
out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
painted_image = mask_painter(frame, out_mask)
num_objs = out_mask.max()
painted_image = frame
for obj in range(1, num_objs+1):
painted_image = mask_painter(painted_image, (out_mask==obj).astype('uint8'), mask_color=obj+1)
return out_mask, probs, painted_image
@torch.no_grad()
@@ -114,11 +118,11 @@ class BaseTracker:
if __name__ == '__main__':
# video frames
video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/dance-twirl', '*.jpg'))
# video frames (multiple objects)
video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/horsejump-high', '*.jpg'))
video_path_list.sort()
# first frame
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/dance-twirl/00000.png'
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/horsejump-high/00000.png'
# load frames
frames = []
for video_path in video_path_list:
@@ -130,21 +134,22 @@ if __name__ == '__main__':
# ----------------------------------------------------------
# initalise tracker
# ----------------------------------------------------------
device = 'cuda:1'
device = 'cuda:0'
XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth'
SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
model_type = 'vit_h'
tracker = BaseTracker(XMEM_checkpoint, device, SAM_checkpoint, model_type)
# # track anything given in the first frame annotation
# for ti, frame in enumerate(frames):
# if ti == 0:
# mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
# else:
# mask, prob, painted_image = tracker.track(frame)
# # save
# painted_image = Image.fromarray(painted_image)
# painted_image.save(f'/ssd1/gaomingqi/results/TrackA/dance-twirl/{ti:05d}.png')
# track anything given in the first frame annotation
for ti, frame in enumerate(frames):
if ti == 0:
mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
else:
mask, prob, painted_image = tracker.track(frame)
# save
painted_image = Image.fromarray(painted_image)
painted_image.save(f'/ssd1/gaomingqi/results/TrackA/horsejump-high/{ti:05d}.png')
# # ----------------------------------------------------------
# # another video
@@ -175,21 +180,21 @@ if __name__ == '__main__':
# painted_image = Image.fromarray(painted_image)
# painted_image.save(f'/ssd1/gaomingqi/results/TrackA/camel/{ti:05d}.png')
# failure case test
failure_path = '/ssd1/gaomingqi/failure'
frames = np.load(os.path.join(failure_path, 'video_frames.npy'))
# first_frame = np.array(Image.open(os.path.join(failure_path, 'template_frame.png')).convert('RGB'))
first_mask = np.array(Image.open(os.path.join(failure_path, 'template_mask.png')).convert('P'))
first_mask = np.clip(first_mask, 0, 1)
# # failure case test
# failure_path = '/ssd1/gaomingqi/failure'
# frames = np.load(os.path.join(failure_path, 'video_frames.npy'))
# # first_frame = np.array(Image.open(os.path.join(failure_path, 'template_frame.png')).convert('RGB'))
# first_mask = np.array(Image.open(os.path.join(failure_path, 'template_mask.png')).convert('P'))
# first_mask = np.clip(first_mask, 0, 1)
for ti, frame in enumerate(frames):
if ti == 0:
mask, probs, painted_image = tracker.track(frame, first_mask)
else:
mask, probs, painted_image = tracker.track(frame)
# save
painted_image = Image.fromarray(painted_image)
painted_image.save(f'/ssd1/gaomingqi/failure/LJ/{ti:05d}.png')
prob = Image.fromarray((probs[1].cpu().numpy()*255).astype('uint8'))
# for ti, frame in enumerate(frames):
# if ti == 0:
# mask, probs, painted_image = tracker.track(frame, first_mask)
# else:
# mask, probs, painted_image = tracker.track(frame)
# # save
# painted_image = Image.fromarray(painted_image)
# painted_image.save(f'/ssd1/gaomingqi/failure/LJ/{ti:05d}.png')
# prob = Image.fromarray((probs[1].cpu().numpy()*255).astype('uint8'))
# prob.save(f'/ssd1/gaomingqi/failure/probs/{ti:05d}.png')
# # prob.save(f'/ssd1/gaomingqi/failure/probs/{ti:05d}.png')