mirror of
https://github.com/guoyww/AnimateDiff.git
synced 2026-04-03 17:56:15 +02:00
fix model
This commit is contained in:
@@ -238,7 +238,7 @@ class PositionalEncoding(nn.Module):
|
||||
pe = torch.zeros(1, max_len, d_model)
|
||||
pe[0, :, 0::2] = torch.sin(position * div_term)
|
||||
pe[0, :, 1::2] = torch.cos(position * div_term)
|
||||
self.register_buffer('pe', pe)
|
||||
self.register_buffer('pe', pe, persistent=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.pe[:, :x.size(1)]
|
||||
@@ -251,7 +251,7 @@ class VersatileAttention(CrossAttention):
|
||||
attention_mode = None,
|
||||
cross_frame_attention_mode = None,
|
||||
temporal_position_encoding = False,
|
||||
temporal_position_encoding_max_len = 24,
|
||||
temporal_position_encoding_max_len = 32,
|
||||
*args, **kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@@ -475,16 +475,77 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
||||
return UNet3DConditionOutput(sample=sample)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):
|
||||
if subfolder is not None:
|
||||
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
||||
print(f"loaded 3D unet's pretrained weights from {pretrained_model_path} ...")
|
||||
def from_pretrained_2d(cls, pretrained_model_name_or_path, unet_additional_kwargs={}, **kwargs):
|
||||
from diffusers import __version__
|
||||
from diffusers.utils import DIFFUSERS_CACHE, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, is_safetensors_available
|
||||
from diffusers.modeling_utils import load_state_dict
|
||||
print(f"loaded 3D unet's pretrained weights from {pretrained_model_name_or_path} ...")
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
|
||||
user_agent = {
|
||||
"diffusers": __version__,
|
||||
"file_type": "model",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
model_file = None
|
||||
if is_safetensors_available():
|
||||
try:
|
||||
model_file = cls._get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=SAFETENSORS_WEIGHTS_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
if model_file is None:
|
||||
model_file = cls._get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=WEIGHTS_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
config, unused_kwargs = cls.load_config(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
return_unused_kwargs=True,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
device_map=device_map,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
config_file = os.path.join(pretrained_model_path, 'config.json')
|
||||
if not os.path.isfile(config_file):
|
||||
raise RuntimeError(f"{config_file} does not exist")
|
||||
with open(config_file, "r") as f:
|
||||
config = json.load(f)
|
||||
config["_class_name"] = cls.__name__
|
||||
config["down_block_types"] = [
|
||||
"CrossAttnDownBlock3D",
|
||||
@@ -499,12 +560,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
||||
"CrossAttnUpBlock3D"
|
||||
]
|
||||
|
||||
from diffusers.utils import WEIGHTS_NAME
|
||||
model = cls.from_config(config, **unet_additional_kwargs)
|
||||
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
||||
if not os.path.isfile(model_file):
|
||||
raise RuntimeError(f"{model_file} does not exist")
|
||||
state_dict = torch.load(model_file, map_location="cpu")
|
||||
model = cls.from_config(config, **unused_kwargs, **unet_additional_kwargs)
|
||||
state_dict = load_state_dict(model_file)
|
||||
|
||||
m, u = model.load_state_dict(state_dict, strict=False)
|
||||
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
||||
|
||||
Reference in New Issue
Block a user