This commit is contained in:
ShangGaoG
2023-04-14 19:11:44 +08:00
parent fcecc735fc
commit 23926d2c6f

View File

@@ -1,3 +1,6 @@
import sys
sys.path.append("/hhd3/gaoshang/Track-Anything/tracker")
import PIL
from tools.interact_tools import SamControler
from tracker.base_tracker import BaseTracker
import numpy as np
@@ -33,9 +36,22 @@ class TrackingAnything():
mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
return mask, logit, painted_image
def generator(self, image: np.ndarray, logits:np.ndarray):
mask, logit, painted_image = self.xmem.track(image, logits)
return mask, logit, painted_image
def generator(self, images: list, mask:np.ndarray):
masks = []
logits = []
painted_images = []
for i in range(len(images)):
if i ==0:
mask, logit, painted_image = self.xmem.track(images[i], mask)
else:
mask, logit, painted_image = self.xmem.track(images[i])
masks.append(mask)
logits.append(logit)
painted_images.append(painted_image)
return masks, logits, painted_images
def parse_augment():
@@ -49,3 +65,26 @@ def parse_augment():
if args.debug:
print(args)
return args
if __name__ == "__main__":
masks = None
logits = None
painted_images = None
images = []
image = np.array(PIL.Image.open('/hhd3/gaoshang/truck.jpg'))
args = parse_augment()
# images.append(np.ones((20,20,3)).astype('uint8'))
# images.append(np.ones((20,20,3)).astype('uint8'))
images.append(image)
images.append(image)
mask = np.zeros_like(image)[:,:,0]
mask[0,0]= 1
trackany = TrackingAnything('/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth','/ssd1/gaomingqi/checkpoints/XMem-s012.pth', args)
masks, logits ,painted_images= trackany.generator(images, mask)