Files
TTS/spkr-attr/cv_data_processing.py

185 lines
6.5 KiB
Python
Raw Permalink Normal View History

2022-06-15 16:58:20 +00:00
import argparse
import json
import os
import pickle
import random
import subprocess
2022-06-17 15:01:02 +00:00
from argparse import RawTextHelpFormatter
2022-06-15 16:58:20 +00:00
import numpy as np
import pandas as pd
from pydub import AudioSegment
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
2022-06-17 15:01:02 +00:00
from sklearn.utils import shuffle
2022-06-15 16:58:20 +00:00
from tqdm import tqdm
2022-06-17 15:01:02 +00:00
def load_df(filename, n):
if n == "All":
df = pd.read_csv(filename, sep="\t")
else:
df = shuffle(pd.read_csv(filename, sep="\t")).head(n=int(n))
2022-06-15 16:58:20 +00:00
return df
2022-06-17 15:01:02 +00:00
def analyze_df(df, label):
2022-06-15 16:58:20 +00:00
label_dict = {}
2022-06-17 15:01:02 +00:00
df_filtered = df[df[label].notnull() & df[label].notna()]
df_final = df_filtered[df_filtered[label] != "other"][label]
2022-06-15 16:58:20 +00:00
for ac in df_final.unique():
2022-06-17 15:01:02 +00:00
speakers = df[df[label] == ac]["client_id"].unique()
2022-06-15 16:58:20 +00:00
no_speakers = len(speakers)
2022-06-17 15:01:02 +00:00
label_dict[ac] = speakers
print(f'"{ac}" unique speakers no.: {no_speakers}')
2022-06-15 16:58:20 +00:00
return label_dict
2022-06-17 15:01:02 +00:00
2022-06-15 16:58:20 +00:00
def train_test_split(df, label, label_dict, split=0.1):
2022-06-17 15:01:02 +00:00
print(len(label_dict.keys()), label_dict.keys())
2022-06-15 16:58:20 +00:00
train = pd.DataFrame()
test = pd.DataFrame()
for l in label_dict.keys():
spkrs = label_dict[l]
2022-06-17 15:01:02 +00:00
train_spkrs = spkrs[: int(len(spkrs) * (1 - split))]
test_spkrs = spkrs[int(len(spkrs) * (1 - split)) :]
train = pd.concat([train, df[df.client_id.isin(train_spkrs)]])
test = pd.concat([test, df[df.client_id.isin(test_spkrs)]])
train = train[train[label] != "other"]
test = test[test[label] != "other"]
2022-06-15 16:58:20 +00:00
return train, test
2022-06-17 15:01:02 +00:00
def mp3_to_wav(mp3_list, data_path, data_split_path, json_file):
2022-06-15 16:58:20 +00:00
waves = []
for i in tqdm(mp3_list):
sound = AudioSegment.from_mp3(f"{data_path}/{i}")
wav = f'{data_path}/{i.split(".mp3")[0]}.wav'
waves.append(wav)
sound.export(wav, format="wav")
2022-06-17 15:01:02 +00:00
with open(f"{data_split_path}", "w") as f:
f.write("wav_filename|gender|text|speaker_name\n")
for i, j in enumerate(waves):
f.write(f"{j}|m|blabla|ID_{i}\n")
write_config_dataset(data_path, data_split_path, json_file)
2022-06-15 16:58:20 +00:00
2022-06-17 15:01:02 +00:00
def write_config_dataset(data_path, data_split_path, json_path):
cwd = os.getcwd()
data_split_full_path = os.path.join(cwd, data_split_path)
2022-06-15 16:58:20 +00:00
data = {
2022-06-17 15:01:02 +00:00
"model": "vits",
"datasets": [
2022-06-15 16:58:20 +00:00
{
2022-06-17 15:01:02 +00:00
"name": "brspeech",
"path": data_path,
"meta_file_train": data_split_full_path,
"language": "en",
"meta_file_val": "null",
"meta_file_attn_mask": "",
2022-06-15 16:58:20 +00:00
}
2022-06-17 15:01:02 +00:00
],
2022-06-15 16:58:20 +00:00
}
2022-06-17 15:01:02 +00:00
with open(json_path, "w") as outfile:
2022-06-15 16:58:20 +00:00
json.dump(data, outfile)
2022-06-17 15:01:02 +00:00
def compute_speaker_emb(tts_root_dir, spkr_emb_model, spkr_emb_config, config_dataset, out_emb_json):
cmd = [
"python",
f"{tts_root_dir}/TTS/bin/compute_embeddings.py",
"--no_eval",
"True",
spkr_emb_model,
spkr_emb_config,
config_dataset,
"--output_path",
out_emb_json,
]
2022-06-15 16:58:20 +00:00
print(" ".join(cmd))
print(subprocess.check_output(cmd).decode("utf-8"))
2022-06-17 15:01:02 +00:00
def compose_dataset(embeddings_json, df, label, out_array_path):
with open(embeddings_json) as f:
embs = json.load(f)
2022-06-15 16:58:20 +00:00
e = []
l = []
for i in tqdm(df.path):
2022-06-17 15:01:02 +00:00
id_ = i.split(".mp3")[0] + ".wav"
e.append(embs[id_]["embedding"])
l.append(df[df["path"] == i][label].item())
2022-06-15 16:58:20 +00:00
values = np.array(l)
label_encoder = LabelEncoder()
integer_encoded = label_encoder.fit_transform(values)
2022-06-17 15:01:02 +00:00
print(np.unique(values, return_counts=True), np.unique(integer_encoded))
2022-06-15 16:58:20 +00:00
onehot_encoder = OneHotEncoder(sparse=False)
integer_encoded = integer_encoded.reshape(len(integer_encoded), 1)
onehot = onehot_encoder.fit_transform(integer_encoded)
2022-06-17 15:01:02 +00:00
d = list(zip(e, onehot))
2022-06-15 16:58:20 +00:00
random.shuffle(d)
2022-06-17 15:01:02 +00:00
data, labels = zip(*d)
2022-06-15 16:58:20 +00:00
data_name = f"{out_array_path}_data.npy"
label_name = f"{out_array_path}_labels.npy"
np.save(data_name, data)
2022-06-17 15:01:02 +00:00
np.save(label_name, labels)
_, counts = np.unique(values, return_counts=True)
weight = {}
2022-06-15 16:58:20 +00:00
for i in np.unique(integer_encoded):
2022-06-17 15:01:02 +00:00
weight[i] = (1 / counts[i]) * (len(values) / 2.0)
2022-06-15 16:58:20 +00:00
print(weight)
2022-06-17 15:01:02 +00:00
with open(f"{out_array_path}-weights.pkl", "wb") as f:
2022-06-15 16:58:20 +00:00
pickle.dump(weight, f)
print(f"Data: {np.array(data).shape} ,{data_name} \n Labels: {np.array(labels).shape} , {label_name}")
2022-06-17 15:01:02 +00:00
2022-06-15 16:58:20 +00:00
def main():
2022-06-17 15:01:02 +00:00
parser = argparse.ArgumentParser(
description="A scirpt to prepare CV data for speaker embedding classification.\n"
"Example runs:\n"
"python cv_data_processing.py --data /datasets/cv/8.0/en/train.tsv --attribute age --out_dir result --num_rec 100 --tts_root_dir /mount-storage/TTS/TTS --spkr_emb_model models/model_se.pth.tar --spkr_emb_config models/config_se.json",
formatter_class=RawTextHelpFormatter,
2022-06-15 16:58:20 +00:00
)
2022-06-17 15:01:02 +00:00
parser.add_argument("--data", help="Full path of CV data in tsv format", required=True)
2022-06-15 16:58:20 +00:00
parser.add_argument(
2022-06-17 15:01:02 +00:00
"--num_rec", help="Number of records to use out of --data. Supply All to use all of the records", required=True
2022-06-15 16:58:20 +00:00
)
2022-06-17 15:01:02 +00:00
parser.add_argument("--attribute", help="Speaker attribute to sample from", required=True)
parser.add_argument("--out_dir", required=True)
parser.add_argument("--spkr_emb_model", required=True)
parser.add_argument("--spkr_emb_config", required=True)
parser.add_argument("--tts_root_dir", required=True)
2022-06-15 16:58:20 +00:00
args = parser.parse_args()
2022-06-17 15:01:02 +00:00
abs_path = "/".join(args.data.split("/")[:-1])
data_path = os.path.join(abs_path, "clips")
2022-06-15 16:58:20 +00:00
if not os.path.exists(args.out_dir):
os.makedirs(args.out_dir)
2022-06-17 15:01:02 +00:00
df = load_df(args.data, args.num_rec)
print(f"Data header: {list(df)}")
assert args.attribute in list(df)
label_dict = analyze_df(df, args.attribute)
train_df, test_df = train_test_split(df, args.attribute, label_dict)
2022-06-15 16:58:20 +00:00
for split in ["train", "test"]:
2022-06-17 15:01:02 +00:00
if split == "train":
2022-06-15 16:58:20 +00:00
df_subset = train_df
else:
df_subset = test_df
2022-06-17 15:01:02 +00:00
tts_csv = os.path.join(args.out_dir, f"{args.attribute}_{split}_tts.csv")
config_dataset = os.path.join(args.out_dir, f"{args.attribute}_{split}_config_dataset.json")
mp3_to_wav(df_subset["path"], data_path, tts_csv, config_dataset)
out_emb_json = os.path.join(args.out_dir, f"{args.attribute}_{split}_spkr_embs.json")
compute_speaker_emb(args.tts_root_dir, args.spkr_emb_model, args.spkr_emb_config, config_dataset, out_emb_json)
out_array_path = os.path.join(args.out_dir, f"{args.attribute}_{split}")
compose_dataset(out_emb_json, df_subset, args.attribute, out_array_path)
print("Done.")
2022-06-15 16:58:20 +00:00
if __name__ == "__main__":
main()