mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-14 15:37:50 +01:00
failed version for using queue with concurrent.futures
This commit is contained in:
169
app.py
169
app.py
@@ -15,6 +15,8 @@ import requests
|
||||
import json
|
||||
import torchvision
|
||||
import torch
|
||||
import concurrent.futures
|
||||
import queue
|
||||
|
||||
def download_checkpoint(url, folder, filename):
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
@@ -32,25 +34,6 @@ def download_checkpoint(url, folder, filename):
|
||||
|
||||
return filepath
|
||||
|
||||
|
||||
# check and download checkpoints if needed
|
||||
SAM_checkpoint = "sam_vit_h_4b8939.pth"
|
||||
sam_checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
||||
xmem_checkpoint = "XMem-s012.pth"
|
||||
xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
|
||||
folder ="./checkpoints"
|
||||
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, SAM_checkpoint)
|
||||
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
|
||||
|
||||
# args, defined in track_anything.py
|
||||
args = parse_augment()
|
||||
args.port = 12213
|
||||
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def pause_video(play_state):
|
||||
print("user pause_video")
|
||||
play_state.append(time.time())
|
||||
@@ -138,11 +121,9 @@ def generate_video_from_frames(frames, output_path, fps=30):
|
||||
torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
|
||||
return output_path
|
||||
|
||||
# def get_video_from_frames():
|
||||
|
||||
# return video_output
|
||||
|
||||
def inference_all(origin_frame, point_prompt, click_state, logit, evt:gr.SelectData):
|
||||
def sam_refine(origin_frame, point_prompt, click_state, logit, evt:gr.SelectData):
|
||||
"""
|
||||
Args:
|
||||
template_frame: PIL.Image
|
||||
@@ -170,29 +151,74 @@ def inference_all(origin_frame, point_prompt, click_state, logit, evt:gr.SelectD
|
||||
multimask=prompt["multimask_output"],
|
||||
|
||||
)
|
||||
return painted_image, click_state, logit, mask
|
||||
yield painted_image, click_state, logit, mask
|
||||
|
||||
def vos_tracking(video_state, template_mask):
|
||||
def vos_tracking_video(video_state, template_mask):
|
||||
|
||||
masks, logits, painted_images = model.generator(images=video_state[1], mask=template_mask)
|
||||
video_output = generate_video_from_frames(painted_images, output_path="./output.mp4")
|
||||
return video_output
|
||||
|
||||
# upload file
|
||||
# def upload_callback(image_input, state):
|
||||
# state = [] + [('Image size: ' + str(image_input.size), None)]
|
||||
# click_state = [[], [], []]
|
||||
# res = 1024
|
||||
# width, height = image_input.size
|
||||
# ratio = min(1.0 * res / max(width, height), 1.0)
|
||||
# if ratio < 1.0:
|
||||
# image_input = image_input.resize((int(width * ratio), int(height * ratio)))
|
||||
# print('Scaling input image to {}'.format(image_input.size))
|
||||
# model.segmenter.image = None
|
||||
# model.segmenter.image_embedding = None
|
||||
# model.segmenter.set_image(image_input)
|
||||
# return state, state, image_input, click_state, image_input
|
||||
def vos_tracking_image(video_state, template_mask, result_queue, done_queue):
|
||||
images = video_state[1]
|
||||
images = images[:5]
|
||||
for i in range(len(images)):
|
||||
if i ==0:
|
||||
mask, logit, painted_image = model.xmem.track(images[i], template_mask)
|
||||
result_queue['images'].put(images[i])
|
||||
result_queue['masks'].put(mask)
|
||||
result_queue['logits'].put(logit)
|
||||
result_queue['painted'].put(painted_image)
|
||||
|
||||
else:
|
||||
mask, logit, painted_image = model.xmem.track(images[i])
|
||||
result_queue['images'].put(images[i])
|
||||
result_queue['masks'].put(mask)
|
||||
result_queue['logits'].put(logit)
|
||||
result_queue['painted'].put(painted_image)
|
||||
done_queue.put(False)
|
||||
time.sleep(1)
|
||||
done_queue.put(True)
|
||||
|
||||
def update_gradio_image(result_queue, done_queue):
|
||||
print("update_gradio_image")
|
||||
while True:
|
||||
if not done_queue.empty():
|
||||
if done_queue.get():
|
||||
break
|
||||
if not result_queue.empty():
|
||||
image = result_queue['images'].get()
|
||||
mask = result_queue['masks'].get()
|
||||
logit = result_queue['logits'].get()
|
||||
painted_image = result_queue['painted'].get()
|
||||
yield painted_image
|
||||
|
||||
def parallel_tracking(video_state, template_mask):
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
||||
executor.submit(vos_tracking_image, video_state, template_mask, result_queue, done_queue)
|
||||
executor.submit(update_gradio_image, result_queue, done_queue)
|
||||
|
||||
|
||||
|
||||
|
||||
# check and download checkpoints if needed
|
||||
SAM_checkpoint = "sam_vit_h_4b8939.pth"
|
||||
sam_checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
||||
xmem_checkpoint = "XMem-s012.pth"
|
||||
xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
|
||||
folder ="./checkpoints"
|
||||
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, SAM_checkpoint)
|
||||
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
|
||||
|
||||
# args, defined in track_anything.py
|
||||
args = parse_augment()
|
||||
args.port = 12214
|
||||
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
|
||||
result_queue = {"images": queue.Queue(),
|
||||
"masks": queue.Queue(),
|
||||
"logits": queue.Queue(),
|
||||
"painted": queue.Queue()}
|
||||
done_queue = queue.Queue()
|
||||
|
||||
with gr.Blocks() as iface:
|
||||
"""
|
||||
@@ -205,6 +231,9 @@ with gr.Blocks() as iface:
|
||||
logits = gr.State([])
|
||||
origin_image = gr.State(None)
|
||||
template_mask = gr.State(None)
|
||||
# queue value for image refresh, origin image, mask, logits, painted image
|
||||
|
||||
|
||||
|
||||
with gr.Row():
|
||||
|
||||
@@ -248,41 +277,21 @@ with gr.Blocks() as iface:
|
||||
# for intermedia result check and correction
|
||||
# intermedia_image = gr.Image(type="pil", interactive=True, elem_id="intermedia_frame").style(height=360)
|
||||
video_output = gr.Video().style(height=360)
|
||||
tracking_video_predict_button = gr.Button(value="Tracking")
|
||||
tracking_video_predict_button = gr.Button(value="Video")
|
||||
|
||||
# seg_automask_video_points_per_batch = gr.Slider(
|
||||
# minimum=0,
|
||||
# maximum=64,
|
||||
# step=2,
|
||||
# value=64,
|
||||
# label="Points per Batch",
|
||||
# )
|
||||
image_output = gr.Image(type="pil", interactive=True, elem_id="image_output").style(height=360)
|
||||
tracking_image_predict_button = gr.Button(value="Tracking")
|
||||
|
||||
template_frame.select(
|
||||
fn=sam_refine,
|
||||
inputs=[
|
||||
origin_image, point_prompt, click_state, logits
|
||||
],
|
||||
outputs=[
|
||||
template_frame, click_state, logits, template_mask
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
|
||||
# Display the first frame
|
||||
# with gr.Column():
|
||||
# first_frame = gr.Image(type="pil", interactive=True, elem_id="first_frame")
|
||||
# seg_automask_firstframe = gr.Button(value="Find target")
|
||||
|
||||
# video_input = gr.inputs.Video(type="mp4")
|
||||
|
||||
# output = gr.outputs.Image(type="pil")
|
||||
|
||||
# gr.Interface(fn=capture_frame, inputs=seg_automask_video_file, outputs=first_frame)
|
||||
|
||||
# seg_automask_video_predict.click(
|
||||
# fn=automask_video_app,
|
||||
# inputs=[
|
||||
# seg_automask_video_file,
|
||||
# seg_automask_video_model_type,
|
||||
# seg_automask_video_points_per_side,
|
||||
# seg_automask_video_points_per_batch,
|
||||
# seg_automask_video_min_area,
|
||||
# ],
|
||||
# outputs=[output_video],
|
||||
# )
|
||||
template_select_button.click(
|
||||
fn=get_frames_from_video,
|
||||
inputs=[
|
||||
@@ -290,23 +299,19 @@ with gr.Blocks() as iface:
|
||||
play_state
|
||||
],
|
||||
outputs=[video_state, template_frame, origin_image],
|
||||
)
|
||||
)
|
||||
|
||||
template_frame.select(
|
||||
fn=inference_all,
|
||||
inputs=[
|
||||
origin_image, point_prompt, click_state, logits
|
||||
],
|
||||
outputs=[
|
||||
template_frame, click_state, logits, template_mask
|
||||
]
|
||||
|
||||
)
|
||||
tracking_video_predict_button.click(
|
||||
fn=vos_tracking,
|
||||
fn=vos_tracking_video,
|
||||
inputs=[video_state, template_mask],
|
||||
outputs=[video_output]
|
||||
)
|
||||
tracking_image_predict_button.click(
|
||||
fn=parallel_tracking,
|
||||
inputs=[video_state, template_mask],
|
||||
outputs=[image_output]
|
||||
)
|
||||
|
||||
# clear
|
||||
# clear_button_clike.click(
|
||||
# lambda x: ([[], [], []], x, ""),
|
||||
|
||||
Reference in New Issue
Block a user