remove redundant code

This commit is contained in:
gaomingqi
2023-04-12 13:21:43 +08:00
parent 9f30e59c45
commit caf539d3ca
150 changed files with 25 additions and 8484 deletions

View File

View File

@@ -0,0 +1,12 @@
import torchvision.transforms as transforms
im_mean = (124, 116, 104)
im_normalization = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
inv_im_trans = transforms.Normalize(
mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
std=[1/0.229, 1/0.224, 1/0.225])

View File

@@ -0,0 +1,6 @@
import torch
import random
def reseed(seed):
random.seed(seed)
torch.manual_seed(seed)

View File

@@ -0,0 +1,179 @@
import os
from os import path
import torch
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from PIL import Image
import numpy as np
from dataset.range_transform import im_normalization, im_mean
from dataset.tps import random_tps_warp
from dataset.reseed import reseed
class StaticTransformDataset(Dataset):
"""
Generate pseudo VOS data by applying random transforms on static images.
Single-object only.
Method 0 - FSS style (class/1.jpg class/1.png)
Method 1 - Others style (XXX.jpg XXX.png)
"""
def __init__(self, parameters, num_frames=3, max_num_obj=1):
self.num_frames = num_frames
self.max_num_obj = max_num_obj
self.im_list = []
for parameter in parameters:
root, method, multiplier = parameter
if method == 0:
# Get images
classes = os.listdir(root)
for c in classes:
imgs = os.listdir(path.join(root, c))
jpg_list = [im for im in imgs if 'jpg' in im[-3:].lower()]
joint_list = [path.join(root, c, im) for im in jpg_list]
self.im_list.extend(joint_list * multiplier)
elif method == 1:
self.im_list.extend([path.join(root, im) for im in os.listdir(root) if '.jpg' in im] * multiplier)
print(f'{len(self.im_list)} images found.')
# These set of transform is the same for im/gt pairs, but different among the 3 sampled frames
self.pair_im_lone_transform = transforms.Compose([
transforms.ColorJitter(0.1, 0.05, 0.05, 0), # No hue change here as that's not realistic
])
self.pair_im_dual_transform = transforms.Compose([
transforms.RandomAffine(degrees=20, scale=(0.9,1.1), shear=10, interpolation=InterpolationMode.BICUBIC, fill=im_mean),
transforms.Resize(384, InterpolationMode.BICUBIC),
transforms.RandomCrop((384, 384), pad_if_needed=True, fill=im_mean),
])
self.pair_gt_dual_transform = transforms.Compose([
transforms.RandomAffine(degrees=20, scale=(0.9,1.1), shear=10, interpolation=InterpolationMode.BICUBIC, fill=0),
transforms.Resize(384, InterpolationMode.NEAREST),
transforms.RandomCrop((384, 384), pad_if_needed=True, fill=0),
])
# These transform are the same for all pairs in the sampled sequence
self.all_im_lone_transform = transforms.Compose([
transforms.ColorJitter(0.1, 0.05, 0.05, 0.05),
transforms.RandomGrayscale(0.05),
])
self.all_im_dual_transform = transforms.Compose([
transforms.RandomAffine(degrees=0, scale=(0.8, 1.5), fill=im_mean),
transforms.RandomHorizontalFlip(),
])
self.all_gt_dual_transform = transforms.Compose([
transforms.RandomAffine(degrees=0, scale=(0.8, 1.5), fill=0),
transforms.RandomHorizontalFlip(),
])
# Final transform without randomness
self.final_im_transform = transforms.Compose([
transforms.ToTensor(),
im_normalization,
])
self.final_gt_transform = transforms.Compose([
transforms.ToTensor(),
])
def _get_sample(self, idx):
im = Image.open(self.im_list[idx]).convert('RGB')
gt = Image.open(self.im_list[idx][:-3]+'png').convert('L')
sequence_seed = np.random.randint(2147483647)
images = []
masks = []
for _ in range(self.num_frames):
reseed(sequence_seed)
this_im = self.all_im_dual_transform(im)
this_im = self.all_im_lone_transform(this_im)
reseed(sequence_seed)
this_gt = self.all_gt_dual_transform(gt)
pairwise_seed = np.random.randint(2147483647)
reseed(pairwise_seed)
this_im = self.pair_im_dual_transform(this_im)
this_im = self.pair_im_lone_transform(this_im)
reseed(pairwise_seed)
this_gt = self.pair_gt_dual_transform(this_gt)
# Use TPS only some of the times
# Not because TPS is bad -- just that it is too slow and I need to speed up data loading
if np.random.rand() < 0.33:
this_im, this_gt = random_tps_warp(this_im, this_gt, scale=0.02)
this_im = self.final_im_transform(this_im)
this_gt = self.final_gt_transform(this_gt)
images.append(this_im)
masks.append(this_gt)
images = torch.stack(images, 0)
masks = torch.stack(masks, 0)
return images, masks.numpy()
def __getitem__(self, idx):
additional_objects = np.random.randint(self.max_num_obj)
indices = [idx, *np.random.randint(self.__len__(), size=additional_objects)]
merged_images = None
merged_masks = np.zeros((self.num_frames, 384, 384), dtype=np.int64)
for i, list_id in enumerate(indices):
images, masks = self._get_sample(list_id)
if merged_images is None:
merged_images = images
else:
merged_images = merged_images*(1-masks) + images*masks
merged_masks[masks[:,0]>0.5] = (i+1)
masks = merged_masks
labels = np.unique(masks[0])
# Remove background
labels = labels[labels!=0]
target_objects = labels.tolist()
# Generate one-hot ground-truth
cls_gt = np.zeros((self.num_frames, 384, 384), dtype=np.int64)
first_frame_gt = np.zeros((1, self.max_num_obj, 384, 384), dtype=np.int64)
for i, l in enumerate(target_objects):
this_mask = (masks==l)
cls_gt[this_mask] = i+1
first_frame_gt[0,i] = (this_mask[0])
cls_gt = np.expand_dims(cls_gt, 1)
info = {}
info['name'] = self.im_list[idx]
info['num_objects'] = max(1, len(target_objects))
# 1 if object exist, 0 otherwise
selector = [1 if i < info['num_objects'] else 0 for i in range(self.max_num_obj)]
selector = torch.FloatTensor(selector)
data = {
'rgb': merged_images,
'first_frame_gt': first_frame_gt,
'cls_gt': cls_gt,
'selector': selector,
'info': info
}
return data
def __len__(self):
return len(self.im_list)

37
tracker/dataset/tps.py Normal file
View File

@@ -0,0 +1,37 @@
import numpy as np
from PIL import Image
import cv2
import thinplate as tps
cv2.setNumThreads(0)
def pick_random_points(h, w, n_samples):
y_idx = np.random.choice(np.arange(h), size=n_samples, replace=False)
x_idx = np.random.choice(np.arange(w), size=n_samples, replace=False)
return y_idx/h, x_idx/w
def warp_dual_cv(img, mask, c_src, c_dst):
dshape = img.shape
theta = tps.tps_theta_from_points(c_src, c_dst, reduced=True)
grid = tps.tps_grid(theta, c_dst, dshape)
mapx, mapy = tps.tps_grid_to_remap(grid, img.shape)
return cv2.remap(img, mapx, mapy, cv2.INTER_LINEAR), cv2.remap(mask, mapx, mapy, cv2.INTER_NEAREST)
def random_tps_warp(img, mask, scale, n_ctrl_pts=12):
"""
Apply a random TPS warp of the input image and mask
Uses randomness from numpy
"""
img = np.asarray(img)
mask = np.asarray(mask)
h, w = mask.shape
points = pick_random_points(h, w, n_ctrl_pts)
c_src = np.stack(points, 1)
c_dst = c_src + np.random.normal(scale=scale, size=c_src.shape)
warp_im, warp_gt = warp_dual_cv(img, mask, c_src, c_dst)
return Image.fromarray(warp_im), Image.fromarray(warp_gt)

13
tracker/dataset/util.py Normal file
View File

@@ -0,0 +1,13 @@
import numpy as np
def all_to_onehot(masks, labels):
if len(masks.shape) == 3:
Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1], masks.shape[2]), dtype=np.uint8)
else:
Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1]), dtype=np.uint8)
for ni, l in enumerate(labels):
Ms[ni] = (masks == l).astype(np.uint8)
return Ms

View File

@@ -0,0 +1,216 @@
import os
from os import path, replace
import torch
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from PIL import Image
import numpy as np
from dataset.range_transform import im_normalization, im_mean
from dataset.reseed import reseed
class VOSDataset(Dataset):
"""
Works for DAVIS/YouTubeVOS/BL30K training
For each sequence:
- Pick three frames
- Pick two objects
- Apply some random transforms that are the same for all frames
- Apply random transform to each of the frame
- The distance between frames is controlled
"""
def __init__(self, im_root, gt_root, max_jump, is_bl, subset=None, num_frames=3, max_num_obj=3, finetune=False):
self.im_root = im_root
self.gt_root = gt_root
self.max_jump = max_jump
self.is_bl = is_bl
self.num_frames = num_frames
self.max_num_obj = max_num_obj
self.videos = []
self.frames = {}
vid_list = sorted(os.listdir(self.im_root))
# Pre-filtering
for vid in vid_list:
if subset is not None:
if vid not in subset:
continue
frames = sorted(os.listdir(os.path.join(self.im_root, vid)))
if len(frames) < num_frames:
continue
self.frames[vid] = frames
self.videos.append(vid)
print('%d out of %d videos accepted in %s.' % (len(self.videos), len(vid_list), im_root))
# These set of transform is the same for im/gt pairs, but different among the 3 sampled frames
self.pair_im_lone_transform = transforms.Compose([
transforms.ColorJitter(0.01, 0.01, 0.01, 0),
])
self.pair_im_dual_transform = transforms.Compose([
transforms.RandomAffine(degrees=0 if finetune or self.is_bl else 15, shear=0 if finetune or self.is_bl else 10, interpolation=InterpolationMode.BILINEAR, fill=im_mean),
])
self.pair_gt_dual_transform = transforms.Compose([
transforms.RandomAffine(degrees=0 if finetune or self.is_bl else 15, shear=0 if finetune or self.is_bl else 10, interpolation=InterpolationMode.NEAREST, fill=0),
])
# These transform are the same for all pairs in the sampled sequence
self.all_im_lone_transform = transforms.Compose([
transforms.ColorJitter(0.1, 0.03, 0.03, 0),
transforms.RandomGrayscale(0.05),
])
if self.is_bl:
# Use a different cropping scheme for the blender dataset because the image size is different
self.all_im_dual_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomResizedCrop((384, 384), scale=(0.25, 1.00), interpolation=InterpolationMode.BILINEAR)
])
self.all_gt_dual_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomResizedCrop((384, 384), scale=(0.25, 1.00), interpolation=InterpolationMode.NEAREST)
])
else:
self.all_im_dual_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomResizedCrop((384, 384), scale=(0.36,1.00), interpolation=InterpolationMode.BILINEAR)
])
self.all_gt_dual_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomResizedCrop((384, 384), scale=(0.36,1.00), interpolation=InterpolationMode.NEAREST)
])
# Final transform without randomness
self.final_im_transform = transforms.Compose([
transforms.ToTensor(),
im_normalization,
])
def __getitem__(self, idx):
video = self.videos[idx]
info = {}
info['name'] = video
vid_im_path = path.join(self.im_root, video)
vid_gt_path = path.join(self.gt_root, video)
frames = self.frames[video]
trials = 0
while trials < 5:
info['frames'] = [] # Appended with actual frames
num_frames = self.num_frames
length = len(frames)
this_max_jump = min(len(frames), self.max_jump)
# iterative sampling
frames_idx = [np.random.randint(length)]
acceptable_set = set(range(max(0, frames_idx[-1]-this_max_jump), min(length, frames_idx[-1]+this_max_jump+1))).difference(set(frames_idx))
while(len(frames_idx) < num_frames):
idx = np.random.choice(list(acceptable_set))
frames_idx.append(idx)
new_set = set(range(max(0, frames_idx[-1]-this_max_jump), min(length, frames_idx[-1]+this_max_jump+1)))
acceptable_set = acceptable_set.union(new_set).difference(set(frames_idx))
frames_idx = sorted(frames_idx)
if np.random.rand() < 0.5:
# Reverse time
frames_idx = frames_idx[::-1]
sequence_seed = np.random.randint(2147483647)
images = []
masks = []
target_objects = []
for f_idx in frames_idx:
jpg_name = frames[f_idx][:-4] + '.jpg'
png_name = frames[f_idx][:-4] + '.png'
info['frames'].append(jpg_name)
reseed(sequence_seed)
this_im = Image.open(path.join(vid_im_path, jpg_name)).convert('RGB')
this_im = self.all_im_dual_transform(this_im)
this_im = self.all_im_lone_transform(this_im)
reseed(sequence_seed)
this_gt = Image.open(path.join(vid_gt_path, png_name)).convert('P')
this_gt = self.all_gt_dual_transform(this_gt)
pairwise_seed = np.random.randint(2147483647)
reseed(pairwise_seed)
this_im = self.pair_im_dual_transform(this_im)
this_im = self.pair_im_lone_transform(this_im)
reseed(pairwise_seed)
this_gt = self.pair_gt_dual_transform(this_gt)
this_im = self.final_im_transform(this_im)
this_gt = np.array(this_gt)
images.append(this_im)
masks.append(this_gt)
images = torch.stack(images, 0)
labels = np.unique(masks[0])
# Remove background
labels = labels[labels!=0]
if self.is_bl:
# Find large enough labels
good_lables = []
for l in labels:
pixel_sum = (masks[0]==l).sum()
if pixel_sum > 10*10:
# OK if the object is always this small
# Not OK if it is actually much bigger
if pixel_sum > 30*30:
good_lables.append(l)
elif max((masks[1]==l).sum(), (masks[2]==l).sum()) < 20*20:
good_lables.append(l)
labels = np.array(good_lables, dtype=np.uint8)
if len(labels) == 0:
target_objects = []
trials += 1
else:
target_objects = labels.tolist()
break
if len(target_objects) > self.max_num_obj:
target_objects = np.random.choice(target_objects, size=self.max_num_obj, replace=False)
info['num_objects'] = max(1, len(target_objects))
masks = np.stack(masks, 0)
# Generate one-hot ground-truth
cls_gt = np.zeros((self.num_frames, 384, 384), dtype=np.int64)
first_frame_gt = np.zeros((1, self.max_num_obj, 384, 384), dtype=np.int64)
for i, l in enumerate(target_objects):
this_mask = (masks==l)
cls_gt[this_mask] = i+1
first_frame_gt[0,i] = (this_mask[0])
cls_gt = np.expand_dims(cls_gt, 1)
# 1 if object exist, 0 otherwise
selector = [1 if i < info['num_objects'] else 0 for i in range(self.max_num_obj)]
selector = torch.FloatTensor(selector)
data = {
'rgb': images,
'first_frame_gt': first_frame_gt,
'cls_gt': cls_gt,
'selector': selector,
'info': info,
}
return data
def __len__(self):
return len(self.videos)

