mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 16:37:58 +01:00
mingqi
This commit is contained in:
@@ -39,8 +39,10 @@ class BaseTracker:
|
||||
transforms.ToTensor(),
|
||||
im_normalization,
|
||||
])
|
||||
self.mapper = MaskMapper()
|
||||
self.device = device
|
||||
|
||||
@torch.no_grad()
|
||||
def resize_mask(self, mask):
|
||||
# mask transform is applied AFTER mapper, so we need to post-process it in eval.py
|
||||
h, w = mask.shape[-2:]
|
||||
@@ -48,6 +50,7 @@ class BaseTracker:
|
||||
return F.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)),
|
||||
mode='nearest')
|
||||
|
||||
@torch.no_grad()
|
||||
def track(self, frames, first_frame_annotation):
|
||||
"""
|
||||
Input:
|
||||
@@ -57,34 +60,28 @@ class BaseTracker:
|
||||
Output:
|
||||
masks: numpy arrays: H, W
|
||||
"""
|
||||
shape = np.array(frames).shape[1:3] # H, W
|
||||
frame_list = [self.im_transform(frame).to(self.device) for frame in frames]
|
||||
frame_tensors = torch.stack(frame_list, dim=0)
|
||||
vid_length = len(frames)
|
||||
masks = []
|
||||
|
||||
# data transformation
|
||||
mapper = MaskMapper()
|
||||
|
||||
vid_length = len(frame_tensors)
|
||||
|
||||
for ti, frame_tensor in enumerate(frame_tensors):
|
||||
for ti, frame in enumerate(frames):
|
||||
# convert to tensor
|
||||
frame_tensor = self.im_transform(frame).to(self.device)
|
||||
if ti == 0:
|
||||
mask, labels = mapper.convert_mask(first_frame_annotation)
|
||||
mask, labels = self.mapper.convert_mask(first_frame_annotation)
|
||||
mask = torch.Tensor(mask).to(self.device)
|
||||
self.tracker.set_all_labels(list(mapper.remappings.values()))
|
||||
self.tracker.set_all_labels(list(self.mapper.remappings.values()))
|
||||
else:
|
||||
mask = None
|
||||
labels = None
|
||||
|
||||
# track one frame
|
||||
prob = self.tracker.step(frame_tensor, mask, labels, end=(ti==vid_length-1))
|
||||
|
||||
# convert to mask
|
||||
out_mask = torch.argmax(prob, dim=0)
|
||||
out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
|
||||
masks.append(out_mask)
|
||||
|
||||
painted_image = mask_painter(frames[ti], out_mask)
|
||||
# save
|
||||
painted_image = Image.fromarray(painted_image)
|
||||
painted_image.save(f'/ssd1/gaomingqi/results/TrackA/{ti}.png')
|
||||
return np.stack(masks, 0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@@ -94,7 +91,7 @@ if __name__ == '__main__':
|
||||
# first frame
|
||||
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/dance-twirl/00000.png'
|
||||
# load frames
|
||||
frames = ["test_confict"]
|
||||
frames = []
|
||||
for video_path in video_path_list:
|
||||
frames.append(np.array(Image.open(video_path).convert('RGB')))
|
||||
frames = np.stack(frames, 0) # N, H, W, C
|
||||
@@ -104,9 +101,16 @@ if __name__ == '__main__':
|
||||
# ----------------------------------------------------------
|
||||
# initalise tracker
|
||||
# ----------------------------------------------------------
|
||||
device = 'cuda:0'
|
||||
device = 'cuda:1'
|
||||
XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth'
|
||||
tracker = BaseTracker('cuda:0', XMEM_checkpoint)
|
||||
tracker = BaseTracker(device, XMEM_checkpoint)
|
||||
|
||||
# track anything given in the first frame annotation
|
||||
tracker.track(frames, first_frame_annotation)
|
||||
masks = tracker.track(frames, first_frame_annotation)
|
||||
|
||||
# save
|
||||
for ti, (frame, mask) in enumerate(zip(frames, masks)):
|
||||
painted_image = mask_painter(frame, mask)
|
||||
# save
|
||||
painted_image = Image.fromarray(painted_image)
|
||||
painted_image.save(f'/ssd1/gaomingqi/results/TrackA/{ti:05d}.png')
|
||||
|
||||
Reference in New Issue
Block a user