negative mask input li

This commit is contained in:
memoryunreal
2023-04-14 09:13:21 +00:00
parent 5abfd0eb8e
commit b17b35ad75

15
app.py
View File

@@ -42,7 +42,7 @@ xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoi
# args, defined in track_anything.py # args, defined in track_anything.py
args = parse_augment() args = parse_augment()
args.port = 12213 args.port = 12212
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args) model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
@@ -109,10 +109,10 @@ def get_frames_from_video(video_input, play_state):
frames = [frames[:key_frame_index], frames[key_frame_index:], nearest_frame] frames = [frames[:key_frame_index], frames[key_frame_index:], nearest_frame]
# set image in sam when select the template frame # set image in sam when select the template frame
model.samcontroler.set_image(np.asarray(nearest_frame)) model.samcontroler.sam_controler.set_image(np.asarray(nearest_frame))
return frames, nearest_frame return frames, nearest_frame
def inference_all(template_frame, point_prompt, click_state, evt:gr.SelectData): def inference_all(template_frame, point_prompt, click_state, logit, evt:gr.SelectData):
""" """
Args: Args:
template_frame: PIL.Image template_frame: PIL.Image
@@ -136,7 +136,7 @@ def inference_all(template_frame, point_prompt, click_state, evt:gr.SelectData):
labels=np.array(prompt["input_label"]), labels=np.array(prompt["input_label"]),
multimask=prompt["multimask_output"] multimask=prompt["multimask_output"]
) )
return painted_image, click_state return painted_image, click_state, logit
with gr.Blocks() as iface: with gr.Blocks() as iface:
""" """
@@ -146,7 +146,7 @@ with gr.Blocks() as iface:
play_state = gr.State([]) play_state = gr.State([])
video_state = gr.State([[],[],[]]) video_state = gr.State([[],[],[]])
click_state = gr.State([[],[]]) click_state = gr.State([[],[]])
logits = gr.State([])
with gr.Row(): with gr.Row():
# for user video input # for user video input
@@ -226,7 +226,8 @@ with gr.Blocks() as iface:
fn=get_frames_from_video, fn=get_frames_from_video,
inputs=[ inputs=[
video_input, video_input,
play_state play_state,
logits
], ],
outputs=[video_state, template_frame], outputs=[video_state, template_frame],
) )
@@ -234,7 +235,7 @@ with gr.Blocks() as iface:
template_frame.select( template_frame.select(
fn=inference_all, fn=inference_all,
inputs=[ inputs=[
template_frame, point_prompt, click_state template_frame, point_prompt, click_state, logits
], ],
outputs=[ outputs=[
template_frame, click_state template_frame, click_state