Merge branch 'master' of github.com:gaomingqi/Track-Anything

This commit is contained in:
memoryunreal
2023-04-13 20:15:13 +00:00
5 changed files with 210 additions and 87 deletions

6
app.py
View File

@@ -7,10 +7,12 @@ from PIL import Image
import numpy as np
# from tools.interact_tools import initialize
from tools.interact_tools import SamControler
samc = SamControler()
# initialize()
def pause_video(play_state):

View File

@@ -7,7 +7,7 @@ from typing import Union
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
import matplotlib.pyplot as plt
import PIL
from mask_painter import mask_painter
from .mask_painter import mask_painter
class BaseSegmenter:
@@ -78,7 +78,7 @@ class BaseSegmenter:
if __name__ == "__main__":
# load and show an image
image = cv2.imread('images/truck.jpg')
image = cv2.imread('/hhd3/gaoshang/truck.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # numpy array (h, w, 3)
# initialise BaseSegmenter
@@ -100,7 +100,7 @@ if __name__ == "__main__":
masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=False) # masks (n, h, w), scores (n,), logits (n, 256, 256)
painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
cv2.imwrite('images/truck_point.jpg', painted_image)
cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image)
# both ------------------------
mode = 'both'
@@ -114,13 +114,15 @@ if __name__ == "__main__":
masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256)
painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
cv2.imwrite('images/truck_both.jpg', painted_image)
cv2.imwrite('/hhd3/gaoshang/truck_both.jpg', painted_image)
# mask only ------------------------
mode = 'mask'
mask_input = logits[np.argmax(scores), :, :]
prompts = {'mask_input': mask_input[None, :, :]}
masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256)
painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
cv2.imwrite('images/truck_mask.jpg', painted_image)
cv2.imwrite('/hhd3/gaoshang/truck_mask.jpg', painted_image)

View File

