mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 16:37:58 +01:00
demo_version update with xmem initialize sam--li
This commit is contained in:
11
app.py
11
app.py
@@ -71,7 +71,10 @@ def get_frames_from_video(video_input, play_state):
|
||||
[[0:nearest_frame], [nearest_frame:], nearest_frame]
|
||||
"""
|
||||
video_path = video_input
|
||||
try:
|
||||
timestamp = play_state[1] - play_state[0]
|
||||
except:
|
||||
timestamp = 0.1
|
||||
frames = []
|
||||
try:
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
@@ -196,7 +199,7 @@ def interactive_correction(video_state, point_prompt, click_state, select_correc
|
||||
def correct_track(video_state, select_correction_frame, corrected_state, masks, logits, painted_images):
|
||||
model.xmem.clear_memory()
|
||||
# inference the following images
|
||||
following_images = video_state[1][select_correction_frame+1:]
|
||||
following_images = video_state[1][select_correction_frame:]
|
||||
corrected_masks, corrected_logits, corrected_painted_images = model.generator(images=following_images, mask=corrected_state[0])
|
||||
masks = masks[:select_correction_frame] + corrected_masks
|
||||
logits = logits[:select_correction_frame] + corrected_logits
|
||||
@@ -216,7 +219,7 @@ xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoi
|
||||
|
||||
# args, defined in track_anything.py
|
||||
args = parse_augment()
|
||||
args.port = 12212
|
||||
args.port = 12315
|
||||
args.device = "cuda:2"
|
||||
|
||||
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
|
||||
@@ -328,11 +331,11 @@ with gr.Blocks() as iface:
|
||||
|
||||
# clear input
|
||||
video_input.clear(
|
||||
lambda: (None, [], [], [[], [], []],
|
||||
lambda: ([], [], [[], [], []],
|
||||
None, "", "", "", "", "", "", "", [[],[]],
|
||||
None),
|
||||
[],
|
||||
[video_input, state, play_state, video_state,
|
||||
[ state, play_state, video_state,
|
||||
template_frame, video_output, image_output, origin_image, template_mask, painted_images, masks, logits, click_state,
|
||||
select_correction_frame],
|
||||
queue=False,
|
||||
|
||||
@@ -12,7 +12,7 @@ class TrackingAnything():
|
||||
def __init__(self, sam_checkpoint, xmem_checkpoint, args):
|
||||
self.args = args
|
||||
self.samcontroler = SamControler(sam_checkpoint, args.sam_model_type, args.device)
|
||||
self.xmem = BaseTracker(xmem_checkpoint, device=args.device, )
|
||||
self.xmem = BaseTracker(xmem_checkpoint, device=args.device, sam_checkpoint=sam_checkpoint, model_type=args.sam_model_type)
|
||||
|
||||
|
||||
def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,
|
||||
|
||||
Reference in New Issue
Block a user