Merge branch 'master' of gitlab.alibaba-inc.com:Ali-MaaS/MaaS-lib into master-gitlab

This commit is contained in:
mulin.lyh
2023-09-11 13:55:34 +08:00
26 changed files with 7928 additions and 9 deletions

View File

@@ -48,7 +48,7 @@ ENV SETUPTOOLS_USE_DISTUTILS=stdlib
RUN CUDA_HOME=/usr/local/cuda TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6" pip install --no-cache-dir 'git+https://github.com/facebookresearch/detectron2.git'
# torchmetrics==0.11.4 for ofa
RUN pip install --no-cache-dir tiktoken torchmetrics==0.11.4 'transformers<4.31.0' transformers_stream_generator 'protobuf<=3.20.0' bitsandbytes basicsr
RUN pip install --no-cache-dir tiktoken torchmetrics==0.11.4 transformers_stream_generator 'protobuf<=3.20.0' bitsandbytes basicsr
COPY docker/scripts/install_flash_attension.sh /tmp/install_flash_attension.sh
RUN if [ "$USE_GPU" = "True" ] ; then \
bash /tmp/install_flash_attension.sh; \

View File

@@ -339,6 +339,7 @@ class Pipelines(object):
image_colorization = 'unet-image-colorization'
image_style_transfer = 'AAMS-style-transfer'
image_super_resolution = 'rrdb-image-super-resolution'
image_super_resolution_pasd = 'image-super-resolution-pasd'
image_debanding = 'rrdb-image-debanding'
face_image_generation = 'gan-face-image-generation'
product_retrieval_embedding = 'resnet50-product-retrieval-embedding'

View File

