remove redundant code

This commit is contained in:
gaomingqi
2023-04-14 12:37:38 +08:00
parent c8ca9078ec
commit 8e2b65f547
22 changed files with 86 additions and 4941 deletions

View File

@@ -7,14 +7,13 @@ from PIL import Image
import torch
import yaml
import torch.nn.functional as F
from model.network import XMem
from inference.inference_core import InferenceCore
from inference.data.mask_mapper import MaskMapper
# for data transormation
from util.mask_mapper import MaskMapper
from torchvision import transforms
from dataset.range_transform import im_normalization
import torch.nn.functional as F
from util.range_transform import im_normalization
import sys
sys.path.insert(0, sys.path[0]+"/../")
@@ -39,9 +38,11 @@ class BaseTracker:
transforms.ToTensor(),
im_normalization,
])
self.mapper = MaskMapper()
self.device = device
self.mapper = MaskMapper()
self.initialised = False
@torch.no_grad()
def resize_mask(self, mask):
# mask transform is applied AFTER mapper, so we need to post-process it in eval.py
@@ -51,37 +52,42 @@ class BaseTracker:
mode='nearest')
@torch.no_grad()
def track(self, frames, first_frame_annotation):
def track(self, frame, first_frame_annotation=None):
"""
Input:
frames: numpy arrays: T, H, W, 3 (T: number of frames)
first_frame_annotation: numpy array: H, W
frames: numpy arrays (H, W, 3)
first_frame_annotation: numpy array (H, W)
Output:
masks: numpy arrays: H, W
mask: numpy arrays (H, W)
prob: numpy arrays, probability map (H, W)
painted_image: numpy array (H, W, 3)
"""
vid_length = len(frames)
masks = []
if first_frame_annotation is not None:
# initialisation
mask, labels = self.mapper.convert_mask(first_frame_annotation)
mask = torch.Tensor(mask).to(self.device)
self.tracker.set_all_labels(list(self.mapper.remappings.values()))
else:
mask = None
labels = None
for ti, frame in enumerate(frames):
# convert to tensor
frame_tensor = self.im_transform(frame).to(self.device)
if ti == 0:
mask, labels = self.mapper.convert_mask(first_frame_annotation)
mask = torch.Tensor(mask).to(self.device)
self.tracker.set_all_labels(list(self.mapper.remappings.values()))
else:
mask = None
labels = None
# track one frame
prob = self.tracker.step(frame_tensor, mask, labels, end=(ti==vid_length-1))
# convert to mask
out_mask = torch.argmax(prob, dim=0)
out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
masks.append(out_mask)
# prepare inputs
frame_tensor = self.im_transform(frame).to(self.device)
# track one frame
prob = self.tracker.step(frame_tensor, mask, labels)
# convert to mask
out_mask = torch.argmax(prob, dim=0)
out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
painted_image = mask_painter(frame, out_mask)
return np.stack(masks, 0)
# mask, _, painted_frame
return out_mask, prob, painted_image
@torch.no_grad()
def clear_memory(self):
self.tracker.clear_memory()
self.mapper.clear_labels()
if __name__ == '__main__':
@@ -106,11 +112,40 @@ if __name__ == '__main__':
tracker = BaseTracker(device, XMEM_checkpoint)
# track anything given in the first frame annotation
masks = tracker.track(frames, first_frame_annotation)
# save
for ti, (frame, mask) in enumerate(zip(frames, masks)):
painted_image = mask_painter(frame, mask)
for ti, frame in enumerate(frames):
if ti == 0:
mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
else:
mask, prob, painted_image = tracker.track(frame)
# save
painted_image = Image.fromarray(painted_image)
painted_image.save(f'/ssd1/gaomingqi/results/TrackA/{ti:05d}.png')
painted_image.save(f'/ssd1/gaomingqi/results/TrackA/dance-twirl/{ti:05d}.png')
# ----------------------------------------------------------
# another video
# ----------------------------------------------------------
# video frames
video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/camel', '*.jpg'))
video_path_list.sort()
# first frame
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/camel/00000.png'
# load frames
frames = []
for video_path in video_path_list:
frames.append(np.array(Image.open(video_path).convert('RGB')))
frames = np.stack(frames, 0) # N, H, W, C
# load first frame annotation
first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C
print('first video done. clear.')
tracker.clear_memory()
# track anything given in the first frame annotation
for ti, frame in enumerate(frames):
if ti == 0:
mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
else:
mask, prob, painted_image = tracker.track(frame)
# save
painted_image = Image.fromarray(painted_image)
painted_image.save(f'/ssd1/gaomingqi/results/TrackA/camel/{ti:05d}.png')

View File

@@ -1,6 +0,0 @@
import torch
import random
def reseed(seed):
random.seed(seed)
torch.manual_seed(seed)

View File

