Merge branches 'master' and 'master' of github.com:gaomingqi/Track-Anything
select template frame
4
.gitignore
vendored
@@ -2,4 +2,6 @@ __pycache__/
|
||||
.vscode/
|
||||
docs/
|
||||
*.pth
|
||||
|
||||
*.mp4
|
||||
debug_images/
|
||||
*.png
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# Track-Anything
|
||||
**Track-Anything** is an Efficient Development Toolkit for Video Object Tracking and Segmentation.
|
||||
|
||||

|
||||
|
||||
## Demo
|
||||
|
||||
This is Demo
|
||||
|
||||
BIN
images/mask_painter.png
Normal file
|
After Width: | Height: | Size: 701 KiB |
BIN
images/painter_input_image.jpg
Normal file
|
After Width: | Height: | Size: 438 KiB |
BIN
images/painter_input_mask.jpg
Normal file
|
After Width: | Height: | Size: 23 KiB |
BIN
images/painter_output_image.png
Normal file
|
After Width: | Height: | Size: 704 KiB |
BIN
images/painter_output_image__.png
Normal file
|
After Width: | Height: | Size: 4.9 KiB |
BIN
images/point_painter.png
Normal file
|
After Width: | Height: | Size: 739 KiB |
BIN
images/point_painter_1.png
Normal file
|
After Width: | Height: | Size: 739 KiB |
BIN
images/point_painter_2.png
Normal file
|
After Width: | Height: | Size: 739 KiB |
BIN
overview.png
Normal file
|
After Width: | Height: | Size: 2.8 MiB |
149
tools/interact_tools.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import time
|
||||
import torch
|
||||
import cv2
|
||||
from PIL import Image, ImageDraw, ImageOps
|
||||
import numpy as np
|
||||
from typing import Union
|
||||
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
|
||||
import matplotlib.pyplot as plt
|
||||
import PIL
|
||||
from mask_painter import mask_painter as mask_painter2
|
||||
from base_segmenter import BaseSegmenter
|
||||
from painter import mask_painter, point_painter
|
||||
|
||||
|
||||
mask_color = 3
|
||||
mask_alpha = 0.7
|
||||
contour_color = 1
|
||||
contour_width = 5
|
||||
point_color_ne = 8
|
||||
point_color_ps = 50
|
||||
point_alpha = 0.9
|
||||
point_radius = 15
|
||||
contour_color = 2
|
||||
contour_width = 5
|
||||
|
||||
|
||||
def initialize():
|
||||
'''
|
||||
initialize sam controler
|
||||
'''
|
||||
SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
|
||||
model_type = 'vit_h'
|
||||
device = "cuda:0"
|
||||
sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
|
||||
return sam_controler
|
||||
|
||||
|
||||
def seg_again(sam_controler, image: np.ndarray):
|
||||
'''
|
||||
it is used when interact in video
|
||||
'''
|
||||
sam_controler.reset_image()
|
||||
sam_controler.set_image(image)
|
||||
return
|
||||
|
||||
|
||||
def first_frame_click(sam_controler, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
|
||||
'''
|
||||
it is used in first frame in video
|
||||
return: mask, logit, painted image(mask+point)
|
||||
'''
|
||||
sam_controler.set_image(image)
|
||||
prompts = {
|
||||
'point_coords': points,
|
||||
'point_labels': labels,
|
||||
}
|
||||
masks, scores, logits = sam_controler.predict(prompts, 'point', multimask)
|
||||
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
||||
|
||||
assert len(points)==len(labels)
|
||||
|
||||
painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
||||
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
|
||||
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
|
||||
painted_image = Image.fromarray(painted_image)
|
||||
|
||||
return mask, logit, painted_image
|
||||
|
||||
def interact_loop(sam_controler, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
||||
if same:
|
||||
'''
|
||||
true; loop in the same image
|
||||
'''
|
||||
prompts = {
|
||||
'point_coords': points,
|
||||
'point_labels': labels,
|
||||
'mask_input': logits[None, :, :]
|
||||
}
|
||||
masks, scores, logits = sam_controler.predict(prompts, 'both', multimask)
|
||||
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
||||
|
||||
painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
||||
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
|
||||
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
|
||||
painted_image = Image.fromarray(painted_image)
|
||||
|
||||
return mask, logit, painted_image
|
||||
else:
|
||||
'''
|
||||
loop in the different image, interact in the video
|
||||
'''
|
||||
if image is None:
|
||||
raise('Image error')
|
||||
else:
|
||||
seg_again(sam_controler, image)
|
||||
prompts = {
|
||||
'point_coords': points,
|
||||
'point_labels': labels,
|
||||
}
|
||||
masks, scores, logits = sam_controler.predict(prompts, 'point', multimask)
|
||||
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
||||
|
||||
painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
||||
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
|
||||
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
|
||||
painted_image = Image.fromarray(painted_image)
|
||||
|
||||
return mask, logit, painted_image
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
points = np.array([[500, 375], [1125, 625]])
|
||||
labels = np.array([1, 1])
|
||||
image = cv2.imread('/hhd3/gaoshang/truck.jpg')
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
sam_controler = initialize()
|
||||
mask, logit, painted_image_full = first_frame_click(sam_controler,image, points, labels, multimask=True)
|
||||
painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
|
||||
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
|
||||
cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image)
|
||||
cv2.imwrite('/hhd3/gaoshang/truck_change.jpg', image)
|
||||
painted_image_full.save('/hhd3/gaoshang/truck_point_full.jpg')
|
||||
|
||||
mask, logit, painted_image_full = interact_loop(sam_controler,image,True, points, np.array([1, 0]), logit, multimask=True)
|
||||
painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
|
||||
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
|
||||
cv2.imwrite('/hhd3/gaoshang/truck_same.jpg', painted_image)
|
||||
painted_image_full.save('/hhd3/gaoshang/truck_same_full.jpg')
|
||||
|
||||
mask, logit, painted_image_full = interact_loop(sam_controler,image, False, points, labels, multimask=True)
|
||||
painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
|
||||
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
|
||||
cv2.imwrite('/hhd3/gaoshang/truck_diff.jpg', painted_image)
|
||||
painted_image_full.save('/hhd3/gaoshang/truck_diff_full.jpg')
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
196
tools/painter.py
Normal file
@@ -0,0 +1,196 @@
|
||||
# paint masks, contours, or points on images, with specified colors
|
||||
import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import copy
|
||||
import time
|
||||
|
||||
|
||||
def colormap(rgb=True):
|
||||
color_list = np.array(
|
||||
[
|
||||
0.000, 0.000, 0.000,
|
||||
1.000, 1.000, 1.000,
|
||||
1.000, 0.498, 0.313,
|
||||
0.392, 0.581, 0.929,
|
||||
0.000, 0.447, 0.741,
|
||||
0.850, 0.325, 0.098,
|
||||
0.929, 0.694, 0.125,
|
||||
0.494, 0.184, 0.556,
|
||||
0.466, 0.674, 0.188,
|
||||
0.301, 0.745, 0.933,
|
||||
0.635, 0.078, 0.184,
|
||||
0.300, 0.300, 0.300,
|
||||
0.600, 0.600, 0.600,
|
||||
1.000, 0.000, 0.000,
|
||||
1.000, 0.500, 0.000,
|
||||
0.749, 0.749, 0.000,
|
||||
0.000, 1.000, 0.000,
|
||||
0.000, 0.000, 1.000,
|
||||
0.667, 0.000, 1.000,
|
||||
0.333, 0.333, 0.000,
|
||||
0.333, 0.667, 0.000,
|
||||
0.333, 1.000, 0.000,
|
||||
0.667, 0.333, 0.000,
|
||||
0.667, 0.667, 0.000,
|
||||
0.667, 1.000, 0.000,
|
||||
1.000, 0.333, 0.000,
|
||||
1.000, 0.667, 0.000,
|
||||
1.000, 1.000, 0.000,
|
||||
0.000, 0.333, 0.500,
|
||||
0.000, 0.667, 0.500,
|
||||
0.000, 1.000, 0.500,
|
||||
0.333, 0.000, 0.500,
|
||||
0.333, 0.333, 0.500,
|
||||
0.333, 0.667, 0.500,
|
||||
0.333, 1.000, 0.500,
|
||||
0.667, 0.000, 0.500,
|
||||
0.667, 0.333, 0.500,
|
||||
0.667, 0.667, 0.500,
|
||||
0.667, 1.000, 0.500,
|
||||
1.000, 0.000, 0.500,
|
||||
1.000, 0.333, 0.500,
|
||||
1.000, 0.667, 0.500,
|
||||
1.000, 1.000, 0.500,
|
||||
0.000, 0.333, 1.000,
|
||||
0.000, 0.667, 1.000,
|
||||
0.000, 1.000, 1.000,
|
||||
0.333, 0.000, 1.000,
|
||||
0.333, 0.333, 1.000,
|
||||
0.333, 0.667, 1.000,
|
||||
0.333, 1.000, 1.000,
|
||||
0.667, 0.000, 1.000,
|
||||
0.667, 0.333, 1.000,
|
||||
0.667, 0.667, 1.000,
|
||||
0.667, 1.000, 1.000,
|
||||
1.000, 0.000, 1.000,
|
||||
1.000, 0.333, 1.000,
|
||||
1.000, 0.667, 1.000,
|
||||
0.167, 0.000, 0.000,
|
||||
0.333, 0.000, 0.000,
|
||||
0.500, 0.000, 0.000,
|
||||
0.667, 0.000, 0.000,
|
||||
0.833, 0.000, 0.000,
|
||||
1.000, 0.000, 0.000,
|
||||
0.000, 0.167, 0.000,
|
||||
0.000, 0.333, 0.000,
|
||||
0.000, 0.500, 0.000,
|
||||
0.000, 0.667, 0.000,
|
||||
0.000, 0.833, 0.000,
|
||||
0.000, 1.000, 0.000,
|
||||
0.000, 0.000, 0.167,
|
||||
0.000, 0.000, 0.333,
|
||||
0.000, 0.000, 0.500,
|
||||
0.000, 0.000, 0.667,
|
||||
0.000, 0.000, 0.833,
|
||||
0.000, 0.000, 1.000,
|
||||
0.143, 0.143, 0.143,
|
||||
0.286, 0.286, 0.286,
|
||||
0.429, 0.429, 0.429,
|
||||
0.571, 0.571, 0.571,
|
||||
0.714, 0.714, 0.714,
|
||||
0.857, 0.857, 0.857
|
||||
]
|
||||
).astype(np.float32)
|
||||
color_list = color_list.reshape((-1, 3)) * 255
|
||||
if not rgb:
|
||||
color_list = color_list[:, ::-1]
|
||||
return color_list
|
||||
|
||||
|
||||
color_list = colormap()
|
||||
color_list = color_list.astype('uint8').tolist()
|
||||
|
||||
|
||||
def vis_add_mask(image, mask, color, alpha):
|
||||
color = np.array(color_list[color])
|
||||
mask = mask > 0.5
|
||||
image[mask] = image[mask] * (1-alpha) + color * alpha
|
||||
return image.astype('uint8')
|
||||
|
||||
def point_painter(input_image, input_points, point_color=5, point_alpha=0.9, point_radius=15, contour_color=2, contour_width=5):
|
||||
h, w = input_image.shape[:2]
|
||||
point_mask = np.zeros((h, w)).astype('uint8')
|
||||
for point in input_points:
|
||||
point_mask[point[1], point[0]] = 1
|
||||
|
||||
kernel = cv2.getStructuringElement(2, (point_radius, point_radius))
|
||||
point_mask = cv2.dilate(point_mask, kernel)
|
||||
|
||||
contour_radius = (contour_width - 1) // 2
|
||||
dist_transform_fore = cv2.distanceTransform(point_mask, cv2.DIST_L2, 3)
|
||||
dist_transform_back = cv2.distanceTransform(1-point_mask, cv2.DIST_L2, 3)
|
||||
dist_map = dist_transform_fore - dist_transform_back
|
||||
# ...:::!!!:::...
|
||||
contour_radius += 2
|
||||
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
|
||||
contour_mask = contour_mask / np.max(contour_mask)
|
||||
contour_mask[contour_mask>0.5] = 1.
|
||||
|
||||
# paint mask
|
||||
painted_image = vis_add_mask(input_image.copy(), point_mask, point_color, point_alpha)
|
||||
# paint contour
|
||||
painted_image = vis_add_mask(painted_image.copy(), 1-contour_mask, contour_color, 1)
|
||||
return painted_image
|
||||
|
||||
def mask_painter(input_image, input_mask, mask_color=5, mask_alpha=0.7, contour_color=1, contour_width=3):
|
||||
assert input_image.shape[:2] == input_mask.shape, 'different shape between image and mask'
|
||||
# 0: background, 1: foreground
|
||||
mask = np.clip(input_mask, 0, 1)
|
||||
contour_radius = (contour_width - 1) // 2
|
||||
|
||||
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
|
||||
dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
|
||||
dist_map = dist_transform_fore - dist_transform_back
|
||||
# ...:::!!!:::...
|
||||
contour_radius += 2
|
||||
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
|
||||
contour_mask = contour_mask / np.max(contour_mask)
|
||||
contour_mask[contour_mask>0.5] = 1.
|
||||
|
||||
# paint mask
|
||||
painted_image = vis_add_mask(input_image.copy(), mask.copy(), mask_color, mask_alpha)
|
||||
# paint contour
|
||||
painted_image = vis_add_mask(painted_image.copy(), 1-contour_mask, contour_color, 1)
|
||||
|
||||
return painted_image
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB'))
|
||||
input_mask = np.array(Image.open('images/painter_input_mask.jpg').convert('P'))
|
||||
|
||||
# example of mask painter
|
||||
mask_color = 3
|
||||
mask_alpha = 0.7
|
||||
contour_color = 1
|
||||
contour_width = 5
|
||||
|
||||
# save
|
||||
painted_image = Image.fromarray(input_image)
|
||||
painted_image.save('images/original.png')
|
||||
|
||||
painted_image = mask_painter(input_image, input_mask, mask_color, mask_alpha, contour_color, contour_width)
|
||||
# save
|
||||
painted_image = Image.fromarray(input_image)
|
||||
painted_image.save('images/original1.png')
|
||||
|
||||
# example of point painter
|
||||
input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB'))
|
||||
input_points = np.array([[500, 375], [70, 600]]) # x, y
|
||||
point_color = 5
|
||||
point_alpha = 0.9
|
||||
point_radius = 15
|
||||
contour_color = 2
|
||||
contour_width = 5
|
||||
painted_image_1 = point_painter(input_image, input_points, point_color, point_alpha, point_radius, contour_color, contour_width)
|
||||
# save
|
||||
painted_image = Image.fromarray(painted_image_1)
|
||||
painted_image.save('images/point_painter_1.png')
|
||||
|
||||
input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB'))
|
||||
painted_image_2 = point_painter(input_image, input_points, point_color=9, point_radius=20, contour_color=29)
|
||||
# save
|
||||
painted_image = Image.fromarray(painted_image_2)
|
||||
painted_image.save('images/point_painter_2.png')
|
||||
@@ -205,10 +205,6 @@ for vid_reader in progressbar(meta_loader, max_value=len(meta_dataset), redirect
|
||||
# Run the model on this frame
|
||||
prob = processor.step(rgb, msk, labels, end=(ti==vid_length-1)) # 0, background, >0, objects
|
||||
|
||||
# consider prob (only object channels) as prompt to refine segment results
|
||||
|
||||
|
||||
|
||||
# Upsample to original size if needed
|
||||
if need_resize:
|
||||
prob = F.interpolate(prob.unsqueeze(1), shape, mode='bilinear', align_corners=False)[:,0]
|
||||
|
||||
29
tracker/xmem.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# input: frame list, first frame mask
|
||||
# output: segmentation results on all frames
|
||||
import os
|
||||
import glob
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class XMem:
|
||||
# based on https://github.com/hkchengrex/XMem
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# video frames
|
||||
video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/dance-twirl', '*.jpg'))
|
||||
video_path_list.sort()
|
||||
# first frame
|
||||
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/dance-twirl/00000.png'
|
||||
|
||||
# load frames
|
||||
frames = []
|
||||
for video_path in video_path_list:
|
||||
frames.append(np.array(Image.open(video_path).convert('RGB')))
|
||||
frames = np.stack(frames, 0) # N, H, W, C
|
||||
|
||||
# load first frame annotation
|
||||
first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C
|
||||
|
||||