update base_tracker

This commit is contained in:
gaomingqi
2023-04-14 22:17:47 +08:00
parent 2809c45402
commit 746d276121

View File

@@ -1,10 +1,9 @@
# input: frame list, first frame mask
# output: segmentation results on all frames
# import for debugging
import os
import glob
import numpy as np
from PIL import Image
# import for base_tracker
import torch
import yaml
import torch.nn.functional as F
@@ -13,8 +12,6 @@ from inference.inference_core import InferenceCore
from util.mask_mapper import MaskMapper
from torchvision import transforms
from util.range_transform import im_normalization
import sys
sys.path.insert(0, sys.path[0]+"/../")
from tools.painter import mask_painter
@@ -39,7 +36,8 @@ class BaseTracker:
im_normalization,
])
self.device = device
# changable properties
self.mapper = MaskMapper()
self.initialised = False
@@ -109,7 +107,7 @@ if __name__ == '__main__':
# ----------------------------------------------------------
device = 'cuda:1'
XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth'
tracker = BaseTracker(device, XMEM_checkpoint)
tracker = BaseTracker(XMEM_checkpoint, device)
# track anything given in the first frame annotation
for ti, frame in enumerate(frames):