default parameter for template image segment for the first test - watchtower

This commit is contained in:
memoryunreal
2023-04-14 02:27:39 +00:00
parent d3f737ede5
commit 2fb43bf75e
4 changed files with 93 additions and 31 deletions

76
app.py
View File

@@ -5,11 +5,45 @@ import cv2
import time
from PIL import Image
import numpy as np
import os
import sys
sys.path.append(sys.path[0]+"/tracker")
sys.path.append(sys.path[0]+"/tracker/model")
from track_anything import TrackingAnything
from track_anything import parse_augment
import requests
def download_checkpoint(url, folder, filename):
os.makedirs(folder, exist_ok=True)
filepath = os.path.join(folder, filename)
if not os.path.exists(filepath):
print("download checkpoints ......")
response = requests.get(url, stream=True)
with open(filepath, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
print("download successfully!")
return filepath
from tools.interact_tools import SamControler
# check and download checkpoints if needed
SAM_checkpoint = "sam_vit_h_4b8939.pth"
sam_checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
xmem_checkpoint = "XMem-s012.pth"
xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
folder ="./checkpoints"
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, SAM_checkpoint)
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
samc = SamControler()
# args, defined in track_anything.py
args = parse_augment()
args.port=12212
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
@@ -25,6 +59,22 @@ def play_video(play_state):
play_state.append(time.time())
return play_state
# convert points input to prompt state
def get_prompt(inputs, click_state):
points = []
labels = []
for input in inputs:
points.append(input[:2])
labels.append(input[2])
click_state[0] = points
prompt = {
"prompt_type":["click"],
"input_point":click_state[0],
"input_label":click_state[1],
"multimask_output":"True",
}
return prompt
def get_frames_from_video(video_input, play_state):
"""
Args:
@@ -56,6 +106,15 @@ def get_frames_from_video(video_input, play_state):
frames = [frames[:key_frame_index], frames[key_frame_index:], nearest_frame]
return frames, nearest_frame
def inference_all(template_frame, evt:gr.SelectData):
coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1])
# default value
points = np.array([[evt.index[0],evt.index[1]]])
labels= np.array([1])
mask, logit, painted_image = model.inference_step(first_flag=True, interact_flag=False, image=np.asarray(template_frame), same_image_flag=False,points=points, labels=labels,logits=None,multimask=True)
return painted_image
with gr.Blocks() as iface:
state = gr.State([])
@@ -125,6 +184,17 @@ with gr.Blocks() as iface:
outputs=[video_state, template_frame],
)
template_frame.select(
fn=inference_all,
inputs=[
template_frame
],
outputs=[
template_frame
]
)
# clear
# clear_button_clike.click(
# lambda x: ([[], [], []], x, ""),
@@ -149,7 +219,7 @@ with gr.Blocks() as iface:
)
iface.queue(concurrency_count=1)
iface.launch(debug=True, enable_queue=True, server_port=12200, server_name="0.0.0.0")
iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0")