mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 16:37:58 +01:00
Merge branch 'master' of github.com:gaomingqi/Track-Anything
This commit is contained in:
@@ -31,6 +31,7 @@ class BaseSegmenter:
|
|||||||
def set_image(self, image: np.ndarray):
|
def set_image(self, image: np.ndarray):
|
||||||
# PIL.open(image_path) 3channel: RGB
|
# PIL.open(image_path) 3channel: RGB
|
||||||
# image embedding: avoid encode the same image multiple times
|
# image embedding: avoid encode the same image multiple times
|
||||||
|
self.orignal_image = image
|
||||||
if self.embedded:
|
if self.embedded:
|
||||||
print('repeat embedding, please reset_image.')
|
print('repeat embedding, please reset_image.')
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -46,28 +46,38 @@ class SamControler():
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray,logits: np.ndarray=None, multimask=True):
|
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
|
||||||
'''
|
'''
|
||||||
it is used in first frame in video
|
it is used in first frame in video
|
||||||
return: mask, logit, painted image(mask+point)
|
return: mask, logit, painted image(mask+point)
|
||||||
'''
|
'''
|
||||||
# self.sam_controler.set_image(image)
|
# self.sam_controler.set_image(image)
|
||||||
|
origal_image = self.sam_controler.orignal_image
|
||||||
if logits is None:
|
neg_flag = labels[-1]
|
||||||
|
if neg_flag==1:
|
||||||
|
#find neg
|
||||||
prompts = {
|
prompts = {
|
||||||
'point_coords': points,
|
'point_coords': points,
|
||||||
'point_labels': labels,
|
'point_labels': labels,
|
||||||
}
|
}
|
||||||
masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
|
masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
|
||||||
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
||||||
else:
|
|
||||||
prompts = {
|
prompts = {
|
||||||
'point_coords': points,
|
'point_coords': points,
|
||||||
'point_labels': labels,
|
'point_labels': labels,
|
||||||
'mask_input': logits[None, :, :]
|
'mask_input': logit[None, :, :]
|
||||||
}
|
}
|
||||||
masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask)
|
masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask)
|
||||||
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
||||||
|
else:
|
||||||
|
#find positive
|
||||||
|
prompts = {
|
||||||
|
'point_coords': points,
|
||||||
|
'point_labels': labels,
|
||||||
|
}
|
||||||
|
masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
|
||||||
|
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
||||||
|
|
||||||
|
|
||||||
assert len(points)==len(labels)
|
assert len(points)==len(labels)
|
||||||
|
|
||||||
@@ -79,6 +89,7 @@ class SamControler():
|
|||||||
return mask, logit, painted_image
|
return mask, logit, painted_image
|
||||||
|
|
||||||
def interact_loop(self, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
def interact_loop(self, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
||||||
|
origal_image = self.sam_controler.orignal_image
|
||||||
if same:
|
if same:
|
||||||
'''
|
'''
|
||||||
true; loop in the same image
|
true; loop in the same image
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
import sys
|
||||||
|
sys.path.append("/hhd3/gaoshang/Track-Anything/tracker")
|
||||||
|
import PIL
|
||||||
from tools.interact_tools import SamControler
|
from tools.interact_tools import SamControler
|
||||||
from tracker.base_tracker import BaseTracker
|
from tracker.base_tracker import BaseTracker
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -25,18 +28,31 @@ class TrackingAnything():
|
|||||||
mask, logit, painted_image = self.xmem.track(image, logit)
|
mask, logit, painted_image = self.xmem.track(image, logit)
|
||||||
return mask, logit, painted_image
|
return mask, logit, painted_image
|
||||||
|
|
||||||
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
|
||||||
mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels,logits, multimask)
|
mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
|
||||||
return mask, logit, painted_image
|
return mask, logit, painted_image
|
||||||
|
|
||||||
def interact(self, image: np.ndarray, same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
def interact(self, image: np.ndarray, same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
||||||
mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
|
mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
|
||||||
return mask, logit, painted_image
|
return mask, logit, painted_image
|
||||||
|
|
||||||
def generator(self, image: np.ndarray, logits:np.ndarray):
|
def generator(self, images: list, mask:np.ndarray):
|
||||||
mask, logit, painted_image = self.xmem.track(image, logits)
|
|
||||||
return mask, logit, painted_image
|
|
||||||
|
|
||||||
|
masks = []
|
||||||
|
logits = []
|
||||||
|
painted_images = []
|
||||||
|
for i in range(len(images)):
|
||||||
|
if i ==0:
|
||||||
|
|
||||||
|
mask, logit, painted_image = self.xmem.track(images[i], mask)
|
||||||
|
|
||||||
|
else:
|
||||||
|
mask, logit, painted_image = self.xmem.track(images[i])
|
||||||
|
masks.append(mask)
|
||||||
|
logits.append(logit)
|
||||||
|
painted_images.append(painted_image)
|
||||||
|
return masks, logits, painted_images
|
||||||
|
|
||||||
|
|
||||||
def parse_augment():
|
def parse_augment():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@@ -48,4 +64,27 @@ def parse_augment():
|
|||||||
|
|
||||||
if args.debug:
|
if args.debug:
|
||||||
print(args)
|
print(args)
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
masks = None
|
||||||
|
logits = None
|
||||||
|
painted_images = None
|
||||||
|
images = []
|
||||||
|
image = np.array(PIL.Image.open('/hhd3/gaoshang/truck.jpg'))
|
||||||
|
args = parse_augment()
|
||||||
|
# images.append(np.ones((20,20,3)).astype('uint8'))
|
||||||
|
# images.append(np.ones((20,20,3)).astype('uint8'))
|
||||||
|
images.append(image)
|
||||||
|
images.append(image)
|
||||||
|
|
||||||
|
mask = np.zeros_like(image)[:,:,0]
|
||||||
|
mask[0,0]= 1
|
||||||
|
trackany = TrackingAnything('/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth','/ssd1/gaomingqi/checkpoints/XMem-s012.pth', args)
|
||||||
|
masks, logits ,painted_images= trackany.generator(images, mask)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Reference in New Issue
Block a user