2023-04-13 13:40:10 +00:00
import gradio as gr
from demo import automask_image_app , automask_video_app , sahi_autoseg_app
import argparse
import cv2
import time
2023-04-13 18:00:48 +00:00
from PIL import Image
import numpy as np
2023-04-14 02:27:39 +00:00
import os
import sys
sys . path . append ( sys . path [ 0 ] + " /tracker " )
sys . path . append ( sys . path [ 0 ] + " /tracker/model " )
from track_anything import TrackingAnything
from track_anything import parse_augment
import requests
2023-04-14 08:24:57 +00:00
import json
2023-04-14 13:26:26 +00:00
import torchvision
import torch
2023-04-15 19:59:58 +00:00
import concurrent . futures
import queue
2023-04-19 11:34:14 +00:00
from tools . painter import mask_painter , point_painter
2023-04-18 04:01:14 +00:00
# download checkpoints
2023-04-14 02:27:39 +00:00
def download_checkpoint ( url , folder , filename ) :
os . makedirs ( folder , exist_ok = True )
filepath = os . path . join ( folder , filename )
2023-04-13 19:04:57 +00:00
2023-04-14 02:27:39 +00:00
if not os . path . exists ( filepath ) :
print ( " download checkpoints ...... " )
response = requests . get ( url , stream = True )
with open ( filepath , " wb " ) as f :
for chunk in response . iter_content ( chunk_size = 8192 ) :
if chunk :
f . write ( chunk )
2023-04-14 04:02:02 +08:00
2023-04-14 02:27:39 +00:00
print ( " download successfully! " )
return filepath
# convert points input to prompt state
2023-04-14 08:24:57 +00:00
def get_prompt ( click_state , click_input ) :
inputs = json . loads ( click_input )
points = click_state [ 0 ]
labels = click_state [ 1 ]
2023-04-14 02:27:39 +00:00
for input in inputs :
points . append ( input [ : 2 ] )
labels . append ( input [ 2 ] )
click_state [ 0 ] = points
2023-04-14 08:24:57 +00:00
click_state [ 1 ] = labels
2023-04-14 02:27:39 +00:00
prompt = {
" prompt_type " : [ " click " ] ,
" input_point " : click_state [ 0 ] ,
" input_label " : click_state [ 1 ] ,
" multimask_output " : " True " ,
}
return prompt
2023-04-18 04:01:14 +00:00
# extract frames from upload video
2023-04-17 21:13:45 +00:00
def get_frames_from_video ( video_input , video_state ) :
2023-04-13 22:27:42 +08:00
"""
2023-04-13 18:00:48 +00:00
Args :
2023-04-13 22:27:42 +08:00
video_path : str
timestamp : float64
2023-04-13 18:00:48 +00:00
Return
2023-04-14 04:46:20 +08:00
[ [ 0 : nearest_frame ] , [ nearest_frame : ] , nearest_frame ]
2023-04-13 22:27:42 +08:00
"""
2023-04-13 18:00:48 +00:00
video_path = video_input
2023-04-13 22:27:42 +08:00
frames = [ ]
try :
cap = cv2 . VideoCapture ( video_path )
fps = cap . get ( cv2 . CAP_PROP_FPS )
while cap . isOpened ( ) :
ret , frame = cap . read ( )
if ret == True :
2023-04-17 09:09:56 +00:00
frames . append ( cv2 . cvtColor ( frame , cv2 . COLOR_BGR2RGB ) )
2023-04-13 22:27:42 +08:00
else :
break
except ( OSError , TypeError , ValueError , KeyError , SyntaxError ) as e :
print ( " read_frame_source: {} error. {} \n " . format ( video_path , str ( e ) ) )
2023-04-13 18:00:48 +00:00
2023-04-17 21:13:45 +00:00
# initialize video_state
video_state = {
" video_name " : os . path . split ( video_path ) [ - 1 ] ,
" origin_images " : frames ,
" painted_images " : frames . copy ( ) ,
" masks " : [ None ] * len ( frames ) ,
" logits " : [ None ] * len ( frames ) ,
" select_frame_number " : 0 ,
2023-04-19 11:34:14 +00:00
" fps " : fps
2023-04-17 21:13:45 +00:00
}
2023-04-19 11:34:14 +00:00
video_info = " Video Name: {} , FPS: {} , Total Frames: {} " . format ( video_state [ " video_name " ] , video_state [ " fps " ] , len ( frames ) )
return video_state , video_info , gr . update ( visible = True , maximum = len ( frames ) , value = 1 ) , gr . update ( visible = True , maximum = len ( frames ) , value = 1 ) , \
gr . update ( visible = True ) , gr . update ( visible = True ) , \
gr . update ( visible = True ) , gr . update ( visible = True ) , \
gr . update ( visible = True ) , gr . update ( visible = True ) , \
gr . update ( visible = True ) , gr . update ( visible = True ) , \
gr . update ( visible = True )
2023-04-17 21:13:45 +00:00
2023-04-18 04:01:14 +00:00
# get the select frame from gradio slider
2023-04-19 11:34:14 +00:00
def select_template ( image_selection_slider , video_state , interactive_state ) :
2023-04-17 21:13:45 +00:00
# images = video_state[1]
image_selection_slider - = 1
video_state [ " select_frame_number " ] = image_selection_slider
2023-04-14 08:50:49 +00:00
2023-04-17 21:13:45 +00:00
# once select a new template frame, set the image in sam
model . samcontroler . sam_controler . reset_image ( )
model . samcontroler . sam_controler . set_image ( video_state [ " origin_images " ] [ image_selection_slider ] )
2023-04-19 11:34:14 +00:00
# # clear multi mask
# interactive_state["multi_mask"] = {"masks":[], "mask_names":[]}
2023-04-17 21:13:45 +00:00
2023-04-19 11:34:14 +00:00
return video_state [ " painted_images " ] [ image_selection_slider ] , video_state , interactive_state
def get_end_number ( track_pause_number_slider , interactive_state ) :
interactive_state [ " track_end_number " ] = track_pause_number_slider
return interactive_state
2023-04-14 13:26:26 +00:00
2023-04-18 04:01:14 +00:00
# use sam to get the mask
def sam_refine ( video_state , point_prompt , click_state , interactive_state , evt : gr . SelectData ) :
2023-04-14 08:24:57 +00:00
"""
Args :
template_frame : PIL . Image
point_prompt : flag for positive or negative button click
click_state : [ [ points ] , [ labels ] ]
"""
if point_prompt == " Positive " :
coordinate = " [[ {} , {} ,1]] " . format ( evt . index [ 0 ] , evt . index [ 1 ] )
2023-04-18 04:01:14 +00:00
interactive_state [ " positive_click_times " ] + = 1
2023-04-14 08:24:57 +00:00
else :
coordinate = " [[ {} , {} ,0]] " . format ( evt . index [ 0 ] , evt . index [ 1 ] )
2023-04-18 04:01:14 +00:00
interactive_state [ " negative_click_times " ] + = 1
2023-04-14 08:24:57 +00:00
# prompt for sam model
prompt = get_prompt ( click_state = click_state , click_input = coordinate )
2023-04-14 02:27:39 +00:00
2023-04-14 08:50:49 +00:00
mask , logit , painted_image = model . first_frame_click (
2023-04-17 21:13:45 +00:00
image = video_state [ " origin_images " ] [ video_state [ " select_frame_number " ] ] ,
2023-04-14 11:27:13 +00:00
points = np . array ( prompt [ " input_point " ] ) ,
2023-04-14 08:24:57 +00:00
labels = np . array ( prompt [ " input_label " ] ) ,
2023-04-14 11:27:13 +00:00
multimask = prompt [ " multimask_output " ] ,
2023-04-14 08:24:57 +00:00
)
2023-04-17 21:13:45 +00:00
video_state [ " masks " ] [ video_state [ " select_frame_number " ] ] = mask
video_state [ " logits " ] [ video_state [ " select_frame_number " ] ] = logit
video_state [ " painted_images " ] [ video_state [ " select_frame_number " ] ] = painted_image
2023-04-14 13:26:26 +00:00
2023-04-18 04:01:14 +00:00
return painted_image , video_state , interactive_state
2023-04-16 08:53:28 +00:00
2023-04-19 11:34:14 +00:00
def add_multi_mask ( video_state , interactive_state , mask_dropdown ) :
mask = video_state [ " masks " ] [ video_state [ " select_frame_number " ] ]
interactive_state [ " multi_mask " ] [ " masks " ] . append ( mask )
interactive_state [ " multi_mask " ] [ " mask_names " ] . append ( " mask_ {} " . format ( len ( interactive_state [ " multi_mask " ] [ " masks " ] ) ) )
return interactive_state , gr . update ( choices = interactive_state [ " multi_mask " ] [ " mask_names " ] )
def remove_multi_mask ( interactive_state ) :
interactive_state [ " multi_mask " ] [ " mask_names " ] = [ ]
interactive_state [ " multi_mask " ] [ " masks " ] = [ ]
return interactive_state
def show_mask ( video_state , interactive_state , mask_dropdown ) :
mask_dropdown . sort ( )
select_frame = video_state [ " origin_images " ] [ video_state [ " select_frame_number " ] ]
for i in range ( len ( mask_dropdown ) ) :
mask_number = int ( mask_dropdown [ i ] . split ( " _ " ) [ 1 ] ) - 1
mask = interactive_state [ " multi_mask " ] [ " masks " ] [ mask_number ]
select_frame = mask_painter ( select_frame , mask . astype ( ' uint8 ' ) , mask_color = mask_number + 2 )
return select_frame
2023-04-18 04:01:14 +00:00
# tracking vos
2023-04-19 11:34:14 +00:00
def vos_tracking_video ( video_state , interactive_state , mask_dropdown ) :
2023-04-16 08:53:28 +00:00
model . xmem . clear_memory ( )
2023-04-19 11:34:14 +00:00
if interactive_state [ " track_end_number " ] :
following_frames = video_state [ " origin_images " ] [ video_state [ " select_frame_number " ] : interactive_state [ " track_end_number " ] ]
else :
following_frames = video_state [ " origin_images " ] [ video_state [ " select_frame_number " ] : ]
if interactive_state [ " multi_mask " ] [ " masks " ] :
# if mask_dropdown:
mask_dropdown . sort ( )
template_mask = interactive_state [ " multi_mask " ] [ " masks " ] [ int ( mask_dropdown [ 0 ] . split ( " _ " ) [ 1 ] ) - 1 ]
for i in range ( 1 , len ( mask_dropdown ) ) :
mask_number = int ( mask_dropdown [ i ] . split ( " _ " ) [ 1 ] ) - 1
template_mask = np . clip ( template_mask + interactive_state [ " multi_mask " ] [ " masks " ] [ mask_number ] * ( mask_number + 1 ) , 0 , mask_number + 1 )
video_state [ " masks " ] [ video_state [ " select_frame_number " ] ] = template_mask
else :
template_mask = video_state [ " masks " ] [ video_state [ " select_frame_number " ] ]
2023-04-17 21:13:45 +00:00
fps = video_state [ " fps " ]
masks , logits , painted_images = model . generator ( images = following_frames , template_mask = template_mask )
2023-04-19 11:34:14 +00:00
if interactive_state [ " track_end_number " ] :
video_state [ " masks " ] [ video_state [ " select_frame_number " ] : interactive_state [ " track_end_number " ] ] = masks
video_state [ " logits " ] [ video_state [ " select_frame_number " ] : interactive_state [ " track_end_number " ] ] = logits
video_state [ " painted_images " ] [ video_state [ " select_frame_number " ] : interactive_state [ " track_end_number " ] ] = painted_images
else :
video_state [ " masks " ] [ video_state [ " select_frame_number " ] : ] = masks
video_state [ " logits " ] [ video_state [ " select_frame_number " ] : ] = logits
video_state [ " painted_images " ] [ video_state [ " select_frame_number " ] : ] = painted_images
2023-04-17 21:13:45 +00:00
video_output = generate_video_from_frames ( video_state [ " painted_images " ] , output_path = " ./result/ {} " . format ( video_state [ " video_name " ] ) , fps = fps ) # import video_input to name the output video
2023-04-18 04:01:14 +00:00
interactive_state [ " inference_times " ] + = 1
2023-04-17 21:13:45 +00:00
2023-04-18 04:01:14 +00:00
print ( " For generating this tracking result, inference times: {} , click times: {} , positive: {} , negative: {} " . format ( interactive_state [ " inference_times " ] ,
interactive_state [ " positive_click_times " ] + interactive_state [ " negative_click_times " ] ,
interactive_state [ " positive_click_times " ] ,
interactive_state [ " negative_click_times " ] ) )
2023-04-19 11:34:14 +00:00
2023-04-18 04:01:14 +00:00
#### shanggao code for mask save
if interactive_state [ " mask_save " ] :
if not os . path . exists ( ' ./result/mask/ {} ' . format ( video_state [ " video_name " ] . split ( ' . ' ) [ 0 ] ) ) :
os . makedirs ( ' ./result/mask/ {} ' . format ( video_state [ " video_name " ] . split ( ' . ' ) [ 0 ] ) )
i = 0
print ( " save mask " )
for mask in video_state [ " masks " ] :
np . save ( os . path . join ( ' ./result/mask/ {} ' . format ( video_state [ " video_name " ] . split ( ' . ' ) [ 0 ] ) , ' {:05d} .npy ' . format ( i ) ) , mask )
i + = 1
# save_mask(video_state["masks"], video_state["video_name"])
#### shanggao code for mask save
return video_output , video_state , interactive_state
# generate video after vos inference
def generate_video_from_frames ( frames , output_path , fps = 30 ) :
"""
Generates a video from a list of frames .
Args :
frames ( list of numpy arrays ) : The frames to include in the video .
output_path ( str ) : The path to save the generated video .
fps ( int , optional ) : The frame rate of the output video . Defaults to 30.
"""
2023-04-19 11:34:14 +00:00
# height, width, layers = frames[0].shape
# fourcc = cv2.VideoWriter_fourcc(*"mp4v")
# video = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
# print(output_path)
# for frame in frames:
# video.write(frame)
# video.release()
2023-04-18 04:01:14 +00:00
frames = torch . from_numpy ( np . asarray ( frames ) )
if not os . path . exists ( os . path . dirname ( output_path ) ) :
os . makedirs ( os . path . dirname ( output_path ) )
torchvision . io . write_video ( output_path , frames , fps = fps , video_codec = " libx264 " )
return output_path
2023-04-15 19:59:58 +00:00
# check and download checkpoints if needed
SAM_checkpoint = " sam_vit_h_4b8939.pth "
sam_checkpoint_url = " https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth "
xmem_checkpoint = " XMem-s012.pth "
xmem_checkpoint_url = " https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth "
folder = " ./checkpoints "
SAM_checkpoint = download_checkpoint ( sam_checkpoint_url , folder , SAM_checkpoint )
xmem_checkpoint = download_checkpoint ( xmem_checkpoint_url , folder , xmem_checkpoint )
2023-04-14 11:27:13 +00:00
2023-04-15 19:59:58 +00:00
# args, defined in track_anything.py
args = parse_augment ( )
2023-04-17 21:13:45 +00:00
args . port = 12212
2023-04-19 11:34:14 +00:00
args . device = " cuda:1 "
2023-04-18 04:01:14 +00:00
args . mask_save = True
2023-04-16 08:53:28 +00:00
2023-04-15 19:59:58 +00:00
model = TrackingAnything ( SAM_checkpoint , xmem_checkpoint , args )
2023-04-13 13:40:10 +00:00
with gr . Blocks ( ) as iface :
2023-04-14 04:10:42 +00:00
"""
state for
"""
click_state = gr . State ( [ [ ] , [ ] ] )
2023-04-18 04:01:14 +00:00
interactive_state = gr . State ( {
" inference_times " : 0 ,
" negative_click_times " : 0 ,
" positive_click_times " : 0 ,
2023-04-19 11:34:14 +00:00
" mask_save " : args . mask_save ,
" multi_mask " : {
" mask_names " : [ ] ,
" masks " : [ ]
} ,
" track_end_num " : None
}
)
2023-04-17 21:13:45 +00:00
video_state = gr . State (
{
" video_name " : " " ,
" origin_images " : None ,
" painted_images " : None ,
" masks " : None ,
" logits " : None ,
" select_frame_number " : 0 ,
" fps " : 30
}
)
2023-04-13 13:40:10 +00:00
with gr . Row ( ) :
2023-04-14 04:10:42 +00:00
# for user video input
2023-04-19 11:34:14 +00:00
with gr . Column ( ) :
with gr . Row ( scale = 0.4 ) :
video_input = gr . Video ( autosize = True )
video_info = gr . Textbox ( )
2023-04-13 18:00:48 +00:00
2023-04-13 13:40:10 +00:00
2023-04-14 04:10:42 +00:00
with gr . Row ( scale = 1 ) :
# put the template frame under the radio button
2023-04-19 11:34:14 +00:00
with gr . Column ( scale = 0.4 ) :
2023-04-18 04:01:14 +00:00
# extract frames
with gr . Column ( ) :
extract_frames_button = gr . Button ( value = " Get video info " , interactive = True , variant = " primary " )
2023-04-14 04:10:42 +00:00
# click points settins, negative or positive, mode continuous or single
with gr . Row ( ) :
2023-04-19 11:34:14 +00:00
with gr . Row ( scale = 0.4 ) :
2023-04-14 04:10:42 +00:00
point_prompt = gr . Radio (
choices = [ " Positive " , " Negative " ] ,
value = " Positive " ,
label = " Point Prompt " ,
2023-04-19 11:34:14 +00:00
interactive = True ,
visible = False )
2023-04-14 04:10:42 +00:00
click_mode = gr . Radio (
choices = [ " Continuous " , " Single " ] ,
value = " Continuous " ,
label = " Clicking Mode " ,
2023-04-19 11:34:14 +00:00
interactive = True ,
visible = False )
2023-04-14 04:10:42 +00:00
with gr . Row ( scale = 0.5 ) :
2023-04-19 11:34:14 +00:00
clear_button_click = gr . Button ( value = " Clear Clicks " , interactive = True , visible = False ) . style ( height = 160 )
Add_mask_button = gr . Button ( value = " Add mask " , interactive = True , visible = False )
template_frame = gr . Image ( type = " pil " , interactive = True , elem_id = " template_frame " , visible = False ) . style ( height = 360 )
image_selection_slider = gr . Slider ( minimum = 1 , maximum = 100 , step = 1 , value = 1 , label = " Image Selection " , visible = False )
track_pause_number_slider = gr . Slider ( minimum = 1 , maximum = 100 , step = 1 , value = 1 , label = " Track end frames " , visible = False )
2023-04-14 04:10:42 +00:00
2023-04-19 11:34:14 +00:00
with gr . Column ( scale = 0.4 ) :
mask_dropdown = gr . Dropdown ( multiselect = True , label = " Mask_select " , info = " . " , visible = False )
remove_mask_button = gr . Button ( value = " Remove mask " , interactive = True , visible = False )
video_output = gr . Video ( autosize = True , visible = False ) . style ( height = 360 )
tracking_video_predict_button = gr . Button ( value = " Tracking " , visible = False )
2023-04-13 18:00:48 +00:00
2023-04-17 21:13:45 +00:00
# first step: get the video information
extract_frames_button . click (
2023-04-13 18:00:48 +00:00
fn = get_frames_from_video ,
inputs = [
2023-04-17 21:13:45 +00:00
video_input , video_state
2023-04-13 18:00:48 +00:00
] ,
2023-04-19 11:34:14 +00:00
outputs = [ video_state , video_info , image_selection_slider , track_pause_number_slider , point_prompt , click_mode , clear_button_click , Add_mask_button , template_frame ,
tracking_video_predict_button , video_output , mask_dropdown , remove_mask_button ]
2023-04-15 19:59:58 +00:00
)
2023-04-13 13:40:10 +00:00
2023-04-17 21:13:45 +00:00
# second step: select images from slider
image_selection_slider . release ( fn = select_template ,
2023-04-19 11:34:14 +00:00
inputs = [ image_selection_slider , video_state , interactive_state ] ,
outputs = [ template_frame , video_state , interactive_state ] , api_name = " select_image " )
track_pause_number_slider . release ( fn = get_end_number ,
inputs = [ track_pause_number_slider , interactive_state ] ,
outputs = [ interactive_state ] , api_name = " end_image " )
2023-04-17 21:13:45 +00:00
2023-04-19 11:34:14 +00:00
# click select image to get mask using sam
2023-04-17 21:13:45 +00:00
template_frame . select (
fn = sam_refine ,
2023-04-18 04:01:14 +00:00
inputs = [ video_state , point_prompt , click_state , interactive_state ] ,
outputs = [ template_frame , video_state , interactive_state ]
2023-04-17 21:13:45 +00:00
)
2023-04-19 11:34:14 +00:00
# add different mask
Add_mask_button . click (
fn = add_multi_mask ,
inputs = [ video_state , interactive_state , mask_dropdown ] ,
outputs = [ interactive_state , mask_dropdown ]
)
remove_mask_button . click (
fn = remove_multi_mask ,
inputs = [ interactive_state ] ,
outputs = [ interactive_state ]
)
# tracking video from select image and mask
2023-04-14 13:26:26 +00:00
tracking_video_predict_button . click (
2023-04-15 19:59:58 +00:00
fn = vos_tracking_video ,
2023-04-19 11:34:14 +00:00
inputs = [ video_state , interactive_state , mask_dropdown ] ,
2023-04-18 04:01:14 +00:00
outputs = [ video_output , video_state , interactive_state ]
2023-04-16 08:53:28 +00:00
)
2023-04-17 21:13:45 +00:00
2023-04-19 11:34:14 +00:00
# click to get mask
mask_dropdown . change (
fn = show_mask ,
inputs = [ video_state , interactive_state , mask_dropdown ] ,
outputs = [ template_frame ]
)
2023-04-16 08:53:28 +00:00
# clear input
2023-04-13 18:16:04 +00:00
video_input . clear (
2023-04-17 21:13:45 +00:00
lambda : (
{
" origin_images " : None ,
" painted_images " : None ,
" masks " : None ,
" logits " : None ,
" select_frame_number " : 0 ,
" fps " : 30
} ,
2023-04-18 04:01:14 +00:00
{
" inference_times " : 0 ,
" negative_click_times " : 0 ,
" positive_click_times " : 0 ,
2023-04-19 11:34:14 +00:00
" mask_save " : args . mask_save ,
" multi_mask " : {
" mask_names " : [ ] ,
" masks " : [ ]
2023-04-17 21:13:45 +00:00
} ,
2023-04-19 11:34:14 +00:00
" track_end_num " : 0
2023-04-18 04:01:14 +00:00
} ,
2023-04-17 21:13:45 +00:00
[ [ ] , [ ] ]
2023-04-19 11:34:14 +00:00
) ,
2023-04-17 21:13:45 +00:00
[ ] ,
[
video_state ,
2023-04-18 04:01:14 +00:00
interactive_state ,
2023-04-17 21:13:45 +00:00
click_state ,
] ,
queue = False ,
2023-04-19 11:34:14 +00:00
show_progress = False )
2023-04-17 21:13:45 +00:00
2023-04-19 11:34:14 +00:00
# points clear
clear_button_click . click (
2023-04-16 08:53:28 +00:00
lambda : ( [ [ ] , [ ] ] ) ,
[ ] ,
[ click_state ] ,
queue = False ,
show_progress = False
)
2023-04-13 13:40:10 +00:00
iface . queue ( concurrency_count = 1 )
2023-04-14 02:27:39 +00:00
iface . launch ( debug = True , enable_queue = True , server_port = args . port , server_name = " 0.0.0.0 " )
2023-04-13 13:40:10 +00:00