Files
Track-Anything/tracker/base_tracker.py

76 lines
2.3 KiB
Python
Raw Normal View History

2023-04-14 03:13:58 +08:00
# input: frame list, first frame mask
# output: segmentation results on all frames
import os
import glob
import numpy as np
from PIL import Image
import torch
import yaml
from model.network import XMem
from inference.inference_core import InferenceCore
2023-04-14 04:10:51 +08:00
# for data transormation
from torchvision import transforms
from dataset.range_transform import im_normalization
2023-04-14 03:13:58 +08:00
class BaseTracker:
def __init__(self, device, xmem_checkpoint) -> None:
"""
device: model device
xmem_checkpoint: checkpoint of XMem model
"""
# load configurations
with open("tracker/config/config.yaml", 'r') as stream:
config = yaml.safe_load(stream)
# initialise XMem
network = XMem(config, xmem_checkpoint).to(device).eval()
# initialise IncerenceCore
self.tracker = InferenceCore(network, config)
2023-04-14 04:10:51 +08:00
# data transformation
self.im_transform = transforms.Compose([
transforms.ToTensor(),
im_normalization,
])
2023-04-14 03:13:58 +08:00
def track(self, frames, first_frame_annotation):
2023-04-14 04:10:51 +08:00
"""
Input:
frames: numpy arrays: T, H, W, 3 (T: number of frames)
first_frame_annotation: numpy array: H, W
Output:
masks: numpy arrays: T, H, W
"""
2023-04-14 03:13:58 +08:00
# data transformation
2023-04-14 04:10:51 +08:00
for frame in frames:
frame = self.im_transform(frame)
2023-04-14 03:13:58 +08:00
# tracking
2023-04-14 04:10:51 +08:00
2023-04-14 03:13:58 +08:00
if __name__ == '__main__':
# video frames
video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/dance-twirl', '*.jpg'))
video_path_list.sort()
# first frame
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/dance-twirl/00000.png'
# load frames
frames = []
for video_path in video_path_list:
frames.append(np.array(Image.open(video_path).convert('RGB')))
frames = np.stack(frames, 0) # N, H, W, C
# load first frame annotation
first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C
# initalise tracker
device = 'cuda:0'
XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth'
tracker = BaseTracker('cuda:0', XMEM_checkpoint)
# track anything given in the first frame annotation
tracker.track(frames, first_frame_annotation)