mirror of
https://github.com/guoyww/AnimateDiff.git
synced 2026-04-03 09:46:36 +02:00
support sparsectrl
This commit is contained in:
@@ -28,7 +28,8 @@ from diffusers.utils import deprecate, logging, BaseOutput
|
||||
from einops import rearrange
|
||||
|
||||
from ..models.unet import UNet3DConditionModel
|
||||
|
||||
from ..models.sparse_controlnet import SparseControlNetModel
|
||||
import pdb
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -55,6 +56,7 @@ class AnimationPipeline(DiffusionPipeline):
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
],
|
||||
controlnet: Union[SparseControlNetModel, None] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -112,6 +114,7 @@ class AnimationPipeline(DiffusionPipeline):
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
controlnet=controlnet,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
|
||||
@@ -330,6 +333,12 @@ class AnimationPipeline(DiffusionPipeline):
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
|
||||
# support controlnet
|
||||
controlnet_images: torch.FloatTensor = None,
|
||||
controlnet_image_index: list = [0],
|
||||
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
||||
|
||||
**kwargs,
|
||||
):
|
||||
# Default height and width to unet
|
||||
@@ -391,15 +400,43 @@ class AnimationPipeline(DiffusionPipeline):
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
down_block_additional_residuals = mid_block_additional_residual = None
|
||||
if (getattr(self, "controlnet", None) != None) and (controlnet_images != None):
|
||||
assert controlnet_images.dim() == 5
|
||||
|
||||
controlnet_noisy_latents = latent_model_input
|
||||
controlnet_prompt_embeds = text_embeddings
|
||||
|
||||
controlnet_images = controlnet_images.to(latents.device)
|
||||
|
||||
controlnet_cond_shape = list(controlnet_images.shape)
|
||||
controlnet_cond_shape[2] = video_length
|
||||
controlnet_cond = torch.zeros(controlnet_cond_shape).to(latents.device)
|
||||
|
||||
controlnet_conditioning_mask_shape = list(controlnet_cond.shape)
|
||||
controlnet_conditioning_mask_shape[1] = 1
|
||||
controlnet_conditioning_mask = torch.zeros(controlnet_conditioning_mask_shape).to(latents.device)
|
||||
|
||||
assert controlnet_images.shape[2] >= len(controlnet_image_index)
|
||||
controlnet_cond[:,:,controlnet_image_index] = controlnet_images[:,:,:len(controlnet_image_index)]
|
||||
controlnet_conditioning_mask[:,:,controlnet_image_index] = 1
|
||||
|
||||
down_block_additional_residuals, mid_block_additional_residual = self.controlnet(
|
||||
controlnet_noisy_latents, t,
|
||||
encoder_hidden_states=controlnet_prompt_embeds,
|
||||
controlnet_cond=controlnet_cond,
|
||||
conditioning_mask=controlnet_conditioning_mask,
|
||||
conditioning_scale=controlnet_conditioning_scale,
|
||||
guess_mode=False, return_dict=False,
|
||||
)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.to(dtype=latents_dtype)
|
||||
# noise_pred = []
|
||||
# import pdb
|
||||
# pdb.set_trace()
|
||||
# for batch_idx in range(latent_model_input.shape[0]):
|
||||
# noise_pred_single = self.unet(latent_model_input[batch_idx:batch_idx+1], t, encoder_hidden_states=text_embeddings[batch_idx:batch_idx+1]).sample.to(dtype=latents_dtype)
|
||||
# noise_pred.append(noise_pred_single)
|
||||
# noise_pred = torch.cat(noise_pred)
|
||||
noise_pred = self.unet(
|
||||
latent_model_input, t,
|
||||
encoder_hidden_states=text_embeddings,
|
||||
down_block_additional_residuals = down_block_additional_residuals,
|
||||
mid_block_additional_residual = mid_block_additional_residual,
|
||||
).sample.to(dtype=latents_dtype)
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
|
||||
Reference in New Issue
Block a user