@@ -7,11 +7,13 @@ from typing import Union
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
import matplotlib.pyplot as plt
import PIL
from tools.mask_painter import mask_painter as mask_painter2
from base_segmenter import BaseSegmenter
from painter import mask_painter, point_painter
from .mask_painter import mask_painter as mask_painter2
from .base_segmenter import BaseSegmenter
from .painter import mask_painter, point_painter
import os
import requests
import sys
mask_color = 3
mask_alpha = 0.7
@@ -24,7 +26,6 @@ point_radius = 15
contour_color = 2
contour_width = 5
def download_checkpoint(url, folder, filename):
os.makedirs(folder, exist_ok=True)
filepath = os.path.join(folder, filename)
@@ -38,94 +39,184 @@ def download_checkpoint(url, folder, filename):
return filepath
def initialize():
'''
initialize sam controler
'''
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
folder = "segmenter"
SAM_checkpoint= './checkpoints/sam_vit_h_4b8939.pth'
download_checkpoint(checkpoint_url, folder, SAM_checkpoint)
model_type = 'vit_h'
device = "cuda:0"
sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
return sam_controler
def seg_again(sam_controler, image: np.ndarray):
'''
it is used when interact in video
'''
sam_controler.reset_image()
sam_controler.set_image(image)
return
def first_frame_click(sam_controler, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
'''
it is used in first frame in video
return: mask, logit, painted image(mask+point)
'''
sam_controler.set_image(image)
prompts = {
'point_coords': points,
'point_labels': labels,
}
masks, scores, logits = sam_controler.predict(prompts, 'point', multimask)
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
assert len(points)==len(labels)
painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
painted_image = Image.fromarray(painted_image)
return mask, logit, painted_image
def interact_loop(sam_controler, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
if same:
class SamControler():
def __init__(self):
'''
true; loop in the same image
initialize sam controler
'''
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
folder ="segmenter/checkpoints"
SAM_checkpoint= 'sam_vit_h_4b8939.pth'
SAM_checkpoint = download_checkpoint(checkpoint_url, folder, SAM_checkpoint)
# SAM_checkpoint = '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
model_type = 'vit_h'
device = "cuda:0"
self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
def seg_again(self, image: np.ndarray):
'''
it is used when interact in video
'''
self.sam_controler.reset_image()
self.sam_controler.set_image(image)
return
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
'''
it is used in first frame in video
return: mask, logit, painted image(mask+point)
'''
self.sam_controler.set_image(image)
prompts = {
'point_coords': points,
'point_labels': labels,
'mask_input': logits[None, :, :]
}
masks, scores, logits = sam_controler.predict(prompts, 'both', multimask)
masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
assert len(points)==len(labels)
painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
painted_image = Image.fromarray(painted_image)
return mask, logit, painted_image
else:
'''
loop in the different image, interact in the video
'''
if image is None:
raise('Image error')
def interact_loop(self, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
if same:
'''
true; loop in the same image
'''
prompts = {
'point_coords': points,
'point_labels': labels,
'mask_input': logits[None, :, :]
}
masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask)
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
painted_image = Image.fromarray(painted_image)
return mask, logit, painted_image
else:
seg_again(sam_controler, image)
prompts = {
'point_coords': points,
'point_labels': labels,
}
masks, scores, logits = sam_controler.predict(prompts, 'point', multimask)
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
'''
loop in the different image, interact in the video
'''
if image is None:
raise('Image error')
else:
self.seg_again(image)
prompts = {
'point_coords': points,
'point_labels': labels,
}
masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
painted_image = Image.fromarray(painted_image)
painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
painted_image = Image.fromarray(painted_image)
return mask, logit, painted_image
return mask, logit, painted_image
# def initialize():
# '''
# initialize sam controler
# '''
# checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
# folder = "segmenter"
# SAM_checkpoint= './checkpoints/sam_vit_h_4b8939.pth'
# download_checkpoint(checkpoint_url, folder, SAM_checkpoint)
# model_type = 'vit_h'
# device = "cuda:0"
# sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
# return sam_controler
# def seg_again(sam_controler, image: np.ndarray):
# '''
# it is used when interact in video
# '''
# sam_controler.reset_image()
# sam_controler.set_image(image)
# return
# def first_frame_click(sam_controler, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
# '''
# it is used in first frame in video
# return: mask, logit, painted image(mask+point)
# '''
# sam_controler.set_image(image)
# prompts = {
# 'point_coords': points,
# 'point_labels': labels,
# }
# masks, scores, logits = sam_controler.predict(prompts, 'point', multimask)
# mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
# assert len(points)==len(labels)
# painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
# painted_image = Image.fromarray(painted_image)
# return mask, logit, painted_image
# def interact_loop(sam_controler, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
# if same:
# '''
# true; loop in the same image
# '''
# prompts = {
# 'point_coords': points,
# 'point_labels': labels,
# 'mask_input': logits[None, :, :]
# }
# masks, scores, logits = sam_controler.predict(prompts, 'both', multimask)
# mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
# painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
# painted_image = Image.fromarray(painted_image)
# return mask, logit, painted_image
# else:
# '''
# loop in the different image, interact in the video
# '''
# if image is None:
# raise('Image error')
# else:
# seg_again(sam_controler, image)
# prompts = {
# 'point_coords': points,
# 'point_labels': labels,
# }
# masks, scores, logits = sam_controler.predict(prompts, 'point', multimask)
# mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
# painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
# painted_image = Image.fromarray(painted_image)
# return mask, logit, painted_image

12
track_anything.py Normal file
View File

@@ -0,0 +1,12 @@
from tools.interact_tools import SamControler
from tracker.xmem import XMem
class TrackingAnything():
def __init__(self, cfg):
self.cfg = cfg
self.samcontroler = SamControler()
self.xmem =
pass

View File

@@ -9,6 +9,9 @@ import torch
import yaml
from model.network import XMem
from inference.inference_core import InferenceCore
# for data transormation
from torchvision import transforms
from dataset.range_transform import im_normalization
class BaseTracker:
@@ -24,14 +27,27 @@ class BaseTracker:
network = XMem(config, xmem_checkpoint).to(device).eval()
# initialise IncerenceCore
self.tracker = InferenceCore(network, config)
# set data transformation
# self.data_transform =
# data transformation
self.im_transform = transforms.Compose([
transforms.ToTensor(),
im_normalization,
])
def track(self, frames, first_frame_annotation):
"""
Input:
frames: numpy arrays: T, H, W, 3 (T: number of frames)
first_frame_annotation: numpy array: H, W
Output:
masks: numpy arrays: T, H, W
"""
# data transformation
for frame in frames:
frame = self.im_transform(frame)
# tracking
pass
if __name__ == '__main__':