260
tracker/eval.py Normal file
View File

@@ -0,0 +1,260 @@
import os
from os import path
from argparse import ArgumentParser
import shutil
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from PIL import Image
from inference.data.test_datasets import LongTestDataset, DAVISTestDataset, YouTubeVOSTestDataset
from inference.data.mask_mapper import MaskMapper
from model.network import XMem
from inference.inference_core import InferenceCore
from progressbar import progressbar
try:
import hickle as hkl
except ImportError:
print('Failed to import hickle. Fine if not using multi-scale testing.')
"""
Arguments loading
"""
parser = ArgumentParser()
parser.add_argument('--model', default='/ssd1/gaomingqi/checkpoints/XMem-s012.pth')
# Data options
parser.add_argument('--d16_path', default='../DAVIS/2016')
parser.add_argument('--d17_path', default='../DAVIS/2017')
parser.add_argument('--y18_path', default='/ssd1/gaomingqi/datasets/youtube-vos/2018')
parser.add_argument('--y19_path', default='../YouTube')
parser.add_argument('--lv_path', default='../long_video_set')
# For generic (G) evaluation, point to a folder that contains "JPEGImages" and "Annotations"
parser.add_argument('--generic_path')
parser.add_argument('--dataset', help='D16/D17/Y18/Y19/LV1/LV3/G', default='D17')
parser.add_argument('--split', help='val/test', default='val')
parser.add_argument('--output', default=None)
parser.add_argument('--save_all', action='store_true',
help='Save all frames. Useful only in YouTubeVOS/long-time video', )
parser.add_argument('--benchmark', action='store_true', help='enable to disable amp for FPS benchmarking')
# Long-term memory options
parser.add_argument('--disable_long_term', action='store_true')
parser.add_argument('--max_mid_term_frames', help='T_max in paper, decrease to save memory', type=int, default=10)
parser.add_argument('--min_mid_term_frames', help='T_min in paper, decrease to save memory', type=int, default=5)
parser.add_argument('--max_long_term_elements', help='LT_max in paper, increase if objects disappear for a long time',
type=int, default=10000)
parser.add_argument('--num_prototypes', help='P in paper', type=int, default=128)
parser.add_argument('--top_k', type=int, default=30)
parser.add_argument('--mem_every', help='r in paper. Increase to improve running speed.', type=int, default=5)
parser.add_argument('--deep_update_every', help='Leave -1 normally to synchronize with mem_every', type=int, default=-1)
# Multi-scale options
parser.add_argument('--save_scores', action='store_true')
parser.add_argument('--flip', action='store_true')
parser.add_argument('--size', default=480, type=int,
help='Resize the shorter side to this size. -1 to use original resolution. ')
args = parser.parse_args()
config = vars(args)
config['enable_long_term'] = not config['disable_long_term']
if args.output is None:
args.output = f'../output/{args.dataset}_{args.split}'
print(f'Output path not provided. Defaulting to {args.output}')
"""
Data preparation
"""
is_youtube = args.dataset.startswith('Y')
is_davis = args.dataset.startswith('D')
is_lv = args.dataset.startswith('LV')
if is_youtube or args.save_scores:
out_path = path.join(args.output, 'Annotations')
else:
out_path = args.output
if is_youtube:
if args.dataset == 'Y18':
yv_path = args.y18_path
elif args.dataset == 'Y19':
yv_path = args.y19_path
if args.split == 'val':
args.split = 'valid'
meta_dataset = YouTubeVOSTestDataset(data_root=yv_path, split='valid', size=args.size)
elif args.split == 'test':
meta_dataset = YouTubeVOSTestDataset(data_root=yv_path, split='test', size=args.size)
else:
raise NotImplementedError
elif is_davis:
if args.dataset == 'D16':
if args.split == 'val':
# Set up Dataset, a small hack to use the image set in the 2017 folder because the 2016 one is of a different format
meta_dataset = DAVISTestDataset(args.d16_path, imset='../../2017/trainval/ImageSets/2016/val.txt', size=args.size)
else:
raise NotImplementedError
palette = None
elif args.dataset == 'D17':
if args.split == 'val':
meta_dataset = DAVISTestDataset(path.join(args.d17_path, 'trainval'), imset='2017/val.txt', size=args.size)
elif args.split == 'test':
meta_dataset = DAVISTestDataset(path.join(args.d17_path, 'test-dev'), imset='2017/test-dev.txt', size=args.size)
else:
raise NotImplementedError
elif is_lv:
if args.dataset == 'LV1':
meta_dataset = LongTestDataset(path.join(args.lv_path, 'long_video'))
elif args.dataset == 'LV3':
meta_dataset = LongTestDataset(path.join(args.lv_path, 'long_video_x3'))
else:
raise NotImplementedError
elif args.dataset == 'G':
meta_dataset = LongTestDataset(path.join(args.generic_path), size=args.size)
if not args.save_all:
args.save_all = True
print('save_all is forced to be true in generic evaluation mode.')
else:
raise NotImplementedError
torch.autograd.set_grad_enabled(False)
# Set up loader
meta_loader = meta_dataset.get_datasets()
# Load our checkpoint
network = XMem(config, args.model).cuda().eval()
if args.model is not None:
model_weights = torch.load(args.model)
network.load_weights(model_weights, init_as_zero_if_needed=True)
else:
print('No model loaded.')
total_process_time = 0
total_frames = 0
# Start eval
for vid_reader in progressbar(meta_loader, max_value=len(meta_dataset), redirect_stdout=True):
loader = DataLoader(vid_reader, batch_size=1, shuffle=False, num_workers=2)
vid_name = vid_reader.vid_name
vid_length = len(loader)
# no need to count usage for LT if the video is not that long anyway
config['enable_long_term_count_usage'] = (
config['enable_long_term'] and
(vid_length
/ (config['max_mid_term_frames']-config['min_mid_term_frames'])
* config['num_prototypes'])
>= config['max_long_term_elements']
)
mapper = MaskMapper()
processor = InferenceCore(network, config=config)
first_mask_loaded = False
for ti, data in enumerate(loader):
with torch.cuda.amp.autocast(enabled=not args.benchmark):
rgb = data['rgb'].cuda()[0]
msk = data.get('mask')
info = data['info']
frame = info['frame'][0]
shape = info['shape']
need_resize = info['need_resize'][0]
"""
For timing see https://discuss.pytorch.org/t/how-to-measure-time-in-pytorch/26964
Seems to be very similar in testing as my previous timing method
with two cuda sync + time.time() in STCN though
"""
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
if not first_mask_loaded:
if msk is not None:
first_mask_loaded = True
else:
# no point to do anything without a mask
continue
if args.flip:
rgb = torch.flip(rgb, dims=[-1])
msk = torch.flip(msk, dims=[-1]) if msk is not None else None
# Map possibly non-continuous labels to continuous ones
if msk is not None:
msk, labels = mapper.convert_mask(msk[0].numpy())
msk = torch.Tensor(msk).cuda()
if need_resize:
msk = vid_reader.resize_mask(msk.unsqueeze(0))[0]
processor.set_all_labels(list(mapper.remappings.values()))
else:
labels = None
# Run the model on this frame
prob = processor.step(rgb, msk, labels, end=(ti==vid_length-1))
# consider prob as prompt to refine segment results
# Upsample to original size if needed
if need_resize:
prob = F.interpolate(prob.unsqueeze(1), shape, mode='bilinear', align_corners=False)[:,0]
end.record()
torch.cuda.synchronize()
total_process_time += (start.elapsed_time(end)/1000)
total_frames += 1
if args.flip:
prob = torch.flip(prob, dims=[-1])
# Probability mask -> index mask
out_mask = torch.argmax(prob, dim=0)
out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
if args.save_scores:
prob = (prob.detach().cpu().numpy()*255).astype(np.uint8)
# Save the mask
if args.save_all or info['save'][0]:
this_out_path = path.join(out_path, vid_name)
os.makedirs(this_out_path, exist_ok=True)
out_mask = mapper.remap_index_mask(out_mask)
out_img = Image.fromarray(out_mask)
if vid_reader.get_palette() is not None:
out_img.putpalette(vid_reader.get_palette())
out_img.save(os.path.join(this_out_path, frame[:-4]+'.png'))
if args.save_scores:
np_path = path.join(args.output, 'Scores', vid_name)
os.makedirs(np_path, exist_ok=True)
if ti==len(loader)-1:
hkl.dump(mapper.remappings, path.join(np_path, f'backward.hkl'), mode='w')
if args.save_all or info['save'][0]:
hkl.dump(prob, path.join(np_path, f'{frame[:-4]}.hkl'), mode='w', compression='lzf')
print(f'Total processing time: {total_process_time}')
print(f'Total processed frames: {total_frames}')
print(f'FPS: {total_frames / total_process_time}')
print(f'Max allocated memory (MB): {torch.cuda.max_memory_allocated() / (2**20)}')
if not args.save_scores:
if is_youtube:
print('Making zip for YouTubeVOS...')
shutil.make_archive(path.join(args.output, path.basename(args.output)), 'zip', args.output, 'Annotations')
elif is_davis and args.split == 'test':
print('Making zip for DAVIS test-dev...')
shutil.make_archive(args.output, 'zip', args.output)

View File

View File

View File

@@ -0,0 +1,64 @@
import numpy as np
import torch
from dataset.util import all_to_onehot
class MaskMapper:
"""
This class is used to convert a indexed-mask to a one-hot representation.
It also takes care of remapping non-continuous indices
It has two modes:
1. Default. Only masks with new indices are supposed to go into the remapper.
This is also the case for YouTubeVOS.
i.e., regions with index 0 are not "background", but "don't care".
2. Exhaustive. Regions with index 0 are considered "background".
Every single pixel is considered to be "labeled".
"""
def __init__(self):
self.labels = []
self.remappings = {}
# if coherent, no mapping is required
self.coherent = True
def convert_mask(self, mask, exhaustive=False):
# mask is in index representation, H*W numpy array
labels = np.unique(mask).astype(np.uint8)
labels = labels[labels!=0].tolist()
new_labels = list(set(labels) - set(self.labels))
if not exhaustive:
assert len(new_labels) == len(labels), 'Old labels found in non-exhaustive mode'
# add new remappings
for i, l in enumerate(new_labels):
self.remappings[l] = i+len(self.labels)+1
if self.coherent and i+len(self.labels)+1 != l:
self.coherent = False
if exhaustive:
new_mapped_labels = range(1, len(self.labels)+len(new_labels)+1)
else:
if self.coherent:
new_mapped_labels = new_labels
else:
new_mapped_labels = range(len(self.labels)+1, len(self.labels)+len(new_labels)+1)
self.labels.extend(new_labels)
mask = torch.from_numpy(all_to_onehot(mask, self.labels)).float()
# mask num_objects*H*W
return mask, new_mapped_labels
def remap_index_mask(self, mask):
# mask is in index representation, H*W numpy array
if self.coherent:
return mask
new_mask = np.zeros_like(mask)
for l, i in self.remappings.items():
new_mask[mask==i] = l
return new_mask

View File

@@ -0,0 +1,96 @@
import os
from os import path
import json
from inference.data.video_reader import VideoReader
class LongTestDataset:
def __init__(self, data_root, size=-1):
self.image_dir = path.join(data_root, 'JPEGImages')
self.mask_dir = path.join(data_root, 'Annotations')
self.size = size
self.vid_list = sorted(os.listdir(self.image_dir))
def get_datasets(self):
for video in self.vid_list:
yield VideoReader(video,
path.join(self.image_dir, video),
path.join(self.mask_dir, video),
to_save = [
name[:-4] for name in os.listdir(path.join(self.mask_dir, video))
],
size=self.size,
)
def __len__(self):
return len(self.vid_list)
class DAVISTestDataset:
def __init__(self, data_root, imset='2017/val.txt', size=-1):
if size != 480:
self.image_dir = path.join(data_root, 'JPEGImages', 'Full-Resolution')
self.mask_dir = path.join(data_root, 'Annotations', 'Full-Resolution')
if not path.exists(self.image_dir):
print(f'{self.image_dir} not found. Look at other options.')
self.image_dir = path.join(data_root, 'JPEGImages', '1080p')
self.mask_dir = path.join(data_root, 'Annotations', '1080p')
assert path.exists(self.image_dir), 'path not found'
else:
self.image_dir = path.join(data_root, 'JPEGImages', '480p')
self.mask_dir = path.join(data_root, 'Annotations', '480p')
self.size_dir = path.join(data_root, 'JPEGImages', '480p')
self.size = size
with open(path.join(data_root, 'ImageSets', imset)) as f:
self.vid_list = sorted([line.strip() for line in f])
def get_datasets(self):
for video in self.vid_list:
yield VideoReader(video,
path.join(self.image_dir, video),
path.join(self.mask_dir, video),
size=self.size,
size_dir=path.join(self.size_dir, video),
)
def __len__(self):
return len(self.vid_list)
class YouTubeVOSTestDataset:
def __init__(self, data_root, split, size=480):
self.image_dir = path.join(data_root, 'all_frames', split+'_all_frames', 'JPEGImages')
self.mask_dir = path.join(data_root, split, 'Annotations')
self.size = size
self.vid_list = sorted(os.listdir(self.image_dir))
self.req_frame_list = {}
with open(path.join(data_root, split, 'meta.json')) as f:
# read meta.json to know which frame is required for evaluation
meta = json.load(f)['videos']
for vid in self.vid_list:
req_frames = []
objects = meta[vid]['objects']
for value in objects.values():
req_frames.extend(value['frames'])
req_frames = list(set(req_frames))
self.req_frame_list[vid] = req_frames
def get_datasets(self):
for video in self.vid_list:
yield VideoReader(video,
path.join(self.image_dir, video),
path.join(self.mask_dir, video),
size=self.size,
to_save=self.req_frame_list[video],
use_all_mask=True
)
def __len__(self):
return len(self.vid_list)

