mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 08:27:49 +01:00
solve image_change
This commit is contained in:
@@ -31,6 +31,7 @@ class BaseSegmenter:
|
||||
def set_image(self, image: np.ndarray):
|
||||
# PIL.open(image_path) 3channel: RGB
|
||||
# image embedding: avoid encode the same image multiple times
|
||||
self.orignal_image = image
|
||||
if self.embedded:
|
||||
print('repeat embedding, please reset_image.')
|
||||
return
|
||||
|
||||
@@ -52,7 +52,7 @@ class SamControler():
|
||||
return: mask, logit, painted image(mask+point)
|
||||
'''
|
||||
# self.sam_controler.set_image(image)
|
||||
|
||||
origal_image = self.sam_controler.orignal_image
|
||||
if logits is None:
|
||||
prompts = {
|
||||
'point_coords': points,
|
||||
@@ -79,6 +79,7 @@ class SamControler():
|
||||
return mask, logit, painted_image
|
||||
|
||||
def interact_loop(self, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
||||
origal_image = self.sam_controler.orignal_image
|
||||
if same:
|
||||
'''
|
||||
true; loop in the same image
|
||||
@@ -112,7 +113,7 @@ class SamControler():
|
||||
masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
|
||||
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
||||
|
||||
painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
||||
painted_image = mask_painter(origal_image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
||||
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
|
||||
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
|
||||
painted_image = Image.fromarray(painted_image)
|
||||
|
||||
Reference in New Issue
Block a user