mirror of
https://github.com/AIGC-Audio/AudioGPT.git
synced 2025-12-18 04:47:54 +01:00
132 lines
4.6 KiB
Python
132 lines
4.6 KiB
Python
from multiprocessing.pool import Pool
|
|
|
|
import matplotlib
|
|
|
|
from utils.pl_utils import data_loader
|
|
from utils.training_utils import RSQRTSchedule
|
|
from vocoders.base_vocoder import get_vocoder_cls, BaseVocoder
|
|
from modules.fastspeech.pe import PitchExtractor
|
|
|
|
matplotlib.use('Agg')
|
|
import os
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
import torch.distributed as dist
|
|
|
|
from tasks.base_task import BaseTask
|
|
from utils.hparams import hparams
|
|
from utils.text_encoder import TokenTextEncoder
|
|
import json
|
|
|
|
import torch
|
|
import torch.optim
|
|
import torch.utils.data
|
|
import utils
|
|
|
|
|
|
|
|
class TtsTask(BaseTask):
|
|
def __init__(self, *args, **kwargs):
|
|
self.vocoder = None
|
|
self.phone_encoder = self.build_phone_encoder(hparams['binary_data_dir'])
|
|
self.padding_idx = self.phone_encoder.pad()
|
|
self.eos_idx = self.phone_encoder.eos()
|
|
self.seg_idx = self.phone_encoder.seg()
|
|
self.saving_result_pool = None
|
|
self.saving_results_futures = None
|
|
self.stats = {}
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def build_scheduler(self, optimizer):
|
|
return RSQRTSchedule(optimizer)
|
|
|
|
def build_optimizer(self, model):
|
|
self.optimizer = optimizer = torch.optim.AdamW(
|
|
model.parameters(),
|
|
lr=hparams['lr'])
|
|
return optimizer
|
|
|
|
def build_dataloader(self, dataset, shuffle, max_tokens=None, max_sentences=None,
|
|
required_batch_size_multiple=-1, endless=False, batch_by_size=True):
|
|
devices_cnt = torch.cuda.device_count()
|
|
if devices_cnt == 0:
|
|
devices_cnt = 1
|
|
if required_batch_size_multiple == -1:
|
|
required_batch_size_multiple = devices_cnt
|
|
|
|
def shuffle_batches(batches):
|
|
np.random.shuffle(batches)
|
|
return batches
|
|
|
|
if max_tokens is not None:
|
|
max_tokens *= devices_cnt
|
|
if max_sentences is not None:
|
|
max_sentences *= devices_cnt
|
|
indices = dataset.ordered_indices()
|
|
if batch_by_size:
|
|
batch_sampler = utils.batch_by_size(
|
|
indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences,
|
|
required_batch_size_multiple=required_batch_size_multiple,
|
|
)
|
|
else:
|
|
batch_sampler = []
|
|
for i in range(0, len(indices), max_sentences):
|
|
batch_sampler.append(indices[i:i + max_sentences])
|
|
|
|
if shuffle:
|
|
batches = shuffle_batches(list(batch_sampler))
|
|
if endless:
|
|
batches = [b for _ in range(1000) for b in shuffle_batches(list(batch_sampler))]
|
|
else:
|
|
batches = batch_sampler
|
|
if endless:
|
|
batches = [b for _ in range(1000) for b in batches]
|
|
num_workers = dataset.num_workers
|
|
if self.trainer.use_ddp:
|
|
num_replicas = dist.get_world_size()
|
|
rank = dist.get_rank()
|
|
batches = [x[rank::num_replicas] for x in batches if len(x) % num_replicas == 0]
|
|
return torch.utils.data.DataLoader(dataset,
|
|
collate_fn=dataset.collater,
|
|
batch_sampler=batches,
|
|
num_workers=num_workers,
|
|
pin_memory=False)
|
|
|
|
def build_phone_encoder(self, data_dir):
|
|
phone_list_file = os.path.join(data_dir, 'phone_set.json')
|
|
|
|
phone_list = json.load(open(phone_list_file))
|
|
return TokenTextEncoder(None, vocab_list=phone_list, replace_oov=',')
|
|
|
|
def build_optimizer(self, model):
|
|
self.optimizer = optimizer = torch.optim.AdamW(
|
|
model.parameters(),
|
|
lr=hparams['lr'])
|
|
return optimizer
|
|
|
|
def test_start(self):
|
|
self.saving_result_pool = Pool(8)
|
|
self.saving_results_futures = []
|
|
self.vocoder: BaseVocoder = get_vocoder_cls(hparams)()
|
|
if hparams.get('pe_enable') is not None and hparams['pe_enable']:
|
|
self.pe = PitchExtractor().cuda()
|
|
utils.load_ckpt(self.pe, hparams['pe_ckpt'], 'model', strict=True)
|
|
self.pe.eval()
|
|
def test_end(self, outputs):
|
|
self.saving_result_pool.close()
|
|
[f.get() for f in tqdm(self.saving_results_futures)]
|
|
self.saving_result_pool.join()
|
|
return {}
|
|
|
|
##########
|
|
# utils
|
|
##########
|
|
def weights_nonzero_speech(self, target):
|
|
# target : B x T x mel
|
|
# Assign weight 1.0 to all labels except for padding (id=0).
|
|
dim = target.size(-1)
|
|
return target.abs().sum(-1, keepdim=True).ne(0).float().repeat(1, 1, dim)
|
|
|
|
if __name__ == '__main__':
|
|
TtsTask.start()
|