@@ -1,179 +0,0 @@
import os
from os import path
import torch
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from PIL import Image
import numpy as np
from dataset.range_transform import im_normalization, im_mean
from dataset.tps import random_tps_warp
from dataset.reseed import reseed
class StaticTransformDataset(Dataset):
"""
Generate pseudo VOS data by applying random transforms on static images.
Single-object only.
Method 0 - FSS style (class/1.jpg class/1.png)
Method 1 - Others style (XXX.jpg XXX.png)
"""
def __init__(self, parameters, num_frames=3, max_num_obj=1):
self.num_frames = num_frames
self.max_num_obj = max_num_obj
self.im_list = []
for parameter in parameters:
root, method, multiplier = parameter
if method == 0:
# Get images
classes = os.listdir(root)
for c in classes:
imgs = os.listdir(path.join(root, c))
jpg_list = [im for im in imgs if 'jpg' in im[-3:].lower()]
joint_list = [path.join(root, c, im) for im in jpg_list]
self.im_list.extend(joint_list * multiplier)
elif method == 1:
self.im_list.extend([path.join(root, im) for im in os.listdir(root) if '.jpg' in im] * multiplier)
print(f'{len(self.im_list)} images found.')
# These set of transform is the same for im/gt pairs, but different among the 3 sampled frames
self.pair_im_lone_transform = transforms.Compose([
transforms.ColorJitter(0.1, 0.05, 0.05, 0), # No hue change here as that's not realistic
])
self.pair_im_dual_transform = transforms.Compose([
transforms.RandomAffine(degrees=20, scale=(0.9,1.1), shear=10, interpolation=InterpolationMode.BICUBIC, fill=im_mean),
transforms.Resize(384, InterpolationMode.BICUBIC),
transforms.RandomCrop((384, 384), pad_if_needed=True, fill=im_mean),
])
self.pair_gt_dual_transform = transforms.Compose([
transforms.RandomAffine(degrees=20, scale=(0.9,1.1), shear=10, interpolation=InterpolationMode.BICUBIC, fill=0),
transforms.Resize(384, InterpolationMode.NEAREST),
transforms.RandomCrop((384, 384), pad_if_needed=True, fill=0),
])
# These transform are the same for all pairs in the sampled sequence
self.all_im_lone_transform = transforms.Compose([
transforms.ColorJitter(0.1, 0.05, 0.05, 0.05),
transforms.RandomGrayscale(0.05),
])
self.all_im_dual_transform = transforms.Compose([
transforms.RandomAffine(degrees=0, scale=(0.8, 1.5), fill=im_mean),
transforms.RandomHorizontalFlip(),
])
self.all_gt_dual_transform = transforms.Compose([
transforms.RandomAffine(degrees=0, scale=(0.8, 1.5), fill=0),
transforms.RandomHorizontalFlip(),
])
# Final transform without randomness
self.final_im_transform = transforms.Compose([
transforms.ToTensor(),
im_normalization,
])
self.final_gt_transform = transforms.Compose([
transforms.ToTensor(),
])
def _get_sample(self, idx):
im = Image.open(self.im_list[idx]).convert('RGB')
gt = Image.open(self.im_list[idx][:-3]+'png').convert('L')
sequence_seed = np.random.randint(2147483647)
images = []
masks = []
for _ in range(self.num_frames):
reseed(sequence_seed)
this_im = self.all_im_dual_transform(im)
this_im = self.all_im_lone_transform(this_im)
reseed(sequence_seed)
this_gt = self.all_gt_dual_transform(gt)
pairwise_seed = np.random.randint(2147483647)
reseed(pairwise_seed)
this_im = self.pair_im_dual_transform(this_im)
this_im = self.pair_im_lone_transform(this_im)
reseed(pairwise_seed)
this_gt = self.pair_gt_dual_transform(this_gt)
# Use TPS only some of the times
# Not because TPS is bad -- just that it is too slow and I need to speed up data loading
if np.random.rand() < 0.33:
this_im, this_gt = random_tps_warp(this_im, this_gt, scale=0.02)
this_im = self.final_im_transform(this_im)
this_gt = self.final_gt_transform(this_gt)
images.append(this_im)
masks.append(this_gt)
images = torch.stack(images, 0)
masks = torch.stack(masks, 0)
return images, masks.numpy()
def __getitem__(self, idx):
additional_objects = np.random.randint(self.max_num_obj)
indices = [idx, *np.random.randint(self.__len__(), size=additional_objects)]
merged_images = None
merged_masks = np.zeros((self.num_frames, 384, 384), dtype=np.int64)
for i, list_id in enumerate(indices):
images, masks = self._get_sample(list_id)
if merged_images is None:
merged_images = images
else:
merged_images = merged_images*(1-masks) + images*masks
merged_masks[masks[:,0]>0.5] = (i+1)
masks = merged_masks
labels = np.unique(masks[0])
# Remove background
labels = labels[labels!=0]
target_objects = labels.tolist()
# Generate one-hot ground-truth
cls_gt = np.zeros((self.num_frames, 384, 384), dtype=np.int64)
first_frame_gt = np.zeros((1, self.max_num_obj, 384, 384), dtype=np.int64)
for i, l in enumerate(target_objects):
this_mask = (masks==l)
cls_gt[this_mask] = i+1
first_frame_gt[0,i] = (this_mask[0])
cls_gt = np.expand_dims(cls_gt, 1)
info = {}
info['name'] = self.im_list[idx]
info['num_objects'] = max(1, len(target_objects))
# 1 if object exist, 0 otherwise
selector = [1 if i < info['num_objects'] else 0 for i in range(self.max_num_obj)]
selector = torch.FloatTensor(selector)
data = {
'rgb': merged_images,
'first_frame_gt': first_frame_gt,
'cls_gt': cls_gt,
'selector': selector,
'info': info
}
return data
def __len__(self):
return len(self.im_list)

View File

@@ -1,37 +0,0 @@
import numpy as np
from PIL import Image
import cv2
import thinplate as tps
cv2.setNumThreads(0)
def pick_random_points(h, w, n_samples):
y_idx = np.random.choice(np.arange(h), size=n_samples, replace=False)
x_idx = np.random.choice(np.arange(w), size=n_samples, replace=False)
return y_idx/h, x_idx/w
def warp_dual_cv(img, mask, c_src, c_dst):
dshape = img.shape
theta = tps.tps_theta_from_points(c_src, c_dst, reduced=True)
grid = tps.tps_grid(theta, c_dst, dshape)
mapx, mapy = tps.tps_grid_to_remap(grid, img.shape)
return cv2.remap(img, mapx, mapy, cv2.INTER_LINEAR), cv2.remap(mask, mapx, mapy, cv2.INTER_NEAREST)
def random_tps_warp(img, mask, scale, n_ctrl_pts=12):
"""
Apply a random TPS warp of the input image and mask
Uses randomness from numpy
"""
img = np.asarray(img)
mask = np.asarray(mask)
h, w = mask.shape
points = pick_random_points(h, w, n_ctrl_pts)
c_src = np.stack(points, 1)
c_dst = c_src + np.random.normal(scale=scale, size=c_src.shape)
warp_im, warp_gt = warp_dual_cv(img, mask, c_src, c_dst)
return Image.fromarray(warp_im), Image.fromarray(warp_gt)

View File

@@ -1,13 +0,0 @@
import numpy as np
def all_to_onehot(masks, labels):
if len(masks.shape) == 3:
Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1], masks.shape[2]), dtype=np.uint8)
else:
Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1]), dtype=np.uint8)
for ni, l in enumerate(labels):
Ms[ni] = (masks == l).astype(np.uint8)
return Ms

View File

