mirror of
https://github.com/voice-cloning-app/Voice-Cloning-App.git
synced 2026-02-24 12:10:29 +01:00
291 lines
7.7 KiB
Python
291 lines
7.7 KiB
Python
import torch
|
|
import random
|
|
import os
|
|
import unicodedata
|
|
from PIL import Image
|
|
|
|
from dataset import CHARACTER_ENCODING
|
|
from dataset.utils import get_invalid_characters
|
|
from training import BASE_SYMBOLS, SEED, TRAIN_FILE, VALIDATION_FILE
|
|
from training.tacotron2_model.utils import get_mask_from_lengths
|
|
from training.clean_text import clean_text
|
|
|
|
CHECKPOINT_SIZE_MB = 333
|
|
BATCH_SIZE_PER_GB = 2.5
|
|
LEARNING_RATE_PER_64 = 4e-4
|
|
MAXIMUM_LEARNING_RATE = 4e-4
|
|
EARLY_STOPPING_WINDOW = 10
|
|
EARLY_STOPPING_MIN_DIFFERENCE = 0.0005
|
|
|
|
|
|
def get_gpu_memory(gpu_index):
|
|
"""
|
|
Get available memory of a GPU.
|
|
|
|
Parameters
|
|
----------
|
|
gpu_index : int
|
|
Index of GPU
|
|
|
|
Returns
|
|
-------
|
|
int
|
|
Available GPU memory in GB
|
|
"""
|
|
gpu_memory = torch.cuda.get_device_properties(gpu_index).total_memory
|
|
memory_in_use = torch.cuda.memory_allocated(gpu_index)
|
|
available_memory = gpu_memory - memory_in_use
|
|
return available_memory // 1024 // 1024 // 1024
|
|
|
|
|
|
def get_available_memory():
|
|
"""
|
|
Get available GPU memory in GB.
|
|
|
|
Returns
|
|
-------
|
|
int
|
|
Available GPU memory in GB
|
|
"""
|
|
available_memory_gb = 0
|
|
|
|
for i in range(torch.cuda.device_count()):
|
|
available_memory_gb += get_gpu_memory(i)
|
|
|
|
return available_memory_gb
|
|
|
|
|
|
def get_batch_size(available_memory_gb):
|
|
"""
|
|
Calulate batch size.
|
|
|
|
Parameters
|
|
----------
|
|
available_memory_gb : int
|
|
Available GPU memory in GB
|
|
|
|
Returns
|
|
-------
|
|
int
|
|
Batch size
|
|
"""
|
|
return int(available_memory_gb * BATCH_SIZE_PER_GB)
|
|
|
|
|
|
def get_learning_rate(batch_size):
|
|
"""
|
|
Calulate learning rate.
|
|
|
|
Parameters
|
|
----------
|
|
batch_size : int
|
|
Batch size
|
|
|
|
Returns
|
|
-------
|
|
float
|
|
Learning rate
|
|
"""
|
|
return min(
|
|
(batch_size / 64) ** 0.5 * LEARNING_RATE_PER_64, # Adam Learning Rate is proportional to sqrt(batch_size)
|
|
MAXIMUM_LEARNING_RATE,
|
|
)
|
|
|
|
|
|
def load_labels_file(filepath):
|
|
"""
|
|
Load labels file
|
|
|
|
Parameters
|
|
----------
|
|
filepath : str
|
|
Path to text file
|
|
|
|
Returns
|
|
-------
|
|
list
|
|
List of samples
|
|
"""
|
|
with open(filepath, encoding=CHARACTER_ENCODING) as f:
|
|
return [line.strip().split("|") for line in f]
|
|
|
|
|
|
def validate_dataset(filepaths_and_text, dataset_directory, symbols):
|
|
"""
|
|
Validates dataset has required files and a valid character set
|
|
|
|
Parameters
|
|
----------
|
|
filepaths_and_text : list
|
|
List of samples
|
|
dataset_directory : str
|
|
Path to dataset audio directory
|
|
symbols : list
|
|
List of supported symbols
|
|
|
|
Raises
|
|
-------
|
|
AssertionError
|
|
If files are missing or invalid characters are found
|
|
"""
|
|
missing_files = set()
|
|
invalid_characters = set()
|
|
wavs = os.listdir(dataset_directory)
|
|
for filename, text in filepaths_and_text:
|
|
text = clean_text(text, remove_invalid_characters=False)
|
|
if filename not in wavs:
|
|
missing_files.add(filename)
|
|
invalid_characters_for_row = get_invalid_characters(text, symbols)
|
|
if invalid_characters_for_row:
|
|
invalid_characters.update(invalid_characters_for_row)
|
|
|
|
assert not missing_files, f"Missing files: {(',').join(missing_files)}"
|
|
assert (
|
|
not invalid_characters
|
|
), f"Invalid characters in text (for alphabet): {','.join([f'{c} ({unicodedata.name(c)})' for c in invalid_characters])}"
|
|
|
|
|
|
def train_test_split(filepaths_and_text, train_size):
|
|
"""
|
|
Split dataset into train & test data
|
|
|
|
Parameters
|
|
----------
|
|
filepaths_and_text : list
|
|
List of samples
|
|
train_size : float
|
|
Percentage of entries to use for training (rest used for testing)
|
|
|
|
Returns
|
|
-------
|
|
(list, list)
|
|
List of train and test samples
|
|
"""
|
|
train_cutoff = int(len(filepaths_and_text) * train_size)
|
|
train_files = filepaths_and_text[:train_cutoff]
|
|
test_files = filepaths_and_text[train_cutoff:]
|
|
print(f"{len(train_files)} train files, {len(test_files)} test files")
|
|
return train_files, test_files
|
|
|
|
|
|
def load_symbols(alphabet_file):
|
|
"""
|
|
Get alphabet and punctuation for a given alphabet file.
|
|
|
|
Parameters
|
|
----------
|
|
alphabet_file : str
|
|
Path to alphabnet file
|
|
|
|
Returns
|
|
-------
|
|
list
|
|
List of symbols (punctuation + alphabet)
|
|
"""
|
|
symbols = BASE_SYMBOLS.copy()
|
|
|
|
with open(alphabet_file, encoding=CHARACTER_ENCODING) as f:
|
|
lines = [l.strip() for l in f.readlines() if l.strip() and not l.startswith("#")]
|
|
|
|
for line in lines:
|
|
if line not in symbols:
|
|
symbols.append(line)
|
|
|
|
return symbols
|
|
|
|
|
|
def check_early_stopping(validation_losses):
|
|
"""
|
|
Decide whether to stop training depending on validation losses.
|
|
|
|
Parameters
|
|
----------
|
|
validation_losses : list
|
|
List of validation loss scores
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
True if training should stop, otherwise False
|
|
"""
|
|
if len(validation_losses) >= EARLY_STOPPING_WINDOW:
|
|
losses = validation_losses[-EARLY_STOPPING_WINDOW:]
|
|
difference = max(losses) - min(losses)
|
|
if difference < EARLY_STOPPING_MIN_DIFFERENCE:
|
|
return True
|
|
return False
|
|
|
|
|
|
def calc_avgmax_attention(mel_lengths, text_lengths, alignment):
|
|
"""
|
|
Calculate Average Max Attention for Tacotron2 Alignment.
|
|
Roughly represents how well the model is linking the text to the audio.
|
|
Low values during training typically result in unstable speech during inference.
|
|
|
|
Parameters
|
|
----------
|
|
mel_lengths : torch.Tensor
|
|
lengths of each mel in the batch
|
|
text_lengths : torch.Tensor
|
|
lengths of each text in the batch
|
|
alignment : torch.Tensor
|
|
alignments from model of shape [B, mel_length, text_length]
|
|
|
|
Returns
|
|
-------
|
|
float
|
|
average max attention
|
|
"""
|
|
mel_mask = get_mask_from_lengths(mel_lengths, device=alignment.device)
|
|
txt_mask = get_mask_from_lengths(text_lengths, device=alignment.device)
|
|
# [B, mel_T, 1] * [B, 1, txt_T] -> [B, mel_T, txt_T]
|
|
attention_mask = txt_mask.unsqueeze(1) & mel_mask.unsqueeze(2)
|
|
|
|
alignment = alignment.data.masked_fill(~attention_mask, 0.0)
|
|
# [B, mel_T, txt_T]
|
|
avg_prob = alignment.data.amax(dim=2).sum(1).div(mel_lengths.to(alignment)).mean().item()
|
|
return avg_prob
|
|
|
|
|
|
def generate_timelapse_gif(folder, output_path):
|
|
"""
|
|
Generates a GIF timelapse from a folder of images.
|
|
|
|
Parameters
|
|
----------
|
|
folder : str
|
|
Path to folder of images
|
|
output_path : str
|
|
Path to save resulting GIF to
|
|
"""
|
|
images = sorted(os.listdir(folder), key=lambda filename: int(filename.split("_")[1].split(".")[0]))
|
|
frames = [Image.open(os.path.join(folder, image)) for image in images]
|
|
frames[0].save(output_path, format="GIF", append_images=frames[1:], save_all=True, duration=200, loop=0)
|
|
|
|
|
|
def create_trainlist_vallist_files(folder, metadata_path, train_size=0.8):
|
|
"""
|
|
Creates trainlist & vallist files for compatibility with other notebooks.
|
|
|
|
Parameters
|
|
----------
|
|
folder : str
|
|
Destination folder
|
|
metadata_path : str
|
|
Path to metadata file
|
|
train_size : float (optional)
|
|
Percentage of samples to use for training (default is 80%/0.8)
|
|
"""
|
|
random.seed(SEED)
|
|
filepaths_and_text = load_labels_file(metadata_path)
|
|
random.shuffle(filepaths_and_text)
|
|
train_files, test_files = train_test_split(filepaths_and_text, train_size)
|
|
|
|
with open(os.path.join(folder, TRAIN_FILE), "w", encoding=CHARACTER_ENCODING) as f:
|
|
for line in train_files:
|
|
f.write(f"{line[0]}|{line[1]}\n")
|
|
|
|
with open(os.path.join(folder, VALIDATION_FILE), "w", encoding=CHARACTER_ENCODING) as f:
|
|
for line in test_files:
|
|
f.write(f"{line[0]}|{line[1]}\n")
|