diff --git a/TTS/bin/convert_melgan_torch_to_tf.py b/TTS/bin/convert_melgan_torch_to_tf.py index 43581348..c1fb8498 100644 --- a/TTS/bin/convert_melgan_torch_to_tf.py +++ b/TTS/bin/convert_melgan_torch_to_tf.py @@ -6,7 +6,7 @@ import numpy as np import tensorflow as tf import torch -from TTS.utils.io import load_config +from TTS.utils.io import load_config, load_fsspec from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import ( compare_torch_tf, convert_tf_name, @@ -33,7 +33,7 @@ num_speakers = 0 # init torch model model = setup_generator(c) -checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu")) +checkpoint = load_fsspec(args.torch_model_path, map_location=torch.device("cpu")) state_dict = checkpoint["model"] model.load_state_dict(state_dict) model.remove_weight_norm() diff --git a/TTS/bin/convert_tacotron2_torch_to_tf.py b/TTS/bin/convert_tacotron2_torch_to_tf.py index a6fb5d9b..78c6b362 100644 --- a/TTS/bin/convert_tacotron2_torch_to_tf.py +++ b/TTS/bin/convert_tacotron2_torch_to_tf.py @@ -13,7 +13,7 @@ from TTS.tts.tf.models.tacotron2 import Tacotron2 from TTS.tts.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf from TTS.tts.tf.utils.generic_utils import save_checkpoint from TTS.tts.utils.text.symbols import phonemes, symbols -from TTS.utils.io import load_config +from TTS.utils.io import load_config, load_fsspec sys.path.append("/home/erogol/Projects") os.environ["CUDA_VISIBLE_DEVICES"] = "" @@ -32,7 +32,7 @@ num_speakers = 0 # init torch model model = setup_model(c) -checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu")) +checkpoint = load_fsspec(args.torch_model_path, map_location=torch.device("cpu")) state_dict = checkpoint["model"] model.load_state_dict(state_dict) diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py index 1cbc5516..debe5933 100755 --- a/TTS/bin/extract_tts_spectrograms.py +++ b/TTS/bin/extract_tts_spectrograms.py @@ -16,6 +16,7 @@ from TTS.tts.models import setup_model from TTS.tts.utils.speakers import get_speaker_manager from TTS.utils.audio import AudioProcessor from TTS.utils.generic_utils import count_parameters +from TTS.utils.io import load_fsspec use_cuda = torch.cuda.is_available() @@ -239,7 +240,7 @@ def main(args): # pylint: disable=redefined-outer-name model = setup_model(c) # restore model - checkpoint = torch.load(args.checkpoint_path, map_location="cpu") + checkpoint = load_fsspec(args.checkpoint_path, map_location="cpu") model.load_state_dict(checkpoint["model"]) if use_cuda: diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index 2bb5bfc7..43867239 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -17,6 +17,7 @@ from TTS.trainer import init_training from TTS.tts.datasets import load_meta_data from TTS.utils.audio import AudioProcessor from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict +from TTS.utils.io import load_fsspec from TTS.utils.radam import RAdam from TTS.utils.training import NoamLR, check_update @@ -169,7 +170,7 @@ def main(args): # pylint: disable=redefined-outer-name raise Exception("The %s not is a loss supported" % c.loss) if args.restore_path: - checkpoint = torch.load(args.restore_path) + checkpoint = load_fsspec(args.restore_path) try: model.load_state_dict(checkpoint["model"]) diff --git a/TTS/config/__init__.py b/TTS/config/__init__.py index ecbe1f9a..ea98f431 100644 --- a/TTS/config/__init__.py +++ b/TTS/config/__init__.py @@ -3,6 +3,7 @@ import os import re from typing import Dict +import fsspec import yaml from coqpit import Coqpit @@ -13,7 +14,7 @@ from TTS.utils.generic_utils import find_module def read_json_with_comments(json_path): """for backward compat.""" # fallback to json - with open(json_path, "r", encoding="utf-8") as f: + with fsspec.open(json_path, "r", encoding="utf-8") as f: input_str = f.read() # handle comments input_str = re.sub(r"\\\n", "", input_str) @@ -76,13 +77,12 @@ def load_config(config_path: str) -> None: config_dict = {} ext = os.path.splitext(config_path)[1] if ext in (".yml", ".yaml"): - with open(config_path, "r", encoding="utf-8") as f: + with fsspec.open(config_path, "r", encoding="utf-8") as f: data = yaml.safe_load(f) elif ext == ".json": try: - with open(config_path, "r", encoding="utf-8") as f: - input_str = f.read() - data = json.loads(input_str) + with fsspec.open(config_path, "r", encoding="utf-8") as f: + data = json.load(f) except json.decoder.JSONDecodeError: # backwards compat. data = read_json_with_comments(config_path) diff --git a/TTS/config/shared_configs.py b/TTS/config/shared_configs.py index 669437b9..0ec7f758 100644 --- a/TTS/config/shared_configs.py +++ b/TTS/config/shared_configs.py @@ -225,8 +225,10 @@ class BaseTrainingConfig(Coqpit): num_eval_loader_workers (int): Number of workers for evaluation time dataloader. output_path (str): - Path for training output folder. The nonexist part of the given path is created automatically. - All training outputs are saved there. + Path for training output folder, either a local file path or other + URLs supported by both fsspec and tensorboardX, e.g. GCS (gs://) or + S3 (s3://) paths. The nonexist part of the given path is created + automatically. All training artefacts are saved there. """ model: str = None diff --git a/TTS/speaker_encoder/models/lstm.py b/TTS/speaker_encoder/models/lstm.py index 7e39087a..de5bb007 100644 --- a/TTS/speaker_encoder/models/lstm.py +++ b/TTS/speaker_encoder/models/lstm.py @@ -2,6 +2,8 @@ import numpy as np import torch from torch import nn +from TTS.utils.io import load_fsspec + class LSTMWithProjection(nn.Module): def __init__(self, input_size, hidden_size, proj_size): @@ -120,7 +122,7 @@ class LSTMSpeakerEncoder(nn.Module): # pylint: disable=unused-argument, redefined-builtin def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False): - state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) self.load_state_dict(state["model"]) if use_cuda: self.cuda() diff --git a/TTS/speaker_encoder/models/resnet.py b/TTS/speaker_encoder/models/resnet.py index f52bb4d5..f121631b 100644 --- a/TTS/speaker_encoder/models/resnet.py +++ b/TTS/speaker_encoder/models/resnet.py @@ -2,6 +2,8 @@ import numpy as np import torch import torch.nn as nn +from TTS.utils.io import load_fsspec + class SELayer(nn.Module): def __init__(self, channel, reduction=8): @@ -201,7 +203,7 @@ class ResNetSpeakerEncoder(nn.Module): return embeddings def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False): - state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) self.load_state_dict(state["model"]) if use_cuda: self.cuda() diff --git a/TTS/speaker_encoder/utils/generic_utils.py b/TTS/speaker_encoder/utils/generic_utils.py index fb61e48e..1981fbe9 100644 --- a/TTS/speaker_encoder/utils/generic_utils.py +++ b/TTS/speaker_encoder/utils/generic_utils.py @@ -6,11 +6,11 @@ import re from multiprocessing import Manager import numpy as np -import torch from scipy import signal from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder +from TTS.utils.io import save_fsspec class Storage(object): @@ -198,7 +198,7 @@ def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_s "loss": model_loss, "date": datetime.date.today().strftime("%B %d, %Y"), } - torch.save(state, checkpoint_path) + save_fsspec(state, checkpoint_path) def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step): @@ -216,5 +216,5 @@ def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path bestmodel_path = "best_model.pth.tar" bestmodel_path = os.path.join(out_path, bestmodel_path) print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path)) - torch.save(state, bestmodel_path) + save_fsspec(state, bestmodel_path) return best_loss diff --git a/TTS/speaker_encoder/utils/io.py b/TTS/speaker_encoder/utils/io.py index 0479f1af..7a3aadc9 100644 --- a/TTS/speaker_encoder/utils/io.py +++ b/TTS/speaker_encoder/utils/io.py @@ -1,7 +1,7 @@ import datetime import os -import torch +from TTS.utils.io import save_fsspec def save_checkpoint(model, optimizer, model_loss, out_path, current_step): @@ -17,7 +17,7 @@ def save_checkpoint(model, optimizer, model_loss, out_path, current_step): "loss": model_loss, "date": datetime.date.today().strftime("%B %d, %Y"), } - torch.save(state, checkpoint_path) + save_fsspec(state, checkpoint_path) def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_step): @@ -34,5 +34,5 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_s bestmodel_path = "best_model.pth.tar" bestmodel_path = os.path.join(out_path, bestmodel_path) print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path)) - torch.save(state, bestmodel_path) + save_fsspec(state, bestmodel_path) return best_loss diff --git a/TTS/trainer.py b/TTS/trainer.py index 903aee5f..3ac83601 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- -import glob import importlib import logging import os @@ -12,7 +11,9 @@ import traceback from argparse import Namespace from dataclasses import dataclass, field from typing import Dict, List, Tuple, Union +from urllib.parse import urlparse +import fsspec import torch from coqpit import Coqpit from torch import nn @@ -29,13 +30,13 @@ from TTS.utils.distribute import init_distributed from TTS.utils.generic_utils import ( KeepAverage, count_parameters, - create_experiment_folder, + get_experiment_folder_path, get_git_branch, remove_experiment_folder, set_init_dict, to_cuda, ) -from TTS.utils.io import copy_model_files, save_best_model, save_checkpoint +from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint from TTS.utils.logging import ConsoleLogger, TensorboardLogger from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data @@ -173,7 +174,6 @@ class Trainer: self.best_loss = float("inf") self.train_loader = None self.eval_loader = None - self.output_audio_path = os.path.join(output_path, "test_audios") self.keep_avg_train = None self.keep_avg_eval = None @@ -309,7 +309,7 @@ class Trainer: return obj print(" > Restoring from %s ..." % os.path.basename(restore_path)) - checkpoint = torch.load(restore_path) + checkpoint = load_fsspec(restore_path) try: print(" > Restoring Model...") model.load_state_dict(checkpoint["model"]) @@ -776,7 +776,7 @@ class Trainer: """🏃 train -> evaluate -> test for the number of epochs.""" if self.restore_step != 0 or self.args.best_path: print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...") - self.best_loss = torch.load(self.args.best_path, map_location="cpu")["model_loss"] + self.best_loss = load_fsspec(self.args.best_path, map_location="cpu")["model_loss"] print(f" > Starting with loaded last best loss {self.best_loss}.") self.total_steps_done = self.restore_step @@ -834,9 +834,16 @@ class Trainer: @staticmethod def _setup_logger_config(log_file: str) -> None: - logging.basicConfig( - level=logging.INFO, format="", handlers=[logging.FileHandler(log_file), logging.StreamHandler()] - ) + handlers = [logging.StreamHandler()] + + # Only add a log file if the output location is local due to poor + # support for writing logs to file-like objects. + parsed_url = urlparse(log_file) + if not parsed_url.scheme or parsed_url.scheme == "file": + schemeless_path = os.path.join(parsed_url.netloc, parsed_url.path) + handlers.append(logging.FileHandler(schemeless_path)) + + logging.basicConfig(level=logging.INFO, format="", handlers=handlers) @staticmethod def _is_apex_available() -> bool: @@ -926,22 +933,27 @@ def init_arguments(): return parser -def get_last_checkpoint(path): +def get_last_checkpoint(path: str) -> Tuple[str, str]: """Get latest checkpoint or/and best model in path. It is based on globbing for `*.pth.tar` and the RegEx `(checkpoint|best_model)_([0-9]+)`. Args: - path (list): Path to files to be compared. + path: Path to files to be compared. Raises: ValueError: If no checkpoint or best_model files are found. Returns: - last_checkpoint (str): Last checkpoint filename. + Path to the last checkpoint + Path to best checkpoint """ - file_names = glob.glob(os.path.join(path, "*.pth.tar")) + fs = fsspec.get_mapper(path).fs + file_names = fs.glob(os.path.join(path, "*.pth.tar")) + scheme = urlparse(path).scheme + if scheme: # scheme is not preserved in fs.glob, add it back + file_names = [scheme + "://" + file_name for file_name in file_names] last_models = {} last_model_nums = {} for key in ["checkpoint", "best_model"]: @@ -963,7 +975,7 @@ def get_last_checkpoint(path): key_file_names = [fn for fn in file_names if key in fn] if last_model is None and len(key_file_names) > 0: last_model = max(key_file_names, key=os.path.getctime) - last_model_num = torch.load(last_model)["step"] + last_model_num = load_fsspec(last_model)["step"] if last_model is not None: last_models[key] = last_model @@ -1030,12 +1042,11 @@ def process_args(args, config=None): print(" > Mixed precision mode is ON") experiment_path = args.continue_path if not experiment_path: - experiment_path = create_experiment_folder(config.output_path, config.run_name) + experiment_path = get_experiment_folder_path(config.output_path, config.run_name) audio_path = os.path.join(experiment_path, "test_audios") # setup rank 0 process in distributed training tb_logger = None if args.rank == 0: - os.makedirs(audio_path, exist_ok=True) new_fields = {} if args.restore_path: new_fields["restore_path"] = args.restore_path @@ -1047,8 +1058,6 @@ def process_args(args, config=None): used_characters = parse_symbols() new_fields["characters"] = used_characters copy_model_files(config, experiment_path, new_fields) - os.chmod(audio_path, 0o775) - os.chmod(experiment_path, 0o775) tb_logger = TensorboardLogger(experiment_path, model_name=config.model) # write model desc to tensorboard tb_logger.tb_add_text("model-config", f"
{config.to_json()}", 0)
diff --git a/TTS/tts/models/align_tts.py b/TTS/tts/models/align_tts.py
index 879ecae4..fb2fa697 100644
--- a/TTS/tts/models/align_tts.py
+++ b/TTS/tts/models/align_tts.py
@@ -16,6 +16,7 @@ from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
+from TTS.utils.io import load_fsspec
@dataclass
@@ -389,7 +390,7 @@ class AlignTTS(BaseTTS):
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
- state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
+ state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:
self.eval()
diff --git a/TTS/tts/models/base_tacotron.py b/TTS/tts/models/base_tacotron.py
index b7056e06..2d2cc111 100644
--- a/TTS/tts/models/base_tacotron.py
+++ b/TTS/tts/models/base_tacotron.py
@@ -13,6 +13,7 @@ from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
from TTS.tts.utils.text import make_symbols
from TTS.utils.generic_utils import format_aux_input
+from TTS.utils.io import load_fsspec
from TTS.utils.training import gradual_training_scheduler
@@ -113,7 +114,7 @@ class BaseTacotron(BaseTTS):
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
- state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
+ state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if "r" in state:
self.decoder.set_r(state["r"])
diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py
index b3bceb09..1c631c8e 100755
--- a/TTS/tts/models/glow_tts.py
+++ b/TTS/tts/models/glow_tts.py
@@ -14,6 +14,7 @@ from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.speakers import get_speaker_manager
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
+from TTS.utils.io import load_fsspec
class GlowTTS(BaseTTS):
@@ -382,7 +383,7 @@ class GlowTTS(BaseTTS):
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
- state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
+ state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:
self.eval()
diff --git a/TTS/tts/models/speedy_speech.py b/TTS/tts/models/speedy_speech.py
index 8f14d610..33b9cb66 100644
--- a/TTS/tts/models/speedy_speech.py
+++ b/TTS/tts/models/speedy_speech.py
@@ -14,6 +14,7 @@ from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
+from TTS.utils.io import load_fsspec
@dataclass
@@ -306,7 +307,7 @@ class SpeedySpeech(BaseTTS):
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
- state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
+ state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:
self.eval()
diff --git a/TTS/tts/tf/utils/generic_utils.py b/TTS/tts/tf/utils/generic_utils.py
index 91434a38..681a9457 100644
--- a/TTS/tts/tf/utils/generic_utils.py
+++ b/TTS/tts/tf/utils/generic_utils.py
@@ -2,6 +2,7 @@ import datetime
import importlib
import pickle
+import fsspec
import numpy as np
import tensorflow as tf
@@ -16,11 +17,13 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwa
"r": r,
}
state.update(kwargs)
- pickle.dump(state, open(output_path, "wb"))
+ with fsspec.open(output_path, "wb") as f:
+ pickle.dump(state, f)
def load_checkpoint(model, checkpoint_path):
- checkpoint = pickle.load(open(checkpoint_path, "rb"))
+ with fsspec.open(checkpoint_path, "rb") as f:
+ checkpoint = pickle.load(f)
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
tf_vars = model.weights
for tf_var in tf_vars:
diff --git a/TTS/tts/tf/utils/io.py b/TTS/tts/tf/utils/io.py
index b2345b00..de6acff9 100644
--- a/TTS/tts/tf/utils/io.py
+++ b/TTS/tts/tf/utils/io.py
@@ -1,6 +1,7 @@
import datetime
import pickle
+import fsspec
import tensorflow as tf
@@ -14,11 +15,13 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwa
"r": r,
}
state.update(kwargs)
- pickle.dump(state, open(output_path, "wb"))
+ with fsspec.open(output_path, "wb") as f:
+ pickle.dump(state, f)
def load_checkpoint(model, checkpoint_path):
- checkpoint = pickle.load(open(checkpoint_path, "rb"))
+ with fsspec.open(checkpoint_path, "rb") as f:
+ checkpoint = pickle.load(f)
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
tf_vars = model.weights
for tf_var in tf_vars:
diff --git a/TTS/tts/tf/utils/tflite.py b/TTS/tts/tf/utils/tflite.py
index 9701d591..2f76aa50 100644
--- a/TTS/tts/tf/utils/tflite.py
+++ b/TTS/tts/tf/utils/tflite.py
@@ -1,3 +1,4 @@
+import fsspec
import tensorflow as tf
@@ -14,7 +15,7 @@ def convert_tacotron2_to_tflite(model, output_path=None, experimental_converter=
print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.")
if output_path is not None:
# same model binary if outputpath is provided
- with open(output_path, "wb") as f:
+ with fsspec.open(output_path, "wb") as f:
f.write(tflite_model)
return None
return tflite_model
diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py
index a8c9e0f6..ed14cd8e 100755
--- a/TTS/tts/utils/speakers.py
+++ b/TTS/tts/utils/speakers.py
@@ -3,6 +3,7 @@ import os
import random
from typing import Any, Dict, List, Tuple, Union
+import fsspec
import numpy as np
import torch
from coqpit import Coqpit
@@ -84,12 +85,12 @@ class SpeakerManager:
@staticmethod
def _load_json(json_file_path: str) -> Dict:
- with open(json_file_path) as f:
+ with fsspec.open(json_file_path, "r") as f:
return json.load(f)
@staticmethod
def _save_json(json_file_path: str, data: dict) -> None:
- with open(json_file_path, "w") as f:
+ with fsspec.open(json_file_path, "w") as f:
json.dump(data, f, indent=4)
@property
@@ -294,9 +295,10 @@ def _set_file_path(path):
Intended to band aid the different paths returned in restored and continued training."""
path_restore = os.path.join(os.path.dirname(path), "speakers.json")
path_continue = os.path.join(path, "speakers.json")
- if os.path.exists(path_restore):
+ fs = fsspec.get_mapper(path).fs
+ if fs.exists(path_restore):
return path_restore
- if os.path.exists(path_continue):
+ if fs.exists(path_continue):
return path_continue
raise FileNotFoundError(f" [!] `speakers.json` not found in {path}")
@@ -307,7 +309,7 @@ def load_speaker_mapping(out_path):
json_file = out_path
else:
json_file = _set_file_path(out_path)
- with open(json_file) as f:
+ with fsspec.open(json_file, "r") as f:
return json.load(f)
@@ -315,7 +317,7 @@ def save_speaker_mapping(out_path, speaker_mapping):
"""Saves speaker mapping if not yet present."""
if out_path is not None:
speakers_json_path = _set_file_path(out_path)
- with open(speakers_json_path, "w") as f:
+ with fsspec.open(speakers_json_path, "w") as f:
json.dump(speaker_mapping, f, indent=4)
diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py
index e7c57529..287143e5 100644
--- a/TTS/utils/generic_utils.py
+++ b/TTS/utils/generic_utils.py
@@ -1,15 +1,14 @@
# -*- coding: utf-8 -*-
import datetime
-import glob
import importlib
import os
import re
-import shutil
import subprocess
import sys
from pathlib import Path
from typing import Dict
+import fsspec
import torch
@@ -58,23 +57,22 @@ def get_commit_hash():
return commit
-def create_experiment_folder(root_path, model_name):
- """Create a folder with the current date and time"""
+def get_experiment_folder_path(root_path, model_name):
+ """Get an experiment folder path with the current date and time"""
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
commit_hash = get_commit_hash()
output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash)
- os.makedirs(output_folder, exist_ok=True)
print(" > Experiment folder: {}".format(output_folder))
return output_folder
def remove_experiment_folder(experiment_path):
"""Check folder if there is a checkpoint, otherwise remove the folder"""
-
- checkpoint_files = glob.glob(experiment_path + "/*.pth.tar")
+ fs = fsspec.get_mapper(experiment_path).fs
+ checkpoint_files = fs.glob(experiment_path + "/*.pth.tar")
if not checkpoint_files:
- if os.path.exists(experiment_path):
- shutil.rmtree(experiment_path, ignore_errors=True)
+ if fs.exists(experiment_path):
+ fs.rm(experiment_path, recursive=True)
print(" ! Run is removed from {}".format(experiment_path))
else:
print(" ! Run is kept in {}".format(experiment_path))
diff --git a/TTS/utils/io.py b/TTS/utils/io.py
index 871cff6c..f634b023 100644
--- a/TTS/utils/io.py
+++ b/TTS/utils/io.py
@@ -1,9 +1,11 @@
import datetime
-import glob
+import json
import os
import pickle as pickle_tts
-from shutil import copyfile
+import shutil
+from typing import Any
+import fsspec
import torch
from coqpit import Coqpit
@@ -24,7 +26,7 @@ class AttrDict(dict):
self.__dict__ = self
-def copy_model_files(config, out_path, new_fields):
+def copy_model_files(config: Coqpit, out_path, new_fields):
"""Copy config.json and other model files to training folder and add
new fields.
@@ -37,23 +39,40 @@ def copy_model_files(config, out_path, new_fields):
copy_config_path = os.path.join(out_path, "config.json")
# add extra information fields
config.update(new_fields, allow_new=True)
- config.save_json(copy_config_path)
+ # TODO: Revert to config.save_json() once Coqpit supports arbitrary paths.
+ with fsspec.open(copy_config_path, "w", encoding="utf8") as f:
+ json.dump(config.to_dict(), f, indent=4)
+
# copy model stats file if available
if config.audio.stats_path is not None:
copy_stats_path = os.path.join(out_path, "scale_stats.npy")
- if not os.path.exists(copy_stats_path):
- copyfile(
- config.audio.stats_path,
- copy_stats_path,
- )
+ filesystem = fsspec.get_mapper(copy_stats_path).fs
+ if not filesystem.exists(copy_stats_path):
+ with fsspec.open(config.audio.stats_path, "rb") as source_file:
+ with fsspec.open(copy_stats_path, "wb") as target_file:
+ shutil.copyfileobj(source_file, target_file)
+
+
+def load_fsspec(path: str, **kwargs) -> Any:
+ """Like torch.load but can load from other locations (e.g. s3:// , gs://).
+
+ Args:
+ path: Any path or url supported by fsspec.
+ **kwargs: Keyword arguments forwarded to torch.load.
+
+ Returns:
+ Object stored in path.
+ """
+ with fsspec.open(path, "rb") as f:
+ return torch.load(f, **kwargs)
def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin
try:
- state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
+ state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
except ModuleNotFoundError:
pickle_tts.Unpickler = RenamingUnpickler
- state = torch.load(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts)
+ state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts)
model.load_state_dict(state["model"])
if use_cuda:
model.cuda()
@@ -62,6 +81,18 @@ def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pyli
return model, state
+def save_fsspec(state: Any, path: str, **kwargs):
+ """Like torch.save but can save to other locations (e.g. s3:// , gs://).
+
+ Args:
+ state: State object to save
+ path: Any path or url supported by fsspec.
+ **kwargs: Keyword arguments forwarded to torch.save.
+ """
+ with fsspec.open(path, "wb") as f:
+ torch.save(state, f, **kwargs)
+
+
def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs):
if hasattr(model, "module"):
model_state = model.module.state_dict()
@@ -90,7 +121,7 @@ def save_model(config, model, optimizer, scaler, current_step, epoch, output_pat
"date": datetime.date.today().strftime("%B %d, %Y"),
}
state.update(kwargs)
- torch.save(state, output_path)
+ save_fsspec(state, output_path)
def save_checkpoint(
@@ -147,18 +178,16 @@ def save_best_model(
model_loss=current_loss,
**kwargs,
)
+ fs = fsspec.get_mapper(out_path).fs
# only delete previous if current is saved successfully
if not keep_all_best or (current_step < keep_after):
- model_names = glob.glob(os.path.join(out_path, "best_model*.pth.tar"))
+ model_names = fs.glob(os.path.join(out_path, "best_model*.pth.tar"))
for model_name in model_names:
- if os.path.basename(model_name) == best_model_name:
- continue
- os.remove(model_name)
- # create symlink to best model for convinience
- link_name = "best_model.pth.tar"
- link_path = os.path.join(out_path, link_name)
- if os.path.islink(link_path) or os.path.isfile(link_path):
- os.remove(link_path)
- os.symlink(best_model_name, os.path.join(out_path, link_name))
+ if os.path.basename(model_name) != best_model_name:
+ fs.rm(model_name)
+ # create a shortcut which always points to the currently best model
+ shortcut_name = "best_model.pth.tar"
+ shortcut_path = os.path.join(out_path, shortcut_name)
+ fs.copy(checkpoint_path, shortcut_path)
best_loss = current_loss
return best_loss
diff --git a/TTS/vocoder/models/gan.py b/TTS/vocoder/models/gan.py
index 39176155..f203c533 100644
--- a/TTS/vocoder/models/gan.py
+++ b/TTS/vocoder/models/gan.py
@@ -9,6 +9,7 @@ from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from TTS.utils.audio import AudioProcessor
+from TTS.utils.io import load_fsspec
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
from TTS.vocoder.datasets.gan_dataset import GANDataset
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
@@ -222,7 +223,7 @@ class GAN(BaseVocoder):
checkpoint_path (str): Checkpoint file path.
eval (bool, optional): If true, load the model for inference. If falseDefaults to False.
"""
- state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
+ state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
# band-aid for older than v0.0.15 GAN models
if "model_disc" in state:
self.model_g.load_checkpoint(config, checkpoint_path, eval)
diff --git a/TTS/vocoder/models/hifigan_generator.py b/TTS/vocoder/models/hifigan_generator.py
index f606c649..2260b781 100644
--- a/TTS/vocoder/models/hifigan_generator.py
+++ b/TTS/vocoder/models/hifigan_generator.py
@@ -5,6 +5,8 @@ import torch.nn.functional as F
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn.utils import remove_weight_norm, weight_norm
+from TTS.utils.io import load_fsspec
+
LRELU_SLOPE = 0.1
@@ -275,7 +277,7 @@ class HifiganGenerator(torch.nn.Module):
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
- state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
+ state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:
self.eval()
diff --git a/TTS/vocoder/models/melgan_generator.py b/TTS/vocoder/models/melgan_generator.py
index dabb4baa..e60baa9d 100644
--- a/TTS/vocoder/models/melgan_generator.py
+++ b/TTS/vocoder/models/melgan_generator.py
@@ -2,6 +2,7 @@ import torch
from torch import nn
from torch.nn.utils import weight_norm
+from TTS.utils.io import load_fsspec
from TTS.vocoder.layers.melgan import ResidualStack
@@ -86,7 +87,7 @@ class MelganGenerator(nn.Module):
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
- state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
+ state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:
self.eval()
diff --git a/TTS/vocoder/models/parallel_wavegan_generator.py b/TTS/vocoder/models/parallel_wavegan_generator.py
index 788856cc..b8e78d03 100644
--- a/TTS/vocoder/models/parallel_wavegan_generator.py
+++ b/TTS/vocoder/models/parallel_wavegan_generator.py
@@ -3,6 +3,7 @@ import math
import numpy as np
import torch
+from TTS.utils.io import load_fsspec
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
from TTS.vocoder.layers.upsample import ConvUpsample
@@ -154,7 +155,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
- state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
+ state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:
self.eval()
diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py
index d2983be2..5dc878d7 100644
--- a/TTS/vocoder/models/wavegrad.py
+++ b/TTS/vocoder/models/wavegrad.py
@@ -11,6 +11,7 @@ from torch.utils.data.distributed import DistributedSampler
from TTS.model import BaseModel
from TTS.utils.audio import AudioProcessor
+from TTS.utils.io import load_fsspec
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
from TTS.vocoder.datasets import WaveGradDataset
from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock
@@ -220,7 +221,7 @@ class Wavegrad(BaseModel):
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
- state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
+ state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:
self.eval()
diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py
index c2e47120..8a968019 100644
--- a/TTS/vocoder/models/wavernn.py
+++ b/TTS/vocoder/models/wavernn.py
@@ -13,6 +13,7 @@ from torch.utils.data.distributed import DistributedSampler
from TTS.tts.utils.visual import plot_spectrogram
from TTS.utils.audio import AudioProcessor
+from TTS.utils.io import load_fsspec
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
from TTS.vocoder.layers.losses import WaveRNNLoss
from TTS.vocoder.models.base_vocoder import BaseVocoder
@@ -545,7 +546,7 @@ class Wavernn(BaseVocoder):
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
- state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
+ state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:
self.eval()
diff --git a/TTS/vocoder/tf/utils/io.py b/TTS/vocoder/tf/utils/io.py
index 7e236db2..3de8adab 100644
--- a/TTS/vocoder/tf/utils/io.py
+++ b/TTS/vocoder/tf/utils/io.py
@@ -1,6 +1,7 @@
import datetime
import pickle
+import fsspec
import tensorflow as tf
@@ -13,12 +14,14 @@ def save_checkpoint(model, current_step, epoch, output_path, **kwargs):
"date": datetime.date.today().strftime("%B %d, %Y"),
}
state.update(kwargs)
- pickle.dump(state, open(output_path, "wb"))
+ with fsspec.open(output_path, "wb") as f:
+ pickle.dump(state, f)
def load_checkpoint(model, checkpoint_path):
"""Load TF Vocoder model"""
- checkpoint = pickle.load(open(checkpoint_path, "rb"))
+ with fsspec.open(checkpoint_path, "rb") as f:
+ checkpoint = pickle.load(f)
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
tf_vars = model.weights
for tf_var in tf_vars:
diff --git a/TTS/vocoder/tf/utils/tflite.py b/TTS/vocoder/tf/utils/tflite.py
index e0c630b9..876739fd 100644
--- a/TTS/vocoder/tf/utils/tflite.py
+++ b/TTS/vocoder/tf/utils/tflite.py
@@ -1,3 +1,4 @@
+import fsspec
import tensorflow as tf
@@ -14,7 +15,7 @@ def convert_melgan_to_tflite(model, output_path=None, experimental_converter=Tru
print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.")
if output_path is not None:
# same model binary if outputpath is provided
- with open(output_path, "wb") as f:
+ with fsspec.open(output_path, "wb") as f:
f.write(tflite_model)
return None
return tflite_model
diff --git a/requirements.txt b/requirements.txt
index d5624c3b..b92947a0 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -24,3 +24,4 @@ mecab-python3==1.0.3
unidic-lite==1.0.8
# gruut+supported langs
gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=1.2.0
+fsspec>=2021.04.0
diff --git a/tests/tts_tests/test_tacotron2_train_fsspec_path.py b/tests/tts_tests/test_tacotron2_train_fsspec_path.py
new file mode 100644
index 00000000..9e4ee102
--- /dev/null
+++ b/tests/tts_tests/test_tacotron2_train_fsspec_path.py
@@ -0,0 +1,55 @@
+import glob
+import os
+import shutil
+
+from tests import get_device_id, get_tests_output_path, run_cli
+from TTS.tts.configs import Tacotron2Config
+
+config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
+output_path = os.path.join(get_tests_output_path(), "train_outputs")
+
+config = Tacotron2Config(
+ r=5,
+ batch_size=8,
+ eval_batch_size=8,
+ num_loader_workers=0,
+ num_eval_loader_workers=0,
+ text_cleaner="english_cleaners",
+ use_phonemes=False,
+ phoneme_language="en-us",
+ phoneme_cache_path=os.path.join(get_tests_output_path(), "train_outputs/phoneme_cache/"),
+ run_eval=True,
+ test_delay_epochs=-1,
+ epochs=1,
+ print_step=1,
+ test_sentences=[
+ "Be a voice, not an echo.",
+ ],
+ print_eval=True,
+ max_decoder_steps=50,
+)
+config.audio.do_trim_silence = True
+config.audio.trim_db = 60
+config.save_json(config_path)
+
+# train the model for one epoch
+command_train = (
+ f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path file://{config_path} "
+ f"--coqpit.output_path file://{output_path} "
+ "--coqpit.datasets.0.name ljspeech "
+ "--coqpit.datasets.0.meta_file_train metadata.csv "
+ "--coqpit.datasets.0.meta_file_val metadata.csv "
+ "--coqpit.datasets.0.path tests/data/ljspeech "
+ "--coqpit.test_delay_epochs 0 "
+)
+run_cli(command_train)
+
+# Find latest folder
+continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
+
+# restore the model and continue training for one more epoch
+command_train = (
+ f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path file://{continue_path} "
+)
+run_cli(command_train)
+shutil.rmtree(continue_path)