mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 16:37:58 +01:00
Merge branch 'master' of github.com:gaomingqi/Track-Anything
This commit is contained in:
2
app.py
2
app.py
@@ -81,7 +81,7 @@ def get_frames_from_video(video_input, play_state):
|
|||||||
video_path:str
|
video_path:str
|
||||||
timestamp:float64
|
timestamp:float64
|
||||||
Return
|
Return
|
||||||
[[0:nearest_frame-1], [nearest_frame+1], nearest_frame]
|
[[0:nearest_frame], [nearest_frame:], nearest_frame]
|
||||||
"""
|
"""
|
||||||
video_path = video_input
|
video_path = video_input
|
||||||
timestamp = play_state[1] - play_state[0]
|
timestamp = play_state[1] - play_state[0]
|
||||||
|
|||||||
@@ -9,9 +9,16 @@ import torch
|
|||||||
import yaml
|
import yaml
|
||||||
from model.network import XMem
|
from model.network import XMem
|
||||||
from inference.inference_core import InferenceCore
|
from inference.inference_core import InferenceCore
|
||||||
|
from inference.data.mask_mapper import MaskMapper
|
||||||
|
|
||||||
# for data transormation
|
# for data transormation
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from dataset.range_transform import im_normalization
|
from dataset.range_transform import im_normalization
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, sys.path[0]+"/../")
|
||||||
|
from tools.painter import mask_painter
|
||||||
|
|
||||||
|
|
||||||
class BaseTracker:
|
class BaseTracker:
|
||||||
@@ -32,7 +39,18 @@ class BaseTracker:
|
|||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
im_normalization,
|
im_normalization,
|
||||||
])
|
])
|
||||||
|
self.mapper = MaskMapper()
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def resize_mask(self, mask):
|
||||||
|
# mask transform is applied AFTER mapper, so we need to post-process it in eval.py
|
||||||
|
h, w = mask.shape[-2:]
|
||||||
|
min_hw = min(h, w)
|
||||||
|
return F.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)),
|
||||||
|
mode='nearest')
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def track(self, frames, first_frame_annotation):
|
def track(self, frames, first_frame_annotation):
|
||||||
"""
|
"""
|
||||||
Input:
|
Input:
|
||||||
@@ -40,14 +58,30 @@ class BaseTracker:
|
|||||||
first_frame_annotation: numpy array: H, W
|
first_frame_annotation: numpy array: H, W
|
||||||
|
|
||||||
Output:
|
Output:
|
||||||
masks: numpy arrays: T, H, W
|
masks: numpy arrays: H, W
|
||||||
"""
|
"""
|
||||||
# data transformation
|
vid_length = len(frames)
|
||||||
for frame in frames:
|
masks = []
|
||||||
frame = self.im_transform(frame)
|
|
||||||
|
|
||||||
# tracking
|
for ti, frame in enumerate(frames):
|
||||||
|
# convert to tensor
|
||||||
|
frame_tensor = self.im_transform(frame).to(self.device)
|
||||||
|
if ti == 0:
|
||||||
|
mask, labels = self.mapper.convert_mask(first_frame_annotation)
|
||||||
|
mask = torch.Tensor(mask).to(self.device)
|
||||||
|
self.tracker.set_all_labels(list(self.mapper.remappings.values()))
|
||||||
|
else:
|
||||||
|
mask = None
|
||||||
|
labels = None
|
||||||
|
|
||||||
|
# track one frame
|
||||||
|
prob = self.tracker.step(frame_tensor, mask, labels, end=(ti==vid_length-1))
|
||||||
|
# convert to mask
|
||||||
|
out_mask = torch.argmax(prob, dim=0)
|
||||||
|
out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
|
||||||
|
masks.append(out_mask)
|
||||||
|
|
||||||
|
return np.stack(masks, 0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@@ -56,20 +90,27 @@ if __name__ == '__main__':
|
|||||||
video_path_list.sort()
|
video_path_list.sort()
|
||||||
# first frame
|
# first frame
|
||||||
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/dance-twirl/00000.png'
|
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/dance-twirl/00000.png'
|
||||||
|
|
||||||
# load frames
|
# load frames
|
||||||
frames = ["test_confict"]
|
frames = []
|
||||||
for video_path in video_path_list:
|
for video_path in video_path_list:
|
||||||
frames.append(np.array(Image.open(video_path).convert('RGB')))
|
frames.append(np.array(Image.open(video_path).convert('RGB')))
|
||||||
frames = np.stack(frames, 0) # N, H, W, C
|
frames = np.stack(frames, 0) # N, H, W, C
|
||||||
|
|
||||||
# load first frame annotation
|
# load first frame annotation
|
||||||
first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C
|
first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C
|
||||||
|
|
||||||
|
# ----------------------------------------------------------
|
||||||
# initalise tracker
|
# initalise tracker
|
||||||
device = 'cuda:0'
|
# ----------------------------------------------------------
|
||||||
|
device = 'cuda:1'
|
||||||
XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth'
|
XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth'
|
||||||
tracker = BaseTracker('cuda:0', XMEM_checkpoint)
|
tracker = BaseTracker(device, XMEM_checkpoint)
|
||||||
|
|
||||||
# track anything given in the first frame annotation
|
# track anything given in the first frame annotation
|
||||||
tracker.track(frames, first_frame_annotation)
|
masks = tracker.track(frames, first_frame_annotation)
|
||||||
|
|
||||||
|
# save
|
||||||
|
for ti, (frame, mask) in enumerate(zip(frames, masks)):
|
||||||
|
painted_image = mask_painter(frame, mask)
|
||||||
|
# save
|
||||||
|
painted_image = Image.fromarray(painted_image)
|
||||||
|
painted_image.save(f'/ssd1/gaomingqi/results/TrackA/{ti:05d}.png')
|
||||||
|
|||||||
Reference in New Issue
Block a user