mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-21 02:29:23 +01:00
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/12770413 * add prompt and lora * add adapter * add prefix * add tests * adapter smoke test passed * prompt test passed * support model id in petl * migrate chatglm6b * add train script for chatglm6b * move gen_kwargs to finetune.py * add chatglm2 * add model definination
917 lines
38 KiB
Python
917 lines
38 KiB
Python
# Copyright 2023-2024 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
|
# The implementation is adopted from HighCWu,
|
|
# made pubicly available under the Apache License 2.0 License at https://github.com/HighCWu/ControlLoRA
|
|
|
|
import os
|
|
from dataclasses import dataclass
|
|
from typing import List, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
|
from diffusers.models.cross_attention import CrossAttention, LoRALinearLayer
|
|
from diffusers.models.modeling_utils import ModelMixin
|
|
from diffusers.models.resnet import (Downsample2D, Upsample2D, downsample_2d,
|
|
partial, upsample_2d)
|
|
from diffusers.models.unet_2d_blocks import \
|
|
get_down_block as get_down_block_default
|
|
from diffusers.utils.outputs import BaseOutput
|
|
|
|
from .sd_lora import LoRACrossAttnProcessor
|
|
|
|
|
|
@dataclass
|
|
class ControlLoRAOutput(BaseOutput):
|
|
control_states: Tuple[torch.FloatTensor]
|
|
|
|
|
|
class ControlLoRATuner(ModelMixin, ConfigMixin):
|
|
""" The implementation of control lora module.
|
|
This module conduct encoding operation for control-condition and use lora to perform efficient tuning.
|
|
"""
|
|
|
|
@staticmethod
|
|
def tune(
|
|
model: nn.Module,
|
|
tuner_config=None,
|
|
pretrained_tuner=None,
|
|
):
|
|
tuner = ControlLoRATuner.from_config(tuner_config)
|
|
if pretrained_tuner is not None and os.path.exists(pretrained_tuner):
|
|
tuner.load_state_dict(
|
|
torch.load(pretrained_tuner, map_location='cpu'), strict=True)
|
|
|
|
tune_layers_list = list(
|
|
[list(layer_list) for layer_list in tuner.lora_layers])
|
|
|
|
assert hasattr(model, 'unet')
|
|
unet = model.unet
|
|
tuner.to(unet.device)
|
|
tune_attn_procs = tuner.set_tune_layers(unet, tune_layers_list)
|
|
unet.set_attn_processor(tune_attn_procs)
|
|
return tuner
|
|
|
|
def set_tune_layers(self, unet, tune_layers_list):
|
|
n_ch = len(unet.config.block_out_channels)
|
|
control_ids = [i for i in range(n_ch)]
|
|
tune_attn_procs = {}
|
|
|
|
for name in unet.attn_processors.keys():
|
|
if name.startswith('mid_block'):
|
|
control_id = control_ids[-1]
|
|
elif name.startswith('up_blocks'):
|
|
block_id = int(name[len('up_blocks.')])
|
|
control_id = list(reversed(control_ids))[block_id]
|
|
elif name.startswith('down_blocks'):
|
|
block_id = int(name[len('down_blocks.')])
|
|
control_id = control_ids[block_id]
|
|
|
|
tune_layers = tune_layers_list[control_id]
|
|
if len(tune_layers) != 0:
|
|
tune_layer = tune_layers.pop(0)
|
|
tune_attn_procs[name] = tune_layer
|
|
return tune_attn_procs
|
|
|
|
@register_to_config
|
|
def __init__(self,
|
|
in_channels: int = 3,
|
|
down_block_types: Tuple[str] = (
|
|
'SimpleDownEncoderBlock2D',
|
|
'SimpleDownEncoderBlock2D',
|
|
'SimpleDownEncoderBlock2D',
|
|
'SimpleDownEncoderBlock2D',
|
|
),
|
|
block_out_channels: Tuple[int] = (32, 64, 128, 256),
|
|
layers_per_block: int = 1,
|
|
act_fn: str = 'silu',
|
|
norm_num_groups: int = 32,
|
|
lora_pre_down_block_types: Tuple[str] = (
|
|
None,
|
|
'SimpleDownEncoderBlock2D',
|
|
'SimpleDownEncoderBlock2D',
|
|
'SimpleDownEncoderBlock2D',
|
|
),
|
|
lora_pre_down_layers_per_block: int = 1,
|
|
lora_pre_conv_skipped: bool = False,
|
|
lora_pre_conv_types: Tuple[str] = (
|
|
'SimpleDownEncoderBlock2D',
|
|
'SimpleDownEncoderBlock2D',
|
|
'SimpleDownEncoderBlock2D',
|
|
'SimpleDownEncoderBlock2D',
|
|
),
|
|
lora_pre_conv_layers_per_block: int = 1,
|
|
lora_pre_conv_layers_kernel_size: int = 1,
|
|
lora_block_in_channels: Tuple[int] = (256, 256, 256, 256),
|
|
lora_block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
|
lora_cross_attention_dims: Tuple[List[int]] = ([
|
|
None, 768, None, 768, None, 768, None, 768, None, 768
|
|
], [None, 768, None, 768, None, 768, None, 768, None, 768], [
|
|
None, 768, None, 768, None, 768, None, 768, None, 768
|
|
], [None, 768]),
|
|
lora_rank: int = 4,
|
|
lora_control_rank: int = None,
|
|
lora_post_add: bool = False,
|
|
lora_concat_hidden: bool = False,
|
|
lora_control_channels: Tuple[int] = (None, None, None, None),
|
|
lora_control_self_add: bool = True,
|
|
lora_key_states_skipped: bool = False,
|
|
lora_value_states_skipped: bool = False,
|
|
lora_output_states_skipped: bool = False,
|
|
lora_control_version: int = 1):
|
|
""" Initialize a control lora module instance.
|
|
Args:
|
|
in_channels (`int`): The number of channels for input conditional data.
|
|
down_block_types (Tuple[str], *optional*):
|
|
The down block types for conditional data's downsample operation.
|
|
block_out_channels (Tuple[int], *optional*, defaults to (32, 64, 128, 256)):
|
|
The number of channels for every down-block.
|
|
layers_per_block (`int`, *optional*, defaults to 1):
|
|
The number of layers of every block.
|
|
act_fn (`str`, *optional*, defaults to silu):
|
|
The activation function.
|
|
norm_num_groups (`int`, *optional*, defaults to 32):
|
|
The number of groups for norm operation.
|
|
lora_pre_down_block_types (Tuple[str], *optional*):
|
|
The block'types for pre down-block.
|
|
lora_pre_down_layers_per_block (`int`, *optional*, defaults to 1)
|
|
The number of layers of every pre down-block block.
|
|
lora_pre_conv_skipped ('bool', *optional*, defaults to False )
|
|
Set to True to skip conv in pre downsample.
|
|
lora_pre_conv_types (Tuple[str], *optional*):
|
|
The block'types for pre conv.
|
|
lora_pre_conv_layers_per_block (`int`, *optional*, defaults to 1)
|
|
The number of layers of every pre conv block.
|
|
lora_pre_conv_layers_kernel_size (`int`, *optional*, defaults to 1)
|
|
The conv kernel size of pre conv block.
|
|
lora_block_in_channels (Tuple[int], *optional*, defaults to (256, 256, 256, 256)):
|
|
The number of input channels for lora block.
|
|
lora_block_out_channels (Tuple[int], *optional*, defaults to (256, 256, 256, 256)):
|
|
The number of output channels for lora block.
|
|
lora_rank (int, *optional*, defaults to 4):
|
|
The rank of lora block.
|
|
lora_control_rank (int, *optional*, defaults to 4):
|
|
The rank of lora block.
|
|
lora_post_add (`bool`, *optional*, defaults to False):
|
|
Set to `True`, conduct weighted adding operation after lora.
|
|
lora_concat_hidden (`bool`, *optional*, defaults to False):
|
|
Set to `True`, conduct concat operation for hidden embedding.
|
|
lora_control_channels (Tuple[int], *optional*, defaults to (None, None, None, None)):
|
|
The number of control channels.
|
|
lora_control_self_add (`bool`, *optional*, defaults to True):
|
|
Set to `True` to perform self attn add.
|
|
lora_key_states_skipped (`bool`, *optional*, defaults to False):
|
|
Set to `True` for skip to perform lora on key value.
|
|
value_states_skipped (`bool`, *optional*, defaults to False):
|
|
Set to `True` for skip to perform lora on value.
|
|
output_states_skipped (`bool`, *optional*, defaults to False):
|
|
Set to `True` for skip to perform lora on output value.
|
|
lora_control_version (int, *optional*, defaults to 1):
|
|
Use lora attn version: ControlLoRACrossAttnProcessor vs ControlLoRACrossAttnProcessorV2.
|
|
"""
|
|
|
|
super().__init__()
|
|
lora_control_cls = ControlLoRACrossAttnProcessor
|
|
if lora_control_version == 2:
|
|
lora_control_cls = ControlLoRACrossAttnProcessorV2
|
|
|
|
assert lora_block_in_channels[0] == block_out_channels[-1]
|
|
|
|
if lora_pre_conv_skipped:
|
|
lora_control_channels = lora_block_in_channels
|
|
lora_control_self_add = False
|
|
|
|
self.layers_per_block = layers_per_block
|
|
self.lora_pre_down_layers_per_block = lora_pre_down_layers_per_block
|
|
self.lora_pre_conv_layers_per_block = lora_pre_conv_layers_per_block
|
|
|
|
self.conv_in = torch.nn.Conv2d(
|
|
in_channels,
|
|
block_out_channels[0],
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1)
|
|
|
|
self.down_blocks = nn.ModuleList([])
|
|
self.pre_lora_layers = nn.ModuleList([])
|
|
self.lora_layers = nn.ModuleList([])
|
|
|
|
# pre_down
|
|
pre_down_blocks = []
|
|
output_channel = block_out_channels[0]
|
|
for i, down_block_type in enumerate(down_block_types):
|
|
input_channel = output_channel
|
|
output_channel = block_out_channels[i]
|
|
is_final_block = i == len(block_out_channels) - 1
|
|
|
|
pre_down_block = get_down_block(
|
|
down_block_type,
|
|
num_layers=self.layers_per_block,
|
|
in_channels=input_channel,
|
|
out_channels=output_channel,
|
|
add_downsample=not is_final_block,
|
|
resnet_eps=1e-6,
|
|
downsample_padding=0,
|
|
resnet_act_fn=act_fn,
|
|
resnet_groups=norm_num_groups,
|
|
attn_num_head_channels=None,
|
|
temb_channels=None,
|
|
)
|
|
pre_down_blocks.append(pre_down_block)
|
|
self.down_blocks.append(nn.Sequential(*pre_down_blocks))
|
|
self.pre_lora_layers.append(
|
|
get_down_block(
|
|
lora_pre_conv_types[0],
|
|
num_layers=self.lora_pre_conv_layers_per_block,
|
|
in_channels=lora_block_in_channels[0],
|
|
out_channels=(
|
|
lora_block_out_channels[0] if lora_control_channels[0] is
|
|
None else lora_control_channels[0]),
|
|
add_downsample=False,
|
|
resnet_eps=1e-6,
|
|
downsample_padding=0,
|
|
resnet_act_fn=act_fn,
|
|
resnet_groups=norm_num_groups,
|
|
attn_num_head_channels=None,
|
|
temb_channels=None,
|
|
resnet_kernel_size=lora_pre_conv_layers_kernel_size,
|
|
) if not lora_pre_conv_skipped else nn.Identity())
|
|
self.lora_layers.append(
|
|
nn.ModuleList([
|
|
lora_control_cls(
|
|
lora_block_out_channels[0],
|
|
cross_attention_dim=cross_attention_dim,
|
|
rank=lora_rank,
|
|
control_rank=lora_control_rank,
|
|
post_add=lora_post_add,
|
|
concat_hidden=lora_concat_hidden,
|
|
control_channels=lora_control_channels[0],
|
|
control_self_add=lora_control_self_add,
|
|
key_states_skipped=lora_key_states_skipped,
|
|
value_states_skipped=lora_value_states_skipped,
|
|
output_states_skipped=lora_output_states_skipped)
|
|
for cross_attention_dim in lora_cross_attention_dims[0]
|
|
]))
|
|
|
|
# down
|
|
output_channel = lora_block_in_channels[0]
|
|
for i, down_block_type in enumerate(lora_pre_down_block_types):
|
|
if i == 0:
|
|
continue
|
|
input_channel = output_channel
|
|
output_channel = lora_block_in_channels[i]
|
|
|
|
down_block = get_down_block(
|
|
down_block_type,
|
|
num_layers=self.lora_pre_down_layers_per_block,
|
|
in_channels=input_channel,
|
|
out_channels=output_channel,
|
|
add_downsample=True,
|
|
resnet_eps=1e-6,
|
|
downsample_padding=0,
|
|
resnet_act_fn=act_fn,
|
|
resnet_groups=norm_num_groups,
|
|
attn_num_head_channels=None,
|
|
temb_channels=None,
|
|
)
|
|
self.down_blocks.append(down_block)
|
|
|
|
self.pre_lora_layers.append(
|
|
get_down_block(
|
|
lora_pre_conv_types[i],
|
|
num_layers=self.lora_pre_conv_layers_per_block,
|
|
in_channels=output_channel,
|
|
out_channels=(
|
|
lora_block_out_channels[i] if lora_control_channels[i]
|
|
is None else lora_control_channels[i]),
|
|
add_downsample=False,
|
|
resnet_eps=1e-6,
|
|
downsample_padding=0,
|
|
resnet_act_fn=act_fn,
|
|
resnet_groups=norm_num_groups,
|
|
attn_num_head_channels=None,
|
|
temb_channels=None,
|
|
resnet_kernel_size=lora_pre_conv_layers_kernel_size,
|
|
) if not lora_pre_conv_skipped else nn.Identity())
|
|
self.lora_layers.append(
|
|
nn.ModuleList([
|
|
lora_control_cls(
|
|
lora_block_out_channels[i],
|
|
cross_attention_dim=cross_attention_dim,
|
|
rank=lora_rank,
|
|
control_rank=lora_control_rank,
|
|
post_add=lora_post_add,
|
|
concat_hidden=lora_concat_hidden,
|
|
control_channels=lora_control_channels[i],
|
|
control_self_add=lora_control_self_add,
|
|
key_states_skipped=lora_key_states_skipped,
|
|
value_states_skipped=lora_value_states_skipped,
|
|
output_states_skipped=lora_output_states_skipped)
|
|
for cross_attention_dim in lora_cross_attention_dims[i]
|
|
]))
|
|
|
|
def forward(self,
|
|
x: torch.FloatTensor,
|
|
return_dict: bool = True) -> Union[ControlLoRAOutput, Tuple]:
|
|
lora_layer: ControlLoRACrossAttnProcessor
|
|
|
|
orig_dtype = x.dtype
|
|
dtype = self.conv_in.weight.dtype
|
|
|
|
h = x.to(dtype)
|
|
h = self.conv_in(h)
|
|
control_states_list = []
|
|
|
|
# down
|
|
for down_block, pre_lora_layer, lora_layer_list in zip(
|
|
self.down_blocks, self.pre_lora_layers, self.lora_layers):
|
|
h = down_block(h)
|
|
control_states = pre_lora_layer(h)
|
|
if isinstance(control_states, tuple):
|
|
control_states = control_states[0]
|
|
control_states = control_states.to(orig_dtype)
|
|
for lora_layer in lora_layer_list:
|
|
lora_layer.inject_control_states(control_states)
|
|
control_states_list.append(control_states)
|
|
|
|
if not return_dict:
|
|
return tuple(control_states_list)
|
|
|
|
return ControlLoRAOutput(control_states=tuple(control_states_list))
|
|
|
|
|
|
def get_down_block(
|
|
down_block_type,
|
|
num_layers,
|
|
in_channels,
|
|
out_channels,
|
|
temb_channels,
|
|
add_downsample,
|
|
resnet_eps,
|
|
resnet_act_fn,
|
|
attn_num_head_channels,
|
|
resnet_groups=None,
|
|
cross_attention_dim=None,
|
|
downsample_padding=None,
|
|
dual_cross_attention=False,
|
|
use_linear_projection=False,
|
|
only_cross_attention=False,
|
|
upcast_attention=False,
|
|
resnet_time_scale_shift='default',
|
|
resnet_kernel_size=3,
|
|
):
|
|
down_block_type = down_block_type[7:] if down_block_type.startswith(
|
|
'UNetRes') else down_block_type
|
|
if down_block_type == 'SimpleDownEncoderBlock2D':
|
|
return SimpleDownEncoderBlock2D(
|
|
num_layers=num_layers,
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
add_downsample=add_downsample,
|
|
convnet_eps=resnet_eps,
|
|
convnet_act_fn=resnet_act_fn,
|
|
convnet_groups=resnet_groups,
|
|
downsample_padding=downsample_padding,
|
|
convnet_time_scale_shift=resnet_time_scale_shift,
|
|
convnet_kernel_size=resnet_kernel_size)
|
|
else:
|
|
return get_down_block_default(
|
|
down_block_type,
|
|
num_layers,
|
|
in_channels,
|
|
out_channels,
|
|
temb_channels,
|
|
add_downsample,
|
|
resnet_eps,
|
|
resnet_act_fn,
|
|
attn_num_head_channels,
|
|
resnet_groups=resnet_groups,
|
|
cross_attention_dim=cross_attention_dim,
|
|
downsample_padding=downsample_padding,
|
|
dual_cross_attention=dual_cross_attention,
|
|
use_linear_projection=use_linear_projection,
|
|
only_cross_attention=only_cross_attention,
|
|
upcast_attention=upcast_attention,
|
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
|
# resnet_kernel_size=resnet_kernel_size
|
|
)
|
|
|
|
|
|
class ControlLoRACrossAttnProcessor(LoRACrossAttnProcessor):
|
|
|
|
def __init__(self,
|
|
hidden_size,
|
|
cross_attention_dim=None,
|
|
rank=4,
|
|
control_rank=None,
|
|
post_add=False,
|
|
concat_hidden=False,
|
|
control_channels=None,
|
|
control_self_add=True,
|
|
key_states_skipped=False,
|
|
value_states_skipped=False,
|
|
output_states_skipped=False,
|
|
**kwargs):
|
|
super().__init__(
|
|
hidden_size,
|
|
cross_attention_dim,
|
|
rank,
|
|
post_add=post_add,
|
|
key_states_skipped=key_states_skipped,
|
|
value_states_skipped=value_states_skipped,
|
|
output_states_skipped=output_states_skipped)
|
|
|
|
control_rank = rank if control_rank is None else control_rank
|
|
control_channels = hidden_size if control_channels is None else control_channels
|
|
self.concat_hidden = concat_hidden
|
|
self.control_self_add = control_self_add if control_channels is None else False
|
|
self.control_states: torch.Tensor = None
|
|
|
|
self.to_control = LoRALinearLayer(
|
|
control_channels + (hidden_size if concat_hidden else 0),
|
|
hidden_size, control_rank)
|
|
self.pre_loras: List[LoRACrossAttnProcessor] = []
|
|
self.post_loras: List[LoRACrossAttnProcessor] = []
|
|
|
|
def inject_pre_lora(self, lora_layer):
|
|
self.pre_loras.append(lora_layer)
|
|
|
|
def inject_post_lora(self, lora_layer):
|
|
self.post_loras.append(lora_layer)
|
|
|
|
def inject_control_states(self, control_states):
|
|
self.control_states = control_states
|
|
|
|
def process_control_states(self, hidden_states, scale=1.0):
|
|
control_states = self.control_states.to(hidden_states.dtype)
|
|
if hidden_states.ndim == 3 and control_states.ndim == 4:
|
|
batch, _, height, width = control_states.shape
|
|
control_states = control_states.permute(0, 2, 3, 1).reshape(
|
|
batch, height * width, -1)
|
|
self.control_states = control_states
|
|
_control_states = control_states
|
|
if self.concat_hidden:
|
|
b1, b2 = control_states.shape[0], hidden_states.shape[0]
|
|
if b1 != b2:
|
|
control_states = control_states[:, None].repeat(
|
|
1, b2 // b1, *([1] * (len(control_states.shape) - 1)))
|
|
control_states = control_states.view(-1,
|
|
*control_states.shape[2:])
|
|
_control_states = torch.cat([hidden_states, control_states], -1)
|
|
_control_states = scale * self.to_control(_control_states)
|
|
if self.control_self_add:
|
|
control_states = control_states + _control_states
|
|
else:
|
|
control_states = _control_states
|
|
|
|
return control_states
|
|
|
|
def __call__(self,
|
|
attn: CrossAttention,
|
|
hidden_states,
|
|
encoder_hidden_states=None,
|
|
attention_mask=None,
|
|
scale=1.0):
|
|
pre_lora: LoRACrossAttnProcessor
|
|
post_lora: LoRACrossAttnProcessor
|
|
assert self.control_states is not None
|
|
|
|
batch_size, sequence_length, _ = hidden_states.shape
|
|
attention_mask = attn.prepare_attention_mask(
|
|
attention_mask=attention_mask,
|
|
target_length=sequence_length,
|
|
batch_size=batch_size)
|
|
query = attn.to_q(hidden_states)
|
|
for pre_lora in self.pre_loras:
|
|
lora_in = query if pre_lora.post_add else hidden_states
|
|
if isinstance(pre_lora, ControlLoRACrossAttnProcessor):
|
|
lora_in = lora_in + pre_lora.process_control_states(
|
|
hidden_states, scale)
|
|
query = query + scale * pre_lora.to_q_lora(lora_in)
|
|
query = query + scale * self.to_q_lora(
|
|
(query if self.post_add else hidden_states)
|
|
+ self.process_control_states(hidden_states, scale))
|
|
for post_lora in self.post_loras:
|
|
lora_in = query if post_lora.post_add else hidden_states
|
|
if isinstance(post_lora, ControlLoRACrossAttnProcessor):
|
|
lora_in = lora_in + post_lora.process_control_states(
|
|
hidden_states, scale)
|
|
query = query + scale * post_lora.to_q_lora(lora_in)
|
|
query = attn.head_to_batch_dim(query)
|
|
|
|
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
|
|
|
key = attn.to_k(encoder_hidden_states)
|
|
for pre_lora in self.pre_loras:
|
|
if not pre_lora.key_states_skipped:
|
|
key = key + scale * pre_lora.to_k_lora(
|
|
key if pre_lora.post_add else encoder_hidden_states)
|
|
if not self.key_states_skipped:
|
|
key = key + scale * self.to_k_lora(
|
|
key if self.post_add else encoder_hidden_states)
|
|
for post_lora in self.post_loras:
|
|
if not post_lora.key_states_skipped:
|
|
key = key + scale * post_lora.to_k_lora(
|
|
key if post_lora.post_add else encoder_hidden_states)
|
|
value = attn.to_v(encoder_hidden_states)
|
|
for pre_lora in self.pre_loras:
|
|
if not pre_lora.value_states_skipped:
|
|
value = value + pre_lora.to_v_lora(
|
|
value if pre_lora.post_add else encoder_hidden_states)
|
|
if not self.value_states_skipped:
|
|
value = value + scale * self.to_v_lora(
|
|
value if self.post_add else encoder_hidden_states)
|
|
for post_lora in self.post_loras:
|
|
if not post_lora.value_states_skipped:
|
|
value = value + post_lora.to_v_lora(
|
|
value if post_lora.post_add else encoder_hidden_states)
|
|
|
|
key = attn.head_to_batch_dim(key)
|
|
value = attn.head_to_batch_dim(value)
|
|
|
|
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
|
hidden_states = torch.bmm(attention_probs, value)
|
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
|
|
# linear proj
|
|
out = attn.to_out[0](hidden_states)
|
|
for pre_lora in self.pre_loras:
|
|
if not pre_lora.output_states_skipped:
|
|
out = out + scale * pre_lora.to_out_lora(
|
|
out if pre_lora.post_add else hidden_states)
|
|
out = out + scale * self.to_out_lora(
|
|
out if self.post_add else hidden_states)
|
|
for post_lora in self.post_loras:
|
|
if not post_lora.output_states_skipped:
|
|
out = out + scale * post_lora.to_out_lora(
|
|
out if post_lora.post_add else hidden_states)
|
|
hidden_states = out
|
|
# dropout
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class ControlLoRACrossAttnProcessorV2(LoRACrossAttnProcessor):
|
|
|
|
def __init__(self,
|
|
hidden_size,
|
|
cross_attention_dim=None,
|
|
rank=4,
|
|
control_rank=None,
|
|
control_channels=None,
|
|
**kwargs):
|
|
super().__init__(
|
|
hidden_size,
|
|
cross_attention_dim,
|
|
rank,
|
|
post_add=False,
|
|
key_states_skipped=True,
|
|
value_states_skipped=True,
|
|
output_states_skipped=False)
|
|
|
|
control_rank = rank if control_rank is None else control_rank
|
|
control_channels = hidden_size if control_channels is None else control_channels
|
|
self.concat_hidden = True
|
|
self.control_self_add = False
|
|
self.control_states: torch.Tensor = None
|
|
|
|
self.to_control = LoRALinearLayer(hidden_size + control_channels,
|
|
hidden_size, control_rank)
|
|
self.to_control_out = LoRALinearLayer(hidden_size + control_channels,
|
|
hidden_size, control_rank)
|
|
self.pre_loras: List[LoRACrossAttnProcessor] = []
|
|
self.post_loras: List[LoRACrossAttnProcessor] = []
|
|
|
|
def inject_pre_lora(self, lora_layer):
|
|
self.pre_loras.append(lora_layer)
|
|
|
|
def inject_post_lora(self, lora_layer):
|
|
self.post_loras.append(lora_layer)
|
|
|
|
def inject_control_states(self, control_states):
|
|
self.control_states = control_states
|
|
|
|
def process_control_states(self, hidden_states, scale=1.0, is_out=False):
|
|
control_states = self.control_states.to(hidden_states.dtype)
|
|
if hidden_states.ndim == 3 and control_states.ndim == 4:
|
|
batch, _, height, width = control_states.shape
|
|
control_states = control_states.permute(0, 2, 3, 1).reshape(
|
|
batch, height * width, -1)
|
|
self.control_states = control_states
|
|
_control_states = control_states
|
|
if self.concat_hidden:
|
|
b1, b2 = control_states.shape[0], hidden_states.shape[0]
|
|
if b1 != b2:
|
|
control_states = control_states[:, None].repeat(
|
|
1, b2 // b1, *([1] * (len(control_states.shape) - 1)))
|
|
control_states = control_states.view(-1,
|
|
*control_states.shape[2:])
|
|
_control_states = torch.cat([hidden_states, control_states], -1)
|
|
_control_states = scale * (self.to_control_out
|
|
if is_out else self.to_control)(
|
|
_control_states)
|
|
if self.control_self_add:
|
|
control_states = control_states + _control_states
|
|
else:
|
|
control_states = _control_states
|
|
|
|
return control_states
|
|
|
|
def __call__(self,
|
|
attn: CrossAttention,
|
|
hidden_states,
|
|
encoder_hidden_states=None,
|
|
attention_mask=None,
|
|
scale=1.0):
|
|
pre_lora: LoRACrossAttnProcessor
|
|
post_lora: LoRACrossAttnProcessor
|
|
assert self.control_states is not None
|
|
|
|
batch_size, sequence_length, _ = hidden_states.shape
|
|
attention_mask = attn.prepare_attention_mask(
|
|
attention_mask=attention_mask,
|
|
target_length=sequence_length,
|
|
batch_size=batch_size)
|
|
for pre_lora in self.pre_loras:
|
|
if isinstance(pre_lora, ControlLoRACrossAttnProcessorV2):
|
|
hidden_states = hidden_states + pre_lora.process_control_states(
|
|
hidden_states, scale)
|
|
hidden_states = hidden_states + self.process_control_states(
|
|
hidden_states, scale)
|
|
for post_lora in self.post_loras:
|
|
if isinstance(post_lora, ControlLoRACrossAttnProcessorV2):
|
|
hidden_states = hidden_states + post_lora.process_control_states(
|
|
hidden_states, scale)
|
|
query = attn.to_q(hidden_states)
|
|
for pre_lora in self.pre_loras:
|
|
lora_in = query if pre_lora.post_add else hidden_states
|
|
query = query + scale * pre_lora.to_q_lora(lora_in)
|
|
query = query + scale * self.to_q_lora(
|
|
query if self.post_add else hidden_states)
|
|
for post_lora in self.post_loras:
|
|
lora_in = query if post_lora.post_add else hidden_states
|
|
query = query + scale * post_lora.to_q_lora(lora_in)
|
|
query = attn.head_to_batch_dim(query)
|
|
|
|
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
|
|
|
key = attn.to_k(encoder_hidden_states)
|
|
for pre_lora in self.pre_loras:
|
|
if not pre_lora.key_states_skipped:
|
|
key = key + scale * pre_lora.to_k_lora(
|
|
key if pre_lora.post_add else encoder_hidden_states)
|
|
if not self.key_states_skipped:
|
|
key = key + scale * self.to_k_lora(
|
|
key if self.post_add else encoder_hidden_states)
|
|
for post_lora in self.post_loras:
|
|
if not post_lora.key_states_skipped:
|
|
key = key + scale * post_lora.to_k_lora(
|
|
key if post_lora.post_add else encoder_hidden_states)
|
|
value = attn.to_v(encoder_hidden_states)
|
|
for pre_lora in self.pre_loras:
|
|
if not pre_lora.value_states_skipped:
|
|
value = value + pre_lora.to_v_lora(
|
|
value if pre_lora.post_add else encoder_hidden_states)
|
|
if not self.value_states_skipped:
|
|
value = value + scale * self.to_v_lora(
|
|
value if self.post_add else encoder_hidden_states)
|
|
for post_lora in self.post_loras:
|
|
if not post_lora.value_states_skipped:
|
|
value = value + post_lora.to_v_lora(
|
|
value if post_lora.post_add else encoder_hidden_states)
|
|
|
|
key = attn.head_to_batch_dim(key)
|
|
value = attn.head_to_batch_dim(value)
|
|
|
|
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
|
hidden_states = torch.bmm(attention_probs, value)
|
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
|
|
# linear proj
|
|
for pre_lora in self.pre_loras:
|
|
if isinstance(pre_lora, ControlLoRACrossAttnProcessorV2):
|
|
hidden_states = hidden_states + pre_lora.process_control_states(
|
|
hidden_states, scale, is_out=True)
|
|
hidden_states = hidden_states + self.process_control_states(
|
|
hidden_states, scale, is_out=True)
|
|
for post_lora in self.post_loras:
|
|
if isinstance(post_lora, ControlLoRACrossAttnProcessorV2):
|
|
hidden_states = hidden_states + post_lora.process_control_states(
|
|
hidden_states, scale, is_out=True)
|
|
out = attn.to_out[0](hidden_states)
|
|
for pre_lora in self.pre_loras:
|
|
if not pre_lora.output_states_skipped:
|
|
out = out + scale * pre_lora.to_out_lora(
|
|
out if pre_lora.post_add else hidden_states)
|
|
out = out + scale * self.to_out_lora(
|
|
out if self.post_add else hidden_states)
|
|
for post_lora in self.post_loras:
|
|
if not post_lora.output_states_skipped:
|
|
out = out + scale * post_lora.to_out_lora(
|
|
out if post_lora.post_add else hidden_states)
|
|
hidden_states = out
|
|
# dropout
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class ConvBlock2D(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
in_channels,
|
|
out_channels=None,
|
|
conv_kernel_size=3,
|
|
dropout=0.0,
|
|
temb_channels=512,
|
|
groups=32,
|
|
groups_out=None,
|
|
pre_norm=True,
|
|
eps=1e-6,
|
|
non_linearity='swish',
|
|
time_embedding_norm='default',
|
|
kernel=None,
|
|
output_scale_factor=1.0,
|
|
up=False,
|
|
down=False,
|
|
):
|
|
super().__init__()
|
|
self.pre_norm = pre_norm
|
|
self.pre_norm = True
|
|
self.in_channels = in_channels
|
|
out_channels = in_channels if out_channels is None else out_channels
|
|
self.out_channels = out_channels
|
|
self.time_embedding_norm = time_embedding_norm
|
|
self.up = up
|
|
self.down = down
|
|
self.output_scale_factor = output_scale_factor
|
|
|
|
if groups_out is None:
|
|
groups_out = groups
|
|
|
|
self.norm1 = torch.nn.GroupNorm(
|
|
num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
|
|
|
self.conv1 = torch.nn.Conv2d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=conv_kernel_size,
|
|
stride=1,
|
|
padding=conv_kernel_size // 2)
|
|
|
|
if temb_channels is not None:
|
|
if self.time_embedding_norm == 'default':
|
|
time_emb_proj_out_channels = out_channels
|
|
elif self.time_embedding_norm == 'scale_shift':
|
|
time_emb_proj_out_channels = out_channels * 2
|
|
else:
|
|
raise ValueError(
|
|
f'unknown time_embedding_norm : {self.time_embedding_norm} '
|
|
)
|
|
|
|
self.time_emb_proj = torch.nn.Linear(temb_channels,
|
|
time_emb_proj_out_channels)
|
|
else:
|
|
self.time_emb_proj = None
|
|
|
|
self.norm2 = torch.nn.GroupNorm(
|
|
num_groups=groups_out,
|
|
num_channels=out_channels,
|
|
eps=eps,
|
|
affine=True)
|
|
self.dropout = torch.nn.Dropout(dropout)
|
|
|
|
if non_linearity == 'swish':
|
|
self.nonlinearity = lambda x: F.silu(x)
|
|
elif non_linearity == 'mish':
|
|
self.nonlinearity = nn.Mish()
|
|
elif non_linearity == 'silu':
|
|
self.nonlinearity = nn.SiLU()
|
|
|
|
self.upsample = self.downsample = None
|
|
if self.up:
|
|
if kernel == 'fir':
|
|
fir_kernel = (1, 3, 3, 1)
|
|
self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
|
|
elif kernel == 'sde_vp':
|
|
self.upsample = partial(
|
|
F.interpolate, scale_factor=2.0, mode='nearest')
|
|
else:
|
|
self.upsample = Upsample2D(in_channels, use_conv=False)
|
|
elif self.down:
|
|
if kernel == 'fir':
|
|
fir_kernel = (1, 3, 3, 1)
|
|
self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
|
|
elif kernel == 'sde_vp':
|
|
self.downsample = partial(
|
|
F.avg_pool2d, kernel_size=2, stride=2)
|
|
else:
|
|
self.downsample = Downsample2D(
|
|
in_channels, use_conv=False, padding=1, name='op')
|
|
|
|
def forward(self, input_tensor, temb):
|
|
hidden_states = input_tensor
|
|
|
|
hidden_states = self.norm1(hidden_states)
|
|
hidden_states = self.nonlinearity(hidden_states)
|
|
|
|
if self.upsample is not None:
|
|
# upsample_nearest_nhwc fails with large batch sizes.
|
|
# see https://github.com/huggingface/diffusers/issues/984
|
|
if hidden_states.shape[0] >= 64:
|
|
input_tensor = input_tensor.contiguous()
|
|
hidden_states = hidden_states.contiguous()
|
|
_ = self.upsample(input_tensor)
|
|
hidden_states = self.upsample(hidden_states)
|
|
elif self.downsample is not None:
|
|
_ = self.downsample(input_tensor)
|
|
hidden_states = self.downsample(hidden_states)
|
|
|
|
hidden_states = self.conv1(hidden_states)
|
|
|
|
if temb is not None:
|
|
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None,
|
|
None]
|
|
|
|
if temb is not None and self.time_embedding_norm == 'default':
|
|
hidden_states = hidden_states + temb
|
|
|
|
hidden_states = self.norm2(hidden_states)
|
|
|
|
if temb is not None and self.time_embedding_norm == 'scale_shift':
|
|
scale, shift = torch.chunk(temb, 2, dim=1)
|
|
hidden_states = hidden_states * (1 + scale) + shift
|
|
|
|
hidden_states = self.nonlinearity(hidden_states)
|
|
|
|
output_tensor = self.dropout(hidden_states)
|
|
|
|
return output_tensor
|
|
|
|
|
|
class SimpleDownEncoderBlock2D(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
dropout: float = 0.0,
|
|
num_layers: int = 1,
|
|
convnet_eps: float = 1e-6,
|
|
convnet_time_scale_shift: str = 'default',
|
|
convnet_act_fn: str = 'swish',
|
|
convnet_groups: int = 32,
|
|
convnet_pre_norm: bool = True,
|
|
convnet_kernel_size: int = 3,
|
|
output_scale_factor=1.0,
|
|
add_downsample=True,
|
|
downsample_padding=1,
|
|
):
|
|
super().__init__()
|
|
convnets = []
|
|
|
|
for i in range(num_layers):
|
|
in_channels = in_channels if i == 0 else out_channels
|
|
convnets.append(
|
|
ConvBlock2D(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
temb_channels=None,
|
|
eps=convnet_eps,
|
|
groups=convnet_groups,
|
|
dropout=dropout,
|
|
time_embedding_norm=convnet_time_scale_shift,
|
|
non_linearity=convnet_act_fn,
|
|
output_scale_factor=output_scale_factor,
|
|
pre_norm=convnet_pre_norm,
|
|
conv_kernel_size=convnet_kernel_size,
|
|
))
|
|
in_channels = in_channels if num_layers == 0 else out_channels
|
|
|
|
self.convnets = nn.ModuleList(convnets)
|
|
|
|
if add_downsample:
|
|
self.downsamplers = nn.ModuleList([
|
|
Downsample2D(
|
|
in_channels,
|
|
use_conv=True,
|
|
out_channels=out_channels,
|
|
padding=downsample_padding,
|
|
name='op')
|
|
])
|
|
else:
|
|
self.downsamplers = None
|
|
|
|
def forward(self, hidden_states):
|
|
for convnet in self.convnets:
|
|
hidden_states = convnet(hidden_states, temb=None)
|
|
|
|
if self.downsamplers is not None:
|
|
for downsampler in self.downsamplers:
|
|
hidden_states = downsampler(hidden_states)
|
|
|
|
return hidden_states
|