View File

@@ -0,0 +1,100 @@
import os
from os import path
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from torchvision.transforms import InterpolationMode
import torch.nn.functional as F
from PIL import Image
import numpy as np
from dataset.range_transform import im_normalization
class VideoReader(Dataset):
"""
This class is used to read a video, one frame at a time
"""
def __init__(self, vid_name, image_dir, mask_dir, size=-1, to_save=None, use_all_mask=False, size_dir=None):
"""
image_dir - points to a directory of jpg images
mask_dir - points to a directory of png masks
size - resize min. side to size. Does nothing if <0.
to_save - optionally contains a list of file names without extensions
where the segmentation mask is required
use_all_mask - when true, read all available mask in mask_dir.
Default false. Set to true for YouTubeVOS validation.
"""
self.vid_name = vid_name
self.image_dir = image_dir
self.mask_dir = mask_dir
self.to_save = to_save
self.use_all_mask = use_all_mask
if size_dir is None:
self.size_dir = self.image_dir
else:
self.size_dir = size_dir
self.frames = sorted(os.listdir(self.image_dir))
self.palette = Image.open(path.join(mask_dir, sorted(os.listdir(mask_dir))[0])).getpalette()
self.first_gt_path = path.join(self.mask_dir, sorted(os.listdir(self.mask_dir))[0])
if size < 0:
self.im_transform = transforms.Compose([
transforms.ToTensor(),
im_normalization,
])
else:
self.im_transform = transforms.Compose([
transforms.ToTensor(),
im_normalization,
transforms.Resize(size, interpolation=InterpolationMode.BILINEAR),
])
self.size = size
def __getitem__(self, idx):
frame = self.frames[idx]
info = {}
data = {}
info['frame'] = frame
info['save'] = (self.to_save is None) or (frame[:-4] in self.to_save)
im_path = path.join(self.image_dir, frame)
img = Image.open(im_path).convert('RGB')
if self.image_dir == self.size_dir:
shape = np.array(img).shape[:2]
else:
size_path = path.join(self.size_dir, frame)
size_im = Image.open(size_path).convert('RGB')
shape = np.array(size_im).shape[:2]
gt_path = path.join(self.mask_dir, frame[:-4]+'.png')
img = self.im_transform(img)
load_mask = self.use_all_mask or (gt_path == self.first_gt_path)
if load_mask and path.exists(gt_path):
mask = Image.open(gt_path).convert('P')
mask = np.array(mask, dtype=np.uint8)
data['mask'] = mask
info['shape'] = shape
info['need_resize'] = not (self.size < 0)
data['rgb'] = img
data['info'] = info
return data
def resize_mask(self, mask):
# mask transform is applied AFTER mapper, so we need to post-process it in eval.py
h, w = mask.shape[-2:]
min_hw = min(h, w)
return F.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)),
mode='nearest')
def get_palette(self):
return self.palette
def __len__(self):
return len(self.frames)

View File

@@ -0,0 +1,110 @@
from inference.memory_manager import MemoryManager
from model.network import XMem
from model.aggregate import aggregate
from util.tensor_util import pad_divide_by, unpad
class InferenceCore:
def __init__(self, network:XMem, config):
self.config = config
self.network = network
self.mem_every = config['mem_every']
self.deep_update_every = config['deep_update_every']
self.enable_long_term = config['enable_long_term']
# if deep_update_every < 0, synchronize deep update with memory frame
self.deep_update_sync = (self.deep_update_every < 0)
self.clear_memory()
self.all_labels = None
def clear_memory(self):
self.curr_ti = -1
self.last_mem_ti = 0
if not self.deep_update_sync:
self.last_deep_update_ti = -self.deep_update_every
self.memory = MemoryManager(config=self.config)
def update_config(self, config):
self.mem_every = config['mem_every']
self.deep_update_every = config['deep_update_every']
self.enable_long_term = config['enable_long_term']
# if deep_update_every < 0, synchronize deep update with memory frame
self.deep_update_sync = (self.deep_update_every < 0)
self.memory.update_config(config)
def set_all_labels(self, all_labels):
# self.all_labels = [l.item() for l in all_labels]
self.all_labels = all_labels
def step(self, image, mask=None, valid_labels=None, end=False):
# image: 3*H*W
# mask: num_objects*H*W or None
self.curr_ti += 1
image, self.pad = pad_divide_by(image, 16)
image = image.unsqueeze(0) # add the batch dimension
is_mem_frame = ((self.curr_ti-self.last_mem_ti >= self.mem_every) or (mask is not None)) and (not end)
need_segment = (self.curr_ti > 0) and ((valid_labels is None) or (len(self.all_labels) != len(valid_labels)))
is_deep_update = (
(self.deep_update_sync and is_mem_frame) or # synchronized
(not self.deep_update_sync and self.curr_ti-self.last_deep_update_ti >= self.deep_update_every) # no-sync
) and (not end)
is_normal_update = (not self.deep_update_sync or not is_deep_update) and (not end)
key, shrinkage, selection, f16, f8, f4 = self.network.encode_key(image,
need_ek=(self.enable_long_term or need_segment),
need_sk=is_mem_frame)
multi_scale_features = (f16, f8, f4)
# segment the current frame is needed
if need_segment:
memory_readout = self.memory.match_memory(key, selection).unsqueeze(0)
hidden, _, pred_prob_with_bg = self.network.segment(multi_scale_features, memory_readout,
self.memory.get_hidden(), h_out=is_normal_update, strip_bg=False)
# remove batch dim
pred_prob_with_bg = pred_prob_with_bg[0]
pred_prob_no_bg = pred_prob_with_bg[1:]
if is_normal_update:
self.memory.set_hidden(hidden)
else:
pred_prob_no_bg = pred_prob_with_bg = None
# use the input mask if any
if mask is not None:
mask, _ = pad_divide_by(mask, 16)
if pred_prob_no_bg is not None:
# if we have a predicted mask, we work on it
# make pred_prob_no_bg consistent with the input mask
mask_regions = (mask.sum(0) > 0.5)
pred_prob_no_bg[:, mask_regions] = 0
# shift by 1 because mask/pred_prob_no_bg do not contain background
mask = mask.type_as(pred_prob_no_bg)
if valid_labels is not None:
shift_by_one_non_labels = [i for i in range(pred_prob_no_bg.shape[0]) if (i+1) not in valid_labels]
# non-labelled objects are copied from the predicted mask
mask[shift_by_one_non_labels] = pred_prob_no_bg[shift_by_one_non_labels]
pred_prob_with_bg = aggregate(mask, dim=0)
# also create new hidden states
self.memory.create_hidden_state(len(self.all_labels), key)
# save as memory if needed
if is_mem_frame:
value, hidden = self.network.encode_value(image, f16, self.memory.get_hidden(),
pred_prob_with_bg[1:].unsqueeze(0), is_deep_update=is_deep_update)
self.memory.add_memory(key, shrinkage, value, self.all_labels,
selection=selection if self.enable_long_term else None)
self.last_mem_ti = self.curr_ti
if is_deep_update:
self.memory.set_hidden(hidden)
self.last_deep_update_ti = self.curr_ti
return unpad(pred_prob_with_bg, self.pad)

View File

@@ -0,0 +1,215 @@
import torch
from typing import List
class KeyValueMemoryStore:
"""
Works for key/value pairs type storage
e.g., working and long-term memory
"""
"""
An object group is created when new objects enter the video
Objects in the same group share the same temporal extent
i.e., objects initialized in the same frame are in the same group
For DAVIS/interactive, there is only one object group
For YouTubeVOS, there can be multiple object groups
"""
def __init__(self, count_usage: bool):
self.count_usage = count_usage
# keys are stored in a single tensor and are shared between groups/objects
# values are stored as a list indexed by object groups
self.k = None
self.v = []
self.obj_groups = []
# for debugging only
self.all_objects = []
# shrinkage and selection are also single tensors
self.s = self.e = None
# usage
if self.count_usage:
self.use_count = self.life_count = None
def add(self, key, value, shrinkage, selection, objects: List[int]):
new_count = torch.zeros((key.shape[0], 1, key.shape[2]), device=key.device, dtype=torch.float32)
new_life = torch.zeros((key.shape[0], 1, key.shape[2]), device=key.device, dtype=torch.float32) + 1e-7
# add the key
if self.k is None:
self.k = key
self.s = shrinkage
self.e = selection
if self.count_usage:
self.use_count = new_count
self.life_count = new_life
else:
self.k = torch.cat([self.k, key], -1)
if shrinkage is not None:
self.s = torch.cat([self.s, shrinkage], -1)
if selection is not None:
self.e = torch.cat([self.e, selection], -1)
if self.count_usage:
self.use_count = torch.cat([self.use_count, new_count], -1)
self.life_count = torch.cat([self.life_count, new_life], -1)
# add the value
if objects is not None:
# When objects is given, v is a tensor; used in working memory
assert isinstance(value, torch.Tensor)
# First consume objects that are already in the memory bank
# cannot use set here because we need to preserve order
# shift by one as background is not part of value
remaining_objects = [obj-1 for obj in objects]
for gi, group in enumerate(self.obj_groups):
for obj in group:
# should properly raise an error if there are overlaps in obj_groups
remaining_objects.remove(obj)
self.v[gi] = torch.cat([self.v[gi], value[group]], -1)
# If there are remaining objects, add them as a new group
if len(remaining_objects) > 0:
new_group = list(remaining_objects)
self.v.append(value[new_group])
self.obj_groups.append(new_group)
self.all_objects.extend(new_group)
assert sorted(self.all_objects) == self.all_objects, 'Objects MUST be inserted in sorted order '
else:
# When objects is not given, v is a list that already has the object groups sorted
# used in long-term memory
assert isinstance(value, list)
for gi, gv in enumerate(value):
if gv is None:
continue
if gi < self.num_groups:
self.v[gi] = torch.cat([self.v[gi], gv], -1)
else:
self.v.append(gv)
def update_usage(self, usage):
# increase all life count by 1
# increase use of indexed elements
if not self.count_usage:
return
self.use_count += usage.view_as(self.use_count)
self.life_count += 1
def sieve_by_range(self, start: int, end: int, min_size: int):
# keep only the elements *outside* of this range (with some boundary conditions)
# i.e., concat (a[:start], a[end:])
# min_size is only used for values, we do not sieve values under this size
# (because they are not consolidated)
if end == 0:
# negative 0 would not work as the end index!
self.k = self.k[:,:,:start]
if self.count_usage:
self.use_count = self.use_count[:,:,:start]
self.life_count = self.life_count[:,:,:start]
if self.s is not None:
self.s = self.s[:,:,:start]
if self.e is not None:
self.e = self.e[:,:,:start]
for gi in range(self.num_groups):
if self.v[gi].shape[-1] >= min_size:
self.v[gi] = self.v[gi][:,:,:start]
else:
self.k = torch.cat([self.k[:,:,:start], self.k[:,:,end:]], -1)
if self.count_usage:
self.use_count = torch.cat([self.use_count[:,:,:start], self.use_count[:,:,end:]], -1)
self.life_count = torch.cat([self.life_count[:,:,:start], self.life_count[:,:,end:]], -1)
if self.s is not None:
self.s = torch.cat([self.s[:,:,:start], self.s[:,:,end:]], -1)
if self.e is not None:
self.e = torch.cat([self.e[:,:,:start], self.e[:,:,end:]], -1)
for gi in range(self.num_groups):
if self.v[gi].shape[-1] >= min_size:
self.v[gi] = torch.cat([self.v[gi][:,:,:start], self.v[gi][:,:,end:]], -1)
def remove_obsolete_features(self, max_size: int):
# normalize with life duration
usage = self.get_usage().flatten()
values, _ = torch.topk(usage, k=(self.size-max_size), largest=False, sorted=True)
survived = (usage > values[-1])
self.k = self.k[:, :, survived]
self.s = self.s[:, :, survived] if self.s is not None else None
# Long-term memory does not store ek so this should not be needed
self.e = self.e[:, :, survived] if self.e is not None else None
if self.num_groups > 1:
raise NotImplementedError("""The current data structure does not support feature removal with
multiple object groups (e.g., some objects start to appear later in the video)
The indices for "survived" is based on keys but not all values are present for every key
Basically we need to remap the indices for keys to values
""")
for gi in range(self.num_groups):
self.v[gi] = self.v[gi][:, :, survived]
self.use_count = self.use_count[:, :, survived]
self.life_count = self.life_count[:, :, survived]
def get_usage(self):
# return normalized usage
if not self.count_usage:
raise RuntimeError('I did not count usage!')
else:
usage = self.use_count / self.life_count
return usage
def get_all_sliced(self, start: int, end: int):
# return k, sk, ek, usage in order, sliced by start and end
if end == 0:
# negative 0 would not work as the end index!
k = self.k[:,:,start:]
sk = self.s[:,:,start:] if self.s is not None else None
ek = self.e[:,:,start:] if self.e is not None else None
usage = self.get_usage()[:,:,start:]
else:
k = self.k[:,:,start:end]
sk = self.s[:,:,start:end] if self.s is not None else None
ek = self.e[:,:,start:end] if self.e is not None else None
usage = self.get_usage()[:,:,start:end]
return k, sk, ek, usage
def get_v_size(self, ni: int):
return self.v[ni].shape[2]
def engaged(self):
return self.k is not None
@property
def size(self):
if self.k is None:
return 0
else:
return self.k.shape[-1]
@property
def num_groups(self):
return len(self.v)
@property
def key(self):
return self.k
@property
def value(self):
return self.v
@property
def shrinkage(self):
return self.s
@property
def selection(self):
return self.e

