select template frame

This commit is contained in:
memoryunreal
2023-04-13 18:00:48 +00:00
parent 8900150c39
commit c4d1fdf924

88
app.py
View File

@@ -3,18 +3,28 @@ from demo import automask_image_app, automask_video_app, sahi_autoseg_app
import argparse import argparse
import cv2 import cv2
import time import time
def pause_video(): from PIL import Image
print(time.time()) import numpy as np
def play_video(): def pause_video(play_state):
print("play video") print("user pause_video")
print(time.time) play_state.append(time.time())
return play_state
def get_frames_from_video(video_path, timestamp): def play_video(play_state):
print("user play_video")
play_state.append(time.time())
return play_state
def get_frames_from_video(video_input, play_state):
""" """
Args:
video_path:str video_path:str
timestamp:float64 timestamp:float64
return [[0:nearest_frame-1], [nearest_frame+1], nearest_frame] Return
[[0:nearest_frame-1], [nearest_frame+1], nearest_frame]
""" """
video_path = video_input
timestamp = play_state[1] - play_state[0]
frames = [] frames = []
try: try:
cap = cv2.VideoCapture(video_path) cap = cv2.VideoCapture(video_path)
@@ -28,51 +38,43 @@ def get_frames_from_video(video_path, timestamp):
except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e: except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e:
print("read_frame_source:{} error. {}\n".format(video_path, str(e))) print("read_frame_source:{} error. {}\n".format(video_path, str(e)))
for frame in frames:
frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
key_frame_index = int(timestamp * fps) key_frame_index = int(timestamp * fps)
nearest_frame = frames[key_frame_index] nearest_frame = frames[key_frame_index]
frames = [frames[:key_frame_index], frames[key_frame_index:], nearest_frame] frames = [frames[:key_frame_index], frames[key_frame_index:], nearest_frame]
return frames return frames, nearest_frame
with gr.Blocks() as iface: with gr.Blocks() as iface:
state = gr.State([])
play_state = gr.State([])
video_state = gr.State([[],[],[]])
with gr.Row(): with gr.Row():
with gr.Column(scale=1.0): with gr.Column(scale=1.0):
seg_automask_video_file = gr.Video().style(height=720) video_input = gr.Video().style(height=720)
seg_automask_video_file.play(fn=play_video)
seg_automask_video_file.pause(fn=pause_video) # listen to the user action for play and pause input video
video_input.play(fn=play_video, inputs=play_state, outputs=play_state)
video_input.pause(fn=pause_video, inputs=play_state, outputs=play_state)
with gr.Row(): with gr.Row():
with gr.Column():
seg_automask_video_model_type = gr.Dropdown(
choices=[
"vit_h",
"vit_l",
"vit_b",
],
value="vit_l",
label="Model Type",
)
seg_automask_video_min_area = gr.Number(
value=1000,
label="Min Area",
)
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
seg_automask_video_points_per_side = gr.Slider( template_frame = gr.Image(type="pil", interactive=True, elem_id="template_frame")
minimum=0, with gr.Column():
maximum=32, template_select_button = gr.Button(value="Template select", interactive=True, variant="primary")
step=2,
value=16,
label="Points per Side",
)
seg_automask_video_points_per_batch = gr.Slider( # seg_automask_video_points_per_batch = gr.Slider(
minimum=0, # minimum=0,
maximum=64, # maximum=64,
step=2, # step=2,
value=64, # value=64,
label="Points per Batch", # label="Points per Batch",
) # )
seg_automask_video_predict = gr.Button(value="Generator") seg_automask_video_predict = gr.Button(value="Generator")
@@ -99,6 +101,14 @@ with gr.Blocks() as iface:
# ], # ],
# outputs=[output_video], # outputs=[output_video],
# ) # )
template_select_button.click(
fn=get_frames_from_video,
inputs=[
video_input,
play_state
],
outputs=[video_state, template_frame],
)
iface.queue(concurrency_count=1) iface.queue(concurrency_count=1)
iface.launch(debug=True, enable_queue=True, server_port=12212, server_name="0.0.0.0") iface.launch(debug=True, enable_queue=True, server_port=12212, server_name="0.0.0.0")