mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 16:37:58 +01:00
demo_version update with xmem initialize sam--li
This commit is contained in:
15
app.py
15
app.py
@@ -71,7 +71,10 @@ def get_frames_from_video(video_input, play_state):
|
|||||||
[[0:nearest_frame], [nearest_frame:], nearest_frame]
|
[[0:nearest_frame], [nearest_frame:], nearest_frame]
|
||||||
"""
|
"""
|
||||||
video_path = video_input
|
video_path = video_input
|
||||||
timestamp = play_state[1] - play_state[0]
|
try:
|
||||||
|
timestamp = play_state[1] - play_state[0]
|
||||||
|
except:
|
||||||
|
timestamp = 0.1
|
||||||
frames = []
|
frames = []
|
||||||
try:
|
try:
|
||||||
cap = cv2.VideoCapture(video_path)
|
cap = cv2.VideoCapture(video_path)
|
||||||
@@ -196,7 +199,7 @@ def interactive_correction(video_state, point_prompt, click_state, select_correc
|
|||||||
def correct_track(video_state, select_correction_frame, corrected_state, masks, logits, painted_images):
|
def correct_track(video_state, select_correction_frame, corrected_state, masks, logits, painted_images):
|
||||||
model.xmem.clear_memory()
|
model.xmem.clear_memory()
|
||||||
# inference the following images
|
# inference the following images
|
||||||
following_images = video_state[1][select_correction_frame+1:]
|
following_images = video_state[1][select_correction_frame:]
|
||||||
corrected_masks, corrected_logits, corrected_painted_images = model.generator(images=following_images, mask=corrected_state[0])
|
corrected_masks, corrected_logits, corrected_painted_images = model.generator(images=following_images, mask=corrected_state[0])
|
||||||
masks = masks[:select_correction_frame] + corrected_masks
|
masks = masks[:select_correction_frame] + corrected_masks
|
||||||
logits = logits[:select_correction_frame] + corrected_logits
|
logits = logits[:select_correction_frame] + corrected_logits
|
||||||
@@ -216,7 +219,7 @@ xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoi
|
|||||||
|
|
||||||
# args, defined in track_anything.py
|
# args, defined in track_anything.py
|
||||||
args = parse_augment()
|
args = parse_augment()
|
||||||
args.port = 12212
|
args.port = 12315
|
||||||
args.device = "cuda:2"
|
args.device = "cuda:2"
|
||||||
|
|
||||||
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
|
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
|
||||||
@@ -328,11 +331,11 @@ with gr.Blocks() as iface:
|
|||||||
|
|
||||||
# clear input
|
# clear input
|
||||||
video_input.clear(
|
video_input.clear(
|
||||||
lambda: (None, [], [], [[], [], []],
|
lambda: ([], [], [[], [], []],
|
||||||
None, "", "", "", "", "", "", "", [[], []],
|
None, "", "", "", "", "", "", "", [[],[]],
|
||||||
None),
|
None),
|
||||||
[],
|
[],
|
||||||
[video_input, state, play_state, video_state,
|
[ state, play_state, video_state,
|
||||||
template_frame, video_output, image_output, origin_image, template_mask, painted_images, masks, logits, click_state,
|
template_frame, video_output, image_output, origin_image, template_mask, painted_images, masks, logits, click_state,
|
||||||
select_correction_frame],
|
select_correction_frame],
|
||||||
queue=False,
|
queue=False,
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ class TrackingAnything():
|
|||||||
def __init__(self, sam_checkpoint, xmem_checkpoint, args):
|
def __init__(self, sam_checkpoint, xmem_checkpoint, args):
|
||||||
self.args = args
|
self.args = args
|
||||||
self.samcontroler = SamControler(sam_checkpoint, args.sam_model_type, args.device)
|
self.samcontroler = SamControler(sam_checkpoint, args.sam_model_type, args.device)
|
||||||
self.xmem = BaseTracker(xmem_checkpoint, device=args.device, )
|
self.xmem = BaseTracker(xmem_checkpoint, device=args.device, sam_checkpoint=sam_checkpoint, model_type=args.sam_model_type)
|
||||||
|
|
||||||
|
|
||||||
def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,
|
def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,
|
||||||
|
|||||||
Reference in New Issue
Block a user