View File

@@ -0,0 +1,284 @@
import torch
import warnings
from inference.kv_memory_store import KeyValueMemoryStore
from model.memory_util import *
class MemoryManager:
"""
Manages all three memory stores and the transition between working/long-term memory
"""
def __init__(self, config):
self.hidden_dim = config['hidden_dim']
self.top_k = config['top_k']
self.enable_long_term = config['enable_long_term']
self.enable_long_term_usage = config['enable_long_term_count_usage']
if self.enable_long_term:
self.max_mt_frames = config['max_mid_term_frames']
self.min_mt_frames = config['min_mid_term_frames']
self.num_prototypes = config['num_prototypes']
self.max_long_elements = config['max_long_term_elements']
# dimensions will be inferred from input later
self.CK = self.CV = None
self.H = self.W = None
# The hidden state will be stored in a single tensor for all objects
# B x num_objects x CH x H x W
self.hidden = None
self.work_mem = KeyValueMemoryStore(count_usage=self.enable_long_term)
if self.enable_long_term:
self.long_mem = KeyValueMemoryStore(count_usage=self.enable_long_term_usage)
self.reset_config = True
def update_config(self, config):
self.reset_config = True
self.hidden_dim = config['hidden_dim']
self.top_k = config['top_k']
assert self.enable_long_term == config['enable_long_term'], 'cannot update this'
assert self.enable_long_term_usage == config['enable_long_term_count_usage'], 'cannot update this'
self.enable_long_term_usage = config['enable_long_term_count_usage']
if self.enable_long_term:
self.max_mt_frames = config['max_mid_term_frames']
self.min_mt_frames = config['min_mid_term_frames']
self.num_prototypes = config['num_prototypes']
self.max_long_elements = config['max_long_term_elements']
def _readout(self, affinity, v):
# this function is for a single object group
return v @ affinity
def match_memory(self, query_key, selection):
# query_key: B x C^k x H x W
# selection: B x C^k x H x W
num_groups = self.work_mem.num_groups
h, w = query_key.shape[-2:]
query_key = query_key.flatten(start_dim=2)
selection = selection.flatten(start_dim=2) if selection is not None else None
"""
Memory readout using keys
"""
if self.enable_long_term and self.long_mem.engaged():
# Use long-term memory
long_mem_size = self.long_mem.size
memory_key = torch.cat([self.long_mem.key, self.work_mem.key], -1)
shrinkage = torch.cat([self.long_mem.shrinkage, self.work_mem.shrinkage], -1)
similarity = get_similarity(memory_key, shrinkage, query_key, selection)
work_mem_similarity = similarity[:, long_mem_size:]
long_mem_similarity = similarity[:, :long_mem_size]
# get the usage with the first group
# the first group always have all the keys valid
affinity, usage = do_softmax(
torch.cat([long_mem_similarity[:, -self.long_mem.get_v_size(0):], work_mem_similarity], 1),
top_k=self.top_k, inplace=True, return_usage=True)
affinity = [affinity]
# compute affinity group by group as later groups only have a subset of keys
for gi in range(1, num_groups):
if gi < self.long_mem.num_groups:
# merge working and lt similarities before softmax
affinity_one_group = do_softmax(
torch.cat([long_mem_similarity[:, -self.long_mem.get_v_size(gi):],
work_mem_similarity[:, -self.work_mem.get_v_size(gi):]], 1),
top_k=self.top_k, inplace=True)
else:
# no long-term memory for this group
affinity_one_group = do_softmax(work_mem_similarity[:, -self.work_mem.get_v_size(gi):],
top_k=self.top_k, inplace=(gi==num_groups-1))
affinity.append(affinity_one_group)
all_memory_value = []
for gi, gv in enumerate(self.work_mem.value):
# merge the working and lt values before readout
if gi < self.long_mem.num_groups:
all_memory_value.append(torch.cat([self.long_mem.value[gi], self.work_mem.value[gi]], -1))
else:
all_memory_value.append(gv)
"""
Record memory usage for working and long-term memory
"""
# ignore the index return for long-term memory
work_usage = usage[:, long_mem_size:]
self.work_mem.update_usage(work_usage.flatten())
if self.enable_long_term_usage:
# ignore the index return for working memory
long_usage = usage[:, :long_mem_size]
self.long_mem.update_usage(long_usage.flatten())
else:
# No long-term memory
similarity = get_similarity(self.work_mem.key, self.work_mem.shrinkage, query_key, selection)
if self.enable_long_term:
affinity, usage = do_softmax(similarity, inplace=(num_groups==1),
top_k=self.top_k, return_usage=True)
# Record memory usage for working memory
self.work_mem.update_usage(usage.flatten())
else:
affinity = do_softmax(similarity, inplace=(num_groups==1),
top_k=self.top_k, return_usage=False)
affinity = [affinity]
# compute affinity group by group as later groups only have a subset of keys
for gi in range(1, num_groups):
affinity_one_group = do_softmax(similarity[:, -self.work_mem.get_v_size(gi):],
top_k=self.top_k, inplace=(gi==num_groups-1))
affinity.append(affinity_one_group)
all_memory_value = self.work_mem.value
# Shared affinity within each group
all_readout_mem = torch.cat([
self._readout(affinity[gi], gv)
for gi, gv in enumerate(all_memory_value)
], 0)
return all_readout_mem.view(all_readout_mem.shape[0], self.CV, h, w)
def add_memory(self, key, shrinkage, value, objects, selection=None):
# key: 1*C*H*W
# value: 1*num_objects*C*H*W
# objects contain a list of object indices
if self.H is None or self.reset_config:
self.reset_config = False
self.H, self.W = key.shape[-2:]
self.HW = self.H*self.W
if self.enable_long_term:
# convert from num. frames to num. nodes
self.min_work_elements = self.min_mt_frames*self.HW
self.max_work_elements = self.max_mt_frames*self.HW
# key: 1*C*N
# value: num_objects*C*N
key = key.flatten(start_dim=2)
shrinkage = shrinkage.flatten(start_dim=2)
value = value[0].flatten(start_dim=2)
self.CK = key.shape[1]
self.CV = value.shape[1]
if selection is not None:
if not self.enable_long_term:
warnings.warn('the selection factor is only needed in long-term mode', UserWarning)
selection = selection.flatten(start_dim=2)
self.work_mem.add(key, value, shrinkage, selection, objects)
# long-term memory cleanup
if self.enable_long_term:
# Do memory compressed if needed
if self.work_mem.size >= self.max_work_elements:
# Remove obsolete features if needed
if self.long_mem.size >= (self.max_long_elements-self.num_prototypes):
self.long_mem.remove_obsolete_features(self.max_long_elements-self.num_prototypes)
self.compress_features()
def create_hidden_state(self, n, sample_key):
# n is the TOTAL number of objects
h, w = sample_key.shape[-2:]
if self.hidden is None:
self.hidden = torch.zeros((1, n, self.hidden_dim, h, w), device=sample_key.device)
elif self.hidden.shape[1] != n:
self.hidden = torch.cat([
self.hidden,
torch.zeros((1, n-self.hidden.shape[1], self.hidden_dim, h, w), device=sample_key.device)
], 1)
assert(self.hidden.shape[1] == n)
def set_hidden(self, hidden):
self.hidden = hidden
def get_hidden(self):
return self.hidden
def compress_features(self):
HW = self.HW
candidate_value = []
total_work_mem_size = self.work_mem.size
for gv in self.work_mem.value:
# Some object groups might be added later in the video
# So not all keys have values associated with all objects
# We need to keep track of the key->value validity
mem_size_in_this_group = gv.shape[-1]
if mem_size_in_this_group == total_work_mem_size:
# full LT
candidate_value.append(gv[:,:,HW:-self.min_work_elements+HW])
else:
# mem_size is smaller than total_work_mem_size, but at least HW
assert HW <= mem_size_in_this_group < total_work_mem_size
if mem_size_in_this_group > self.min_work_elements+HW:
# part of this object group still goes into LT
candidate_value.append(gv[:,:,HW:-self.min_work_elements+HW])
else:
# this object group cannot go to the LT at all
candidate_value.append(None)
# perform memory consolidation
prototype_key, prototype_value, prototype_shrinkage = self.consolidation(
*self.work_mem.get_all_sliced(HW, -self.min_work_elements+HW), candidate_value)
# remove consolidated working memory
self.work_mem.sieve_by_range(HW, -self.min_work_elements+HW, min_size=self.min_work_elements+HW)
# add to long-term memory
self.long_mem.add(prototype_key, prototype_value, prototype_shrinkage, selection=None, objects=None)
def consolidation(self, candidate_key, candidate_shrinkage, candidate_selection, usage, candidate_value):
# keys: 1*C*N
# values: num_objects*C*N
N = candidate_key.shape[-1]
# find the indices with max usage
_, max_usage_indices = torch.topk(usage, k=self.num_prototypes, dim=-1, sorted=True)
prototype_indices = max_usage_indices.flatten()
# Prototypes are invalid for out-of-bound groups
validity = [prototype_indices >= (N-gv.shape[2]) if gv is not None else None for gv in candidate_value]
prototype_key = candidate_key[:, :, prototype_indices]
prototype_selection = candidate_selection[:, :, prototype_indices] if candidate_selection is not None else None
"""
Potentiation step
"""
similarity = get_similarity(candidate_key, candidate_shrinkage, prototype_key, prototype_selection)
# convert similarity to affinity
# need to do it group by group since the softmax normalization would be different
affinity = [
do_softmax(similarity[:, -gv.shape[2]:, validity[gi]]) if gv is not None else None
for gi, gv in enumerate(candidate_value)
]
# some values can be have all False validity. Weed them out.
affinity = [
aff if aff is None or aff.shape[-1] > 0 else None for aff in affinity
]
# readout the values
prototype_value = [
self._readout(affinity[gi], gv) if affinity[gi] is not None else None
for gi, gv in enumerate(candidate_value)
]
# readout the shrinkage term
prototype_shrinkage = self._readout(affinity[0], candidate_shrinkage) if candidate_shrinkage is not None else None
return prototype_key, prototype_value, prototype_shrinkage

View File

View File

@@ -0,0 +1,17 @@
import torch
import torch.nn.functional as F
# Soft aggregation from STM
def aggregate(prob, dim, return_logits=False):
new_prob = torch.cat([
torch.prod(1-prob, dim=dim, keepdim=True),
prob
], dim).clamp(1e-7, 1-1e-7)
logits = torch.log((new_prob /(1-new_prob)))
prob = F.softmax(logits, dim=dim)
if return_logits:
return logits, prob
else:
return prob

77
tracker/model/cbam.py Normal file
View File

@@ -0,0 +1,77 @@
# Modified from https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class BasicConv(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
super(BasicConv, self).__init__()
self.out_channels = out_planes
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
def forward(self, x):
x = self.conv(x)
return x
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
class ChannelGate(nn.Module):
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
super(ChannelGate, self).__init__()
self.gate_channels = gate_channels
self.mlp = nn.Sequential(
Flatten(),
nn.Linear(gate_channels, gate_channels // reduction_ratio),
nn.ReLU(),
nn.Linear(gate_channels // reduction_ratio, gate_channels)
)
self.pool_types = pool_types
def forward(self, x):
channel_att_sum = None
for pool_type in self.pool_types:
if pool_type=='avg':
avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp( avg_pool )
elif pool_type=='max':
max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp( max_pool )
if channel_att_sum is None:
channel_att_sum = channel_att_raw
else:
channel_att_sum = channel_att_sum + channel_att_raw
scale = torch.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
return x * scale
class ChannelPool(nn.Module):
def forward(self, x):
return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )
class SpatialGate(nn.Module):
def __init__(self):
super(SpatialGate, self).__init__()
kernel_size = 7
self.compress = ChannelPool()
self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2)
def forward(self, x):
x_compress = self.compress(x)
x_out = self.spatial(x_compress)
scale = torch.sigmoid(x_out) # broadcasting
return x * scale
class CBAM(nn.Module):
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
super(CBAM, self).__init__()
self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
self.no_spatial=no_spatial
if not no_spatial:
self.SpatialGate = SpatialGate()
def forward(self, x):
x_out = self.ChannelGate(x)
if not self.no_spatial:
x_out = self.SpatialGate(x_out)
return x_out

View File

