Files
Track-Anything/track_anything.py

28 lines
1.1 KiB
Python
Raw Normal View History

2023-04-14 04:02:02 +08:00
from tools.interact_tools import SamControler
2023-04-14 04:40:07 +08:00
from tracker.base_tracker import BaseTracker
import numpy as np
2023-04-14 04:02:02 +08:00
class TrackingAnything():
def __init__(self, cfg):
self.cfg = cfg
2023-04-14 04:40:07 +08:00
self.samcontroler = SamControler(cfg.sam_checkpoint, cfg.model_type, cfg.device)
self.xmem = BaseTracker(cfg.device, cfg.xmem_checkpoint)
def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,
same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
if first_flag:
mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
return mask, logit, painted_image
if interact_flag:
mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
return mask, logit, painted_image
mask, logit, painted_image = self.xmem.track(image, logit)
return mask, logit, painted_image