Merge pull request #27 from liuhaozhe6788/develop

误差曲线更新脚本修改误差序列文件的路径并处理了异常
This commit is contained in:
liuhaozhe6788
2023-01-01 12:22:24 +08:00
committed by GitHub

View File

@@ -1,16 +1,17 @@
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.animation import FuncAnimation
import argparse
def main(module_name):
if module_name == "synthesizer":
if module_name == "syn":
# function to update the data
def my_function(i):
# get data
train_loss_arr = np.load("src/synthesizer_loss/synthesizer_train_loss.npy")
dev_loss_arr = np.load("src/synthesizer_loss/synthesizer_dev_loss.npy")
train_loss_arr = np.load("synthesizer_loss/synthesizer_train_loss.npy")
dev_loss_arr = np.load("synthesizer_loss/synthesizer_dev_loss.npy")
# clear axis
ax.cla()
# plot cpu
@@ -33,12 +34,12 @@ def main(module_name):
ani = FuncAnimation(fig, my_function, interval=1000)
plt.show()
elif module_name == "vocoder":
elif module_name == "voc":
# function to update the data
def my_function(i):
# get data
train_loss_arr = np.load("src/vocoder_loss/vocoder_train_loss.npy")
dev_loss_arr = np.load("src/vocoder_loss/vocoder_dev_loss.npy")
train_loss_arr = np.load("vocoder_loss/vocoder_train_loss.npy")
dev_loss_arr = np.load("vocoder_loss/vocoder_dev_loss.npy")
# clear axis
ax.cla()
# plot cpu
@@ -61,4 +62,17 @@ def main(module_name):
ani = FuncAnimation(fig, my_function, interval=1000)
plt.show()
main("synthesizer")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("model", type=str, help= \
"The model to show plot, model name is syn or voc")
args = parser.parse_args()
arg_dict = vars(args)
try:
main(arg_dict["model"])
except Exception as e:
print("Caught exception: %s" % repr(e))
print("Restarting\n")