mirror of
https://github.com/guoyww/AnimateDiff.git
synced 2026-04-03 09:46:36 +02:00
fix script
This commit is contained in:
@@ -17,7 +17,7 @@ from animatediff.models.unet import UNet3DConditionModel
|
||||
from animatediff.models.sparse_controlnet import SparseControlNetModel
|
||||
from animatediff.pipelines.pipeline_animation import AnimationPipeline
|
||||
from animatediff.utils.util import save_videos_grid
|
||||
from animatediff.utils.util import load_weights
|
||||
from animatediff.utils.util import load_weights, auto_download
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
from einops import rearrange, repeat
|
||||
@@ -66,9 +66,11 @@ def main(args):
|
||||
controlnet_config = OmegaConf.load(model_config.controlnet_config)
|
||||
controlnet = SparseControlNetModel.from_unet(unet, controlnet_additional_kwargs=controlnet_config.get("controlnet_additional_kwargs", {}))
|
||||
|
||||
auto_download(model_config.controlnet_path, is_dreambooth_lora=False)
|
||||
print(f"loading controlnet checkpoint from {model_config.controlnet_path} ...")
|
||||
controlnet_state_dict = torch.load(model_config.controlnet_path, map_location="cpu")
|
||||
controlnet_state_dict = controlnet_state_dict["controlnet"] if "controlnet" in controlnet_state_dict else controlnet_state_dict
|
||||
controlnet_state_dict = {name: param for name, param in controlnet_state_dict.items() if "pos_encoder.pe" not in name}
|
||||
controlnet_state_dict.pop("animatediff_config", "")
|
||||
controlnet.load_state_dict(controlnet_state_dict)
|
||||
controlnet.cuda()
|
||||
@@ -181,7 +183,7 @@ def main(args):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--pretrained-model-path", type=str, default="models/StableDiffusion/stable-diffusion-v1-5",)
|
||||
parser.add_argument("--pretrained-model-path", type=str, default="runwayml/stable-diffusion-v1-5")
|
||||
parser.add_argument("--inference-config", type=str, default="configs/inference/inference-v1.yaml")
|
||||
parser.add_argument("--config", type=str, required=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user