@@ -0,0 +1,82 @@
"""
Group-specific modules
They handle features that also depends on the mask.
Features are typically of shape
batch_size * num_objects * num_channels * H * W
All of them are permutation equivariant w.r.t. to the num_objects dimension
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
def interpolate_groups(g, ratio, mode, align_corners):
batch_size, num_objects = g.shape[:2]
g = F.interpolate(g.flatten(start_dim=0, end_dim=1),
scale_factor=ratio, mode=mode, align_corners=align_corners)
g = g.view(batch_size, num_objects, *g.shape[1:])
return g
def upsample_groups(g, ratio=2, mode='bilinear', align_corners=False):
return interpolate_groups(g, ratio, mode, align_corners)
def downsample_groups(g, ratio=1/2, mode='area', align_corners=None):
return interpolate_groups(g, ratio, mode, align_corners)
class GConv2D(nn.Conv2d):
def forward(self, g):
batch_size, num_objects = g.shape[:2]
g = super().forward(g.flatten(start_dim=0, end_dim=1))
return g.view(batch_size, num_objects, *g.shape[1:])
class GroupResBlock(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
if in_dim == out_dim:
self.downsample = None
else:
self.downsample = GConv2D(in_dim, out_dim, kernel_size=3, padding=1)
self.conv1 = GConv2D(in_dim, out_dim, kernel_size=3, padding=1)
self.conv2 = GConv2D(out_dim, out_dim, kernel_size=3, padding=1)
def forward(self, g):
out_g = self.conv1(F.relu(g))
out_g = self.conv2(F.relu(out_g))
if self.downsample is not None:
g = self.downsample(g)
return out_g + g
class MainToGroupDistributor(nn.Module):
def __init__(self, x_transform=None, method='cat', reverse_order=False):
super().__init__()
self.x_transform = x_transform
self.method = method
self.reverse_order = reverse_order
def forward(self, x, g):
num_objects = g.shape[1]
if self.x_transform is not None:
x = self.x_transform(x)
if self.method == 'cat':
if self.reverse_order:
g = torch.cat([g, x.unsqueeze(1).expand(-1,num_objects,-1,-1,-1)], 2)
else:
g = torch.cat([x.unsqueeze(1).expand(-1,num_objects,-1,-1,-1), g], 2)
elif self.method == 'add':
g = x.unsqueeze(1).expand(-1,num_objects,-1,-1,-1) + g
else:
raise NotImplementedError
return g

68
tracker/model/losses.py Normal file
View File

@@ -0,0 +1,68 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import defaultdict
def dice_loss(input_mask, cls_gt):
num_objects = input_mask.shape[1]
losses = []
for i in range(num_objects):
mask = input_mask[:,i].flatten(start_dim=1)
# background not in mask, so we add one to cls_gt
gt = (cls_gt==(i+1)).float().flatten(start_dim=1)
numerator = 2 * (mask * gt).sum(-1)
denominator = mask.sum(-1) + gt.sum(-1)
loss = 1 - (numerator + 1) / (denominator + 1)
losses.append(loss)
return torch.cat(losses).mean()
# https://stackoverflow.com/questions/63735255/how-do-i-compute-bootstrapped-cross-entropy-loss-in-pytorch
class BootstrappedCE(nn.Module):
def __init__(self, start_warm, end_warm, top_p=0.15):
super().__init__()
self.start_warm = start_warm
self.end_warm = end_warm
self.top_p = top_p
def forward(self, input, target, it):
if it < self.start_warm:
return F.cross_entropy(input, target), 1.0
raw_loss = F.cross_entropy(input, target, reduction='none').view(-1)
num_pixels = raw_loss.numel()
if it > self.end_warm:
this_p = self.top_p
else:
this_p = self.top_p + (1-self.top_p)*((self.end_warm-it)/(self.end_warm-self.start_warm))
loss, _ = torch.topk(raw_loss, int(num_pixels * this_p), sorted=False)
return loss.mean(), this_p
class LossComputer:
def __init__(self, config):
super().__init__()
self.config = config
self.bce = BootstrappedCE(config['start_warm'], config['end_warm'])
def compute(self, data, num_objects, it):
losses = defaultdict(int)
b, t = data['rgb'].shape[:2]
losses['total_loss'] = 0
for ti in range(1, t):
for bi in range(b):
loss, p = self.bce(data[f'logits_{ti}'][bi:bi+1, :num_objects[bi]+1], data['cls_gt'][bi:bi+1,ti,0], it)
losses['p'] += p / b / (t-1)
losses[f'ce_loss_{ti}'] += loss / b
losses['total_loss'] += losses['ce_loss_%d'%ti]
losses[f'dice_loss_{ti}'] = dice_loss(data[f'masks_{ti}'], data['cls_gt'][:,ti,0])
losses['total_loss'] += losses[f'dice_loss_{ti}']
return losses

View File

@@ -0,0 +1,80 @@
import math
import numpy as np
import torch
from typing import Optional
def get_similarity(mk, ms, qk, qe):
# used for training/inference and memory reading/memory potentiation
# mk: B x CK x [N] - Memory keys
# ms: B x 1 x [N] - Memory shrinkage
# qk: B x CK x [HW/P] - Query keys
# qe: B x CK x [HW/P] - Query selection
# Dimensions in [] are flattened
CK = mk.shape[1]
mk = mk.flatten(start_dim=2)
ms = ms.flatten(start_dim=1).unsqueeze(2) if ms is not None else None
qk = qk.flatten(start_dim=2)
qe = qe.flatten(start_dim=2) if qe is not None else None
if qe is not None:
# See appendix for derivation
# or you can just trust me ヽ(ー_ー )
mk = mk.transpose(1, 2)
a_sq = (mk.pow(2) @ qe)
two_ab = 2 * (mk @ (qk * qe))
b_sq = (qe * qk.pow(2)).sum(1, keepdim=True)
similarity = (-a_sq+two_ab-b_sq)
else:
# similar to STCN if we don't have the selection term
a_sq = mk.pow(2).sum(1).unsqueeze(2)
two_ab = 2 * (mk.transpose(1, 2) @ qk)
similarity = (-a_sq+two_ab)
if ms is not None:
similarity = similarity * ms / math.sqrt(CK) # B*N*HW
else:
similarity = similarity / math.sqrt(CK) # B*N*HW
return similarity
def do_softmax(similarity, top_k: Optional[int]=None, inplace=False, return_usage=False):
# normalize similarity with top-k softmax
# similarity: B x N x [HW/P]
# use inplace with care
if top_k is not None:
values, indices = torch.topk(similarity, k=top_k, dim=1)
x_exp = values.exp_()
x_exp /= torch.sum(x_exp, dim=1, keepdim=True)
if inplace:
similarity.zero_().scatter_(1, indices, x_exp) # B*N*HW
affinity = similarity
else:
affinity = torch.zeros_like(similarity).scatter_(1, indices, x_exp) # B*N*HW
else:
maxes = torch.max(similarity, dim=1, keepdim=True)[0]
x_exp = torch.exp(similarity - maxes)
x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True)
affinity = x_exp / x_exp_sum
indices = None
if return_usage:
return affinity, affinity.sum(dim=2)
return affinity
def get_affinity(mk, ms, qk, qe):
# shorthand used in training with no top-k
similarity = get_similarity(mk, ms, qk, qe)
affinity = do_softmax(similarity)
return affinity
def readout(affinity, mv):
B, CV, T, H, W = mv.shape
mo = mv.view(B, CV, T*H*W)
mem = torch.bmm(mo, affinity)
mem = mem.view(B, CV, H, W)
return mem

250
tracker/model/modules.py Normal file
View File

@@ -0,0 +1,250 @@
"""
modules.py - This file stores the rather boring network blocks.
x - usually means features that only depends on the image
g - usually means features that also depends on the mask.
They might have an extra "group" or "num_objects" dimension, hence
batch_size * num_objects * num_channels * H * W
The trailing number of a variable usually denote the stride
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.group_modules import *
from model import resnet
from model.cbam import CBAM
class FeatureFusionBlock(nn.Module):
def __init__(self, x_in_dim, g_in_dim, g_mid_dim, g_out_dim):
super().__init__()
self.distributor = MainToGroupDistributor()
self.block1 = GroupResBlock(x_in_dim+g_in_dim, g_mid_dim)
self.attention = CBAM(g_mid_dim)
self.block2 = GroupResBlock(g_mid_dim, g_out_dim)
def forward(self, x, g):
batch_size, num_objects = g.shape[:2]
g = self.distributor(x, g)
g = self.block1(g)
r = self.attention(g.flatten(start_dim=0, end_dim=1))
r = r.view(batch_size, num_objects, *r.shape[1:])
g = self.block2(g+r)
return g
class HiddenUpdater(nn.Module):
# Used in the decoder, multi-scale feature + GRU
def __init__(self, g_dims, mid_dim, hidden_dim):
super().__init__()
self.hidden_dim = hidden_dim
self.g16_conv = GConv2D(g_dims[0], mid_dim, kernel_size=1)
self.g8_conv = GConv2D(g_dims[1], mid_dim, kernel_size=1)
self.g4_conv = GConv2D(g_dims[2], mid_dim, kernel_size=1)
self.transform = GConv2D(mid_dim+hidden_dim, hidden_dim*3, kernel_size=3, padding=1)
nn.init.xavier_normal_(self.transform.weight)
def forward(self, g, h):
g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \
self.g4_conv(downsample_groups(g[2], ratio=1/4))
g = torch.cat([g, h], 2)
# defined slightly differently than standard GRU,
# namely the new value is generated before the forget gate.
# might provide better gradient but frankly it was initially just an
# implementation error that I never bothered fixing
values = self.transform(g)
forget_gate = torch.sigmoid(values[:,:,:self.hidden_dim])
update_gate = torch.sigmoid(values[:,:,self.hidden_dim:self.hidden_dim*2])
new_value = torch.tanh(values[:,:,self.hidden_dim*2:])
new_h = forget_gate*h*(1-update_gate) + update_gate*new_value
return new_h
class HiddenReinforcer(nn.Module):
# Used in the value encoder, a single GRU
def __init__(self, g_dim, hidden_dim):
super().__init__()
self.hidden_dim = hidden_dim
self.transform = GConv2D(g_dim+hidden_dim, hidden_dim*3, kernel_size=3, padding=1)
nn.init.xavier_normal_(self.transform.weight)
def forward(self, g, h):
g = torch.cat([g, h], 2)
# defined slightly differently than standard GRU,
# namely the new value is generated before the forget gate.
# might provide better gradient but frankly it was initially just an
# implementation error that I never bothered fixing
values = self.transform(g)
forget_gate = torch.sigmoid(values[:,:,:self.hidden_dim])
update_gate = torch.sigmoid(values[:,:,self.hidden_dim:self.hidden_dim*2])
new_value = torch.tanh(values[:,:,self.hidden_dim*2:])
new_h = forget_gate*h*(1-update_gate) + update_gate*new_value
return new_h
class ValueEncoder(nn.Module):
def __init__(self, value_dim, hidden_dim, single_object=False):
super().__init__()
self.single_object = single_object
network = resnet.resnet18(pretrained=True, extra_dim=1 if single_object else 2)
self.conv1 = network.conv1
self.bn1 = network.bn1
self.relu = network.relu # 1/2, 64
self.maxpool = network.maxpool
self.layer1 = network.layer1 # 1/4, 64
self.layer2 = network.layer2 # 1/8, 128
self.layer3 = network.layer3 # 1/16, 256
self.distributor = MainToGroupDistributor()
self.fuser = FeatureFusionBlock(1024, 256, value_dim, value_dim)
if hidden_dim > 0:
self.hidden_reinforce = HiddenReinforcer(value_dim, hidden_dim)
else:
self.hidden_reinforce = None
def forward(self, image, image_feat_f16, h, masks, others, is_deep_update=True):
# image_feat_f16 is the feature from the key encoder
if not self.single_object:
g = torch.stack([masks, others], 2)
else:
g = masks.unsqueeze(2)
g = self.distributor(image, g)
batch_size, num_objects = g.shape[:2]
g = g.flatten(start_dim=0, end_dim=1)
g = self.conv1(g)
g = self.bn1(g) # 1/2, 64
g = self.maxpool(g) # 1/4, 64
g = self.relu(g)
g = self.layer1(g) # 1/4
g = self.layer2(g) # 1/8
g = self.layer3(g) # 1/16
g = g.view(batch_size, num_objects, *g.shape[1:])
g = self.fuser(image_feat_f16, g)
if is_deep_update and self.hidden_reinforce is not None:
h = self.hidden_reinforce(g, h)
return g, h
class KeyEncoder(nn.Module):
def __init__(self):
super().__init__()
network = resnet.resnet50(pretrained=True)
self.conv1 = network.conv1
self.bn1 = network.bn1
self.relu = network.relu # 1/2, 64
self.maxpool = network.maxpool
self.res2 = network.layer1 # 1/4, 256
self.layer2 = network.layer2 # 1/8, 512
self.layer3 = network.layer3 # 1/16, 1024
def forward(self, f):
x = self.conv1(f)
x = self.bn1(x)
x = self.relu(x) # 1/2, 64
x = self.maxpool(x) # 1/4, 64
f4 = self.res2(x) # 1/4, 256
f8 = self.layer2(f4) # 1/8, 512
f16 = self.layer3(f8) # 1/16, 1024
return f16, f8, f4
class UpsampleBlock(nn.Module):
def __init__(self, skip_dim, g_up_dim, g_out_dim, scale_factor=2):
super().__init__()
self.skip_conv = nn.Conv2d(skip_dim, g_up_dim, kernel_size=3, padding=1)
self.distributor = MainToGroupDistributor(method='add')
self.out_conv = GroupResBlock(g_up_dim, g_out_dim)
self.scale_factor = scale_factor
def forward(self, skip_f, up_g):
skip_f = self.skip_conv(skip_f)
g = upsample_groups(up_g, ratio=self.scale_factor)
g = self.distributor(skip_f, g)
g = self.out_conv(g)
return g
class KeyProjection(nn.Module):
def __init__(self, in_dim, keydim):
super().__init__()
self.key_proj = nn.Conv2d(in_dim, keydim, kernel_size=3, padding=1)
# shrinkage
self.d_proj = nn.Conv2d(in_dim, 1, kernel_size=3, padding=1)
# selection
self.e_proj = nn.Conv2d(in_dim, keydim, kernel_size=3, padding=1)
nn.init.orthogonal_(self.key_proj.weight.data)
nn.init.zeros_(self.key_proj.bias.data)
def forward(self, x, need_s, need_e):
shrinkage = self.d_proj(x)**2 + 1 if (need_s) else None
selection = torch.sigmoid(self.e_proj(x)) if (need_e) else None
return self.key_proj(x), shrinkage, selection
class Decoder(nn.Module):
def __init__(self, val_dim, hidden_dim):
super().__init__()
self.fuser = FeatureFusionBlock(1024, val_dim+hidden_dim, 512, 512)
if hidden_dim > 0:
self.hidden_update = HiddenUpdater([512, 256, 256+1], 256, hidden_dim)
else:
self.hidden_update = None
self.up_16_8 = UpsampleBlock(512, 512, 256) # 1/16 -> 1/8
self.up_8_4 = UpsampleBlock(256, 256, 256) # 1/8 -> 1/4
self.pred = nn.Conv2d(256, 1, kernel_size=3, padding=1, stride=1)
def forward(self, f16, f8, f4, hidden_state, memory_readout, h_out=True):
batch_size, num_objects = memory_readout.shape[:2]
if self.hidden_update is not None:
g16 = self.fuser(f16, torch.cat([memory_readout, hidden_state], 2))
else:
g16 = self.fuser(f16, memory_readout)
g8 = self.up_16_8(f8, g16)
g4 = self.up_8_4(f4, g8)
logits = self.pred(F.relu(g4.flatten(start_dim=0, end_dim=1)))
if h_out and self.hidden_update is not None:
g4 = torch.cat([g4, logits.view(batch_size, num_objects, 1, *logits.shape[-2:])], 2)
hidden_state = self.hidden_update([g16, g8, g4], hidden_state)
else:
hidden_state = None
logits = F.interpolate(logits, scale_factor=4, mode='bilinear', align_corners=False)
logits = logits.view(batch_size, num_objects, *logits.shape[-2:])
return hidden_state, logits

