mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-15 16:07:51 +01:00
add base_tracker
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -5,3 +5,4 @@ docs/
|
||||
*.mp4
|
||||
debug_images/
|
||||
*.png
|
||||
*.jpg
|
||||
|
||||
@@ -13,3 +13,4 @@ matplotlib
|
||||
onnxruntime
|
||||
onnx
|
||||
metaseg
|
||||
pyyaml
|
||||
|
||||
59
tracker/base_tracker.py
Normal file
59
tracker/base_tracker.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# input: frame list, first frame mask
|
||||
# output: segmentation results on all frames
|
||||
import os
|
||||
import glob
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from model.network import XMem
|
||||
from inference.inference_core import InferenceCore
|
||||
|
||||
|
||||
class BaseTracker:
|
||||
def __init__(self, device, xmem_checkpoint) -> None:
|
||||
"""
|
||||
device: model device
|
||||
xmem_checkpoint: checkpoint of XMem model
|
||||
"""
|
||||
# load configurations
|
||||
with open("tracker/config/config.yaml", 'r') as stream:
|
||||
config = yaml.safe_load(stream)
|
||||
# initialise XMem
|
||||
network = XMem(config, xmem_checkpoint).to(device).eval()
|
||||
# initialise IncerenceCore
|
||||
self.tracker = InferenceCore(network, config)
|
||||
# set data transformation
|
||||
# self.data_transform =
|
||||
|
||||
def track(self, frames, first_frame_annotation):
|
||||
# data transformation
|
||||
|
||||
# tracking
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# video frames
|
||||
video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/dance-twirl', '*.jpg'))
|
||||
video_path_list.sort()
|
||||
# first frame
|
||||
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/dance-twirl/00000.png'
|
||||
|
||||
# load frames
|
||||
frames = []
|
||||
for video_path in video_path_list:
|
||||
frames.append(np.array(Image.open(video_path).convert('RGB')))
|
||||
frames = np.stack(frames, 0) # N, H, W, C
|
||||
|
||||
# load first frame annotation
|
||||
first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C
|
||||
|
||||
# initalise tracker
|
||||
device = 'cuda:0'
|
||||
XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth'
|
||||
tracker = BaseTracker('cuda:0', XMEM_checkpoint)
|
||||
|
||||
# track anything given in the first frame annotation
|
||||
tracker.track(frames, first_frame_annotation)
|
||||
15
tracker/config/config.yaml
Normal file
15
tracker/config/config.yaml
Normal file
@@ -0,0 +1,15 @@
|
||||
# config info for XMem
|
||||
benchmark: False
|
||||
disable_long_term: False
|
||||
max_mid_term_frames: 10
|
||||
min_mid_term_frames: 5
|
||||
max_long_term_elements: 10000
|
||||
num_prototypes: 128
|
||||
top_k: 30
|
||||
mem_every: 5
|
||||
deep_update_every: -1
|
||||
save_scores: False
|
||||
flip: False
|
||||
size: 480
|
||||
enable_long_term: True
|
||||
enable_long_term_count_usage: True
|
||||
@@ -213,4 +213,4 @@ class VOSDataset(Dataset):
|
||||
return data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.videos)
|
||||
return len(self.videos)
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
# input: frame list, first frame mask
|
||||
# output: segmentation results on all frames
|
||||
import os
|
||||
import glob
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class XMem:
|
||||
# based on https://github.com/hkchengrex/XMem
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# video frames
|
||||
video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/dance-twirl', '*.jpg'))
|
||||
video_path_list.sort()
|
||||
# first frame
|
||||
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/dance-twirl/00000.png'
|
||||
|
||||
# load frames
|
||||
frames = []
|
||||
for video_path in video_path_list:
|
||||
frames.append(np.array(Image.open(video_path).convert('RGB')))
|
||||
frames = np.stack(frames, 0) # N, H, W, C
|
||||
|
||||
# load first frame annotation
|
||||
first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C
|
||||
|
||||
Reference in New Issue
Block a user