mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 08:27:49 +01:00
104 lines
3.2 KiB
Python
104 lines
3.2 KiB
Python
import torch
|
|
|
|
from ..fbrs.inference import clicker
|
|
from ..fbrs.inference.predictors import get_predictor
|
|
|
|
|
|
class InteractiveController:
|
|
def __init__(self, net, device, predictor_params, prob_thresh=0.5):
|
|
self.net = net.to(device)
|
|
self.prob_thresh = prob_thresh
|
|
self.clicker = clicker.Clicker()
|
|
self.states = []
|
|
self.probs_history = []
|
|
self.object_count = 0
|
|
self._result_mask = None
|
|
|
|
self.image = None
|
|
self.predictor = None
|
|
self.device = device
|
|
self.predictor_params = predictor_params
|
|
self.reset_predictor()
|
|
|
|
def set_image(self, image):
|
|
self.image = image
|
|
self._result_mask = torch.zeros(image.shape[-2:], dtype=torch.uint8)
|
|
self.object_count = 0
|
|
self.reset_last_object()
|
|
|
|
def add_click(self, x, y, is_positive):
|
|
self.states.append({
|
|
'clicker': self.clicker.get_state(),
|
|
'predictor': self.predictor.get_states()
|
|
})
|
|
|
|
click = clicker.Click(is_positive=is_positive, coords=(y, x))
|
|
self.clicker.add_click(click)
|
|
pred = self.predictor.get_prediction(self.clicker)
|
|
torch.cuda.empty_cache()
|
|
|
|
if self.probs_history:
|
|
self.probs_history.append((self.probs_history[-1][0], pred))
|
|
else:
|
|
self.probs_history.append((torch.zeros_like(pred), pred))
|
|
|
|
def undo_click(self):
|
|
if not self.states:
|
|
return
|
|
|
|
prev_state = self.states.pop()
|
|
self.clicker.set_state(prev_state['clicker'])
|
|
self.predictor.set_states(prev_state['predictor'])
|
|
self.probs_history.pop()
|
|
|
|
def partially_finish_object(self):
|
|
object_prob = self.current_object_prob
|
|
if object_prob is None:
|
|
return
|
|
|
|
self.probs_history.append((object_prob, torch.zeros_like(object_prob)))
|
|
self.states.append(self.states[-1])
|
|
|
|
self.clicker.reset_clicks()
|
|
self.reset_predictor()
|
|
|
|
def finish_object(self):
|
|
object_prob = self.current_object_prob
|
|
if object_prob is None:
|
|
return
|
|
|
|
self.object_count += 1
|
|
object_mask = object_prob > self.prob_thresh
|
|
self._result_mask[object_mask] = self.object_count
|
|
self.reset_last_object()
|
|
|
|
def reset_last_object(self):
|
|
self.states = []
|
|
self.probs_history = []
|
|
self.clicker.reset_clicks()
|
|
self.reset_predictor()
|
|
|
|
def reset_predictor(self, predictor_params=None):
|
|
if predictor_params is not None:
|
|
self.predictor_params = predictor_params
|
|
self.predictor = get_predictor(self.net, device=self.device,
|
|
**self.predictor_params)
|
|
if self.image is not None:
|
|
self.predictor.set_input_image(self.image)
|
|
|
|
@property
|
|
def current_object_prob(self):
|
|
if self.probs_history:
|
|
current_prob_total, current_prob_additive = self.probs_history[-1]
|
|
return torch.maximum(current_prob_total, current_prob_additive)
|
|
else:
|
|
return None
|
|
|
|
@property
|
|
def is_incomplete_mask(self):
|
|
return len(self.probs_history) > 0
|
|
|
|
@property
|
|
def result_mask(self):
|
|
return self._result_mask.clone()
|