mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 08:27:49 +01:00
solve image_change
This commit is contained in:
@@ -31,6 +31,7 @@ class BaseSegmenter:
|
|||||||
def set_image(self, image: np.ndarray):
|
def set_image(self, image: np.ndarray):
|
||||||
# PIL.open(image_path) 3channel: RGB
|
# PIL.open(image_path) 3channel: RGB
|
||||||
# image embedding: avoid encode the same image multiple times
|
# image embedding: avoid encode the same image multiple times
|
||||||
|
self.orignal_image = image
|
||||||
if self.embedded:
|
if self.embedded:
|
||||||
print('repeat embedding, please reset_image.')
|
print('repeat embedding, please reset_image.')
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ class SamControler():
|
|||||||
return: mask, logit, painted image(mask+point)
|
return: mask, logit, painted image(mask+point)
|
||||||
'''
|
'''
|
||||||
# self.sam_controler.set_image(image)
|
# self.sam_controler.set_image(image)
|
||||||
|
origal_image = self.sam_controler.orignal_image
|
||||||
if logits is None:
|
if logits is None:
|
||||||
prompts = {
|
prompts = {
|
||||||
'point_coords': points,
|
'point_coords': points,
|
||||||
@@ -79,6 +79,7 @@ class SamControler():
|
|||||||
return mask, logit, painted_image
|
return mask, logit, painted_image
|
||||||
|
|
||||||
def interact_loop(self, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
def interact_loop(self, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
||||||
|
origal_image = self.sam_controler.orignal_image
|
||||||
if same:
|
if same:
|
||||||
'''
|
'''
|
||||||
true; loop in the same image
|
true; loop in the same image
|
||||||
@@ -112,7 +113,7 @@ class SamControler():
|
|||||||
masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
|
masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
|
||||||
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
||||||
|
|
||||||
painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
painted_image = mask_painter(origal_image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
||||||
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
|
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
|
||||||
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
|
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
|
||||||
painted_image = Image.fromarray(painted_image)
|
painted_image = Image.fromarray(painted_image)
|
||||||
|
|||||||
Reference in New Issue
Block a user