diff --git a/tools/base_segmenter.py b/tools/base_segmenter.py index c5758a8..f5d5c4a 100644 --- a/tools/base_segmenter.py +++ b/tools/base_segmenter.py @@ -31,6 +31,7 @@ class BaseSegmenter: def set_image(self, image: np.ndarray): # PIL.open(image_path) 3channel: RGB # image embedding: avoid encode the same image multiple times + self.orignal_image = image if self.embedded: print('repeat embedding, please reset_image.') return diff --git a/tools/interact_tools.py b/tools/interact_tools.py index 2cf95df..c1edb12 100644 --- a/tools/interact_tools.py +++ b/tools/interact_tools.py @@ -52,7 +52,7 @@ class SamControler(): return: mask, logit, painted image(mask+point) ''' # self.sam_controler.set_image(image) - + origal_image = self.sam_controler.orignal_image if logits is None: prompts = { 'point_coords': points, @@ -79,6 +79,7 @@ class SamControler(): return mask, logit, painted_image def interact_loop(self, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True): + origal_image = self.sam_controler.orignal_image if same: ''' true; loop in the same image @@ -112,7 +113,7 @@ class SamControler(): masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask) mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] - painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width) + painted_image = mask_painter(origal_image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width) painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width) painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width) painted_image = Image.fromarray(painted_image)