This commit is contained in:
ShangGaoG
2023-04-14 17:07:01 +08:00

10
app.py
View File

@@ -42,7 +42,7 @@ xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoi
# args, defined in track_anything.py
args = parse_augment()
# args.port = 12213
args.port = 12213
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
@@ -107,6 +107,9 @@ def get_frames_from_video(video_input, play_state):
key_frame_index = int(timestamp * fps)
nearest_frame = frames[key_frame_index]
frames = [frames[:key_frame_index], frames[key_frame_index:], nearest_frame]
# set image in sam when select the template frame
model.samcontroler.set_image(np.asarray(nearest_frame))
return frames, nearest_frame
def inference_all(template_frame, point_prompt, click_state, evt:gr.SelectData):
@@ -127,13 +130,10 @@ def inference_all(template_frame, point_prompt, click_state, evt:gr.SelectData):
# default value
# points = np.array([[evt.index[0],evt.index[1]]])
# labels= np.array([1])
mask, logit, painted_image = model.inference_step(first_flag=True,
interact_flag=False,
mask, logit, painted_image = model.first_frame_click(
image=np.asarray(template_frame),
same_image_flag=False,
points=np.array(prompt["input_point"]),
labels=np.array(prompt["input_label"]),
logits=None,
multimask=prompt["multimask_output"]
)
return painted_image, click_state