mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 08:27:49 +01:00
get video from frames -li
This commit is contained in:
56
app.py
56
app.py
@@ -42,7 +42,7 @@ xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoi
|
||||
|
||||
# args, defined in track_anything.py
|
||||
args = parse_augment()
|
||||
args.port = 12212
|
||||
args.port = 12213
|
||||
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
|
||||
|
||||
|
||||
@@ -102,17 +102,21 @@ def get_frames_from_video(video_input, play_state):
|
||||
print("read_frame_source:{} error. {}\n".format(video_path, str(e)))
|
||||
|
||||
for index, frame in enumerate(frames):
|
||||
frames[index] = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||
frames[index] = np.asarray(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
|
||||
|
||||
key_frame_index = int(timestamp * fps)
|
||||
nearest_frame = frames[key_frame_index]
|
||||
frames = [frames[:key_frame_index], frames[key_frame_index:], nearest_frame]
|
||||
|
||||
# set image in sam when select the template frame
|
||||
model.samcontroler.sam_controler.set_image(np.asarray(nearest_frame))
|
||||
return frames, nearest_frame
|
||||
model.samcontroler.sam_controler.set_image(nearest_frame)
|
||||
return frames, nearest_frame, nearest_frame
|
||||
|
||||
def inference_all(template_frame, point_prompt, click_state, logit, evt:gr.SelectData):
|
||||
# def get_video_from_frames():
|
||||
|
||||
# return video_output
|
||||
|
||||
def inference_all(origin_frame, point_prompt, click_state, logit, evt:gr.SelectData):
|
||||
"""
|
||||
Args:
|
||||
template_frame: PIL.Image
|
||||
@@ -130,13 +134,34 @@ def inference_all(template_frame, point_prompt, click_state, logit, evt:gr.Selec
|
||||
# default value
|
||||
# points = np.array([[evt.index[0],evt.index[1]]])
|
||||
# labels= np.array([1])
|
||||
if len(logit)==0:
|
||||
logit = None
|
||||
|
||||
mask, logit, painted_image = model.first_frame_click(
|
||||
image=np.asarray(template_frame),
|
||||
image=origin_frame,
|
||||
points=np.array(prompt["input_point"]),
|
||||
labels=np.array(prompt["input_label"]),
|
||||
multimask=prompt["multimask_output"]
|
||||
logits=logit,
|
||||
multimask=prompt["multimask_output"],
|
||||
|
||||
)
|
||||
return painted_image, click_state, logit
|
||||
return painted_image, click_state, logit, mask
|
||||
|
||||
# upload file
|
||||
# def upload_callback(image_input, state):
|
||||
# state = [] + [('Image size: ' + str(image_input.size), None)]
|
||||
# click_state = [[], [], []]
|
||||
# res = 1024
|
||||
# width, height = image_input.size
|
||||
# ratio = min(1.0 * res / max(width, height), 1.0)
|
||||
# if ratio < 1.0:
|
||||
# image_input = image_input.resize((int(width * ratio), int(height * ratio)))
|
||||
# print('Scaling input image to {}'.format(image_input.size))
|
||||
# model.segmenter.image = None
|
||||
# model.segmenter.image_embedding = None
|
||||
# model.segmenter.set_image(image_input)
|
||||
# return state, state, image_input, click_state, image_input
|
||||
|
||||
|
||||
with gr.Blocks() as iface:
|
||||
"""
|
||||
@@ -147,6 +172,9 @@ with gr.Blocks() as iface:
|
||||
video_state = gr.State([[],[],[]])
|
||||
click_state = gr.State([[],[]])
|
||||
logits = gr.State([])
|
||||
origin_image = gr.State(None)
|
||||
template_mask = gr.State(None)
|
||||
|
||||
with gr.Row():
|
||||
|
||||
# for user video input
|
||||
@@ -188,6 +216,7 @@ with gr.Blocks() as iface:
|
||||
|
||||
# for intermedia result check and correction
|
||||
intermedia_image = gr.Image(type="pil", interactive=True, elem_id="intermedia_frame").style(height=360)
|
||||
tracking_video_predict = gr.Button(value="Tracking")
|
||||
|
||||
# seg_automask_video_points_per_batch = gr.Slider(
|
||||
# minimum=0,
|
||||
@@ -197,7 +226,7 @@ with gr.Blocks() as iface:
|
||||
# label="Points per Batch",
|
||||
# )
|
||||
|
||||
seg_automask_video_predict = gr.Button(value="Generator")
|
||||
|
||||
|
||||
|
||||
# Display the first frame
|
||||
@@ -226,19 +255,18 @@ with gr.Blocks() as iface:
|
||||
fn=get_frames_from_video,
|
||||
inputs=[
|
||||
video_input,
|
||||
play_state,
|
||||
logits
|
||||
play_state
|
||||
],
|
||||
outputs=[video_state, template_frame],
|
||||
outputs=[video_state, template_frame, origin_image],
|
||||
)
|
||||
|
||||
template_frame.select(
|
||||
fn=inference_all,
|
||||
inputs=[
|
||||
template_frame, point_prompt, click_state, logits
|
||||
origin_image, point_prompt, click_state, logits
|
||||
],
|
||||
outputs=[
|
||||
template_frame, click_state
|
||||
template_frame, click_state, logits, template_mask
|
||||
]
|
||||
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user