mirror of
https://github.com/hzwer/ECCV2022-RIFE.git
synced 2025-12-16 00:17:46 +01:00
Fix benchmark
This commit is contained in:
@@ -12,7 +12,7 @@ from model.RIFE import Model
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
model = Model()
|
||||
model.load_model(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'train_log'))
|
||||
model.load_model('train_log')
|
||||
model.eval()
|
||||
model.device()
|
||||
|
||||
@@ -23,15 +23,15 @@ for i in name:
|
||||
i1 = cv2.imread('other-data/{}/frame11.png'.format(i)).transpose(2, 0, 1) / 255.
|
||||
gt = cv2.imread('other-gt-interp/{}/frame10i11.png'.format(i))
|
||||
h, w = i0.shape[1], i0.shape[2]
|
||||
imgs = torch.zeros([1, 6, 480, 640])
|
||||
imgs = torch.zeros([1, 6, 480, 640]).to(device)
|
||||
ph = (480 - h) // 2
|
||||
pw = (640 - w) // 2
|
||||
imgs[:, :3, :h, :w] = torch.from_numpy(i0).unsqueeze(0).float()
|
||||
imgs[:, 3:, :h, :w] = torch.from_numpy(i1).unsqueeze(0).float()
|
||||
imgs[:, :3, :h, :w] = torch.from_numpy(i0).unsqueeze(0).float().to(device)
|
||||
imgs[:, 3:, :h, :w] = torch.from_numpy(i1).unsqueeze(0).float().to(device)
|
||||
I0 = imgs[:, :3]
|
||||
I2 = imgs[:, 3:]
|
||||
pred = model.inference(I0, I2)
|
||||
out = pred[0].cpu().numpy().transpose(1, 2, 0)
|
||||
out = pred[0].detach().cpu().numpy().transpose(1, 2, 0)
|
||||
out = np.round(out[:h, :w] * 255)
|
||||
IE_list.append(np.abs((out - gt * 1.0)).mean())
|
||||
print(np.mean(IE_list))
|
||||
|
||||
@@ -13,7 +13,7 @@ from model.RIFE import Model
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
model = Model()
|
||||
model.load_model(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'train_log'))
|
||||
model.load_model('train_log')
|
||||
model.eval()
|
||||
model.device()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user