mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 08:27:49 +01:00
Merge branch 'master' of github.com:gaomingqi/Track-Anything
This commit is contained in:
6
app.py
6
app.py
@@ -7,10 +7,12 @@ from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
# from tools.interact_tools import initialize
|
||||
from tools.interact_tools import SamControler
|
||||
|
||||
samc = SamControler()
|
||||
|
||||
|
||||
|
||||
# initialize()
|
||||
|
||||
|
||||
def pause_video(play_state):
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Union
|
||||
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
|
||||
import matplotlib.pyplot as plt
|
||||
import PIL
|
||||
from mask_painter import mask_painter
|
||||
from .mask_painter import mask_painter
|
||||
|
||||
|
||||
class BaseSegmenter:
|
||||
@@ -78,7 +78,7 @@ class BaseSegmenter:
|
||||
|
||||
if __name__ == "__main__":
|
||||
# load and show an image
|
||||
image = cv2.imread('images/truck.jpg')
|
||||
image = cv2.imread('/hhd3/gaoshang/truck.jpg')
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # numpy array (h, w, 3)
|
||||
|
||||
# initialise BaseSegmenter
|
||||
@@ -100,7 +100,7 @@ if __name__ == "__main__":
|
||||
masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=False) # masks (n, h, w), scores (n,), logits (n, 256, 256)
|
||||
painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
|
||||
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
|
||||
cv2.imwrite('images/truck_point.jpg', painted_image)
|
||||
cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image)
|
||||
|
||||
# both ------------------------
|
||||
mode = 'both'
|
||||
@@ -114,13 +114,15 @@ if __name__ == "__main__":
|
||||
masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256)
|
||||
painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
|
||||
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
|
||||
cv2.imwrite('images/truck_both.jpg', painted_image)
|
||||
cv2.imwrite('/hhd3/gaoshang/truck_both.jpg', painted_image)
|
||||
|
||||
# mask only ------------------------
|
||||
mode = 'mask'
|
||||
mask_input = logits[np.argmax(scores), :, :]
|
||||
|
||||
prompts = {'mask_input': mask_input[None, :, :]}
|
||||
|
||||
masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256)
|
||||
painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
|
||||
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
|
||||
cv2.imwrite('images/truck_mask.jpg', painted_image)
|
||||
cv2.imwrite('/hhd3/gaoshang/truck_mask.jpg', painted_image)
|
||||
|
||||
@@ -7,11 +7,13 @@ from typing import Union
|
||||
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
|
||||
import matplotlib.pyplot as plt
|
||||
import PIL
|
||||
from tools.mask_painter import mask_painter as mask_painter2
|
||||
from base_segmenter import BaseSegmenter
|
||||
from painter import mask_painter, point_painter
|
||||
from .mask_painter import mask_painter as mask_painter2
|
||||
from .base_segmenter import BaseSegmenter
|
||||
from .painter import mask_painter, point_painter
|
||||
import os
|
||||
import requests
|
||||
import sys
|
||||
|
||||
|
||||
mask_color = 3
|
||||
mask_alpha = 0.7
|
||||
@@ -24,7 +26,6 @@ point_radius = 15
|
||||
contour_color = 2
|
||||
contour_width = 5
|
||||
|
||||
|
||||
def download_checkpoint(url, folder, filename):
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
filepath = os.path.join(folder, filename)
|
||||
@@ -38,43 +39,40 @@ def download_checkpoint(url, folder, filename):
|
||||
|
||||
return filepath
|
||||
|
||||
|
||||
def initialize():
|
||||
class SamControler():
|
||||
def __init__(self):
|
||||
'''
|
||||
initialize sam controler
|
||||
'''
|
||||
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
||||
folder = "segmenter"
|
||||
SAM_checkpoint= './checkpoints/sam_vit_h_4b8939.pth'
|
||||
download_checkpoint(checkpoint_url, folder, SAM_checkpoint)
|
||||
|
||||
|
||||
folder ="segmenter/checkpoints"
|
||||
SAM_checkpoint= 'sam_vit_h_4b8939.pth'
|
||||
SAM_checkpoint = download_checkpoint(checkpoint_url, folder, SAM_checkpoint)
|
||||
# SAM_checkpoint = '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
|
||||
model_type = 'vit_h'
|
||||
device = "cuda:0"
|
||||
sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
|
||||
return sam_controler
|
||||
self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
|
||||
|
||||
|
||||
def seg_again(sam_controler, image: np.ndarray):
|
||||
def seg_again(self, image: np.ndarray):
|
||||
'''
|
||||
it is used when interact in video
|
||||
'''
|
||||
sam_controler.reset_image()
|
||||
sam_controler.set_image(image)
|
||||
self.sam_controler.reset_image()
|
||||
self.sam_controler.set_image(image)
|
||||
return
|
||||
|
||||
|
||||
def first_frame_click(sam_controler, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
|
||||
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
|
||||
'''
|
||||
it is used in first frame in video
|
||||
return: mask, logit, painted image(mask+point)
|
||||
'''
|
||||
sam_controler.set_image(image)
|
||||
self.sam_controler.set_image(image)
|
||||
prompts = {
|
||||
'point_coords': points,
|
||||
'point_labels': labels,
|
||||
}
|
||||
masks, scores, logits = sam_controler.predict(prompts, 'point', multimask)
|
||||
masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
|
||||
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
||||
|
||||
assert len(points)==len(labels)
|
||||
@@ -86,7 +84,7 @@ def first_frame_click(sam_controler, image: np.ndarray, points:np.ndarray, label
|
||||
|
||||
return mask, logit, painted_image
|
||||
|
||||
def interact_loop(sam_controler, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
||||
def interact_loop(self, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
||||
if same:
|
||||
'''
|
||||
true; loop in the same image
|
||||
@@ -96,7 +94,7 @@ def interact_loop(sam_controler, image:np.ndarray, same: bool, points:np.ndarray
|
||||
'point_labels': labels,
|
||||
'mask_input': logits[None, :, :]
|
||||
}
|
||||
masks, scores, logits = sam_controler.predict(prompts, 'both', multimask)
|
||||
masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask)
|
||||
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
||||
|
||||
painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
||||
@@ -112,12 +110,12 @@ def interact_loop(sam_controler, image:np.ndarray, same: bool, points:np.ndarray
|
||||
if image is None:
|
||||
raise('Image error')
|
||||
else:
|
||||
seg_again(sam_controler, image)
|
||||
self.seg_again(image)
|
||||
prompts = {
|
||||
'point_coords': points,
|
||||
'point_labels': labels,
|
||||
}
|
||||
masks, scores, logits = sam_controler.predict(prompts, 'point', multimask)
|
||||
masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
|
||||
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
||||
|
||||
painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
||||
@@ -130,6 +128,99 @@ def interact_loop(sam_controler, image:np.ndarray, same: bool, points:np.ndarray
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# def initialize():
|
||||
# '''
|
||||
# initialize sam controler
|
||||
# '''
|
||||
# checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
||||
# folder = "segmenter"
|
||||
# SAM_checkpoint= './checkpoints/sam_vit_h_4b8939.pth'
|
||||
# download_checkpoint(checkpoint_url, folder, SAM_checkpoint)
|
||||
|
||||
|
||||
# model_type = 'vit_h'
|
||||
# device = "cuda:0"
|
||||
# sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
|
||||
# return sam_controler
|
||||
|
||||
|
||||
# def seg_again(sam_controler, image: np.ndarray):
|
||||
# '''
|
||||
# it is used when interact in video
|
||||
# '''
|
||||
# sam_controler.reset_image()
|
||||
# sam_controler.set_image(image)
|
||||
# return
|
||||
|
||||
|
||||
# def first_frame_click(sam_controler, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
|
||||
# '''
|
||||
# it is used in first frame in video
|
||||
# return: mask, logit, painted image(mask+point)
|
||||
# '''
|
||||
# sam_controler.set_image(image)
|
||||
# prompts = {
|
||||
# 'point_coords': points,
|
||||
# 'point_labels': labels,
|
||||
# }
|
||||
# masks, scores, logits = sam_controler.predict(prompts, 'point', multimask)
|
||||
# mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
||||
|
||||
# assert len(points)==len(labels)
|
||||
|
||||
# painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
||||
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
|
||||
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
|
||||
# painted_image = Image.fromarray(painted_image)
|
||||
|
||||
# return mask, logit, painted_image
|
||||
|
||||
# def interact_loop(sam_controler, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
||||
# if same:
|
||||
# '''
|
||||
# true; loop in the same image
|
||||
# '''
|
||||
# prompts = {
|
||||
# 'point_coords': points,
|
||||
# 'point_labels': labels,
|
||||
# 'mask_input': logits[None, :, :]
|
||||
# }
|
||||
# masks, scores, logits = sam_controler.predict(prompts, 'both', multimask)
|
||||
# mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
||||
|
||||
# painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
||||
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
|
||||
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
|
||||
# painted_image = Image.fromarray(painted_image)
|
||||
|
||||
# return mask, logit, painted_image
|
||||
# else:
|
||||
# '''
|
||||
# loop in the different image, interact in the video
|
||||
# '''
|
||||
# if image is None:
|
||||
# raise('Image error')
|
||||
# else:
|
||||
# seg_again(sam_controler, image)
|
||||
# prompts = {
|
||||
# 'point_coords': points,
|
||||
# 'point_labels': labels,
|
||||
# }
|
||||
# masks, scores, logits = sam_controler.predict(prompts, 'point', multimask)
|
||||
# mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
||||
|
||||
# painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
||||
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
|
||||
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
|
||||
# painted_image = Image.fromarray(painted_image)
|
||||
|
||||
# return mask, logit, painted_image
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
points = np.array([[500, 375], [1125, 625]])
|
||||
labels = np.array([1, 1])
|
||||
|
||||
12
track_anything.py
Normal file
12
track_anything.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from tools.interact_tools import SamControler
|
||||
from tracker.xmem import XMem
|
||||
|
||||
|
||||
|
||||
|
||||
class TrackingAnything():
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
self.samcontroler = SamControler()
|
||||
self.xmem =
|
||||
pass
|
||||
@@ -9,6 +9,9 @@ import torch
|
||||
import yaml
|
||||
from model.network import XMem
|
||||
from inference.inference_core import InferenceCore
|
||||
# for data transormation
|
||||
from torchvision import transforms
|
||||
from dataset.range_transform import im_normalization
|
||||
|
||||
|
||||
class BaseTracker:
|
||||
@@ -24,14 +27,27 @@ class BaseTracker:
|
||||
network = XMem(config, xmem_checkpoint).to(device).eval()
|
||||
# initialise IncerenceCore
|
||||
self.tracker = InferenceCore(network, config)
|
||||
# set data transformation
|
||||
# self.data_transform =
|
||||
# data transformation
|
||||
self.im_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
im_normalization,
|
||||
])
|
||||
|
||||
def track(self, frames, first_frame_annotation):
|
||||
"""
|
||||
Input:
|
||||
frames: numpy arrays: T, H, W, 3 (T: number of frames)
|
||||
first_frame_annotation: numpy array: H, W
|
||||
|
||||
Output:
|
||||
masks: numpy arrays: T, H, W
|
||||
"""
|
||||
# data transformation
|
||||
for frame in frames:
|
||||
frame = self.im_transform(frame)
|
||||
|
||||
# tracking
|
||||
pass
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user