mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 08:27:49 +01:00
update base_tracker
This commit is contained in:
@@ -1,10 +1,9 @@
|
||||
# input: frame list, first frame mask
|
||||
# output: segmentation results on all frames
|
||||
# import for debugging
|
||||
import os
|
||||
import glob
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
# import for base_tracker
|
||||
import torch
|
||||
import yaml
|
||||
import torch.nn.functional as F
|
||||
@@ -13,8 +12,6 @@ from inference.inference_core import InferenceCore
|
||||
from util.mask_mapper import MaskMapper
|
||||
from torchvision import transforms
|
||||
from util.range_transform import im_normalization
|
||||
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, sys.path[0]+"/../")
|
||||
from tools.painter import mask_painter
|
||||
@@ -39,7 +36,8 @@ class BaseTracker:
|
||||
im_normalization,
|
||||
])
|
||||
self.device = device
|
||||
|
||||
|
||||
# changable properties
|
||||
self.mapper = MaskMapper()
|
||||
self.initialised = False
|
||||
|
||||
@@ -109,7 +107,7 @@ if __name__ == '__main__':
|
||||
# ----------------------------------------------------------
|
||||
device = 'cuda:1'
|
||||
XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth'
|
||||
tracker = BaseTracker(device, XMEM_checkpoint)
|
||||
tracker = BaseTracker(XMEM_checkpoint, device)
|
||||
|
||||
# track anything given in the first frame annotation
|
||||
for ti, frame in enumerate(frames):
|
||||
|
||||
Reference in New Issue
Block a user