Fix warmup

This commit is contained in:
hzwer
2021-09-02 16:42:52 +08:00
parent ecf9c879dc
commit ea632dd3fb
2 changed files with 2 additions and 17 deletions

View File

@@ -50,22 +50,6 @@ class Model:
if rank == 0:
torch.save(self.flownet.state_dict(),'{}/flownet.pkl'.format(path))
'''
def predict(self, imgs, flow, merged, training=True, flow_gt=None):
img0 = imgs[:, :3]
img1 = imgs[:, 3:]
c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4])
refine_output = self.unet(img0, img1, flow, merged, c0, c1, flow_gt)
res = torch.sigmoid(refine_output[:, :3]) * 2 - 1
pred = merged + res
pred = torch.clamp(pred, 0, 1)
if training:
return pred, merged
else:
return pred
'''
def inference(self, img0, img1, scale_list=[4, 2, 1], TTA=False):
imgs = torch.cat((img0, img1), 1)
flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet(imgs, scale_list)

View File

@@ -21,6 +21,7 @@ exp = os.path.abspath('.').split('/')[-1]
def get_learning_rate(step):
if step < 2000:
mul = step / 2000.
return 3e-4 * mul
else:
mul = np.cos((step - 2000) / (args.epoch * args.step_per_epoch - 2000.) * math.pi) * 0.5 + 0.5
return (3e-4 - 3e-5) * mul + 3e-5