support sparsectrl

This commit is contained in:
Yuwei Guo
2023-12-15 20:55:51 +08:00
parent 6c8a01b148
commit 401bc45697
7 changed files with 697 additions and 45 deletions

View File

@@ -28,7 +28,8 @@ from diffusers.utils import deprecate, logging, BaseOutput
from einops import rearrange
from ..models.unet import UNet3DConditionModel
from ..models.sparse_controlnet import SparseControlNetModel
import pdb
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -55,6 +56,7 @@ class AnimationPipeline(DiffusionPipeline):
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
],
controlnet: Union[SparseControlNetModel, None] = None,
):
super().__init__()
@@ -112,6 +114,7 @@ class AnimationPipeline(DiffusionPipeline):
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
controlnet=controlnet,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
@@ -330,6 +333,12 @@ class AnimationPipeline(DiffusionPipeline):
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
# support controlnet
controlnet_images: torch.FloatTensor = None,
controlnet_image_index: list = [0],
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
**kwargs,
):
# Default height and width to unet
@@ -391,15 +400,43 @@ class AnimationPipeline(DiffusionPipeline):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
down_block_additional_residuals = mid_block_additional_residual = None
if (getattr(self, "controlnet", None) != None) and (controlnet_images != None):
assert controlnet_images.dim() == 5
controlnet_noisy_latents = latent_model_input
controlnet_prompt_embeds = text_embeddings
controlnet_images = controlnet_images.to(latents.device)
controlnet_cond_shape = list(controlnet_images.shape)
controlnet_cond_shape[2] = video_length
controlnet_cond = torch.zeros(controlnet_cond_shape).to(latents.device)
controlnet_conditioning_mask_shape = list(controlnet_cond.shape)
controlnet_conditioning_mask_shape[1] = 1
controlnet_conditioning_mask = torch.zeros(controlnet_conditioning_mask_shape).to(latents.device)
assert controlnet_images.shape[2] >= len(controlnet_image_index)
controlnet_cond[:,:,controlnet_image_index] = controlnet_images[:,:,:len(controlnet_image_index)]
controlnet_conditioning_mask[:,:,controlnet_image_index] = 1
down_block_additional_residuals, mid_block_additional_residual = self.controlnet(
controlnet_noisy_latents, t,
encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=controlnet_cond,
conditioning_mask=controlnet_conditioning_mask,
conditioning_scale=controlnet_conditioning_scale,
guess_mode=False, return_dict=False,
)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.to(dtype=latents_dtype)
# noise_pred = []
# import pdb
# pdb.set_trace()
# for batch_idx in range(latent_model_input.shape[0]):
# noise_pred_single = self.unet(latent_model_input[batch_idx:batch_idx+1], t, encoder_hidden_states=text_embeddings[batch_idx:batch_idx+1]).sample.to(dtype=latents_dtype)
# noise_pred.append(noise_pred_single)
# noise_pred = torch.cat(noise_pred)
noise_pred = self.unet(
latent_model_input, t,
encoder_hidden_states=text_embeddings,
down_block_additional_residuals = down_block_additional_residuals,
mid_block_additional_residual = mid_block_additional_residual,
).sample.to(dtype=latents_dtype)
# perform guidance
if do_classifier_free_guidance: