mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 08:27:49 +01:00
gao base update
This commit is contained in:
6
app.py
6
app.py
@@ -7,10 +7,12 @@ from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
from tools.interact_tools import initialize
|
||||
from tools.interact_tools import SamControler
|
||||
|
||||
samc = SamControler()
|
||||
|
||||
|
||||
|
||||
initialize()
|
||||
|
||||
|
||||
def pause_video(play_state):
|
||||
|
||||
37
app_test.py
37
app_test.py
@@ -1,37 +0,0 @@
|
||||
import gradio as gr
|
||||
import time
|
||||
|
||||
def capture_frame(video):
|
||||
frame = video.get_frame_at_sec(video.current_time)
|
||||
return frame
|
||||
|
||||
def capture_time(video):
|
||||
while True:
|
||||
if video.paused:
|
||||
time_paused = video.current_time
|
||||
return time_paused
|
||||
|
||||
iface = gr.Interface(fn=capture_frame,
|
||||
inputs=[gr.inputs.Video(type="mp4", label="Input video",
|
||||
source="upload")],
|
||||
outputs=["image"],
|
||||
server_port=12212,
|
||||
server_name="0.0.0.0",
|
||||
capture_session=True)
|
||||
|
||||
video_player = iface.video[0]
|
||||
video_player.pause = False
|
||||
|
||||
time_interface = gr.Interface(fn=capture_time,
|
||||
inputs=[gr.inputs.Video(type="mp4", label="Input video",
|
||||
source="upload", max_duration=10)],
|
||||
outputs=["text"],
|
||||
server_port=12212,
|
||||
server_name="0.0.0.0",
|
||||
capture_session=True)
|
||||
|
||||
time_interface.video[0].play = False
|
||||
time_interface.video[0].pause = False
|
||||
|
||||
iface.launch()
|
||||
time_interface.launch()
|
||||
@@ -7,7 +7,7 @@ from typing import Union
|
||||
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
|
||||
import matplotlib.pyplot as plt
|
||||
import PIL
|
||||
from mask_painter import mask_painter
|
||||
from .mask_painter import mask_painter
|
||||
|
||||
|
||||
class BaseSegmenter:
|
||||
@@ -78,7 +78,7 @@ class BaseSegmenter:
|
||||
|
||||
if __name__ == "__main__":
|
||||
# load and show an image
|
||||
image = cv2.imread('images/truck.jpg')
|
||||
image = cv2.imread('/hhd3/gaoshang/truck.jpg')
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # numpy array (h, w, 3)
|
||||
|
||||
# initialise BaseSegmenter
|
||||
@@ -100,7 +100,7 @@ if __name__ == "__main__":
|
||||
masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=False) # masks (n, h, w), scores (n,), logits (n, 256, 256)
|
||||
painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
|
||||
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
|
||||
cv2.imwrite('images/truck_point.jpg', painted_image)
|
||||
cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image)
|
||||
|
||||
# both ------------------------
|
||||
mode = 'both'
|
||||
@@ -114,13 +114,15 @@ if __name__ == "__main__":
|
||||
masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256)
|
||||
painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
|
||||
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
|
||||
cv2.imwrite('images/truck_both.jpg', painted_image)
|
||||
cv2.imwrite('/hhd3/gaoshang/truck_both.jpg', painted_image)
|
||||
|
||||
# mask only ------------------------
|
||||
mode = 'mask'
|
||||
mask_input = logits[np.argmax(scores), :, :]
|
||||
|
||||
prompts = {'mask_input': mask_input[None, :, :]}
|
||||
|
||||
masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256)
|
||||
painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
|
||||
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
|
||||
cv2.imwrite('images/truck_mask.jpg', painted_image)
|
||||
cv2.imwrite('/hhd3/gaoshang/truck_mask.jpg', painted_image)
|
||||
|
||||
@@ -7,11 +7,13 @@ from typing import Union
|
||||
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
|
||||
import matplotlib.pyplot as plt
|
||||
import PIL
|
||||
from tools.mask_painter import mask_painter as mask_painter2
|
||||
from base_segmenter import BaseSegmenter
|
||||
from painter import mask_painter, point_painter
|
||||
from .mask_painter import mask_painter as mask_painter2
|
||||
from .base_segmenter import BaseSegmenter
|
||||
from .painter import mask_painter, point_painter
|
||||
import os
|
||||
import requests
|
||||
import sys
|
||||
|
||||
|
||||
mask_color = 3
|
||||
mask_alpha = 0.7
|
||||
@@ -24,7 +26,6 @@ point_radius = 15
|
||||
contour_color = 2
|
||||
contour_width = 5
|
||||
|
||||
|
||||
def download_checkpoint(url, folder, filename):
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
filepath = os.path.join(folder, filename)
|
||||
@@ -38,94 +39,184 @@ def download_checkpoint(url, folder, filename):
|
||||
|
||||
return filepath
|
||||
|
||||
|
||||
def initialize():
|
||||
'''
|
||||
initialize sam controler
|
||||
'''
|
||||
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
||||
folder = "segmenter"
|
||||
SAM_checkpoint= './checkpoints/sam_vit_h_4b8939.pth'
|
||||
download_checkpoint(checkpoint_url, folder, SAM_checkpoint)
|
||||
|
||||
|
||||
model_type = 'vit_h'
|
||||
device = "cuda:0"
|
||||
sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
|
||||
return sam_controler
|
||||
|
||||
|
||||
def seg_again(sam_controler, image: np.ndarray):
|
||||
'''
|
||||
it is used when interact in video
|
||||
'''
|
||||
sam_controler.reset_image()
|
||||
sam_controler.set_image(image)
|
||||
return
|
||||
|
||||
|
||||
def first_frame_click(sam_controler, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
|
||||
'''
|
||||
it is used in first frame in video
|
||||
return: mask, logit, painted image(mask+point)
|
||||
'''
|
||||
sam_controler.set_image(image)
|
||||
prompts = {
|
||||
'point_coords': points,
|
||||
'point_labels': labels,
|
||||
}
|
||||
masks, scores, logits = sam_controler.predict(prompts, 'point', multimask)
|
||||
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
||||
|
||||
assert len(points)==len(labels)
|
||||
|
||||
painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
||||
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
|
||||
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
|
||||
painted_image = Image.fromarray(painted_image)
|
||||
|
||||
return mask, logit, painted_image
|
||||
|
||||
def interact_loop(sam_controler, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
||||
if same:
|
||||
class SamControler():
|
||||
def __init__(self):
|
||||
'''
|
||||
true; loop in the same image
|
||||
initialize sam controler
|
||||
'''
|
||||
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
||||
folder ="segmenter/checkpoints"
|
||||
SAM_checkpoint= 'sam_vit_h_4b8939.pth'
|
||||
SAM_checkpoint = download_checkpoint(checkpoint_url, folder, SAM_checkpoint)
|
||||
# SAM_checkpoint = '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
|
||||
model_type = 'vit_h'
|
||||
device = "cuda:0"
|
||||
self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
|
||||
|
||||
|
||||
def seg_again(self, image: np.ndarray):
|
||||
'''
|
||||
it is used when interact in video
|
||||
'''
|
||||
self.sam_controler.reset_image()
|
||||
self.sam_controler.set_image(image)
|
||||
return
|
||||
|
||||
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
|
||||
'''
|
||||
it is used in first frame in video
|
||||
return: mask, logit, painted image(mask+point)
|
||||
'''
|
||||
self.sam_controler.set_image(image)
|
||||
prompts = {
|
||||
'point_coords': points,
|
||||
'point_labels': labels,
|
||||
'mask_input': logits[None, :, :]
|
||||
}
|
||||
masks, scores, logits = sam_controler.predict(prompts, 'both', multimask)
|
||||
masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
|
||||
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
||||
|
||||
assert len(points)==len(labels)
|
||||
|
||||
painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
||||
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
|
||||
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
|
||||
painted_image = Image.fromarray(painted_image)
|
||||
|
||||
|
||||
return mask, logit, painted_image
|
||||
else:
|
||||
'''
|
||||
loop in the different image, interact in the video
|
||||
'''
|
||||
if image is None:
|
||||
raise('Image error')
|
||||
|
||||
def interact_loop(self, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
||||
if same:
|
||||
'''
|
||||
true; loop in the same image
|
||||
'''
|
||||
prompts = {
|
||||
'point_coords': points,
|
||||
'point_labels': labels,
|
||||
'mask_input': logits[None, :, :]
|
||||
}
|
||||
masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask)
|
||||
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
||||
|
||||
painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
||||
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
|
||||
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
|
||||
painted_image = Image.fromarray(painted_image)
|
||||
|
||||
return mask, logit, painted_image
|
||||
else:
|
||||
seg_again(sam_controler, image)
|
||||
prompts = {
|
||||
'point_coords': points,
|
||||
'point_labels': labels,
|
||||
}
|
||||
masks, scores, logits = sam_controler.predict(prompts, 'point', multimask)
|
||||
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
||||
|
||||
painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
||||
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
|
||||
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
|
||||
painted_image = Image.fromarray(painted_image)
|
||||
'''
|
||||
loop in the different image, interact in the video
|
||||
'''
|
||||
if image is None:
|
||||
raise('Image error')
|
||||
else:
|
||||
self.seg_again(image)
|
||||
prompts = {
|
||||
'point_coords': points,
|
||||
'point_labels': labels,
|
||||
}
|
||||
masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
|
||||
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
||||
|
||||
painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
||||
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
|
||||
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
|
||||
painted_image = Image.fromarray(painted_image)
|
||||
|
||||
return mask, logit, painted_image
|
||||
return mask, logit, painted_image
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# def initialize():
|
||||
# '''
|
||||
# initialize sam controler
|
||||
# '''
|
||||
# checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
||||
# folder = "segmenter"
|
||||
# SAM_checkpoint= './checkpoints/sam_vit_h_4b8939.pth'
|
||||
# download_checkpoint(checkpoint_url, folder, SAM_checkpoint)
|
||||
|
||||
|
||||
# model_type = 'vit_h'
|
||||
# device = "cuda:0"
|
||||
# sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
|
||||
# return sam_controler
|
||||
|
||||
|
||||
# def seg_again(sam_controler, image: np.ndarray):
|
||||
# '''
|
||||
# it is used when interact in video
|
||||
# '''
|
||||
# sam_controler.reset_image()
|
||||
# sam_controler.set_image(image)
|
||||
# return
|
||||
|
||||
|
||||
# def first_frame_click(sam_controler, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
|
||||
# '''
|
||||
# it is used in first frame in video
|
||||
# return: mask, logit, painted image(mask+point)
|
||||
# '''
|
||||
# sam_controler.set_image(image)
|
||||
# prompts = {
|
||||
# 'point_coords': points,
|
||||
# 'point_labels': labels,
|
||||
# }
|
||||
# masks, scores, logits = sam_controler.predict(prompts, 'point', multimask)
|
||||
# mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
||||
|
||||
# assert len(points)==len(labels)
|
||||
|
||||
# painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
||||
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
|
||||
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
|
||||
# painted_image = Image.fromarray(painted_image)
|
||||
|
||||
# return mask, logit, painted_image
|
||||
|
||||
# def interact_loop(sam_controler, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
||||
# if same:
|
||||
# '''
|
||||
# true; loop in the same image
|
||||
# '''
|
||||
# prompts = {
|
||||
# 'point_coords': points,
|
||||
# 'point_labels': labels,
|
||||
# 'mask_input': logits[None, :, :]
|
||||
# }
|
||||
# masks, scores, logits = sam_controler.predict(prompts, 'both', multimask)
|
||||
# mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
||||
|
||||
# painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
||||
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
|
||||
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
|
||||
# painted_image = Image.fromarray(painted_image)
|
||||
|
||||
# return mask, logit, painted_image
|
||||
# else:
|
||||
# '''
|
||||
# loop in the different image, interact in the video
|
||||
# '''
|
||||
# if image is None:
|
||||
# raise('Image error')
|
||||
# else:
|
||||
# seg_again(sam_controler, image)
|
||||
# prompts = {
|
||||
# 'point_coords': points,
|
||||
# 'point_labels': labels,
|
||||
# }
|
||||
# masks, scores, logits = sam_controler.predict(prompts, 'point', multimask)
|
||||
# mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
||||
|
||||
# painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
||||
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
|
||||
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
|
||||
# painted_image = Image.fromarray(painted_image)
|
||||
|
||||
# return mask, logit, painted_image
|
||||
|
||||
|
||||
|
||||
|
||||
12
track_anything.py
Normal file
12
track_anything.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from tools.interact_tools import SamControler
|
||||
from tracker.xmem import XMem
|
||||
|
||||
|
||||
|
||||
|
||||
class TrackingAnything():
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
self.samcontroler = SamControler()
|
||||
self.xmem =
|
||||
pass
|
||||
Reference in New Issue
Block a user