@@ -1,216 +0,0 @@
import os
from os import path, replace
import torch
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from PIL import Image
import numpy as np
from dataset.range_transform import im_normalization, im_mean
from dataset.reseed import reseed
class VOSDataset(Dataset):
"""
Works for DAVIS/YouTubeVOS/BL30K training
For each sequence:
- Pick three frames
- Pick two objects
- Apply some random transforms that are the same for all frames
- Apply random transform to each of the frame
- The distance between frames is controlled
"""
def __init__(self, im_root, gt_root, max_jump, is_bl, subset=None, num_frames=3, max_num_obj=3, finetune=False):
self.im_root = im_root
self.gt_root = gt_root
self.max_jump = max_jump
self.is_bl = is_bl
self.num_frames = num_frames
self.max_num_obj = max_num_obj
self.videos = []
self.frames = {}
vid_list = sorted(os.listdir(self.im_root))
# Pre-filtering
for vid in vid_list:
if subset is not None:
if vid not in subset:
continue
frames = sorted(os.listdir(os.path.join(self.im_root, vid)))
if len(frames) < num_frames:
continue
self.frames[vid] = frames
self.videos.append(vid)
print('%d out of %d videos accepted in %s.' % (len(self.videos), len(vid_list), im_root))
# These set of transform is the same for im/gt pairs, but different among the 3 sampled frames
self.pair_im_lone_transform = transforms.Compose([
transforms.ColorJitter(0.01, 0.01, 0.01, 0),
])
self.pair_im_dual_transform = transforms.Compose([
transforms.RandomAffine(degrees=0 if finetune or self.is_bl else 15, shear=0 if finetune or self.is_bl else 10, interpolation=InterpolationMode.BILINEAR, fill=im_mean),
])
self.pair_gt_dual_transform = transforms.Compose([
transforms.RandomAffine(degrees=0 if finetune or self.is_bl else 15, shear=0 if finetune or self.is_bl else 10, interpolation=InterpolationMode.NEAREST, fill=0),
])
# These transform are the same for all pairs in the sampled sequence
self.all_im_lone_transform = transforms.Compose([
transforms.ColorJitter(0.1, 0.03, 0.03, 0),
transforms.RandomGrayscale(0.05),
])
if self.is_bl:
# Use a different cropping scheme for the blender dataset because the image size is different
self.all_im_dual_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomResizedCrop((384, 384), scale=(0.25, 1.00), interpolation=InterpolationMode.BILINEAR)
])
self.all_gt_dual_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomResizedCrop((384, 384), scale=(0.25, 1.00), interpolation=InterpolationMode.NEAREST)
])
else:
self.all_im_dual_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomResizedCrop((384, 384), scale=(0.36,1.00), interpolation=InterpolationMode.BILINEAR)
])
self.all_gt_dual_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomResizedCrop((384, 384), scale=(0.36,1.00), interpolation=InterpolationMode.NEAREST)
])
# Final transform without randomness
self.final_im_transform = transforms.Compose([
transforms.ToTensor(),
im_normalization,
])
def __getitem__(self, idx):
video = self.videos[idx]
info = {}
info['name'] = video
vid_im_path = path.join(self.im_root, video)
vid_gt_path = path.join(self.gt_root, video)
frames = self.frames[video]
trials = 0
while trials < 5:
info['frames'] = [] # Appended with actual frames
num_frames = self.num_frames
length = len(frames)
this_max_jump = min(len(frames), self.max_jump)
# iterative sampling
frames_idx = [np.random.randint(length)]
acceptable_set = set(range(max(0, frames_idx[-1]-this_max_jump), min(length, frames_idx[-1]+this_max_jump+1))).difference(set(frames_idx))
while(len(frames_idx) < num_frames):
idx = np.random.choice(list(acceptable_set))
frames_idx.append(idx)
new_set = set(range(max(0, frames_idx[-1]-this_max_jump), min(length, frames_idx[-1]+this_max_jump+1)))
acceptable_set = acceptable_set.union(new_set).difference(set(frames_idx))
frames_idx = sorted(frames_idx)
if np.random.rand() < 0.5:
# Reverse time
frames_idx = frames_idx[::-1]
sequence_seed = np.random.randint(2147483647)
images = []
masks = []
target_objects = []
for f_idx in frames_idx:
jpg_name = frames[f_idx][:-4] + '.jpg'
png_name = frames[f_idx][:-4] + '.png'
info['frames'].append(jpg_name)
reseed(sequence_seed)
this_im = Image.open(path.join(vid_im_path, jpg_name)).convert('RGB')
this_im = self.all_im_dual_transform(this_im)
this_im = self.all_im_lone_transform(this_im)
reseed(sequence_seed)
this_gt = Image.open(path.join(vid_gt_path, png_name)).convert('P')
this_gt = self.all_gt_dual_transform(this_gt)
pairwise_seed = np.random.randint(2147483647)
reseed(pairwise_seed)
this_im = self.pair_im_dual_transform(this_im)
this_im = self.pair_im_lone_transform(this_im)
reseed(pairwise_seed)
this_gt = self.pair_gt_dual_transform(this_gt)
this_im = self.final_im_transform(this_im)
this_gt = np.array(this_gt)
images.append(this_im)
masks.append(this_gt)
images = torch.stack(images, 0)
labels = np.unique(masks[0])
# Remove background
labels = labels[labels!=0]
if self.is_bl:
# Find large enough labels
good_lables = []
for l in labels:
pixel_sum = (masks[0]==l).sum()
if pixel_sum > 10*10:
# OK if the object is always this small
# Not OK if it is actually much bigger
if pixel_sum > 30*30:
good_lables.append(l)
elif max((masks[1]==l).sum(), (masks[2]==l).sum()) < 20*20:
good_lables.append(l)
labels = np.array(good_lables, dtype=np.uint8)
if len(labels) == 0:
target_objects = []
trials += 1
else:
target_objects = labels.tolist()
break
if len(target_objects) > self.max_num_obj:
target_objects = np.random.choice(target_objects, size=self.max_num_obj, replace=False)
info['num_objects'] = max(1, len(target_objects))
masks = np.stack(masks, 0)
# Generate one-hot ground-truth
cls_gt = np.zeros((self.num_frames, 384, 384), dtype=np.int64)
first_frame_gt = np.zeros((1, self.max_num_obj, 384, 384), dtype=np.int64)
for i, l in enumerate(target_objects):
this_mask = (masks==l)
cls_gt[this_mask] = i+1
first_frame_gt[0,i] = (this_mask[0])
cls_gt = np.expand_dims(cls_gt, 1)
# 1 if object exist, 0 otherwise
selector = [1 if i < info['num_objects'] else 0 for i in range(self.max_num_obj)]
selector = torch.FloatTensor(selector)
data = {
'rgb': images,
'first_frame_gt': first_frame_gt,
'cls_gt': cls_gt,
'selector': selector,
'info': info,
}
return data
def __len__(self):
return len(self.videos)

View File

