2023-04-14 03:13:58 +08:00
|
|
|
# input: frame list, first frame mask
|
|
|
|
|
# output: segmentation results on all frames
|
|
|
|
|
import os
|
|
|
|
|
import glob
|
|
|
|
|
import numpy as np
|
|
|
|
|
from PIL import Image
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import yaml
|
|
|
|
|
from model.network import XMem
|
|
|
|
|
from inference.inference_core import InferenceCore
|
2023-04-14 05:37:08 +08:00
|
|
|
from inference.data.mask_mapper import MaskMapper
|
|
|
|
|
|
2023-04-14 04:10:51 +08:00
|
|
|
# for data transormation
|
|
|
|
|
from torchvision import transforms
|
|
|
|
|
from dataset.range_transform import im_normalization
|
2023-04-14 05:37:08 +08:00
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
import sys
|
|
|
|
|
sys.path.insert(0, sys.path[0]+"/../")
|
|
|
|
|
from tools.painter import mask_painter
|
2023-04-14 03:13:58 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseTracker:
|
|
|
|
|
def __init__(self, device, xmem_checkpoint) -> None:
|
|
|
|
|
"""
|
|
|
|
|
device: model device
|
|
|
|
|
xmem_checkpoint: checkpoint of XMem model
|
|
|
|
|
"""
|
|
|
|
|
# load configurations
|
|
|
|
|
with open("tracker/config/config.yaml", 'r') as stream:
|
|
|
|
|
config = yaml.safe_load(stream)
|
|
|
|
|
# initialise XMem
|
|
|
|
|
network = XMem(config, xmem_checkpoint).to(device).eval()
|
|
|
|
|
# initialise IncerenceCore
|
|
|
|
|
self.tracker = InferenceCore(network, config)
|
2023-04-14 04:10:51 +08:00
|
|
|
# data transformation
|
|
|
|
|
self.im_transform = transforms.Compose([
|
|
|
|
|
transforms.ToTensor(),
|
|
|
|
|
im_normalization,
|
|
|
|
|
])
|
2023-04-14 05:37:08 +08:00
|
|
|
self.device = device
|
|
|
|
|
|
|
|
|
|
def resize_mask(self, mask):
|
|
|
|
|
# mask transform is applied AFTER mapper, so we need to post-process it in eval.py
|
|
|
|
|
h, w = mask.shape[-2:]
|
|
|
|
|
min_hw = min(h, w)
|
|
|
|
|
return F.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)),
|
|
|
|
|
mode='nearest')
|
2023-04-14 03:13:58 +08:00
|
|
|
|
|
|
|
|
def track(self, frames, first_frame_annotation):
|
2023-04-14 04:10:51 +08:00
|
|
|
"""
|
|
|
|
|
Input:
|
|
|
|
|
frames: numpy arrays: T, H, W, 3 (T: number of frames)
|
|
|
|
|
first_frame_annotation: numpy array: H, W
|
|
|
|
|
|
|
|
|
|
Output:
|
2023-04-14 05:37:08 +08:00
|
|
|
masks: numpy arrays: H, W
|
2023-04-14 04:10:51 +08:00
|
|
|
"""
|
2023-04-14 05:37:08 +08:00
|
|
|
shape = np.array(frames).shape[1:3] # H, W
|
|
|
|
|
frame_list = [self.im_transform(frame).to(self.device) for frame in frames]
|
|
|
|
|
frame_tensors = torch.stack(frame_list, dim=0)
|
|
|
|
|
|
2023-04-14 03:13:58 +08:00
|
|
|
# data transformation
|
2023-04-14 05:37:08 +08:00
|
|
|
mapper = MaskMapper()
|
2023-04-14 03:13:58 +08:00
|
|
|
|
2023-04-14 05:37:08 +08:00
|
|
|
vid_length = len(frame_tensors)
|
|
|
|
|
|
|
|
|
|
for ti, frame_tensor in enumerate(frame_tensors):
|
|
|
|
|
if ti == 0:
|
|
|
|
|
mask, labels = mapper.convert_mask(first_frame_annotation)
|
|
|
|
|
mask = torch.Tensor(mask).to(self.device)
|
|
|
|
|
self.tracker.set_all_labels(list(mapper.remappings.values()))
|
|
|
|
|
else:
|
|
|
|
|
mask = None
|
|
|
|
|
labels = None
|
|
|
|
|
|
|
|
|
|
# track one frame
|
|
|
|
|
prob = self.tracker.step(frame_tensor, mask, labels, end=(ti==vid_length-1))
|
|
|
|
|
|
|
|
|
|
out_mask = torch.argmax(prob, dim=0)
|
|
|
|
|
out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
|
|
|
|
|
|
|
|
|
|
painted_image = mask_painter(frames[ti], out_mask)
|
|
|
|
|
# save
|
|
|
|
|
painted_image = Image.fromarray(painted_image)
|
|
|
|
|
painted_image.save(f'/ssd1/gaomingqi/results/TrackA/{ti}.png')
|
2023-04-14 03:13:58 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
# video frames
|
|
|
|
|
video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/dance-twirl', '*.jpg'))
|
|
|
|
|
video_path_list.sort()
|
|
|
|
|
# first frame
|
|
|
|
|
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/dance-twirl/00000.png'
|
|
|
|
|
# load frames
|
2023-04-14 05:37:08 +08:00
|
|
|
frames = []
|
2023-04-14 03:13:58 +08:00
|
|
|
for video_path in video_path_list:
|
|
|
|
|
frames.append(np.array(Image.open(video_path).convert('RGB')))
|
|
|
|
|
frames = np.stack(frames, 0) # N, H, W, C
|
|
|
|
|
# load first frame annotation
|
|
|
|
|
first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C
|
|
|
|
|
|
2023-04-14 05:37:08 +08:00
|
|
|
# ----------------------------------------------------------
|
2023-04-14 03:13:58 +08:00
|
|
|
# initalise tracker
|
2023-04-14 05:37:08 +08:00
|
|
|
# ----------------------------------------------------------
|
2023-04-14 03:13:58 +08:00
|
|
|
device = 'cuda:0'
|
|
|
|
|
XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth'
|
|
|
|
|
tracker = BaseTracker('cuda:0', XMEM_checkpoint)
|
|
|
|
|
|
|
|
|
|
# track anything given in the first frame annotation
|
|
|
|
|
tracker.track(frames, first_frame_annotation)
|