@@ -14,10 +14,11 @@ from . import (action_recognition, animal_recognition, bad_image_detecting,
image_quality_assessment_degradation,
image_quality_assessment_man, image_quality_assessment_mos,
image_reid_person, image_restoration,
image_semantic_segmentation, image_to_image_generation,
image_to_image_translation, language_guided_video_summarization,
movie_scene_segmentation, object_detection,
panorama_depth_estimation, pedestrian_attribute_recognition,
image_semantic_segmentation, image_super_resolution_pasd,
image_to_image_generation, image_to_image_translation,
language_guided_video_summarization, movie_scene_segmentation,
object_detection, panorama_depth_estimation,
pedestrian_attribute_recognition,
pointcloud_sceneflow_estimation, product_retrieval_embedding,
referring_video_object_segmentation,
robust_image_classification, salient_detection,

View File

@@ -0,0 +1,24 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .unet_2d_condition import UNet2DConditionModel
from .controlnet import ControlNetModel
else:
_import_structure = {
'unet_2d_condition': ['UNet2DConditionModel'],
'controlnet': ['ControlNetModel']
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1,421 @@
# Part of the implementation is borrowed and modified from diffusers,
# publicly available at https://github.com/huggingface/diffusers/tree/main/src/diffusers/models/attention.py
import math
from typing import Callable, Optional
import torch
import torch.nn.functional as F
from diffusers.models.attention_processor import Attention
from diffusers.models.embeddings import CombinedTimestepLabelEmbeddings
from diffusers.utils import maybe_allow_in_graph
from torch import nn
@maybe_allow_in_graph
class BasicTransformerBlock(nn.Module):
r"""
A basic Transformer block.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
pixelwise_cross_attention_dim: Optional[int] = None,
activation_fn: str = 'geglu',
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = 'layer_norm',
final_dropout: bool = False,
use_pixelwise_attention: bool = False,
):
super().__init__()
self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm_zero = (
num_embeds_ada_norm is not None) and norm_type == 'ada_norm_zero'
self.use_ada_layer_norm = (num_embeds_ada_norm
is not None) and norm_type == 'ada_norm'
if norm_type in ('ada_norm',
'ada_norm_zero') and num_embeds_ada_norm is None:
raise ValueError(
f'`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to'
f' define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}.'
)
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
if self.use_ada_layer_norm:
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif self.use_ada_layer_norm_zero:
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
else:
self.norm1 = nn.LayerNorm(
dim, elementwise_affine=norm_elementwise_affine)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim
if only_cross_attention else None,
upcast_attention=upcast_attention,
)
# 2. Cross-Attn
if cross_attention_dim is not None or double_self_attention:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
self.norm2 = (
AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm else nn.LayerNorm(
dim, elementwise_affine=norm_elementwise_affine))
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim
if not double_self_attention else None,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
) # is self-attn if encoder_hidden_states is none
else:
self.norm2 = None
self.attn2 = None
# 2+. pixelwise-Attn
if use_pixelwise_attention:
self.norm2_plus = (
AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm else nn.LayerNorm(
dim, elementwise_affine=norm_elementwise_affine))
self.attn2_plus = Attention(
query_dim=dim,
cross_attention_dim=pixelwise_cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
) # is self-attn if encoder_hidden_states is none
else:
self.norm2_plus = None
self.attn2_plus = None
# 3. Feed-forward
self.norm3 = nn.LayerNorm(
dim, elementwise_affine=norm_elementwise_affine)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout)
def forward(
self,
hidden_states,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
encoder_pixelwise_hidden_states=None,
encoder_pixelwise_attention_mask=None,
timestep=None,
cross_attention_kwargs=None,
class_labels=None,
):
# Notice that normalization is always applied before the real computation in the following blocks.
# 1. Self-Attention
if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states,
timestep,
class_labels,
hidden_dtype=hidden_states.dtype)
else:
norm_hidden_states = self.norm1(hidden_states)
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states
if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = attn_output + hidden_states
# 2. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = (
self.norm2(hidden_states, timestep)
if self.use_ada_layer_norm else self.norm2(hidden_states))
# TODO (Birch-San): Here we should prepare the encoder_attention mask correctly
# prepare attention mask here
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 2+. pixelwise-Attention
if self.attn2_plus is not None:
norm_hidden_states = (
self.norm2_plus(hidden_states, timestep)
if self.use_ada_layer_norm else self.norm2_plus(hidden_states))
attn_output = self.attn2_plus(
norm_hidden_states,
encoder_hidden_states=encoder_pixelwise_hidden_states,
attention_mask=encoder_pixelwise_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 3. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero:
norm_hidden_states = norm_hidden_states * (
1 + scale_mlp[:, None]) + shift_mlp[:, None]
ff_output = self.ff(norm_hidden_states)
if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = ff_output + hidden_states
return hidden_states
class FeedForward(nn.Module):
r"""
A feed-forward layer.
Parameters:
dim (`int`): The number of channels in the input.
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
"""
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
mult: int = 4,
dropout: float = 0.0,
activation_fn: str = 'geglu',
final_dropout: bool = False,
):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
if activation_fn == 'gelu':
act_fn = GELU(dim, inner_dim)
if activation_fn == 'gelu-approximate':
act_fn = GELU(dim, inner_dim, approximate='tanh')
elif activation_fn == 'geglu':
act_fn = GEGLU(dim, inner_dim)
elif activation_fn == 'geglu-approximate':
act_fn = ApproximateGELU(dim, inner_dim)
self.net = nn.ModuleList([])
# project in
self.net.append(act_fn)
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(nn.Linear(inner_dim, dim_out))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout:
self.net.append(nn.Dropout(dropout))
def forward(self, hidden_states):
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
class GELU(nn.Module):
r"""
GELU activation function with tanh approximation support with `approximate="tanh"`.
"""
def __init__(self, dim_in: int, dim_out: int, approximate: str = 'none'):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out)
self.approximate = approximate
def gelu(self, gate):
if gate.device.type != 'mps':
return F.gelu(gate, approximate=self.approximate)
# mps: gelu is not implemented for float16
return F.gelu(
gate.to(dtype=torch.float32),
approximate=self.approximate).to(dtype=gate.dtype)
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states = self.gelu(hidden_states)
return hidden_states
class GEGLU(nn.Module):
r"""
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
"""
def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def gelu(self, gate):
if gate.device.type != 'mps':
return F.gelu(gate)
# mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
def forward(self, hidden_states):
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
return hidden_states * self.gelu(gate)
class ApproximateGELU(nn.Module):
"""
The approximate form of Gaussian Error Linear Unit (GELU)
For more details, see section 2: https://arxiv.org/abs/1606.08415
"""
def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out)
def forward(self, x):
x = self.proj(x)
return x * torch.sigmoid(1.702 * x)
class AdaLayerNorm(nn.Module):
"""
Norm layer modified to incorporate timestep embeddings.
"""
def __init__(self, embedding_dim, num_embeddings):
super().__init__()
self.emb = nn.Embedding(num_embeddings, embedding_dim)
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
def forward(self, x, timestep):
emb = self.linear(self.silu(self.emb(timestep)))
scale, shift = torch.chunk(emb, 2)
x = self.norm(x) * (1 + scale) + shift
return x
class AdaLayerNormZero(nn.Module):
"""
Norm layer adaptive layer norm zero (adaLN-Zero).
"""
def __init__(self, embedding_dim, num_embeddings):
super().__init__()
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings,
embedding_dim)
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
self.norm = nn.LayerNorm(
embedding_dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, timestep, class_labels, hidden_dtype=None):
emb = self.linear(
self.silu(
self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(
6, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
class AdaGroupNorm(nn.Module):
"""
GroupNorm layer modified to incorporate timestep embeddings.
"""
def __init__(self,
embedding_dim: int,
out_dim: int,
num_groups: int,
act_fn: Optional[str] = None,
eps: float = 1e-5):
super().__init__()
self.num_groups = num_groups
self.eps = eps
self.act = None
if act_fn == 'swish':
self.act = lambda x: F.silu(x)
elif act_fn == 'mish':
self.act = nn.Mish()
elif act_fn == 'silu':
self.act = nn.SiLU()
elif act_fn == 'gelu':
self.act = nn.GELU()
self.linear = nn.Linear(embedding_dim, out_dim * 2)
def forward(self, x, emb):
if self.act:
emb = self.act(emb)
emb = self.linear(emb)
emb = emb[:, :, None, None]
scale, shift = emb.chunk(2, dim=1)
x = F.group_norm(x, self.num_groups, eps=self.eps)
x = x * (1 + scale) + shift
return x

View File

@@ -0,0 +1,738 @@
# Part of the implementation is borrowed and modified from diffusers,
# publicly available at https://github.com/huggingface/diffusers/tree/main/src/diffusers/models/controlnet.py
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models import ModelMixin # , ControlNetModel
from diffusers.models.attention_processor import (AttentionProcessor,
AttnProcessor)
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from diffusers.utils import BaseOutput, logging
from torch import nn
from torch.nn import functional as F
from torchvision import utils
from modelscope.models.cv.super_resolution.rrdbnet_arch import RRDB
from .unet_2d_blocks import (CrossAttnDownBlock2D, DownBlock2D,
UNetMidBlock2DCrossAttn, get_down_block)
from .unet_2d_condition import UNet2DConditionModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class ControlNetOutput(BaseOutput):
controlnet_cond_mid: torch.Tensor
down_block_res_samples: Tuple[torch.Tensor]
mid_block_res_sample: torch.Tensor
class ControlNetConditioningEmbedding(nn.Module):
"""
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
model) to encode image-space conditions ... into feature maps ..."
"""
def __init__(
self,
conditioning_embedding_channels: int,
conditioning_channels: int = 3,
block_out_channels: Tuple[int] = (16, 32, 96, 256),
return_rgbs: bool = True,
use_rrdb: bool = False,
):
super().__init__()
self.return_rgbs = return_rgbs
self.use_rrdb = use_rrdb
self.conv_in = nn.Conv2d(
conditioning_channels,
block_out_channels[0],
kernel_size=3,
padding=1)
if self.use_rrdb:
num_rrdb_block = 2
layers = (
RRDB(block_out_channels[0], block_out_channels[0])
for i in range(num_rrdb_block))
self.preprocesser = nn.Sequential(*layers)
self.blocks = nn.ModuleList([])
if return_rgbs:
self.to_rgbs = nn.ModuleList([])
for i in range(len(block_out_channels) - 1):
channel_in = block_out_channels[i]
channel_out = block_out_channels[i + 1]
self.blocks.append(
nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
self.blocks.append(
nn.Conv2d(
channel_in,
channel_out,
kernel_size=3,
padding=1,
stride=2))
if return_rgbs:
self.to_rgbs.append(
nn.Conv2d(channel_out, 3, kernel_size=3, padding=1))
self.conv_out = zero_module(
nn.Conv2d(
block_out_channels[-1],
conditioning_embedding_channels,
kernel_size=3,
padding=1))
def forward(self, conditioning):
embedding = self.conv_in(conditioning)
embedding = F.silu(embedding)
if self.use_rrdb:
embedding = self.preprocesser(embedding)
out_rgbs = []
for i, block in enumerate(self.blocks):
embedding = block(embedding)
embedding = F.silu(embedding)
if i % 2 and self.return_rgbs: # 0
out_rgbs.append(self.to_rgbs[i // 2](embedding))
embedding = self.conv_out(embedding)
if self.return_rgbs:
return embedding, out_rgbs
else:
return embedding
class ControlNetModel(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 4,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
'CrossAttnDownBlock2D',
'CrossAttnDownBlock2D',
'CrossAttnDownBlock2D',
'DownBlock2D',
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = 'silu',
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
attention_head_dim: Union[int, Tuple[int]] = 8,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = 'default',
projection_class_embeddings_input_dim: Optional[int] = None,
controlnet_conditioning_channel_order: str = 'rgb',
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16,
32,
96,
256),
global_pool_conditions: bool = False,
return_rgbs: bool = False,
use_rrdb: bool = False):
super().__init__()
# Check inputs
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f'Must provide the same number of `block_out_channels` as `down_block_types`. \
`block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}.'
)
if not isinstance(
only_cross_attention,
bool) and len(only_cross_attention) != len(down_block_types):
raise ValueError(
f'Must provide the same number of `only_cross_attention` as `down_block_types`. \
`only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}.'
)
if not isinstance(
attention_head_dim,
int) and len(attention_head_dim) != len(down_block_types):
raise ValueError(
f'Must provide the same number of `attention_head_dim` as `down_block_types`. \
`attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}.'
)
# input
self.return_rgbs = return_rgbs
conv_in_kernel = 3
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = nn.Conv2d(
in_channels,
block_out_channels[0],
kernel_size=conv_in_kernel,
padding=conv_in_padding)
# time
time_embed_dim = block_out_channels[0] * 4
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos,
freq_shift)
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(
timestep_input_dim,
time_embed_dim,
act_fn=act_fn,
)
# class embedding
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds,
time_embed_dim)
elif class_embed_type == 'timestep':
self.class_embedding = TimestepEmbedding(timestep_input_dim,
time_embed_dim)
elif class_embed_type == 'identity':
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
elif class_embed_type == 'projection':
if projection_class_embeddings_input_dim is None:
raise ValueError(
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
)
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
# 2. it projects from an arbitrary input dimension.
#
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
self.class_embedding = TimestepEmbedding(
projection_class_embeddings_input_dim, time_embed_dim)
else:
self.class_embedding = None
# control net conditioning embedding
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
conditioning_embedding_channels=block_out_channels[0],
block_out_channels=conditioning_embedding_out_channels,
return_rgbs=return_rgbs,
use_rrdb=use_rrdb,
)
self.down_blocks = nn.ModuleList([])
self.controlnet_down_blocks = nn.ModuleList([])
if isinstance(only_cross_attention, bool):
only_cross_attention = [only_cross_attention
] * len(down_block_types)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim, ) * len(down_block_types)
# down
output_channel = block_out_channels[0]
controlnet_block = nn.Conv2d(
output_channel, output_channel, kernel_size=1)
controlnet_block = zero_module(controlnet_block)
self.controlnet_down_blocks.append(controlnet_block)
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[i],
downsample_padding=downsample_padding,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
)
self.down_blocks.append(down_block)
for _ in range(layers_per_block):
controlnet_block = nn.Conv2d(
output_channel, output_channel, kernel_size=1)
controlnet_block = zero_module(controlnet_block)
self.controlnet_down_blocks.append(controlnet_block)
if not is_final_block:
controlnet_block = nn.Conv2d(
output_channel, output_channel, kernel_size=1)
controlnet_block = zero_module(controlnet_block)
self.controlnet_down_blocks.append(controlnet_block)
# mid
mid_block_channel = block_out_channels[-1]
controlnet_block = nn.Conv2d(
mid_block_channel, mid_block_channel, kernel_size=1)
controlnet_block = zero_module(controlnet_block)
self.controlnet_mid_block = controlnet_block
self.mid_block = UNetMidBlock2DCrossAttn(
in_channels=mid_block_channel,
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
@classmethod
def from_unet(
cls,
unet: UNet2DConditionModel,
controlnet_conditioning_channel_order: str = 'rgb',
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32,
96, 256),
load_weights_from_unet: bool = True,
):
r"""
Instantiate Controlnet class from UNet2DConditionModel.
Parameters:
unet (`UNet2DConditionModel`):
UNet model which weights are copied to the ControlNet. Note that all configuration options are also
copied where applicable.
"""
controlnet = cls(
in_channels=unet.config.in_channels,
flip_sin_to_cos=unet.config.flip_sin_to_cos,
freq_shift=unet.config.freq_shift,
down_block_types=unet.config.down_block_types,
only_cross_attention=unet.config.only_cross_attention,
block_out_channels=unet.config.block_out_channels,
layers_per_block=unet.config.layers_per_block,
downsample_padding=unet.config.downsample_padding,
mid_block_scale_factor=unet.config.mid_block_scale_factor,
act_fn=unet.config.act_fn,
norm_num_groups=unet.config.norm_num_groups,
norm_eps=unet.config.norm_eps,
cross_attention_dim=unet.config.cross_attention_dim,
attention_head_dim=unet.config.attention_head_dim,
use_linear_projection=unet.config.use_linear_projection,
class_embed_type=unet.config.class_embed_type,
num_class_embeds=unet.config.num_class_embeds,
upcast_attention=unet.config.upcast_attention,
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
projection_class_embeddings_input_dim=unet.config.
projection_class_embeddings_input_dim,
controlnet_conditioning_channel_order=
controlnet_conditioning_channel_order,
conditioning_embedding_out_channels=
conditioning_embedding_out_channels,
)
if load_weights_from_unet:
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
controlnet.time_embedding.load_state_dict(
unet.time_embedding.state_dict())
if controlnet.class_embedding:
controlnet.class_embedding.load_state_dict(
unet.class_embedding.state_dict())
controlnet.down_blocks.load_state_dict(
unet.down_blocks.state_dict())
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
if controlnet.sr_model is not None:
load_net = torch.load(
'annotator/ckpts/RealESRNet_x4plus.pth',
map_location=lambda storage, loc: storage)
if 'params_ema' in load_net:
load_net = load_net['params_ema']
elif 'params' in load_net:
load_net = load_net['params']
# remove unnecessary 'module.'
for k, v in deepcopy(load_net).items():
if k.startswith('module.'):
load_net[k[7:]] = v
load_net.pop(k)
controlnet.sr_model.load_state_dict(load_net, strict=True)
return controlnet
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module,
processors: Dict[str,
AttentionProcessor]):
if hasattr(module, 'set_processor'):
processors[f'{name}.processor'] = module.processor
for sub_name, child in module.named_children():
fn_recursive_add_processors(f'{name}.{sub_name}', child,
processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor,
Dict[str,
AttentionProcessor]]):
r"""
Parameters:
`processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
of **all** `Attention` layers.
In case `processor` is a dict, the key needs to define the path to
the corresponding cross attention processor.
This is strongly recommended when setting trainable attention processors.:
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f'A dict of processors was passed, but the number of processors {len(processor)} does not match the'
f' number of attention layers: {count}. Please make sure to pass {count} processor classes.'
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module,
processor):
if hasattr(module, 'set_processor'):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f'{name}.processor'))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f'{name}.{sub_name}', child,
processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
self.set_attn_processor(AttnProcessor())
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
sliceable_head_dims = []
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
if hasattr(module, 'set_attention_slice'):
sliceable_head_dims.append(module.sliceable_head_dim)
for child in module.children():
fn_recursive_retrieve_sliceable_dims(child)
# retrieve number of attention layers
for module in self.children():
fn_recursive_retrieve_sliceable_dims(module)
num_sliceable_layers = len(sliceable_head_dims)
if slice_size == 'auto':
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = [dim // 2 for dim in sliceable_head_dims]
elif slice_size == 'max':
# make smallest slice possible
slice_size = num_sliceable_layers * [1]
slice_size = num_sliceable_layers * [slice_size] if not isinstance(
slice_size, list) else slice_size
if len(slice_size) != len(sliceable_head_dims):
raise ValueError(
f'You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different'
f' attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}.'
)
for i in range(len(slice_size)):
size = slice_size[i]
dim = sliceable_head_dims[i]
if size is not None and size > dim:
raise ValueError(
f'size {size} has to be smaller or equal to {dim}.')
# Recursively walk through all the children.
# Any children which exposes the set_attention_slice method
# gets the message
def fn_recursive_set_attention_slice(module: torch.nn.Module,
slice_size: List[int]):
if hasattr(module, 'set_attention_slice'):
module.set_attention_slice(slice_size.pop())
for child in module.children():
fn_recursive_set_attention_slice(child, slice_size)
reversed_slice_size = list(reversed(slice_size))
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
module.gradient_checkpointing = value
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
controlnet_cond: torch.FloatTensor,
fg_mask: Optional[torch.FloatTensor] = None,
conditioning_scale_fg: float = 1.0,
conditioning_scale_bg: float = 1.0,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guess_mode: bool = False,
return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple]:
# check channel order
channel_order = self.config.controlnet_conditioning_channel_order
if channel_order == 'rgb':
# in rgb order by default
...
elif channel_order == 'bgr':
controlnet_cond = torch.flip(controlnet_cond, dims=[1])
else:
raise ValueError(
f'unknown `controlnet_conditioning_channel_order`: {channel_order}'
)
# prepare attention_mask
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == 'mps'
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps],
dtype=dtype,
device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
if self.class_embedding is not None:
if class_labels is None:
raise ValueError(
'class_labels should be provided when num_class_embeds > 0'
)
if self.config.class_embed_type == 'timestep':
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
# 2. pre-process
sample = self.conv_in(sample)
controlnet_cond_mid = None
if self.return_rgbs:
controlnet_cond, controlnet_cond_mid = self.controlnet_cond_embedding(
controlnet_cond)
else:
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
sample = sample + controlnet_cond
# 3. down
down_block_res_samples = (sample, )
for downsample_block in self.down_blocks:
if hasattr(downsample_block, 'has_cross_attention'
) and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
)
else:
sample, res_samples = downsample_block(
hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# 4. mid
if self.mid_block is not None:
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
)
# 5. Control net blocks
controlnet_down_block_res_samples = ()
for down_block_res_sample, controlnet_block in zip(
down_block_res_samples, self.controlnet_down_blocks):
down_block_res_sample = controlnet_block(down_block_res_sample)
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (
down_block_res_sample, )
down_block_res_samples = controlnet_down_block_res_samples
mid_block_res_sample = self.controlnet_mid_block(sample)
# 6. scaling
if guess_mode and not self.config.global_pool_conditions:
scales = torch.logspace(
-1, 0, len(down_block_res_samples) + 1,
device=sample.device) # 0.1 to 1.0
scales = scales * conditioning_scale_fg
down_block_res_samples = [
sample * scale
for sample, scale in zip(down_block_res_samples, scales)
]
mid_block_res_sample = mid_block_res_sample * scales[
-1] # last one
else:
if fg_mask is None:
down_block_res_samples = [
sample * conditioning_scale_fg
for sample in down_block_res_samples
]
mid_block_res_sample = mid_block_res_sample * conditioning_scale_fg
else:
down_block_masks = [
torch.zeros_like(sample) + conditioning_scale_bg
for i, sample in enumerate(down_block_res_samples)
]
mid_block_mask = torch.zeros_like(
mid_block_res_sample) + conditioning_scale_bg
for i, sample in enumerate(down_block_masks):
tmp_mask = F.interpolate(
fg_mask,
size=sample.shape[-2:]).repeat(sample.shape[0],
sample.shape[1], 1,
1).bool()
down_block_masks[i] = sample.masked_fill(
tmp_mask, conditioning_scale_fg)
tmp_mask = F.interpolate(
fg_mask, size=mid_block_mask.shape[-2:]).repeat(
mid_block_mask.shape[0], mid_block_mask.shape[1], 1,
1).bool()
mid_block_mask = mid_block_mask.masked_fill(
tmp_mask, conditioning_scale_fg)
down_block_res_samples = [
sample * down_block_mask for sample, down_block_mask in
zip(down_block_res_samples, down_block_masks)
]
mid_block_res_sample = mid_block_res_sample * mid_block_mask
if self.config.global_pool_conditions:
down_block_res_samples = [
torch.mean(sample, dim=(2, 3), keepdim=True)
for sample in down_block_res_samples
]
mid_block_res_sample = torch.mean(
mid_block_res_sample, dim=(2, 3), keepdim=True)
if not return_dict:
return (controlnet_cond_mid, down_block_res_samples,
mid_block_res_sample)
return ControlNetOutput(
controlnet_cond_mid=controlnet_cond_mid,
down_block_res_samples=down_block_res_samples,
mid_block_res_sample=mid_block_res_sample)
def zero_module(module):
for p in module.parameters():
nn.init.zeros_(p)
return module

View File

@@ -0,0 +1,159 @@
# Part of the implementation is borrowed and modified from StableSR,
# publicly available at https://github.com/IceClear/StableSR/blob/main/scripts/wavelet_color_fix.py
import torch
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
convert_ldm_unet_checkpoint, convert_ldm_vae_checkpoint)
from PIL import Image
from safetensors import safe_open
from torch import Tensor
from torch.nn import functional as F
from torchvision.transforms import ToPILImage, ToTensor
def adain_color_fix(target: Image, source: Image):
# Convert images to tensors
to_tensor = ToTensor()
target_tensor = to_tensor(target).unsqueeze(0)
source_tensor = to_tensor(source).unsqueeze(0)
# Apply adaptive instance normalization
result_tensor = adaptive_instance_normalization(target_tensor,
source_tensor)
# Convert tensor back to image
to_image = ToPILImage()
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
return result_image
def wavelet_color_fix(target: Image, source: Image):
# Convert images to tensors
to_tensor = ToTensor()
target_tensor = to_tensor(target).unsqueeze(0)
source_tensor = to_tensor(source).unsqueeze(0)
# Apply wavelet reconstruction
result_tensor = wavelet_reconstruction(target_tensor, source_tensor)
# Convert tensor back to image
to_image = ToPILImage()
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
return result_image
def calc_mean_std(feat: Tensor, eps=1e-5):
"""Calculate mean and std for adaptive_instance_normalization.
Args:
feat (Tensor): 4D tensor.
eps (float): A small value added to the variance to avoid
divide-by-zero. Default: 1e-5.
"""
size = feat.size()
assert len(size) == 4, 'The input feature should be 4D tensor.'
b, c = size[:2]
feat_var = feat.reshape(b, c, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().reshape(b, c, 1, 1)
feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1)
return feat_mean, feat_std
def adaptive_instance_normalization(content_feat: Tensor, style_feat: Tensor):
"""Adaptive instance normalization.
Adjust the reference features to have the similar color and illuminations
as those in the degradate features.
Args:
content_feat (Tensor): The reference feature.
style_feat (Tensor): The degradate features.
"""
size = content_feat.size()
style_mean, style_std = calc_mean_std(style_feat)
content_mean, content_std = calc_mean_std(content_feat)
normalized_feat = (content_feat
- content_mean.expand(size)) / content_std.expand(size)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
def wavelet_blur(image: Tensor, radius: int):
"""
Apply wavelet blur to the input tensor.
"""
# input shape: (1, 3, H, W)
# convolution kernel
kernel_vals = [
[0.0625, 0.125, 0.0625],
[0.125, 0.25, 0.125],
[0.0625, 0.125, 0.0625],
]
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
# add channel dimensions to the kernel to make it a 4D tensor
kernel = kernel[None, None]
# repeat the kernel across all input channels
kernel = kernel.repeat(3, 1, 1, 1)
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
# apply convolution
output = F.conv2d(image, kernel, groups=3, dilation=radius)
return output
def wavelet_decomposition(image: Tensor, levels=5):
"""
Apply wavelet decomposition to the input tensor.
This function only returns the low frequency & the high frequency.
"""
high_freq = torch.zeros_like(image)
for i in range(levels):
radius = 2**i
low_freq = wavelet_blur(image, radius)
high_freq += (image - low_freq)
image = low_freq
return high_freq, low_freq
def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor):
"""
Apply wavelet decomposition, so that the content will have the same color as the style.
"""
# calculate the wavelet decomposition of the content feature
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
del content_low_freq
# calculate the wavelet decomposition of the style feature
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
del style_high_freq
# reconstruct the content feature with the style's high frequency
return content_high_freq + style_low_freq
def load_dreambooth_lora(unet, vae=None, model_path=None, model_base=''):
if model_path is None:
return unet
if model_path.endswith('.ckpt'):
base_state_dict = torch.load(model_path)['state_dict']
elif model_path.endswith('.safetensors'):
state_dict = {}
with safe_open(model_path, framework='pt', device='cpu') as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key)
is_lora = all('lora' in k for k in state_dict.keys())
if not is_lora:
base_state_dict = state_dict
else:
base_state_dict = {}
with safe_open(model_base, framework='pt', device='cpu') as f:
for key in f.keys():
base_state_dict[key] = f.get_tensor(key)
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
base_state_dict, unet.config)
unet.load_state_dict(converted_unet_checkpoint, strict=False)
if vae is not None:
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
base_state_dict, vae.config)
vae.load_state_dict(converted_vae_checkpoint)
return unet, vae

View File

@@ -0,0 +1,365 @@
# Part of the implementation is borrowed and modified from diffusers,
# publicly available at https://github.com/huggingface/diffusers/tree/main/src/diffusers/models/transformer_2d.py
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models import ModelMixin
from diffusers.models.embeddings import ImagePositionalEmbeddings, PatchEmbed
from diffusers.utils import BaseOutput, deprecate
from torch import nn
from .attention import BasicTransformerBlock
@dataclass
class Transformer2DModelOutput(BaseOutput):
"""
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or
`(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
for the unnoised latent pixels.
"""
sample: torch.FloatTensor
class Transformer2DModel(ModelMixin, ConfigMixin):
"""
Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
embeddings) inputs.
When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
transformer action. Finally, reshape to image.
When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
classes of unnoised image.
Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
Pass if the input is continuous. The number of channels in the input and output.
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
`ImagePositionalEmbeddings`.
num_vector_embeds (`int`, *optional*):
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
Includes the class for the masked latent pixel.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
up to but not more than steps than `num_embeds_ada_norm`.
attention_bias (`bool`, *optional*):
Configure if the TransformerBlocks' attention should contain a bias parameter.
"""
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
pixelwise_cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
sample_size: Optional[int] = None,
num_vector_embeds: Optional[int] = None,
patch_size: Optional[int] = None,
activation_fn: str = 'geglu',
num_embeds_ada_norm: Optional[int] = None,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
norm_type: str = 'layer_norm',
norm_elementwise_affine: bool = True,
use_pixelwise_attention: bool = False,
):
super().__init__()
self.use_linear_projection = use_linear_projection
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
# 1. Transformer2DModel can process both standard continuous images of shape
# `(batch_size, num_channels, width, height)`
# as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
# Define whether input is continuous or discrete depending on configuration
self.is_input_continuous = (in_channels
is not None) and (patch_size is None)
self.is_input_vectorized = num_vector_embeds is not None
self.is_input_patches = in_channels is not None and patch_size is not None
if norm_type == 'layer_norm' and num_embeds_ada_norm is not None:
deprecation_message = (
f'The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or'
" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
' Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect'
' results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it'
' would be very nice if you could open a Pull request for the `transformer/config.json` file'
)
deprecate(
'norm_type!=num_embeds_ada_norm',
'1.0.0',
deprecation_message,
standard_warn=False)
norm_type = 'ada_norm'
if self.is_input_continuous and self.is_input_vectorized:
raise ValueError(
f'Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make'
' sure that either `in_channels` or `num_vector_embeds` is None.'
)
elif self.is_input_vectorized and self.is_input_patches:
raise ValueError(
f'Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make'
' sure that either `num_vector_embeds` or `num_patches` is None.'
)
elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
raise ValueError(
f'Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:'
f' {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None.'
)
# 2. Define input layers
if self.is_input_continuous:
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(
num_groups=norm_num_groups,
num_channels=in_channels,
eps=1e-6,
affine=True)
if use_linear_projection:
self.proj_in = nn.Linear(in_channels, inner_dim)
if use_pixelwise_attention:
self.proj_in_plus = nn.Linear(
pixelwise_cross_attention_dim,
pixelwise_cross_attention_dim)
else:
self.proj_in_plus = None
else:
self.proj_in = nn.Conv2d(
in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
if use_pixelwise_attention:
self.proj_in_plus = nn.Conv2d(
pixelwise_cross_attention_dim,
pixelwise_cross_attention_dim,
kernel_size=1,
stride=1,
padding=0)
else:
self.proj_in_plus = None
elif self.is_input_vectorized:
assert sample_size is not None, 'Transformer2DModel over discrete input must provide sample_size'
assert num_vector_embeds is not None, 'Transformer2DModel over discrete input must provide num_embed'
self.height = sample_size
self.width = sample_size
self.num_vector_embeds = num_vector_embeds
self.num_latent_pixels = self.height * self.width
self.latent_image_embedding = ImagePositionalEmbeddings(
num_embed=num_vector_embeds,
embed_dim=inner_dim,
height=self.height,
width=self.width)
elif self.is_input_patches:
assert sample_size is not None, 'Transformer2DModel over patched input must provide sample_size'
self.height = sample_size
self.width = sample_size
self.patch_size = patch_size
self.pos_embed = PatchEmbed(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
)
# 3. Define transformers blocks
self.transformer_blocks = nn.ModuleList([
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
pixelwise_cross_attention_dim=pixelwise_cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
use_pixelwise_attention=use_pixelwise_attention,
) for d in range(num_layers)
])
# 4. Define output layers
self.out_channels = in_channels if out_channels is None else out_channels
if self.is_input_continuous:
# TODO: should use out_channels for continuous projections
if use_linear_projection:
self.proj_out = nn.Linear(inner_dim, in_channels)
else:
self.proj_out = nn.Conv2d(
inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
self.norm_out = nn.LayerNorm(inner_dim)
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
elif self.is_input_patches:
self.norm_out = nn.LayerNorm(
inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
self.proj_out_2 = nn.Linear(
inner_dim, patch_size * patch_size * self.out_channels)
def forward(
self,
hidden_states,
encoder_hidden_states=None,
encoder_pixelwise_hidden_states=None,
timestep=None,
class_labels=None,
cross_attention_kwargs=None,
return_dict: bool = True,
):
"""
Args:
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
hidden_states
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.long`, *optional*):
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels
conditioning.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
Returns:
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
[`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
# 1. Input
if self.is_input_continuous:
batch, _, height, width = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
if not self.use_linear_projection:
hidden_states = self.proj_in(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
batch, height * width, inner_dim)
if self.proj_in_plus is not None:
encoder_pixelwise_hidden_states = self.proj_in_plus(
encoder_pixelwise_hidden_states)
pixelwise_inner_dim = encoder_pixelwise_hidden_states.shape[
1]
encoder_pixelwise_hidden_states = encoder_pixelwise_hidden_states.permute(
0, 2, 3, 1).reshape(batch, height * width,
pixelwise_inner_dim)
else:
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
batch, height * width, inner_dim)
hidden_states = self.proj_in(hidden_states)
if self.proj_in_plus is not None:
pixelwise_inner_dim = encoder_pixelwise_hidden_states.shape[
1]
encoder_pixelwise_hidden_states = encoder_pixelwise_hidden_states.permute(
0, 2, 3, 1).reshape(batch, height * width,
pixelwise_inner_dim)
encoder_pixelwise_hidden_states = self.proj_in_plus(
encoder_pixelwise_hidden_states)
elif self.is_input_vectorized:
hidden_states = self.latent_image_embedding(hidden_states)
elif self.is_input_patches:
hidden_states = self.pos_embed(hidden_states)
# 2. Blocks
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_pixelwise_hidden_states=encoder_pixelwise_hidden_states,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
)
# 3. Output
if self.is_input_continuous:
if not self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height, width,
inner_dim).permute(
0, 3, 1,
2).contiguous()
hidden_states = self.proj_out(hidden_states)
else:
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, width,
inner_dim).permute(
0, 3, 1,
2).contiguous()
output = hidden_states + residual
elif self.is_input_vectorized:
hidden_states = self.norm_out(hidden_states)
logits = self.out(hidden_states)
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
logits = logits.permute(0, 2, 1)
# log(p(x_0))
output = F.log_softmax(logits.double(), dim=1).float()
elif self.is_input_patches:
# TODO: cleanup!
conditioning = self.transformer_blocks[0].norm1.emb(
timestep, class_labels, hidden_dtype=hidden_states.dtype)
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(
2, dim=1)
hidden_states = self.norm_out(hidden_states) * (
1 + scale[:, None]) + shift[:, None]
hidden_states = self.proj_out_2(hidden_states)
# unpatchify
height = width = int(hidden_states.shape[1]**0.5)
hidden_states = hidden_states.reshape(
shape=(-1, height, width, self.patch_size, self.patch_size,
self.out_channels))
hidden_states = torch.einsum('nhwpqc->nchpwq', hidden_states)
output = hidden_states.reshape(
shape=(-1, self.out_channels, height * self.patch_size,
width * self.patch_size))
if not return_dict:
return (output, )
return Transformer2DModelOutput(sample=output)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,887 @@
# Part of the implementation is borrowed and modified from diffusers,
# publicly available at https://github.com/huggingface/diffusers/tree/main/src/diffusers/models/unet_2d_condition.py
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import UNet2DConditionLoadersMixin
from diffusers.models import ModelMixin
from diffusers.models.attention_processor import (AttentionProcessor,
AttnProcessor)
from diffusers.models.embeddings import (GaussianFourierProjection,
TextTimeEmbedding, TimestepEmbedding,
Timesteps)
from diffusers.utils import BaseOutput, logging
from .unet_2d_blocks import (CrossAttnDownBlock2D, CrossAttnUpBlock2D,
DownBlock2D, UNetMidBlock2DCrossAttn,
UNetMidBlock2DSimpleCrossAttn, UpBlock2D,
get_down_block, get_up_block)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class UNet2DConditionOutput(BaseOutput):
"""
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
"""
sample: torch.FloatTensor
class UNet2DConditionModel(ModelMixin, ConfigMixin,
UNet2DConditionLoadersMixin):
r"""
UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
and returns sample shaped output.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the models (such as downloading or saving, etc.)
Parameters:
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
Height and width of input/output sample.
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`, will skip the
mid block layer if `None`.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D",
"CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
The tuple of upsample blocks to use.
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
Whether to include self-attention in the basic transformer blocks, see
[`~models.attention.BasicTransformerBlock`].
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
If `None`, it will skip the normalization and activation layers in post-processing
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
The dimension of the cross attention features.
encoder_hid_dim (`int`, *optional*, defaults to None):
If given, `encoder_hidden_states` will be projected from this dimension to `cross_attention_dim`.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to None):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
addition_embed_type (`str`, *optional*, defaults to None):
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
"text". "text" will use the `TextTimeEmbedding` layer.
num_class_embeds (`int`, *optional*, defaults to None):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`.
time_embedding_type (`str`, *optional*, default to `positional`):
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
time_embedding_dim (`int`, *optional*, default to `None`):
An optional override for the dimension of the projected time embedding.
time_embedding_act_fn (`str`, *optional*, default to `None`):
Optional activation function to use on the time embeddings only one time before they as passed to the rest
of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
timestep_post_act (`str, *optional*, default to `None`):
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
time_cond_proj_dim (`int`, *optional*, default to `None`):
The dimension of `cond_proj` layer in timestep embedding.
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
embeddings with the class embeddings.
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is None, the
`only_cross_attention` value will be used as the value for `mid_block_only_cross_attention`. Else, it will
default to `False`.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
sample_size: Optional[int] = None,
in_channels: int = 4,
out_channels: int = 4,
center_input_sample: bool = False,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
'CrossAttnDownBlock2D',
'CrossAttnDownBlock2D',
'CrossAttnDownBlock2D',
'DownBlock2D',
),
mid_block_type: Optional[str] = 'UNetMidBlock2DCrossAttn',
up_block_types: Tuple[str] = ('UpBlock2D', 'CrossAttnUpBlock2D',
'CrossAttnUpBlock2D',
'CrossAttnUpBlock2D'),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = 'silu',
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: Union[int, Tuple[int]] = 1280,
encoder_hid_dim: Optional[int] = None,
attention_head_dim: Union[int, Tuple[int]] = 8,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
addition_embed_type: Optional[str] = None,
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = 'default',
resnet_skip_time_act: bool = False,
resnet_out_scale_factor: int = 1.0,
time_embedding_type: str = 'positional',
time_embedding_dim: Optional[int] = None,
time_embedding_act_fn: Optional[str] = None,
timestep_post_act: Optional[str] = None,
time_cond_proj_dim: Optional[int] = None,
conv_in_kernel: int = 3,
conv_out_kernel: int = 3,
projection_class_embeddings_input_dim: Optional[int] = None,
class_embeddings_concat: bool = False,
mid_block_only_cross_attention: Optional[bool] = None,
cross_attention_norm: Optional[str] = None,
addition_embed_type_num_heads=64,
):
super().__init__()
self.sample_size = sample_size
# Check inputs
if len(down_block_types) != len(up_block_types):
raise ValueError(
f'Must provide the same number of `down_block_types` as `up_block_types`. \
`down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}.'
)
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f'Must provide the same number of `block_out_channels` as `down_block_types`. \
`block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}.'
)
if not isinstance(
only_cross_attention,
bool) and len(only_cross_attention) != len(down_block_types):
raise ValueError(
f'Must provide the same number of `only_cross_attention` as `down_block_types`. \
`only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}.'
)
if not isinstance(
attention_head_dim,
int) and len(attention_head_dim) != len(down_block_types):
raise ValueError(
f'Must provide the same number of `attention_head_dim` as `down_block_types`. \
`attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}.'
)
if isinstance(
cross_attention_dim,
list) and len(cross_attention_dim) != len(down_block_types):
raise ValueError(
f'Must provide the same number of `cross_attention_dim` as `down_block_types`. \
`cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}.'
)
if not isinstance(
layers_per_block,
int) and len(layers_per_block) != len(down_block_types):
raise ValueError(
f'Must provide the same number of `layers_per_block` as `down_block_types`. \
`layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}.'
)
# input
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = nn.Conv2d(
in_channels,
block_out_channels[0],
kernel_size=conv_in_kernel,
padding=conv_in_padding)
# time
if time_embedding_type == 'fourier':
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
if time_embed_dim % 2 != 0:
raise ValueError(
f'`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.'
)
self.time_proj = GaussianFourierProjection(
time_embed_dim // 2,
set_W_to_weight=False,
log=False,
flip_sin_to_cos=flip_sin_to_cos)
timestep_input_dim = time_embed_dim
elif time_embedding_type == 'positional':
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos,
freq_shift)
timestep_input_dim = block_out_channels[0]
else:
raise ValueError(
f'{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`.'
)
self.time_embedding = TimestepEmbedding(
timestep_input_dim,
time_embed_dim,
act_fn=act_fn,
post_act_fn=timestep_post_act,
cond_proj_dim=time_cond_proj_dim,
)
if encoder_hid_dim is not None:
self.encoder_hid_proj = nn.Linear(encoder_hid_dim,
cross_attention_dim)
else:
self.encoder_hid_proj = None
# class embedding
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds,
time_embed_dim)
elif class_embed_type == 'timestep':
self.class_embedding = TimestepEmbedding(
timestep_input_dim, time_embed_dim, act_fn=act_fn)
elif class_embed_type == 'identity':
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
elif class_embed_type == 'projection':
if projection_class_embeddings_input_dim is None:
raise ValueError(
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
)
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
# 2. it projects from an arbitrary input dimension.
#
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
self.class_embedding = TimestepEmbedding(
projection_class_embeddings_input_dim, time_embed_dim)
elif class_embed_type == 'simple_projection':
if projection_class_embeddings_input_dim is None:
raise ValueError(
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
)
self.class_embedding = nn.Linear(
projection_class_embeddings_input_dim, time_embed_dim)
else:
self.class_embedding = None
if addition_embed_type == 'text':
if encoder_hid_dim is not None:
text_time_embedding_from_dim = encoder_hid_dim
else:
text_time_embedding_from_dim = cross_attention_dim
self.add_embedding = TextTimeEmbedding(
text_time_embedding_from_dim,
time_embed_dim,
num_heads=addition_embed_type_num_heads)
elif addition_embed_type is not None:
raise ValueError(
f"addition_embed_type: {addition_embed_type} must be None or 'text'."
)
if time_embedding_act_fn is None:
self.time_embed_act = None
elif time_embedding_act_fn == 'swish':
self.time_embed_act = lambda x: F.silu(x)
elif time_embedding_act_fn == 'mish':
self.time_embed_act = nn.Mish()
elif time_embedding_act_fn == 'silu':
self.time_embed_act = nn.SiLU()
elif time_embedding_act_fn == 'gelu':
self.time_embed_act = nn.GELU()
else:
raise ValueError(
f'Unsupported activation function: {time_embedding_act_fn}')
self.down_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
if isinstance(only_cross_attention, bool):
if mid_block_only_cross_attention is None:
mid_block_only_cross_attention = only_cross_attention
only_cross_attention = [only_cross_attention
] * len(down_block_types)
if mid_block_only_cross_attention is None:
mid_block_only_cross_attention = False
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim, ) * len(down_block_types)
if isinstance(cross_attention_dim, int):
cross_attention_dim = (
cross_attention_dim, ) * len(down_block_types)
if isinstance(layers_per_block, int):
layers_per_block = [layers_per_block] * len(down_block_types)
if class_embeddings_concat:
# The time embeddings are concatenated with the class embeddings. The dimension of the
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
# regular time embeddings
blocks_time_embed_dim = time_embed_dim * 2
else:
blocks_time_embed_dim = time_embed_dim
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block[i],
in_channels=input_channel,
out_channels=output_channel,
temb_channels=blocks_time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim[i],
attn_num_head_channels=attention_head_dim[i],
downsample_padding=downsample_padding,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
)
self.down_blocks.append(down_block)
# mid
if mid_block_type == 'UNetMidBlock2DCrossAttn':
self.mid_block = UNetMidBlock2DCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=blocks_time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim[-1],
attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
elif mid_block_type == 'UNetMidBlock2DSimpleCrossAttn':
self.mid_block = UNetMidBlock2DSimpleCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=blocks_time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
cross_attention_dim=cross_attention_dim[-1],
attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
only_cross_attention=mid_block_only_cross_attention,
cross_attention_norm=cross_attention_norm,
)
elif mid_block_type is None:
self.mid_block = None
else:
raise ValueError(f'unknown mid_block_type : {mid_block_type}')
# count how many layers upsample the images
self.num_upsamplers = 0
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
reversed_layers_per_block = list(reversed(layers_per_block))
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
only_cross_attention = list(reversed(only_cross_attention))
reversed_pixelwise_cross_attention_dim = [-1, 1280, 640, 320]
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
is_final_block = i == len(block_out_channels) - 1
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(
i + 1,
len(block_out_channels) - 1)]
# add upsample block for all BUT final layer
if not is_final_block:
add_upsample = True
self.num_upsamplers += 1
else:
add_upsample = False
up_block = get_up_block(
up_block_type,
num_layers=reversed_layers_per_block[i] + 1,
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=blocks_time_embed_dim,
add_upsample=add_upsample,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=reversed_cross_attention_dim[i],
pixelwise_cross_attention_dim=
reversed_pixelwise_cross_attention_dim[i],
attn_num_head_channels=reversed_attention_head_dim[i],
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
if norm_num_groups is not None:
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[0],
num_groups=norm_num_groups,
eps=norm_eps)
if act_fn == 'swish':
self.conv_act = lambda x: F.silu(x)
elif act_fn == 'mish':
self.conv_act = nn.Mish()
elif act_fn == 'silu':
self.conv_act = nn.SiLU()
elif act_fn == 'gelu':
self.conv_act = nn.GELU()
else:
raise ValueError(f'Unsupported activation function: {act_fn}')
else:
self.conv_norm_out = None
self.conv_act = None
conv_out_padding = (conv_out_kernel - 1) // 2
self.conv_out = nn.Conv2d(
block_out_channels[0],
out_channels,
kernel_size=conv_out_kernel,
padding=conv_out_padding)
@property
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module,
processors: Dict[str,
AttentionProcessor]):
if hasattr(module, 'set_processor'):
processors[f'{name}.processor'] = module.processor
for sub_name, child in module.named_children():
fn_recursive_add_processors(f'{name}.{sub_name}', child,
processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
def set_attn_processor(self, processor: Union[AttentionProcessor,
Dict[str,
AttentionProcessor]]):
r"""
Parameters:
`processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
of **all** `Attention` layers.
In case `processor` is a dict, the key needs to define the path to
the corresponding cross attention processor.
This is strongly recommended when setting trainable attention processors.:
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f'A dict of processors was passed, but the number of processors {len(processor)} does not match the'
f' number of attention layers: {count}. Please make sure to pass {count} processor classes.'
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module,
processor):
if hasattr(module, 'set_processor'):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f'{name}.processor'))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f'{name}.{sub_name}', child,
processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
self.set_attn_processor(AttnProcessor())
def set_attention_slice(self, slice_size):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
sliceable_head_dims = []
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
if hasattr(module, 'set_attention_slice'):
sliceable_head_dims.append(module.sliceable_head_dim)
for child in module.children():
fn_recursive_retrieve_sliceable_dims(child)
# retrieve number of attention layers
for module in self.children():
fn_recursive_retrieve_sliceable_dims(module)
num_sliceable_layers = len(sliceable_head_dims)
if slice_size == 'auto':
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = [dim // 2 for dim in sliceable_head_dims]
elif slice_size == 'max':
# make smallest slice possible
slice_size = num_sliceable_layers * [1]
slice_size = num_sliceable_layers * [slice_size] if not isinstance(
slice_size, list) else slice_size
if len(slice_size) != len(sliceable_head_dims):
raise ValueError(
f'You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different'
f' attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}.'
)
for i in range(len(slice_size)):
size = slice_size[i]
dim = sliceable_head_dims[i]
if size is not None and size > dim:
raise ValueError(
f'size {size} has to be smaller or equal to {dim}.')
# Recursively walk through all the children.
# Any children which exposes the set_attention_slice method
# gets the message
def fn_recursive_set_attention_slice(module: torch.nn.Module,
slice_size: List[int]):
if hasattr(module, 'set_attention_slice'):
module.set_attention_slice(slice_size.pop())
for child in module.children():
fn_recursive_set_attention_slice(child, slice_size)
reversed_slice_size = list(reversed(slice_size))
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D,
CrossAttnUpBlock2D, UpBlock2D)):
module.gradient_checkpointing = value
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]:
r"""
Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
default_overall_up_factor = 2**self.num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
logger.info(
'Forward upsample size to force interpolation output size.')
forward_upsample_size = True
# prepare attention_mask
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == 'mps'
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps],
dtype=dtype,
device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
# `Timesteps` does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
if self.class_embedding is not None:
if class_labels is None:
raise ValueError(
'class_labels should be provided when num_class_embeds > 0'
)
if self.config.class_embed_type == 'timestep':
class_labels = self.time_proj(class_labels)
# `Timesteps` does not contain any weights and will always return f32 tensors
# there might be better ways to encapsulate this.
class_labels = class_labels.to(dtype=sample.dtype)
class_emb = self.class_embedding(class_labels).to(
dtype=sample.dtype)
if self.config.class_embeddings_concat:
emb = torch.cat([emb, class_emb], dim=-1)
else:
emb = emb + class_emb
if self.config.addition_embed_type == 'text':
aug_emb = self.add_embedding(encoder_hidden_states)
emb = emb + aug_emb
if self.time_embed_act is not None:
emb = self.time_embed_act(emb)
if self.encoder_hid_proj is not None:
encoder_hidden_states = self.encoder_hid_proj(
encoder_hidden_states)
# 2. pre-process
sample = self.conv_in(sample)
# 3. down
down_block_res_samples = (sample, )
for downsample_block in self.down_blocks:
if hasattr(downsample_block, 'has_cross_attention'
) and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
)
else:
sample, res_samples = downsample_block(
hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
if down_block_additional_residuals is not None:
new_down_block_res_samples = ()
for down_block_res_sample, down_block_additional_residual in zip(
down_block_res_samples, down_block_additional_residuals):
down_block_res_sample = down_block_res_sample + down_block_additional_residual
new_down_block_res_samples = new_down_block_res_samples + (
down_block_res_sample, )
down_block_res_samples = new_down_block_res_samples
# 4. mid
if self.mid_block is not None:
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
)
if mid_block_additional_residual is not None:
sample = sample + mid_block_additional_residual
# 5. up
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
down_block_res_samples = down_block_res_samples[:-len(
upsample_block.resnets)]
down_block_additional_residual = down_block_additional_residuals[
-len(upsample_block.resnets):]
down_block_additional_residuals = down_block_additional_residuals[:-len(
upsample_block.resnets)]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, 'has_cross_attention'
) and upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
encoder_pixelwise_hidden_states_tuple=
down_block_additional_residual,
cross_attention_kwargs=cross_attention_kwargs,
upsample_size=upsample_size,
attention_mask=attention_mask,
)
else:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
upsample_size=upsample_size)
# 6. post-process
if self.conv_norm_out:
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if not return_dict:
return (sample, )
return UNet2DConditionOutput(sample=sample)
@classmethod
def from_pretrained_(cls, pretrained_model_path, subfolder=None, **kwargs):
if subfolder is not None:
pretrained_model_path = os.path.join(pretrained_model_path,
subfolder)
config_file = os.path.join(pretrained_model_path, 'config.json')
if not os.path.isfile(config_file):
raise RuntimeError(f'{config_file} does not exist')
with open(config_file, 'r') as f:
config = json.load(f)
from diffusers.utils import WEIGHTS_NAME
model = cls.from_config(config)
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
if not os.path.isfile(model_file):
raise RuntimeError(f'{model_file} does not exist')
state_dict = torch.load(model_file, map_location='cpu')
for k, v in model.state_dict().items():
if 'attn2_plus' in k:
state_dict.update({k: v})
model.load_state_dict(state_dict, strict=False)
return model

View File

@@ -9,6 +9,8 @@ import torch
from modelscope.metainfo import Models
from modelscope.models.base.base_torch_model import TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.compatible_with_transformers import \
compatible_position_ids
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
@@ -40,6 +42,8 @@ class ReferringVideoObjectSegmentation(TorchModel):
params_dict = torch.load(model_path, map_location='cpu')
if 'model_state_dict' in params_dict.keys():
params_dict = params_dict['model_state_dict']
compatible_position_ids(
params_dict, 'transformer.text_encoder.embeddings.position_ids')
self.model.load_state_dict(params_dict, strict=True)
self.set_postprocessor(self.cfg.pipeline.dataset_name)

View File

@@ -15,6 +15,9 @@ import torch.utils.checkpoint as checkpoint
from torch import nn
from transformers import BertConfig, BertForMaskedLM
from modelscope.utils.compatible_with_transformers import \
compatible_position_ids
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
@@ -290,6 +293,8 @@ class TEAM(nn.Module):
self.text_tensor_fc = nn.Linear(1024, 768)
params = torch.load(pretrained, 'cpu')
compatible_position_ids(params,
'text_model.bert.embeddings.position_ids')
self.load_state_dict(params, strict=True)
def get_feature(self, text_data=None, text_mask=None, img_tensor=None):

View File

@@ -12,6 +12,8 @@ from modelscope.metainfo import Models
from modelscope.models import TorchModel
from modelscope.models.base import Tensor
from modelscope.models.builder import MODELS
from modelscope.utils.compatible_with_transformers import \
compatible_position_ids
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
@@ -47,6 +49,9 @@ class StarForTextToSql(TorchModel):
open(
os.path.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), 'rb'),
map_location=self.device)
compatible_position_ids(
check_point['model'],
'encoder.input_layer.plm_model.embeddings.position_ids')
self.model.load_state_dict(check_point['model'])
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:

View File

@@ -22,6 +22,8 @@ from modelscope.models.base import TorchModel
from modelscope.models.builder import MODELS
from modelscope.models.nlp.unite.configuration import InputFormat
from modelscope.outputs.nlp_outputs import TranslationEvaluationOutput
from modelscope.utils.compatible_with_transformers import \
compatible_position_ids
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
@@ -305,6 +307,8 @@ class UniTEForTranslationEvaluation(TorchModel):
self.encoder.pooler = None
else:
state_dict = torch.load(path, map_location=device)
compatible_position_ids(state_dict,
'encoder.embeddings.position_ids')
self.load_state_dict(state_dict)
logger.info('Loading checkpoint parameters from %s' % path)
return

View File

@@ -15,6 +15,8 @@ from modelscope.models import TorchModel
from modelscope.models.base import Tensor
from modelscope.models.builder import MODELS
from modelscope.outputs import DialogueUserSatisfactionEstimationModelOutput
from modelscope.utils.compatible_with_transformers import \
compatible_position_ids
from modelscope.utils.constant import ModelFile, Tasks
from .transformer import TransformerEncoder
@@ -45,8 +47,10 @@ class UserSatisfactionEstimation(TorchModel):
self.device = device
self.model = self.init_model()
model_ckpt = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE)
self.model.load_state_dict(
torch.load(model_ckpt, map_location=torch.device('cpu')))
stats_dict = torch.load(model_ckpt, map_location=torch.device('cpu'))
compatible_position_ids(stats_dict,
'private.bert.embeddings.position_ids')
self.model.load_state_dict(stats_dict)
def init_model(self):
configs = {

View File

@@ -727,6 +727,7 @@ TASK_OUTPUTS = {
# {"output_img": np.array with shape (h, w, 3)}
Tasks.skin_retouching: [OutputKeys.OUTPUT_IMG],
Tasks.image_super_resolution: [OutputKeys.OUTPUT_IMG],
Tasks.image_super_resolution_pasd: [OutputKeys.OUTPUT_IMG],
Tasks.image_colorization: [OutputKeys.OUTPUT_IMG],
Tasks.image_color_enhancement: [OutputKeys.OUTPUT_IMG],
Tasks.image_denoising: [OutputKeys.OUTPUT_IMG],

View File

@@ -182,6 +182,10 @@ TASK_INPUTS = {
InputType.IMAGE,
Tasks.crowd_counting:
InputType.IMAGE,
Tasks.image_super_resolution_pasd: {
'image': InputType.IMAGE,
'prompt': InputType.TEXT,
},
Tasks.image_inpainting: {
'img': InputType.IMAGE,
'mask': InputType.IMAGE,

View File

@@ -38,6 +38,7 @@ if TYPE_CHECKING:
from .image_semantic_segmentation_pipeline import ImageSemanticSegmentationPipeline
from .image_style_transfer_pipeline import ImageStyleTransferPipeline
from .image_super_resolution_pipeline import ImageSuperResolutionPipeline
from .image_super_resolution_pasd_pipeline import ImageSuperResolutionPASDPipeline
from .image_to_image_generate_pipeline import Image2ImageGenerationPipeline
from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline
from .image_inpainting_pipeline import ImageInpaintingPipeline
@@ -151,6 +152,8 @@ else:
['ImageSemanticSegmentationPipeline'],
'image_style_transfer_pipeline': ['ImageStyleTransferPipeline'],
'image_super_resolution_pipeline': ['ImageSuperResolutionPipeline'],
'image_super_resolution_pasd_pipeline':
['ImageSuperResolutionPASDPipeline'],
'image_to_image_translation_pipeline':
['Image2ImageTranslationPipeline'],
'product_retrieval_embedding_pipeline':
@@ -185,8 +188,9 @@ else:
['FaceAttributeRecognitionPipeline'],
'mtcnn_face_detection_pipeline': ['MtcnnFaceDetectionPipeline'],
'hand_static_pipeline': ['HandStaticPipeline'],
'referring_video_object_segmentation_pipeline':
['ReferringVideoObjectSegmentationPipeline'],
'referring_video_object_segmentation_pipeline': [
'ReferringVideoObjectSegmentationPipeline'
],
'language_guided_video_summarization_pipeline': [
'LanguageGuidedVideoSummarizationPipeline'
],

View File

@@ -0,0 +1,229 @@
# Copyright © Alibaba, Inc. and its affiliates.
import os
import tempfile
from typing import Any, Dict, Optional, Union
import numpy as np
import PIL
import torch
from diffusers import AutoencoderKL, UniPCMultistepScheduler
from torchvision.models import ResNet50_Weights, resnet50
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from modelscope.metainfo import Pipelines
from modelscope.models.cv.image_portrait_enhancement.retinaface import \
detection
from modelscope.models.cv.image_super_resolution_pasd import (
ControlNetModel, UNet2DConditionModel)
from modelscope.models.cv.image_super_resolution_pasd.misc import (
load_dreambooth_lora, wavelet_color_fix)
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.pipelines.multi_modal.diffusers_wrapped.pasd_pipeline import \
PixelAwareStableDiffusionPipeline
from modelscope.preprocessors.image import load_image
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.device import create_device
@PIPELINES.register_module(
Tasks.image_super_resolution_pasd,
module_name=Pipelines.image_super_resolution_pasd)
class ImageSuperResolutionPASDPipeline(Pipeline):
""" Pixel-Aware Stable Diffusion for Realistic Image Super-Resolution Pipeline.
Example:
>>> import cv2
>>> from modelscope.outputs import OutputKeys
>>> from modelscope.pipelines import pipeline
>>> from modelscope.utils.constant import Tasks
>>> input_location = 'example_image.png'
>>> input = {
>>> 'image': input_location,
>>> 'upscale': 2,
>>> 'prompt': prompt,
>>> 'fidelity_scale_fg': 1.5,
>>> 'fidelity_scale_bg': 0.7
>>> }
>>> pasd = pipeline(Tasks.image_super_resolution_pasd, model='damo/PASD_image_super_resolutions')
>>> output = pasd(input)[OutputKeys.OUTPUT_IMG]
>>> cv2.imwrite('result.png', output)
"""
def __init__(self, model: str, device_name: str = 'cuda', **kwargs):
"""
use `model` to create a image super resolution pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model, **kwargs)
torch_dtype = kwargs.get('torch_dtype', torch.float16)
self.device = create_device(device_name)
self.config = Config.from_file(
os.path.join(model, ModelFile.CONFIGURATION))
cfg = self.config.model_cfg
dreambooth_lora_ckpt = cfg['dreambooth_lora_ckpt']
tiled_size = cfg['tiled_size']
self.process_size = cfg['process_size']
scheduler = UniPCMultistepScheduler.from_pretrained(
model, subfolder='scheduler')
text_encoder = CLIPTextModel.from_pretrained(
model, subfolder='text_encoder')
tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer')
vae = AutoencoderKL.from_pretrained(model, subfolder='vae')
feature_extractor = CLIPImageProcessor.from_pretrained(
f'{model}/feature_extractor')
unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet')
controlnet = ControlNetModel.from_pretrained(
model, subfolder='controlnet')
unet, vae = load_dreambooth_lora(unet, vae,
f'{model}/{dreambooth_lora_ckpt}')
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
unet.requires_grad_(False)
controlnet.requires_grad_(False)
text_encoder.to(self.device, dtype=torch_dtype)
vae.to(self.device, dtype=torch_dtype)
unet.to(self.device, dtype=torch_dtype)
controlnet.to(self.device, dtype=torch_dtype)
self.pipeline = PixelAwareStableDiffusionPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
unet=unet,
controlnet=controlnet,
scheduler=scheduler,
safety_checker=None,
requires_safety_checker=False,
)
self.pipeline._init_tiled_vae(decoder_tile_size=tiled_size)
self.pipeline.enable_model_cpu_offload()
self.weights = ResNet50_Weights.DEFAULT
self.resnet_preprocess = self.weights.transforms()
self.resnet = resnet50(weights=self.weights)
self.resnet.eval()
self.threshold = 0.8
detector_model_path = f'{model}/RetinaFace-R50.pth'
self.face_detector = detection.RetinaFaceDetection(
detector_model_path, self.device)
def preprocess(self, input: Input):
return input
def forward(self, inputs: Dict[str, Any]):
if not isinstance(inputs, dict):
raise ValueError(
f'Expected the input to be a dictionary, but got {type(input)}'
)
num_inference_steps = inputs.get('num_inference_steps', 20)
guidance_scale = inputs.get('guidance_scale', 7.5)
added_prompt = inputs.get(
'added_prompt',
'clean, high-resolution, 8k, best quality, masterpiece, extremely detailed'
)
negative_prompt = inputs.get(
'negative_prompt',
'dotted, noise, blur, lowres, smooth, longbody, bad anatomy, bad hands, missing fingers, extra digit, \
fewer digits, cropped, worst quality, low quality')
eta = inputs.get('eta', 0.0)
prompt = inputs.get('prompt', '')
upscale = inputs.get('upscale', 2)
fidelity_scale_fg = inputs.get('fidelity_scale_fg', 1.5)
fidelity_scale_bg = inputs.get('fidelity_scale_bg', 0.7)
input_image = load_image(inputs['image']).convert('RGB')
with torch.no_grad():
generator = torch.Generator(device=self.device)
batch = self.resnet_preprocess(input_image).unsqueeze(0)
prediction = self.resnet(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = self.weights.meta['categories'][class_id]
if score >= 0.1:
prompt += f'{category_name}' if prompt == '' else f', {category_name}'
prompt = added_prompt if prompt == '' else f'{prompt}, {added_prompt}'
ori_width, ori_height = input_image.size
resize_flag = False
rscale = upscale
if ori_width < self.process_size // rscale or ori_height < self.process_size // rscale:
scale = (self.process_size // rscale) / min(
ori_width, ori_height)
tmp_image = input_image.resize(
(int(scale * ori_width), int(scale * ori_height)))
input_image = tmp_image
resize_flag = True
input_image = input_image.resize(
(input_image.size[0] * rscale, input_image.size[1] * rscale))
input_image = input_image.resize(
(input_image.size[0] // 8 * 8, input_image.size[1] // 8 * 8))
width, height = input_image.size
fg_mask = None
if fidelity_scale_fg != fidelity_scale_bg:
fg_mask = torch.zeros([1, 1, height, width])
facebs, _ = self.face_detector.detect(np.array(input_image))
for fb in facebs:
if fb[-1] < self.threshold:
continue
fb = list(map(int, fb))
fg_mask[:, :, fb[1]:fb[3], fb[0]:fb[2]] = 1
fg_mask = fg_mask.to(self.device)
if fg_mask is None:
fidelity_scale = min(
max(fidelity_scale_fg, fidelity_scale_bg), 1)
fidelity_scale_fg = fidelity_scale_bg = fidelity_scale
try:
image = self.pipeline(
prompt,
input_image,
num_inference_steps=num_inference_steps,
generator=generator,
height=height,
width=width,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
conditioning_scale_fg=fidelity_scale_fg,
conditioning_scale_bg=fidelity_scale_bg,
fg_mask=fg_mask,
eta=eta,
).images[0]
image = wavelet_color_fix(image, input_image)
if resize_flag:
image = image.resize(
(ori_width * rscale, ori_height * rscale))
except Exception as e:
print(e)
image = PIL.Image.new('RGB', (512, 512), (0, 0, 0))
return {'result': image}
def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
result = np.array(inputs['result'])
return {OutputKeys.OUTPUT_IMG: result[:, :, ::-1]}

View File

@@ -0,0 +1,126 @@
# The implementation is adopted from stable-diffusion-webui, made public available under the Apache 2.0 License
# at https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/modules/devices.py
import contextlib
import sys
import torch
if sys.platform == 'darwin':
from modules import mac_specific
def has_mps() -> bool:
if sys.platform != 'darwin':
return False
else:
return mac_specific.has_mps
def get_cuda_device_string():
return 'cuda'
def get_optimal_device_name():
if torch.cuda.is_available():
return get_cuda_device_string()
if has_mps():
return 'mps'
return 'cpu'
def get_optimal_device():
return torch.device(get_optimal_device_name())
def get_device_for(task):
return get_optimal_device()
def torch_gc():
if torch.cuda.is_available():
with torch.cuda.device(get_cuda_device_string()):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
if has_mps():
mac_specific.torch_mps_gc()
def enable_tf32():
if torch.cuda.is_available():
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
if any(
torch.cuda.get_device_capability(devid) == (7, 5)
for devid in range(0, torch.cuda.device_count())):
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
enable_tf32()
cpu = torch.device('cpu')
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = torch.device(
'cuda')
dtype = torch.float16
dtype_vae = torch.float16
dtype_unet = torch.float16
unet_needs_upcast = False
def cond_cast_unet(input):
return input.to(dtype_unet) if unet_needs_upcast else input
def cond_cast_float(input):
return input.float() if unet_needs_upcast else input
def randn(seed, shape):
torch.manual_seed(seed)
return torch.randn(shape, device=device)
def randn_without_seed(shape):
return torch.randn(shape, device=device)
def autocast(disable=False):
if disable:
return contextlib.nullcontext()
return torch.autocast('cuda')
def without_autocast(disable=False):
return torch.autocast('cuda', enabled=False) if torch.is_autocast_enabled() and \
not disable else contextlib.nullcontext()
class NansException(Exception):
pass
def test_for_nans(x, where):
if not torch.all(torch.isnan(x)).item():
return
if where == 'unet':
message = 'A tensor with all NaNs was produced in Unet.'
elif where == 'vae':
message = 'A tensor with all NaNs was produced in VAE.'
else:
message = 'A tensor with all NaNs was produced.'
message += ' Use --disable-nan-check commandline argument to disable this check.'
raise NansException(message)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,762 @@
# Part of the implementation is borrowed and modified from multidiffusion-upscaler-for-automatic1111, publicly available
# at https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111/blob/main/scripts/vae_optimize.py
import gc
import math
from time import time
import torch
import torch.nn.functional as F
import torch.version
from tqdm import tqdm
from .devices import device, get_optimal_device, test_for_nans, torch_gc
sd_flag = False
def get_recommend_encoder_tile_size():
if torch.cuda.is_available():
total_memory = torch.cuda.get_device_properties(
device).total_memory // 2**20
if total_memory > 16 * 1000:
ENCODER_TILE_SIZE = 3072
elif total_memory > 12 * 1000:
ENCODER_TILE_SIZE = 2048
elif total_memory > 8 * 1000:
ENCODER_TILE_SIZE = 1536
else:
ENCODER_TILE_SIZE = 960
else:
ENCODER_TILE_SIZE = 512
return ENCODER_TILE_SIZE
def get_recommend_decoder_tile_size():
if torch.cuda.is_available():
total_memory = torch.cuda.get_device_properties(
device).total_memory // 2**20
if total_memory > 30 * 1000:
DECODER_TILE_SIZE = 256
elif total_memory > 16 * 1000:
DECODER_TILE_SIZE = 192
elif total_memory > 12 * 1000:
DECODER_TILE_SIZE = 128
elif total_memory > 8 * 1000:
DECODER_TILE_SIZE = 96
else:
DECODER_TILE_SIZE = 64
else:
DECODER_TILE_SIZE = 64
return DECODER_TILE_SIZE
if 'global const':
DEFAULT_ENABLED = False
DEFAULT_MOVE_TO_GPU = False
DEFAULT_FAST_ENCODER = True
DEFAULT_FAST_DECODER = True
DEFAULT_COLOR_FIX = 0
DEFAULT_ENCODER_TILE_SIZE = get_recommend_encoder_tile_size()
DEFAULT_DECODER_TILE_SIZE = get_recommend_decoder_tile_size()
# inplace version of silu
def inplace_nonlinearity(x):
# Test: fix for Nans
return F.silu(x, inplace=True)
# extracted from ldm.modules.diffusionmodules.model
# from diffusers lib
def attn_forward_new(self, h_):
batch_size, channel, height, width = h_.shape
hidden_states = h_.view(batch_size, channel,
height * width).transpose(1, 2)
attention_mask = None
encoder_hidden_states = None
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = self.prepare_attention_mask(attention_mask,
sequence_length, batch_size)
query = self.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif self.norm_cross:
encoder_hidden_states = self.norm_encoder_hidden_states(
encoder_hidden_states)
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
query = self.head_to_batch_dim(query)
key = self.head_to_batch_dim(key)
value = self.head_to_batch_dim(value)
attention_probs = self.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = self.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width)
return hidden_states
def attn2task(task_queue, net):
task_queue.append(('store_res', lambda x: x))
task_queue.append(('pre_norm', net.group_norm))
task_queue.append(('attn', lambda x, net=net: attn_forward_new(net, x)))
task_queue.append(['add_res', None])
def resblock2task(queue, block):
"""
Turn a ResNetBlock into a sequence of tasks and append to the task queue
@param queue: the target task queue
@param block: ResNetBlock
"""
if block.in_channels != block.out_channels:
if sd_flag:
if block.use_conv_shortcut:
queue.append(('store_res', block.conv_shortcut))
else:
queue.append(('store_res', block.nin_shortcut))
else:
if block.use_in_shortcut:
queue.append(('store_res', block.conv_shortcut))
else:
queue.append(('store_res', block.nin_shortcut))
else:
queue.append(('store_res', lambda x: x))
queue.append(('pre_norm', block.norm1))
queue.append(('silu', inplace_nonlinearity))
queue.append(('conv1', block.conv1))
queue.append(('pre_norm', block.norm2))
queue.append(('silu', inplace_nonlinearity))
queue.append(('conv2', block.conv2))
queue.append(['add_res', None])
def build_sampling(task_queue, net, is_decoder):
"""
Build the sampling part of a task queue
@param task_queue: the target task queue
@param net: the network
@param is_decoder: currently building decoder or encoder
"""
if is_decoder:
if sd_flag:
resblock2task(task_queue, net.mid.block_1)
attn2task(task_queue, net.mid.attn_1)
print(task_queue)
resblock2task(task_queue, net.mid.block_2)
resolution_iter = reversed(range(net.num_resolutions))
block_ids = net.num_res_blocks + 1
condition = 0
module = net.up
func_name = 'upsample'
else:
resblock2task(task_queue, net.mid_block.resnets[0])
attn2task(task_queue, net.mid_block.attentions[0])
resblock2task(task_queue, net.mid_block.resnets[1])
resolution_iter = (range(len(net.up_blocks))
) # net.num_resolutions = 3
block_ids = 2 + 1
condition = len(net.up_blocks) - 1
module = net.up_blocks
func_name = 'upsamplers'
else:
resolution_iter = range(net.num_resolutions)
block_ids = net.num_res_blocks
condition = net.num_resolutions - 1
module = net.down
func_name = 'downsample'
for i_level in resolution_iter:
for i_block in range(block_ids):
if sd_flag:
resblock2task(task_queue, module[i_level].block[i_block])
else:
resblock2task(task_queue, module[i_level].resnets[i_block])
if i_level != condition:
if sd_flag:
task_queue.append(
(func_name, getattr(module[i_level], func_name)))
else:
task_queue.append((func_name, module[i_level].upsamplers[0]))
if not is_decoder:
if sd_flag:
resblock2task(task_queue, net.mid.block_1)
attn2task(task_queue, net.mid.attn_1)
resblock2task(task_queue, net.mid.block_2)
else:
resblock2task(task_queue, net.mid_block.resnets[0])
attn2task(task_queue, net.mid_block.attentions[0])
resblock2task(task_queue, net.mid_block.resnets[1])
def build_task_queue(net, is_decoder):
"""
Build a single task queue for the encoder or decoder
@param net: the VAE decoder or encoder network
@param is_decoder: currently building decoder or encoder
@return: the task queue
"""
task_queue = []
task_queue.append(('conv_in', net.conv_in))
# construct the sampling part of the task queue
# because encoder and decoder share the same architecture, we extract the sampling part
build_sampling(task_queue, net, is_decoder)
if is_decoder and not sd_flag:
net.give_pre_end = False
net.tanh_out = False
if not is_decoder or not net.give_pre_end:
if sd_flag:
task_queue.append(('pre_norm', net.norm_out))
else:
task_queue.append(('pre_norm', net.conv_norm_out))
task_queue.append(('silu', inplace_nonlinearity))
task_queue.append(('conv_out', net.conv_out))
if is_decoder and net.tanh_out:
task_queue.append(('tanh', torch.tanh))
return task_queue
def clone_task_queue(task_queue):
"""
Clone a task queue
@param task_queue: the task queue to be cloned
@return: the cloned task queue
"""
return [[item for item in task] for task in task_queue]
def get_var_mean(input, num_groups, eps=1e-6):
"""
Get mean and var for group norm
"""
b, c = input.size(0), input.size(1)
channel_in_group = int(c / num_groups)
input_reshaped = input.contiguous().view(1, int(b * num_groups),
channel_in_group,
*input.size()[2:])
var, mean = torch.var_mean(
input_reshaped, dim=[0, 2, 3, 4], unbiased=False)
return var, mean
def custom_group_norm(input,
num_groups,
mean,
var,
weight=None,
bias=None,
eps=1e-6):
"""
Custom group norm with fixed mean and var
@param input: input tensor
@param num_groups: number of groups. by default, num_groups = 32
@param mean: mean, must be pre-calculated by get_var_mean
@param var: var, must be pre-calculated by get_var_mean
@param weight: weight, should be fetched from the original group norm
@param bias: bias, should be fetched from the original group norm
@param eps: epsilon, by default, eps = 1e-6 to match the original group norm
@return: normalized tensor
"""
b, c = input.size(0), input.size(1)
channel_in_group = int(c / num_groups)
input_reshaped = input.contiguous().view(1, int(b * num_groups),
channel_in_group,
*input.size()[2:])
out = F.batch_norm(
input_reshaped,
mean,
var,
weight=None,
bias=None,
training=False,
momentum=0,
eps=eps)
out = out.view(b, c, *input.size()[2:])
# post affine transform
if weight is not None:
out *= weight.view(1, -1, 1, 1)
if bias is not None:
out += bias.view(1, -1, 1, 1)
return out
def crop_valid_region(x, input_bbox, target_bbox, is_decoder):
"""
Crop the valid region from the tile
@param x: input tile
@param input_bbox: original input bounding box
@param target_bbox: output bounding box
@param scale: scale factor
@return: cropped tile
"""
padded_bbox = [i * 8 if is_decoder else i // 8 for i in input_bbox]
margin = [target_bbox[i] - padded_bbox[i] for i in range(4)]
return x[:, :, margin[2]:x.size(2) + margin[3],
margin[0]:x.size(3) + margin[1]]
# ↓↓↓ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↓↓↓
def perfcount(fn):
def wrapper(*args, **kwargs):
ts = time()
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats(device)
torch_gc()
gc.collect()
ret = fn(*args, **kwargs)
torch_gc()
gc.collect()
if torch.cuda.is_available():
vram = torch.cuda.max_memory_allocated(device) / 2**20
torch.cuda.reset_peak_memory_stats(device)
print(
f'[Tiled VAE]: Done in {time() - ts:.3f}s, max VRAM alloc {vram:.3f} MB'
)
else:
print(f'[Tiled VAE]: Done in {time() - ts:.3f}s')
return ret
return wrapper
# copy end :)
class GroupNormParam:
def __init__(self):
self.var_list = []
self.mean_list = []
self.pixel_list = []
self.weight = None
self.bias = None
def add_tile(self, tile, layer):
var, mean = get_var_mean(tile, 32)
# For giant images, the variance can be larger than max float16
# In this case we create a copy to float32
if var.dtype == torch.float16 and var.isinf().any():
fp32_tile = tile.float()
var, mean = get_var_mean(fp32_tile, 32)
# ============= DEBUG: test for infinite =============
# if torch.isinf(var).any():
# print('var: ', var)
# ====================================================
self.var_list.append(var)
self.mean_list.append(mean)
self.pixel_list.append(tile.shape[2] * tile.shape[3])
if hasattr(layer, 'weight'):
self.weight = layer.weight
self.bias = layer.bias
else:
self.weight = None
self.bias = None
def summary(self):
"""
summarize the mean and var and return a function
that apply group norm on each tile
"""
if len(self.var_list) == 0:
return None
var = torch.vstack(self.var_list)
mean = torch.vstack(self.mean_list)
max_value = max(self.pixel_list)
pixels = torch.tensor(
self.pixel_list, dtype=torch.float32, device=device) / max_value
sum_pixels = torch.sum(pixels)
pixels = pixels.unsqueeze(1) / sum_pixels
var = torch.sum(var * pixels, dim=0)
mean = torch.sum(mean * pixels, dim=0)
return lambda x: custom_group_norm(x, 32, mean, var, self.weight, self.
bias)
@staticmethod
def from_tile(tile, norm):
"""
create a function from a single tile without summary
"""
var, mean = get_var_mean(tile, 32)
if var.dtype == torch.float16 and var.isinf().any():
fp32_tile = tile.float()
var, mean = get_var_mean(fp32_tile, 32)
# if it is a macbook, we need to convert back to float16
if var.device.type == 'mps':
# clamp to avoid overflow
var = torch.clamp(var, 0, 60000)
var = var.half()
mean = mean.half()
if hasattr(norm, 'weight'):
weight = norm.weight
bias = norm.bias
else:
weight = None
bias = None
def group_norm_func(x, mean=mean, var=var, weight=weight, bias=bias):
return custom_group_norm(x, 32, mean, var, weight, bias, 1e-6)
return group_norm_func
class VAEHook:
def __init__(self,
net,
tile_size,
is_decoder,
fast_decoder,
fast_encoder,
color_fix,
to_gpu=False):
self.net = net # encoder | decoder
self.tile_size = tile_size
self.is_decoder = is_decoder
self.fast_mode = (fast_encoder and not is_decoder) or (fast_decoder
and is_decoder)
self.color_fix = color_fix and not is_decoder
self.to_gpu = to_gpu
self.pad = 11 if is_decoder else 32
def __call__(self, x):
B, C, H, W = x.shape
original_device = next(self.net.parameters()).device
try:
if self.to_gpu:
self.net.to(get_optimal_device())
if max(H, W) <= self.pad * 2 + self.tile_size:
print(
'[Tiled VAE]: the input size is tiny and unnecessary to tile.'
)
return self.net.original_forward(x)
else:
return self.vae_tile_forward(x)
finally:
self.net.to(original_device)
def get_best_tile_size(self, lowerbound, upperbound):
"""
Get the best tile size for GPU memory
"""
divider = 32
while divider >= 2:
remainer = lowerbound % divider
if remainer == 0:
return lowerbound
candidate = lowerbound - remainer + divider
if candidate <= upperbound:
return candidate
divider //= 2
return lowerbound
def split_tiles(self, h, w):
"""
Tool function to split the image into tiles
@param h: height of the image
@param w: width of the image
@return: tile_input_bboxes, tile_output_bboxes
"""
tile_input_bboxes, tile_output_bboxes = [], []
tile_size = self.tile_size
pad = self.pad
num_height_tiles = math.ceil((h - 2 * pad) / tile_size)
num_width_tiles = math.ceil((w - 2 * pad) / tile_size)
# If any of the numbers are 0, we let it be 1
# This is to deal with long and thin images
num_height_tiles = max(num_height_tiles, 1)
num_width_tiles = max(num_width_tiles, 1)
# Suggestions from https://github.com/Kahsolt: auto shrink the tile size
real_tile_height = math.ceil((h - 2 * pad) / num_height_tiles)
real_tile_width = math.ceil((w - 2 * pad) / num_width_tiles)
real_tile_height = self.get_best_tile_size(real_tile_height, tile_size)
real_tile_width = self.get_best_tile_size(real_tile_width, tile_size)
print(
f'[Tiled VAE]: split to {num_height_tiles}x{num_width_tiles}={num_height_tiles*num_width_tiles} tiles.',
f'Optimal tile size {real_tile_width}x{real_tile_height}, original tile size {tile_size}x{tile_size}'
)
for i in range(num_height_tiles):
for j in range(num_width_tiles):
# bbox: [x1, x2, y1, y2]
# the padding is is unnessary for image borders. So we directly start from (32, 32)
input_bbox = [
pad + j * real_tile_width,
min(pad + (j + 1) * real_tile_width, w),
pad + i * real_tile_height,
min(pad + (i + 1) * real_tile_height, h),
]
# if the output bbox is close to the image boundary, we extend it to the image boundary
output_bbox = [
input_bbox[0] if input_bbox[0] > pad else 0,
input_bbox[1] if input_bbox[1] < w - pad else w,
input_bbox[2] if input_bbox[2] > pad else 0,
input_bbox[3] if input_bbox[3] < h - pad else h,
]
# scale to get the final output bbox
output_bbox = [
x * 8 if self.is_decoder else x // 8 for x in output_bbox
]
tile_output_bboxes.append(output_bbox)
# indistinguishable expand the input bbox by pad pixels
tile_input_bboxes.append([
max(0, input_bbox[0] - pad),
min(w, input_bbox[1] + pad),
max(0, input_bbox[2] - pad),
min(h, input_bbox[3] + pad),
])
return tile_input_bboxes, tile_output_bboxes
@torch.no_grad()
def estimate_group_norm(self, z, task_queue, color_fix):
device = z.device
tile = z
last_id = len(task_queue) - 1
while last_id >= 0 and task_queue[last_id][0] != 'pre_norm':
last_id -= 1
if last_id <= 0 or task_queue[last_id][0] != 'pre_norm':
raise ValueError('No group norm found in the task queue')
# estimate until the last group norm
for i in range(last_id + 1):
task = task_queue[i]
if task[0] == 'pre_norm':
group_norm_func = GroupNormParam.from_tile(tile, task[1])
task_queue[i] = ('apply_norm', group_norm_func)
if i == last_id:
return True
tile = group_norm_func(tile)
elif task[0] == 'store_res':
task_id = i + 1
while task_id < last_id and task_queue[task_id][0] != 'add_res':
task_id += 1
if task_id >= last_id:
continue
task_queue[task_id][1] = task[1](tile)
elif task[0] == 'add_res':
tile += task[1].to(device)
task[1] = None
elif color_fix and task[0] == 'downsample':
for j in range(i, last_id + 1):
if task_queue[j][0] == 'store_res':
task_queue[j] = ('store_res_cpu', task_queue[j][1])
return True
else:
tile = task[1](tile)
try:
test_for_nans(tile, 'vae')
except Exception as e:
print(
f'{e}. Nan detected in fast mode estimation. Fast mode disabled.'
)
return False
raise IndexError('Should not reach here')
@perfcount
@torch.no_grad()
def vae_tile_forward(self, z):
"""
Decode a latent vector z into an image in a tiled manner.
@param z: latent vector
@return: image
"""
device = next(self.net.parameters()).device
net = self.net
tile_size = self.tile_size
is_decoder = self.is_decoder
z = z.detach() # detach the input to avoid backprop
N, height, width = z.shape[0], z.shape[2], z.shape[3]
net.last_z_shape = z.shape
# Split the input into tiles and build a task queue for each tile
print(
f'[Tiled VAE]: input_size: {z.shape}, tile_size: {tile_size}, padding: {self.pad}'
)
in_bboxes, out_bboxes = self.split_tiles(height, width)
# Prepare tiles by split the input latents
tiles = []
for input_bbox in in_bboxes:
tile = z[:, :, input_bbox[2]:input_bbox[3],
input_bbox[0]:input_bbox[1]].cpu()
tiles.append(tile)
num_tiles = len(tiles)
num_completed = 0
# Build task queues
single_task_queue = build_task_queue(net, is_decoder)
if self.fast_mode:
# Fast mode: downsample the input image to the tile size,
# then estimate the group norm parameters on the downsampled image
scale_factor = tile_size / max(height, width)
z = z.to(device)
downsampled_z = F.interpolate(
z, scale_factor=scale_factor, mode='nearest-exact')
# use nearest-exact to keep statictics as close as possible
print(
f'[Tiled VAE]: Fast mode enabled, estimating group norm parameters on \
{downsampled_z.shape[3]} x {downsampled_z.shape[2]} image')
# ======= Special thanks to @Kahsolt for distribution shift issue ======= #
# The downsampling will heavily distort its mean and std, so we need to recover it.
std_old, mean_old = torch.std_mean(z, dim=[0, 2, 3], keepdim=True)
std_new, mean_new = torch.std_mean(
downsampled_z, dim=[0, 2, 3], keepdim=True)
downsampled_z = (downsampled_z
- mean_new) / std_new * std_old + mean_old
del std_old, mean_old, std_new, mean_new
# occasionally the std_new is too small or too large, which exceeds the range of float16
# so we need to clamp it to max z's range.
downsampled_z = torch.clamp_(
downsampled_z, min=z.min(), max=z.max())
estimate_task_queue = clone_task_queue(single_task_queue)
if self.estimate_group_norm(
downsampled_z, estimate_task_queue,
color_fix=self.color_fix):
single_task_queue = estimate_task_queue
del downsampled_z
task_queues = [
clone_task_queue(single_task_queue) for _ in range(num_tiles)
]
# Dummy result
result = None
result_approx = None
# Free memory of input latent tensor
del z
# Task queue execution
pbar = tqdm(
total=num_tiles * len(task_queues[0]),
desc=
f"[Tiled VAE]: Executing {'Decoder' if is_decoder else 'Encoder'} Task Queue: "
)
# execute the task back and forth when switch tiles so that we always
# keep one tile on the GPU to reduce unnecessary data transfer
forward = True
interrupted = False
while True:
group_norm_param = GroupNormParam()
for i in range(num_tiles) if forward else reversed(
range(num_tiles)):
tile = tiles[i].to(device)
input_bbox = in_bboxes[i]
task_queue = task_queues[i]
interrupted = False
while len(task_queue) > 0:
# DEBUG: current task
task = task_queue.pop(0)
if task[0] == 'pre_norm':
group_norm_param.add_tile(tile, task[1])
break
elif task[0] == 'store_res' or task[0] == 'store_res_cpu':
task_id = 0
res = task[1](tile)
if not self.fast_mode or task[0] == 'store_res_cpu':
res = res.cpu()
while task_queue[task_id][0] != 'add_res':
task_id += 1
task_queue[task_id][1] = res
elif task[0] == 'add_res':
tile += task[1].to(device)
task[1] = None
else:
tile = task[1](tile)
pbar.update(1)
if interrupted:
break
# check for NaNs in the tile.
# If there are NaNs, we abort the process to save user's time
test_for_nans(tile, 'vae')
if len(task_queue) == 0:
tiles[i] = None
num_completed += 1
if result is None: # NOTE: dim C varies from different cases, can only be inited dynamically
result = torch.zeros(
(N, tile.shape[1],
height * 8 if is_decoder else height // 8,
width * 8 if is_decoder else width // 8),
device=device,
requires_grad=False)
result[:, :, out_bboxes[i][2]:out_bboxes[i][3],
out_bboxes[i][0]:out_bboxes[i]
[1]] = crop_valid_region(tile, in_bboxes[i],
out_bboxes[i], is_decoder)
del tile
elif i == num_tiles - 1 and forward:
forward = False
tiles[i] = tile
elif i == 0 and not forward:
forward = True
tiles[i] = tile
else:
tiles[i] = tile.cpu()
del tile
if interrupted:
break
if num_completed == num_tiles:
break
# insert the group norm task to the head of each task queue
group_norm_func = group_norm_param.summary()
if group_norm_func is not None:
for i in range(num_tiles):
task_queue = task_queues[i]
task_queue.insert(0, ('apply_norm', group_norm_func))
# Done!
pbar.close()
return result if result is not None else result_approx.to(device)

View File

@@ -0,0 +1,16 @@
import transformers
from packaging import version
def compatible_position_ids(state_dict, position_id_key):
"""Transformers no longer expect position_ids after transformers==4.31
https://github.com/huggingface/transformers/pull/24505
Args:
position_id_key (str): position_ids key,
such as(encoder.embeddings.position_ids)
"""
transformer_version = version.parse('.'.join(
transformers.__version__.split('.')[:2]))
if transformer_version >= version.parse('4.31.0'):
del state_dict[position_id_key]

View File

@@ -75,6 +75,7 @@ class CVTasks(object):
# image editing
skin_retouching = 'skin-retouching'
image_super_resolution = 'image-super-resolution'
image_super_resolution_pasd = 'image-super-resolution-pasd'
image_debanding = 'image-debanding'
image_colorization = 'image-colorization'
image_color_enhancement = 'image-color-enhancement'

View File

@@ -0,0 +1,43 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
import unittest
import cv2
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level
class ImageSuperResolutionPASDTest(unittest.TestCase):
def setUp(self) -> None:
self.model_id = 'damo/PASD_image_super_resolutions'
self.img = 'data/test/images/dogs.jpg'
self.input = {
'image': self.img,
'prompt': '',
'upscale': 1,
'fidelity_scale_fg': 1.5,
'fidelity_scale_bg': 0.7
}
self.task = Tasks.image_super_resolution_pasd
def pipeline_inference(self, pipeline: Pipeline, input: dict):
result = pipeline(input)
if result is not None:
cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG])
print(f'Output written to {osp.abspath("result.png")}')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub(self):
super_resolution = pipeline(
Tasks.image_super_resolution_pasd, model=self.model_id)
self.pipeline_inference(super_resolution, self.input)
if __name__ == '__main__':
unittest.main()

View File

@@ -7,6 +7,7 @@ isolated: # test cases that may require excessive anmount of GPU memory or run
- test_dialog_modeling.py
- test_csanmt_translation.py
- test_image_super_resolution.py
- test_image_super_resolution_pasd.py
- test_easycv_trainer.py
- test_segformer.py
- test_segmentation_pipeline.py