@@ -1,257 +0,0 @@
import os
from os import path
from argparse import ArgumentParser
import shutil
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from PIL import Image
from inference.data.test_datasets import LongTestDataset, DAVISTestDataset, YouTubeVOSTestDataset
from inference.data.mask_mapper import MaskMapper
from model.network import XMem
from inference.inference_core import InferenceCore
from progressbar import progressbar
try:
import hickle as hkl
except ImportError:
print('Failed to import hickle. Fine if not using multi-scale testing.')
"""
Arguments loading
"""
parser = ArgumentParser()
parser.add_argument('--model', default='/ssd1/gaomingqi/checkpoints/XMem-s012.pth')
# Data options
parser.add_argument('--d16_path', default='../DAVIS/2016')
parser.add_argument('--d17_path', default='../DAVIS/2017')
parser.add_argument('--y18_path', default='/ssd1/gaomingqi/datasets/youtube-vos/2018')
parser.add_argument('--y19_path', default='../YouTube')
parser.add_argument('--lv_path', default='../long_video_set')
# For generic (G) evaluation, point to a folder that contains "JPEGImages" and "Annotations"
parser.add_argument('--generic_path')
parser.add_argument('--dataset', help='D16/D17/Y18/Y19/LV1/LV3/G', default='D17')
parser.add_argument('--split', help='val/test', default='val')
parser.add_argument('--output', default=None)
parser.add_argument('--save_all', action='store_true',
help='Save all frames. Useful only in YouTubeVOS/long-time video', )
parser.add_argument('--benchmark', action='store_true', help='enable to disable amp for FPS benchmarking')
# Long-term memory options
parser.add_argument('--disable_long_term', action='store_true')
parser.add_argument('--max_mid_term_frames', help='T_max in paper, decrease to save memory', type=int, default=10)
parser.add_argument('--min_mid_term_frames', help='T_min in paper, decrease to save memory', type=int, default=5)
parser.add_argument('--max_long_term_elements', help='LT_max in paper, increase if objects disappear for a long time',
type=int, default=10000)
parser.add_argument('--num_prototypes', help='P in paper', type=int, default=128)
parser.add_argument('--top_k', type=int, default=30)
parser.add_argument('--mem_every', help='r in paper. Increase to improve running speed.', type=int, default=5)
parser.add_argument('--deep_update_every', help='Leave -1 normally to synchronize with mem_every', type=int, default=-1)
# Multi-scale options
parser.add_argument('--save_scores', action='store_true')
parser.add_argument('--flip', action='store_true')
parser.add_argument('--size', default=480, type=int,
help='Resize the shorter side to this size. -1 to use original resolution. ')
args = parser.parse_args()
config = vars(args)
config['enable_long_term'] = not config['disable_long_term']
if args.output is None:
args.output = f'../output/{args.dataset}_{args.split}'
print(f'Output path not provided. Defaulting to {args.output}')
"""
Data preparation
"""
is_youtube = args.dataset.startswith('Y')
is_davis = args.dataset.startswith('D')
is_lv = args.dataset.startswith('LV')
if is_youtube or args.save_scores:
out_path = path.join(args.output, 'Annotations')
else:
out_path = args.output
if is_youtube:
if args.dataset == 'Y18':
yv_path = args.y18_path
elif args.dataset == 'Y19':
yv_path = args.y19_path
if args.split == 'val':
args.split = 'valid'
meta_dataset = YouTubeVOSTestDataset(data_root=yv_path, split='valid', size=args.size)
elif args.split == 'test':
meta_dataset = YouTubeVOSTestDataset(data_root=yv_path, split='test', size=args.size)
else:
raise NotImplementedError
elif is_davis:
if args.dataset == 'D16':
if args.split == 'val':
# Set up Dataset, a small hack to use the image set in the 2017 folder because the 2016 one is of a different format
meta_dataset = DAVISTestDataset(args.d16_path, imset='../../2017/trainval/ImageSets/2016/val.txt', size=args.size)
else:
raise NotImplementedError
palette = None
elif args.dataset == 'D17':
if args.split == 'val':
meta_dataset = DAVISTestDataset(path.join(args.d17_path, 'trainval'), imset='2017/val.txt', size=args.size)
elif args.split == 'test':
meta_dataset = DAVISTestDataset(path.join(args.d17_path, 'test-dev'), imset='2017/test-dev.txt', size=args.size)
else:
raise NotImplementedError
elif is_lv:
if args.dataset == 'LV1':
meta_dataset = LongTestDataset(path.join(args.lv_path, 'long_video'))
elif args.dataset == 'LV3':
meta_dataset = LongTestDataset(path.join(args.lv_path, 'long_video_x3'))
else:
raise NotImplementedError
elif args.dataset == 'G':
meta_dataset = LongTestDataset(path.join(args.generic_path), size=args.size)
if not args.save_all:
args.save_all = True
print('save_all is forced to be true in generic evaluation mode.')
else:
raise NotImplementedError
torch.autograd.set_grad_enabled(False)
# Set up loader
meta_loader = meta_dataset.get_datasets()
# Load our checkpoint
network = XMem(config, args.model).cuda().eval()
if args.model is not None:
model_weights = torch.load(args.model)
network.load_weights(model_weights, init_as_zero_if_needed=True)
else:
print('No model loaded.')
total_process_time = 0
total_frames = 0
# Start eval
for vid_reader in progressbar(meta_loader, max_value=len(meta_dataset), redirect_stdout=True):
loader = DataLoader(vid_reader, batch_size=1, shuffle=False, num_workers=2)
vid_name = vid_reader.vid_name
vid_length = len(loader)
# no need to count usage for LT if the video is not that long anyway
config['enable_long_term_count_usage'] = (
config['enable_long_term'] and
(vid_length
/ (config['max_mid_term_frames']-config['min_mid_term_frames'])
* config['num_prototypes'])
>= config['max_long_term_elements']
)
mapper = MaskMapper()
processor = InferenceCore(network, config=config)
first_mask_loaded = False
for ti, data in enumerate(loader):
with torch.cuda.amp.autocast(enabled=not args.benchmark):
rgb = data['rgb'].cuda()[0]
msk = data.get('mask')
info = data['info']
frame = info['frame'][0]
shape = info['shape']
need_resize = info['need_resize'][0]
"""
For timing see https://discuss.pytorch.org/t/how-to-measure-time-in-pytorch/26964
Seems to be very similar in testing as my previous timing method
with two cuda sync + time.time() in STCN though
"""
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
if not first_mask_loaded:
if msk is not None:
first_mask_loaded = True
else:
# no point to do anything without a mask
continue
if args.flip:
rgb = torch.flip(rgb, dims=[-1])
msk = torch.flip(msk, dims=[-1]) if msk is not None else None
# Map possibly non-continuous labels to continuous ones
if msk is not None:
msk, labels = mapper.convert_mask(msk[0].numpy())
msk = torch.Tensor(msk).cuda()
if need_resize:
msk = vid_reader.resize_mask(msk.unsqueeze(0))[0]
processor.set_all_labels(list(mapper.remappings.values()))
else:
labels = None
# Run the model on this frame
prob = processor.step(rgb, msk, labels, end=(ti==vid_length-1)) # 0, background, >0, objects
# Upsample to original size if needed
if need_resize:
prob = F.interpolate(prob.unsqueeze(1), shape, mode='bilinear', align_corners=False)[:,0]
end.record()
torch.cuda.synchronize()
total_process_time += (start.elapsed_time(end)/1000)
total_frames += 1
if args.flip:
prob = torch.flip(prob, dims=[-1])
# Probability mask -> index mask
out_mask = torch.argmax(prob, dim=0)
out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
if args.save_scores:
prob = (prob.detach().cpu().numpy()*255).astype(np.uint8)
# Save the mask
if args.save_all or info['save'][0]:
this_out_path = path.join(out_path, vid_name)
os.makedirs(this_out_path, exist_ok=True)
out_mask = mapper.remap_index_mask(out_mask)
out_img = Image.fromarray(out_mask)
if vid_reader.get_palette() is not None:
out_img.putpalette(vid_reader.get_palette())
out_img.save(os.path.join(this_out_path, frame[:-4]+'.png'))
if args.save_scores:
np_path = path.join(args.output, 'Scores', vid_name)
os.makedirs(np_path, exist_ok=True)
if ti==len(loader)-1:
hkl.dump(mapper.remappings, path.join(np_path, f'backward.hkl'), mode='w')
if args.save_all or info['save'][0]:
hkl.dump(prob, path.join(np_path, f'{frame[:-4]}.hkl'), mode='w', compression='lzf')
print(f'Total processing time: {total_process_time}')
print(f'Total processed frames: {total_frames}')
print(f'FPS: {total_frames / total_process_time}')
print(f'Max allocated memory (MB): {torch.cuda.max_memory_allocated() / (2**20)}')
if not args.save_scores:
if is_youtube:
print('Making zip for YouTubeVOS...')
shutil.make_archive(path.join(args.output, path.basename(args.output)), 'zip', args.output, 'Annotations')
elif is_davis and args.split == 'test':
print('Making zip for DAVIS test-dev...')
shutil.make_archive(args.output, 'zip', args.output)

View File

