diff --git a/inference_video.py b/inference_video.py index eaf2edc..6ce6f19 100644 --- a/inference_video.py +++ b/inference_video.py @@ -8,7 +8,7 @@ from torch.nn import functional as F import warnings import _thread import skvideo.io -from queue import Queue +from queue import Queue, Empty warnings.filterwarnings("ignore") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -51,13 +51,13 @@ else: video_path_wo_ext, ext = os.path.splitext(args.video) vid_out = cv2.VideoWriter('{}_{}X_{}fps.{}'.format(video_path_wo_ext, args.exp, int(np.round(args.fps)), args.ext), fourcc, args.fps, (w, h)) -cnt = 0 -def clear_buffer(user_args, buffer): - global cnt +def clear_buffer(user_args): + cnt = 0 while True: - item = buffer.get() - if item is None: - break + try: + item = buffer.get(timeout=2) + except Empty: + return if user_args.png: cv2.imwrite('vid_out/{:0>7d}.png'.format(cnt), item[:, :, ::-1]) cnt += 1 @@ -76,7 +76,7 @@ skip_frame = 1 if args.montage: lastframe = lastframe[:, left: left + w] buffer = Queue() -_thread.start_new_thread(clear_buffer, (args, buffer)) +_thread.start_new_thread(clear_buffer, (args, )) for frame in videogen: if args.montage: frame = frame[:, left: left + w] @@ -124,9 +124,6 @@ if args.montage: buffer.put(np.concatenate((lastframe, lastframe), 1)) else: buffer.put(lastframe) -import time -while(not buffer.empty()): - time.sleep(0.1) pbar.close() if not vid_out is None: vid_out.release() diff --git a/inference_video_parallel.py b/inference_video_parallel.py index 04456cb..432661d 100644 --- a/inference_video_parallel.py +++ b/inference_video_parallel.py @@ -8,7 +8,7 @@ from torch.nn import functional as F import warnings import _thread import skvideo.io -from queue import Queue +from queue import Queue, Empty warnings.filterwarnings("ignore") @@ -80,9 +80,10 @@ def write_frame(i0, infs, i1, p, user_args): def clear_buffer(user_args): global cnt while True: - item = buffer.get() - if item is None: - break + try: + item = buffer.get(timeout=2) + except Empty: + return if user_args.png: cv2.imwrite('output/{:0>7d}.png'.format(cnt), item[:, :, ::-1]) cnt += 1 @@ -125,9 +126,6 @@ for frame in videogen: pbar.update(4) img_list = img_list[-1:] buffer.put(img_list[0]) -import time -while(not buffer.empty()): - time.sleep(0.1) pbar.close() if not vid_out is None: vid_out.release()