This commit is contained in:
ShangGaoG
2023-04-12 19:05:11 +08:00
parent a9219ebe2f
commit a060359343

View File

@@ -30,7 +30,7 @@ class BaseSegmenter:
@torch.no_grad() @torch.no_grad()
def set_image(self, image: np.ndarray): def set_image(self, image: np.ndarray):
# PIL.open(image_path) 3channel: RGB # PIL.open(image_path) 3channel: RGB
# image embedding # image embedding: avoid encode the same image multiple times
if self.embedded: if self.embedded:
print('repeat embedding, please reset_image.') print('repeat embedding, please reset_image.')
return return
@@ -40,7 +40,7 @@ class BaseSegmenter:
@torch.no_grad() @torch.no_grad()
def reset_image(self): def reset_image(self):
# reset # reset image embeding
self.predictor.reset_image() self.predictor.reset_image()
self.embedded = False self.embedded = False
@@ -71,7 +71,7 @@ class BaseSegmenter:
mask_input=prompts['mask_input'], mask_input=prompts['mask_input'],
multimask_output=multimask) multimask_output=multimask)
else: else:
raise("Not implement!") raise("Not implement now!")
# masks (n, h, w), scores (n,), logits (n, 256, 256) # masks (n, h, w), scores (n,), logits (n, 256, 256)
return masks, scores, logits return masks, scores, logits