failed version for using queue with concurrent.futures

This commit is contained in:
memoryunreal
2023-04-15 19:59:58 +00:00
parent d90121faad
commit 11c1ef7afb

169
app.py
View File

@@ -15,6 +15,8 @@ import requests
import json
import torchvision
import torch
import concurrent.futures
import queue
def download_checkpoint(url, folder, filename):
os.makedirs(folder, exist_ok=True)
@@ -32,25 +34,6 @@ def download_checkpoint(url, folder, filename):
return filepath
# check and download checkpoints if needed
SAM_checkpoint = "sam_vit_h_4b8939.pth"
sam_checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
xmem_checkpoint = "XMem-s012.pth"
xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
folder ="./checkpoints"
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, SAM_checkpoint)
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
# args, defined in track_anything.py
args = parse_augment()
args.port = 12213
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
def pause_video(play_state):
print("user pause_video")
play_state.append(time.time())
@@ -138,11 +121,9 @@ def generate_video_from_frames(frames, output_path, fps=30):
torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
return output_path
# def get_video_from_frames():
# return video_output
def inference_all(origin_frame, point_prompt, click_state, logit, evt:gr.SelectData):
def sam_refine(origin_frame, point_prompt, click_state, logit, evt:gr.SelectData):
"""
Args:
template_frame: PIL.Image
@@ -170,29 +151,74 @@ def inference_all(origin_frame, point_prompt, click_state, logit, evt:gr.SelectD
multimask=prompt["multimask_output"],
)
return painted_image, click_state, logit, mask
yield painted_image, click_state, logit, mask
def vos_tracking(video_state, template_mask):
def vos_tracking_video(video_state, template_mask):
masks, logits, painted_images = model.generator(images=video_state[1], mask=template_mask)
video_output = generate_video_from_frames(painted_images, output_path="./output.mp4")
return video_output
# upload file
# def upload_callback(image_input, state):
# state = [] + [('Image size: ' + str(image_input.size), None)]
# click_state = [[], [], []]
# res = 1024
# width, height = image_input.size
# ratio = min(1.0 * res / max(width, height), 1.0)
# if ratio < 1.0:
# image_input = image_input.resize((int(width * ratio), int(height * ratio)))
# print('Scaling input image to {}'.format(image_input.size))
# model.segmenter.image = None
# model.segmenter.image_embedding = None
# model.segmenter.set_image(image_input)
# return state, state, image_input, click_state, image_input
def vos_tracking_image(video_state, template_mask, result_queue, done_queue):
images = video_state[1]
images = images[:5]
for i in range(len(images)):
if i ==0:
mask, logit, painted_image = model.xmem.track(images[i], template_mask)
result_queue['images'].put(images[i])
result_queue['masks'].put(mask)
result_queue['logits'].put(logit)
result_queue['painted'].put(painted_image)
else:
mask, logit, painted_image = model.xmem.track(images[i])
result_queue['images'].put(images[i])
result_queue['masks'].put(mask)
result_queue['logits'].put(logit)
result_queue['painted'].put(painted_image)
done_queue.put(False)
time.sleep(1)
done_queue.put(True)
def update_gradio_image(result_queue, done_queue):
print("update_gradio_image")
while True:
if not done_queue.empty():
if done_queue.get():
break
if not result_queue.empty():
image = result_queue['images'].get()
mask = result_queue['masks'].get()
logit = result_queue['logits'].get()
painted_image = result_queue['painted'].get()
yield painted_image
def parallel_tracking(video_state, template_mask):
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
executor.submit(vos_tracking_image, video_state, template_mask, result_queue, done_queue)
executor.submit(update_gradio_image, result_queue, done_queue)
# check and download checkpoints if needed
SAM_checkpoint = "sam_vit_h_4b8939.pth"
sam_checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
xmem_checkpoint = "XMem-s012.pth"
xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
folder ="./checkpoints"
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, SAM_checkpoint)
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
# args, defined in track_anything.py
args = parse_augment()
args.port = 12214
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
result_queue = {"images": queue.Queue(),
"masks": queue.Queue(),
"logits": queue.Queue(),
"painted": queue.Queue()}
done_queue = queue.Queue()
with gr.Blocks() as iface:
"""
@@ -205,6 +231,9 @@ with gr.Blocks() as iface:
logits = gr.State([])
origin_image = gr.State(None)
template_mask = gr.State(None)
# queue value for image refresh, origin image, mask, logits, painted image
with gr.Row():
@@ -248,41 +277,21 @@ with gr.Blocks() as iface:
# for intermedia result check and correction
# intermedia_image = gr.Image(type="pil", interactive=True, elem_id="intermedia_frame").style(height=360)
video_output = gr.Video().style(height=360)
tracking_video_predict_button = gr.Button(value="Tracking")
tracking_video_predict_button = gr.Button(value="Video")
# seg_automask_video_points_per_batch = gr.Slider(
# minimum=0,
# maximum=64,
# step=2,
# value=64,
# label="Points per Batch",
# )
image_output = gr.Image(type="pil", interactive=True, elem_id="image_output").style(height=360)
tracking_image_predict_button = gr.Button(value="Tracking")
template_frame.select(
fn=sam_refine,
inputs=[
origin_image, point_prompt, click_state, logits
],
outputs=[
template_frame, click_state, logits, template_mask
]
)
# Display the first frame
# with gr.Column():
# first_frame = gr.Image(type="pil", interactive=True, elem_id="first_frame")
# seg_automask_firstframe = gr.Button(value="Find target")
# video_input = gr.inputs.Video(type="mp4")
# output = gr.outputs.Image(type="pil")
# gr.Interface(fn=capture_frame, inputs=seg_automask_video_file, outputs=first_frame)
# seg_automask_video_predict.click(
# fn=automask_video_app,
# inputs=[
# seg_automask_video_file,
# seg_automask_video_model_type,
# seg_automask_video_points_per_side,
# seg_automask_video_points_per_batch,
# seg_automask_video_min_area,
# ],
# outputs=[output_video],
# )
template_select_button.click(
fn=get_frames_from_video,
inputs=[
@@ -290,23 +299,19 @@ with gr.Blocks() as iface:
play_state
],
outputs=[video_state, template_frame, origin_image],
)
)
template_frame.select(
fn=inference_all,
inputs=[
origin_image, point_prompt, click_state, logits
],
outputs=[
template_frame, click_state, logits, template_mask
]
)
tracking_video_predict_button.click(
fn=vos_tracking,
fn=vos_tracking_video,
inputs=[video_state, template_mask],
outputs=[video_output]
)
tracking_image_predict_button.click(
fn=parallel_tracking,
inputs=[video_state, template_mask],
outputs=[image_output]
)
# clear
# clear_button_clike.click(
# lambda x: ([[], [], []], x, ""),