From 579a105166ac7ddaf70b3b2bf29657818965ece8 Mon Sep 17 00:00:00 2001 From: memoryunreal <814514103@qq.com> Date: Tue, 18 Apr 2023 04:01:14 +0000 Subject: [PATCH] add args.mask_save = True, add interactive_state to record, remove memory print --li --- app.py | 192 +++++++++++++++------------- track_anything.py | 1 + tracker/inference/memory_manager.py | 6 +- 3 files changed, 104 insertions(+), 95 deletions(-) diff --git a/app.py b/app.py index 97a4ad8..760b9a4 100644 --- a/app.py +++ b/app.py @@ -18,6 +18,7 @@ import torch import concurrent.futures import queue +# download checkpoints def download_checkpoint(url, folder, filename): os.makedirs(folder, exist_ok=True) filepath = os.path.join(folder, filename) @@ -51,7 +52,8 @@ def get_prompt(click_state, click_input): "multimask_output":"True", } return prompt - + +# extract frames from upload video def get_frames_from_video(video_input, video_state): """ Args: @@ -86,6 +88,7 @@ def get_frames_from_video(video_input, video_state): } return video_state, gr.update(visible=True, maximum=len(frames), value=1) +# get the select frame from gradio slider def select_template(image_selection_slider, video_state): # images = video_state[1] @@ -100,6 +103,70 @@ def select_template(image_selection_slider, video_state): return video_state["painted_images"][image_selection_slider], video_state +# use sam to get the mask +def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData): + """ + Args: + template_frame: PIL.Image + point_prompt: flag for positive or negative button click + click_state: [[points], [labels]] + """ + if point_prompt == "Positive": + coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1]) + interactive_state["positive_click_times"] += 1 + else: + coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1]) + interactive_state["negative_click_times"] += 1 + + # prompt for sam model + prompt = get_prompt(click_state=click_state, click_input=coordinate) + + mask, logit, painted_image = model.first_frame_click( + image=video_state["origin_images"][video_state["select_frame_number"]], + points=np.array(prompt["input_point"]), + labels=np.array(prompt["input_label"]), + multimask=prompt["multimask_output"], + ) + video_state["masks"][video_state["select_frame_number"]] = mask + video_state["logits"][video_state["select_frame_number"]] = logit + video_state["painted_images"][video_state["select_frame_number"]] = painted_image + + return painted_image, video_state, interactive_state + +# tracking vos +def vos_tracking_video(video_state, interactive_state): + model.xmem.clear_memory() + following_frames = video_state["origin_images"][video_state["select_frame_number"]:] + template_mask = video_state["masks"][video_state["select_frame_number"]] + fps = video_state["fps"] + masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask) + + video_state["masks"][video_state["select_frame_number"]:] = masks + video_state["logits"][video_state["select_frame_number"]:] = logits + video_state["painted_images"][video_state["select_frame_number"]:] = painted_images + + video_output = generate_video_from_frames(video_state["painted_images"], output_path="./result/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video + interactive_state["inference_times"] += 1 + + print("For generating this tracking result, inference times: {}, click times: {}, positive: {}, negative: {}".format(interactive_state["inference_times"], + interactive_state["positive_click_times"]+interactive_state["negative_click_times"], + interactive_state["positive_click_times"], + interactive_state["negative_click_times"])) + + #### shanggao code for mask save + if interactive_state["mask_save"]: + if not os.path.exists('./result/mask/{}'.format(video_state["video_name"].split('.')[0])): + os.makedirs('./result/mask/{}'.format(video_state["video_name"].split('.')[0])) + i = 0 + print("save mask") + for mask in video_state["masks"]: + np.save(os.path.join('./result/mask/{}'.format(video_state["video_name"].split('.')[0]), '{:05d}.npy'.format(i)), mask) + i+=1 + # save_mask(video_state["masks"], video_state["video_name"]) + #### shanggao code for mask save + return video_output, video_state, interactive_state + +# generate video after vos inference def generate_video_from_frames(frames, output_path, fps=30): """ Generates a video from a list of frames. @@ -115,75 +182,6 @@ def generate_video_from_frames(frames, output_path, fps=30): torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264") return output_path -def sam_refine(video_state, point_prompt, click_state, evt:gr.SelectData): - """ - Args: - template_frame: PIL.Image - point_prompt: flag for positive or negative button click - click_state: [[points], [labels]] - """ - if point_prompt == "Positive": - coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1]) - else: - coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1]) - - # prompt for sam model - prompt = get_prompt(click_state=click_state, click_input=coordinate) - - mask, logit, painted_image = model.first_frame_click( - image=video_state["origin_images"][video_state["select_frame_number"]], - points=np.array(prompt["input_point"]), - labels=np.array(prompt["input_label"]), - multimask=prompt["multimask_output"], - ) - video_state["masks"][video_state["select_frame_number"]] = mask - video_state["logits"][video_state["select_frame_number"]] = logit - video_state["painted_images"][video_state["select_frame_number"]] = painted_image - - return painted_image, video_state - -def interactive_correction(video_state, point_prompt, click_state, select_correction_frame, evt: gr.SelectData): - """ - Args: - template_frame: PIL.Image - point_prompt: flag for positive or negative button click - click_state: [[points], [labels]] - """ - refine_image = video_state[1][select_correction_frame] - if point_prompt == "Positive": - coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1]) - else: - coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1]) - - # prompt for sam model - prompt = get_prompt(click_state=click_state, click_input=coordinate) - # model.samcontroler.seg_again(refine_image) - corrected_mask, corrected_logit, corrected_painted_image = model.first_frame_click( - image=refine_image, - points=np.array(prompt["input_point"]), - labels=np.array(prompt["input_label"]), - multimask=prompt["multimask_output"], - ) - return corrected_painted_image, [corrected_mask, corrected_logit, corrected_painted_image] - -def vos_tracking_video(video_state): - model.xmem.clear_memory() - following_frames = video_state["origin_images"][video_state["select_frame_number"]:] - template_mask = video_state["masks"][video_state["select_frame_number"]] - fps = video_state["fps"] - masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask) - - video_state["masks"][video_state["select_frame_number"]:] = masks - video_state["logits"][video_state["select_frame_number"]:] = logits - video_state["painted_images"][video_state["select_frame_number"]:] = painted_images - - video_output = generate_video_from_frames(video_state["painted_images"], output_path="./result/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video - - return video_output, video_state - - - - # check and download checkpoints if needed SAM_checkpoint = "sam_vit_h_4b8939.pth" sam_checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" @@ -196,7 +194,8 @@ xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoi # args, defined in track_anything.py args = parse_augment() args.port = 12212 -args.device = "cuda:2" +args.device = "cuda:4" +args.mask_save = True model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args) @@ -205,6 +204,12 @@ with gr.Blocks() as iface: state for """ click_state = gr.State([[],[]]) + interactive_state = gr.State({ + "inference_times": 0, + "negative_click_times" : 0, + "positive_click_times": 0, + "mask_save": args.mask_save + }) video_state = gr.State( { "video_name": "", @@ -217,20 +222,21 @@ with gr.Blocks() as iface: } ) - - - with gr.Row(): # for user video input with gr.Column(scale=1.0): - video_input = gr.Video().style(height=720) + video_input = gr.Video().style(height=360) with gr.Row(scale=1): # put the template frame under the radio button with gr.Column(scale=0.5): + # extract frames + with gr.Column(): + extract_frames_button = gr.Button(value="Get video info", interactive=True, variant="primary") + # click points settins, negative or positive, mode continuous or single with gr.Row(): with gr.Row(scale=0.5): @@ -250,20 +256,13 @@ with gr.Blocks() as iface: template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame").style(height=360) image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Image Selection", invisible=False) - # extract frames - with gr.Column(): - extract_frames_button = gr.Button(value="Get video info", interactive=True, variant="primary") - + with gr.Column(scale=0.5): - - video_output = gr.Video().style(height=360) tracking_video_predict_button = gr.Button(value="Tracking") - - # first step: get the video information extract_frames_button.click( fn=get_frames_from_video, @@ -273,10 +272,6 @@ with gr.Blocks() as iface: outputs=[video_state, image_selection_slider], ) - - - - # second step: select images from slider image_selection_slider.release(fn=select_template, inputs=[image_selection_slider, video_state], @@ -285,17 +280,16 @@ with gr.Blocks() as iface: template_frame.select( fn=sam_refine, - inputs=[video_state, point_prompt, click_state], - outputs=[template_frame, video_state] + inputs=[video_state, point_prompt, click_state, interactive_state], + outputs=[template_frame, video_state, interactive_state] ) tracking_video_predict_button.click( fn=vos_tracking_video, - inputs=[video_state], - outputs=[video_output, video_state] + inputs=[video_state, interactive_state], + outputs=[video_output, video_state, interactive_state] ) - # clear input video_input.clear( @@ -308,11 +302,18 @@ with gr.Blocks() as iface: "select_frame_number": 0, "fps": 30 }, + { + "inference_times": 0, + "negative_click_times" : 0, + "positive_click_times": 0, + "mask_save": args.mask_save + }, [[],[]] ), [], [ video_state, + interactive_state, click_state, ], queue=False, @@ -328,11 +329,18 @@ with gr.Blocks() as iface: "select_frame_number": 0, "fps": 30 }, + { + "inference_times": 0, + "negative_click_times" : 0, + "positive_click_times": 0, + "mask_save": args.mask_save + }, [[],[]] ), [], [ video_state, + interactive_state, click_state, ], diff --git a/track_anything.py b/track_anything.py index db0e2c4..78e2604 100644 --- a/track_anything.py +++ b/track_anything.py @@ -62,6 +62,7 @@ def parse_augment(): parser.add_argument('--sam_model_type', type=str, default="vit_h") parser.add_argument('--port', type=int, default=6080, help="only useful when running gradio applications") parser.add_argument('--debug', action="store_true") + parser.add_argument('--mask_save', default=True) args = parser.parse_args() if args.debug: diff --git a/tracker/inference/memory_manager.py b/tracker/inference/memory_manager.py index adf6c85..d47d96e 100644 --- a/tracker/inference/memory_manager.py +++ b/tracker/inference/memory_manager.py @@ -182,7 +182,7 @@ class MemoryManager: if self.enable_long_term: # Do memory compressed if needed if self.work_mem.size >= self.max_work_elements: - print('remove memory') + # print('remove memory') # Remove obsolete features if needed if self.long_mem.size >= (self.max_long_elements-self.num_prototypes): self.long_mem.remove_obsolete_features(self.max_long_elements-self.num_prototypes) @@ -239,8 +239,8 @@ class MemoryManager: # add to long-term memory self.long_mem.add(prototype_key, prototype_value, prototype_shrinkage, selection=None, objects=None) - print(f'long memory size: {self.long_mem.size}') - print(f'work memory size: {self.work_mem.size}') + # print(f'long memory size: {self.long_mem.size}') + # print(f'work memory size: {self.work_mem.size}') def consolidation(self, candidate_key, candidate_shrinkage, candidate_selection, usage, candidate_value): # keys: 1*C*N