multi-mask success -- li

This commit is contained in:
memoryunreal
2023-04-19 17:06:19 +00:00
parent c975670db3
commit 08bfb27bc1

36
app.py
View File

@@ -87,7 +87,10 @@ def get_frames_from_video(video_input, video_state):
"fps": fps
}
video_info = "Video Name: {}, FPS: {}, Total Frames: {}".format(video_state["video_name"], video_state["fps"], len(frames))
return video_state, video_info, gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=1), \
model.samcontroler.sam_controler.reset_image()
model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=1), \
gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True), gr.update(visible=True), \
@@ -148,9 +151,10 @@ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr
def add_multi_mask(video_state, interactive_state, mask_dropdown):
mask = video_state["masks"][video_state["select_frame_number"]]
interactive_state["multi_mask"]["masks"].append(mask)
interactive_state["multi_mask"]["mask_names"].append("mask_{}".format(len(interactive_state["multi_mask"]["masks"])))
interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"])
return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown)
def remove_multi_mask(interactive_state):
@@ -178,9 +182,10 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
following_frames = video_state["origin_images"][video_state["select_frame_number"]:]
if interactive_state["multi_mask"]["masks"]:
# if mask_dropdown:
if len(mask_dropdown) == 0:
mask_dropdown = ["mask_001"]
mask_dropdown.sort()
template_mask = interactive_state["multi_mask"]["masks"][int(mask_dropdown[0].split("_")[1]) - 1]
template_mask = interactive_state["multi_mask"]["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1]))
for i in range(1,len(mask_dropdown)):
mask_number = int(mask_dropdown[i].split("_")[1]) - 1
template_mask = np.clip(template_mask+interactive_state["multi_mask"]["masks"][mask_number]*(mask_number+1), 0, mask_number+1)
@@ -255,9 +260,9 @@ xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoi
# args, defined in track_anything.py
args = parse_augment()
args.port = 12212
args.device = "cuda:1"
args.mask_save = True
# args.port = 12315
# args.device = "cuda:1"
# args.mask_save = True
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
@@ -331,7 +336,7 @@ with gr.Blocks() as iface:
track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frames", visible=False)
with gr.Column(scale=0.4):
mask_dropdown = gr.Dropdown(multiselect=True, label="Mask_select", info=".", visible=False)
mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask_select", info=".", visible=False)
remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False)
video_output = gr.Video(autosize=True, visible=False).style(height=360)
tracking_video_predict_button = gr.Button(value="Tracking", visible=False)
@@ -342,7 +347,8 @@ with gr.Blocks() as iface:
inputs=[
video_input, video_state
],
outputs=[video_state, video_info, image_selection_slider, track_pause_number_slider,point_prompt, click_mode, clear_button_click, Add_mask_button, template_frame,
outputs=[video_state, video_info, template_frame,
image_selection_slider, track_pause_number_slider,point_prompt, click_mode, clear_button_click, Add_mask_button, template_frame,
tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button]
)
@@ -410,13 +416,21 @@ with gr.Blocks() as iface:
},
"track_end_num": 0
},
[[],[]]
[[],[]],
None,
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=[]), gr.update(visible=False) \
),
[],
[
video_state,
interactive_state,
click_state,
template_frame,
tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, click_mode, clear_button_click,
Add_mask_button, template_frame, tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button
],
queue=False,
show_progress=False)