mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 16:37:58 +01:00
neg or positive
This commit is contained in:
@@ -46,28 +46,38 @@ class SamControler():
|
||||
return
|
||||
|
||||
|
||||
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray,logits: np.ndarray=None, multimask=True):
|
||||
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
|
||||
'''
|
||||
it is used in first frame in video
|
||||
return: mask, logit, painted image(mask+point)
|
||||
'''
|
||||
# self.sam_controler.set_image(image)
|
||||
origal_image = self.sam_controler.orignal_image
|
||||
if logits is None:
|
||||
neg_flag = labels[-1]
|
||||
if neg_flag==1:
|
||||
#find neg
|
||||
prompts = {
|
||||
'point_coords': points,
|
||||
'point_labels': labels,
|
||||
}
|
||||
masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
|
||||
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
||||
else:
|
||||
prompts = {
|
||||
'point_coords': points,
|
||||
'point_labels': labels,
|
||||
'mask_input': logits[None, :, :]
|
||||
'mask_input': logit[None, :, :]
|
||||
}
|
||||
masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask)
|
||||
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
||||
else:
|
||||
#find positive
|
||||
prompts = {
|
||||
'point_coords': points,
|
||||
'point_labels': labels,
|
||||
}
|
||||
masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
|
||||
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
||||
|
||||
|
||||
assert len(points)==len(labels)
|
||||
|
||||
|
||||
@@ -28,8 +28,8 @@ class TrackingAnything():
|
||||
mask, logit, painted_image = self.xmem.track(image, logit)
|
||||
return mask, logit, painted_image
|
||||
|
||||
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
||||
mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels,logits, multimask)
|
||||
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
|
||||
mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
|
||||
return mask, logit, painted_image
|
||||
|
||||
def interact(self, image: np.ndarray, same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
||||
|
||||
Reference in New Issue
Block a user