neg or positive

This commit is contained in:
ShangGaoG
2023-04-14 19:25:56 +08:00
parent 23926d2c6f
commit a6f78aff3b
2 changed files with 16 additions and 6 deletions

View File

@@ -46,28 +46,38 @@ class SamControler():
return
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray,logits: np.ndarray=None, multimask=True):
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
'''
it is used in first frame in video
return: mask, logit, painted image(mask+point)
'''
# self.sam_controler.set_image(image)
origal_image = self.sam_controler.orignal_image
if logits is None:
neg_flag = labels[-1]
if neg_flag==1:
#find neg
prompts = {
'point_coords': points,
'point_labels': labels,
}
masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
else:
prompts = {
'point_coords': points,
'point_labels': labels,
'mask_input': logits[None, :, :]
'mask_input': logit[None, :, :]
}
masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask)
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
else:
#find positive
prompts = {
'point_coords': points,
'point_labels': labels,
}
masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
assert len(points)==len(labels)

View File

@@ -28,8 +28,8 @@ class TrackingAnything():
mask, logit, painted_image = self.xmem.track(image, logit)
return mask, logit, painted_image
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels,logits, multimask)
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
return mask, logit, painted_image
def interact(self, image: np.ndarray, same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):