diff --git a/inference_mp4_2x.py b/inference_mp4_2x.py index e7ccaeb..9df0dbd 100644 --- a/inference_mp4_2x.py +++ b/inference_mp4_2x.py @@ -47,13 +47,13 @@ while success: mid = model.inference(I0, I1) mid = ((mid[0].cpu().detach().numpy().transpose(1, 2, 0))*255.).astype('uint8') if args.montage: - output.write(torch.cat((lastframe, lastframe), 2)) - output.write(torch.cat((lastframe, mid[:h, :w]), 2)) + output.write(np.concatenate((lastframe, lastframe), 1)) + output.write(np.concatenate((lastframe, mid[:h, :w]), 1)) else: output.write(lastframe) output.write(mid[:h, :w]) if args.montage: - output.write(torch.cat((frame, frame), 2)) + output.write(np.concatenate((frame, frame), 1)) else: output.write(frame) output.release()