mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 16:37:58 +01:00
negative mask input li
This commit is contained in:
15
app.py
15
app.py
@@ -42,7 +42,7 @@ xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoi
|
||||
|
||||
# args, defined in track_anything.py
|
||||
args = parse_augment()
|
||||
args.port = 12213
|
||||
args.port = 12212
|
||||
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
|
||||
|
||||
|
||||
@@ -109,10 +109,10 @@ def get_frames_from_video(video_input, play_state):
|
||||
frames = [frames[:key_frame_index], frames[key_frame_index:], nearest_frame]
|
||||
|
||||
# set image in sam when select the template frame
|
||||
model.samcontroler.set_image(np.asarray(nearest_frame))
|
||||
model.samcontroler.sam_controler.set_image(np.asarray(nearest_frame))
|
||||
return frames, nearest_frame
|
||||
|
||||
def inference_all(template_frame, point_prompt, click_state, evt:gr.SelectData):
|
||||
def inference_all(template_frame, point_prompt, click_state, logit, evt:gr.SelectData):
|
||||
"""
|
||||
Args:
|
||||
template_frame: PIL.Image
|
||||
@@ -136,7 +136,7 @@ def inference_all(template_frame, point_prompt, click_state, evt:gr.SelectData):
|
||||
labels=np.array(prompt["input_label"]),
|
||||
multimask=prompt["multimask_output"]
|
||||
)
|
||||
return painted_image, click_state
|
||||
return painted_image, click_state, logit
|
||||
|
||||
with gr.Blocks() as iface:
|
||||
"""
|
||||
@@ -146,7 +146,7 @@ with gr.Blocks() as iface:
|
||||
play_state = gr.State([])
|
||||
video_state = gr.State([[],[],[]])
|
||||
click_state = gr.State([[],[]])
|
||||
|
||||
logits = gr.State([])
|
||||
with gr.Row():
|
||||
|
||||
# for user video input
|
||||
@@ -226,7 +226,8 @@ with gr.Blocks() as iface:
|
||||
fn=get_frames_from_video,
|
||||
inputs=[
|
||||
video_input,
|
||||
play_state
|
||||
play_state,
|
||||
logits
|
||||
],
|
||||
outputs=[video_state, template_frame],
|
||||
)
|
||||
@@ -234,7 +235,7 @@ with gr.Blocks() as iface:
|
||||
template_frame.select(
|
||||
fn=inference_all,
|
||||
inputs=[
|
||||
template_frame, point_prompt, click_state
|
||||
template_frame, point_prompt, click_state, logits
|
||||
],
|
||||
outputs=[
|
||||
template_frame, click_state
|
||||
|
||||
Reference in New Issue
Block a user