diff --git a/README.md b/README.md index 02bb125..043e29c 100644 --- a/README.md +++ b/README.md @@ -17,5 +17,5 @@ pip3 install opencv-python ## Usage ``` -python3 inference.py --img img0.png img1.png +python3 inference.py --img /path/to/image_0 /path/to/image_1 ``` diff --git a/inference.py b/inference.py index 9553f8a..e669fcb 100644 --- a/inference.py +++ b/inference.py @@ -1,16 +1,15 @@ import cv2 import torch import argparse +from torch.nn import functional as F from model.RIFE import Model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") parser = argparse.ArgumentParser(description='Interpolation for a pair of images') -parser.add_argument('--img', dest='img', nargs=2) +parser.add_argument('--img', dest='img', nargs=2, required=True) args = parser.parse_args() -assert(len(args.img) == 2) - model = Model() model.load_model('./train_log') model.eval() @@ -18,9 +17,12 @@ model.device() img0 = cv2.imread(args.img[0]) img1 = cv2.imread(args.img[1]) h, w, _ = img0.shape +ph = h // 32 * 32 +pw = w // 32 * 32 +padding = (0, pw - w, 0, ph - h) img0 = torch.tensor(img0.transpose(2, 0, 1)).to(device) / 255. img1 = torch.tensor(img1.transpose(2, 0, 1)).to(device) / 255. -imgs = torch.cat((img0, img1), 0).float() +imgs = F.pad(torch.cat((img0, img1), 0).float(), padding) with torch.no_grad(): res = model.inference(imgs.unsqueeze(0)) * 255 -cv2.imwrite('output.png', res[0].numpy().transpose(1, 2, 0)) +cv2.imwrite('output.png', res[0].numpy().transpose(1, 2, 0)[:h, :w])