diff --git a/inference_img.py b/inference_img.py index 1e6ea25..24a9ab0 100644 --- a/inference_img.py +++ b/inference_img.py @@ -22,12 +22,19 @@ model = Model() model.load_model('./train_log', -1) model.eval() model.device() - -img0 = cv2.imread(args.img[0]) -img1 = cv2.imread(args.img[1]) -img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0) -img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0) +if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'): + img0 = cv2.imread(args.img[0], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH) + img1 = cv2.imread(args.img[1], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH) + img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device)).unsqueeze(0) + img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device)).unsqueeze(0) + +else: + img0 = cv2.imread(args.img[0]) + img1 = cv2.imread(args.img[1]) + img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0) + img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0) + n, c, h, w = img0.shape ph = ((h - 1) // 32 + 1) * 32 pw = ((w - 1) // 32 + 1) * 32 @@ -48,4 +55,7 @@ for i in range(args.exp): if not os.path.exists('output'): os.mkdir('output') for i in range(len(img_list)): - cv2.imwrite('output/img{}.png'.format(i), (img_list[i][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w]) + if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'): + cv2.imwrite('output/img{}.exr'.format(i), (img_list[i][0]).cpu().numpy().transpose(1, 2, 0)[:h, :w], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF]) + else: + cv2.imwrite('output/img{}.png'.format(i), (img_list[i][0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w])