@@ -1,96 +0,0 @@
import os
from os import path
import json
from inference.data.video_reader import VideoReader
class LongTestDataset:
def __init__(self, data_root, size=-1):
self.image_dir = path.join(data_root, 'JPEGImages')
self.mask_dir = path.join(data_root, 'Annotations')
self.size = size
self.vid_list = sorted(os.listdir(self.image_dir))
def get_datasets(self):
for video in self.vid_list:
yield VideoReader(video,
path.join(self.image_dir, video),
path.join(self.mask_dir, video),
to_save = [
name[:-4] for name in os.listdir(path.join(self.mask_dir, video))
],
size=self.size,
)
def __len__(self):
return len(self.vid_list)
class DAVISTestDataset:
def __init__(self, data_root, imset='2017/val.txt', size=-1):
if size != 480:
self.image_dir = path.join(data_root, 'JPEGImages', 'Full-Resolution')
self.mask_dir = path.join(data_root, 'Annotations', 'Full-Resolution')
if not path.exists(self.image_dir):
print(f'{self.image_dir} not found. Look at other options.')
self.image_dir = path.join(data_root, 'JPEGImages', '1080p')
self.mask_dir = path.join(data_root, 'Annotations', '1080p')
assert path.exists(self.image_dir), 'path not found'
else:
self.image_dir = path.join(data_root, 'JPEGImages', '480p')
self.mask_dir = path.join(data_root, 'Annotations', '480p')
self.size_dir = path.join(data_root, 'JPEGImages', '480p')
self.size = size
with open(path.join(data_root, 'ImageSets', imset)) as f:
self.vid_list = sorted([line.strip() for line in f])
def get_datasets(self):
for video in self.vid_list:
yield VideoReader(video,
path.join(self.image_dir, video),
path.join(self.mask_dir, video),
size=self.size,
size_dir=path.join(self.size_dir, video),
)
def __len__(self):
return len(self.vid_list)
class YouTubeVOSTestDataset:
def __init__(self, data_root, split, size=480):
self.image_dir = path.join(data_root, 'all_frames', split+'_all_frames', 'JPEGImages')
self.mask_dir = path.join(data_root, split, 'Annotations')
self.size = size
self.vid_list = sorted(os.listdir(self.image_dir))
self.req_frame_list = {}
with open(path.join(data_root, split, 'meta.json')) as f:
# read meta.json to know which frame is required for evaluation
meta = json.load(f)['videos']
for vid in self.vid_list:
req_frames = []
objects = meta[vid]['objects']
for value in objects.values():
req_frames.extend(value['frames'])
req_frames = list(set(req_frames))
self.req_frame_list[vid] = req_frames
def get_datasets(self):
for video in self.vid_list:
yield VideoReader(video,
path.join(self.image_dir, video),
path.join(self.mask_dir, video),
size=self.size,
to_save=self.req_frame_list[video],
use_all_mask=True
)
def __len__(self):
return len(self.vid_list)

View File

@@ -1,100 +0,0 @@
import os
from os import path
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from torchvision.transforms import InterpolationMode
import torch.nn.functional as F
from PIL import Image
import numpy as np
from dataset.range_transform import im_normalization
class VideoReader(Dataset):
"""
This class is used to read a video, one frame at a time
"""
def __init__(self, vid_name, image_dir, mask_dir, size=-1, to_save=None, use_all_mask=False, size_dir=None):
"""
image_dir - points to a directory of jpg images
mask_dir - points to a directory of png masks
size - resize min. side to size. Does nothing if <0.
to_save - optionally contains a list of file names without extensions
where the segmentation mask is required
use_all_mask - when true, read all available mask in mask_dir.
Default false. Set to true for YouTubeVOS validation.
"""
self.vid_name = vid_name
self.image_dir = image_dir
self.mask_dir = mask_dir
self.to_save = to_save
self.use_all_mask = use_all_mask
if size_dir is None:
self.size_dir = self.image_dir
else:
self.size_dir = size_dir
self.frames = sorted(os.listdir(self.image_dir))
self.palette = Image.open(path.join(mask_dir, sorted(os.listdir(mask_dir))[0])).getpalette()
self.first_gt_path = path.join(self.mask_dir, sorted(os.listdir(self.mask_dir))[0])
if size < 0:
self.im_transform = transforms.Compose([
transforms.ToTensor(),
im_normalization,
])
else:
self.im_transform = transforms.Compose([
transforms.ToTensor(),
im_normalization,
transforms.Resize(size, interpolation=InterpolationMode.BILINEAR),
])
self.size = size
def __getitem__(self, idx):
frame = self.frames[idx]
info = {}
data = {}
info['frame'] = frame
info['save'] = (self.to_save is None) or (frame[:-4] in self.to_save)
im_path = path.join(self.image_dir, frame)
img = Image.open(im_path).convert('RGB')
if self.image_dir == self.size_dir:
shape = np.array(img).shape[:2]
else:
size_path = path.join(self.size_dir, frame)
size_im = Image.open(size_path).convert('RGB')
shape = np.array(size_im).shape[:2]
gt_path = path.join(self.mask_dir, frame[:-4]+'.png')
img = self.im_transform(img)
load_mask = self.use_all_mask or (gt_path == self.first_gt_path)
if load_mask and path.exists(gt_path):
mask = Image.open(gt_path).convert('P')
mask = np.array(mask, dtype=np.uint8)
data['mask'] = mask
info['shape'] = shape
info['need_resize'] = not (self.size < 0)
data['rgb'] = img
data['info'] = info
return data
def resize_mask(self, mask):
# mask transform is applied AFTER mapper, so we need to post-process it in eval.py
h, w = mask.shape[-2:]
min_hw = min(h, w)
return F.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)),
mode='nearest')
def get_palette(self):
return self.palette
def __len__(self):
return len(self.frames)

View File

@@ -63,8 +63,6 @@ class InferenceCore:
if need_segment:
memory_readout = self.memory.match_memory(key, selection).unsqueeze(0)
hidden, _, pred_prob_with_bg = self.network.segment(multi_scale_features, memory_readout,
self.memory.get_hidden(), h_out=is_normal_update, strip_bg=False)
# remove batch dim

View File

