mirror of
https://github.com/guoyww/AnimateDiff.git
synced 2026-04-03 17:56:15 +02:00
support sparsectrl
This commit is contained in:
@@ -86,8 +86,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
||||
motion_module_decoder_only = False,
|
||||
motion_module_type = None,
|
||||
motion_module_kwargs = {},
|
||||
unet_use_cross_frame_attention = None,
|
||||
unet_use_temporal_attention = None,
|
||||
unet_use_cross_frame_attention = False,
|
||||
unet_use_temporal_attention = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -324,6 +324,11 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
|
||||
# support controlnet
|
||||
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
||||
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet3DConditionOutput, Tuple]:
|
||||
r"""
|
||||
@@ -414,11 +419,25 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# support controlnet
|
||||
down_block_res_samples = list(down_block_res_samples)
|
||||
if down_block_additional_residuals is not None:
|
||||
for i, down_block_additional_residual in enumerate(down_block_additional_residuals):
|
||||
if down_block_additional_residual.dim() == 4: # boardcast
|
||||
down_block_additional_residual = down_block_additional_residual.unsqueeze(2)
|
||||
down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual
|
||||
|
||||
# mid
|
||||
sample = self.mid_block(
|
||||
sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
||||
)
|
||||
|
||||
# support controlnet
|
||||
if mid_block_additional_residual is not None:
|
||||
if mid_block_additional_residual.dim() == 4: # boardcast
|
||||
mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2)
|
||||
sample = sample + mid_block_additional_residual
|
||||
|
||||
# up
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
is_final_block = i == len(self.up_blocks) - 1
|
||||
@@ -459,7 +478,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
||||
def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):
|
||||
if subfolder is not None:
|
||||
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
||||
print(f"loaded temporal unet's pretrained weights from {pretrained_model_path} ...")
|
||||
print(f"loaded 3D unet's pretrained weights from {pretrained_model_path} ...")
|
||||
|
||||
config_file = os.path.join(pretrained_model_path, 'config.json')
|
||||
if not os.path.isfile(config_file):
|
||||
@@ -489,9 +508,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
||||
|
||||
m, u = model.load_state_dict(state_dict, strict=False)
|
||||
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
||||
# print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n")
|
||||
|
||||
params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()]
|
||||
print(f"### Temporal Module Parameters: {sum(params) / 1e6} M")
|
||||
params = [p.numel() if "motion_modules." in n else 0 for n, p in model.named_parameters()]
|
||||
print(f"### Motion Module Parameters: {sum(params) / 1e6} M")
|
||||
|
||||
return model
|
||||
|
||||
Reference in New Issue
Block a user