Files

211 lines
6.3 KiB
Python
Raw Permalink Normal View History

2019-07-19 08:46:23 +02:00
# visualisation tools for mimic2
2018-10-02 14:00:02 +02:00
import argparse
import csv
2021-04-12 11:47:39 +02:00
import os
2018-10-02 14:00:02 +02:00
import random
2021-04-12 11:47:39 +02:00
from statistics import StatisticsError, mean, median, mode, stdev
import matplotlib.pyplot as plt
import seaborn as sns
2018-10-02 14:00:02 +02:00
from text.cmudict import CMUDict
2021-04-12 11:47:39 +02:00
2018-10-02 14:00:02 +02:00
def get_audio_seconds(frames):
2021-04-12 11:47:39 +02:00
return (frames * 12.5) / 1000
2018-10-02 14:00:02 +02:00
def append_data_statistics(meta_data):
# get data statistics
for char_cnt in meta_data:
data = meta_data[char_cnt]["data"]
audio_len_list = [d["audio_len"] for d in data]
mean_audio_len = mean(audio_len_list)
try:
mode_audio_list = [round(d["audio_len"], 2) for d in data]
mode_audio_len = mode(mode_audio_list)
except StatisticsError:
mode_audio_len = audio_len_list[0]
median_audio_len = median(audio_len_list)
try:
2021-04-12 11:47:39 +02:00
std = stdev(d["audio_len"] for d in data)
2019-07-19 08:46:23 +02:00
except StatisticsError:
2018-10-02 14:00:02 +02:00
std = 0
meta_data[char_cnt]["mean"] = mean_audio_len
meta_data[char_cnt]["median"] = median_audio_len
meta_data[char_cnt]["mode"] = mode_audio_len
meta_data[char_cnt]["std"] = std
return meta_data
def process_meta_data(path):
meta_data = {}
# load meta data
with open(path, "r", encoding="utf-8") as f:
2021-04-12 11:47:39 +02:00
data = csv.reader(f, delimiter="|")
2018-10-02 14:00:02 +02:00
for row in data:
frames = int(row[2])
utt = row[3]
audio_len = get_audio_seconds(frames)
char_count = len(utt)
if not meta_data.get(char_count):
2021-04-12 11:47:39 +02:00
meta_data[char_count] = {"data": []}
2018-10-02 14:00:02 +02:00
meta_data[char_count]["data"].append(
{
"utt": utt,
"frames": frames,
"audio_len": audio_len,
2021-04-12 11:47:39 +02:00
"row": "{}|{}|{}|{}".format(row[0], row[1], row[2], row[3]),
2018-10-02 14:00:02 +02:00
}
)
meta_data = append_data_statistics(meta_data)
return meta_data
def get_data_points(meta_data):
2020-08-04 14:07:47 +02:00
x = meta_data
2021-04-12 11:47:39 +02:00
y_avg = [meta_data[d]["mean"] for d in meta_data]
y_mode = [meta_data[d]["mode"] for d in meta_data]
y_median = [meta_data[d]["median"] for d in meta_data]
y_std = [meta_data[d]["std"] for d in meta_data]
y_num_samples = [len(meta_data[d]["data"]) for d in meta_data]
2018-10-02 14:00:02 +02:00
return {
"x": x,
"y_avg": y_avg,
"y_mode": y_mode,
"y_median": y_median,
"y_std": y_std,
2021-04-12 11:47:39 +02:00
"y_num_samples": y_num_samples,
2018-10-02 14:00:02 +02:00
}
def save_training(file_path, meta_data):
rows = []
for char_cnt in meta_data:
2021-04-12 11:47:39 +02:00
data = meta_data[char_cnt]["data"]
2018-10-02 14:00:02 +02:00
for d in data:
2021-04-12 11:47:39 +02:00
rows.append(d["row"] + "\n")
2018-10-02 14:00:02 +02:00
random.shuffle(rows)
with open(file_path, "w+", encoding="utf-8") as f:
2018-10-02 14:00:02 +02:00
for row in rows:
f.write(row)
def plot(meta_data, save_path=None):
save = False
if save_path:
save = True
graph_data = get_data_points(meta_data)
2021-04-12 11:47:39 +02:00
x = graph_data["x"]
y_avg = graph_data["y_avg"]
y_std = graph_data["y_std"]
y_mode = graph_data["y_mode"]
y_median = graph_data["y_median"]
y_num_samples = graph_data["y_num_samples"]
2019-07-19 08:46:23 +02:00
2018-10-02 14:00:02 +02:00
plt.figure()
2021-04-12 11:47:39 +02:00
plt.plot(x, y_avg, "ro")
2018-10-02 14:00:02 +02:00
plt.xlabel("character lengths", fontsize=30)
plt.ylabel("avg seconds", fontsize=30)
if save:
name = "char_len_vs_avg_secs"
plt.savefig(os.path.join(save_path, name))
2019-07-19 08:46:23 +02:00
2018-10-02 14:00:02 +02:00
plt.figure()
2021-04-12 11:47:39 +02:00
plt.plot(x, y_mode, "ro")
2018-10-02 14:00:02 +02:00
plt.xlabel("character lengths", fontsize=30)
plt.ylabel("mode seconds", fontsize=30)
if save:
name = "char_len_vs_mode_secs"
plt.savefig(os.path.join(save_path, name))
plt.figure()
2021-04-12 11:47:39 +02:00
plt.plot(x, y_median, "ro")
2018-10-02 14:00:02 +02:00
plt.xlabel("character lengths", fontsize=30)
plt.ylabel("median seconds", fontsize=30)
if save:
name = "char_len_vs_med_secs"
plt.savefig(os.path.join(save_path, name))
plt.figure()
2021-04-12 11:47:39 +02:00
plt.plot(x, y_std, "ro")
2018-10-02 14:00:02 +02:00
plt.xlabel("character lengths", fontsize=30)
plt.ylabel("standard deviation", fontsize=30)
if save:
name = "char_len_vs_std"
plt.savefig(os.path.join(save_path, name))
plt.figure()
2021-04-12 11:47:39 +02:00
plt.plot(x, y_num_samples, "ro")
2018-10-02 14:00:02 +02:00
plt.xlabel("character lengths", fontsize=30)
plt.ylabel("number of samples", fontsize=30)
if save:
name = "char_len_vs_num_samples"
plt.savefig(os.path.join(save_path, name))
def plot_phonemes(train_path, cmu_dict_path, save_path):
cmudict = CMUDict(cmu_dict_path)
phonemes = {}
with open(train_path, "r", encoding="utf-8") as f:
2021-04-12 11:47:39 +02:00
data = csv.reader(f, delimiter="|")
2018-10-02 14:00:02 +02:00
phonemes["None"] = 0
for row in data:
words = row[3].split()
for word in words:
pho = cmudict.lookup(word)
if pho:
indie = pho[0].split()
for nemes in indie:
if phonemes.get(nemes):
phonemes[nemes] += 1
else:
phonemes[nemes] = 1
else:
phonemes["None"] += 1
x, y = [], []
for k, v in phonemes.items():
x.append(k)
y.append(v)
2019-07-19 08:46:23 +02:00
2018-10-02 14:00:02 +02:00
plt.figure()
plt.rcParams["figure.figsize"] = (50, 20)
2021-11-02 19:10:18 +01:00
barplot = sns.barplot(x=x, y=y)
2018-10-02 14:00:02 +02:00
if save_path:
2019-07-19 08:46:23 +02:00
fig = barplot.get_figure()
2018-10-02 14:00:02 +02:00
fig.savefig(os.path.join(save_path, "phoneme_dist"))
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
2021-04-12 11:47:39 +02:00
"--train_file_path",
required=True,
help="this is the path to the train.txt file that the preprocess.py script creates",
2018-10-02 14:00:02 +02:00
)
2021-04-12 11:47:39 +02:00
parser.add_argument("--save_to", help="path to save charts of data to")
parser.add_argument("--cmu_dict_path", help="give cmudict-0.7b to see phoneme distribution")
2018-10-02 14:00:02 +02:00
args = parser.parse_args()
meta_data = process_meta_data(args.train_file_path)
plt.rcParams["figure.figsize"] = (10, 5)
plot(meta_data, save_path=args.save_to)
if args.cmu_dict_path:
plt.rcParams["figure.figsize"] = (30, 10)
plot_phonemes(args.train_file_path, args.cmu_dict_path, args.save_to)
2019-07-19 08:46:23 +02:00
2018-10-02 14:00:02 +02:00
plt.show()
2021-04-12 11:47:39 +02:00
if __name__ == "__main__":
2019-07-19 08:46:23 +02:00
main()