This commit is contained in:
XDUWQ
2023-08-31 10:40:59 +08:00
parent d1478db172
commit c217d29309
2 changed files with 7 additions and 5 deletions

View File

@@ -10,7 +10,8 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.cross_attention import CrossAttention, LoRALinearLayer
from diffusers.models.attention_processor import Attention
from diffusers.models.lora import LoRALinearLayer
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.resnet import (Downsample2D, Upsample2D, downsample_2d,
partial, upsample_2d)
@@ -467,7 +468,7 @@ class ControlLoRACrossAttnProcessor(LoRACrossAttnProcessor):
return control_states
def __call__(self,
attn: CrossAttention,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
@@ -619,7 +620,7 @@ class ControlLoRACrossAttnProcessorV2(LoRACrossAttnProcessor):
return control_states
def __call__(self,
attn: CrossAttention,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,

View File

@@ -8,7 +8,8 @@ from typing import List, Tuple, Union
import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.cross_attention import CrossAttention, LoRALinearLayer
from diffusers.models.attention_processor import Attention
from diffusers.models.lora import LoRALinearLayer
from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils.outputs import BaseOutput
@@ -84,7 +85,7 @@ class LoRACrossAttnProcessor(nn.Module):
self.output_states_skipped = is_skipped
def __call__(self,
attn: CrossAttention,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,