diff --git a/tools/interact_tools.py b/tools/interact_tools.py index ca50259..6d93cac 100644 --- a/tools/interact_tools.py +++ b/tools/interact_tools.py @@ -45,17 +45,24 @@ class SamControler(): self.sam_controler.set_image(image) return - def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True): + + def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray,logits: np.ndarray=None, multimask=True): ''' it is used in first frame in video return: mask, logit, painted image(mask+point) ''' - # self.sam_controler.set_image(image) - self.sam_controler.seg_again(image) - prompts = { - 'point_coords': points, - 'point_labels': labels, - } + self.sam_controler.set_image(image) + if logits is None: + prompts = { + 'point_coords': points, + 'point_labels': labels, + } + else: + prompts = { + 'point_coords': points, + 'point_labels': labels, + 'mask_input': logits[None, :, :] + } masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask) mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]