From 11c1ef7afb7beef4bf8bd56e18a2c5a7322e8f9b Mon Sep 17 00:00:00 2001 From: memoryunreal <814514103@qq.com> Date: Sat, 15 Apr 2023 19:59:58 +0000 Subject: [PATCH] failed version for using queue with concurrent.futures --- app.py | 169 +++++++++++++++++++++++++++++---------------------------- 1 file changed, 87 insertions(+), 82 deletions(-) diff --git a/app.py b/app.py index ad9b84f..bab88fc 100644 --- a/app.py +++ b/app.py @@ -15,6 +15,8 @@ import requests import json import torchvision import torch +import concurrent.futures +import queue def download_checkpoint(url, folder, filename): os.makedirs(folder, exist_ok=True) @@ -32,25 +34,6 @@ def download_checkpoint(url, folder, filename): return filepath - -# check and download checkpoints if needed -SAM_checkpoint = "sam_vit_h_4b8939.pth" -sam_checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" -xmem_checkpoint = "XMem-s012.pth" -xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth" -folder ="./checkpoints" -SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, SAM_checkpoint) -xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint) - -# args, defined in track_anything.py -args = parse_augment() -args.port = 12213 -model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args) - - - - - def pause_video(play_state): print("user pause_video") play_state.append(time.time()) @@ -138,11 +121,9 @@ def generate_video_from_frames(frames, output_path, fps=30): torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264") return output_path -# def get_video_from_frames(): -# return video_output -def inference_all(origin_frame, point_prompt, click_state, logit, evt:gr.SelectData): +def sam_refine(origin_frame, point_prompt, click_state, logit, evt:gr.SelectData): """ Args: template_frame: PIL.Image @@ -170,29 +151,74 @@ def inference_all(origin_frame, point_prompt, click_state, logit, evt:gr.SelectD multimask=prompt["multimask_output"], ) - return painted_image, click_state, logit, mask + yield painted_image, click_state, logit, mask -def vos_tracking(video_state, template_mask): +def vos_tracking_video(video_state, template_mask): masks, logits, painted_images = model.generator(images=video_state[1], mask=template_mask) video_output = generate_video_from_frames(painted_images, output_path="./output.mp4") return video_output -# upload file -# def upload_callback(image_input, state): -# state = [] + [('Image size: ' + str(image_input.size), None)] -# click_state = [[], [], []] -# res = 1024 -# width, height = image_input.size -# ratio = min(1.0 * res / max(width, height), 1.0) -# if ratio < 1.0: -# image_input = image_input.resize((int(width * ratio), int(height * ratio))) -# print('Scaling input image to {}'.format(image_input.size)) -# model.segmenter.image = None -# model.segmenter.image_embedding = None -# model.segmenter.set_image(image_input) -# return state, state, image_input, click_state, image_input +def vos_tracking_image(video_state, template_mask, result_queue, done_queue): + images = video_state[1] + images = images[:5] + for i in range(len(images)): + if i ==0: + mask, logit, painted_image = model.xmem.track(images[i], template_mask) + result_queue['images'].put(images[i]) + result_queue['masks'].put(mask) + result_queue['logits'].put(logit) + result_queue['painted'].put(painted_image) + + else: + mask, logit, painted_image = model.xmem.track(images[i]) + result_queue['images'].put(images[i]) + result_queue['masks'].put(mask) + result_queue['logits'].put(logit) + result_queue['painted'].put(painted_image) + done_queue.put(False) + time.sleep(1) + done_queue.put(True) +def update_gradio_image(result_queue, done_queue): + print("update_gradio_image") + while True: + if not done_queue.empty(): + if done_queue.get(): + break + if not result_queue.empty(): + image = result_queue['images'].get() + mask = result_queue['masks'].get() + logit = result_queue['logits'].get() + painted_image = result_queue['painted'].get() + yield painted_image + +def parallel_tracking(video_state, template_mask): + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: + executor.submit(vos_tracking_image, video_state, template_mask, result_queue, done_queue) + executor.submit(update_gradio_image, result_queue, done_queue) + + + + +# check and download checkpoints if needed +SAM_checkpoint = "sam_vit_h_4b8939.pth" +sam_checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" +xmem_checkpoint = "XMem-s012.pth" +xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth" +folder ="./checkpoints" +SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, SAM_checkpoint) +xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint) + +# args, defined in track_anything.py +args = parse_augment() +args.port = 12214 +model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args) +result_queue = {"images": queue.Queue(), + "masks": queue.Queue(), + "logits": queue.Queue(), + "painted": queue.Queue()} +done_queue = queue.Queue() with gr.Blocks() as iface: """ @@ -205,6 +231,9 @@ with gr.Blocks() as iface: logits = gr.State([]) origin_image = gr.State(None) template_mask = gr.State(None) + # queue value for image refresh, origin image, mask, logits, painted image + + with gr.Row(): @@ -248,41 +277,21 @@ with gr.Blocks() as iface: # for intermedia result check and correction # intermedia_image = gr.Image(type="pil", interactive=True, elem_id="intermedia_frame").style(height=360) video_output = gr.Video().style(height=360) - tracking_video_predict_button = gr.Button(value="Tracking") + tracking_video_predict_button = gr.Button(value="Video") - # seg_automask_video_points_per_batch = gr.Slider( - # minimum=0, - # maximum=64, - # step=2, - # value=64, - # label="Points per Batch", - # ) + image_output = gr.Image(type="pil", interactive=True, elem_id="image_output").style(height=360) + tracking_image_predict_button = gr.Button(value="Tracking") + template_frame.select( + fn=sam_refine, + inputs=[ + origin_image, point_prompt, click_state, logits + ], + outputs=[ + template_frame, click_state, logits, template_mask + ] + ) - - - # Display the first frame - # with gr.Column(): - # first_frame = gr.Image(type="pil", interactive=True, elem_id="first_frame") - # seg_automask_firstframe = gr.Button(value="Find target") - - # video_input = gr.inputs.Video(type="mp4") - - # output = gr.outputs.Image(type="pil") - - # gr.Interface(fn=capture_frame, inputs=seg_automask_video_file, outputs=first_frame) - - # seg_automask_video_predict.click( - # fn=automask_video_app, - # inputs=[ - # seg_automask_video_file, - # seg_automask_video_model_type, - # seg_automask_video_points_per_side, - # seg_automask_video_points_per_batch, - # seg_automask_video_min_area, - # ], - # outputs=[output_video], - # ) template_select_button.click( fn=get_frames_from_video, inputs=[ @@ -290,23 +299,19 @@ with gr.Blocks() as iface: play_state ], outputs=[video_state, template_frame, origin_image], - ) + ) - template_frame.select( - fn=inference_all, - inputs=[ - origin_image, point_prompt, click_state, logits - ], - outputs=[ - template_frame, click_state, logits, template_mask - ] - - ) tracking_video_predict_button.click( - fn=vos_tracking, + fn=vos_tracking_video, inputs=[video_state, template_mask], outputs=[video_output] ) + tracking_image_predict_button.click( + fn=parallel_tracking, + inputs=[video_state, template_mask], + outputs=[image_output] + ) + # clear # clear_button_clike.click( # lambda x: ([[], [], []], x, ""),