Files
AnimateDiff/webui.py
2023-07-16 22:08:25 -04:00

308 lines
17 KiB
Python

import argparse
import gradio as gr
import copy
import os
import glob
import utils as help
sep = '\\' if help.is_windows() else '/'
all_motion_model_opts = ["mm_sd_v14.ckpt", "mm_sd_v15.ckpt"]
def get_available_motion_models():
motion_model_opts_path = os.path.join(os.getcwd(), os.path.join("models", "Motion_Module"))
motion_model_opts = sorted([ckpt for ckpt in glob.glob(os.path.join(motion_model_opts_path, f"*.ckpt"))])
return motion_model_opts
def get_available_sd_models():
sd_model_opts_path = os.path.join(os.getcwd(), os.path.join("models", "StableDiffusion"))
sd_model_opts = sorted([safetensor.split(sep)[-1] for safetensor in glob.glob(os.path.join(sd_model_opts_path, f"*.safetensors"))])
return sd_model_opts
def get_available_db_models():
db_model_opts_path = os.path.join(os.getcwd(), os.path.join("models", "DreamBooth_LoRA"))
db_model_opts = sorted([safetensor.split(sep)[-1] for safetensor in glob.glob(os.path.join(db_model_opts_path, f"*.safetensors"))])
return db_model_opts
def get_db_config():
prompt_configs_path = os.path.join(os.path.join(os.getcwd(), "configs"), "prompts")
return sorted([(prompt_yaml.split(sep)[-1]) for prompt_yaml in glob.glob(os.path.join(prompt_configs_path, f"*.yaml"))])
def get_sd_config():
inference_configs_path = os.path.join(os.path.join(os.getcwd(), "configs"), "inference")
return sorted([(inference_yaml.split(sep)[-1]) for inference_yaml in glob.glob(os.path.join(inference_configs_path, f"*.yaml"))])
def set_motion_model(menu_opt: gr.SelectData):
model_name = menu_opt.value
motion_model_opts_path = os.path.join(os.getcwd(), os.path.join("models", "Motion_Module"))
motion_model_opts = sorted([ckpt for ckpt in glob.glob(os.path.join(motion_model_opts_path, f"*.ckpt"))])
motion_model_map = {"mm_sd_v14.ckpt": "1RqkQuGPaCO5sGZ6V6KZ-jUWmsRu48Kdq",
"mm_sd_v15.ckpt": "1ql0g_Ys4UCz2RnokYlBjyOYPbttbIpbu"}
if not os.path.join(motion_model_opts_path, model_name) in motion_model_opts: # download
help.download_from_drive_gdown(motion_model_map[model_name], os.path.join(motion_model_opts_path, model_name))
return gr.update(value=os.path.join(motion_model_opts_path, model_name)) # model path
def set_sd_model(menu_opt: gr.SelectData):
model_name = menu_opt.value
sd_model_opts_path = os.path.join(os.getcwd(), os.path.join("models", "StableDiffusion"))
return gr.update(value=os.path.join(sd_model_opts_path, model_name)), gr.update(value=os.path.join(sd_model_opts_path, model_name)) # sd path, pretrained path
def set_db_model(menu_opt: gr.SelectData):
model_name = menu_opt.value
db_model_opts_path = os.path.join(os.getcwd(), os.path.join("models", "DreamBooth_LoRA"))
return gr.update(value=os.path.join(db_model_opts_path, model_name)) # db path
def update_available_sd_models():
sd_model_opts_path = os.path.join(os.getcwd(), os.path.join("models", "StableDiffusion"))
sd_model_opts = sorted([safetensor.split(sep)[-1] for safetensor in glob.glob(os.path.join(sd_model_opts_path, f"*.safetensors"))])
return gr.update(choices=sd_model_opts)
def update_available_db_models():
db_model_opts_path = os.path.join(os.getcwd(), os.path.join("models", "DreamBooth_LoRA"))
db_model_opts = sorted([safetensor.split(sep)[-1] for safetensor in glob.glob(os.path.join(db_model_opts_path, f"*.safetensors"))])
return gr.update(choices=db_model_opts)
def update_sd_config():
inference_configs_path = os.path.join(os.path.join(os.getcwd(), "configs"), "inference")
return gr.update(choices=sorted([(inference_yaml.split(sep)[-1]) for inference_yaml in glob.glob(os.path.join(inference_configs_path, f"*.yaml"))]))
def update_db_config():
prompt_configs_path = os.path.join(os.path.join(os.getcwd(), "configs"), "prompts")
return gr.update(choices=sorted([(prompt_yaml.split(sep)[-1]) for prompt_yaml in glob.glob(os.path.join(prompt_configs_path, f"*.yaml"))]))
def load_db_config(filename: gr.SelectData):
filename = filename.value
global prompt_config_dict
prompt_configs_path = os.path.join(os.path.join(os.getcwd(), "configs"), "prompts")
# populate the dictionary
prompt_config_dict = help.yaml_to_dict(os.path.join(prompt_configs_path, f"{filename}"))
name_only = list(prompt_config_dict.keys())[0]
help.verbose_print(f"Config Key Name:\t{name_only}")
# return populated UI components
config_name = name_only
motion_model_path = list(prompt_config_dict[name_only]["motion_module"])[0]
base_path = str(prompt_config_dict[name_only]["base"])
db_path = str(prompt_config_dict[name_only]["path"])
steps = int(prompt_config_dict[name_only]["steps"])
guidance_scale = float(prompt_config_dict[name_only]["guidance_scale"])
lora_alpha = float(prompt_config_dict[name_only]["lora_alpha"]) if "lora_alpha" in prompt_config_dict[name_only] else 1.0
seed_list = list(prompt_config_dict[name_only]["seed"])
prompt_list = list(prompt_config_dict[name_only]["prompt"])
n_prompt_list = list(prompt_config_dict[name_only]["n_prompt"])
seed1 = str(seed_list[0]) if len(seed_list) > 0 else "-1"
prompt1 = str(prompt_list[0]) if len(prompt_list) > 0 else ""
n_prompt1 = str(n_prompt_list[0]) if len(n_prompt_list) > 0 else ""
seed2 = str(seed_list[1]) if len(seed_list) > 1 else "-1"
prompt2 = str(prompt_list[1]) if len(prompt_list) > 1 else ""
n_prompt2 = str(n_prompt_list[1]) if len(n_prompt_list) > 1 else ""
seed3 = str(seed_list[2]) if len(seed_list) > 2 else "-1"
prompt3 = str(prompt_list[2]) if len(prompt_list) > 2 else ""
n_prompt3 = str(n_prompt_list[2]) if len(n_prompt_list) > 2 else ""
seed4 = str(seed_list[3]) if len(seed_list) > 3 else "-1"
prompt4 = str(prompt_list[3]) if len(prompt_list) > 3 else ""
n_prompt4 = str(n_prompt_list[3]) if len(n_prompt_list) > 3 else ""
help.verbose_print(f"Done Loading Prompt Config!")
motion_model_dropdown = gr.update(value=motion_model_path.split(sep)[-1])
sd_model_dropdown = gr.update(value=base_path.split(sep)[-1])
db_model_dropdown = gr.update(value=db_path.split(sep)[-1])
pretrained_model_path = gr.update(value=base_path)
return config_name, motion_model_path, base_path, db_path, steps, guidance_scale, lora_alpha, \
seed1, prompt1, n_prompt1, seed2, prompt2, n_prompt2, seed3, prompt3, n_prompt3, seed4, prompt4, n_prompt4, \
motion_model_dropdown, sd_model_dropdown, db_model_dropdown, pretrained_model_path
def save_db_config(filename):
global prompt_config_dict
prompt_configs_path = os.path.join(os.path.join(os.getcwd(), "configs"), "prompts")
help.dict_to_yaml(copy.deepcopy(prompt_config_dict), os.path.join(prompt_configs_path, f"{filename}.yaml"))
help.verbose_print(f"Done Creating NEW Prompt Config!")
def save_prompt_dict(config_name, motion_model_path, base_path, db_path, steps, guidance_scale, lora_alpha,
seed1, prompt1, n_prompt1, seed2, prompt2, n_prompt2, seed3, prompt3, n_prompt3, seed4, prompt4, n_prompt4):
global prompt_config_dict
prompt_config_dict[config_name] = {}
prompt_config_dict[config_name]["base"] = base_path
prompt_config_dict[config_name]["path"] = db_path
prompt_config_dict[config_name]["motion_module"] = []
prompt_config_dict[config_name]["motion_module"].append(motion_model_path)
prompt_config_dict[config_name]["seed"] = [0] * 4
prompt_config_dict[config_name]["steps"] = steps
prompt_config_dict[config_name]["guidance_scale"] = guidance_scale
prompt_config_dict[config_name]["lora_alpha"] = lora_alpha
prompt_config_dict[config_name]["prompt"] = [""]*4
prompt_config_dict[config_name]["n_prompt"] = [""]*4
prompt_config_dict[config_name]["seed"][0] = int(seed1)
prompt_config_dict[config_name]["prompt"][0] = prompt1
prompt_config_dict[config_name]["n_prompt"][0] = n_prompt1
prompt_config_dict[config_name]["seed"][1] = int(seed2)
prompt_config_dict[config_name]["prompt"][1] = prompt2
prompt_config_dict[config_name]["n_prompt"][1] = n_prompt2
prompt_config_dict[config_name]["seed"][2] = int(seed3)
prompt_config_dict[config_name]["prompt"][2] = prompt3
prompt_config_dict[config_name]["n_prompt"][2] = n_prompt3
prompt_config_dict[config_name]["seed"][3] = int(seed4)
prompt_config_dict[config_name]["prompt"][3] = prompt4
prompt_config_dict[config_name]["n_prompt"][3] = n_prompt4
prompt_configs_path = os.path.join(os.path.join(os.getcwd(), "configs"), "prompts")
help.dict_to_yaml(copy.deepcopy(prompt_config_dict), os.path.join(prompt_configs_path, f"{config_name}.yaml"))
help.verbose_print(f"Done Updating Prompt Config!")
def animate(pretrained_model_path, frame_count, width, height, inference_yaml_select, prompt_yaml_select):
prompt_configs_path = os.path.join(os.path.join(os.getcwd(), "configs"), "prompts")
inference_configs_path = os.path.join(os.path.join(os.getcwd(), "configs"), "inference")
command_str = f"python -m scripts.animate --config {os.path.join(prompt_configs_path, prompt_yaml_select)}"
if pretrained_model_path is not None and len(pretrained_model_path) > 0:
command_str += f" --pretrained_model_path {pretrained_model_path}"
command_str += f" --L {frame_count}"
command_str += f" --W {width}"
command_str += f" --H {height}"
if inference_yaml_select is not None and len(inference_yaml_select) > 0:
command_str += f" --inference_config {os.path.join(inference_configs_path, inference_yaml_select)}"
help.verbose_print(f"Running Command:\t{command_str}")
for line in help.execute(command_str.split(" ")):
help.verbose_print(line)
help.verbose_print(f"Done Generating!")
def build_ui():
with gr.Blocks() as demo:
with gr.Tab("Model Selection & Setup"):
with gr.Row():
motion_model_dropdown = gr.Dropdown(interactive=True, label="Select Motion Model", info="Downloads model if not present", choices=all_motion_model_opts)
sd_model_dropdown = gr.Dropdown(interactive=True, label="Select Stable Diffusion Model", info="At user/s discretion to download", choices=get_available_sd_models())
db_model_dropdown = gr.Dropdown(interactive=True, label="Select LoRA/Dreambooth Model", info="At user/s discretion to download", choices=get_available_db_models())
with gr.Row():
pretrained_model_path = gr.Textbox(info="Pretrained Model Path", interactive=True, show_label=False)
with gr.Row():
frame_count = gr.Slider(info="Total Frames", minimum=0, maximum=1000, step=1, value=16, show_label=False)
width = gr.Slider(info="Width", minimum=0, maximum=4096, step=1, value=512, show_label=False)
height = gr.Slider(info="Height", minimum=0, maximum=4096, step=1, value=512, show_label=False)
with gr.Row():
inference_yaml_select = gr.Dropdown(info='YAML Select', interactive=True, choices=get_sd_config(), show_label=False)
animate_button = gr.Button(value="Generate", variant='primary')
with gr.Tab("LoRA/Dreambooth Prompt Config"):
with gr.Row():
config_save = gr.Button(value="Apply & Save Settings", variant='primary')
create_prompt_yaml = gr.Button(value="Create Prompt Config", variant='secondary')
with gr.Row():
prompt_yaml_select = gr.Dropdown(info='YAML Select', interactive=True, choices=get_db_config(), show_label=False)
config_name = gr.Textbox(info="Config Name", interactive=True, show_label=False)
motion_model_path = gr.Textbox(info="Motion Model Path", interactive=True, show_label=False)
with gr.Row():
base_path = gr.Textbox(info="Base Model Path", interactive=True, show_label=False)
db_path = gr.Textbox(info="LoRA/Dreambooth Path", interactive=True, show_label=False)
with gr.Row():
steps = gr.Slider(info="Steps", minimum=0, maximum=1000, step=1, value=25, show_label=False)
with gr.Row():
guidance_scale = gr.Slider(info="Guidance Scale", minimum=0.0, maximum=100.0, step=0.05, value=6.5, show_label=False)
lora_alpha = gr.Slider(info="LoRA Alpha", minimum=0.0, maximum=1.0, step=0.025, value=1.0, show_label=False)
with gr.Accordion("Prompt 1", visible=True, open=False):
with gr.Column():
seed1 = gr.Textbox(info="Seed", interactive=True, show_label=False)
prompt1 = gr.Textbox(info="Prompt", interactive=True, show_label=False)
n_prompt1 = gr.Textbox(info="Negative Prompt", interactive=True, show_label=False)
with gr.Accordion("Prompt 2", visible=True, open=False):
with gr.Column():
seed2 = gr.Textbox(info="Seed", interactive=True, show_label=False)
prompt2 = gr.Textbox(info="Prompt", interactive=True, show_label=False)
n_prompt2 = gr.Textbox(info="Negative Prompt", interactive=True, show_label=False)
with gr.Accordion("Prompt 3", visible=True, open=False):
with gr.Column():
seed3 = gr.Textbox(info="Seed", interactive=True, show_label=False)
prompt3 = gr.Textbox(info="Prompt", interactive=True, show_label=False)
n_prompt3 = gr.Textbox(info="Negative Prompt", interactive=True, show_label=False)
with gr.Accordion("Prompt 4", visible=True, open=False):
with gr.Column():
seed4 = gr.Textbox(info="Seed", interactive=True, show_label=False)
prompt4 = gr.Textbox(info="Prompt", interactive=True, show_label=False)
n_prompt4 = gr.Textbox(info="Negative Prompt", interactive=True, show_label=False)
motion_model_dropdown.select(fn=set_motion_model, inputs=[], outputs=[motion_model_path])
sd_model_dropdown.select(fn=set_sd_model, inputs=[], outputs=[base_path, pretrained_model_path])
db_model_dropdown.select(fn=set_db_model, inputs=[], outputs=[db_path])
prompt_yaml_select.select(fn=load_db_config, inputs=[],
outputs=[config_name, motion_model_path, base_path, db_path, steps, guidance_scale, lora_alpha,
seed1, prompt1, n_prompt1, seed2, prompt2, n_prompt2, seed3, prompt3, n_prompt3,
seed4, prompt4, n_prompt4, motion_model_dropdown, sd_model_dropdown,
db_model_dropdown, pretrained_model_path]).then(
fn=update_db_config, inputs=[], outputs=[prompt_yaml_select])
create_prompt_yaml.click(fn=save_db_config, inputs=[config_name], outputs=[])
config_save.click(fn=save_prompt_dict, inputs=[config_name, motion_model_path, base_path, db_path, steps, guidance_scale, lora_alpha,
seed1, prompt1, n_prompt1, seed2, prompt2, n_prompt2, seed3, prompt3, n_prompt3, seed4, prompt4, n_prompt4],
outputs=[])
animate_button.click(fn=animate, inputs=[pretrained_model_path, frame_count, width, height, inference_yaml_select, prompt_yaml_select], outputs=[])
return demo
def UI(**kwargs):
# Show the interface
launch_kwargs = {}
if not kwargs.get('username', None) == '':
launch_kwargs['auth'] = (
kwargs.get('username', None),
kwargs.get('password', None),
)
if kwargs.get('server_port', 0) > 0:
launch_kwargs['server_port'] = kwargs.get('server_port', 0)
if kwargs.get('share', True):
launch_kwargs['share'] = True
print(launch_kwargs)
demo.queue().launch(**launch_kwargs)
if __name__ == "__main__":
# init client & server connection
HOST = "127.0.0.1"
parser = argparse.ArgumentParser()
parser.add_argument(
'--username', type=str, default='', help='Username for authentication'
)
parser.add_argument(
'--password', type=str, default='', help='Password for authentication'
)
parser.add_argument(
'--server_port',
type=int,
default=0,
help='Port to run the server listener on',
)
parser.add_argument(
'--share',
action='store_true',
help='Share live gradio link',
)
args = parser.parse_args()
demo = build_ui()
global prompt_config_dict
prompt_config_dict = {}
help.verbose_print(f"Motion models available to use:\t{get_available_motion_models()}")
help.verbose_print(f"Stable Diffusion models available to use:\t{get_available_sd_models()}")
help.verbose_print(f"LoRA/Dreambooth models available to use:\t{get_available_db_models()}")
UI(
username=args.username,
password=args.password,
server_port=args.server_port,
share=args.share,
)