deprecated .update not used anymore, better error handling, can use voicecraft without whisper

This commit is contained in:
Stepan Zuev
2024-04-03 05:01:55 +03:00
parent 5cef625c1b
commit 74fa65979d

View File

@@ -6,62 +6,63 @@ from data.tokenizer import (
TextTokenizer, TextTokenizer,
) )
from models import voicecraft from models import voicecraft
import whisper
from whisper.tokenizer import get_tokenizer
import os import os
import io import io
whisper_model = None def load_models(whisper_model_choice, voicecraft_model_choice):
voicecraft_model = None whisper_model, voicecraft_model = None, None
device = "cuda" if torch.cuda.is_available() else "cpu" if whisper_model_choice is not None:
import whisper
from whisper.tokenizer import get_tokenizer
whisper_model = {
"model": whisper.load_model(whisper_model_choice),
"tokenizer": get_tokenizer(multilingual=False)
}
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = "cuda" if torch.cuda.is_available() else "cpu"
voicecraft_name = f"{voicecraft_model_choice}.pth"
ckpt_fn = f"./pretrained_models/{voicecraft_name}"
encodec_fn = "./pretrained_models/encodec_4cb2048_giga.th"
if not os.path.exists(ckpt_fn):
os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/{voicecraft_name}\?download\=true")
os.system(f"mv {voicecraft_name}\?download\=true ./pretrained_models/{voicecraft_name}")
if not os.path.exists(encodec_fn):
os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th")
os.system(f"mv encodec_4cb2048_giga.th ./pretrained_models/encodec_4cb2048_giga.th")
ckpt = torch.load(ckpt_fn, map_location="cpu")
model = voicecraft.VoiceCraft(ckpt["config"])
model.load_state_dict(ckpt["model"])
model.to(device)
model.eval()
voicecraft_model = {
"ckpt": ckpt,
"model": model,
"text_tokenizer": TextTokenizer(backend="espeak"),
"audio_tokenizer": AudioTokenizer(signature=encodec_fn)
}
return [
whisper_model,
voicecraft_model,
gr.Audio(interactive=True),
]
def load_models(input_audio, transcribe_btn, run_btn, rerun_btn): def transcribe(whisper_model, audio_path):
def impl(whisper_model_choice, voicecraft_model_choice): if whisper_model is None:
global whisper_model, voicecraft_model raise gr.Error("Whisper model not loaded")
whisper_model = whisper.load_model(whisper_model_choice)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
voicecraft_name = f"{voicecraft_model_choice}.pth"
ckpt_fn = f"./pretrained_models/{voicecraft_name}"
encodec_fn = "./pretrained_models/encodec_4cb2048_giga.th"
if not os.path.exists(ckpt_fn):
os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/{voicecraft_name}\?download\=true")
os.system(f"mv {voicecraft_name}\?download\=true ./pretrained_models/{voicecraft_name}")
if not os.path.exists(encodec_fn):
os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th")
os.system(f"mv encodec_4cb2048_giga.th ./pretrained_models/encodec_4cb2048_giga.th")
voicecraft_model = {}
voicecraft_model["ckpt"] = torch.load(ckpt_fn, map_location="cpu")
voicecraft_model["model"] = voicecraft.VoiceCraft(voicecraft_model["ckpt"]["config"])
voicecraft_model["model"].load_state_dict(voicecraft_model["ckpt"]["model"])
voicecraft_model["model"].to(device)
voicecraft_model["model"].eval()
voicecraft_model["text_tokenizer"] = TextTokenizer(backend="espeak")
voicecraft_model["audio_tokenizer"] = AudioTokenizer(signature=encodec_fn)
return [
input_audio.update(interactive=True),
transcribe_btn.update(interactive=True),
run_btn.update(interactive=True),
rerun_btn.update(interactive=True)
]
return impl
def transcribe(audio_path):
tokenizer = get_tokenizer(multilingual=False)
number_tokens = [ number_tokens = [
i i
for i in range(tokenizer.eot) for i in range(whisper_model["tokenizer"].eot)
if all(c in "0123456789" for c in tokenizer.decode([i]).removeprefix(" ")) if all(c in "0123456789" for c in whisper_model["tokenizer"].decode([i]).removeprefix(" "))
] ]
result = whisper_model.transcribe(audio_path, suppress_tokens=[-1] + number_tokens, word_timestamps=True) result = whisper_model["model"].transcribe(audio_path, suppress_tokens=[-1] + number_tokens, word_timestamps=True)
words = [word_info for segment in result["segments"] for word_info in segment["words"]] words = [word_info for segment in result["segments"] for word_info in segment["words"]]
transcript = result["text"] transcript = result["text"]
@@ -69,12 +70,12 @@ def transcribe(audio_path):
transcript_with_end_time = " ".join([f"{word['word']} {word['end']}" for word in words]) transcript_with_end_time = " ".join([f"{word['word']} {word['end']}" for word in words])
choices = [f"{word['start']} {word['word']} {word['end']}" for word in words] choices = [f"{word['start']} {word['word']} {word['end']}" for word in words]
edit_from_word = gr.Dropdown(label="First word to edit", value=choices[0], choices=choices, interactive=True)
edit_to_word = gr.Dropdown(label="Last word to edit", value=choices[-1], choices=choices, interactive=True)
return [ return [
transcript, transcript_with_start_time, transcript_with_end_time, transcript, transcript_with_start_time, transcript_with_end_time,
edit_from_word, edit_to_word, words gr.Dropdown(value=choices[0], choices=choices, interactive=True), # edit_from_word
gr.Dropdown(value=choices[-1], choices=choices, interactive=True), # edit_to_word
words
] ]
@@ -86,11 +87,16 @@ def get_output_audio(audio_tensors, codec_audio_sr):
return buffer.read() return buffer.read()
def run(left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, temperature, def run(voicecraft_model, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, temperature,
stop_repetition, sample_batch_size, kvcache, silence_tokens, stop_repetition, sample_batch_size, kvcache, silence_tokens,
audio_path, word_info, transcript, smart_transcript, audio_path, word_info, transcript, smart_transcript,
mode, prompt_end_time, edit_start_time, edit_end_time, mode, prompt_end_time, edit_start_time, edit_end_time,
split_text, selected_sentence, previous_audio_tensors): split_text, selected_sentence, previous_audio_tensors):
if voicecraft_model is None:
raise gr.Error("VoiceCraft model not loaded")
if smart_transcript and (word_info is None):
raise gr.Error("Can't use smart transcript: whisper transcript not found")
if mode == "Long TTS": if mode == "Long TTS":
if split_text == "Newline": if split_text == "Newline":
sentences = transcript.split('\n') sentences = transcript.split('\n')
@@ -104,6 +110,7 @@ def run(left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, tempe
else: else:
sentences = [transcript.replace("\n", " ")] sentences = [transcript.replace("\n", " ")]
device = "cuda" if torch.cuda.is_available() else "cpu"
info = torchaudio.info(audio_path) info = torchaudio.info(audio_path)
audio_dur = info.num_frames / info.sample_rate audio_dur = info.num_frames / info.sample_rate
@@ -175,8 +182,7 @@ def run(left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, tempe
if mode != "Rerun": if mode != "Rerun":
output_audio = get_output_audio(audio_tensors, codec_audio_sr) output_audio = get_output_audio(audio_tensors, codec_audio_sr)
sentences = [f"{idx}: {text}" for idx, text in enumerate(sentences)] sentences = [f"{idx}: {text}" for idx, text in enumerate(sentences)]
component = gr.Dropdown(label="Sentence", choices=sentences, value=sentences[0], component = gr.Dropdown(choices=sentences, value=sentences[0])
info="Select sentence you want to regenerate")
return output_audio, inference_transcript, component, audio_tensors return output_audio, inference_transcript, component, audio_tensors
else: else:
previous_audio_tensors[selected_sentence_idx] = audio_tensors[0] previous_audio_tensors[selected_sentence_idx] = audio_tensors[0]
@@ -185,29 +191,25 @@ def run(left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, tempe
return output_audio, inference_transcript, sentence_audio, previous_audio_tensors return output_audio, inference_transcript, sentence_audio, previous_audio_tensors
def update_input_audio(prompt_end_time, edit_start_time, edit_end_time): def update_input_audio(audio_path):
def impl(audio_path): info = torchaudio.info(audio_path)
info = torchaudio.info(audio_path) max_time = round(info.num_frames / info.sample_rate, 2)
max_time = round(info.num_frames / info.sample_rate, 2) return [
return [ gr.Slider(maximum=max_time, value=max_time),
prompt_end_time.update(maximum=max_time, value=max_time), gr.Slider(maximum=max_time, value=0),
edit_start_time.update(maximum=max_time, value=0), gr.Slider(maximum=max_time, value=max_time),
edit_end_time.update(maximum=max_time, value=max_time), ]
]
return impl
def change_mode(prompt_end_time, split_text, edit_word_mode, segment_control, precise_segment_control, long_tts_controls): def change_mode(mode):
def impl(mode): return [
return [ gr.Slider(visible=mode != "Edit"),
prompt_end_time.update(visible=mode != "Edit"), gr.Radio(visible=mode == "Long TTS"),
split_text.update(visible=mode == "Long TTS"), gr.Radio(visible=mode == "Edit"),
edit_word_mode.update(visible=mode == "Edit"), gr.Row(visible=mode == "Edit"),
segment_control.update(visible=mode == "Edit"), gr.Accordion(visible=mode == "Edit"),
precise_segment_control.update(visible=mode == "Edit"), gr.Group(visible=mode == "Long TTS"),
long_tts_controls.update(visible=mode == "Long TTS"), ]
]
return impl
def load_sentence(selected_sentence, codec_audio_sr, audio_tensors): def load_sentence(selected_sentence, codec_audio_sr, audio_tensors):
@@ -218,28 +220,27 @@ def load_sentence(selected_sentence, codec_audio_sr, audio_tensors):
return get_output_audio([audio_tensors[selected_sentence_idx]], codec_audio_sr) return get_output_audio([audio_tensors[selected_sentence_idx]], codec_audio_sr)
def update_bound_word(is_first_word, edit_time): def update_bound_word(is_first_word, selected_word, edit_word_mode):
def impl(selected_word, edit_word_mode): if selected_word is None:
word_start_time = float(selected_word.split(' ')[0]) return None
word_end_time = float(selected_word.split(' ')[-1])
if edit_word_mode == "Replace half":
bound_time = (word_start_time + word_end_time) / 2
elif is_first_word:
bound_time = word_start_time
else:
bound_time = word_end_time
return edit_time.update(value=bound_time) word_start_time = float(selected_word.split(' ')[0])
return impl word_end_time = float(selected_word.split(' ')[-1])
if edit_word_mode == "Replace half":
bound_time = (word_start_time + word_end_time) / 2
elif is_first_word:
bound_time = word_start_time
else:
bound_time = word_end_time
return bound_time
def update_bound_words(edit_start_time, edit_end_time): def update_bound_words(from_selected_word, to_selected_word, edit_word_mode):
def impl(from_selected_word, to_selected_word, edit_word_mode): return [
return [ update_bound_word(True, from_selected_word, edit_word_mode),
update_bound_word(True, edit_start_time)(from_selected_word, edit_word_mode), update_bound_word(False, to_selected_word, edit_word_mode),
update_bound_word(True, edit_end_time)(to_selected_word, edit_word_mode), ]
]
return impl
smart_transcript_info = """ smart_transcript_info = """
@@ -251,6 +252,7 @@ If disabled, you should write the target transcript yourself:</br>
- In Long TTS select split by newline (<b>SENTENCE SPLIT WON'T WORK</b>) and start each line with a prompt transcript.</br> - In Long TTS select split by newline (<b>SENTENCE SPLIT WON'T WORK</b>) and start each line with a prompt transcript.</br>
- In Edit mode write full prompt</br> - In Edit mode write full prompt</br>
""" """
demo_text = { demo_text = {
"TTS": { "TTS": {
"smart": "I cannot believe that the same model can also do text to speech synthesis as well!", "smart": "I cannot believe that the same model can also do text to speech synthesis as well!",
@@ -269,7 +271,9 @@ demo_text = {
"But when I had approached so near to them, the common If some sentences sound odd, just rerun TTS on them, no need to generate the whole text again!" "But when I had approached so near to them, the common If some sentences sound odd, just rerun TTS on them, no need to generate the whole text again!"
} }
} }
all_demo_texts = {vv for k, v in demo_text.items() for kk, vv in v.items()} all_demo_texts = {vv for k, v in demo_text.items() for kk, vv in v.items()}
demo_words = [ demo_words = [
'0.0 But 0.12', '0.12 when 0.26', '0.26 I 0.44', '0.44 had 0.6', '0.6 approached 0.94', '0.94 so 1.42', '0.0 But 0.12', '0.12 when 0.26', '0.26 I 0.44', '0.44 had 0.6', '0.6 approached 0.94', '0.94 so 1.42',
'1.42 near 1.78', '1.78 to 2.02', '2.02 them, 2.24', '2.52 the 2.58', '2.58 common 2.9', '2.9 object, 3.3', '1.42 near 1.78', '1.78 to 2.02', '2.02 them, 2.24', '2.52 the 2.58', '2.58 common 2.9', '2.9 object, 3.3',
@@ -278,19 +282,17 @@ demo_words = [
] ]
def update_demo(transcript, edit_from_word, edit_to_word, prompt_end_time): def update_demo(mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word, prompt_end_time):
def impl(mode, smart_transcript, edit_word_mode): if transcript not in all_demo_texts:
if transcript.value not in all_demo_texts: return transcript, edit_from_word, edit_to_word, prompt_end_time
return [transcript, edit_from_word, edit_to_word, prompt_end_time]
replace_half = edit_word_mode == "Replace half" replace_half = edit_word_mode == "Replace half"
return [ return [
transcript.update(value=demo_text[mode]["smart" if smart_transcript else "regular"]), demo_text[mode]["smart" if smart_transcript else "regular"],
edit_from_word.update(value="0.26 I 0.44" if replace_half else "0.44 had 0.6"), "0.26 I 0.44" if replace_half else "0.44 had 0.6",
edit_to_word.update(value="3.72 which 3.78" if replace_half else "2.9 object, 3.3"), "3.72 which 3.78" if replace_half else "2.9 object, 3.3",
prompt_end_time.update(value=3.01), 3.01,
] ]
return impl
with gr.Blocks() as app: with gr.Blocks() as app:
@@ -302,7 +304,7 @@ with gr.Blocks() as app:
with gr.Row(): with gr.Row():
voicecraft_model_choice = gr.Radio(label="VoiceCraft model", value="giga830M", choices=["giga330M", "giga830M"]) voicecraft_model_choice = gr.Radio(label="VoiceCraft model", value="giga830M", choices=["giga330M", "giga830M"])
whisper_model_choice = gr.Radio(label="Whisper model", value="base.en", whisper_model_choice = gr.Radio(label="Whisper model", value="base.en",
choices=["tiny.en", "base.en", "small.en", "medium.en", "large"]) choices=[None, "tiny.en", "base.en", "small.en", "medium.en", "large"])
with gr.Row(): with gr.Row():
with gr.Column(scale=2): with gr.Column(scale=2):
@@ -315,7 +317,7 @@ with gr.Blocks() as app:
with gr.Accordion("Word end time", open=False): with gr.Accordion("Word end time", open=False):
transcript_with_end_time = gr.Textbox(label="End time", lines=5, interactive=False, info="End time after each word") transcript_with_end_time = gr.Textbox(label="End time", lines=5, interactive=False, info="End time after each word")
transcribe_btn = gr.Button(value="Transcribe", interactive=False) transcribe_btn = gr.Button(value="Transcribe")
with gr.Column(scale=3): with gr.Column(scale=3):
with gr.Group(): with gr.Group():
@@ -338,7 +340,7 @@ with gr.Blocks() as app:
edit_start_time = gr.Slider(label="Edit from time", minimum=0, maximum=60, step=0.01, value=0) edit_start_time = gr.Slider(label="Edit from time", minimum=0, maximum=60, step=0.01, value=0)
edit_end_time = gr.Slider(label="Edit to time", minimum=0, maximum=60, step=0.01, value=60) edit_end_time = gr.Slider(label="Edit to time", minimum=0, maximum=60, step=0.01, value=60)
run_btn = gr.Button(value="Run", interactive=False) run_btn = gr.Button(value="Run")
with gr.Column(scale=2): with gr.Column(scale=2):
output_audio = gr.Audio(label="Output Audio") output_audio = gr.Audio(label="Output Audio")
@@ -349,7 +351,7 @@ with gr.Blocks() as app:
sentence_selector = gr.Dropdown(label="Sentence", value=None, sentence_selector = gr.Dropdown(label="Sentence", value=None,
info="Select sentence you want to regenerate") info="Select sentence you want to regenerate")
sentence_audio = gr.Audio(label="Sentence Audio", scale=2) sentence_audio = gr.Audio(label="Sentence Audio", scale=2)
rerun_btn = gr.Button(value="Rerun", interactive=False) rerun_btn = gr.Button(value="Rerun")
with gr.Row(): with gr.Row():
with gr.Accordion("VoiceCraft config", open=False): with gr.Accordion("VoiceCraft config", open=False):
@@ -369,34 +371,40 @@ with gr.Blocks() as app:
silence_tokens = gr.Textbox(label="silence tokens", value="[1388,1898,131]") silence_tokens = gr.Textbox(label="silence tokens", value="[1388,1898,131]")
whisper_model = gr.State()
voicecraft_model = gr.State()
audio_tensors = gr.State() audio_tensors = gr.State()
word_info = gr.State() word_info = gr.State()
mode.change(fn=update_demo(transcript, edit_from_word, edit_to_word, prompt_end_time),
inputs=[mode, smart_transcript, edit_word_mode], mode.change(fn=update_demo,
inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word, prompt_end_time],
outputs=[transcript, edit_from_word, edit_to_word, prompt_end_time]) outputs=[transcript, edit_from_word, edit_to_word, prompt_end_time])
edit_word_mode.change(fn=update_demo(transcript, edit_from_word, edit_to_word, prompt_end_time), edit_word_mode.change(fn=update_demo,
inputs=[mode, smart_transcript, edit_word_mode], inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word, prompt_end_time],
outputs=[transcript, edit_from_word, edit_to_word, prompt_end_time]) outputs=[transcript, edit_from_word, edit_to_word, prompt_end_time])
smart_transcript.change(fn=update_demo(transcript, edit_from_word, edit_to_word, prompt_end_time), smart_transcript.change(fn=update_demo,
inputs=[mode, smart_transcript, edit_word_mode], inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word, prompt_end_time],
outputs=[transcript, edit_from_word, edit_to_word, prompt_end_time]) outputs=[transcript, edit_from_word, edit_to_word, prompt_end_time])
load_models_btn.click(fn=load_models(input_audio, transcribe_btn, run_btn, rerun_btn), load_models_btn.click(fn=load_models,
inputs=[whisper_model_choice, voicecraft_model_choice], inputs=[whisper_model_choice, voicecraft_model_choice],
outputs=[input_audio, transcribe_btn, run_btn, rerun_btn]) outputs=[whisper_model, voicecraft_model, input_audio])
input_audio.change(fn=update_input_audio(prompt_end_time, edit_start_time, edit_end_time), input_audio.change(fn=update_input_audio,
inputs=[input_audio], outputs=[prompt_end_time, edit_start_time, edit_end_time]) inputs=[input_audio],
transcribe_btn.click(fn=transcribe, inputs=[input_audio], outputs=[prompt_end_time, edit_start_time, edit_end_time])
transcribe_btn.click(fn=transcribe,
inputs=[whisper_model, input_audio],
outputs=[original_transcript, transcript_with_start_time, transcript_with_end_time, edit_from_word, edit_to_word, word_info]) outputs=[original_transcript, transcript_with_start_time, transcript_with_end_time, edit_from_word, edit_to_word, word_info])
mode.change(fn=change_mode(prompt_end_time, split_text, edit_word_mode, segment_control, precise_segment_control, long_tts_controls), mode.change(fn=change_mode,
inputs=[mode], outputs=[prompt_end_time, split_text, edit_word_mode, segment_control, precise_segment_control, long_tts_controls]) inputs=[mode],
outputs=[prompt_end_time, split_text, edit_word_mode, segment_control, precise_segment_control, long_tts_controls])
run_btn.click(fn=run, run_btn.click(fn=run,
inputs=[ inputs=[
left_margin, right_margin, voicecraft_model, left_margin, right_margin,
codec_audio_sr, codec_sr, codec_audio_sr, codec_sr,
top_k, top_p, temperature, top_k, top_p, temperature,
stop_repetition, sample_batch_size, stop_repetition, sample_batch_size,
@@ -405,14 +413,14 @@ with gr.Blocks() as app:
mode, prompt_end_time, edit_start_time, edit_end_time, mode, prompt_end_time, edit_start_time, edit_end_time,
split_text, sentence_selector, audio_tensors split_text, sentence_selector, audio_tensors
], ],
outputs=[ outputs=[output_audio, inference_transcript, sentence_selector, audio_tensors])
output_audio, inference_transcript, sentence_selector, audio_tensors
])
sentence_selector.change(fn=load_sentence, inputs=[sentence_selector, codec_audio_sr, audio_tensors], outputs=[sentence_audio]) sentence_selector.change(fn=load_sentence,
inputs=[sentence_selector, codec_audio_sr, audio_tensors],
outputs=[sentence_audio])
rerun_btn.click(fn=run, rerun_btn.click(fn=run,
inputs=[ inputs=[
left_margin, right_margin, voicecraft_model, left_margin, right_margin,
codec_audio_sr, codec_sr, codec_audio_sr, codec_sr,
top_k, top_p, temperature, top_k, top_p, temperature,
stop_repetition, sample_batch_size, stop_repetition, sample_batch_size,
@@ -421,16 +429,17 @@ with gr.Blocks() as app:
gr.State(value="Rerun"), prompt_end_time, edit_start_time, edit_end_time, gr.State(value="Rerun"), prompt_end_time, edit_start_time, edit_end_time,
split_text, sentence_selector, audio_tensors split_text, sentence_selector, audio_tensors
], ],
outputs=[ outputs=[output_audio, inference_transcript, sentence_audio, audio_tensors])
output_audio, inference_transcript, sentence_audio, audio_tensors
])
edit_word_mode.change(fn=update_bound_words(edit_start_time, edit_end_time), edit_from_word.change(fn=update_bound_word,
inputs=[edit_from_word, edit_to_word, edit_word_mode], outputs=[edit_start_time, edit_end_time]) inputs=[gr.State(True), edit_from_word, edit_word_mode],
edit_from_word.change(fn=update_bound_word(True, edit_start_time), outputs=[edit_start_time])
inputs=[edit_from_word, edit_word_mode], outputs=[edit_start_time]) edit_to_word.change(fn=update_bound_word,
edit_to_word.change(fn=update_bound_word(False, edit_end_time), inputs=[gr.State(False), edit_to_word, edit_word_mode],
inputs=[edit_to_word, edit_word_mode], outputs=[edit_end_time]) outputs=[edit_end_time])
edit_word_mode.change(fn=update_bound_words,
inputs=[edit_from_word, edit_to_word, edit_word_mode],
outputs=[edit_start_time, edit_end_time])
if __name__ == "__main__": if __name__ == "__main__":