diff --git a/gradio_app.py b/gradio_app.py index 956175c..cab37c4 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -6,62 +6,63 @@ from data.tokenizer import ( TextTokenizer, ) from models import voicecraft -import whisper -from whisper.tokenizer import get_tokenizer import os import io -whisper_model = None -voicecraft_model = None -device = "cuda" if torch.cuda.is_available() else "cpu" +def load_models(whisper_model_choice, voicecraft_model_choice): + whisper_model, voicecraft_model = None, None + if whisper_model_choice is not None: + import whisper + from whisper.tokenizer import get_tokenizer + whisper_model = { + "model": whisper.load_model(whisper_model_choice), + "tokenizer": get_tokenizer(multilingual=False) + } + + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = "0" + device = "cuda" if torch.cuda.is_available() else "cpu" + + voicecraft_name = f"{voicecraft_model_choice}.pth" + ckpt_fn = f"./pretrained_models/{voicecraft_name}" + encodec_fn = "./pretrained_models/encodec_4cb2048_giga.th" + if not os.path.exists(ckpt_fn): + os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/{voicecraft_name}\?download\=true") + os.system(f"mv {voicecraft_name}\?download\=true ./pretrained_models/{voicecraft_name}") + if not os.path.exists(encodec_fn): + os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th") + os.system(f"mv encodec_4cb2048_giga.th ./pretrained_models/encodec_4cb2048_giga.th") + + ckpt = torch.load(ckpt_fn, map_location="cpu") + model = voicecraft.VoiceCraft(ckpt["config"]) + model.load_state_dict(ckpt["model"]) + model.to(device) + model.eval() + voicecraft_model = { + "ckpt": ckpt, + "model": model, + "text_tokenizer": TextTokenizer(backend="espeak"), + "audio_tokenizer": AudioTokenizer(signature=encodec_fn) + } + + return [ + whisper_model, + voicecraft_model, + gr.Audio(interactive=True), + ] -def load_models(input_audio, transcribe_btn, run_btn, rerun_btn): - def impl(whisper_model_choice, voicecraft_model_choice): - global whisper_model, voicecraft_model - whisper_model = whisper.load_model(whisper_model_choice) - - os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" - os.environ["CUDA_VISIBLE_DEVICES"] = "0" - - voicecraft_name = f"{voicecraft_model_choice}.pth" - ckpt_fn = f"./pretrained_models/{voicecraft_name}" - encodec_fn = "./pretrained_models/encodec_4cb2048_giga.th" - if not os.path.exists(ckpt_fn): - os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/{voicecraft_name}\?download\=true") - os.system(f"mv {voicecraft_name}\?download\=true ./pretrained_models/{voicecraft_name}") - if not os.path.exists(encodec_fn): - os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th") - os.system(f"mv encodec_4cb2048_giga.th ./pretrained_models/encodec_4cb2048_giga.th") - - voicecraft_model = {} - voicecraft_model["ckpt"] = torch.load(ckpt_fn, map_location="cpu") - voicecraft_model["model"] = voicecraft.VoiceCraft(voicecraft_model["ckpt"]["config"]) - voicecraft_model["model"].load_state_dict(voicecraft_model["ckpt"]["model"]) - voicecraft_model["model"].to(device) - voicecraft_model["model"].eval() - - voicecraft_model["text_tokenizer"] = TextTokenizer(backend="espeak") - voicecraft_model["audio_tokenizer"] = AudioTokenizer(signature=encodec_fn) - - return [ - input_audio.update(interactive=True), - transcribe_btn.update(interactive=True), - run_btn.update(interactive=True), - rerun_btn.update(interactive=True) - ] - return impl - - -def transcribe(audio_path): - tokenizer = get_tokenizer(multilingual=False) +def transcribe(whisper_model, audio_path): + if whisper_model is None: + raise gr.Error("Whisper model not loaded") + number_tokens = [ i - for i in range(tokenizer.eot) - if all(c in "0123456789" for c in tokenizer.decode([i]).removeprefix(" ")) + for i in range(whisper_model["tokenizer"].eot) + if all(c in "0123456789" for c in whisper_model["tokenizer"].decode([i]).removeprefix(" ")) ] - result = whisper_model.transcribe(audio_path, suppress_tokens=[-1] + number_tokens, word_timestamps=True) + result = whisper_model["model"].transcribe(audio_path, suppress_tokens=[-1] + number_tokens, word_timestamps=True) words = [word_info for segment in result["segments"] for word_info in segment["words"]] transcript = result["text"] @@ -69,12 +70,12 @@ def transcribe(audio_path): transcript_with_end_time = " ".join([f"{word['word']} {word['end']}" for word in words]) choices = [f"{word['start']} {word['word']} {word['end']}" for word in words] - edit_from_word = gr.Dropdown(label="First word to edit", value=choices[0], choices=choices, interactive=True) - edit_to_word = gr.Dropdown(label="Last word to edit", value=choices[-1], choices=choices, interactive=True) return [ transcript, transcript_with_start_time, transcript_with_end_time, - edit_from_word, edit_to_word, words + gr.Dropdown(value=choices[0], choices=choices, interactive=True), # edit_from_word + gr.Dropdown(value=choices[-1], choices=choices, interactive=True), # edit_to_word + words ] @@ -86,11 +87,16 @@ def get_output_audio(audio_tensors, codec_audio_sr): return buffer.read() -def run(left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, temperature, +def run(voicecraft_model, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, temperature, stop_repetition, sample_batch_size, kvcache, silence_tokens, audio_path, word_info, transcript, smart_transcript, mode, prompt_end_time, edit_start_time, edit_end_time, split_text, selected_sentence, previous_audio_tensors): + if voicecraft_model is None: + raise gr.Error("VoiceCraft model not loaded") + if smart_transcript and (word_info is None): + raise gr.Error("Can't use smart transcript: whisper transcript not found") + if mode == "Long TTS": if split_text == "Newline": sentences = transcript.split('\n') @@ -104,6 +110,7 @@ def run(left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, tempe else: sentences = [transcript.replace("\n", " ")] + device = "cuda" if torch.cuda.is_available() else "cpu" info = torchaudio.info(audio_path) audio_dur = info.num_frames / info.sample_rate @@ -116,7 +123,7 @@ def run(left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, tempe if mode != "Edit": from inference_tts_scale import inference_one_sample - if smart_transcript: + if smart_transcript: target_transcript = "" for word in word_info: if word["end"] < prompt_end_time: @@ -175,8 +182,7 @@ def run(left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, tempe if mode != "Rerun": output_audio = get_output_audio(audio_tensors, codec_audio_sr) sentences = [f"{idx}: {text}" for idx, text in enumerate(sentences)] - component = gr.Dropdown(label="Sentence", choices=sentences, value=sentences[0], - info="Select sentence you want to regenerate") + component = gr.Dropdown(choices=sentences, value=sentences[0]) return output_audio, inference_transcript, component, audio_tensors else: previous_audio_tensors[selected_sentence_idx] = audio_tensors[0] @@ -185,29 +191,25 @@ def run(left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, tempe return output_audio, inference_transcript, sentence_audio, previous_audio_tensors -def update_input_audio(prompt_end_time, edit_start_time, edit_end_time): - def impl(audio_path): - info = torchaudio.info(audio_path) - max_time = round(info.num_frames / info.sample_rate, 2) - return [ - prompt_end_time.update(maximum=max_time, value=max_time), - edit_start_time.update(maximum=max_time, value=0), - edit_end_time.update(maximum=max_time, value=max_time), - ] - return impl +def update_input_audio(audio_path): + info = torchaudio.info(audio_path) + max_time = round(info.num_frames / info.sample_rate, 2) + return [ + gr.Slider(maximum=max_time, value=max_time), + gr.Slider(maximum=max_time, value=0), + gr.Slider(maximum=max_time, value=max_time), + ] -def change_mode(prompt_end_time, split_text, edit_word_mode, segment_control, precise_segment_control, long_tts_controls): - def impl(mode): - return [ - prompt_end_time.update(visible=mode != "Edit"), - split_text.update(visible=mode == "Long TTS"), - edit_word_mode.update(visible=mode == "Edit"), - segment_control.update(visible=mode == "Edit"), - precise_segment_control.update(visible=mode == "Edit"), - long_tts_controls.update(visible=mode == "Long TTS"), - ] - return impl +def change_mode(mode): + return [ + gr.Slider(visible=mode != "Edit"), + gr.Radio(visible=mode == "Long TTS"), + gr.Radio(visible=mode == "Edit"), + gr.Row(visible=mode == "Edit"), + gr.Accordion(visible=mode == "Edit"), + gr.Group(visible=mode == "Long TTS"), + ] def load_sentence(selected_sentence, codec_audio_sr, audio_tensors): @@ -218,28 +220,27 @@ def load_sentence(selected_sentence, codec_audio_sr, audio_tensors): return get_output_audio([audio_tensors[selected_sentence_idx]], codec_audio_sr) -def update_bound_word(is_first_word, edit_time): - def impl(selected_word, edit_word_mode): - word_start_time = float(selected_word.split(' ')[0]) - word_end_time = float(selected_word.split(' ')[-1]) - if edit_word_mode == "Replace half": - bound_time = (word_start_time + word_end_time) / 2 - elif is_first_word: - bound_time = word_start_time - else: - bound_time = word_end_time +def update_bound_word(is_first_word, selected_word, edit_word_mode): + if selected_word is None: + return None - return edit_time.update(value=bound_time) - return impl + word_start_time = float(selected_word.split(' ')[0]) + word_end_time = float(selected_word.split(' ')[-1]) + if edit_word_mode == "Replace half": + bound_time = (word_start_time + word_end_time) / 2 + elif is_first_word: + bound_time = word_start_time + else: + bound_time = word_end_time + + return bound_time -def update_bound_words(edit_start_time, edit_end_time): - def impl(from_selected_word, to_selected_word, edit_word_mode): - return [ - update_bound_word(True, edit_start_time)(from_selected_word, edit_word_mode), - update_bound_word(True, edit_end_time)(to_selected_word, edit_word_mode), - ] - return impl +def update_bound_words(from_selected_word, to_selected_word, edit_word_mode): + return [ + update_bound_word(True, from_selected_word, edit_word_mode), + update_bound_word(False, to_selected_word, edit_word_mode), + ] smart_transcript_info = """ @@ -251,6 +252,7 @@ If disabled, you should write the target transcript yourself:
- In Long TTS select split by newline (SENTENCE SPLIT WON'T WORK) and start each line with a prompt transcript.
- In Edit mode write full prompt
""" + demo_text = { "TTS": { "smart": "I cannot believe that the same model can also do text to speech synthesis as well!", @@ -269,7 +271,9 @@ demo_text = { "But when I had approached so near to them, the common If some sentences sound odd, just rerun TTS on them, no need to generate the whole text again!" } } + all_demo_texts = {vv for k, v in demo_text.items() for kk, vv in v.items()} + demo_words = [ '0.0 But 0.12', '0.12 when 0.26', '0.26 I 0.44', '0.44 had 0.6', '0.6 approached 0.94', '0.94 so 1.42', '1.42 near 1.78', '1.78 to 2.02', '2.02 them, 2.24', '2.52 the 2.58', '2.58 common 2.9', '2.9 object, 3.3', @@ -278,19 +282,17 @@ demo_words = [ ] -def update_demo(transcript, edit_from_word, edit_to_word, prompt_end_time): - def impl(mode, smart_transcript, edit_word_mode): - if transcript.value not in all_demo_texts: - return [transcript, edit_from_word, edit_to_word, prompt_end_time] - - replace_half = edit_word_mode == "Replace half" - return [ - transcript.update(value=demo_text[mode]["smart" if smart_transcript else "regular"]), - edit_from_word.update(value="0.26 I 0.44" if replace_half else "0.44 had 0.6"), - edit_to_word.update(value="3.72 which 3.78" if replace_half else "2.9 object, 3.3"), - prompt_end_time.update(value=3.01), - ] - return impl +def update_demo(mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word, prompt_end_time): + if transcript not in all_demo_texts: + return transcript, edit_from_word, edit_to_word, prompt_end_time + + replace_half = edit_word_mode == "Replace half" + return [ + demo_text[mode]["smart" if smart_transcript else "regular"], + "0.26 I 0.44" if replace_half else "0.44 had 0.6", + "3.72 which 3.78" if replace_half else "2.9 object, 3.3", + 3.01, + ] with gr.Blocks() as app: @@ -302,7 +304,7 @@ with gr.Blocks() as app: with gr.Row(): voicecraft_model_choice = gr.Radio(label="VoiceCraft model", value="giga830M", choices=["giga330M", "giga830M"]) whisper_model_choice = gr.Radio(label="Whisper model", value="base.en", - choices=["tiny.en", "base.en", "small.en", "medium.en", "large"]) + choices=[None, "tiny.en", "base.en", "small.en", "medium.en", "large"]) with gr.Row(): with gr.Column(scale=2): @@ -315,7 +317,7 @@ with gr.Blocks() as app: with gr.Accordion("Word end time", open=False): transcript_with_end_time = gr.Textbox(label="End time", lines=5, interactive=False, info="End time after each word") - transcribe_btn = gr.Button(value="Transcribe", interactive=False) + transcribe_btn = gr.Button(value="Transcribe") with gr.Column(scale=3): with gr.Group(): @@ -338,7 +340,7 @@ with gr.Blocks() as app: edit_start_time = gr.Slider(label="Edit from time", minimum=0, maximum=60, step=0.01, value=0) edit_end_time = gr.Slider(label="Edit to time", minimum=0, maximum=60, step=0.01, value=60) - run_btn = gr.Button(value="Run", interactive=False) + run_btn = gr.Button(value="Run") with gr.Column(scale=2): output_audio = gr.Audio(label="Output Audio") @@ -349,7 +351,7 @@ with gr.Blocks() as app: sentence_selector = gr.Dropdown(label="Sentence", value=None, info="Select sentence you want to regenerate") sentence_audio = gr.Audio(label="Sentence Audio", scale=2) - rerun_btn = gr.Button(value="Rerun", interactive=False) + rerun_btn = gr.Button(value="Rerun") with gr.Row(): with gr.Accordion("VoiceCraft config", open=False): @@ -369,34 +371,40 @@ with gr.Blocks() as app: silence_tokens = gr.Textbox(label="silence tokens", value="[1388,1898,131]") + whisper_model = gr.State() + voicecraft_model = gr.State() audio_tensors = gr.State() word_info = gr.State() - mode.change(fn=update_demo(transcript, edit_from_word, edit_to_word, prompt_end_time), - inputs=[mode, smart_transcript, edit_word_mode], + + mode.change(fn=update_demo, + inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word, prompt_end_time], outputs=[transcript, edit_from_word, edit_to_word, prompt_end_time]) - edit_word_mode.change(fn=update_demo(transcript, edit_from_word, edit_to_word, prompt_end_time), - inputs=[mode, smart_transcript, edit_word_mode], + edit_word_mode.change(fn=update_demo, + inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word, prompt_end_time], outputs=[transcript, edit_from_word, edit_to_word, prompt_end_time]) - smart_transcript.change(fn=update_demo(transcript, edit_from_word, edit_to_word, prompt_end_time), - inputs=[mode, smart_transcript, edit_word_mode], + smart_transcript.change(fn=update_demo, + inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word, prompt_end_time], outputs=[transcript, edit_from_word, edit_to_word, prompt_end_time]) - load_models_btn.click(fn=load_models(input_audio, transcribe_btn, run_btn, rerun_btn), + load_models_btn.click(fn=load_models, inputs=[whisper_model_choice, voicecraft_model_choice], - outputs=[input_audio, transcribe_btn, run_btn, rerun_btn]) - - input_audio.change(fn=update_input_audio(prompt_end_time, edit_start_time, edit_end_time), - inputs=[input_audio], outputs=[prompt_end_time, edit_start_time, edit_end_time]) - transcribe_btn.click(fn=transcribe, inputs=[input_audio], + outputs=[whisper_model, voicecraft_model, input_audio]) + + input_audio.change(fn=update_input_audio, + inputs=[input_audio], + outputs=[prompt_end_time, edit_start_time, edit_end_time]) + transcribe_btn.click(fn=transcribe, + inputs=[whisper_model, input_audio], outputs=[original_transcript, transcript_with_start_time, transcript_with_end_time, edit_from_word, edit_to_word, word_info]) - mode.change(fn=change_mode(prompt_end_time, split_text, edit_word_mode, segment_control, precise_segment_control, long_tts_controls), - inputs=[mode], outputs=[prompt_end_time, split_text, edit_word_mode, segment_control, precise_segment_control, long_tts_controls]) + mode.change(fn=change_mode, + inputs=[mode], + outputs=[prompt_end_time, split_text, edit_word_mode, segment_control, precise_segment_control, long_tts_controls]) run_btn.click(fn=run, inputs=[ - left_margin, right_margin, + voicecraft_model, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, temperature, stop_repetition, sample_batch_size, @@ -405,14 +413,14 @@ with gr.Blocks() as app: mode, prompt_end_time, edit_start_time, edit_end_time, split_text, sentence_selector, audio_tensors ], - outputs=[ - output_audio, inference_transcript, sentence_selector, audio_tensors - ]) + outputs=[output_audio, inference_transcript, sentence_selector, audio_tensors]) - sentence_selector.change(fn=load_sentence, inputs=[sentence_selector, codec_audio_sr, audio_tensors], outputs=[sentence_audio]) + sentence_selector.change(fn=load_sentence, + inputs=[sentence_selector, codec_audio_sr, audio_tensors], + outputs=[sentence_audio]) rerun_btn.click(fn=run, inputs=[ - left_margin, right_margin, + voicecraft_model, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, temperature, stop_repetition, sample_batch_size, @@ -421,16 +429,17 @@ with gr.Blocks() as app: gr.State(value="Rerun"), prompt_end_time, edit_start_time, edit_end_time, split_text, sentence_selector, audio_tensors ], - outputs=[ - output_audio, inference_transcript, sentence_audio, audio_tensors - ]) + outputs=[output_audio, inference_transcript, sentence_audio, audio_tensors]) - edit_word_mode.change(fn=update_bound_words(edit_start_time, edit_end_time), - inputs=[edit_from_word, edit_to_word, edit_word_mode], outputs=[edit_start_time, edit_end_time]) - edit_from_word.change(fn=update_bound_word(True, edit_start_time), - inputs=[edit_from_word, edit_word_mode], outputs=[edit_start_time]) - edit_to_word.change(fn=update_bound_word(False, edit_end_time), - inputs=[edit_to_word, edit_word_mode], outputs=[edit_end_time]) + edit_from_word.change(fn=update_bound_word, + inputs=[gr.State(True), edit_from_word, edit_word_mode], + outputs=[edit_start_time]) + edit_to_word.change(fn=update_bound_word, + inputs=[gr.State(False), edit_to_word, edit_word_mode], + outputs=[edit_end_time]) + edit_word_mode.change(fn=update_bound_words, + inputs=[edit_from_word, edit_to_word, edit_word_mode], + outputs=[edit_start_time, edit_end_time]) if __name__ == "__main__":