From 1a219cf6da69df070cf03966f06433c9a0362de2 Mon Sep 17 00:00:00 2001 From: Stepan Zuev Date: Wed, 3 Apr 2024 20:24:34 +0300 Subject: [PATCH] bugfixes, seed support, better ui --- gradio_app.py | 136 ++++++++++++++++++++++++++++++++++---------------- 1 file changed, 93 insertions(+), 43 deletions(-) diff --git a/gradio_app.py b/gradio_app.py index 4124444..707defa 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -8,11 +8,24 @@ from data.tokenizer import ( from models import voicecraft import os import io +import numpy as np +import random whisper_model, voicecraft_model = None, None +def seed_everything(seed): + if seed != -1: + os.environ['PYTHONHASHSEED'] = str(seed) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + + def load_models(whisper_model_choice, voicecraft_model_choice): global whisper_model, voicecraft_model @@ -50,12 +63,13 @@ def load_models(whisper_model_choice, voicecraft_model_choice): "audio_tokenizer": AudioTokenizer(signature=encodec_fn) } - return gr.Audio(interactive=True) + return gr.Accordion() -def transcribe(audio_path): +def transcribe(seed, audio_path): if whisper_model is None: raise gr.Error("Whisper model not loaded") + seed_everything(seed) number_tokens = [ i @@ -73,6 +87,7 @@ def transcribe(audio_path): return [ transcript, transcript_with_start_time, transcript_with_end_time, + gr.Dropdown(value=choices[-1], choices=choices, interactive=True), # prompt_to_word gr.Dropdown(value=choices[0], choices=choices, interactive=True), # edit_from_word gr.Dropdown(value=choices[-1], choices=choices, interactive=True), # edit_to_word words @@ -87,7 +102,7 @@ def get_output_audio(audio_tensors, codec_audio_sr): return buffer.read() -def run(left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, temperature, +def run(seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, temperature, stop_repetition, sample_batch_size, kvcache, silence_tokens, audio_path, word_info, transcript, smart_transcript, mode, prompt_end_time, edit_start_time, edit_end_time, @@ -97,6 +112,7 @@ def run(left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, tempe if smart_transcript and (word_info is None): raise gr.Error("Can't use smart transcript: whisper transcript not found") + seed_everything(seed) if mode == "Long TTS": if split_text == "Newline": sentences = transcript.split('\n') @@ -192,6 +208,9 @@ def run(left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, tempe def update_input_audio(audio_path): + if audio_path is None: + return 0, 0, 0 + info = torchaudio.info(audio_path) max_time = round(info.num_frames / info.sample_rate, 2) return [ @@ -202,12 +221,12 @@ def update_input_audio(audio_path): def change_mode(mode): + tts_mode_controls, edit_mode_controls, edit_word_mode, split_text, long_tts_sentence_editor return [ - gr.Slider(visible=mode != "Edit"), - gr.Radio(visible=mode == "Long TTS"), + gr.Group(visible=mode != "Edit"), + gr.Group(visible=mode == "Edit"), gr.Radio(visible=mode == "Edit"), - gr.Row(visible=mode == "Edit"), - gr.Accordion(visible=mode == "Edit"), + gr.Radio(visible=mode == "Long TTS"), gr.Group(visible=mode == "Long TTS"), ] @@ -253,6 +272,8 @@ If disabled, you should write the target transcript yourself:
- In Edit mode write full prompt
""" +demo_original_transcript = " But when I had approached so near to them, the common object, which the sense deceives, lost not by distance any of its marks." + demo_text = { "TTS": { "smart": "I cannot believe that the same model can also do text to speech synthesis as well!", @@ -281,17 +302,35 @@ demo_words = [ '5.74 by 6.08', '6.08 distance 6.36', '6.36 any 6.92', '6.92 of 7.12', '7.12 its 7.26', '7.26 marks. 7.54' ] +demo_word_info = [ + {'word': ' But', 'start': 0.0, 'end': 0.12}, {'word': ' when', 'start': 0.12, 'end': 0.26}, + {'word': ' I', 'start': 0.26, 'end': 0.44}, {'word': ' had', 'start': 0.44, 'end': 0.6}, + {'word': ' approached', 'start': 0.6, 'end': 0.94}, {'word': ' so', 'start': 0.94, 'end': 1.42}, + {'word': ' near', 'start': 1.42, 'end': 1.78}, {'word': ' to', 'start': 1.78, 'end': 2.02}, + {'word': ' them,', 'start': 2.02, 'end': 2.24}, {'word': ' the', 'start': 2.52, 'end': 2.58}, + {'word': ' common', 'start': 2.58, 'end': 2.9}, {'word': ' object,', 'start': 2.9, 'end': 3.3}, + {'word': ' which', 'start': 3.72, 'end': 3.78}, {'word': ' the', 'start': 3.78, 'end': 3.98}, + {'word': ' sense', 'start': 3.98, 'end': 4.18}, {'word': ' deceives,', 'start': 4.18, 'end': 4.88}, + {'word': ' lost', 'start': 5.06, 'end': 5.26}, {'word': ' not', 'start': 5.26, 'end': 5.74}, + {'word': ' by', 'start': 5.74, 'end': 6.08}, {'word': ' distance', 'start': 6.08, 'end': 6.36}, + {'word': ' any', 'start': 6.36, 'end': 6.92}, {'word': ' of', 'start': 6.92, 'end': 7.12}, + {'word': ' its', 'start': 7.12, 'end': 7.26}, {'word': ' marks.', 'start': 7.26, 'end': 7.54} +] -def update_demo(mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word, prompt_end_time): + +def update_demo(mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word): if transcript not in all_demo_texts: - return transcript, edit_from_word, edit_to_word, prompt_end_time + return transcript, edit_from_word, edit_to_word replace_half = edit_word_mode == "Replace half" + change_edit_from_word = edit_from_word == demo_words[2] or edit_from_word == demo_words[3] + change_edit_to_word = edit_to_word == demo_words[11] or edit_to_word == demo_words[12] + demo_edit_from_word_value = demo_words[2] if replace_half else demo_words[3] + demo_edit_to_word_value = demo_words[12] if replace_half else demo_words[11] return [ demo_text[mode]["smart" if smart_transcript else "regular"], - "0.26 I 0.44" if replace_half else "0.44 had 0.6", - "3.72 which 3.78" if replace_half else "2.9 object, 3.3", - 3.01, + demo_edit_from_word_value if change_edit_from_word else edit_from_word, + demo_edit_to_word_value if change_edit_to_word else edit_to_word, ] @@ -300,7 +339,7 @@ with gr.Blocks() as app: with gr.Column(scale=2): load_models_btn = gr.Button(value="Load models") with gr.Column(scale=5): - with gr.Accordion("Select models", open=False): + with gr.Accordion("Select models", open=False) as models_selector: with gr.Row(): voicecraft_model_choice = gr.Radio(label="VoiceCraft model", value="giga830M", choices=["giga330M", "giga830M"]) whisper_model_choice = gr.Radio(label="Whisper model", value="base.en", @@ -308,9 +347,9 @@ with gr.Blocks() as app: with gr.Row(): with gr.Column(scale=2): - input_audio = gr.Audio(value="./demo/84_121550_000074_000000.wav", label="Input Audio", type="filepath", interactive=False) + input_audio = gr.Audio(value="./demo/84_121550_000074_000000.wav", label="Input Audio", type="filepath") with gr.Group(): - original_transcript = gr.Textbox(label="Original transcript", lines=5, interactive=False, + original_transcript = gr.Textbox(label="Original transcript", lines=5, value=demo_original_transcript, interactive=False, info="Use whisper model to get the transcript. Fix it if necessary.") with gr.Accordion("Word start time", open=False): transcript_with_start_time = gr.Textbox(label="Start time", lines=5, interactive=False, info="Start time before each word") @@ -325,20 +364,26 @@ with gr.Blocks() as app: with gr.Row(): smart_transcript = gr.Checkbox(label="Smart transcript", value=True) with gr.Accordion(label="?", open=False): - info = gr.HTML(value=smart_transcript_info) - mode = gr.Radio(label="Mode", choices=["TTS", "Edit", "Long TTS"], value="TTS") + info = gr.Markdown(value=smart_transcript_info) + + with gr.Row(): + mode = gr.Radio(label="Mode", choices=["TTS", "Edit", "Long TTS"], value="TTS") + split_text = gr.Radio(label="Split text", choices=["Newline", "Sentence"], value="Newline", + info="Split text into parts and run TTS for each part.", visible=False) + edit_word_mode = gr.Radio(label="Edit word mode", choices=["Replace half", "Replace all"], value="Replace half", + info="What to do with first and last word", visible=False) - prompt_end_time = gr.Slider(label="Prompt end time", minimum=0, maximum=7.93, step=0.01, value=3.01) - split_text = gr.Radio(label="Split text", choices=["Newline", "Sentence"], value="Newline", visible=False, - info="Split text into parts and run TTS for each part.") - edit_word_mode = gr.Radio(label="Edit word mode", choices=["Replace half", "Replace all"], value="Replace half", visible=False, - info="What to do with first and last word") - with gr.Row(visible=False) as segment_control: - edit_from_word = gr.Dropdown(label="First word to edit", choices=demo_words, interactive=True) - edit_to_word = gr.Dropdown(label="Last word to edit", choices=demo_words, interactive=True) - with gr.Accordion("Precise segment control", open=False, visible=False) as precise_segment_control: - edit_start_time = gr.Slider(label="Edit from time", minimum=0, maximum=60, step=0.01, value=0) - edit_end_time = gr.Slider(label="Edit to time", minimum=0, maximum=60, step=0.01, value=60) + with gr.Group() as tts_mode_controls: + prompt_to_word = gr.Dropdown(label="Last word in prompt", choices=demo_words, value=demo_words[10], interactive=True) + prompt_end_time = gr.Slider(label="Prompt end time", minimum=0, maximum=7.93, step=0.01, value=3.01) + + with gr.Group(visible=False) as edit_mode_controls: + with gr.Row(): + edit_from_word = gr.Dropdown(label="First word to edit", choices=demo_words, value=demo_words[2], interactive=True) + edit_to_word = gr.Dropdown(label="Last word to edit", choices=demo_words, value=demo_words[12], interactive=True) + with gr.Row(): + edit_start_time = gr.Slider(label="Edit from time", minimum=0, maximum=7.93, step=0.01, value=0.35) + edit_end_time = gr.Slider(label="Edit to time", minimum=0, maximum=7.93, step=0.01, value=3.75) run_btn = gr.Button(value="Run") @@ -347,7 +392,7 @@ with gr.Blocks() as app: with gr.Accordion("Inference transcript", open=False): inference_transcript = gr.Textbox(label="Inference transcript", lines=5, interactive=False, info="Inference was performed on this transcript.") - with gr.Group(visible=False) as long_tts_controls: + with gr.Group(visible=False) as long_tts_sentence_editor: sentence_selector = gr.Dropdown(label="Sentence", value=None, info="Select sentence you want to regenerate") sentence_audio = gr.Audio(label="Sentence Audio", scale=2) @@ -355,6 +400,7 @@ with gr.Blocks() as app: with gr.Row(): with gr.Accordion("VoiceCraft config", open=False): + seed = gr.Number(label="seed", value=-1, precision=0) left_margin = gr.Number(label="left_margin", value=0.08) right_margin = gr.Number(label="right_margin", value=0.08) codec_audio_sr = gr.Number(label="codec_audio_sr", value=16000) @@ -372,37 +418,38 @@ with gr.Blocks() as app: audio_tensors = gr.State() - word_info = gr.State() + word_info = gr.State(value=demo_word_info) mode.change(fn=update_demo, - inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word, prompt_end_time], - outputs=[transcript, edit_from_word, edit_to_word, prompt_end_time]) + inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word], + outputs=[transcript, edit_from_word, edit_to_word]) edit_word_mode.change(fn=update_demo, - inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word, prompt_end_time], - outputs=[transcript, edit_from_word, edit_to_word, prompt_end_time]) + inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word], + outputs=[transcript, edit_from_word, edit_to_word]) smart_transcript.change(fn=update_demo, - inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word, prompt_end_time], - outputs=[transcript, edit_from_word, edit_to_word, prompt_end_time]) + inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word], + outputs=[transcript, edit_from_word, edit_to_word]) load_models_btn.click(fn=load_models, inputs=[whisper_model_choice, voicecraft_model_choice], - outputs=[input_audio]) + outputs=[models_selector]) - input_audio.change(fn=update_input_audio, + input_audio.upload(fn=update_input_audio, inputs=[input_audio], outputs=[prompt_end_time, edit_start_time, edit_end_time]) transcribe_btn.click(fn=transcribe, - inputs=[input_audio], - outputs=[original_transcript, transcript_with_start_time, transcript_with_end_time, edit_from_word, edit_to_word, word_info]) + inputs=[seed, input_audio], + outputs=[original_transcript, transcript_with_start_time, transcript_with_end_time, + prompt_to_word, edit_from_word, edit_to_word, word_info]) mode.change(fn=change_mode, inputs=[mode], - outputs=[prompt_end_time, split_text, edit_word_mode, segment_control, precise_segment_control, long_tts_controls]) + outputs=[tts_mode_controls, edit_mode_controls, edit_word_mode, split_text, long_tts_sentence_editor]) run_btn.click(fn=run, inputs=[ - left_margin, right_margin, + seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, temperature, stop_repetition, sample_batch_size, @@ -418,7 +465,7 @@ with gr.Blocks() as app: outputs=[sentence_audio]) rerun_btn.click(fn=run, inputs=[ - left_margin, right_margin, + seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, temperature, stop_repetition, sample_batch_size, @@ -429,6 +476,9 @@ with gr.Blocks() as app: ], outputs=[output_audio, inference_transcript, sentence_audio, audio_tensors]) + prompt_to_word.change(fn=update_bound_word, + inputs=[gr.State(False), prompt_to_word, gr.State("Replace all")], + outputs=[prompt_end_time]) edit_from_word.change(fn=update_bound_word, inputs=[gr.State(True), edit_from_word, edit_word_mode], outputs=[edit_start_time])