diff --git a/inference_video_parallel.py b/inference_video_parallel.py index becdfd5..44578c9 100644 --- a/inference_video_parallel.py +++ b/inference_video_parallel.py @@ -6,6 +6,8 @@ import numpy as np from tqdm import tqdm from torch.nn import functional as F import warnings +import _thread + warnings.filterwarnings("ignore") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -48,11 +50,10 @@ else: cnt = 0 skip_frame = 1 - +buffer = [] def write_frame(i0, infs, i1, p, user_args): - global skip_frame, cnt - + global skip_frame, cnt for i in range(i0.shape[0]): # A video transition occurs. if p[i] > 0.2: @@ -70,18 +71,19 @@ def write_frame(i0, infs, i1, p, user_args): skip_frame += 1 continue - # Write results. - if user_args.png: - cv2.imwrite('output/{:0>7d}.png'.format(cnt), i0[i]) - cnt += 1 - for inf in infs: - cv2.imwrite('output/{:0>7d}.png'.format(cnt), inf[i]) - cnt += 1 - else: - vid_out.write(i0[i]) - for inf in infs: - vid_out.write(inf[i]) + # Write results. + buffer.append(i0[i]) + for inf in infs: + buffer.append(inf[i]) +def clear_buffer(user_args, buffer): + global cnt + for i in buffer: + if user_args.png: + cv2.imwrite('output/{:0>7d}.png'.format(cnt), i) + cnt += i + else: + vid_out.write(i) def make_inference(model, I0, I1, exp): middle = model.inference(I0, I1) @@ -120,5 +122,9 @@ while success: write_frame(I0, inferences, I1, p.mean(3).mean(2).mean(1), args) pbar.update(4) img_list = img_list[-1:] + if len(buffer) > 100: + _thread.start_new_thread(clear_buffer, (args, buffer)) + buffer = [] +_thread.start_new_thread(clear_buffer, (args, buffer)) pbar.close() vid_out.release()