mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 08:27:49 +01:00
update base_tracker
This commit is contained in:
@@ -9,6 +9,9 @@ import torch
|
||||
import yaml
|
||||
from model.network import XMem
|
||||
from inference.inference_core import InferenceCore
|
||||
# for data transormation
|
||||
from torchvision import transforms
|
||||
from dataset.range_transform import im_normalization
|
||||
|
||||
|
||||
class BaseTracker:
|
||||
@@ -24,14 +27,27 @@ class BaseTracker:
|
||||
network = XMem(config, xmem_checkpoint).to(device).eval()
|
||||
# initialise IncerenceCore
|
||||
self.tracker = InferenceCore(network, config)
|
||||
# set data transformation
|
||||
# self.data_transform =
|
||||
# data transformation
|
||||
self.im_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
im_normalization,
|
||||
])
|
||||
|
||||
def track(self, frames, first_frame_annotation):
|
||||
"""
|
||||
Input:
|
||||
frames: numpy arrays: T, H, W, 3 (T: number of frames)
|
||||
first_frame_annotation: numpy array: H, W
|
||||
|
||||
Output:
|
||||
masks: numpy arrays: T, H, W
|
||||
"""
|
||||
# data transformation
|
||||
for frame in frames:
|
||||
frame = self.im_transform(frame)
|
||||
|
||||
# tracking
|
||||
pass
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user