diff --git a/track_anything.py b/track_anything.py index 5a218d8..2618425 100644 --- a/track_anything.py +++ b/track_anything.py @@ -25,8 +25,8 @@ class TrackingAnything(): mask, logit, painted_image = self.xmem.track(image, logit) return mask, logit, painted_image - def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True): - mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask) + def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True): + mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels,logits, multimask) return mask, logit, painted_image def interact(self, image: np.ndarray, same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):