From 9ea44e7e7bff45d98db229776e5689699817f2b8 Mon Sep 17 00:00:00 2001 From: hzwer <598460606@163.com> Date: Sun, 22 Nov 2020 10:32:13 +0800 Subject: [PATCH] Use thread safe queue --- inference_video.py | 36 ++++++++++++++++++++---------------- inference_video_parallel.py | 23 ++++++++++++----------- 2 files changed, 32 insertions(+), 27 deletions(-) diff --git a/inference_video.py b/inference_video.py index 1894571..4837eb3 100644 --- a/inference_video.py +++ b/inference_video.py @@ -7,6 +7,7 @@ from tqdm import tqdm from torch.nn import functional as F import warnings import _thread +from queue import Queue warnings.filterwarnings("ignore") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -40,7 +41,6 @@ if args.fps is None: success, frame = videoCapture.read() h, w, _ = frame.shape fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') -buffer = [] if args.png: if not os.path.exists('vid_out'): os.mkdir('vid_out') @@ -51,12 +51,15 @@ else: cnt = 0 def clear_buffer(user_args, buffer): global cnt - for i in buffer: + while True: + item = buffer.get() + if item is None: + break if user_args.png: - cv2.imwrite('vid_out/{:0>7d}.png'.format(cnt), i) + cv2.imwrite('vid_out/{:0>7d}.png'.format(cnt), item) cnt += 1 else: - vid_out.write(i) + vid_out.write(item) if args.montage: left = w // 4 @@ -70,6 +73,7 @@ pbar = tqdm(total=tot_frame) skip_frame = 1 if args.montage: frame = frame[:, left: left + w] +buffer = Queue() while success: lastframe = frame success, frame = videoCapture.read() @@ -101,27 +105,27 @@ while success: mid0 = (((mid[0] * 255.).byte().cpu().detach().numpy().transpose(1, 2, 0))) mid2 = (((mid[1] * 255.).byte().cpu().detach().numpy().transpose(1, 2, 0))) if args.montage: - buffer.append(np.concatenate((lastframe, lastframe), 1)) + buffer.put(np.concatenate((lastframe, lastframe), 1)) if args.exp == 4: - buffer.append(np.concatenate((lastframe, mid0[:h, :w]), 1)) - buffer.append(np.concatenate((lastframe, mid1[:h, :w]), 1)) + buffer.put(np.concatenate((lastframe, mid0[:h, :w]), 1)) + buffer.put(np.concatenate((lastframe, mid1[:h, :w]), 1)) if args.exp == 4: - buffer.append(np.concatenate((lastframe, mid2[:h, :w]), 1)) + buffer.put(np.concatenate((lastframe, mid2[:h, :w]), 1)) else: - buffer.append(lastframe) + buffer.put(lastframe) if args.exp == 4: - buffer.append(mid0[:h, :w]) - buffer.append(mid1[:h, :w]) + buffer.put(mid0[:h, :w]) + buffer.put(mid1[:h, :w]) if args.exp == 4: - buffer.append(mid2[:h, :w]) + buffer.put(mid2[:h, :w]) pbar.update(1) - if len(buffer) > 100: + if buffer.qsize() > 100: _thread.start_new_thread(clear_buffer, (args, buffer)) - buffer = [] + buffer.clear() if args.montage: - buffer.append(np.concatenate((lastframe, lastframe), 1)) + buffer.put(np.concatenate((lastframe, lastframe), 1)) else: - buffer.append(lastframe) + buffer.put(lastframe) _thread.start_new_thread(clear_buffer, (args, buffer)) pbar.close() if not vid_out is None: diff --git a/inference_video_parallel.py b/inference_video_parallel.py index 2ebf8b3..d04e917 100644 --- a/inference_video_parallel.py +++ b/inference_video_parallel.py @@ -7,6 +7,7 @@ from tqdm import tqdm from torch.nn import functional as F import warnings import _thread +from queue import Queue warnings.filterwarnings("ignore") @@ -50,7 +51,7 @@ else: cnt = 0 skip_frame = 1 -buffer = [] +buffer = Queue() def write_frame(i0, infs, i1, p, user_args): global skip_frame, cnt @@ -72,18 +73,21 @@ def write_frame(i0, infs, i1, p, user_args): continue # Write results. - buffer.append(i0[i]) + buffer.put(i0[i]) for inf in infs: - buffer.append(inf[i]) + buffer.put(inf[i]) -def clear_buffer(user_args, buffer): +def clear_buffer(user_args, buffer): global cnt - for i in buffer: + while True: + item = buffer.get() + if item is None: + break if user_args.png: - cv2.imwrite('output/{:0>7d}.png'.format(cnt), i) + cv2.imwrite('output/{:0>7d}.png'.format(cnt), item) cnt += 1 else: - vid_out.write(i) + vid_out.write(item) def make_inference(model, I0, I1, exp): middle = model.inference(I0, I1) @@ -101,6 +105,7 @@ tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT) print('{}.{}, {} frames in total, {}FPS to {}FPS'.format(video_path_wo_ext, args.ext, tot_frame, fps, args.fps)) pbar = tqdm(total=tot_frame) img_list = [frame] +_thread.start_new_thread(clear_buffer, (args, buffer)) while success: success, frame = videoCapture.read() if success: @@ -122,10 +127,6 @@ while success: write_frame(I0, inferences, I1, p.mean(3).mean(2).mean(1), args) pbar.update(4) img_list = img_list[-1:] - if len(buffer) > 100: - _thread.start_new_thread(clear_buffer, (args, buffer)) - buffer = [] -_thread.start_new_thread(clear_buffer, (args, buffer)) pbar.close() if not vid_out is None: vid_out.release()