Merge branch 'master' of github.com:gaomingqi/Track-Anything

This commit is contained in:
memoryunreal
2023-04-14 02:27:50 +00:00
2 changed files with 54 additions and 13 deletions

2
app.py
View File

@@ -81,7 +81,7 @@ def get_frames_from_video(video_input, play_state):
video_path:str video_path:str
timestamp:float64 timestamp:float64
Return Return
[[0:nearest_frame-1], [nearest_frame+1], nearest_frame] [[0:nearest_frame], [nearest_frame:], nearest_frame]
""" """
video_path = video_input video_path = video_input
timestamp = play_state[1] - play_state[0] timestamp = play_state[1] - play_state[0]

View File

@@ -9,9 +9,16 @@ import torch
import yaml import yaml
from model.network import XMem from model.network import XMem
from inference.inference_core import InferenceCore from inference.inference_core import InferenceCore
from inference.data.mask_mapper import MaskMapper
# for data transormation # for data transormation
from torchvision import transforms from torchvision import transforms
from dataset.range_transform import im_normalization from dataset.range_transform import im_normalization
import torch.nn.functional as F
import sys
sys.path.insert(0, sys.path[0]+"/../")
from tools.painter import mask_painter
class BaseTracker: class BaseTracker:
@@ -32,7 +39,18 @@ class BaseTracker:
transforms.ToTensor(), transforms.ToTensor(),
im_normalization, im_normalization,
]) ])
self.mapper = MaskMapper()
self.device = device
@torch.no_grad()
def resize_mask(self, mask):
# mask transform is applied AFTER mapper, so we need to post-process it in eval.py
h, w = mask.shape[-2:]
min_hw = min(h, w)
return F.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)),
mode='nearest')
@torch.no_grad()
def track(self, frames, first_frame_annotation): def track(self, frames, first_frame_annotation):
""" """
Input: Input:
@@ -40,14 +58,30 @@ class BaseTracker:
first_frame_annotation: numpy array: H, W first_frame_annotation: numpy array: H, W
Output: Output:
masks: numpy arrays: T, H, W masks: numpy arrays: H, W
""" """
# data transformation vid_length = len(frames)
for frame in frames: masks = []
frame = self.im_transform(frame)
# tracking for ti, frame in enumerate(frames):
# convert to tensor
frame_tensor = self.im_transform(frame).to(self.device)
if ti == 0:
mask, labels = self.mapper.convert_mask(first_frame_annotation)
mask = torch.Tensor(mask).to(self.device)
self.tracker.set_all_labels(list(self.mapper.remappings.values()))
else:
mask = None
labels = None
# track one frame
prob = self.tracker.step(frame_tensor, mask, labels, end=(ti==vid_length-1))
# convert to mask
out_mask = torch.argmax(prob, dim=0)
out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
masks.append(out_mask)
return np.stack(masks, 0)
if __name__ == '__main__': if __name__ == '__main__':
@@ -56,20 +90,27 @@ if __name__ == '__main__':
video_path_list.sort() video_path_list.sort()
# first frame # first frame
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/dance-twirl/00000.png' first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/dance-twirl/00000.png'
# load frames # load frames
frames = ["test_confict"] frames = []
for video_path in video_path_list: for video_path in video_path_list:
frames.append(np.array(Image.open(video_path).convert('RGB'))) frames.append(np.array(Image.open(video_path).convert('RGB')))
frames = np.stack(frames, 0) # N, H, W, C frames = np.stack(frames, 0) # N, H, W, C
# load first frame annotation # load first frame annotation
first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C
# ----------------------------------------------------------
# initalise tracker # initalise tracker
device = 'cuda:0' # ----------------------------------------------------------
device = 'cuda:1'
XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth' XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth'
tracker = BaseTracker('cuda:0', XMEM_checkpoint) tracker = BaseTracker(device, XMEM_checkpoint)
# track anything given in the first frame annotation # track anything given in the first frame annotation
tracker.track(frames, first_frame_annotation) masks = tracker.track(frames, first_frame_annotation)
# save
for ti, (frame, mask) in enumerate(zip(frames, masks)):
painted_image = mask_painter(frame, mask)
# save
painted_image = Image.fromarray(painted_image)
painted_image.save(f'/ssd1/gaomingqi/results/TrackA/{ti:05d}.png')