198
tracker/model/network.py Normal file
View File

@@ -0,0 +1,198 @@
"""
This file defines XMem, the highest level nn.Module interface
During training, it is used by trainer.py
During evaluation, it is used by inference_core.py
It further depends on modules.py which gives more detailed implementations of sub-modules
"""
import torch
import torch.nn as nn
from model.aggregate import aggregate
from model.modules import *
from model.memory_util import *
class XMem(nn.Module):
def __init__(self, config, model_path=None, map_location=None):
"""
model_path/map_location are used in evaluation only
map_location is for converting models saved in cuda to cpu
"""
super().__init__()
model_weights = self.init_hyperparameters(config, model_path, map_location)
self.single_object = config.get('single_object', False)
print(f'Single object mode: {self.single_object}')
self.key_encoder = KeyEncoder()
self.value_encoder = ValueEncoder(self.value_dim, self.hidden_dim, self.single_object)
# Projection from f16 feature space to key/value space
self.key_proj = KeyProjection(1024, self.key_dim)
self.decoder = Decoder(self.value_dim, self.hidden_dim)
if model_weights is not None:
self.load_weights(model_weights, init_as_zero_if_needed=True)
def encode_key(self, frame, need_sk=True, need_ek=True):
# Determine input shape
if len(frame.shape) == 5:
# shape is b*t*c*h*w
need_reshape = True
b, t = frame.shape[:2]
# flatten so that we can feed them into a 2D CNN
frame = frame.flatten(start_dim=0, end_dim=1)
elif len(frame.shape) == 4:
# shape is b*c*h*w
need_reshape = False
else:
raise NotImplementedError
f16, f8, f4 = self.key_encoder(frame)
key, shrinkage, selection = self.key_proj(f16, need_sk, need_ek)
if need_reshape:
# B*C*T*H*W
key = key.view(b, t, *key.shape[-3:]).transpose(1, 2).contiguous()
if shrinkage is not None:
shrinkage = shrinkage.view(b, t, *shrinkage.shape[-3:]).transpose(1, 2).contiguous()
if selection is not None:
selection = selection.view(b, t, *selection.shape[-3:]).transpose(1, 2).contiguous()
# B*T*C*H*W
f16 = f16.view(b, t, *f16.shape[-3:])
f8 = f8.view(b, t, *f8.shape[-3:])
f4 = f4.view(b, t, *f4.shape[-3:])
return key, shrinkage, selection, f16, f8, f4
def encode_value(self, frame, image_feat_f16, h16, masks, is_deep_update=True):
num_objects = masks.shape[1]
if num_objects != 1:
others = torch.cat([
torch.sum(
masks[:, [j for j in range(num_objects) if i!=j]]
, dim=1, keepdim=True)
for i in range(num_objects)], 1)
else:
others = torch.zeros_like(masks)
g16, h16 = self.value_encoder(frame, image_feat_f16, h16, masks, others, is_deep_update)
return g16, h16
# Used in training only.
# This step is replaced by MemoryManager in test time
def read_memory(self, query_key, query_selection, memory_key,
memory_shrinkage, memory_value):
"""
query_key : B * CK * H * W
query_selection : B * CK * H * W
memory_key : B * CK * T * H * W
memory_shrinkage: B * 1 * T * H * W
memory_value : B * num_objects * CV * T * H * W
"""
batch_size, num_objects = memory_value.shape[:2]
memory_value = memory_value.flatten(start_dim=1, end_dim=2)
affinity = get_affinity(memory_key, memory_shrinkage, query_key, query_selection)
memory = readout(affinity, memory_value)
memory = memory.view(batch_size, num_objects, self.value_dim, *memory.shape[-2:])
return memory
def segment(self, multi_scale_features, memory_readout,
hidden_state, selector=None, h_out=True, strip_bg=True):
hidden_state, logits = self.decoder(*multi_scale_features, hidden_state, memory_readout, h_out=h_out)
prob = torch.sigmoid(logits)
if selector is not None:
prob = prob * selector
logits, prob = aggregate(prob, dim=1, return_logits=True)
if strip_bg:
# Strip away the background
prob = prob[:, 1:]
return hidden_state, logits, prob
def forward(self, mode, *args, **kwargs):
if mode == 'encode_key':
return self.encode_key(*args, **kwargs)
elif mode == 'encode_value':
return self.encode_value(*args, **kwargs)
elif mode == 'read_memory':
return self.read_memory(*args, **kwargs)
elif mode == 'segment':
return self.segment(*args, **kwargs)
else:
raise NotImplementedError
def init_hyperparameters(self, config, model_path=None, map_location=None):
"""
Init three hyperparameters: key_dim, value_dim, and hidden_dim
If model_path is provided, we load these from the model weights
The actual parameters are then updated to the config in-place
Otherwise we load it either from the config or default
"""
if model_path is not None:
# load the model and key/value/hidden dimensions with some hacks
# config is updated with the loaded parameters
model_weights = torch.load(model_path, map_location=map_location)
self.key_dim = model_weights['key_proj.key_proj.weight'].shape[0]
self.value_dim = model_weights['value_encoder.fuser.block2.conv2.weight'].shape[0]
self.disable_hidden = 'decoder.hidden_update.transform.weight' not in model_weights
if self.disable_hidden:
self.hidden_dim = 0
else:
self.hidden_dim = model_weights['decoder.hidden_update.transform.weight'].shape[0]//3
print(f'Hyperparameters read from the model weights: '
f'C^k={self.key_dim}, C^v={self.value_dim}, C^h={self.hidden_dim}')
else:
model_weights = None
# load dimensions from config or default
if 'key_dim' not in config:
self.key_dim = 64
print(f'key_dim not found in config. Set to default {self.key_dim}')
else:
self.key_dim = config['key_dim']
if 'value_dim' not in config:
self.value_dim = 512
print(f'value_dim not found in config. Set to default {self.value_dim}')
else:
self.value_dim = config['value_dim']
if 'hidden_dim' not in config:
self.hidden_dim = 64
print(f'hidden_dim not found in config. Set to default {self.hidden_dim}')
else:
self.hidden_dim = config['hidden_dim']
self.disable_hidden = (self.hidden_dim <= 0)
config['key_dim'] = self.key_dim
config['value_dim'] = self.value_dim
config['hidden_dim'] = self.hidden_dim
return model_weights
def load_weights(self, src_dict, init_as_zero_if_needed=False):
# Maps SO weight (without other_mask) to MO weight (with other_mask)
for k in list(src_dict.keys()):
if k == 'value_encoder.conv1.weight':
if src_dict[k].shape[1] == 4:
print('Converting weights from single object to multiple objects.')
pads = torch.zeros((64,1,7,7), device=src_dict[k].device)
if not init_as_zero_if_needed:
print('Randomly initialized padding.')
nn.init.orthogonal_(pads)
else:
print('Zero-initialized padding.')
src_dict[k] = torch.cat([src_dict[k], pads], 1)
self.load_state_dict(src_dict)

165
tracker/model/resnet.py Normal file
View File

@@ -0,0 +1,165 @@
"""
resnet.py - A modified ResNet structure
We append extra channels to the first conv by some network surgery
"""
from collections import OrderedDict
import math
import torch
import torch.nn as nn
from torch.utils import model_zoo
def load_weights_add_extra_dim(target, source_state, extra_dim=1):
new_dict = OrderedDict()
for k1, v1 in target.state_dict().items():
if not 'num_batches_tracked' in k1:
if k1 in source_state:
tar_v = source_state[k1]
if v1.shape != tar_v.shape:
# Init the new segmentation channel with zeros
# print(v1.shape, tar_v.shape)
c, _, w, h = v1.shape
pads = torch.zeros((c,extra_dim,w,h), device=tar_v.device)
nn.init.orthogonal_(pads)
tar_v = torch.cat([tar_v, pads], 1)
new_dict[k1] = tar_v
target.load_state_dict(new_dict)
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
}
def conv3x3(in_planes, out_planes, stride=1, dilation=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, dilation=dilation, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, dilation=dilation,
padding=dilation, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers=(3, 4, 23, 3), extra_dim=0):
self.inplanes = 64
super(ResNet, self).__init__()
self.conv1 = nn.Conv2d(3+extra_dim, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = [block(self.inplanes, planes, stride, downsample)]
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, dilation=dilation))
return nn.Sequential(*layers)
def resnet18(pretrained=True, extra_dim=0):
model = ResNet(BasicBlock, [2, 2, 2, 2], extra_dim)
if pretrained:
load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet18']), extra_dim)
return model
def resnet50(pretrained=True, extra_dim=0):
model = ResNet(Bottleneck, [3, 4, 6, 3], extra_dim)
if pretrained:
load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet50']), extra_dim)
return model

244
tracker/model/trainer.py Normal file
View File

