diff --git a/tools/base_segmenter.py b/tools/base_segmenter.py index b94f9ba..42e017c 100644 --- a/tools/base_segmenter.py +++ b/tools/base_segmenter.py @@ -30,7 +30,7 @@ class BaseSegmenter: @torch.no_grad() def set_image(self, image: np.ndarray): # PIL.open(image_path) 3channel: RGB - # image embedding + # image embedding: avoid encode the same image multiple times if self.embedded: print('repeat embedding, please reset_image.') return @@ -40,7 +40,7 @@ class BaseSegmenter: @torch.no_grad() def reset_image(self): - # reset + # reset image embeding self.predictor.reset_image() self.embedded = False @@ -71,7 +71,7 @@ class BaseSegmenter: mask_input=prompts['mask_input'], multimask_output=multimask) else: - raise("Not implement!") + raise("Not implement now!") # masks (n, h, w), scores (n,), logits (n, 256, 256) return masks, scores, logits