This commit is contained in:
gaomingqi
2023-04-26 17:14:42 +08:00

91
app.py
View File

@@ -103,7 +103,7 @@ def get_frames_from_video(video_input, video_state):
"fps": fps "fps": fps
} }
video_info = "Video Name: {}, FPS: {}, Total Frames: {}, Image Size:{}".format(video_state["video_name"], video_state["fps"], len(frames), image_size) video_info = "Video Name: {}, FPS: {}, Total Frames: {}, Image Size:{}".format(video_state["video_name"], video_state["fps"], len(frames), image_size)
operation_log = "Upload video already. Try click the image for adding targets to track and inpaint."
model.samcontroler.sam_controler.reset_image() model.samcontroler.sam_controler.reset_image()
model.samcontroler.sam_controler.set_image(video_state["origin_images"][0]) model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), \ return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), \
@@ -111,7 +111,8 @@ def get_frames_from_video(video_input, video_state):
gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True), gr.update(visible=True) gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True, value=operation_log)
def run_example(example): def run_example(example):
return video_input return video_input
@@ -130,15 +131,16 @@ def select_template(image_selection_slider, video_state, interactive_state):
# update the masks when select a new template frame # update the masks when select a new template frame
# if video_state["masks"][image_selection_slider] is not None: # if video_state["masks"][image_selection_slider] is not None:
# video_state["painted_images"][image_selection_slider] = mask_painter(video_state["origin_images"][image_selection_slider], video_state["masks"][image_selection_slider]) # video_state["painted_images"][image_selection_slider] = mask_painter(video_state["origin_images"][image_selection_slider], video_state["masks"][image_selection_slider])
operation_log = "Select frame {}. Try click image and add mask for tracking.".format(image_selection_slider)
return video_state["painted_images"][image_selection_slider], video_state, interactive_state, operation_log
return video_state["painted_images"][image_selection_slider], video_state, interactive_state
# set the tracking end frame # set the tracking end frame
def get_end_number(track_pause_number_slider, video_state, interactive_state): def get_end_number(track_pause_number_slider, video_state, interactive_state):
interactive_state["track_end_number"] = track_pause_number_slider interactive_state["track_end_number"] = track_pause_number_slider
operation_log = "Set the tracking finish at frame {}".format(track_pause_number_slider)
return video_state["painted_images"][track_pause_number_slider],interactive_state return video_state["painted_images"][track_pause_number_slider],interactive_state, operation_log
def get_resize_ratio(resize_ratio_slider, interactive_state): def get_resize_ratio(resize_ratio_slider, interactive_state):
interactive_state["resize_ratio"] = resize_ratio_slider interactive_state["resize_ratio"] = resize_ratio_slider
@@ -161,6 +163,8 @@ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr
interactive_state["negative_click_times"] += 1 interactive_state["negative_click_times"] += 1
# prompt for sam model # prompt for sam model
model.samcontroler.sam_controler.reset_image()
model.samcontroler.sam_controler.set_image(video_state["origin_images"][video_state["select_frame_number"]])
prompt = get_prompt(click_state=click_state, click_input=coordinate) prompt = get_prompt(click_state=click_state, click_input=coordinate)
mask, logit, painted_image = model.first_frame_click( mask, logit, painted_image = model.first_frame_click(
@@ -173,25 +177,31 @@ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr
video_state["logits"][video_state["select_frame_number"]] = logit video_state["logits"][video_state["select_frame_number"]] = logit
video_state["painted_images"][video_state["select_frame_number"]] = painted_image video_state["painted_images"][video_state["select_frame_number"]] = painted_image
return painted_image, video_state, interactive_state operation_log = "Use SAM for segment. You can try add positive and negative points by clicking. Or press Clear clicks button to refresh the image. Press Add mask button when you are satisfied with the segment"
return painted_image, video_state, interactive_state, operation_log
def add_multi_mask(video_state, interactive_state, mask_dropdown): def add_multi_mask(video_state, interactive_state, mask_dropdown):
mask = video_state["masks"][video_state["select_frame_number"]] mask = video_state["masks"][video_state["select_frame_number"]]
interactive_state["multi_mask"]["masks"].append(mask) interactive_state["multi_mask"]["masks"].append(mask)
interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"]))) interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"]))) mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
select_frame = show_mask(video_state, interactive_state, mask_dropdown) select_frame, run_status = show_mask(video_state, interactive_state, mask_dropdown)
return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]]
operation_log = "Added a mask, use the mask select for target tracking or inpainting."
return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]], operation_log
def clear_click(video_state, click_state): def clear_click(video_state, click_state):
click_state = [[],[]] click_state = [[],[]]
template_frame = video_state["origin_images"][video_state["select_frame_number"]] template_frame = video_state["origin_images"][video_state["select_frame_number"]]
return template_frame, click_state operation_log = "Clear points history and refresh the image."
return template_frame, click_state, operation_log
def remove_multi_mask(interactive_state): def remove_multi_mask(interactive_state, mask_dropdown):
interactive_state["multi_mask"]["mask_names"]= [] interactive_state["multi_mask"]["mask_names"]= []
interactive_state["multi_mask"]["masks"] = [] interactive_state["multi_mask"]["masks"] = []
return interactive_state
operation_log = "Remove all mask, please add new masks"
return interactive_state, gr.update(choices=[],value=[]), operation_log
def show_mask(video_state, interactive_state, mask_dropdown): def show_mask(video_state, interactive_state, mask_dropdown):
mask_dropdown.sort() mask_dropdown.sort()
@@ -202,11 +212,12 @@ def show_mask(video_state, interactive_state, mask_dropdown):
mask = interactive_state["multi_mask"]["masks"][mask_number] mask = interactive_state["multi_mask"]["masks"][mask_number]
select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2) select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2)
return select_frame operation_log = "Select {} for tracking or inpainting".format(mask_dropdown)
return select_frame, operation_log
# tracking vos # tracking vos
def vos_tracking_video(video_state, interactive_state, mask_dropdown): def vos_tracking_video(video_state, interactive_state, mask_dropdown):
operation_log = "Track the selected masks, and then you can select the masks for inpainting."
model.xmem.clear_memory() model.xmem.clear_memory()
if interactive_state["track_end_number"]: if interactive_state["track_end_number"]:
following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
@@ -225,6 +236,12 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
else: else:
template_mask = video_state["masks"][video_state["select_frame_number"]] template_mask = video_state["masks"][video_state["select_frame_number"]]
fps = video_state["fps"] fps = video_state["fps"]
# operation error
if len(np.unique(template_mask))==1:
template_mask[0][0]=1
operation_log = "Error! Please add at least one mask to track by clicking the left image."
# return video_output, video_state, interactive_state, operation_error
masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask) masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
# clear GPU memory # clear GPU memory
model.xmem.clear_memory() model.xmem.clear_memory()
@@ -257,7 +274,7 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
i+=1 i+=1
# save_mask(video_state["masks"], video_state["video_name"]) # save_mask(video_state["masks"], video_state["video_name"])
#### shanggao code for mask save #### shanggao code for mask save
return video_output, video_state, interactive_state return video_output, video_state, interactive_state, operation_log
# extracting masks from mask_dropdown # extracting masks from mask_dropdown
# def extract_sole_mask(video_state, mask_dropdown): # def extract_sole_mask(video_state, mask_dropdown):
@@ -267,6 +284,7 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
# inpaint # inpaint
def inpaint_video(video_state, interactive_state, mask_dropdown): def inpaint_video(video_state, interactive_state, mask_dropdown):
operation_log = "Removed the selected masks."
frames = np.asarray(video_state["origin_images"]) frames = np.asarray(video_state["origin_images"])
fps = video_state["fps"] fps = video_state["fps"]
@@ -284,10 +302,15 @@ def inpaint_video(video_state, interactive_state, mask_dropdown):
continue continue
inpaint_masks[inpaint_masks==i] = 0 inpaint_masks[inpaint_masks==i] = 0
# inpaint for videos # inpaint for videos
inpainted_frames = model.baseinpainter.inpaint(frames, inpaint_masks, ratio=interactive_state["resize_ratio"]) # numpy array, T, H, W, 3
try:
inpainted_frames = model.baseinpainter.inpaint(frames, inpaint_masks, ratio=interactive_state["resize_ratio"]) # numpy array, T, H, W, 3
except:
operation_log = "Error! You are trying to inpaint without masks input. Please track the selected mask first, and then press inpaint. If VRAM exceeded, please use the resize ratio to scaling down the image size."
inpainted_frames = video_state["origin_images"]
video_output = generate_video_from_frames(inpainted_frames, output_path="./result/inpaint/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video video_output = generate_video_from_frames(inpainted_frames, output_path="./result/inpaint/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
return video_output return video_output, operation_log
# generate video after vos inference # generate video after vos inference
@@ -341,8 +364,8 @@ folder ="./checkpoints"
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint) SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint) xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint) e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
# args.port = 12315 args.port = 12211
# args.device = "cuda:2" args.device = "cuda:2"
# args.mask_save = True # args.mask_save = True
# initialize sam, xmem, e2fgvi models # initialize sam, xmem, e2fgvi models
@@ -395,8 +418,8 @@ with gr.Blocks() as iface:
video_input = gr.Video(autosize=True) video_input = gr.Video(autosize=True)
with gr.Column(): with gr.Column():
video_info = gr.Textbox() video_info = gr.Textbox()
resize_info = gr.Textbox(value="If you want to use the inpaint function, it is best to download and use a machine with more VRAM locally. \ resize_info = gr.Textbox(value="If you want to use the inpaint function, it is best to git clone the repo and use a machine with more VRAM locally. \
Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.") Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.", label="Tips for running this demo.")
resize_ratio_slider = gr.Slider(minimum=0.02, maximum=1, step=0.02, value=1, label="Resize ratio", visible=True) resize_ratio_slider = gr.Slider(minimum=0.02, maximum=1, step=0.02, value=1, label="Resize ratio", visible=True)
@@ -436,6 +459,7 @@ with gr.Blocks() as iface:
with gr.Row(): with gr.Row():
tracking_video_predict_button = gr.Button(value="Tracking", visible=False) tracking_video_predict_button = gr.Button(value="Tracking", visible=False)
inpaint_video_predict_button = gr.Button(value="Inpaint", visible=False) inpaint_video_predict_button = gr.Button(value="Inpaint", visible=False)
run_status = gr.Textbox(label="Operation log", visible=False)
# first step: get the video information # first step: get the video information
extract_frames_button.click( extract_frames_button.click(
@@ -445,16 +469,16 @@ with gr.Blocks() as iface:
], ],
outputs=[video_state, video_info, template_frame, outputs=[video_state, video_info, template_frame,
image_selection_slider, track_pause_number_slider,point_prompt, click_mode, clear_button_click, Add_mask_button, template_frame, image_selection_slider, track_pause_number_slider,point_prompt, click_mode, clear_button_click, Add_mask_button, template_frame,
tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button, inpaint_video_predict_button] tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button, inpaint_video_predict_button, run_status]
) )
# second step: select images from slider # second step: select images from slider
image_selection_slider.release(fn=select_template, image_selection_slider.release(fn=select_template,
inputs=[image_selection_slider, video_state, interactive_state], inputs=[image_selection_slider, video_state, interactive_state],
outputs=[template_frame, video_state, interactive_state], api_name="select_image") outputs=[template_frame, video_state, interactive_state, run_status], api_name="select_image")
track_pause_number_slider.release(fn=get_end_number, track_pause_number_slider.release(fn=get_end_number,
inputs=[track_pause_number_slider, video_state, interactive_state], inputs=[track_pause_number_slider, video_state, interactive_state],
outputs=[template_frame, interactive_state], api_name="end_image") outputs=[template_frame, interactive_state, run_status], api_name="end_image")
resize_ratio_slider.release(fn=get_resize_ratio, resize_ratio_slider.release(fn=get_resize_ratio,
inputs=[resize_ratio_slider, interactive_state], inputs=[resize_ratio_slider, interactive_state],
outputs=[interactive_state], api_name="resize_ratio") outputs=[interactive_state], api_name="resize_ratio")
@@ -463,41 +487,41 @@ with gr.Blocks() as iface:
template_frame.select( template_frame.select(
fn=sam_refine, fn=sam_refine,
inputs=[video_state, point_prompt, click_state, interactive_state], inputs=[video_state, point_prompt, click_state, interactive_state],
outputs=[template_frame, video_state, interactive_state] outputs=[template_frame, video_state, interactive_state, run_status]
) )
# add different mask # add different mask
Add_mask_button.click( Add_mask_button.click(
fn=add_multi_mask, fn=add_multi_mask,
inputs=[video_state, interactive_state, mask_dropdown], inputs=[video_state, interactive_state, mask_dropdown],
outputs=[interactive_state, mask_dropdown, template_frame, click_state] outputs=[interactive_state, mask_dropdown, template_frame, click_state, run_status]
) )
remove_mask_button.click( remove_mask_button.click(
fn=remove_multi_mask, fn=remove_multi_mask,
inputs=[interactive_state], inputs=[interactive_state, mask_dropdown],
outputs=[interactive_state] outputs=[interactive_state, mask_dropdown, run_status]
) )
# tracking video from select image and mask # tracking video from select image and mask
tracking_video_predict_button.click( tracking_video_predict_button.click(
fn=vos_tracking_video, fn=vos_tracking_video,
inputs=[video_state, interactive_state, mask_dropdown], inputs=[video_state, interactive_state, mask_dropdown],
outputs=[video_output, video_state, interactive_state] outputs=[video_output, video_state, interactive_state, run_status]
) )
# inpaint video from select image and mask # inpaint video from select image and mask
inpaint_video_predict_button.click( inpaint_video_predict_button.click(
fn=inpaint_video, fn=inpaint_video,
inputs=[video_state, interactive_state, mask_dropdown], inputs=[video_state, interactive_state, mask_dropdown],
outputs=[video_output] outputs=[video_output, run_status]
) )
# click to get mask # click to get mask
mask_dropdown.change( mask_dropdown.change(
fn=show_mask, fn=show_mask,
inputs=[video_state, interactive_state, mask_dropdown], inputs=[video_state, interactive_state, mask_dropdown],
outputs=[template_frame] outputs=[template_frame, run_status]
) )
# clear input # clear input
@@ -529,7 +553,8 @@ with gr.Blocks() as iface:
None, None,
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=[]), gr.update(visible=False), gr.update(visible=False) \ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=[]), gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=False)
), ),
[], [],
@@ -540,7 +565,7 @@ with gr.Blocks() as iface:
video_output, video_output,
template_frame, template_frame,
tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, click_mode, clear_button_click, tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, click_mode, clear_button_click,
Add_mask_button, template_frame, tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button,inpaint_video_predict_button Add_mask_button, template_frame, tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button,inpaint_video_predict_button, run_status
], ],
queue=False, queue=False,
show_progress=False) show_progress=False)
@@ -549,7 +574,7 @@ with gr.Blocks() as iface:
clear_button_click.click( clear_button_click.click(
fn = clear_click, fn = clear_click,
inputs = [video_state, click_state,], inputs = [video_state, click_state,],
outputs = [template_frame,click_state], outputs = [template_frame,click_state, run_status],
) )
# set example # set example
gr.Markdown("## Examples") gr.Markdown("## Examples")