mirror of
https://github.com/hzwer/ECCV2022-RIFE.git
synced 2025-12-16 00:17:46 +01:00
Fix warmup
This commit is contained in:
@@ -50,22 +50,6 @@ class Model:
|
||||
if rank == 0:
|
||||
torch.save(self.flownet.state_dict(),'{}/flownet.pkl'.format(path))
|
||||
|
||||
'''
|
||||
def predict(self, imgs, flow, merged, training=True, flow_gt=None):
|
||||
img0 = imgs[:, :3]
|
||||
img1 = imgs[:, 3:]
|
||||
c0 = self.contextnet(img0, flow[:, :2])
|
||||
c1 = self.contextnet(img1, flow[:, 2:4])
|
||||
refine_output = self.unet(img0, img1, flow, merged, c0, c1, flow_gt)
|
||||
res = torch.sigmoid(refine_output[:, :3]) * 2 - 1
|
||||
pred = merged + res
|
||||
pred = torch.clamp(pred, 0, 1)
|
||||
if training:
|
||||
return pred, merged
|
||||
else:
|
||||
return pred
|
||||
'''
|
||||
|
||||
def inference(self, img0, img1, scale_list=[4, 2, 1], TTA=False):
|
||||
imgs = torch.cat((img0, img1), 1)
|
||||
flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet(imgs, scale_list)
|
||||
|
||||
3
train.py
3
train.py
@@ -21,9 +21,10 @@ exp = os.path.abspath('.').split('/')[-1]
|
||||
def get_learning_rate(step):
|
||||
if step < 2000:
|
||||
mul = step / 2000.
|
||||
return 3e-4 * mul
|
||||
else:
|
||||
mul = np.cos((step - 2000) / (args.epoch * args.step_per_epoch - 2000.) * math.pi) * 0.5 + 0.5
|
||||
return (3e-4 - 3e-5) * mul + 3e-5
|
||||
return (3e-4 - 3e-5) * mul + 3e-5
|
||||
|
||||
def flow2rgb(flow_map_np):
|
||||
h, w, _ = flow_map_np.shape
|
||||
|
||||
Reference in New Issue
Block a user