2023-04-13 13:40:10 +00:00
import gradio as gr
import argparse
2023-04-25 13:38:16 +00:00
import gdown
2023-04-13 13:40:10 +00:00
import cv2
2023-04-13 18:00:48 +00:00
import numpy as np
2023-04-14 02:27:39 +00:00
import os
import sys
sys . path . append ( sys . path [ 0 ] + " /tracker " )
sys . path . append ( sys . path [ 0 ] + " /tracker/model " )
from track_anything import TrackingAnything
from track_anything import parse_augment
import requests
2023-04-14 08:24:57 +00:00
import json
2023-04-14 13:26:26 +00:00
import torchvision
import torch
2023-04-25 13:38:16 +00:00
from tools . painter import mask_painter
2023-04-18 04:01:14 +00:00
# download checkpoints
2023-04-14 02:27:39 +00:00
def download_checkpoint ( url , folder , filename ) :
os . makedirs ( folder , exist_ok = True )
filepath = os . path . join ( folder , filename )
2023-04-13 19:04:57 +00:00
2023-04-14 02:27:39 +00:00
if not os . path . exists ( filepath ) :
print ( " download checkpoints ...... " )
response = requests . get ( url , stream = True )
with open ( filepath , " wb " ) as f :
for chunk in response . iter_content ( chunk_size = 8192 ) :
if chunk :
f . write ( chunk )
2023-04-14 04:02:02 +08:00
2023-04-14 02:27:39 +00:00
print ( " download successfully! " )
return filepath
2023-04-25 13:38:16 +00:00
def download_checkpoint_from_google_drive ( file_id , folder , filename ) :
os . makedirs ( folder , exist_ok = True )
filepath = os . path . join ( folder , filename )
if not os . path . exists ( filepath ) :
print ( " Downloading checkpoints from Google Drive... tips: If you cannot see the progress bar, please try to download it manuall \
and put it in the checkpointes directory . E2FGVI - HQ - CVPR22 . pth : https : / / github . com / MCG - NKU / E2FGVI ( E2FGVI - HQ model ) " )
url = f " https://drive.google.com/uc?id= { file_id } "
gdown . download ( url , filepath , quiet = False )
print ( " Downloaded successfully! " )
return filepath
2023-04-14 02:27:39 +00:00
# convert points input to prompt state
2023-04-14 08:24:57 +00:00
def get_prompt ( click_state , click_input ) :
inputs = json . loads ( click_input )
points = click_state [ 0 ]
labels = click_state [ 1 ]
2023-04-14 02:27:39 +00:00
for input in inputs :
points . append ( input [ : 2 ] )
labels . append ( input [ 2 ] )
click_state [ 0 ] = points
2023-04-14 08:24:57 +00:00
click_state [ 1 ] = labels
2023-04-14 02:27:39 +00:00
prompt = {
" prompt_type " : [ " click " ] ,
" input_point " : click_state [ 0 ] ,
" input_label " : click_state [ 1 ] ,
" multimask_output " : " True " ,
}
return prompt
2023-04-18 04:01:14 +00:00
# extract frames from upload video
2023-04-17 21:13:45 +00:00
def get_frames_from_video ( video_input , video_state ) :
2023-04-13 22:27:42 +08:00
"""
2023-04-13 18:00:48 +00:00
Args :
2023-04-13 22:27:42 +08:00
video_path : str
timestamp : float64
2023-04-13 18:00:48 +00:00
Return
2023-04-14 04:46:20 +08:00
[ [ 0 : nearest_frame ] , [ nearest_frame : ] , nearest_frame ]
2023-04-13 22:27:42 +08:00
"""
2023-04-13 18:00:48 +00:00
video_path = video_input
2023-04-13 22:27:42 +08:00
frames = [ ]
try :
cap = cv2 . VideoCapture ( video_path )
fps = cap . get ( cv2 . CAP_PROP_FPS )
while cap . isOpened ( ) :
ret , frame = cap . read ( )
if ret == True :
2023-04-17 09:09:56 +00:00
frames . append ( cv2 . cvtColor ( frame , cv2 . COLOR_BGR2RGB ) )
2023-04-13 22:27:42 +08:00
else :
break
except ( OSError , TypeError , ValueError , KeyError , SyntaxError ) as e :
print ( " read_frame_source: {} error. {} \n " . format ( video_path , str ( e ) ) )
2023-04-25 13:38:16 +00:00
image_size = ( frames [ 0 ] . shape [ 0 ] , frames [ 0 ] . shape [ 1 ] )
2023-04-17 21:13:45 +00:00
# initialize video_state
video_state = {
" video_name " : os . path . split ( video_path ) [ - 1 ] ,
" origin_images " : frames ,
" painted_images " : frames . copy ( ) ,
2023-04-25 13:38:16 +00:00
" masks " : [ np . zeros ( ( frames [ 0 ] . shape [ 0 ] , frames [ 0 ] . shape [ 1 ] ) , np . uint8 ) ] * len ( frames ) ,
2023-04-17 21:13:45 +00:00
" logits " : [ None ] * len ( frames ) ,
" select_frame_number " : 0 ,
2023-04-19 11:34:14 +00:00
" fps " : fps
2023-04-17 21:13:45 +00:00
}
2023-04-25 13:38:16 +00:00
video_info = " Video Name: {} , FPS: {} , Total Frames: {} , Image Size: {} " . format ( video_state [ " video_name " ] , video_state [ " fps " ] , len ( frames ) , image_size )
2023-04-19 17:06:19 +00:00
model . samcontroler . sam_controler . reset_image ( )
model . samcontroler . sam_controler . set_image ( video_state [ " origin_images " ] [ 0 ] )
2023-04-20 11:53:19 +00:00
return video_state , video_info , video_state [ " origin_images " ] [ 0 ] , gr . update ( visible = True , maximum = len ( frames ) , value = 1 ) , gr . update ( visible = True , maximum = len ( frames ) , value = len ( frames ) ) , \
2023-04-19 11:34:14 +00:00
gr . update ( visible = True ) , gr . update ( visible = True ) , \
gr . update ( visible = True ) , gr . update ( visible = True ) , \
gr . update ( visible = True ) , gr . update ( visible = True ) , \
gr . update ( visible = True ) , gr . update ( visible = True ) , \
2023-04-25 13:38:16 +00:00
gr . update ( visible = True ) , gr . update ( visible = True )
2023-04-17 21:13:45 +00:00
2023-04-24 17:07:46 +00:00
def run_example ( example ) :
return video_input
2023-04-18 04:01:14 +00:00
# get the select frame from gradio slider
2023-04-19 11:34:14 +00:00
def select_template ( image_selection_slider , video_state , interactive_state ) :
2023-04-17 21:13:45 +00:00
# images = video_state[1]
image_selection_slider - = 1
video_state [ " select_frame_number " ] = image_selection_slider
2023-04-14 08:50:49 +00:00
2023-04-17 21:13:45 +00:00
# once select a new template frame, set the image in sam
model . samcontroler . sam_controler . reset_image ( )
model . samcontroler . sam_controler . set_image ( video_state [ " origin_images " ] [ image_selection_slider ] )
2023-04-24 17:07:46 +00:00
# update the masks when select a new template frame
# if video_state["masks"][image_selection_slider] is not None:
# video_state["painted_images"][image_selection_slider] = mask_painter(video_state["origin_images"][image_selection_slider], video_state["masks"][image_selection_slider])
2023-04-17 21:13:45 +00:00
2023-04-19 11:34:14 +00:00
return video_state [ " painted_images " ] [ image_selection_slider ] , video_state , interactive_state
2023-04-24 17:07:46 +00:00
# set the tracking end frame
2023-04-25 13:38:16 +00:00
def get_end_number ( track_pause_number_slider , video_state , interactive_state ) :
2023-04-19 11:34:14 +00:00
interactive_state [ " track_end_number " ] = track_pause_number_slider
2023-04-25 13:38:16 +00:00
return video_state [ " painted_images " ] [ track_pause_number_slider ] , interactive_state
def get_resize_ratio ( resize_ratio_slider , interactive_state ) :
interactive_state [ " resize_ratio " ] = resize_ratio_slider
2023-04-19 11:34:14 +00:00
return interactive_state
2023-04-14 13:26:26 +00:00
2023-04-18 04:01:14 +00:00
# use sam to get the mask
def sam_refine ( video_state , point_prompt , click_state , interactive_state , evt : gr . SelectData ) :
2023-04-14 08:24:57 +00:00
"""
Args :
template_frame : PIL . Image
point_prompt : flag for positive or negative button click
click_state : [ [ points ] , [ labels ] ]
"""
if point_prompt == " Positive " :
coordinate = " [[ {} , {} ,1]] " . format ( evt . index [ 0 ] , evt . index [ 1 ] )
2023-04-18 04:01:14 +00:00
interactive_state [ " positive_click_times " ] + = 1
2023-04-14 08:24:57 +00:00
else :
coordinate = " [[ {} , {} ,0]] " . format ( evt . index [ 0 ] , evt . index [ 1 ] )
2023-04-18 04:01:14 +00:00
interactive_state [ " negative_click_times " ] + = 1
2023-04-14 08:24:57 +00:00
# prompt for sam model
prompt = get_prompt ( click_state = click_state , click_input = coordinate )
2023-04-14 02:27:39 +00:00
2023-04-14 08:50:49 +00:00
mask , logit , painted_image = model . first_frame_click (
2023-04-17 21:13:45 +00:00
image = video_state [ " origin_images " ] [ video_state [ " select_frame_number " ] ] ,
2023-04-14 11:27:13 +00:00
points = np . array ( prompt [ " input_point " ] ) ,
2023-04-14 08:24:57 +00:00
labels = np . array ( prompt [ " input_label " ] ) ,
2023-04-14 11:27:13 +00:00
multimask = prompt [ " multimask_output " ] ,
2023-04-14 08:24:57 +00:00
)
2023-04-17 21:13:45 +00:00
video_state [ " masks " ] [ video_state [ " select_frame_number " ] ] = mask
video_state [ " logits " ] [ video_state [ " select_frame_number " ] ] = logit
video_state [ " painted_images " ] [ video_state [ " select_frame_number " ] ] = painted_image
2023-04-14 13:26:26 +00:00
2023-04-18 04:01:14 +00:00
return painted_image , video_state , interactive_state
2023-04-16 08:53:28 +00:00
2023-04-19 11:34:14 +00:00
def add_multi_mask ( video_state , interactive_state , mask_dropdown ) :
mask = video_state [ " masks " ] [ video_state [ " select_frame_number " ] ]
interactive_state [ " multi_mask " ] [ " masks " ] . append ( mask )
2023-04-19 17:06:19 +00:00
interactive_state [ " multi_mask " ] [ " mask_names " ] . append ( " mask_ {:03d} " . format ( len ( interactive_state [ " multi_mask " ] [ " masks " ] ) ) )
mask_dropdown . append ( " mask_ {:03d} " . format ( len ( interactive_state [ " multi_mask " ] [ " masks " ] ) ) )
2023-04-20 11:53:19 +00:00
select_frame = show_mask ( video_state , interactive_state , mask_dropdown )
return interactive_state , gr . update ( choices = interactive_state [ " multi_mask " ] [ " mask_names " ] , value = mask_dropdown ) , select_frame , [ [ ] , [ ] ]
2023-04-19 11:34:14 +00:00
2023-04-20 11:53:19 +00:00
def clear_click ( video_state , click_state ) :
click_state = [ [ ] , [ ] ]
template_frame = video_state [ " origin_images " ] [ video_state [ " select_frame_number " ] ]
return template_frame , click_state
2023-04-19 11:34:14 +00:00
def remove_multi_mask ( interactive_state ) :
interactive_state [ " multi_mask " ] [ " mask_names " ] = [ ]
interactive_state [ " multi_mask " ] [ " masks " ] = [ ]
return interactive_state
def show_mask ( video_state , interactive_state , mask_dropdown ) :
mask_dropdown . sort ( )
select_frame = video_state [ " origin_images " ] [ video_state [ " select_frame_number " ] ]
for i in range ( len ( mask_dropdown ) ) :
mask_number = int ( mask_dropdown [ i ] . split ( " _ " ) [ 1 ] ) - 1
mask = interactive_state [ " multi_mask " ] [ " masks " ] [ mask_number ]
select_frame = mask_painter ( select_frame , mask . astype ( ' uint8 ' ) , mask_color = mask_number + 2 )
return select_frame
2023-04-18 04:01:14 +00:00
# tracking vos
2023-04-19 11:34:14 +00:00
def vos_tracking_video ( video_state , interactive_state , mask_dropdown ) :
2023-04-16 08:53:28 +00:00
model . xmem . clear_memory ( )
2023-04-19 11:34:14 +00:00
if interactive_state [ " track_end_number " ] :
following_frames = video_state [ " origin_images " ] [ video_state [ " select_frame_number " ] : interactive_state [ " track_end_number " ] ]
else :
following_frames = video_state [ " origin_images " ] [ video_state [ " select_frame_number " ] : ]
if interactive_state [ " multi_mask " ] [ " masks " ] :
2023-04-19 17:06:19 +00:00
if len ( mask_dropdown ) == 0 :
mask_dropdown = [ " mask_001 " ]
2023-04-19 11:34:14 +00:00
mask_dropdown . sort ( )
2023-04-19 17:06:19 +00:00
template_mask = interactive_state [ " multi_mask " ] [ " masks " ] [ int ( mask_dropdown [ 0 ] . split ( " _ " ) [ 1 ] ) - 1 ] * ( int ( mask_dropdown [ 0 ] . split ( " _ " ) [ 1 ] ) )
2023-04-19 11:34:14 +00:00
for i in range ( 1 , len ( mask_dropdown ) ) :
mask_number = int ( mask_dropdown [ i ] . split ( " _ " ) [ 1 ] ) - 1
template_mask = np . clip ( template_mask + interactive_state [ " multi_mask " ] [ " masks " ] [ mask_number ] * ( mask_number + 1 ) , 0 , mask_number + 1 )
video_state [ " masks " ] [ video_state [ " select_frame_number " ] ] = template_mask
else :
template_mask = video_state [ " masks " ] [ video_state [ " select_frame_number " ] ]
2023-04-17 21:13:45 +00:00
fps = video_state [ " fps " ]
masks , logits , painted_images = model . generator ( images = following_frames , template_mask = template_mask )
2023-04-19 11:34:14 +00:00
if interactive_state [ " track_end_number " ] :
video_state [ " masks " ] [ video_state [ " select_frame_number " ] : interactive_state [ " track_end_number " ] ] = masks
video_state [ " logits " ] [ video_state [ " select_frame_number " ] : interactive_state [ " track_end_number " ] ] = logits
video_state [ " painted_images " ] [ video_state [ " select_frame_number " ] : interactive_state [ " track_end_number " ] ] = painted_images
else :
video_state [ " masks " ] [ video_state [ " select_frame_number " ] : ] = masks
video_state [ " logits " ] [ video_state [ " select_frame_number " ] : ] = logits
video_state [ " painted_images " ] [ video_state [ " select_frame_number " ] : ] = painted_images
2023-04-17 21:13:45 +00:00
2023-04-25 13:38:16 +00:00
video_output = generate_video_from_frames ( video_state [ " painted_images " ] , output_path = " ./result/track/ {} " . format ( video_state [ " video_name " ] ) , fps = fps ) # import video_input to name the output video
2023-04-18 04:01:14 +00:00
interactive_state [ " inference_times " ] + = 1
2023-04-17 21:13:45 +00:00
2023-04-18 04:01:14 +00:00
print ( " For generating this tracking result, inference times: {} , click times: {} , positive: {} , negative: {} " . format ( interactive_state [ " inference_times " ] ,
interactive_state [ " positive_click_times " ] + interactive_state [ " negative_click_times " ] ,
interactive_state [ " positive_click_times " ] ,
interactive_state [ " negative_click_times " ] ) )
2023-04-19 11:34:14 +00:00
2023-04-18 04:01:14 +00:00
#### shanggao code for mask save
if interactive_state [ " mask_save " ] :
if not os . path . exists ( ' ./result/mask/ {} ' . format ( video_state [ " video_name " ] . split ( ' . ' ) [ 0 ] ) ) :
os . makedirs ( ' ./result/mask/ {} ' . format ( video_state [ " video_name " ] . split ( ' . ' ) [ 0 ] ) )
i = 0
print ( " save mask " )
for mask in video_state [ " masks " ] :
np . save ( os . path . join ( ' ./result/mask/ {} ' . format ( video_state [ " video_name " ] . split ( ' . ' ) [ 0 ] ) , ' {:05d} .npy ' . format ( i ) ) , mask )
i + = 1
# save_mask(video_state["masks"], video_state["video_name"])
#### shanggao code for mask save
return video_output , video_state , interactive_state
2023-04-25 13:38:16 +00:00
# extracting masks from mask_dropdown
# def extract_sole_mask(video_state, mask_dropdown):
# combined_masks =
# unique_masks = np.unique(combined_masks)
# return 0
# inpaint
def inpaint_video ( video_state , interactive_state , mask_dropdown ) :
frames = np . asarray ( video_state [ " origin_images " ] )
fps = video_state [ " fps " ]
inpaint_masks = np . asarray ( video_state [ " masks " ] )
if len ( mask_dropdown ) == 0 :
mask_dropdown = [ " mask_001 " ]
mask_dropdown . sort ( )
# convert mask_dropdown to mask numbers
inpaint_mask_numbers = [ int ( mask_dropdown [ i ] . split ( " _ " ) [ 1 ] ) for i in range ( len ( mask_dropdown ) ) ]
# interate through all masks and remove the masks that are not in mask_dropdown
unique_masks = np . unique ( inpaint_masks )
num_masks = len ( unique_masks ) - 1
for i in range ( 1 , num_masks + 1 ) :
if i in inpaint_mask_numbers :
continue
inpaint_masks [ inpaint_masks == i ] = 0
# inpaint for videos
inpainted_frames = model . baseinpainter . inpaint ( frames , inpaint_masks , ratio = interactive_state [ " resize_ratio " ] ) # numpy array, T, H, W, 3
video_output = generate_video_from_frames ( inpainted_frames , output_path = " ./result/inpaint/ {} " . format ( video_state [ " video_name " ] ) , fps = fps ) # import video_input to name the output video
return video_output
2023-04-18 04:01:14 +00:00
# generate video after vos inference
def generate_video_from_frames ( frames , output_path , fps = 30 ) :
"""
Generates a video from a list of frames .
Args :
frames ( list of numpy arrays ) : The frames to include in the video .
output_path ( str ) : The path to save the generated video .
fps ( int , optional ) : The frame rate of the output video . Defaults to 30.
"""
2023-04-19 11:34:14 +00:00
# height, width, layers = frames[0].shape
# fourcc = cv2.VideoWriter_fourcc(*"mp4v")
# video = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
# print(output_path)
# for frame in frames:
# video.write(frame)
# video.release()
2023-04-18 04:01:14 +00:00
frames = torch . from_numpy ( np . asarray ( frames ) )
if not os . path . exists ( os . path . dirname ( output_path ) ) :
os . makedirs ( os . path . dirname ( output_path ) )
torchvision . io . write_video ( output_path , frames , fps = fps , video_codec = " libx264 " )
return output_path
2023-04-15 19:59:58 +00:00
# check and download checkpoints if needed
SAM_checkpoint = " sam_vit_h_4b8939.pth "
sam_checkpoint_url = " https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth "
xmem_checkpoint = " XMem-s012.pth "
xmem_checkpoint_url = " https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth "
2023-04-25 13:38:16 +00:00
e2fgvi_checkpoint = " E2FGVI-HQ-CVPR22.pth "
e2fgvi_checkpoint_id = " 10wGdKSUOie0XmCr8SQ2A2FeDe-mfn5w3 "
2023-04-15 19:59:58 +00:00
folder = " ./checkpoints "
SAM_checkpoint = download_checkpoint ( sam_checkpoint_url , folder , SAM_checkpoint )
xmem_checkpoint = download_checkpoint ( xmem_checkpoint_url , folder , xmem_checkpoint )
2023-04-25 13:38:16 +00:00
e2fgvi_checkpoint = download_checkpoint_from_google_drive ( e2fgvi_checkpoint_id , folder , e2fgvi_checkpoint )
2023-04-15 19:59:58 +00:00
# args, defined in track_anything.py
args = parse_augment ( )
2023-04-19 17:06:19 +00:00
# args.port = 12315
2023-04-25 13:38:16 +00:00
# args.device = "cuda:2"
2023-04-19 17:06:19 +00:00
# args.mask_save = True
2023-04-16 08:53:28 +00:00
2023-04-25 13:38:16 +00:00
# initialize sam, xmem, e2fgvi models
model = TrackingAnything ( SAM_checkpoint , xmem_checkpoint , e2fgvi_checkpoint , args )
2023-04-13 13:40:10 +00:00
with gr . Blocks ( ) as iface :
2023-04-14 04:10:42 +00:00
"""
state for
"""
click_state = gr . State ( [ [ ] , [ ] ] )
2023-04-18 04:01:14 +00:00
interactive_state = gr . State ( {
" inference_times " : 0 ,
" negative_click_times " : 0 ,
" positive_click_times " : 0 ,
2023-04-19 11:34:14 +00:00
" mask_save " : args . mask_save ,
" multi_mask " : {
" mask_names " : [ ] ,
" masks " : [ ]
} ,
2023-04-25 13:38:16 +00:00
" track_end_number " : None ,
" resize_ratio " : 1
2023-04-19 11:34:14 +00:00
}
)
2023-04-17 21:13:45 +00:00
video_state = gr . State (
{
" video_name " : " " ,
" origin_images " : None ,
" painted_images " : None ,
" masks " : None ,
2023-04-25 13:38:16 +00:00
" inpaint_masks " : None ,
2023-04-17 21:13:45 +00:00
" logits " : None ,
" select_frame_number " : 0 ,
" fps " : 30
}
)
2023-04-13 13:40:10 +00:00
with gr . Row ( ) :
2023-04-14 04:10:42 +00:00
# for user video input
2023-04-19 11:34:14 +00:00
with gr . Column ( ) :
with gr . Row ( scale = 0.4 ) :
video_input = gr . Video ( autosize = True )
2023-04-25 13:38:16 +00:00
with gr . Column ( ) :
video_info = gr . Textbox ( )
video_info = gr . Textbox ( value = " If you want to use the inpaint function, it is best to download and use a machine with more VRAM locally. \
Alternatively , you can use the resize ratio slider to scale down the original image to around 360 P resolution for faster processing . " )
resize_ratio_slider = gr . Slider ( minimum = 0.02 , maximum = 1 , step = 0.02 , value = 1 , label = " Resize ratio " , visible = True )
2023-04-13 18:00:48 +00:00
2023-04-13 13:40:10 +00:00
2023-04-20 11:53:19 +00:00
with gr . Row ( ) :
2023-04-14 04:10:42 +00:00
# put the template frame under the radio button
2023-04-20 11:53:19 +00:00
with gr . Column ( ) :
2023-04-18 04:01:14 +00:00
# extract frames
with gr . Column ( ) :
extract_frames_button = gr . Button ( value = " Get video info " , interactive = True , variant = " primary " )
2023-04-14 04:10:42 +00:00
# click points settins, negative or positive, mode continuous or single
with gr . Row ( ) :
2023-04-20 11:53:19 +00:00
with gr . Row ( ) :
2023-04-14 04:10:42 +00:00
point_prompt = gr . Radio (
choices = [ " Positive " , " Negative " ] ,
value = " Positive " ,
label = " Point Prompt " ,
2023-04-19 11:34:14 +00:00
interactive = True ,
visible = False )
2023-04-14 04:10:42 +00:00
click_mode = gr . Radio (
choices = [ " Continuous " , " Single " ] ,
value = " Continuous " ,
label = " Clicking Mode " ,
2023-04-19 11:34:14 +00:00
interactive = True ,
visible = False )
2023-04-20 11:53:19 +00:00
with gr . Row ( ) :
2023-04-19 11:34:14 +00:00
clear_button_click = gr . Button ( value = " Clear Clicks " , interactive = True , visible = False ) . style ( height = 160 )
Add_mask_button = gr . Button ( value = " Add mask " , interactive = True , visible = False )
template_frame = gr . Image ( type = " pil " , interactive = True , elem_id = " template_frame " , visible = False ) . style ( height = 360 )
image_selection_slider = gr . Slider ( minimum = 1 , maximum = 100 , step = 1 , value = 1 , label = " Image Selection " , visible = False )
track_pause_number_slider = gr . Slider ( minimum = 1 , maximum = 100 , step = 1 , value = 1 , label = " Track end frames " , visible = False )
2023-04-14 04:10:42 +00:00
2023-04-20 11:53:19 +00:00
with gr . Column ( ) :
2023-04-19 17:06:19 +00:00
mask_dropdown = gr . Dropdown ( multiselect = True , value = [ ] , label = " Mask_select " , info = " . " , visible = False )
2023-04-19 11:34:14 +00:00
remove_mask_button = gr . Button ( value = " Remove mask " , interactive = True , visible = False )
video_output = gr . Video ( autosize = True , visible = False ) . style ( height = 360 )
2023-04-25 13:38:16 +00:00
with gr . Row ( ) :
tracking_video_predict_button = gr . Button ( value = " Tracking " , visible = False )
inpaint_video_predict_button = gr . Button ( value = " Inpaint " , visible = False )
2023-04-13 18:00:48 +00:00
2023-04-17 21:13:45 +00:00
# first step: get the video information
extract_frames_button . click (
2023-04-13 18:00:48 +00:00
fn = get_frames_from_video ,
inputs = [
2023-04-17 21:13:45 +00:00
video_input , video_state
2023-04-13 18:00:48 +00:00
] ,
2023-04-19 17:06:19 +00:00
outputs = [ video_state , video_info , template_frame ,
image_selection_slider , track_pause_number_slider , point_prompt , click_mode , clear_button_click , Add_mask_button , template_frame ,
2023-04-25 13:38:16 +00:00
tracking_video_predict_button , video_output , mask_dropdown , remove_mask_button , inpaint_video_predict_button ]
2023-04-15 19:59:58 +00:00
)
2023-04-13 13:40:10 +00:00
2023-04-17 21:13:45 +00:00
# second step: select images from slider
image_selection_slider . release ( fn = select_template ,
2023-04-19 11:34:14 +00:00
inputs = [ image_selection_slider , video_state , interactive_state ] ,
outputs = [ template_frame , video_state , interactive_state ] , api_name = " select_image " )
track_pause_number_slider . release ( fn = get_end_number ,
2023-04-25 13:38:16 +00:00
inputs = [ track_pause_number_slider , video_state , interactive_state ] ,
outputs = [ template_frame , interactive_state ] , api_name = " end_image " )
resize_ratio_slider . release ( fn = get_resize_ratio ,
inputs = [ resize_ratio_slider , interactive_state ] ,
outputs = [ interactive_state ] , api_name = " resize_ratio " )
2023-04-17 21:13:45 +00:00
2023-04-19 11:34:14 +00:00
# click select image to get mask using sam
2023-04-17 21:13:45 +00:00
template_frame . select (
fn = sam_refine ,
2023-04-18 04:01:14 +00:00
inputs = [ video_state , point_prompt , click_state , interactive_state ] ,
outputs = [ template_frame , video_state , interactive_state ]
2023-04-17 21:13:45 +00:00
)
2023-04-19 11:34:14 +00:00
# add different mask
Add_mask_button . click (
fn = add_multi_mask ,
inputs = [ video_state , interactive_state , mask_dropdown ] ,
2023-04-20 11:53:19 +00:00
outputs = [ interactive_state , mask_dropdown , template_frame , click_state ]
2023-04-19 11:34:14 +00:00
)
remove_mask_button . click (
fn = remove_multi_mask ,
inputs = [ interactive_state ] ,
outputs = [ interactive_state ]
)
# tracking video from select image and mask
2023-04-14 13:26:26 +00:00
tracking_video_predict_button . click (
2023-04-15 19:59:58 +00:00
fn = vos_tracking_video ,
2023-04-19 11:34:14 +00:00
inputs = [ video_state , interactive_state , mask_dropdown ] ,
2023-04-18 04:01:14 +00:00
outputs = [ video_output , video_state , interactive_state ]
2023-04-16 08:53:28 +00:00
)
2023-04-17 21:13:45 +00:00
2023-04-25 13:38:16 +00:00
# inpaint video from select image and mask
inpaint_video_predict_button . click (
fn = inpaint_video ,
inputs = [ video_state , interactive_state , mask_dropdown ] ,
outputs = [ video_output ]
)
2023-04-19 11:34:14 +00:00
# click to get mask
mask_dropdown . change (
fn = show_mask ,
inputs = [ video_state , interactive_state , mask_dropdown ] ,
outputs = [ template_frame ]
)
2023-04-16 08:53:28 +00:00
# clear input
2023-04-13 18:16:04 +00:00
video_input . clear (
2023-04-17 21:13:45 +00:00
lambda : (
{
" origin_images " : None ,
" painted_images " : None ,
" masks " : None ,
2023-04-25 13:38:16 +00:00
" inpaint_masks " : None ,
2023-04-17 21:13:45 +00:00
" logits " : None ,
" select_frame_number " : 0 ,
" fps " : 30
} ,
2023-04-18 04:01:14 +00:00
{
" inference_times " : 0 ,
" negative_click_times " : 0 ,
" positive_click_times " : 0 ,
2023-04-19 11:34:14 +00:00
" mask_save " : args . mask_save ,
" multi_mask " : {
" mask_names " : [ ] ,
" masks " : [ ]
2023-04-17 21:13:45 +00:00
} ,
2023-04-25 13:38:16 +00:00
" track_end_number " : 0 ,
" resize_ratio " : 1
2023-04-18 04:01:14 +00:00
} ,
2023-04-19 17:06:19 +00:00
[ [ ] , [ ] ] ,
None ,
2023-04-20 11:53:19 +00:00
None ,
2023-04-19 17:06:19 +00:00
gr . update ( visible = False ) , gr . update ( visible = False ) , gr . update ( visible = False ) , gr . update ( visible = False ) , \
gr . update ( visible = False ) , gr . update ( visible = False ) , gr . update ( visible = False ) , gr . update ( visible = False ) , \
2023-04-25 13:38:16 +00:00
gr . update ( visible = False ) , gr . update ( visible = False ) , gr . update ( visible = False , value = [ ] ) , gr . update ( visible = False ) , gr . update ( visible = False ) \
2023-04-19 17:06:19 +00:00
2023-04-19 11:34:14 +00:00
) ,
2023-04-17 21:13:45 +00:00
[ ] ,
[
video_state ,
2023-04-18 04:01:14 +00:00
interactive_state ,
2023-04-17 21:13:45 +00:00
click_state ,
2023-04-20 11:53:19 +00:00
video_output ,
2023-04-19 17:06:19 +00:00
template_frame ,
tracking_video_predict_button , image_selection_slider , track_pause_number_slider , point_prompt , click_mode , clear_button_click ,
2023-04-25 13:38:16 +00:00
Add_mask_button , template_frame , tracking_video_predict_button , video_output , mask_dropdown , remove_mask_button , inpaint_video_predict_button
2023-04-17 21:13:45 +00:00
] ,
queue = False ,
2023-04-19 11:34:14 +00:00
show_progress = False )
2023-04-17 21:13:45 +00:00
2023-04-19 11:34:14 +00:00
# points clear
clear_button_click . click (
2023-04-20 11:53:19 +00:00
fn = clear_click ,
inputs = [ video_state , click_state , ] ,
outputs = [ template_frame , click_state ] ,
2023-04-24 17:07:46 +00:00
)
# set example
gr . Markdown ( " ## Examples " )
gr . Examples (
examples = [ os . path . join ( os . path . dirname ( __file__ ) , " ./test_sample/ " , test_sample ) for test_sample in [ " test-sample8.mp4 " , " test-sample4.mp4 " , \
" test-sample2.mp4 " , " test-sample13.mp4 " ] ] ,
fn = run_example ,
inputs = [
video_input
] ,
outputs = [ video_input ] ,
# cache_examples=True,
2023-04-16 08:53:28 +00:00
)
2023-04-13 13:40:10 +00:00
iface . queue ( concurrency_count = 1 )
2023-04-14 02:27:39 +00:00
iface . launch ( debug = True , enable_queue = True , server_port = args . port , server_name = " 0.0.0.0 " )
2023-04-13 13:40:10 +00:00