mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-15 16:07:51 +01:00
Merge branch 'master' of https://github.com/gaomingqi/VOS-Anything
This commit is contained in:
93
app.py
93
app.py
@@ -103,7 +103,7 @@ def get_frames_from_video(video_input, video_state):
|
||||
"fps": fps
|
||||
}
|
||||
video_info = "Video Name: {}, FPS: {}, Total Frames: {}, Image Size:{}".format(video_state["video_name"], video_state["fps"], len(frames), image_size)
|
||||
|
||||
operation_log = "Upload video already. Try click the image for adding targets to track and inpaint."
|
||||
model.samcontroler.sam_controler.reset_image()
|
||||
model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
|
||||
return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), \
|
||||
@@ -111,7 +111,8 @@ def get_frames_from_video(video_input, video_state):
|
||||
gr.update(visible=True), gr.update(visible=True), \
|
||||
gr.update(visible=True), gr.update(visible=True), \
|
||||
gr.update(visible=True), gr.update(visible=True), \
|
||||
gr.update(visible=True), gr.update(visible=True)
|
||||
gr.update(visible=True), gr.update(visible=True), \
|
||||
gr.update(visible=True, value=operation_log)
|
||||
|
||||
def run_example(example):
|
||||
return video_input
|
||||
@@ -130,15 +131,16 @@ def select_template(image_selection_slider, video_state, interactive_state):
|
||||
# update the masks when select a new template frame
|
||||
# if video_state["masks"][image_selection_slider] is not None:
|
||||
# video_state["painted_images"][image_selection_slider] = mask_painter(video_state["origin_images"][image_selection_slider], video_state["masks"][image_selection_slider])
|
||||
operation_log = "Select frame {}. Try click image and add mask for tracking.".format(image_selection_slider)
|
||||
|
||||
|
||||
return video_state["painted_images"][image_selection_slider], video_state, interactive_state
|
||||
return video_state["painted_images"][image_selection_slider], video_state, interactive_state, operation_log
|
||||
|
||||
# set the tracking end frame
|
||||
def get_end_number(track_pause_number_slider, video_state, interactive_state):
|
||||
interactive_state["track_end_number"] = track_pause_number_slider
|
||||
operation_log = "Set the tracking finish at frame {}".format(track_pause_number_slider)
|
||||
|
||||
return video_state["painted_images"][track_pause_number_slider],interactive_state
|
||||
return video_state["painted_images"][track_pause_number_slider],interactive_state, operation_log
|
||||
|
||||
def get_resize_ratio(resize_ratio_slider, interactive_state):
|
||||
interactive_state["resize_ratio"] = resize_ratio_slider
|
||||
@@ -161,6 +163,8 @@ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr
|
||||
interactive_state["negative_click_times"] += 1
|
||||
|
||||
# prompt for sam model
|
||||
model.samcontroler.sam_controler.reset_image()
|
||||
model.samcontroler.sam_controler.set_image(video_state["origin_images"][video_state["select_frame_number"]])
|
||||
prompt = get_prompt(click_state=click_state, click_input=coordinate)
|
||||
|
||||
mask, logit, painted_image = model.first_frame_click(
|
||||
@@ -173,25 +177,31 @@ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr
|
||||
video_state["logits"][video_state["select_frame_number"]] = logit
|
||||
video_state["painted_images"][video_state["select_frame_number"]] = painted_image
|
||||
|
||||
return painted_image, video_state, interactive_state
|
||||
operation_log = "Use SAM for segment. You can try add positive and negative points by clicking. Or press Clear clicks button to refresh the image. Press Add mask button when you are satisfied with the segment"
|
||||
return painted_image, video_state, interactive_state, operation_log
|
||||
|
||||
def add_multi_mask(video_state, interactive_state, mask_dropdown):
|
||||
mask = video_state["masks"][video_state["select_frame_number"]]
|
||||
interactive_state["multi_mask"]["masks"].append(mask)
|
||||
interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
|
||||
mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
|
||||
select_frame = show_mask(video_state, interactive_state, mask_dropdown)
|
||||
return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]]
|
||||
select_frame, run_status = show_mask(video_state, interactive_state, mask_dropdown)
|
||||
|
||||
operation_log = "Added a mask, use the mask select for target tracking or inpainting."
|
||||
return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]], operation_log
|
||||
|
||||
def clear_click(video_state, click_state):
|
||||
click_state = [[],[]]
|
||||
template_frame = video_state["origin_images"][video_state["select_frame_number"]]
|
||||
return template_frame, click_state
|
||||
operation_log = "Clear points history and refresh the image."
|
||||
return template_frame, click_state, operation_log
|
||||
|
||||
def remove_multi_mask(interactive_state):
|
||||
def remove_multi_mask(interactive_state, mask_dropdown):
|
||||
interactive_state["multi_mask"]["mask_names"]= []
|
||||
interactive_state["multi_mask"]["masks"] = []
|
||||
return interactive_state
|
||||
|
||||
operation_log = "Remove all mask, please add new masks"
|
||||
return interactive_state, gr.update(choices=[],value=[]), operation_log
|
||||
|
||||
def show_mask(video_state, interactive_state, mask_dropdown):
|
||||
mask_dropdown.sort()
|
||||
@@ -201,12 +211,13 @@ def show_mask(video_state, interactive_state, mask_dropdown):
|
||||
mask_number = int(mask_dropdown[i].split("_")[1]) - 1
|
||||
mask = interactive_state["multi_mask"]["masks"][mask_number]
|
||||
select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2)
|
||||
|
||||
return select_frame
|
||||
|
||||
operation_log = "Select {} for tracking or inpainting".format(mask_dropdown)
|
||||
return select_frame, operation_log
|
||||
|
||||
# tracking vos
|
||||
def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
||||
|
||||
operation_log = "Track the selected masks, and then you can select the masks for inpainting."
|
||||
model.xmem.clear_memory()
|
||||
if interactive_state["track_end_number"]:
|
||||
following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
|
||||
@@ -225,6 +236,12 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
||||
else:
|
||||
template_mask = video_state["masks"][video_state["select_frame_number"]]
|
||||
fps = video_state["fps"]
|
||||
|
||||
# operation error
|
||||
if len(np.unique(template_mask))==1:
|
||||
template_mask[0][0]=1
|
||||
operation_log = "Error! Please add at least one mask to track by clicking the left image."
|
||||
# return video_output, video_state, interactive_state, operation_error
|
||||
masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
|
||||
# clear GPU memory
|
||||
model.xmem.clear_memory()
|
||||
@@ -257,7 +274,7 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
||||
i+=1
|
||||
# save_mask(video_state["masks"], video_state["video_name"])
|
||||
#### shanggao code for mask save
|
||||
return video_output, video_state, interactive_state
|
||||
return video_output, video_state, interactive_state, operation_log
|
||||
|
||||
# extracting masks from mask_dropdown
|
||||
# def extract_sole_mask(video_state, mask_dropdown):
|
||||
@@ -267,6 +284,7 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
||||
|
||||
# inpaint
|
||||
def inpaint_video(video_state, interactive_state, mask_dropdown):
|
||||
operation_log = "Removed the selected masks."
|
||||
|
||||
frames = np.asarray(video_state["origin_images"])
|
||||
fps = video_state["fps"]
|
||||
@@ -284,10 +302,15 @@ def inpaint_video(video_state, interactive_state, mask_dropdown):
|
||||
continue
|
||||
inpaint_masks[inpaint_masks==i] = 0
|
||||
# inpaint for videos
|
||||
inpainted_frames = model.baseinpainter.inpaint(frames, inpaint_masks, ratio=interactive_state["resize_ratio"]) # numpy array, T, H, W, 3
|
||||
|
||||
try:
|
||||
inpainted_frames = model.baseinpainter.inpaint(frames, inpaint_masks, ratio=interactive_state["resize_ratio"]) # numpy array, T, H, W, 3
|
||||
except:
|
||||
operation_log = "Error! You are trying to inpaint without masks input. Please track the selected mask first, and then press inpaint. If VRAM exceeded, please use the resize ratio to scaling down the image size."
|
||||
inpainted_frames = video_state["origin_images"]
|
||||
video_output = generate_video_from_frames(inpainted_frames, output_path="./result/inpaint/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
|
||||
|
||||
return video_output
|
||||
return video_output, operation_log
|
||||
|
||||
|
||||
# generate video after vos inference
|
||||
@@ -341,8 +364,8 @@ folder ="./checkpoints"
|
||||
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
|
||||
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
|
||||
e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
|
||||
# args.port = 12315
|
||||
# args.device = "cuda:2"
|
||||
args.port = 12211
|
||||
args.device = "cuda:2"
|
||||
# args.mask_save = True
|
||||
|
||||
# initialize sam, xmem, e2fgvi models
|
||||
@@ -395,8 +418,8 @@ with gr.Blocks() as iface:
|
||||
video_input = gr.Video(autosize=True)
|
||||
with gr.Column():
|
||||
video_info = gr.Textbox()
|
||||
resize_info = gr.Textbox(value="If you want to use the inpaint function, it is best to download and use a machine with more VRAM locally. \
|
||||
Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.")
|
||||
resize_info = gr.Textbox(value="If you want to use the inpaint function, it is best to git clone the repo and use a machine with more VRAM locally. \
|
||||
Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.", label="Tips for running this demo.")
|
||||
resize_ratio_slider = gr.Slider(minimum=0.02, maximum=1, step=0.02, value=1, label="Resize ratio", visible=True)
|
||||
|
||||
|
||||
@@ -436,6 +459,7 @@ with gr.Blocks() as iface:
|
||||
with gr.Row():
|
||||
tracking_video_predict_button = gr.Button(value="Tracking", visible=False)
|
||||
inpaint_video_predict_button = gr.Button(value="Inpaint", visible=False)
|
||||
run_status = gr.Textbox(label="Operation log", visible=False)
|
||||
|
||||
# first step: get the video information
|
||||
extract_frames_button.click(
|
||||
@@ -445,16 +469,16 @@ with gr.Blocks() as iface:
|
||||
],
|
||||
outputs=[video_state, video_info, template_frame,
|
||||
image_selection_slider, track_pause_number_slider,point_prompt, click_mode, clear_button_click, Add_mask_button, template_frame,
|
||||
tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button, inpaint_video_predict_button]
|
||||
tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button, inpaint_video_predict_button, run_status]
|
||||
)
|
||||
|
||||
# second step: select images from slider
|
||||
image_selection_slider.release(fn=select_template,
|
||||
inputs=[image_selection_slider, video_state, interactive_state],
|
||||
outputs=[template_frame, video_state, interactive_state], api_name="select_image")
|
||||
outputs=[template_frame, video_state, interactive_state, run_status], api_name="select_image")
|
||||
track_pause_number_slider.release(fn=get_end_number,
|
||||
inputs=[track_pause_number_slider, video_state, interactive_state],
|
||||
outputs=[template_frame, interactive_state], api_name="end_image")
|
||||
outputs=[template_frame, interactive_state, run_status], api_name="end_image")
|
||||
resize_ratio_slider.release(fn=get_resize_ratio,
|
||||
inputs=[resize_ratio_slider, interactive_state],
|
||||
outputs=[interactive_state], api_name="resize_ratio")
|
||||
@@ -463,41 +487,41 @@ with gr.Blocks() as iface:
|
||||
template_frame.select(
|
||||
fn=sam_refine,
|
||||
inputs=[video_state, point_prompt, click_state, interactive_state],
|
||||
outputs=[template_frame, video_state, interactive_state]
|
||||
outputs=[template_frame, video_state, interactive_state, run_status]
|
||||
)
|
||||
|
||||
# add different mask
|
||||
Add_mask_button.click(
|
||||
fn=add_multi_mask,
|
||||
inputs=[video_state, interactive_state, mask_dropdown],
|
||||
outputs=[interactive_state, mask_dropdown, template_frame, click_state]
|
||||
outputs=[interactive_state, mask_dropdown, template_frame, click_state, run_status]
|
||||
)
|
||||
|
||||
remove_mask_button.click(
|
||||
fn=remove_multi_mask,
|
||||
inputs=[interactive_state],
|
||||
outputs=[interactive_state]
|
||||
inputs=[interactive_state, mask_dropdown],
|
||||
outputs=[interactive_state, mask_dropdown, run_status]
|
||||
)
|
||||
|
||||
# tracking video from select image and mask
|
||||
tracking_video_predict_button.click(
|
||||
fn=vos_tracking_video,
|
||||
inputs=[video_state, interactive_state, mask_dropdown],
|
||||
outputs=[video_output, video_state, interactive_state]
|
||||
outputs=[video_output, video_state, interactive_state, run_status]
|
||||
)
|
||||
|
||||
# inpaint video from select image and mask
|
||||
inpaint_video_predict_button.click(
|
||||
fn=inpaint_video,
|
||||
inputs=[video_state, interactive_state, mask_dropdown],
|
||||
outputs=[video_output]
|
||||
outputs=[video_output, run_status]
|
||||
)
|
||||
|
||||
# click to get mask
|
||||
mask_dropdown.change(
|
||||
fn=show_mask,
|
||||
inputs=[video_state, interactive_state, mask_dropdown],
|
||||
outputs=[template_frame]
|
||||
outputs=[template_frame, run_status]
|
||||
)
|
||||
|
||||
# clear input
|
||||
@@ -529,7 +553,8 @@ with gr.Blocks() as iface:
|
||||
None,
|
||||
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
|
||||
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
|
||||
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=[]), gr.update(visible=False), gr.update(visible=False) \
|
||||
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=[]), gr.update(visible=False), \
|
||||
gr.update(visible=False), gr.update(visible=False)
|
||||
|
||||
),
|
||||
[],
|
||||
@@ -540,7 +565,7 @@ with gr.Blocks() as iface:
|
||||
video_output,
|
||||
template_frame,
|
||||
tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, click_mode, clear_button_click,
|
||||
Add_mask_button, template_frame, tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button,inpaint_video_predict_button
|
||||
Add_mask_button, template_frame, tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button,inpaint_video_predict_button, run_status
|
||||
],
|
||||
queue=False,
|
||||
show_progress=False)
|
||||
@@ -549,7 +574,7 @@ with gr.Blocks() as iface:
|
||||
clear_button_click.click(
|
||||
fn = clear_click,
|
||||
inputs = [video_state, click_state,],
|
||||
outputs = [template_frame,click_state],
|
||||
outputs = [template_frame,click_state, run_status],
|
||||
)
|
||||
# set example
|
||||
gr.Markdown("## Examples")
|
||||
|
||||
Reference in New Issue
Block a user