mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 16:37:58 +01:00
update base_tracker
This commit is contained in:
@@ -9,6 +9,9 @@ import torch
|
|||||||
import yaml
|
import yaml
|
||||||
from model.network import XMem
|
from model.network import XMem
|
||||||
from inference.inference_core import InferenceCore
|
from inference.inference_core import InferenceCore
|
||||||
|
# for data transormation
|
||||||
|
from torchvision import transforms
|
||||||
|
from dataset.range_transform import im_normalization
|
||||||
|
|
||||||
|
|
||||||
class BaseTracker:
|
class BaseTracker:
|
||||||
@@ -24,14 +27,27 @@ class BaseTracker:
|
|||||||
network = XMem(config, xmem_checkpoint).to(device).eval()
|
network = XMem(config, xmem_checkpoint).to(device).eval()
|
||||||
# initialise IncerenceCore
|
# initialise IncerenceCore
|
||||||
self.tracker = InferenceCore(network, config)
|
self.tracker = InferenceCore(network, config)
|
||||||
# set data transformation
|
# data transformation
|
||||||
# self.data_transform =
|
self.im_transform = transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
im_normalization,
|
||||||
|
])
|
||||||
|
|
||||||
def track(self, frames, first_frame_annotation):
|
def track(self, frames, first_frame_annotation):
|
||||||
|
"""
|
||||||
|
Input:
|
||||||
|
frames: numpy arrays: T, H, W, 3 (T: number of frames)
|
||||||
|
first_frame_annotation: numpy array: H, W
|
||||||
|
|
||||||
|
Output:
|
||||||
|
masks: numpy arrays: T, H, W
|
||||||
|
"""
|
||||||
# data transformation
|
# data transformation
|
||||||
|
for frame in frames:
|
||||||
|
frame = self.im_transform(frame)
|
||||||
|
|
||||||
# tracking
|
# tracking
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
Reference in New Issue
Block a user