add speech editing

This commit is contained in:
Pranay Gosar
2024-04-23 18:38:09 -05:00
parent b8bb2ab592
commit 59877c085e
2 changed files with 225 additions and 11 deletions

View File

@@ -9,7 +9,6 @@ from data.tokenizer import (
AudioTokenizer,
TextTokenizer,
)
from IPython.display import display, Audio
import argparse
import random
import numpy as np
@@ -25,7 +24,7 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
def parse_arguments():
parser = argparse.ArgumentParser(
description="VoiceCraft Inference: see the script for more information on the options")
description="VoiceCraft TTS Inference: see the script for more information on the options")
parser.add_argument("--model_name", type=str, default="giga330M.pth", choices=[
"giga330M.pth", "gigaHalfLibri330M_TTSEnhanced_max16s.pth", "giga830M.pth"],
@@ -34,15 +33,15 @@ def parse_arguments():
default=16000, help="Audio sampling rate for the codec")
parser.add_argument("--codec_sr", type=int, default=50,
help="Sampling rate for the codec")
parser.add_argument("--top_k", type=int, default=0,
help="Top-k sampling value")
parser.add_argument("--top_k", type=float, default=0,
help="Top-k value")
parser.add_argument("--top_p", type=float, default=0.9,
help="Top-p sampling value")
help="Top-p value")
parser.add_argument("--temperature", type=float,
default=1.0, help="Temperature for sampling")
parser.add_argument("--silence_tokens", type=int, nargs="*",
default=[1388, 1898, 131], help="Silence token IDs")
parser.add_argument("--kvcache", type=int, default=1,
parser.add_argument("--kvcache", type=int, default=1, choices=[0, 1],
help="Key-value cache flag (0 or 1)")
parser.add_argument("--stop_repetition", type=int,
default=3, help="Stop repetition for generation")
@@ -54,7 +53,6 @@ def parse_arguments():
help="beam size for MFA alignment")
parser.add_argument("--retry_beam_size", type=int, default=40,
help="retry beam size for MFA alignment")
parser.add_argument("--output_dir", type=str, default="./generated_tts",
help="directory to save generated audio")
parser.add_argument("--original_audio", type=str,
@@ -147,9 +145,6 @@ audio_dur = info.num_frames / info.sample_rate
assert cut_off_sec < audio_dur, f"cut_off_sec {cut_off_sec} is larger than the audio duration {audio_dur}"
prompt_end_frame = int(cut_off_sec * info.sample_rate)
# run the model to get the output
def seed_everything(seed):
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
@@ -162,6 +157,7 @@ def seed_everything(seed):
seed_everything(seed)
# inference
decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache,
"codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr, "silence_tokens": silence_tokens, "sample_batch_size": sample_batch_size}
concated_audio, gen_audio = inference_one_sample(model, argparse.Namespace(