Fix benchmark

This commit is contained in:
hzwer
2021-02-25 15:45:39 +08:00
parent 485d1a77a0
commit bdda49d81f
2 changed files with 6 additions and 6 deletions

View File

@@ -12,7 +12,7 @@ from model.RIFE import Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Model() model = Model()
model.load_model(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'train_log')) model.load_model('train_log')
model.eval() model.eval()
model.device() model.device()
@@ -23,15 +23,15 @@ for i in name:
i1 = cv2.imread('other-data/{}/frame11.png'.format(i)).transpose(2, 0, 1) / 255. i1 = cv2.imread('other-data/{}/frame11.png'.format(i)).transpose(2, 0, 1) / 255.
gt = cv2.imread('other-gt-interp/{}/frame10i11.png'.format(i)) gt = cv2.imread('other-gt-interp/{}/frame10i11.png'.format(i))
h, w = i0.shape[1], i0.shape[2] h, w = i0.shape[1], i0.shape[2]
imgs = torch.zeros([1, 6, 480, 640]) imgs = torch.zeros([1, 6, 480, 640]).to(device)
ph = (480 - h) // 2 ph = (480 - h) // 2
pw = (640 - w) // 2 pw = (640 - w) // 2
imgs[:, :3, :h, :w] = torch.from_numpy(i0).unsqueeze(0).float() imgs[:, :3, :h, :w] = torch.from_numpy(i0).unsqueeze(0).float().to(device)
imgs[:, 3:, :h, :w] = torch.from_numpy(i1).unsqueeze(0).float() imgs[:, 3:, :h, :w] = torch.from_numpy(i1).unsqueeze(0).float().to(device)
I0 = imgs[:, :3] I0 = imgs[:, :3]
I2 = imgs[:, 3:] I2 = imgs[:, 3:]
pred = model.inference(I0, I2) pred = model.inference(I0, I2)
out = pred[0].cpu().numpy().transpose(1, 2, 0) out = pred[0].detach().cpu().numpy().transpose(1, 2, 0)
out = np.round(out[:h, :w] * 255) out = np.round(out[:h, :w] * 255)
IE_list.append(np.abs((out - gt * 1.0)).mean()) IE_list.append(np.abs((out - gt * 1.0)).mean())
print(np.mean(IE_list)) print(np.mean(IE_list))

View File

@@ -13,7 +13,7 @@ from model.RIFE import Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Model() model = Model()
model.load_model(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'train_log')) model.load_model('train_log')
model.eval() model.eval()
model.device() model.device()