mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 16:37:58 +01:00
add multi-object support to base_tracker
This commit is contained in:
@@ -9,7 +9,7 @@ https://user-images.githubusercontent.com/28050374/232070852-af2e85e5-a834-4bbc-
|
|||||||
|
|
||||||
## Get Started
|
## Get Started
|
||||||
|
|
||||||
This is Get Started.
|
|
||||||
## Acknowledgement
|
## Acknowledgement
|
||||||
|
|
||||||
The project is based on [Segment Anything](https://github.com/facebookresearch/segment-anything) and [XMem](https://github.com/hkchengrex/XMem). Thanks for the authors for their efforts.
|
The project is based on [Segment Anything](https://github.com/facebookresearch/segment-anything) and [XMem](https://github.com/hkchengrex/XMem). Thanks for the authors for their efforts.
|
||||||
|
|||||||
@@ -87,7 +87,11 @@ class BaseTracker:
|
|||||||
# convert to mask
|
# convert to mask
|
||||||
out_mask = torch.argmax(probs, dim=0)
|
out_mask = torch.argmax(probs, dim=0)
|
||||||
out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
|
out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
|
||||||
painted_image = mask_painter(frame, out_mask)
|
|
||||||
|
num_objs = out_mask.max()
|
||||||
|
painted_image = frame
|
||||||
|
for obj in range(1, num_objs+1):
|
||||||
|
painted_image = mask_painter(painted_image, (out_mask==obj).astype('uint8'), mask_color=obj+1)
|
||||||
return out_mask, probs, painted_image
|
return out_mask, probs, painted_image
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -114,11 +118,11 @@ class BaseTracker:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# video frames
|
# video frames (multiple objects)
|
||||||
video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/dance-twirl', '*.jpg'))
|
video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/horsejump-high', '*.jpg'))
|
||||||
video_path_list.sort()
|
video_path_list.sort()
|
||||||
# first frame
|
# first frame
|
||||||
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/dance-twirl/00000.png'
|
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/horsejump-high/00000.png'
|
||||||
# load frames
|
# load frames
|
||||||
frames = []
|
frames = []
|
||||||
for video_path in video_path_list:
|
for video_path in video_path_list:
|
||||||
@@ -130,21 +134,22 @@ if __name__ == '__main__':
|
|||||||
# ----------------------------------------------------------
|
# ----------------------------------------------------------
|
||||||
# initalise tracker
|
# initalise tracker
|
||||||
# ----------------------------------------------------------
|
# ----------------------------------------------------------
|
||||||
device = 'cuda:1'
|
device = 'cuda:0'
|
||||||
XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth'
|
XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth'
|
||||||
SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
|
SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
|
||||||
model_type = 'vit_h'
|
model_type = 'vit_h'
|
||||||
tracker = BaseTracker(XMEM_checkpoint, device, SAM_checkpoint, model_type)
|
tracker = BaseTracker(XMEM_checkpoint, device, SAM_checkpoint, model_type)
|
||||||
|
|
||||||
# # track anything given in the first frame annotation
|
|
||||||
# for ti, frame in enumerate(frames):
|
# track anything given in the first frame annotation
|
||||||
# if ti == 0:
|
for ti, frame in enumerate(frames):
|
||||||
# mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
|
if ti == 0:
|
||||||
# else:
|
mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
|
||||||
# mask, prob, painted_image = tracker.track(frame)
|
else:
|
||||||
# # save
|
mask, prob, painted_image = tracker.track(frame)
|
||||||
# painted_image = Image.fromarray(painted_image)
|
# save
|
||||||
# painted_image.save(f'/ssd1/gaomingqi/results/TrackA/dance-twirl/{ti:05d}.png')
|
painted_image = Image.fromarray(painted_image)
|
||||||
|
painted_image.save(f'/ssd1/gaomingqi/results/TrackA/horsejump-high/{ti:05d}.png')
|
||||||
|
|
||||||
# # ----------------------------------------------------------
|
# # ----------------------------------------------------------
|
||||||
# # another video
|
# # another video
|
||||||
@@ -161,7 +166,7 @@ if __name__ == '__main__':
|
|||||||
# frames = np.stack(frames, 0) # N, H, W, C
|
# frames = np.stack(frames, 0) # N, H, W, C
|
||||||
# # load first frame annotation
|
# # load first frame annotation
|
||||||
# first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C
|
# first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C
|
||||||
|
|
||||||
# print('first video done. clear.')
|
# print('first video done. clear.')
|
||||||
|
|
||||||
# tracker.clear_memory()
|
# tracker.clear_memory()
|
||||||
@@ -175,21 +180,21 @@ if __name__ == '__main__':
|
|||||||
# painted_image = Image.fromarray(painted_image)
|
# painted_image = Image.fromarray(painted_image)
|
||||||
# painted_image.save(f'/ssd1/gaomingqi/results/TrackA/camel/{ti:05d}.png')
|
# painted_image.save(f'/ssd1/gaomingqi/results/TrackA/camel/{ti:05d}.png')
|
||||||
|
|
||||||
# failure case test
|
# # failure case test
|
||||||
failure_path = '/ssd1/gaomingqi/failure'
|
# failure_path = '/ssd1/gaomingqi/failure'
|
||||||
frames = np.load(os.path.join(failure_path, 'video_frames.npy'))
|
# frames = np.load(os.path.join(failure_path, 'video_frames.npy'))
|
||||||
# first_frame = np.array(Image.open(os.path.join(failure_path, 'template_frame.png')).convert('RGB'))
|
# # first_frame = np.array(Image.open(os.path.join(failure_path, 'template_frame.png')).convert('RGB'))
|
||||||
first_mask = np.array(Image.open(os.path.join(failure_path, 'template_mask.png')).convert('P'))
|
# first_mask = np.array(Image.open(os.path.join(failure_path, 'template_mask.png')).convert('P'))
|
||||||
first_mask = np.clip(first_mask, 0, 1)
|
# first_mask = np.clip(first_mask, 0, 1)
|
||||||
|
|
||||||
for ti, frame in enumerate(frames):
|
# for ti, frame in enumerate(frames):
|
||||||
if ti == 0:
|
# if ti == 0:
|
||||||
mask, probs, painted_image = tracker.track(frame, first_mask)
|
# mask, probs, painted_image = tracker.track(frame, first_mask)
|
||||||
else:
|
# else:
|
||||||
mask, probs, painted_image = tracker.track(frame)
|
# mask, probs, painted_image = tracker.track(frame)
|
||||||
# save
|
# # save
|
||||||
painted_image = Image.fromarray(painted_image)
|
# painted_image = Image.fromarray(painted_image)
|
||||||
painted_image.save(f'/ssd1/gaomingqi/failure/LJ/{ti:05d}.png')
|
# painted_image.save(f'/ssd1/gaomingqi/failure/LJ/{ti:05d}.png')
|
||||||
prob = Image.fromarray((probs[1].cpu().numpy()*255).astype('uint8'))
|
# prob = Image.fromarray((probs[1].cpu().numpy()*255).astype('uint8'))
|
||||||
|
|
||||||
# prob.save(f'/ssd1/gaomingqi/failure/probs/{ti:05d}.png')
|
# # prob.save(f'/ssd1/gaomingqi/failure/probs/{ti:05d}.png')
|
||||||
|
|||||||
Reference in New Issue
Block a user