mirror of
https://github.com/jasonppy/VoiceCraft.git
synced 2026-04-05 02:36:21 +02:00
add speech editing
This commit is contained in:
@@ -9,7 +9,6 @@ from data.tokenizer import (
|
||||
AudioTokenizer,
|
||||
TextTokenizer,
|
||||
)
|
||||
from IPython.display import display, Audio
|
||||
import argparse
|
||||
import random
|
||||
import numpy as np
|
||||
@@ -25,7 +24,7 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="VoiceCraft Inference: see the script for more information on the options")
|
||||
description="VoiceCraft TTS Inference: see the script for more information on the options")
|
||||
|
||||
parser.add_argument("--model_name", type=str, default="giga330M.pth", choices=[
|
||||
"giga330M.pth", "gigaHalfLibri330M_TTSEnhanced_max16s.pth", "giga830M.pth"],
|
||||
@@ -34,15 +33,15 @@ def parse_arguments():
|
||||
default=16000, help="Audio sampling rate for the codec")
|
||||
parser.add_argument("--codec_sr", type=int, default=50,
|
||||
help="Sampling rate for the codec")
|
||||
parser.add_argument("--top_k", type=int, default=0,
|
||||
help="Top-k sampling value")
|
||||
parser.add_argument("--top_k", type=float, default=0,
|
||||
help="Top-k value")
|
||||
parser.add_argument("--top_p", type=float, default=0.9,
|
||||
help="Top-p sampling value")
|
||||
help="Top-p value")
|
||||
parser.add_argument("--temperature", type=float,
|
||||
default=1.0, help="Temperature for sampling")
|
||||
parser.add_argument("--silence_tokens", type=int, nargs="*",
|
||||
default=[1388, 1898, 131], help="Silence token IDs")
|
||||
parser.add_argument("--kvcache", type=int, default=1,
|
||||
parser.add_argument("--kvcache", type=int, default=1, choices=[0, 1],
|
||||
help="Key-value cache flag (0 or 1)")
|
||||
parser.add_argument("--stop_repetition", type=int,
|
||||
default=3, help="Stop repetition for generation")
|
||||
@@ -54,7 +53,6 @@ def parse_arguments():
|
||||
help="beam size for MFA alignment")
|
||||
parser.add_argument("--retry_beam_size", type=int, default=40,
|
||||
help="retry beam size for MFA alignment")
|
||||
|
||||
parser.add_argument("--output_dir", type=str, default="./generated_tts",
|
||||
help="directory to save generated audio")
|
||||
parser.add_argument("--original_audio", type=str,
|
||||
@@ -147,9 +145,6 @@ audio_dur = info.num_frames / info.sample_rate
|
||||
assert cut_off_sec < audio_dur, f"cut_off_sec {cut_off_sec} is larger than the audio duration {audio_dur}"
|
||||
prompt_end_frame = int(cut_off_sec * info.sample_rate)
|
||||
|
||||
# run the model to get the output
|
||||
|
||||
|
||||
def seed_everything(seed):
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
random.seed(seed)
|
||||
@@ -162,6 +157,7 @@ def seed_everything(seed):
|
||||
|
||||
seed_everything(seed)
|
||||
|
||||
# inference
|
||||
decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache,
|
||||
"codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr, "silence_tokens": silence_tokens, "sample_batch_size": sample_batch_size}
|
||||
concated_audio, gen_audio = inference_one_sample(model, argparse.Namespace(
|
||||
|
||||
Reference in New Issue
Block a user