diff --git a/inference_mp4_2x.py b/inference_mp4_2x.py index c1e536a..5de612e 100644 --- a/inference_mp4_2x.py +++ b/inference_mp4_2x.py @@ -43,6 +43,7 @@ padding = (0, pw - w, 0, ph - h) tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT) print('{}.mp4, {} frames in total, {}FPS to {}FPS'.format(args.video[:-4], tot_frame, fps, args.fps)) pbar = tqdm(total=tot_frame) +cnt = 0 if args.montage: frame = frame[:, left: left + w] while success: @@ -55,8 +56,14 @@ while success: I1 = torch.from_numpy(np.transpose(frame, (2,0,1)).astype("float32") / 255.).to(device).unsqueeze(0) I0 = F.pad(I0, padding) I1 = F.pad(I1, padding) - if (F.interpolate(I0, (16, 16), mode='bilinear', align_corners=False) - - F.interpolate(I1, (16, 16), mode='bilinear', align_corners=False)).abs().mean() > 0.2: + p = (F.interpolate(I0, (16, 16), mode='bilinear', align_corners=False) + - F.interpolate(I1, (16, 16), mode='bilinear', align_corners=False)).abs().mean() + if p < 0.01: + print("Warning: Your video has {} static frames, it may change the duration of the generated video.".format(cnt)) + cnt += 1 + pbar.update(1) + continue + if p > 0.2: mid1 = lastframe else: mid1 = model.inference(I0, I1) @@ -66,7 +73,9 @@ while success: output.write(np.concatenate((lastframe, mid1[:h, :w]), 1)) else: output.write(lastframe) + output.write(mid0[:h, :w]) output.write(mid1[:h, :w]) + output.write(mid2[:h, :w]) pbar.update(1) if args.montage: output.write(np.concatenate((lastframe, lastframe), 1)) diff --git a/inference_mp4_4x.py b/inference_mp4_4x.py index 20dd593..0ec65be 100644 --- a/inference_mp4_4x.py +++ b/inference_mp4_4x.py @@ -43,6 +43,7 @@ padding = (0, pw - w, 0, ph - h) tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT) print('{}.mp4, {} frames in total, {}FPS to {}FPS'.format(args.video[:-4], tot_frame, fps, args.fps)) pbar = tqdm(total=tot_frame) +cnt = 0 if args.montage: frame = frame[:, left: left + w] while success: @@ -55,8 +56,14 @@ while success: I1 = torch.from_numpy(np.transpose(frame, (2,0,1)).astype("float32") / 255.).to(device).unsqueeze(0) I0 = F.pad(I0, padding) I1 = F.pad(I1, padding) - if (F.interpolate(I0, (16, 16), mode='bilinear', align_corners=False) - - F.interpolate(I1, (16, 16), mode='bilinear', align_corners=False)).abs().mean() > 0.2: + p = (F.interpolate(I0, (16, 16), mode='bilinear', align_corners=False) + - F.interpolate(I1, (16, 16), mode='bilinear', align_corners=False)).abs().mean() + if p < 0.01: + print("Warning: Your video has {} static frames, it may change the duration of the generated video.".format(cnt)) + cnt += 1 + pbar.update(1) + continue + if p > 0.2: mid0 = lastframe mid1 = lastframe mid2 = frame