fix script

This commit is contained in:
Yuwei
2024-07-17 08:05:50 +00:00
parent 786a99cc7f
commit 88068d9b6d
12 changed files with 711 additions and 2 deletions

View File

@@ -17,7 +17,7 @@ from animatediff.models.unet import UNet3DConditionModel
from animatediff.models.sparse_controlnet import SparseControlNetModel
from animatediff.pipelines.pipeline_animation import AnimationPipeline
from animatediff.utils.util import save_videos_grid
from animatediff.utils.util import load_weights
from animatediff.utils.util import load_weights, auto_download
from diffusers.utils.import_utils import is_xformers_available
from einops import rearrange, repeat
@@ -66,9 +66,11 @@ def main(args):
controlnet_config = OmegaConf.load(model_config.controlnet_config)
controlnet = SparseControlNetModel.from_unet(unet, controlnet_additional_kwargs=controlnet_config.get("controlnet_additional_kwargs", {}))
auto_download(model_config.controlnet_path, is_dreambooth_lora=False)
print(f"loading controlnet checkpoint from {model_config.controlnet_path} ...")
controlnet_state_dict = torch.load(model_config.controlnet_path, map_location="cpu")
controlnet_state_dict = controlnet_state_dict["controlnet"] if "controlnet" in controlnet_state_dict else controlnet_state_dict
controlnet_state_dict = {name: param for name, param in controlnet_state_dict.items() if "pos_encoder.pe" not in name}
controlnet_state_dict.pop("animatediff_config", "")
controlnet.load_state_dict(controlnet_state_dict)
controlnet.cuda()
@@ -181,7 +183,7 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pretrained-model-path", type=str, default="models/StableDiffusion/stable-diffusion-v1-5",)
parser.add_argument("--pretrained-model-path", type=str, default="runwayml/stable-diffusion-v1-5")
parser.add_argument("--inference-config", type=str, default="configs/inference/inference-v1.yaml")
parser.add_argument("--config", type=str, required=True)