mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2026-05-18 05:05:29 +02:00
remove redundant code
This commit is contained in:
0
tracker/dataset/__init__.py
Normal file
0
tracker/dataset/__init__.py
Normal file
12
tracker/dataset/range_transform.py
Normal file
12
tracker/dataset/range_transform.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
im_mean = (124, 116, 104)
|
||||
|
||||
im_normalization = transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225]
|
||||
)
|
||||
|
||||
inv_im_trans = transforms.Normalize(
|
||||
mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
|
||||
std=[1/0.229, 1/0.224, 1/0.225])
|
||||
6
tracker/dataset/reseed.py
Normal file
6
tracker/dataset/reseed.py
Normal file
@@ -0,0 +1,6 @@
|
||||
import torch
|
||||
import random
|
||||
|
||||
def reseed(seed):
|
||||
random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
179
tracker/dataset/static_dataset.py
Normal file
179
tracker/dataset/static_dataset.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import os
|
||||
from os import path
|
||||
|
||||
import torch
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import InterpolationMode
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
from dataset.range_transform import im_normalization, im_mean
|
||||
from dataset.tps import random_tps_warp
|
||||
from dataset.reseed import reseed
|
||||
|
||||
|
||||
class StaticTransformDataset(Dataset):
|
||||
"""
|
||||
Generate pseudo VOS data by applying random transforms on static images.
|
||||
Single-object only.
|
||||
|
||||
Method 0 - FSS style (class/1.jpg class/1.png)
|
||||
Method 1 - Others style (XXX.jpg XXX.png)
|
||||
"""
|
||||
def __init__(self, parameters, num_frames=3, max_num_obj=1):
|
||||
self.num_frames = num_frames
|
||||
self.max_num_obj = max_num_obj
|
||||
|
||||
self.im_list = []
|
||||
for parameter in parameters:
|
||||
root, method, multiplier = parameter
|
||||
if method == 0:
|
||||
# Get images
|
||||
classes = os.listdir(root)
|
||||
for c in classes:
|
||||
imgs = os.listdir(path.join(root, c))
|
||||
jpg_list = [im for im in imgs if 'jpg' in im[-3:].lower()]
|
||||
|
||||
joint_list = [path.join(root, c, im) for im in jpg_list]
|
||||
self.im_list.extend(joint_list * multiplier)
|
||||
|
||||
elif method == 1:
|
||||
self.im_list.extend([path.join(root, im) for im in os.listdir(root) if '.jpg' in im] * multiplier)
|
||||
|
||||
print(f'{len(self.im_list)} images found.')
|
||||
|
||||
# These set of transform is the same for im/gt pairs, but different among the 3 sampled frames
|
||||
self.pair_im_lone_transform = transforms.Compose([
|
||||
transforms.ColorJitter(0.1, 0.05, 0.05, 0), # No hue change here as that's not realistic
|
||||
])
|
||||
|
||||
self.pair_im_dual_transform = transforms.Compose([
|
||||
transforms.RandomAffine(degrees=20, scale=(0.9,1.1), shear=10, interpolation=InterpolationMode.BICUBIC, fill=im_mean),
|
||||
transforms.Resize(384, InterpolationMode.BICUBIC),
|
||||
transforms.RandomCrop((384, 384), pad_if_needed=True, fill=im_mean),
|
||||
])
|
||||
|
||||
self.pair_gt_dual_transform = transforms.Compose([
|
||||
transforms.RandomAffine(degrees=20, scale=(0.9,1.1), shear=10, interpolation=InterpolationMode.BICUBIC, fill=0),
|
||||
transforms.Resize(384, InterpolationMode.NEAREST),
|
||||
transforms.RandomCrop((384, 384), pad_if_needed=True, fill=0),
|
||||
])
|
||||
|
||||
|
||||
# These transform are the same for all pairs in the sampled sequence
|
||||
self.all_im_lone_transform = transforms.Compose([
|
||||
transforms.ColorJitter(0.1, 0.05, 0.05, 0.05),
|
||||
transforms.RandomGrayscale(0.05),
|
||||
])
|
||||
|
||||
self.all_im_dual_transform = transforms.Compose([
|
||||
transforms.RandomAffine(degrees=0, scale=(0.8, 1.5), fill=im_mean),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
])
|
||||
|
||||
self.all_gt_dual_transform = transforms.Compose([
|
||||
transforms.RandomAffine(degrees=0, scale=(0.8, 1.5), fill=0),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
])
|
||||
|
||||
# Final transform without randomness
|
||||
self.final_im_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
im_normalization,
|
||||
])
|
||||
|
||||
self.final_gt_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
def _get_sample(self, idx):
|
||||
im = Image.open(self.im_list[idx]).convert('RGB')
|
||||
gt = Image.open(self.im_list[idx][:-3]+'png').convert('L')
|
||||
|
||||
sequence_seed = np.random.randint(2147483647)
|
||||
|
||||
images = []
|
||||
masks = []
|
||||
for _ in range(self.num_frames):
|
||||
reseed(sequence_seed)
|
||||
this_im = self.all_im_dual_transform(im)
|
||||
this_im = self.all_im_lone_transform(this_im)
|
||||
reseed(sequence_seed)
|
||||
this_gt = self.all_gt_dual_transform(gt)
|
||||
|
||||
pairwise_seed = np.random.randint(2147483647)
|
||||
reseed(pairwise_seed)
|
||||
this_im = self.pair_im_dual_transform(this_im)
|
||||
this_im = self.pair_im_lone_transform(this_im)
|
||||
reseed(pairwise_seed)
|
||||
this_gt = self.pair_gt_dual_transform(this_gt)
|
||||
|
||||
# Use TPS only some of the times
|
||||
# Not because TPS is bad -- just that it is too slow and I need to speed up data loading
|
||||
if np.random.rand() < 0.33:
|
||||
this_im, this_gt = random_tps_warp(this_im, this_gt, scale=0.02)
|
||||
|
||||
this_im = self.final_im_transform(this_im)
|
||||
this_gt = self.final_gt_transform(this_gt)
|
||||
|
||||
images.append(this_im)
|
||||
masks.append(this_gt)
|
||||
|
||||
images = torch.stack(images, 0)
|
||||
masks = torch.stack(masks, 0)
|
||||
|
||||
return images, masks.numpy()
|
||||
|
||||
def __getitem__(self, idx):
|
||||
additional_objects = np.random.randint(self.max_num_obj)
|
||||
indices = [idx, *np.random.randint(self.__len__(), size=additional_objects)]
|
||||
|
||||
merged_images = None
|
||||
merged_masks = np.zeros((self.num_frames, 384, 384), dtype=np.int64)
|
||||
|
||||
for i, list_id in enumerate(indices):
|
||||
images, masks = self._get_sample(list_id)
|
||||
if merged_images is None:
|
||||
merged_images = images
|
||||
else:
|
||||
merged_images = merged_images*(1-masks) + images*masks
|
||||
merged_masks[masks[:,0]>0.5] = (i+1)
|
||||
|
||||
masks = merged_masks
|
||||
|
||||
labels = np.unique(masks[0])
|
||||
# Remove background
|
||||
labels = labels[labels!=0]
|
||||
target_objects = labels.tolist()
|
||||
|
||||
# Generate one-hot ground-truth
|
||||
cls_gt = np.zeros((self.num_frames, 384, 384), dtype=np.int64)
|
||||
first_frame_gt = np.zeros((1, self.max_num_obj, 384, 384), dtype=np.int64)
|
||||
for i, l in enumerate(target_objects):
|
||||
this_mask = (masks==l)
|
||||
cls_gt[this_mask] = i+1
|
||||
first_frame_gt[0,i] = (this_mask[0])
|
||||
cls_gt = np.expand_dims(cls_gt, 1)
|
||||
|
||||
info = {}
|
||||
info['name'] = self.im_list[idx]
|
||||
info['num_objects'] = max(1, len(target_objects))
|
||||
|
||||
# 1 if object exist, 0 otherwise
|
||||
selector = [1 if i < info['num_objects'] else 0 for i in range(self.max_num_obj)]
|
||||
selector = torch.FloatTensor(selector)
|
||||
|
||||
data = {
|
||||
'rgb': merged_images,
|
||||
'first_frame_gt': first_frame_gt,
|
||||
'cls_gt': cls_gt,
|
||||
'selector': selector,
|
||||
'info': info
|
||||
}
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.im_list)
|
||||
37
tracker/dataset/tps.py
Normal file
37
tracker/dataset/tps.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import cv2
|
||||
import thinplate as tps
|
||||
|
||||
cv2.setNumThreads(0)
|
||||
|
||||
def pick_random_points(h, w, n_samples):
|
||||
y_idx = np.random.choice(np.arange(h), size=n_samples, replace=False)
|
||||
x_idx = np.random.choice(np.arange(w), size=n_samples, replace=False)
|
||||
return y_idx/h, x_idx/w
|
||||
|
||||
|
||||
def warp_dual_cv(img, mask, c_src, c_dst):
|
||||
dshape = img.shape
|
||||
theta = tps.tps_theta_from_points(c_src, c_dst, reduced=True)
|
||||
grid = tps.tps_grid(theta, c_dst, dshape)
|
||||
mapx, mapy = tps.tps_grid_to_remap(grid, img.shape)
|
||||
return cv2.remap(img, mapx, mapy, cv2.INTER_LINEAR), cv2.remap(mask, mapx, mapy, cv2.INTER_NEAREST)
|
||||
|
||||
|
||||
def random_tps_warp(img, mask, scale, n_ctrl_pts=12):
|
||||
"""
|
||||
Apply a random TPS warp of the input image and mask
|
||||
Uses randomness from numpy
|
||||
"""
|
||||
img = np.asarray(img)
|
||||
mask = np.asarray(mask)
|
||||
|
||||
h, w = mask.shape
|
||||
points = pick_random_points(h, w, n_ctrl_pts)
|
||||
c_src = np.stack(points, 1)
|
||||
c_dst = c_src + np.random.normal(scale=scale, size=c_src.shape)
|
||||
warp_im, warp_gt = warp_dual_cv(img, mask, c_src, c_dst)
|
||||
|
||||
return Image.fromarray(warp_im), Image.fromarray(warp_gt)
|
||||
|
||||
13
tracker/dataset/util.py
Normal file
13
tracker/dataset/util.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def all_to_onehot(masks, labels):
|
||||
if len(masks.shape) == 3:
|
||||
Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1], masks.shape[2]), dtype=np.uint8)
|
||||
else:
|
||||
Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1]), dtype=np.uint8)
|
||||
|
||||
for ni, l in enumerate(labels):
|
||||
Ms[ni] = (masks == l).astype(np.uint8)
|
||||
|
||||
return Ms
|
||||
216
tracker/dataset/vos_dataset.py
Normal file
216
tracker/dataset/vos_dataset.py
Normal file
@@ -0,0 +1,216 @@
|
||||
import os
|
||||
from os import path, replace
|
||||
|
||||
import torch
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import InterpolationMode
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
from dataset.range_transform import im_normalization, im_mean
|
||||
from dataset.reseed import reseed
|
||||
|
||||
|
||||
class VOSDataset(Dataset):
|
||||
"""
|
||||
Works for DAVIS/YouTubeVOS/BL30K training
|
||||
For each sequence:
|
||||
- Pick three frames
|
||||
- Pick two objects
|
||||
- Apply some random transforms that are the same for all frames
|
||||
- Apply random transform to each of the frame
|
||||
- The distance between frames is controlled
|
||||
"""
|
||||
def __init__(self, im_root, gt_root, max_jump, is_bl, subset=None, num_frames=3, max_num_obj=3, finetune=False):
|
||||
self.im_root = im_root
|
||||
self.gt_root = gt_root
|
||||
self.max_jump = max_jump
|
||||
self.is_bl = is_bl
|
||||
self.num_frames = num_frames
|
||||
self.max_num_obj = max_num_obj
|
||||
|
||||
self.videos = []
|
||||
self.frames = {}
|
||||
|
||||
vid_list = sorted(os.listdir(self.im_root))
|
||||
# Pre-filtering
|
||||
for vid in vid_list:
|
||||
if subset is not None:
|
||||
if vid not in subset:
|
||||
continue
|
||||
frames = sorted(os.listdir(os.path.join(self.im_root, vid)))
|
||||
if len(frames) < num_frames:
|
||||
continue
|
||||
self.frames[vid] = frames
|
||||
self.videos.append(vid)
|
||||
|
||||
print('%d out of %d videos accepted in %s.' % (len(self.videos), len(vid_list), im_root))
|
||||
|
||||
# These set of transform is the same for im/gt pairs, but different among the 3 sampled frames
|
||||
self.pair_im_lone_transform = transforms.Compose([
|
||||
transforms.ColorJitter(0.01, 0.01, 0.01, 0),
|
||||
])
|
||||
|
||||
self.pair_im_dual_transform = transforms.Compose([
|
||||
transforms.RandomAffine(degrees=0 if finetune or self.is_bl else 15, shear=0 if finetune or self.is_bl else 10, interpolation=InterpolationMode.BILINEAR, fill=im_mean),
|
||||
])
|
||||
|
||||
self.pair_gt_dual_transform = transforms.Compose([
|
||||
transforms.RandomAffine(degrees=0 if finetune or self.is_bl else 15, shear=0 if finetune or self.is_bl else 10, interpolation=InterpolationMode.NEAREST, fill=0),
|
||||
])
|
||||
|
||||
# These transform are the same for all pairs in the sampled sequence
|
||||
self.all_im_lone_transform = transforms.Compose([
|
||||
transforms.ColorJitter(0.1, 0.03, 0.03, 0),
|
||||
transforms.RandomGrayscale(0.05),
|
||||
])
|
||||
|
||||
if self.is_bl:
|
||||
# Use a different cropping scheme for the blender dataset because the image size is different
|
||||
self.all_im_dual_transform = transforms.Compose([
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.RandomResizedCrop((384, 384), scale=(0.25, 1.00), interpolation=InterpolationMode.BILINEAR)
|
||||
])
|
||||
|
||||
self.all_gt_dual_transform = transforms.Compose([
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.RandomResizedCrop((384, 384), scale=(0.25, 1.00), interpolation=InterpolationMode.NEAREST)
|
||||
])
|
||||
else:
|
||||
self.all_im_dual_transform = transforms.Compose([
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.RandomResizedCrop((384, 384), scale=(0.36,1.00), interpolation=InterpolationMode.BILINEAR)
|
||||
])
|
||||
|
||||
self.all_gt_dual_transform = transforms.Compose([
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.RandomResizedCrop((384, 384), scale=(0.36,1.00), interpolation=InterpolationMode.NEAREST)
|
||||
])
|
||||
|
||||
# Final transform without randomness
|
||||
self.final_im_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
im_normalization,
|
||||
])
|
||||
|
||||
def __getitem__(self, idx):
|
||||
video = self.videos[idx]
|
||||
info = {}
|
||||
info['name'] = video
|
||||
|
||||
vid_im_path = path.join(self.im_root, video)
|
||||
vid_gt_path = path.join(self.gt_root, video)
|
||||
frames = self.frames[video]
|
||||
|
||||
trials = 0
|
||||
while trials < 5:
|
||||
info['frames'] = [] # Appended with actual frames
|
||||
|
||||
num_frames = self.num_frames
|
||||
length = len(frames)
|
||||
this_max_jump = min(len(frames), self.max_jump)
|
||||
|
||||
# iterative sampling
|
||||
frames_idx = [np.random.randint(length)]
|
||||
acceptable_set = set(range(max(0, frames_idx[-1]-this_max_jump), min(length, frames_idx[-1]+this_max_jump+1))).difference(set(frames_idx))
|
||||
while(len(frames_idx) < num_frames):
|
||||
idx = np.random.choice(list(acceptable_set))
|
||||
frames_idx.append(idx)
|
||||
new_set = set(range(max(0, frames_idx[-1]-this_max_jump), min(length, frames_idx[-1]+this_max_jump+1)))
|
||||
acceptable_set = acceptable_set.union(new_set).difference(set(frames_idx))
|
||||
|
||||
frames_idx = sorted(frames_idx)
|
||||
if np.random.rand() < 0.5:
|
||||
# Reverse time
|
||||
frames_idx = frames_idx[::-1]
|
||||
|
||||
sequence_seed = np.random.randint(2147483647)
|
||||
images = []
|
||||
masks = []
|
||||
target_objects = []
|
||||
for f_idx in frames_idx:
|
||||
jpg_name = frames[f_idx][:-4] + '.jpg'
|
||||
png_name = frames[f_idx][:-4] + '.png'
|
||||
info['frames'].append(jpg_name)
|
||||
|
||||
reseed(sequence_seed)
|
||||
this_im = Image.open(path.join(vid_im_path, jpg_name)).convert('RGB')
|
||||
this_im = self.all_im_dual_transform(this_im)
|
||||
this_im = self.all_im_lone_transform(this_im)
|
||||
reseed(sequence_seed)
|
||||
this_gt = Image.open(path.join(vid_gt_path, png_name)).convert('P')
|
||||
this_gt = self.all_gt_dual_transform(this_gt)
|
||||
|
||||
pairwise_seed = np.random.randint(2147483647)
|
||||
reseed(pairwise_seed)
|
||||
this_im = self.pair_im_dual_transform(this_im)
|
||||
this_im = self.pair_im_lone_transform(this_im)
|
||||
reseed(pairwise_seed)
|
||||
this_gt = self.pair_gt_dual_transform(this_gt)
|
||||
|
||||
this_im = self.final_im_transform(this_im)
|
||||
this_gt = np.array(this_gt)
|
||||
|
||||
images.append(this_im)
|
||||
masks.append(this_gt)
|
||||
|
||||
images = torch.stack(images, 0)
|
||||
|
||||
labels = np.unique(masks[0])
|
||||
# Remove background
|
||||
labels = labels[labels!=0]
|
||||
|
||||
if self.is_bl:
|
||||
# Find large enough labels
|
||||
good_lables = []
|
||||
for l in labels:
|
||||
pixel_sum = (masks[0]==l).sum()
|
||||
if pixel_sum > 10*10:
|
||||
# OK if the object is always this small
|
||||
# Not OK if it is actually much bigger
|
||||
if pixel_sum > 30*30:
|
||||
good_lables.append(l)
|
||||
elif max((masks[1]==l).sum(), (masks[2]==l).sum()) < 20*20:
|
||||
good_lables.append(l)
|
||||
labels = np.array(good_lables, dtype=np.uint8)
|
||||
|
||||
if len(labels) == 0:
|
||||
target_objects = []
|
||||
trials += 1
|
||||
else:
|
||||
target_objects = labels.tolist()
|
||||
break
|
||||
|
||||
if len(target_objects) > self.max_num_obj:
|
||||
target_objects = np.random.choice(target_objects, size=self.max_num_obj, replace=False)
|
||||
|
||||
info['num_objects'] = max(1, len(target_objects))
|
||||
|
||||
masks = np.stack(masks, 0)
|
||||
|
||||
# Generate one-hot ground-truth
|
||||
cls_gt = np.zeros((self.num_frames, 384, 384), dtype=np.int64)
|
||||
first_frame_gt = np.zeros((1, self.max_num_obj, 384, 384), dtype=np.int64)
|
||||
for i, l in enumerate(target_objects):
|
||||
this_mask = (masks==l)
|
||||
cls_gt[this_mask] = i+1
|
||||
first_frame_gt[0,i] = (this_mask[0])
|
||||
cls_gt = np.expand_dims(cls_gt, 1)
|
||||
|
||||
# 1 if object exist, 0 otherwise
|
||||
selector = [1 if i < info['num_objects'] else 0 for i in range(self.max_num_obj)]
|
||||
selector = torch.FloatTensor(selector)
|
||||
|
||||
data = {
|
||||
'rgb': images,
|
||||
'first_frame_gt': first_frame_gt,
|
||||
'cls_gt': cls_gt,
|
||||
'selector': selector,
|
||||
'info': info,
|
||||
}
|
||||
|
||||
return data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.videos)
|
||||
260
tracker/eval.py
Normal file
260
tracker/eval.py
Normal file
@@ -0,0 +1,260 @@
|
||||
import os
|
||||
from os import path
|
||||
from argparse import ArgumentParser
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from inference.data.test_datasets import LongTestDataset, DAVISTestDataset, YouTubeVOSTestDataset
|
||||
from inference.data.mask_mapper import MaskMapper
|
||||
from model.network import XMem
|
||||
from inference.inference_core import InferenceCore
|
||||
|
||||
from progressbar import progressbar
|
||||
|
||||
try:
|
||||
import hickle as hkl
|
||||
except ImportError:
|
||||
print('Failed to import hickle. Fine if not using multi-scale testing.')
|
||||
|
||||
|
||||
"""
|
||||
Arguments loading
|
||||
"""
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--model', default='/ssd1/gaomingqi/checkpoints/XMem-s012.pth')
|
||||
|
||||
# Data options
|
||||
parser.add_argument('--d16_path', default='../DAVIS/2016')
|
||||
parser.add_argument('--d17_path', default='../DAVIS/2017')
|
||||
parser.add_argument('--y18_path', default='/ssd1/gaomingqi/datasets/youtube-vos/2018')
|
||||
parser.add_argument('--y19_path', default='../YouTube')
|
||||
parser.add_argument('--lv_path', default='../long_video_set')
|
||||
# For generic (G) evaluation, point to a folder that contains "JPEGImages" and "Annotations"
|
||||
parser.add_argument('--generic_path')
|
||||
|
||||
parser.add_argument('--dataset', help='D16/D17/Y18/Y19/LV1/LV3/G', default='D17')
|
||||
parser.add_argument('--split', help='val/test', default='val')
|
||||
parser.add_argument('--output', default=None)
|
||||
parser.add_argument('--save_all', action='store_true',
|
||||
help='Save all frames. Useful only in YouTubeVOS/long-time video', )
|
||||
|
||||
parser.add_argument('--benchmark', action='store_true', help='enable to disable amp for FPS benchmarking')
|
||||
|
||||
# Long-term memory options
|
||||
parser.add_argument('--disable_long_term', action='store_true')
|
||||
parser.add_argument('--max_mid_term_frames', help='T_max in paper, decrease to save memory', type=int, default=10)
|
||||
parser.add_argument('--min_mid_term_frames', help='T_min in paper, decrease to save memory', type=int, default=5)
|
||||
parser.add_argument('--max_long_term_elements', help='LT_max in paper, increase if objects disappear for a long time',
|
||||
type=int, default=10000)
|
||||
parser.add_argument('--num_prototypes', help='P in paper', type=int, default=128)
|
||||
|
||||
parser.add_argument('--top_k', type=int, default=30)
|
||||
parser.add_argument('--mem_every', help='r in paper. Increase to improve running speed.', type=int, default=5)
|
||||
parser.add_argument('--deep_update_every', help='Leave -1 normally to synchronize with mem_every', type=int, default=-1)
|
||||
|
||||
# Multi-scale options
|
||||
parser.add_argument('--save_scores', action='store_true')
|
||||
parser.add_argument('--flip', action='store_true')
|
||||
parser.add_argument('--size', default=480, type=int,
|
||||
help='Resize the shorter side to this size. -1 to use original resolution. ')
|
||||
|
||||
args = parser.parse_args()
|
||||
config = vars(args)
|
||||
config['enable_long_term'] = not config['disable_long_term']
|
||||
|
||||
if args.output is None:
|
||||
args.output = f'../output/{args.dataset}_{args.split}'
|
||||
print(f'Output path not provided. Defaulting to {args.output}')
|
||||
|
||||
"""
|
||||
Data preparation
|
||||
"""
|
||||
is_youtube = args.dataset.startswith('Y')
|
||||
is_davis = args.dataset.startswith('D')
|
||||
is_lv = args.dataset.startswith('LV')
|
||||
|
||||
if is_youtube or args.save_scores:
|
||||
out_path = path.join(args.output, 'Annotations')
|
||||
else:
|
||||
out_path = args.output
|
||||
|
||||
if is_youtube:
|
||||
if args.dataset == 'Y18':
|
||||
yv_path = args.y18_path
|
||||
elif args.dataset == 'Y19':
|
||||
yv_path = args.y19_path
|
||||
|
||||
if args.split == 'val':
|
||||
args.split = 'valid'
|
||||
meta_dataset = YouTubeVOSTestDataset(data_root=yv_path, split='valid', size=args.size)
|
||||
elif args.split == 'test':
|
||||
meta_dataset = YouTubeVOSTestDataset(data_root=yv_path, split='test', size=args.size)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
elif is_davis:
|
||||
if args.dataset == 'D16':
|
||||
if args.split == 'val':
|
||||
# Set up Dataset, a small hack to use the image set in the 2017 folder because the 2016 one is of a different format
|
||||
meta_dataset = DAVISTestDataset(args.d16_path, imset='../../2017/trainval/ImageSets/2016/val.txt', size=args.size)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
palette = None
|
||||
elif args.dataset == 'D17':
|
||||
if args.split == 'val':
|
||||
meta_dataset = DAVISTestDataset(path.join(args.d17_path, 'trainval'), imset='2017/val.txt', size=args.size)
|
||||
elif args.split == 'test':
|
||||
meta_dataset = DAVISTestDataset(path.join(args.d17_path, 'test-dev'), imset='2017/test-dev.txt', size=args.size)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
elif is_lv:
|
||||
if args.dataset == 'LV1':
|
||||
meta_dataset = LongTestDataset(path.join(args.lv_path, 'long_video'))
|
||||
elif args.dataset == 'LV3':
|
||||
meta_dataset = LongTestDataset(path.join(args.lv_path, 'long_video_x3'))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
elif args.dataset == 'G':
|
||||
meta_dataset = LongTestDataset(path.join(args.generic_path), size=args.size)
|
||||
if not args.save_all:
|
||||
args.save_all = True
|
||||
print('save_all is forced to be true in generic evaluation mode.')
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
torch.autograd.set_grad_enabled(False)
|
||||
|
||||
# Set up loader
|
||||
meta_loader = meta_dataset.get_datasets()
|
||||
|
||||
# Load our checkpoint
|
||||
network = XMem(config, args.model).cuda().eval()
|
||||
if args.model is not None:
|
||||
model_weights = torch.load(args.model)
|
||||
network.load_weights(model_weights, init_as_zero_if_needed=True)
|
||||
else:
|
||||
print('No model loaded.')
|
||||
|
||||
total_process_time = 0
|
||||
total_frames = 0
|
||||
|
||||
# Start eval
|
||||
for vid_reader in progressbar(meta_loader, max_value=len(meta_dataset), redirect_stdout=True):
|
||||
|
||||
loader = DataLoader(vid_reader, batch_size=1, shuffle=False, num_workers=2)
|
||||
vid_name = vid_reader.vid_name
|
||||
vid_length = len(loader)
|
||||
# no need to count usage for LT if the video is not that long anyway
|
||||
config['enable_long_term_count_usage'] = (
|
||||
config['enable_long_term'] and
|
||||
(vid_length
|
||||
/ (config['max_mid_term_frames']-config['min_mid_term_frames'])
|
||||
* config['num_prototypes'])
|
||||
>= config['max_long_term_elements']
|
||||
)
|
||||
|
||||
mapper = MaskMapper()
|
||||
processor = InferenceCore(network, config=config)
|
||||
first_mask_loaded = False
|
||||
|
||||
for ti, data in enumerate(loader):
|
||||
with torch.cuda.amp.autocast(enabled=not args.benchmark):
|
||||
rgb = data['rgb'].cuda()[0]
|
||||
msk = data.get('mask')
|
||||
info = data['info']
|
||||
frame = info['frame'][0]
|
||||
shape = info['shape']
|
||||
need_resize = info['need_resize'][0]
|
||||
|
||||
"""
|
||||
For timing see https://discuss.pytorch.org/t/how-to-measure-time-in-pytorch/26964
|
||||
Seems to be very similar in testing as my previous timing method
|
||||
with two cuda sync + time.time() in STCN though
|
||||
"""
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
start.record()
|
||||
|
||||
if not first_mask_loaded:
|
||||
if msk is not None:
|
||||
first_mask_loaded = True
|
||||
else:
|
||||
# no point to do anything without a mask
|
||||
continue
|
||||
|
||||
if args.flip:
|
||||
rgb = torch.flip(rgb, dims=[-1])
|
||||
msk = torch.flip(msk, dims=[-1]) if msk is not None else None
|
||||
|
||||
# Map possibly non-continuous labels to continuous ones
|
||||
if msk is not None:
|
||||
msk, labels = mapper.convert_mask(msk[0].numpy())
|
||||
msk = torch.Tensor(msk).cuda()
|
||||
if need_resize:
|
||||
msk = vid_reader.resize_mask(msk.unsqueeze(0))[0]
|
||||
processor.set_all_labels(list(mapper.remappings.values()))
|
||||
else:
|
||||
labels = None
|
||||
|
||||
# Run the model on this frame
|
||||
prob = processor.step(rgb, msk, labels, end=(ti==vid_length-1))
|
||||
|
||||
# consider prob as prompt to refine segment results
|
||||
|
||||
|
||||
# Upsample to original size if needed
|
||||
if need_resize:
|
||||
prob = F.interpolate(prob.unsqueeze(1), shape, mode='bilinear', align_corners=False)[:,0]
|
||||
|
||||
end.record()
|
||||
torch.cuda.synchronize()
|
||||
total_process_time += (start.elapsed_time(end)/1000)
|
||||
total_frames += 1
|
||||
|
||||
if args.flip:
|
||||
prob = torch.flip(prob, dims=[-1])
|
||||
|
||||
# Probability mask -> index mask
|
||||
out_mask = torch.argmax(prob, dim=0)
|
||||
out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
|
||||
|
||||
if args.save_scores:
|
||||
prob = (prob.detach().cpu().numpy()*255).astype(np.uint8)
|
||||
|
||||
# Save the mask
|
||||
if args.save_all or info['save'][0]:
|
||||
this_out_path = path.join(out_path, vid_name)
|
||||
os.makedirs(this_out_path, exist_ok=True)
|
||||
out_mask = mapper.remap_index_mask(out_mask)
|
||||
out_img = Image.fromarray(out_mask)
|
||||
if vid_reader.get_palette() is not None:
|
||||
out_img.putpalette(vid_reader.get_palette())
|
||||
out_img.save(os.path.join(this_out_path, frame[:-4]+'.png'))
|
||||
|
||||
if args.save_scores:
|
||||
np_path = path.join(args.output, 'Scores', vid_name)
|
||||
os.makedirs(np_path, exist_ok=True)
|
||||
if ti==len(loader)-1:
|
||||
hkl.dump(mapper.remappings, path.join(np_path, f'backward.hkl'), mode='w')
|
||||
if args.save_all or info['save'][0]:
|
||||
hkl.dump(prob, path.join(np_path, f'{frame[:-4]}.hkl'), mode='w', compression='lzf')
|
||||
|
||||
|
||||
print(f'Total processing time: {total_process_time}')
|
||||
print(f'Total processed frames: {total_frames}')
|
||||
print(f'FPS: {total_frames / total_process_time}')
|
||||
print(f'Max allocated memory (MB): {torch.cuda.max_memory_allocated() / (2**20)}')
|
||||
|
||||
if not args.save_scores:
|
||||
if is_youtube:
|
||||
print('Making zip for YouTubeVOS...')
|
||||
shutil.make_archive(path.join(args.output, path.basename(args.output)), 'zip', args.output, 'Annotations')
|
||||
elif is_davis and args.split == 'test':
|
||||
print('Making zip for DAVIS test-dev...')
|
||||
shutil.make_archive(args.output, 'zip', args.output)
|
||||
0
tracker/inference/__init__.py
Normal file
0
tracker/inference/__init__.py
Normal file
0
tracker/inference/data/__init__.py
Normal file
0
tracker/inference/data/__init__.py
Normal file
64
tracker/inference/data/mask_mapper.py
Normal file
64
tracker/inference/data/mask_mapper.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from dataset.util import all_to_onehot
|
||||
|
||||
|
||||
class MaskMapper:
|
||||
"""
|
||||
This class is used to convert a indexed-mask to a one-hot representation.
|
||||
It also takes care of remapping non-continuous indices
|
||||
It has two modes:
|
||||
1. Default. Only masks with new indices are supposed to go into the remapper.
|
||||
This is also the case for YouTubeVOS.
|
||||
i.e., regions with index 0 are not "background", but "don't care".
|
||||
|
||||
2. Exhaustive. Regions with index 0 are considered "background".
|
||||
Every single pixel is considered to be "labeled".
|
||||
"""
|
||||
def __init__(self):
|
||||
self.labels = []
|
||||
self.remappings = {}
|
||||
|
||||
# if coherent, no mapping is required
|
||||
self.coherent = True
|
||||
|
||||
def convert_mask(self, mask, exhaustive=False):
|
||||
# mask is in index representation, H*W numpy array
|
||||
labels = np.unique(mask).astype(np.uint8)
|
||||
labels = labels[labels!=0].tolist()
|
||||
|
||||
new_labels = list(set(labels) - set(self.labels))
|
||||
if not exhaustive:
|
||||
assert len(new_labels) == len(labels), 'Old labels found in non-exhaustive mode'
|
||||
|
||||
# add new remappings
|
||||
for i, l in enumerate(new_labels):
|
||||
self.remappings[l] = i+len(self.labels)+1
|
||||
if self.coherent and i+len(self.labels)+1 != l:
|
||||
self.coherent = False
|
||||
|
||||
if exhaustive:
|
||||
new_mapped_labels = range(1, len(self.labels)+len(new_labels)+1)
|
||||
else:
|
||||
if self.coherent:
|
||||
new_mapped_labels = new_labels
|
||||
else:
|
||||
new_mapped_labels = range(len(self.labels)+1, len(self.labels)+len(new_labels)+1)
|
||||
|
||||
self.labels.extend(new_labels)
|
||||
mask = torch.from_numpy(all_to_onehot(mask, self.labels)).float()
|
||||
|
||||
# mask num_objects*H*W
|
||||
return mask, new_mapped_labels
|
||||
|
||||
|
||||
def remap_index_mask(self, mask):
|
||||
# mask is in index representation, H*W numpy array
|
||||
if self.coherent:
|
||||
return mask
|
||||
|
||||
new_mask = np.zeros_like(mask)
|
||||
for l, i in self.remappings.items():
|
||||
new_mask[mask==i] = l
|
||||
return new_mask
|
||||
96
tracker/inference/data/test_datasets.py
Normal file
96
tracker/inference/data/test_datasets.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import os
|
||||
from os import path
|
||||
import json
|
||||
|
||||
from inference.data.video_reader import VideoReader
|
||||
|
||||
|
||||
class LongTestDataset:
|
||||
def __init__(self, data_root, size=-1):
|
||||
self.image_dir = path.join(data_root, 'JPEGImages')
|
||||
self.mask_dir = path.join(data_root, 'Annotations')
|
||||
self.size = size
|
||||
|
||||
self.vid_list = sorted(os.listdir(self.image_dir))
|
||||
|
||||
def get_datasets(self):
|
||||
for video in self.vid_list:
|
||||
yield VideoReader(video,
|
||||
path.join(self.image_dir, video),
|
||||
path.join(self.mask_dir, video),
|
||||
to_save = [
|
||||
name[:-4] for name in os.listdir(path.join(self.mask_dir, video))
|
||||
],
|
||||
size=self.size,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.vid_list)
|
||||
|
||||
|
||||
class DAVISTestDataset:
|
||||
def __init__(self, data_root, imset='2017/val.txt', size=-1):
|
||||
if size != 480:
|
||||
self.image_dir = path.join(data_root, 'JPEGImages', 'Full-Resolution')
|
||||
self.mask_dir = path.join(data_root, 'Annotations', 'Full-Resolution')
|
||||
if not path.exists(self.image_dir):
|
||||
print(f'{self.image_dir} not found. Look at other options.')
|
||||
self.image_dir = path.join(data_root, 'JPEGImages', '1080p')
|
||||
self.mask_dir = path.join(data_root, 'Annotations', '1080p')
|
||||
assert path.exists(self.image_dir), 'path not found'
|
||||
else:
|
||||
self.image_dir = path.join(data_root, 'JPEGImages', '480p')
|
||||
self.mask_dir = path.join(data_root, 'Annotations', '480p')
|
||||
self.size_dir = path.join(data_root, 'JPEGImages', '480p')
|
||||
self.size = size
|
||||
|
||||
with open(path.join(data_root, 'ImageSets', imset)) as f:
|
||||
self.vid_list = sorted([line.strip() for line in f])
|
||||
|
||||
def get_datasets(self):
|
||||
for video in self.vid_list:
|
||||
yield VideoReader(video,
|
||||
path.join(self.image_dir, video),
|
||||
path.join(self.mask_dir, video),
|
||||
size=self.size,
|
||||
size_dir=path.join(self.size_dir, video),
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.vid_list)
|
||||
|
||||
|
||||
class YouTubeVOSTestDataset:
|
||||
def __init__(self, data_root, split, size=480):
|
||||
self.image_dir = path.join(data_root, 'all_frames', split+'_all_frames', 'JPEGImages')
|
||||
self.mask_dir = path.join(data_root, split, 'Annotations')
|
||||
self.size = size
|
||||
|
||||
self.vid_list = sorted(os.listdir(self.image_dir))
|
||||
self.req_frame_list = {}
|
||||
|
||||
with open(path.join(data_root, split, 'meta.json')) as f:
|
||||
# read meta.json to know which frame is required for evaluation
|
||||
meta = json.load(f)['videos']
|
||||
|
||||
for vid in self.vid_list:
|
||||
req_frames = []
|
||||
objects = meta[vid]['objects']
|
||||
for value in objects.values():
|
||||
req_frames.extend(value['frames'])
|
||||
|
||||
req_frames = list(set(req_frames))
|
||||
self.req_frame_list[vid] = req_frames
|
||||
|
||||
def get_datasets(self):
|
||||
for video in self.vid_list:
|
||||
yield VideoReader(video,
|
||||
path.join(self.image_dir, video),
|
||||
path.join(self.mask_dir, video),
|
||||
size=self.size,
|
||||
to_save=self.req_frame_list[video],
|
||||
use_all_mask=True
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.vid_list)
|
||||
100
tracker/inference/data/video_reader.py
Normal file
100
tracker/inference/data/video_reader.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import os
|
||||
from os import path
|
||||
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import InterpolationMode
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
from dataset.range_transform import im_normalization
|
||||
|
||||
|
||||
class VideoReader(Dataset):
|
||||
"""
|
||||
This class is used to read a video, one frame at a time
|
||||
"""
|
||||
def __init__(self, vid_name, image_dir, mask_dir, size=-1, to_save=None, use_all_mask=False, size_dir=None):
|
||||
"""
|
||||
image_dir - points to a directory of jpg images
|
||||
mask_dir - points to a directory of png masks
|
||||
size - resize min. side to size. Does nothing if <0.
|
||||
to_save - optionally contains a list of file names without extensions
|
||||
where the segmentation mask is required
|
||||
use_all_mask - when true, read all available mask in mask_dir.
|
||||
Default false. Set to true for YouTubeVOS validation.
|
||||
"""
|
||||
self.vid_name = vid_name
|
||||
self.image_dir = image_dir
|
||||
self.mask_dir = mask_dir
|
||||
self.to_save = to_save
|
||||
self.use_all_mask = use_all_mask
|
||||
if size_dir is None:
|
||||
self.size_dir = self.image_dir
|
||||
else:
|
||||
self.size_dir = size_dir
|
||||
|
||||
self.frames = sorted(os.listdir(self.image_dir))
|
||||
self.palette = Image.open(path.join(mask_dir, sorted(os.listdir(mask_dir))[0])).getpalette()
|
||||
self.first_gt_path = path.join(self.mask_dir, sorted(os.listdir(self.mask_dir))[0])
|
||||
|
||||
if size < 0:
|
||||
self.im_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
im_normalization,
|
||||
])
|
||||
else:
|
||||
self.im_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
im_normalization,
|
||||
transforms.Resize(size, interpolation=InterpolationMode.BILINEAR),
|
||||
])
|
||||
self.size = size
|
||||
|
||||
|
||||
def __getitem__(self, idx):
|
||||
frame = self.frames[idx]
|
||||
info = {}
|
||||
data = {}
|
||||
info['frame'] = frame
|
||||
info['save'] = (self.to_save is None) or (frame[:-4] in self.to_save)
|
||||
|
||||
im_path = path.join(self.image_dir, frame)
|
||||
img = Image.open(im_path).convert('RGB')
|
||||
|
||||
if self.image_dir == self.size_dir:
|
||||
shape = np.array(img).shape[:2]
|
||||
else:
|
||||
size_path = path.join(self.size_dir, frame)
|
||||
size_im = Image.open(size_path).convert('RGB')
|
||||
shape = np.array(size_im).shape[:2]
|
||||
|
||||
gt_path = path.join(self.mask_dir, frame[:-4]+'.png')
|
||||
img = self.im_transform(img)
|
||||
|
||||
load_mask = self.use_all_mask or (gt_path == self.first_gt_path)
|
||||
if load_mask and path.exists(gt_path):
|
||||
mask = Image.open(gt_path).convert('P')
|
||||
mask = np.array(mask, dtype=np.uint8)
|
||||
data['mask'] = mask
|
||||
|
||||
info['shape'] = shape
|
||||
info['need_resize'] = not (self.size < 0)
|
||||
data['rgb'] = img
|
||||
data['info'] = info
|
||||
|
||||
return data
|
||||
|
||||
def resize_mask(self, mask):
|
||||
# mask transform is applied AFTER mapper, so we need to post-process it in eval.py
|
||||
h, w = mask.shape[-2:]
|
||||
min_hw = min(h, w)
|
||||
return F.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)),
|
||||
mode='nearest')
|
||||
|
||||
def get_palette(self):
|
||||
return self.palette
|
||||
|
||||
def __len__(self):
|
||||
return len(self.frames)
|
||||
110
tracker/inference/inference_core.py
Normal file
110
tracker/inference/inference_core.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from inference.memory_manager import MemoryManager
|
||||
from model.network import XMem
|
||||
from model.aggregate import aggregate
|
||||
|
||||
from util.tensor_util import pad_divide_by, unpad
|
||||
|
||||
|
||||
class InferenceCore:
|
||||
def __init__(self, network:XMem, config):
|
||||
self.config = config
|
||||
self.network = network
|
||||
self.mem_every = config['mem_every']
|
||||
self.deep_update_every = config['deep_update_every']
|
||||
self.enable_long_term = config['enable_long_term']
|
||||
|
||||
# if deep_update_every < 0, synchronize deep update with memory frame
|
||||
self.deep_update_sync = (self.deep_update_every < 0)
|
||||
|
||||
self.clear_memory()
|
||||
self.all_labels = None
|
||||
|
||||
def clear_memory(self):
|
||||
self.curr_ti = -1
|
||||
self.last_mem_ti = 0
|
||||
if not self.deep_update_sync:
|
||||
self.last_deep_update_ti = -self.deep_update_every
|
||||
self.memory = MemoryManager(config=self.config)
|
||||
|
||||
def update_config(self, config):
|
||||
self.mem_every = config['mem_every']
|
||||
self.deep_update_every = config['deep_update_every']
|
||||
self.enable_long_term = config['enable_long_term']
|
||||
|
||||
# if deep_update_every < 0, synchronize deep update with memory frame
|
||||
self.deep_update_sync = (self.deep_update_every < 0)
|
||||
self.memory.update_config(config)
|
||||
|
||||
def set_all_labels(self, all_labels):
|
||||
# self.all_labels = [l.item() for l in all_labels]
|
||||
self.all_labels = all_labels
|
||||
|
||||
def step(self, image, mask=None, valid_labels=None, end=False):
|
||||
# image: 3*H*W
|
||||
# mask: num_objects*H*W or None
|
||||
self.curr_ti += 1
|
||||
image, self.pad = pad_divide_by(image, 16)
|
||||
image = image.unsqueeze(0) # add the batch dimension
|
||||
|
||||
is_mem_frame = ((self.curr_ti-self.last_mem_ti >= self.mem_every) or (mask is not None)) and (not end)
|
||||
need_segment = (self.curr_ti > 0) and ((valid_labels is None) or (len(self.all_labels) != len(valid_labels)))
|
||||
is_deep_update = (
|
||||
(self.deep_update_sync and is_mem_frame) or # synchronized
|
||||
(not self.deep_update_sync and self.curr_ti-self.last_deep_update_ti >= self.deep_update_every) # no-sync
|
||||
) and (not end)
|
||||
is_normal_update = (not self.deep_update_sync or not is_deep_update) and (not end)
|
||||
|
||||
key, shrinkage, selection, f16, f8, f4 = self.network.encode_key(image,
|
||||
need_ek=(self.enable_long_term or need_segment),
|
||||
need_sk=is_mem_frame)
|
||||
multi_scale_features = (f16, f8, f4)
|
||||
|
||||
# segment the current frame is needed
|
||||
if need_segment:
|
||||
memory_readout = self.memory.match_memory(key, selection).unsqueeze(0)
|
||||
|
||||
|
||||
|
||||
hidden, _, pred_prob_with_bg = self.network.segment(multi_scale_features, memory_readout,
|
||||
self.memory.get_hidden(), h_out=is_normal_update, strip_bg=False)
|
||||
# remove batch dim
|
||||
pred_prob_with_bg = pred_prob_with_bg[0]
|
||||
pred_prob_no_bg = pred_prob_with_bg[1:]
|
||||
if is_normal_update:
|
||||
self.memory.set_hidden(hidden)
|
||||
else:
|
||||
pred_prob_no_bg = pred_prob_with_bg = None
|
||||
|
||||
# use the input mask if any
|
||||
if mask is not None:
|
||||
mask, _ = pad_divide_by(mask, 16)
|
||||
|
||||
if pred_prob_no_bg is not None:
|
||||
# if we have a predicted mask, we work on it
|
||||
# make pred_prob_no_bg consistent with the input mask
|
||||
mask_regions = (mask.sum(0) > 0.5)
|
||||
pred_prob_no_bg[:, mask_regions] = 0
|
||||
# shift by 1 because mask/pred_prob_no_bg do not contain background
|
||||
mask = mask.type_as(pred_prob_no_bg)
|
||||
if valid_labels is not None:
|
||||
shift_by_one_non_labels = [i for i in range(pred_prob_no_bg.shape[0]) if (i+1) not in valid_labels]
|
||||
# non-labelled objects are copied from the predicted mask
|
||||
mask[shift_by_one_non_labels] = pred_prob_no_bg[shift_by_one_non_labels]
|
||||
pred_prob_with_bg = aggregate(mask, dim=0)
|
||||
|
||||
# also create new hidden states
|
||||
self.memory.create_hidden_state(len(self.all_labels), key)
|
||||
|
||||
# save as memory if needed
|
||||
if is_mem_frame:
|
||||
value, hidden = self.network.encode_value(image, f16, self.memory.get_hidden(),
|
||||
pred_prob_with_bg[1:].unsqueeze(0), is_deep_update=is_deep_update)
|
||||
self.memory.add_memory(key, shrinkage, value, self.all_labels,
|
||||
selection=selection if self.enable_long_term else None)
|
||||
self.last_mem_ti = self.curr_ti
|
||||
|
||||
if is_deep_update:
|
||||
self.memory.set_hidden(hidden)
|
||||
self.last_deep_update_ti = self.curr_ti
|
||||
|
||||
return unpad(pred_prob_with_bg, self.pad)
|
||||
215
tracker/inference/kv_memory_store.py
Normal file
215
tracker/inference/kv_memory_store.py
Normal file
@@ -0,0 +1,215 @@
|
||||
import torch
|
||||
from typing import List
|
||||
|
||||
class KeyValueMemoryStore:
|
||||
"""
|
||||
Works for key/value pairs type storage
|
||||
e.g., working and long-term memory
|
||||
"""
|
||||
|
||||
"""
|
||||
An object group is created when new objects enter the video
|
||||
Objects in the same group share the same temporal extent
|
||||
i.e., objects initialized in the same frame are in the same group
|
||||
For DAVIS/interactive, there is only one object group
|
||||
For YouTubeVOS, there can be multiple object groups
|
||||
"""
|
||||
|
||||
def __init__(self, count_usage: bool):
|
||||
self.count_usage = count_usage
|
||||
|
||||
# keys are stored in a single tensor and are shared between groups/objects
|
||||
# values are stored as a list indexed by object groups
|
||||
self.k = None
|
||||
self.v = []
|
||||
self.obj_groups = []
|
||||
# for debugging only
|
||||
self.all_objects = []
|
||||
|
||||
# shrinkage and selection are also single tensors
|
||||
self.s = self.e = None
|
||||
|
||||
# usage
|
||||
if self.count_usage:
|
||||
self.use_count = self.life_count = None
|
||||
|
||||
def add(self, key, value, shrinkage, selection, objects: List[int]):
|
||||
new_count = torch.zeros((key.shape[0], 1, key.shape[2]), device=key.device, dtype=torch.float32)
|
||||
new_life = torch.zeros((key.shape[0], 1, key.shape[2]), device=key.device, dtype=torch.float32) + 1e-7
|
||||
|
||||
# add the key
|
||||
if self.k is None:
|
||||
self.k = key
|
||||
self.s = shrinkage
|
||||
self.e = selection
|
||||
if self.count_usage:
|
||||
self.use_count = new_count
|
||||
self.life_count = new_life
|
||||
else:
|
||||
self.k = torch.cat([self.k, key], -1)
|
||||
if shrinkage is not None:
|
||||
self.s = torch.cat([self.s, shrinkage], -1)
|
||||
if selection is not None:
|
||||
self.e = torch.cat([self.e, selection], -1)
|
||||
if self.count_usage:
|
||||
self.use_count = torch.cat([self.use_count, new_count], -1)
|
||||
self.life_count = torch.cat([self.life_count, new_life], -1)
|
||||
|
||||
# add the value
|
||||
if objects is not None:
|
||||
# When objects is given, v is a tensor; used in working memory
|
||||
assert isinstance(value, torch.Tensor)
|
||||
# First consume objects that are already in the memory bank
|
||||
# cannot use set here because we need to preserve order
|
||||
# shift by one as background is not part of value
|
||||
remaining_objects = [obj-1 for obj in objects]
|
||||
for gi, group in enumerate(self.obj_groups):
|
||||
for obj in group:
|
||||
# should properly raise an error if there are overlaps in obj_groups
|
||||
remaining_objects.remove(obj)
|
||||
self.v[gi] = torch.cat([self.v[gi], value[group]], -1)
|
||||
|
||||
# If there are remaining objects, add them as a new group
|
||||
if len(remaining_objects) > 0:
|
||||
new_group = list(remaining_objects)
|
||||
self.v.append(value[new_group])
|
||||
self.obj_groups.append(new_group)
|
||||
self.all_objects.extend(new_group)
|
||||
|
||||
assert sorted(self.all_objects) == self.all_objects, 'Objects MUST be inserted in sorted order '
|
||||
else:
|
||||
# When objects is not given, v is a list that already has the object groups sorted
|
||||
# used in long-term memory
|
||||
assert isinstance(value, list)
|
||||
for gi, gv in enumerate(value):
|
||||
if gv is None:
|
||||
continue
|
||||
if gi < self.num_groups:
|
||||
self.v[gi] = torch.cat([self.v[gi], gv], -1)
|
||||
else:
|
||||
self.v.append(gv)
|
||||
|
||||
def update_usage(self, usage):
|
||||
# increase all life count by 1
|
||||
# increase use of indexed elements
|
||||
if not self.count_usage:
|
||||
return
|
||||
|
||||
self.use_count += usage.view_as(self.use_count)
|
||||
self.life_count += 1
|
||||
|
||||
def sieve_by_range(self, start: int, end: int, min_size: int):
|
||||
# keep only the elements *outside* of this range (with some boundary conditions)
|
||||
# i.e., concat (a[:start], a[end:])
|
||||
# min_size is only used for values, we do not sieve values under this size
|
||||
# (because they are not consolidated)
|
||||
|
||||
if end == 0:
|
||||
# negative 0 would not work as the end index!
|
||||
self.k = self.k[:,:,:start]
|
||||
if self.count_usage:
|
||||
self.use_count = self.use_count[:,:,:start]
|
||||
self.life_count = self.life_count[:,:,:start]
|
||||
if self.s is not None:
|
||||
self.s = self.s[:,:,:start]
|
||||
if self.e is not None:
|
||||
self.e = self.e[:,:,:start]
|
||||
|
||||
for gi in range(self.num_groups):
|
||||
if self.v[gi].shape[-1] >= min_size:
|
||||
self.v[gi] = self.v[gi][:,:,:start]
|
||||
else:
|
||||
self.k = torch.cat([self.k[:,:,:start], self.k[:,:,end:]], -1)
|
||||
if self.count_usage:
|
||||
self.use_count = torch.cat([self.use_count[:,:,:start], self.use_count[:,:,end:]], -1)
|
||||
self.life_count = torch.cat([self.life_count[:,:,:start], self.life_count[:,:,end:]], -1)
|
||||
if self.s is not None:
|
||||
self.s = torch.cat([self.s[:,:,:start], self.s[:,:,end:]], -1)
|
||||
if self.e is not None:
|
||||
self.e = torch.cat([self.e[:,:,:start], self.e[:,:,end:]], -1)
|
||||
|
||||
for gi in range(self.num_groups):
|
||||
if self.v[gi].shape[-1] >= min_size:
|
||||
self.v[gi] = torch.cat([self.v[gi][:,:,:start], self.v[gi][:,:,end:]], -1)
|
||||
|
||||
def remove_obsolete_features(self, max_size: int):
|
||||
# normalize with life duration
|
||||
usage = self.get_usage().flatten()
|
||||
|
||||
values, _ = torch.topk(usage, k=(self.size-max_size), largest=False, sorted=True)
|
||||
survived = (usage > values[-1])
|
||||
|
||||
self.k = self.k[:, :, survived]
|
||||
self.s = self.s[:, :, survived] if self.s is not None else None
|
||||
# Long-term memory does not store ek so this should not be needed
|
||||
self.e = self.e[:, :, survived] if self.e is not None else None
|
||||
if self.num_groups > 1:
|
||||
raise NotImplementedError("""The current data structure does not support feature removal with
|
||||
multiple object groups (e.g., some objects start to appear later in the video)
|
||||
The indices for "survived" is based on keys but not all values are present for every key
|
||||
Basically we need to remap the indices for keys to values
|
||||
""")
|
||||
for gi in range(self.num_groups):
|
||||
self.v[gi] = self.v[gi][:, :, survived]
|
||||
|
||||
self.use_count = self.use_count[:, :, survived]
|
||||
self.life_count = self.life_count[:, :, survived]
|
||||
|
||||
def get_usage(self):
|
||||
# return normalized usage
|
||||
if not self.count_usage:
|
||||
raise RuntimeError('I did not count usage!')
|
||||
else:
|
||||
usage = self.use_count / self.life_count
|
||||
return usage
|
||||
|
||||
def get_all_sliced(self, start: int, end: int):
|
||||
# return k, sk, ek, usage in order, sliced by start and end
|
||||
|
||||
if end == 0:
|
||||
# negative 0 would not work as the end index!
|
||||
k = self.k[:,:,start:]
|
||||
sk = self.s[:,:,start:] if self.s is not None else None
|
||||
ek = self.e[:,:,start:] if self.e is not None else None
|
||||
usage = self.get_usage()[:,:,start:]
|
||||
else:
|
||||
k = self.k[:,:,start:end]
|
||||
sk = self.s[:,:,start:end] if self.s is not None else None
|
||||
ek = self.e[:,:,start:end] if self.e is not None else None
|
||||
usage = self.get_usage()[:,:,start:end]
|
||||
|
||||
return k, sk, ek, usage
|
||||
|
||||
def get_v_size(self, ni: int):
|
||||
return self.v[ni].shape[2]
|
||||
|
||||
def engaged(self):
|
||||
return self.k is not None
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
if self.k is None:
|
||||
return 0
|
||||
else:
|
||||
return self.k.shape[-1]
|
||||
|
||||
@property
|
||||
def num_groups(self):
|
||||
return len(self.v)
|
||||
|
||||
@property
|
||||
def key(self):
|
||||
return self.k
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self.v
|
||||
|
||||
@property
|
||||
def shrinkage(self):
|
||||
return self.s
|
||||
|
||||
@property
|
||||
def selection(self):
|
||||
return self.e
|
||||
|
||||
284
tracker/inference/memory_manager.py
Normal file
284
tracker/inference/memory_manager.py
Normal file
@@ -0,0 +1,284 @@
|
||||
import torch
|
||||
import warnings
|
||||
|
||||
from inference.kv_memory_store import KeyValueMemoryStore
|
||||
from model.memory_util import *
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
"""
|
||||
Manages all three memory stores and the transition between working/long-term memory
|
||||
"""
|
||||
def __init__(self, config):
|
||||
self.hidden_dim = config['hidden_dim']
|
||||
self.top_k = config['top_k']
|
||||
|
||||
self.enable_long_term = config['enable_long_term']
|
||||
self.enable_long_term_usage = config['enable_long_term_count_usage']
|
||||
if self.enable_long_term:
|
||||
self.max_mt_frames = config['max_mid_term_frames']
|
||||
self.min_mt_frames = config['min_mid_term_frames']
|
||||
self.num_prototypes = config['num_prototypes']
|
||||
self.max_long_elements = config['max_long_term_elements']
|
||||
|
||||
# dimensions will be inferred from input later
|
||||
self.CK = self.CV = None
|
||||
self.H = self.W = None
|
||||
|
||||
# The hidden state will be stored in a single tensor for all objects
|
||||
# B x num_objects x CH x H x W
|
||||
self.hidden = None
|
||||
|
||||
self.work_mem = KeyValueMemoryStore(count_usage=self.enable_long_term)
|
||||
if self.enable_long_term:
|
||||
self.long_mem = KeyValueMemoryStore(count_usage=self.enable_long_term_usage)
|
||||
|
||||
self.reset_config = True
|
||||
|
||||
def update_config(self, config):
|
||||
self.reset_config = True
|
||||
self.hidden_dim = config['hidden_dim']
|
||||
self.top_k = config['top_k']
|
||||
|
||||
assert self.enable_long_term == config['enable_long_term'], 'cannot update this'
|
||||
assert self.enable_long_term_usage == config['enable_long_term_count_usage'], 'cannot update this'
|
||||
|
||||
self.enable_long_term_usage = config['enable_long_term_count_usage']
|
||||
if self.enable_long_term:
|
||||
self.max_mt_frames = config['max_mid_term_frames']
|
||||
self.min_mt_frames = config['min_mid_term_frames']
|
||||
self.num_prototypes = config['num_prototypes']
|
||||
self.max_long_elements = config['max_long_term_elements']
|
||||
|
||||
def _readout(self, affinity, v):
|
||||
# this function is for a single object group
|
||||
return v @ affinity
|
||||
|
||||
def match_memory(self, query_key, selection):
|
||||
# query_key: B x C^k x H x W
|
||||
# selection: B x C^k x H x W
|
||||
num_groups = self.work_mem.num_groups
|
||||
h, w = query_key.shape[-2:]
|
||||
|
||||
query_key = query_key.flatten(start_dim=2)
|
||||
selection = selection.flatten(start_dim=2) if selection is not None else None
|
||||
|
||||
"""
|
||||
Memory readout using keys
|
||||
"""
|
||||
|
||||
if self.enable_long_term and self.long_mem.engaged():
|
||||
# Use long-term memory
|
||||
long_mem_size = self.long_mem.size
|
||||
memory_key = torch.cat([self.long_mem.key, self.work_mem.key], -1)
|
||||
shrinkage = torch.cat([self.long_mem.shrinkage, self.work_mem.shrinkage], -1)
|
||||
|
||||
similarity = get_similarity(memory_key, shrinkage, query_key, selection)
|
||||
work_mem_similarity = similarity[:, long_mem_size:]
|
||||
long_mem_similarity = similarity[:, :long_mem_size]
|
||||
|
||||
# get the usage with the first group
|
||||
# the first group always have all the keys valid
|
||||
affinity, usage = do_softmax(
|
||||
torch.cat([long_mem_similarity[:, -self.long_mem.get_v_size(0):], work_mem_similarity], 1),
|
||||
top_k=self.top_k, inplace=True, return_usage=True)
|
||||
affinity = [affinity]
|
||||
|
||||
# compute affinity group by group as later groups only have a subset of keys
|
||||
for gi in range(1, num_groups):
|
||||
if gi < self.long_mem.num_groups:
|
||||
# merge working and lt similarities before softmax
|
||||
affinity_one_group = do_softmax(
|
||||
torch.cat([long_mem_similarity[:, -self.long_mem.get_v_size(gi):],
|
||||
work_mem_similarity[:, -self.work_mem.get_v_size(gi):]], 1),
|
||||
top_k=self.top_k, inplace=True)
|
||||
else:
|
||||
# no long-term memory for this group
|
||||
affinity_one_group = do_softmax(work_mem_similarity[:, -self.work_mem.get_v_size(gi):],
|
||||
top_k=self.top_k, inplace=(gi==num_groups-1))
|
||||
affinity.append(affinity_one_group)
|
||||
|
||||
all_memory_value = []
|
||||
for gi, gv in enumerate(self.work_mem.value):
|
||||
# merge the working and lt values before readout
|
||||
if gi < self.long_mem.num_groups:
|
||||
all_memory_value.append(torch.cat([self.long_mem.value[gi], self.work_mem.value[gi]], -1))
|
||||
else:
|
||||
all_memory_value.append(gv)
|
||||
|
||||
"""
|
||||
Record memory usage for working and long-term memory
|
||||
"""
|
||||
# ignore the index return for long-term memory
|
||||
work_usage = usage[:, long_mem_size:]
|
||||
self.work_mem.update_usage(work_usage.flatten())
|
||||
|
||||
if self.enable_long_term_usage:
|
||||
# ignore the index return for working memory
|
||||
long_usage = usage[:, :long_mem_size]
|
||||
self.long_mem.update_usage(long_usage.flatten())
|
||||
else:
|
||||
# No long-term memory
|
||||
similarity = get_similarity(self.work_mem.key, self.work_mem.shrinkage, query_key, selection)
|
||||
|
||||
if self.enable_long_term:
|
||||
affinity, usage = do_softmax(similarity, inplace=(num_groups==1),
|
||||
top_k=self.top_k, return_usage=True)
|
||||
|
||||
# Record memory usage for working memory
|
||||
self.work_mem.update_usage(usage.flatten())
|
||||
else:
|
||||
affinity = do_softmax(similarity, inplace=(num_groups==1),
|
||||
top_k=self.top_k, return_usage=False)
|
||||
|
||||
affinity = [affinity]
|
||||
|
||||
# compute affinity group by group as later groups only have a subset of keys
|
||||
for gi in range(1, num_groups):
|
||||
affinity_one_group = do_softmax(similarity[:, -self.work_mem.get_v_size(gi):],
|
||||
top_k=self.top_k, inplace=(gi==num_groups-1))
|
||||
affinity.append(affinity_one_group)
|
||||
|
||||
all_memory_value = self.work_mem.value
|
||||
|
||||
# Shared affinity within each group
|
||||
all_readout_mem = torch.cat([
|
||||
self._readout(affinity[gi], gv)
|
||||
for gi, gv in enumerate(all_memory_value)
|
||||
], 0)
|
||||
|
||||
return all_readout_mem.view(all_readout_mem.shape[0], self.CV, h, w)
|
||||
|
||||
def add_memory(self, key, shrinkage, value, objects, selection=None):
|
||||
# key: 1*C*H*W
|
||||
# value: 1*num_objects*C*H*W
|
||||
# objects contain a list of object indices
|
||||
if self.H is None or self.reset_config:
|
||||
self.reset_config = False
|
||||
self.H, self.W = key.shape[-2:]
|
||||
self.HW = self.H*self.W
|
||||
if self.enable_long_term:
|
||||
# convert from num. frames to num. nodes
|
||||
self.min_work_elements = self.min_mt_frames*self.HW
|
||||
self.max_work_elements = self.max_mt_frames*self.HW
|
||||
|
||||
# key: 1*C*N
|
||||
# value: num_objects*C*N
|
||||
key = key.flatten(start_dim=2)
|
||||
shrinkage = shrinkage.flatten(start_dim=2)
|
||||
value = value[0].flatten(start_dim=2)
|
||||
|
||||
self.CK = key.shape[1]
|
||||
self.CV = value.shape[1]
|
||||
|
||||
if selection is not None:
|
||||
if not self.enable_long_term:
|
||||
warnings.warn('the selection factor is only needed in long-term mode', UserWarning)
|
||||
selection = selection.flatten(start_dim=2)
|
||||
|
||||
self.work_mem.add(key, value, shrinkage, selection, objects)
|
||||
|
||||
# long-term memory cleanup
|
||||
if self.enable_long_term:
|
||||
# Do memory compressed if needed
|
||||
if self.work_mem.size >= self.max_work_elements:
|
||||
# Remove obsolete features if needed
|
||||
if self.long_mem.size >= (self.max_long_elements-self.num_prototypes):
|
||||
self.long_mem.remove_obsolete_features(self.max_long_elements-self.num_prototypes)
|
||||
|
||||
self.compress_features()
|
||||
|
||||
|
||||
def create_hidden_state(self, n, sample_key):
|
||||
# n is the TOTAL number of objects
|
||||
h, w = sample_key.shape[-2:]
|
||||
if self.hidden is None:
|
||||
self.hidden = torch.zeros((1, n, self.hidden_dim, h, w), device=sample_key.device)
|
||||
elif self.hidden.shape[1] != n:
|
||||
self.hidden = torch.cat([
|
||||
self.hidden,
|
||||
torch.zeros((1, n-self.hidden.shape[1], self.hidden_dim, h, w), device=sample_key.device)
|
||||
], 1)
|
||||
|
||||
assert(self.hidden.shape[1] == n)
|
||||
|
||||
def set_hidden(self, hidden):
|
||||
self.hidden = hidden
|
||||
|
||||
def get_hidden(self):
|
||||
return self.hidden
|
||||
|
||||
def compress_features(self):
|
||||
HW = self.HW
|
||||
candidate_value = []
|
||||
total_work_mem_size = self.work_mem.size
|
||||
for gv in self.work_mem.value:
|
||||
# Some object groups might be added later in the video
|
||||
# So not all keys have values associated with all objects
|
||||
# We need to keep track of the key->value validity
|
||||
mem_size_in_this_group = gv.shape[-1]
|
||||
if mem_size_in_this_group == total_work_mem_size:
|
||||
# full LT
|
||||
candidate_value.append(gv[:,:,HW:-self.min_work_elements+HW])
|
||||
else:
|
||||
# mem_size is smaller than total_work_mem_size, but at least HW
|
||||
assert HW <= mem_size_in_this_group < total_work_mem_size
|
||||
if mem_size_in_this_group > self.min_work_elements+HW:
|
||||
# part of this object group still goes into LT
|
||||
candidate_value.append(gv[:,:,HW:-self.min_work_elements+HW])
|
||||
else:
|
||||
# this object group cannot go to the LT at all
|
||||
candidate_value.append(None)
|
||||
|
||||
# perform memory consolidation
|
||||
prototype_key, prototype_value, prototype_shrinkage = self.consolidation(
|
||||
*self.work_mem.get_all_sliced(HW, -self.min_work_elements+HW), candidate_value)
|
||||
|
||||
# remove consolidated working memory
|
||||
self.work_mem.sieve_by_range(HW, -self.min_work_elements+HW, min_size=self.min_work_elements+HW)
|
||||
|
||||
# add to long-term memory
|
||||
self.long_mem.add(prototype_key, prototype_value, prototype_shrinkage, selection=None, objects=None)
|
||||
|
||||
def consolidation(self, candidate_key, candidate_shrinkage, candidate_selection, usage, candidate_value):
|
||||
# keys: 1*C*N
|
||||
# values: num_objects*C*N
|
||||
N = candidate_key.shape[-1]
|
||||
|
||||
# find the indices with max usage
|
||||
_, max_usage_indices = torch.topk(usage, k=self.num_prototypes, dim=-1, sorted=True)
|
||||
prototype_indices = max_usage_indices.flatten()
|
||||
|
||||
# Prototypes are invalid for out-of-bound groups
|
||||
validity = [prototype_indices >= (N-gv.shape[2]) if gv is not None else None for gv in candidate_value]
|
||||
|
||||
prototype_key = candidate_key[:, :, prototype_indices]
|
||||
prototype_selection = candidate_selection[:, :, prototype_indices] if candidate_selection is not None else None
|
||||
|
||||
"""
|
||||
Potentiation step
|
||||
"""
|
||||
similarity = get_similarity(candidate_key, candidate_shrinkage, prototype_key, prototype_selection)
|
||||
|
||||
# convert similarity to affinity
|
||||
# need to do it group by group since the softmax normalization would be different
|
||||
affinity = [
|
||||
do_softmax(similarity[:, -gv.shape[2]:, validity[gi]]) if gv is not None else None
|
||||
for gi, gv in enumerate(candidate_value)
|
||||
]
|
||||
|
||||
# some values can be have all False validity. Weed them out.
|
||||
affinity = [
|
||||
aff if aff is None or aff.shape[-1] > 0 else None for aff in affinity
|
||||
]
|
||||
|
||||
# readout the values
|
||||
prototype_value = [
|
||||
self._readout(affinity[gi], gv) if affinity[gi] is not None else None
|
||||
for gi, gv in enumerate(candidate_value)
|
||||
]
|
||||
|
||||
# readout the shrinkage term
|
||||
prototype_shrinkage = self._readout(affinity[0], candidate_shrinkage) if candidate_shrinkage is not None else None
|
||||
|
||||
return prototype_key, prototype_value, prototype_shrinkage
|
||||
0
tracker/model/__init__.py
Normal file
0
tracker/model/__init__.py
Normal file
17
tracker/model/aggregate.py
Normal file
17
tracker/model/aggregate.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# Soft aggregation from STM
|
||||
def aggregate(prob, dim, return_logits=False):
|
||||
new_prob = torch.cat([
|
||||
torch.prod(1-prob, dim=dim, keepdim=True),
|
||||
prob
|
||||
], dim).clamp(1e-7, 1-1e-7)
|
||||
logits = torch.log((new_prob /(1-new_prob)))
|
||||
prob = F.softmax(logits, dim=dim)
|
||||
|
||||
if return_logits:
|
||||
return logits, prob
|
||||
else:
|
||||
return prob
|
||||
77
tracker/model/cbam.py
Normal file
77
tracker/model/cbam.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# Modified from https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class BasicConv(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
||||
super(BasicConv, self).__init__()
|
||||
self.out_channels = out_planes
|
||||
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
class Flatten(nn.Module):
|
||||
def forward(self, x):
|
||||
return x.view(x.size(0), -1)
|
||||
|
||||
class ChannelGate(nn.Module):
|
||||
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
|
||||
super(ChannelGate, self).__init__()
|
||||
self.gate_channels = gate_channels
|
||||
self.mlp = nn.Sequential(
|
||||
Flatten(),
|
||||
nn.Linear(gate_channels, gate_channels // reduction_ratio),
|
||||
nn.ReLU(),
|
||||
nn.Linear(gate_channels // reduction_ratio, gate_channels)
|
||||
)
|
||||
self.pool_types = pool_types
|
||||
def forward(self, x):
|
||||
channel_att_sum = None
|
||||
for pool_type in self.pool_types:
|
||||
if pool_type=='avg':
|
||||
avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
|
||||
channel_att_raw = self.mlp( avg_pool )
|
||||
elif pool_type=='max':
|
||||
max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
|
||||
channel_att_raw = self.mlp( max_pool )
|
||||
|
||||
if channel_att_sum is None:
|
||||
channel_att_sum = channel_att_raw
|
||||
else:
|
||||
channel_att_sum = channel_att_sum + channel_att_raw
|
||||
|
||||
scale = torch.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
|
||||
return x * scale
|
||||
|
||||
class ChannelPool(nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )
|
||||
|
||||
class SpatialGate(nn.Module):
|
||||
def __init__(self):
|
||||
super(SpatialGate, self).__init__()
|
||||
kernel_size = 7
|
||||
self.compress = ChannelPool()
|
||||
self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2)
|
||||
def forward(self, x):
|
||||
x_compress = self.compress(x)
|
||||
x_out = self.spatial(x_compress)
|
||||
scale = torch.sigmoid(x_out) # broadcasting
|
||||
return x * scale
|
||||
|
||||
class CBAM(nn.Module):
|
||||
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
|
||||
super(CBAM, self).__init__()
|
||||
self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
|
||||
self.no_spatial=no_spatial
|
||||
if not no_spatial:
|
||||
self.SpatialGate = SpatialGate()
|
||||
def forward(self, x):
|
||||
x_out = self.ChannelGate(x)
|
||||
if not self.no_spatial:
|
||||
x_out = self.SpatialGate(x_out)
|
||||
return x_out
|
||||
82
tracker/model/group_modules.py
Normal file
82
tracker/model/group_modules.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""
|
||||
Group-specific modules
|
||||
They handle features that also depends on the mask.
|
||||
Features are typically of shape
|
||||
batch_size * num_objects * num_channels * H * W
|
||||
|
||||
All of them are permutation equivariant w.r.t. to the num_objects dimension
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def interpolate_groups(g, ratio, mode, align_corners):
|
||||
batch_size, num_objects = g.shape[:2]
|
||||
g = F.interpolate(g.flatten(start_dim=0, end_dim=1),
|
||||
scale_factor=ratio, mode=mode, align_corners=align_corners)
|
||||
g = g.view(batch_size, num_objects, *g.shape[1:])
|
||||
return g
|
||||
|
||||
def upsample_groups(g, ratio=2, mode='bilinear', align_corners=False):
|
||||
return interpolate_groups(g, ratio, mode, align_corners)
|
||||
|
||||
def downsample_groups(g, ratio=1/2, mode='area', align_corners=None):
|
||||
return interpolate_groups(g, ratio, mode, align_corners)
|
||||
|
||||
|
||||
class GConv2D(nn.Conv2d):
|
||||
def forward(self, g):
|
||||
batch_size, num_objects = g.shape[:2]
|
||||
g = super().forward(g.flatten(start_dim=0, end_dim=1))
|
||||
return g.view(batch_size, num_objects, *g.shape[1:])
|
||||
|
||||
|
||||
class GroupResBlock(nn.Module):
|
||||
def __init__(self, in_dim, out_dim):
|
||||
super().__init__()
|
||||
|
||||
if in_dim == out_dim:
|
||||
self.downsample = None
|
||||
else:
|
||||
self.downsample = GConv2D(in_dim, out_dim, kernel_size=3, padding=1)
|
||||
|
||||
self.conv1 = GConv2D(in_dim, out_dim, kernel_size=3, padding=1)
|
||||
self.conv2 = GConv2D(out_dim, out_dim, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, g):
|
||||
out_g = self.conv1(F.relu(g))
|
||||
out_g = self.conv2(F.relu(out_g))
|
||||
|
||||
if self.downsample is not None:
|
||||
g = self.downsample(g)
|
||||
|
||||
return out_g + g
|
||||
|
||||
|
||||
class MainToGroupDistributor(nn.Module):
|
||||
def __init__(self, x_transform=None, method='cat', reverse_order=False):
|
||||
super().__init__()
|
||||
|
||||
self.x_transform = x_transform
|
||||
self.method = method
|
||||
self.reverse_order = reverse_order
|
||||
|
||||
def forward(self, x, g):
|
||||
num_objects = g.shape[1]
|
||||
|
||||
if self.x_transform is not None:
|
||||
x = self.x_transform(x)
|
||||
|
||||
if self.method == 'cat':
|
||||
if self.reverse_order:
|
||||
g = torch.cat([g, x.unsqueeze(1).expand(-1,num_objects,-1,-1,-1)], 2)
|
||||
else:
|
||||
g = torch.cat([x.unsqueeze(1).expand(-1,num_objects,-1,-1,-1), g], 2)
|
||||
elif self.method == 'add':
|
||||
g = x.unsqueeze(1).expand(-1,num_objects,-1,-1,-1) + g
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return g
|
||||
68
tracker/model/losses.py
Normal file
68
tracker/model/losses.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
def dice_loss(input_mask, cls_gt):
|
||||
num_objects = input_mask.shape[1]
|
||||
losses = []
|
||||
for i in range(num_objects):
|
||||
mask = input_mask[:,i].flatten(start_dim=1)
|
||||
# background not in mask, so we add one to cls_gt
|
||||
gt = (cls_gt==(i+1)).float().flatten(start_dim=1)
|
||||
numerator = 2 * (mask * gt).sum(-1)
|
||||
denominator = mask.sum(-1) + gt.sum(-1)
|
||||
loss = 1 - (numerator + 1) / (denominator + 1)
|
||||
losses.append(loss)
|
||||
return torch.cat(losses).mean()
|
||||
|
||||
|
||||
# https://stackoverflow.com/questions/63735255/how-do-i-compute-bootstrapped-cross-entropy-loss-in-pytorch
|
||||
class BootstrappedCE(nn.Module):
|
||||
def __init__(self, start_warm, end_warm, top_p=0.15):
|
||||
super().__init__()
|
||||
|
||||
self.start_warm = start_warm
|
||||
self.end_warm = end_warm
|
||||
self.top_p = top_p
|
||||
|
||||
def forward(self, input, target, it):
|
||||
if it < self.start_warm:
|
||||
return F.cross_entropy(input, target), 1.0
|
||||
|
||||
raw_loss = F.cross_entropy(input, target, reduction='none').view(-1)
|
||||
num_pixels = raw_loss.numel()
|
||||
|
||||
if it > self.end_warm:
|
||||
this_p = self.top_p
|
||||
else:
|
||||
this_p = self.top_p + (1-self.top_p)*((self.end_warm-it)/(self.end_warm-self.start_warm))
|
||||
loss, _ = torch.topk(raw_loss, int(num_pixels * this_p), sorted=False)
|
||||
return loss.mean(), this_p
|
||||
|
||||
|
||||
class LossComputer:
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.bce = BootstrappedCE(config['start_warm'], config['end_warm'])
|
||||
|
||||
def compute(self, data, num_objects, it):
|
||||
losses = defaultdict(int)
|
||||
|
||||
b, t = data['rgb'].shape[:2]
|
||||
|
||||
losses['total_loss'] = 0
|
||||
for ti in range(1, t):
|
||||
for bi in range(b):
|
||||
loss, p = self.bce(data[f'logits_{ti}'][bi:bi+1, :num_objects[bi]+1], data['cls_gt'][bi:bi+1,ti,0], it)
|
||||
losses['p'] += p / b / (t-1)
|
||||
losses[f'ce_loss_{ti}'] += loss / b
|
||||
|
||||
losses['total_loss'] += losses['ce_loss_%d'%ti]
|
||||
losses[f'dice_loss_{ti}'] = dice_loss(data[f'masks_{ti}'], data['cls_gt'][:,ti,0])
|
||||
losses['total_loss'] += losses[f'dice_loss_{ti}']
|
||||
|
||||
return losses
|
||||
80
tracker/model/memory_util.py
Normal file
80
tracker/model/memory_util.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def get_similarity(mk, ms, qk, qe):
|
||||
# used for training/inference and memory reading/memory potentiation
|
||||
# mk: B x CK x [N] - Memory keys
|
||||
# ms: B x 1 x [N] - Memory shrinkage
|
||||
# qk: B x CK x [HW/P] - Query keys
|
||||
# qe: B x CK x [HW/P] - Query selection
|
||||
# Dimensions in [] are flattened
|
||||
CK = mk.shape[1]
|
||||
mk = mk.flatten(start_dim=2)
|
||||
ms = ms.flatten(start_dim=1).unsqueeze(2) if ms is not None else None
|
||||
qk = qk.flatten(start_dim=2)
|
||||
qe = qe.flatten(start_dim=2) if qe is not None else None
|
||||
|
||||
if qe is not None:
|
||||
# See appendix for derivation
|
||||
# or you can just trust me ヽ(ー_ー )ノ
|
||||
mk = mk.transpose(1, 2)
|
||||
a_sq = (mk.pow(2) @ qe)
|
||||
two_ab = 2 * (mk @ (qk * qe))
|
||||
b_sq = (qe * qk.pow(2)).sum(1, keepdim=True)
|
||||
similarity = (-a_sq+two_ab-b_sq)
|
||||
else:
|
||||
# similar to STCN if we don't have the selection term
|
||||
a_sq = mk.pow(2).sum(1).unsqueeze(2)
|
||||
two_ab = 2 * (mk.transpose(1, 2) @ qk)
|
||||
similarity = (-a_sq+two_ab)
|
||||
|
||||
if ms is not None:
|
||||
similarity = similarity * ms / math.sqrt(CK) # B*N*HW
|
||||
else:
|
||||
similarity = similarity / math.sqrt(CK) # B*N*HW
|
||||
|
||||
return similarity
|
||||
|
||||
def do_softmax(similarity, top_k: Optional[int]=None, inplace=False, return_usage=False):
|
||||
# normalize similarity with top-k softmax
|
||||
# similarity: B x N x [HW/P]
|
||||
# use inplace with care
|
||||
if top_k is not None:
|
||||
values, indices = torch.topk(similarity, k=top_k, dim=1)
|
||||
|
||||
x_exp = values.exp_()
|
||||
x_exp /= torch.sum(x_exp, dim=1, keepdim=True)
|
||||
if inplace:
|
||||
similarity.zero_().scatter_(1, indices, x_exp) # B*N*HW
|
||||
affinity = similarity
|
||||
else:
|
||||
affinity = torch.zeros_like(similarity).scatter_(1, indices, x_exp) # B*N*HW
|
||||
else:
|
||||
maxes = torch.max(similarity, dim=1, keepdim=True)[0]
|
||||
x_exp = torch.exp(similarity - maxes)
|
||||
x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True)
|
||||
affinity = x_exp / x_exp_sum
|
||||
indices = None
|
||||
|
||||
if return_usage:
|
||||
return affinity, affinity.sum(dim=2)
|
||||
|
||||
return affinity
|
||||
|
||||
def get_affinity(mk, ms, qk, qe):
|
||||
# shorthand used in training with no top-k
|
||||
similarity = get_similarity(mk, ms, qk, qe)
|
||||
affinity = do_softmax(similarity)
|
||||
return affinity
|
||||
|
||||
def readout(affinity, mv):
|
||||
B, CV, T, H, W = mv.shape
|
||||
|
||||
mo = mv.view(B, CV, T*H*W)
|
||||
mem = torch.bmm(mo, affinity)
|
||||
mem = mem.view(B, CV, H, W)
|
||||
|
||||
return mem
|
||||
250
tracker/model/modules.py
Normal file
250
tracker/model/modules.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""
|
||||
modules.py - This file stores the rather boring network blocks.
|
||||
|
||||
x - usually means features that only depends on the image
|
||||
g - usually means features that also depends on the mask.
|
||||
They might have an extra "group" or "num_objects" dimension, hence
|
||||
batch_size * num_objects * num_channels * H * W
|
||||
|
||||
The trailing number of a variable usually denote the stride
|
||||
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from model.group_modules import *
|
||||
from model import resnet
|
||||
from model.cbam import CBAM
|
||||
|
||||
|
||||
class FeatureFusionBlock(nn.Module):
|
||||
def __init__(self, x_in_dim, g_in_dim, g_mid_dim, g_out_dim):
|
||||
super().__init__()
|
||||
|
||||
self.distributor = MainToGroupDistributor()
|
||||
self.block1 = GroupResBlock(x_in_dim+g_in_dim, g_mid_dim)
|
||||
self.attention = CBAM(g_mid_dim)
|
||||
self.block2 = GroupResBlock(g_mid_dim, g_out_dim)
|
||||
|
||||
def forward(self, x, g):
|
||||
batch_size, num_objects = g.shape[:2]
|
||||
|
||||
g = self.distributor(x, g)
|
||||
g = self.block1(g)
|
||||
r = self.attention(g.flatten(start_dim=0, end_dim=1))
|
||||
r = r.view(batch_size, num_objects, *r.shape[1:])
|
||||
|
||||
g = self.block2(g+r)
|
||||
|
||||
return g
|
||||
|
||||
|
||||
class HiddenUpdater(nn.Module):
|
||||
# Used in the decoder, multi-scale feature + GRU
|
||||
def __init__(self, g_dims, mid_dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_dim
|
||||
|
||||
self.g16_conv = GConv2D(g_dims[0], mid_dim, kernel_size=1)
|
||||
self.g8_conv = GConv2D(g_dims[1], mid_dim, kernel_size=1)
|
||||
self.g4_conv = GConv2D(g_dims[2], mid_dim, kernel_size=1)
|
||||
|
||||
self.transform = GConv2D(mid_dim+hidden_dim, hidden_dim*3, kernel_size=3, padding=1)
|
||||
|
||||
nn.init.xavier_normal_(self.transform.weight)
|
||||
|
||||
def forward(self, g, h):
|
||||
g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \
|
||||
self.g4_conv(downsample_groups(g[2], ratio=1/4))
|
||||
|
||||
g = torch.cat([g, h], 2)
|
||||
|
||||
# defined slightly differently than standard GRU,
|
||||
# namely the new value is generated before the forget gate.
|
||||
# might provide better gradient but frankly it was initially just an
|
||||
# implementation error that I never bothered fixing
|
||||
values = self.transform(g)
|
||||
forget_gate = torch.sigmoid(values[:,:,:self.hidden_dim])
|
||||
update_gate = torch.sigmoid(values[:,:,self.hidden_dim:self.hidden_dim*2])
|
||||
new_value = torch.tanh(values[:,:,self.hidden_dim*2:])
|
||||
new_h = forget_gate*h*(1-update_gate) + update_gate*new_value
|
||||
|
||||
return new_h
|
||||
|
||||
|
||||
class HiddenReinforcer(nn.Module):
|
||||
# Used in the value encoder, a single GRU
|
||||
def __init__(self, g_dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_dim
|
||||
self.transform = GConv2D(g_dim+hidden_dim, hidden_dim*3, kernel_size=3, padding=1)
|
||||
|
||||
nn.init.xavier_normal_(self.transform.weight)
|
||||
|
||||
def forward(self, g, h):
|
||||
g = torch.cat([g, h], 2)
|
||||
|
||||
# defined slightly differently than standard GRU,
|
||||
# namely the new value is generated before the forget gate.
|
||||
# might provide better gradient but frankly it was initially just an
|
||||
# implementation error that I never bothered fixing
|
||||
values = self.transform(g)
|
||||
forget_gate = torch.sigmoid(values[:,:,:self.hidden_dim])
|
||||
update_gate = torch.sigmoid(values[:,:,self.hidden_dim:self.hidden_dim*2])
|
||||
new_value = torch.tanh(values[:,:,self.hidden_dim*2:])
|
||||
new_h = forget_gate*h*(1-update_gate) + update_gate*new_value
|
||||
|
||||
return new_h
|
||||
|
||||
|
||||
class ValueEncoder(nn.Module):
|
||||
def __init__(self, value_dim, hidden_dim, single_object=False):
|
||||
super().__init__()
|
||||
|
||||
self.single_object = single_object
|
||||
network = resnet.resnet18(pretrained=True, extra_dim=1 if single_object else 2)
|
||||
self.conv1 = network.conv1
|
||||
self.bn1 = network.bn1
|
||||
self.relu = network.relu # 1/2, 64
|
||||
self.maxpool = network.maxpool
|
||||
|
||||
self.layer1 = network.layer1 # 1/4, 64
|
||||
self.layer2 = network.layer2 # 1/8, 128
|
||||
self.layer3 = network.layer3 # 1/16, 256
|
||||
|
||||
self.distributor = MainToGroupDistributor()
|
||||
self.fuser = FeatureFusionBlock(1024, 256, value_dim, value_dim)
|
||||
if hidden_dim > 0:
|
||||
self.hidden_reinforce = HiddenReinforcer(value_dim, hidden_dim)
|
||||
else:
|
||||
self.hidden_reinforce = None
|
||||
|
||||
def forward(self, image, image_feat_f16, h, masks, others, is_deep_update=True):
|
||||
# image_feat_f16 is the feature from the key encoder
|
||||
if not self.single_object:
|
||||
g = torch.stack([masks, others], 2)
|
||||
else:
|
||||
g = masks.unsqueeze(2)
|
||||
g = self.distributor(image, g)
|
||||
|
||||
batch_size, num_objects = g.shape[:2]
|
||||
g = g.flatten(start_dim=0, end_dim=1)
|
||||
|
||||
g = self.conv1(g)
|
||||
g = self.bn1(g) # 1/2, 64
|
||||
g = self.maxpool(g) # 1/4, 64
|
||||
g = self.relu(g)
|
||||
|
||||
g = self.layer1(g) # 1/4
|
||||
g = self.layer2(g) # 1/8
|
||||
g = self.layer3(g) # 1/16
|
||||
|
||||
g = g.view(batch_size, num_objects, *g.shape[1:])
|
||||
g = self.fuser(image_feat_f16, g)
|
||||
|
||||
if is_deep_update and self.hidden_reinforce is not None:
|
||||
h = self.hidden_reinforce(g, h)
|
||||
|
||||
return g, h
|
||||
|
||||
|
||||
class KeyEncoder(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
network = resnet.resnet50(pretrained=True)
|
||||
self.conv1 = network.conv1
|
||||
self.bn1 = network.bn1
|
||||
self.relu = network.relu # 1/2, 64
|
||||
self.maxpool = network.maxpool
|
||||
|
||||
self.res2 = network.layer1 # 1/4, 256
|
||||
self.layer2 = network.layer2 # 1/8, 512
|
||||
self.layer3 = network.layer3 # 1/16, 1024
|
||||
|
||||
def forward(self, f):
|
||||
x = self.conv1(f)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x) # 1/2, 64
|
||||
x = self.maxpool(x) # 1/4, 64
|
||||
f4 = self.res2(x) # 1/4, 256
|
||||
f8 = self.layer2(f4) # 1/8, 512
|
||||
f16 = self.layer3(f8) # 1/16, 1024
|
||||
|
||||
return f16, f8, f4
|
||||
|
||||
|
||||
class UpsampleBlock(nn.Module):
|
||||
def __init__(self, skip_dim, g_up_dim, g_out_dim, scale_factor=2):
|
||||
super().__init__()
|
||||
self.skip_conv = nn.Conv2d(skip_dim, g_up_dim, kernel_size=3, padding=1)
|
||||
self.distributor = MainToGroupDistributor(method='add')
|
||||
self.out_conv = GroupResBlock(g_up_dim, g_out_dim)
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
def forward(self, skip_f, up_g):
|
||||
skip_f = self.skip_conv(skip_f)
|
||||
g = upsample_groups(up_g, ratio=self.scale_factor)
|
||||
g = self.distributor(skip_f, g)
|
||||
g = self.out_conv(g)
|
||||
return g
|
||||
|
||||
|
||||
class KeyProjection(nn.Module):
|
||||
def __init__(self, in_dim, keydim):
|
||||
super().__init__()
|
||||
|
||||
self.key_proj = nn.Conv2d(in_dim, keydim, kernel_size=3, padding=1)
|
||||
# shrinkage
|
||||
self.d_proj = nn.Conv2d(in_dim, 1, kernel_size=3, padding=1)
|
||||
# selection
|
||||
self.e_proj = nn.Conv2d(in_dim, keydim, kernel_size=3, padding=1)
|
||||
|
||||
nn.init.orthogonal_(self.key_proj.weight.data)
|
||||
nn.init.zeros_(self.key_proj.bias.data)
|
||||
|
||||
def forward(self, x, need_s, need_e):
|
||||
shrinkage = self.d_proj(x)**2 + 1 if (need_s) else None
|
||||
selection = torch.sigmoid(self.e_proj(x)) if (need_e) else None
|
||||
|
||||
return self.key_proj(x), shrinkage, selection
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, val_dim, hidden_dim):
|
||||
super().__init__()
|
||||
|
||||
self.fuser = FeatureFusionBlock(1024, val_dim+hidden_dim, 512, 512)
|
||||
if hidden_dim > 0:
|
||||
self.hidden_update = HiddenUpdater([512, 256, 256+1], 256, hidden_dim)
|
||||
else:
|
||||
self.hidden_update = None
|
||||
|
||||
self.up_16_8 = UpsampleBlock(512, 512, 256) # 1/16 -> 1/8
|
||||
self.up_8_4 = UpsampleBlock(256, 256, 256) # 1/8 -> 1/4
|
||||
|
||||
self.pred = nn.Conv2d(256, 1, kernel_size=3, padding=1, stride=1)
|
||||
|
||||
def forward(self, f16, f8, f4, hidden_state, memory_readout, h_out=True):
|
||||
batch_size, num_objects = memory_readout.shape[:2]
|
||||
|
||||
if self.hidden_update is not None:
|
||||
g16 = self.fuser(f16, torch.cat([memory_readout, hidden_state], 2))
|
||||
else:
|
||||
g16 = self.fuser(f16, memory_readout)
|
||||
|
||||
g8 = self.up_16_8(f8, g16)
|
||||
g4 = self.up_8_4(f4, g8)
|
||||
logits = self.pred(F.relu(g4.flatten(start_dim=0, end_dim=1)))
|
||||
|
||||
if h_out and self.hidden_update is not None:
|
||||
g4 = torch.cat([g4, logits.view(batch_size, num_objects, 1, *logits.shape[-2:])], 2)
|
||||
hidden_state = self.hidden_update([g16, g8, g4], hidden_state)
|
||||
else:
|
||||
hidden_state = None
|
||||
|
||||
logits = F.interpolate(logits, scale_factor=4, mode='bilinear', align_corners=False)
|
||||
logits = logits.view(batch_size, num_objects, *logits.shape[-2:])
|
||||
|
||||
return hidden_state, logits
|
||||
198
tracker/model/network.py
Normal file
198
tracker/model/network.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""
|
||||
This file defines XMem, the highest level nn.Module interface
|
||||
During training, it is used by trainer.py
|
||||
During evaluation, it is used by inference_core.py
|
||||
|
||||
It further depends on modules.py which gives more detailed implementations of sub-modules
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from model.aggregate import aggregate
|
||||
from model.modules import *
|
||||
from model.memory_util import *
|
||||
|
||||
|
||||
class XMem(nn.Module):
|
||||
def __init__(self, config, model_path=None, map_location=None):
|
||||
"""
|
||||
model_path/map_location are used in evaluation only
|
||||
map_location is for converting models saved in cuda to cpu
|
||||
"""
|
||||
super().__init__()
|
||||
model_weights = self.init_hyperparameters(config, model_path, map_location)
|
||||
|
||||
self.single_object = config.get('single_object', False)
|
||||
print(f'Single object mode: {self.single_object}')
|
||||
|
||||
self.key_encoder = KeyEncoder()
|
||||
self.value_encoder = ValueEncoder(self.value_dim, self.hidden_dim, self.single_object)
|
||||
|
||||
# Projection from f16 feature space to key/value space
|
||||
self.key_proj = KeyProjection(1024, self.key_dim)
|
||||
|
||||
self.decoder = Decoder(self.value_dim, self.hidden_dim)
|
||||
|
||||
if model_weights is not None:
|
||||
self.load_weights(model_weights, init_as_zero_if_needed=True)
|
||||
|
||||
def encode_key(self, frame, need_sk=True, need_ek=True):
|
||||
# Determine input shape
|
||||
if len(frame.shape) == 5:
|
||||
# shape is b*t*c*h*w
|
||||
need_reshape = True
|
||||
b, t = frame.shape[:2]
|
||||
# flatten so that we can feed them into a 2D CNN
|
||||
frame = frame.flatten(start_dim=0, end_dim=1)
|
||||
elif len(frame.shape) == 4:
|
||||
# shape is b*c*h*w
|
||||
need_reshape = False
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
f16, f8, f4 = self.key_encoder(frame)
|
||||
key, shrinkage, selection = self.key_proj(f16, need_sk, need_ek)
|
||||
|
||||
if need_reshape:
|
||||
# B*C*T*H*W
|
||||
key = key.view(b, t, *key.shape[-3:]).transpose(1, 2).contiguous()
|
||||
if shrinkage is not None:
|
||||
shrinkage = shrinkage.view(b, t, *shrinkage.shape[-3:]).transpose(1, 2).contiguous()
|
||||
if selection is not None:
|
||||
selection = selection.view(b, t, *selection.shape[-3:]).transpose(1, 2).contiguous()
|
||||
|
||||
# B*T*C*H*W
|
||||
f16 = f16.view(b, t, *f16.shape[-3:])
|
||||
f8 = f8.view(b, t, *f8.shape[-3:])
|
||||
f4 = f4.view(b, t, *f4.shape[-3:])
|
||||
|
||||
return key, shrinkage, selection, f16, f8, f4
|
||||
|
||||
def encode_value(self, frame, image_feat_f16, h16, masks, is_deep_update=True):
|
||||
num_objects = masks.shape[1]
|
||||
if num_objects != 1:
|
||||
others = torch.cat([
|
||||
torch.sum(
|
||||
masks[:, [j for j in range(num_objects) if i!=j]]
|
||||
, dim=1, keepdim=True)
|
||||
for i in range(num_objects)], 1)
|
||||
else:
|
||||
others = torch.zeros_like(masks)
|
||||
|
||||
g16, h16 = self.value_encoder(frame, image_feat_f16, h16, masks, others, is_deep_update)
|
||||
|
||||
return g16, h16
|
||||
|
||||
# Used in training only.
|
||||
# This step is replaced by MemoryManager in test time
|
||||
def read_memory(self, query_key, query_selection, memory_key,
|
||||
memory_shrinkage, memory_value):
|
||||
"""
|
||||
query_key : B * CK * H * W
|
||||
query_selection : B * CK * H * W
|
||||
memory_key : B * CK * T * H * W
|
||||
memory_shrinkage: B * 1 * T * H * W
|
||||
memory_value : B * num_objects * CV * T * H * W
|
||||
"""
|
||||
batch_size, num_objects = memory_value.shape[:2]
|
||||
memory_value = memory_value.flatten(start_dim=1, end_dim=2)
|
||||
|
||||
affinity = get_affinity(memory_key, memory_shrinkage, query_key, query_selection)
|
||||
memory = readout(affinity, memory_value)
|
||||
memory = memory.view(batch_size, num_objects, self.value_dim, *memory.shape[-2:])
|
||||
|
||||
return memory
|
||||
|
||||
def segment(self, multi_scale_features, memory_readout,
|
||||
hidden_state, selector=None, h_out=True, strip_bg=True):
|
||||
|
||||
hidden_state, logits = self.decoder(*multi_scale_features, hidden_state, memory_readout, h_out=h_out)
|
||||
prob = torch.sigmoid(logits)
|
||||
if selector is not None:
|
||||
prob = prob * selector
|
||||
|
||||
logits, prob = aggregate(prob, dim=1, return_logits=True)
|
||||
if strip_bg:
|
||||
# Strip away the background
|
||||
prob = prob[:, 1:]
|
||||
|
||||
return hidden_state, logits, prob
|
||||
|
||||
def forward(self, mode, *args, **kwargs):
|
||||
if mode == 'encode_key':
|
||||
return self.encode_key(*args, **kwargs)
|
||||
elif mode == 'encode_value':
|
||||
return self.encode_value(*args, **kwargs)
|
||||
elif mode == 'read_memory':
|
||||
return self.read_memory(*args, **kwargs)
|
||||
elif mode == 'segment':
|
||||
return self.segment(*args, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def init_hyperparameters(self, config, model_path=None, map_location=None):
|
||||
"""
|
||||
Init three hyperparameters: key_dim, value_dim, and hidden_dim
|
||||
If model_path is provided, we load these from the model weights
|
||||
The actual parameters are then updated to the config in-place
|
||||
|
||||
Otherwise we load it either from the config or default
|
||||
"""
|
||||
if model_path is not None:
|
||||
# load the model and key/value/hidden dimensions with some hacks
|
||||
# config is updated with the loaded parameters
|
||||
model_weights = torch.load(model_path, map_location=map_location)
|
||||
self.key_dim = model_weights['key_proj.key_proj.weight'].shape[0]
|
||||
self.value_dim = model_weights['value_encoder.fuser.block2.conv2.weight'].shape[0]
|
||||
self.disable_hidden = 'decoder.hidden_update.transform.weight' not in model_weights
|
||||
if self.disable_hidden:
|
||||
self.hidden_dim = 0
|
||||
else:
|
||||
self.hidden_dim = model_weights['decoder.hidden_update.transform.weight'].shape[0]//3
|
||||
print(f'Hyperparameters read from the model weights: '
|
||||
f'C^k={self.key_dim}, C^v={self.value_dim}, C^h={self.hidden_dim}')
|
||||
else:
|
||||
model_weights = None
|
||||
# load dimensions from config or default
|
||||
if 'key_dim' not in config:
|
||||
self.key_dim = 64
|
||||
print(f'key_dim not found in config. Set to default {self.key_dim}')
|
||||
else:
|
||||
self.key_dim = config['key_dim']
|
||||
|
||||
if 'value_dim' not in config:
|
||||
self.value_dim = 512
|
||||
print(f'value_dim not found in config. Set to default {self.value_dim}')
|
||||
else:
|
||||
self.value_dim = config['value_dim']
|
||||
|
||||
if 'hidden_dim' not in config:
|
||||
self.hidden_dim = 64
|
||||
print(f'hidden_dim not found in config. Set to default {self.hidden_dim}')
|
||||
else:
|
||||
self.hidden_dim = config['hidden_dim']
|
||||
|
||||
self.disable_hidden = (self.hidden_dim <= 0)
|
||||
|
||||
config['key_dim'] = self.key_dim
|
||||
config['value_dim'] = self.value_dim
|
||||
config['hidden_dim'] = self.hidden_dim
|
||||
|
||||
return model_weights
|
||||
|
||||
def load_weights(self, src_dict, init_as_zero_if_needed=False):
|
||||
# Maps SO weight (without other_mask) to MO weight (with other_mask)
|
||||
for k in list(src_dict.keys()):
|
||||
if k == 'value_encoder.conv1.weight':
|
||||
if src_dict[k].shape[1] == 4:
|
||||
print('Converting weights from single object to multiple objects.')
|
||||
pads = torch.zeros((64,1,7,7), device=src_dict[k].device)
|
||||
if not init_as_zero_if_needed:
|
||||
print('Randomly initialized padding.')
|
||||
nn.init.orthogonal_(pads)
|
||||
else:
|
||||
print('Zero-initialized padding.')
|
||||
src_dict[k] = torch.cat([src_dict[k], pads], 1)
|
||||
|
||||
self.load_state_dict(src_dict)
|
||||
165
tracker/model/resnet.py
Normal file
165
tracker/model/resnet.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""
|
||||
resnet.py - A modified ResNet structure
|
||||
We append extra channels to the first conv by some network surgery
|
||||
"""
|
||||
|
||||
from collections import OrderedDict
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils import model_zoo
|
||||
|
||||
|
||||
def load_weights_add_extra_dim(target, source_state, extra_dim=1):
|
||||
new_dict = OrderedDict()
|
||||
|
||||
for k1, v1 in target.state_dict().items():
|
||||
if not 'num_batches_tracked' in k1:
|
||||
if k1 in source_state:
|
||||
tar_v = source_state[k1]
|
||||
|
||||
if v1.shape != tar_v.shape:
|
||||
# Init the new segmentation channel with zeros
|
||||
# print(v1.shape, tar_v.shape)
|
||||
c, _, w, h = v1.shape
|
||||
pads = torch.zeros((c,extra_dim,w,h), device=tar_v.device)
|
||||
nn.init.orthogonal_(pads)
|
||||
tar_v = torch.cat([tar_v, pads], 1)
|
||||
|
||||
new_dict[k1] = tar_v
|
||||
|
||||
target.load_state_dict(new_dict)
|
||||
|
||||
|
||||
model_urls = {
|
||||
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||
}
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1, dilation=1):
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=dilation, dilation=dilation, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, dilation=dilation,
|
||||
padding=dilation, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * 4)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
def __init__(self, block, layers=(3, 4, 23, 3), extra_dim=0):
|
||||
self.inplanes = 64
|
||||
super(ResNet, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3+extra_dim, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||
kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = [block(self.inplanes, planes, stride, downsample)]
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, dilation=dilation))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def resnet18(pretrained=True, extra_dim=0):
|
||||
model = ResNet(BasicBlock, [2, 2, 2, 2], extra_dim)
|
||||
if pretrained:
|
||||
load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet18']), extra_dim)
|
||||
return model
|
||||
|
||||
def resnet50(pretrained=True, extra_dim=0):
|
||||
model = ResNet(Bottleneck, [3, 4, 6, 3], extra_dim)
|
||||
if pretrained:
|
||||
load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet50']), extra_dim)
|
||||
return model
|
||||
|
||||
244
tracker/model/trainer.py
Normal file
244
tracker/model/trainer.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""
|
||||
trainer.py - warpper and utility functions for network training
|
||||
Compute loss, back-prop, update parameters, logging, etc.
|
||||
"""
|
||||
import datetime
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from model.network import XMem
|
||||
from model.losses import LossComputer
|
||||
from util.log_integrator import Integrator
|
||||
from util.image_saver import pool_pairs
|
||||
|
||||
|
||||
class XMemTrainer:
|
||||
def __init__(self, config, logger=None, save_path=None, local_rank=0, world_size=1):
|
||||
self.config = config
|
||||
self.num_frames = config['num_frames']
|
||||
self.num_ref_frames = config['num_ref_frames']
|
||||
self.deep_update_prob = config['deep_update_prob']
|
||||
self.local_rank = local_rank
|
||||
|
||||
self.XMem = nn.parallel.DistributedDataParallel(
|
||||
XMem(config).cuda(),
|
||||
device_ids=[local_rank], output_device=local_rank, broadcast_buffers=False)
|
||||
|
||||
# Set up logger when local_rank=0
|
||||
self.logger = logger
|
||||
self.save_path = save_path
|
||||
if logger is not None:
|
||||
self.last_time = time.time()
|
||||
self.logger.log_string('model_size', str(sum([param.nelement() for param in self.XMem.parameters()])))
|
||||
self.train_integrator = Integrator(self.logger, distributed=True, local_rank=local_rank, world_size=world_size)
|
||||
self.loss_computer = LossComputer(config)
|
||||
|
||||
self.train()
|
||||
self.optimizer = optim.AdamW(filter(
|
||||
lambda p: p.requires_grad, self.XMem.parameters()), lr=config['lr'], weight_decay=config['weight_decay'])
|
||||
self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, config['steps'], config['gamma'])
|
||||
if config['amp']:
|
||||
self.scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
# Logging info
|
||||
self.log_text_interval = config['log_text_interval']
|
||||
self.log_image_interval = config['log_image_interval']
|
||||
self.save_network_interval = config['save_network_interval']
|
||||
self.save_checkpoint_interval = config['save_checkpoint_interval']
|
||||
if config['debug']:
|
||||
self.log_text_interval = self.log_image_interval = 1
|
||||
|
||||
def do_pass(self, data, max_it, it=0):
|
||||
# No need to store the gradient outside training
|
||||
torch.set_grad_enabled(self._is_train)
|
||||
|
||||
for k, v in data.items():
|
||||
if type(v) != list and type(v) != dict and type(v) != int:
|
||||
data[k] = v.cuda(non_blocking=True)
|
||||
|
||||
out = {}
|
||||
frames = data['rgb']
|
||||
first_frame_gt = data['first_frame_gt'].float()
|
||||
b = frames.shape[0]
|
||||
num_filled_objects = [o.item() for o in data['info']['num_objects']]
|
||||
num_objects = first_frame_gt.shape[2]
|
||||
selector = data['selector'].unsqueeze(2).unsqueeze(2)
|
||||
|
||||
global_avg = 0
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=self.config['amp']):
|
||||
# image features never change, compute once
|
||||
key, shrinkage, selection, f16, f8, f4 = self.XMem('encode_key', frames)
|
||||
|
||||
filler_one = torch.zeros(1, dtype=torch.int64)
|
||||
hidden = torch.zeros((b, num_objects, self.config['hidden_dim'], *key.shape[-2:]))
|
||||
v16, hidden = self.XMem('encode_value', frames[:,0], f16[:,0], hidden, first_frame_gt[:,0])
|
||||
values = v16.unsqueeze(3) # add the time dimension
|
||||
|
||||
for ti in range(1, self.num_frames):
|
||||
if ti <= self.num_ref_frames:
|
||||
ref_values = values
|
||||
ref_keys = key[:,:,:ti]
|
||||
ref_shrinkage = shrinkage[:,:,:ti] if shrinkage is not None else None
|
||||
else:
|
||||
# pick num_ref_frames random frames
|
||||
# this is not very efficient but I think we would
|
||||
# need broadcasting in gather which we don't have
|
||||
indices = [
|
||||
torch.cat([filler_one, torch.randperm(ti-1)[:self.num_ref_frames-1]+1])
|
||||
for _ in range(b)]
|
||||
ref_values = torch.stack([
|
||||
values[bi, :, :, indices[bi]] for bi in range(b)
|
||||
], 0)
|
||||
ref_keys = torch.stack([
|
||||
key[bi, :, indices[bi]] for bi in range(b)
|
||||
], 0)
|
||||
ref_shrinkage = torch.stack([
|
||||
shrinkage[bi, :, indices[bi]] for bi in range(b)
|
||||
], 0) if shrinkage is not None else None
|
||||
|
||||
# Segment frame ti
|
||||
memory_readout = self.XMem('read_memory', key[:,:,ti], selection[:,:,ti] if selection is not None else None,
|
||||
ref_keys, ref_shrinkage, ref_values)
|
||||
hidden, logits, masks = self.XMem('segment', (f16[:,ti], f8[:,ti], f4[:,ti]), memory_readout,
|
||||
hidden, selector, h_out=(ti < (self.num_frames-1)))
|
||||
|
||||
# No need to encode the last frame
|
||||
if ti < (self.num_frames-1):
|
||||
is_deep_update = np.random.rand() < self.deep_update_prob
|
||||
v16, hidden = self.XMem('encode_value', frames[:,ti], f16[:,ti], hidden, masks, is_deep_update=is_deep_update)
|
||||
values = torch.cat([values, v16.unsqueeze(3)], 3)
|
||||
|
||||
out[f'masks_{ti}'] = masks
|
||||
out[f'logits_{ti}'] = logits
|
||||
|
||||
if self._do_log or self._is_train:
|
||||
losses = self.loss_computer.compute({**data, **out}, num_filled_objects, it)
|
||||
|
||||
# Logging
|
||||
if self._do_log:
|
||||
self.integrator.add_dict(losses)
|
||||
if self._is_train:
|
||||
if it % self.log_image_interval == 0 and it != 0:
|
||||
if self.logger is not None:
|
||||
images = {**data, **out}
|
||||
size = (384, 384)
|
||||
self.logger.log_cv2('train/pairs', pool_pairs(images, size, num_filled_objects), it)
|
||||
|
||||
if self._is_train:
|
||||
|
||||
if (it) % self.log_text_interval == 0 and it != 0:
|
||||
time_spent = time.time()-self.last_time
|
||||
|
||||
if self.logger is not None:
|
||||
self.logger.log_scalar('train/lr', self.scheduler.get_last_lr()[0], it)
|
||||
self.logger.log_metrics('train', 'time', (time_spent)/self.log_text_interval, it)
|
||||
|
||||
global_avg = 0.5*(global_avg) + 0.5*(time_spent)
|
||||
eta_seconds = global_avg * (max_it - it) / 100
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
print(f'ETA: {eta_string}')
|
||||
|
||||
self.last_time = time.time()
|
||||
self.train_integrator.finalize('train', it)
|
||||
self.train_integrator.reset_except_hooks()
|
||||
|
||||
if it % self.save_network_interval == 0 and it != 0:
|
||||
if self.logger is not None:
|
||||
self.save_network(it)
|
||||
|
||||
if it % self.save_checkpoint_interval == 0 and it != 0:
|
||||
if self.logger is not None:
|
||||
self.save_checkpoint(it)
|
||||
|
||||
# Backward pass
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
if self.config['amp']:
|
||||
self.scaler.scale(losses['total_loss']).backward()
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
else:
|
||||
losses['total_loss'].backward()
|
||||
self.optimizer.step()
|
||||
|
||||
self.scheduler.step()
|
||||
|
||||
def save_network(self, it):
|
||||
if self.save_path is None:
|
||||
print('Saving has been disabled.')
|
||||
return
|
||||
|
||||
os.makedirs(os.path.dirname(self.save_path), exist_ok=True)
|
||||
model_path = f'{self.save_path}_{it}.pth'
|
||||
torch.save(self.XMem.module.state_dict(), model_path)
|
||||
print(f'Network saved to {model_path}.')
|
||||
|
||||
def save_checkpoint(self, it):
|
||||
if self.save_path is None:
|
||||
print('Saving has been disabled.')
|
||||
return
|
||||
|
||||
os.makedirs(os.path.dirname(self.save_path), exist_ok=True)
|
||||
checkpoint_path = f'{self.save_path}_checkpoint_{it}.pth'
|
||||
checkpoint = {
|
||||
'it': it,
|
||||
'network': self.XMem.module.state_dict(),
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
'scheduler': self.scheduler.state_dict()}
|
||||
torch.save(checkpoint, checkpoint_path)
|
||||
print(f'Checkpoint saved to {checkpoint_path}.')
|
||||
|
||||
def load_checkpoint(self, path):
|
||||
# This method loads everything and should be used to resume training
|
||||
map_location = 'cuda:%d' % self.local_rank
|
||||
checkpoint = torch.load(path, map_location={'cuda:0': map_location})
|
||||
|
||||
it = checkpoint['it']
|
||||
network = checkpoint['network']
|
||||
optimizer = checkpoint['optimizer']
|
||||
scheduler = checkpoint['scheduler']
|
||||
|
||||
map_location = 'cuda:%d' % self.local_rank
|
||||
self.XMem.module.load_state_dict(network)
|
||||
self.optimizer.load_state_dict(optimizer)
|
||||
self.scheduler.load_state_dict(scheduler)
|
||||
|
||||
print('Network weights, optimizer states, and scheduler states loaded.')
|
||||
|
||||
return it
|
||||
|
||||
def load_network_in_memory(self, src_dict):
|
||||
self.XMem.module.load_weights(src_dict)
|
||||
print('Network weight loaded from memory.')
|
||||
|
||||
def load_network(self, path):
|
||||
# This method loads only the network weight and should be used to load a pretrained model
|
||||
map_location = 'cuda:%d' % self.local_rank
|
||||
src_dict = torch.load(path, map_location={'cuda:0': map_location})
|
||||
|
||||
self.load_network_in_memory(src_dict)
|
||||
print(f'Network weight loaded from {path}')
|
||||
|
||||
def train(self):
|
||||
self._is_train = True
|
||||
self._do_log = True
|
||||
self.integrator = self.train_integrator
|
||||
self.XMem.eval()
|
||||
return self
|
||||
|
||||
def val(self):
|
||||
self._is_train = False
|
||||
self._do_log = True
|
||||
self.XMem.eval()
|
||||
return self
|
||||
|
||||
def test(self):
|
||||
self._is_train = False
|
||||
self._do_log = False
|
||||
self.XMem.eval()
|
||||
return self
|
||||
|
||||
0
tracker/util/__init__.py
Normal file
0
tracker/util/__init__.py
Normal file
138
tracker/util/configuration.py
Normal file
138
tracker/util/configuration.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from argparse import ArgumentParser
|
||||
|
||||
|
||||
def none_or_default(x, default):
|
||||
return x if x is not None else default
|
||||
|
||||
class Configuration():
|
||||
def parse(self, unknown_arg_ok=False):
|
||||
parser = ArgumentParser()
|
||||
|
||||
# Enable torch.backends.cudnn.benchmark -- Faster in some cases, test in your own environment
|
||||
parser.add_argument('--benchmark', action='store_true')
|
||||
parser.add_argument('--no_amp', action='store_true')
|
||||
|
||||
# Save paths
|
||||
parser.add_argument('--save_path', default='/ssd1/gaomingqi/output/xmem-sam')
|
||||
|
||||
# Data parameters
|
||||
parser.add_argument('--static_root', help='Static training data root', default='/ssd1/gaomingqi/datasets/static')
|
||||
parser.add_argument('--bl_root', help='Blender training data root', default='../BL30K')
|
||||
parser.add_argument('--yv_root', help='YouTubeVOS data root', default='/ssd1/gaomingqi/datasets/youtube-vos/2018')
|
||||
parser.add_argument('--davis_root', help='DAVIS data root', default='/ssd1/gaomingqi/datasets/davis')
|
||||
parser.add_argument('--num_workers', help='Total number of dataloader workers across all GPUs processes', type=int, default=16)
|
||||
|
||||
parser.add_argument('--key_dim', default=64, type=int)
|
||||
parser.add_argument('--value_dim', default=512, type=int)
|
||||
parser.add_argument('--hidden_dim', default=64, help='Set to =0 to disable', type=int)
|
||||
|
||||
parser.add_argument('--deep_update_prob', default=0.2, type=float)
|
||||
|
||||
parser.add_argument('--stages', help='Training stage (0-static images, 1-Blender dataset, 2-DAVIS+YouTubeVOS)', default='02')
|
||||
|
||||
"""
|
||||
Stage-specific learning parameters
|
||||
Batch sizes are effective -- you don't have to scale them when you scale the number processes
|
||||
"""
|
||||
# Stage 0, static images
|
||||
parser.add_argument('--s0_batch_size', default=16, type=int)
|
||||
parser.add_argument('--s0_iterations', default=150000, type=int)
|
||||
parser.add_argument('--s0_finetune', default=0, type=int)
|
||||
parser.add_argument('--s0_steps', nargs="*", default=[], type=int)
|
||||
parser.add_argument('--s0_lr', help='Initial learning rate', default=1e-5, type=float)
|
||||
parser.add_argument('--s0_num_ref_frames', default=2, type=int)
|
||||
parser.add_argument('--s0_num_frames', default=3, type=int)
|
||||
parser.add_argument('--s0_start_warm', default=20000, type=int)
|
||||
parser.add_argument('--s0_end_warm', default=70000, type=int)
|
||||
|
||||
# Stage 1, BL30K
|
||||
parser.add_argument('--s1_batch_size', default=8, type=int)
|
||||
parser.add_argument('--s1_iterations', default=250000, type=int)
|
||||
# fine-tune means fewer augmentations to train the sensory memory
|
||||
parser.add_argument('--s1_finetune', default=0, type=int)
|
||||
parser.add_argument('--s1_steps', nargs="*", default=[200000], type=int)
|
||||
parser.add_argument('--s1_lr', help='Initial learning rate', default=1e-5, type=float)
|
||||
parser.add_argument('--s1_num_ref_frames', default=3, type=int)
|
||||
parser.add_argument('--s1_num_frames', default=8, type=int)
|
||||
parser.add_argument('--s1_start_warm', default=20000, type=int)
|
||||
parser.add_argument('--s1_end_warm', default=70000, type=int)
|
||||
|
||||
# Stage 2, DAVIS+YoutubeVOS, longer
|
||||
parser.add_argument('--s2_batch_size', default=8, type=int)
|
||||
parser.add_argument('--s2_iterations', default=150000, type=int)
|
||||
# fine-tune means fewer augmentations to train the sensory memory
|
||||
parser.add_argument('--s2_finetune', default=10000, type=int)
|
||||
parser.add_argument('--s2_steps', nargs="*", default=[120000], type=int)
|
||||
parser.add_argument('--s2_lr', help='Initial learning rate', default=1e-5, type=float)
|
||||
parser.add_argument('--s2_num_ref_frames', default=3, type=int)
|
||||
parser.add_argument('--s2_num_frames', default=8, type=int)
|
||||
parser.add_argument('--s2_start_warm', default=20000, type=int)
|
||||
parser.add_argument('--s2_end_warm', default=70000, type=int)
|
||||
|
||||
# Stage 3, DAVIS+YoutubeVOS, shorter
|
||||
parser.add_argument('--s3_batch_size', default=8, type=int)
|
||||
parser.add_argument('--s3_iterations', default=100000, type=int)
|
||||
# fine-tune means fewer augmentations to train the sensory memory
|
||||
parser.add_argument('--s3_finetune', default=10000, type=int)
|
||||
parser.add_argument('--s3_steps', nargs="*", default=[80000], type=int)
|
||||
parser.add_argument('--s3_lr', help='Initial learning rate', default=1e-5, type=float)
|
||||
parser.add_argument('--s3_num_ref_frames', default=3, type=int)
|
||||
parser.add_argument('--s3_num_frames', default=8, type=int)
|
||||
parser.add_argument('--s3_start_warm', default=20000, type=int)
|
||||
parser.add_argument('--s3_end_warm', default=70000, type=int)
|
||||
|
||||
parser.add_argument('--gamma', help='LR := LR*gamma at every decay step', default=0.1, type=float)
|
||||
parser.add_argument('--weight_decay', default=0.05, type=float)
|
||||
|
||||
# Loading
|
||||
parser.add_argument('--load_network', help='Path to pretrained network weight only')
|
||||
parser.add_argument('--load_checkpoint', help='Path to the checkpoint file, including network, optimizer and such')
|
||||
|
||||
# Logging information
|
||||
parser.add_argument('--log_text_interval', default=100, type=int)
|
||||
parser.add_argument('--log_image_interval', default=1000, type=int)
|
||||
parser.add_argument('--save_network_interval', default=25000, type=int)
|
||||
parser.add_argument('--save_checkpoint_interval', default=50000, type=int)
|
||||
parser.add_argument('--exp_id', help='Experiment UNIQUE id, use NULL to disable logging to tensorboard', default='NULL')
|
||||
parser.add_argument('--debug', help='Debug mode which logs information more often', action='store_true')
|
||||
|
||||
# # Multiprocessing parameters, not set by users
|
||||
# parser.add_argument('--local_rank', default=0, type=int, help='Local rank of this process')
|
||||
|
||||
if unknown_arg_ok:
|
||||
args, _ = parser.parse_known_args()
|
||||
self.args = vars(args)
|
||||
else:
|
||||
self.args = vars(parser.parse_args())
|
||||
|
||||
self.args['amp'] = not self.args['no_amp']
|
||||
|
||||
# check if the stages are valid
|
||||
stage_to_perform = list(self.args['stages'])
|
||||
for s in stage_to_perform:
|
||||
if s not in ['0', '1', '2', '3']:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_stage_parameters(self, stage):
|
||||
parameters = {
|
||||
'batch_size': self.args['s%s_batch_size'%stage],
|
||||
'iterations': self.args['s%s_iterations'%stage],
|
||||
'finetune': self.args['s%s_finetune'%stage],
|
||||
'steps': self.args['s%s_steps'%stage],
|
||||
'lr': self.args['s%s_lr'%stage],
|
||||
'num_ref_frames': self.args['s%s_num_ref_frames'%stage],
|
||||
'num_frames': self.args['s%s_num_frames'%stage],
|
||||
'start_warm': self.args['s%s_start_warm'%stage],
|
||||
'end_warm': self.args['s%s_end_warm'%stage],
|
||||
}
|
||||
|
||||
return parameters
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.args[key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.args[key] = value
|
||||
|
||||
def __str__(self):
|
||||
return str(self.args)
|
||||
60
tracker/util/davis_subset.txt
Normal file
60
tracker/util/davis_subset.txt
Normal file
@@ -0,0 +1,60 @@
|
||||
bear
|
||||
bmx-bumps
|
||||
boat
|
||||
boxing-fisheye
|
||||
breakdance-flare
|
||||
bus
|
||||
car-turn
|
||||
cat-girl
|
||||
classic-car
|
||||
color-run
|
||||
crossing
|
||||
dance-jump
|
||||
dancing
|
||||
disc-jockey
|
||||
dog-agility
|
||||
dog-gooses
|
||||
dogs-scale
|
||||
drift-turn
|
||||
drone
|
||||
elephant
|
||||
flamingo
|
||||
hike
|
||||
hockey
|
||||
horsejump-low
|
||||
kid-football
|
||||
kite-walk
|
||||
koala
|
||||
lady-running
|
||||
lindy-hop
|
||||
longboard
|
||||
lucia
|
||||
mallard-fly
|
||||
mallard-water
|
||||
miami-surf
|
||||
motocross-bumps
|
||||
motorbike
|
||||
night-race
|
||||
paragliding
|
||||
planes-water
|
||||
rallye
|
||||
rhino
|
||||
rollerblade
|
||||
schoolgirls
|
||||
scooter-board
|
||||
scooter-gray
|
||||
sheep
|
||||
skate-park
|
||||
snowboard
|
||||
soccerball
|
||||
stroller
|
||||
stunt
|
||||
surf
|
||||
swing
|
||||
tennis
|
||||
tractor-sand
|
||||
train
|
||||
tuk-tuk
|
||||
upside-down
|
||||
varanus-cage
|
||||
walking
|
||||
136
tracker/util/image_saver.py
Normal file
136
tracker/util/image_saver.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from dataset.range_transform import inv_im_trans
|
||||
from collections import defaultdict
|
||||
|
||||
def tensor_to_numpy(image):
|
||||
image_np = (image.numpy() * 255).astype('uint8')
|
||||
return image_np
|
||||
|
||||
def tensor_to_np_float(image):
|
||||
image_np = image.numpy().astype('float32')
|
||||
return image_np
|
||||
|
||||
def detach_to_cpu(x):
|
||||
return x.detach().cpu()
|
||||
|
||||
def transpose_np(x):
|
||||
return np.transpose(x, [1,2,0])
|
||||
|
||||
def tensor_to_gray_im(x):
|
||||
x = detach_to_cpu(x)
|
||||
x = tensor_to_numpy(x)
|
||||
x = transpose_np(x)
|
||||
return x
|
||||
|
||||
def tensor_to_im(x):
|
||||
x = detach_to_cpu(x)
|
||||
x = inv_im_trans(x).clamp(0, 1)
|
||||
x = tensor_to_numpy(x)
|
||||
x = transpose_np(x)
|
||||
return x
|
||||
|
||||
# Predefined key <-> caption dict
|
||||
key_captions = {
|
||||
'im': 'Image',
|
||||
'gt': 'GT',
|
||||
}
|
||||
|
||||
"""
|
||||
Return an image array with captions
|
||||
keys in dictionary will be used as caption if not provided
|
||||
values should contain lists of cv2 images
|
||||
"""
|
||||
def get_image_array(images, grid_shape, captions={}):
|
||||
h, w = grid_shape
|
||||
cate_counts = len(images)
|
||||
rows_counts = len(next(iter(images.values())))
|
||||
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
|
||||
output_image = np.zeros([w*cate_counts, h*(rows_counts+1), 3], dtype=np.uint8)
|
||||
col_cnt = 0
|
||||
for k, v in images.items():
|
||||
|
||||
# Default as key value itself
|
||||
caption = captions.get(k, k)
|
||||
|
||||
# Handles new line character
|
||||
dy = 40
|
||||
for i, line in enumerate(caption.split('\n')):
|
||||
cv2.putText(output_image, line, (10, col_cnt*w+100+i*dy),
|
||||
font, 0.8, (255,255,255), 2, cv2.LINE_AA)
|
||||
|
||||
# Put images
|
||||
for row_cnt, img in enumerate(v):
|
||||
im_shape = img.shape
|
||||
if len(im_shape) == 2:
|
||||
img = img[..., np.newaxis]
|
||||
|
||||
img = (img * 255).astype('uint8')
|
||||
|
||||
output_image[(col_cnt+0)*w:(col_cnt+1)*w,
|
||||
(row_cnt+1)*h:(row_cnt+2)*h, :] = img
|
||||
|
||||
col_cnt += 1
|
||||
|
||||
return output_image
|
||||
|
||||
def base_transform(im, size):
|
||||
im = tensor_to_np_float(im)
|
||||
if len(im.shape) == 3:
|
||||
im = im.transpose((1, 2, 0))
|
||||
else:
|
||||
im = im[:, :, None]
|
||||
|
||||
# Resize
|
||||
if im.shape[1] != size:
|
||||
im = cv2.resize(im, size, interpolation=cv2.INTER_NEAREST)
|
||||
|
||||
return im.clip(0, 1)
|
||||
|
||||
def im_transform(im, size):
|
||||
return base_transform(inv_im_trans(detach_to_cpu(im)), size=size)
|
||||
|
||||
def mask_transform(mask, size):
|
||||
return base_transform(detach_to_cpu(mask), size=size)
|
||||
|
||||
def out_transform(mask, size):
|
||||
return base_transform(detach_to_cpu(torch.sigmoid(mask)), size=size)
|
||||
|
||||
def pool_pairs(images, size, num_objects):
|
||||
req_images = defaultdict(list)
|
||||
|
||||
b, t = images['rgb'].shape[:2]
|
||||
|
||||
# limit the number of images saved
|
||||
b = min(2, b)
|
||||
|
||||
# find max num objects
|
||||
max_num_objects = max(num_objects[:b])
|
||||
|
||||
GT_suffix = ''
|
||||
for bi in range(b):
|
||||
GT_suffix += ' \n%s' % images['info']['name'][bi][-25:-4]
|
||||
|
||||
for bi in range(b):
|
||||
for ti in range(t):
|
||||
req_images['RGB'].append(im_transform(images['rgb'][bi,ti], size))
|
||||
for oi in range(max_num_objects):
|
||||
if ti == 0 or oi >= num_objects[bi]:
|
||||
req_images['Mask_%d'%oi].append(mask_transform(images['first_frame_gt'][bi][0,oi], size))
|
||||
# req_images['Mask_X8_%d'%oi].append(mask_transform(images['first_frame_gt'][bi][0,oi], size))
|
||||
# req_images['Mask_X16_%d'%oi].append(mask_transform(images['first_frame_gt'][bi][0,oi], size))
|
||||
else:
|
||||
req_images['Mask_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi], size))
|
||||
# req_images['Mask_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi][2], size))
|
||||
# req_images['Mask_X8_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi][1], size))
|
||||
# req_images['Mask_X16_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi][0], size))
|
||||
req_images['GT_%d_%s'%(oi, GT_suffix)].append(mask_transform(images['cls_gt'][bi,ti,0]==(oi+1), size))
|
||||
# print((images['cls_gt'][bi,ti,0]==(oi+1)).shape)
|
||||
# print(mask_transform(images['cls_gt'][bi,ti,0]==(oi+1), size).shape)
|
||||
|
||||
|
||||
return get_image_array(req_images, size, key_captions)
|
||||
16
tracker/util/load_subset.py
Normal file
16
tracker/util/load_subset.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
load_subset.py - Presents a subset of data
|
||||
DAVIS - only the training set
|
||||
YouTubeVOS - I manually filtered some erroneous ones out but I haven't checked all
|
||||
"""
|
||||
|
||||
|
||||
def load_sub_davis(path='util/davis_subset.txt'):
|
||||
with open(path, mode='r') as f:
|
||||
subset = set(f.read().splitlines())
|
||||
return subset
|
||||
|
||||
def load_sub_yv(path='util/yv_subset.txt'):
|
||||
with open(path, mode='r') as f:
|
||||
subset = set(f.read().splitlines())
|
||||
return subset
|
||||
80
tracker/util/log_integrator.py
Normal file
80
tracker/util/log_integrator.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
Integrate numerical values for some iterations
|
||||
Typically used for loss computation / logging to tensorboard
|
||||
Call finalize and create a new Integrator when you want to display/log
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class Integrator:
|
||||
def __init__(self, logger, distributed=True, local_rank=0, world_size=1):
|
||||
self.values = {}
|
||||
self.counts = {}
|
||||
self.hooks = [] # List is used here to maintain insertion order
|
||||
|
||||
self.logger = logger
|
||||
|
||||
self.distributed = distributed
|
||||
self.local_rank = local_rank
|
||||
self.world_size = world_size
|
||||
|
||||
def add_tensor(self, key, tensor):
|
||||
if key not in self.values:
|
||||
self.counts[key] = 1
|
||||
if type(tensor) == float or type(tensor) == int:
|
||||
self.values[key] = tensor
|
||||
else:
|
||||
self.values[key] = tensor.mean().item()
|
||||
else:
|
||||
self.counts[key] += 1
|
||||
if type(tensor) == float or type(tensor) == int:
|
||||
self.values[key] += tensor
|
||||
else:
|
||||
self.values[key] += tensor.mean().item()
|
||||
|
||||
def add_dict(self, tensor_dict):
|
||||
for k, v in tensor_dict.items():
|
||||
self.add_tensor(k, v)
|
||||
|
||||
def add_hook(self, hook):
|
||||
"""
|
||||
Adds a custom hook, i.e. compute new metrics using values in the dict
|
||||
The hook takes the dict as argument, and returns a (k, v) tuple
|
||||
e.g. for computing IoU
|
||||
"""
|
||||
if type(hook) == list:
|
||||
self.hooks.extend(hook)
|
||||
else:
|
||||
self.hooks.append(hook)
|
||||
|
||||
def reset_except_hooks(self):
|
||||
self.values = {}
|
||||
self.counts = {}
|
||||
|
||||
# Average and output the metrics
|
||||
def finalize(self, prefix, it, f=None):
|
||||
|
||||
for hook in self.hooks:
|
||||
k, v = hook(self.values)
|
||||
self.add_tensor(k, v)
|
||||
|
||||
for k, v in self.values.items():
|
||||
|
||||
if k[:4] == 'hide':
|
||||
continue
|
||||
|
||||
avg = v / self.counts[k]
|
||||
|
||||
if self.distributed:
|
||||
# Inplace operation
|
||||
avg = torch.tensor(avg).cuda()
|
||||
torch.distributed.reduce(avg, dst=0)
|
||||
|
||||
if self.local_rank == 0:
|
||||
avg = (avg/self.world_size).cpu().item()
|
||||
self.logger.log_metrics(prefix, k, avg, it, f)
|
||||
else:
|
||||
# Simple does it
|
||||
self.logger.log_metrics(prefix, k, avg, it, f)
|
||||
|
||||
101
tracker/util/logger.py
Normal file
101
tracker/util/logger.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Dumps things to tensorboard and console
|
||||
"""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
|
||||
def tensor_to_numpy(image):
|
||||
image_np = (image.numpy() * 255).astype('uint8')
|
||||
return image_np
|
||||
|
||||
def detach_to_cpu(x):
|
||||
return x.detach().cpu()
|
||||
|
||||
def fix_width_trunc(x):
|
||||
return ('{:.9s}'.format('{:0.9f}'.format(x)))
|
||||
|
||||
class TensorboardLogger:
|
||||
def __init__(self, short_id, id, git_info):
|
||||
self.short_id = short_id
|
||||
if self.short_id == 'NULL':
|
||||
self.short_id = 'DEBUG'
|
||||
|
||||
if id is None:
|
||||
self.no_log = True
|
||||
warnings.warn('Logging has been disbaled.')
|
||||
else:
|
||||
self.no_log = False
|
||||
|
||||
self.inv_im_trans = transforms.Normalize(
|
||||
mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
|
||||
std=[1/0.229, 1/0.224, 1/0.225])
|
||||
|
||||
self.inv_seg_trans = transforms.Normalize(
|
||||
mean=[-0.5/0.5],
|
||||
std=[1/0.5])
|
||||
|
||||
log_path = os.path.join('.', 'logs', '%s' % id)
|
||||
self.logger = SummaryWriter(log_path)
|
||||
|
||||
self.log_string('git', git_info)
|
||||
|
||||
def log_scalar(self, tag, x, step):
|
||||
if self.no_log:
|
||||
warnings.warn('Logging has been disabled.')
|
||||
return
|
||||
self.logger.add_scalar(tag, x, step)
|
||||
|
||||
def log_metrics(self, l1_tag, l2_tag, val, step, f=None):
|
||||
tag = l1_tag + '/' + l2_tag
|
||||
text = '{:s} - It {:6d} [{:5s}] [{:13}]: {:s}'.format(self.short_id, step, l1_tag.upper(), l2_tag, fix_width_trunc(val))
|
||||
print(text)
|
||||
if f is not None:
|
||||
f.write(text + '\n')
|
||||
f.flush()
|
||||
self.log_scalar(tag, val, step)
|
||||
|
||||
def log_im(self, tag, x, step):
|
||||
if self.no_log:
|
||||
warnings.warn('Logging has been disabled.')
|
||||
return
|
||||
x = detach_to_cpu(x)
|
||||
x = self.inv_im_trans(x)
|
||||
x = tensor_to_numpy(x)
|
||||
self.logger.add_image(tag, x, step)
|
||||
|
||||
def log_cv2(self, tag, x, step):
|
||||
if self.no_log:
|
||||
warnings.warn('Logging has been disabled.')
|
||||
return
|
||||
x = x.transpose((2, 0, 1))
|
||||
self.logger.add_image(tag, x, step)
|
||||
|
||||
def log_seg(self, tag, x, step):
|
||||
if self.no_log:
|
||||
warnings.warn('Logging has been disabled.')
|
||||
return
|
||||
x = detach_to_cpu(x)
|
||||
x = self.inv_seg_trans(x)
|
||||
x = tensor_to_numpy(x)
|
||||
self.logger.add_image(tag, x, step)
|
||||
|
||||
def log_gray(self, tag, x, step):
|
||||
if self.no_log:
|
||||
warnings.warn('Logging has been disabled.')
|
||||
return
|
||||
x = detach_to_cpu(x)
|
||||
x = tensor_to_numpy(x)
|
||||
self.logger.add_image(tag, x, step)
|
||||
|
||||
def log_string(self, tag, x):
|
||||
print(tag, x)
|
||||
if self.no_log:
|
||||
warnings.warn('Logging has been disabled.')
|
||||
return
|
||||
self.logger.add_text(tag, x)
|
||||
|
||||
3
tracker/util/palette.py
Normal file
3
tracker/util/palette.py
Normal file
@@ -0,0 +1,3 @@
|
||||
davis_palette = b'\x00\x00\x00\x80\x00\x00\x00\x80\x00\x80\x80\x00\x00\x00\x80\x80\x00\x80\x00\x80\x80\x80\x80\x80@\x00\x00\xc0\x00\x00@\x80\x00\xc0\x80\x00@\x00\x80\xc0\x00\x80@\x80\x80\xc0\x80\x80\x00@\x00\x80@\x00\x00\xc0\x00\x80\xc0\x00\x00@\x80\x80@\x80\x00\xc0\x80\x80\xc0\x80@@\x00\xc0@\x00@\xc0\x00\xc0\xc0\x00@@\x80\xc0@\x80@\xc0\x80\xc0\xc0\x80\x00\x00@\x80\x00@\x00\x80@\x80\x80@\x00\x00\xc0\x80\x00\xc0\x00\x80\xc0\x80\x80\xc0@\x00@\xc0\x00@@\x80@\xc0\x80@@\x00\xc0\xc0\x00\xc0@\x80\xc0\xc0\x80\xc0\x00@@\x80@@\x00\xc0@\x80\xc0@\x00@\xc0\x80@\xc0\x00\xc0\xc0\x80\xc0\xc0@@@\xc0@@@\xc0@\xc0\xc0@@@\xc0\xc0@\xc0@\xc0\xc0\xc0\xc0\xc0 \x00\x00\xa0\x00\x00 \x80\x00\xa0\x80\x00 \x00\x80\xa0\x00\x80 \x80\x80\xa0\x80\x80`\x00\x00\xe0\x00\x00`\x80\x00\xe0\x80\x00`\x00\x80\xe0\x00\x80`\x80\x80\xe0\x80\x80 @\x00\xa0@\x00 \xc0\x00\xa0\xc0\x00 @\x80\xa0@\x80 \xc0\x80\xa0\xc0\x80`@\x00\xe0@\x00`\xc0\x00\xe0\xc0\x00`@\x80\xe0@\x80`\xc0\x80\xe0\xc0\x80 \x00@\xa0\x00@ \x80@\xa0\x80@ \x00\xc0\xa0\x00\xc0 \x80\xc0\xa0\x80\xc0`\x00@\xe0\x00@`\x80@\xe0\x80@`\x00\xc0\xe0\x00\xc0`\x80\xc0\xe0\x80\xc0 @@\xa0@@ \xc0@\xa0\xc0@ @\xc0\xa0@\xc0 \xc0\xc0\xa0\xc0\xc0`@@\xe0@@`\xc0@\xe0\xc0@`@\xc0\xe0@\xc0`\xc0\xc0\xe0\xc0\xc0\x00 \x00\x80 \x00\x00\xa0\x00\x80\xa0\x00\x00 \x80\x80 \x80\x00\xa0\x80\x80\xa0\x80@ \x00\xc0 \x00@\xa0\x00\xc0\xa0\x00@ \x80\xc0 \x80@\xa0\x80\xc0\xa0\x80\x00`\x00\x80`\x00\x00\xe0\x00\x80\xe0\x00\x00`\x80\x80`\x80\x00\xe0\x80\x80\xe0\x80@`\x00\xc0`\x00@\xe0\x00\xc0\xe0\x00@`\x80\xc0`\x80@\xe0\x80\xc0\xe0\x80\x00 @\x80 @\x00\xa0@\x80\xa0@\x00 \xc0\x80 \xc0\x00\xa0\xc0\x80\xa0\xc0@ @\xc0 @@\xa0@\xc0\xa0@@ \xc0\xc0 \xc0@\xa0\xc0\xc0\xa0\xc0\x00`@\x80`@\x00\xe0@\x80\xe0@\x00`\xc0\x80`\xc0\x00\xe0\xc0\x80\xe0\xc0@`@\xc0`@@\xe0@\xc0\xe0@@`\xc0\xc0`\xc0@\xe0\xc0\xc0\xe0\xc0 \x00\xa0 \x00 \xa0\x00\xa0\xa0\x00 \x80\xa0 \x80 \xa0\x80\xa0\xa0\x80` \x00\xe0 \x00`\xa0\x00\xe0\xa0\x00` \x80\xe0 \x80`\xa0\x80\xe0\xa0\x80 `\x00\xa0`\x00 \xe0\x00\xa0\xe0\x00 `\x80\xa0`\x80 \xe0\x80\xa0\xe0\x80``\x00\xe0`\x00`\xe0\x00\xe0\xe0\x00``\x80\xe0`\x80`\xe0\x80\xe0\xe0\x80 @\xa0 @ \xa0@\xa0\xa0@ \xc0\xa0 \xc0 \xa0\xc0\xa0\xa0\xc0` @\xe0 @`\xa0@\xe0\xa0@` \xc0\xe0 \xc0`\xa0\xc0\xe0\xa0\xc0 `@\xa0`@ \xe0@\xa0\xe0@ `\xc0\xa0`\xc0 \xe0\xc0\xa0\xe0\xc0``@\xe0`@`\xe0@\xe0\xe0@``\xc0\xe0`\xc0`\xe0\xc0\xe0\xe0\xc0'
|
||||
|
||||
youtube_palette = b'\x00\x00\x00\xec_g\xf9\x91W\xfa\xc8c\x99\xc7\x94b\xb3\xb2f\x99\xcc\xc5\x94\xc5\xabyg\xff\xff\xffes~\x0b\x0b\x0b\x0c\x0c\x0c\r\r\r\x0e\x0e\x0e\x0f\x0f\x0f'
|
||||
47
tracker/util/tensor_util.py
Normal file
47
tracker/util/tensor_util.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def compute_tensor_iu(seg, gt):
|
||||
intersection = (seg & gt).float().sum()
|
||||
union = (seg | gt).float().sum()
|
||||
|
||||
return intersection, union
|
||||
|
||||
def compute_tensor_iou(seg, gt):
|
||||
intersection, union = compute_tensor_iu(seg, gt)
|
||||
iou = (intersection + 1e-6) / (union + 1e-6)
|
||||
|
||||
return iou
|
||||
|
||||
# STM
|
||||
def pad_divide_by(in_img, d):
|
||||
h, w = in_img.shape[-2:]
|
||||
|
||||
if h % d > 0:
|
||||
new_h = h + d - h % d
|
||||
else:
|
||||
new_h = h
|
||||
if w % d > 0:
|
||||
new_w = w + d - w % d
|
||||
else:
|
||||
new_w = w
|
||||
lh, uh = int((new_h-h) / 2), int(new_h-h) - int((new_h-h) / 2)
|
||||
lw, uw = int((new_w-w) / 2), int(new_w-w) - int((new_w-w) / 2)
|
||||
pad_array = (int(lw), int(uw), int(lh), int(uh))
|
||||
out = F.pad(in_img, pad_array)
|
||||
return out, pad_array
|
||||
|
||||
def unpad(img, pad):
|
||||
if len(img.shape) == 4:
|
||||
if pad[2]+pad[3] > 0:
|
||||
img = img[:,:,pad[2]:-pad[3],:]
|
||||
if pad[0]+pad[1] > 0:
|
||||
img = img[:,:,:,pad[0]:-pad[1]]
|
||||
elif len(img.shape) == 3:
|
||||
if pad[2]+pad[3] > 0:
|
||||
img = img[:,pad[2]:-pad[3],:]
|
||||
if pad[0]+pad[1] > 0:
|
||||
img = img[:,:,pad[0]:-pad[1]]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return img
|
||||
3464
tracker/util/yv_subset.txt
Normal file
3464
tracker/util/yv_subset.txt
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user