fix model

This commit is contained in:
Yuwei
2024-07-17 08:03:42 +00:00
parent cf80ddeb47
commit 786a99cc7f
3 changed files with 140 additions and 71 deletions

View File

@@ -238,7 +238,7 @@ class PositionalEncoding(nn.Module):
pe = torch.zeros(1, max_len, d_model)
pe[0, :, 0::2] = torch.sin(position * div_term)
pe[0, :, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
self.register_buffer('pe', pe, persistent=False)
def forward(self, x):
x = x + self.pe[:, :x.size(1)]
@@ -251,7 +251,7 @@ class VersatileAttention(CrossAttention):
attention_mode = None,
cross_frame_attention_mode = None,
temporal_position_encoding = False,
temporal_position_encoding_max_len = 24,
temporal_position_encoding_max_len = 32,
*args, **kwargs
):
super().__init__(*args, **kwargs)

View File

@@ -475,16 +475,77 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
return UNet3DConditionOutput(sample=sample)
@classmethod
def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):
if subfolder is not None:
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
print(f"loaded 3D unet's pretrained weights from {pretrained_model_path} ...")
def from_pretrained_2d(cls, pretrained_model_name_or_path, unet_additional_kwargs={}, **kwargs):
from diffusers import __version__
from diffusers.utils import DIFFUSERS_CACHE, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, is_safetensors_available
from diffusers.modeling_utils import load_state_dict
print(f"loaded 3D unet's pretrained weights from {pretrained_model_name_or_path} ...")
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
device_map = kwargs.pop("device_map", None)
user_agent = {
"diffusers": __version__,
"file_type": "model",
"framework": "pytorch",
}
model_file = None
if is_safetensors_available():
try:
model_file = cls._get_model_file(
pretrained_model_name_or_path,
weights_name=SAFETENSORS_WEIGHTS_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
except:
pass
if model_file is None:
model_file = cls._get_model_file(
pretrained_model_name_or_path,
weights_name=WEIGHTS_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
config, unused_kwargs = cls.load_config(
pretrained_model_name_or_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
device_map=device_map,
**kwargs,
)
config_file = os.path.join(pretrained_model_path, 'config.json')
if not os.path.isfile(config_file):
raise RuntimeError(f"{config_file} does not exist")
with open(config_file, "r") as f:
config = json.load(f)
config["_class_name"] = cls.__name__
config["down_block_types"] = [
"CrossAttnDownBlock3D",
@@ -499,12 +560,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
"CrossAttnUpBlock3D"
]
from diffusers.utils import WEIGHTS_NAME
model = cls.from_config(config, **unet_additional_kwargs)
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
if not os.path.isfile(model_file):
raise RuntimeError(f"{model_file} does not exist")
state_dict = torch.load(model_file, map_location="cpu")
model = cls.from_config(config, **unused_kwargs, **unet_additional_kwargs)
state_dict = load_state_dict(model_file)
m, u = model.load_state_dict(state_dict, strict=False)
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")