set the max memory usage 90 and add highlighted text -- li

This commit is contained in:
memoryunreal
2023-04-26 20:44:48 +00:00
parent 35850bdb73
commit 6d3925046a
2 changed files with 39 additions and 36 deletions

72
app.py
View File

@@ -16,6 +16,7 @@ import torch
from tools.interact_tools import SamControler
from tracker.base_tracker import BaseTracker
from tools.painter import mask_painter
import psutil
try:
from mmcv.cnn import ConvModule
except:
@@ -69,6 +70,7 @@ def get_prompt(click_state, click_input):
}
return prompt
# extract frames from upload video
def get_frames_from_video(video_input, video_state):
"""
@@ -80,13 +82,20 @@ def get_frames_from_video(video_input, video_state):
"""
video_path = video_input
frames = []
operation_log = [("",""),("Upload video already. Try click the image for adding targets to track and inpaint.","Normal")]
try:
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
while cap.isOpened():
ret, frame = cap.read()
if ret == True:
current_memory_usage = psutil.virtual_memory().percent
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
if current_memory_usage > 90:
operation_log = [("Memory usage is too high (>90%). Stop the video extraction. Please reduce the video resolution or frame rate.", "Error")]
print("Memory usage is too high (>90%). Please reduce the video resolution or frame rate.")
break
else:
break
except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e:
@@ -103,11 +112,10 @@ def get_frames_from_video(video_input, video_state):
"fps": fps
}
video_info = "Video Name: {}, FPS: {}, Total Frames: {}, Image Size:{}".format(video_state["video_name"], video_state["fps"], len(frames), image_size)
operation_log = "Upload video already. Try click the image for adding targets to track and inpaint."
model.samcontroler.sam_controler.reset_image()
model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), \
gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True),\
gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True), gr.update(visible=True), \
@@ -131,14 +139,14 @@ def select_template(image_selection_slider, video_state, interactive_state):
# update the masks when select a new template frame
# if video_state["masks"][image_selection_slider] is not None:
# video_state["painted_images"][image_selection_slider] = mask_painter(video_state["origin_images"][image_selection_slider], video_state["masks"][image_selection_slider])
operation_log = "Select frame {}. Try click image and add mask for tracking.".format(image_selection_slider)
operation_log = [("",""), ("Select frame {}. Try click image and add mask for tracking.".format(image_selection_slider),"Normal")]
return video_state["painted_images"][image_selection_slider], video_state, interactive_state, operation_log
# set the tracking end frame
def get_end_number(track_pause_number_slider, video_state, interactive_state):
interactive_state["track_end_number"] = track_pause_number_slider
operation_log = "Set the tracking finish at frame {}".format(track_pause_number_slider)
operation_log = [("",""),("Set the tracking finish at frame {}".format(track_pause_number_slider),"Normal")]
return video_state["painted_images"][track_pause_number_slider],interactive_state, operation_log
@@ -177,30 +185,33 @@ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr
video_state["logits"][video_state["select_frame_number"]] = logit
video_state["painted_images"][video_state["select_frame_number"]] = painted_image
operation_log = "Use SAM for segment. You can try add positive and negative points by clicking. Or press Clear clicks button to refresh the image. Press Add mask button when you are satisfied with the segment"
operation_log = [("",""), ("Use SAM for segment. You can try add positive and negative points by clicking. Or press Clear clicks button to refresh the image. Press Add mask button when you are satisfied with the segment","Normal")]
return painted_image, video_state, interactive_state, operation_log
def add_multi_mask(video_state, interactive_state, mask_dropdown):
mask = video_state["masks"][video_state["select_frame_number"]]
interactive_state["multi_mask"]["masks"].append(mask)
interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
select_frame, run_status = show_mask(video_state, interactive_state, mask_dropdown)
try:
mask = video_state["masks"][video_state["select_frame_number"]]
interactive_state["multi_mask"]["masks"].append(mask)
interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
select_frame, run_status = show_mask(video_state, interactive_state, mask_dropdown)
operation_log = "Added a mask, use the mask select for target tracking or inpainting."
operation_log = [("",""),("Added a mask, use the mask select for target tracking or inpainting.","Normal")]
except:
operation_log = [("Please click the left image to generate mask.", "Error"), ("","")]
return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]], operation_log
def clear_click(video_state, click_state):
click_state = [[],[]]
template_frame = video_state["origin_images"][video_state["select_frame_number"]]
operation_log = "Clear points history and refresh the image."
operation_log = [("",""), ("Clear points history and refresh the image.","Normal")]
return template_frame, click_state, operation_log
def remove_multi_mask(interactive_state, mask_dropdown):
interactive_state["multi_mask"]["mask_names"]= []
interactive_state["multi_mask"]["masks"] = []
operation_log = "Remove all mask, please add new masks"
operation_log = [("",""), ("Remove all mask, please add new masks","Normal")]
return interactive_state, gr.update(choices=[],value=[]), operation_log
def show_mask(video_state, interactive_state, mask_dropdown):
@@ -212,12 +223,12 @@ def show_mask(video_state, interactive_state, mask_dropdown):
mask = interactive_state["multi_mask"]["masks"][mask_number]
select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2)
operation_log = "Select {} for tracking or inpainting".format(mask_dropdown)
operation_log = [("",""), ("Select {} for tracking or inpainting".format(mask_dropdown),"Normal")]
return select_frame, operation_log
# tracking vos
def vos_tracking_video(video_state, interactive_state, mask_dropdown):
operation_log = "Track the selected masks, and then you can select the masks for inpainting."
operation_log = [("",""), ("Track the selected masks, and then you can select the masks for inpainting.","Normal")]
model.xmem.clear_memory()
if interactive_state["track_end_number"]:
following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
@@ -240,7 +251,7 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
# operation error
if len(np.unique(template_mask))==1:
template_mask[0][0]=1
operation_log = "Error! Please add at least one mask to track by clicking the left image."
operation_log = [("Error! Please add at least one mask to track by clicking the left image.","Error"), ("","")]
# return video_output, video_state, interactive_state, operation_error
masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
# clear GPU memory
@@ -284,7 +295,7 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
# inpaint
def inpaint_video(video_state, interactive_state, mask_dropdown):
operation_log = "Removed the selected masks."
operation_log = [("",""), ("Removed the selected masks.","Normal")]
frames = np.asarray(video_state["origin_images"])
fps = video_state["fps"]
@@ -306,7 +317,7 @@ def inpaint_video(video_state, interactive_state, mask_dropdown):
try:
inpainted_frames = model.baseinpainter.inpaint(frames, inpaint_masks, ratio=interactive_state["resize_ratio"]) # numpy array, T, H, W, 3
except:
operation_log = "Error! You are trying to inpaint without masks input. Please track the selected mask first, and then press inpaint. If VRAM exceeded, please use the resize ratio to scaling down the image size."
operation_log = [("Error! You are trying to inpaint without masks input. Please track the selected mask first, and then press inpaint. If VRAM exceeded, please use the resize ratio to scaling down the image size.","Error"), ("","")]
inpainted_frames = video_state["origin_images"]
video_output = generate_video_from_frames(inpainted_frames, output_path="./result/inpaint/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
@@ -364,7 +375,7 @@ folder ="./checkpoints"
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
# args.port = 12212
# args.port = 12214
# args.device = "cuda:2"
# args.mask_save = True
@@ -439,13 +450,7 @@ with gr.Blocks() as iface:
label="Point Prompt",
interactive=True,
visible=False)
click_mode = gr.Radio(
choices=["Continuous", "Single"],
value="Continuous",
label="Clicking Mode",
interactive=True,
visible=False)
with gr.Row():
remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False)
clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False).style(height=160)
Add_mask_button = gr.Button(value="Add mask", interactive=True, visible=False)
template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame", visible=False).style(height=360)
@@ -453,13 +458,12 @@ with gr.Blocks() as iface:
track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frames", visible=False)
with gr.Column():
run_status = gr.HighlightedText(value=[("Text","Error"),("to be","Label 2"),("highlighted","Label 3")], visible=False)
mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask selection", info=".", visible=False)
remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False)
video_output = gr.Video(autosize=True, visible=False).style(height=360)
with gr.Row():
tracking_video_predict_button = gr.Button(value="Tracking", visible=False)
inpaint_video_predict_button = gr.Button(value="Inpaint", visible=False)
run_status = gr.Textbox(label="Operation log", visible=False)
# first step: get the video information
extract_frames_button.click(
@@ -468,7 +472,7 @@ with gr.Blocks() as iface:
video_input, video_state
],
outputs=[video_state, video_info, template_frame,
image_selection_slider, track_pause_number_slider,point_prompt, click_mode, clear_button_click, Add_mask_button, template_frame,
image_selection_slider, track_pause_number_slider,point_prompt, clear_button_click, Add_mask_button, template_frame,
tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button, inpaint_video_predict_button, run_status]
)
@@ -551,7 +555,7 @@ with gr.Blocks() as iface:
[[],[]],
None,
None,
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=[]), gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=False)
@@ -564,7 +568,7 @@ with gr.Blocks() as iface:
click_state,
video_output,
template_frame,
tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, click_mode, clear_button_click,
tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, clear_button_click,
Add_mask_button, template_frame, tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button,inpaint_video_predict_button, run_status
],
queue=False,
@@ -589,7 +593,5 @@ with gr.Blocks() as iface:
# cache_examples=True,
)
iface.queue(concurrency_count=1)
iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0")
# iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0")
iface.launch(debug=True, enable_queue=True)

View File

@@ -13,4 +13,5 @@ matplotlib
pyyaml
av
openmim
tqdm
tqdm
psutil