mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 16:37:58 +01:00
note
This commit is contained in:
@@ -30,7 +30,7 @@ class BaseSegmenter:
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def set_image(self, image: np.ndarray):
|
def set_image(self, image: np.ndarray):
|
||||||
# PIL.open(image_path) 3channel: RGB
|
# PIL.open(image_path) 3channel: RGB
|
||||||
# image embedding
|
# image embedding: avoid encode the same image multiple times
|
||||||
if self.embedded:
|
if self.embedded:
|
||||||
print('repeat embedding, please reset_image.')
|
print('repeat embedding, please reset_image.')
|
||||||
return
|
return
|
||||||
@@ -40,7 +40,7 @@ class BaseSegmenter:
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def reset_image(self):
|
def reset_image(self):
|
||||||
# reset
|
# reset image embeding
|
||||||
self.predictor.reset_image()
|
self.predictor.reset_image()
|
||||||
self.embedded = False
|
self.embedded = False
|
||||||
|
|
||||||
@@ -71,7 +71,7 @@ class BaseSegmenter:
|
|||||||
mask_input=prompts['mask_input'],
|
mask_input=prompts['mask_input'],
|
||||||
multimask_output=multimask)
|
multimask_output=multimask)
|
||||||
else:
|
else:
|
||||||
raise("Not implement!")
|
raise("Not implement now!")
|
||||||
# masks (n, h, w), scores (n,), logits (n, 256, 256)
|
# masks (n, h, w), scores (n,), logits (n, 256, 256)
|
||||||
return masks, scores, logits
|
return masks, scores, logits
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user