diff --git a/inference_mp4_4x_parallel.py b/inference_mp4_4x_parallel.py index e30a3a0..9b43178 100644 --- a/inference_mp4_4x_parallel.py +++ b/inference_mp4_4x_parallel.py @@ -49,20 +49,21 @@ def writeframe(I0, mid0, mid1, mid2, I1, p): if skip_frame % 100 == 0: print("Warning: Your video has {} static frames, skipping them may change the duration of the generated video.".format(skip_frame)) skip_frame += 1 - if args.png: - cv2.imwrite('output/{:0>7d}.png'.format(cnt), I0[i]) - cnt += 1 - cv2.imwrite('output/{:0>7d}.png'.format(cnt), mid0[i]) - cnt += 1 - cv2.imwrite('output/{:0>7d}.png'.format(cnt), mid1[i]) - cnt += 1 - cv2.imwrite('output/{:0>7d}.png'.format(cnt), mid2[i]) - cnt += 1 - else: - output.write(I0[i]) - output.write(mid0[i]) - output.write(mid1[i]) - output.write(mid2[i]) + continue + if args.png: + cv2.imwrite('output/{:0>7d}.png'.format(cnt), I0[i]) + cnt += 1 + cv2.imwrite('output/{:0>7d}.png'.format(cnt), mid0[i]) + cnt += 1 + cv2.imwrite('output/{:0>7d}.png'.format(cnt), mid1[i]) + cnt += 1 + cv2.imwrite('output/{:0>7d}.png'.format(cnt), mid2[i]) + cnt += 1 + else: + output.write(I0[i]) + output.write(mid0[i]) + output.write(mid1[i]) + output.write(mid2[i]) ph = ((h - 1) // 32 + 1) * 32 pw = ((w - 1) // 32 + 1) * 32 padding = (0, pw - w, 0, ph - h)