diff --git a/README.md b/README.md index 2c47a02..a12e8cf 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ cd arXiv2020-RIFE pip3 install -r requirements.txt ``` -* Download the pretrained **HD** models from [here](https://drive.google.com/file/d/10-2AaFUyX-c7yCfubsxF2NTvM7DgvS8l/view?usp=sharing). (百度网盘链接:https://pan.baidu.com/s/1cJ7-dPuwR8THPUGWb207ZQ 密码:aa0w,把压缩包解开后放在 train_log/\*) +* Download the pretrained **HD** models from [here](https://drive.google.com/file/d/1APIzVeI-4ZZCEuIRE1m6WYfSCaOsi_7_/view?usp=sharing). (百度网盘链接:https://pan.baidu.com/share/init?surl=u6Q7-i4Hu4Vx9_5BJibPPA 密码:hfk3,把压缩包解开后放在 train_log/\*) * Unzip and move the pretrained parameters to train_log/\* diff --git a/inference_video.py b/inference_video.py index fd98e27..1befced 100644 --- a/inference_video.py +++ b/inference_video.py @@ -213,7 +213,7 @@ while True: I1 = pad_image(I1) I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False) I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False) - ssim = ssim_matlab(I0_small, I1_small) + ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3]) if ssim > 0.995: if skip_frame % 100 == 0: diff --git a/model/RIFE.py b/model/RIFE.py index 9ed110d..cdc8876 100644 --- a/model/RIFE.py +++ b/model/RIFE.py @@ -57,7 +57,15 @@ class Model: return pred, merged else: return pred +<<<<<<< HEAD ''' +======= + + def inference(self, img0, img1, scale=None): + imgs = torch.cat((img0, img1), 1) + flow, _ = self.flownet(torch.cat((img0, img1), 1)) + return self.predict(imgs, flow, training=False) +>>>>>>> origin/main def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None): for param_group in self.optimG.param_groups: