From 6fd26eb9e80a1870815ba28f8c61288eb35e8c46 Mon Sep 17 00:00:00 2001 From: gaomingqi Date: Fri, 14 Apr 2023 04:10:51 +0800 Subject: [PATCH] update base_tracker --- tracker/base_tracker.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/tracker/base_tracker.py b/tracker/base_tracker.py index 67940e1..f128bfd 100644 --- a/tracker/base_tracker.py +++ b/tracker/base_tracker.py @@ -9,6 +9,9 @@ import torch import yaml from model.network import XMem from inference.inference_core import InferenceCore +# for data transormation +from torchvision import transforms +from dataset.range_transform import im_normalization class BaseTracker: @@ -24,14 +27,27 @@ class BaseTracker: network = XMem(config, xmem_checkpoint).to(device).eval() # initialise IncerenceCore self.tracker = InferenceCore(network, config) - # set data transformation - # self.data_transform = + # data transformation + self.im_transform = transforms.Compose([ + transforms.ToTensor(), + im_normalization, + ]) def track(self, frames, first_frame_annotation): + """ + Input: + frames: numpy arrays: T, H, W, 3 (T: number of frames) + first_frame_annotation: numpy array: H, W + + Output: + masks: numpy arrays: T, H, W + """ # data transformation + for frame in frames: + frame = self.im_transform(frame) # tracking - pass + if __name__ == '__main__':