diff --git a/benchmark/MiddleBury_Other.py b/benchmark/MiddleBury_Other.py index 7c0f77f..a6676ce 100644 --- a/benchmark/MiddleBury_Other.py +++ b/benchmark/MiddleBury_Other.py @@ -12,7 +12,7 @@ from model.RIFE import Model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = Model() -model.load_model(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'train_log')) +model.load_model('train_log') model.eval() model.device() @@ -23,15 +23,15 @@ for i in name: i1 = cv2.imread('other-data/{}/frame11.png'.format(i)).transpose(2, 0, 1) / 255. gt = cv2.imread('other-gt-interp/{}/frame10i11.png'.format(i)) h, w = i0.shape[1], i0.shape[2] - imgs = torch.zeros([1, 6, 480, 640]) + imgs = torch.zeros([1, 6, 480, 640]).to(device) ph = (480 - h) // 2 pw = (640 - w) // 2 - imgs[:, :3, :h, :w] = torch.from_numpy(i0).unsqueeze(0).float() - imgs[:, 3:, :h, :w] = torch.from_numpy(i1).unsqueeze(0).float() + imgs[:, :3, :h, :w] = torch.from_numpy(i0).unsqueeze(0).float().to(device) + imgs[:, 3:, :h, :w] = torch.from_numpy(i1).unsqueeze(0).float().to(device) I0 = imgs[:, :3] I2 = imgs[:, 3:] pred = model.inference(I0, I2) - out = pred[0].cpu().numpy().transpose(1, 2, 0) + out = pred[0].detach().cpu().numpy().transpose(1, 2, 0) out = np.round(out[:h, :w] * 255) IE_list.append(np.abs((out - gt * 1.0)).mean()) print(np.mean(IE_list)) diff --git a/benchmark/Vimeo90K.py b/benchmark/Vimeo90K.py index cd70e48..ccc7d28 100644 --- a/benchmark/Vimeo90K.py +++ b/benchmark/Vimeo90K.py @@ -13,7 +13,7 @@ from model.RIFE import Model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = Model() -model.load_model(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'train_log')) +model.load_model('train_log') model.eval() model.device()