mirror of
https://github.com/hzwer/ECCV2022-RIFE.git
synced 2025-12-16 08:27:45 +01:00
Fix warmup
This commit is contained in:
@@ -50,22 +50,6 @@ class Model:
|
|||||||
if rank == 0:
|
if rank == 0:
|
||||||
torch.save(self.flownet.state_dict(),'{}/flownet.pkl'.format(path))
|
torch.save(self.flownet.state_dict(),'{}/flownet.pkl'.format(path))
|
||||||
|
|
||||||
'''
|
|
||||||
def predict(self, imgs, flow, merged, training=True, flow_gt=None):
|
|
||||||
img0 = imgs[:, :3]
|
|
||||||
img1 = imgs[:, 3:]
|
|
||||||
c0 = self.contextnet(img0, flow[:, :2])
|
|
||||||
c1 = self.contextnet(img1, flow[:, 2:4])
|
|
||||||
refine_output = self.unet(img0, img1, flow, merged, c0, c1, flow_gt)
|
|
||||||
res = torch.sigmoid(refine_output[:, :3]) * 2 - 1
|
|
||||||
pred = merged + res
|
|
||||||
pred = torch.clamp(pred, 0, 1)
|
|
||||||
if training:
|
|
||||||
return pred, merged
|
|
||||||
else:
|
|
||||||
return pred
|
|
||||||
'''
|
|
||||||
|
|
||||||
def inference(self, img0, img1, scale_list=[4, 2, 1], TTA=False):
|
def inference(self, img0, img1, scale_list=[4, 2, 1], TTA=False):
|
||||||
imgs = torch.cat((img0, img1), 1)
|
imgs = torch.cat((img0, img1), 1)
|
||||||
flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet(imgs, scale_list)
|
flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet(imgs, scale_list)
|
||||||
|
|||||||
3
train.py
3
train.py
@@ -21,9 +21,10 @@ exp = os.path.abspath('.').split('/')[-1]
|
|||||||
def get_learning_rate(step):
|
def get_learning_rate(step):
|
||||||
if step < 2000:
|
if step < 2000:
|
||||||
mul = step / 2000.
|
mul = step / 2000.
|
||||||
|
return 3e-4 * mul
|
||||||
else:
|
else:
|
||||||
mul = np.cos((step - 2000) / (args.epoch * args.step_per_epoch - 2000.) * math.pi) * 0.5 + 0.5
|
mul = np.cos((step - 2000) / (args.epoch * args.step_per_epoch - 2000.) * math.pi) * 0.5 + 0.5
|
||||||
return (3e-4 - 3e-5) * mul + 3e-5
|
return (3e-4 - 3e-5) * mul + 3e-5
|
||||||
|
|
||||||
def flow2rgb(flow_map_np):
|
def flow2rgb(flow_map_np):
|
||||||
h, w, _ = flow_map_np.shape
|
h, w, _ = flow_map_np.shape
|
||||||
|
|||||||
Reference in New Issue
Block a user