From 7005b64b4301ae6549d9dd38fe9ba68bd9f69960 Mon Sep 17 00:00:00 2001 From: hzwer <598460606@163.com> Date: Mon, 16 Nov 2020 20:08:11 +0800 Subject: [PATCH] Add png output --- inference_mp4_2x.py | 39 +++++++++++++++++++++++++++------------ inference_mp4_4x.py | 44 +++++++++++++++++++++++++++++--------------- 2 files changed, 56 insertions(+), 27 deletions(-) diff --git a/inference_mp4_2x.py b/inference_mp4_2x.py index f866515..c5f4983 100644 --- a/inference_mp4_2x.py +++ b/inference_mp4_2x.py @@ -18,8 +18,9 @@ parser.add_argument('--montage', dest='montage', action='store_true', help='mont parser.add_argument('--skip', dest='skip', action='store_true', help='whether to remove static frames before processing') parser.add_argument('--fps', dest='fps', type=int, default=60) parser.add_argument('--model', dest='model', type=str, default='RIFE') +parser.add_argument('--png', dest='png', action='store_true', help='whether to output png format outputs') args = parser.parse_args() - + if args.model == '2F': from model.RIFE2F import Model else: @@ -34,7 +35,21 @@ fps = np.round(videoCapture.get(cv2.CAP_PROP_FPS)) success, frame = videoCapture.read() h, w, _ = frame.shape fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') -output = cv2.VideoWriter('{}_2x.mp4'.format(args.video[:-4]), fourcc, args.fps, (w, h)) +if args.png: + if not os.path.exists('output'): + os.mkdir('output') +else: + output = cv2.VideoWriter('{}_2x.mp4'.format(args.video[:-4]), fourcc, args.fps, (w, h)) + +cnt = 0 +def writeframe(frame): + global cnt + if args.png: + cv2.imwrite('output/{}.png'.format(cnt), frame) + cnt += 1 + else: + output.write(frame) + if args.montage: left = w // 4 w = w // 2 @@ -44,7 +59,7 @@ padding = (0, pw - w, 0, ph - h) tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT) print('{}.mp4, {} frames in total, {}FPS to {}FPS'.format(args.video[:-4], tot_frame, fps, args.fps)) pbar = tqdm(total=tot_frame) -cnt = 1 +skip_frame = 1 if args.montage: frame = frame[:, left: left + w] while success: @@ -60,9 +75,9 @@ while success: p = (F.interpolate(I0, (16, 16), mode='bilinear', align_corners=False) - F.interpolate(I1, (16, 16), mode='bilinear', align_corners=False)).abs().mean() if p < 1e-3 and args.skip: - if cnt % 100 == 0: - print("Warning: Your video has {} static frames, skipping them change the duration of the generated video.".format(cnt)) - cnt += 1 + if skip_frame % 100 == 0: + print("Warning: Your video has {} static frames, skipping them may change the duration of the generated video.".format(skip_frame)) + skip_frame += 1 pbar.update(1) continue if p > 0.2: @@ -71,15 +86,15 @@ while success: mid1 = model.inference(I0, I1) mid1 = (((mid1[0] * 255.).cpu().detach().numpy().transpose(1, 2, 0))).astype('uint8') if args.montage: - output.write(np.concatenate((lastframe, lastframe), 1)) - output.write(np.concatenate((lastframe, mid1[:h, :w]), 1)) + writeframe(np.concatenate((lastframe, lastframe), 1)) + writeframe(np.concatenate((lastframe, mid1[:h, :w]), 1)) else: - output.write(lastframe) - output.write(mid1[:h, :w]) + writeframe(lastframe) + writeframe(mid1[:h, :w]) pbar.update(1) if args.montage: - output.write(np.concatenate((lastframe, lastframe), 1)) + writeframe(np.concatenate((lastframe, lastframe), 1)) else: - output.write(lastframe) + writeframe(lastframe) pbar.close() output.release() diff --git a/inference_mp4_4x.py b/inference_mp4_4x.py index 27d90b7..a0f4ba2 100644 --- a/inference_mp4_4x.py +++ b/inference_mp4_4x.py @@ -18,6 +18,7 @@ parser.add_argument('--montage', dest='montage', action='store_true', help='mont parser.add_argument('--skip', dest='skip', action='store_true', help='whether to remove static frames before processing') parser.add_argument('--fps', dest='fps', type=int, default=60) parser.add_argument('--model', dest='model', type=str, default='RIFE') +parser.add_argument('--png', dest='png', action='store_true', help='whether to output png format outputs') args = parser.parse_args() if args.model == '2F': @@ -34,7 +35,20 @@ fps = np.round(videoCapture.get(cv2.CAP_PROP_FPS)) success, frame = videoCapture.read() h, w, _ = frame.shape fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') -output = cv2.VideoWriter('{}_4x.mp4'.format(args.video[:-4]), fourcc, args.fps, (w, h)) +if args.png: + if not os.path.exists('output'): + os.mkdir('output') +else: + output = cv2.VideoWriter('{}_4x.mp4'.format(args.video[:-4]), fourcc, args.fps, (w, h)) + +cnt = 0 +def writeframe(frame): + global cnt + if args.png: + cv2.imwrite('output/{}.png'.format(cnt), frame) + cnt += 1 + else: + output.write(frame) if args.montage: left = w // 4 w = w // 2 @@ -44,7 +58,7 @@ padding = (0, pw - w, 0, ph - h) tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT) print('{}.mp4, {} frames in total, {}FPS to {}FPS'.format(args.video[:-4], tot_frame, fps, args.fps)) pbar = tqdm(total=tot_frame) -cnt = 1 +skip_frame = 1 if args.montage: frame = frame[:, left: left + w] while success: @@ -60,9 +74,9 @@ while success: p = (F.interpolate(I0, (16, 16), mode='bilinear', align_corners=False) - F.interpolate(I1, (16, 16), mode='bilinear', align_corners=False)).abs().mean() if p < 1e-3 and args.skip: - if cnt % 100 == 0: - print("Warning: Your video has {} static frames, skipping them may change the duration of the generated video.".format(cnt)) - cnt += 1 + if skip_frame % 100 == 0: + print("Warning: Your video has {} static frames, skipping them may change the duration of the generated video.".format(skip_frame)) + skip_frame += 1 pbar.update(1) continue if p > 0.2: @@ -76,19 +90,19 @@ while success: mid1 = (((mid1[0] * 255.).cpu().detach().numpy().transpose(1, 2, 0))).astype('uint8') mid2 = (((mid[1]* 255.).cpu().detach().numpy().transpose(1, 2, 0))).astype('uint8') if args.montage: - output.write(np.concatenate((lastframe, lastframe), 1)) - output.write(np.concatenate((lastframe, mid0[:h, :w]), 1)) - output.write(np.concatenate((lastframe, mid1[:h, :w]), 1)) - output.write(np.concatenate((lastframe, mid2[:h, :w]), 1)) + writeframe(np.concatenate((lastframe, lastframe), 1)) + writeframe(np.concatenate((lastframe, mid0[:h, :w]), 1)) + writeframe(np.concatenate((lastframe, mid1[:h, :w]), 1)) + writeframe(np.concatenate((lastframe, mid2[:h, :w]), 1)) else: - output.write(lastframe) - output.write(mid0[:h, :w]) - output.write(mid1[:h, :w]) - output.write(mid2[:h, :w]) + writeframe(lastframe) + writeframe(mid0[:h, :w]) + writeframe(mid1[:h, :w]) + writeframe(mid2[:h, :w]) pbar.update(1) if args.montage: - output.write(np.concatenate((lastframe, lastframe), 1)) + writeframe(np.concatenate((lastframe, lastframe), 1)) else: - output.write(lastframe) + writeframe(lastframe) pbar.close() output.release()