demo_version update with xmem initialize sam--li

This commit is contained in:
memoryunreal
2023-04-16 13:29:31 +00:00
parent f1811de58e
commit 71f6cb726d
2 changed files with 10 additions and 7 deletions

15
app.py
View File

@@ -71,7 +71,10 @@ def get_frames_from_video(video_input, play_state):
[[0:nearest_frame], [nearest_frame:], nearest_frame] [[0:nearest_frame], [nearest_frame:], nearest_frame]
""" """
video_path = video_input video_path = video_input
timestamp = play_state[1] - play_state[0] try:
timestamp = play_state[1] - play_state[0]
except:
timestamp = 0.1
frames = [] frames = []
try: try:
cap = cv2.VideoCapture(video_path) cap = cv2.VideoCapture(video_path)
@@ -196,7 +199,7 @@ def interactive_correction(video_state, point_prompt, click_state, select_correc
def correct_track(video_state, select_correction_frame, corrected_state, masks, logits, painted_images): def correct_track(video_state, select_correction_frame, corrected_state, masks, logits, painted_images):
model.xmem.clear_memory() model.xmem.clear_memory()
# inference the following images # inference the following images
following_images = video_state[1][select_correction_frame+1:] following_images = video_state[1][select_correction_frame:]
corrected_masks, corrected_logits, corrected_painted_images = model.generator(images=following_images, mask=corrected_state[0]) corrected_masks, corrected_logits, corrected_painted_images = model.generator(images=following_images, mask=corrected_state[0])
masks = masks[:select_correction_frame] + corrected_masks masks = masks[:select_correction_frame] + corrected_masks
logits = logits[:select_correction_frame] + corrected_logits logits = logits[:select_correction_frame] + corrected_logits
@@ -216,7 +219,7 @@ xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoi
# args, defined in track_anything.py # args, defined in track_anything.py
args = parse_augment() args = parse_augment()
args.port = 12212 args.port = 12315
args.device = "cuda:2" args.device = "cuda:2"
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args) model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
@@ -328,11 +331,11 @@ with gr.Blocks() as iface:
# clear input # clear input
video_input.clear( video_input.clear(
lambda: (None, [], [], [[], [], []], lambda: ([], [], [[], [], []],
None, "", "", "", "", "", "", "", [[], []], None, "", "", "", "", "", "", "", [[],[]],
None), None),
[], [],
[video_input, state, play_state, video_state, [ state, play_state, video_state,
template_frame, video_output, image_output, origin_image, template_mask, painted_images, masks, logits, click_state, template_frame, video_output, image_output, origin_image, template_mask, painted_images, masks, logits, click_state,
select_correction_frame], select_correction_frame],
queue=False, queue=False,

View File

@@ -12,7 +12,7 @@ class TrackingAnything():
def __init__(self, sam_checkpoint, xmem_checkpoint, args): def __init__(self, sam_checkpoint, xmem_checkpoint, args):
self.args = args self.args = args
self.samcontroler = SamControler(sam_checkpoint, args.sam_model_type, args.device) self.samcontroler = SamControler(sam_checkpoint, args.sam_model_type, args.device)
self.xmem = BaseTracker(xmem_checkpoint, device=args.device, ) self.xmem = BaseTracker(xmem_checkpoint, device=args.device, sam_checkpoint=sam_checkpoint, model_type=args.sam_model_type)
def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray, def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,