2023-04-26 11:24:21 +08:00
|
|
|
import PIL
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
2023-04-14 04:02:02 +08:00
|
|
|
from tools.interact_tools import SamControler
|
2023-04-14 04:40:07 +08:00
|
|
|
from tracker.base_tracker import BaseTracker
|
2023-04-25 13:38:16 +00:00
|
|
|
from inpainter.base_inpainter import BaseInpainter
|
2023-04-14 04:40:07 +08:00
|
|
|
import numpy as np
|
2023-04-14 02:27:39 +00:00
|
|
|
import argparse
|
2023-04-27 20:43:10 +00:00
|
|
|
import cv2
|
2023-04-14 04:02:02 +08:00
|
|
|
|
2023-04-27 20:43:10 +00:00
|
|
|
def read_image_from_userfolder(image_path):
|
|
|
|
|
# if type:
|
|
|
|
|
image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
|
|
|
|
# else:
|
|
|
|
|
# image = cv2.cvtColor(cv2.imread("/tmp/{}/paintedimages/{}/{:08d}.png".format(username, video_state["video_name"], index+ ".png")), cv2.COLOR_BGR2RGB)
|
|
|
|
|
return image
|
2023-04-14 04:02:02 +08:00
|
|
|
|
2023-04-27 20:43:10 +00:00
|
|
|
def save_image_to_userfolder(video_state, index, image, type:bool):
|
|
|
|
|
if type:
|
|
|
|
|
image_path = "/tmp/{}/originimages/{}/{:08d}.png".format(video_state["user_name"], video_state["video_name"], index)
|
|
|
|
|
else:
|
|
|
|
|
image_path = "/tmp/{}/paintedimages/{}/{:08d}.png".format(video_state["user_name"], video_state["video_name"], index)
|
|
|
|
|
cv2.imwrite(image_path, image)
|
|
|
|
|
return image_path
|
2023-04-14 04:02:02 +08:00
|
|
|
class TrackingAnything():
|
2023-04-25 13:38:16 +00:00
|
|
|
def __init__(self, sam_checkpoint, xmem_checkpoint, e2fgvi_checkpoint, args):
|
2023-04-14 02:27:39 +00:00
|
|
|
self.args = args
|
2023-04-26 05:04:57 +00:00
|
|
|
self.sam_checkpoint = sam_checkpoint
|
|
|
|
|
self.xmem_checkpoint = xmem_checkpoint
|
|
|
|
|
self.e2fgvi_checkpoint = e2fgvi_checkpoint
|
|
|
|
|
self.samcontroler = SamControler(self.sam_checkpoint, args.sam_model_type, args.device)
|
|
|
|
|
self.xmem = BaseTracker(self.xmem_checkpoint, device=args.device)
|
|
|
|
|
self.baseinpainter = BaseInpainter(self.e2fgvi_checkpoint, args.device)
|
2023-04-19 11:34:14 +00:00
|
|
|
# def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,
|
|
|
|
|
# same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
|
|
|
|
# if first_flag:
|
|
|
|
|
# mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
|
|
|
|
|
# return mask, logit, painted_image
|
2023-04-14 04:40:07 +08:00
|
|
|
|
2023-04-19 11:34:14 +00:00
|
|
|
# if interact_flag:
|
|
|
|
|
# mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
|
|
|
|
|
# return mask, logit, painted_image
|
2023-04-14 04:40:07 +08:00
|
|
|
|
2023-04-19 11:34:14 +00:00
|
|
|
# mask, logit, painted_image = self.xmem.track(image, logit)
|
|
|
|
|
# return mask, logit, painted_image
|
2023-04-14 16:24:34 +08:00
|
|
|
|
2023-04-14 19:25:56 +08:00
|
|
|
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
|
|
|
|
|
mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
|
2023-04-14 16:24:34 +08:00
|
|
|
return mask, logit, painted_image
|
|
|
|
|
|
2023-04-19 11:34:14 +00:00
|
|
|
# def interact(self, image: np.ndarray, same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
|
|
|
|
# mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
|
|
|
|
|
# return mask, logit, painted_image
|
2023-04-14 16:24:34 +08:00
|
|
|
|
2023-04-27 20:43:10 +00:00
|
|
|
def generator(self, images: list, template_mask:np.ndarray, video_state:dict):
|
2023-04-14 04:40:07 +08:00
|
|
|
|
2023-04-14 19:11:44 +08:00
|
|
|
masks = []
|
|
|
|
|
logits = []
|
|
|
|
|
painted_images = []
|
2023-04-26 11:24:21 +08:00
|
|
|
for i in tqdm(range(len(images)), desc="Tracking image"):
|
2023-04-14 13:26:26 +00:00
|
|
|
if i ==0:
|
2023-06-01 07:55:12 +00:00
|
|
|
mask, logit = self.xmem.track(read_image_from_userfolder(images[i]), template_mask)
|
2023-04-14 13:26:26 +00:00
|
|
|
masks.append(mask)
|
|
|
|
|
logits.append(logit)
|
2023-04-27 20:43:10 +00:00
|
|
|
# painted_images.append(painted_image)
|
2023-06-01 07:55:12 +00:00
|
|
|
# painted_images.append(save_image_to_userfolder(video_state, index=i, image=cv2.cvtColor(np.asarray(painted_image),cv2.COLOR_BGR2RGB), type=False))
|
2023-04-14 19:11:44 +08:00
|
|
|
|
|
|
|
|
else:
|
2023-06-01 07:55:12 +00:00
|
|
|
mask, logit = self.xmem.track(read_image_from_userfolder(images[i]))
|
2023-04-14 13:26:26 +00:00
|
|
|
masks.append(mask)
|
|
|
|
|
logits.append(logit)
|
2023-04-27 20:43:10 +00:00
|
|
|
# painted_images.append(painted_image)
|
2023-06-01 07:55:12 +00:00
|
|
|
# painted_images.append(save_image_to_userfolder(video_state, index=i, image=cv2.cvtColor(np.asarray(painted_image),cv2.COLOR_BGR2RGB), type=False))
|
|
|
|
|
return masks, logits
|
2023-04-14 19:11:44 +08:00
|
|
|
|
2023-04-14 04:40:07 +08:00
|
|
|
|
2023-04-14 02:27:39 +00:00
|
|
|
def parse_augment():
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
|
parser.add_argument('--device', type=str, default="cuda:0")
|
|
|
|
|
parser.add_argument('--sam_model_type', type=str, default="vit_h")
|
|
|
|
|
parser.add_argument('--port', type=int, default=6080, help="only useful when running gradio applications")
|
|
|
|
|
parser.add_argument('--debug', action="store_true")
|
2023-04-24 17:07:46 +00:00
|
|
|
parser.add_argument('--mask_save', default=False)
|
2023-05-30 06:48:07 +00:00
|
|
|
parser.add_argument('--sequence', default="", help="sequence name")
|
|
|
|
|
parser.add_argument('--votdir', default="", help="vot workspace directory")
|
2023-06-05 09:41:15 +00:00
|
|
|
parser.add_argument('--davisdir', default="", help="davis workspace directory")
|
2023-04-14 02:27:39 +00:00
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
if args.debug:
|
|
|
|
|
print(args)
|
2023-04-14 19:11:44 +08:00
|
|
|
return args
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
masks = None
|
|
|
|
|
logits = None
|
|
|
|
|
painted_images = None
|
|
|
|
|
images = []
|
|
|
|
|
image = np.array(PIL.Image.open('/hhd3/gaoshang/truck.jpg'))
|
|
|
|
|
args = parse_augment()
|
|
|
|
|
# images.append(np.ones((20,20,3)).astype('uint8'))
|
|
|
|
|
# images.append(np.ones((20,20,3)).astype('uint8'))
|
|
|
|
|
images.append(image)
|
|
|
|
|
images.append(image)
|
|
|
|
|
|
|
|
|
|
mask = np.zeros_like(image)[:,:,0]
|
|
|
|
|
mask[0,0]= 1
|
|
|
|
|
trackany = TrackingAnything('/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth','/ssd1/gaomingqi/checkpoints/XMem-s012.pth', args)
|
|
|
|
|
masks, logits ,painted_images= trackany.generator(images, mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|