@@ -0,0 +1,244 @@
"""
trainer.py - warpper and utility functions for network training
Compute loss, back-prop, update parameters, logging, etc.
"""
import datetime
import os
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from model.network import XMem
from model.losses import LossComputer
from util.log_integrator import Integrator
from util.image_saver import pool_pairs
class XMemTrainer:
def __init__(self, config, logger=None, save_path=None, local_rank=0, world_size=1):
self.config = config
self.num_frames = config['num_frames']
self.num_ref_frames = config['num_ref_frames']
self.deep_update_prob = config['deep_update_prob']
self.local_rank = local_rank
self.XMem = nn.parallel.DistributedDataParallel(
XMem(config).cuda(),
device_ids=[local_rank], output_device=local_rank, broadcast_buffers=False)
# Set up logger when local_rank=0
self.logger = logger
self.save_path = save_path
if logger is not None:
self.last_time = time.time()
self.logger.log_string('model_size', str(sum([param.nelement() for param in self.XMem.parameters()])))
self.train_integrator = Integrator(self.logger, distributed=True, local_rank=local_rank, world_size=world_size)
self.loss_computer = LossComputer(config)
self.train()
self.optimizer = optim.AdamW(filter(
lambda p: p.requires_grad, self.XMem.parameters()), lr=config['lr'], weight_decay=config['weight_decay'])
self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, config['steps'], config['gamma'])
if config['amp']:
self.scaler = torch.cuda.amp.GradScaler()
# Logging info
self.log_text_interval = config['log_text_interval']
self.log_image_interval = config['log_image_interval']
self.save_network_interval = config['save_network_interval']
self.save_checkpoint_interval = config['save_checkpoint_interval']
if config['debug']:
self.log_text_interval = self.log_image_interval = 1
def do_pass(self, data, max_it, it=0):
# No need to store the gradient outside training
torch.set_grad_enabled(self._is_train)
for k, v in data.items():
if type(v) != list and type(v) != dict and type(v) != int:
data[k] = v.cuda(non_blocking=True)
out = {}
frames = data['rgb']
first_frame_gt = data['first_frame_gt'].float()
b = frames.shape[0]
num_filled_objects = [o.item() for o in data['info']['num_objects']]
num_objects = first_frame_gt.shape[2]
selector = data['selector'].unsqueeze(2).unsqueeze(2)
global_avg = 0
with torch.cuda.amp.autocast(enabled=self.config['amp']):
# image features never change, compute once
key, shrinkage, selection, f16, f8, f4 = self.XMem('encode_key', frames)
filler_one = torch.zeros(1, dtype=torch.int64)
hidden = torch.zeros((b, num_objects, self.config['hidden_dim'], *key.shape[-2:]))
v16, hidden = self.XMem('encode_value', frames[:,0], f16[:,0], hidden, first_frame_gt[:,0])
values = v16.unsqueeze(3) # add the time dimension
for ti in range(1, self.num_frames):
if ti <= self.num_ref_frames:
ref_values = values
ref_keys = key[:,:,:ti]
ref_shrinkage = shrinkage[:,:,:ti] if shrinkage is not None else None
else:
# pick num_ref_frames random frames
# this is not very efficient but I think we would
# need broadcasting in gather which we don't have
indices = [
torch.cat([filler_one, torch.randperm(ti-1)[:self.num_ref_frames-1]+1])
for _ in range(b)]
ref_values = torch.stack([
values[bi, :, :, indices[bi]] for bi in range(b)
], 0)
ref_keys = torch.stack([
key[bi, :, indices[bi]] for bi in range(b)
], 0)
ref_shrinkage = torch.stack([
shrinkage[bi, :, indices[bi]] for bi in range(b)
], 0) if shrinkage is not None else None
# Segment frame ti
memory_readout = self.XMem('read_memory', key[:,:,ti], selection[:,:,ti] if selection is not None else None,
ref_keys, ref_shrinkage, ref_values)
hidden, logits, masks = self.XMem('segment', (f16[:,ti], f8[:,ti], f4[:,ti]), memory_readout,
hidden, selector, h_out=(ti < (self.num_frames-1)))
# No need to encode the last frame
if ti < (self.num_frames-1):
is_deep_update = np.random.rand() < self.deep_update_prob
v16, hidden = self.XMem('encode_value', frames[:,ti], f16[:,ti], hidden, masks, is_deep_update=is_deep_update)
values = torch.cat([values, v16.unsqueeze(3)], 3)
out[f'masks_{ti}'] = masks
out[f'logits_{ti}'] = logits
if self._do_log or self._is_train:
losses = self.loss_computer.compute({**data, **out}, num_filled_objects, it)
# Logging
if self._do_log:
self.integrator.add_dict(losses)
if self._is_train:
if it % self.log_image_interval == 0 and it != 0:
if self.logger is not None:
images = {**data, **out}
size = (384, 384)
self.logger.log_cv2('train/pairs', pool_pairs(images, size, num_filled_objects), it)
if self._is_train:
if (it) % self.log_text_interval == 0 and it != 0:
time_spent = time.time()-self.last_time
if self.logger is not None:
self.logger.log_scalar('train/lr', self.scheduler.get_last_lr()[0], it)
self.logger.log_metrics('train', 'time', (time_spent)/self.log_text_interval, it)
global_avg = 0.5*(global_avg) + 0.5*(time_spent)
eta_seconds = global_avg * (max_it - it) / 100
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
print(f'ETA: {eta_string}')
self.last_time = time.time()
self.train_integrator.finalize('train', it)
self.train_integrator.reset_except_hooks()
if it % self.save_network_interval == 0 and it != 0:
if self.logger is not None:
self.save_network(it)
if it % self.save_checkpoint_interval == 0 and it != 0:
if self.logger is not None:
self.save_checkpoint(it)
# Backward pass
self.optimizer.zero_grad(set_to_none=True)
if self.config['amp']:
self.scaler.scale(losses['total_loss']).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
else:
losses['total_loss'].backward()
self.optimizer.step()
self.scheduler.step()
def save_network(self, it):
if self.save_path is None:
print('Saving has been disabled.')
return
os.makedirs(os.path.dirname(self.save_path), exist_ok=True)
model_path = f'{self.save_path}_{it}.pth'
torch.save(self.XMem.module.state_dict(), model_path)
print(f'Network saved to {model_path}.')
def save_checkpoint(self, it):
if self.save_path is None:
print('Saving has been disabled.')
return
os.makedirs(os.path.dirname(self.save_path), exist_ok=True)
checkpoint_path = f'{self.save_path}_checkpoint_{it}.pth'
checkpoint = {
'it': it,
'network': self.XMem.module.state_dict(),
'optimizer': self.optimizer.state_dict(),
'scheduler': self.scheduler.state_dict()}
torch.save(checkpoint, checkpoint_path)
print(f'Checkpoint saved to {checkpoint_path}.')
def load_checkpoint(self, path):
# This method loads everything and should be used to resume training
map_location = 'cuda:%d' % self.local_rank
checkpoint = torch.load(path, map_location={'cuda:0': map_location})
it = checkpoint['it']
network = checkpoint['network']
optimizer = checkpoint['optimizer']
scheduler = checkpoint['scheduler']
map_location = 'cuda:%d' % self.local_rank
self.XMem.module.load_state_dict(network)
self.optimizer.load_state_dict(optimizer)
self.scheduler.load_state_dict(scheduler)
print('Network weights, optimizer states, and scheduler states loaded.')
return it
def load_network_in_memory(self, src_dict):
self.XMem.module.load_weights(src_dict)
print('Network weight loaded from memory.')
def load_network(self, path):
# This method loads only the network weight and should be used to load a pretrained model
map_location = 'cuda:%d' % self.local_rank
src_dict = torch.load(path, map_location={'cuda:0': map_location})
self.load_network_in_memory(src_dict)
print(f'Network weight loaded from {path}')
def train(self):
self._is_train = True
self._do_log = True
self.integrator = self.train_integrator
self.XMem.eval()
return self
def val(self):
self._is_train = False
self._do_log = True
self.XMem.eval()
return self
def test(self):
self._is_train = False
self._do_log = False
self.XMem.eval()
return self

0
tracker/util/__init__.py Normal file
View File

View File

@@ -0,0 +1,138 @@
from argparse import ArgumentParser
def none_or_default(x, default):
return x if x is not None else default
class Configuration():
def parse(self, unknown_arg_ok=False):
parser = ArgumentParser()
# Enable torch.backends.cudnn.benchmark -- Faster in some cases, test in your own environment
parser.add_argument('--benchmark', action='store_true')
parser.add_argument('--no_amp', action='store_true')
# Save paths
parser.add_argument('--save_path', default='/ssd1/gaomingqi/output/xmem-sam')
# Data parameters
parser.add_argument('--static_root', help='Static training data root', default='/ssd1/gaomingqi/datasets/static')
parser.add_argument('--bl_root', help='Blender training data root', default='../BL30K')
parser.add_argument('--yv_root', help='YouTubeVOS data root', default='/ssd1/gaomingqi/datasets/youtube-vos/2018')
parser.add_argument('--davis_root', help='DAVIS data root', default='/ssd1/gaomingqi/datasets/davis')
parser.add_argument('--num_workers', help='Total number of dataloader workers across all GPUs processes', type=int, default=16)
parser.add_argument('--key_dim', default=64, type=int)
parser.add_argument('--value_dim', default=512, type=int)
parser.add_argument('--hidden_dim', default=64, help='Set to =0 to disable', type=int)
parser.add_argument('--deep_update_prob', default=0.2, type=float)
parser.add_argument('--stages', help='Training stage (0-static images, 1-Blender dataset, 2-DAVIS+YouTubeVOS)', default='02')
"""
Stage-specific learning parameters
Batch sizes are effective -- you don't have to scale them when you scale the number processes
"""
# Stage 0, static images
parser.add_argument('--s0_batch_size', default=16, type=int)
parser.add_argument('--s0_iterations', default=150000, type=int)
parser.add_argument('--s0_finetune', default=0, type=int)
parser.add_argument('--s0_steps', nargs="*", default=[], type=int)
parser.add_argument('--s0_lr', help='Initial learning rate', default=1e-5, type=float)
parser.add_argument('--s0_num_ref_frames', default=2, type=int)
parser.add_argument('--s0_num_frames', default=3, type=int)
parser.add_argument('--s0_start_warm', default=20000, type=int)
parser.add_argument('--s0_end_warm', default=70000, type=int)
# Stage 1, BL30K
parser.add_argument('--s1_batch_size', default=8, type=int)
parser.add_argument('--s1_iterations', default=250000, type=int)
# fine-tune means fewer augmentations to train the sensory memory
parser.add_argument('--s1_finetune', default=0, type=int)
parser.add_argument('--s1_steps', nargs="*", default=[200000], type=int)
parser.add_argument('--s1_lr', help='Initial learning rate', default=1e-5, type=float)
parser.add_argument('--s1_num_ref_frames', default=3, type=int)
parser.add_argument('--s1_num_frames', default=8, type=int)
parser.add_argument('--s1_start_warm', default=20000, type=int)
parser.add_argument('--s1_end_warm', default=70000, type=int)
# Stage 2, DAVIS+YoutubeVOS, longer
parser.add_argument('--s2_batch_size', default=8, type=int)
parser.add_argument('--s2_iterations', default=150000, type=int)
# fine-tune means fewer augmentations to train the sensory memory
parser.add_argument('--s2_finetune', default=10000, type=int)
parser.add_argument('--s2_steps', nargs="*", default=[120000], type=int)
parser.add_argument('--s2_lr', help='Initial learning rate', default=1e-5, type=float)
parser.add_argument('--s2_num_ref_frames', default=3, type=int)
parser.add_argument('--s2_num_frames', default=8, type=int)
parser.add_argument('--s2_start_warm', default=20000, type=int)
parser.add_argument('--s2_end_warm', default=70000, type=int)
# Stage 3, DAVIS+YoutubeVOS, shorter
parser.add_argument('--s3_batch_size', default=8, type=int)
parser.add_argument('--s3_iterations', default=100000, type=int)
# fine-tune means fewer augmentations to train the sensory memory
parser.add_argument('--s3_finetune', default=10000, type=int)
parser.add_argument('--s3_steps', nargs="*", default=[80000], type=int)
parser.add_argument('--s3_lr', help='Initial learning rate', default=1e-5, type=float)
parser.add_argument('--s3_num_ref_frames', default=3, type=int)
parser.add_argument('--s3_num_frames', default=8, type=int)
parser.add_argument('--s3_start_warm', default=20000, type=int)
parser.add_argument('--s3_end_warm', default=70000, type=int)
parser.add_argument('--gamma', help='LR := LR*gamma at every decay step', default=0.1, type=float)
parser.add_argument('--weight_decay', default=0.05, type=float)
# Loading
parser.add_argument('--load_network', help='Path to pretrained network weight only')
parser.add_argument('--load_checkpoint', help='Path to the checkpoint file, including network, optimizer and such')
# Logging information
parser.add_argument('--log_text_interval', default=100, type=int)
parser.add_argument('--log_image_interval', default=1000, type=int)
parser.add_argument('--save_network_interval', default=25000, type=int)
parser.add_argument('--save_checkpoint_interval', default=50000, type=int)
parser.add_argument('--exp_id', help='Experiment UNIQUE id, use NULL to disable logging to tensorboard', default='NULL')
parser.add_argument('--debug', help='Debug mode which logs information more often', action='store_true')
# # Multiprocessing parameters, not set by users
# parser.add_argument('--local_rank', default=0, type=int, help='Local rank of this process')
if unknown_arg_ok:
args, _ = parser.parse_known_args()
self.args = vars(args)
else:
self.args = vars(parser.parse_args())
self.args['amp'] = not self.args['no_amp']
# check if the stages are valid
stage_to_perform = list(self.args['stages'])
for s in stage_to_perform:
if s not in ['0', '1', '2', '3']:
raise NotImplementedError
def get_stage_parameters(self, stage):
parameters = {
'batch_size': self.args['s%s_batch_size'%stage],
'iterations': self.args['s%s_iterations'%stage],
'finetune': self.args['s%s_finetune'%stage],
'steps': self.args['s%s_steps'%stage],
'lr': self.args['s%s_lr'%stage],
'num_ref_frames': self.args['s%s_num_ref_frames'%stage],
'num_frames': self.args['s%s_num_frames'%stage],
'start_warm': self.args['s%s_start_warm'%stage],
'end_warm': self.args['s%s_end_warm'%stage],
}
return parameters
def __getitem__(self, key):
return self.args[key]
def __setitem__(self, key, value):
self.args[key] = value
def __str__(self):
return str(self.args)

View File

@@ -0,0 +1,60 @@
bear
bmx-bumps
boat
boxing-fisheye
breakdance-flare
bus
car-turn
cat-girl
classic-car
color-run
crossing
dance-jump
dancing
disc-jockey
dog-agility
dog-gooses
dogs-scale
drift-turn
drone
elephant
flamingo
hike
hockey
horsejump-low
kid-football
kite-walk
koala
lady-running
lindy-hop
longboard
lucia
mallard-fly
mallard-water
miami-surf
motocross-bumps
motorbike
night-race
paragliding
planes-water
rallye
rhino
rollerblade
schoolgirls
scooter-board
scooter-gray
sheep
skate-park
snowboard
soccerball
stroller
stunt
surf
swing
tennis
tractor-sand
train
tuk-tuk
upside-down
varanus-cage
walking

136
tracker/util/image_saver.py Normal file
View File

