mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 16:37:58 +01:00
get video from frames -li
This commit is contained in:
56
app.py
56
app.py
@@ -42,7 +42,7 @@ xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoi
|
|||||||
|
|
||||||
# args, defined in track_anything.py
|
# args, defined in track_anything.py
|
||||||
args = parse_augment()
|
args = parse_augment()
|
||||||
args.port = 12212
|
args.port = 12213
|
||||||
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
|
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
|
||||||
|
|
||||||
|
|
||||||
@@ -102,17 +102,21 @@ def get_frames_from_video(video_input, play_state):
|
|||||||
print("read_frame_source:{} error. {}\n".format(video_path, str(e)))
|
print("read_frame_source:{} error. {}\n".format(video_path, str(e)))
|
||||||
|
|
||||||
for index, frame in enumerate(frames):
|
for index, frame in enumerate(frames):
|
||||||
frames[index] = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
frames[index] = np.asarray(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
|
||||||
|
|
||||||
key_frame_index = int(timestamp * fps)
|
key_frame_index = int(timestamp * fps)
|
||||||
nearest_frame = frames[key_frame_index]
|
nearest_frame = frames[key_frame_index]
|
||||||
frames = [frames[:key_frame_index], frames[key_frame_index:], nearest_frame]
|
frames = [frames[:key_frame_index], frames[key_frame_index:], nearest_frame]
|
||||||
|
|
||||||
# set image in sam when select the template frame
|
# set image in sam when select the template frame
|
||||||
model.samcontroler.sam_controler.set_image(np.asarray(nearest_frame))
|
model.samcontroler.sam_controler.set_image(nearest_frame)
|
||||||
return frames, nearest_frame
|
return frames, nearest_frame, nearest_frame
|
||||||
|
|
||||||
def inference_all(template_frame, point_prompt, click_state, logit, evt:gr.SelectData):
|
# def get_video_from_frames():
|
||||||
|
|
||||||
|
# return video_output
|
||||||
|
|
||||||
|
def inference_all(origin_frame, point_prompt, click_state, logit, evt:gr.SelectData):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
template_frame: PIL.Image
|
template_frame: PIL.Image
|
||||||
@@ -130,13 +134,34 @@ def inference_all(template_frame, point_prompt, click_state, logit, evt:gr.Selec
|
|||||||
# default value
|
# default value
|
||||||
# points = np.array([[evt.index[0],evt.index[1]]])
|
# points = np.array([[evt.index[0],evt.index[1]]])
|
||||||
# labels= np.array([1])
|
# labels= np.array([1])
|
||||||
|
if len(logit)==0:
|
||||||
|
logit = None
|
||||||
|
|
||||||
mask, logit, painted_image = model.first_frame_click(
|
mask, logit, painted_image = model.first_frame_click(
|
||||||
image=np.asarray(template_frame),
|
image=origin_frame,
|
||||||
points=np.array(prompt["input_point"]),
|
points=np.array(prompt["input_point"]),
|
||||||
labels=np.array(prompt["input_label"]),
|
labels=np.array(prompt["input_label"]),
|
||||||
multimask=prompt["multimask_output"]
|
logits=logit,
|
||||||
|
multimask=prompt["multimask_output"],
|
||||||
|
|
||||||
)
|
)
|
||||||
return painted_image, click_state, logit
|
return painted_image, click_state, logit, mask
|
||||||
|
|
||||||
|
# upload file
|
||||||
|
# def upload_callback(image_input, state):
|
||||||
|
# state = [] + [('Image size: ' + str(image_input.size), None)]
|
||||||
|
# click_state = [[], [], []]
|
||||||
|
# res = 1024
|
||||||
|
# width, height = image_input.size
|
||||||
|
# ratio = min(1.0 * res / max(width, height), 1.0)
|
||||||
|
# if ratio < 1.0:
|
||||||
|
# image_input = image_input.resize((int(width * ratio), int(height * ratio)))
|
||||||
|
# print('Scaling input image to {}'.format(image_input.size))
|
||||||
|
# model.segmenter.image = None
|
||||||
|
# model.segmenter.image_embedding = None
|
||||||
|
# model.segmenter.set_image(image_input)
|
||||||
|
# return state, state, image_input, click_state, image_input
|
||||||
|
|
||||||
|
|
||||||
with gr.Blocks() as iface:
|
with gr.Blocks() as iface:
|
||||||
"""
|
"""
|
||||||
@@ -147,6 +172,9 @@ with gr.Blocks() as iface:
|
|||||||
video_state = gr.State([[],[],[]])
|
video_state = gr.State([[],[],[]])
|
||||||
click_state = gr.State([[],[]])
|
click_state = gr.State([[],[]])
|
||||||
logits = gr.State([])
|
logits = gr.State([])
|
||||||
|
origin_image = gr.State(None)
|
||||||
|
template_mask = gr.State(None)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
|
||||||
# for user video input
|
# for user video input
|
||||||
@@ -188,6 +216,7 @@ with gr.Blocks() as iface:
|
|||||||
|
|
||||||
# for intermedia result check and correction
|
# for intermedia result check and correction
|
||||||
intermedia_image = gr.Image(type="pil", interactive=True, elem_id="intermedia_frame").style(height=360)
|
intermedia_image = gr.Image(type="pil", interactive=True, elem_id="intermedia_frame").style(height=360)
|
||||||
|
tracking_video_predict = gr.Button(value="Tracking")
|
||||||
|
|
||||||
# seg_automask_video_points_per_batch = gr.Slider(
|
# seg_automask_video_points_per_batch = gr.Slider(
|
||||||
# minimum=0,
|
# minimum=0,
|
||||||
@@ -197,7 +226,7 @@ with gr.Blocks() as iface:
|
|||||||
# label="Points per Batch",
|
# label="Points per Batch",
|
||||||
# )
|
# )
|
||||||
|
|
||||||
seg_automask_video_predict = gr.Button(value="Generator")
|
|
||||||
|
|
||||||
|
|
||||||
# Display the first frame
|
# Display the first frame
|
||||||
@@ -226,19 +255,18 @@ with gr.Blocks() as iface:
|
|||||||
fn=get_frames_from_video,
|
fn=get_frames_from_video,
|
||||||
inputs=[
|
inputs=[
|
||||||
video_input,
|
video_input,
|
||||||
play_state,
|
play_state
|
||||||
logits
|
|
||||||
],
|
],
|
||||||
outputs=[video_state, template_frame],
|
outputs=[video_state, template_frame, origin_image],
|
||||||
)
|
)
|
||||||
|
|
||||||
template_frame.select(
|
template_frame.select(
|
||||||
fn=inference_all,
|
fn=inference_all,
|
||||||
inputs=[
|
inputs=[
|
||||||
template_frame, point_prompt, click_state, logits
|
origin_image, point_prompt, click_state, logits
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
template_frame, click_state
|
template_frame, click_state, logits, template_mask
|
||||||
]
|
]
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user