mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 16:37:58 +01:00
gao
This commit is contained in:
@@ -7,7 +7,7 @@ from typing import Union
|
||||
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
|
||||
import matplotlib.pyplot as plt
|
||||
import PIL
|
||||
from mask_painter import mask_painter
|
||||
from .mask_painter import mask_painter
|
||||
|
||||
|
||||
class BaseSegmenter:
|
||||
@@ -78,7 +78,7 @@ class BaseSegmenter:
|
||||
|
||||
if __name__ == "__main__":
|
||||
# load and show an image
|
||||
image = cv2.imread('images/truck.jpg')
|
||||
image = cv2.imread('/hhd3/gaoshang/truck.jpg')
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # numpy array (h, w, 3)
|
||||
|
||||
# initialise BaseSegmenter
|
||||
@@ -100,7 +100,7 @@ if __name__ == "__main__":
|
||||
masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=False) # masks (n, h, w), scores (n,), logits (n, 256, 256)
|
||||
painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
|
||||
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
|
||||
cv2.imwrite('images/truck_point.jpg', painted_image)
|
||||
cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image)
|
||||
|
||||
# both ------------------------
|
||||
mode = 'both'
|
||||
@@ -114,13 +114,15 @@ if __name__ == "__main__":
|
||||
masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256)
|
||||
painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
|
||||
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
|
||||
cv2.imwrite('images/truck_both.jpg', painted_image)
|
||||
cv2.imwrite('/hhd3/gaoshang/truck_both.jpg', painted_image)
|
||||
|
||||
# mask only ------------------------
|
||||
mode = 'mask'
|
||||
mask_input = logits[np.argmax(scores), :, :]
|
||||
|
||||
prompts = {'mask_input': mask_input[None, :, :]}
|
||||
|
||||
masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256)
|
||||
painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
|
||||
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
|
||||
cv2.imwrite('images/truck_mask.jpg', painted_image)
|
||||
cv2.imwrite('/hhd3/gaoshang/truck_mask.jpg', painted_image)
|
||||
|
||||
Reference in New Issue
Block a user