mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 16:37:58 +01:00
60 lines
1.8 KiB
Python
60 lines
1.8 KiB
Python
|
|
# input: frame list, first frame mask
|
||
|
|
# output: segmentation results on all frames
|
||
|
|
import os
|
||
|
|
import glob
|
||
|
|
import numpy as np
|
||
|
|
from PIL import Image
|
||
|
|
|
||
|
|
import torch
|
||
|
|
import yaml
|
||
|
|
from model.network import XMem
|
||
|
|
from inference.inference_core import InferenceCore
|
||
|
|
|
||
|
|
|
||
|
|
class BaseTracker:
|
||
|
|
def __init__(self, device, xmem_checkpoint) -> None:
|
||
|
|
"""
|
||
|
|
device: model device
|
||
|
|
xmem_checkpoint: checkpoint of XMem model
|
||
|
|
"""
|
||
|
|
# load configurations
|
||
|
|
with open("tracker/config/config.yaml", 'r') as stream:
|
||
|
|
config = yaml.safe_load(stream)
|
||
|
|
# initialise XMem
|
||
|
|
network = XMem(config, xmem_checkpoint).to(device).eval()
|
||
|
|
# initialise IncerenceCore
|
||
|
|
self.tracker = InferenceCore(network, config)
|
||
|
|
# set data transformation
|
||
|
|
# self.data_transform =
|
||
|
|
|
||
|
|
def track(self, frames, first_frame_annotation):
|
||
|
|
# data transformation
|
||
|
|
|
||
|
|
# tracking
|
||
|
|
pass
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == '__main__':
|
||
|
|
# video frames
|
||
|
|
video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/dance-twirl', '*.jpg'))
|
||
|
|
video_path_list.sort()
|
||
|
|
# first frame
|
||
|
|
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/dance-twirl/00000.png'
|
||
|
|
|
||
|
|
# load frames
|
||
|
|
frames = []
|
||
|
|
for video_path in video_path_list:
|
||
|
|
frames.append(np.array(Image.open(video_path).convert('RGB')))
|
||
|
|
frames = np.stack(frames, 0) # N, H, W, C
|
||
|
|
|
||
|
|
# load first frame annotation
|
||
|
|
first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C
|
||
|
|
|
||
|
|
# initalise tracker
|
||
|
|
device = 'cuda:0'
|
||
|
|
XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth'
|
||
|
|
tracker = BaseTracker('cuda:0', XMEM_checkpoint)
|
||
|
|
|
||
|
|
# track anything given in the first frame annotation
|
||
|
|
tracker.track(frames, first_frame_annotation)
|