mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 16:37:58 +01:00
initialize the sam set_image -- li
This commit is contained in:
10
app.py
10
app.py
@@ -42,7 +42,7 @@ xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoi
|
||||
|
||||
# args, defined in track_anything.py
|
||||
args = parse_augment()
|
||||
# args.port = 12213
|
||||
args.port = 12213
|
||||
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
|
||||
|
||||
|
||||
@@ -107,6 +107,9 @@ def get_frames_from_video(video_input, play_state):
|
||||
key_frame_index = int(timestamp * fps)
|
||||
nearest_frame = frames[key_frame_index]
|
||||
frames = [frames[:key_frame_index], frames[key_frame_index:], nearest_frame]
|
||||
|
||||
# set image in sam when select the template frame
|
||||
model.samcontroler.set_image(np.asarray(nearest_frame))
|
||||
return frames, nearest_frame
|
||||
|
||||
def inference_all(template_frame, point_prompt, click_state, evt:gr.SelectData):
|
||||
@@ -127,13 +130,10 @@ def inference_all(template_frame, point_prompt, click_state, evt:gr.SelectData):
|
||||
# default value
|
||||
# points = np.array([[evt.index[0],evt.index[1]]])
|
||||
# labels= np.array([1])
|
||||
mask, logit, painted_image = model.inference_step(first_flag=True,
|
||||
interact_flag=False,
|
||||
mask, logit, painted_image = model.first_frame_click(
|
||||
image=np.asarray(template_frame),
|
||||
same_image_flag=False,
|
||||
points=np.array(prompt["input_point"]),
|
||||
labels=np.array(prompt["input_label"]),
|
||||
logits=None,
|
||||
multimask=prompt["multimask_output"]
|
||||
)
|
||||
return painted_image, click_state
|
||||
|
||||
Reference in New Issue
Block a user