mirror of
https://github.com/guoyww/AnimateDiff.git
synced 2026-04-03 01:36:20 +02:00
BIN
__assets__/ui/first-tab.png
Normal file
BIN
__assets__/ui/first-tab.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 48 KiB |
BIN
__assets__/ui/second-tab.png
Normal file
BIN
__assets__/ui/second-tab.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 84 KiB |
@@ -19,3 +19,4 @@ dependencies:
|
||||
- omegaconf
|
||||
- safetensors
|
||||
- gradio
|
||||
- pyyaml
|
||||
|
||||
41
utils.py
Normal file
41
utils.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import copy
|
||||
import os
|
||||
import subprocess as sub
|
||||
import gdown
|
||||
import yaml
|
||||
|
||||
def yaml_to_dict(filename):
|
||||
with open(filename, 'r') as file:
|
||||
data = yaml.safe_load(file)
|
||||
return copy.deepcopy(data)
|
||||
|
||||
def dict_to_yaml(data, filename):
|
||||
with open(filename, 'w') as file:
|
||||
yaml.safe_dump(data, file)
|
||||
|
||||
def download_from_drive_gdown(file_id, output_path):
|
||||
url = f'https://drive.google.com/uc?id={file_id}'
|
||||
gdown.download(url, output_path, quiet=False)
|
||||
|
||||
def verbose_print(text):
|
||||
print(f"{text}")
|
||||
|
||||
def create_dirs(arb_path):
|
||||
if not os.path.exists(arb_path):
|
||||
os.makedirs(arb_path)
|
||||
|
||||
def make_all_dirs(list_of_paths):
|
||||
for path in list_of_paths:
|
||||
create_dirs(path)
|
||||
|
||||
def execute(cmd):
|
||||
popen = sub.Popen(cmd, stdout=sub.PIPE, universal_newlines=True)
|
||||
for stdout_line in iter(popen.stdout.readline, ""):
|
||||
yield stdout_line
|
||||
popen.stdout.close()
|
||||
return_code = popen.wait()
|
||||
if return_code:
|
||||
raise sub.CalledProcessError(return_code, cmd)
|
||||
|
||||
def is_windows():
|
||||
return os.name == 'nt'
|
||||
307
webui.py
Normal file
307
webui.py
Normal file
@@ -0,0 +1,307 @@
|
||||
import argparse
|
||||
import gradio as gr
|
||||
import copy
|
||||
import os
|
||||
import glob
|
||||
|
||||
import utils as help
|
||||
|
||||
sep = '\\' if help.is_windows() else '/'
|
||||
all_motion_model_opts = ["mm_sd_v14.ckpt", "mm_sd_v15.ckpt"]
|
||||
|
||||
def get_available_motion_models():
|
||||
motion_model_opts_path = os.path.join(os.getcwd(), os.path.join("models", "Motion_Module"))
|
||||
motion_model_opts = sorted([ckpt for ckpt in glob.glob(os.path.join(motion_model_opts_path, f"*.ckpt"))])
|
||||
return motion_model_opts
|
||||
|
||||
def get_available_sd_models():
|
||||
sd_model_opts_path = os.path.join(os.getcwd(), os.path.join("models", "StableDiffusion"))
|
||||
sd_model_opts = sorted([safetensor.split(sep)[-1] for safetensor in glob.glob(os.path.join(sd_model_opts_path, f"*.safetensors"))])
|
||||
return sd_model_opts
|
||||
|
||||
def get_available_db_models():
|
||||
db_model_opts_path = os.path.join(os.getcwd(), os.path.join("models", "DreamBooth_LoRA"))
|
||||
db_model_opts = sorted([safetensor.split(sep)[-1] for safetensor in glob.glob(os.path.join(db_model_opts_path, f"*.safetensors"))])
|
||||
return db_model_opts
|
||||
|
||||
def get_db_config():
|
||||
prompt_configs_path = os.path.join(os.path.join(os.getcwd(), "configs"), "prompts")
|
||||
return sorted([(prompt_yaml.split(sep)[-1]) for prompt_yaml in glob.glob(os.path.join(prompt_configs_path, f"*.yaml"))])
|
||||
|
||||
def get_sd_config():
|
||||
inference_configs_path = os.path.join(os.path.join(os.getcwd(), "configs"), "inference")
|
||||
return sorted([(inference_yaml.split(sep)[-1]) for inference_yaml in glob.glob(os.path.join(inference_configs_path, f"*.yaml"))])
|
||||
|
||||
def set_motion_model(menu_opt: gr.SelectData):
|
||||
model_name = menu_opt.value
|
||||
motion_model_opts_path = os.path.join(os.getcwd(), os.path.join("models", "Motion_Module"))
|
||||
motion_model_opts = sorted([ckpt for ckpt in glob.glob(os.path.join(motion_model_opts_path, f"*.ckpt"))])
|
||||
motion_model_map = {"mm_sd_v14.ckpt": "1RqkQuGPaCO5sGZ6V6KZ-jUWmsRu48Kdq",
|
||||
"mm_sd_v15.ckpt": "1ql0g_Ys4UCz2RnokYlBjyOYPbttbIpbu"}
|
||||
if not os.path.join(motion_model_opts_path, model_name) in motion_model_opts: # download
|
||||
help.download_from_drive_gdown(motion_model_map[model_name], os.path.join(motion_model_opts_path, model_name))
|
||||
return gr.update(value=os.path.join(motion_model_opts_path, model_name)) # model path
|
||||
|
||||
def set_sd_model(menu_opt: gr.SelectData):
|
||||
model_name = menu_opt.value
|
||||
sd_model_opts_path = os.path.join(os.getcwd(), os.path.join("models", "StableDiffusion"))
|
||||
return gr.update(value=os.path.join(sd_model_opts_path, model_name)), gr.update(value=os.path.join(sd_model_opts_path, model_name)) # sd path, pretrained path
|
||||
|
||||
def set_db_model(menu_opt: gr.SelectData):
|
||||
model_name = menu_opt.value
|
||||
db_model_opts_path = os.path.join(os.getcwd(), os.path.join("models", "DreamBooth_LoRA"))
|
||||
return gr.update(value=os.path.join(db_model_opts_path, model_name)) # db path
|
||||
|
||||
def update_available_sd_models():
|
||||
sd_model_opts_path = os.path.join(os.getcwd(), os.path.join("models", "StableDiffusion"))
|
||||
sd_model_opts = sorted([safetensor.split(sep)[-1] for safetensor in glob.glob(os.path.join(sd_model_opts_path, f"*.safetensors"))])
|
||||
return gr.update(choices=sd_model_opts)
|
||||
|
||||
def update_available_db_models():
|
||||
db_model_opts_path = os.path.join(os.getcwd(), os.path.join("models", "DreamBooth_LoRA"))
|
||||
db_model_opts = sorted([safetensor.split(sep)[-1] for safetensor in glob.glob(os.path.join(db_model_opts_path, f"*.safetensors"))])
|
||||
return gr.update(choices=db_model_opts)
|
||||
|
||||
def update_sd_config():
|
||||
inference_configs_path = os.path.join(os.path.join(os.getcwd(), "configs"), "inference")
|
||||
return gr.update(choices=sorted([(inference_yaml.split(sep)[-1]) for inference_yaml in glob.glob(os.path.join(inference_configs_path, f"*.yaml"))]))
|
||||
|
||||
def update_db_config():
|
||||
prompt_configs_path = os.path.join(os.path.join(os.getcwd(), "configs"), "prompts")
|
||||
return gr.update(choices=sorted([(prompt_yaml.split(sep)[-1]) for prompt_yaml in glob.glob(os.path.join(prompt_configs_path, f"*.yaml"))]))
|
||||
|
||||
def load_db_config(filename: gr.SelectData):
|
||||
filename = filename.value
|
||||
|
||||
global prompt_config_dict
|
||||
prompt_configs_path = os.path.join(os.path.join(os.getcwd(), "configs"), "prompts")
|
||||
# populate the dictionary
|
||||
prompt_config_dict = help.yaml_to_dict(os.path.join(prompt_configs_path, f"{filename}"))
|
||||
|
||||
name_only = list(prompt_config_dict.keys())[0]
|
||||
help.verbose_print(f"Config Key Name:\t{name_only}")
|
||||
|
||||
# return populated UI components
|
||||
config_name = name_only
|
||||
|
||||
motion_model_path = list(prompt_config_dict[name_only]["motion_module"])[0]
|
||||
|
||||
|
||||
base_path = str(prompt_config_dict[name_only]["base"])
|
||||
db_path = str(prompt_config_dict[name_only]["path"])
|
||||
steps = int(prompt_config_dict[name_only]["steps"])
|
||||
guidance_scale = float(prompt_config_dict[name_only]["guidance_scale"])
|
||||
lora_alpha = float(prompt_config_dict[name_only]["lora_alpha"]) if "lora_alpha" in prompt_config_dict[name_only] else 1.0
|
||||
|
||||
seed_list = list(prompt_config_dict[name_only]["seed"])
|
||||
prompt_list = list(prompt_config_dict[name_only]["prompt"])
|
||||
n_prompt_list = list(prompt_config_dict[name_only]["n_prompt"])
|
||||
|
||||
seed1 = str(seed_list[0]) if len(seed_list) > 0 else "-1"
|
||||
prompt1 = str(prompt_list[0]) if len(prompt_list) > 0 else ""
|
||||
n_prompt1 = str(n_prompt_list[0]) if len(n_prompt_list) > 0 else ""
|
||||
seed2 = str(seed_list[1]) if len(seed_list) > 1 else "-1"
|
||||
prompt2 = str(prompt_list[1]) if len(prompt_list) > 1 else ""
|
||||
n_prompt2 = str(n_prompt_list[1]) if len(n_prompt_list) > 1 else ""
|
||||
seed3 = str(seed_list[2]) if len(seed_list) > 2 else "-1"
|
||||
prompt3 = str(prompt_list[2]) if len(prompt_list) > 2 else ""
|
||||
n_prompt3 = str(n_prompt_list[2]) if len(n_prompt_list) > 2 else ""
|
||||
seed4 = str(seed_list[3]) if len(seed_list) > 3 else "-1"
|
||||
prompt4 = str(prompt_list[3]) if len(prompt_list) > 3 else ""
|
||||
n_prompt4 = str(n_prompt_list[3]) if len(n_prompt_list) > 3 else ""
|
||||
help.verbose_print(f"Done Loading Prompt Config!")
|
||||
|
||||
motion_model_dropdown = gr.update(value=motion_model_path.split(sep)[-1])
|
||||
sd_model_dropdown = gr.update(value=base_path.split(sep)[-1])
|
||||
db_model_dropdown = gr.update(value=db_path.split(sep)[-1])
|
||||
pretrained_model_path = gr.update(value=base_path)
|
||||
|
||||
return config_name, motion_model_path, base_path, db_path, steps, guidance_scale, lora_alpha, \
|
||||
seed1, prompt1, n_prompt1, seed2, prompt2, n_prompt2, seed3, prompt3, n_prompt3, seed4, prompt4, n_prompt4, \
|
||||
motion_model_dropdown, sd_model_dropdown, db_model_dropdown, pretrained_model_path
|
||||
|
||||
def save_db_config(filename):
|
||||
global prompt_config_dict
|
||||
prompt_configs_path = os.path.join(os.path.join(os.getcwd(), "configs"), "prompts")
|
||||
help.dict_to_yaml(copy.deepcopy(prompt_config_dict), os.path.join(prompt_configs_path, f"{filename}.yaml"))
|
||||
help.verbose_print(f"Done Creating NEW Prompt Config!")
|
||||
|
||||
def save_prompt_dict(config_name, motion_model_path, base_path, db_path, steps, guidance_scale, lora_alpha,
|
||||
seed1, prompt1, n_prompt1, seed2, prompt2, n_prompt2, seed3, prompt3, n_prompt3, seed4, prompt4, n_prompt4):
|
||||
global prompt_config_dict
|
||||
prompt_config_dict[config_name] = {}
|
||||
prompt_config_dict[config_name]["base"] = base_path
|
||||
prompt_config_dict[config_name]["path"] = db_path
|
||||
|
||||
prompt_config_dict[config_name]["motion_module"] = []
|
||||
prompt_config_dict[config_name]["motion_module"].append(motion_model_path)
|
||||
|
||||
prompt_config_dict[config_name]["seed"] = [0] * 4
|
||||
prompt_config_dict[config_name]["steps"] = steps
|
||||
prompt_config_dict[config_name]["guidance_scale"] = guidance_scale
|
||||
prompt_config_dict[config_name]["lora_alpha"] = lora_alpha
|
||||
|
||||
|
||||
prompt_config_dict[config_name]["prompt"] = [""]*4
|
||||
prompt_config_dict[config_name]["n_prompt"] = [""]*4
|
||||
|
||||
prompt_config_dict[config_name]["seed"][0] = int(seed1)
|
||||
prompt_config_dict[config_name]["prompt"][0] = prompt1
|
||||
prompt_config_dict[config_name]["n_prompt"][0] = n_prompt1
|
||||
prompt_config_dict[config_name]["seed"][1] = int(seed2)
|
||||
prompt_config_dict[config_name]["prompt"][1] = prompt2
|
||||
prompt_config_dict[config_name]["n_prompt"][1] = n_prompt2
|
||||
prompt_config_dict[config_name]["seed"][2] = int(seed3)
|
||||
prompt_config_dict[config_name]["prompt"][2] = prompt3
|
||||
prompt_config_dict[config_name]["n_prompt"][2] = n_prompt3
|
||||
prompt_config_dict[config_name]["seed"][3] = int(seed4)
|
||||
prompt_config_dict[config_name]["prompt"][3] = prompt4
|
||||
prompt_config_dict[config_name]["n_prompt"][3] = n_prompt4
|
||||
|
||||
prompt_configs_path = os.path.join(os.path.join(os.getcwd(), "configs"), "prompts")
|
||||
help.dict_to_yaml(copy.deepcopy(prompt_config_dict), os.path.join(prompt_configs_path, f"{config_name}.yaml"))
|
||||
help.verbose_print(f"Done Updating Prompt Config!")
|
||||
|
||||
def animate(pretrained_model_path, frame_count, width, height, inference_yaml_select, prompt_yaml_select):
|
||||
prompt_configs_path = os.path.join(os.path.join(os.getcwd(), "configs"), "prompts")
|
||||
inference_configs_path = os.path.join(os.path.join(os.getcwd(), "configs"), "inference")
|
||||
|
||||
command_str = f"python -m scripts.animate --config {os.path.join(prompt_configs_path, prompt_yaml_select)}"
|
||||
if pretrained_model_path is not None and len(pretrained_model_path) > 0:
|
||||
command_str += f" --pretrained_model_path {pretrained_model_path}"
|
||||
command_str += f" --L {frame_count}"
|
||||
command_str += f" --W {width}"
|
||||
command_str += f" --H {height}"
|
||||
if inference_yaml_select is not None and len(inference_yaml_select) > 0:
|
||||
command_str += f" --inference_config {os.path.join(inference_configs_path, inference_yaml_select)}"
|
||||
|
||||
help.verbose_print(f"Running Command:\t{command_str}")
|
||||
for line in help.execute(command_str.split(" ")):
|
||||
help.verbose_print(line)
|
||||
help.verbose_print(f"Done Generating!")
|
||||
|
||||
def build_ui():
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Tab("Model Selection & Setup"):
|
||||
with gr.Row():
|
||||
motion_model_dropdown = gr.Dropdown(interactive=True, label="Select Motion Model", info="Downloads model if not present", choices=all_motion_model_opts)
|
||||
sd_model_dropdown = gr.Dropdown(interactive=True, label="Select Stable Diffusion Model", info="At user/s discretion to download", choices=get_available_sd_models())
|
||||
db_model_dropdown = gr.Dropdown(interactive=True, label="Select LoRA/Dreambooth Model", info="At user/s discretion to download", choices=get_available_db_models())
|
||||
with gr.Row():
|
||||
pretrained_model_path = gr.Textbox(info="Pretrained Model Path", interactive=True, show_label=False)
|
||||
with gr.Row():
|
||||
frame_count = gr.Slider(info="Total Frames", minimum=0, maximum=1000, step=1, value=16, show_label=False)
|
||||
width = gr.Slider(info="Width", minimum=0, maximum=4096, step=1, value=512, show_label=False)
|
||||
height = gr.Slider(info="Height", minimum=0, maximum=4096, step=1, value=512, show_label=False)
|
||||
with gr.Row():
|
||||
inference_yaml_select = gr.Dropdown(info='YAML Select', interactive=True, choices=get_sd_config(), show_label=False)
|
||||
animate_button = gr.Button(value="Generate", variant='primary')
|
||||
|
||||
with gr.Tab("LoRA/Dreambooth Prompt Config"):
|
||||
with gr.Row():
|
||||
config_save = gr.Button(value="Apply & Save Settings", variant='primary')
|
||||
create_prompt_yaml = gr.Button(value="Create Prompt Config", variant='secondary')
|
||||
with gr.Row():
|
||||
prompt_yaml_select = gr.Dropdown(info='YAML Select', interactive=True, choices=get_db_config(), show_label=False)
|
||||
config_name = gr.Textbox(info="Config Name", interactive=True, show_label=False)
|
||||
motion_model_path = gr.Textbox(info="Motion Model Path", interactive=True, show_label=False)
|
||||
with gr.Row():
|
||||
base_path = gr.Textbox(info="Base Model Path", interactive=True, show_label=False)
|
||||
db_path = gr.Textbox(info="LoRA/Dreambooth Path", interactive=True, show_label=False)
|
||||
with gr.Row():
|
||||
steps = gr.Slider(info="Steps", minimum=0, maximum=1000, step=1, value=25, show_label=False)
|
||||
with gr.Row():
|
||||
guidance_scale = gr.Slider(info="Guidance Scale", minimum=0.0, maximum=100.0, step=0.05, value=6.5, show_label=False)
|
||||
lora_alpha = gr.Slider(info="LoRA Alpha", minimum=0.0, maximum=1.0, step=0.025, value=1.0, show_label=False)
|
||||
with gr.Accordion("Prompt 1", visible=True, open=False):
|
||||
with gr.Column():
|
||||
seed1 = gr.Textbox(info="Seed", interactive=True, show_label=False)
|
||||
prompt1 = gr.Textbox(info="Prompt", interactive=True, show_label=False)
|
||||
n_prompt1 = gr.Textbox(info="Negative Prompt", interactive=True, show_label=False)
|
||||
with gr.Accordion("Prompt 2", visible=True, open=False):
|
||||
with gr.Column():
|
||||
seed2 = gr.Textbox(info="Seed", interactive=True, show_label=False)
|
||||
prompt2 = gr.Textbox(info="Prompt", interactive=True, show_label=False)
|
||||
n_prompt2 = gr.Textbox(info="Negative Prompt", interactive=True, show_label=False)
|
||||
with gr.Accordion("Prompt 3", visible=True, open=False):
|
||||
with gr.Column():
|
||||
seed3 = gr.Textbox(info="Seed", interactive=True, show_label=False)
|
||||
prompt3 = gr.Textbox(info="Prompt", interactive=True, show_label=False)
|
||||
n_prompt3 = gr.Textbox(info="Negative Prompt", interactive=True, show_label=False)
|
||||
with gr.Accordion("Prompt 4", visible=True, open=False):
|
||||
with gr.Column():
|
||||
seed4 = gr.Textbox(info="Seed", interactive=True, show_label=False)
|
||||
prompt4 = gr.Textbox(info="Prompt", interactive=True, show_label=False)
|
||||
n_prompt4 = gr.Textbox(info="Negative Prompt", interactive=True, show_label=False)
|
||||
|
||||
motion_model_dropdown.select(fn=set_motion_model, inputs=[], outputs=[motion_model_path])
|
||||
sd_model_dropdown.select(fn=set_sd_model, inputs=[], outputs=[base_path, pretrained_model_path])
|
||||
db_model_dropdown.select(fn=set_db_model, inputs=[], outputs=[db_path])
|
||||
prompt_yaml_select.select(fn=load_db_config, inputs=[],
|
||||
outputs=[config_name, motion_model_path, base_path, db_path, steps, guidance_scale, lora_alpha,
|
||||
seed1, prompt1, n_prompt1, seed2, prompt2, n_prompt2, seed3, prompt3, n_prompt3,
|
||||
seed4, prompt4, n_prompt4, motion_model_dropdown, sd_model_dropdown,
|
||||
db_model_dropdown, pretrained_model_path]).then(
|
||||
fn=update_db_config, inputs=[], outputs=[prompt_yaml_select])
|
||||
create_prompt_yaml.click(fn=save_db_config, inputs=[config_name], outputs=[])
|
||||
config_save.click(fn=save_prompt_dict, inputs=[config_name, motion_model_path, base_path, db_path, steps, guidance_scale, lora_alpha,
|
||||
seed1, prompt1, n_prompt1, seed2, prompt2, n_prompt2, seed3, prompt3, n_prompt3, seed4, prompt4, n_prompt4],
|
||||
outputs=[])
|
||||
animate_button.click(fn=animate, inputs=[pretrained_model_path, frame_count, width, height, inference_yaml_select, prompt_yaml_select], outputs=[])
|
||||
return demo
|
||||
|
||||
def UI(**kwargs):
|
||||
# Show the interface
|
||||
launch_kwargs = {}
|
||||
if not kwargs.get('username', None) == '':
|
||||
launch_kwargs['auth'] = (
|
||||
kwargs.get('username', None),
|
||||
kwargs.get('password', None),
|
||||
)
|
||||
if kwargs.get('server_port', 0) > 0:
|
||||
launch_kwargs['server_port'] = kwargs.get('server_port', 0)
|
||||
if kwargs.get('share', True):
|
||||
launch_kwargs['share'] = True
|
||||
|
||||
print(launch_kwargs)
|
||||
demo.queue().launch(**launch_kwargs)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# init client & server connection
|
||||
HOST = "127.0.0.1"
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--username', type=str, default='', help='Username for authentication'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--password', type=str, default='', help='Password for authentication'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--server_port',
|
||||
type=int,
|
||||
default=0,
|
||||
help='Port to run the server listener on',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--share',
|
||||
action='store_true',
|
||||
help='Share live gradio link',
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
demo = build_ui()
|
||||
|
||||
global prompt_config_dict
|
||||
prompt_config_dict = {}
|
||||
|
||||
help.verbose_print(f"Motion models available to use:\t{get_available_motion_models()}")
|
||||
help.verbose_print(f"Stable Diffusion models available to use:\t{get_available_sd_models()}")
|
||||
help.verbose_print(f"LoRA/Dreambooth models available to use:\t{get_available_db_models()}")
|
||||
|
||||
UI(
|
||||
username=args.username,
|
||||
password=args.password,
|
||||
server_port=args.server_port,
|
||||
share=args.share,
|
||||
)
|
||||
Reference in New Issue
Block a user