This commit is contained in:
gaomingqi
2023-04-14 12:37:46 +08:00
5 changed files with 136 additions and 48 deletions

136
app.py
View File

@@ -5,11 +5,45 @@ import cv2
import time import time
from PIL import Image from PIL import Image
import numpy as np import numpy as np
import os
import sys
sys.path.append(sys.path[0]+"/tracker")
sys.path.append(sys.path[0]+"/tracker/model")
from track_anything import TrackingAnything
from track_anything import parse_augment
import requests
def download_checkpoint(url, folder, filename):
os.makedirs(folder, exist_ok=True)
filepath = os.path.join(folder, filename)
if not os.path.exists(filepath):
print("download checkpoints ......")
response = requests.get(url, stream=True)
with open(filepath, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
print("download successfully!")
return filepath
from tools.interact_tools import SamControler # check and download checkpoints if needed
SAM_checkpoint = "sam_vit_h_4b8939.pth"
sam_checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
xmem_checkpoint = "XMem-s012.pth"
xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
folder ="./checkpoints"
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, SAM_checkpoint)
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
samc = SamControler() # args, defined in track_anything.py
args = parse_augment()
args.port=12212
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
@@ -25,6 +59,23 @@ def play_video(play_state):
play_state.append(time.time()) play_state.append(time.time())
return play_state return play_state
# convert points input to prompt state
def get_prompt(inputs, click_state):
points = []
labels = []
for input in inputs:
points.append(input[:2])
labels.append(input[2])
click_state[0] = points
prompt = {
"prompt_type":["click"],
"input_point":click_state[0],
"input_label":click_state[1],
"multimask_output":"True",
}
return prompt
def get_frames_from_video(video_input, play_state): def get_frames_from_video(video_input, play_state):
""" """
Args: Args:
@@ -56,13 +107,28 @@ def get_frames_from_video(video_input, play_state):
frames = [frames[:key_frame_index], frames[key_frame_index:], nearest_frame] frames = [frames[:key_frame_index], frames[key_frame_index:], nearest_frame]
return frames, nearest_frame return frames, nearest_frame
def inference_all(template_frame, evt:gr.SelectData):
coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1])
# default value
points = np.array([[evt.index[0],evt.index[1]]])
labels= np.array([1])
mask, logit, painted_image = model.inference_step(first_flag=True, interact_flag=False, image=np.asarray(template_frame), same_image_flag=False,points=points, labels=labels,logits=None,multimask=True)
return painted_image
with gr.Blocks() as iface: with gr.Blocks() as iface:
"""
state for
"""
state = gr.State([]) state = gr.State([])
play_state = gr.State([]) play_state = gr.State([])
video_state = gr.State([[],[],[]]) video_state = gr.State([[],[],[]])
click_state = gr.State([[],[]])
with gr.Row(): with gr.Row():
# for user video input
with gr.Column(scale=1.0): with gr.Column(scale=1.0):
video_input = gr.Video().style(height=720) video_input = gr.Video().style(height=720)
@@ -70,26 +136,45 @@ with gr.Blocks() as iface:
video_input.play(fn=play_video, inputs=play_state, outputs=play_state) video_input.play(fn=play_video, inputs=play_state, outputs=play_state)
video_input.pause(fn=pause_video, inputs=play_state, outputs=play_state) video_input.pause(fn=pause_video, inputs=play_state, outputs=play_state)
with gr.Row():
with gr.Row(): with gr.Row(scale=1):
with gr.Column(scale=0.5): # put the template frame under the radio button
template_frame = gr.Image(type="pil", interactive=True, elem_id="template_frame") with gr.Column(scale=0.5):
with gr.Column(): # click points settins, negative or positive, mode continuous or single
template_select_button = gr.Button(value="Template select", interactive=True, variant="primary") with gr.Row():
with gr.Row(scale=0.5):
with gr.Column(scale=0.5): point_prompt = gr.Radio(
with gr.Row(scale=0.4): choices=["Positive", "Negative"],
clear_button_clike = gr.Button(value="Clear Clicks", interactive=True) value="Positive",
label="Point Prompt",
interactive=True)
click_mode = gr.Radio(
choices=["Continuous", "Single"],
value="Continuous",
label="Clicking Mode",
interactive=True)
with gr.Row(scale=0.5):
clear_button_clike = gr.Button(value="Clear Clicks", interactive=True).style(height=160)
clear_button_image = gr.Button(value="Clear Image", interactive=True) clear_button_image = gr.Button(value="Clear Image", interactive=True)
template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame").style(height=360)
with gr.Column():
template_select_button = gr.Button(value="Template select", interactive=True, variant="primary")
with gr.Column(scale=0.5):
# seg_automask_video_points_per_batch = gr.Slider(
# minimum=0, # for intermedia result check and correction
# maximum=64, intermedia_image = gr.Image(type="pil", interactive=True, elem_id="intermedia_frame").style(height=360)
# step=2,
# value=64, # seg_automask_video_points_per_batch = gr.Slider(
# label="Points per Batch", # minimum=0,
# ) # maximum=64,
# step=2,
# value=64,
# label="Points per Batch",
# )
seg_automask_video_predict = gr.Button(value="Generator") seg_automask_video_predict = gr.Button(value="Generator")
@@ -125,6 +210,17 @@ with gr.Blocks() as iface:
outputs=[video_state, template_frame], outputs=[video_state, template_frame],
) )
template_frame.select(
fn=inference_all,
inputs=[
template_frame
],
outputs=[
template_frame
]
)
# clear # clear
# clear_button_clike.click( # clear_button_clike.click(
# lambda x: ([[], [], []], x, ""), # lambda x: ([[], [], []], x, ""),
@@ -149,7 +245,7 @@ with gr.Blocks() as iface:
) )
iface.queue(concurrency_count=1) iface.queue(concurrency_count=1)
iface.launch(debug=True, enable_queue=True, server_port=12200, server_name="0.0.0.0") iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0")

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.8 MiB

