From d7f2684303c27ac4560ed12a5710fd5a97c75c93 Mon Sep 17 00:00:00 2001 From: memoryunreal <814514103@qq.com> Date: Wed, 19 Apr 2023 11:34:14 +0000 Subject: [PATCH] beta -version -li --- app.py | 212 +++++++++++++++++++++++++++------------- app_test.py | 61 ++++++++---- tools/interact_tools.py | 134 ++++++++++++------------- track_anything.py | 27 ++--- 4 files changed, 266 insertions(+), 168 deletions(-) diff --git a/app.py b/app.py index 760b9a4..bbf4644 100644 --- a/app.py +++ b/app.py @@ -17,7 +17,7 @@ import torchvision import torch import concurrent.futures import queue - +from tools.painter import mask_painter, point_painter # download checkpoints def download_checkpoint(url, folder, filename): os.makedirs(folder, exist_ok=True) @@ -84,12 +84,18 @@ def get_frames_from_video(video_input, video_state): "masks": [None]*len(frames), "logits": [None]*len(frames), "select_frame_number": 0, - "fps": 30 + "fps": fps } - return video_state, gr.update(visible=True, maximum=len(frames), value=1) + video_info = "Video Name: {}, FPS: {}, Total Frames: {}".format(video_state["video_name"], video_state["fps"], len(frames)) + return video_state, video_info, gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=1), \ + gr.update(visible=True), gr.update(visible=True), \ + gr.update(visible=True), gr.update(visible=True), \ + gr.update(visible=True), gr.update(visible=True), \ + gr.update(visible=True), gr.update(visible=True), \ + gr.update(visible=True) # get the select frame from gradio slider -def select_template(image_selection_slider, video_state): +def select_template(image_selection_slider, video_state, interactive_state): # images = video_state[1] image_selection_slider -= 1 @@ -100,8 +106,14 @@ def select_template(image_selection_slider, video_state): model.samcontroler.sam_controler.reset_image() model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider]) + # # clear multi mask + # interactive_state["multi_mask"] = {"masks":[], "mask_names":[]} - return video_state["painted_images"][image_selection_slider], video_state + return video_state["painted_images"][image_selection_slider], video_state, interactive_state + +def get_end_number(track_pause_number_slider, interactive_state): + interactive_state["track_end_number"] = track_pause_number_slider + return interactive_state # use sam to get the mask def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData): @@ -133,17 +145,59 @@ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr return painted_image, video_state, interactive_state +def add_multi_mask(video_state, interactive_state, mask_dropdown): + mask = video_state["masks"][video_state["select_frame_number"]] + interactive_state["multi_mask"]["masks"].append(mask) + interactive_state["multi_mask"]["mask_names"].append("mask_{}".format(len(interactive_state["multi_mask"]["masks"]))) + + return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"]) + + +def remove_multi_mask(interactive_state): + interactive_state["multi_mask"]["mask_names"]= [] + interactive_state["multi_mask"]["masks"] = [] + return interactive_state + +def show_mask(video_state, interactive_state, mask_dropdown): + mask_dropdown.sort() + select_frame = video_state["origin_images"][video_state["select_frame_number"]] + + for i in range(len(mask_dropdown)): + mask_number = int(mask_dropdown[i].split("_")[1]) - 1 + mask = interactive_state["multi_mask"]["masks"][mask_number] + select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2) + + return select_frame + # tracking vos -def vos_tracking_video(video_state, interactive_state): +def vos_tracking_video(video_state, interactive_state, mask_dropdown): model.xmem.clear_memory() - following_frames = video_state["origin_images"][video_state["select_frame_number"]:] - template_mask = video_state["masks"][video_state["select_frame_number"]] + if interactive_state["track_end_number"]: + following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] + else: + following_frames = video_state["origin_images"][video_state["select_frame_number"]:] + + if interactive_state["multi_mask"]["masks"]: + # if mask_dropdown: + mask_dropdown.sort() + template_mask = interactive_state["multi_mask"]["masks"][int(mask_dropdown[0].split("_")[1]) - 1] + for i in range(1,len(mask_dropdown)): + mask_number = int(mask_dropdown[i].split("_")[1]) - 1 + template_mask = np.clip(template_mask+interactive_state["multi_mask"]["masks"][mask_number]*(mask_number+1), 0, mask_number+1) + video_state["masks"][video_state["select_frame_number"]]= template_mask + else: + template_mask = video_state["masks"][video_state["select_frame_number"]] fps = video_state["fps"] masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask) - video_state["masks"][video_state["select_frame_number"]:] = masks - video_state["logits"][video_state["select_frame_number"]:] = logits - video_state["painted_images"][video_state["select_frame_number"]:] = painted_images + if interactive_state["track_end_number"]: + video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks + video_state["logits"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = logits + video_state["painted_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = painted_images + else: + video_state["masks"][video_state["select_frame_number"]:] = masks + video_state["logits"][video_state["select_frame_number"]:] = logits + video_state["painted_images"][video_state["select_frame_number"]:] = painted_images video_output = generate_video_from_frames(video_state["painted_images"], output_path="./result/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video interactive_state["inference_times"] += 1 @@ -152,7 +206,7 @@ def vos_tracking_video(video_state, interactive_state): interactive_state["positive_click_times"]+interactive_state["negative_click_times"], interactive_state["positive_click_times"], interactive_state["negative_click_times"])) - + #### shanggao code for mask save if interactive_state["mask_save"]: if not os.path.exists('./result/mask/{}'.format(video_state["video_name"].split('.')[0])): @@ -176,6 +230,14 @@ def generate_video_from_frames(frames, output_path, fps=30): output_path (str): The path to save the generated video. fps (int, optional): The frame rate of the output video. Defaults to 30. """ + # height, width, layers = frames[0].shape + # fourcc = cv2.VideoWriter_fourcc(*"mp4v") + # video = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) + # print(output_path) + # for frame in frames: + # video.write(frame) + + # video.release() frames = torch.from_numpy(np.asarray(frames)) if not os.path.exists(os.path.dirname(output_path)): os.makedirs(os.path.dirname(output_path)) @@ -194,7 +256,7 @@ xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoi # args, defined in track_anything.py args = parse_augment() args.port = 12212 -args.device = "cuda:4" +args.device = "cuda:1" args.mask_save = True model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args) @@ -208,8 +270,15 @@ with gr.Blocks() as iface: "inference_times": 0, "negative_click_times" : 0, "positive_click_times": 0, - "mask_save": args.mask_save - }) + "mask_save": args.mask_save, + "multi_mask": { + "mask_names": [], + "masks": [] + }, + "track_end_num": None + } + ) + video_state = gr.State( { "video_name": "", @@ -225,43 +294,47 @@ with gr.Blocks() as iface: with gr.Row(): # for user video input - with gr.Column(scale=1.0): - video_input = gr.Video().style(height=360) + with gr.Column(): + with gr.Row(scale=0.4): + video_input = gr.Video(autosize=True) + video_info = gr.Textbox() with gr.Row(scale=1): # put the template frame under the radio button - with gr.Column(scale=0.5): + with gr.Column(scale=0.4): # extract frames with gr.Column(): extract_frames_button = gr.Button(value="Get video info", interactive=True, variant="primary") # click points settins, negative or positive, mode continuous or single with gr.Row(): - with gr.Row(scale=0.5): + with gr.Row(scale=0.4): point_prompt = gr.Radio( choices=["Positive", "Negative"], value="Positive", label="Point Prompt", - interactive=True) + interactive=True, + visible=False) click_mode = gr.Radio( choices=["Continuous", "Single"], value="Continuous", label="Clicking Mode", - interactive=True) + interactive=True, + visible=False) with gr.Row(scale=0.5): - clear_button_clike = gr.Button(value="Clear Clicks", interactive=True).style(height=160) - clear_button_image = gr.Button(value="Clear Image", interactive=True) - template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame").style(height=360) - image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Image Selection", invisible=False) - - - + clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False).style(height=160) + Add_mask_button = gr.Button(value="Add mask", interactive=True, visible=False) + template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame", visible=False).style(height=360) + image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Image Selection", visible=False) + track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frames", visible=False) - with gr.Column(scale=0.5): - video_output = gr.Video().style(height=360) - tracking_video_predict_button = gr.Button(value="Tracking") + with gr.Column(scale=0.4): + mask_dropdown = gr.Dropdown(multiselect=True, label="Mask_select", info=".", visible=False) + remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False) + video_output = gr.Video(autosize=True, visible=False).style(height=360) + tracking_video_predict_button = gr.Button(value="Tracking", visible=False) # first step: get the video information extract_frames_button.click( @@ -269,27 +342,51 @@ with gr.Blocks() as iface: inputs=[ video_input, video_state ], - outputs=[video_state, image_selection_slider], + outputs=[video_state, video_info, image_selection_slider, track_pause_number_slider,point_prompt, click_mode, clear_button_click, Add_mask_button, template_frame, + tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button] ) # second step: select images from slider image_selection_slider.release(fn=select_template, - inputs=[image_selection_slider, video_state], - outputs=[template_frame, video_state], api_name="select_image") + inputs=[image_selection_slider, video_state, interactive_state], + outputs=[template_frame, video_state, interactive_state], api_name="select_image") + track_pause_number_slider.release(fn=get_end_number, + inputs=[track_pause_number_slider, interactive_state], + outputs=[interactive_state], api_name="end_image") - + # click select image to get mask using sam template_frame.select( fn=sam_refine, inputs=[video_state, point_prompt, click_state, interactive_state], outputs=[template_frame, video_state, interactive_state] ) + # add different mask + Add_mask_button.click( + fn=add_multi_mask, + inputs=[video_state, interactive_state, mask_dropdown], + outputs=[interactive_state, mask_dropdown] + ) + + remove_mask_button.click( + fn=remove_multi_mask, + inputs=[interactive_state], + outputs=[interactive_state] + ) + + # tracking video from select image and mask tracking_video_predict_button.click( fn=vos_tracking_video, - inputs=[video_state, interactive_state], + inputs=[video_state, interactive_state, mask_dropdown], outputs=[video_output, video_state, interactive_state] ) + # click to get mask + mask_dropdown.change( + fn=show_mask, + inputs=[video_state, interactive_state, mask_dropdown], + outputs=[template_frame] + ) # clear input video_input.clear( @@ -306,10 +403,15 @@ with gr.Blocks() as iface: "inference_times": 0, "negative_click_times" : 0, "positive_click_times": 0, - "mask_save": args.mask_save + "mask_save": args.mask_save, + "multi_mask": { + "mask_names": [], + "masks": [] + }, + "track_end_num": 0 }, [[],[]] - ), + ), [], [ video_state, @@ -317,38 +419,10 @@ with gr.Blocks() as iface: click_state, ], queue=False, - show_progress=False - ) - clear_button_image.click( - lambda: ( - { - "origin_images": None, - "painted_images": None, - "masks": None, - "logits": None, - "select_frame_number": 0, - "fps": 30 - }, - { - "inference_times": 0, - "negative_click_times" : 0, - "positive_click_times": 0, - "mask_save": args.mask_save - }, - [[],[]] - ), - [], - [ - video_state, - interactive_state, - click_state, - ], + show_progress=False) - queue=False, - show_progress=False - - ) - clear_button_clike.click( + # points clear + clear_button_click.click( lambda: ([[],[]]), [], [click_state], diff --git a/app_test.py b/app_test.py index 80af21c..cd10fe7 100644 --- a/app_test.py +++ b/app_test.py @@ -1,23 +1,46 @@ +# import gradio as gr + +# def update_iframe(slider_value): +# return f''' +# +# +# ''' + +# iface = gr.Interface( +# fn=update_iframe, +# inputs=gr.inputs.Slider(minimum=0, maximum=100, step=1, default=50), +# outputs=gr.outputs.HTML(), +# allow_flagging=False, +# ) + +# iface.launch(server_name='0.0.0.0', server_port=12212) + import gradio as gr -def update_iframe(slider_value): - return f''' - - - ''' -iface = gr.Interface( - fn=update_iframe, - inputs=gr.inputs.Slider(minimum=0, maximum=100, step=1, default=50), - outputs=gr.outputs.HTML(), - allow_flagging=False, -) +def change_mask(drop): + return gr.update(choices=["hello", "kitty"]) -iface.launch(server_name='0.0.0.0', server_port=12212) +with gr.Blocks() as iface: + drop = gr.Dropdown( + choices=["cat", "dog", "bird"], label="Animal", info="Will add more animals later!" + ) + radio = gr.Radio(["park", "zoo", "road"], label="Location", info="Where did they go?") + multi_drop = gr.Dropdown( + ["ran", "swam", "ate", "slept"], value=["swam", "slept"], multiselect=True, label="Activity", info="Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed auctor, nisl eget ultricies aliquam, nunc nisl aliquet nunc, eget aliquam nisl nunc vel nisl." + ) + + multi_drop.change( + fn=change_mask, + inputs = multi_drop, + outputs=multi_drop + ) + +iface.launch(server_name='0.0.0.0', server_port=1223) \ No newline at end of file diff --git a/tools/interact_tools.py b/tools/interact_tools.py index 0df422d..daecc73 100644 --- a/tools/interact_tools.py +++ b/tools/interact_tools.py @@ -37,16 +37,16 @@ class SamControler(): self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device) - def seg_again(self, image: np.ndarray): - ''' - it is used when interact in video - ''' - self.sam_controler.reset_image() - self.sam_controler.set_image(image) - return + # def seg_again(self, image: np.ndarray): + # ''' + # it is used when interact in video + # ''' + # self.sam_controler.reset_image() + # self.sam_controler.set_image(image) + # return - def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True): + def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True,mask_color=3): ''' it is used in first frame in video return: mask, logit, painted image(mask+point) @@ -88,47 +88,47 @@ class SamControler(): return mask, logit, painted_image - def interact_loop(self, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True): - origal_image = self.sam_controler.orignal_image - if same: - ''' - true; loop in the same image - ''' - prompts = { - 'point_coords': points, - 'point_labels': labels, - 'mask_input': logits[None, :, :] - } - masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask) - mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] + # def interact_loop(self, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True): + # origal_image = self.sam_controler.orignal_image + # if same: + # ''' + # true; loop in the same image + # ''' + # prompts = { + # 'point_coords': points, + # 'point_labels': labels, + # 'mask_input': logits[None, :, :] + # } + # masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask) + # mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] - painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width) - painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width) - painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width) - painted_image = Image.fromarray(painted_image) + # painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width) + # painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width) + # painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width) + # painted_image = Image.fromarray(painted_image) - return mask, logit, painted_image - else: - ''' - loop in the different image, interact in the video - ''' - if image is None: - raise('Image error') - else: - self.seg_again(image) - prompts = { - 'point_coords': points, - 'point_labels': labels, - } - masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask) - mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] + # return mask, logit, painted_image + # else: + # ''' + # loop in the different image, interact in the video + # ''' + # if image is None: + # raise('Image error') + # else: + # self.seg_again(image) + # prompts = { + # 'point_coords': points, + # 'point_labels': labels, + # } + # masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask) + # mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] - painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width) - painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width) - painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width) - painted_image = Image.fromarray(painted_image) + # painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width) + # painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width) + # painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width) + # painted_image = Image.fromarray(painted_image) - return mask, logit, painted_image + # return mask, logit, painted_image @@ -226,31 +226,31 @@ class SamControler(): -if __name__ == "__main__": - points = np.array([[500, 375], [1125, 625]]) - labels = np.array([1, 1]) - image = cv2.imread('/hhd3/gaoshang/truck.jpg') - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) +# if __name__ == "__main__": +# points = np.array([[500, 375], [1125, 625]]) +# labels = np.array([1, 1]) +# image = cv2.imread('/hhd3/gaoshang/truck.jpg') +# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - sam_controler = initialize() - mask, logit, painted_image_full = first_frame_click(sam_controler,image, points, labels, multimask=True) - painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8) - painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3) - cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image) - cv2.imwrite('/hhd3/gaoshang/truck_change.jpg', image) - painted_image_full.save('/hhd3/gaoshang/truck_point_full.jpg') +# sam_controler = initialize() +# mask, logit, painted_image_full = first_frame_click(sam_controler,image, points, labels, multimask=True) +# painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8) +# painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3) +# cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image) +# cv2.imwrite('/hhd3/gaoshang/truck_change.jpg', image) +# painted_image_full.save('/hhd3/gaoshang/truck_point_full.jpg') - mask, logit, painted_image_full = interact_loop(sam_controler,image,True, points, np.array([1, 0]), logit, multimask=True) - painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8) - painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3) - cv2.imwrite('/hhd3/gaoshang/truck_same.jpg', painted_image) - painted_image_full.save('/hhd3/gaoshang/truck_same_full.jpg') +# mask, logit, painted_image_full = interact_loop(sam_controler,image,True, points, np.array([1, 0]), logit, multimask=True) +# painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8) +# painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3) +# cv2.imwrite('/hhd3/gaoshang/truck_same.jpg', painted_image) +# painted_image_full.save('/hhd3/gaoshang/truck_same_full.jpg') - mask, logit, painted_image_full = interact_loop(sam_controler,image, False, points, labels, multimask=True) - painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8) - painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3) - cv2.imwrite('/hhd3/gaoshang/truck_diff.jpg', painted_image) - painted_image_full.save('/hhd3/gaoshang/truck_diff_full.jpg') +# mask, logit, painted_image_full = interact_loop(sam_controler,image, False, points, labels, multimask=True) +# painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8) +# painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3) +# cv2.imwrite('/hhd3/gaoshang/truck_diff.jpg', painted_image) +# painted_image_full.save('/hhd3/gaoshang/truck_diff_full.jpg') diff --git a/track_anything.py b/track_anything.py index 78e2604..ab6c1f5 100644 --- a/track_anything.py +++ b/track_anything.py @@ -15,26 +15,26 @@ class TrackingAnything(): self.xmem = BaseTracker(xmem_checkpoint, device=args.device) - def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray, - same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True): - if first_flag: - mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask) - return mask, logit, painted_image + # def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray, + # same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True): + # if first_flag: + # mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask) + # return mask, logit, painted_image - if interact_flag: - mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask) - return mask, logit, painted_image + # if interact_flag: + # mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask) + # return mask, logit, painted_image - mask, logit, painted_image = self.xmem.track(image, logit) - return mask, logit, painted_image + # mask, logit, painted_image = self.xmem.track(image, logit) + # return mask, logit, painted_image def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True): mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask) return mask, logit, painted_image - def interact(self, image: np.ndarray, same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True): - mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask) - return mask, logit, painted_image + # def interact(self, image: np.ndarray, same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True): + # mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask) + # return mask, logit, painted_image def generator(self, images: list, template_mask:np.ndarray): @@ -53,6 +53,7 @@ class TrackingAnything(): masks.append(mask) logits.append(logit) painted_images.append(painted_image) + print("tracking image {}".format(i)) return masks, logits, painted_images