@@ -1,138 +0,0 @@
from argparse import ArgumentParser
def none_or_default(x, default):
return x if x is not None else default
class Configuration():
def parse(self, unknown_arg_ok=False):
parser = ArgumentParser()
# Enable torch.backends.cudnn.benchmark -- Faster in some cases, test in your own environment
parser.add_argument('--benchmark', action='store_true')
parser.add_argument('--no_amp', action='store_true')
# Save paths
parser.add_argument('--save_path', default='/ssd1/gaomingqi/output/xmem-sam')
# Data parameters
parser.add_argument('--static_root', help='Static training data root', default='/ssd1/gaomingqi/datasets/static')
parser.add_argument('--bl_root', help='Blender training data root', default='../BL30K')
parser.add_argument('--yv_root', help='YouTubeVOS data root', default='/ssd1/gaomingqi/datasets/youtube-vos/2018')
parser.add_argument('--davis_root', help='DAVIS data root', default='/ssd1/gaomingqi/datasets/davis')
parser.add_argument('--num_workers', help='Total number of dataloader workers across all GPUs processes', type=int, default=16)
parser.add_argument('--key_dim', default=64, type=int)
parser.add_argument('--value_dim', default=512, type=int)
parser.add_argument('--hidden_dim', default=64, help='Set to =0 to disable', type=int)
parser.add_argument('--deep_update_prob', default=0.2, type=float)
parser.add_argument('--stages', help='Training stage (0-static images, 1-Blender dataset, 2-DAVIS+YouTubeVOS)', default='02')
"""
Stage-specific learning parameters
Batch sizes are effective -- you don't have to scale them when you scale the number processes
"""
# Stage 0, static images
parser.add_argument('--s0_batch_size', default=16, type=int)
parser.add_argument('--s0_iterations', default=150000, type=int)
parser.add_argument('--s0_finetune', default=0, type=int)
parser.add_argument('--s0_steps', nargs="*", default=[], type=int)
parser.add_argument('--s0_lr', help='Initial learning rate', default=1e-5, type=float)
parser.add_argument('--s0_num_ref_frames', default=2, type=int)
parser.add_argument('--s0_num_frames', default=3, type=int)
parser.add_argument('--s0_start_warm', default=20000, type=int)
parser.add_argument('--s0_end_warm', default=70000, type=int)
# Stage 1, BL30K
parser.add_argument('--s1_batch_size', default=8, type=int)
parser.add_argument('--s1_iterations', default=250000, type=int)
# fine-tune means fewer augmentations to train the sensory memory
parser.add_argument('--s1_finetune', default=0, type=int)
parser.add_argument('--s1_steps', nargs="*", default=[200000], type=int)
parser.add_argument('--s1_lr', help='Initial learning rate', default=1e-5, type=float)
parser.add_argument('--s1_num_ref_frames', default=3, type=int)
parser.add_argument('--s1_num_frames', default=8, type=int)
parser.add_argument('--s1_start_warm', default=20000, type=int)
parser.add_argument('--s1_end_warm', default=70000, type=int)
# Stage 2, DAVIS+YoutubeVOS, longer
parser.add_argument('--s2_batch_size', default=8, type=int)
parser.add_argument('--s2_iterations', default=150000, type=int)
# fine-tune means fewer augmentations to train the sensory memory
parser.add_argument('--s2_finetune', default=10000, type=int)
parser.add_argument('--s2_steps', nargs="*", default=[120000], type=int)
parser.add_argument('--s2_lr', help='Initial learning rate', default=1e-5, type=float)
parser.add_argument('--s2_num_ref_frames', default=3, type=int)
parser.add_argument('--s2_num_frames', default=8, type=int)
parser.add_argument('--s2_start_warm', default=20000, type=int)
parser.add_argument('--s2_end_warm', default=70000, type=int)
# Stage 3, DAVIS+YoutubeVOS, shorter
parser.add_argument('--s3_batch_size', default=8, type=int)
parser.add_argument('--s3_iterations', default=100000, type=int)
# fine-tune means fewer augmentations to train the sensory memory
parser.add_argument('--s3_finetune', default=10000, type=int)
parser.add_argument('--s3_steps', nargs="*", default=[80000], type=int)
parser.add_argument('--s3_lr', help='Initial learning rate', default=1e-5, type=float)
parser.add_argument('--s3_num_ref_frames', default=3, type=int)
parser.add_argument('--s3_num_frames', default=8, type=int)
parser.add_argument('--s3_start_warm', default=20000, type=int)
parser.add_argument('--s3_end_warm', default=70000, type=int)
parser.add_argument('--gamma', help='LR := LR*gamma at every decay step', default=0.1, type=float)
parser.add_argument('--weight_decay', default=0.05, type=float)
# Loading
parser.add_argument('--load_network', help='Path to pretrained network weight only')
parser.add_argument('--load_checkpoint', help='Path to the checkpoint file, including network, optimizer and such')
# Logging information
parser.add_argument('--log_text_interval', default=100, type=int)
parser.add_argument('--log_image_interval', default=1000, type=int)
parser.add_argument('--save_network_interval', default=25000, type=int)
parser.add_argument('--save_checkpoint_interval', default=50000, type=int)
parser.add_argument('--exp_id', help='Experiment UNIQUE id, use NULL to disable logging to tensorboard', default='NULL')
parser.add_argument('--debug', help='Debug mode which logs information more often', action='store_true')
# # Multiprocessing parameters, not set by users
# parser.add_argument('--local_rank', default=0, type=int, help='Local rank of this process')
if unknown_arg_ok:
args, _ = parser.parse_known_args()
self.args = vars(args)
else:
self.args = vars(parser.parse_args())
self.args['amp'] = not self.args['no_amp']
# check if the stages are valid
stage_to_perform = list(self.args['stages'])
for s in stage_to_perform:
if s not in ['0', '1', '2', '3']:
raise NotImplementedError
def get_stage_parameters(self, stage):
parameters = {
'batch_size': self.args['s%s_batch_size'%stage],
'iterations': self.args['s%s_iterations'%stage],
'finetune': self.args['s%s_finetune'%stage],
'steps': self.args['s%s_steps'%stage],
'lr': self.args['s%s_lr'%stage],
'num_ref_frames': self.args['s%s_num_ref_frames'%stage],
'num_frames': self.args['s%s_num_frames'%stage],
'start_warm': self.args['s%s_start_warm'%stage],
'end_warm': self.args['s%s_end_warm'%stage],
}
return parameters
def __getitem__(self, key):
return self.args[key]
def __setitem__(self, key, value):
self.args[key] = value
def __str__(self):
return str(self.args)

View File

@@ -1,60 +0,0 @@
bear
bmx-bumps
boat
boxing-fisheye
breakdance-flare
bus
car-turn
cat-girl
classic-car
color-run
crossing
dance-jump
dancing
disc-jockey
dog-agility
dog-gooses
dogs-scale
drift-turn
drone
elephant
flamingo
hike
hockey
horsejump-low
kid-football
kite-walk
koala
lady-running
lindy-hop
longboard
lucia
mallard-fly
mallard-water
miami-surf
motocross-bumps
motorbike
night-race
paragliding
planes-water
rallye
rhino
rollerblade
schoolgirls
scooter-board
scooter-gray
sheep
skate-park
snowboard
soccerball
stroller
stunt
surf
swing
tennis
tractor-sand
train
tuk-tuk
upside-down
varanus-cage
walking

View File

