diff --git a/app.py b/app.py index b4a5c0b..94a46d6 100644 --- a/app.py +++ b/app.py @@ -42,7 +42,7 @@ xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoi # args, defined in track_anything.py args = parse_augment() -args.port = 12212 +args.port = 12213 model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args) @@ -102,17 +102,21 @@ def get_frames_from_video(video_input, play_state): print("read_frame_source:{} error. {}\n".format(video_path, str(e))) for index, frame in enumerate(frames): - frames[index] = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + frames[index] = np.asarray(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))) key_frame_index = int(timestamp * fps) nearest_frame = frames[key_frame_index] frames = [frames[:key_frame_index], frames[key_frame_index:], nearest_frame] # set image in sam when select the template frame - model.samcontroler.sam_controler.set_image(np.asarray(nearest_frame)) - return frames, nearest_frame + model.samcontroler.sam_controler.set_image(nearest_frame) + return frames, nearest_frame, nearest_frame -def inference_all(template_frame, point_prompt, click_state, logit, evt:gr.SelectData): +# def get_video_from_frames(): + +# return video_output + +def inference_all(origin_frame, point_prompt, click_state, logit, evt:gr.SelectData): """ Args: template_frame: PIL.Image @@ -130,13 +134,34 @@ def inference_all(template_frame, point_prompt, click_state, logit, evt:gr.Selec # default value # points = np.array([[evt.index[0],evt.index[1]]]) # labels= np.array([1]) + if len(logit)==0: + logit = None + mask, logit, painted_image = model.first_frame_click( - image=np.asarray(template_frame), - points=np.array(prompt["input_point"]), + image=origin_frame, + points=np.array(prompt["input_point"]), labels=np.array(prompt["input_label"]), - multimask=prompt["multimask_output"] + logits=logit, + multimask=prompt["multimask_output"], + ) - return painted_image, click_state, logit + return painted_image, click_state, logit, mask + +# upload file +# def upload_callback(image_input, state): +# state = [] + [('Image size: ' + str(image_input.size), None)] +# click_state = [[], [], []] +# res = 1024 +# width, height = image_input.size +# ratio = min(1.0 * res / max(width, height), 1.0) +# if ratio < 1.0: +# image_input = image_input.resize((int(width * ratio), int(height * ratio))) +# print('Scaling input image to {}'.format(image_input.size)) +# model.segmenter.image = None +# model.segmenter.image_embedding = None +# model.segmenter.set_image(image_input) +# return state, state, image_input, click_state, image_input + with gr.Blocks() as iface: """ @@ -147,6 +172,9 @@ with gr.Blocks() as iface: video_state = gr.State([[],[],[]]) click_state = gr.State([[],[]]) logits = gr.State([]) + origin_image = gr.State(None) + template_mask = gr.State(None) + with gr.Row(): # for user video input @@ -188,6 +216,7 @@ with gr.Blocks() as iface: # for intermedia result check and correction intermedia_image = gr.Image(type="pil", interactive=True, elem_id="intermedia_frame").style(height=360) + tracking_video_predict = gr.Button(value="Tracking") # seg_automask_video_points_per_batch = gr.Slider( # minimum=0, @@ -197,7 +226,7 @@ with gr.Blocks() as iface: # label="Points per Batch", # ) - seg_automask_video_predict = gr.Button(value="Generator") + # Display the first frame @@ -226,19 +255,18 @@ with gr.Blocks() as iface: fn=get_frames_from_video, inputs=[ video_input, - play_state, - logits + play_state ], - outputs=[video_state, template_frame], + outputs=[video_state, template_frame, origin_image], ) template_frame.select( fn=inference_all, inputs=[ - template_frame, point_prompt, click_state, logits + origin_image, point_prompt, click_state, logits ], outputs=[ - template_frame, click_state + template_frame, click_state, logits, template_mask ] )