diff --git a/app.py b/app.py index a3d3fcd..5b66e92 100644 --- a/app.py +++ b/app.py @@ -16,6 +16,7 @@ import torch from tools.interact_tools import SamControler from tracker.base_tracker import BaseTracker from tools.painter import mask_painter +import psutil try: from mmcv.cnn import ConvModule except: @@ -69,6 +70,7 @@ def get_prompt(click_state, click_input): } return prompt + # extract frames from upload video def get_frames_from_video(video_input, video_state): """ @@ -80,13 +82,20 @@ def get_frames_from_video(video_input, video_state): """ video_path = video_input frames = [] + + operation_log = [("",""),("Upload video already. Try click the image for adding targets to track and inpaint.","Normal")] try: cap = cv2.VideoCapture(video_path) fps = cap.get(cv2.CAP_PROP_FPS) while cap.isOpened(): ret, frame = cap.read() if ret == True: + current_memory_usage = psutil.virtual_memory().percent frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + if current_memory_usage > 90: + operation_log = [("Memory usage is too high (>90%). Stop the video extraction. Please reduce the video resolution or frame rate.", "Error")] + print("Memory usage is too high (>90%). Please reduce the video resolution or frame rate.") + break else: break except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e: @@ -103,11 +112,10 @@ def get_frames_from_video(video_input, video_state): "fps": fps } video_info = "Video Name: {}, FPS: {}, Total Frames: {}, Image Size:{}".format(video_state["video_name"], video_state["fps"], len(frames), image_size) - operation_log = "Upload video already. Try click the image for adding targets to track and inpaint." model.samcontroler.sam_controler.reset_image() model.samcontroler.sam_controler.set_image(video_state["origin_images"][0]) return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), \ - gr.update(visible=True), gr.update(visible=True), \ + gr.update(visible=True),\ gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True), gr.update(visible=True), \ @@ -131,14 +139,14 @@ def select_template(image_selection_slider, video_state, interactive_state): # update the masks when select a new template frame # if video_state["masks"][image_selection_slider] is not None: # video_state["painted_images"][image_selection_slider] = mask_painter(video_state["origin_images"][image_selection_slider], video_state["masks"][image_selection_slider]) - operation_log = "Select frame {}. Try click image and add mask for tracking.".format(image_selection_slider) + operation_log = [("",""), ("Select frame {}. Try click image and add mask for tracking.".format(image_selection_slider),"Normal")] return video_state["painted_images"][image_selection_slider], video_state, interactive_state, operation_log # set the tracking end frame def get_end_number(track_pause_number_slider, video_state, interactive_state): interactive_state["track_end_number"] = track_pause_number_slider - operation_log = "Set the tracking finish at frame {}".format(track_pause_number_slider) + operation_log = [("",""),("Set the tracking finish at frame {}".format(track_pause_number_slider),"Normal")] return video_state["painted_images"][track_pause_number_slider],interactive_state, operation_log @@ -177,30 +185,33 @@ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr video_state["logits"][video_state["select_frame_number"]] = logit video_state["painted_images"][video_state["select_frame_number"]] = painted_image - operation_log = "Use SAM for segment. You can try add positive and negative points by clicking. Or press Clear clicks button to refresh the image. Press Add mask button when you are satisfied with the segment" + operation_log = [("",""), ("Use SAM for segment. You can try add positive and negative points by clicking. Or press Clear clicks button to refresh the image. Press Add mask button when you are satisfied with the segment","Normal")] return painted_image, video_state, interactive_state, operation_log def add_multi_mask(video_state, interactive_state, mask_dropdown): - mask = video_state["masks"][video_state["select_frame_number"]] - interactive_state["multi_mask"]["masks"].append(mask) - interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"]))) - mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"]))) - select_frame, run_status = show_mask(video_state, interactive_state, mask_dropdown) + try: + mask = video_state["masks"][video_state["select_frame_number"]] + interactive_state["multi_mask"]["masks"].append(mask) + interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"]))) + mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"]))) + select_frame, run_status = show_mask(video_state, interactive_state, mask_dropdown) - operation_log = "Added a mask, use the mask select for target tracking or inpainting." + operation_log = [("",""),("Added a mask, use the mask select for target tracking or inpainting.","Normal")] + except: + operation_log = [("Please click the left image to generate mask.", "Error"), ("","")] return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]], operation_log def clear_click(video_state, click_state): click_state = [[],[]] template_frame = video_state["origin_images"][video_state["select_frame_number"]] - operation_log = "Clear points history and refresh the image." + operation_log = [("",""), ("Clear points history and refresh the image.","Normal")] return template_frame, click_state, operation_log def remove_multi_mask(interactive_state, mask_dropdown): interactive_state["multi_mask"]["mask_names"]= [] interactive_state["multi_mask"]["masks"] = [] - operation_log = "Remove all mask, please add new masks" + operation_log = [("",""), ("Remove all mask, please add new masks","Normal")] return interactive_state, gr.update(choices=[],value=[]), operation_log def show_mask(video_state, interactive_state, mask_dropdown): @@ -212,12 +223,12 @@ def show_mask(video_state, interactive_state, mask_dropdown): mask = interactive_state["multi_mask"]["masks"][mask_number] select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2) - operation_log = "Select {} for tracking or inpainting".format(mask_dropdown) + operation_log = [("",""), ("Select {} for tracking or inpainting".format(mask_dropdown),"Normal")] return select_frame, operation_log # tracking vos def vos_tracking_video(video_state, interactive_state, mask_dropdown): - operation_log = "Track the selected masks, and then you can select the masks for inpainting." + operation_log = [("",""), ("Track the selected masks, and then you can select the masks for inpainting.","Normal")] model.xmem.clear_memory() if interactive_state["track_end_number"]: following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] @@ -240,7 +251,7 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown): # operation error if len(np.unique(template_mask))==1: template_mask[0][0]=1 - operation_log = "Error! Please add at least one mask to track by clicking the left image." + operation_log = [("Error! Please add at least one mask to track by clicking the left image.","Error"), ("","")] # return video_output, video_state, interactive_state, operation_error masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask) # clear GPU memory @@ -284,7 +295,7 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown): # inpaint def inpaint_video(video_state, interactive_state, mask_dropdown): - operation_log = "Removed the selected masks." + operation_log = [("",""), ("Removed the selected masks.","Normal")] frames = np.asarray(video_state["origin_images"]) fps = video_state["fps"] @@ -306,7 +317,7 @@ def inpaint_video(video_state, interactive_state, mask_dropdown): try: inpainted_frames = model.baseinpainter.inpaint(frames, inpaint_masks, ratio=interactive_state["resize_ratio"]) # numpy array, T, H, W, 3 except: - operation_log = "Error! You are trying to inpaint without masks input. Please track the selected mask first, and then press inpaint. If VRAM exceeded, please use the resize ratio to scaling down the image size." + operation_log = [("Error! You are trying to inpaint without masks input. Please track the selected mask first, and then press inpaint. If VRAM exceeded, please use the resize ratio to scaling down the image size.","Error"), ("","")] inpainted_frames = video_state["origin_images"] video_output = generate_video_from_frames(inpainted_frames, output_path="./result/inpaint/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video @@ -364,7 +375,7 @@ folder ="./checkpoints" SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint) xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint) e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint) -# args.port = 12212 +# args.port = 12214 # args.device = "cuda:2" # args.mask_save = True @@ -439,13 +450,7 @@ with gr.Blocks() as iface: label="Point Prompt", interactive=True, visible=False) - click_mode = gr.Radio( - choices=["Continuous", "Single"], - value="Continuous", - label="Clicking Mode", - interactive=True, - visible=False) - with gr.Row(): + remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False) clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False).style(height=160) Add_mask_button = gr.Button(value="Add mask", interactive=True, visible=False) template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame", visible=False).style(height=360) @@ -453,13 +458,12 @@ with gr.Blocks() as iface: track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frames", visible=False) with gr.Column(): + run_status = gr.HighlightedText(value=[("Text","Error"),("to be","Label 2"),("highlighted","Label 3")], visible=False) mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask selection", info=".", visible=False) - remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False) video_output = gr.Video(autosize=True, visible=False).style(height=360) with gr.Row(): tracking_video_predict_button = gr.Button(value="Tracking", visible=False) inpaint_video_predict_button = gr.Button(value="Inpaint", visible=False) - run_status = gr.Textbox(label="Operation log", visible=False) # first step: get the video information extract_frames_button.click( @@ -468,7 +472,7 @@ with gr.Blocks() as iface: video_input, video_state ], outputs=[video_state, video_info, template_frame, - image_selection_slider, track_pause_number_slider,point_prompt, click_mode, clear_button_click, Add_mask_button, template_frame, + image_selection_slider, track_pause_number_slider,point_prompt, clear_button_click, Add_mask_button, template_frame, tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button, inpaint_video_predict_button, run_status] ) @@ -551,7 +555,7 @@ with gr.Blocks() as iface: [[],[]], None, None, - gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \ + gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=[]), gr.update(visible=False), \ gr.update(visible=False), gr.update(visible=False) @@ -564,7 +568,7 @@ with gr.Blocks() as iface: click_state, video_output, template_frame, - tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, click_mode, clear_button_click, + tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, clear_button_click, Add_mask_button, template_frame, tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button,inpaint_video_predict_button, run_status ], queue=False, @@ -589,7 +593,5 @@ with gr.Blocks() as iface: # cache_examples=True, ) iface.queue(concurrency_count=1) -iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0") - - - +# iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0") +iface.launch(debug=True, enable_queue=True) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 410219f..16ec760 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,5 @@ matplotlib pyyaml av openmim -tqdm \ No newline at end of file +tqdm +psutil \ No newline at end of file