Files
Track-Anything/tracker/base_tracker.py

196 lines
7.6 KiB
Python
Raw Normal View History

2023-04-14 22:17:47 +08:00
# import for debugging
2023-04-14 03:13:58 +08:00
import os
import glob
import numpy as np
from PIL import Image
2023-04-14 22:17:47 +08:00
# import for base_tracker
2023-04-14 03:13:58 +08:00
import torch
import yaml
2023-04-14 12:37:38 +08:00
import torch.nn.functional as F
2023-04-14 03:13:58 +08:00
from model.network import XMem
from inference.inference_core import InferenceCore
2023-04-14 12:37:38 +08:00
from util.mask_mapper import MaskMapper
2023-04-14 04:10:51 +08:00
from torchvision import transforms
2023-04-14 12:37:38 +08:00
from util.range_transform import im_normalization
2023-04-14 05:37:08 +08:00
import sys
sys.path.insert(0, sys.path[0]+"/../")
from tools.painter import mask_painter
2023-04-16 16:49:38 +08:00
from tools.base_segmenter import BaseSegmenter
from torchvision.transforms import Resize
2023-04-14 03:13:58 +08:00
class BaseTracker:
2023-04-16 16:49:38 +08:00
def __init__(self, xmem_checkpoint, device, sam_checkpoint, model_type) -> None:
2023-04-14 03:13:58 +08:00
"""
device: model device
xmem_checkpoint: checkpoint of XMem model
"""
# load configurations
with open("tracker/config/config.yaml", 'r') as stream:
config = yaml.safe_load(stream)
# initialise XMem
network = XMem(config, xmem_checkpoint).to(device).eval()
# initialise IncerenceCore
self.tracker = InferenceCore(network, config)
2023-04-14 04:10:51 +08:00
# data transformation
self.im_transform = transforms.Compose([
transforms.ToTensor(),
im_normalization,
])
2023-04-14 05:37:08 +08:00
self.device = device
2023-04-14 22:17:47 +08:00
# changable properties
2023-04-14 12:37:38 +08:00
self.mapper = MaskMapper()
self.initialised = False
2023-04-16 16:49:38 +08:00
# SAM-based refinement
self.sam_model = BaseSegmenter(sam_checkpoint, model_type, device=device)
self.resizer = Resize([256, 256])
2023-04-14 10:17:41 +08:00
@torch.no_grad()
2023-04-14 05:37:08 +08:00
def resize_mask(self, mask):
# mask transform is applied AFTER mapper, so we need to post-process it in eval.py
h, w = mask.shape[-2:]
min_hw = min(h, w)
return F.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)),
mode='nearest')
2023-04-14 03:13:58 +08:00
2023-04-14 10:17:41 +08:00
@torch.no_grad()
2023-04-14 12:37:38 +08:00
def track(self, frame, first_frame_annotation=None):
2023-04-14 04:10:51 +08:00
"""
Input:
2023-04-14 12:37:38 +08:00
frames: numpy arrays (H, W, 3)
2023-04-16 16:49:38 +08:00
logit: numpy array (H, W), logit
2023-04-14 04:10:51 +08:00
Output:
2023-04-14 12:37:38 +08:00
mask: numpy arrays (H, W)
2023-04-16 16:49:38 +08:00
logit: numpy arrays, probability map (H, W)
2023-04-14 12:37:38 +08:00
painted_image: numpy array (H, W, 3)
2023-04-14 04:10:51 +08:00
"""
2023-04-16 16:49:38 +08:00
if first_frame_annotation is not None: # first frame mask
2023-04-14 12:37:38 +08:00
# initialisation
mask, labels = self.mapper.convert_mask(first_frame_annotation)
mask = torch.Tensor(mask).to(self.device)
self.tracker.set_all_labels(list(self.mapper.remappings.values()))
else:
mask = None
labels = None
# prepare inputs
frame_tensor = self.im_transform(frame).to(self.device)
# track one frame
2023-04-16 16:49:38 +08:00
probs, logits = self.tracker.step(frame_tensor, mask, labels) # logits 2 (bg fg) H W
# # refine
# if first_frame_annotation is None:
# out_mask = self.sam_refinement(frame, logits[1], ti)
2023-04-14 12:37:38 +08:00
# convert to mask
2023-04-16 16:49:38 +08:00
out_mask = torch.argmax(probs, dim=0)
2023-04-14 12:37:38 +08:00
out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
painted_image = mask_painter(frame, out_mask)
2023-04-16 16:49:38 +08:00
return out_mask, probs, painted_image
2023-04-14 12:37:38 +08:00
2023-04-16 16:49:38 +08:00
@torch.no_grad()
def sam_refinement(self, frame, logits, ti):
"""
refine segmentation results with mask prompt
"""
# convert to 1, 256, 256
self.sam_model.set_image(frame)
mode = 'mask'
logits = logits.unsqueeze(0)
logits = self.resizer(logits).cpu().numpy()
prompts = {'mask_input': logits} # 1 256 256
masks, scores, logits = self.sam_model.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256)
painted_image = mask_painter(frame, masks[np.argmax(scores)].astype('uint8'), mask_alpha=0.8)
painted_image = Image.fromarray(painted_image)
painted_image.save(f'/ssd1/gaomingqi/refine/{ti:05d}.png')
self.sam_model.reset_image()
2023-04-14 12:37:38 +08:00
@torch.no_grad()
def clear_memory(self):
self.tracker.clear_memory()
self.mapper.clear_labels()
2023-04-14 03:13:58 +08:00
if __name__ == '__main__':
# video frames
video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/dance-twirl', '*.jpg'))
video_path_list.sort()
# first frame
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/dance-twirl/00000.png'
# load frames
2023-04-14 05:37:08 +08:00
frames = []
2023-04-14 03:13:58 +08:00
for video_path in video_path_list:
frames.append(np.array(Image.open(video_path).convert('RGB')))
frames = np.stack(frames, 0) # N, H, W, C
# load first frame annotation
first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C
2023-04-14 05:37:08 +08:00
# ----------------------------------------------------------
2023-04-14 03:13:58 +08:00
# initalise tracker
2023-04-14 05:37:08 +08:00
# ----------------------------------------------------------
2023-04-14 10:17:41 +08:00
device = 'cuda:1'
2023-04-14 03:13:58 +08:00
XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth'
2023-04-16 16:49:38 +08:00
SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
model_type = 'vit_h'
tracker = BaseTracker(XMEM_checkpoint, device, SAM_checkpoint, model_type)
2023-04-14 03:13:58 +08:00
2023-04-16 16:49:38 +08:00
# # track anything given in the first frame annotation
# for ti, frame in enumerate(frames):
# if ti == 0:
# mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
# else:
# mask, prob, painted_image = tracker.track(frame)
# # save
# painted_image = Image.fromarray(painted_image)
# painted_image.save(f'/ssd1/gaomingqi/results/TrackA/dance-twirl/{ti:05d}.png')
2023-04-14 12:37:38 +08:00
2023-04-16 16:49:38 +08:00
# # ----------------------------------------------------------
# # another video
# # ----------------------------------------------------------
# # video frames
# video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/camel', '*.jpg'))
# video_path_list.sort()
# # first frame
# first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/camel/00000.png'
# # load frames
# frames = []
# for video_path in video_path_list:
# frames.append(np.array(Image.open(video_path).convert('RGB')))
# frames = np.stack(frames, 0) # N, H, W, C
# # load first frame annotation
# first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C
2023-04-14 12:37:38 +08:00
2023-04-16 16:49:38 +08:00
# print('first video done. clear.')
# tracker.clear_memory()
# # track anything given in the first frame annotation
# for ti, frame in enumerate(frames):
# if ti == 0:
# mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
# else:
# mask, prob, painted_image = tracker.track(frame)
# # save
# painted_image = Image.fromarray(painted_image)
# painted_image.save(f'/ssd1/gaomingqi/results/TrackA/camel/{ti:05d}.png')
# failure case test
failure_path = '/ssd1/gaomingqi/failure'
frames = np.load(os.path.join(failure_path, 'video_frames.npy'))
# first_frame = np.array(Image.open(os.path.join(failure_path, 'template_frame.png')).convert('RGB'))
first_mask = np.array(Image.open(os.path.join(failure_path, 'template_mask.png')).convert('P'))
first_mask = np.clip(first_mask, 0, 1)
2023-04-14 10:17:41 +08:00
2023-04-14 12:37:38 +08:00
for ti, frame in enumerate(frames):
if ti == 0:
2023-04-16 16:49:38 +08:00
mask, probs, painted_image = tracker.track(frame, first_mask)
2023-04-14 12:37:38 +08:00
else:
2023-04-16 16:49:38 +08:00
mask, probs, painted_image = tracker.track(frame)
2023-04-14 10:17:41 +08:00
# save
painted_image = Image.fromarray(painted_image)
2023-04-16 16:49:38 +08:00
painted_image.save(f'/ssd1/gaomingqi/failure/LJ/{ti:05d}.png')
prob = Image.fromarray((probs[1].cpu().numpy()*255).astype('uint8'))
# prob.save(f'/ssd1/gaomingqi/failure/probs/{ti:05d}.png')