@@ -1,136 +0,0 @@
import cv2
import numpy as np
import torch
from dataset.range_transform import inv_im_trans
from collections import defaultdict
def tensor_to_numpy(image):
image_np = (image.numpy() * 255).astype('uint8')
return image_np
def tensor_to_np_float(image):
image_np = image.numpy().astype('float32')
return image_np
def detach_to_cpu(x):
return x.detach().cpu()
def transpose_np(x):
return np.transpose(x, [1,2,0])
def tensor_to_gray_im(x):
x = detach_to_cpu(x)
x = tensor_to_numpy(x)
x = transpose_np(x)
return x
def tensor_to_im(x):
x = detach_to_cpu(x)
x = inv_im_trans(x).clamp(0, 1)
x = tensor_to_numpy(x)
x = transpose_np(x)
return x
# Predefined key <-> caption dict
key_captions = {
'im': 'Image',
'gt': 'GT',
}
"""
Return an image array with captions
keys in dictionary will be used as caption if not provided
values should contain lists of cv2 images
"""
def get_image_array(images, grid_shape, captions={}):
h, w = grid_shape
cate_counts = len(images)
rows_counts = len(next(iter(images.values())))
font = cv2.FONT_HERSHEY_SIMPLEX
output_image = np.zeros([w*cate_counts, h*(rows_counts+1), 3], dtype=np.uint8)
col_cnt = 0
for k, v in images.items():
# Default as key value itself
caption = captions.get(k, k)
# Handles new line character
dy = 40
for i, line in enumerate(caption.split('\n')):
cv2.putText(output_image, line, (10, col_cnt*w+100+i*dy),
font, 0.8, (255,255,255), 2, cv2.LINE_AA)
# Put images
for row_cnt, img in enumerate(v):
im_shape = img.shape
if len(im_shape) == 2:
img = img[..., np.newaxis]
img = (img * 255).astype('uint8')
output_image[(col_cnt+0)*w:(col_cnt+1)*w,
(row_cnt+1)*h:(row_cnt+2)*h, :] = img
col_cnt += 1
return output_image
def base_transform(im, size):
im = tensor_to_np_float(im)
if len(im.shape) == 3:
im = im.transpose((1, 2, 0))
else:
im = im[:, :, None]
# Resize
if im.shape[1] != size:
im = cv2.resize(im, size, interpolation=cv2.INTER_NEAREST)
return im.clip(0, 1)
def im_transform(im, size):
return base_transform(inv_im_trans(detach_to_cpu(im)), size=size)
def mask_transform(mask, size):
return base_transform(detach_to_cpu(mask), size=size)
def out_transform(mask, size):
return base_transform(detach_to_cpu(torch.sigmoid(mask)), size=size)
def pool_pairs(images, size, num_objects):
req_images = defaultdict(list)
b, t = images['rgb'].shape[:2]
# limit the number of images saved
b = min(2, b)
# find max num objects
max_num_objects = max(num_objects[:b])
GT_suffix = ''
for bi in range(b):
GT_suffix += ' \n%s' % images['info']['name'][bi][-25:-4]
for bi in range(b):
for ti in range(t):
req_images['RGB'].append(im_transform(images['rgb'][bi,ti], size))
for oi in range(max_num_objects):
if ti == 0 or oi >= num_objects[bi]:
req_images['Mask_%d'%oi].append(mask_transform(images['first_frame_gt'][bi][0,oi], size))
# req_images['Mask_X8_%d'%oi].append(mask_transform(images['first_frame_gt'][bi][0,oi], size))
# req_images['Mask_X16_%d'%oi].append(mask_transform(images['first_frame_gt'][bi][0,oi], size))
else:
req_images['Mask_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi], size))
# req_images['Mask_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi][2], size))
# req_images['Mask_X8_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi][1], size))
# req_images['Mask_X16_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi][0], size))
req_images['GT_%d_%s'%(oi, GT_suffix)].append(mask_transform(images['cls_gt'][bi,ti,0]==(oi+1), size))
# print((images['cls_gt'][bi,ti,0]==(oi+1)).shape)
# print(mask_transform(images['cls_gt'][bi,ti,0]==(oi+1), size).shape)
return get_image_array(req_images, size, key_captions)

View File

@@ -1,16 +0,0 @@
"""
load_subset.py - Presents a subset of data
DAVIS - only the training set
YouTubeVOS - I manually filtered some erroneous ones out but I haven't checked all
"""
def load_sub_davis(path='util/davis_subset.txt'):
with open(path, mode='r') as f:
subset = set(f.read().splitlines())
return subset
def load_sub_yv(path='util/yv_subset.txt'):
with open(path, mode='r') as f:
subset = set(f.read().splitlines())
return subset

View File

@@ -1,80 +0,0 @@
"""
Integrate numerical values for some iterations
Typically used for loss computation / logging to tensorboard
Call finalize and create a new Integrator when you want to display/log
"""
import torch
class Integrator:
def __init__(self, logger, distributed=True, local_rank=0, world_size=1):
self.values = {}
self.counts = {}
self.hooks = [] # List is used here to maintain insertion order
self.logger = logger
self.distributed = distributed
self.local_rank = local_rank
self.world_size = world_size
def add_tensor(self, key, tensor):
if key not in self.values:
self.counts[key] = 1
if type(tensor) == float or type(tensor) == int:
self.values[key] = tensor
else:
self.values[key] = tensor.mean().item()
else:
self.counts[key] += 1
if type(tensor) == float or type(tensor) == int:
self.values[key] += tensor
else:
self.values[key] += tensor.mean().item()
def add_dict(self, tensor_dict):
for k, v in tensor_dict.items():
self.add_tensor(k, v)
def add_hook(self, hook):
"""
Adds a custom hook, i.e. compute new metrics using values in the dict
The hook takes the dict as argument, and returns a (k, v) tuple
e.g. for computing IoU
"""
if type(hook) == list:
self.hooks.extend(hook)
else:
self.hooks.append(hook)
def reset_except_hooks(self):
self.values = {}
self.counts = {}
# Average and output the metrics
def finalize(self, prefix, it, f=None):
for hook in self.hooks:
k, v = hook(self.values)
self.add_tensor(k, v)
for k, v in self.values.items():
if k[:4] == 'hide':
continue
avg = v / self.counts[k]
if self.distributed:
# Inplace operation
avg = torch.tensor(avg).cuda()
torch.distributed.reduce(avg, dst=0)
if self.local_rank == 0:
avg = (avg/self.world_size).cpu().item()
self.logger.log_metrics(prefix, k, avg, it, f)
else:
# Simple does it
self.logger.log_metrics(prefix, k, avg, it, f)

View File

@@ -1,101 +0,0 @@
"""
Dumps things to tensorboard and console
"""
import os
import warnings
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
def tensor_to_numpy(image):
image_np = (image.numpy() * 255).astype('uint8')
return image_np
def detach_to_cpu(x):
return x.detach().cpu()
def fix_width_trunc(x):
return ('{:.9s}'.format('{:0.9f}'.format(x)))
class TensorboardLogger:
def __init__(self, short_id, id, git_info):
self.short_id = short_id
if self.short_id == 'NULL':
self.short_id = 'DEBUG'
if id is None:
self.no_log = True
warnings.warn('Logging has been disbaled.')
else:
self.no_log = False
self.inv_im_trans = transforms.Normalize(
mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
std=[1/0.229, 1/0.224, 1/0.225])
self.inv_seg_trans = transforms.Normalize(
mean=[-0.5/0.5],
std=[1/0.5])
log_path = os.path.join('.', 'logs', '%s' % id)
self.logger = SummaryWriter(log_path)
self.log_string('git', git_info)
def log_scalar(self, tag, x, step):
if self.no_log:
warnings.warn('Logging has been disabled.')
return
self.logger.add_scalar(tag, x, step)
def log_metrics(self, l1_tag, l2_tag, val, step, f=None):
tag = l1_tag + '/' + l2_tag
text = '{:s} - It {:6d} [{:5s}] [{:13}]: {:s}'.format(self.short_id, step, l1_tag.upper(), l2_tag, fix_width_trunc(val))
print(text)
if f is not None:
f.write(text + '\n')
f.flush()
self.log_scalar(tag, val, step)
def log_im(self, tag, x, step):
if self.no_log:
warnings.warn('Logging has been disabled.')
return
x = detach_to_cpu(x)
x = self.inv_im_trans(x)
x = tensor_to_numpy(x)
self.logger.add_image(tag, x, step)
def log_cv2(self, tag, x, step):
if self.no_log:
warnings.warn('Logging has been disabled.')
return
x = x.transpose((2, 0, 1))
self.logger.add_image(tag, x, step)
def log_seg(self, tag, x, step):
if self.no_log:
warnings.warn('Logging has been disabled.')
return
x = detach_to_cpu(x)
x = self.inv_seg_trans(x)
x = tensor_to_numpy(x)
self.logger.add_image(tag, x, step)
def log_gray(self, tag, x, step):
if self.no_log:
warnings.warn('Logging has been disabled.')
return
x = detach_to_cpu(x)
x = tensor_to_numpy(x)
self.logger.add_image(tag, x, step)
def log_string(self, tag, x):
print(tag, x)
if self.no_log:
warnings.warn('Logging has been disabled.')
return
self.logger.add_text(tag, x)

View File

@@ -1,8 +1,16 @@
import numpy as np
import torch
from dataset.util import all_to_onehot
def all_to_onehot(masks, labels):
if len(masks.shape) == 3:
Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1], masks.shape[2]), dtype=np.uint8)
else:
Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1]), dtype=np.uint8)
for ni, l in enumerate(labels):
Ms[ni] = (masks == l).astype(np.uint8)
return Ms
class MaskMapper:
"""
@@ -23,6 +31,12 @@ class MaskMapper:
# if coherent, no mapping is required
self.coherent = True
def clear_labels(self):
self.labels = []
self.remappings = {}
# if coherent, no mapping is required
self.coherent = True
def convert_mask(self, mask, exhaustive=False):
# mask is in index representation, H*W numpy array
labels = np.unique(mask).astype(np.uint8)

View File

@@ -1,3 +0,0 @@
davis_palette = b'\x00\x00\x00\x80\x00\x00\x00\x80\x00\x80\x80\x00\x00\x00\x80\x80\x00\x80\x00\x80\x80\x80\x80\x80@\x00\x00\xc0\x00\x00@\x80\x00\xc0\x80\x00@\x00\x80\xc0\x00\x80@\x80\x80\xc0\x80\x80\x00@\x00\x80@\x00\x00\xc0\x00\x80\xc0\x00\x00@\x80\x80@\x80\x00\xc0\x80\x80\xc0\x80@@\x00\xc0@\x00@\xc0\x00\xc0\xc0\x00@@\x80\xc0@\x80@\xc0\x80\xc0\xc0\x80\x00\x00@\x80\x00@\x00\x80@\x80\x80@\x00\x00\xc0\x80\x00\xc0\x00\x80\xc0\x80\x80\xc0@\x00@\xc0\x00@@\x80@\xc0\x80@@\x00\xc0\xc0\x00\xc0@\x80\xc0\xc0\x80\xc0\x00@@\x80@@\x00\xc0@\x80\xc0@\x00@\xc0\x80@\xc0\x00\xc0\xc0\x80\xc0\xc0@@@\xc0@@@\xc0@\xc0\xc0@@@\xc0\xc0@\xc0@\xc0\xc0\xc0\xc0\xc0 \x00\x00\xa0\x00\x00 \x80\x00\xa0\x80\x00 \x00\x80\xa0\x00\x80 \x80\x80\xa0\x80\x80`\x00\x00\xe0\x00\x00`\x80\x00\xe0\x80\x00`\x00\x80\xe0\x00\x80`\x80\x80\xe0\x80\x80 @\x00\xa0@\x00 \xc0\x00\xa0\xc0\x00 @\x80\xa0@\x80 \xc0\x80\xa0\xc0\x80`@\x00\xe0@\x00`\xc0\x00\xe0\xc0\x00`@\x80\xe0@\x80`\xc0\x80\xe0\xc0\x80 \x00@\xa0\x00@ \x80@\xa0\x80@ \x00\xc0\xa0\x00\xc0 \x80\xc0\xa0\x80\xc0`\x00@\xe0\x00@`\x80@\xe0\x80@`\x00\xc0\xe0\x00\xc0`\x80\xc0\xe0\x80\xc0 @@\xa0@@ \xc0@\xa0\xc0@ @\xc0\xa0@\xc0 \xc0\xc0\xa0\xc0\xc0`@@\xe0@@`\xc0@\xe0\xc0@`@\xc0\xe0@\xc0`\xc0\xc0\xe0\xc0\xc0\x00 \x00\x80 \x00\x00\xa0\x00\x80\xa0\x00\x00 \x80\x80 \x80\x00\xa0\x80\x80\xa0\x80@ \x00\xc0 \x00@\xa0\x00\xc0\xa0\x00@ \x80\xc0 \x80@\xa0\x80\xc0\xa0\x80\x00`\x00\x80`\x00\x00\xe0\x00\x80\xe0\x00\x00`\x80\x80`\x80\x00\xe0\x80\x80\xe0\x80@`\x00\xc0`\x00@\xe0\x00\xc0\xe0\x00@`\x80\xc0`\x80@\xe0\x80\xc0\xe0\x80\x00 @\x80 @\x00\xa0@\x80\xa0@\x00 \xc0\x80 \xc0\x00\xa0\xc0\x80\xa0\xc0@ @\xc0 @@\xa0@\xc0\xa0@@ \xc0\xc0 \xc0@\xa0\xc0\xc0\xa0\xc0\x00`@\x80`@\x00\xe0@\x80\xe0@\x00`\xc0\x80`\xc0\x00\xe0\xc0\x80\xe0\xc0@`@\xc0`@@\xe0@\xc0\xe0@@`\xc0\xc0`\xc0@\xe0\xc0\xc0\xe0\xc0 \x00\xa0 \x00 \xa0\x00\xa0\xa0\x00 \x80\xa0 \x80 \xa0\x80\xa0\xa0\x80` \x00\xe0 \x00`\xa0\x00\xe0\xa0\x00` \x80\xe0 \x80`\xa0\x80\xe0\xa0\x80 `\x00\xa0`\x00 \xe0\x00\xa0\xe0\x00 `\x80\xa0`\x80 \xe0\x80\xa0\xe0\x80``\x00\xe0`\x00`\xe0\x00\xe0\xe0\x00``\x80\xe0`\x80`\xe0\x80\xe0\xe0\x80 @\xa0 @ \xa0@\xa0\xa0@ \xc0\xa0 \xc0 \xa0\xc0\xa0\xa0\xc0` @\xe0 @`\xa0@\xe0\xa0@` \xc0\xe0 \xc0`\xa0\xc0\xe0\xa0\xc0 `@\xa0`@ \xe0@\xa0\xe0@ `\xc0\xa0`\xc0 \xe0\xc0\xa0\xe0\xc0``@\xe0`@`\xe0@\xe0\xe0@``\xc0\xe0`\xc0`\xe0\xc0\xe0\xe0\xc0'
youtube_palette = b'\x00\x00\x00\xec_g\xf9\x91W\xfa\xc8c\x99\xc7\x94b\xb3\xb2f\x99\xcc\xc5\x94\xc5\xabyg\xff\xff\xffes~\x0b\x0b\x0b\x0c\x0c\x0c\r\r\r\x0e\x0e\x0e\x0f\x0f\x0f'

File diff suppressed because it is too large Load Diff