solve image_change

This commit is contained in:
ShangGaoG
2023-04-14 17:43:05 +08:00
parent cddb447e3f
commit 0000a2a103
2 changed files with 4 additions and 2 deletions

View File

@@ -31,6 +31,7 @@ class BaseSegmenter:
def set_image(self, image: np.ndarray):
# PIL.open(image_path) 3channel: RGB
# image embedding: avoid encode the same image multiple times
self.orignal_image = image
if self.embedded:
print('repeat embedding, please reset_image.')
return

View File

@@ -52,7 +52,7 @@ class SamControler():
return: mask, logit, painted image(mask+point)
'''
# self.sam_controler.set_image(image)
origal_image = self.sam_controler.orignal_image
if logits is None:
prompts = {
'point_coords': points,
@@ -79,6 +79,7 @@ class SamControler():
return mask, logit, painted_image
def interact_loop(self, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
origal_image = self.sam_controler.orignal_image
if same:
'''
true; loop in the same image
@@ -112,7 +113,7 @@ class SamControler():
masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
painted_image = mask_painter(origal_image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
painted_image = Image.fromarray(painted_image)