This commit is contained in:
ShangGaoG
2023-04-12 18:41:55 +08:00
parent 9b9bc62425
commit a9219ebe2f

View File

@@ -65,11 +65,13 @@ class BaseSegmenter:
elif mode == 'mask':
masks, scores, logits = self.predictor.predict(mask_input=prompts['mask_input'],
multimask_output=multimask)
else: # both
elif mode == 'both': # both
masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'],
point_labels=prompts['point_labels'],
mask_input=prompts['mask_input'],
multimask_output=multimask)
else:
raise("Not implement!")
# masks (n, h, w), scores (n,), logits (n, 256, 256)
return masks, scores, logits