Update train.py

This commit is contained in:
hzwer
2022-03-15 15:37:06 +08:00
committed by GitHub
parent 7788494700
commit f114857833

View File

@@ -21,10 +21,10 @@ log_path = 'train_log'
def get_learning_rate(step):
if step < 2000:
mul = step / 2000.
return 1e-4 * mul
return 3e-4 * mul
else:
mul = np.cos((step - 2000) / (args.epoch * args.step_per_epoch - 2000.) * math.pi) * 0.5 + 0.5
return (1e-4 - 1e-5) * mul + 1e-5
return (3e-4 - 3e-5) * mul + 3e-5
def flow2rgb(flow_map_np):
h, w, _ = flow_map_np.shape