diff --git a/app.py b/app.py index ade5c50..2a601af 100644 --- a/app.py +++ b/app.py @@ -103,7 +103,7 @@ def get_frames_from_video(video_input, video_state): "fps": fps } video_info = "Video Name: {}, FPS: {}, Total Frames: {}, Image Size:{}".format(video_state["video_name"], video_state["fps"], len(frames), image_size) - + operation_log = "Upload video already. Try click the image for adding targets to track and inpaint." model.samcontroler.sam_controler.reset_image() model.samcontroler.sam_controler.set_image(video_state["origin_images"][0]) return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), \ @@ -111,7 +111,8 @@ def get_frames_from_video(video_input, video_state): gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True), gr.update(visible=True), \ - gr.update(visible=True), gr.update(visible=True) + gr.update(visible=True), gr.update(visible=True), \ + gr.update(visible=True, value=operation_log) def run_example(example): return video_input @@ -130,15 +131,16 @@ def select_template(image_selection_slider, video_state, interactive_state): # update the masks when select a new template frame # if video_state["masks"][image_selection_slider] is not None: # video_state["painted_images"][image_selection_slider] = mask_painter(video_state["origin_images"][image_selection_slider], video_state["masks"][image_selection_slider]) + operation_log = "Select frame {}. Try click image and add mask for tracking.".format(image_selection_slider) - - return video_state["painted_images"][image_selection_slider], video_state, interactive_state + return video_state["painted_images"][image_selection_slider], video_state, interactive_state, operation_log # set the tracking end frame def get_end_number(track_pause_number_slider, video_state, interactive_state): interactive_state["track_end_number"] = track_pause_number_slider + operation_log = "Set the tracking finish at frame {}".format(track_pause_number_slider) - return video_state["painted_images"][track_pause_number_slider],interactive_state + return video_state["painted_images"][track_pause_number_slider],interactive_state, operation_log def get_resize_ratio(resize_ratio_slider, interactive_state): interactive_state["resize_ratio"] = resize_ratio_slider @@ -161,6 +163,8 @@ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr interactive_state["negative_click_times"] += 1 # prompt for sam model + model.samcontroler.sam_controler.reset_image() + model.samcontroler.sam_controler.set_image(video_state["origin_images"][video_state["select_frame_number"]]) prompt = get_prompt(click_state=click_state, click_input=coordinate) mask, logit, painted_image = model.first_frame_click( @@ -173,25 +177,31 @@ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr video_state["logits"][video_state["select_frame_number"]] = logit video_state["painted_images"][video_state["select_frame_number"]] = painted_image - return painted_image, video_state, interactive_state + operation_log = "Use SAM for segment. You can try add positive and negative points by clicking. Or press Clear clicks button to refresh the image. Press Add mask button when you are satisfied with the segment" + return painted_image, video_state, interactive_state, operation_log def add_multi_mask(video_state, interactive_state, mask_dropdown): mask = video_state["masks"][video_state["select_frame_number"]] interactive_state["multi_mask"]["masks"].append(mask) interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"]))) mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"]))) - select_frame = show_mask(video_state, interactive_state, mask_dropdown) - return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]] + select_frame, run_status = show_mask(video_state, interactive_state, mask_dropdown) + + operation_log = "Added a mask, use the mask select for target tracking or inpainting." + return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]], operation_log def clear_click(video_state, click_state): click_state = [[],[]] template_frame = video_state["origin_images"][video_state["select_frame_number"]] - return template_frame, click_state + operation_log = "Clear points history and refresh the image." + return template_frame, click_state, operation_log -def remove_multi_mask(interactive_state): +def remove_multi_mask(interactive_state, mask_dropdown): interactive_state["multi_mask"]["mask_names"]= [] interactive_state["multi_mask"]["masks"] = [] - return interactive_state + + operation_log = "Remove all mask, please add new masks" + return interactive_state, gr.update(choices=[],value=[]), operation_log def show_mask(video_state, interactive_state, mask_dropdown): mask_dropdown.sort() @@ -201,12 +211,13 @@ def show_mask(video_state, interactive_state, mask_dropdown): mask_number = int(mask_dropdown[i].split("_")[1]) - 1 mask = interactive_state["multi_mask"]["masks"][mask_number] select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2) - - return select_frame + + operation_log = "Select {} for tracking or inpainting".format(mask_dropdown) + return select_frame, operation_log # tracking vos def vos_tracking_video(video_state, interactive_state, mask_dropdown): - + operation_log = "Track the selected masks, and then you can select the masks for inpainting." model.xmem.clear_memory() if interactive_state["track_end_number"]: following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] @@ -225,6 +236,12 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown): else: template_mask = video_state["masks"][video_state["select_frame_number"]] fps = video_state["fps"] + + # operation error + if len(np.unique(template_mask))==1: + template_mask[0][0]=1 + operation_log = "Error! Please add at least one mask to track by clicking the left image." + # return video_output, video_state, interactive_state, operation_error masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask) # clear GPU memory model.xmem.clear_memory() @@ -257,7 +274,7 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown): i+=1 # save_mask(video_state["masks"], video_state["video_name"]) #### shanggao code for mask save - return video_output, video_state, interactive_state + return video_output, video_state, interactive_state, operation_log # extracting masks from mask_dropdown # def extract_sole_mask(video_state, mask_dropdown): @@ -267,6 +284,7 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown): # inpaint def inpaint_video(video_state, interactive_state, mask_dropdown): + operation_log = "Removed the selected masks." frames = np.asarray(video_state["origin_images"]) fps = video_state["fps"] @@ -284,10 +302,15 @@ def inpaint_video(video_state, interactive_state, mask_dropdown): continue inpaint_masks[inpaint_masks==i] = 0 # inpaint for videos - inpainted_frames = model.baseinpainter.inpaint(frames, inpaint_masks, ratio=interactive_state["resize_ratio"]) # numpy array, T, H, W, 3 + + try: + inpainted_frames = model.baseinpainter.inpaint(frames, inpaint_masks, ratio=interactive_state["resize_ratio"]) # numpy array, T, H, W, 3 + except: + operation_log = "Error! You are trying to inpaint without masks input. Please track the selected mask first, and then press inpaint. If VRAM exceeded, please use the resize ratio to scaling down the image size." + inpainted_frames = video_state["origin_images"] video_output = generate_video_from_frames(inpainted_frames, output_path="./result/inpaint/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video - return video_output + return video_output, operation_log # generate video after vos inference @@ -341,8 +364,8 @@ folder ="./checkpoints" SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint) xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint) e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint) -# args.port = 12315 -# args.device = "cuda:2" +args.port = 12211 +args.device = "cuda:2" # args.mask_save = True # initialize sam, xmem, e2fgvi models @@ -395,8 +418,8 @@ with gr.Blocks() as iface: video_input = gr.Video(autosize=True) with gr.Column(): video_info = gr.Textbox() - resize_info = gr.Textbox(value="If you want to use the inpaint function, it is best to download and use a machine with more VRAM locally. \ - Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.") + resize_info = gr.Textbox(value="If you want to use the inpaint function, it is best to git clone the repo and use a machine with more VRAM locally. \ + Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.", label="Tips for running this demo.") resize_ratio_slider = gr.Slider(minimum=0.02, maximum=1, step=0.02, value=1, label="Resize ratio", visible=True) @@ -436,6 +459,7 @@ with gr.Blocks() as iface: with gr.Row(): tracking_video_predict_button = gr.Button(value="Tracking", visible=False) inpaint_video_predict_button = gr.Button(value="Inpaint", visible=False) + run_status = gr.Textbox(label="Operation log", visible=False) # first step: get the video information extract_frames_button.click( @@ -445,16 +469,16 @@ with gr.Blocks() as iface: ], outputs=[video_state, video_info, template_frame, image_selection_slider, track_pause_number_slider,point_prompt, click_mode, clear_button_click, Add_mask_button, template_frame, - tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button, inpaint_video_predict_button] + tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button, inpaint_video_predict_button, run_status] ) # second step: select images from slider image_selection_slider.release(fn=select_template, inputs=[image_selection_slider, video_state, interactive_state], - outputs=[template_frame, video_state, interactive_state], api_name="select_image") + outputs=[template_frame, video_state, interactive_state, run_status], api_name="select_image") track_pause_number_slider.release(fn=get_end_number, inputs=[track_pause_number_slider, video_state, interactive_state], - outputs=[template_frame, interactive_state], api_name="end_image") + outputs=[template_frame, interactive_state, run_status], api_name="end_image") resize_ratio_slider.release(fn=get_resize_ratio, inputs=[resize_ratio_slider, interactive_state], outputs=[interactive_state], api_name="resize_ratio") @@ -463,41 +487,41 @@ with gr.Blocks() as iface: template_frame.select( fn=sam_refine, inputs=[video_state, point_prompt, click_state, interactive_state], - outputs=[template_frame, video_state, interactive_state] + outputs=[template_frame, video_state, interactive_state, run_status] ) # add different mask Add_mask_button.click( fn=add_multi_mask, inputs=[video_state, interactive_state, mask_dropdown], - outputs=[interactive_state, mask_dropdown, template_frame, click_state] + outputs=[interactive_state, mask_dropdown, template_frame, click_state, run_status] ) remove_mask_button.click( fn=remove_multi_mask, - inputs=[interactive_state], - outputs=[interactive_state] + inputs=[interactive_state, mask_dropdown], + outputs=[interactive_state, mask_dropdown, run_status] ) # tracking video from select image and mask tracking_video_predict_button.click( fn=vos_tracking_video, inputs=[video_state, interactive_state, mask_dropdown], - outputs=[video_output, video_state, interactive_state] + outputs=[video_output, video_state, interactive_state, run_status] ) # inpaint video from select image and mask inpaint_video_predict_button.click( fn=inpaint_video, inputs=[video_state, interactive_state, mask_dropdown], - outputs=[video_output] + outputs=[video_output, run_status] ) # click to get mask mask_dropdown.change( fn=show_mask, inputs=[video_state, interactive_state, mask_dropdown], - outputs=[template_frame] + outputs=[template_frame, run_status] ) # clear input @@ -529,7 +553,8 @@ with gr.Blocks() as iface: None, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \ - gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=[]), gr.update(visible=False), gr.update(visible=False) \ + gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=[]), gr.update(visible=False), \ + gr.update(visible=False), gr.update(visible=False) ), [], @@ -540,7 +565,7 @@ with gr.Blocks() as iface: video_output, template_frame, tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, click_mode, clear_button_click, - Add_mask_button, template_frame, tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button,inpaint_video_predict_button + Add_mask_button, template_frame, tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button,inpaint_video_predict_button, run_status ], queue=False, show_progress=False) @@ -549,7 +574,7 @@ with gr.Blocks() as iface: clear_button_click.click( fn = clear_click, inputs = [video_state, click_state,], - outputs = [template_frame,click_state], + outputs = [template_frame,click_state, run_status], ) # set example gr.Markdown("## Examples")