mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 08:27:49 +01:00
none mask
This commit is contained in:
@@ -45,17 +45,24 @@ class SamControler():
|
|||||||
self.sam_controler.set_image(image)
|
self.sam_controler.set_image(image)
|
||||||
return
|
return
|
||||||
|
|
||||||
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
|
|
||||||
|
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray,logits: np.ndarray=None, multimask=True):
|
||||||
'''
|
'''
|
||||||
it is used in first frame in video
|
it is used in first frame in video
|
||||||
return: mask, logit, painted image(mask+point)
|
return: mask, logit, painted image(mask+point)
|
||||||
'''
|
'''
|
||||||
# self.sam_controler.set_image(image)
|
self.sam_controler.set_image(image)
|
||||||
self.sam_controler.seg_again(image)
|
if logits is None:
|
||||||
prompts = {
|
prompts = {
|
||||||
'point_coords': points,
|
'point_coords': points,
|
||||||
'point_labels': labels,
|
'point_labels': labels,
|
||||||
}
|
}
|
||||||
|
else:
|
||||||
|
prompts = {
|
||||||
|
'point_coords': points,
|
||||||
|
'point_labels': labels,
|
||||||
|
'mask_input': logits[None, :, :]
|
||||||
|
}
|
||||||
masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
|
masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
|
||||||
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user