mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 08:27:49 +01:00
remove redundant code
This commit is contained in:
@@ -7,14 +7,13 @@ from PIL import Image
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
|
import torch.nn.functional as F
|
||||||
from model.network import XMem
|
from model.network import XMem
|
||||||
from inference.inference_core import InferenceCore
|
from inference.inference_core import InferenceCore
|
||||||
from inference.data.mask_mapper import MaskMapper
|
from util.mask_mapper import MaskMapper
|
||||||
|
|
||||||
# for data transormation
|
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from dataset.range_transform import im_normalization
|
from util.range_transform import im_normalization
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
sys.path.insert(0, sys.path[0]+"/../")
|
sys.path.insert(0, sys.path[0]+"/../")
|
||||||
@@ -39,9 +38,11 @@ class BaseTracker:
|
|||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
im_normalization,
|
im_normalization,
|
||||||
])
|
])
|
||||||
self.mapper = MaskMapper()
|
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
|
self.mapper = MaskMapper()
|
||||||
|
self.initialised = False
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def resize_mask(self, mask):
|
def resize_mask(self, mask):
|
||||||
# mask transform is applied AFTER mapper, so we need to post-process it in eval.py
|
# mask transform is applied AFTER mapper, so we need to post-process it in eval.py
|
||||||
@@ -51,37 +52,42 @@ class BaseTracker:
|
|||||||
mode='nearest')
|
mode='nearest')
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def track(self, frames, first_frame_annotation):
|
def track(self, frame, first_frame_annotation=None):
|
||||||
"""
|
"""
|
||||||
Input:
|
Input:
|
||||||
frames: numpy arrays: T, H, W, 3 (T: number of frames)
|
frames: numpy arrays (H, W, 3)
|
||||||
first_frame_annotation: numpy array: H, W
|
first_frame_annotation: numpy array (H, W)
|
||||||
|
|
||||||
Output:
|
Output:
|
||||||
masks: numpy arrays: H, W
|
mask: numpy arrays (H, W)
|
||||||
|
prob: numpy arrays, probability map (H, W)
|
||||||
|
painted_image: numpy array (H, W, 3)
|
||||||
"""
|
"""
|
||||||
vid_length = len(frames)
|
if first_frame_annotation is not None:
|
||||||
masks = []
|
# initialisation
|
||||||
|
mask, labels = self.mapper.convert_mask(first_frame_annotation)
|
||||||
|
mask = torch.Tensor(mask).to(self.device)
|
||||||
|
self.tracker.set_all_labels(list(self.mapper.remappings.values()))
|
||||||
|
else:
|
||||||
|
mask = None
|
||||||
|
labels = None
|
||||||
|
|
||||||
for ti, frame in enumerate(frames):
|
# prepare inputs
|
||||||
# convert to tensor
|
frame_tensor = self.im_transform(frame).to(self.device)
|
||||||
frame_tensor = self.im_transform(frame).to(self.device)
|
# track one frame
|
||||||
if ti == 0:
|
prob = self.tracker.step(frame_tensor, mask, labels)
|
||||||
mask, labels = self.mapper.convert_mask(first_frame_annotation)
|
# convert to mask
|
||||||
mask = torch.Tensor(mask).to(self.device)
|
out_mask = torch.argmax(prob, dim=0)
|
||||||
self.tracker.set_all_labels(list(self.mapper.remappings.values()))
|
out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
|
||||||
else:
|
painted_image = mask_painter(frame, out_mask)
|
||||||
mask = None
|
|
||||||
labels = None
|
|
||||||
|
|
||||||
# track one frame
|
|
||||||
prob = self.tracker.step(frame_tensor, mask, labels, end=(ti==vid_length-1))
|
|
||||||
# convert to mask
|
|
||||||
out_mask = torch.argmax(prob, dim=0)
|
|
||||||
out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
|
|
||||||
masks.append(out_mask)
|
|
||||||
|
|
||||||
return np.stack(masks, 0)
|
# mask, _, painted_frame
|
||||||
|
return out_mask, prob, painted_image
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def clear_memory(self):
|
||||||
|
self.tracker.clear_memory()
|
||||||
|
self.mapper.clear_labels()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@@ -106,11 +112,40 @@ if __name__ == '__main__':
|
|||||||
tracker = BaseTracker(device, XMEM_checkpoint)
|
tracker = BaseTracker(device, XMEM_checkpoint)
|
||||||
|
|
||||||
# track anything given in the first frame annotation
|
# track anything given in the first frame annotation
|
||||||
masks = tracker.track(frames, first_frame_annotation)
|
for ti, frame in enumerate(frames):
|
||||||
|
if ti == 0:
|
||||||
# save
|
mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
|
||||||
for ti, (frame, mask) in enumerate(zip(frames, masks)):
|
else:
|
||||||
painted_image = mask_painter(frame, mask)
|
mask, prob, painted_image = tracker.track(frame)
|
||||||
# save
|
# save
|
||||||
painted_image = Image.fromarray(painted_image)
|
painted_image = Image.fromarray(painted_image)
|
||||||
painted_image.save(f'/ssd1/gaomingqi/results/TrackA/{ti:05d}.png')
|
painted_image.save(f'/ssd1/gaomingqi/results/TrackA/dance-twirl/{ti:05d}.png')
|
||||||
|
|
||||||
|
# ----------------------------------------------------------
|
||||||
|
# another video
|
||||||
|
# ----------------------------------------------------------
|
||||||
|
# video frames
|
||||||
|
video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/camel', '*.jpg'))
|
||||||
|
video_path_list.sort()
|
||||||
|
# first frame
|
||||||
|
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/camel/00000.png'
|
||||||
|
# load frames
|
||||||
|
frames = []
|
||||||
|
for video_path in video_path_list:
|
||||||
|
frames.append(np.array(Image.open(video_path).convert('RGB')))
|
||||||
|
frames = np.stack(frames, 0) # N, H, W, C
|
||||||
|
# load first frame annotation
|
||||||
|
first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C
|
||||||
|
|
||||||
|
print('first video done. clear.')
|
||||||
|
|
||||||
|
tracker.clear_memory()
|
||||||
|
# track anything given in the first frame annotation
|
||||||
|
for ti, frame in enumerate(frames):
|
||||||
|
if ti == 0:
|
||||||
|
mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
|
||||||
|
else:
|
||||||
|
mask, prob, painted_image = tracker.track(frame)
|
||||||
|
# save
|
||||||
|
painted_image = Image.fromarray(painted_image)
|
||||||
|
painted_image.save(f'/ssd1/gaomingqi/results/TrackA/camel/{ti:05d}.png')
|
||||||
|
|||||||
@@ -1,6 +0,0 @@
|
|||||||
import torch
|
|
||||||
import random
|
|
||||||
|
|
||||||
def reseed(seed):
|
|
||||||
random.seed(seed)
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
@@ -1,179 +0,0 @@
|
|||||||
import os
|
|
||||||
from os import path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.utils.data.dataset import Dataset
|
|
||||||
from torchvision import transforms
|
|
||||||
from torchvision.transforms import InterpolationMode
|
|
||||||
from PIL import Image
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from dataset.range_transform import im_normalization, im_mean
|
|
||||||
from dataset.tps import random_tps_warp
|
|
||||||
from dataset.reseed import reseed
|
|
||||||
|
|
||||||
|
|
||||||
class StaticTransformDataset(Dataset):
|
|
||||||
"""
|
|
||||||
Generate pseudo VOS data by applying random transforms on static images.
|
|
||||||
Single-object only.
|
|
||||||
|
|
||||||
Method 0 - FSS style (class/1.jpg class/1.png)
|
|
||||||
Method 1 - Others style (XXX.jpg XXX.png)
|
|
||||||
"""
|
|
||||||
def __init__(self, parameters, num_frames=3, max_num_obj=1):
|
|
||||||
self.num_frames = num_frames
|
|
||||||
self.max_num_obj = max_num_obj
|
|
||||||
|
|
||||||
self.im_list = []
|
|
||||||
for parameter in parameters:
|
|
||||||
root, method, multiplier = parameter
|
|
||||||
if method == 0:
|
|
||||||
# Get images
|
|
||||||
classes = os.listdir(root)
|
|
||||||
for c in classes:
|
|
||||||
imgs = os.listdir(path.join(root, c))
|
|
||||||
jpg_list = [im for im in imgs if 'jpg' in im[-3:].lower()]
|
|
||||||
|
|
||||||
joint_list = [path.join(root, c, im) for im in jpg_list]
|
|
||||||
self.im_list.extend(joint_list * multiplier)
|
|
||||||
|
|
||||||
elif method == 1:
|
|
||||||
self.im_list.extend([path.join(root, im) for im in os.listdir(root) if '.jpg' in im] * multiplier)
|
|
||||||
|
|
||||||
print(f'{len(self.im_list)} images found.')
|
|
||||||
|
|
||||||
# These set of transform is the same for im/gt pairs, but different among the 3 sampled frames
|
|
||||||
self.pair_im_lone_transform = transforms.Compose([
|
|
||||||
transforms.ColorJitter(0.1, 0.05, 0.05, 0), # No hue change here as that's not realistic
|
|
||||||
])
|
|
||||||
|
|
||||||
self.pair_im_dual_transform = transforms.Compose([
|
|
||||||
transforms.RandomAffine(degrees=20, scale=(0.9,1.1), shear=10, interpolation=InterpolationMode.BICUBIC, fill=im_mean),
|
|
||||||
transforms.Resize(384, InterpolationMode.BICUBIC),
|
|
||||||
transforms.RandomCrop((384, 384), pad_if_needed=True, fill=im_mean),
|
|
||||||
])
|
|
||||||
|
|
||||||
self.pair_gt_dual_transform = transforms.Compose([
|
|
||||||
transforms.RandomAffine(degrees=20, scale=(0.9,1.1), shear=10, interpolation=InterpolationMode.BICUBIC, fill=0),
|
|
||||||
transforms.Resize(384, InterpolationMode.NEAREST),
|
|
||||||
transforms.RandomCrop((384, 384), pad_if_needed=True, fill=0),
|
|
||||||
])
|
|
||||||
|
|
||||||
|
|
||||||
# These transform are the same for all pairs in the sampled sequence
|
|
||||||
self.all_im_lone_transform = transforms.Compose([
|
|
||||||
transforms.ColorJitter(0.1, 0.05, 0.05, 0.05),
|
|
||||||
transforms.RandomGrayscale(0.05),
|
|
||||||
])
|
|
||||||
|
|
||||||
self.all_im_dual_transform = transforms.Compose([
|
|
||||||
transforms.RandomAffine(degrees=0, scale=(0.8, 1.5), fill=im_mean),
|
|
||||||
transforms.RandomHorizontalFlip(),
|
|
||||||
])
|
|
||||||
|
|
||||||
self.all_gt_dual_transform = transforms.Compose([
|
|
||||||
transforms.RandomAffine(degrees=0, scale=(0.8, 1.5), fill=0),
|
|
||||||
transforms.RandomHorizontalFlip(),
|
|
||||||
])
|
|
||||||
|
|
||||||
# Final transform without randomness
|
|
||||||
self.final_im_transform = transforms.Compose([
|
|
||||||
transforms.ToTensor(),
|
|
||||||
im_normalization,
|
|
||||||
])
|
|
||||||
|
|
||||||
self.final_gt_transform = transforms.Compose([
|
|
||||||
transforms.ToTensor(),
|
|
||||||
])
|
|
||||||
|
|
||||||
def _get_sample(self, idx):
|
|
||||||
im = Image.open(self.im_list[idx]).convert('RGB')
|
|
||||||
gt = Image.open(self.im_list[idx][:-3]+'png').convert('L')
|
|
||||||
|
|
||||||
sequence_seed = np.random.randint(2147483647)
|
|
||||||
|
|
||||||
images = []
|
|
||||||
masks = []
|
|
||||||
for _ in range(self.num_frames):
|
|
||||||
reseed(sequence_seed)
|
|
||||||
this_im = self.all_im_dual_transform(im)
|
|
||||||
this_im = self.all_im_lone_transform(this_im)
|
|
||||||
reseed(sequence_seed)
|
|
||||||
this_gt = self.all_gt_dual_transform(gt)
|
|
||||||
|
|
||||||
pairwise_seed = np.random.randint(2147483647)
|
|
||||||
reseed(pairwise_seed)
|
|
||||||
this_im = self.pair_im_dual_transform(this_im)
|
|
||||||
this_im = self.pair_im_lone_transform(this_im)
|
|
||||||
reseed(pairwise_seed)
|
|
||||||
this_gt = self.pair_gt_dual_transform(this_gt)
|
|
||||||
|
|
||||||
# Use TPS only some of the times
|
|
||||||
# Not because TPS is bad -- just that it is too slow and I need to speed up data loading
|
|
||||||
if np.random.rand() < 0.33:
|
|
||||||
this_im, this_gt = random_tps_warp(this_im, this_gt, scale=0.02)
|
|
||||||
|
|
||||||
this_im = self.final_im_transform(this_im)
|
|
||||||
this_gt = self.final_gt_transform(this_gt)
|
|
||||||
|
|
||||||
images.append(this_im)
|
|
||||||
masks.append(this_gt)
|
|
||||||
|
|
||||||
images = torch.stack(images, 0)
|
|
||||||
masks = torch.stack(masks, 0)
|
|
||||||
|
|
||||||
return images, masks.numpy()
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
additional_objects = np.random.randint(self.max_num_obj)
|
|
||||||
indices = [idx, *np.random.randint(self.__len__(), size=additional_objects)]
|
|
||||||
|
|
||||||
merged_images = None
|
|
||||||
merged_masks = np.zeros((self.num_frames, 384, 384), dtype=np.int64)
|
|
||||||
|
|
||||||
for i, list_id in enumerate(indices):
|
|
||||||
images, masks = self._get_sample(list_id)
|
|
||||||
if merged_images is None:
|
|
||||||
merged_images = images
|
|
||||||
else:
|
|
||||||
merged_images = merged_images*(1-masks) + images*masks
|
|
||||||
merged_masks[masks[:,0]>0.5] = (i+1)
|
|
||||||
|
|
||||||
masks = merged_masks
|
|
||||||
|
|
||||||
labels = np.unique(masks[0])
|
|
||||||
# Remove background
|
|
||||||
labels = labels[labels!=0]
|
|
||||||
target_objects = labels.tolist()
|
|
||||||
|
|
||||||
# Generate one-hot ground-truth
|
|
||||||
cls_gt = np.zeros((self.num_frames, 384, 384), dtype=np.int64)
|
|
||||||
first_frame_gt = np.zeros((1, self.max_num_obj, 384, 384), dtype=np.int64)
|
|
||||||
for i, l in enumerate(target_objects):
|
|
||||||
this_mask = (masks==l)
|
|
||||||
cls_gt[this_mask] = i+1
|
|
||||||
first_frame_gt[0,i] = (this_mask[0])
|
|
||||||
cls_gt = np.expand_dims(cls_gt, 1)
|
|
||||||
|
|
||||||
info = {}
|
|
||||||
info['name'] = self.im_list[idx]
|
|
||||||
info['num_objects'] = max(1, len(target_objects))
|
|
||||||
|
|
||||||
# 1 if object exist, 0 otherwise
|
|
||||||
selector = [1 if i < info['num_objects'] else 0 for i in range(self.max_num_obj)]
|
|
||||||
selector = torch.FloatTensor(selector)
|
|
||||||
|
|
||||||
data = {
|
|
||||||
'rgb': merged_images,
|
|
||||||
'first_frame_gt': first_frame_gt,
|
|
||||||
'cls_gt': cls_gt,
|
|
||||||
'selector': selector,
|
|
||||||
'info': info
|
|
||||||
}
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.im_list)
|
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
import cv2
|
|
||||||
import thinplate as tps
|
|
||||||
|
|
||||||
cv2.setNumThreads(0)
|
|
||||||
|
|
||||||
def pick_random_points(h, w, n_samples):
|
|
||||||
y_idx = np.random.choice(np.arange(h), size=n_samples, replace=False)
|
|
||||||
x_idx = np.random.choice(np.arange(w), size=n_samples, replace=False)
|
|
||||||
return y_idx/h, x_idx/w
|
|
||||||
|
|
||||||
|
|
||||||
def warp_dual_cv(img, mask, c_src, c_dst):
|
|
||||||
dshape = img.shape
|
|
||||||
theta = tps.tps_theta_from_points(c_src, c_dst, reduced=True)
|
|
||||||
grid = tps.tps_grid(theta, c_dst, dshape)
|
|
||||||
mapx, mapy = tps.tps_grid_to_remap(grid, img.shape)
|
|
||||||
return cv2.remap(img, mapx, mapy, cv2.INTER_LINEAR), cv2.remap(mask, mapx, mapy, cv2.INTER_NEAREST)
|
|
||||||
|
|
||||||
|
|
||||||
def random_tps_warp(img, mask, scale, n_ctrl_pts=12):
|
|
||||||
"""
|
|
||||||
Apply a random TPS warp of the input image and mask
|
|
||||||
Uses randomness from numpy
|
|
||||||
"""
|
|
||||||
img = np.asarray(img)
|
|
||||||
mask = np.asarray(mask)
|
|
||||||
|
|
||||||
h, w = mask.shape
|
|
||||||
points = pick_random_points(h, w, n_ctrl_pts)
|
|
||||||
c_src = np.stack(points, 1)
|
|
||||||
c_dst = c_src + np.random.normal(scale=scale, size=c_src.shape)
|
|
||||||
warp_im, warp_gt = warp_dual_cv(img, mask, c_src, c_dst)
|
|
||||||
|
|
||||||
return Image.fromarray(warp_im), Image.fromarray(warp_gt)
|
|
||||||
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def all_to_onehot(masks, labels):
|
|
||||||
if len(masks.shape) == 3:
|
|
||||||
Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1], masks.shape[2]), dtype=np.uint8)
|
|
||||||
else:
|
|
||||||
Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1]), dtype=np.uint8)
|
|
||||||
|
|
||||||
for ni, l in enumerate(labels):
|
|
||||||
Ms[ni] = (masks == l).astype(np.uint8)
|
|
||||||
|
|
||||||
return Ms
|
|
||||||
@@ -1,216 +0,0 @@
|
|||||||
import os
|
|
||||||
from os import path, replace
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.utils.data.dataset import Dataset
|
|
||||||
from torchvision import transforms
|
|
||||||
from torchvision.transforms import InterpolationMode
|
|
||||||
from PIL import Image
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from dataset.range_transform import im_normalization, im_mean
|
|
||||||
from dataset.reseed import reseed
|
|
||||||
|
|
||||||
|
|
||||||
class VOSDataset(Dataset):
|
|
||||||
"""
|
|
||||||
Works for DAVIS/YouTubeVOS/BL30K training
|
|
||||||
For each sequence:
|
|
||||||
- Pick three frames
|
|
||||||
- Pick two objects
|
|
||||||
- Apply some random transforms that are the same for all frames
|
|
||||||
- Apply random transform to each of the frame
|
|
||||||
- The distance between frames is controlled
|
|
||||||
"""
|
|
||||||
def __init__(self, im_root, gt_root, max_jump, is_bl, subset=None, num_frames=3, max_num_obj=3, finetune=False):
|
|
||||||
self.im_root = im_root
|
|
||||||
self.gt_root = gt_root
|
|
||||||
self.max_jump = max_jump
|
|
||||||
self.is_bl = is_bl
|
|
||||||
self.num_frames = num_frames
|
|
||||||
self.max_num_obj = max_num_obj
|
|
||||||
|
|
||||||
self.videos = []
|
|
||||||
self.frames = {}
|
|
||||||
|
|
||||||
vid_list = sorted(os.listdir(self.im_root))
|
|
||||||
# Pre-filtering
|
|
||||||
for vid in vid_list:
|
|
||||||
if subset is not None:
|
|
||||||
if vid not in subset:
|
|
||||||
continue
|
|
||||||
frames = sorted(os.listdir(os.path.join(self.im_root, vid)))
|
|
||||||
if len(frames) < num_frames:
|
|
||||||
continue
|
|
||||||
self.frames[vid] = frames
|
|
||||||
self.videos.append(vid)
|
|
||||||
|
|
||||||
print('%d out of %d videos accepted in %s.' % (len(self.videos), len(vid_list), im_root))
|
|
||||||
|
|
||||||
# These set of transform is the same for im/gt pairs, but different among the 3 sampled frames
|
|
||||||
self.pair_im_lone_transform = transforms.Compose([
|
|
||||||
transforms.ColorJitter(0.01, 0.01, 0.01, 0),
|
|
||||||
])
|
|
||||||
|
|
||||||
self.pair_im_dual_transform = transforms.Compose([
|
|
||||||
transforms.RandomAffine(degrees=0 if finetune or self.is_bl else 15, shear=0 if finetune or self.is_bl else 10, interpolation=InterpolationMode.BILINEAR, fill=im_mean),
|
|
||||||
])
|
|
||||||
|
|
||||||
self.pair_gt_dual_transform = transforms.Compose([
|
|
||||||
transforms.RandomAffine(degrees=0 if finetune or self.is_bl else 15, shear=0 if finetune or self.is_bl else 10, interpolation=InterpolationMode.NEAREST, fill=0),
|
|
||||||
])
|
|
||||||
|
|
||||||
# These transform are the same for all pairs in the sampled sequence
|
|
||||||
self.all_im_lone_transform = transforms.Compose([
|
|
||||||
transforms.ColorJitter(0.1, 0.03, 0.03, 0),
|
|
||||||
transforms.RandomGrayscale(0.05),
|
|
||||||
])
|
|
||||||
|
|
||||||
if self.is_bl:
|
|
||||||
# Use a different cropping scheme for the blender dataset because the image size is different
|
|
||||||
self.all_im_dual_transform = transforms.Compose([
|
|
||||||
transforms.RandomHorizontalFlip(),
|
|
||||||
transforms.RandomResizedCrop((384, 384), scale=(0.25, 1.00), interpolation=InterpolationMode.BILINEAR)
|
|
||||||
])
|
|
||||||
|
|
||||||
self.all_gt_dual_transform = transforms.Compose([
|
|
||||||
transforms.RandomHorizontalFlip(),
|
|
||||||
transforms.RandomResizedCrop((384, 384), scale=(0.25, 1.00), interpolation=InterpolationMode.NEAREST)
|
|
||||||
])
|
|
||||||
else:
|
|
||||||
self.all_im_dual_transform = transforms.Compose([
|
|
||||||
transforms.RandomHorizontalFlip(),
|
|
||||||
transforms.RandomResizedCrop((384, 384), scale=(0.36,1.00), interpolation=InterpolationMode.BILINEAR)
|
|
||||||
])
|
|
||||||
|
|
||||||
self.all_gt_dual_transform = transforms.Compose([
|
|
||||||
transforms.RandomHorizontalFlip(),
|
|
||||||
transforms.RandomResizedCrop((384, 384), scale=(0.36,1.00), interpolation=InterpolationMode.NEAREST)
|
|
||||||
])
|
|
||||||
|
|
||||||
# Final transform without randomness
|
|
||||||
self.final_im_transform = transforms.Compose([
|
|
||||||
transforms.ToTensor(),
|
|
||||||
im_normalization,
|
|
||||||
])
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
video = self.videos[idx]
|
|
||||||
info = {}
|
|
||||||
info['name'] = video
|
|
||||||
|
|
||||||
vid_im_path = path.join(self.im_root, video)
|
|
||||||
vid_gt_path = path.join(self.gt_root, video)
|
|
||||||
frames = self.frames[video]
|
|
||||||
|
|
||||||
trials = 0
|
|
||||||
while trials < 5:
|
|
||||||
info['frames'] = [] # Appended with actual frames
|
|
||||||
|
|
||||||
num_frames = self.num_frames
|
|
||||||
length = len(frames)
|
|
||||||
this_max_jump = min(len(frames), self.max_jump)
|
|
||||||
|
|
||||||
# iterative sampling
|
|
||||||
frames_idx = [np.random.randint(length)]
|
|
||||||
acceptable_set = set(range(max(0, frames_idx[-1]-this_max_jump), min(length, frames_idx[-1]+this_max_jump+1))).difference(set(frames_idx))
|
|
||||||
while(len(frames_idx) < num_frames):
|
|
||||||
idx = np.random.choice(list(acceptable_set))
|
|
||||||
frames_idx.append(idx)
|
|
||||||
new_set = set(range(max(0, frames_idx[-1]-this_max_jump), min(length, frames_idx[-1]+this_max_jump+1)))
|
|
||||||
acceptable_set = acceptable_set.union(new_set).difference(set(frames_idx))
|
|
||||||
|
|
||||||
frames_idx = sorted(frames_idx)
|
|
||||||
if np.random.rand() < 0.5:
|
|
||||||
# Reverse time
|
|
||||||
frames_idx = frames_idx[::-1]
|
|
||||||
|
|
||||||
sequence_seed = np.random.randint(2147483647)
|
|
||||||
images = []
|
|
||||||
masks = []
|
|
||||||
target_objects = []
|
|
||||||
for f_idx in frames_idx:
|
|
||||||
jpg_name = frames[f_idx][:-4] + '.jpg'
|
|
||||||
png_name = frames[f_idx][:-4] + '.png'
|
|
||||||
info['frames'].append(jpg_name)
|
|
||||||
|
|
||||||
reseed(sequence_seed)
|
|
||||||
this_im = Image.open(path.join(vid_im_path, jpg_name)).convert('RGB')
|
|
||||||
this_im = self.all_im_dual_transform(this_im)
|
|
||||||
this_im = self.all_im_lone_transform(this_im)
|
|
||||||
reseed(sequence_seed)
|
|
||||||
this_gt = Image.open(path.join(vid_gt_path, png_name)).convert('P')
|
|
||||||
this_gt = self.all_gt_dual_transform(this_gt)
|
|
||||||
|
|
||||||
pairwise_seed = np.random.randint(2147483647)
|
|
||||||
reseed(pairwise_seed)
|
|
||||||
this_im = self.pair_im_dual_transform(this_im)
|
|
||||||
this_im = self.pair_im_lone_transform(this_im)
|
|
||||||
reseed(pairwise_seed)
|
|
||||||
this_gt = self.pair_gt_dual_transform(this_gt)
|
|
||||||
|
|
||||||
this_im = self.final_im_transform(this_im)
|
|
||||||
this_gt = np.array(this_gt)
|
|
||||||
|
|
||||||
images.append(this_im)
|
|
||||||
masks.append(this_gt)
|
|
||||||
|
|
||||||
images = torch.stack(images, 0)
|
|
||||||
|
|
||||||
labels = np.unique(masks[0])
|
|
||||||
# Remove background
|
|
||||||
labels = labels[labels!=0]
|
|
||||||
|
|
||||||
if self.is_bl:
|
|
||||||
# Find large enough labels
|
|
||||||
good_lables = []
|
|
||||||
for l in labels:
|
|
||||||
pixel_sum = (masks[0]==l).sum()
|
|
||||||
if pixel_sum > 10*10:
|
|
||||||
# OK if the object is always this small
|
|
||||||
# Not OK if it is actually much bigger
|
|
||||||
if pixel_sum > 30*30:
|
|
||||||
good_lables.append(l)
|
|
||||||
elif max((masks[1]==l).sum(), (masks[2]==l).sum()) < 20*20:
|
|
||||||
good_lables.append(l)
|
|
||||||
labels = np.array(good_lables, dtype=np.uint8)
|
|
||||||
|
|
||||||
if len(labels) == 0:
|
|
||||||
target_objects = []
|
|
||||||
trials += 1
|
|
||||||
else:
|
|
||||||
target_objects = labels.tolist()
|
|
||||||
break
|
|
||||||
|
|
||||||
if len(target_objects) > self.max_num_obj:
|
|
||||||
target_objects = np.random.choice(target_objects, size=self.max_num_obj, replace=False)
|
|
||||||
|
|
||||||
info['num_objects'] = max(1, len(target_objects))
|
|
||||||
|
|
||||||
masks = np.stack(masks, 0)
|
|
||||||
|
|
||||||
# Generate one-hot ground-truth
|
|
||||||
cls_gt = np.zeros((self.num_frames, 384, 384), dtype=np.int64)
|
|
||||||
first_frame_gt = np.zeros((1, self.max_num_obj, 384, 384), dtype=np.int64)
|
|
||||||
for i, l in enumerate(target_objects):
|
|
||||||
this_mask = (masks==l)
|
|
||||||
cls_gt[this_mask] = i+1
|
|
||||||
first_frame_gt[0,i] = (this_mask[0])
|
|
||||||
cls_gt = np.expand_dims(cls_gt, 1)
|
|
||||||
|
|
||||||
# 1 if object exist, 0 otherwise
|
|
||||||
selector = [1 if i < info['num_objects'] else 0 for i in range(self.max_num_obj)]
|
|
||||||
selector = torch.FloatTensor(selector)
|
|
||||||
|
|
||||||
data = {
|
|
||||||
'rgb': images,
|
|
||||||
'first_frame_gt': first_frame_gt,
|
|
||||||
'cls_gt': cls_gt,
|
|
||||||
'selector': selector,
|
|
||||||
'info': info,
|
|
||||||
}
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.videos)
|
|
||||||
257
tracker/eval.py
257
tracker/eval.py
@@ -1,257 +0,0 @@
|
|||||||
import os
|
|
||||||
from os import path
|
|
||||||
from argparse import ArgumentParser
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from inference.data.test_datasets import LongTestDataset, DAVISTestDataset, YouTubeVOSTestDataset
|
|
||||||
from inference.data.mask_mapper import MaskMapper
|
|
||||||
from model.network import XMem
|
|
||||||
from inference.inference_core import InferenceCore
|
|
||||||
|
|
||||||
from progressbar import progressbar
|
|
||||||
|
|
||||||
try:
|
|
||||||
import hickle as hkl
|
|
||||||
except ImportError:
|
|
||||||
print('Failed to import hickle. Fine if not using multi-scale testing.')
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
Arguments loading
|
|
||||||
"""
|
|
||||||
parser = ArgumentParser()
|
|
||||||
parser.add_argument('--model', default='/ssd1/gaomingqi/checkpoints/XMem-s012.pth')
|
|
||||||
|
|
||||||
# Data options
|
|
||||||
parser.add_argument('--d16_path', default='../DAVIS/2016')
|
|
||||||
parser.add_argument('--d17_path', default='../DAVIS/2017')
|
|
||||||
parser.add_argument('--y18_path', default='/ssd1/gaomingqi/datasets/youtube-vos/2018')
|
|
||||||
parser.add_argument('--y19_path', default='../YouTube')
|
|
||||||
parser.add_argument('--lv_path', default='../long_video_set')
|
|
||||||
# For generic (G) evaluation, point to a folder that contains "JPEGImages" and "Annotations"
|
|
||||||
parser.add_argument('--generic_path')
|
|
||||||
|
|
||||||
parser.add_argument('--dataset', help='D16/D17/Y18/Y19/LV1/LV3/G', default='D17')
|
|
||||||
parser.add_argument('--split', help='val/test', default='val')
|
|
||||||
parser.add_argument('--output', default=None)
|
|
||||||
parser.add_argument('--save_all', action='store_true',
|
|
||||||
help='Save all frames. Useful only in YouTubeVOS/long-time video', )
|
|
||||||
|
|
||||||
parser.add_argument('--benchmark', action='store_true', help='enable to disable amp for FPS benchmarking')
|
|
||||||
|
|
||||||
# Long-term memory options
|
|
||||||
parser.add_argument('--disable_long_term', action='store_true')
|
|
||||||
parser.add_argument('--max_mid_term_frames', help='T_max in paper, decrease to save memory', type=int, default=10)
|
|
||||||
parser.add_argument('--min_mid_term_frames', help='T_min in paper, decrease to save memory', type=int, default=5)
|
|
||||||
parser.add_argument('--max_long_term_elements', help='LT_max in paper, increase if objects disappear for a long time',
|
|
||||||
type=int, default=10000)
|
|
||||||
parser.add_argument('--num_prototypes', help='P in paper', type=int, default=128)
|
|
||||||
|
|
||||||
parser.add_argument('--top_k', type=int, default=30)
|
|
||||||
parser.add_argument('--mem_every', help='r in paper. Increase to improve running speed.', type=int, default=5)
|
|
||||||
parser.add_argument('--deep_update_every', help='Leave -1 normally to synchronize with mem_every', type=int, default=-1)
|
|
||||||
|
|
||||||
# Multi-scale options
|
|
||||||
parser.add_argument('--save_scores', action='store_true')
|
|
||||||
parser.add_argument('--flip', action='store_true')
|
|
||||||
parser.add_argument('--size', default=480, type=int,
|
|
||||||
help='Resize the shorter side to this size. -1 to use original resolution. ')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
config = vars(args)
|
|
||||||
config['enable_long_term'] = not config['disable_long_term']
|
|
||||||
|
|
||||||
if args.output is None:
|
|
||||||
args.output = f'../output/{args.dataset}_{args.split}'
|
|
||||||
print(f'Output path not provided. Defaulting to {args.output}')
|
|
||||||
|
|
||||||
"""
|
|
||||||
Data preparation
|
|
||||||
"""
|
|
||||||
is_youtube = args.dataset.startswith('Y')
|
|
||||||
is_davis = args.dataset.startswith('D')
|
|
||||||
is_lv = args.dataset.startswith('LV')
|
|
||||||
|
|
||||||
if is_youtube or args.save_scores:
|
|
||||||
out_path = path.join(args.output, 'Annotations')
|
|
||||||
else:
|
|
||||||
out_path = args.output
|
|
||||||
|
|
||||||
if is_youtube:
|
|
||||||
if args.dataset == 'Y18':
|
|
||||||
yv_path = args.y18_path
|
|
||||||
elif args.dataset == 'Y19':
|
|
||||||
yv_path = args.y19_path
|
|
||||||
|
|
||||||
if args.split == 'val':
|
|
||||||
args.split = 'valid'
|
|
||||||
meta_dataset = YouTubeVOSTestDataset(data_root=yv_path, split='valid', size=args.size)
|
|
||||||
elif args.split == 'test':
|
|
||||||
meta_dataset = YouTubeVOSTestDataset(data_root=yv_path, split='test', size=args.size)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
elif is_davis:
|
|
||||||
if args.dataset == 'D16':
|
|
||||||
if args.split == 'val':
|
|
||||||
# Set up Dataset, a small hack to use the image set in the 2017 folder because the 2016 one is of a different format
|
|
||||||
meta_dataset = DAVISTestDataset(args.d16_path, imset='../../2017/trainval/ImageSets/2016/val.txt', size=args.size)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
palette = None
|
|
||||||
elif args.dataset == 'D17':
|
|
||||||
if args.split == 'val':
|
|
||||||
meta_dataset = DAVISTestDataset(path.join(args.d17_path, 'trainval'), imset='2017/val.txt', size=args.size)
|
|
||||||
elif args.split == 'test':
|
|
||||||
meta_dataset = DAVISTestDataset(path.join(args.d17_path, 'test-dev'), imset='2017/test-dev.txt', size=args.size)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
elif is_lv:
|
|
||||||
if args.dataset == 'LV1':
|
|
||||||
meta_dataset = LongTestDataset(path.join(args.lv_path, 'long_video'))
|
|
||||||
elif args.dataset == 'LV3':
|
|
||||||
meta_dataset = LongTestDataset(path.join(args.lv_path, 'long_video_x3'))
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
elif args.dataset == 'G':
|
|
||||||
meta_dataset = LongTestDataset(path.join(args.generic_path), size=args.size)
|
|
||||||
if not args.save_all:
|
|
||||||
args.save_all = True
|
|
||||||
print('save_all is forced to be true in generic evaluation mode.')
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
torch.autograd.set_grad_enabled(False)
|
|
||||||
|
|
||||||
# Set up loader
|
|
||||||
meta_loader = meta_dataset.get_datasets()
|
|
||||||
|
|
||||||
# Load our checkpoint
|
|
||||||
network = XMem(config, args.model).cuda().eval()
|
|
||||||
if args.model is not None:
|
|
||||||
model_weights = torch.load(args.model)
|
|
||||||
network.load_weights(model_weights, init_as_zero_if_needed=True)
|
|
||||||
else:
|
|
||||||
print('No model loaded.')
|
|
||||||
|
|
||||||
total_process_time = 0
|
|
||||||
total_frames = 0
|
|
||||||
|
|
||||||
# Start eval
|
|
||||||
for vid_reader in progressbar(meta_loader, max_value=len(meta_dataset), redirect_stdout=True):
|
|
||||||
|
|
||||||
loader = DataLoader(vid_reader, batch_size=1, shuffle=False, num_workers=2)
|
|
||||||
vid_name = vid_reader.vid_name
|
|
||||||
vid_length = len(loader)
|
|
||||||
# no need to count usage for LT if the video is not that long anyway
|
|
||||||
config['enable_long_term_count_usage'] = (
|
|
||||||
config['enable_long_term'] and
|
|
||||||
(vid_length
|
|
||||||
/ (config['max_mid_term_frames']-config['min_mid_term_frames'])
|
|
||||||
* config['num_prototypes'])
|
|
||||||
>= config['max_long_term_elements']
|
|
||||||
)
|
|
||||||
|
|
||||||
mapper = MaskMapper()
|
|
||||||
processor = InferenceCore(network, config=config)
|
|
||||||
first_mask_loaded = False
|
|
||||||
|
|
||||||
for ti, data in enumerate(loader):
|
|
||||||
with torch.cuda.amp.autocast(enabled=not args.benchmark):
|
|
||||||
rgb = data['rgb'].cuda()[0]
|
|
||||||
msk = data.get('mask')
|
|
||||||
info = data['info']
|
|
||||||
frame = info['frame'][0]
|
|
||||||
shape = info['shape']
|
|
||||||
need_resize = info['need_resize'][0]
|
|
||||||
|
|
||||||
"""
|
|
||||||
For timing see https://discuss.pytorch.org/t/how-to-measure-time-in-pytorch/26964
|
|
||||||
Seems to be very similar in testing as my previous timing method
|
|
||||||
with two cuda sync + time.time() in STCN though
|
|
||||||
"""
|
|
||||||
start = torch.cuda.Event(enable_timing=True)
|
|
||||||
end = torch.cuda.Event(enable_timing=True)
|
|
||||||
start.record()
|
|
||||||
|
|
||||||
if not first_mask_loaded:
|
|
||||||
if msk is not None:
|
|
||||||
first_mask_loaded = True
|
|
||||||
else:
|
|
||||||
# no point to do anything without a mask
|
|
||||||
continue
|
|
||||||
|
|
||||||
if args.flip:
|
|
||||||
rgb = torch.flip(rgb, dims=[-1])
|
|
||||||
msk = torch.flip(msk, dims=[-1]) if msk is not None else None
|
|
||||||
|
|
||||||
# Map possibly non-continuous labels to continuous ones
|
|
||||||
if msk is not None:
|
|
||||||
msk, labels = mapper.convert_mask(msk[0].numpy())
|
|
||||||
msk = torch.Tensor(msk).cuda()
|
|
||||||
if need_resize:
|
|
||||||
msk = vid_reader.resize_mask(msk.unsqueeze(0))[0]
|
|
||||||
processor.set_all_labels(list(mapper.remappings.values()))
|
|
||||||
else:
|
|
||||||
labels = None
|
|
||||||
|
|
||||||
# Run the model on this frame
|
|
||||||
prob = processor.step(rgb, msk, labels, end=(ti==vid_length-1)) # 0, background, >0, objects
|
|
||||||
|
|
||||||
# Upsample to original size if needed
|
|
||||||
if need_resize:
|
|
||||||
prob = F.interpolate(prob.unsqueeze(1), shape, mode='bilinear', align_corners=False)[:,0]
|
|
||||||
|
|
||||||
end.record()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
total_process_time += (start.elapsed_time(end)/1000)
|
|
||||||
total_frames += 1
|
|
||||||
|
|
||||||
if args.flip:
|
|
||||||
prob = torch.flip(prob, dims=[-1])
|
|
||||||
|
|
||||||
# Probability mask -> index mask
|
|
||||||
out_mask = torch.argmax(prob, dim=0)
|
|
||||||
out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
|
|
||||||
|
|
||||||
if args.save_scores:
|
|
||||||
prob = (prob.detach().cpu().numpy()*255).astype(np.uint8)
|
|
||||||
|
|
||||||
# Save the mask
|
|
||||||
if args.save_all or info['save'][0]:
|
|
||||||
this_out_path = path.join(out_path, vid_name)
|
|
||||||
os.makedirs(this_out_path, exist_ok=True)
|
|
||||||
out_mask = mapper.remap_index_mask(out_mask)
|
|
||||||
out_img = Image.fromarray(out_mask)
|
|
||||||
if vid_reader.get_palette() is not None:
|
|
||||||
out_img.putpalette(vid_reader.get_palette())
|
|
||||||
out_img.save(os.path.join(this_out_path, frame[:-4]+'.png'))
|
|
||||||
|
|
||||||
if args.save_scores:
|
|
||||||
np_path = path.join(args.output, 'Scores', vid_name)
|
|
||||||
os.makedirs(np_path, exist_ok=True)
|
|
||||||
if ti==len(loader)-1:
|
|
||||||
hkl.dump(mapper.remappings, path.join(np_path, f'backward.hkl'), mode='w')
|
|
||||||
if args.save_all or info['save'][0]:
|
|
||||||
hkl.dump(prob, path.join(np_path, f'{frame[:-4]}.hkl'), mode='w', compression='lzf')
|
|
||||||
|
|
||||||
|
|
||||||
print(f'Total processing time: {total_process_time}')
|
|
||||||
print(f'Total processed frames: {total_frames}')
|
|
||||||
print(f'FPS: {total_frames / total_process_time}')
|
|
||||||
print(f'Max allocated memory (MB): {torch.cuda.max_memory_allocated() / (2**20)}')
|
|
||||||
|
|
||||||
if not args.save_scores:
|
|
||||||
if is_youtube:
|
|
||||||
print('Making zip for YouTubeVOS...')
|
|
||||||
shutil.make_archive(path.join(args.output, path.basename(args.output)), 'zip', args.output, 'Annotations')
|
|
||||||
elif is_davis and args.split == 'test':
|
|
||||||
print('Making zip for DAVIS test-dev...')
|
|
||||||
shutil.make_archive(args.output, 'zip', args.output)
|
|
||||||
@@ -1,96 +0,0 @@
|
|||||||
import os
|
|
||||||
from os import path
|
|
||||||
import json
|
|
||||||
|
|
||||||
from inference.data.video_reader import VideoReader
|
|
||||||
|
|
||||||
|
|
||||||
class LongTestDataset:
|
|
||||||
def __init__(self, data_root, size=-1):
|
|
||||||
self.image_dir = path.join(data_root, 'JPEGImages')
|
|
||||||
self.mask_dir = path.join(data_root, 'Annotations')
|
|
||||||
self.size = size
|
|
||||||
|
|
||||||
self.vid_list = sorted(os.listdir(self.image_dir))
|
|
||||||
|
|
||||||
def get_datasets(self):
|
|
||||||
for video in self.vid_list:
|
|
||||||
yield VideoReader(video,
|
|
||||||
path.join(self.image_dir, video),
|
|
||||||
path.join(self.mask_dir, video),
|
|
||||||
to_save = [
|
|
||||||
name[:-4] for name in os.listdir(path.join(self.mask_dir, video))
|
|
||||||
],
|
|
||||||
size=self.size,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.vid_list)
|
|
||||||
|
|
||||||
|
|
||||||
class DAVISTestDataset:
|
|
||||||
def __init__(self, data_root, imset='2017/val.txt', size=-1):
|
|
||||||
if size != 480:
|
|
||||||
self.image_dir = path.join(data_root, 'JPEGImages', 'Full-Resolution')
|
|
||||||
self.mask_dir = path.join(data_root, 'Annotations', 'Full-Resolution')
|
|
||||||
if not path.exists(self.image_dir):
|
|
||||||
print(f'{self.image_dir} not found. Look at other options.')
|
|
||||||
self.image_dir = path.join(data_root, 'JPEGImages', '1080p')
|
|
||||||
self.mask_dir = path.join(data_root, 'Annotations', '1080p')
|
|
||||||
assert path.exists(self.image_dir), 'path not found'
|
|
||||||
else:
|
|
||||||
self.image_dir = path.join(data_root, 'JPEGImages', '480p')
|
|
||||||
self.mask_dir = path.join(data_root, 'Annotations', '480p')
|
|
||||||
self.size_dir = path.join(data_root, 'JPEGImages', '480p')
|
|
||||||
self.size = size
|
|
||||||
|
|
||||||
with open(path.join(data_root, 'ImageSets', imset)) as f:
|
|
||||||
self.vid_list = sorted([line.strip() for line in f])
|
|
||||||
|
|
||||||
def get_datasets(self):
|
|
||||||
for video in self.vid_list:
|
|
||||||
yield VideoReader(video,
|
|
||||||
path.join(self.image_dir, video),
|
|
||||||
path.join(self.mask_dir, video),
|
|
||||||
size=self.size,
|
|
||||||
size_dir=path.join(self.size_dir, video),
|
|
||||||
)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.vid_list)
|
|
||||||
|
|
||||||
|
|
||||||
class YouTubeVOSTestDataset:
|
|
||||||
def __init__(self, data_root, split, size=480):
|
|
||||||
self.image_dir = path.join(data_root, 'all_frames', split+'_all_frames', 'JPEGImages')
|
|
||||||
self.mask_dir = path.join(data_root, split, 'Annotations')
|
|
||||||
self.size = size
|
|
||||||
|
|
||||||
self.vid_list = sorted(os.listdir(self.image_dir))
|
|
||||||
self.req_frame_list = {}
|
|
||||||
|
|
||||||
with open(path.join(data_root, split, 'meta.json')) as f:
|
|
||||||
# read meta.json to know which frame is required for evaluation
|
|
||||||
meta = json.load(f)['videos']
|
|
||||||
|
|
||||||
for vid in self.vid_list:
|
|
||||||
req_frames = []
|
|
||||||
objects = meta[vid]['objects']
|
|
||||||
for value in objects.values():
|
|
||||||
req_frames.extend(value['frames'])
|
|
||||||
|
|
||||||
req_frames = list(set(req_frames))
|
|
||||||
self.req_frame_list[vid] = req_frames
|
|
||||||
|
|
||||||
def get_datasets(self):
|
|
||||||
for video in self.vid_list:
|
|
||||||
yield VideoReader(video,
|
|
||||||
path.join(self.image_dir, video),
|
|
||||||
path.join(self.mask_dir, video),
|
|
||||||
size=self.size,
|
|
||||||
to_save=self.req_frame_list[video],
|
|
||||||
use_all_mask=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.vid_list)
|
|
||||||
@@ -1,100 +0,0 @@
|
|||||||
import os
|
|
||||||
from os import path
|
|
||||||
|
|
||||||
from torch.utils.data.dataset import Dataset
|
|
||||||
from torchvision import transforms
|
|
||||||
from torchvision.transforms import InterpolationMode
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from PIL import Image
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from dataset.range_transform import im_normalization
|
|
||||||
|
|
||||||
|
|
||||||
class VideoReader(Dataset):
|
|
||||||
"""
|
|
||||||
This class is used to read a video, one frame at a time
|
|
||||||
"""
|
|
||||||
def __init__(self, vid_name, image_dir, mask_dir, size=-1, to_save=None, use_all_mask=False, size_dir=None):
|
|
||||||
"""
|
|
||||||
image_dir - points to a directory of jpg images
|
|
||||||
mask_dir - points to a directory of png masks
|
|
||||||
size - resize min. side to size. Does nothing if <0.
|
|
||||||
to_save - optionally contains a list of file names without extensions
|
|
||||||
where the segmentation mask is required
|
|
||||||
use_all_mask - when true, read all available mask in mask_dir.
|
|
||||||
Default false. Set to true for YouTubeVOS validation.
|
|
||||||
"""
|
|
||||||
self.vid_name = vid_name
|
|
||||||
self.image_dir = image_dir
|
|
||||||
self.mask_dir = mask_dir
|
|
||||||
self.to_save = to_save
|
|
||||||
self.use_all_mask = use_all_mask
|
|
||||||
if size_dir is None:
|
|
||||||
self.size_dir = self.image_dir
|
|
||||||
else:
|
|
||||||
self.size_dir = size_dir
|
|
||||||
|
|
||||||
self.frames = sorted(os.listdir(self.image_dir))
|
|
||||||
self.palette = Image.open(path.join(mask_dir, sorted(os.listdir(mask_dir))[0])).getpalette()
|
|
||||||
self.first_gt_path = path.join(self.mask_dir, sorted(os.listdir(self.mask_dir))[0])
|
|
||||||
|
|
||||||
if size < 0:
|
|
||||||
self.im_transform = transforms.Compose([
|
|
||||||
transforms.ToTensor(),
|
|
||||||
im_normalization,
|
|
||||||
])
|
|
||||||
else:
|
|
||||||
self.im_transform = transforms.Compose([
|
|
||||||
transforms.ToTensor(),
|
|
||||||
im_normalization,
|
|
||||||
transforms.Resize(size, interpolation=InterpolationMode.BILINEAR),
|
|
||||||
])
|
|
||||||
self.size = size
|
|
||||||
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
frame = self.frames[idx]
|
|
||||||
info = {}
|
|
||||||
data = {}
|
|
||||||
info['frame'] = frame
|
|
||||||
info['save'] = (self.to_save is None) or (frame[:-4] in self.to_save)
|
|
||||||
|
|
||||||
im_path = path.join(self.image_dir, frame)
|
|
||||||
img = Image.open(im_path).convert('RGB')
|
|
||||||
|
|
||||||
if self.image_dir == self.size_dir:
|
|
||||||
shape = np.array(img).shape[:2]
|
|
||||||
else:
|
|
||||||
size_path = path.join(self.size_dir, frame)
|
|
||||||
size_im = Image.open(size_path).convert('RGB')
|
|
||||||
shape = np.array(size_im).shape[:2]
|
|
||||||
|
|
||||||
gt_path = path.join(self.mask_dir, frame[:-4]+'.png')
|
|
||||||
img = self.im_transform(img)
|
|
||||||
|
|
||||||
load_mask = self.use_all_mask or (gt_path == self.first_gt_path)
|
|
||||||
if load_mask and path.exists(gt_path):
|
|
||||||
mask = Image.open(gt_path).convert('P')
|
|
||||||
mask = np.array(mask, dtype=np.uint8)
|
|
||||||
data['mask'] = mask
|
|
||||||
|
|
||||||
info['shape'] = shape
|
|
||||||
info['need_resize'] = not (self.size < 0)
|
|
||||||
data['rgb'] = img
|
|
||||||
data['info'] = info
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
def resize_mask(self, mask):
|
|
||||||
# mask transform is applied AFTER mapper, so we need to post-process it in eval.py
|
|
||||||
h, w = mask.shape[-2:]
|
|
||||||
min_hw = min(h, w)
|
|
||||||
return F.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)),
|
|
||||||
mode='nearest')
|
|
||||||
|
|
||||||
def get_palette(self):
|
|
||||||
return self.palette
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.frames)
|
|
||||||
@@ -63,8 +63,6 @@ class InferenceCore:
|
|||||||
if need_segment:
|
if need_segment:
|
||||||
memory_readout = self.memory.match_memory(key, selection).unsqueeze(0)
|
memory_readout = self.memory.match_memory(key, selection).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
hidden, _, pred_prob_with_bg = self.network.segment(multi_scale_features, memory_readout,
|
hidden, _, pred_prob_with_bg = self.network.segment(multi_scale_features, memory_readout,
|
||||||
self.memory.get_hidden(), h_out=is_normal_update, strip_bg=False)
|
self.memory.get_hidden(), h_out=is_normal_update, strip_bg=False)
|
||||||
# remove batch dim
|
# remove batch dim
|
||||||
|
|||||||
@@ -1,138 +0,0 @@
|
|||||||
from argparse import ArgumentParser
|
|
||||||
|
|
||||||
|
|
||||||
def none_or_default(x, default):
|
|
||||||
return x if x is not None else default
|
|
||||||
|
|
||||||
class Configuration():
|
|
||||||
def parse(self, unknown_arg_ok=False):
|
|
||||||
parser = ArgumentParser()
|
|
||||||
|
|
||||||
# Enable torch.backends.cudnn.benchmark -- Faster in some cases, test in your own environment
|
|
||||||
parser.add_argument('--benchmark', action='store_true')
|
|
||||||
parser.add_argument('--no_amp', action='store_true')
|
|
||||||
|
|
||||||
# Save paths
|
|
||||||
parser.add_argument('--save_path', default='/ssd1/gaomingqi/output/xmem-sam')
|
|
||||||
|
|
||||||
# Data parameters
|
|
||||||
parser.add_argument('--static_root', help='Static training data root', default='/ssd1/gaomingqi/datasets/static')
|
|
||||||
parser.add_argument('--bl_root', help='Blender training data root', default='../BL30K')
|
|
||||||
parser.add_argument('--yv_root', help='YouTubeVOS data root', default='/ssd1/gaomingqi/datasets/youtube-vos/2018')
|
|
||||||
parser.add_argument('--davis_root', help='DAVIS data root', default='/ssd1/gaomingqi/datasets/davis')
|
|
||||||
parser.add_argument('--num_workers', help='Total number of dataloader workers across all GPUs processes', type=int, default=16)
|
|
||||||
|
|
||||||
parser.add_argument('--key_dim', default=64, type=int)
|
|
||||||
parser.add_argument('--value_dim', default=512, type=int)
|
|
||||||
parser.add_argument('--hidden_dim', default=64, help='Set to =0 to disable', type=int)
|
|
||||||
|
|
||||||
parser.add_argument('--deep_update_prob', default=0.2, type=float)
|
|
||||||
|
|
||||||
parser.add_argument('--stages', help='Training stage (0-static images, 1-Blender dataset, 2-DAVIS+YouTubeVOS)', default='02')
|
|
||||||
|
|
||||||
"""
|
|
||||||
Stage-specific learning parameters
|
|
||||||
Batch sizes are effective -- you don't have to scale them when you scale the number processes
|
|
||||||
"""
|
|
||||||
# Stage 0, static images
|
|
||||||
parser.add_argument('--s0_batch_size', default=16, type=int)
|
|
||||||
parser.add_argument('--s0_iterations', default=150000, type=int)
|
|
||||||
parser.add_argument('--s0_finetune', default=0, type=int)
|
|
||||||
parser.add_argument('--s0_steps', nargs="*", default=[], type=int)
|
|
||||||
parser.add_argument('--s0_lr', help='Initial learning rate', default=1e-5, type=float)
|
|
||||||
parser.add_argument('--s0_num_ref_frames', default=2, type=int)
|
|
||||||
parser.add_argument('--s0_num_frames', default=3, type=int)
|
|
||||||
parser.add_argument('--s0_start_warm', default=20000, type=int)
|
|
||||||
parser.add_argument('--s0_end_warm', default=70000, type=int)
|
|
||||||
|
|
||||||
# Stage 1, BL30K
|
|
||||||
parser.add_argument('--s1_batch_size', default=8, type=int)
|
|
||||||
parser.add_argument('--s1_iterations', default=250000, type=int)
|
|
||||||
# fine-tune means fewer augmentations to train the sensory memory
|
|
||||||
parser.add_argument('--s1_finetune', default=0, type=int)
|
|
||||||
parser.add_argument('--s1_steps', nargs="*", default=[200000], type=int)
|
|
||||||
parser.add_argument('--s1_lr', help='Initial learning rate', default=1e-5, type=float)
|
|
||||||
parser.add_argument('--s1_num_ref_frames', default=3, type=int)
|
|
||||||
parser.add_argument('--s1_num_frames', default=8, type=int)
|
|
||||||
parser.add_argument('--s1_start_warm', default=20000, type=int)
|
|
||||||
parser.add_argument('--s1_end_warm', default=70000, type=int)
|
|
||||||
|
|
||||||
# Stage 2, DAVIS+YoutubeVOS, longer
|
|
||||||
parser.add_argument('--s2_batch_size', default=8, type=int)
|
|
||||||
parser.add_argument('--s2_iterations', default=150000, type=int)
|
|
||||||
# fine-tune means fewer augmentations to train the sensory memory
|
|
||||||
parser.add_argument('--s2_finetune', default=10000, type=int)
|
|
||||||
parser.add_argument('--s2_steps', nargs="*", default=[120000], type=int)
|
|
||||||
parser.add_argument('--s2_lr', help='Initial learning rate', default=1e-5, type=float)
|
|
||||||
parser.add_argument('--s2_num_ref_frames', default=3, type=int)
|
|
||||||
parser.add_argument('--s2_num_frames', default=8, type=int)
|
|
||||||
parser.add_argument('--s2_start_warm', default=20000, type=int)
|
|
||||||
parser.add_argument('--s2_end_warm', default=70000, type=int)
|
|
||||||
|
|
||||||
# Stage 3, DAVIS+YoutubeVOS, shorter
|
|
||||||
parser.add_argument('--s3_batch_size', default=8, type=int)
|
|
||||||
parser.add_argument('--s3_iterations', default=100000, type=int)
|
|
||||||
# fine-tune means fewer augmentations to train the sensory memory
|
|
||||||
parser.add_argument('--s3_finetune', default=10000, type=int)
|
|
||||||
parser.add_argument('--s3_steps', nargs="*", default=[80000], type=int)
|
|
||||||
parser.add_argument('--s3_lr', help='Initial learning rate', default=1e-5, type=float)
|
|
||||||
parser.add_argument('--s3_num_ref_frames', default=3, type=int)
|
|
||||||
parser.add_argument('--s3_num_frames', default=8, type=int)
|
|
||||||
parser.add_argument('--s3_start_warm', default=20000, type=int)
|
|
||||||
parser.add_argument('--s3_end_warm', default=70000, type=int)
|
|
||||||
|
|
||||||
parser.add_argument('--gamma', help='LR := LR*gamma at every decay step', default=0.1, type=float)
|
|
||||||
parser.add_argument('--weight_decay', default=0.05, type=float)
|
|
||||||
|
|
||||||
# Loading
|
|
||||||
parser.add_argument('--load_network', help='Path to pretrained network weight only')
|
|
||||||
parser.add_argument('--load_checkpoint', help='Path to the checkpoint file, including network, optimizer and such')
|
|
||||||
|
|
||||||
# Logging information
|
|
||||||
parser.add_argument('--log_text_interval', default=100, type=int)
|
|
||||||
parser.add_argument('--log_image_interval', default=1000, type=int)
|
|
||||||
parser.add_argument('--save_network_interval', default=25000, type=int)
|
|
||||||
parser.add_argument('--save_checkpoint_interval', default=50000, type=int)
|
|
||||||
parser.add_argument('--exp_id', help='Experiment UNIQUE id, use NULL to disable logging to tensorboard', default='NULL')
|
|
||||||
parser.add_argument('--debug', help='Debug mode which logs information more often', action='store_true')
|
|
||||||
|
|
||||||
# # Multiprocessing parameters, not set by users
|
|
||||||
# parser.add_argument('--local_rank', default=0, type=int, help='Local rank of this process')
|
|
||||||
|
|
||||||
if unknown_arg_ok:
|
|
||||||
args, _ = parser.parse_known_args()
|
|
||||||
self.args = vars(args)
|
|
||||||
else:
|
|
||||||
self.args = vars(parser.parse_args())
|
|
||||||
|
|
||||||
self.args['amp'] = not self.args['no_amp']
|
|
||||||
|
|
||||||
# check if the stages are valid
|
|
||||||
stage_to_perform = list(self.args['stages'])
|
|
||||||
for s in stage_to_perform:
|
|
||||||
if s not in ['0', '1', '2', '3']:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def get_stage_parameters(self, stage):
|
|
||||||
parameters = {
|
|
||||||
'batch_size': self.args['s%s_batch_size'%stage],
|
|
||||||
'iterations': self.args['s%s_iterations'%stage],
|
|
||||||
'finetune': self.args['s%s_finetune'%stage],
|
|
||||||
'steps': self.args['s%s_steps'%stage],
|
|
||||||
'lr': self.args['s%s_lr'%stage],
|
|
||||||
'num_ref_frames': self.args['s%s_num_ref_frames'%stage],
|
|
||||||
'num_frames': self.args['s%s_num_frames'%stage],
|
|
||||||
'start_warm': self.args['s%s_start_warm'%stage],
|
|
||||||
'end_warm': self.args['s%s_end_warm'%stage],
|
|
||||||
}
|
|
||||||
|
|
||||||
return parameters
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
|
||||||
return self.args[key]
|
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
|
||||||
self.args[key] = value
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return str(self.args)
|
|
||||||
@@ -1,60 +0,0 @@
|
|||||||
bear
|
|
||||||
bmx-bumps
|
|
||||||
boat
|
|
||||||
boxing-fisheye
|
|
||||||
breakdance-flare
|
|
||||||
bus
|
|
||||||
car-turn
|
|
||||||
cat-girl
|
|
||||||
classic-car
|
|
||||||
color-run
|
|
||||||
crossing
|
|
||||||
dance-jump
|
|
||||||
dancing
|
|
||||||
disc-jockey
|
|
||||||
dog-agility
|
|
||||||
dog-gooses
|
|
||||||
dogs-scale
|
|
||||||
drift-turn
|
|
||||||
drone
|
|
||||||
elephant
|
|
||||||
flamingo
|
|
||||||
hike
|
|
||||||
hockey
|
|
||||||
horsejump-low
|
|
||||||
kid-football
|
|
||||||
kite-walk
|
|
||||||
koala
|
|
||||||
lady-running
|
|
||||||
lindy-hop
|
|
||||||
longboard
|
|
||||||
lucia
|
|
||||||
mallard-fly
|
|
||||||
mallard-water
|
|
||||||
miami-surf
|
|
||||||
motocross-bumps
|
|
||||||
motorbike
|
|
||||||
night-race
|
|
||||||
paragliding
|
|
||||||
planes-water
|
|
||||||
rallye
|
|
||||||
rhino
|
|
||||||
rollerblade
|
|
||||||
schoolgirls
|
|
||||||
scooter-board
|
|
||||||
scooter-gray
|
|
||||||
sheep
|
|
||||||
skate-park
|
|
||||||
snowboard
|
|
||||||
soccerball
|
|
||||||
stroller
|
|
||||||
stunt
|
|
||||||
surf
|
|
||||||
swing
|
|
||||||
tennis
|
|
||||||
tractor-sand
|
|
||||||
train
|
|
||||||
tuk-tuk
|
|
||||||
upside-down
|
|
||||||
varanus-cage
|
|
||||||
walking
|
|
||||||
@@ -1,136 +0,0 @@
|
|||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from dataset.range_transform import inv_im_trans
|
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
def tensor_to_numpy(image):
|
|
||||||
image_np = (image.numpy() * 255).astype('uint8')
|
|
||||||
return image_np
|
|
||||||
|
|
||||||
def tensor_to_np_float(image):
|
|
||||||
image_np = image.numpy().astype('float32')
|
|
||||||
return image_np
|
|
||||||
|
|
||||||
def detach_to_cpu(x):
|
|
||||||
return x.detach().cpu()
|
|
||||||
|
|
||||||
def transpose_np(x):
|
|
||||||
return np.transpose(x, [1,2,0])
|
|
||||||
|
|
||||||
def tensor_to_gray_im(x):
|
|
||||||
x = detach_to_cpu(x)
|
|
||||||
x = tensor_to_numpy(x)
|
|
||||||
x = transpose_np(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def tensor_to_im(x):
|
|
||||||
x = detach_to_cpu(x)
|
|
||||||
x = inv_im_trans(x).clamp(0, 1)
|
|
||||||
x = tensor_to_numpy(x)
|
|
||||||
x = transpose_np(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
# Predefined key <-> caption dict
|
|
||||||
key_captions = {
|
|
||||||
'im': 'Image',
|
|
||||||
'gt': 'GT',
|
|
||||||
}
|
|
||||||
|
|
||||||
"""
|
|
||||||
Return an image array with captions
|
|
||||||
keys in dictionary will be used as caption if not provided
|
|
||||||
values should contain lists of cv2 images
|
|
||||||
"""
|
|
||||||
def get_image_array(images, grid_shape, captions={}):
|
|
||||||
h, w = grid_shape
|
|
||||||
cate_counts = len(images)
|
|
||||||
rows_counts = len(next(iter(images.values())))
|
|
||||||
|
|
||||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
|
||||||
|
|
||||||
output_image = np.zeros([w*cate_counts, h*(rows_counts+1), 3], dtype=np.uint8)
|
|
||||||
col_cnt = 0
|
|
||||||
for k, v in images.items():
|
|
||||||
|
|
||||||
# Default as key value itself
|
|
||||||
caption = captions.get(k, k)
|
|
||||||
|
|
||||||
# Handles new line character
|
|
||||||
dy = 40
|
|
||||||
for i, line in enumerate(caption.split('\n')):
|
|
||||||
cv2.putText(output_image, line, (10, col_cnt*w+100+i*dy),
|
|
||||||
font, 0.8, (255,255,255), 2, cv2.LINE_AA)
|
|
||||||
|
|
||||||
# Put images
|
|
||||||
for row_cnt, img in enumerate(v):
|
|
||||||
im_shape = img.shape
|
|
||||||
if len(im_shape) == 2:
|
|
||||||
img = img[..., np.newaxis]
|
|
||||||
|
|
||||||
img = (img * 255).astype('uint8')
|
|
||||||
|
|
||||||
output_image[(col_cnt+0)*w:(col_cnt+1)*w,
|
|
||||||
(row_cnt+1)*h:(row_cnt+2)*h, :] = img
|
|
||||||
|
|
||||||
col_cnt += 1
|
|
||||||
|
|
||||||
return output_image
|
|
||||||
|
|
||||||
def base_transform(im, size):
|
|
||||||
im = tensor_to_np_float(im)
|
|
||||||
if len(im.shape) == 3:
|
|
||||||
im = im.transpose((1, 2, 0))
|
|
||||||
else:
|
|
||||||
im = im[:, :, None]
|
|
||||||
|
|
||||||
# Resize
|
|
||||||
if im.shape[1] != size:
|
|
||||||
im = cv2.resize(im, size, interpolation=cv2.INTER_NEAREST)
|
|
||||||
|
|
||||||
return im.clip(0, 1)
|
|
||||||
|
|
||||||
def im_transform(im, size):
|
|
||||||
return base_transform(inv_im_trans(detach_to_cpu(im)), size=size)
|
|
||||||
|
|
||||||
def mask_transform(mask, size):
|
|
||||||
return base_transform(detach_to_cpu(mask), size=size)
|
|
||||||
|
|
||||||
def out_transform(mask, size):
|
|
||||||
return base_transform(detach_to_cpu(torch.sigmoid(mask)), size=size)
|
|
||||||
|
|
||||||
def pool_pairs(images, size, num_objects):
|
|
||||||
req_images = defaultdict(list)
|
|
||||||
|
|
||||||
b, t = images['rgb'].shape[:2]
|
|
||||||
|
|
||||||
# limit the number of images saved
|
|
||||||
b = min(2, b)
|
|
||||||
|
|
||||||
# find max num objects
|
|
||||||
max_num_objects = max(num_objects[:b])
|
|
||||||
|
|
||||||
GT_suffix = ''
|
|
||||||
for bi in range(b):
|
|
||||||
GT_suffix += ' \n%s' % images['info']['name'][bi][-25:-4]
|
|
||||||
|
|
||||||
for bi in range(b):
|
|
||||||
for ti in range(t):
|
|
||||||
req_images['RGB'].append(im_transform(images['rgb'][bi,ti], size))
|
|
||||||
for oi in range(max_num_objects):
|
|
||||||
if ti == 0 or oi >= num_objects[bi]:
|
|
||||||
req_images['Mask_%d'%oi].append(mask_transform(images['first_frame_gt'][bi][0,oi], size))
|
|
||||||
# req_images['Mask_X8_%d'%oi].append(mask_transform(images['first_frame_gt'][bi][0,oi], size))
|
|
||||||
# req_images['Mask_X16_%d'%oi].append(mask_transform(images['first_frame_gt'][bi][0,oi], size))
|
|
||||||
else:
|
|
||||||
req_images['Mask_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi], size))
|
|
||||||
# req_images['Mask_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi][2], size))
|
|
||||||
# req_images['Mask_X8_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi][1], size))
|
|
||||||
# req_images['Mask_X16_%d'%oi].append(mask_transform(images['masks_%d'%ti][bi][oi][0], size))
|
|
||||||
req_images['GT_%d_%s'%(oi, GT_suffix)].append(mask_transform(images['cls_gt'][bi,ti,0]==(oi+1), size))
|
|
||||||
# print((images['cls_gt'][bi,ti,0]==(oi+1)).shape)
|
|
||||||
# print(mask_transform(images['cls_gt'][bi,ti,0]==(oi+1), size).shape)
|
|
||||||
|
|
||||||
|
|
||||||
return get_image_array(req_images, size, key_captions)
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
"""
|
|
||||||
load_subset.py - Presents a subset of data
|
|
||||||
DAVIS - only the training set
|
|
||||||
YouTubeVOS - I manually filtered some erroneous ones out but I haven't checked all
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def load_sub_davis(path='util/davis_subset.txt'):
|
|
||||||
with open(path, mode='r') as f:
|
|
||||||
subset = set(f.read().splitlines())
|
|
||||||
return subset
|
|
||||||
|
|
||||||
def load_sub_yv(path='util/yv_subset.txt'):
|
|
||||||
with open(path, mode='r') as f:
|
|
||||||
subset = set(f.read().splitlines())
|
|
||||||
return subset
|
|
||||||
@@ -1,80 +0,0 @@
|
|||||||
"""
|
|
||||||
Integrate numerical values for some iterations
|
|
||||||
Typically used for loss computation / logging to tensorboard
|
|
||||||
Call finalize and create a new Integrator when you want to display/log
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class Integrator:
|
|
||||||
def __init__(self, logger, distributed=True, local_rank=0, world_size=1):
|
|
||||||
self.values = {}
|
|
||||||
self.counts = {}
|
|
||||||
self.hooks = [] # List is used here to maintain insertion order
|
|
||||||
|
|
||||||
self.logger = logger
|
|
||||||
|
|
||||||
self.distributed = distributed
|
|
||||||
self.local_rank = local_rank
|
|
||||||
self.world_size = world_size
|
|
||||||
|
|
||||||
def add_tensor(self, key, tensor):
|
|
||||||
if key not in self.values:
|
|
||||||
self.counts[key] = 1
|
|
||||||
if type(tensor) == float or type(tensor) == int:
|
|
||||||
self.values[key] = tensor
|
|
||||||
else:
|
|
||||||
self.values[key] = tensor.mean().item()
|
|
||||||
else:
|
|
||||||
self.counts[key] += 1
|
|
||||||
if type(tensor) == float or type(tensor) == int:
|
|
||||||
self.values[key] += tensor
|
|
||||||
else:
|
|
||||||
self.values[key] += tensor.mean().item()
|
|
||||||
|
|
||||||
def add_dict(self, tensor_dict):
|
|
||||||
for k, v in tensor_dict.items():
|
|
||||||
self.add_tensor(k, v)
|
|
||||||
|
|
||||||
def add_hook(self, hook):
|
|
||||||
"""
|
|
||||||
Adds a custom hook, i.e. compute new metrics using values in the dict
|
|
||||||
The hook takes the dict as argument, and returns a (k, v) tuple
|
|
||||||
e.g. for computing IoU
|
|
||||||
"""
|
|
||||||
if type(hook) == list:
|
|
||||||
self.hooks.extend(hook)
|
|
||||||
else:
|
|
||||||
self.hooks.append(hook)
|
|
||||||
|
|
||||||
def reset_except_hooks(self):
|
|
||||||
self.values = {}
|
|
||||||
self.counts = {}
|
|
||||||
|
|
||||||
# Average and output the metrics
|
|
||||||
def finalize(self, prefix, it, f=None):
|
|
||||||
|
|
||||||
for hook in self.hooks:
|
|
||||||
k, v = hook(self.values)
|
|
||||||
self.add_tensor(k, v)
|
|
||||||
|
|
||||||
for k, v in self.values.items():
|
|
||||||
|
|
||||||
if k[:4] == 'hide':
|
|
||||||
continue
|
|
||||||
|
|
||||||
avg = v / self.counts[k]
|
|
||||||
|
|
||||||
if self.distributed:
|
|
||||||
# Inplace operation
|
|
||||||
avg = torch.tensor(avg).cuda()
|
|
||||||
torch.distributed.reduce(avg, dst=0)
|
|
||||||
|
|
||||||
if self.local_rank == 0:
|
|
||||||
avg = (avg/self.world_size).cpu().item()
|
|
||||||
self.logger.log_metrics(prefix, k, avg, it, f)
|
|
||||||
else:
|
|
||||||
# Simple does it
|
|
||||||
self.logger.log_metrics(prefix, k, avg, it, f)
|
|
||||||
|
|
||||||
@@ -1,101 +0,0 @@
|
|||||||
"""
|
|
||||||
Dumps things to tensorboard and console
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
import torchvision.transforms as transforms
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
|
|
||||||
|
|
||||||
def tensor_to_numpy(image):
|
|
||||||
image_np = (image.numpy() * 255).astype('uint8')
|
|
||||||
return image_np
|
|
||||||
|
|
||||||
def detach_to_cpu(x):
|
|
||||||
return x.detach().cpu()
|
|
||||||
|
|
||||||
def fix_width_trunc(x):
|
|
||||||
return ('{:.9s}'.format('{:0.9f}'.format(x)))
|
|
||||||
|
|
||||||
class TensorboardLogger:
|
|
||||||
def __init__(self, short_id, id, git_info):
|
|
||||||
self.short_id = short_id
|
|
||||||
if self.short_id == 'NULL':
|
|
||||||
self.short_id = 'DEBUG'
|
|
||||||
|
|
||||||
if id is None:
|
|
||||||
self.no_log = True
|
|
||||||
warnings.warn('Logging has been disbaled.')
|
|
||||||
else:
|
|
||||||
self.no_log = False
|
|
||||||
|
|
||||||
self.inv_im_trans = transforms.Normalize(
|
|
||||||
mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
|
|
||||||
std=[1/0.229, 1/0.224, 1/0.225])
|
|
||||||
|
|
||||||
self.inv_seg_trans = transforms.Normalize(
|
|
||||||
mean=[-0.5/0.5],
|
|
||||||
std=[1/0.5])
|
|
||||||
|
|
||||||
log_path = os.path.join('.', 'logs', '%s' % id)
|
|
||||||
self.logger = SummaryWriter(log_path)
|
|
||||||
|
|
||||||
self.log_string('git', git_info)
|
|
||||||
|
|
||||||
def log_scalar(self, tag, x, step):
|
|
||||||
if self.no_log:
|
|
||||||
warnings.warn('Logging has been disabled.')
|
|
||||||
return
|
|
||||||
self.logger.add_scalar(tag, x, step)
|
|
||||||
|
|
||||||
def log_metrics(self, l1_tag, l2_tag, val, step, f=None):
|
|
||||||
tag = l1_tag + '/' + l2_tag
|
|
||||||
text = '{:s} - It {:6d} [{:5s}] [{:13}]: {:s}'.format(self.short_id, step, l1_tag.upper(), l2_tag, fix_width_trunc(val))
|
|
||||||
print(text)
|
|
||||||
if f is not None:
|
|
||||||
f.write(text + '\n')
|
|
||||||
f.flush()
|
|
||||||
self.log_scalar(tag, val, step)
|
|
||||||
|
|
||||||
def log_im(self, tag, x, step):
|
|
||||||
if self.no_log:
|
|
||||||
warnings.warn('Logging has been disabled.')
|
|
||||||
return
|
|
||||||
x = detach_to_cpu(x)
|
|
||||||
x = self.inv_im_trans(x)
|
|
||||||
x = tensor_to_numpy(x)
|
|
||||||
self.logger.add_image(tag, x, step)
|
|
||||||
|
|
||||||
def log_cv2(self, tag, x, step):
|
|
||||||
if self.no_log:
|
|
||||||
warnings.warn('Logging has been disabled.')
|
|
||||||
return
|
|
||||||
x = x.transpose((2, 0, 1))
|
|
||||||
self.logger.add_image(tag, x, step)
|
|
||||||
|
|
||||||
def log_seg(self, tag, x, step):
|
|
||||||
if self.no_log:
|
|
||||||
warnings.warn('Logging has been disabled.')
|
|
||||||
return
|
|
||||||
x = detach_to_cpu(x)
|
|
||||||
x = self.inv_seg_trans(x)
|
|
||||||
x = tensor_to_numpy(x)
|
|
||||||
self.logger.add_image(tag, x, step)
|
|
||||||
|
|
||||||
def log_gray(self, tag, x, step):
|
|
||||||
if self.no_log:
|
|
||||||
warnings.warn('Logging has been disabled.')
|
|
||||||
return
|
|
||||||
x = detach_to_cpu(x)
|
|
||||||
x = tensor_to_numpy(x)
|
|
||||||
self.logger.add_image(tag, x, step)
|
|
||||||
|
|
||||||
def log_string(self, tag, x):
|
|
||||||
print(tag, x)
|
|
||||||
if self.no_log:
|
|
||||||
warnings.warn('Logging has been disabled.')
|
|
||||||
return
|
|
||||||
self.logger.add_text(tag, x)
|
|
||||||
|
|
||||||
@@ -1,8 +1,16 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from dataset.util import all_to_onehot
|
def all_to_onehot(masks, labels):
|
||||||
|
if len(masks.shape) == 3:
|
||||||
|
Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1], masks.shape[2]), dtype=np.uint8)
|
||||||
|
else:
|
||||||
|
Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1]), dtype=np.uint8)
|
||||||
|
|
||||||
|
for ni, l in enumerate(labels):
|
||||||
|
Ms[ni] = (masks == l).astype(np.uint8)
|
||||||
|
|
||||||
|
return Ms
|
||||||
|
|
||||||
class MaskMapper:
|
class MaskMapper:
|
||||||
"""
|
"""
|
||||||
@@ -23,6 +31,12 @@ class MaskMapper:
|
|||||||
# if coherent, no mapping is required
|
# if coherent, no mapping is required
|
||||||
self.coherent = True
|
self.coherent = True
|
||||||
|
|
||||||
|
def clear_labels(self):
|
||||||
|
self.labels = []
|
||||||
|
self.remappings = {}
|
||||||
|
# if coherent, no mapping is required
|
||||||
|
self.coherent = True
|
||||||
|
|
||||||
def convert_mask(self, mask, exhaustive=False):
|
def convert_mask(self, mask, exhaustive=False):
|
||||||
# mask is in index representation, H*W numpy array
|
# mask is in index representation, H*W numpy array
|
||||||
labels = np.unique(mask).astype(np.uint8)
|
labels = np.unique(mask).astype(np.uint8)
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
davis_palette = b'\x00\x00\x00\x80\x00\x00\x00\x80\x00\x80\x80\x00\x00\x00\x80\x80\x00\x80\x00\x80\x80\x80\x80\x80@\x00\x00\xc0\x00\x00@\x80\x00\xc0\x80\x00@\x00\x80\xc0\x00\x80@\x80\x80\xc0\x80\x80\x00@\x00\x80@\x00\x00\xc0\x00\x80\xc0\x00\x00@\x80\x80@\x80\x00\xc0\x80\x80\xc0\x80@@\x00\xc0@\x00@\xc0\x00\xc0\xc0\x00@@\x80\xc0@\x80@\xc0\x80\xc0\xc0\x80\x00\x00@\x80\x00@\x00\x80@\x80\x80@\x00\x00\xc0\x80\x00\xc0\x00\x80\xc0\x80\x80\xc0@\x00@\xc0\x00@@\x80@\xc0\x80@@\x00\xc0\xc0\x00\xc0@\x80\xc0\xc0\x80\xc0\x00@@\x80@@\x00\xc0@\x80\xc0@\x00@\xc0\x80@\xc0\x00\xc0\xc0\x80\xc0\xc0@@@\xc0@@@\xc0@\xc0\xc0@@@\xc0\xc0@\xc0@\xc0\xc0\xc0\xc0\xc0 \x00\x00\xa0\x00\x00 \x80\x00\xa0\x80\x00 \x00\x80\xa0\x00\x80 \x80\x80\xa0\x80\x80`\x00\x00\xe0\x00\x00`\x80\x00\xe0\x80\x00`\x00\x80\xe0\x00\x80`\x80\x80\xe0\x80\x80 @\x00\xa0@\x00 \xc0\x00\xa0\xc0\x00 @\x80\xa0@\x80 \xc0\x80\xa0\xc0\x80`@\x00\xe0@\x00`\xc0\x00\xe0\xc0\x00`@\x80\xe0@\x80`\xc0\x80\xe0\xc0\x80 \x00@\xa0\x00@ \x80@\xa0\x80@ \x00\xc0\xa0\x00\xc0 \x80\xc0\xa0\x80\xc0`\x00@\xe0\x00@`\x80@\xe0\x80@`\x00\xc0\xe0\x00\xc0`\x80\xc0\xe0\x80\xc0 @@\xa0@@ \xc0@\xa0\xc0@ @\xc0\xa0@\xc0 \xc0\xc0\xa0\xc0\xc0`@@\xe0@@`\xc0@\xe0\xc0@`@\xc0\xe0@\xc0`\xc0\xc0\xe0\xc0\xc0\x00 \x00\x80 \x00\x00\xa0\x00\x80\xa0\x00\x00 \x80\x80 \x80\x00\xa0\x80\x80\xa0\x80@ \x00\xc0 \x00@\xa0\x00\xc0\xa0\x00@ \x80\xc0 \x80@\xa0\x80\xc0\xa0\x80\x00`\x00\x80`\x00\x00\xe0\x00\x80\xe0\x00\x00`\x80\x80`\x80\x00\xe0\x80\x80\xe0\x80@`\x00\xc0`\x00@\xe0\x00\xc0\xe0\x00@`\x80\xc0`\x80@\xe0\x80\xc0\xe0\x80\x00 @\x80 @\x00\xa0@\x80\xa0@\x00 \xc0\x80 \xc0\x00\xa0\xc0\x80\xa0\xc0@ @\xc0 @@\xa0@\xc0\xa0@@ \xc0\xc0 \xc0@\xa0\xc0\xc0\xa0\xc0\x00`@\x80`@\x00\xe0@\x80\xe0@\x00`\xc0\x80`\xc0\x00\xe0\xc0\x80\xe0\xc0@`@\xc0`@@\xe0@\xc0\xe0@@`\xc0\xc0`\xc0@\xe0\xc0\xc0\xe0\xc0 \x00\xa0 \x00 \xa0\x00\xa0\xa0\x00 \x80\xa0 \x80 \xa0\x80\xa0\xa0\x80` \x00\xe0 \x00`\xa0\x00\xe0\xa0\x00` \x80\xe0 \x80`\xa0\x80\xe0\xa0\x80 `\x00\xa0`\x00 \xe0\x00\xa0\xe0\x00 `\x80\xa0`\x80 \xe0\x80\xa0\xe0\x80``\x00\xe0`\x00`\xe0\x00\xe0\xe0\x00``\x80\xe0`\x80`\xe0\x80\xe0\xe0\x80 @\xa0 @ \xa0@\xa0\xa0@ \xc0\xa0 \xc0 \xa0\xc0\xa0\xa0\xc0` @\xe0 @`\xa0@\xe0\xa0@` \xc0\xe0 \xc0`\xa0\xc0\xe0\xa0\xc0 `@\xa0`@ \xe0@\xa0\xe0@ `\xc0\xa0`\xc0 \xe0\xc0\xa0\xe0\xc0``@\xe0`@`\xe0@\xe0\xe0@``\xc0\xe0`\xc0`\xe0\xc0\xe0\xe0\xc0'
|
|
||||||
|
|
||||||
youtube_palette = b'\x00\x00\x00\xec_g\xf9\x91W\xfa\xc8c\x99\xc7\x94b\xb3\xb2f\x99\xcc\xc5\x94\xc5\xabyg\xff\xff\xffes~\x0b\x0b\x0b\x0c\x0c\x0c\r\r\r\x0e\x0e\x0e\x0f\x0f\x0f'
|
|
||||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user