This commit is contained in:
gaomingqi
2023-04-26 17:14:42 +08:00

93
app.py
View File

@@ -103,7 +103,7 @@ def get_frames_from_video(video_input, video_state):
"fps": fps
}
video_info = "Video Name: {}, FPS: {}, Total Frames: {}, Image Size:{}".format(video_state["video_name"], video_state["fps"], len(frames), image_size)
operation_log = "Upload video already. Try click the image for adding targets to track and inpaint."
model.samcontroler.sam_controler.reset_image()
model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), \
@@ -111,7 +111,8 @@ def get_frames_from_video(video_input, video_state):
gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True), gr.update(visible=True)
gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True, value=operation_log)
def run_example(example):
return video_input
@@ -130,15 +131,16 @@ def select_template(image_selection_slider, video_state, interactive_state):
# update the masks when select a new template frame
# if video_state["masks"][image_selection_slider] is not None:
# video_state["painted_images"][image_selection_slider] = mask_painter(video_state["origin_images"][image_selection_slider], video_state["masks"][image_selection_slider])
operation_log = "Select frame {}. Try click image and add mask for tracking.".format(image_selection_slider)
return video_state["painted_images"][image_selection_slider], video_state, interactive_state
return video_state["painted_images"][image_selection_slider], video_state, interactive_state, operation_log
# set the tracking end frame
def get_end_number(track_pause_number_slider, video_state, interactive_state):
interactive_state["track_end_number"] = track_pause_number_slider
operation_log = "Set the tracking finish at frame {}".format(track_pause_number_slider)
return video_state["painted_images"][track_pause_number_slider],interactive_state
return video_state["painted_images"][track_pause_number_slider],interactive_state, operation_log
def get_resize_ratio(resize_ratio_slider, interactive_state):
interactive_state["resize_ratio"] = resize_ratio_slider
@@ -161,6 +163,8 @@ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr
interactive_state["negative_click_times"] += 1
# prompt for sam model
model.samcontroler.sam_controler.reset_image()
model.samcontroler.sam_controler.set_image(video_state["origin_images"][video_state["select_frame_number"]])
prompt = get_prompt(click_state=click_state, click_input=coordinate)
mask, logit, painted_image = model.first_frame_click(
@@ -173,25 +177,31 @@ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr
video_state["logits"][video_state["select_frame_number"]] = logit
video_state["painted_images"][video_state["select_frame_number"]] = painted_image
return painted_image, video_state, interactive_state
operation_log = "Use SAM for segment. You can try add positive and negative points by clicking. Or press Clear clicks button to refresh the image. Press Add mask button when you are satisfied with the segment"
return painted_image, video_state, interactive_state, operation_log
def add_multi_mask(video_state, interactive_state, mask_dropdown):
mask = video_state["masks"][video_state["select_frame_number"]]
interactive_state["multi_mask"]["masks"].append(mask)
interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
select_frame = show_mask(video_state, interactive_state, mask_dropdown)
return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]]
select_frame, run_status = show_mask(video_state, interactive_state, mask_dropdown)
operation_log = "Added a mask, use the mask select for target tracking or inpainting."
return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]], operation_log
def clear_click(video_state, click_state):
click_state = [[],[]]
template_frame = video_state["origin_images"][video_state["select_frame_number"]]
return template_frame, click_state
operation_log = "Clear points history and refresh the image."
return template_frame, click_state, operation_log
def remove_multi_mask(interactive_state):
def remove_multi_mask(interactive_state, mask_dropdown):
interactive_state["multi_mask"]["mask_names"]= []
interactive_state["multi_mask"]["masks"] = []
return interactive_state
operation_log = "Remove all mask, please add new masks"
return interactive_state, gr.update(choices=[],value=[]), operation_log
def show_mask(video_state, interactive_state, mask_dropdown):
mask_dropdown.sort()
@@ -201,12 +211,13 @@ def show_mask(video_state, interactive_state, mask_dropdown):
mask_number = int(mask_dropdown[i].split("_")[1]) - 1
mask = interactive_state["multi_mask"]["masks"][mask_number]
select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2)
return select_frame
operation_log = "Select {} for tracking or inpainting".format(mask_dropdown)
return select_frame, operation_log
# tracking vos
def vos_tracking_video(video_state, interactive_state, mask_dropdown):
operation_log = "Track the selected masks, and then you can select the masks for inpainting."
model.xmem.clear_memory()
if interactive_state["track_end_number"]:
following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
@@ -225,6 +236,12 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
else:
template_mask = video_state["masks"][video_state["select_frame_number"]]
fps = video_state["fps"]
# operation error
if len(np.unique(template_mask))==1:
template_mask[0][0]=1
operation_log = "Error! Please add at least one mask to track by clicking the left image."
# return video_output, video_state, interactive_state, operation_error
masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
# clear GPU memory
model.xmem.clear_memory()
@@ -257,7 +274,7 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
i+=1
# save_mask(video_state["masks"], video_state["video_name"])
#### shanggao code for mask save
return video_output, video_state, interactive_state
return video_output, video_state, interactive_state, operation_log
# extracting masks from mask_dropdown
# def extract_sole_mask(video_state, mask_dropdown):
@@ -267,6 +284,7 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
# inpaint
def inpaint_video(video_state, interactive_state, mask_dropdown):
operation_log = "Removed the selected masks."
frames = np.asarray(video_state["origin_images"])
fps = video_state["fps"]
@@ -284,10 +302,15 @@ def inpaint_video(video_state, interactive_state, mask_dropdown):
continue
inpaint_masks[inpaint_masks==i] = 0
# inpaint for videos
inpainted_frames = model.baseinpainter.inpaint(frames, inpaint_masks, ratio=interactive_state["resize_ratio"]) # numpy array, T, H, W, 3
try:
inpainted_frames = model.baseinpainter.inpaint(frames, inpaint_masks, ratio=interactive_state["resize_ratio"]) # numpy array, T, H, W, 3
except:
operation_log = "Error! You are trying to inpaint without masks input. Please track the selected mask first, and then press inpaint. If VRAM exceeded, please use the resize ratio to scaling down the image size."
inpainted_frames = video_state["origin_images"]
video_output = generate_video_from_frames(inpainted_frames, output_path="./result/inpaint/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
return video_output
return video_output, operation_log
# generate video after vos inference
@@ -341,8 +364,8 @@ folder ="./checkpoints"
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
# args.port = 12315
# args.device = "cuda:2"
args.port = 12211
args.device = "cuda:2"
# args.mask_save = True
# initialize sam, xmem, e2fgvi models
@@ -395,8 +418,8 @@ with gr.Blocks() as iface:
video_input = gr.Video(autosize=True)
with gr.Column():
video_info = gr.Textbox()
resize_info = gr.Textbox(value="If you want to use the inpaint function, it is best to download and use a machine with more VRAM locally. \
Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.")
resize_info = gr.Textbox(value="If you want to use the inpaint function, it is best to git clone the repo and use a machine with more VRAM locally. \
Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.", label="Tips for running this demo.")
resize_ratio_slider = gr.Slider(minimum=0.02, maximum=1, step=0.02, value=1, label="Resize ratio", visible=True)
@@ -436,6 +459,7 @@ with gr.Blocks() as iface:
with gr.Row():
tracking_video_predict_button = gr.Button(value="Tracking", visible=False)
inpaint_video_predict_button = gr.Button(value="Inpaint", visible=False)
run_status = gr.Textbox(label="Operation log", visible=False)
# first step: get the video information
extract_frames_button.click(
@@ -445,16 +469,16 @@ with gr.Blocks() as iface:
],
outputs=[video_state, video_info, template_frame,
image_selection_slider, track_pause_number_slider,point_prompt, click_mode, clear_button_click, Add_mask_button, template_frame,
tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button, inpaint_video_predict_button]
tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button, inpaint_video_predict_button, run_status]
)
# second step: select images from slider
image_selection_slider.release(fn=select_template,
inputs=[image_selection_slider, video_state, interactive_state],
outputs=[template_frame, video_state, interactive_state], api_name="select_image")
outputs=[template_frame, video_state, interactive_state, run_status], api_name="select_image")
track_pause_number_slider.release(fn=get_end_number,
inputs=[track_pause_number_slider, video_state, interactive_state],
outputs=[template_frame, interactive_state], api_name="end_image")
outputs=[template_frame, interactive_state, run_status], api_name="end_image")
resize_ratio_slider.release(fn=get_resize_ratio,
inputs=[resize_ratio_slider, interactive_state],
outputs=[interactive_state], api_name="resize_ratio")
@@ -463,41 +487,41 @@ with gr.Blocks() as iface:
template_frame.select(
fn=sam_refine,
inputs=[video_state, point_prompt, click_state, interactive_state],
outputs=[template_frame, video_state, interactive_state]
outputs=[template_frame, video_state, interactive_state, run_status]
)
# add different mask
Add_mask_button.click(
fn=add_multi_mask,
inputs=[video_state, interactive_state, mask_dropdown],
outputs=[interactive_state, mask_dropdown, template_frame, click_state]
outputs=[interactive_state, mask_dropdown, template_frame, click_state, run_status]
)
remove_mask_button.click(
fn=remove_multi_mask,
inputs=[interactive_state],
outputs=[interactive_state]
inputs=[interactive_state, mask_dropdown],
outputs=[interactive_state, mask_dropdown, run_status]
)
# tracking video from select image and mask
tracking_video_predict_button.click(
fn=vos_tracking_video,
inputs=[video_state, interactive_state, mask_dropdown],
outputs=[video_output, video_state, interactive_state]
outputs=[video_output, video_state, interactive_state, run_status]
)
# inpaint video from select image and mask
inpaint_video_predict_button.click(
fn=inpaint_video,
inputs=[video_state, interactive_state, mask_dropdown],
outputs=[video_output]
outputs=[video_output, run_status]
)
# click to get mask
mask_dropdown.change(
fn=show_mask,
inputs=[video_state, interactive_state, mask_dropdown],
outputs=[template_frame]
outputs=[template_frame, run_status]
)
# clear input
@@ -529,7 +553,8 @@ with gr.Blocks() as iface:
None,
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=[]), gr.update(visible=False), gr.update(visible=False) \
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=[]), gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=False)
),
[],
@@ -540,7 +565,7 @@ with gr.Blocks() as iface:
video_output,
template_frame,
tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, click_mode, clear_button_click,
Add_mask_button, template_frame, tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button,inpaint_video_predict_button
Add_mask_button, template_frame, tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button,inpaint_video_predict_button, run_status
],
queue=False,
show_progress=False)
@@ -549,7 +574,7 @@ with gr.Blocks() as iface:
clear_button_click.click(
fn = clear_click,
inputs = [video_state, click_state,],
outputs = [template_frame,click_state],
outputs = [template_frame,click_state, run_status],
)
# set example
gr.Markdown("## Examples")