mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 16:37:58 +01:00
add multi-object support to base_tracker
This commit is contained in:
@@ -9,7 +9,7 @@ https://user-images.githubusercontent.com/28050374/232070852-af2e85e5-a834-4bbc-
|
||||
|
||||
## Get Started
|
||||
|
||||
This is Get Started.
|
||||
|
||||
## Acknowledgement
|
||||
|
||||
The project is based on [Segment Anything](https://github.com/facebookresearch/segment-anything) and [XMem](https://github.com/hkchengrex/XMem). Thanks for the authors for their efforts.
|
||||
|
||||
@@ -87,7 +87,11 @@ class BaseTracker:
|
||||
# convert to mask
|
||||
out_mask = torch.argmax(probs, dim=0)
|
||||
out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
|
||||
painted_image = mask_painter(frame, out_mask)
|
||||
|
||||
num_objs = out_mask.max()
|
||||
painted_image = frame
|
||||
for obj in range(1, num_objs+1):
|
||||
painted_image = mask_painter(painted_image, (out_mask==obj).astype('uint8'), mask_color=obj+1)
|
||||
return out_mask, probs, painted_image
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -114,11 +118,11 @@ class BaseTracker:
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# video frames
|
||||
video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/dance-twirl', '*.jpg'))
|
||||
# video frames (multiple objects)
|
||||
video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/horsejump-high', '*.jpg'))
|
||||
video_path_list.sort()
|
||||
# first frame
|
||||
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/dance-twirl/00000.png'
|
||||
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/horsejump-high/00000.png'
|
||||
# load frames
|
||||
frames = []
|
||||
for video_path in video_path_list:
|
||||
@@ -130,21 +134,22 @@ if __name__ == '__main__':
|
||||
# ----------------------------------------------------------
|
||||
# initalise tracker
|
||||
# ----------------------------------------------------------
|
||||
device = 'cuda:1'
|
||||
device = 'cuda:0'
|
||||
XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth'
|
||||
SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
|
||||
model_type = 'vit_h'
|
||||
tracker = BaseTracker(XMEM_checkpoint, device, SAM_checkpoint, model_type)
|
||||
|
||||
# # track anything given in the first frame annotation
|
||||
# for ti, frame in enumerate(frames):
|
||||
# if ti == 0:
|
||||
# mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
|
||||
# else:
|
||||
# mask, prob, painted_image = tracker.track(frame)
|
||||
# # save
|
||||
# painted_image = Image.fromarray(painted_image)
|
||||
# painted_image.save(f'/ssd1/gaomingqi/results/TrackA/dance-twirl/{ti:05d}.png')
|
||||
|
||||
# track anything given in the first frame annotation
|
||||
for ti, frame in enumerate(frames):
|
||||
if ti == 0:
|
||||
mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
|
||||
else:
|
||||
mask, prob, painted_image = tracker.track(frame)
|
||||
# save
|
||||
painted_image = Image.fromarray(painted_image)
|
||||
painted_image.save(f'/ssd1/gaomingqi/results/TrackA/horsejump-high/{ti:05d}.png')
|
||||
|
||||
# # ----------------------------------------------------------
|
||||
# # another video
|
||||
@@ -175,21 +180,21 @@ if __name__ == '__main__':
|
||||
# painted_image = Image.fromarray(painted_image)
|
||||
# painted_image.save(f'/ssd1/gaomingqi/results/TrackA/camel/{ti:05d}.png')
|
||||
|
||||
# failure case test
|
||||
failure_path = '/ssd1/gaomingqi/failure'
|
||||
frames = np.load(os.path.join(failure_path, 'video_frames.npy'))
|
||||
# first_frame = np.array(Image.open(os.path.join(failure_path, 'template_frame.png')).convert('RGB'))
|
||||
first_mask = np.array(Image.open(os.path.join(failure_path, 'template_mask.png')).convert('P'))
|
||||
first_mask = np.clip(first_mask, 0, 1)
|
||||
# # failure case test
|
||||
# failure_path = '/ssd1/gaomingqi/failure'
|
||||
# frames = np.load(os.path.join(failure_path, 'video_frames.npy'))
|
||||
# # first_frame = np.array(Image.open(os.path.join(failure_path, 'template_frame.png')).convert('RGB'))
|
||||
# first_mask = np.array(Image.open(os.path.join(failure_path, 'template_mask.png')).convert('P'))
|
||||
# first_mask = np.clip(first_mask, 0, 1)
|
||||
|
||||
for ti, frame in enumerate(frames):
|
||||
if ti == 0:
|
||||
mask, probs, painted_image = tracker.track(frame, first_mask)
|
||||
else:
|
||||
mask, probs, painted_image = tracker.track(frame)
|
||||
# save
|
||||
painted_image = Image.fromarray(painted_image)
|
||||
painted_image.save(f'/ssd1/gaomingqi/failure/LJ/{ti:05d}.png')
|
||||
prob = Image.fromarray((probs[1].cpu().numpy()*255).astype('uint8'))
|
||||
# for ti, frame in enumerate(frames):
|
||||
# if ti == 0:
|
||||
# mask, probs, painted_image = tracker.track(frame, first_mask)
|
||||
# else:
|
||||
# mask, probs, painted_image = tracker.track(frame)
|
||||
# # save
|
||||
# painted_image = Image.fromarray(painted_image)
|
||||
# painted_image.save(f'/ssd1/gaomingqi/failure/LJ/{ti:05d}.png')
|
||||
# prob = Image.fromarray((probs[1].cpu().numpy()*255).astype('uint8'))
|
||||
|
||||
# prob.save(f'/ssd1/gaomingqi/failure/probs/{ti:05d}.png')
|
||||
# # prob.save(f'/ssd1/gaomingqi/failure/probs/{ti:05d}.png')
|
||||
|
||||
Reference in New Issue
Block a user