mirror of
https://github.com/AIGC-Audio/AudioGPT.git
synced 2025-12-17 20:37:55 +01:00
28 lines
976 B
Python
28 lines
976 B
Python
from utils.hparams import hparams
|
|
|
|
|
|
class RSQRTSchedule(object):
|
|
def __init__(self, optimizer):
|
|
super().__init__()
|
|
self.optimizer = optimizer
|
|
self.constant_lr = hparams['lr']
|
|
self.warmup_updates = hparams['warmup_updates']
|
|
self.hidden_size = hparams['hidden_size']
|
|
self.lr = hparams['lr']
|
|
for param_group in optimizer.param_groups:
|
|
param_group['lr'] = self.lr
|
|
self.step(0)
|
|
|
|
def step(self, num_updates):
|
|
constant_lr = self.constant_lr
|
|
warmup = min(num_updates / self.warmup_updates, 1.0)
|
|
rsqrt_decay = max(self.warmup_updates, num_updates) ** -0.5
|
|
rsqrt_hidden = self.hidden_size ** -0.5
|
|
self.lr = max(constant_lr * warmup * rsqrt_decay * rsqrt_hidden, 1e-7)
|
|
for param_group in self.optimizer.param_groups:
|
|
param_group['lr'] = self.lr
|
|
return self.lr
|
|
|
|
def get_lr(self):
|
|
return self.optimizer.param_groups[0]['lr']
|