demo_version update with xmem initialize sam--li

This commit is contained in:
memoryunreal
2023-04-16 13:29:31 +00:00
parent f1811de58e
commit 71f6cb726d
2 changed files with 10 additions and 7 deletions

13
app.py
View File

@@ -71,7 +71,10 @@ def get_frames_from_video(video_input, play_state):
[[0:nearest_frame], [nearest_frame:], nearest_frame]
"""
video_path = video_input
try:
timestamp = play_state[1] - play_state[0]
except:
timestamp = 0.1
frames = []
try:
cap = cv2.VideoCapture(video_path)
@@ -196,7 +199,7 @@ def interactive_correction(video_state, point_prompt, click_state, select_correc
def correct_track(video_state, select_correction_frame, corrected_state, masks, logits, painted_images):
model.xmem.clear_memory()
# inference the following images
following_images = video_state[1][select_correction_frame+1:]
following_images = video_state[1][select_correction_frame:]
corrected_masks, corrected_logits, corrected_painted_images = model.generator(images=following_images, mask=corrected_state[0])
masks = masks[:select_correction_frame] + corrected_masks
logits = logits[:select_correction_frame] + corrected_logits
@@ -216,7 +219,7 @@ xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoi
# args, defined in track_anything.py
args = parse_augment()
args.port = 12212
args.port = 12315
args.device = "cuda:2"
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
@@ -328,11 +331,11 @@ with gr.Blocks() as iface:
# clear input
video_input.clear(
lambda: (None, [], [], [[], [], []],
None, "", "", "", "", "", "", "", [[], []],
lambda: ([], [], [[], [], []],
None, "", "", "", "", "", "", "", [[],[]],
None),
[],
[video_input, state, play_state, video_state,
[ state, play_state, video_state,
template_frame, video_output, image_output, origin_image, template_mask, painted_images, masks, logits, click_state,
select_correction_frame],
queue=False,

View File

@@ -12,7 +12,7 @@ class TrackingAnything():
def __init__(self, sam_checkpoint, xmem_checkpoint, args):
self.args = args
self.samcontroler = SamControler(sam_checkpoint, args.sam_model_type, args.device)
self.xmem = BaseTracker(xmem_checkpoint, device=args.device, )
self.xmem = BaseTracker(xmem_checkpoint, device=args.device, sam_checkpoint=sam_checkpoint, model_type=args.sam_model_type)
def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,