diff --git a/app.py b/app.py index 79ba26b..1cd9603 100644 --- a/app.py +++ b/app.py @@ -7,10 +7,12 @@ from PIL import Image import numpy as np -# from tools.interact_tools import initialize +from tools.interact_tools import SamControler + +samc = SamControler() + -# initialize() def pause_video(play_state): diff --git a/tools/base_segmenter.py b/tools/base_segmenter.py index 42e017c..c5758a8 100644 --- a/tools/base_segmenter.py +++ b/tools/base_segmenter.py @@ -7,7 +7,7 @@ from typing import Union from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator import matplotlib.pyplot as plt import PIL -from mask_painter import mask_painter +from .mask_painter import mask_painter class BaseSegmenter: @@ -78,7 +78,7 @@ class BaseSegmenter: if __name__ == "__main__": # load and show an image - image = cv2.imread('images/truck.jpg') + image = cv2.imread('/hhd3/gaoshang/truck.jpg') image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # numpy array (h, w, 3) # initialise BaseSegmenter @@ -100,7 +100,7 @@ if __name__ == "__main__": masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=False) # masks (n, h, w), scores (n,), logits (n, 256, 256) painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8) painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3) - cv2.imwrite('images/truck_point.jpg', painted_image) + cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image) # both ------------------------ mode = 'both' @@ -114,13 +114,15 @@ if __name__ == "__main__": masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256) painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8) painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3) - cv2.imwrite('images/truck_both.jpg', painted_image) + cv2.imwrite('/hhd3/gaoshang/truck_both.jpg', painted_image) # mask only ------------------------ mode = 'mask' mask_input = logits[np.argmax(scores), :, :] + prompts = {'mask_input': mask_input[None, :, :]} + masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256) painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8) painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3) - cv2.imwrite('images/truck_mask.jpg', painted_image) + cv2.imwrite('/hhd3/gaoshang/truck_mask.jpg', painted_image) diff --git a/tools/interact_tools.py b/tools/interact_tools.py index 997fcdf..f674db5 100644 --- a/tools/interact_tools.py +++ b/tools/interact_tools.py @@ -7,11 +7,13 @@ from typing import Union from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator import matplotlib.pyplot as plt import PIL -from tools.mask_painter import mask_painter as mask_painter2 -from base_segmenter import BaseSegmenter -from painter import mask_painter, point_painter +from .mask_painter import mask_painter as mask_painter2 +from .base_segmenter import BaseSegmenter +from .painter import mask_painter, point_painter import os import requests +import sys + mask_color = 3 mask_alpha = 0.7 @@ -24,7 +26,6 @@ point_radius = 15 contour_color = 2 contour_width = 5 - def download_checkpoint(url, folder, filename): os.makedirs(folder, exist_ok=True) filepath = os.path.join(folder, filename) @@ -38,94 +39,184 @@ def download_checkpoint(url, folder, filename): return filepath - -def initialize(): - ''' - initialize sam controler - ''' - checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" - folder = "segmenter" - SAM_checkpoint= './checkpoints/sam_vit_h_4b8939.pth' - download_checkpoint(checkpoint_url, folder, SAM_checkpoint) - - - model_type = 'vit_h' - device = "cuda:0" - sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device) - return sam_controler - - -def seg_again(sam_controler, image: np.ndarray): - ''' - it is used when interact in video - ''' - sam_controler.reset_image() - sam_controler.set_image(image) - return - - -def first_frame_click(sam_controler, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True): - ''' - it is used in first frame in video - return: mask, logit, painted image(mask+point) - ''' - sam_controler.set_image(image) - prompts = { - 'point_coords': points, - 'point_labels': labels, - } - masks, scores, logits = sam_controler.predict(prompts, 'point', multimask) - mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] - - assert len(points)==len(labels) - - painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width) - painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width) - painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width) - painted_image = Image.fromarray(painted_image) - - return mask, logit, painted_image - -def interact_loop(sam_controler, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True): - if same: +class SamControler(): + def __init__(self): ''' - true; loop in the same image + initialize sam controler ''' + checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" + folder ="segmenter/checkpoints" + SAM_checkpoint= 'sam_vit_h_4b8939.pth' + SAM_checkpoint = download_checkpoint(checkpoint_url, folder, SAM_checkpoint) + # SAM_checkpoint = '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth' + model_type = 'vit_h' + device = "cuda:0" + self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device) + + + def seg_again(self, image: np.ndarray): + ''' + it is used when interact in video + ''' + self.sam_controler.reset_image() + self.sam_controler.set_image(image) + return + + def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True): + ''' + it is used in first frame in video + return: mask, logit, painted image(mask+point) + ''' + self.sam_controler.set_image(image) prompts = { 'point_coords': points, 'point_labels': labels, - 'mask_input': logits[None, :, :] } - masks, scores, logits = sam_controler.predict(prompts, 'both', multimask) + masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask) mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] + assert len(points)==len(labels) + painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width) painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width) painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width) painted_image = Image.fromarray(painted_image) - + return mask, logit, painted_image - else: - ''' - loop in the different image, interact in the video - ''' - if image is None: - raise('Image error') + + def interact_loop(self, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True): + if same: + ''' + true; loop in the same image + ''' + prompts = { + 'point_coords': points, + 'point_labels': labels, + 'mask_input': logits[None, :, :] + } + masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask) + mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] + + painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width) + painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width) + painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width) + painted_image = Image.fromarray(painted_image) + + return mask, logit, painted_image else: - seg_again(sam_controler, image) - prompts = { - 'point_coords': points, - 'point_labels': labels, - } - masks, scores, logits = sam_controler.predict(prompts, 'point', multimask) - mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] - - painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width) - painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width) - painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width) - painted_image = Image.fromarray(painted_image) + ''' + loop in the different image, interact in the video + ''' + if image is None: + raise('Image error') + else: + self.seg_again(image) + prompts = { + 'point_coords': points, + 'point_labels': labels, + } + masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask) + mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] + + painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width) + painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width) + painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width) + painted_image = Image.fromarray(painted_image) - return mask, logit, painted_image + return mask, logit, painted_image + + + + + + +# def initialize(): +# ''' +# initialize sam controler +# ''' +# checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" +# folder = "segmenter" +# SAM_checkpoint= './checkpoints/sam_vit_h_4b8939.pth' +# download_checkpoint(checkpoint_url, folder, SAM_checkpoint) + + +# model_type = 'vit_h' +# device = "cuda:0" +# sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device) +# return sam_controler + + +# def seg_again(sam_controler, image: np.ndarray): +# ''' +# it is used when interact in video +# ''' +# sam_controler.reset_image() +# sam_controler.set_image(image) +# return + + +# def first_frame_click(sam_controler, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True): +# ''' +# it is used in first frame in video +# return: mask, logit, painted image(mask+point) +# ''' +# sam_controler.set_image(image) +# prompts = { +# 'point_coords': points, +# 'point_labels': labels, +# } +# masks, scores, logits = sam_controler.predict(prompts, 'point', multimask) +# mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] + +# assert len(points)==len(labels) + +# painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width) +# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width) +# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width) +# painted_image = Image.fromarray(painted_image) + +# return mask, logit, painted_image + +# def interact_loop(sam_controler, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True): +# if same: +# ''' +# true; loop in the same image +# ''' +# prompts = { +# 'point_coords': points, +# 'point_labels': labels, +# 'mask_input': logits[None, :, :] +# } +# masks, scores, logits = sam_controler.predict(prompts, 'both', multimask) +# mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] + +# painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width) +# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width) +# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width) +# painted_image = Image.fromarray(painted_image) + +# return mask, logit, painted_image +# else: +# ''' +# loop in the different image, interact in the video +# ''' +# if image is None: +# raise('Image error') +# else: +# seg_again(sam_controler, image) +# prompts = { +# 'point_coords': points, +# 'point_labels': labels, +# } +# masks, scores, logits = sam_controler.predict(prompts, 'point', multimask) +# mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] + +# painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width) +# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width) +# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width) +# painted_image = Image.fromarray(painted_image) + +# return mask, logit, painted_image diff --git a/track_anything.py b/track_anything.py new file mode 100644 index 0000000..df8240b --- /dev/null +++ b/track_anything.py @@ -0,0 +1,12 @@ +from tools.interact_tools import SamControler +from tracker.xmem import XMem + + + + +class TrackingAnything(): + def __init__(self, cfg): + self.cfg = cfg + self.samcontroler = SamControler() + self.xmem = + pass \ No newline at end of file diff --git a/tracker/base_tracker.py b/tracker/base_tracker.py index 67940e1..f128bfd 100644 --- a/tracker/base_tracker.py +++ b/tracker/base_tracker.py @@ -9,6 +9,9 @@ import torch import yaml from model.network import XMem from inference.inference_core import InferenceCore +# for data transormation +from torchvision import transforms +from dataset.range_transform import im_normalization class BaseTracker: @@ -24,14 +27,27 @@ class BaseTracker: network = XMem(config, xmem_checkpoint).to(device).eval() # initialise IncerenceCore self.tracker = InferenceCore(network, config) - # set data transformation - # self.data_transform = + # data transformation + self.im_transform = transforms.Compose([ + transforms.ToTensor(), + im_normalization, + ]) def track(self, frames, first_frame_annotation): + """ + Input: + frames: numpy arrays: T, H, W, 3 (T: number of frames) + first_frame_annotation: numpy array: H, W + + Output: + masks: numpy arrays: T, H, W + """ # data transformation + for frame in frames: + frame = self.im_transform(frame) # tracking - pass + if __name__ == '__main__':