multiple click prompt realize - li

This commit is contained in:
memoryunreal
2023-04-14 08:24:57 +00:00
parent f4dde35968
commit 2a4289b150

49
app.py
View File

@@ -12,6 +12,7 @@ sys.path.append(sys.path[0]+"/tracker/model")
from track_anything import TrackingAnything
from track_anything import parse_augment
import requests
import json
def download_checkpoint(url, folder, filename):
os.makedirs(folder, exist_ok=True)
@@ -41,8 +42,7 @@ xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoi
# args, defined in track_anything.py
args = parse_augment()
args.port=12212
# args.port = 12213
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
@@ -60,13 +60,15 @@ def play_video(play_state):
return play_state
# convert points input to prompt state
def get_prompt(inputs, click_state):
points = []
labels = []
def get_prompt(click_state, click_input):
inputs = json.loads(click_input)
points = click_state[0]
labels = click_state[1]
for input in inputs:
points.append(input[:2])
labels.append(input[2])
click_state[0] = points
click_state[1] = labels
prompt = {
"prompt_type":["click"],
"input_point":click_state[0],
@@ -107,15 +109,34 @@ def get_frames_from_video(video_input, play_state):
frames = [frames[:key_frame_index], frames[key_frame_index:], nearest_frame]
return frames, nearest_frame
def inference_all(template_frame, click_state, evt:gr.SelectData):
coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1])
def inference_all(template_frame, point_prompt, click_state, evt:gr.SelectData):
"""
Args:
template_frame: PIL.Image
point_prompt: flag for positive or negative button click
click_state: [[points], [labels]]
"""
if point_prompt == "Positive":
coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1])
else:
coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1])
# prompt for sam model
prompt = get_prompt(click_state=click_state, click_input=coordinate)
# default value
points = np.array([[evt.index[0],evt.index[1]]])
labels= np.array([1])
mask, logit, painted_image = model.inference_step(first_flag=True, interact_flag=False, image=np.asarray(template_frame), same_image_flag=False,points=points, labels=labels,logits=None,multimask=True)
return painted_image
# points = np.array([[evt.index[0],evt.index[1]]])
# labels= np.array([1])
mask, logit, painted_image = model.inference_step(first_flag=True,
interact_flag=False,
image=np.asarray(template_frame),
same_image_flag=False,
points=np.array(prompt["input_point"]),
labels=np.array(prompt["input_label"]),
logits=None,
multimask=prompt["multimask_output"]
)
return painted_image, click_state
with gr.Blocks() as iface:
"""
@@ -213,10 +234,10 @@ with gr.Blocks() as iface:
template_frame.select(
fn=inference_all,
inputs=[
template_frame
template_frame, point_prompt, click_state
],
outputs=[
template_frame
template_frame, click_state
]
)