diff --git a/.gitignore b/.gitignore index c7064c1..343bdb4 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,6 @@ *.pkl output/* *.mp4 + +test/ +.idea/ \ No newline at end of file diff --git a/inference_mp4_2x.py b/inference_mp4_2x.py index 7b1941b..5dcd4f1 100644 --- a/inference_mp4_2x.py +++ b/inference_mp4_2x.py @@ -18,6 +18,7 @@ parser.add_argument('--montage', dest='montage', action='store_true', help='mont parser.add_argument('--skip', dest='skip', action='store_true', help='whether to remove static frames before processing') parser.add_argument('--fps', dest='fps', type=int, default=60) parser.add_argument('--png', dest='png', action='store_true', help='whether to output png format outputs') +parser.add_argument('--ext', dest='ext', type=str, default='mp4', help='output video extension') args = parser.parse_args() from model.RIFE import Model @@ -35,7 +36,8 @@ if args.png: if not os.path.exists('output'): os.mkdir('output') else: - output = cv2.VideoWriter('{}_2x.mp4'.format(args.video[:-4]), fourcc, args.fps, (w, h)) + video_path_wo_ext, ext = os.path.splitext(args.video) + output = cv2.VideoWriter('{}_2x.{}'.format(video_path_wo_ext, args.ext), fourcc, args.fps, (w, h)) cnt = 0 def writeframe(frame): @@ -53,7 +55,7 @@ ph = ((h - 1) // 32 + 1) * 32 pw = ((w - 1) // 32 + 1) * 32 padding = (0, pw - w, 0, ph - h) tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT) -print('{}.mp4, {} frames in total, {}FPS to {}FPS'.format(args.video[:-4], tot_frame, fps, args.fps)) +print('{}.{}, {} frames in total, {}FPS to {}FPS'.format(video_path_wo_ext, args.ext, tot_frame, fps, args.fps)) pbar = tqdm(total=tot_frame) skip_frame = 1 if args.montage: diff --git a/inference_mp4_4x.py b/inference_mp4_4x.py index 76ca4fa..4135592 100644 --- a/inference_mp4_4x.py +++ b/inference_mp4_4x.py @@ -18,6 +18,7 @@ parser.add_argument('--montage', dest='montage', action='store_true', help='mont parser.add_argument('--skip', dest='skip', action='store_true', help='whether to remove static frames before processing') parser.add_argument('--fps', dest='fps', type=int, default=60) parser.add_argument('--png', dest='png', action='store_true', help='whether to output png format outputs') +parser.add_argument('--ext', dest='ext', type=str, default='mp4', help='output video extension') args = parser.parse_args() from model.RIFE import Model @@ -35,7 +36,8 @@ if args.png: if not os.path.exists('output'): os.mkdir('output') else: - output = cv2.VideoWriter('{}_4x.mp4'.format(args.video[:-4]), fourcc, args.fps, (w, h)) + video_path_wo_ext, ext = os.path.splitext(args.video) + output = cv2.VideoWriter('{}_4x.{}'.format(video_path_wo_ext, args.ext), fourcc, args.fps, (w, h)) cnt = 0 def writeframe(frame): @@ -52,7 +54,7 @@ ph = ((h - 1) // 32 + 1) * 32 pw = ((w - 1) // 32 + 1) * 32 padding = (0, pw - w, 0, ph - h) tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT) -print('{}.mp4, {} frames in total, {}FPS to {}FPS'.format(args.video[:-4], tot_frame, fps, args.fps)) +print('{}.{}, {} frames in total, {}FPS to {}FPS'.format(video_path_wo_ext, args.ext, tot_frame, fps, args.fps)) pbar = tqdm(total=tot_frame) skip_frame = 1 if args.montage: diff --git a/inference_mp4_4x_parallel.py b/inference_mp4_4x_parallel.py index 759ff8f..715e9f0 100644 --- a/inference_mp4_4x_parallel.py +++ b/inference_mp4_4x_parallel.py @@ -17,6 +17,7 @@ parser.add_argument('--video', dest='video', required=True) parser.add_argument('--skip', dest='skip', action='store_true', help='whether to remove static frames before processing') parser.add_argument('--fps', dest='fps', type=int, default=60) parser.add_argument('--png', dest='png', action='store_true', help='whether to output png format outputs') +parser.add_argument('--ext', dest='ext', type=str, default='mp4', help='output video extension') args = parser.parse_args() from model.RIFE import Model @@ -34,7 +35,8 @@ if args.png: if not os.path.exists('output'): os.mkdir('output') else: - output = cv2.VideoWriter('{}_4x.mp4'.format(args.video[:-4]), fourcc, args.fps, (w, h)) + video_path_wo_ext, ext = os.path.splitext(args.video) + output = cv2.VideoWriter('{}_4x.{}'.format(video_path_wo_ext, args.ext), fourcc, args.fps, (w, h)) cnt = 0 skip_frame = 1 @@ -64,14 +66,15 @@ def writeframe(I0, mid0, mid1, mid2, I1, p): output.write(mid0[i]) output.write(mid1[i]) output.write(mid2[i]) + + ph = ((h - 1) // 32 + 1) * 32 pw = ((w - 1) // 32 + 1) * 32 padding = (0, pw - w, 0, ph - h) tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT) -print('{}.mp4, {} frames in total, {}FPS to {}FPS'.format(args.video[:-4], tot_frame, fps, args.fps)) +print('{}.{}, {} frames in total, {}FPS to {}FPS'.format(video_path_wo_ext, args.ext, tot_frame, fps, args.fps)) pbar = tqdm(total=tot_frame) -img_list = [] -img_list.append(frame) +img_list = [frame] while success: success, frame = videoCapture.read() if success: @@ -86,11 +89,11 @@ while success: mid1 = model.inference(I0, I1) mid0 = model.inference(I0, mid1) mid2 = model.inference(mid1, I1) - I0 = (((I0[:, :, :h, :w] * 255.).cpu().detach().numpy().transpose(0, 2, 3, 1))).astype('uint8') - I1 = (((I1[:, :, :h, :w] * 255.).cpu().detach().numpy().transpose(0, 2, 3, 1))).astype('uint8') - mid0 = (((mid0[:, :, :h, :w] * 255.).cpu().detach().numpy().transpose(0, 2, 3, 1))).astype('uint8') - mid1 = (((mid1[:, :, :h, :w] * 255.).cpu().detach().numpy().transpose(0, 2, 3, 1))).astype('uint8') - mid2 = (((mid2[:, :, :h, :w] * 255.).cpu().detach().numpy().transpose(0, 2, 3, 1))).astype('uint8') + I0 = ((I0[:, :, :h, :w] * 255.).cpu().detach().numpy().transpose(0, 2, 3, 1)).astype('uint8') + I1 = ((I1[:, :, :h, :w] * 255.).cpu().detach().numpy().transpose(0, 2, 3, 1)).astype('uint8') + mid0 = ((mid0[:, :, :h, :w] * 255.).cpu().detach().numpy().transpose(0, 2, 3, 1)).astype('uint8') + mid1 = ((mid1[:, :, :h, :w] * 255.).cpu().detach().numpy().transpose(0, 2, 3, 1)).astype('uint8') + mid2 = ((mid2[:, :, :h, :w] * 255.).cpu().detach().numpy().transpose(0, 2, 3, 1)).astype('uint8') writeframe(I0, mid0, mid1, mid2, I1, p.mean(3).mean(2).mean(1)) pbar.update(4) img_list = img_list[-1:]