2023-04-14 22:17:47 +08:00
|
|
|
# import for debugging
|
2023-04-14 03:13:58 +08:00
|
|
|
import os
|
|
|
|
|
import glob
|
|
|
|
|
import numpy as np
|
|
|
|
|
from PIL import Image
|
2023-04-14 22:17:47 +08:00
|
|
|
# import for base_tracker
|
2023-04-14 03:13:58 +08:00
|
|
|
import torch
|
|
|
|
|
import yaml
|
2023-04-14 12:37:38 +08:00
|
|
|
import torch.nn.functional as F
|
2023-04-14 03:13:58 +08:00
|
|
|
from model.network import XMem
|
|
|
|
|
from inference.inference_core import InferenceCore
|
2023-04-14 12:37:38 +08:00
|
|
|
from util.mask_mapper import MaskMapper
|
2023-04-14 04:10:51 +08:00
|
|
|
from torchvision import transforms
|
2023-04-14 12:37:38 +08:00
|
|
|
from util.range_transform import im_normalization
|
2023-04-14 05:37:08 +08:00
|
|
|
import sys
|
|
|
|
|
sys.path.insert(0, sys.path[0]+"/../")
|
|
|
|
|
from tools.painter import mask_painter
|
2023-04-16 16:49:38 +08:00
|
|
|
from tools.base_segmenter import BaseSegmenter
|
|
|
|
|
from torchvision.transforms import Resize
|
2023-04-14 03:13:58 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseTracker:
|
2023-04-17 12:58:19 +00:00
|
|
|
def __init__(self, xmem_checkpoint, device, sam_model=None, model_type=None) -> None:
|
2023-04-14 03:13:58 +08:00
|
|
|
"""
|
|
|
|
|
device: model device
|
|
|
|
|
xmem_checkpoint: checkpoint of XMem model
|
|
|
|
|
"""
|
|
|
|
|
# load configurations
|
|
|
|
|
with open("tracker/config/config.yaml", 'r') as stream:
|
|
|
|
|
config = yaml.safe_load(stream)
|
|
|
|
|
# initialise XMem
|
|
|
|
|
network = XMem(config, xmem_checkpoint).to(device).eval()
|
|
|
|
|
# initialise IncerenceCore
|
|
|
|
|
self.tracker = InferenceCore(network, config)
|
2023-04-14 04:10:51 +08:00
|
|
|
# data transformation
|
|
|
|
|
self.im_transform = transforms.Compose([
|
|
|
|
|
transforms.ToTensor(),
|
|
|
|
|
im_normalization,
|
|
|
|
|
])
|
2023-04-14 05:37:08 +08:00
|
|
|
self.device = device
|
2023-04-14 22:17:47 +08:00
|
|
|
|
|
|
|
|
# changable properties
|
2023-04-14 12:37:38 +08:00
|
|
|
self.mapper = MaskMapper()
|
|
|
|
|
self.initialised = False
|
|
|
|
|
|
2023-04-17 19:02:32 +08:00
|
|
|
# # SAM-based refinement
|
|
|
|
|
# self.sam_model = sam_model
|
|
|
|
|
# self.resizer = Resize([256, 256])
|
2023-04-16 16:49:38 +08:00
|
|
|
|
2023-04-14 10:17:41 +08:00
|
|
|
@torch.no_grad()
|
2023-04-14 05:37:08 +08:00
|
|
|
def resize_mask(self, mask):
|
|
|
|
|
# mask transform is applied AFTER mapper, so we need to post-process it in eval.py
|
|
|
|
|
h, w = mask.shape[-2:]
|
|
|
|
|
min_hw = min(h, w)
|
|
|
|
|
return F.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)),
|
|
|
|
|
mode='nearest')
|
2023-04-14 03:13:58 +08:00
|
|
|
|
2023-04-14 10:17:41 +08:00
|
|
|
@torch.no_grad()
|
2023-04-14 12:37:38 +08:00
|
|
|
def track(self, frame, first_frame_annotation=None):
|
2023-04-14 04:10:51 +08:00
|
|
|
"""
|
|
|
|
|
Input:
|
2023-04-14 12:37:38 +08:00
|
|
|
frames: numpy arrays (H, W, 3)
|
2023-04-16 16:49:38 +08:00
|
|
|
logit: numpy array (H, W), logit
|
2023-04-14 04:10:51 +08:00
|
|
|
|
|
|
|
|
Output:
|
2023-04-14 12:37:38 +08:00
|
|
|
mask: numpy arrays (H, W)
|
2023-04-16 16:49:38 +08:00
|
|
|
logit: numpy arrays, probability map (H, W)
|
2023-04-14 12:37:38 +08:00
|
|
|
painted_image: numpy array (H, W, 3)
|
2023-04-14 04:10:51 +08:00
|
|
|
"""
|
2023-04-19 19:58:41 +08:00
|
|
|
|
2023-04-16 16:49:38 +08:00
|
|
|
if first_frame_annotation is not None: # first frame mask
|
2023-04-14 12:37:38 +08:00
|
|
|
# initialisation
|
|
|
|
|
mask, labels = self.mapper.convert_mask(first_frame_annotation)
|
|
|
|
|
mask = torch.Tensor(mask).to(self.device)
|
|
|
|
|
self.tracker.set_all_labels(list(self.mapper.remappings.values()))
|
|
|
|
|
else:
|
|
|
|
|
mask = None
|
|
|
|
|
labels = None
|
|
|
|
|
# prepare inputs
|
|
|
|
|
frame_tensor = self.im_transform(frame).to(self.device)
|
|
|
|
|
# track one frame
|
2023-04-17 19:02:32 +08:00
|
|
|
probs, _ = self.tracker.step(frame_tensor, mask, labels) # logits 2 (bg fg) H W
|
2023-04-16 16:49:38 +08:00
|
|
|
# # refine
|
|
|
|
|
# if first_frame_annotation is None:
|
|
|
|
|
# out_mask = self.sam_refinement(frame, logits[1], ti)
|
|
|
|
|
|
2023-04-14 12:37:38 +08:00
|
|
|
# convert to mask
|
2023-04-16 16:49:38 +08:00
|
|
|
out_mask = torch.argmax(probs, dim=0)
|
2023-04-14 12:37:38 +08:00
|
|
|
out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
|
2023-04-17 12:33:37 +08:00
|
|
|
|
2023-04-19 19:58:41 +08:00
|
|
|
final_mask = np.zeros_like(out_mask)
|
|
|
|
|
|
|
|
|
|
# map back
|
|
|
|
|
for k, v in self.mapper.remappings.items():
|
|
|
|
|
final_mask[out_mask == v] = k
|
|
|
|
|
|
|
|
|
|
num_objs = final_mask.max()
|
2023-04-17 12:33:37 +08:00
|
|
|
painted_image = frame
|
|
|
|
|
for obj in range(1, num_objs+1):
|
2023-04-19 19:58:41 +08:00
|
|
|
if np.max(final_mask==obj) == 0:
|
|
|
|
|
continue
|
|
|
|
|
painted_image = mask_painter(painted_image, (final_mask==obj).astype('uint8'), mask_color=obj+1)
|
|
|
|
|
|
|
|
|
|
return final_mask, final_mask, painted_image
|
2023-04-14 12:37:38 +08:00
|
|
|
|
2023-04-16 16:49:38 +08:00
|
|
|
@torch.no_grad()
|
|
|
|
|
def sam_refinement(self, frame, logits, ti):
|
|
|
|
|
"""
|
|
|
|
|
refine segmentation results with mask prompt
|
|
|
|
|
"""
|
|
|
|
|
# convert to 1, 256, 256
|
|
|
|
|
self.sam_model.set_image(frame)
|
|
|
|
|
mode = 'mask'
|
|
|
|
|
logits = logits.unsqueeze(0)
|
|
|
|
|
logits = self.resizer(logits).cpu().numpy()
|
|
|
|
|
prompts = {'mask_input': logits} # 1 256 256
|
|
|
|
|
masks, scores, logits = self.sam_model.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256)
|
|
|
|
|
painted_image = mask_painter(frame, masks[np.argmax(scores)].astype('uint8'), mask_alpha=0.8)
|
|
|
|
|
painted_image = Image.fromarray(painted_image)
|
|
|
|
|
painted_image.save(f'/ssd1/gaomingqi/refine/{ti:05d}.png')
|
|
|
|
|
self.sam_model.reset_image()
|
2023-04-14 12:37:38 +08:00
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def clear_memory(self):
|
|
|
|
|
self.tracker.clear_memory()
|
|
|
|
|
self.mapper.clear_labels()
|
2023-04-14 03:13:58 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
2023-04-17 12:33:37 +08:00
|
|
|
# video frames (multiple objects)
|
|
|
|
|
video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/horsejump-high', '*.jpg'))
|
2023-04-14 03:13:58 +08:00
|
|
|
video_path_list.sort()
|
|
|
|
|
# first frame
|
2023-04-17 12:33:37 +08:00
|
|
|
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/horsejump-high/00000.png'
|
2023-04-14 03:13:58 +08:00
|
|
|
# load frames
|
2023-04-14 05:37:08 +08:00
|
|
|
frames = []
|
2023-04-14 03:13:58 +08:00
|
|
|
for video_path in video_path_list:
|
|
|
|
|
frames.append(np.array(Image.open(video_path).convert('RGB')))
|
|
|
|
|
frames = np.stack(frames, 0) # N, H, W, C
|
|
|
|
|
# load first frame annotation
|
|
|
|
|
first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C
|
|
|
|
|
|
2023-04-14 05:37:08 +08:00
|
|
|
# ----------------------------------------------------------
|
2023-04-14 03:13:58 +08:00
|
|
|
# initalise tracker
|
2023-04-14 05:37:08 +08:00
|
|
|
# ----------------------------------------------------------
|
2023-04-17 19:02:32 +08:00
|
|
|
device = 'cuda:4'
|
2023-04-14 03:13:58 +08:00
|
|
|
XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth'
|
2023-04-16 16:49:38 +08:00
|
|
|
SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
|
|
|
|
|
model_type = 'vit_h'
|
2023-04-17 12:33:37 +08:00
|
|
|
|
2023-04-17 19:02:32 +08:00
|
|
|
# sam_model = BaseSegmenter(SAM_checkpoint, model_type, device=device)
|
|
|
|
|
tracker = BaseTracker(XMEM_checkpoint, device, None, device)
|
|
|
|
|
|
2023-04-19 19:58:41 +08:00
|
|
|
# # test for storage efficiency
|
|
|
|
|
# frames = np.load('/ssd1/gaomingqi/efficiency/efficiency.npy')
|
|
|
|
|
# first_frame_annotation = np.array(Image.open('/ssd1/gaomingqi/efficiency/template_mask.png'))
|
2023-04-17 19:02:32 +08:00
|
|
|
|
2023-04-19 19:58:41 +08:00
|
|
|
first_frame_annotation[first_frame_annotation==1] = 15
|
|
|
|
|
first_frame_annotation[first_frame_annotation==2] = 20
|
|
|
|
|
|
|
|
|
|
save_path = '/ssd1/gaomingqi/results/TrackA/multi-change1'
|
|
|
|
|
if not os.path.exists(save_path):
|
|
|
|
|
os.mkdir(save_path)
|
2023-04-17 12:33:37 +08:00
|
|
|
|
|
|
|
|
for ti, frame in enumerate(frames):
|
|
|
|
|
if ti == 0:
|
|
|
|
|
mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
|
|
|
|
|
else:
|
|
|
|
|
mask, prob, painted_image = tracker.track(frame)
|
|
|
|
|
# save
|
|
|
|
|
painted_image = Image.fromarray(painted_image)
|
2023-04-19 19:58:41 +08:00
|
|
|
painted_image.save(f'{save_path}/{ti:05d}.png')
|
|
|
|
|
|
|
|
|
|
# tracker.clear_memory()
|
|
|
|
|
# for ti, frame in enumerate(frames):
|
|
|
|
|
# print(ti)
|
|
|
|
|
# # if ti > 200:
|
|
|
|
|
# # break
|
|
|
|
|
# if ti == 0:
|
|
|
|
|
# mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
|
|
|
|
|
# else:
|
|
|
|
|
# mask, prob, painted_image = tracker.track(frame)
|
|
|
|
|
# # save
|
|
|
|
|
# painted_image = Image.fromarray(painted_image)
|
|
|
|
|
# painted_image.save(f'/ssd1/gaomingqi/results/TrackA/gsw/{ti:05d}.png')
|
2023-04-17 19:02:32 +08:00
|
|
|
|
|
|
|
|
# # track anything given in the first frame annotation
|
|
|
|
|
# for ti, frame in enumerate(frames):
|
|
|
|
|
# if ti == 0:
|
|
|
|
|
# mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
|
|
|
|
|
# else:
|
|
|
|
|
# mask, prob, painted_image = tracker.track(frame)
|
|
|
|
|
# # save
|
|
|
|
|
# painted_image = Image.fromarray(painted_image)
|
|
|
|
|
# painted_image.save(f'/ssd1/gaomingqi/results/TrackA/horsejump-high/{ti:05d}.png')
|
2023-04-14 12:37:38 +08:00
|
|
|
|
2023-04-16 16:49:38 +08:00
|
|
|
# # ----------------------------------------------------------
|
|
|
|
|
# # another video
|
|
|
|
|
# # ----------------------------------------------------------
|
|
|
|
|
# # video frames
|
|
|
|
|
# video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/camel', '*.jpg'))
|
|
|
|
|
# video_path_list.sort()
|
|
|
|
|
# # first frame
|
|
|
|
|
# first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/camel/00000.png'
|
|
|
|
|
# # load frames
|
|
|
|
|
# frames = []
|
|
|
|
|
# for video_path in video_path_list:
|
|
|
|
|
# frames.append(np.array(Image.open(video_path).convert('RGB')))
|
|
|
|
|
# frames = np.stack(frames, 0) # N, H, W, C
|
|
|
|
|
# # load first frame annotation
|
|
|
|
|
# first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C
|
2023-04-17 12:33:37 +08:00
|
|
|
|
2023-04-16 16:49:38 +08:00
|
|
|
# print('first video done. clear.')
|
|
|
|
|
|
|
|
|
|
# tracker.clear_memory()
|
|
|
|
|
# # track anything given in the first frame annotation
|
|
|
|
|
# for ti, frame in enumerate(frames):
|
|
|
|
|
# if ti == 0:
|
|
|
|
|
# mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
|
|
|
|
|
# else:
|
|
|
|
|
# mask, prob, painted_image = tracker.track(frame)
|
|
|
|
|
# # save
|
|
|
|
|
# painted_image = Image.fromarray(painted_image)
|
|
|
|
|
# painted_image.save(f'/ssd1/gaomingqi/results/TrackA/camel/{ti:05d}.png')
|
|
|
|
|
|
2023-04-17 12:33:37 +08:00
|
|
|
# # failure case test
|
|
|
|
|
# failure_path = '/ssd1/gaomingqi/failure'
|
|
|
|
|
# frames = np.load(os.path.join(failure_path, 'video_frames.npy'))
|
|
|
|
|
# # first_frame = np.array(Image.open(os.path.join(failure_path, 'template_frame.png')).convert('RGB'))
|
|
|
|
|
# first_mask = np.array(Image.open(os.path.join(failure_path, 'template_mask.png')).convert('P'))
|
|
|
|
|
# first_mask = np.clip(first_mask, 0, 1)
|
2023-04-14 10:17:41 +08:00
|
|
|
|
2023-04-17 12:33:37 +08:00
|
|
|
# for ti, frame in enumerate(frames):
|
|
|
|
|
# if ti == 0:
|
|
|
|
|
# mask, probs, painted_image = tracker.track(frame, first_mask)
|
|
|
|
|
# else:
|
|
|
|
|
# mask, probs, painted_image = tracker.track(frame)
|
|
|
|
|
# # save
|
|
|
|
|
# painted_image = Image.fromarray(painted_image)
|
|
|
|
|
# painted_image.save(f'/ssd1/gaomingqi/failure/LJ/{ti:05d}.png')
|
|
|
|
|
# prob = Image.fromarray((probs[1].cpu().numpy()*255).astype('uint8'))
|
2023-04-16 16:49:38 +08:00
|
|
|
|
2023-04-17 12:33:37 +08:00
|
|
|
# # prob.save(f'/ssd1/gaomingqi/failure/probs/{ti:05d}.png')
|
2023-04-17 19:02:32 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|