From 83834c1cdf295bd86bcc64f98e69651ba03600b7 Mon Sep 17 00:00:00 2001 From: memoryunreal <814514103@qq.com> Date: Fri, 14 Apr 2023 08:50:49 +0000 Subject: [PATCH] initialize the sam set_image -- li --- app.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/app.py b/app.py index 5541933..9efc241 100644 --- a/app.py +++ b/app.py @@ -42,7 +42,7 @@ xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoi # args, defined in track_anything.py args = parse_augment() -# args.port = 12213 +args.port = 12213 model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args) @@ -107,6 +107,9 @@ def get_frames_from_video(video_input, play_state): key_frame_index = int(timestamp * fps) nearest_frame = frames[key_frame_index] frames = [frames[:key_frame_index], frames[key_frame_index:], nearest_frame] + + # set image in sam when select the template frame + model.samcontroler.set_image(np.asarray(nearest_frame)) return frames, nearest_frame def inference_all(template_frame, point_prompt, click_state, evt:gr.SelectData): @@ -127,13 +130,10 @@ def inference_all(template_frame, point_prompt, click_state, evt:gr.SelectData): # default value # points = np.array([[evt.index[0],evt.index[1]]]) # labels= np.array([1]) - mask, logit, painted_image = model.inference_step(first_flag=True, - interact_flag=False, + mask, logit, painted_image = model.first_frame_click( image=np.asarray(template_frame), - same_image_flag=False, points=np.array(prompt["input_point"]), labels=np.array(prompt["input_label"]), - logits=None, multimask=prompt["multimask_output"] ) return painted_image, click_state