set the max memory usage 90 and add highlighted text -- li

This commit is contained in:
memoryunreal
2023-04-26 20:44:48 +00:00
parent 35850bdb73
commit 6d3925046a
2 changed files with 39 additions and 36 deletions

72
app.py
View File

@@ -16,6 +16,7 @@ import torch
from tools.interact_tools import SamControler from tools.interact_tools import SamControler
from tracker.base_tracker import BaseTracker from tracker.base_tracker import BaseTracker
from tools.painter import mask_painter from tools.painter import mask_painter
import psutil
try: try:
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
except: except:
@@ -69,6 +70,7 @@ def get_prompt(click_state, click_input):
} }
return prompt return prompt
# extract frames from upload video # extract frames from upload video
def get_frames_from_video(video_input, video_state): def get_frames_from_video(video_input, video_state):
""" """
@@ -80,13 +82,20 @@ def get_frames_from_video(video_input, video_state):
""" """
video_path = video_input video_path = video_input
frames = [] frames = []
operation_log = [("",""),("Upload video already. Try click the image for adding targets to track and inpaint.","Normal")]
try: try:
cap = cv2.VideoCapture(video_path) cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS) fps = cap.get(cv2.CAP_PROP_FPS)
while cap.isOpened(): while cap.isOpened():
ret, frame = cap.read() ret, frame = cap.read()
if ret == True: if ret == True:
current_memory_usage = psutil.virtual_memory().percent
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
if current_memory_usage > 90:
operation_log = [("Memory usage is too high (>90%). Stop the video extraction. Please reduce the video resolution or frame rate.", "Error")]
print("Memory usage is too high (>90%). Please reduce the video resolution or frame rate.")
break
else: else:
break break
except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e: except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e:
@@ -103,11 +112,10 @@ def get_frames_from_video(video_input, video_state):
"fps": fps "fps": fps
} }
video_info = "Video Name: {}, FPS: {}, Total Frames: {}, Image Size:{}".format(video_state["video_name"], video_state["fps"], len(frames), image_size) video_info = "Video Name: {}, FPS: {}, Total Frames: {}, Image Size:{}".format(video_state["video_name"], video_state["fps"], len(frames), image_size)
operation_log = "Upload video already. Try click the image for adding targets to track and inpaint."
model.samcontroler.sam_controler.reset_image() model.samcontroler.sam_controler.reset_image()
model.samcontroler.sam_controler.set_image(video_state["origin_images"][0]) model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), \ return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), \
gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True),\
gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True), gr.update(visible=True), \
@@ -131,14 +139,14 @@ def select_template(image_selection_slider, video_state, interactive_state):
# update the masks when select a new template frame # update the masks when select a new template frame
# if video_state["masks"][image_selection_slider] is not None: # if video_state["masks"][image_selection_slider] is not None:
# video_state["painted_images"][image_selection_slider] = mask_painter(video_state["origin_images"][image_selection_slider], video_state["masks"][image_selection_slider]) # video_state["painted_images"][image_selection_slider] = mask_painter(video_state["origin_images"][image_selection_slider], video_state["masks"][image_selection_slider])
operation_log = "Select frame {}. Try click image and add mask for tracking.".format(image_selection_slider) operation_log = [("",""), ("Select frame {}. Try click image and add mask for tracking.".format(image_selection_slider),"Normal")]
return video_state["painted_images"][image_selection_slider], video_state, interactive_state, operation_log return video_state["painted_images"][image_selection_slider], video_state, interactive_state, operation_log
# set the tracking end frame # set the tracking end frame
def get_end_number(track_pause_number_slider, video_state, interactive_state): def get_end_number(track_pause_number_slider, video_state, interactive_state):
interactive_state["track_end_number"] = track_pause_number_slider interactive_state["track_end_number"] = track_pause_number_slider
operation_log = "Set the tracking finish at frame {}".format(track_pause_number_slider) operation_log = [("",""),("Set the tracking finish at frame {}".format(track_pause_number_slider),"Normal")]
return video_state["painted_images"][track_pause_number_slider],interactive_state, operation_log return video_state["painted_images"][track_pause_number_slider],interactive_state, operation_log
@@ -177,30 +185,33 @@ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr
video_state["logits"][video_state["select_frame_number"]] = logit video_state["logits"][video_state["select_frame_number"]] = logit
video_state["painted_images"][video_state["select_frame_number"]] = painted_image video_state["painted_images"][video_state["select_frame_number"]] = painted_image
operation_log = "Use SAM for segment. You can try add positive and negative points by clicking. Or press Clear clicks button to refresh the image. Press Add mask button when you are satisfied with the segment" operation_log = [("",""), ("Use SAM for segment. You can try add positive and negative points by clicking. Or press Clear clicks button to refresh the image. Press Add mask button when you are satisfied with the segment","Normal")]
return painted_image, video_state, interactive_state, operation_log return painted_image, video_state, interactive_state, operation_log
def add_multi_mask(video_state, interactive_state, mask_dropdown): def add_multi_mask(video_state, interactive_state, mask_dropdown):
mask = video_state["masks"][video_state["select_frame_number"]] try:
interactive_state["multi_mask"]["masks"].append(mask) mask = video_state["masks"][video_state["select_frame_number"]]
interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"]))) interactive_state["multi_mask"]["masks"].append(mask)
mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"]))) interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
select_frame, run_status = show_mask(video_state, interactive_state, mask_dropdown) mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
select_frame, run_status = show_mask(video_state, interactive_state, mask_dropdown)
operation_log = "Added a mask, use the mask select for target tracking or inpainting." operation_log = [("",""),("Added a mask, use the mask select for target tracking or inpainting.","Normal")]
except:
operation_log = [("Please click the left image to generate mask.", "Error"), ("","")]
return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]], operation_log return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]], operation_log
def clear_click(video_state, click_state): def clear_click(video_state, click_state):
click_state = [[],[]] click_state = [[],[]]
template_frame = video_state["origin_images"][video_state["select_frame_number"]] template_frame = video_state["origin_images"][video_state["select_frame_number"]]
operation_log = "Clear points history and refresh the image." operation_log = [("",""), ("Clear points history and refresh the image.","Normal")]
return template_frame, click_state, operation_log return template_frame, click_state, operation_log
def remove_multi_mask(interactive_state, mask_dropdown): def remove_multi_mask(interactive_state, mask_dropdown):
interactive_state["multi_mask"]["mask_names"]= [] interactive_state["multi_mask"]["mask_names"]= []
interactive_state["multi_mask"]["masks"] = [] interactive_state["multi_mask"]["masks"] = []
operation_log = "Remove all mask, please add new masks" operation_log = [("",""), ("Remove all mask, please add new masks","Normal")]
return interactive_state, gr.update(choices=[],value=[]), operation_log return interactive_state, gr.update(choices=[],value=[]), operation_log
def show_mask(video_state, interactive_state, mask_dropdown): def show_mask(video_state, interactive_state, mask_dropdown):
@@ -212,12 +223,12 @@ def show_mask(video_state, interactive_state, mask_dropdown):
mask = interactive_state["multi_mask"]["masks"][mask_number] mask = interactive_state["multi_mask"]["masks"][mask_number]
select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2) select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2)
operation_log = "Select {} for tracking or inpainting".format(mask_dropdown) operation_log = [("",""), ("Select {} for tracking or inpainting".format(mask_dropdown),"Normal")]
return select_frame, operation_log return select_frame, operation_log
# tracking vos # tracking vos
def vos_tracking_video(video_state, interactive_state, mask_dropdown): def vos_tracking_video(video_state, interactive_state, mask_dropdown):
operation_log = "Track the selected masks, and then you can select the masks for inpainting." operation_log = [("",""), ("Track the selected masks, and then you can select the masks for inpainting.","Normal")]
model.xmem.clear_memory() model.xmem.clear_memory()
if interactive_state["track_end_number"]: if interactive_state["track_end_number"]:
following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
@@ -240,7 +251,7 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
# operation error # operation error
if len(np.unique(template_mask))==1: if len(np.unique(template_mask))==1:
template_mask[0][0]=1 template_mask[0][0]=1
operation_log = "Error! Please add at least one mask to track by clicking the left image." operation_log = [("Error! Please add at least one mask to track by clicking the left image.","Error"), ("","")]
# return video_output, video_state, interactive_state, operation_error # return video_output, video_state, interactive_state, operation_error
masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask) masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
# clear GPU memory # clear GPU memory
@@ -284,7 +295,7 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
# inpaint # inpaint
def inpaint_video(video_state, interactive_state, mask_dropdown): def inpaint_video(video_state, interactive_state, mask_dropdown):
operation_log = "Removed the selected masks." operation_log = [("",""), ("Removed the selected masks.","Normal")]
frames = np.asarray(video_state["origin_images"]) frames = np.asarray(video_state["origin_images"])
fps = video_state["fps"] fps = video_state["fps"]
@@ -306,7 +317,7 @@ def inpaint_video(video_state, interactive_state, mask_dropdown):
try: try:
inpainted_frames = model.baseinpainter.inpaint(frames, inpaint_masks, ratio=interactive_state["resize_ratio"]) # numpy array, T, H, W, 3 inpainted_frames = model.baseinpainter.inpaint(frames, inpaint_masks, ratio=interactive_state["resize_ratio"]) # numpy array, T, H, W, 3
except: except:
operation_log = "Error! You are trying to inpaint without masks input. Please track the selected mask first, and then press inpaint. If VRAM exceeded, please use the resize ratio to scaling down the image size." operation_log = [("Error! You are trying to inpaint without masks input. Please track the selected mask first, and then press inpaint. If VRAM exceeded, please use the resize ratio to scaling down the image size.","Error"), ("","")]
inpainted_frames = video_state["origin_images"] inpainted_frames = video_state["origin_images"]
video_output = generate_video_from_frames(inpainted_frames, output_path="./result/inpaint/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video video_output = generate_video_from_frames(inpainted_frames, output_path="./result/inpaint/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
@@ -364,7 +375,7 @@ folder ="./checkpoints"
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint) SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint) xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint) e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
# args.port = 12212 # args.port = 12214
# args.device = "cuda:2" # args.device = "cuda:2"
# args.mask_save = True # args.mask_save = True
@@ -439,13 +450,7 @@ with gr.Blocks() as iface:
label="Point Prompt", label="Point Prompt",
interactive=True, interactive=True,
visible=False) visible=False)
click_mode = gr.Radio( remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False)
choices=["Continuous", "Single"],
value="Continuous",
label="Clicking Mode",
interactive=True,
visible=False)
with gr.Row():
clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False).style(height=160) clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False).style(height=160)
Add_mask_button = gr.Button(value="Add mask", interactive=True, visible=False) Add_mask_button = gr.Button(value="Add mask", interactive=True, visible=False)
template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame", visible=False).style(height=360) template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame", visible=False).style(height=360)
@@ -453,13 +458,12 @@ with gr.Blocks() as iface:
track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frames", visible=False) track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frames", visible=False)
with gr.Column(): with gr.Column():
run_status = gr.HighlightedText(value=[("Text","Error"),("to be","Label 2"),("highlighted","Label 3")], visible=False)
mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask selection", info=".", visible=False) mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask selection", info=".", visible=False)
remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False)
video_output = gr.Video(autosize=True, visible=False).style(height=360) video_output = gr.Video(autosize=True, visible=False).style(height=360)
with gr.Row(): with gr.Row():
tracking_video_predict_button = gr.Button(value="Tracking", visible=False) tracking_video_predict_button = gr.Button(value="Tracking", visible=False)
inpaint_video_predict_button = gr.Button(value="Inpaint", visible=False) inpaint_video_predict_button = gr.Button(value="Inpaint", visible=False)
run_status = gr.Textbox(label="Operation log", visible=False)
# first step: get the video information # first step: get the video information
extract_frames_button.click( extract_frames_button.click(
@@ -468,7 +472,7 @@ with gr.Blocks() as iface:
video_input, video_state video_input, video_state
], ],
outputs=[video_state, video_info, template_frame, outputs=[video_state, video_info, template_frame,
image_selection_slider, track_pause_number_slider,point_prompt, click_mode, clear_button_click, Add_mask_button, template_frame, image_selection_slider, track_pause_number_slider,point_prompt, clear_button_click, Add_mask_button, template_frame,
tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button, inpaint_video_predict_button, run_status] tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button, inpaint_video_predict_button, run_status]
) )
@@ -551,7 +555,7 @@ with gr.Blocks() as iface:
[[],[]], [[],[]],
None, None,
None, None,
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=[]), gr.update(visible=False), \ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=[]), gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=False) gr.update(visible=False), gr.update(visible=False)
@@ -564,7 +568,7 @@ with gr.Blocks() as iface:
click_state, click_state,
video_output, video_output,
template_frame, template_frame,
tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, click_mode, clear_button_click, tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, clear_button_click,
Add_mask_button, template_frame, tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button,inpaint_video_predict_button, run_status Add_mask_button, template_frame, tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button,inpaint_video_predict_button, run_status
], ],
queue=False, queue=False,
@@ -589,7 +593,5 @@ with gr.Blocks() as iface:
# cache_examples=True, # cache_examples=True,
) )
iface.queue(concurrency_count=1) iface.queue(concurrency_count=1)
iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0") # iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0")
iface.launch(debug=True, enable_queue=True)

View File

@@ -13,4 +13,5 @@ matplotlib
pyyaml pyyaml
av av
openmim openmim
tqdm tqdm
psutil