After

Width:  |  Height:  |  Size: 2.8 MiB

View File

@@ -26,33 +26,14 @@ point_radius = 15
contour_color = 2 contour_color = 2
contour_width = 5 contour_width = 5
def download_checkpoint(url, folder, filename):
os.makedirs(folder, exist_ok=True)
filepath = os.path.join(folder, filename)
if not os.path.exists(filepath):
print("download sam checkpoints ......")
response = requests.get(url, stream=True)
with open(filepath, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
print("download successfully!")
return filepath
class SamControler(): class SamControler():
def __init__(self, sam_checkpoint, model_type, device): def __init__(self, SAM_checkpoint, model_type, device):
''' '''
initialize sam controler initialize sam controler
''' '''
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
folder ="checkpoints"
SAM_checkpoint= 'sam_vit_h_4b8939.pth'
SAM_checkpoint = download_checkpoint(checkpoint_url, folder, SAM_checkpoint)
# SAM_checkpoint = '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
model_type = 'vit_h'
device = "cuda:0"
self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device) self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)

View File

@@ -1,14 +1,15 @@
from tools.interact_tools import SamControler from tools.interact_tools import SamControler
from tracker.base_tracker import BaseTracker from tracker.base_tracker import BaseTracker
import numpy as np import numpy as np
import argparse
class TrackingAnything(): class TrackingAnything():
def __init__(self, cfg): def __init__(self, sam_checkpoint, xmem_checkpoint, args):
self.cfg = cfg self.args = args
self.samcontroler = SamControler(cfg.sam_checkpoint, cfg.model_type, cfg.device) self.samcontroler = SamControler(sam_checkpoint, args.sam_model_type, args.device)
self.xmem = BaseTracker(cfg.device, cfg.xmem_checkpoint) self.xmem = BaseTracker(xmem_checkpoint, device=args.device, )
def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray, def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,
@@ -25,4 +26,14 @@ class TrackingAnything():
return mask, logit, painted_image return mask, logit, painted_image
def parse_augment():
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default="cuda:0")
parser.add_argument('--sam_model_type', type=str, default="vit_h")
parser.add_argument('--port', type=int, default=6080, help="only useful when running gradio applications")
parser.add_argument('--debug', action="store_true")
args = parser.parse_args()
if args.debug:
print(args)
return args

View File

@@ -21,7 +21,7 @@ from tools.painter import mask_painter
class BaseTracker: class BaseTracker:
def __init__(self, device, xmem_checkpoint) -> None: def __init__(self, xmem_checkpoint, device) -> None:
""" """
device: model device device: model device
xmem_checkpoint: checkpoint of XMem model xmem_checkpoint: checkpoint of XMem model