update base_tracker

This commit is contained in:
gaomingqi
2023-04-14 04:10:51 +08:00
parent 996ed9bf37
commit 6fd26eb9e8

View File

@@ -9,6 +9,9 @@ import torch
import yaml import yaml
from model.network import XMem from model.network import XMem
from inference.inference_core import InferenceCore from inference.inference_core import InferenceCore
# for data transormation
from torchvision import transforms
from dataset.range_transform import im_normalization
class BaseTracker: class BaseTracker:
@@ -24,14 +27,27 @@ class BaseTracker:
network = XMem(config, xmem_checkpoint).to(device).eval() network = XMem(config, xmem_checkpoint).to(device).eval()
# initialise IncerenceCore # initialise IncerenceCore
self.tracker = InferenceCore(network, config) self.tracker = InferenceCore(network, config)
# set data transformation # data transformation
# self.data_transform = self.im_transform = transforms.Compose([
transforms.ToTensor(),
im_normalization,
])
def track(self, frames, first_frame_annotation): def track(self, frames, first_frame_annotation):
"""
Input:
frames: numpy arrays: T, H, W, 3 (T: number of frames)
first_frame_annotation: numpy array: H, W
Output:
masks: numpy arrays: T, H, W
"""
# data transformation # data transformation
for frame in frames:
frame = self.im_transform(frame)
# tracking # tracking
pass
if __name__ == '__main__': if __name__ == '__main__':