mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 00:17:50 +01:00
multiple click prompt realize - li
This commit is contained in:
49
app.py
49
app.py
@@ -12,6 +12,7 @@ sys.path.append(sys.path[0]+"/tracker/model")
|
||||
from track_anything import TrackingAnything
|
||||
from track_anything import parse_augment
|
||||
import requests
|
||||
import json
|
||||
|
||||
def download_checkpoint(url, folder, filename):
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
@@ -41,8 +42,7 @@ xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoi
|
||||
|
||||
# args, defined in track_anything.py
|
||||
args = parse_augment()
|
||||
args.port=12212
|
||||
|
||||
# args.port = 12213
|
||||
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
|
||||
|
||||
|
||||
@@ -60,13 +60,15 @@ def play_video(play_state):
|
||||
return play_state
|
||||
|
||||
# convert points input to prompt state
|
||||
def get_prompt(inputs, click_state):
|
||||
points = []
|
||||
labels = []
|
||||
def get_prompt(click_state, click_input):
|
||||
inputs = json.loads(click_input)
|
||||
points = click_state[0]
|
||||
labels = click_state[1]
|
||||
for input in inputs:
|
||||
points.append(input[:2])
|
||||
labels.append(input[2])
|
||||
click_state[0] = points
|
||||
click_state[1] = labels
|
||||
prompt = {
|
||||
"prompt_type":["click"],
|
||||
"input_point":click_state[0],
|
||||
@@ -107,15 +109,34 @@ def get_frames_from_video(video_input, play_state):
|
||||
frames = [frames[:key_frame_index], frames[key_frame_index:], nearest_frame]
|
||||
return frames, nearest_frame
|
||||
|
||||
def inference_all(template_frame, click_state, evt:gr.SelectData):
|
||||
coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1])
|
||||
def inference_all(template_frame, point_prompt, click_state, evt:gr.SelectData):
|
||||
"""
|
||||
Args:
|
||||
template_frame: PIL.Image
|
||||
point_prompt: flag for positive or negative button click
|
||||
click_state: [[points], [labels]]
|
||||
"""
|
||||
if point_prompt == "Positive":
|
||||
coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1])
|
||||
else:
|
||||
coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1])
|
||||
|
||||
# prompt for sam model
|
||||
prompt = get_prompt(click_state=click_state, click_input=coordinate)
|
||||
|
||||
# default value
|
||||
points = np.array([[evt.index[0],evt.index[1]]])
|
||||
labels= np.array([1])
|
||||
mask, logit, painted_image = model.inference_step(first_flag=True, interact_flag=False, image=np.asarray(template_frame), same_image_flag=False,points=points, labels=labels,logits=None,multimask=True)
|
||||
return painted_image
|
||||
|
||||
# points = np.array([[evt.index[0],evt.index[1]]])
|
||||
# labels= np.array([1])
|
||||
mask, logit, painted_image = model.inference_step(first_flag=True,
|
||||
interact_flag=False,
|
||||
image=np.asarray(template_frame),
|
||||
same_image_flag=False,
|
||||
points=np.array(prompt["input_point"]),
|
||||
labels=np.array(prompt["input_label"]),
|
||||
logits=None,
|
||||
multimask=prompt["multimask_output"]
|
||||
)
|
||||
return painted_image, click_state
|
||||
|
||||
with gr.Blocks() as iface:
|
||||
"""
|
||||
@@ -213,10 +234,10 @@ with gr.Blocks() as iface:
|
||||
template_frame.select(
|
||||
fn=inference_all,
|
||||
inputs=[
|
||||
template_frame
|
||||
template_frame, point_prompt, click_state
|
||||
],
|
||||
outputs=[
|
||||
template_frame
|
||||
template_frame, click_state
|
||||
]
|
||||
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user