mirror of
https://github.com/hzwer/ECCV2022-RIFE.git
synced 2025-12-16 16:37:51 +01:00
Update and rename HD_multi.py to HD_multi_4X.py
This commit is contained in:
@@ -31,7 +31,7 @@ name_list = [
|
|||||||
('HD_dataset/HD544p_GT/Sintel_Temple1_1280x544.yuv', 544, 1280),
|
('HD_dataset/HD544p_GT/Sintel_Temple1_1280x544.yuv', 544, 1280),
|
||||||
('HD_dataset/HD544p_GT/Sintel_Temple2_1280x544.yuv', 544, 1280),
|
('HD_dataset/HD544p_GT/Sintel_Temple2_1280x544.yuv', 544, 1280),
|
||||||
]
|
]
|
||||||
def inference(I0, I1, pad, multi=3):
|
def inference(I0, I1, pad, multi=2):
|
||||||
img = [I0, I1]
|
img = [I0, I1]
|
||||||
for i in range(multi):
|
for i in range(multi):
|
||||||
res = [I0]
|
res = [I0]
|
||||||
@@ -56,14 +56,14 @@ for data in name_list:
|
|||||||
_, lastframe = Reader.read()
|
_, lastframe = Reader.read()
|
||||||
# fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
|
# fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
|
||||||
# video = cv2.VideoWriter(name + '.mp4', fourcc, 30, (w, h))
|
# video = cv2.VideoWriter(name + '.mp4', fourcc, 30, (w, h))
|
||||||
for index in range(0, 100, 8):
|
for index in range(0, 100, 4):
|
||||||
gt = []
|
gt = []
|
||||||
if 'yuv' in name:
|
if 'yuv' in name:
|
||||||
IMAGE1, success1 = Reader.read(index)
|
IMAGE1, success1 = Reader.read(index)
|
||||||
IMAGE2, success2 = Reader.read(index + 8)
|
IMAGE2, success2 = Reader.read(index + 4)
|
||||||
if not success2:
|
if not success2:
|
||||||
break
|
break
|
||||||
for i in range(1, 8):
|
for i in range(1, 4):
|
||||||
tmp, _ = Reader.read(index + i)
|
tmp, _ = Reader.read(index + i)
|
||||||
gt.append(tmp)
|
gt.append(tmp)
|
||||||
else:
|
else:
|
||||||
@@ -82,7 +82,7 @@ for data in name_list:
|
|||||||
I1 = pader(I1)
|
I1 = pader(I1)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pred = inference(I0, I1, pad)
|
pred = inference(I0, I1, pad)
|
||||||
for i in range(8 - 1):
|
for i in range(4 - 1):
|
||||||
out = (np.round(pred[i].detach().cpu().numpy().transpose(1, 2, 0) * 255)).astype('uint8')
|
out = (np.round(pred[i].detach().cpu().numpy().transpose(1, 2, 0) * 255)).astype('uint8')
|
||||||
if 'yuv' in name:
|
if 'yuv' in name:
|
||||||
diff_rgb = 128.0 + rgb2yuv(gt[i] / 255.)[:, :, 0] * 255 - rgb2yuv(out / 255.)[:, :, 0] * 255
|
diff_rgb = 128.0 + rgb2yuv(gt[i] / 255.)[:, :, 0] * 255 - rgb2yuv(out / 255.)[:, :, 0] * 255
|
||||||
Reference in New Issue
Block a user