@@ -0,0 +1,136 @@
import cv2
import numpy as np
import torch
from dataset.range_transform import inv_im_trans
from collections import defaultdict
def tensor_to_numpy(image):
image_np = (image.numpy() * 255).astype('uint8')
return image_np
def tensor_to_np_float(image):
image_np = image.numpy().astype('float32')
return image_np
def detach_to_cpu(x):
return x.detach().cpu()
def transpose_np(x):
return np.transpose(x, [1,2,0])
def tensor_to_gray_im(x):
x = detach_to_cpu(x)
x = tensor_to_numpy(x)
x = transpose_np(x)
return x
def tensor_to_im(x):
x = detach_to_cpu(x)
x = inv_im_trans(x).clamp(0, 1)
x = tensor_to_numpy(x)
x = transpose_np(x)
return x
# Predefined key <-> caption dict
key_captions = {
'im': 'Image',
'gt': 'GT',
}
"""
Return an image array with captions
keys in dictionary will be used as caption if not provided
values should contain lists of cv2 images
"""
def get_image_array(images, grid_shape, captions={}):
h, w = grid_shape
cate_counts = len(images)
rows_counts = len(next(iter(images.values())))
font = cv2.FONT_HERSHEY_SIMPLEX
output_image = np.zeros([w*cate_counts, h*(rows_counts+1), 3], dtype=np.uint8)
col_cnt = 0
for k, v in images.items():
# Default as key value itself
caption = captions.get(k, k)
# Handles new line character
dy = 40
for i, line in enumerate(caption.split('\n')):
cv2.putText(output_image, line, (10, col_cnt*w+100+i*dy),
font, 0.8, (255,255,255), 2, cv2.LINE_AA)
# Put images
for row_cnt, img in enumerate(v):
im_shape = img.shape
if len(im_shape) == 2:
img = img[..., np.newaxis]
img = (img * 255).astype('uint8')
output_image[(col_cnt+0)*w:(col_cnt+1)*w,
(row_cnt+1)*h:(row_cnt+2)*h, :] = img
col_cnt += 1
return output_image
def base_transform(im, size):
im = tensor_to_np_float(im)
if len(im.shape) == 3:
im = im.transpose((1, 2, 0))
else:
im = im[:, :, None]
# Resize
if im.shape[1] != size:
im = cv2.resize(im, size, interpolation=cv2.INTER_NEAREST)
return im.clip(0, 1)
def im_transform(im, size):
return base_transform(inv_im_trans(detach_to_cpu(im)), size=size)
def mask_transform(mask, size):
return base_transform(detach_to_cpu(mask), size=size)
def out_transform(mask, size):
return base_transform(detach_to_cpu(torch.sigmoid(mask)), size=size)
def pool_pairs(images, size, num_objects):
req_images = defaultdict(list)
b, t = images['rgb'].shape[:2]
# limit the number of images saved
b = min(2, b)
# find max num objects
max_num_objects = max(num_objects[:b])
GT_suffix = ''
for bi in range(b):
GT_suffix += ' \n%s' % images['info']['name'][bi][-25:-4]
for bi in range(b):
for ti in range(t):
req_images['RGB'].append(im_transform(images['rgb'][bi,ti], size))
for oi in range(max_num_objects):
if ti == 0 or oi >= num_objects[bi]:
req_images['Mask_%d'%oi].append(mask_transform(images['first_frame_gt'][bi][0,oi], size))
# req_images['Mask_X8_%d'%oi].append(mask_transform(images['first_frame_gt'][bi][0,oi], size))
# req_images['Mask_X16_%d'%oi].append(mask_transform(images['first_frame_gt'][bi][0,oi], size))
else:
req_images['Mask_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi], size))
# req_images['Mask_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi][2], size))
# req_images['Mask_X8_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi][1], size))
# req_images['Mask_X16_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi][0], size))
req_images['GT_%d_%s'%(oi, GT_suffix)].append(mask_transform(images['cls_gt'][bi,ti,0]==(oi+1), size))
# print((images['cls_gt'][bi,ti,0]==(oi+1)).shape)
# print(mask_transform(images['cls_gt'][bi,ti,0]==(oi+1), size).shape)
return get_image_array(req_images, size, key_captions)

View File

@@ -0,0 +1,16 @@
"""
load_subset.py - Presents a subset of data
DAVIS - only the training set
YouTubeVOS - I manually filtered some erroneous ones out but I haven't checked all
"""
def load_sub_davis(path='util/davis_subset.txt'):
with open(path, mode='r') as f:
subset = set(f.read().splitlines())
return subset
def load_sub_yv(path='util/yv_subset.txt'):
with open(path, mode='r') as f:
subset = set(f.read().splitlines())
return subset

View File

@@ -0,0 +1,80 @@
"""
Integrate numerical values for some iterations
Typically used for loss computation / logging to tensorboard
Call finalize and create a new Integrator when you want to display/log
"""
import torch
class Integrator:
def __init__(self, logger, distributed=True, local_rank=0, world_size=1):
self.values = {}
self.counts = {}
self.hooks = [] # List is used here to maintain insertion order
self.logger = logger
self.distributed = distributed
self.local_rank = local_rank
self.world_size = world_size
def add_tensor(self, key, tensor):
if key not in self.values:
self.counts[key] = 1
if type(tensor) == float or type(tensor) == int:
self.values[key] = tensor
else:
self.values[key] = tensor.mean().item()
else:
self.counts[key] += 1
if type(tensor) == float or type(tensor) == int:
self.values[key] += tensor
else:
self.values[key] += tensor.mean().item()
def add_dict(self, tensor_dict):
for k, v in tensor_dict.items():
self.add_tensor(k, v)
def add_hook(self, hook):
"""
Adds a custom hook, i.e. compute new metrics using values in the dict
The hook takes the dict as argument, and returns a (k, v) tuple
e.g. for computing IoU
"""
if type(hook) == list:
self.hooks.extend(hook)
else:
self.hooks.append(hook)
def reset_except_hooks(self):
self.values = {}
self.counts = {}
# Average and output the metrics
def finalize(self, prefix, it, f=None):
for hook in self.hooks:
k, v = hook(self.values)
self.add_tensor(k, v)
for k, v in self.values.items():
if k[:4] == 'hide':
continue
avg = v / self.counts[k]
if self.distributed:
# Inplace operation
avg = torch.tensor(avg).cuda()
torch.distributed.reduce(avg, dst=0)
if self.local_rank == 0:
avg = (avg/self.world_size).cpu().item()
self.logger.log_metrics(prefix, k, avg, it, f)
else:
# Simple does it
self.logger.log_metrics(prefix, k, avg, it, f)

101
tracker/util/logger.py Normal file
View File

@@ -0,0 +1,101 @@
"""
Dumps things to tensorboard and console
"""
import os
import warnings
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
def tensor_to_numpy(image):
image_np = (image.numpy() * 255).astype('uint8')
return image_np
def detach_to_cpu(x):
return x.detach().cpu()
def fix_width_trunc(x):
return ('{:.9s}'.format('{:0.9f}'.format(x)))
class TensorboardLogger:
def __init__(self, short_id, id, git_info):
self.short_id = short_id
if self.short_id == 'NULL':
self.short_id = 'DEBUG'
if id is None:
self.no_log = True
warnings.warn('Logging has been disbaled.')
else:
self.no_log = False
self.inv_im_trans = transforms.Normalize(
mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
std=[1/0.229, 1/0.224, 1/0.225])
self.inv_seg_trans = transforms.Normalize(
mean=[-0.5/0.5],
std=[1/0.5])
log_path = os.path.join('.', 'logs', '%s' % id)
self.logger = SummaryWriter(log_path)
self.log_string('git', git_info)
def log_scalar(self, tag, x, step):
if self.no_log:
warnings.warn('Logging has been disabled.')
return
self.logger.add_scalar(tag, x, step)
def log_metrics(self, l1_tag, l2_tag, val, step, f=None):
tag = l1_tag + '/' + l2_tag
text = '{:s} - It {:6d} [{:5s}] [{:13}]: {:s}'.format(self.short_id, step, l1_tag.upper(), l2_tag, fix_width_trunc(val))
print(text)
if f is not None:
f.write(text + '\n')
f.flush()
self.log_scalar(tag, val, step)
def log_im(self, tag, x, step):
if self.no_log:
warnings.warn('Logging has been disabled.')
return
x = detach_to_cpu(x)
x = self.inv_im_trans(x)
x = tensor_to_numpy(x)
self.logger.add_image(tag, x, step)
def log_cv2(self, tag, x, step):
if self.no_log:
warnings.warn('Logging has been disabled.')
return
x = x.transpose((2, 0, 1))
self.logger.add_image(tag, x, step)
def log_seg(self, tag, x, step):
if self.no_log:
warnings.warn('Logging has been disabled.')
return
x = detach_to_cpu(x)
x = self.inv_seg_trans(x)
x = tensor_to_numpy(x)
self.logger.add_image(tag, x, step)
def log_gray(self, tag, x, step):
if self.no_log:
warnings.warn('Logging has been disabled.')
return
x = detach_to_cpu(x)
x = tensor_to_numpy(x)
self.logger.add_image(tag, x, step)
def log_string(self, tag, x):
print(tag, x)
if self.no_log:
warnings.warn('Logging has been disabled.')
return
self.logger.add_text(tag, x)

3
tracker/util/palette.py Normal file
View File

@@ -0,0 +1,3 @@
davis_palette = b'\x00\x00\x00\x80\x00\x00\x00\x80\x00\x80\x80\x00\x00\x00\x80\x80\x00\x80\x00\x80\x80\x80\x80\x80@\x00\x00\xc0\x00\x00@\x80\x00\xc0\x80\x00@\x00\x80\xc0\x00\x80@\x80\x80\xc0\x80\x80\x00@\x00\x80@\x00\x00\xc0\x00\x80\xc0\x00\x00@\x80\x80@\x80\x00\xc0\x80\x80\xc0\x80@@\x00\xc0@\x00@\xc0\x00\xc0\xc0\x00@@\x80\xc0@\x80@\xc0\x80\xc0\xc0\x80\x00\x00@\x80\x00@\x00\x80@\x80\x80@\x00\x00\xc0\x80\x00\xc0\x00\x80\xc0\x80\x80\xc0@\x00@\xc0\x00@@\x80@\xc0\x80@@\x00\xc0\xc0\x00\xc0@\x80\xc0\xc0\x80\xc0\x00@@\x80@@\x00\xc0@\x80\xc0@\x00@\xc0\x80@\xc0\x00\xc0\xc0\x80\xc0\xc0@@@\xc0@@@\xc0@\xc0\xc0@@@\xc0\xc0@\xc0@\xc0\xc0\xc0\xc0\xc0 \x00\x00\xa0\x00\x00 \x80\x00\xa0\x80\x00 \x00\x80\xa0\x00\x80 \x80\x80\xa0\x80\x80`\x00\x00\xe0\x00\x00`\x80\x00\xe0\x80\x00`\x00\x80\xe0\x00\x80`\x80\x80\xe0\x80\x80 @\x00\xa0@\x00 \xc0\x00\xa0\xc0\x00 @\x80\xa0@\x80 \xc0\x80\xa0\xc0\x80`@\x00\xe0@\x00`\xc0\x00\xe0\xc0\x00`@\x80\xe0@\x80`\xc0\x80\xe0\xc0\x80 \x00@\xa0\x00@ \x80@\xa0\x80@ \x00\xc0\xa0\x00\xc0 \x80\xc0\xa0\x80\xc0`\x00@\xe0\x00@`\x80@\xe0\x80@`\x00\xc0\xe0\x00\xc0`\x80\xc0\xe0\x80\xc0 @@\xa0@@ \xc0@\xa0\xc0@ @\xc0\xa0@\xc0 \xc0\xc0\xa0\xc0\xc0`@@\xe0@@`\xc0@\xe0\xc0@`@\xc0\xe0@\xc0`\xc0\xc0\xe0\xc0\xc0\x00 \x00\x80 \x00\x00\xa0\x00\x80\xa0\x00\x00 \x80\x80 \x80\x00\xa0\x80\x80\xa0\x80@ \x00\xc0 \x00@\xa0\x00\xc0\xa0\x00@ \x80\xc0 \x80@\xa0\x80\xc0\xa0\x80\x00`\x00\x80`\x00\x00\xe0\x00\x80\xe0\x00\x00`\x80\x80`\x80\x00\xe0\x80\x80\xe0\x80@`\x00\xc0`\x00@\xe0\x00\xc0\xe0\x00@`\x80\xc0`\x80@\xe0\x80\xc0\xe0\x80\x00 @\x80 @\x00\xa0@\x80\xa0@\x00 \xc0\x80 \xc0\x00\xa0\xc0\x80\xa0\xc0@ @\xc0 @@\xa0@\xc0\xa0@@ \xc0\xc0 \xc0@\xa0\xc0\xc0\xa0\xc0\x00`@\x80`@\x00\xe0@\x80\xe0@\x00`\xc0\x80`\xc0\x00\xe0\xc0\x80\xe0\xc0@`@\xc0`@@\xe0@\xc0\xe0@@`\xc0\xc0`\xc0@\xe0\xc0\xc0\xe0\xc0 \x00\xa0 \x00 \xa0\x00\xa0\xa0\x00 \x80\xa0 \x80 \xa0\x80\xa0\xa0\x80` \x00\xe0 \x00`\xa0\x00\xe0\xa0\x00` \x80\xe0 \x80`\xa0\x80\xe0\xa0\x80 `\x00\xa0`\x00 \xe0\x00\xa0\xe0\x00 `\x80\xa0`\x80 \xe0\x80\xa0\xe0\x80``\x00\xe0`\x00`\xe0\x00\xe0\xe0\x00``\x80\xe0`\x80`\xe0\x80\xe0\xe0\x80 @\xa0 @ \xa0@\xa0\xa0@ \xc0\xa0 \xc0 \xa0\xc0\xa0\xa0\xc0` @\xe0 @`\xa0@\xe0\xa0@` \xc0\xe0 \xc0`\xa0\xc0\xe0\xa0\xc0 `@\xa0`@ \xe0@\xa0\xe0@ `\xc0\xa0`\xc0 \xe0\xc0\xa0\xe0\xc0``@\xe0`@`\xe0@\xe0\xe0@``\xc0\xe0`\xc0`\xe0\xc0\xe0\xe0\xc0'
youtube_palette = b'\x00\x00\x00\xec_g\xf9\x91W\xfa\xc8c\x99\xc7\x94b\xb3\xb2f\x99\xcc\xc5\x94\xc5\xabyg\xff\xff\xffes~\x0b\x0b\x0b\x0c\x0c\x0c\r\r\r\x0e\x0e\x0e\x0f\x0f\x0f'

View File

@@ -0,0 +1,47 @@
import torch.nn.functional as F
def compute_tensor_iu(seg, gt):
intersection = (seg & gt).float().sum()
union = (seg | gt).float().sum()
return intersection, union
def compute_tensor_iou(seg, gt):
intersection, union = compute_tensor_iu(seg, gt)
iou = (intersection + 1e-6) / (union + 1e-6)
return iou
# STM
def pad_divide_by(in_img, d):
h, w = in_img.shape[-2:]
if h % d > 0:
new_h = h + d - h % d
else:
new_h = h
if w % d > 0:
new_w = w + d - w % d
else:
new_w = w
lh, uh = int((new_h-h) / 2), int(new_h-h) - int((new_h-h) / 2)
lw, uw = int((new_w-w) / 2), int(new_w-w) - int((new_w-w) / 2)
pad_array = (int(lw), int(uw), int(lh), int(uh))
out = F.pad(in_img, pad_array)
return out, pad_array
def unpad(img, pad):
if len(img.shape) == 4:
if pad[2]+pad[3] > 0:
img = img[:,:,pad[2]:-pad[3],:]
if pad[0]+pad[1] > 0:
img = img[:,:,:,pad[0]:-pad[1]]
elif len(img.shape) == 3:
if pad[2]+pad[3] > 0:
img = img[:,pad[2]:-pad[3],:]
if pad[0]+pad[1] > 0:
img = img[:,:,pad[0]:-pad[1]]
else:
raise NotImplementedError
return img

3464
tracker/util/yv_subset.txt Normal file

File diff suppressed because it is too large Load Diff