mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
Merge branch 'master' of gitlab.alibaba-inc.com:Ali-MaaS/MaaS-lib into master-gitlab
This commit is contained in:
@@ -48,7 +48,7 @@ ENV SETUPTOOLS_USE_DISTUTILS=stdlib
|
||||
RUN CUDA_HOME=/usr/local/cuda TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6" pip install --no-cache-dir 'git+https://github.com/facebookresearch/detectron2.git'
|
||||
|
||||
# torchmetrics==0.11.4 for ofa
|
||||
RUN pip install --no-cache-dir tiktoken torchmetrics==0.11.4 'transformers<4.31.0' transformers_stream_generator 'protobuf<=3.20.0' bitsandbytes basicsr
|
||||
RUN pip install --no-cache-dir tiktoken torchmetrics==0.11.4 transformers_stream_generator 'protobuf<=3.20.0' bitsandbytes basicsr
|
||||
COPY docker/scripts/install_flash_attension.sh /tmp/install_flash_attension.sh
|
||||
RUN if [ "$USE_GPU" = "True" ] ; then \
|
||||
bash /tmp/install_flash_attension.sh; \
|
||||
|
||||
@@ -339,6 +339,7 @@ class Pipelines(object):
|
||||
image_colorization = 'unet-image-colorization'
|
||||
image_style_transfer = 'AAMS-style-transfer'
|
||||
image_super_resolution = 'rrdb-image-super-resolution'
|
||||
image_super_resolution_pasd = 'image-super-resolution-pasd'
|
||||
image_debanding = 'rrdb-image-debanding'
|
||||
face_image_generation = 'gan-face-image-generation'
|
||||
product_retrieval_embedding = 'resnet50-product-retrieval-embedding'
|
||||
|
||||
@@ -14,10 +14,11 @@ from . import (action_recognition, animal_recognition, bad_image_detecting,
|
||||
image_quality_assessment_degradation,
|
||||
image_quality_assessment_man, image_quality_assessment_mos,
|
||||
image_reid_person, image_restoration,
|
||||
image_semantic_segmentation, image_to_image_generation,
|
||||
image_to_image_translation, language_guided_video_summarization,
|
||||
movie_scene_segmentation, object_detection,
|
||||
panorama_depth_estimation, pedestrian_attribute_recognition,
|
||||
image_semantic_segmentation, image_super_resolution_pasd,
|
||||
image_to_image_generation, image_to_image_translation,
|
||||
language_guided_video_summarization, movie_scene_segmentation,
|
||||
object_detection, panorama_depth_estimation,
|
||||
pedestrian_attribute_recognition,
|
||||
pointcloud_sceneflow_estimation, product_retrieval_embedding,
|
||||
referring_video_object_segmentation,
|
||||
robust_image_classification, salient_detection,
|
||||
|
||||
24
modelscope/models/cv/image_super_resolution_pasd/__init__.py
Normal file
24
modelscope/models/cv/image_super_resolution_pasd/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
from .controlnet import ControlNetModel
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'unet_2d_condition': ['UNet2DConditionModel'],
|
||||
'controlnet': ['ControlNetModel']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
421
modelscope/models/cv/image_super_resolution_pasd/attention.py
Normal file
421
modelscope/models/cv/image_super_resolution_pasd/attention.py
Normal file
@@ -0,0 +1,421 @@
|
||||
# Part of the implementation is borrowed and modified from diffusers,
|
||||
# publicly available at https://github.com/huggingface/diffusers/tree/main/src/diffusers/models/attention.py
|
||||
import math
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from diffusers.models.attention_processor import Attention
|
||||
from diffusers.models.embeddings import CombinedTimestepLabelEmbeddings
|
||||
from diffusers.utils import maybe_allow_in_graph
|
||||
from torch import nn
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
r"""
|
||||
A basic Transformer block.
|
||||
|
||||
Parameters:
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
||||
only_cross_attention (`bool`, *optional*):
|
||||
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
||||
double_self_attention (`bool`, *optional*):
|
||||
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm (:
|
||||
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
||||
attention_bias (:
|
||||
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout=0.0,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
pixelwise_cross_attention_dim: Optional[int] = None,
|
||||
activation_fn: str = 'geglu',
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
double_self_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_type: str = 'layer_norm',
|
||||
final_dropout: bool = False,
|
||||
use_pixelwise_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.only_cross_attention = only_cross_attention
|
||||
|
||||
self.use_ada_layer_norm_zero = (
|
||||
num_embeds_ada_norm is not None) and norm_type == 'ada_norm_zero'
|
||||
self.use_ada_layer_norm = (num_embeds_ada_norm
|
||||
is not None) and norm_type == 'ada_norm'
|
||||
|
||||
if norm_type in ('ada_norm',
|
||||
'ada_norm_zero') and num_embeds_ada_norm is None:
|
||||
raise ValueError(
|
||||
f'`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to'
|
||||
f' define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}.'
|
||||
)
|
||||
|
||||
# Define 3 blocks. Each block has its own normalization layer.
|
||||
# 1. Self-Attn
|
||||
if self.use_ada_layer_norm:
|
||||
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
elif self.use_ada_layer_norm_zero:
|
||||
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
||||
else:
|
||||
self.norm1 = nn.LayerNorm(
|
||||
dim, elementwise_affine=norm_elementwise_affine)
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=cross_attention_dim
|
||||
if only_cross_attention else None,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
|
||||
# 2. Cross-Attn
|
||||
if cross_attention_dim is not None or double_self_attention:
|
||||
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
||||
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
||||
# the second cross attention block.
|
||||
self.norm2 = (
|
||||
AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
if self.use_ada_layer_norm else nn.LayerNorm(
|
||||
dim, elementwise_affine=norm_elementwise_affine))
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim
|
||||
if not double_self_attention else None,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
upcast_attention=upcast_attention,
|
||||
) # is self-attn if encoder_hidden_states is none
|
||||
else:
|
||||
self.norm2 = None
|
||||
self.attn2 = None
|
||||
|
||||
# 2+. pixelwise-Attn
|
||||
if use_pixelwise_attention:
|
||||
self.norm2_plus = (
|
||||
AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
if self.use_ada_layer_norm else nn.LayerNorm(
|
||||
dim, elementwise_affine=norm_elementwise_affine))
|
||||
self.attn2_plus = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=pixelwise_cross_attention_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
upcast_attention=upcast_attention,
|
||||
) # is self-attn if encoder_hidden_states is none
|
||||
else:
|
||||
self.norm2_plus = None
|
||||
self.attn2_plus = None
|
||||
|
||||
# 3. Feed-forward
|
||||
self.norm3 = nn.LayerNorm(
|
||||
dim, elementwise_affine=norm_elementwise_affine)
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
final_dropout=final_dropout)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
encoder_pixelwise_hidden_states=None,
|
||||
encoder_pixelwise_attention_mask=None,
|
||||
timestep=None,
|
||||
cross_attention_kwargs=None,
|
||||
class_labels=None,
|
||||
):
|
||||
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||
# 1. Self-Attention
|
||||
if self.use_ada_layer_norm:
|
||||
norm_hidden_states = self.norm1(hidden_states, timestep)
|
||||
elif self.use_ada_layer_norm_zero:
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
||||
hidden_states,
|
||||
timestep,
|
||||
class_labels,
|
||||
hidden_dtype=hidden_states.dtype)
|
||||
else:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states
|
||||
if self.only_cross_attention else None,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
if self.use_ada_layer_norm_zero:
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 2. Cross-Attention
|
||||
if self.attn2 is not None:
|
||||
norm_hidden_states = (
|
||||
self.norm2(hidden_states, timestep)
|
||||
if self.use_ada_layer_norm else self.norm2(hidden_states))
|
||||
# TODO (Birch-San): Here we should prepare the encoder_attention mask correctly
|
||||
# prepare attention mask here
|
||||
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 2+. pixelwise-Attention
|
||||
if self.attn2_plus is not None:
|
||||
norm_hidden_states = (
|
||||
self.norm2_plus(hidden_states, timestep)
|
||||
if self.use_ada_layer_norm else self.norm2_plus(hidden_states))
|
||||
|
||||
attn_output = self.attn2_plus(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_pixelwise_hidden_states,
|
||||
attention_mask=encoder_pixelwise_attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 3. Feed-forward
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
|
||||
if self.use_ada_layer_norm_zero:
|
||||
norm_hidden_states = norm_hidden_states * (
|
||||
1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
if self.use_ada_layer_norm_zero:
|
||||
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
r"""
|
||||
A feed-forward layer.
|
||||
|
||||
Parameters:
|
||||
dim (`int`): The number of channels in the input.
|
||||
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
||||
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_out: Optional[int] = None,
|
||||
mult: int = 4,
|
||||
dropout: float = 0.0,
|
||||
activation_fn: str = 'geglu',
|
||||
final_dropout: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = dim_out if dim_out is not None else dim
|
||||
|
||||
if activation_fn == 'gelu':
|
||||
act_fn = GELU(dim, inner_dim)
|
||||
if activation_fn == 'gelu-approximate':
|
||||
act_fn = GELU(dim, inner_dim, approximate='tanh')
|
||||
elif activation_fn == 'geglu':
|
||||
act_fn = GEGLU(dim, inner_dim)
|
||||
elif activation_fn == 'geglu-approximate':
|
||||
act_fn = ApproximateGELU(dim, inner_dim)
|
||||
|
||||
self.net = nn.ModuleList([])
|
||||
# project in
|
||||
self.net.append(act_fn)
|
||||
# project dropout
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
# project out
|
||||
self.net.append(nn.Linear(inner_dim, dim_out))
|
||||
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
||||
if final_dropout:
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
for module in self.net:
|
||||
hidden_states = module(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GELU(nn.Module):
|
||||
r"""
|
||||
GELU activation function with tanh approximation support with `approximate="tanh"`.
|
||||
"""
|
||||
|
||||
def __init__(self, dim_in: int, dim_out: int, approximate: str = 'none'):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out)
|
||||
self.approximate = approximate
|
||||
|
||||
def gelu(self, gate):
|
||||
if gate.device.type != 'mps':
|
||||
return F.gelu(gate, approximate=self.approximate)
|
||||
# mps: gelu is not implemented for float16
|
||||
return F.gelu(
|
||||
gate.to(dtype=torch.float32),
|
||||
approximate=self.approximate).to(dtype=gate.dtype)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.proj(hidden_states)
|
||||
hidden_states = self.gelu(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GEGLU(nn.Module):
|
||||
r"""
|
||||
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
||||
|
||||
Parameters:
|
||||
dim_in (`int`): The number of channels in the input.
|
||||
dim_out (`int`): The number of channels in the output.
|
||||
"""
|
||||
|
||||
def __init__(self, dim_in: int, dim_out: int):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def gelu(self, gate):
|
||||
if gate.device.type != 'mps':
|
||||
return F.gelu(gate)
|
||||
# mps: gelu is not implemented for float16
|
||||
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
||||
return hidden_states * self.gelu(gate)
|
||||
|
||||
|
||||
class ApproximateGELU(nn.Module):
|
||||
"""
|
||||
The approximate form of Gaussian Error Linear Unit (GELU)
|
||||
|
||||
For more details, see section 2: https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
|
||||
def __init__(self, dim_in: int, dim_out: int):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
"""
|
||||
Norm layer modified to incorporate timestep embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim, num_embeddings):
|
||||
super().__init__()
|
||||
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
|
||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
|
||||
|
||||
def forward(self, x, timestep):
|
||||
emb = self.linear(self.silu(self.emb(timestep)))
|
||||
scale, shift = torch.chunk(emb, 2)
|
||||
x = self.norm(x) * (1 + scale) + shift
|
||||
return x
|
||||
|
||||
|
||||
class AdaLayerNormZero(nn.Module):
|
||||
"""
|
||||
Norm layer adaptive layer norm zero (adaLN-Zero).
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim, num_embeddings):
|
||||
super().__init__()
|
||||
|
||||
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings,
|
||||
embedding_dim)
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
||||
self.norm = nn.LayerNorm(
|
||||
embedding_dim, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
def forward(self, x, timestep, class_labels, hidden_dtype=None):
|
||||
emb = self.linear(
|
||||
self.silu(
|
||||
self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(
|
||||
6, dim=1)
|
||||
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
||||
|
||||
|
||||
class AdaGroupNorm(nn.Module):
|
||||
"""
|
||||
GroupNorm layer modified to incorporate timestep embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embedding_dim: int,
|
||||
out_dim: int,
|
||||
num_groups: int,
|
||||
act_fn: Optional[str] = None,
|
||||
eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.num_groups = num_groups
|
||||
self.eps = eps
|
||||
self.act = None
|
||||
if act_fn == 'swish':
|
||||
self.act = lambda x: F.silu(x)
|
||||
elif act_fn == 'mish':
|
||||
self.act = nn.Mish()
|
||||
elif act_fn == 'silu':
|
||||
self.act = nn.SiLU()
|
||||
elif act_fn == 'gelu':
|
||||
self.act = nn.GELU()
|
||||
|
||||
self.linear = nn.Linear(embedding_dim, out_dim * 2)
|
||||
|
||||
def forward(self, x, emb):
|
||||
if self.act:
|
||||
emb = self.act(emb)
|
||||
emb = self.linear(emb)
|
||||
emb = emb[:, :, None, None]
|
||||
scale, shift = emb.chunk(2, dim=1)
|
||||
|
||||
x = F.group_norm(x, self.num_groups, eps=self.eps)
|
||||
x = x * (1 + scale) + shift
|
||||
return x
|
||||
738
modelscope/models/cv/image_super_resolution_pasd/controlnet.py
Normal file
738
modelscope/models/cv/image_super_resolution_pasd/controlnet.py
Normal file
@@ -0,0 +1,738 @@
|
||||
# Part of the implementation is borrowed and modified from diffusers,
|
||||
# publicly available at https://github.com/huggingface/diffusers/tree/main/src/diffusers/models/controlnet.py
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.models import ModelMixin # , ControlNetModel
|
||||
from diffusers.models.attention_processor import (AttentionProcessor,
|
||||
AttnProcessor)
|
||||
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
||||
from diffusers.utils import BaseOutput, logging
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torchvision import utils
|
||||
|
||||
from modelscope.models.cv.super_resolution.rrdbnet_arch import RRDB
|
||||
from .unet_2d_blocks import (CrossAttnDownBlock2D, DownBlock2D,
|
||||
UNetMidBlock2DCrossAttn, get_down_block)
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class ControlNetOutput(BaseOutput):
|
||||
controlnet_cond_mid: torch.Tensor
|
||||
down_block_res_samples: Tuple[torch.Tensor]
|
||||
mid_block_res_sample: torch.Tensor
|
||||
|
||||
|
||||
class ControlNetConditioningEmbedding(nn.Module):
|
||||
"""
|
||||
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
|
||||
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
|
||||
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
|
||||
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
|
||||
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
|
||||
model) to encode image-space conditions ... into feature maps ..."
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conditioning_embedding_channels: int,
|
||||
conditioning_channels: int = 3,
|
||||
block_out_channels: Tuple[int] = (16, 32, 96, 256),
|
||||
return_rgbs: bool = True,
|
||||
use_rrdb: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.return_rgbs = return_rgbs
|
||||
self.use_rrdb = use_rrdb
|
||||
|
||||
self.conv_in = nn.Conv2d(
|
||||
conditioning_channels,
|
||||
block_out_channels[0],
|
||||
kernel_size=3,
|
||||
padding=1)
|
||||
|
||||
if self.use_rrdb:
|
||||
num_rrdb_block = 2
|
||||
layers = (
|
||||
RRDB(block_out_channels[0], block_out_channels[0])
|
||||
for i in range(num_rrdb_block))
|
||||
self.preprocesser = nn.Sequential(*layers)
|
||||
|
||||
self.blocks = nn.ModuleList([])
|
||||
if return_rgbs:
|
||||
self.to_rgbs = nn.ModuleList([])
|
||||
|
||||
for i in range(len(block_out_channels) - 1):
|
||||
channel_in = block_out_channels[i]
|
||||
channel_out = block_out_channels[i + 1]
|
||||
self.blocks.append(
|
||||
nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
|
||||
self.blocks.append(
|
||||
nn.Conv2d(
|
||||
channel_in,
|
||||
channel_out,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
stride=2))
|
||||
|
||||
if return_rgbs:
|
||||
self.to_rgbs.append(
|
||||
nn.Conv2d(channel_out, 3, kernel_size=3, padding=1))
|
||||
|
||||
self.conv_out = zero_module(
|
||||
nn.Conv2d(
|
||||
block_out_channels[-1],
|
||||
conditioning_embedding_channels,
|
||||
kernel_size=3,
|
||||
padding=1))
|
||||
|
||||
def forward(self, conditioning):
|
||||
embedding = self.conv_in(conditioning)
|
||||
embedding = F.silu(embedding)
|
||||
|
||||
if self.use_rrdb:
|
||||
embedding = self.preprocesser(embedding)
|
||||
|
||||
out_rgbs = []
|
||||
for i, block in enumerate(self.blocks):
|
||||
embedding = block(embedding)
|
||||
embedding = F.silu(embedding)
|
||||
|
||||
if i % 2 and self.return_rgbs: # 0
|
||||
out_rgbs.append(self.to_rgbs[i // 2](embedding))
|
||||
|
||||
embedding = self.conv_out(embedding)
|
||||
|
||||
if self.return_rgbs:
|
||||
return embedding, out_rgbs
|
||||
else:
|
||||
return embedding
|
||||
|
||||
|
||||
class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 4,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str] = (
|
||||
'CrossAttnDownBlock2D',
|
||||
'CrossAttnDownBlock2D',
|
||||
'CrossAttnDownBlock2D',
|
||||
'DownBlock2D',
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
act_fn: str = 'silu',
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
attention_head_dim: Union[int, Tuple[int]] = 8,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = 'default',
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
controlnet_conditioning_channel_order: str = 'rgb',
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16,
|
||||
32,
|
||||
96,
|
||||
256),
|
||||
global_pool_conditions: bool = False,
|
||||
return_rgbs: bool = False,
|
||||
use_rrdb: bool = False):
|
||||
super().__init__()
|
||||
|
||||
# Check inputs
|
||||
if len(block_out_channels) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f'Must provide the same number of `block_out_channels` as `down_block_types`. \
|
||||
`block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}.'
|
||||
)
|
||||
|
||||
if not isinstance(
|
||||
only_cross_attention,
|
||||
bool) and len(only_cross_attention) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f'Must provide the same number of `only_cross_attention` as `down_block_types`. \
|
||||
`only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}.'
|
||||
)
|
||||
|
||||
if not isinstance(
|
||||
attention_head_dim,
|
||||
int) and len(attention_head_dim) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f'Must provide the same number of `attention_head_dim` as `down_block_types`. \
|
||||
`attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}.'
|
||||
)
|
||||
|
||||
# input
|
||||
self.return_rgbs = return_rgbs
|
||||
conv_in_kernel = 3
|
||||
conv_in_padding = (conv_in_kernel - 1) // 2
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels,
|
||||
block_out_channels[0],
|
||||
kernel_size=conv_in_kernel,
|
||||
padding=conv_in_padding)
|
||||
|
||||
# time
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
|
||||
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos,
|
||||
freq_shift)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
|
||||
self.time_embedding = TimestepEmbedding(
|
||||
timestep_input_dim,
|
||||
time_embed_dim,
|
||||
act_fn=act_fn,
|
||||
)
|
||||
|
||||
# class embedding
|
||||
if class_embed_type is None and num_class_embeds is not None:
|
||||
self.class_embedding = nn.Embedding(num_class_embeds,
|
||||
time_embed_dim)
|
||||
elif class_embed_type == 'timestep':
|
||||
self.class_embedding = TimestepEmbedding(timestep_input_dim,
|
||||
time_embed_dim)
|
||||
elif class_embed_type == 'identity':
|
||||
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
||||
elif class_embed_type == 'projection':
|
||||
if projection_class_embeddings_input_dim is None:
|
||||
raise ValueError(
|
||||
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
||||
)
|
||||
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
||||
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
||||
# 2. it projects from an arbitrary input dimension.
|
||||
#
|
||||
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
||||
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
||||
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
||||
self.class_embedding = TimestepEmbedding(
|
||||
projection_class_embeddings_input_dim, time_embed_dim)
|
||||
else:
|
||||
self.class_embedding = None
|
||||
|
||||
# control net conditioning embedding
|
||||
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
|
||||
conditioning_embedding_channels=block_out_channels[0],
|
||||
block_out_channels=conditioning_embedding_out_channels,
|
||||
return_rgbs=return_rgbs,
|
||||
use_rrdb=use_rrdb,
|
||||
)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.controlnet_down_blocks = nn.ModuleList([])
|
||||
|
||||
if isinstance(only_cross_attention, bool):
|
||||
only_cross_attention = [only_cross_attention
|
||||
] * len(down_block_types)
|
||||
|
||||
if isinstance(attention_head_dim, int):
|
||||
attention_head_dim = (attention_head_dim, ) * len(down_block_types)
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
|
||||
controlnet_block = nn.Conv2d(
|
||||
output_channel, output_channel, kernel_size=1)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_down_blocks.append(controlnet_block)
|
||||
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_downsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim[i],
|
||||
downsample_padding=downsample_padding,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
for _ in range(layers_per_block):
|
||||
controlnet_block = nn.Conv2d(
|
||||
output_channel, output_channel, kernel_size=1)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_down_blocks.append(controlnet_block)
|
||||
|
||||
if not is_final_block:
|
||||
controlnet_block = nn.Conv2d(
|
||||
output_channel, output_channel, kernel_size=1)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_down_blocks.append(controlnet_block)
|
||||
|
||||
# mid
|
||||
mid_block_channel = block_out_channels[-1]
|
||||
|
||||
controlnet_block = nn.Conv2d(
|
||||
mid_block_channel, mid_block_channel, kernel_size=1)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_mid_block = controlnet_block
|
||||
|
||||
self.mid_block = UNetMidBlock2DCrossAttn(
|
||||
in_channels=mid_block_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_unet(
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
controlnet_conditioning_channel_order: str = 'rgb',
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32,
|
||||
96, 256),
|
||||
load_weights_from_unet: bool = True,
|
||||
):
|
||||
r"""
|
||||
Instantiate Controlnet class from UNet2DConditionModel.
|
||||
|
||||
Parameters:
|
||||
unet (`UNet2DConditionModel`):
|
||||
UNet model which weights are copied to the ControlNet. Note that all configuration options are also
|
||||
copied where applicable.
|
||||
"""
|
||||
controlnet = cls(
|
||||
in_channels=unet.config.in_channels,
|
||||
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
||||
freq_shift=unet.config.freq_shift,
|
||||
down_block_types=unet.config.down_block_types,
|
||||
only_cross_attention=unet.config.only_cross_attention,
|
||||
block_out_channels=unet.config.block_out_channels,
|
||||
layers_per_block=unet.config.layers_per_block,
|
||||
downsample_padding=unet.config.downsample_padding,
|
||||
mid_block_scale_factor=unet.config.mid_block_scale_factor,
|
||||
act_fn=unet.config.act_fn,
|
||||
norm_num_groups=unet.config.norm_num_groups,
|
||||
norm_eps=unet.config.norm_eps,
|
||||
cross_attention_dim=unet.config.cross_attention_dim,
|
||||
attention_head_dim=unet.config.attention_head_dim,
|
||||
use_linear_projection=unet.config.use_linear_projection,
|
||||
class_embed_type=unet.config.class_embed_type,
|
||||
num_class_embeds=unet.config.num_class_embeds,
|
||||
upcast_attention=unet.config.upcast_attention,
|
||||
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
|
||||
projection_class_embeddings_input_dim=unet.config.
|
||||
projection_class_embeddings_input_dim,
|
||||
controlnet_conditioning_channel_order=
|
||||
controlnet_conditioning_channel_order,
|
||||
conditioning_embedding_out_channels=
|
||||
conditioning_embedding_out_channels,
|
||||
)
|
||||
|
||||
if load_weights_from_unet:
|
||||
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
|
||||
controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
|
||||
controlnet.time_embedding.load_state_dict(
|
||||
unet.time_embedding.state_dict())
|
||||
|
||||
if controlnet.class_embedding:
|
||||
controlnet.class_embedding.load_state_dict(
|
||||
unet.class_embedding.state_dict())
|
||||
|
||||
controlnet.down_blocks.load_state_dict(
|
||||
unet.down_blocks.state_dict())
|
||||
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
|
||||
|
||||
if controlnet.sr_model is not None:
|
||||
load_net = torch.load(
|
||||
'annotator/ckpts/RealESRNet_x4plus.pth',
|
||||
map_location=lambda storage, loc: storage)
|
||||
if 'params_ema' in load_net:
|
||||
load_net = load_net['params_ema']
|
||||
elif 'params' in load_net:
|
||||
load_net = load_net['params']
|
||||
# remove unnecessary 'module.'
|
||||
for k, v in deepcopy(load_net).items():
|
||||
if k.startswith('module.'):
|
||||
load_net[k[7:]] = v
|
||||
load_net.pop(k)
|
||||
controlnet.sr_model.load_state_dict(load_net, strict=True)
|
||||
|
||||
return controlnet
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module,
|
||||
processors: Dict[str,
|
||||
AttentionProcessor]):
|
||||
if hasattr(module, 'set_processor'):
|
||||
processors[f'{name}.processor'] = module.processor
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f'{name}.{sub_name}', child,
|
||||
processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor,
|
||||
Dict[str,
|
||||
AttentionProcessor]]):
|
||||
r"""
|
||||
Parameters:
|
||||
`processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
of **all** `Attention` layers.
|
||||
In case `processor` is a dict, the key needs to define the path to
|
||||
the corresponding cross attention processor.
|
||||
This is strongly recommended when setting trainable attention processors.:
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f'A dict of processors was passed, but the number of processors {len(processor)} does not match the'
|
||||
f' number of attention layers: {count}. Please make sure to pass {count} processor classes.'
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module,
|
||||
processor):
|
||||
if hasattr(module, 'set_processor'):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f'{name}.processor'))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f'{name}.{sub_name}', child,
|
||||
processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
"""
|
||||
self.set_attn_processor(AttnProcessor())
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
def set_attention_slice(self, slice_size):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||||
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||||
`"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
|
||||
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
||||
must be a multiple of `slice_size`.
|
||||
"""
|
||||
sliceable_head_dims = []
|
||||
|
||||
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
||||
if hasattr(module, 'set_attention_slice'):
|
||||
sliceable_head_dims.append(module.sliceable_head_dim)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_retrieve_sliceable_dims(child)
|
||||
|
||||
# retrieve number of attention layers
|
||||
for module in self.children():
|
||||
fn_recursive_retrieve_sliceable_dims(module)
|
||||
|
||||
num_sliceable_layers = len(sliceable_head_dims)
|
||||
|
||||
if slice_size == 'auto':
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
||||
elif slice_size == 'max':
|
||||
# make smallest slice possible
|
||||
slice_size = num_sliceable_layers * [1]
|
||||
|
||||
slice_size = num_sliceable_layers * [slice_size] if not isinstance(
|
||||
slice_size, list) else slice_size
|
||||
|
||||
if len(slice_size) != len(sliceable_head_dims):
|
||||
raise ValueError(
|
||||
f'You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different'
|
||||
f' attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}.'
|
||||
)
|
||||
|
||||
for i in range(len(slice_size)):
|
||||
size = slice_size[i]
|
||||
dim = sliceable_head_dims[i]
|
||||
if size is not None and size > dim:
|
||||
raise ValueError(
|
||||
f'size {size} has to be smaller or equal to {dim}.')
|
||||
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_attention_slice method
|
||||
# gets the message
|
||||
def fn_recursive_set_attention_slice(module: torch.nn.Module,
|
||||
slice_size: List[int]):
|
||||
if hasattr(module, 'set_attention_slice'):
|
||||
module.set_attention_slice(slice_size.pop())
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_set_attention_slice(child, slice_size)
|
||||
|
||||
reversed_slice_size = list(reversed(slice_size))
|
||||
for module in self.children():
|
||||
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
controlnet_cond: torch.FloatTensor,
|
||||
fg_mask: Optional[torch.FloatTensor] = None,
|
||||
conditioning_scale_fg: float = 1.0,
|
||||
conditioning_scale_bg: float = 1.0,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[ControlNetOutput, Tuple]:
|
||||
# check channel order
|
||||
channel_order = self.config.controlnet_conditioning_channel_order
|
||||
|
||||
if channel_order == 'rgb':
|
||||
# in rgb order by default
|
||||
...
|
||||
elif channel_order == 'bgr':
|
||||
controlnet_cond = torch.flip(controlnet_cond, dims=[1])
|
||||
else:
|
||||
raise ValueError(
|
||||
f'unknown `controlnet_conditioning_channel_order`: {channel_order}'
|
||||
)
|
||||
|
||||
# prepare attention_mask
|
||||
if attention_mask is not None:
|
||||
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = sample.device.type == 'mps'
|
||||
if isinstance(timestep, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
timesteps = torch.tensor([timesteps],
|
||||
dtype=dtype,
|
||||
device=sample.device)
|
||||
elif len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=sample.dtype)
|
||||
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
|
||||
if self.class_embedding is not None:
|
||||
if class_labels is None:
|
||||
raise ValueError(
|
||||
'class_labels should be provided when num_class_embeds > 0'
|
||||
)
|
||||
|
||||
if self.config.class_embed_type == 'timestep':
|
||||
class_labels = self.time_proj(class_labels)
|
||||
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
emb = emb + class_emb
|
||||
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
controlnet_cond_mid = None
|
||||
if self.return_rgbs:
|
||||
controlnet_cond, controlnet_cond_mid = self.controlnet_cond_embedding(
|
||||
controlnet_cond)
|
||||
else:
|
||||
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
||||
|
||||
sample = sample + controlnet_cond
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample, )
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, 'has_cross_attention'
|
||||
) and downsample_block.has_cross_attention:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample, temb=emb)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. mid
|
||||
if self.mid_block is not None:
|
||||
sample = self.mid_block(
|
||||
sample,
|
||||
emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
|
||||
# 5. Control net blocks
|
||||
|
||||
controlnet_down_block_res_samples = ()
|
||||
|
||||
for down_block_res_sample, controlnet_block in zip(
|
||||
down_block_res_samples, self.controlnet_down_blocks):
|
||||
down_block_res_sample = controlnet_block(down_block_res_sample)
|
||||
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (
|
||||
down_block_res_sample, )
|
||||
|
||||
down_block_res_samples = controlnet_down_block_res_samples
|
||||
|
||||
mid_block_res_sample = self.controlnet_mid_block(sample)
|
||||
|
||||
# 6. scaling
|
||||
if guess_mode and not self.config.global_pool_conditions:
|
||||
scales = torch.logspace(
|
||||
-1, 0, len(down_block_res_samples) + 1,
|
||||
device=sample.device) # 0.1 to 1.0
|
||||
|
||||
scales = scales * conditioning_scale_fg
|
||||
down_block_res_samples = [
|
||||
sample * scale
|
||||
for sample, scale in zip(down_block_res_samples, scales)
|
||||
]
|
||||
mid_block_res_sample = mid_block_res_sample * scales[
|
||||
-1] # last one
|
||||
else:
|
||||
if fg_mask is None:
|
||||
down_block_res_samples = [
|
||||
sample * conditioning_scale_fg
|
||||
for sample in down_block_res_samples
|
||||
]
|
||||
mid_block_res_sample = mid_block_res_sample * conditioning_scale_fg
|
||||
else:
|
||||
down_block_masks = [
|
||||
torch.zeros_like(sample) + conditioning_scale_bg
|
||||
for i, sample in enumerate(down_block_res_samples)
|
||||
]
|
||||
mid_block_mask = torch.zeros_like(
|
||||
mid_block_res_sample) + conditioning_scale_bg
|
||||
|
||||
for i, sample in enumerate(down_block_masks):
|
||||
tmp_mask = F.interpolate(
|
||||
fg_mask,
|
||||
size=sample.shape[-2:]).repeat(sample.shape[0],
|
||||
sample.shape[1], 1,
|
||||
1).bool()
|
||||
down_block_masks[i] = sample.masked_fill(
|
||||
tmp_mask, conditioning_scale_fg)
|
||||
|
||||
tmp_mask = F.interpolate(
|
||||
fg_mask, size=mid_block_mask.shape[-2:]).repeat(
|
||||
mid_block_mask.shape[0], mid_block_mask.shape[1], 1,
|
||||
1).bool()
|
||||
mid_block_mask = mid_block_mask.masked_fill(
|
||||
tmp_mask, conditioning_scale_fg)
|
||||
|
||||
down_block_res_samples = [
|
||||
sample * down_block_mask for sample, down_block_mask in
|
||||
zip(down_block_res_samples, down_block_masks)
|
||||
]
|
||||
mid_block_res_sample = mid_block_res_sample * mid_block_mask
|
||||
|
||||
if self.config.global_pool_conditions:
|
||||
down_block_res_samples = [
|
||||
torch.mean(sample, dim=(2, 3), keepdim=True)
|
||||
for sample in down_block_res_samples
|
||||
]
|
||||
mid_block_res_sample = torch.mean(
|
||||
mid_block_res_sample, dim=(2, 3), keepdim=True)
|
||||
|
||||
if not return_dict:
|
||||
return (controlnet_cond_mid, down_block_res_samples,
|
||||
mid_block_res_sample)
|
||||
|
||||
return ControlNetOutput(
|
||||
controlnet_cond_mid=controlnet_cond_mid,
|
||||
down_block_res_samples=down_block_res_samples,
|
||||
mid_block_res_sample=mid_block_res_sample)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
for p in module.parameters():
|
||||
nn.init.zeros_(p)
|
||||
return module
|
||||
159
modelscope/models/cv/image_super_resolution_pasd/misc.py
Normal file
159
modelscope/models/cv/image_super_resolution_pasd/misc.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# Part of the implementation is borrowed and modified from StableSR,
|
||||
# publicly available at https://github.com/IceClear/StableSR/blob/main/scripts/wavelet_color_fix.py
|
||||
import torch
|
||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
convert_ldm_unet_checkpoint, convert_ldm_vae_checkpoint)
|
||||
from PIL import Image
|
||||
from safetensors import safe_open
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
from torchvision.transforms import ToPILImage, ToTensor
|
||||
|
||||
|
||||
def adain_color_fix(target: Image, source: Image):
|
||||
# Convert images to tensors
|
||||
to_tensor = ToTensor()
|
||||
target_tensor = to_tensor(target).unsqueeze(0)
|
||||
source_tensor = to_tensor(source).unsqueeze(0)
|
||||
|
||||
# Apply adaptive instance normalization
|
||||
result_tensor = adaptive_instance_normalization(target_tensor,
|
||||
source_tensor)
|
||||
|
||||
# Convert tensor back to image
|
||||
to_image = ToPILImage()
|
||||
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
|
||||
|
||||
return result_image
|
||||
|
||||
|
||||
def wavelet_color_fix(target: Image, source: Image):
|
||||
# Convert images to tensors
|
||||
to_tensor = ToTensor()
|
||||
target_tensor = to_tensor(target).unsqueeze(0)
|
||||
source_tensor = to_tensor(source).unsqueeze(0)
|
||||
|
||||
# Apply wavelet reconstruction
|
||||
result_tensor = wavelet_reconstruction(target_tensor, source_tensor)
|
||||
|
||||
# Convert tensor back to image
|
||||
to_image = ToPILImage()
|
||||
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
|
||||
|
||||
return result_image
|
||||
|
||||
|
||||
def calc_mean_std(feat: Tensor, eps=1e-5):
|
||||
"""Calculate mean and std for adaptive_instance_normalization.
|
||||
Args:
|
||||
feat (Tensor): 4D tensor.
|
||||
eps (float): A small value added to the variance to avoid
|
||||
divide-by-zero. Default: 1e-5.
|
||||
"""
|
||||
size = feat.size()
|
||||
assert len(size) == 4, 'The input feature should be 4D tensor.'
|
||||
b, c = size[:2]
|
||||
feat_var = feat.reshape(b, c, -1).var(dim=2) + eps
|
||||
feat_std = feat_var.sqrt().reshape(b, c, 1, 1)
|
||||
feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1)
|
||||
return feat_mean, feat_std
|
||||
|
||||
|
||||
def adaptive_instance_normalization(content_feat: Tensor, style_feat: Tensor):
|
||||
"""Adaptive instance normalization.
|
||||
Adjust the reference features to have the similar color and illuminations
|
||||
as those in the degradate features.
|
||||
Args:
|
||||
content_feat (Tensor): The reference feature.
|
||||
style_feat (Tensor): The degradate features.
|
||||
"""
|
||||
size = content_feat.size()
|
||||
style_mean, style_std = calc_mean_std(style_feat)
|
||||
content_mean, content_std = calc_mean_std(content_feat)
|
||||
normalized_feat = (content_feat
|
||||
- content_mean.expand(size)) / content_std.expand(size)
|
||||
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
||||
|
||||
|
||||
def wavelet_blur(image: Tensor, radius: int):
|
||||
"""
|
||||
Apply wavelet blur to the input tensor.
|
||||
"""
|
||||
# input shape: (1, 3, H, W)
|
||||
# convolution kernel
|
||||
kernel_vals = [
|
||||
[0.0625, 0.125, 0.0625],
|
||||
[0.125, 0.25, 0.125],
|
||||
[0.0625, 0.125, 0.0625],
|
||||
]
|
||||
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
|
||||
# add channel dimensions to the kernel to make it a 4D tensor
|
||||
kernel = kernel[None, None]
|
||||
# repeat the kernel across all input channels
|
||||
kernel = kernel.repeat(3, 1, 1, 1)
|
||||
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
|
||||
# apply convolution
|
||||
output = F.conv2d(image, kernel, groups=3, dilation=radius)
|
||||
return output
|
||||
|
||||
|
||||
def wavelet_decomposition(image: Tensor, levels=5):
|
||||
"""
|
||||
Apply wavelet decomposition to the input tensor.
|
||||
This function only returns the low frequency & the high frequency.
|
||||
"""
|
||||
high_freq = torch.zeros_like(image)
|
||||
for i in range(levels):
|
||||
radius = 2**i
|
||||
low_freq = wavelet_blur(image, radius)
|
||||
high_freq += (image - low_freq)
|
||||
image = low_freq
|
||||
|
||||
return high_freq, low_freq
|
||||
|
||||
|
||||
def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor):
|
||||
"""
|
||||
Apply wavelet decomposition, so that the content will have the same color as the style.
|
||||
"""
|
||||
# calculate the wavelet decomposition of the content feature
|
||||
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
|
||||
del content_low_freq
|
||||
# calculate the wavelet decomposition of the style feature
|
||||
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
|
||||
del style_high_freq
|
||||
# reconstruct the content feature with the style's high frequency
|
||||
return content_high_freq + style_low_freq
|
||||
|
||||
|
||||
def load_dreambooth_lora(unet, vae=None, model_path=None, model_base=''):
|
||||
if model_path is None:
|
||||
return unet
|
||||
|
||||
if model_path.endswith('.ckpt'):
|
||||
base_state_dict = torch.load(model_path)['state_dict']
|
||||
elif model_path.endswith('.safetensors'):
|
||||
state_dict = {}
|
||||
with safe_open(model_path, framework='pt', device='cpu') as f:
|
||||
for key in f.keys():
|
||||
state_dict[key] = f.get_tensor(key)
|
||||
|
||||
is_lora = all('lora' in k for k in state_dict.keys())
|
||||
if not is_lora:
|
||||
base_state_dict = state_dict
|
||||
else:
|
||||
base_state_dict = {}
|
||||
with safe_open(model_base, framework='pt', device='cpu') as f:
|
||||
for key in f.keys():
|
||||
base_state_dict[key] = f.get_tensor(key)
|
||||
|
||||
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
|
||||
base_state_dict, unet.config)
|
||||
unet.load_state_dict(converted_unet_checkpoint, strict=False)
|
||||
|
||||
if vae is not None:
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
|
||||
base_state_dict, vae.config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
|
||||
return unet, vae
|
||||
@@ -0,0 +1,365 @@
|
||||
# Part of the implementation is borrowed and modified from diffusers,
|
||||
# publicly available at https://github.com/huggingface/diffusers/tree/main/src/diffusers/models/transformer_2d.py
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.models import ModelMixin
|
||||
from diffusers.models.embeddings import ImagePositionalEmbeddings, PatchEmbed
|
||||
from diffusers.utils import BaseOutput, deprecate
|
||||
from torch import nn
|
||||
|
||||
from .attention import BasicTransformerBlock
|
||||
|
||||
|
||||
@dataclass
|
||||
class Transformer2DModelOutput(BaseOutput):
|
||||
"""
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or
|
||||
`(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
||||
Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
|
||||
for the unnoised latent pixels.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
|
||||
|
||||
class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
|
||||
embeddings) inputs.
|
||||
|
||||
When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
|
||||
transformer action. Finally, reshape to image.
|
||||
|
||||
When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
|
||||
embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
|
||||
classes of unnoised image.
|
||||
|
||||
Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
|
||||
image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
Pass if the input is continuous. The number of channels in the input and output.
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
|
||||
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
|
||||
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
|
||||
`ImagePositionalEmbeddings`.
|
||||
num_vector_embeds (`int`, *optional*):
|
||||
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
|
||||
Includes the class for the masked latent pixel.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
|
||||
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
|
||||
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
|
||||
up to but not more than steps than `num_embeds_ada_norm`.
|
||||
attention_bias (`bool`, *optional*):
|
||||
Configure if the TransformerBlocks' attention should contain a bias parameter.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
in_channels: Optional[int] = None,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
dropout: float = 0.0,
|
||||
norm_num_groups: int = 32,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
pixelwise_cross_attention_dim: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
sample_size: Optional[int] = None,
|
||||
num_vector_embeds: Optional[int] = None,
|
||||
patch_size: Optional[int] = None,
|
||||
activation_fn: str = 'geglu',
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
use_linear_projection: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_type: str = 'layer_norm',
|
||||
norm_elementwise_affine: bool = True,
|
||||
use_pixelwise_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_linear_projection = use_linear_projection
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
# 1. Transformer2DModel can process both standard continuous images of shape
|
||||
# `(batch_size, num_channels, width, height)`
|
||||
# as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
||||
# Define whether input is continuous or discrete depending on configuration
|
||||
self.is_input_continuous = (in_channels
|
||||
is not None) and (patch_size is None)
|
||||
self.is_input_vectorized = num_vector_embeds is not None
|
||||
self.is_input_patches = in_channels is not None and patch_size is not None
|
||||
|
||||
if norm_type == 'layer_norm' and num_embeds_ada_norm is not None:
|
||||
deprecation_message = (
|
||||
f'The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or'
|
||||
" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
|
||||
' Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect'
|
||||
' results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it'
|
||||
' would be very nice if you could open a Pull request for the `transformer/config.json` file'
|
||||
)
|
||||
deprecate(
|
||||
'norm_type!=num_embeds_ada_norm',
|
||||
'1.0.0',
|
||||
deprecation_message,
|
||||
standard_warn=False)
|
||||
norm_type = 'ada_norm'
|
||||
|
||||
if self.is_input_continuous and self.is_input_vectorized:
|
||||
raise ValueError(
|
||||
f'Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make'
|
||||
' sure that either `in_channels` or `num_vector_embeds` is None.'
|
||||
)
|
||||
elif self.is_input_vectorized and self.is_input_patches:
|
||||
raise ValueError(
|
||||
f'Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make'
|
||||
' sure that either `num_vector_embeds` or `num_patches` is None.'
|
||||
)
|
||||
elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
|
||||
raise ValueError(
|
||||
f'Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:'
|
||||
f' {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None.'
|
||||
)
|
||||
|
||||
# 2. Define input layers
|
||||
if self.is_input_continuous:
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = torch.nn.GroupNorm(
|
||||
num_groups=norm_num_groups,
|
||||
num_channels=in_channels,
|
||||
eps=1e-6,
|
||||
affine=True)
|
||||
if use_linear_projection:
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
if use_pixelwise_attention:
|
||||
self.proj_in_plus = nn.Linear(
|
||||
pixelwise_cross_attention_dim,
|
||||
pixelwise_cross_attention_dim)
|
||||
else:
|
||||
self.proj_in_plus = None
|
||||
else:
|
||||
self.proj_in = nn.Conv2d(
|
||||
in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||
if use_pixelwise_attention:
|
||||
self.proj_in_plus = nn.Conv2d(
|
||||
pixelwise_cross_attention_dim,
|
||||
pixelwise_cross_attention_dim,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
else:
|
||||
self.proj_in_plus = None
|
||||
|
||||
elif self.is_input_vectorized:
|
||||
assert sample_size is not None, 'Transformer2DModel over discrete input must provide sample_size'
|
||||
assert num_vector_embeds is not None, 'Transformer2DModel over discrete input must provide num_embed'
|
||||
|
||||
self.height = sample_size
|
||||
self.width = sample_size
|
||||
self.num_vector_embeds = num_vector_embeds
|
||||
self.num_latent_pixels = self.height * self.width
|
||||
|
||||
self.latent_image_embedding = ImagePositionalEmbeddings(
|
||||
num_embed=num_vector_embeds,
|
||||
embed_dim=inner_dim,
|
||||
height=self.height,
|
||||
width=self.width)
|
||||
elif self.is_input_patches:
|
||||
assert sample_size is not None, 'Transformer2DModel over patched input must provide sample_size'
|
||||
|
||||
self.height = sample_size
|
||||
self.width = sample_size
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.pos_embed = PatchEmbed(
|
||||
height=sample_size,
|
||||
width=sample_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=inner_dim,
|
||||
)
|
||||
|
||||
# 3. Define transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList([
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
pixelwise_cross_attention_dim=pixelwise_cross_attention_dim,
|
||||
activation_fn=activation_fn,
|
||||
num_embeds_ada_norm=num_embeds_ada_norm,
|
||||
attention_bias=attention_bias,
|
||||
only_cross_attention=only_cross_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
norm_type=norm_type,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
use_pixelwise_attention=use_pixelwise_attention,
|
||||
) for d in range(num_layers)
|
||||
])
|
||||
|
||||
# 4. Define output layers
|
||||
self.out_channels = in_channels if out_channels is None else out_channels
|
||||
if self.is_input_continuous:
|
||||
# TODO: should use out_channels for continuous projections
|
||||
if use_linear_projection:
|
||||
self.proj_out = nn.Linear(inner_dim, in_channels)
|
||||
else:
|
||||
self.proj_out = nn.Conv2d(
|
||||
inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
elif self.is_input_vectorized:
|
||||
self.norm_out = nn.LayerNorm(inner_dim)
|
||||
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
|
||||
elif self.is_input_patches:
|
||||
self.norm_out = nn.LayerNorm(
|
||||
inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
|
||||
self.proj_out_2 = nn.Linear(
|
||||
inner_dim, patch_size * patch_size * self.out_channels)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
encoder_pixelwise_hidden_states=None,
|
||||
timestep=None,
|
||||
class_labels=None,
|
||||
cross_attention_kwargs=None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
||||
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
||||
hidden_states
|
||||
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep ( `torch.long`, *optional*):
|
||||
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
||||
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
||||
Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels
|
||||
conditioning.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
|
||||
[`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
# 1. Input
|
||||
if self.is_input_continuous:
|
||||
batch, _, height, width = hidden_states.shape
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
|
||||
batch, height * width, inner_dim)
|
||||
|
||||
if self.proj_in_plus is not None:
|
||||
encoder_pixelwise_hidden_states = self.proj_in_plus(
|
||||
encoder_pixelwise_hidden_states)
|
||||
pixelwise_inner_dim = encoder_pixelwise_hidden_states.shape[
|
||||
1]
|
||||
encoder_pixelwise_hidden_states = encoder_pixelwise_hidden_states.permute(
|
||||
0, 2, 3, 1).reshape(batch, height * width,
|
||||
pixelwise_inner_dim)
|
||||
else:
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
|
||||
batch, height * width, inner_dim)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
if self.proj_in_plus is not None:
|
||||
pixelwise_inner_dim = encoder_pixelwise_hidden_states.shape[
|
||||
1]
|
||||
encoder_pixelwise_hidden_states = encoder_pixelwise_hidden_states.permute(
|
||||
0, 2, 3, 1).reshape(batch, height * width,
|
||||
pixelwise_inner_dim)
|
||||
encoder_pixelwise_hidden_states = self.proj_in_plus(
|
||||
encoder_pixelwise_hidden_states)
|
||||
|
||||
elif self.is_input_vectorized:
|
||||
hidden_states = self.latent_image_embedding(hidden_states)
|
||||
elif self.is_input_patches:
|
||||
hidden_states = self.pos_embed(hidden_states)
|
||||
|
||||
# 2. Blocks
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_pixelwise_hidden_states=encoder_pixelwise_hidden_states,
|
||||
timestep=timestep,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
class_labels=class_labels,
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
if self.is_input_continuous:
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = hidden_states.reshape(batch, height, width,
|
||||
inner_dim).permute(
|
||||
0, 3, 1,
|
||||
2).contiguous()
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
else:
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.reshape(batch, height, width,
|
||||
inner_dim).permute(
|
||||
0, 3, 1,
|
||||
2).contiguous()
|
||||
|
||||
output = hidden_states + residual
|
||||
elif self.is_input_vectorized:
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
logits = self.out(hidden_states)
|
||||
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
|
||||
logits = logits.permute(0, 2, 1)
|
||||
|
||||
# log(p(x_0))
|
||||
output = F.log_softmax(logits.double(), dim=1).float()
|
||||
elif self.is_input_patches:
|
||||
# TODO: cleanup!
|
||||
conditioning = self.transformer_blocks[0].norm1.emb(
|
||||
timestep, class_labels, hidden_dtype=hidden_states.dtype)
|
||||
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(
|
||||
2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states) * (
|
||||
1 + scale[:, None]) + shift[:, None]
|
||||
hidden_states = self.proj_out_2(hidden_states)
|
||||
|
||||
# unpatchify
|
||||
height = width = int(hidden_states.shape[1]**0.5)
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(-1, height, width, self.patch_size, self.patch_size,
|
||||
self.out_channels))
|
||||
hidden_states = torch.einsum('nhwpqc->nchpwq', hidden_states)
|
||||
output = hidden_states.reshape(
|
||||
shape=(-1, self.out_channels, height * self.patch_size,
|
||||
width * self.patch_size))
|
||||
|
||||
if not return_dict:
|
||||
return (output, )
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
2983
modelscope/models/cv/image_super_resolution_pasd/unet_2d_blocks.py
Normal file
2983
modelscope/models/cv/image_super_resolution_pasd/unet_2d_blocks.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,887 @@
|
||||
# Part of the implementation is borrowed and modified from diffusers,
|
||||
# publicly available at https://github.com/huggingface/diffusers/tree/main/src/diffusers/models/unet_2d_condition.py
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import json
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.loaders import UNet2DConditionLoadersMixin
|
||||
from diffusers.models import ModelMixin
|
||||
from diffusers.models.attention_processor import (AttentionProcessor,
|
||||
AttnProcessor)
|
||||
from diffusers.models.embeddings import (GaussianFourierProjection,
|
||||
TextTimeEmbedding, TimestepEmbedding,
|
||||
Timesteps)
|
||||
from diffusers.utils import BaseOutput, logging
|
||||
|
||||
from .unet_2d_blocks import (CrossAttnDownBlock2D, CrossAttnUpBlock2D,
|
||||
DownBlock2D, UNetMidBlock2DCrossAttn,
|
||||
UNetMidBlock2DSimpleCrossAttn, UpBlock2D,
|
||||
get_down_block, get_up_block)
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class UNet2DConditionOutput(BaseOutput):
|
||||
"""
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
|
||||
|
||||
class UNet2DConditionModel(ModelMixin, ConfigMixin,
|
||||
UNet2DConditionLoadersMixin):
|
||||
r"""
|
||||
UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
|
||||
and returns sample shaped output.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
|
||||
implements for all the models (such as downloading or saving, etc.)
|
||||
|
||||
Parameters:
|
||||
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
||||
Height and width of input/output sample.
|
||||
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
|
||||
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
|
||||
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
||||
Whether to flip the sin to cos in the time embedding.
|
||||
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||
The tuple of downsample blocks to use.
|
||||
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
||||
The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`, will skip the
|
||||
mid block layer if `None`.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
|
||||
The tuple of upsample blocks to use.
|
||||
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
|
||||
Whether to include self-attention in the basic transformer blocks, see
|
||||
[`~models.attention.BasicTransformerBlock`].
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
||||
The tuple of output channels for each block.
|
||||
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
||||
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
|
||||
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
|
||||
If `None`, it will skip the normalization and activation layers in post-processing
|
||||
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
|
||||
The dimension of the cross attention features.
|
||||
encoder_hid_dim (`int`, *optional*, defaults to None):
|
||||
If given, `encoder_hidden_states` will be projected from this dimension to `cross_attention_dim`.
|
||||
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
||||
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
||||
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
|
||||
class_embed_type (`str`, *optional*, defaults to None):
|
||||
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
||||
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
||||
addition_embed_type (`str`, *optional*, defaults to None):
|
||||
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
||||
"text". "text" will use the `TextTimeEmbedding` layer.
|
||||
num_class_embeds (`int`, *optional*, defaults to None):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||
class conditioning with `class_embed_type` equal to `None`.
|
||||
time_embedding_type (`str`, *optional*, default to `positional`):
|
||||
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
|
||||
time_embedding_dim (`int`, *optional*, default to `None`):
|
||||
An optional override for the dimension of the projected time embedding.
|
||||
time_embedding_act_fn (`str`, *optional*, default to `None`):
|
||||
Optional activation function to use on the time embeddings only one time before they as passed to the rest
|
||||
of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
|
||||
timestep_post_act (`str, *optional*, default to `None`):
|
||||
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
|
||||
time_cond_proj_dim (`int`, *optional*, default to `None`):
|
||||
The dimension of `cond_proj` layer in timestep embedding.
|
||||
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
|
||||
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
|
||||
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
|
||||
using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
|
||||
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
|
||||
embeddings with the class embeddings.
|
||||
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
|
||||
Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
|
||||
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is None, the
|
||||
`only_cross_attention` value will be used as the value for `mid_block_only_cross_attention`. Else, it will
|
||||
default to `False`.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: Optional[int] = None,
|
||||
in_channels: int = 4,
|
||||
out_channels: int = 4,
|
||||
center_input_sample: bool = False,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str] = (
|
||||
'CrossAttnDownBlock2D',
|
||||
'CrossAttnDownBlock2D',
|
||||
'CrossAttnDownBlock2D',
|
||||
'DownBlock2D',
|
||||
),
|
||||
mid_block_type: Optional[str] = 'UNetMidBlock2DCrossAttn',
|
||||
up_block_types: Tuple[str] = ('UpBlock2D', 'CrossAttnUpBlock2D',
|
||||
'CrossAttnUpBlock2D',
|
||||
'CrossAttnUpBlock2D'),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
layers_per_block: Union[int, Tuple[int]] = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
act_fn: str = 'silu',
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: Union[int, Tuple[int]] = 1280,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
attention_head_dim: Union[int, Tuple[int]] = 8,
|
||||
dual_cross_attention: bool = False,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = 'default',
|
||||
resnet_skip_time_act: bool = False,
|
||||
resnet_out_scale_factor: int = 1.0,
|
||||
time_embedding_type: str = 'positional',
|
||||
time_embedding_dim: Optional[int] = None,
|
||||
time_embedding_act_fn: Optional[str] = None,
|
||||
timestep_post_act: Optional[str] = None,
|
||||
time_cond_proj_dim: Optional[int] = None,
|
||||
conv_in_kernel: int = 3,
|
||||
conv_out_kernel: int = 3,
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
class_embeddings_concat: bool = False,
|
||||
mid_block_only_cross_attention: Optional[bool] = None,
|
||||
cross_attention_norm: Optional[str] = None,
|
||||
addition_embed_type_num_heads=64,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sample_size = sample_size
|
||||
|
||||
# Check inputs
|
||||
if len(down_block_types) != len(up_block_types):
|
||||
raise ValueError(
|
||||
f'Must provide the same number of `down_block_types` as `up_block_types`. \
|
||||
`down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}.'
|
||||
)
|
||||
|
||||
if len(block_out_channels) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f'Must provide the same number of `block_out_channels` as `down_block_types`. \
|
||||
`block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}.'
|
||||
)
|
||||
|
||||
if not isinstance(
|
||||
only_cross_attention,
|
||||
bool) and len(only_cross_attention) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f'Must provide the same number of `only_cross_attention` as `down_block_types`. \
|
||||
`only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}.'
|
||||
)
|
||||
|
||||
if not isinstance(
|
||||
attention_head_dim,
|
||||
int) and len(attention_head_dim) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f'Must provide the same number of `attention_head_dim` as `down_block_types`. \
|
||||
`attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}.'
|
||||
)
|
||||
|
||||
if isinstance(
|
||||
cross_attention_dim,
|
||||
list) and len(cross_attention_dim) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f'Must provide the same number of `cross_attention_dim` as `down_block_types`. \
|
||||
`cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}.'
|
||||
)
|
||||
|
||||
if not isinstance(
|
||||
layers_per_block,
|
||||
int) and len(layers_per_block) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f'Must provide the same number of `layers_per_block` as `down_block_types`. \
|
||||
`layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}.'
|
||||
)
|
||||
|
||||
# input
|
||||
conv_in_padding = (conv_in_kernel - 1) // 2
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels,
|
||||
block_out_channels[0],
|
||||
kernel_size=conv_in_kernel,
|
||||
padding=conv_in_padding)
|
||||
|
||||
# time
|
||||
if time_embedding_type == 'fourier':
|
||||
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
|
||||
if time_embed_dim % 2 != 0:
|
||||
raise ValueError(
|
||||
f'`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.'
|
||||
)
|
||||
self.time_proj = GaussianFourierProjection(
|
||||
time_embed_dim // 2,
|
||||
set_W_to_weight=False,
|
||||
log=False,
|
||||
flip_sin_to_cos=flip_sin_to_cos)
|
||||
timestep_input_dim = time_embed_dim
|
||||
elif time_embedding_type == 'positional':
|
||||
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
||||
|
||||
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos,
|
||||
freq_shift)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
f'{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`.'
|
||||
)
|
||||
|
||||
self.time_embedding = TimestepEmbedding(
|
||||
timestep_input_dim,
|
||||
time_embed_dim,
|
||||
act_fn=act_fn,
|
||||
post_act_fn=timestep_post_act,
|
||||
cond_proj_dim=time_cond_proj_dim,
|
||||
)
|
||||
|
||||
if encoder_hid_dim is not None:
|
||||
self.encoder_hid_proj = nn.Linear(encoder_hid_dim,
|
||||
cross_attention_dim)
|
||||
else:
|
||||
self.encoder_hid_proj = None
|
||||
|
||||
# class embedding
|
||||
if class_embed_type is None and num_class_embeds is not None:
|
||||
self.class_embedding = nn.Embedding(num_class_embeds,
|
||||
time_embed_dim)
|
||||
elif class_embed_type == 'timestep':
|
||||
self.class_embedding = TimestepEmbedding(
|
||||
timestep_input_dim, time_embed_dim, act_fn=act_fn)
|
||||
elif class_embed_type == 'identity':
|
||||
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
||||
elif class_embed_type == 'projection':
|
||||
if projection_class_embeddings_input_dim is None:
|
||||
raise ValueError(
|
||||
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
||||
)
|
||||
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
||||
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
||||
# 2. it projects from an arbitrary input dimension.
|
||||
#
|
||||
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
||||
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
||||
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
||||
self.class_embedding = TimestepEmbedding(
|
||||
projection_class_embeddings_input_dim, time_embed_dim)
|
||||
elif class_embed_type == 'simple_projection':
|
||||
if projection_class_embeddings_input_dim is None:
|
||||
raise ValueError(
|
||||
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
|
||||
)
|
||||
self.class_embedding = nn.Linear(
|
||||
projection_class_embeddings_input_dim, time_embed_dim)
|
||||
else:
|
||||
self.class_embedding = None
|
||||
|
||||
if addition_embed_type == 'text':
|
||||
if encoder_hid_dim is not None:
|
||||
text_time_embedding_from_dim = encoder_hid_dim
|
||||
else:
|
||||
text_time_embedding_from_dim = cross_attention_dim
|
||||
|
||||
self.add_embedding = TextTimeEmbedding(
|
||||
text_time_embedding_from_dim,
|
||||
time_embed_dim,
|
||||
num_heads=addition_embed_type_num_heads)
|
||||
elif addition_embed_type is not None:
|
||||
raise ValueError(
|
||||
f"addition_embed_type: {addition_embed_type} must be None or 'text'."
|
||||
)
|
||||
|
||||
if time_embedding_act_fn is None:
|
||||
self.time_embed_act = None
|
||||
elif time_embedding_act_fn == 'swish':
|
||||
self.time_embed_act = lambda x: F.silu(x)
|
||||
elif time_embedding_act_fn == 'mish':
|
||||
self.time_embed_act = nn.Mish()
|
||||
elif time_embedding_act_fn == 'silu':
|
||||
self.time_embed_act = nn.SiLU()
|
||||
elif time_embedding_act_fn == 'gelu':
|
||||
self.time_embed_act = nn.GELU()
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Unsupported activation function: {time_embedding_act_fn}')
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
if isinstance(only_cross_attention, bool):
|
||||
if mid_block_only_cross_attention is None:
|
||||
mid_block_only_cross_attention = only_cross_attention
|
||||
|
||||
only_cross_attention = [only_cross_attention
|
||||
] * len(down_block_types)
|
||||
|
||||
if mid_block_only_cross_attention is None:
|
||||
mid_block_only_cross_attention = False
|
||||
|
||||
if isinstance(attention_head_dim, int):
|
||||
attention_head_dim = (attention_head_dim, ) * len(down_block_types)
|
||||
|
||||
if isinstance(cross_attention_dim, int):
|
||||
cross_attention_dim = (
|
||||
cross_attention_dim, ) * len(down_block_types)
|
||||
|
||||
if isinstance(layers_per_block, int):
|
||||
layers_per_block = [layers_per_block] * len(down_block_types)
|
||||
|
||||
if class_embeddings_concat:
|
||||
# The time embeddings are concatenated with the class embeddings. The dimension of the
|
||||
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
|
||||
# regular time embeddings
|
||||
blocks_time_embed_dim = time_embed_dim * 2
|
||||
else:
|
||||
blocks_time_embed_dim = time_embed_dim
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block[i],
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=blocks_time_embed_dim,
|
||||
add_downsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim[i],
|
||||
attn_num_head_channels=attention_head_dim[i],
|
||||
downsample_padding=downsample_padding,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
resnet_skip_time_act=resnet_skip_time_act,
|
||||
resnet_out_scale_factor=resnet_out_scale_factor,
|
||||
cross_attention_norm=cross_attention_norm,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
if mid_block_type == 'UNetMidBlock2DCrossAttn':
|
||||
self.mid_block = UNetMidBlock2DCrossAttn(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=blocks_time_embed_dim,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
cross_attention_dim=cross_attention_dim[-1],
|
||||
attn_num_head_channels=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
elif mid_block_type == 'UNetMidBlock2DSimpleCrossAttn':
|
||||
self.mid_block = UNetMidBlock2DSimpleCrossAttn(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=blocks_time_embed_dim,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
cross_attention_dim=cross_attention_dim[-1],
|
||||
attn_num_head_channels=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
skip_time_act=resnet_skip_time_act,
|
||||
only_cross_attention=mid_block_only_cross_attention,
|
||||
cross_attention_norm=cross_attention_norm,
|
||||
)
|
||||
elif mid_block_type is None:
|
||||
self.mid_block = None
|
||||
else:
|
||||
raise ValueError(f'unknown mid_block_type : {mid_block_type}')
|
||||
|
||||
# count how many layers upsample the images
|
||||
self.num_upsamplers = 0
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
||||
reversed_layers_per_block = list(reversed(layers_per_block))
|
||||
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
||||
only_cross_attention = list(reversed(only_cross_attention))
|
||||
|
||||
reversed_pixelwise_cross_attention_dim = [-1, 1280, 640, 320]
|
||||
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
input_channel = reversed_block_out_channels[min(
|
||||
i + 1,
|
||||
len(block_out_channels) - 1)]
|
||||
|
||||
# add upsample block for all BUT final layer
|
||||
if not is_final_block:
|
||||
add_upsample = True
|
||||
self.num_upsamplers += 1
|
||||
else:
|
||||
add_upsample = False
|
||||
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=reversed_layers_per_block[i] + 1,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=blocks_time_embed_dim,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=reversed_cross_attention_dim[i],
|
||||
pixelwise_cross_attention_dim=
|
||||
reversed_pixelwise_cross_attention_dim[i],
|
||||
attn_num_head_channels=reversed_attention_head_dim[i],
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
resnet_skip_time_act=resnet_skip_time_act,
|
||||
resnet_out_scale_factor=resnet_out_scale_factor,
|
||||
cross_attention_norm=cross_attention_norm,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
if norm_num_groups is not None:
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
num_channels=block_out_channels[0],
|
||||
num_groups=norm_num_groups,
|
||||
eps=norm_eps)
|
||||
|
||||
if act_fn == 'swish':
|
||||
self.conv_act = lambda x: F.silu(x)
|
||||
elif act_fn == 'mish':
|
||||
self.conv_act = nn.Mish()
|
||||
elif act_fn == 'silu':
|
||||
self.conv_act = nn.SiLU()
|
||||
elif act_fn == 'gelu':
|
||||
self.conv_act = nn.GELU()
|
||||
else:
|
||||
raise ValueError(f'Unsupported activation function: {act_fn}')
|
||||
|
||||
else:
|
||||
self.conv_norm_out = None
|
||||
self.conv_act = None
|
||||
|
||||
conv_out_padding = (conv_out_kernel - 1) // 2
|
||||
self.conv_out = nn.Conv2d(
|
||||
block_out_channels[0],
|
||||
out_channels,
|
||||
kernel_size=conv_out_kernel,
|
||||
padding=conv_out_padding)
|
||||
|
||||
@property
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module,
|
||||
processors: Dict[str,
|
||||
AttentionProcessor]):
|
||||
if hasattr(module, 'set_processor'):
|
||||
processors[f'{name}.processor'] = module.processor
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f'{name}.{sub_name}', child,
|
||||
processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor,
|
||||
Dict[str,
|
||||
AttentionProcessor]]):
|
||||
r"""
|
||||
Parameters:
|
||||
`processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
of **all** `Attention` layers.
|
||||
In case `processor` is a dict, the key needs to define the path to
|
||||
the corresponding cross attention processor.
|
||||
This is strongly recommended when setting trainable attention processors.:
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f'A dict of processors was passed, but the number of processors {len(processor)} does not match the'
|
||||
f' number of attention layers: {count}. Please make sure to pass {count} processor classes.'
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module,
|
||||
processor):
|
||||
if hasattr(module, 'set_processor'):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f'{name}.processor'))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f'{name}.{sub_name}', child,
|
||||
processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
"""
|
||||
self.set_attn_processor(AttnProcessor())
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||||
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||||
`"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
|
||||
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
||||
must be a multiple of `slice_size`.
|
||||
"""
|
||||
sliceable_head_dims = []
|
||||
|
||||
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
||||
if hasattr(module, 'set_attention_slice'):
|
||||
sliceable_head_dims.append(module.sliceable_head_dim)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_retrieve_sliceable_dims(child)
|
||||
|
||||
# retrieve number of attention layers
|
||||
for module in self.children():
|
||||
fn_recursive_retrieve_sliceable_dims(module)
|
||||
|
||||
num_sliceable_layers = len(sliceable_head_dims)
|
||||
|
||||
if slice_size == 'auto':
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
||||
elif slice_size == 'max':
|
||||
# make smallest slice possible
|
||||
slice_size = num_sliceable_layers * [1]
|
||||
|
||||
slice_size = num_sliceable_layers * [slice_size] if not isinstance(
|
||||
slice_size, list) else slice_size
|
||||
|
||||
if len(slice_size) != len(sliceable_head_dims):
|
||||
raise ValueError(
|
||||
f'You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different'
|
||||
f' attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}.'
|
||||
)
|
||||
|
||||
for i in range(len(slice_size)):
|
||||
size = slice_size[i]
|
||||
dim = sliceable_head_dims[i]
|
||||
if size is not None and size > dim:
|
||||
raise ValueError(
|
||||
f'size {size} has to be smaller or equal to {dim}.')
|
||||
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_attention_slice method
|
||||
# gets the message
|
||||
def fn_recursive_set_attention_slice(module: torch.nn.Module,
|
||||
slice_size: List[int]):
|
||||
if hasattr(module, 'set_attention_slice'):
|
||||
module.set_attention_slice(slice_size.pop())
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_set_attention_slice(child, slice_size)
|
||||
|
||||
reversed_slice_size = list(reversed(slice_size))
|
||||
for module in self.children():
|
||||
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D,
|
||||
CrossAttnUpBlock2D, UpBlock2D)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet2DConditionOutput, Tuple]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
||||
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
||||
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
||||
|
||||
Returns:
|
||||
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
||||
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
||||
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
||||
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
||||
# on the fly if necessary.
|
||||
default_overall_up_factor = 2**self.num_upsamplers
|
||||
|
||||
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
||||
forward_upsample_size = False
|
||||
upsample_size = None
|
||||
|
||||
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
||||
logger.info(
|
||||
'Forward upsample size to force interpolation output size.')
|
||||
forward_upsample_size = True
|
||||
|
||||
# prepare attention_mask
|
||||
if attention_mask is not None:
|
||||
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# 0. center input if necessary
|
||||
if self.config.center_input_sample:
|
||||
sample = 2 * sample - 1.0
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = sample.device.type == 'mps'
|
||||
if isinstance(timestep, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
timesteps = torch.tensor([timesteps],
|
||||
dtype=dtype,
|
||||
device=sample.device)
|
||||
elif len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# `Timesteps` does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=sample.dtype)
|
||||
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
|
||||
if self.class_embedding is not None:
|
||||
if class_labels is None:
|
||||
raise ValueError(
|
||||
'class_labels should be provided when num_class_embeds > 0'
|
||||
)
|
||||
|
||||
if self.config.class_embed_type == 'timestep':
|
||||
class_labels = self.time_proj(class_labels)
|
||||
|
||||
# `Timesteps` does not contain any weights and will always return f32 tensors
|
||||
# there might be better ways to encapsulate this.
|
||||
class_labels = class_labels.to(dtype=sample.dtype)
|
||||
|
||||
class_emb = self.class_embedding(class_labels).to(
|
||||
dtype=sample.dtype)
|
||||
|
||||
if self.config.class_embeddings_concat:
|
||||
emb = torch.cat([emb, class_emb], dim=-1)
|
||||
else:
|
||||
emb = emb + class_emb
|
||||
|
||||
if self.config.addition_embed_type == 'text':
|
||||
aug_emb = self.add_embedding(encoder_hidden_states)
|
||||
emb = emb + aug_emb
|
||||
|
||||
if self.time_embed_act is not None:
|
||||
emb = self.time_embed_act(emb)
|
||||
|
||||
if self.encoder_hid_proj is not None:
|
||||
encoder_hidden_states = self.encoder_hid_proj(
|
||||
encoder_hidden_states)
|
||||
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample, )
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, 'has_cross_attention'
|
||||
) and downsample_block.has_cross_attention:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample, temb=emb)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
if down_block_additional_residuals is not None:
|
||||
new_down_block_res_samples = ()
|
||||
|
||||
for down_block_res_sample, down_block_additional_residual in zip(
|
||||
down_block_res_samples, down_block_additional_residuals):
|
||||
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
||||
new_down_block_res_samples = new_down_block_res_samples + (
|
||||
down_block_res_sample, )
|
||||
|
||||
down_block_res_samples = new_down_block_res_samples
|
||||
|
||||
# 4. mid
|
||||
if self.mid_block is not None:
|
||||
sample = self.mid_block(
|
||||
sample,
|
||||
emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
|
||||
if mid_block_additional_residual is not None:
|
||||
sample = sample + mid_block_additional_residual
|
||||
|
||||
# 5. up
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
is_final_block = i == len(self.up_blocks) - 1
|
||||
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
|
||||
down_block_res_samples = down_block_res_samples[:-len(
|
||||
upsample_block.resnets)]
|
||||
|
||||
down_block_additional_residual = down_block_additional_residuals[
|
||||
-len(upsample_block.resnets):]
|
||||
down_block_additional_residuals = down_block_additional_residuals[:-len(
|
||||
upsample_block.resnets)]
|
||||
|
||||
# if we have not reached the final block and need to forward the
|
||||
# upsample size, we do it here
|
||||
if not is_final_block and forward_upsample_size:
|
||||
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||
|
||||
if hasattr(upsample_block, 'has_cross_attention'
|
||||
) and upsample_block.has_cross_attention:
|
||||
sample = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_pixelwise_hidden_states_tuple=
|
||||
down_block_additional_residual,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
upsample_size=upsample_size,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
else:
|
||||
sample = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
upsample_size=upsample_size)
|
||||
|
||||
# 6. post-process
|
||||
if self.conv_norm_out:
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
if not return_dict:
|
||||
return (sample, )
|
||||
|
||||
return UNet2DConditionOutput(sample=sample)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained_(cls, pretrained_model_path, subfolder=None, **kwargs):
|
||||
if subfolder is not None:
|
||||
pretrained_model_path = os.path.join(pretrained_model_path,
|
||||
subfolder)
|
||||
|
||||
config_file = os.path.join(pretrained_model_path, 'config.json')
|
||||
if not os.path.isfile(config_file):
|
||||
raise RuntimeError(f'{config_file} does not exist')
|
||||
with open(config_file, 'r') as f:
|
||||
config = json.load(f)
|
||||
|
||||
from diffusers.utils import WEIGHTS_NAME
|
||||
model = cls.from_config(config)
|
||||
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
||||
if not os.path.isfile(model_file):
|
||||
raise RuntimeError(f'{model_file} does not exist')
|
||||
state_dict = torch.load(model_file, map_location='cpu')
|
||||
for k, v in model.state_dict().items():
|
||||
if 'attn2_plus' in k:
|
||||
state_dict.update({k: v})
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
return model
|
||||
@@ -9,6 +9,8 @@ import torch
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base.base_torch_model import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.compatible_with_transformers import \
|
||||
compatible_position_ids
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
@@ -40,6 +42,8 @@ class ReferringVideoObjectSegmentation(TorchModel):
|
||||
params_dict = torch.load(model_path, map_location='cpu')
|
||||
if 'model_state_dict' in params_dict.keys():
|
||||
params_dict = params_dict['model_state_dict']
|
||||
compatible_position_ids(
|
||||
params_dict, 'transformer.text_encoder.embeddings.position_ids')
|
||||
self.model.load_state_dict(params_dict, strict=True)
|
||||
|
||||
self.set_postprocessor(self.cfg.pipeline.dataset_name)
|
||||
|
||||
@@ -15,6 +15,9 @@ import torch.utils.checkpoint as checkpoint
|
||||
from torch import nn
|
||||
from transformers import BertConfig, BertForMaskedLM
|
||||
|
||||
from modelscope.utils.compatible_with_transformers import \
|
||||
compatible_position_ids
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
"""Subclass torch's LayerNorm to handle fp16."""
|
||||
@@ -290,6 +293,8 @@ class TEAM(nn.Module):
|
||||
self.text_tensor_fc = nn.Linear(1024, 768)
|
||||
|
||||
params = torch.load(pretrained, 'cpu')
|
||||
compatible_position_ids(params,
|
||||
'text_model.bert.embeddings.position_ids')
|
||||
self.load_state_dict(params, strict=True)
|
||||
|
||||
def get_feature(self, text_data=None, text_mask=None, img_tensor=None):
|
||||
|
||||
@@ -12,6 +12,8 @@ from modelscope.metainfo import Models
|
||||
from modelscope.models import TorchModel
|
||||
from modelscope.models.base import Tensor
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.compatible_with_transformers import \
|
||||
compatible_position_ids
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
|
||||
@@ -47,6 +49,9 @@ class StarForTextToSql(TorchModel):
|
||||
open(
|
||||
os.path.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), 'rb'),
|
||||
map_location=self.device)
|
||||
compatible_position_ids(
|
||||
check_point['model'],
|
||||
'encoder.input_layer.plm_model.embeddings.position_ids')
|
||||
self.model.load_state_dict(check_point['model'])
|
||||
|
||||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
|
||||
@@ -22,6 +22,8 @@ from modelscope.models.base import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.models.nlp.unite.configuration import InputFormat
|
||||
from modelscope.outputs.nlp_outputs import TranslationEvaluationOutput
|
||||
from modelscope.utils.compatible_with_transformers import \
|
||||
compatible_position_ids
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
@@ -305,6 +307,8 @@ class UniTEForTranslationEvaluation(TorchModel):
|
||||
self.encoder.pooler = None
|
||||
else:
|
||||
state_dict = torch.load(path, map_location=device)
|
||||
compatible_position_ids(state_dict,
|
||||
'encoder.embeddings.position_ids')
|
||||
self.load_state_dict(state_dict)
|
||||
logger.info('Loading checkpoint parameters from %s' % path)
|
||||
return
|
||||
|
||||
@@ -15,6 +15,8 @@ from modelscope.models import TorchModel
|
||||
from modelscope.models.base import Tensor
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.outputs import DialogueUserSatisfactionEstimationModelOutput
|
||||
from modelscope.utils.compatible_with_transformers import \
|
||||
compatible_position_ids
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from .transformer import TransformerEncoder
|
||||
|
||||
@@ -45,8 +47,10 @@ class UserSatisfactionEstimation(TorchModel):
|
||||
self.device = device
|
||||
self.model = self.init_model()
|
||||
model_ckpt = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE)
|
||||
self.model.load_state_dict(
|
||||
torch.load(model_ckpt, map_location=torch.device('cpu')))
|
||||
stats_dict = torch.load(model_ckpt, map_location=torch.device('cpu'))
|
||||
compatible_position_ids(stats_dict,
|
||||
'private.bert.embeddings.position_ids')
|
||||
self.model.load_state_dict(stats_dict)
|
||||
|
||||
def init_model(self):
|
||||
configs = {
|
||||
|
||||
@@ -727,6 +727,7 @@ TASK_OUTPUTS = {
|
||||
# {"output_img": np.array with shape (h, w, 3)}
|
||||
Tasks.skin_retouching: [OutputKeys.OUTPUT_IMG],
|
||||
Tasks.image_super_resolution: [OutputKeys.OUTPUT_IMG],
|
||||
Tasks.image_super_resolution_pasd: [OutputKeys.OUTPUT_IMG],
|
||||
Tasks.image_colorization: [OutputKeys.OUTPUT_IMG],
|
||||
Tasks.image_color_enhancement: [OutputKeys.OUTPUT_IMG],
|
||||
Tasks.image_denoising: [OutputKeys.OUTPUT_IMG],
|
||||
|
||||
@@ -182,6 +182,10 @@ TASK_INPUTS = {
|
||||
InputType.IMAGE,
|
||||
Tasks.crowd_counting:
|
||||
InputType.IMAGE,
|
||||
Tasks.image_super_resolution_pasd: {
|
||||
'image': InputType.IMAGE,
|
||||
'prompt': InputType.TEXT,
|
||||
},
|
||||
Tasks.image_inpainting: {
|
||||
'img': InputType.IMAGE,
|
||||
'mask': InputType.IMAGE,
|
||||
|
||||
@@ -38,6 +38,7 @@ if TYPE_CHECKING:
|
||||
from .image_semantic_segmentation_pipeline import ImageSemanticSegmentationPipeline
|
||||
from .image_style_transfer_pipeline import ImageStyleTransferPipeline
|
||||
from .image_super_resolution_pipeline import ImageSuperResolutionPipeline
|
||||
from .image_super_resolution_pasd_pipeline import ImageSuperResolutionPASDPipeline
|
||||
from .image_to_image_generate_pipeline import Image2ImageGenerationPipeline
|
||||
from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline
|
||||
from .image_inpainting_pipeline import ImageInpaintingPipeline
|
||||
@@ -151,6 +152,8 @@ else:
|
||||
['ImageSemanticSegmentationPipeline'],
|
||||
'image_style_transfer_pipeline': ['ImageStyleTransferPipeline'],
|
||||
'image_super_resolution_pipeline': ['ImageSuperResolutionPipeline'],
|
||||
'image_super_resolution_pasd_pipeline':
|
||||
['ImageSuperResolutionPASDPipeline'],
|
||||
'image_to_image_translation_pipeline':
|
||||
['Image2ImageTranslationPipeline'],
|
||||
'product_retrieval_embedding_pipeline':
|
||||
@@ -185,8 +188,9 @@ else:
|
||||
['FaceAttributeRecognitionPipeline'],
|
||||
'mtcnn_face_detection_pipeline': ['MtcnnFaceDetectionPipeline'],
|
||||
'hand_static_pipeline': ['HandStaticPipeline'],
|
||||
'referring_video_object_segmentation_pipeline':
|
||||
['ReferringVideoObjectSegmentationPipeline'],
|
||||
'referring_video_object_segmentation_pipeline': [
|
||||
'ReferringVideoObjectSegmentationPipeline'
|
||||
],
|
||||
'language_guided_video_summarization_pipeline': [
|
||||
'LanguageGuidedVideoSummarizationPipeline'
|
||||
],
|
||||
|
||||
229
modelscope/pipelines/cv/image_super_resolution_pasd_pipeline.py
Normal file
229
modelscope/pipelines/cv/image_super_resolution_pasd_pipeline.py
Normal file
@@ -0,0 +1,229 @@
|
||||
# Copyright © Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from diffusers import AutoencoderKL, UniPCMultistepScheduler
|
||||
from torchvision.models import ResNet50_Weights, resnet50
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.cv.image_portrait_enhancement.retinaface import \
|
||||
detection
|
||||
from modelscope.models.cv.image_super_resolution_pasd import (
|
||||
ControlNetModel, UNet2DConditionModel)
|
||||
from modelscope.models.cv.image_super_resolution_pasd.misc import (
|
||||
load_dreambooth_lora, wavelet_color_fix)
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.pipelines.multi_modal.diffusers_wrapped.pasd_pipeline import \
|
||||
PixelAwareStableDiffusionPipeline
|
||||
from modelscope.preprocessors.image import load_image
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.device import create_device
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.image_super_resolution_pasd,
|
||||
module_name=Pipelines.image_super_resolution_pasd)
|
||||
class ImageSuperResolutionPASDPipeline(Pipeline):
|
||||
""" Pixel-Aware Stable Diffusion for Realistic Image Super-Resolution Pipeline.
|
||||
|
||||
Example:
|
||||
|
||||
>>> import cv2
|
||||
>>> from modelscope.outputs import OutputKeys
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> from modelscope.utils.constant import Tasks
|
||||
|
||||
>>> input_location = 'example_image.png'
|
||||
|
||||
>>> input = {
|
||||
>>> 'image': input_location,
|
||||
>>> 'upscale': 2,
|
||||
>>> 'prompt': prompt,
|
||||
>>> 'fidelity_scale_fg': 1.5,
|
||||
>>> 'fidelity_scale_bg': 0.7
|
||||
>>> }
|
||||
>>> pasd = pipeline(Tasks.image_super_resolution_pasd, model='damo/PASD_image_super_resolutions')
|
||||
>>> output = pasd(input)[OutputKeys.OUTPUT_IMG]
|
||||
>>> cv2.imwrite('result.png', output)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, model: str, device_name: str = 'cuda', **kwargs):
|
||||
"""
|
||||
use `model` to create a image super resolution pipeline for prediction
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
|
||||
torch_dtype = kwargs.get('torch_dtype', torch.float16)
|
||||
self.device = create_device(device_name)
|
||||
self.config = Config.from_file(
|
||||
os.path.join(model, ModelFile.CONFIGURATION))
|
||||
cfg = self.config.model_cfg
|
||||
dreambooth_lora_ckpt = cfg['dreambooth_lora_ckpt']
|
||||
tiled_size = cfg['tiled_size']
|
||||
self.process_size = cfg['process_size']
|
||||
|
||||
scheduler = UniPCMultistepScheduler.from_pretrained(
|
||||
model, subfolder='scheduler')
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
model, subfolder='text_encoder')
|
||||
tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer')
|
||||
vae = AutoencoderKL.from_pretrained(model, subfolder='vae')
|
||||
feature_extractor = CLIPImageProcessor.from_pretrained(
|
||||
f'{model}/feature_extractor')
|
||||
unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet')
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
model, subfolder='controlnet')
|
||||
|
||||
unet, vae = load_dreambooth_lora(unet, vae,
|
||||
f'{model}/{dreambooth_lora_ckpt}')
|
||||
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
unet.requires_grad_(False)
|
||||
controlnet.requires_grad_(False)
|
||||
|
||||
text_encoder.to(self.device, dtype=torch_dtype)
|
||||
vae.to(self.device, dtype=torch_dtype)
|
||||
unet.to(self.device, dtype=torch_dtype)
|
||||
controlnet.to(self.device, dtype=torch_dtype)
|
||||
|
||||
self.pipeline = PixelAwareStableDiffusionPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
unet=unet,
|
||||
controlnet=controlnet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
self.pipeline._init_tiled_vae(decoder_tile_size=tiled_size)
|
||||
self.pipeline.enable_model_cpu_offload()
|
||||
|
||||
self.weights = ResNet50_Weights.DEFAULT
|
||||
self.resnet_preprocess = self.weights.transforms()
|
||||
self.resnet = resnet50(weights=self.weights)
|
||||
self.resnet.eval()
|
||||
|
||||
self.threshold = 0.8
|
||||
detector_model_path = f'{model}/RetinaFace-R50.pth'
|
||||
self.face_detector = detection.RetinaFaceDetection(
|
||||
detector_model_path, self.device)
|
||||
|
||||
def preprocess(self, input: Input):
|
||||
return input
|
||||
|
||||
def forward(self, inputs: Dict[str, Any]):
|
||||
if not isinstance(inputs, dict):
|
||||
raise ValueError(
|
||||
f'Expected the input to be a dictionary, but got {type(input)}'
|
||||
)
|
||||
|
||||
num_inference_steps = inputs.get('num_inference_steps', 20)
|
||||
guidance_scale = inputs.get('guidance_scale', 7.5)
|
||||
added_prompt = inputs.get(
|
||||
'added_prompt',
|
||||
'clean, high-resolution, 8k, best quality, masterpiece, extremely detailed'
|
||||
)
|
||||
negative_prompt = inputs.get(
|
||||
'negative_prompt',
|
||||
'dotted, noise, blur, lowres, smooth, longbody, bad anatomy, bad hands, missing fingers, extra digit, \
|
||||
fewer digits, cropped, worst quality, low quality')
|
||||
eta = inputs.get('eta', 0.0)
|
||||
prompt = inputs.get('prompt', '')
|
||||
upscale = inputs.get('upscale', 2)
|
||||
fidelity_scale_fg = inputs.get('fidelity_scale_fg', 1.5)
|
||||
fidelity_scale_bg = inputs.get('fidelity_scale_bg', 0.7)
|
||||
|
||||
input_image = load_image(inputs['image']).convert('RGB')
|
||||
|
||||
with torch.no_grad():
|
||||
generator = torch.Generator(device=self.device)
|
||||
|
||||
batch = self.resnet_preprocess(input_image).unsqueeze(0)
|
||||
prediction = self.resnet(batch).squeeze(0).softmax(0)
|
||||
class_id = prediction.argmax().item()
|
||||
score = prediction[class_id].item()
|
||||
category_name = self.weights.meta['categories'][class_id]
|
||||
if score >= 0.1:
|
||||
prompt += f'{category_name}' if prompt == '' else f', {category_name}'
|
||||
|
||||
prompt = added_prompt if prompt == '' else f'{prompt}, {added_prompt}'
|
||||
|
||||
ori_width, ori_height = input_image.size
|
||||
resize_flag = False
|
||||
rscale = upscale
|
||||
if ori_width < self.process_size // rscale or ori_height < self.process_size // rscale:
|
||||
scale = (self.process_size // rscale) / min(
|
||||
ori_width, ori_height)
|
||||
tmp_image = input_image.resize(
|
||||
(int(scale * ori_width), int(scale * ori_height)))
|
||||
|
||||
input_image = tmp_image
|
||||
resize_flag = True
|
||||
|
||||
input_image = input_image.resize(
|
||||
(input_image.size[0] * rscale, input_image.size[1] * rscale))
|
||||
input_image = input_image.resize(
|
||||
(input_image.size[0] // 8 * 8, input_image.size[1] // 8 * 8))
|
||||
width, height = input_image.size
|
||||
|
||||
fg_mask = None
|
||||
if fidelity_scale_fg != fidelity_scale_bg:
|
||||
fg_mask = torch.zeros([1, 1, height, width])
|
||||
facebs, _ = self.face_detector.detect(np.array(input_image))
|
||||
for fb in facebs:
|
||||
if fb[-1] < self.threshold:
|
||||
continue
|
||||
fb = list(map(int, fb))
|
||||
fg_mask[:, :, fb[1]:fb[3], fb[0]:fb[2]] = 1
|
||||
fg_mask = fg_mask.to(self.device)
|
||||
|
||||
if fg_mask is None:
|
||||
fidelity_scale = min(
|
||||
max(fidelity_scale_fg, fidelity_scale_bg), 1)
|
||||
fidelity_scale_fg = fidelity_scale_bg = fidelity_scale
|
||||
|
||||
try:
|
||||
image = self.pipeline(
|
||||
prompt,
|
||||
input_image,
|
||||
num_inference_steps=num_inference_steps,
|
||||
generator=generator,
|
||||
height=height,
|
||||
width=width,
|
||||
guidance_scale=guidance_scale,
|
||||
negative_prompt=negative_prompt,
|
||||
conditioning_scale_fg=fidelity_scale_fg,
|
||||
conditioning_scale_bg=fidelity_scale_bg,
|
||||
fg_mask=fg_mask,
|
||||
eta=eta,
|
||||
).images[0]
|
||||
|
||||
image = wavelet_color_fix(image, input_image)
|
||||
|
||||
if resize_flag:
|
||||
image = image.resize(
|
||||
(ori_width * rscale, ori_height * rscale))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
image = PIL.Image.new('RGB', (512, 512), (0, 0, 0))
|
||||
|
||||
return {'result': image}
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
|
||||
result = np.array(inputs['result'])
|
||||
return {OutputKeys.OUTPUT_IMG: result[:, :, ::-1]}
|
||||
126
modelscope/pipelines/multi_modal/diffusers_wrapped/devices.py
Normal file
126
modelscope/pipelines/multi_modal/diffusers_wrapped/devices.py
Normal file
@@ -0,0 +1,126 @@
|
||||
# The implementation is adopted from stable-diffusion-webui, made public available under the Apache 2.0 License
|
||||
# at https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/modules/devices.py
|
||||
|
||||
import contextlib
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
if sys.platform == 'darwin':
|
||||
from modules import mac_specific
|
||||
|
||||
|
||||
def has_mps() -> bool:
|
||||
if sys.platform != 'darwin':
|
||||
return False
|
||||
else:
|
||||
return mac_specific.has_mps
|
||||
|
||||
|
||||
def get_cuda_device_string():
|
||||
return 'cuda'
|
||||
|
||||
|
||||
def get_optimal_device_name():
|
||||
if torch.cuda.is_available():
|
||||
return get_cuda_device_string()
|
||||
|
||||
if has_mps():
|
||||
return 'mps'
|
||||
|
||||
return 'cpu'
|
||||
|
||||
|
||||
def get_optimal_device():
|
||||
return torch.device(get_optimal_device_name())
|
||||
|
||||
|
||||
def get_device_for(task):
|
||||
return get_optimal_device()
|
||||
|
||||
|
||||
def torch_gc():
|
||||
|
||||
if torch.cuda.is_available():
|
||||
with torch.cuda.device(get_cuda_device_string()):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
if has_mps():
|
||||
mac_specific.torch_mps_gc()
|
||||
|
||||
|
||||
def enable_tf32():
|
||||
if torch.cuda.is_available():
|
||||
|
||||
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
|
||||
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
|
||||
if any(
|
||||
torch.cuda.get_device_capability(devid) == (7, 5)
|
||||
for devid in range(0, torch.cuda.device_count())):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
|
||||
enable_tf32()
|
||||
|
||||
cpu = torch.device('cpu')
|
||||
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = torch.device(
|
||||
'cuda')
|
||||
dtype = torch.float16
|
||||
dtype_vae = torch.float16
|
||||
dtype_unet = torch.float16
|
||||
unet_needs_upcast = False
|
||||
|
||||
|
||||
def cond_cast_unet(input):
|
||||
return input.to(dtype_unet) if unet_needs_upcast else input
|
||||
|
||||
|
||||
def cond_cast_float(input):
|
||||
return input.float() if unet_needs_upcast else input
|
||||
|
||||
|
||||
def randn(seed, shape):
|
||||
torch.manual_seed(seed)
|
||||
return torch.randn(shape, device=device)
|
||||
|
||||
|
||||
def randn_without_seed(shape):
|
||||
return torch.randn(shape, device=device)
|
||||
|
||||
|
||||
def autocast(disable=False):
|
||||
if disable:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
return torch.autocast('cuda')
|
||||
|
||||
|
||||
def without_autocast(disable=False):
|
||||
return torch.autocast('cuda', enabled=False) if torch.is_autocast_enabled() and \
|
||||
not disable else contextlib.nullcontext()
|
||||
|
||||
|
||||
class NansException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def test_for_nans(x, where):
|
||||
if not torch.all(torch.isnan(x)).item():
|
||||
return
|
||||
|
||||
if where == 'unet':
|
||||
message = 'A tensor with all NaNs was produced in Unet.'
|
||||
|
||||
elif where == 'vae':
|
||||
message = 'A tensor with all NaNs was produced in VAE.'
|
||||
|
||||
else:
|
||||
message = 'A tensor with all NaNs was produced.'
|
||||
|
||||
message += ' Use --disable-nan-check commandline argument to disable this check.'
|
||||
|
||||
raise NansException(message)
|
||||
1131
modelscope/pipelines/multi_modal/diffusers_wrapped/pasd_pipeline.py
Normal file
1131
modelscope/pipelines/multi_modal/diffusers_wrapped/pasd_pipeline.py
Normal file
File diff suppressed because it is too large
Load Diff
762
modelscope/pipelines/multi_modal/diffusers_wrapped/vaehook.py
Normal file
762
modelscope/pipelines/multi_modal/diffusers_wrapped/vaehook.py
Normal file
@@ -0,0 +1,762 @@
|
||||
# Part of the implementation is borrowed and modified from multidiffusion-upscaler-for-automatic1111, publicly available
|
||||
# at https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111/blob/main/scripts/vae_optimize.py
|
||||
|
||||
import gc
|
||||
import math
|
||||
from time import time
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.version
|
||||
from tqdm import tqdm
|
||||
|
||||
from .devices import device, get_optimal_device, test_for_nans, torch_gc
|
||||
|
||||
sd_flag = False
|
||||
|
||||
|
||||
def get_recommend_encoder_tile_size():
|
||||
if torch.cuda.is_available():
|
||||
total_memory = torch.cuda.get_device_properties(
|
||||
device).total_memory // 2**20
|
||||
if total_memory > 16 * 1000:
|
||||
ENCODER_TILE_SIZE = 3072
|
||||
elif total_memory > 12 * 1000:
|
||||
ENCODER_TILE_SIZE = 2048
|
||||
elif total_memory > 8 * 1000:
|
||||
ENCODER_TILE_SIZE = 1536
|
||||
else:
|
||||
ENCODER_TILE_SIZE = 960
|
||||
else:
|
||||
ENCODER_TILE_SIZE = 512
|
||||
return ENCODER_TILE_SIZE
|
||||
|
||||
|
||||
def get_recommend_decoder_tile_size():
|
||||
if torch.cuda.is_available():
|
||||
total_memory = torch.cuda.get_device_properties(
|
||||
device).total_memory // 2**20
|
||||
if total_memory > 30 * 1000:
|
||||
DECODER_TILE_SIZE = 256
|
||||
elif total_memory > 16 * 1000:
|
||||
DECODER_TILE_SIZE = 192
|
||||
elif total_memory > 12 * 1000:
|
||||
DECODER_TILE_SIZE = 128
|
||||
elif total_memory > 8 * 1000:
|
||||
DECODER_TILE_SIZE = 96
|
||||
else:
|
||||
DECODER_TILE_SIZE = 64
|
||||
else:
|
||||
DECODER_TILE_SIZE = 64
|
||||
return DECODER_TILE_SIZE
|
||||
|
||||
|
||||
if 'global const':
|
||||
DEFAULT_ENABLED = False
|
||||
DEFAULT_MOVE_TO_GPU = False
|
||||
DEFAULT_FAST_ENCODER = True
|
||||
DEFAULT_FAST_DECODER = True
|
||||
DEFAULT_COLOR_FIX = 0
|
||||
DEFAULT_ENCODER_TILE_SIZE = get_recommend_encoder_tile_size()
|
||||
DEFAULT_DECODER_TILE_SIZE = get_recommend_decoder_tile_size()
|
||||
|
||||
|
||||
# inplace version of silu
|
||||
def inplace_nonlinearity(x):
|
||||
# Test: fix for Nans
|
||||
return F.silu(x, inplace=True)
|
||||
|
||||
|
||||
# extracted from ldm.modules.diffusionmodules.model
|
||||
|
||||
|
||||
# from diffusers lib
|
||||
def attn_forward_new(self, h_):
|
||||
batch_size, channel, height, width = h_.shape
|
||||
hidden_states = h_.view(batch_size, channel,
|
||||
height * width).transpose(1, 2)
|
||||
|
||||
attention_mask = None
|
||||
encoder_hidden_states = None
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
attention_mask = self.prepare_attention_mask(attention_mask,
|
||||
sequence_length, batch_size)
|
||||
|
||||
query = self.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif self.norm_cross:
|
||||
encoder_hidden_states = self.norm_encoder_hidden_states(
|
||||
encoder_hidden_states)
|
||||
|
||||
key = self.to_k(encoder_hidden_states)
|
||||
value = self.to_v(encoder_hidden_states)
|
||||
|
||||
query = self.head_to_batch_dim(query)
|
||||
key = self.head_to_batch_dim(key)
|
||||
value = self.head_to_batch_dim(value)
|
||||
|
||||
attention_probs = self.get_attention_scores(query, key, attention_mask)
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
hidden_states = self.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = self.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = self.to_out[1](hidden_states)
|
||||
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
||||
batch_size, channel, height, width)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def attn2task(task_queue, net):
|
||||
task_queue.append(('store_res', lambda x: x))
|
||||
task_queue.append(('pre_norm', net.group_norm))
|
||||
task_queue.append(('attn', lambda x, net=net: attn_forward_new(net, x)))
|
||||
task_queue.append(['add_res', None])
|
||||
|
||||
|
||||
def resblock2task(queue, block):
|
||||
"""
|
||||
Turn a ResNetBlock into a sequence of tasks and append to the task queue
|
||||
|
||||
@param queue: the target task queue
|
||||
@param block: ResNetBlock
|
||||
|
||||
"""
|
||||
if block.in_channels != block.out_channels:
|
||||
if sd_flag:
|
||||
if block.use_conv_shortcut:
|
||||
queue.append(('store_res', block.conv_shortcut))
|
||||
else:
|
||||
queue.append(('store_res', block.nin_shortcut))
|
||||
else:
|
||||
if block.use_in_shortcut:
|
||||
queue.append(('store_res', block.conv_shortcut))
|
||||
else:
|
||||
queue.append(('store_res', block.nin_shortcut))
|
||||
|
||||
else:
|
||||
queue.append(('store_res', lambda x: x))
|
||||
queue.append(('pre_norm', block.norm1))
|
||||
queue.append(('silu', inplace_nonlinearity))
|
||||
queue.append(('conv1', block.conv1))
|
||||
queue.append(('pre_norm', block.norm2))
|
||||
queue.append(('silu', inplace_nonlinearity))
|
||||
queue.append(('conv2', block.conv2))
|
||||
queue.append(['add_res', None])
|
||||
|
||||
|
||||
def build_sampling(task_queue, net, is_decoder):
|
||||
"""
|
||||
Build the sampling part of a task queue
|
||||
@param task_queue: the target task queue
|
||||
@param net: the network
|
||||
@param is_decoder: currently building decoder or encoder
|
||||
"""
|
||||
if is_decoder:
|
||||
if sd_flag:
|
||||
resblock2task(task_queue, net.mid.block_1)
|
||||
attn2task(task_queue, net.mid.attn_1)
|
||||
print(task_queue)
|
||||
resblock2task(task_queue, net.mid.block_2)
|
||||
resolution_iter = reversed(range(net.num_resolutions))
|
||||
block_ids = net.num_res_blocks + 1
|
||||
condition = 0
|
||||
module = net.up
|
||||
func_name = 'upsample'
|
||||
else:
|
||||
resblock2task(task_queue, net.mid_block.resnets[0])
|
||||
attn2task(task_queue, net.mid_block.attentions[0])
|
||||
resblock2task(task_queue, net.mid_block.resnets[1])
|
||||
resolution_iter = (range(len(net.up_blocks))
|
||||
) # net.num_resolutions = 3
|
||||
block_ids = 2 + 1
|
||||
condition = len(net.up_blocks) - 1
|
||||
module = net.up_blocks
|
||||
func_name = 'upsamplers'
|
||||
else:
|
||||
resolution_iter = range(net.num_resolutions)
|
||||
block_ids = net.num_res_blocks
|
||||
condition = net.num_resolutions - 1
|
||||
module = net.down
|
||||
func_name = 'downsample'
|
||||
|
||||
for i_level in resolution_iter:
|
||||
for i_block in range(block_ids):
|
||||
if sd_flag:
|
||||
resblock2task(task_queue, module[i_level].block[i_block])
|
||||
else:
|
||||
resblock2task(task_queue, module[i_level].resnets[i_block])
|
||||
if i_level != condition:
|
||||
if sd_flag:
|
||||
task_queue.append(
|
||||
(func_name, getattr(module[i_level], func_name)))
|
||||
else:
|
||||
task_queue.append((func_name, module[i_level].upsamplers[0]))
|
||||
|
||||
if not is_decoder:
|
||||
if sd_flag:
|
||||
resblock2task(task_queue, net.mid.block_1)
|
||||
attn2task(task_queue, net.mid.attn_1)
|
||||
resblock2task(task_queue, net.mid.block_2)
|
||||
else:
|
||||
resblock2task(task_queue, net.mid_block.resnets[0])
|
||||
attn2task(task_queue, net.mid_block.attentions[0])
|
||||
resblock2task(task_queue, net.mid_block.resnets[1])
|
||||
|
||||
|
||||
def build_task_queue(net, is_decoder):
|
||||
"""
|
||||
Build a single task queue for the encoder or decoder
|
||||
@param net: the VAE decoder or encoder network
|
||||
@param is_decoder: currently building decoder or encoder
|
||||
@return: the task queue
|
||||
"""
|
||||
task_queue = []
|
||||
task_queue.append(('conv_in', net.conv_in))
|
||||
|
||||
# construct the sampling part of the task queue
|
||||
# because encoder and decoder share the same architecture, we extract the sampling part
|
||||
build_sampling(task_queue, net, is_decoder)
|
||||
if is_decoder and not sd_flag:
|
||||
net.give_pre_end = False
|
||||
net.tanh_out = False
|
||||
|
||||
if not is_decoder or not net.give_pre_end:
|
||||
if sd_flag:
|
||||
task_queue.append(('pre_norm', net.norm_out))
|
||||
else:
|
||||
task_queue.append(('pre_norm', net.conv_norm_out))
|
||||
task_queue.append(('silu', inplace_nonlinearity))
|
||||
task_queue.append(('conv_out', net.conv_out))
|
||||
if is_decoder and net.tanh_out:
|
||||
task_queue.append(('tanh', torch.tanh))
|
||||
|
||||
return task_queue
|
||||
|
||||
|
||||
def clone_task_queue(task_queue):
|
||||
"""
|
||||
Clone a task queue
|
||||
@param task_queue: the task queue to be cloned
|
||||
@return: the cloned task queue
|
||||
"""
|
||||
return [[item for item in task] for task in task_queue]
|
||||
|
||||
|
||||
def get_var_mean(input, num_groups, eps=1e-6):
|
||||
"""
|
||||
Get mean and var for group norm
|
||||
"""
|
||||
b, c = input.size(0), input.size(1)
|
||||
channel_in_group = int(c / num_groups)
|
||||
input_reshaped = input.contiguous().view(1, int(b * num_groups),
|
||||
channel_in_group,
|
||||
*input.size()[2:])
|
||||
var, mean = torch.var_mean(
|
||||
input_reshaped, dim=[0, 2, 3, 4], unbiased=False)
|
||||
return var, mean
|
||||
|
||||
|
||||
def custom_group_norm(input,
|
||||
num_groups,
|
||||
mean,
|
||||
var,
|
||||
weight=None,
|
||||
bias=None,
|
||||
eps=1e-6):
|
||||
"""
|
||||
Custom group norm with fixed mean and var
|
||||
|
||||
@param input: input tensor
|
||||
@param num_groups: number of groups. by default, num_groups = 32
|
||||
@param mean: mean, must be pre-calculated by get_var_mean
|
||||
@param var: var, must be pre-calculated by get_var_mean
|
||||
@param weight: weight, should be fetched from the original group norm
|
||||
@param bias: bias, should be fetched from the original group norm
|
||||
@param eps: epsilon, by default, eps = 1e-6 to match the original group norm
|
||||
|
||||
@return: normalized tensor
|
||||
"""
|
||||
b, c = input.size(0), input.size(1)
|
||||
channel_in_group = int(c / num_groups)
|
||||
input_reshaped = input.contiguous().view(1, int(b * num_groups),
|
||||
channel_in_group,
|
||||
*input.size()[2:])
|
||||
|
||||
out = F.batch_norm(
|
||||
input_reshaped,
|
||||
mean,
|
||||
var,
|
||||
weight=None,
|
||||
bias=None,
|
||||
training=False,
|
||||
momentum=0,
|
||||
eps=eps)
|
||||
|
||||
out = out.view(b, c, *input.size()[2:])
|
||||
|
||||
# post affine transform
|
||||
if weight is not None:
|
||||
out *= weight.view(1, -1, 1, 1)
|
||||
if bias is not None:
|
||||
out += bias.view(1, -1, 1, 1)
|
||||
return out
|
||||
|
||||
|
||||
def crop_valid_region(x, input_bbox, target_bbox, is_decoder):
|
||||
"""
|
||||
Crop the valid region from the tile
|
||||
@param x: input tile
|
||||
@param input_bbox: original input bounding box
|
||||
@param target_bbox: output bounding box
|
||||
@param scale: scale factor
|
||||
@return: cropped tile
|
||||
"""
|
||||
padded_bbox = [i * 8 if is_decoder else i // 8 for i in input_bbox]
|
||||
margin = [target_bbox[i] - padded_bbox[i] for i in range(4)]
|
||||
return x[:, :, margin[2]:x.size(2) + margin[3],
|
||||
margin[0]:x.size(3) + margin[1]]
|
||||
|
||||
|
||||
# ↓↓↓ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↓↓↓
|
||||
|
||||
|
||||
def perfcount(fn):
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
ts = time()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
torch_gc()
|
||||
gc.collect()
|
||||
|
||||
ret = fn(*args, **kwargs)
|
||||
|
||||
torch_gc()
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
vram = torch.cuda.max_memory_allocated(device) / 2**20
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
print(
|
||||
f'[Tiled VAE]: Done in {time() - ts:.3f}s, max VRAM alloc {vram:.3f} MB'
|
||||
)
|
||||
else:
|
||||
print(f'[Tiled VAE]: Done in {time() - ts:.3f}s')
|
||||
|
||||
return ret
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# copy end :)
|
||||
|
||||
|
||||
class GroupNormParam:
|
||||
|
||||
def __init__(self):
|
||||
self.var_list = []
|
||||
self.mean_list = []
|
||||
self.pixel_list = []
|
||||
self.weight = None
|
||||
self.bias = None
|
||||
|
||||
def add_tile(self, tile, layer):
|
||||
var, mean = get_var_mean(tile, 32)
|
||||
# For giant images, the variance can be larger than max float16
|
||||
# In this case we create a copy to float32
|
||||
if var.dtype == torch.float16 and var.isinf().any():
|
||||
fp32_tile = tile.float()
|
||||
var, mean = get_var_mean(fp32_tile, 32)
|
||||
# ============= DEBUG: test for infinite =============
|
||||
# if torch.isinf(var).any():
|
||||
# print('var: ', var)
|
||||
# ====================================================
|
||||
self.var_list.append(var)
|
||||
self.mean_list.append(mean)
|
||||
self.pixel_list.append(tile.shape[2] * tile.shape[3])
|
||||
if hasattr(layer, 'weight'):
|
||||
self.weight = layer.weight
|
||||
self.bias = layer.bias
|
||||
else:
|
||||
self.weight = None
|
||||
self.bias = None
|
||||
|
||||
def summary(self):
|
||||
"""
|
||||
summarize the mean and var and return a function
|
||||
that apply group norm on each tile
|
||||
"""
|
||||
if len(self.var_list) == 0:
|
||||
return None
|
||||
var = torch.vstack(self.var_list)
|
||||
mean = torch.vstack(self.mean_list)
|
||||
max_value = max(self.pixel_list)
|
||||
pixels = torch.tensor(
|
||||
self.pixel_list, dtype=torch.float32, device=device) / max_value
|
||||
sum_pixels = torch.sum(pixels)
|
||||
pixels = pixels.unsqueeze(1) / sum_pixels
|
||||
var = torch.sum(var * pixels, dim=0)
|
||||
mean = torch.sum(mean * pixels, dim=0)
|
||||
return lambda x: custom_group_norm(x, 32, mean, var, self.weight, self.
|
||||
bias)
|
||||
|
||||
@staticmethod
|
||||
def from_tile(tile, norm):
|
||||
"""
|
||||
create a function from a single tile without summary
|
||||
"""
|
||||
var, mean = get_var_mean(tile, 32)
|
||||
if var.dtype == torch.float16 and var.isinf().any():
|
||||
fp32_tile = tile.float()
|
||||
var, mean = get_var_mean(fp32_tile, 32)
|
||||
# if it is a macbook, we need to convert back to float16
|
||||
if var.device.type == 'mps':
|
||||
# clamp to avoid overflow
|
||||
var = torch.clamp(var, 0, 60000)
|
||||
var = var.half()
|
||||
mean = mean.half()
|
||||
if hasattr(norm, 'weight'):
|
||||
weight = norm.weight
|
||||
bias = norm.bias
|
||||
else:
|
||||
weight = None
|
||||
bias = None
|
||||
|
||||
def group_norm_func(x, mean=mean, var=var, weight=weight, bias=bias):
|
||||
return custom_group_norm(x, 32, mean, var, weight, bias, 1e-6)
|
||||
|
||||
return group_norm_func
|
||||
|
||||
|
||||
class VAEHook:
|
||||
|
||||
def __init__(self,
|
||||
net,
|
||||
tile_size,
|
||||
is_decoder,
|
||||
fast_decoder,
|
||||
fast_encoder,
|
||||
color_fix,
|
||||
to_gpu=False):
|
||||
self.net = net # encoder | decoder
|
||||
self.tile_size = tile_size
|
||||
self.is_decoder = is_decoder
|
||||
self.fast_mode = (fast_encoder and not is_decoder) or (fast_decoder
|
||||
and is_decoder)
|
||||
self.color_fix = color_fix and not is_decoder
|
||||
self.to_gpu = to_gpu
|
||||
self.pad = 11 if is_decoder else 32
|
||||
|
||||
def __call__(self, x):
|
||||
B, C, H, W = x.shape
|
||||
original_device = next(self.net.parameters()).device
|
||||
try:
|
||||
if self.to_gpu:
|
||||
self.net.to(get_optimal_device())
|
||||
if max(H, W) <= self.pad * 2 + self.tile_size:
|
||||
print(
|
||||
'[Tiled VAE]: the input size is tiny and unnecessary to tile.'
|
||||
)
|
||||
return self.net.original_forward(x)
|
||||
else:
|
||||
return self.vae_tile_forward(x)
|
||||
finally:
|
||||
self.net.to(original_device)
|
||||
|
||||
def get_best_tile_size(self, lowerbound, upperbound):
|
||||
"""
|
||||
Get the best tile size for GPU memory
|
||||
"""
|
||||
divider = 32
|
||||
while divider >= 2:
|
||||
remainer = lowerbound % divider
|
||||
if remainer == 0:
|
||||
return lowerbound
|
||||
candidate = lowerbound - remainer + divider
|
||||
if candidate <= upperbound:
|
||||
return candidate
|
||||
divider //= 2
|
||||
return lowerbound
|
||||
|
||||
def split_tiles(self, h, w):
|
||||
"""
|
||||
Tool function to split the image into tiles
|
||||
@param h: height of the image
|
||||
@param w: width of the image
|
||||
@return: tile_input_bboxes, tile_output_bboxes
|
||||
"""
|
||||
tile_input_bboxes, tile_output_bboxes = [], []
|
||||
tile_size = self.tile_size
|
||||
pad = self.pad
|
||||
num_height_tiles = math.ceil((h - 2 * pad) / tile_size)
|
||||
num_width_tiles = math.ceil((w - 2 * pad) / tile_size)
|
||||
# If any of the numbers are 0, we let it be 1
|
||||
# This is to deal with long and thin images
|
||||
num_height_tiles = max(num_height_tiles, 1)
|
||||
num_width_tiles = max(num_width_tiles, 1)
|
||||
|
||||
# Suggestions from https://github.com/Kahsolt: auto shrink the tile size
|
||||
real_tile_height = math.ceil((h - 2 * pad) / num_height_tiles)
|
||||
real_tile_width = math.ceil((w - 2 * pad) / num_width_tiles)
|
||||
real_tile_height = self.get_best_tile_size(real_tile_height, tile_size)
|
||||
real_tile_width = self.get_best_tile_size(real_tile_width, tile_size)
|
||||
|
||||
print(
|
||||
f'[Tiled VAE]: split to {num_height_tiles}x{num_width_tiles}={num_height_tiles*num_width_tiles} tiles.',
|
||||
f'Optimal tile size {real_tile_width}x{real_tile_height}, original tile size {tile_size}x{tile_size}'
|
||||
)
|
||||
|
||||
for i in range(num_height_tiles):
|
||||
for j in range(num_width_tiles):
|
||||
# bbox: [x1, x2, y1, y2]
|
||||
# the padding is is unnessary for image borders. So we directly start from (32, 32)
|
||||
input_bbox = [
|
||||
pad + j * real_tile_width,
|
||||
min(pad + (j + 1) * real_tile_width, w),
|
||||
pad + i * real_tile_height,
|
||||
min(pad + (i + 1) * real_tile_height, h),
|
||||
]
|
||||
|
||||
# if the output bbox is close to the image boundary, we extend it to the image boundary
|
||||
output_bbox = [
|
||||
input_bbox[0] if input_bbox[0] > pad else 0,
|
||||
input_bbox[1] if input_bbox[1] < w - pad else w,
|
||||
input_bbox[2] if input_bbox[2] > pad else 0,
|
||||
input_bbox[3] if input_bbox[3] < h - pad else h,
|
||||
]
|
||||
|
||||
# scale to get the final output bbox
|
||||
output_bbox = [
|
||||
x * 8 if self.is_decoder else x // 8 for x in output_bbox
|
||||
]
|
||||
tile_output_bboxes.append(output_bbox)
|
||||
|
||||
# indistinguishable expand the input bbox by pad pixels
|
||||
tile_input_bboxes.append([
|
||||
max(0, input_bbox[0] - pad),
|
||||
min(w, input_bbox[1] + pad),
|
||||
max(0, input_bbox[2] - pad),
|
||||
min(h, input_bbox[3] + pad),
|
||||
])
|
||||
|
||||
return tile_input_bboxes, tile_output_bboxes
|
||||
|
||||
@torch.no_grad()
|
||||
def estimate_group_norm(self, z, task_queue, color_fix):
|
||||
device = z.device
|
||||
tile = z
|
||||
last_id = len(task_queue) - 1
|
||||
while last_id >= 0 and task_queue[last_id][0] != 'pre_norm':
|
||||
last_id -= 1
|
||||
if last_id <= 0 or task_queue[last_id][0] != 'pre_norm':
|
||||
raise ValueError('No group norm found in the task queue')
|
||||
# estimate until the last group norm
|
||||
for i in range(last_id + 1):
|
||||
task = task_queue[i]
|
||||
if task[0] == 'pre_norm':
|
||||
group_norm_func = GroupNormParam.from_tile(tile, task[1])
|
||||
task_queue[i] = ('apply_norm', group_norm_func)
|
||||
if i == last_id:
|
||||
return True
|
||||
tile = group_norm_func(tile)
|
||||
elif task[0] == 'store_res':
|
||||
task_id = i + 1
|
||||
while task_id < last_id and task_queue[task_id][0] != 'add_res':
|
||||
task_id += 1
|
||||
if task_id >= last_id:
|
||||
continue
|
||||
task_queue[task_id][1] = task[1](tile)
|
||||
elif task[0] == 'add_res':
|
||||
tile += task[1].to(device)
|
||||
task[1] = None
|
||||
elif color_fix and task[0] == 'downsample':
|
||||
for j in range(i, last_id + 1):
|
||||
if task_queue[j][0] == 'store_res':
|
||||
task_queue[j] = ('store_res_cpu', task_queue[j][1])
|
||||
return True
|
||||
else:
|
||||
tile = task[1](tile)
|
||||
try:
|
||||
test_for_nans(tile, 'vae')
|
||||
except Exception as e:
|
||||
print(
|
||||
f'{e}. Nan detected in fast mode estimation. Fast mode disabled.'
|
||||
)
|
||||
return False
|
||||
|
||||
raise IndexError('Should not reach here')
|
||||
|
||||
@perfcount
|
||||
@torch.no_grad()
|
||||
def vae_tile_forward(self, z):
|
||||
"""
|
||||
Decode a latent vector z into an image in a tiled manner.
|
||||
@param z: latent vector
|
||||
@return: image
|
||||
"""
|
||||
device = next(self.net.parameters()).device
|
||||
net = self.net
|
||||
tile_size = self.tile_size
|
||||
is_decoder = self.is_decoder
|
||||
|
||||
z = z.detach() # detach the input to avoid backprop
|
||||
|
||||
N, height, width = z.shape[0], z.shape[2], z.shape[3]
|
||||
net.last_z_shape = z.shape
|
||||
|
||||
# Split the input into tiles and build a task queue for each tile
|
||||
print(
|
||||
f'[Tiled VAE]: input_size: {z.shape}, tile_size: {tile_size}, padding: {self.pad}'
|
||||
)
|
||||
|
||||
in_bboxes, out_bboxes = self.split_tiles(height, width)
|
||||
|
||||
# Prepare tiles by split the input latents
|
||||
tiles = []
|
||||
for input_bbox in in_bboxes:
|
||||
tile = z[:, :, input_bbox[2]:input_bbox[3],
|
||||
input_bbox[0]:input_bbox[1]].cpu()
|
||||
tiles.append(tile)
|
||||
|
||||
num_tiles = len(tiles)
|
||||
num_completed = 0
|
||||
|
||||
# Build task queues
|
||||
single_task_queue = build_task_queue(net, is_decoder)
|
||||
if self.fast_mode:
|
||||
# Fast mode: downsample the input image to the tile size,
|
||||
# then estimate the group norm parameters on the downsampled image
|
||||
scale_factor = tile_size / max(height, width)
|
||||
z = z.to(device)
|
||||
downsampled_z = F.interpolate(
|
||||
z, scale_factor=scale_factor, mode='nearest-exact')
|
||||
# use nearest-exact to keep statictics as close as possible
|
||||
print(
|
||||
f'[Tiled VAE]: Fast mode enabled, estimating group norm parameters on \
|
||||
{downsampled_z.shape[3]} x {downsampled_z.shape[2]} image')
|
||||
|
||||
# ======= Special thanks to @Kahsolt for distribution shift issue ======= #
|
||||
# The downsampling will heavily distort its mean and std, so we need to recover it.
|
||||
std_old, mean_old = torch.std_mean(z, dim=[0, 2, 3], keepdim=True)
|
||||
std_new, mean_new = torch.std_mean(
|
||||
downsampled_z, dim=[0, 2, 3], keepdim=True)
|
||||
downsampled_z = (downsampled_z
|
||||
- mean_new) / std_new * std_old + mean_old
|
||||
del std_old, mean_old, std_new, mean_new
|
||||
# occasionally the std_new is too small or too large, which exceeds the range of float16
|
||||
# so we need to clamp it to max z's range.
|
||||
downsampled_z = torch.clamp_(
|
||||
downsampled_z, min=z.min(), max=z.max())
|
||||
estimate_task_queue = clone_task_queue(single_task_queue)
|
||||
if self.estimate_group_norm(
|
||||
downsampled_z, estimate_task_queue,
|
||||
color_fix=self.color_fix):
|
||||
single_task_queue = estimate_task_queue
|
||||
del downsampled_z
|
||||
|
||||
task_queues = [
|
||||
clone_task_queue(single_task_queue) for _ in range(num_tiles)
|
||||
]
|
||||
|
||||
# Dummy result
|
||||
result = None
|
||||
result_approx = None
|
||||
# Free memory of input latent tensor
|
||||
del z
|
||||
|
||||
# Task queue execution
|
||||
pbar = tqdm(
|
||||
total=num_tiles * len(task_queues[0]),
|
||||
desc=
|
||||
f"[Tiled VAE]: Executing {'Decoder' if is_decoder else 'Encoder'} Task Queue: "
|
||||
)
|
||||
|
||||
# execute the task back and forth when switch tiles so that we always
|
||||
# keep one tile on the GPU to reduce unnecessary data transfer
|
||||
forward = True
|
||||
interrupted = False
|
||||
while True:
|
||||
|
||||
group_norm_param = GroupNormParam()
|
||||
for i in range(num_tiles) if forward else reversed(
|
||||
range(num_tiles)):
|
||||
|
||||
tile = tiles[i].to(device)
|
||||
input_bbox = in_bboxes[i]
|
||||
task_queue = task_queues[i]
|
||||
|
||||
interrupted = False
|
||||
while len(task_queue) > 0:
|
||||
# DEBUG: current task
|
||||
task = task_queue.pop(0)
|
||||
if task[0] == 'pre_norm':
|
||||
group_norm_param.add_tile(tile, task[1])
|
||||
break
|
||||
elif task[0] == 'store_res' or task[0] == 'store_res_cpu':
|
||||
task_id = 0
|
||||
res = task[1](tile)
|
||||
if not self.fast_mode or task[0] == 'store_res_cpu':
|
||||
res = res.cpu()
|
||||
while task_queue[task_id][0] != 'add_res':
|
||||
task_id += 1
|
||||
task_queue[task_id][1] = res
|
||||
elif task[0] == 'add_res':
|
||||
tile += task[1].to(device)
|
||||
task[1] = None
|
||||
else:
|
||||
tile = task[1](tile)
|
||||
pbar.update(1)
|
||||
|
||||
if interrupted:
|
||||
break
|
||||
|
||||
# check for NaNs in the tile.
|
||||
# If there are NaNs, we abort the process to save user's time
|
||||
test_for_nans(tile, 'vae')
|
||||
|
||||
if len(task_queue) == 0:
|
||||
tiles[i] = None
|
||||
num_completed += 1
|
||||
if result is None: # NOTE: dim C varies from different cases, can only be inited dynamically
|
||||
result = torch.zeros(
|
||||
(N, tile.shape[1],
|
||||
height * 8 if is_decoder else height // 8,
|
||||
width * 8 if is_decoder else width // 8),
|
||||
device=device,
|
||||
requires_grad=False)
|
||||
result[:, :, out_bboxes[i][2]:out_bboxes[i][3],
|
||||
out_bboxes[i][0]:out_bboxes[i]
|
||||
[1]] = crop_valid_region(tile, in_bboxes[i],
|
||||
out_bboxes[i], is_decoder)
|
||||
del tile
|
||||
elif i == num_tiles - 1 and forward:
|
||||
forward = False
|
||||
tiles[i] = tile
|
||||
elif i == 0 and not forward:
|
||||
forward = True
|
||||
tiles[i] = tile
|
||||
else:
|
||||
tiles[i] = tile.cpu()
|
||||
del tile
|
||||
|
||||
if interrupted:
|
||||
break
|
||||
if num_completed == num_tiles:
|
||||
break
|
||||
|
||||
# insert the group norm task to the head of each task queue
|
||||
group_norm_func = group_norm_param.summary()
|
||||
if group_norm_func is not None:
|
||||
for i in range(num_tiles):
|
||||
task_queue = task_queues[i]
|
||||
task_queue.insert(0, ('apply_norm', group_norm_func))
|
||||
|
||||
# Done!
|
||||
pbar.close()
|
||||
return result if result is not None else result_approx.to(device)
|
||||
16
modelscope/utils/compatible_with_transformers.py
Normal file
16
modelscope/utils/compatible_with_transformers.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import transformers
|
||||
from packaging import version
|
||||
|
||||
|
||||
def compatible_position_ids(state_dict, position_id_key):
|
||||
"""Transformers no longer expect position_ids after transformers==4.31
|
||||
https://github.com/huggingface/transformers/pull/24505
|
||||
|
||||
Args:
|
||||
position_id_key (str): position_ids key,
|
||||
such as(encoder.embeddings.position_ids)
|
||||
"""
|
||||
transformer_version = version.parse('.'.join(
|
||||
transformers.__version__.split('.')[:2]))
|
||||
if transformer_version >= version.parse('4.31.0'):
|
||||
del state_dict[position_id_key]
|
||||
@@ -75,6 +75,7 @@ class CVTasks(object):
|
||||
# image editing
|
||||
skin_retouching = 'skin-retouching'
|
||||
image_super_resolution = 'image-super-resolution'
|
||||
image_super_resolution_pasd = 'image-super-resolution-pasd'
|
||||
image_debanding = 'image-debanding'
|
||||
image_colorization = 'image-colorization'
|
||||
image_color_enhancement = 'image-color-enhancement'
|
||||
|
||||
43
tests/pipelines/test_image_super_resolution_pasd.py
Normal file
43
tests/pipelines/test_image_super_resolution_pasd.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os.path as osp
|
||||
import unittest
|
||||
|
||||
import cv2
|
||||
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class ImageSuperResolutionPASDTest(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.model_id = 'damo/PASD_image_super_resolutions'
|
||||
self.img = 'data/test/images/dogs.jpg'
|
||||
self.input = {
|
||||
'image': self.img,
|
||||
'prompt': '',
|
||||
'upscale': 1,
|
||||
'fidelity_scale_fg': 1.5,
|
||||
'fidelity_scale_bg': 0.7
|
||||
}
|
||||
self.task = Tasks.image_super_resolution_pasd
|
||||
|
||||
def pipeline_inference(self, pipeline: Pipeline, input: dict):
|
||||
result = pipeline(input)
|
||||
if result is not None:
|
||||
cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG])
|
||||
print(f'Output written to {osp.abspath("result.png")}')
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_modelhub(self):
|
||||
super_resolution = pipeline(
|
||||
Tasks.image_super_resolution_pasd, model=self.model_id)
|
||||
|
||||
self.pipeline_inference(super_resolution, self.input)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -7,6 +7,7 @@ isolated: # test cases that may require excessive anmount of GPU memory or run
|
||||
- test_dialog_modeling.py
|
||||
- test_csanmt_translation.py
|
||||
- test_image_super_resolution.py
|
||||
- test_image_super_resolution_pasd.py
|
||||
- test_easycv_trainer.py
|
||||
- test_segformer.py
|
||||
- test_segmentation_pipeline.py
|
||||
|
||||
Reference in New Issue
Block a user