neg or positive

This commit is contained in:
ShangGaoG
2023-04-14 19:25:56 +08:00
parent 23926d2c6f
commit a6f78aff3b
2 changed files with 16 additions and 6 deletions

View File

@@ -46,28 +46,38 @@ class SamControler():
return return
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray,logits: np.ndarray=None, multimask=True): def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
''' '''
it is used in first frame in video it is used in first frame in video
return: mask, logit, painted image(mask+point) return: mask, logit, painted image(mask+point)
''' '''
# self.sam_controler.set_image(image) # self.sam_controler.set_image(image)
origal_image = self.sam_controler.orignal_image origal_image = self.sam_controler.orignal_image
if logits is None: neg_flag = labels[-1]
if neg_flag==1:
#find neg
prompts = { prompts = {
'point_coords': points, 'point_coords': points,
'point_labels': labels, 'point_labels': labels,
} }
masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask) masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
else:
prompts = { prompts = {
'point_coords': points, 'point_coords': points,
'point_labels': labels, 'point_labels': labels,
'mask_input': logits[None, :, :] 'mask_input': logit[None, :, :]
} }
masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask) masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask)
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
else:
#find positive
prompts = {
'point_coords': points,
'point_labels': labels,
}
masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
assert len(points)==len(labels) assert len(points)==len(labels)

View File

@@ -28,8 +28,8 @@ class TrackingAnything():
mask, logit, painted_image = self.xmem.track(image, logit) mask, logit, painted_image = self.xmem.track(image, logit)
return mask, logit, painted_image return mask, logit, painted_image
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True): def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels,logits, multimask) mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
return mask, logit, painted_image return mask, logit, painted_image
def interact(self, image: np.ndarray, same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True): def interact(self, image: np.ndarray, same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):