get video from frames -li

This commit is contained in:
memoryunreal
2023-04-14 11:27:13 +00:00
parent 2c9b0e58a4
commit 9df1500007

58
app.py
View File

@@ -42,7 +42,7 @@ xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoi
# args, defined in track_anything.py
args = parse_augment()
args.port = 12212
args.port = 12213
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
@@ -102,17 +102,21 @@ def get_frames_from_video(video_input, play_state):
print("read_frame_source:{} error. {}\n".format(video_path, str(e)))
for index, frame in enumerate(frames):
frames[index] = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
frames[index] = np.asarray(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
key_frame_index = int(timestamp * fps)
nearest_frame = frames[key_frame_index]
frames = [frames[:key_frame_index], frames[key_frame_index:], nearest_frame]
# set image in sam when select the template frame
model.samcontroler.sam_controler.set_image(np.asarray(nearest_frame))
return frames, nearest_frame
model.samcontroler.sam_controler.set_image(nearest_frame)
return frames, nearest_frame, nearest_frame
def inference_all(template_frame, point_prompt, click_state, logit, evt:gr.SelectData):
# def get_video_from_frames():
# return video_output
def inference_all(origin_frame, point_prompt, click_state, logit, evt:gr.SelectData):
"""
Args:
template_frame: PIL.Image
@@ -130,13 +134,34 @@ def inference_all(template_frame, point_prompt, click_state, logit, evt:gr.Selec
# default value
# points = np.array([[evt.index[0],evt.index[1]]])
# labels= np.array([1])
if len(logit)==0:
logit = None
mask, logit, painted_image = model.first_frame_click(
image=np.asarray(template_frame),
points=np.array(prompt["input_point"]),
image=origin_frame,
points=np.array(prompt["input_point"]),
labels=np.array(prompt["input_label"]),
multimask=prompt["multimask_output"]
logits=logit,
multimask=prompt["multimask_output"],
)
return painted_image, click_state, logit
return painted_image, click_state, logit, mask
# upload file
# def upload_callback(image_input, state):
# state = [] + [('Image size: ' + str(image_input.size), None)]
# click_state = [[], [], []]
# res = 1024
# width, height = image_input.size
# ratio = min(1.0 * res / max(width, height), 1.0)
# if ratio < 1.0:
# image_input = image_input.resize((int(width * ratio), int(height * ratio)))
# print('Scaling input image to {}'.format(image_input.size))
# model.segmenter.image = None
# model.segmenter.image_embedding = None
# model.segmenter.set_image(image_input)
# return state, state, image_input, click_state, image_input
with gr.Blocks() as iface:
"""
@@ -147,6 +172,9 @@ with gr.Blocks() as iface:
video_state = gr.State([[],[],[]])
click_state = gr.State([[],[]])
logits = gr.State([])
origin_image = gr.State(None)
template_mask = gr.State(None)
with gr.Row():
# for user video input
@@ -188,6 +216,7 @@ with gr.Blocks() as iface:
# for intermedia result check and correction
intermedia_image = gr.Image(type="pil", interactive=True, elem_id="intermedia_frame").style(height=360)
tracking_video_predict = gr.Button(value="Tracking")
# seg_automask_video_points_per_batch = gr.Slider(
# minimum=0,
@@ -197,7 +226,7 @@ with gr.Blocks() as iface:
# label="Points per Batch",
# )
seg_automask_video_predict = gr.Button(value="Generator")
# Display the first frame
@@ -226,19 +255,18 @@ with gr.Blocks() as iface:
fn=get_frames_from_video,
inputs=[
video_input,
play_state,
logits
play_state
],
outputs=[video_state, template_frame],
outputs=[video_state, template_frame, origin_image],
)
template_frame.select(
fn=inference_all,
inputs=[
template_frame, point_prompt, click_state, logits
origin_image, point_prompt, click_state, logits
],
outputs=[
template_frame, click_state
template_frame, click_state, logits, template_mask
]
)