Merge remote-tracking branch 'origin/main' into main

# Conflicts:
#	model/RIFE.py
This commit is contained in:
hzwer
2021-08-13 16:34:55 +08:00
3 changed files with 10 additions and 2 deletions

View File

@@ -35,7 +35,7 @@ cd arXiv2020-RIFE
pip3 install -r requirements.txt pip3 install -r requirements.txt
``` ```
* Download the pretrained **HD** models from [here](https://drive.google.com/file/d/10-2AaFUyX-c7yCfubsxF2NTvM7DgvS8l/view?usp=sharing). (百度网盘链接:https://pan.baidu.com/s/1cJ7-dPuwR8THPUGWb207ZQ 密码:aa0w,把压缩包解开后放在 train_log/\*) * Download the pretrained **HD** models from [here](https://drive.google.com/file/d/1APIzVeI-4ZZCEuIRE1m6WYfSCaOsi_7_/view?usp=sharing). (百度网盘链接:https://pan.baidu.com/share/init?surl=u6Q7-i4Hu4Vx9_5BJibPPA 密码:hfk3,把压缩包解开后放在 train_log/\*)
* Unzip and move the pretrained parameters to train_log/\* * Unzip and move the pretrained parameters to train_log/\*

View File

@@ -213,7 +213,7 @@ while True:
I1 = pad_image(I1) I1 = pad_image(I1)
I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False) I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False)
I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False) I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
ssim = ssim_matlab(I0_small, I1_small) ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
if ssim > 0.995: if ssim > 0.995:
if skip_frame % 100 == 0: if skip_frame % 100 == 0:

View File

@@ -57,7 +57,15 @@ class Model:
return pred, merged return pred, merged
else: else:
return pred return pred
<<<<<<< HEAD
''' '''
=======
def inference(self, img0, img1, scale=None):
imgs = torch.cat((img0, img1), 1)
flow, _ = self.flownet(torch.cat((img0, img1), 1))
return self.predict(imgs, flow, training=False)
>>>>>>> origin/main
def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None): def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
for param_group in self.optimG.param_groups: for param_group in self.optimG.param_groups: