mirror of
https://github.com/AIGC-Audio/AudioGPT.git
synced 2025-12-23 15:19:38 +01:00
130 lines
5.8 KiB
Python
Executable File
130 lines
5.8 KiB
Python
Executable File
from tasks.tts.fs2 import FastSpeech2Task
|
|
from modules.syntaspeech.multi_window_disc import Discriminator
|
|
from utils.hparams import hparams
|
|
from torch import nn
|
|
import torch
|
|
import torch.optim
|
|
import torch.utils.data
|
|
import utils
|
|
|
|
|
|
class FastSpeech2AdvTask(FastSpeech2Task):
|
|
def build_model(self):
|
|
self.build_tts_model()
|
|
if hparams['load_ckpt'] != '':
|
|
self.load_ckpt(hparams['load_ckpt'], strict=False)
|
|
utils.print_arch(self.model, 'Generator')
|
|
self.build_disc_model()
|
|
if not hasattr(self, 'gen_params'):
|
|
self.gen_params = list(self.model.parameters())
|
|
return self.model
|
|
|
|
def build_disc_model(self):
|
|
disc_win_num = hparams['disc_win_num']
|
|
h = hparams['mel_disc_hidden_size']
|
|
self.mel_disc = Discriminator(
|
|
time_lengths=[32, 64, 128][:disc_win_num],
|
|
freq_length=80, hidden_size=h, kernel=(3, 3)
|
|
)
|
|
self.disc_params = list(self.mel_disc.parameters())
|
|
utils.print_arch(self.mel_disc, model_name='Mel Disc')
|
|
|
|
def _training_step(self, sample, batch_idx, optimizer_idx):
|
|
log_outputs = {}
|
|
loss_weights = {}
|
|
disc_start = hparams['mel_gan'] and self.global_step >= hparams["disc_start_steps"] and \
|
|
hparams['lambda_mel_adv'] > 0
|
|
if optimizer_idx == 0:
|
|
#######################
|
|
# Generator #
|
|
#######################
|
|
log_outputs, model_out = self.run_model(self.model, sample, return_output=True)
|
|
self.model_out = {k: v.detach() for k, v in model_out.items() if isinstance(v, torch.Tensor)}
|
|
if disc_start:
|
|
self.disc_cond = disc_cond = self.model_out['decoder_inp'].detach() \
|
|
if hparams['use_cond_disc'] else None
|
|
if hparams['mel_loss_no_noise']:
|
|
self.add_mel_loss(model_out['mel_out_nonoise'], sample['mels'], log_outputs)
|
|
mel_p = model_out['mel_out']
|
|
if hasattr(self.model, 'out2mel'):
|
|
mel_p = self.model.out2mel(mel_p)
|
|
o_ = self.mel_disc(mel_p, disc_cond)
|
|
p_, pc_ = o_['y'], o_['y_c']
|
|
|
|
if p_ is not None:
|
|
log_outputs['a'] = self.mse_loss_fn(p_, p_.new_ones(p_.size()))
|
|
loss_weights['a'] = hparams['lambda_mel_adv']
|
|
if pc_ is not None:
|
|
log_outputs['ac'] = self.mse_loss_fn(pc_, pc_.new_ones(pc_.size()))
|
|
loss_weights['ac'] = hparams['lambda_mel_adv']
|
|
else:
|
|
#######################
|
|
# Discriminator #
|
|
#######################
|
|
if disc_start and self.global_step % hparams['disc_interval'] == 0:
|
|
if hparams['rerun_gen']:
|
|
with torch.no_grad():
|
|
_, model_out = self.run_model(self.model, sample, return_output=True)
|
|
else:
|
|
model_out = self.model_out
|
|
mel_g = sample['mels']
|
|
mel_p = model_out['mel_out']
|
|
if hasattr(self.model, 'out2mel'):
|
|
mel_p = self.model.out2mel(mel_p)
|
|
|
|
o = self.mel_disc(mel_g, self.disc_cond)
|
|
p, pc = o['y'], o['y_c']
|
|
o_ = self.mel_disc(mel_p, self.disc_cond)
|
|
p_, pc_ = o_['y'], o_['y_c']
|
|
|
|
if p_ is not None:
|
|
log_outputs["r"] = self.mse_loss_fn(p, p.new_ones(p.size()))
|
|
log_outputs["f"] = self.mse_loss_fn(p_, p_.new_zeros(p_.size()))
|
|
|
|
if pc_ is not None:
|
|
log_outputs["rc"] = self.mse_loss_fn(pc, pc.new_ones(pc.size()))
|
|
log_outputs["fc"] = self.mse_loss_fn(pc_, pc_.new_zeros(pc_.size()))
|
|
|
|
if len(log_outputs) == 0:
|
|
return None
|
|
total_loss = sum([loss_weights.get(k, 1) * v for k, v in log_outputs.items()])
|
|
|
|
log_outputs['bs'] = sample['mels'].shape[0]
|
|
return total_loss, log_outputs
|
|
|
|
def configure_optimizers(self):
|
|
if not hasattr(self, 'gen_params'):
|
|
self.gen_params = list(self.model.parameters())
|
|
optimizer_gen = torch.optim.AdamW(
|
|
self.gen_params,
|
|
lr=hparams['lr'],
|
|
betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']),
|
|
weight_decay=hparams['weight_decay'])
|
|
optimizer_disc = torch.optim.AdamW(
|
|
self.disc_params,
|
|
lr=hparams['disc_lr'],
|
|
betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']),
|
|
**hparams["discriminator_optimizer_params"]) if len(self.disc_params) > 0 else None
|
|
self.scheduler = self.build_scheduler({'gen': optimizer_gen, 'disc': optimizer_disc})
|
|
return [optimizer_gen, optimizer_disc]
|
|
|
|
def build_scheduler(self, optimizer):
|
|
return {
|
|
"gen": super().build_scheduler(optimizer['gen']),
|
|
"disc": torch.optim.lr_scheduler.StepLR(
|
|
optimizer=optimizer["disc"],
|
|
**hparams["discriminator_scheduler_params"]) if optimizer["disc"] is not None else None,
|
|
}
|
|
|
|
def on_before_optimization(self, opt_idx):
|
|
if opt_idx == 0:
|
|
nn.utils.clip_grad_norm_(self.gen_params, hparams['generator_grad_norm'])
|
|
else:
|
|
nn.utils.clip_grad_norm_(self.disc_params, hparams["discriminator_grad_norm"])
|
|
|
|
def on_after_optimization(self, epoch, batch_idx, optimizer, optimizer_idx):
|
|
if optimizer_idx == 0:
|
|
self.scheduler['gen'].step(self.global_step)
|
|
else:
|
|
self.scheduler['disc'].step(max(self.global_step - hparams["disc_start_steps"], 1))
|