diff --git a/gradio_app.py b/gradio_app.py index cab37c4..4124444 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -10,8 +10,12 @@ import os import io +whisper_model, voicecraft_model = None, None + + def load_models(whisper_model_choice, voicecraft_model_choice): - whisper_model, voicecraft_model = None, None + global whisper_model, voicecraft_model + if whisper_model_choice is not None: import whisper from whisper.tokenizer import get_tokenizer @@ -46,14 +50,10 @@ def load_models(whisper_model_choice, voicecraft_model_choice): "audio_tokenizer": AudioTokenizer(signature=encodec_fn) } - return [ - whisper_model, - voicecraft_model, - gr.Audio(interactive=True), - ] + return gr.Audio(interactive=True) -def transcribe(whisper_model, audio_path): +def transcribe(audio_path): if whisper_model is None: raise gr.Error("Whisper model not loaded") @@ -87,7 +87,7 @@ def get_output_audio(audio_tensors, codec_audio_sr): return buffer.read() -def run(voicecraft_model, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, temperature, +def run(left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, temperature, stop_repetition, sample_batch_size, kvcache, silence_tokens, audio_path, word_info, transcript, smart_transcript, mode, prompt_end_time, edit_start_time, edit_end_time, @@ -371,8 +371,6 @@ with gr.Blocks() as app: silence_tokens = gr.Textbox(label="silence tokens", value="[1388,1898,131]") - whisper_model = gr.State() - voicecraft_model = gr.State() audio_tensors = gr.State() word_info = gr.State() @@ -389,13 +387,13 @@ with gr.Blocks() as app: load_models_btn.click(fn=load_models, inputs=[whisper_model_choice, voicecraft_model_choice], - outputs=[whisper_model, voicecraft_model, input_audio]) + outputs=[input_audio]) input_audio.change(fn=update_input_audio, inputs=[input_audio], outputs=[prompt_end_time, edit_start_time, edit_end_time]) transcribe_btn.click(fn=transcribe, - inputs=[whisper_model, input_audio], + inputs=[input_audio], outputs=[original_transcript, transcript_with_start_time, transcript_with_end_time, edit_from_word, edit_to_word, word_info]) mode.change(fn=change_mode, @@ -404,7 +402,7 @@ with gr.Blocks() as app: run_btn.click(fn=run, inputs=[ - voicecraft_model, left_margin, right_margin, + left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, temperature, stop_repetition, sample_batch_size, @@ -420,7 +418,7 @@ with gr.Blocks() as app: outputs=[sentence_audio]) rerun_btn.click(fn=run, inputs=[ - voicecraft_model, left_margin, right_margin, + left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, temperature, stop_repetition, sample_batch_size,