From 582c3a5415a3413d26aa1452793541563d2a6a60 Mon Sep 17 00:00:00 2001 From: "baiguan.yt" Date: Wed, 11 Oct 2023 09:59:19 +0800 Subject: [PATCH 01/21] =?UTF-8?q?add=20pasd=5Fv2=20=E4=BF=9D=E7=9C=9F?= =?UTF-8?q?=E5=BA=A6=E6=9B=B4=E5=A5=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/14202282 --- modelscope/models/cv/__init__.py | 8 +- .../__init__.py | 24 + .../controlnet.py | 980 +++++ .../unet_2d_blocks.py | 3439 +++++++++++++++++ .../unet_2d_condition.py | 1101 ++++++ .../image_super_resolution_pasd_pipeline.py | 38 +- .../diffusers_wrapped/pasd_pipeline.py | 237 +- .../test_image_super_resolution_pasd.py | 12 +- 8 files changed, 5783 insertions(+), 56 deletions(-) create mode 100644 modelscope/models/cv/image_super_resolution_pasd_v2/__init__.py create mode 100644 modelscope/models/cv/image_super_resolution_pasd_v2/controlnet.py create mode 100644 modelscope/models/cv/image_super_resolution_pasd_v2/unet_2d_blocks.py create mode 100644 modelscope/models/cv/image_super_resolution_pasd_v2/unet_2d_condition.py diff --git a/modelscope/models/cv/__init__.py b/modelscope/models/cv/__init__.py index 3fc455c5..5da87a00 100644 --- a/modelscope/models/cv/__init__.py +++ b/modelscope/models/cv/__init__.py @@ -15,10 +15,10 @@ from . import (action_recognition, animal_recognition, bad_image_detecting, image_quality_assessment_man, image_quality_assessment_mos, image_reid_person, image_restoration, image_semantic_segmentation, image_super_resolution_pasd, - image_to_image_generation, image_to_image_translation, - language_guided_video_summarization, movie_scene_segmentation, - object_detection, panorama_depth_estimation, - pedestrian_attribute_recognition, + image_super_resolution_pasd_v2, image_to_image_generation, + image_to_image_translation, language_guided_video_summarization, + movie_scene_segmentation, object_detection, + panorama_depth_estimation, pedestrian_attribute_recognition, pointcloud_sceneflow_estimation, product_retrieval_embedding, referring_video_object_segmentation, robust_image_classification, salient_detection, diff --git a/modelscope/models/cv/image_super_resolution_pasd_v2/__init__.py b/modelscope/models/cv/image_super_resolution_pasd_v2/__init__.py new file mode 100644 index 00000000..1448c348 --- /dev/null +++ b/modelscope/models/cv/image_super_resolution_pasd_v2/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .unet_2d_condition import UNet2DConditionModel + from .controlnet import ControlNetModel + +else: + _import_structure = { + 'unet_2d_condition': ['UNet2DConditionModel'], + 'controlnet': ['ControlNetModel'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/image_super_resolution_pasd_v2/controlnet.py b/modelscope/models/cv/image_super_resolution_pasd_v2/controlnet.py new file mode 100644 index 00000000..0440c82d --- /dev/null +++ b/modelscope/models/cv/image_super_resolution_pasd_v2/controlnet.py @@ -0,0 +1,980 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.attention_processor import (AttentionProcessor, + AttnProcessor) +from diffusers.models.embeddings import (TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, TimestepEmbedding, + Timesteps) +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import BaseOutput, logging +from torch import nn +from torch.nn import functional as F + +from modelscope.models.cv.super_resolution.rrdbnet_arch import RRDB +from .unet_2d_blocks import (CrossAttnDownBlock2D, DownBlock2D, + UNetMidBlock2DCrossAttn, get_down_block) +from .unet_2d_condition import UNet2DConditionModel + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class ControlNetOutput(BaseOutput): + """ + The output of [`ControlNetModel`]. + + Args: + down_block_res_samples (`tuple[torch.Tensor]`): + A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should + be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be + used to condition the original UNet's downsampling activations. + mid_down_block_re_sample (`torch.Tensor`): + The activation of the midde block (the lowest sample resolution). Each tensor should be of shape + `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. + Output can be used to condition the original UNet's middle block activation. + """ + + down_block_res_samples: Tuple[torch.Tensor] + mid_block_res_sample: torch.Tensor + + +class ControlNetConditioningEmbedding(nn.Module): + """ + Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN + [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized + training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the + convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides + (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full + model) to encode image-space conditions ... into feature maps ..." + """ + + def __init__( + self, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels: Tuple[int] = (16, 32, 96, 256), + return_rgbs: bool = True, + use_rrdb: bool = True, + ): + super().__init__() + + self.return_rgbs = return_rgbs + self.use_rrdb = use_rrdb + + self.conv_in = nn.Conv2d( + conditioning_channels, + block_out_channels[0], + kernel_size=3, + padding=1) + + if self.use_rrdb: + num_rrdb_block = 2 + layers = ( + RRDB(block_out_channels[0], block_out_channels[0]) + for i in range(num_rrdb_block)) + self.preprocesser = nn.Sequential(*layers) + + self.blocks = nn.ModuleList([]) + self.to_rgbs = nn.ModuleList([]) + + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + self.blocks.append( + nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) + self.blocks.append( + nn.Conv2d( + channel_in, + channel_out, + kernel_size=3, + padding=1, + stride=2)) + + if return_rgbs: + self.to_rgbs.append( + nn.Conv2d(channel_out, 3, kernel_size=3, + padding=1)) # channel_in + + self.conv_out = zero_module( + nn.Conv2d( + block_out_channels[-1], + conditioning_embedding_channels, + kernel_size=3, + padding=1)) + + def forward(self, conditioning): + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + if self.use_rrdb: + embedding = self.preprocesser(embedding) + + out_rgbs = [] + for i, block in enumerate(self.blocks): + embedding = block(embedding) + embedding = F.silu(embedding) + + if i % 2 and self.return_rgbs: # 0 + out_rgbs.append(self.to_rgbs[i // 2](embedding)) + + embedding = self.conv_out(embedding) + + return [embedding, out_rgbs] if self.return_rgbs else embedding + + +class ControlNetModel(ModelMixin, ConfigMixin): + """ + A ControlNet model. + + Args: + in_channels (`int`, defaults to 4): + The number of channels in the input sample. + flip_sin_to_cos (`bool`, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, defaults to 0): + The frequency shift to apply to the time embedding. + down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`): + block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, defaults to 2): + The number of layers per block. + downsample_padding (`int`, defaults to 1): + The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, defaults to 1): + The scale factor to use for the mid block. + act_fn (`str`, defaults to "silu"): + The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the normalization. If None, normalization and activation layers is skipped + in post-processing. + norm_eps (`float`, defaults to 1e-5): + The epsilon to use for the normalization. + cross_attention_dim (`int`, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8): + The dimension of the attention heads. + use_linear_projection (`bool`, defaults to `False`): + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + num_class_embeds (`int`, *optional*, defaults to 0): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + upcast_attention (`bool`, defaults to `False`): + resnet_time_scale_shift (`str`, defaults to `"default"`): + Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`. + projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`): + The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when + `class_embed_type="projection"`. + controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): + The channel order of conditional image. Will convert to `rgb` if it's `bgr`. + conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`): + The tuple of output channel for each block in the `conditioning_embedding` layer. + global_pool_conditions (`bool`, defaults to `False`): + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 4, + conditioning_channels: int = 3, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + 'CrossAttnDownBlock2D', + 'CrossAttnDownBlock2D', + 'CrossAttnDownBlock2D', + 'DownBlock2D', + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = 'silu', + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = 'default', + projection_class_embeddings_input_dim: Optional[int] = None, + controlnet_conditioning_channel_order: str = 'rgb', + conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, + 96, 256), + global_pool_conditions: bool = False, + addition_embed_type_num_heads=64, + return_rgbs: bool = True, + use_rrdb: bool = False, + ): + super().__init__() + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in + # https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f'Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: \ + {block_out_channels}. `down_block_types`: {down_block_types}.' + ) + + if not isinstance( + only_cross_attention, + bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f'Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: \ + {only_cross_attention}. `down_block_types`: {down_block_types}.' + ) + + if not isinstance( + num_attention_heads, + int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f'Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: \ + {num_attention_heads}. `down_block_types`: {down_block_types}.' + ) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block + ] * len(down_block_types) + + # input + self.return_rgbs = return_rgbs + conv_in_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[0], + kernel_size=conv_in_kernel, + padding=conv_in_padding) + + # time + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, + freq_shift) + timestep_input_dim = block_out_channels[0] + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = 'text_proj' + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info( + "encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined." + ) + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f'`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}.' + ) + + if encoder_hid_dim_type == 'text_proj': + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, + cross_attention_dim) + elif encoder_hid_dim_type == 'text_image_proj': + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the + # currently only use case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, + time_embed_dim) + elif class_embed_type == 'timestep': + self.class_embedding = TimestepEmbedding(timestep_input_dim, + time_embed_dim) + elif class_embed_type == 'identity': + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == 'projection': + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding( + projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if addition_embed_type == 'text': + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, + time_embed_dim, + num_heads=addition_embed_type_num_heads) + elif addition_embed_type == 'text_image': + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter + # the __init__ too much they are set to `cross_attention_dim` here as this is exactly the + # required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, + image_embed_dim=cross_attention_dim, + time_embed_dim=time_embed_dim) + elif addition_embed_type == 'text_time': + self.add_time_proj = Timesteps(addition_time_embed_dim, + flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding( + projection_class_embeddings_input_dim, time_embed_dim) + + elif addition_embed_type is not None: + raise ValueError( + f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'." + ) + + # control net conditioning embedding + self.controlnet_cond_embedding = ControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + block_out_channels=conditioning_embedding_out_channels, + conditioning_channels=conditioning_channels, + return_rgbs=return_rgbs, + use_rrdb=use_rrdb, + ) + + self.down_blocks = nn.ModuleList([]) + self.controlnet_down_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention + ] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim, ) * len(down_block_types) + + if isinstance(num_attention_heads, int): + num_attention_heads = ( + num_attention_heads, ) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + + controlnet_block = nn.Conv2d( + output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[i], + attention_head_dim=attention_head_dim[i] + if attention_head_dim[i] is not None else output_channel, + downsample_padding=downsample_padding, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.down_blocks.append(down_block) + + for _ in range(layers_per_block): + controlnet_block = nn.Conv2d( + output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + if not is_final_block: + controlnet_block = nn.Conv2d( + output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + # mid + mid_block_channel = block_out_channels[-1] + + controlnet_block = nn.Conv2d( + mid_block_channel, mid_block_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_mid_block = controlnet_block + + self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=mid_block_channel, + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + + @classmethod + def from_unet( + cls, + unet: UNet2DConditionModel, + controlnet_conditioning_channel_order: str = 'rgb', + conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, + 96, 256), + load_weights_from_unet: bool = True, + ): + r""" + Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`]. + + Parameters: + unet (`UNet2DConditionModel`): + The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied + where applicable. + """ + transformer_layers_per_block = ( + unet.config.transformer_layers_per_block + if 'transformer_layers_per_block' in unet.config else 1) + encoder_hid_dim = unet.config.encoder_hid_dim if 'encoder_hid_dim' in unet.config else None + encoder_hid_dim_type = unet.config.encoder_hid_dim_type if 'encoder_hid_dim_type' in unet.config else None + addition_embed_type = unet.config.addition_embed_type if 'addition_embed_type' in unet.config else None + addition_time_embed_dim = ( + unet.config.addition_time_embed_dim + if 'addition_time_embed_dim' in unet.config else None) + + controlnet = cls( + encoder_hid_dim=encoder_hid_dim, + encoder_hid_dim_type=encoder_hid_dim_type, + addition_embed_type=addition_embed_type, + addition_time_embed_dim=addition_time_embed_dim, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=unet.config.in_channels, + flip_sin_to_cos=unet.config.flip_sin_to_cos, + freq_shift=unet.config.freq_shift, + down_block_types=unet.config.down_block_types, + only_cross_attention=unet.config.only_cross_attention, + block_out_channels=unet.config.block_out_channels, + layers_per_block=unet.config.layers_per_block, + downsample_padding=unet.config.downsample_padding, + mid_block_scale_factor=unet.config.mid_block_scale_factor, + act_fn=unet.config.act_fn, + norm_num_groups=unet.config.norm_num_groups, + norm_eps=unet.config.norm_eps, + cross_attention_dim=unet.config.cross_attention_dim, + attention_head_dim=unet.config.attention_head_dim, + num_attention_heads=unet.config.num_attention_heads, + use_linear_projection=unet.config.use_linear_projection, + class_embed_type=unet.config.class_embed_type, + num_class_embeds=unet.config.num_class_embeds, + upcast_attention=unet.config.upcast_attention, + resnet_time_scale_shift=unet.config.resnet_time_scale_shift, + projection_class_embeddings_input_dim=unet.config. + projection_class_embeddings_input_dim, + controlnet_conditioning_channel_order= + controlnet_conditioning_channel_order, + conditioning_embedding_out_channels= + conditioning_embedding_out_channels, + ) + + if load_weights_from_unet: + controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) + controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) + controlnet.time_embedding.load_state_dict( + unet.time_embedding.state_dict()) + + if controlnet.class_embedding: + controlnet.class_embedding.load_state_dict( + unet.class_embedding.state_dict()) + + controlnet.down_blocks.load_state_dict( + unet.down_blocks.state_dict()) + controlnet.mid_block.load_state_dict(unet.mid_block.state_dict()) + + return controlnet + + @property + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, + processors: Dict[str, + AttentionProcessor]): + if hasattr(module, 'get_processor'): + processors[f'{name}.processor'] = module.get_processor( + return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f'{name}.{sub_name}', child, + processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, + Dict[str, + AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f'A dict of processors was passed, but the number of processors {len(processor)} does not match the' + f' number of attention layers: {count}. Please make sure to pass {count} processor classes.' + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, + processor): + if hasattr(module, 'set_processor'): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f'{name}.processor')) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f'{name}.{sub_name}', child, + processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, 'set_attention_slice'): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == 'auto': + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == 'max': + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance( + slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f'You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different' + f' attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}.' + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError( + f'size {size} has to be smaller or equal to {dim}.') + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, + slice_size: List[int]): + if hasattr(module, 'set_attention_slice'): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: torch.FloatTensor, + fg_mask: Optional[torch.FloatTensor] = None, + conditioning_scale_fg: float = 1.0, + conditioning_scale_bg: float = 1.0, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guess_mode: bool = False, + return_dict: bool = True, + ) -> Union[ControlNetOutput, Tuple]: + """ + The [`ControlNetModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor. + timestep (`Union[torch.Tensor, float, int]`): + The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states. + controlnet_cond (`torch.FloatTensor`): + The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + conditioning_scale (`float`, defaults to `1.0`): + The scale factor for ControlNet outputs. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + added_cond_kwargs (`dict`): + Additional conditions for the Stable Diffusion XL UNet. + cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): + A kwargs dictionary that if specified is passed along to the `AttnProcessor`. + guess_mode (`bool`, defaults to `False`): + In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if + you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. + + Returns: + [`~models.controlnet.ControlNetOutput`] **or** `tuple`: + If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is + returned where the first element is the sample tensor. + """ + # check channel order + channel_order = self.config.controlnet_conditioning_channel_order + + if channel_order == 'rgb': + # in rgb order by default + ... + elif channel_order == 'bgr': + controlnet_cond = torch.flip(controlnet_cond, dims=[1]) + else: + raise ValueError( + f'unknown `controlnet_conditioning_channel_order`: {channel_order}' + ) + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == 'mps' + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], + dtype=dtype, + device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError( + 'class_labels should be provided when num_class_embeds > 0' + ) + + if self.config.class_embed_type == 'timestep': + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + if self.config.addition_embed_type is not None: + if self.config.addition_embed_type == 'text': + aug_emb = self.add_embedding(encoder_hidden_states) + + elif self.config.addition_embed_type == 'text_time': + if 'text_embeds' not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' \ + which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get('text_embeds') + if 'time_ids' not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' \ + which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get('time_ids') + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + + emb = emb + aug_emb if aug_emb is not None else emb + + # 2. pre-process + sample = self.conv_in(sample) + + controlnet_cond_mid = None + if self.return_rgbs: + controlnet_cond, controlnet_cond_mid = self.controlnet_cond_embedding( + controlnet_cond) + else: + controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + + sample = sample + controlnet_cond + + # 3. down + down_block_res_samples = (sample, ) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, 'has_cross_attention' + ) and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block( + hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + + # 5. Control net blocks + + controlnet_down_block_res_samples = () + + for down_block_res_sample, controlnet_block in zip( + down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples = controlnet_down_block_res_samples + ( + down_block_res_sample, ) + + down_block_res_samples = controlnet_down_block_res_samples + + mid_block_res_sample = self.controlnet_mid_block(sample) + + # 6. scaling + if guess_mode and not self.config.global_pool_conditions: + scales = torch.logspace( + -1, 0, len(down_block_res_samples) + 1, + device=sample.device) # 0.1 to 1.0 + + scales = scales * conditioning_scale_fg + down_block_res_samples = [ + sample * scale + for sample, scale in zip(down_block_res_samples, scales) + ] + mid_block_res_sample = mid_block_res_sample * scales[ + -1] # last one + else: + if fg_mask is None: + down_block_res_samples = [ + sample * conditioning_scale_fg + for sample in down_block_res_samples + ] + mid_block_res_sample = mid_block_res_sample * conditioning_scale_fg + else: + down_block_masks = [ + torch.zeros_like(sample) + conditioning_scale_bg + for i, sample in enumerate(down_block_res_samples) + ] + mid_block_mask = torch.zeros_like( + mid_block_res_sample) + conditioning_scale_bg + + for i, sample in enumerate(down_block_masks): + tmp_mask = F.interpolate( + fg_mask, + size=sample.shape[-2:]).repeat(sample.shape[0], + sample.shape[1], 1, + 1).bool() + down_block_masks[i] = sample.masked_fill( + tmp_mask, conditioning_scale_fg) + + tmp_mask = F.interpolate( + fg_mask, size=mid_block_mask.shape[-2:]).repeat( + mid_block_mask.shape[0], mid_block_mask.shape[1], 1, + 1).bool() + mid_block_mask = mid_block_mask.masked_fill( + tmp_mask, conditioning_scale_fg) + + down_block_res_samples = [ + sample * down_block_mask for sample, down_block_mask in + zip(down_block_res_samples, down_block_masks) + ] + mid_block_res_sample = mid_block_res_sample * mid_block_mask + + if self.config.global_pool_conditions: + down_block_res_samples = [ + torch.mean(sample, dim=(2, 3), keepdim=True) + for sample in down_block_res_samples + ] + mid_block_res_sample = torch.mean( + mid_block_res_sample, dim=(2, 3), keepdim=True) + + if not return_dict: + return (controlnet_cond_mid, down_block_res_samples, + mid_block_res_sample) + + return ControlNetOutput( + controlnet_cond_mid=controlnet_cond_mid, + down_block_res_samples=down_block_res_samples, + mid_block_res_sample=mid_block_res_sample) + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module diff --git a/modelscope/models/cv/image_super_resolution_pasd_v2/unet_2d_blocks.py b/modelscope/models/cv/image_super_resolution_pasd_v2/unet_2d_blocks.py new file mode 100644 index 00000000..33de31e6 --- /dev/null +++ b/modelscope/models/cv/image_super_resolution_pasd_v2/unet_2d_blocks.py @@ -0,0 +1,3439 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from diffusers.models.activations import get_activation +from diffusers.models.attention import AdaGroupNorm +from diffusers.models.attention_processor import (Attention, + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0) +from diffusers.models.dual_transformer_2d import DualTransformer2DModel +from diffusers.models.resnet import (Downsample2D, FirDownsample2D, + FirUpsample2D, KDownsample2D, KUpsample2D, + ResnetBlock2D, Upsample2D) +from diffusers.models.transformer_2d import Transformer2DModel +from diffusers.utils import is_torch_version, logging +from einops import rearrange +from torch import nn + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + transformer_layers_per_block=1, + num_attention_heads=None, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift='default', + attention_type='default', + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, + attention_head_dim=None, + downsample_type=None, +): + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warn( + f'It is recommended to provide `attention_head_dim` when calling `get_down_block`. \ + Defaulting `attention_head_dim` to {num_attention_heads}.') + attention_head_dim = num_attention_heads + + down_block_type = down_block_type[7:] if down_block_type.startswith( + 'UNetRes') else down_block_type + if down_block_type == 'DownBlock2D': + return DownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == 'ResnetDownsampleBlock2D': + return ResnetDownsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + ) + elif down_block_type == 'AttnDownBlock2D': + if add_downsample is False: + downsample_type = None + else: + downsample_type = downsample_type or 'conv' # default to 'conv' + return AttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + downsample_type=downsample_type, + ) + elif down_block_type == 'CrossAttnDownBlock2D': + if cross_attention_dim is None: + raise ValueError( + 'cross_attention_dim must be specified for CrossAttnDownBlock2D' + ) + return CrossAttnDownBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + ) + elif down_block_type == 'SimpleCrossAttnDownBlock2D': + if cross_attention_dim is None: + raise ValueError( + 'cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D' + ) + return SimpleCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif down_block_type == 'SkipDownBlock2D': + return SkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == 'AttnSkipDownBlock2D': + return AttnSkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == 'DownEncoderBlock2D': + return DownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == 'AttnDownEncoderBlock2D': + return AttnDownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == 'KDownBlock2D': + return KDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif down_block_type == 'KCrossAttnDownBlock2D': + return KCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + add_self_attention=True if not add_downsample else False, + ) + raise ValueError(f'{down_block_type} does not exist.') + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + transformer_layers_per_block=1, + num_attention_heads=None, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift='default', + attention_type='default', + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, + attention_head_dim=None, + upsample_type=None, +): + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warn( + f'It is recommended to provide `attention_head_dim` when calling `get_up_block`. \ + Defaulting `attention_head_dim` to {num_attention_heads}.') + attention_head_dim = num_attention_heads + + up_block_type = up_block_type[7:] if up_block_type.startswith( + 'UNetRes') else up_block_type + if up_block_type == 'UpBlock2D': + return UpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + transformer_layers_per_block=transformer_layers_per_block, + num_attention_heads=num_attention_heads, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + elif up_block_type == 'ResnetUpsampleBlock2D': + return ResnetUpsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + ) + elif up_block_type == 'CrossAttnUpBlock2D': + if cross_attention_dim is None: + raise ValueError( + 'cross_attention_dim must be specified for CrossAttnUpBlock2D') + return CrossAttnUpBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + ) + elif up_block_type == 'SimpleCrossAttnUpBlock2D': + if cross_attention_dim is None: + raise ValueError( + 'cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D' + ) + return SimpleCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif up_block_type == 'AttnUpBlock2D': + if add_upsample is False: + upsample_type = None + else: + upsample_type = upsample_type or 'conv' # default to 'conv' + + return AttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + upsample_type=upsample_type, + ) + elif up_block_type == 'SkipUpBlock2D': + return SkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == 'AttnSkipUpBlock2D': + return AttnSkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == 'UpDecoderBlock2D': + return UpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels, + ) + elif up_block_type == 'AttnUpDecoderBlock2D': + return AttnUpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels, + ) + elif up_block_type == 'KUpBlock2D': + return KUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif up_block_type == 'KCrossAttnUpBlock2D': + return KCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + ) + + raise ValueError(f'{up_block_type} does not exist.') + + +class AutoencoderTinyBlock(nn.Module): + + def __init__(self, in_channels: int, out_channels: int, act_fn: str): + super().__init__() + act_fn = get_activation(act_fn) + self.conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), + act_fn, + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + act_fn, + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + ) + self.skip = ( + nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) + if in_channels != out_channels else nn.Identity()) + self.fuse = nn.ReLU() + + def forward(self, x): + return self.fuse(self.conv(x) + self.skip(x)) + + +class UNetMidBlock2D(nn.Module): + + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = 'default', # default, spatial + resnet_act_fn: str = 'swish', + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + add_attention: bool = True, + attention_head_dim=1, + output_scale_factor=1.0, + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min( + in_channels // 4, 32) + self.add_attention = add_attention + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + if attention_head_dim is None: + logger.warn( + f'It is not recommend to pass `attention_head_dim=None`. \ + Defaulting `attention_head_dim` to `in_channels`: {in_channels}.' + ) + attention_head_dim = in_channels + + for _ in range(num_layers): + if self.add_attention: + attentions.append( + Attention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups + if resnet_time_scale_shift == 'default' else None, + spatial_norm_dim=temb_channels + if resnet_time_scale_shift == 'spatial' else None, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + )) + else: + attentions.append(None) + + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + )) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + hidden_states = attn(hidden_states, temb=temb) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class UNetMidBlock2DCrossAttn(nn.Module): + + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = 'default', + resnet_act_fn: str = 'swish', + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + attention_type='default', + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min( + in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + )) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + )) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + )) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = { + 'use_reentrant': False + } if is_torch_version('>=', '1.11.0') else {} + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class UNetMidBlock2DSimpleCrossAttn(nn.Module): + + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = 'default', + resnet_act_fn: str = 'swish', + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + skip_time_act=False, + only_cross_attention=False, + cross_attention_norm=None, + ): + super().__init__() + + self.has_cross_attention = True + + self.attention_head_dim = attention_head_dim + resnet_groups = resnet_groups if resnet_groups is not None else min( + in_channels // 4, 32) + + self.num_heads = in_channels // self.attention_head_dim + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ] + attentions = [] + + for _ in range(num_layers): + processor = ( + AttnAddedKVProcessor2_0() if hasattr( + F, 'scaled_dot_product_attention') else + AttnAddedKVProcessor()) + + attentions.append( + Attention( + query_dim=in_channels, + cross_attention_dim=in_channels, + heads=self.num_heads, + dim_head=self.attention_head_dim, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, + )) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + )) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, + # so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param + # for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of + # via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + # attn + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + + # resnet + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class AttnDownBlock2D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = 'default', + resnet_act_fn: str = 'swish', + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim=1, + output_scale_factor=1.0, + downsample_padding=1, + downsample_type='conv', + ): + super().__init__() + resnets = [] + attentions = [] + self.downsample_type = downsample_type + + if attention_head_dim is None: + logger.warn( + f'It is not recommend to pass `attention_head_dim=None`. \ + Defaulting `attention_head_dim` to `in_channels`: {out_channels}.' + ) + attention_head_dim = out_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + )) + attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + )) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if downsample_type == 'conv': + self.downsamplers = nn.ModuleList([ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name='op') + ]) + elif downsample_type == 'resnet': + self.downsamplers = nn.ModuleList([ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + down=True, + ) + ]) + else: + self.downsamplers = None + + def forward(self, hidden_states, temb=None, upsample_size=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + output_states = output_states + (hidden_states, ) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + if self.downsample_type == 'resnet': + hidden_states = downsampler(hidden_states, temb=temb) + else: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states, ) + + return hidden_states, output_states + + +class CrossAttnDownBlock2D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = 'default', + resnet_act_fn: str = 'swish', + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + attention_type='default', + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + )) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + )) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + )) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList([ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name='op') + ]) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + additional_residuals=None, + ): + output_states = () + + blocks = list(zip(self.resnets, self.attentions)) + + for i, (resnet, attn) in enumerate(blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = { + 'use_reentrant': False + } if is_torch_version('>=', '1.11.0') else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals + + output_states = output_states + (hidden_states, ) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states, ) + + return hidden_states, output_states + + +class DownBlock2D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = 'default', + resnet_act_fn: str = 'swish', + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + )) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList([ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name='op') + ]) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version('>=', '1.11.0'): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + use_reentrant=False) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + output_states = output_states + (hidden_states, ) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states, ) + + return hidden_states, output_states + + +class DownEncoderBlock2D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = 'default', + resnet_act_fn: str = 'swish', + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + )) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList([ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name='op') + ]) + else: + self.downsamplers = None + + def forward(self, hidden_states): + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=None) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class AttnDownEncoderBlock2D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = 'default', + resnet_act_fn: str = 'swish', + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim=1, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + attentions = [] + + if attention_head_dim is None: + logger.warn( + f'It is not recommend to pass `attention_head_dim=None`. \ + Defaulting `attention_head_dim` to `in_channels`: {out_channels}.' + ) + attention_head_dim = out_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + )) + attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + )) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList([ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name='op') + ]) + else: + self.downsamplers = None + + def forward(self, hidden_states): + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb=None) + hidden_states = attn(hidden_states) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class AttnSkipDownBlock2D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = 'default', + resnet_act_fn: str = 'swish', + resnet_pre_norm: bool = True, + attention_head_dim=1, + output_scale_factor=np.sqrt(2.0), + add_downsample=True, + ): + super().__init__() + self.attentions = nn.ModuleList([]) + self.resnets = nn.ModuleList([]) + + if attention_head_dim is None: + logger.warn( + f'It is not recommend to pass `attention_head_dim=None`. \ + Defaulting `attention_head_dim` to `in_channels`: {out_channels}.' + ) + attention_head_dim = out_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + self.resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(in_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + )) + self.attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=32, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + )) + + if add_downsample: + self.resnet_down = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + down=True, + kernel='fir', + ) + self.downsamplers = nn.ModuleList( + [FirDownsample2D(out_channels, out_channels=out_channels)]) + self.skip_conv = nn.Conv2d( + 3, out_channels, kernel_size=(1, 1), stride=(1, 1)) + else: + self.resnet_down = None + self.downsamplers = None + self.skip_conv = None + + def forward(self, hidden_states, temb=None, skip_sample=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + output_states += (hidden_states, ) + + if self.downsamplers is not None: + hidden_states = self.resnet_down(hidden_states, temb) + for downsampler in self.downsamplers: + skip_sample = downsampler(skip_sample) + + hidden_states = self.skip_conv(skip_sample) + hidden_states + + output_states += (hidden_states, ) + + return hidden_states, output_states, skip_sample + + +class SkipDownBlock2D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = 'default', + resnet_act_fn: str = 'swish', + resnet_pre_norm: bool = True, + output_scale_factor=np.sqrt(2.0), + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + self.resnets = nn.ModuleList([]) + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + self.resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(in_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + )) + + if add_downsample: + self.resnet_down = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + down=True, + kernel='fir', + ) + self.downsamplers = nn.ModuleList( + [FirDownsample2D(out_channels, out_channels=out_channels)]) + self.skip_conv = nn.Conv2d( + 3, out_channels, kernel_size=(1, 1), stride=(1, 1)) + else: + self.resnet_down = None + self.downsamplers = None + self.skip_conv = None + + def forward(self, hidden_states, temb=None, skip_sample=None): + output_states = () + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb) + output_states += (hidden_states, ) + + if self.downsamplers is not None: + hidden_states = self.resnet_down(hidden_states, temb) + for downsampler in self.downsamplers: + skip_sample = downsampler(skip_sample) + + hidden_states = self.skip_conv(skip_sample) + hidden_states + + output_states += (hidden_states, ) + + return hidden_states, output_states, skip_sample + + +class ResnetDownsampleBlock2D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = 'default', + resnet_act_fn: str = 'swish', + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + skip_time_act=False, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + )) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList([ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + down=True, + ) + ]) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version('>=', '1.11.0'): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + use_reentrant=False) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + output_states = output_states + (hidden_states, ) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, temb) + + output_states = output_states + (hidden_states, ) + + return hidden_states, output_states + + +class SimpleCrossAttnDownBlock2D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = 'default', + resnet_act_fn: str = 'swish', + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_downsample=True, + skip_time_act=False, + only_cross_attention=False, + cross_attention_norm=None, + ): + super().__init__() + + self.has_cross_attention = True + + resnets = [] + attentions = [] + + self.attention_head_dim = attention_head_dim + self.num_heads = out_channels // self.attention_head_dim + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + )) + + processor = ( + AttnAddedKVProcessor2_0() if hasattr( + F, 'scaled_dot_product_attention') else + AttnAddedKVProcessor()) + + attentions.append( + Attention( + query_dim=out_channels, + cross_attention_dim=out_channels, + heads=self.num_heads, + dim_head=attention_head_dim, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, + )) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList([ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + down=True, + ) + ]) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + output_states = () + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, + # so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param + # for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param + # instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + else: + hidden_states = resnet(hidden_states, temb) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + + output_states = output_states + (hidden_states, ) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, temb) + + output_states = output_states + (hidden_states, ) + + return hidden_states, output_states + + +class KDownBlock2D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 4, + resnet_eps: float = 1e-5, + resnet_act_fn: str = 'gelu', + resnet_group_size: int = 32, + add_downsample=False, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=groups, + groups_out=groups_out, + eps=resnet_eps, + non_linearity=resnet_act_fn, + time_embedding_norm='ada_group', + conv_shortcut_bias=False, + )) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + # YiYi's comments- might be able to use FirDownsample2D, look into details later + self.downsamplers = nn.ModuleList([KDownsample2D()]) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version('>=', '1.11.0'): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + use_reentrant=False) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states, ) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states, output_states + + +class KCrossAttnDownBlock2D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + cross_attention_dim: int, + dropout: float = 0.0, + num_layers: int = 4, + resnet_group_size: int = 32, + add_downsample=True, + attention_head_dim: int = 64, + add_self_attention: bool = False, + resnet_eps: float = 1e-5, + resnet_act_fn: str = 'gelu', + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=groups, + groups_out=groups_out, + eps=resnet_eps, + non_linearity=resnet_act_fn, + time_embedding_norm='ada_group', + conv_shortcut_bias=False, + )) + attentions.append( + KAttentionBlock( + out_channels, + out_channels // attention_head_dim, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + temb_channels=temb_channels, + attention_bias=True, + add_self_attention=add_self_attention, + cross_attention_norm='layer_norm', + group_size=resnet_group_size, + )) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + + if add_downsample: + self.downsamplers = nn.ModuleList([KDownsample2D()]) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = { + 'use_reentrant': False + } if is_torch_version('>=', '1.11.0') else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + + if self.downsamplers is None: + output_states += (None, ) + else: + output_states += (hidden_states, ) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states, output_states + + +class AttnUpBlock2D(nn.Module): + + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = 'default', + resnet_act_fn: str = 'swish', + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim=1, + output_scale_factor=1.0, + upsample_type='conv', + ): + super().__init__() + resnets = [] + attentions = [] + + self.upsample_type = upsample_type + + if attention_head_dim is None: + logger.warn( + f'It is not recommend to pass `attention_head_dim=None`. \ + Defaulting `attention_head_dim` to `in_channels`: {out_channels}.' + ) + attention_head_dim = out_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers + - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + )) + attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + )) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if upsample_type == 'conv': + self.upsamplers = nn.ModuleList([ + Upsample2D( + out_channels, use_conv=True, out_channels=out_channels) + ]) + elif upsample_type == 'resnet': + self.upsamplers = nn.ModuleList([ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + up=True, + ) + ]) + else: + self.upsamplers = None + + def forward(self, + hidden_states, + res_hidden_states_tuple, + temb=None, + upsample_size=None): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], + dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + if self.upsample_type == 'resnet': + hidden_states = upsampler(hidden_states, temb=temb) + else: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class CrossAttnUpBlock2D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = 'default', + resnet_act_fn: str = 'swish', + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + attention_type='default', + use_pixelwise_attention=True, + ): + super().__init__() + resnets = [] + attentions = [] + pixel_attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers + - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + )) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + )) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + )) + pixel_attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=res_skip_channels, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) if use_pixelwise_attention else None) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.pixel_attentions = nn.ModuleList(pixel_attentions) + + if add_upsample: + self.upsamplers = nn.ModuleList([ + Upsample2D( + out_channels, use_conv=True, out_channels=out_channels) + ]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + pixelwise_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + for resnet, attn, pix_attn in zip(self.resnets, self.attentions, + self.pixel_attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], + dim=1) + + if pixelwise_hidden_states is not None: + pixelwise_hidden_state = pixelwise_hidden_states[-1] + pixelwise_hidden_states = pixelwise_hidden_states[:-1] + pixelwise_hidden_state = rearrange(pixelwise_hidden_state, + 'b c h w -> b (h w) c') + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = { + 'use_reentrant': False + } if is_torch_version('>=', '1.11.0') else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + if pixelwise_hidden_states is not None: + hidden_states = pix_attn( + hidden_states, + encoder_hidden_states=pixelwise_hidden_state, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=None, + encoder_attention_mask=None, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + if pixelwise_hidden_states is not None: + hidden_states = pix_attn( + hidden_states, + encoder_hidden_states=pixelwise_hidden_state, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=None, + encoder_attention_mask=None, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlock2D(nn.Module): + + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = 'default', + resnet_act_fn: str = 'swish', + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + transformer_layers_per_block: int = 1, + num_attention_heads=1, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + use_pixelwise_attention=True, + ): + super().__init__() + resnets = [] + pixel_attentions = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers + - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + )) + pixel_attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=out_channels, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) if use_pixelwise_attention else None) + + self.resnets = nn.ModuleList(resnets) + self.pixel_attentions = nn.ModuleList(pixel_attentions) + + if add_upsample: + self.upsamplers = nn.ModuleList([ + Upsample2D( + out_channels, use_conv=True, out_channels=out_channels) + ]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, + hidden_states, + res_hidden_states_tuple, + temb=None, + upsample_size=None, + pixelwise_hidden_states=None, + cross_attention_kwargs=None): + for resnet, pix_attn in zip(self.resnets, self.pixel_attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], + dim=1) + + if pixelwise_hidden_states is not None: + pixelwise_hidden_state = pixelwise_hidden_states[-1] + pixelwise_hidden_states = pixelwise_hidden_states[:-1] + pixelwise_hidden_state = rearrange(pixelwise_hidden_state, + 'b c h w -> b (h w) c') + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version('>=', '1.11.0'): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + use_reentrant=False) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + if pixelwise_hidden_states is not None: + hidden_states = pix_attn( + hidden_states, + encoder_hidden_states=pixelwise_hidden_state, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=None, + encoder_attention_mask=None, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpDecoderBlock2D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = 'default', # default, spatial + resnet_act_fn: str = 'swish', + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + temb_channels=None, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + )) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([ + Upsample2D( + out_channels, use_conv=True, out_channels=out_channels) + ]) + else: + self.upsamplers = None + + def forward(self, hidden_states, temb=None): + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class AttnUpDecoderBlock2D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = 'default', + resnet_act_fn: str = 'swish', + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim=1, + output_scale_factor=1.0, + add_upsample=True, + temb_channels=None, + ): + super().__init__() + resnets = [] + attentions = [] + + if attention_head_dim is None: + logger.warn( + f'It is not recommend to pass `attention_head_dim=None`. \ + Defaulting `attention_head_dim` to `out_channels`: {out_channels}.' + ) + attention_head_dim = out_channels + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + )) + attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups + if resnet_time_scale_shift != 'spatial' else None, + spatial_norm_dim=temb_channels + if resnet_time_scale_shift == 'spatial' else None, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + )) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([ + Upsample2D( + out_channels, use_conv=True, out_channels=out_channels) + ]) + else: + self.upsamplers = None + + def forward(self, hidden_states, temb=None): + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb=temb) + hidden_states = attn(hidden_states, temb=temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class AttnSkipUpBlock2D(nn.Module): + + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = 'default', + resnet_act_fn: str = 'swish', + resnet_pre_norm: bool = True, + attention_head_dim=1, + output_scale_factor=np.sqrt(2.0), + add_upsample=True, + ): + super().__init__() + self.attentions = nn.ModuleList([]) + self.resnets = nn.ModuleList([]) + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers + - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + self.resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(resnet_in_channels + res_skip_channels // 4, + 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + )) + + if attention_head_dim is None: + logger.warn( + f'It is not recommend to pass `attention_head_dim=None`. \ + Defaulting `attention_head_dim` to `out_channels`: {out_channels}.' + ) + attention_head_dim = out_channels + + self.attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=32, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + )) + + self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) + if add_upsample: + self.resnet_up = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + up=True, + kernel='fir', + ) + self.skip_conv = nn.Conv2d( + out_channels, + 3, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1)) + self.skip_norm = torch.nn.GroupNorm( + num_groups=min(out_channels // 4, 32), + num_channels=out_channels, + eps=resnet_eps, + affine=True) + self.act = nn.SiLU() + else: + self.resnet_up = None + self.skip_conv = None + self.skip_norm = None + self.act = None + + def forward(self, + hidden_states, + res_hidden_states_tuple, + temb=None, + skip_sample=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], + dim=1) + + hidden_states = resnet(hidden_states, temb) + + hidden_states = self.attentions[0](hidden_states) + + if skip_sample is not None: + skip_sample = self.upsampler(skip_sample) + else: + skip_sample = 0 + + if self.resnet_up is not None: + skip_sample_states = self.skip_norm(hidden_states) + skip_sample_states = self.act(skip_sample_states) + skip_sample_states = self.skip_conv(skip_sample_states) + + skip_sample = skip_sample + skip_sample_states + + hidden_states = self.resnet_up(hidden_states, temb) + + return hidden_states, skip_sample + + +class SkipUpBlock2D(nn.Module): + + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = 'default', + resnet_act_fn: str = 'swish', + resnet_pre_norm: bool = True, + output_scale_factor=np.sqrt(2.0), + add_upsample=True, + upsample_padding=1, + ): + super().__init__() + self.resnets = nn.ModuleList([]) + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers + - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + self.resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min((resnet_in_channels + res_skip_channels) // 4, + 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + )) + + self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) + if add_upsample: + self.resnet_up = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + up=True, + kernel='fir', + ) + self.skip_conv = nn.Conv2d( + out_channels, + 3, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1)) + self.skip_norm = torch.nn.GroupNorm( + num_groups=min(out_channels // 4, 32), + num_channels=out_channels, + eps=resnet_eps, + affine=True) + self.act = nn.SiLU() + else: + self.resnet_up = None + self.skip_conv = None + self.skip_norm = None + self.act = None + + def forward(self, + hidden_states, + res_hidden_states_tuple, + temb=None, + skip_sample=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], + dim=1) + + hidden_states = resnet(hidden_states, temb) + + if skip_sample is not None: + skip_sample = self.upsampler(skip_sample) + else: + skip_sample = 0 + + if self.resnet_up is not None: + skip_sample_states = self.skip_norm(hidden_states) + skip_sample_states = self.act(skip_sample_states) + skip_sample_states = self.skip_conv(skip_sample_states) + + skip_sample = skip_sample + skip_sample_states + + hidden_states = self.resnet_up(hidden_states, temb) + + return hidden_states, skip_sample + + +class ResnetUpsampleBlock2D(nn.Module): + + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = 'default', + resnet_act_fn: str = 'swish', + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + skip_time_act=False, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers + - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + )) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + up=True, + ) + ]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, + hidden_states, + res_hidden_states_tuple, + temb=None, + upsample_size=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], + dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version('>=', '1.11.0'): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + use_reentrant=False) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, temb) + + return hidden_states + + +class SimpleCrossAttnUpBlock2D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = 'default', + resnet_act_fn: str = 'swish', + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + skip_time_act=False, + only_cross_attention=False, + cross_attention_norm=None, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.attention_head_dim = attention_head_dim + + self.num_heads = out_channels // self.attention_head_dim + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers + - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + )) + + processor = ( + AttnAddedKVProcessor2_0() if hasattr( + F, 'scaled_dot_product_attention') else + AttnAddedKVProcessor()) + + attentions.append( + Attention( + query_dim=out_channels, + cross_attention_dim=out_channels, + heads=self.num_heads, + dim_head=self.attention_head_dim, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, + )) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + up=True, + ) + ]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, + # so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param + # for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param + # instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + + for resnet, attn in zip(self.resnets, self.attentions): + # resnet + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], + dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + else: + hidden_states = resnet(hidden_states, temb) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, temb) + + return hidden_states + + +class KUpBlock2D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 5, + resnet_eps: float = 1e-5, + resnet_act_fn: str = 'gelu', + resnet_group_size: Optional[int] = 32, + add_upsample=True, + ): + super().__init__() + resnets = [] + k_in_channels = 2 * out_channels + k_out_channels = in_channels + num_layers = num_layers - 1 + + for i in range(num_layers): + in_channels = k_in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=k_out_channels if + (i == num_layers - 1) else out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=groups, + groups_out=groups_out, + dropout=dropout, + non_linearity=resnet_act_fn, + time_embedding_norm='ada_group', + conv_shortcut_bias=False, + )) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([KUpsample2D()]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, + hidden_states, + res_hidden_states_tuple, + temb=None, + upsample_size=None): + res_hidden_states_tuple = res_hidden_states_tuple[-1] + if res_hidden_states_tuple is not None: + hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], + dim=1) + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version('>=', '1.11.0'): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + use_reentrant=False) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class KCrossAttnUpBlock2D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 4, + resnet_eps: float = 1e-5, + resnet_act_fn: str = 'gelu', + resnet_group_size: int = 32, + attention_head_dim=1, # attention dim_head + cross_attention_dim: int = 768, + add_upsample: bool = True, + upcast_attention: bool = False, + ): + super().__init__() + resnets = [] + attentions = [] + + is_first_block = in_channels == out_channels == temb_channels + is_middle_block = in_channels != out_channels + add_self_attention = True if is_first_block else False + + self.has_cross_attention = True + self.attention_head_dim = attention_head_dim + + # in_channels, and out_channels for the block (k-unet) + k_in_channels = out_channels if is_first_block else 2 * out_channels + k_out_channels = in_channels + + num_layers = num_layers - 1 + + for i in range(num_layers): + in_channels = k_in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + + if is_middle_block and (i == num_layers - 1): + conv_2d_out_channels = k_out_channels + else: + conv_2d_out_channels = None + + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + conv_2d_out_channels=conv_2d_out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=groups, + groups_out=groups_out, + dropout=dropout, + non_linearity=resnet_act_fn, + time_embedding_norm='ada_group', + conv_shortcut_bias=False, + )) + attentions.append( + KAttentionBlock( + k_out_channels if (i == num_layers - 1) else out_channels, + k_out_channels // attention_head_dim if + (i == num_layers + - 1) else out_channels // attention_head_dim, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + temb_channels=temb_channels, + attention_bias=True, + add_self_attention=add_self_attention, + cross_attention_norm='layer_norm', + upcast_attention=upcast_attention, + )) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + + if add_upsample: + self.upsamplers = nn.ModuleList([KUpsample2D()]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + res_hidden_states_tuple = res_hidden_states_tuple[-1] + if res_hidden_states_tuple is not None: + hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], + dim=1) + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = { + 'use_reentrant': False + } if is_torch_version('>=', '1.11.0') else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +# can potentially later be renamed to `No-feed-forward` attention +class KAttentionBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout: float = 0.0, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + upcast_attention: bool = False, + temb_channels: int = 768, # for ada_group_norm + add_self_attention: bool = False, + cross_attention_norm: Optional[str] = None, + group_size: int = 32, + ): + super().__init__() + self.add_self_attention = add_self_attention + + # 1. Self-Attn + if add_self_attention: + self.norm1 = AdaGroupNorm(temb_channels, dim, + max(1, dim // group_size)) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=None, + cross_attention_norm=None, + ) + + # 2. Cross-Attn + self.norm2 = AdaGroupNorm(temb_channels, dim, + max(1, dim // group_size)) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + cross_attention_norm=cross_attention_norm, + ) + + def _to_3d(self, hidden_states, height, weight): + return hidden_states.permute(0, 2, 3, + 1).reshape(hidden_states.shape[0], + height * weight, -1) + + def _to_4d(self, hidden_states, height, weight): + return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], + -1, height, weight) + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + # TODO: mark emb as non-optional (self.norm2 requires it). + # requires assessing impact of change to positional param interface. + emb: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + # 1. Self-Attention + if self.add_self_attention: + norm_hidden_states = self.norm1(hidden_states, emb) + + height, weight = norm_hidden_states.shape[2:] + norm_hidden_states = self._to_3d(norm_hidden_states, height, + weight) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + attn_output = self._to_4d(attn_output, height, weight) + + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention/None + norm_hidden_states = self.norm2(hidden_states, emb) + + height, weight = norm_hidden_states.shape[2:] + norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask + if encoder_hidden_states is None else encoder_attention_mask, + **cross_attention_kwargs, + ) + attn_output = self._to_4d(attn_output, height, weight) + + hidden_states = attn_output + hidden_states + + return hidden_states diff --git a/modelscope/models/cv/image_super_resolution_pasd_v2/unet_2d_condition.py b/modelscope/models/cv/image_super_resolution_pasd_v2/unet_2d_condition.py new file mode 100644 index 00000000..0f68b91f --- /dev/null +++ b/modelscope/models/cv/image_super_resolution_pasd_v2/unet_2d_condition.py @@ -0,0 +1,1101 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import json +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import UNet2DConditionLoadersMixin +from diffusers.models import ModelMixin +from diffusers.models.attention_processor import (AttentionProcessor, + AttnProcessor) +from diffusers.models.embeddings import (GaussianFourierProjection, + TextTimeEmbedding, TimestepEmbedding, + Timesteps) +from diffusers.utils import BaseOutput, logging + +from .unet_2d_blocks import (UNetMidBlock2DCrossAttn, + UNetMidBlock2DSimpleCrossAttn, get_down_block, + get_up_block) + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet2DConditionOutput(BaseOutput): + """ + The output of [`UNet2DConditionModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor = None + + +class UNet2DConditionModel(ModelMixin, ConfigMixin, + UNet2DConditionLoadersMixin): + r""" + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or + `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): + The number of attention heads. If not defined, defaults to `attention_head_dim` + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + addition_time_embed_dim: (`int`, *optional*, defaults to `None`): + Dimension for the timestep embeddings. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, defaults to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, defaults to `None`): + An optional override for the dimension of the projected time embedding. + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. + conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. + projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): + Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the + `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` + otherwise. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + 'CrossAttnDownBlock2D', + 'CrossAttnDownBlock2D', + 'CrossAttnDownBlock2D', + 'DownBlock2D', + ), + mid_block_type: Optional[str] = 'UNetMidBlock2DCrossAttn', + up_block_types: Tuple[str] = ('UpBlock2D', 'CrossAttnUpBlock2D', + 'CrossAttnUpBlock2D', + 'CrossAttnUpBlock2D'), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: Union[int, Tuple[int]] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = 'silu', + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = 'default', + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: int = 1.0, + time_embedding_type: str = 'positional', + time_embedding_dim: Optional[int] = None, + time_embedding_act_fn: Optional[str] = None, + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, + attention_type: str = 'default', + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, + addition_embed_type_num_heads=64, + ): + super().__init__() + + self.sample_size = sample_size + + if num_attention_heads is not None: + raise ValueError( + 'At the moment it is not possible to define the number of attention heads via \ + `num_attention_heads` because of a naming issue as described in\ + https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. \ + Passing `num_attention_heads` will only be supported in diffusers v0.19.' + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in + # https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f'Must provide the same number of `down_block_types` as `up_block_types`. \ + `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}.' + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f'Must provide the same number of `block_out_channels` as `down_block_types`. \ + `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}.' + ) + + if not isinstance( + only_cross_attention, + bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f'Must provide the same number of `only_cross_attention` as `down_block_types`. \ + `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}.' + ) + + if not isinstance( + num_attention_heads, + int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f'Must provide the same number of `num_attention_heads` as `down_block_types`. \ + `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}.' + ) + + if not isinstance( + attention_head_dim, + int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f'Must provide the same number of `attention_head_dim` as `down_block_types`. \ + `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}.' + ) + + if isinstance( + cross_attention_dim, + list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f'Must provide the same number of `cross_attention_dim` as `down_block_types`. \ + `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}.' + ) + + if not isinstance( + layers_per_block, + int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f'Must provide the same number of `layers_per_block` as `down_block_types`. \ + `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}.' + ) + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[0], + kernel_size=conv_in_kernel, + padding=conv_in_padding) + + # time + if time_embedding_type == 'fourier': + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError( + f'`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.' + ) + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, + set_W_to_weight=False, + log=False, + flip_sin_to_cos=flip_sin_to_cos) + timestep_input_dim = time_embed_dim + elif time_embedding_type == 'positional': + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, + freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f'{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`.' + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = 'text_proj' + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info( + "encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined." + ) + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f'`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}.' + ) + + if encoder_hid_dim_type == 'text_proj': + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, + cross_attention_dim) + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, + time_embed_dim) + elif class_embed_type == 'timestep': + self.class_embedding = TimestepEmbedding( + timestep_input_dim, time_embed_dim, act_fn=act_fn) + elif class_embed_type == 'identity': + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == 'projection': + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding( + projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == 'simple_projection': + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear( + projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if addition_embed_type == 'text': + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, + time_embed_dim, + num_heads=addition_embed_type_num_heads) + elif addition_embed_type is not None: + raise ValueError( + f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'." + ) + + if time_embedding_act_fn is None: + self.time_embed_act = None + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + + only_cross_attention = [only_cross_attention + ] * len(down_block_types) + + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + + if isinstance(num_attention_heads, int): + num_attention_heads = ( + num_attention_heads, ) * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim, ) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = ( + cross_attention_dim, ) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block + ] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] + if attention_head_dim[i] is not None else output_channel, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == 'UNetMidBlock2DCrossAttn': + self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + elif mid_block_type == 'UNetMidBlock2DSimpleCrossAttn': + self.mid_block = UNetMidBlock2DSimpleCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim[-1], + attention_head_dim=attention_head_dim[-1], + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif mid_block_type is None: + self.mid_block = None + else: + raise ValueError(f'unknown mid_block_type : {mid_block_type}') + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = list( + reversed(transformer_layers_per_block)) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min( + i + 1, + len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block= + reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] + if attention_head_dim[i] is not None else output_channel, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], + num_groups=norm_num_groups, + eps=norm_eps) + + if act_fn == 'swish': + self.conv_act = lambda x: F.silu(x) + elif act_fn == 'mish': + self.conv_act = nn.Mish() + elif act_fn == 'silu': + self.conv_act = nn.SiLU() + elif act_fn == 'gelu': + self.conv_act = nn.GELU() + + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], + out_channels, + kernel_size=conv_out_kernel, + padding=conv_out_padding) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, + processors: Dict[str, + AttentionProcessor]): + if hasattr(module, 'get_processor'): + processors[f'{name}.processor'] = module.get_processor( + return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f'{name}.{sub_name}', child, + processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, + Dict[str, + AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f'A dict of processors was passed, but the number of processors {len(processor)} does not match the' + f' number of attention layers: {count}. Please make sure to pass {count} processor classes.' + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, + processor): + if hasattr(module, 'set_processor'): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f'{name}.processor')) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f'{name}.{sub_name}', child, + processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, 'set_attention_slice'): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == 'auto': + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == 'max': + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance( + slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f'You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different' + f' attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}.' + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError( + f'size {size} has to be smaller or equal to {dim}.') + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, + slice_size: List[int]): + if hasattr(module, 'set_attention_slice'): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, 'gradient_checkpointing'): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet2DConditionOutput, Tuple]: + r""" + The [`UNet2DConditionModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.FloatTensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + forward_upsample_size = True + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = ( + 1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == 'mps' + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], + dtype=dtype, + device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError( + 'class_labels should be provided when num_class_embeds > 0' + ) + + if self.config.class_embed_type == 'timestep': + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to( + dtype=sample.dtype) + + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + if self.config.addition_embed_type == 'text': + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == 'text_image': + # Kandinsky 2.1 - style + if 'image_embeds' not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' \ + which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get('image_embeds') + text_embs = added_cond_kwargs.get('text_embeds', + encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == 'text_time': + # SDXL - style + if 'text_embeds' not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' \ + which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get('text_embeds') + if 'time_ids' not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' \ + which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get('time_ids') + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == 'image': + # Kandinsky 2.2 - style + if 'image_embeds' not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' \ + which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get('image_embeds') + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == 'image_hint': + # Kandinsky 2.2 - style + if 'image_embeds' not in added_cond_kwargs or 'hint' not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which \ + requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get('image_embeds') + hint = added_cond_kwargs.get('hint') + aug_emb, hint = self.add_embedding(image_embs, hint) + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == 'text_proj': + encoder_hidden_states = self.encoder_hid_proj( + encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == 'text_image_proj': + # Kadinsky 2.1 - style + if 'image_embeds' not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' \ + which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get('image_embeds') + encoder_hidden_states = self.encoder_hid_proj( + encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == 'image_proj': + # Kandinsky 2.2 - style + if 'image_embeds' not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' \ + which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get('image_embeds') + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + # 2. pre-process + sample = self.conv_in(sample) + + # 2.5 GLIGEN position net + if cross_attention_kwargs is not None and cross_attention_kwargs.get( + 'gligen', None) is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() + gligen_args = cross_attention_kwargs.pop('gligen') + cross_attention_kwargs['gligen'] = { + 'objs': self.position_net(**gligen_args) + } + + # 3. down + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None + + down_block_res_samples = (sample, ) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, 'has_cross_attention' + ) and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_block_additional_residuals) > 0: + additional_residuals[ + 'additional_residuals'] = down_block_additional_residuals.pop( + 0) + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block( + hidden_states=sample, temb=emb) + + if is_adapter and len(down_block_additional_residuals) > 0: + sample += down_block_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + ( + down_block_res_sample, ) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + # To support T2I-Adapter-XL + if all([ + is_adapter, + len(down_block_additional_residuals) > 0, + sample.shape == down_block_additional_residuals[0].shape + ]): + sample += down_block_additional_residuals.pop(0) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets):] + down_block_res_samples = down_block_res_samples[:-len( + upsample_block.resnets)] + + down_block_additional_residual = down_block_additional_residuals[ + -len(upsample_block.resnets):] + down_block_additional_residuals = down_block_additional_residuals[:-len( + upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, 'has_cross_attention' + ) and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + pixelwise_hidden_states=down_block_additional_residual, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + pixelwise_hidden_states=down_block_additional_residual, + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample, ) + + return UNet2DConditionOutput(sample=sample) + + @classmethod + def from_pretrained_orig(cls, + pretrained_model_path, + subfolder=None, + **kwargs): + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, + subfolder) + + config_file = os.path.join(pretrained_model_path, 'config.json') + if not os.path.isfile(config_file): + raise RuntimeError(f'{config_file} does not exist') + with open(config_file, 'r') as f: + config = json.load(f) + + from diffusers.utils import WEIGHTS_NAME + model = cls.from_config(config) + model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + if not os.path.isfile(model_file): + raise RuntimeError(f'{model_file} does not exist') + state_dict = torch.load(model_file, map_location='cpu') + for k, v in model.state_dict().items(): + if 'attn2_plus' in k: + state_dict.update({k: v}) + model.load_state_dict(state_dict, strict=False) + + return model diff --git a/modelscope/pipelines/cv/image_super_resolution_pasd_pipeline.py b/modelscope/pipelines/cv/image_super_resolution_pasd_pipeline.py index c4451483..93fd5fd5 100644 --- a/modelscope/pipelines/cv/image_super_resolution_pasd_pipeline.py +++ b/modelscope/pipelines/cv/image_super_resolution_pasd_pipeline.py @@ -8,14 +8,13 @@ import numpy as np import PIL import torch from diffusers import AutoencoderKL, UniPCMultistepScheduler +from torchvision import transforms from torchvision.models import ResNet50_Weights, resnet50 from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from modelscope.metainfo import Pipelines from modelscope.models.cv.image_portrait_enhancement.retinaface import \ detection -from modelscope.models.cv.image_super_resolution_pasd import ( - ControlNetModel, UNet2DConditionModel) from modelscope.models.cv.image_super_resolution_pasd.misc import ( load_dreambooth_lora, wavelet_color_fix) from modelscope.outputs import OutputKeys @@ -48,8 +47,8 @@ class ImageSuperResolutionPASDPipeline(Pipeline): >>> 'image': input_location, >>> 'upscale': 2, >>> 'prompt': prompt, - >>> 'fidelity_scale_fg': 1.5, - >>> 'fidelity_scale_bg': 0.7 + >>> 'fidelity_scale_fg': 1.0, + >>> 'fidelity_scale_bg': 1.0 >>> } >>> pasd = pipeline(Tasks.image_super_resolution_pasd, model='damo/PASD_image_super_resolutions') >>> output = pasd(input)[OutputKeys.OUTPUT_IMG] @@ -69,6 +68,13 @@ class ImageSuperResolutionPASDPipeline(Pipeline): self.device = create_device(device_name) self.config = Config.from_file( os.path.join(model, ModelFile.CONFIGURATION)) + version = self.config.pipeline.get('version', 'pasd_v2') + if version == 'pasd': + from modelscope.models.cv.image_super_resolution_pasd import ( + ControlNetModel, UNet2DConditionModel) + else: + from modelscope.models.cv.image_super_resolution_pasd_v2 import ( + ControlNetModel, UNet2DConditionModel) cfg = self.config.model_cfg dreambooth_lora_ckpt = cfg['dreambooth_lora_ckpt'] tiled_size = cfg['tiled_size'] @@ -123,6 +129,12 @@ class ImageSuperResolutionPASDPipeline(Pipeline): self.face_detector = detection.RetinaFaceDetection( detector_model_path, self.device) + self.resize_preproc = transforms.Compose([ + transforms.Resize( + self.process_size, + interpolation=transforms.InterpolationMode.BILINEAR), + ]) + def preprocess(self, input: Input): return input @@ -145,8 +157,8 @@ class ImageSuperResolutionPASDPipeline(Pipeline): eta = inputs.get('eta', 0.0) prompt = inputs.get('prompt', '') upscale = inputs.get('upscale', 2) - fidelity_scale_fg = inputs.get('fidelity_scale_fg', 1.5) - fidelity_scale_bg = inputs.get('fidelity_scale_bg', 0.7) + fidelity_scale_fg = inputs.get('fidelity_scale_fg', 1.0) + fidelity_scale_bg = inputs.get('fidelity_scale_bg', 1.0) input_image = load_image(inputs['image']).convert('RGB') @@ -164,19 +176,15 @@ class ImageSuperResolutionPASDPipeline(Pipeline): prompt = added_prompt if prompt == '' else f'{prompt}, {added_prompt}' ori_width, ori_height = input_image.size - resize_flag = False + resize_flag = True rscale = upscale - if ori_width < self.process_size // rscale or ori_height < self.process_size // rscale: - scale = (self.process_size // rscale) / min( - ori_width, ori_height) - tmp_image = input_image.resize( - (int(scale * ori_width), int(scale * ori_height))) - - input_image = tmp_image - resize_flag = True input_image = input_image.resize( (input_image.size[0] * rscale, input_image.size[1] * rscale)) + + if min(input_image.size) < self.process_size: + input_image = self.resize_preproc(input_image) + input_image = input_image.resize( (input_image.size[0] // 8 * 8, input_image.size[1] // 8 * 8)) width, height = input_image.size diff --git a/modelscope/pipelines/multi_modal/diffusers_wrapped/pasd_pipeline.py b/modelscope/pipelines/multi_modal/diffusers_wrapped/pasd_pipeline.py index fee262b5..eb48e4ed 100644 --- a/modelscope/pipelines/multi_modal/diffusers_wrapped/pasd_pipeline.py +++ b/modelscope/pipelines/multi_modal/diffusers_wrapped/pasd_pipeline.py @@ -1031,46 +1031,213 @@ class PixelAwareStableDiffusionPipeline(DiffusionPipeline, controlnet_latent_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds - if image is not None: - rgbs, down_block_res_samples, mid_block_res_sample = self.controlnet( - controlnet_latent_model_input, + _, _, h, w = latent_model_input.size() + tile_size, tile_overlap = 120, 32 + if h < tile_size and w < tile_size: + if image is not None: + rgbs, down_block_res_samples, mid_block_res_sample = self.controlnet( + controlnet_latent_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=image, + fg_mask=fg_mask, + conditioning_scale_fg=conditioning_scale_fg, + conditioning_scale_bg=conditioning_scale_bg, + guess_mode=guess_mode, + return_dict=False, + ) + else: + down_block_res_samples, mid_block_res_sample = [ + None + ] * 10, [None] * 10 + + if guess_mode and do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [ + torch.cat([torch.zeros_like(d), d]) + for d in down_block_res_samples + ] + mid_block_res_sample = torch.cat([ + torch.zeros_like(mid_block_res_sample), + mid_block_res_sample + ]) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, t, - encoder_hidden_states=controlnet_prompt_embeds, - controlnet_cond=image, - fg_mask=fg_mask, - conditioning_scale_fg=conditioning_scale_fg, - conditioning_scale_bg=conditioning_scale_bg, - guess_mode=guess_mode, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, return_dict=False, - ) + )[0] else: - down_block_res_samples, mid_block_res_sample = [ - None - ] * 10, [None] * 10 + tile_size = min(tile_size, min(h, w)) + tile_weights = self._gaussian_weights( + tile_size, tile_size, 1).to(latent_model_input.device) - if guess_mode and do_classifier_free_guidance: - # Infered ControlNet only for the conditional batch. - # To apply the output of ControlNet to both the unconditional and conditional batches, - # add 0 to the unconditional batch to keep it unchanged. - down_block_res_samples = [ - torch.cat([torch.zeros_like(d), d]) - for d in down_block_res_samples - ] - mid_block_res_sample = torch.cat([ - torch.zeros_like(mid_block_res_sample), - mid_block_res_sample - ]) + grid_rows = 0 + cur_x = 0 + while cur_x < latent_model_input.size(-1): + cur_x = max( + grid_rows * tile_size - tile_overlap * grid_rows, + 0) + tile_size + grid_rows += 1 - # predict the noise residual - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - return_dict=False, - )[0] + grid_cols = 0 + cur_y = 0 + while cur_y < latent_model_input.size(-2): + cur_y = max( + grid_cols * tile_size - tile_overlap * grid_cols, + 0) + tile_size + grid_cols += 1 + + input_list = [] + cond_list = [] + img_list = [] + fg_mask_list = [] + noise_preds = [] + for row in range(grid_rows): + for col in range(grid_cols): + if col < grid_cols - 1 or row < grid_rows - 1: + # extract tile from input image + ofs_x = max( + row * tile_size - tile_overlap * row, 0) + ofs_y = max( + col * tile_size - tile_overlap * col, 0) + # input tile area on total image + if row == grid_rows - 1: + ofs_x = w - tile_size + if col == grid_cols - 1: + ofs_y = h - tile_size + + input_start_x = ofs_x + input_end_x = ofs_x + tile_size + input_start_y = ofs_y + input_end_y = ofs_y + tile_size + + # input tile dimensions + input_tile = latent_model_input[:, :, + input_start_y: + input_end_y, + input_start_x: + input_end_x] + input_list.append(input_tile) + cond_tile = controlnet_latent_model_input[:, :, + input_start_y: + input_end_y, + input_start_x: + input_end_x] + cond_list.append(cond_tile) + img_tile = image[:, :, + input_start_y * 8:input_end_y * 8, + input_start_x * 8:input_end_x * 8] + img_list.append(img_tile) + if fg_mask is not None: + fg_mask_tile = fg_mask[:, :, input_start_y + * 8:input_end_y * 8, + input_start_x + * 8:input_end_x * 8] + fg_mask_list.append(fg_mask_tile) + + if len(input_list + ) == batch_size or col == grid_cols - 1: + input_list_t = torch.cat(input_list, dim=0) + cond_list_t = torch.cat(cond_list, dim=0) + img_list_t = torch.cat(img_list, dim=0) + if fg_mask is not None: + fg_mask_list_t = torch.cat( + fg_mask_list, dim=0) + else: + fg_mask_list_t = None + + _, down_block_res_samples, mid_block_res_sample = self.controlnet( + cond_list_t, + t, + encoder_hidden_states= + controlnet_prompt_embeds, + controlnet_cond=img_list_t, + fg_mask=fg_mask_list_t, + conditioning_scale_fg=conditioning_scale_fg, + conditioning_scale_bg=conditioning_scale_bg, + guess_mode=guess_mode, + return_dict=False, + ) + + if guess_mode and do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to the unconditional/conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [ + torch.cat([torch.zeros_like(d), d]) + for d in down_block_res_samples + ] + mid_block_res_sample = torch.cat([ + torch.zeros_like(mid_block_res_sample), + mid_block_res_sample + ]) + + # predict the noise residual + model_out = self.unet( + input_list_t, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs= + cross_attention_kwargs, + down_block_additional_residuals= + down_block_res_samples, + mid_block_additional_residual= + mid_block_res_sample, + return_dict=False, + )[0] + + input_list = [] + cond_list = [] + img_list = [] + fg_mask_list = [] + + noise_preds.append(model_out) + + # Stitch noise predictions for all tiles + noise_pred = torch.zeros( + latent_model_input.shape, + device=latent_model_input.device) + contributors = torch.zeros( + latent_model_input.shape, + device=latent_model_input.device) + # Add each tile contribution to overall latents + for row in range(grid_rows): + for col in range(grid_cols): + if col < grid_cols - 1 or row < grid_rows - 1: + # extract tile from input image + ofs_x = max( + row * tile_size - tile_overlap * row, 0) + ofs_y = max( + col * tile_size - tile_overlap * col, 0) + # input tile area on total image + if row == grid_rows - 1: + ofs_x = w - tile_size + if col == grid_cols - 1: + ofs_y = h - tile_size + + input_start_x = ofs_x + input_end_x = ofs_x + tile_size + input_start_y = ofs_y + input_end_y = ofs_y + tile_size + + noise_pred[:, :, input_start_y:input_end_y, + input_start_x: + input_end_x] += noise_preds[ + row * grid_cols + + col] * tile_weights + contributors[:, :, input_start_y:input_end_y, + input_start_x: + input_end_x] += tile_weights + # Average overlapping areas with more than 1 contributor + noise_pred /= contributors # perform guidance if do_classifier_free_guidance: diff --git a/tests/pipelines/test_image_super_resolution_pasd.py b/tests/pipelines/test_image_super_resolution_pasd.py index 6547001a..77fdb3d5 100644 --- a/tests/pipelines/test_image_super_resolution_pasd.py +++ b/tests/pipelines/test_image_super_resolution_pasd.py @@ -15,13 +15,14 @@ class ImageSuperResolutionPASDTest(unittest.TestCase): def setUp(self) -> None: self.model_id = 'damo/PASD_image_super_resolutions' + self.model_v2_id = 'damo/PASD_v2_image_super_resolutions' self.img = 'data/test/images/dogs.jpg' self.input = { 'image': self.img, 'prompt': '', 'upscale': 1, - 'fidelity_scale_fg': 1.5, - 'fidelity_scale_bg': 0.7 + 'fidelity_scale_fg': 1.0, + 'fidelity_scale_bg': 1.0 } self.task = Tasks.image_super_resolution_pasd @@ -38,6 +39,13 @@ class ImageSuperResolutionPASDTest(unittest.TestCase): self.pipeline_inference(super_resolution, self.input) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub_v2(self): + super_resolution = pipeline( + Tasks.image_super_resolution_pasd, model=self.model_v2_id) + + self.pipeline_inference(super_resolution, self.input) + if __name__ == '__main__': unittest.main() From 4cad376298445a15307bf6fdb866debbc2370be0 Mon Sep 17 00:00:00 2001 From: "hemu.zp" Date: Fri, 13 Oct 2023 14:04:04 +0800 Subject: [PATCH 02/21] Add llm_first parameter for pipeline Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/14264249 * support llm_first parameter * register_module(Tasks.text_generation) * fix bug * update format & fix out_base64 for int4 * pre-commit --- modelscope/metainfo.py | 1 + modelscope/models/base/base_model.py | 2 +- modelscope/pipeline_inputs.py | 10 +-- modelscope/pipelines/builder.py | 44 ++++++++++- modelscope/pipelines/nlp/__init__.py | 2 + modelscope/pipelines/nlp/llm_pipeline.py | 11 ++- modelscope/utils/input_output.py | 1 + modelscope/utils/pipeline_inputs.json | 21 ++++-- tests/json_call_test.py | 32 ++++---- tests/pipelines/test_llm_pipeline.py | 95 ++++++++++++++++-------- 10 files changed, 150 insertions(+), 69 deletions(-) diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index a8b93cc3..ea56efb5 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -514,6 +514,7 @@ class Pipelines(object): document_grounded_dialog_generate = 'document-grounded-dialog-generate' language_identification = 'language_identification' machine_reading_comprehension_for_ner = 'machine-reading-comprehension-for-ner' + llm = 'llm' # audio tasks sambert_hifigan_tts = 'sambert-hifigan-tts' diff --git a/modelscope/models/base/base_model.py b/modelscope/models/base/base_model.py index 9f225383..a3b65812 100644 --- a/modelscope/models/base/base_model.py +++ b/modelscope/models/base/base_model.py @@ -117,6 +117,7 @@ class Model(ABC): else: invoked_by = Invoke.PRETRAINED + ignore_file_pattern = kwargs.pop('ignore_file_pattern', None) if osp.exists(model_name_or_path): local_model_dir = model_name_or_path else: @@ -126,7 +127,6 @@ class Model(ABC): ) invoked_by = '%s/%s' % (Invoke.KEY, invoked_by) - ignore_file_pattern = kwargs.pop('ignore_file_pattern', None) local_model_dir = snapshot_download( model_name_or_path, revision, diff --git a/modelscope/pipeline_inputs.py b/modelscope/pipeline_inputs.py index bffbebbd..d97a95f9 100644 --- a/modelscope/pipeline_inputs.py +++ b/modelscope/pipeline_inputs.py @@ -325,13 +325,9 @@ TASK_INPUTS = { }, # ============ nlp tasks =================== - Tasks.chat: [ - InputType.TEXT, - { - 'text': InputType.TEXT, - 'history': InputType.LIST, - } - ], + Tasks.chat: { + 'messages': InputType.LIST + }, Tasks.text_classification: [ InputType.TEXT, (InputType.TEXT, InputType.TEXT), diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index d6dff693..525bc92c 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -1,14 +1,16 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os +import os.path as osp from typing import List, Optional, Union +from modelscope.hub.file_download import model_file_download from modelscope.hub.snapshot_download import snapshot_download from modelscope.metainfo import DEFAULT_MODEL_FOR_PIPELINE, Pipelines from modelscope.models.base import Model -from modelscope.utils.config import ConfigDict, check_config +from modelscope.utils.config import Config, ConfigDict, check_config from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, Invoke, - ThirdParty) + ModelFile, ThirdParty) from modelscope.utils.hub import read_config from modelscope.utils.plugins import (register_modelhub_repo, register_plugins_repo) @@ -117,6 +119,8 @@ def pipeline(task: str = None, model_revision, third_party=third_party, ignore_file_pattern=ignore_file_pattern) + if pipeline_name is None and kwargs.get('llm_first'): + pipeline_name = llm_first_checker(model, model_revision) pipeline_props = {'type': pipeline_name} if pipeline_name is None: # get default pipeline for this task @@ -196,3 +200,39 @@ def get_default_pipeline_info(task): else: pipeline_name, default_model = DEFAULT_MODEL_FOR_PIPELINE[task] return pipeline_name, default_model + + +def llm_first_checker(model: Union[str, List[str], Model, List[Model]], + revision: Optional[str]) -> Optional[str]: + from modelscope.pipelines.nlp.llm_pipeline import LLM_FORMAT_MAP + + def get_file_name(model: str, cfg_name: str, + revision: Optional[str]) -> Optional[str]: + if osp.exists(model): + return osp.join(model, cfg_name) + try: + return model_file_download(model, cfg_name, revision=revision) + except Exception: + return None + + def parse_model_type(file: Optional[str], pattern: str) -> Optional[str]: + if file is None or not osp.exists(file): + return None + return Config.from_file(file).safe_get(pattern) + + def get_model_type(model: str, revision: Optional[str]) -> Optional[str]: + cfg_file = get_file_name(model, ModelFile.CONFIGURATION, revision) + hf_cfg_file = get_file_name(model, ModelFile.CONFIG, revision) + cfg_model_type = parse_model_type(cfg_file, 'model.type') + hf_cfg_model_type = parse_model_type(hf_cfg_file, 'model_type') + return cfg_model_type or hf_cfg_model_type + + if isinstance(model, list): + model = model[0] + if not isinstance(model, str): + model = model.model_dir + model_type = get_model_type(model, revision) + if model_type is not None: + model_type = model_type.lower().split('-')[0] + if model_type in LLM_FORMAT_MAP: + return 'llm' diff --git a/modelscope/pipelines/nlp/__init__.py b/modelscope/pipelines/nlp/__init__.py index 23473007..df7e2068 100644 --- a/modelscope/pipelines/nlp/__init__.py +++ b/modelscope/pipelines/nlp/__init__.py @@ -48,6 +48,7 @@ if TYPE_CHECKING: from .document_grounded_dialog_rerank_pipeline import DocumentGroundedDialogRerankPipeline from .language_identification_pipline import LanguageIdentificationPipeline from .machine_reading_comprehension_pipeline import MachineReadingComprehensionForNERPipeline + from .llm_pipeline import LLMPipeline else: _import_structure = { @@ -119,6 +120,7 @@ else: 'machine_reading_comprehension_pipeline': [ 'MachineReadingComprehensionForNERPipeline' ], + 'llm_pipeline': ['LLMPipeline'], } import sys diff --git a/modelscope/pipelines/nlp/llm_pipeline.py b/modelscope/pipelines/nlp/llm_pipeline.py index 63fc55ea..e2979ccb 100644 --- a/modelscope/pipelines/nlp/llm_pipeline.py +++ b/modelscope/pipelines/nlp/llm_pipeline.py @@ -11,6 +11,7 @@ from modelscope import (AutoModelForCausalLM, AutoTokenizer, Pipeline, snapshot_download) from modelscope.models.base import Model from modelscope.models.nlp import ChatGLM2Tokenizer, Llama2Tokenizer +from modelscope.outputs import OutputKeys from modelscope.pipelines.builder import PIPELINES from modelscope.pipelines.util import is_model, is_official_hub_path from modelscope.utils.constant import Invoke, ModelFile, Tasks @@ -19,7 +20,8 @@ from modelscope.utils.logger import get_logger logger = get_logger() -@PIPELINES.register_module(Tasks.chat, module_name='llm-pipeline') +@PIPELINES.register_module(Tasks.chat, module_name='llm') +@PIPELINES.register_module(Tasks.text_generation, module_name='llm') class LLMPipeline(Pipeline): def initiate_single_model(self, model): @@ -55,6 +57,9 @@ class LLMPipeline(Pipeline): *args, **kwargs): self.device_map = kwargs.pop('device_map', None) + # TODO: qwen-int4 need 'cuda'/'auto' device_map. + if not self.device_map and 'qwen' in kwargs['model'].lower(): + self.device_map = 'cuda' self.torch_dtype = kwargs.pop('torch_dtype', None) self.ignore_file_pattern = kwargs.pop('ignore_file_pattern', None) with self._temp_configuration_file(kwargs): @@ -138,6 +143,8 @@ class LLMPipeline(Pipeline): outputs, skip_special_tokens=True, **kwargs) if is_messages: response = self.format_output(response, **kwargs) + else: + response = {OutputKeys.TEXT: response} return response @@ -260,7 +267,7 @@ def chatglm2_format_output(response, **kwargs): response = response.replace('[[训练时间]]', '2023年') messages = {'role': 'assistant', 'content': response} outputs = { - 'messages': messages, + 'message': messages, } return outputs diff --git a/modelscope/utils/input_output.py b/modelscope/utils/input_output.py index d8e32cce..679069c1 100644 --- a/modelscope/utils/input_output.py +++ b/modelscope/utils/input_output.py @@ -773,6 +773,7 @@ def pipeline_output_to_service_base64_output(task_name, pipeline_output): pipeline_output = pipeline_output[0] for key, value in pipeline_output.items(): if key not in task_outputs: + json_serializable_output[key] = value continue # skip the output not defined. if key in [ OutputKeys.OUTPUT_IMG, OutputKeys.OUTPUT_IMGS, diff --git a/modelscope/utils/pipeline_inputs.json b/modelscope/utils/pipeline_inputs.json index 2ba31bcc..0cb9c1b1 100644 --- a/modelscope/utils/pipeline_inputs.json +++ b/modelscope/utils/pipeline_inputs.json @@ -16,13 +16,20 @@ }, "chat":{ "input":{ - "text":"你有什么推荐吗?", - "history":[ - [ - "今天天气真好,", - "今天天气真好,出去走走怎么样?" - ] - ] + "messages": [{ + "role": "user", + "content": "Hello! 你是谁?" + }, { + "role": "assistant", + "content": "我是你的助手。" + }, { + "role": "user", + "content": "你叫什么名字?" + }] + }, + "parameters": { + "do_sample": true, + "max_length": 512 } }, "domain-specific-object-detection":{ diff --git a/tests/json_call_test.py b/tests/json_call_test.py index 658c947f..7073a90d 100644 --- a/tests/json_call_test.py +++ b/tests/json_call_test.py @@ -37,12 +37,15 @@ class ModelJsonTest: # init pipeline ppl = pipeline( - task=task, model=model_id, model_revision=model_revision) + task=task, + model=model_id, + model_revision=model_revision, + llm_first=True) pipeline_info = get_pipeline_information_by_pipeline(ppl) # call pipeline data = get_task_input_examples(task) - print(task, data) + infer_result = call_pipeline_with_json(pipeline_info, ppl, data) result = pipeline_output_to_service_base64_output(task, infer_result) return result @@ -50,27 +53,20 @@ class ModelJsonTest: if __name__ == '__main__': model_list = [ - 'damo/nlp_structbert_nli_chinese-base', - 'damo/nlp_structbert_word-segmentation_chinese-base', - 'damo/nlp_structbert_zero-shot-classification_chinese-base', - 'damo/cv_unet_person-image-cartoon_compound-models', - 'damo/nlp_structbert_sentiment-classification_chinese-tiny', - 'damo/nlp_csanmt_translation_zh2en', - 'damo/nlp_rom_passage-ranking_chinese-base', - 'damo/ofa_image-caption_muge_base_zh', - 'damo/nlp_raner_named-entity-recognition_chinese-base-ecom-50cls', - 'damo/nlp_structbert_sentiment-classification_chinese-ecommerce-base', - 'damo/text-to-video-synthesis', - 'qwen/Qwen-7B', - 'qwen/Qwen-7B-Chat', - 'ZhipuAI/ChatGLM-6B', + 'qwen/Qwen-7B-Chat-Int4', + 'qwen/Qwen-14B-Chat-Int4', + 'baichuan-inc/Baichuan2-7B-Chat-4bits', + 'baichuan-inc/Baichuan2-13B-Chat-4bits', + 'ZhipuAI/chatglm2-6b-int4', ] tester = ModelJsonTest() for model in model_list: try: res = tester.test_single(model) - print(f'\nmodel_id {model} call_pipeline_with_json run ok.\n') + print( + f'\nmodel_id {model} call_pipeline_with_json run ok. {res}\n\n\n\n' + ) except BaseException as e: print( - f'\nmodel_id {model} call_pipeline_with_json run failed: {e}.\n' + f'\nmodel_id {model} call_pipeline_with_json run failed: {e}.\n\n\n\n' ) diff --git a/tests/pipelines/test_llm_pipeline.py b/tests/pipelines/test_llm_pipeline.py index 1b6d211a..9b7e832f 100644 --- a/tests/pipelines/test_llm_pipeline.py +++ b/tests/pipelines/test_llm_pipeline.py @@ -3,6 +3,7 @@ import unittest import torch +from modelscope import pipeline from modelscope.pipelines.nlp.llm_pipeline import LLMPipeline from modelscope.utils.test_utils import test_level @@ -132,143 +133,172 @@ class LLMPipelineTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_chatglm2(self): - pipe = LLMPipeline(model='ZhipuAI/chatglm2-6b', device_map='auto') + pipe = pipeline( + task='chat', model='ZhipuAI/chatglm2-6b', llm_first=True) print('messages: ', pipe(self.messages_zh, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_chatglm2int4(self): - pipe = LLMPipeline(model='ZhipuAI/chatglm2-6b-int4') + pipe = pipeline( + task='chat', model='ZhipuAI/chatglm2-6b-int4', llm_first=True) print('messages: ', pipe(self.messages_zh, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_chatglm232k(self): - pipe = LLMPipeline(model='ZhipuAI/chatglm2-6b-32k', device_map='auto') + pipe = pipeline( + task='chat', model='ZhipuAI/chatglm2-6b-32k', llm_first=True) print('messages: ', pipe(self.messages_zh, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_llama2(self): - pipe = LLMPipeline( + pipe = pipeline( + task='chat', model='modelscope/Llama-2-7b-ms', torch_dtype=torch.float16, device_map='auto', - ignore_file_pattern=[r'.+\.bin$']) + ignore_file_pattern=[r'.+\.bin$'], + llm_first=True) print('messages: ', pipe(self.messages_en, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_en, **self.gen_cfg)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_llama2chat(self): - pipe = LLMPipeline( + pipe = pipeline( + task='chat', model='modelscope/Llama-2-7b-chat-ms', revision='v1.0.2', torch_dtype=torch.float16, device_map='auto', - ignore_file_pattern=[r'.+\.bin$']) + ignore_file_pattern=[r'.+\.bin$'], + llm_first=True) print('messages: ', pipe(self.messages_en, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_en, **self.gen_cfg)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_codellama(self): - pipe = LLMPipeline( + pipe = pipeline( + task='chat', model='AI-ModelScope/CodeLlama-7b-Instruct-hf', torch_dtype=torch.float16, device_map='auto', - ignore_file_pattern=[r'.+\.bin$']) + ignore_file_pattern=[r'.+\.bin$'], + llm_first=True) print('messages: ', pipe(self.messages_code, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_code, **self.gen_cfg)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_baichuan_7b(self): - pipe = LLMPipeline( + pipe = pipeline( + task='chat', model='baichuan-inc/baichuan-7B', device_map='auto', - torch_dtype=torch.float16) + torch_dtype=torch.float16, + llm_first=True) print('messages: ', pipe(self.messages_zh, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_baichuan_13b(self): - pipe = LLMPipeline( + pipe = pipeline( + task='chat', model='baichuan-inc/Baichuan-13B-Base', device_map='auto', - torch_dtype=torch.float16) + torch_dtype=torch.float16, + llm_first=True) print('messages: ', pipe(self.messages_zh, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_baichuan_13bchat(self): - pipe = LLMPipeline( + pipe = pipeline( + task='chat', model='baichuan-inc/Baichuan-13B-Chat', device_map='auto', - torch_dtype=torch.float16) + torch_dtype=torch.float16, + llm_first=True) print('messages: ', pipe(self.messages_zh, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_baichuan2_7b(self): - pipe = LLMPipeline( + pipe = pipeline( + task='chat', model='baichuan-inc/Baichuan2-7B-Base', device_map='auto', - torch_dtype=torch.float16) + torch_dtype=torch.float16, + llm_first=True) print('messages: ', pipe(self.messages_zh, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_baichuan2_7bchat(self): - pipe = LLMPipeline( + pipe = pipeline( + task='chat', model='baichuan-inc/Baichuan2-7B-Chat', device_map='auto', - torch_dtype=torch.float16) + torch_dtype=torch.float16, + llm_first=True) print('messages: ', pipe(self.messages_zh, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg)) @unittest.skip('Need bitsandbytes') def test_baichuan2_7bchat_int4(self): - pipe = LLMPipeline( + pipe = pipeline( + task='chat', model='baichuan-inc/Baichuan2-7B-Chat-4bits', device_map='auto', - torch_dtype=torch.float16) + torch_dtype=torch.float16, + llm_first=True) print('messages: ', pipe(self.messages_zh, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg)) @unittest.skip('Need bitsandbytes') def test_baichuan2_13bchat_int4(self): - pipe = LLMPipeline( + pipe = pipeline( + task='chat', model='baichuan-inc/Baichuan2-13B-Chat-4bits', device_map='auto', - torch_dtype=torch.float16) + torch_dtype=torch.float16, + llm_first=True) print('messages: ', pipe(self.messages_zh, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_wizardlm_13b(self): - pipe = LLMPipeline( + pipe = pipeline( + task='chat', model='AI-ModelScope/WizardLM-13B-V1.2', device_map='auto', torch_dtype=torch.float16, - format_messages='wizardlm') + format_messages='wizardlm', + llm_first=True) print('messages: ', pipe(self.messages_en, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_en, **self.gen_cfg)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_wizardmath(self): - pipe = LLMPipeline( + pipe = pipeline( + task='chat', model='AI-ModelScope/WizardMath-7B-V1.0', device_map='auto', torch_dtype=torch.float16, - format_messages='wizardcode') + format_messages='wizardcode', + llm_first=True) print('messages: ', pipe(self.message_wizard_math, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_wizard_math, **self.gen_cfg)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_wizardcode_13b(self): - pipe = LLMPipeline( + pipe = pipeline( + task='chat', model='AI-ModelScope/WizardCoder-Python-13B-V1.0', device_map='auto', torch_dtype=torch.float16, - format_messages='wizardcode') + format_messages='wizardcode', + llm_first=True) print('messages: ', pipe(self.message_wizard_code, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_wizard_code, **self.gen_cfg)) @@ -284,19 +314,20 @@ class LLMPipelineTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_qwen(self): - pipe = LLMPipeline(model='qwen/Qwen-7B-Chat', device_map='auto') + pipe = pipeline(task='chat', model='qwen/Qwen-7B-Chat', llm_first=True) print('messages: ', pipe(self.messages_zh_with_system, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg)) @unittest.skip('Need optimum and auto-gptq') def test_qwen_int4(self): - pipe = LLMPipeline(model='qwen/Qwen-7B-Chat-Int4', device_map='auto') + pipe = pipeline( + task='chat', model='qwen/Qwen-7B-Chat-Int4', llm_first=True) print('messages: ', pipe(self.messages_zh_with_system, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_qwen_vl(self): - pipe = LLMPipeline(model='qwen/Qwen-VL-Chat', device_map='auto') + pipe = pipeline(task='chat', model='qwen/Qwen-VL-Chat', llm_first=True) print('messages: ', pipe(self.messages_mm, **self.gen_cfg)) print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg)) From 4ac59e29b0d8c29066a41bb6fd6e6adedae24115 Mon Sep 17 00:00:00 2001 From: "mulin.lyh" Date: Fri, 13 Oct 2023 14:11:30 +0800 Subject: [PATCH 03/21] refactor ci to analyze file dependency MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit import分析不需要实际import文件,通过静态扫描,得到文件中定义的符号,在分析import时找到import的符号所在的文件,从而建立起关联。 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/14281794 * refactor ci * fix typo * fix bug * test * add indirect import name resolve * remove test code --- modelscope/pipelines/cv/ocr_utils/ops.py | 2 + tests/run_analysis.py | 28 ++- tests/utils/case_file_analyzer.py | 3 +- tests/utils/source_file_analyzer.py | 255 +++++++++++++++++------ 4 files changed, 211 insertions(+), 77 deletions(-) diff --git a/modelscope/pipelines/cv/ocr_utils/ops.py b/modelscope/pipelines/cv/ocr_utils/ops.py index 73b58c38..842cc0fb 100644 --- a/modelscope/pipelines/cv/ocr_utils/ops.py +++ b/modelscope/pipelines/cv/ocr_utils/ops.py @@ -16,6 +16,8 @@ from . import utils if tf.__version__ >= '2.0': tf = tf.compat.v1 +# test + # skip parse sys.argv in tf, so fix bug: # absl.flags._exceptions.UnrecognizedFlagError: # Unknown command line flag 'OCRDetectionPipeline: Unknown command line flag diff --git a/tests/run_analysis.py b/tests/run_analysis.py index 1fb12ff6..ac0f2ac9 100644 --- a/tests/run_analysis.py +++ b/tests/run_analysis.py @@ -231,8 +231,8 @@ def get_test_suites_to_run(): affected_pipeline_cases.extend( task_pipeline_test_suite_map[affected_register_module[1]]) else: - logger.warn('Pipeline task: %s has no test case!' - % affected_register_module[1]) + logger.warning('Pipeline task: %s has no test case!' + % affected_register_module[1]) elif affected_register_module[0] == 'MODELS': # ["MODELS", "keyword_spotting", "kws_kwsbp", "GenericKeyWordSpotting"], # ["MODELS", task, model_name, model_class_name] @@ -240,8 +240,8 @@ def get_test_suites_to_run(): affected_pipeline_cases.extend( task_pipeline_test_suite_map[affected_register_module[1]]) else: - logger.warn('Pipeline task: %s has no test case!' - % affected_register_module[1]) + logger.warning('Pipeline task: %s has no test case!' + % affected_register_module[1]) elif affected_register_module[0] == 'TRAINERS': # ["TRAINERS", "", "nlp_base_trainer", "NlpEpochBasedTrainer"], # ["TRAINERS", "", trainer_name, trainer_class_name] @@ -298,13 +298,27 @@ def get_test_suites_to_run(): return test_suites_to_run -def get_files_related_modules(files): +def get_files_related_modules(files, reverse_import_map): register_modules = [] for single_file in files: if single_file.startswith('./modelscope') or \ single_file.startswith('modelscope'): register_modules.extend(get_file_register_modules(single_file)) + while len(register_modules) == 0: + logger.warn('There is no affected register module') + deeper_imported_by = [] + has_deeper_affected_files = False + for source_file in files: + if len(source_file.split('/')) > 4 and source_file.startswith( + 'modelscope'): + deeper_imported_by.extend(reverse_import_map[source_file]) + has_deeper_affected_files = True + if not has_deeper_affected_files: + break + for file in deeper_imported_by: + register_modules = get_file_register_modules(file) + files = deeper_imported_by return register_modules @@ -354,8 +368,8 @@ def get_all_file_test_info(): file_test_info = {} file_test_info['imports'] = import_map[f] file_test_info['imported_by'] = reverse_depend_map[f] - register_modules = get_files_related_modules([f] - + reverse_depend_map[f]) + register_modules = get_files_related_modules( + [f] + reverse_depend_map[f], reverse_depend_map) file_test_info['relate_modules'] = register_modules affected_pipeline_cases, affected_trainer_cases = get_modules_related_cases( register_modules, task_pipeline_test_suite_map, diff --git a/tests/utils/case_file_analyzer.py b/tests/utils/case_file_analyzer.py index 63be95bd..f1b73a20 100644 --- a/tests/utils/case_file_analyzer.py +++ b/tests/utils/case_file_analyzer.py @@ -90,7 +90,8 @@ class AnalysisTestClass(ast.NodeVisitor): if isinstance(item, ast.Name): res.append(self.get_variables(item.id)) elif isinstance(item, ast.Attribute): - res.append(self.get_variables(item.value.id)) + if hasattr(item.value, 'id'): + res.append(self.get_variables(item.value.id)) elif isinstance(item, ast.Str): res.append(self.get_variables(item.s)) elif isinstance(item, ast.Dict): diff --git a/tests/utils/source_file_analyzer.py b/tests/utils/source_file_analyzer.py index ef31c8aa..356d009f 100644 --- a/tests/utils/source_file_analyzer.py +++ b/tests/utils/source_file_analyzer.py @@ -8,24 +8,62 @@ import pkgutil import site import sys +import json + from modelscope.utils.logger import get_logger logger = get_logger() +class AnalysisSourceFileDefines(ast.NodeVisitor): + """Analysis source file function, class, global variable defines. + """ + + def __init__(self, source_file_path) -> None: + super().__init__() + self.global_variables = [] + self.functions = [] + self.classes = [] + self.async_functions = [] + self.symbols = [] + + self.source_file_path = source_file_path + rel_file_path = source_file_path + if os.path.isabs(source_file_path): + rel_file_path = os.path.relpath(source_file_path, os.getcwd()) + + if rel_file_path.endswith('__init__.py'): # processing package + self.base_module_name = os.path.dirname(rel_file_path).replace( + '/', '.') + else: # import x.y.z z is the filename + self.base_module_name = rel_file_path.replace('/', '.').replace( + '.py', '') + self.symbols.append(self.base_module_name) + + def visit_ClassDef(self, node: ast.ClassDef): + self.symbols.append(self.base_module_name + '.' + node.name) + self.classes.append(node.name) + + def visit_FunctionDef(self, node: ast.FunctionDef): + self.symbols.append(self.base_module_name + '.' + node.name) + self.functions.append(node.name) + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef): + self.symbols.append(self.base_module_name + '.' + node.name) + self.async_functions.append(node.name) + + def visit_Assign(self, node: ast.Assign): + for tg in node.targets: + if isinstance(tg, ast.Name): + self.symbols.append(self.base_module_name + '.' + tg.id) + self.global_variables.append(tg.id) + + def is_relative_import(path): # from .x import y or from ..x import y return path.startswith('.') -def resolve_import(module_name): - try: - spec = importlib.util.find_spec(module_name) - return spec and spec.origin - except Exception: - return None - - def convert_to_path(name): if name.startswith('.'): remainder = name.lstrip('.') @@ -39,58 +77,93 @@ def convert_to_path(name): return filename -def resolve_relative_import(source_file_path, module_name): +def resolve_relative_import(source_file_path, module_name, all_symbols): current_package = os.path.dirname(source_file_path).replace('/', '.') absolute_name = importlib.util.resolve_name(module_name, current_package) # get - return resolve_absolute_import(absolute_name) + return resolve_absolute_import(absolute_name, all_symbols) -def onerror(name): - logger.error('Importing module %s error!' % name) +def resolve_absolute_import(module_name, all_symbols): + # direct imports + if module_name in all_symbols: + return all_symbols[module_name] + + # some symble import by package __init__.py, we need find the real file which define the symbel. + parent, sub = module_name.rsplit('.', 1) + + # case module_name is a python Definition + for symbol, symbol_path in all_symbols.items(): + if symbol.startswith(parent) and symbol.endswith(sub): + return all_symbols[symbol] + + return None -def resolve_absolute_import(module_name): - module_file_path = resolve_import(module_name) - if module_file_path is None: - # find from base module. - parent_module, sub_module = module_name.rsplit('.', 1) - if parent_module in sys.modules: - if hasattr(sys.modules[parent_module], '_import_structure'): - import_structure = sys.modules[parent_module]._import_structure - for k, v in import_structure.items(): - if sub_module in v: - parent_module = parent_module + '.' + k - break - module_file_path = resolve_absolute_import(parent_module) - # the parent_module is a package, we need find the module_name's file - if os.path.basename(module_file_path) == '__init__.py' and \ - (os.path.relpath(module_file_path, site.getsitepackages()[0]) != 'modelscope/__init__.py' - or os.path.relpath(module_file_path, os.getcwd()) != 'modelscope/__init__.py'): - for _, sub_module_name, _ in pkgutil.walk_packages( - [os.path.dirname(module_file_path)], - parent_module + '.', - onerror=onerror): - try: - module_ = importlib.import_module(sub_module_name) - for k, v in module_.__dict__.items(): - if k == sub_module and v.__module__ == module_.__name__: - module_file_path = module_.__file__ - break - except ModuleNotFoundError as e: - logger.warn( - 'Import error in %s, ModuleNotFoundError: %s' % - (sub_module_name, e)) - continue - except Exception as e: - logger.warn('Import error in %s, Exception: %s' % - (sub_module_name, e)) - continue +class IndirectDefines(ast.NodeVisitor): + """Analysis source file function, class, global variable defines. + """ + + def __init__(self, source_file_path, all_symbols, + file_symbols_map) -> None: + super().__init__() + self.symbols_map = { + } # key symbol name in current file, value the real file path. + self.all_symbols = all_symbols + self.file_symbols_map = file_symbols_map + self.source_file_path = source_file_path + + rel_file_path = source_file_path + if os.path.isabs(source_file_path): + rel_file_path = os.path.relpath(source_file_path, os.getcwd()) + + if rel_file_path.endswith('__init__.py'): # processing package + self.base_module_name = os.path.dirname(rel_file_path).replace( + '/', '.') + else: # import x.y.z z is the filename + self.base_module_name = rel_file_path.replace('/', '.').replace( + '.py', '') + + # import from will get the symbol in current file. + # from a import b, will get b in current file. + def visit_ImportFrom(self, node): + # level 0 absolute import such as from os.path import join + # level 1 from .x import y + # level 2 from ..x import y + module_name = '.' * node.level + (node.module or '') + for alias in node.names: + file_path = None + if alias.name == '*': # from x import * + if is_relative_import(module_name): + # resolve model path. + file_path = resolve_relative_import( + self.source_file_path, module_name, self.all_symbols) + elif module_name.startswith('modelscope'): + file_path = resolve_absolute_import( + module_name, self.all_symbols) + else: + file_path = None # ignore other package. + if file_path is not None: + for symbol in self.file_symbols_map[file_path][1:]: + symbol_name = symbol.split('.')[-1] + self.symbols_map[self.base_module_name + + symbol_name] = file_path else: - return module_file_path - else: - module_file_path = resolve_absolute_import(parent_module) - return module_file_path + if not module_name.endswith('.'): + module_name = module_name + '.' + name = module_name + alias.name + if alias.asname is not None: + current_module_name = self.base_module_name + '.' + alias.asname + else: + current_module_name = self.base_module_name + '.' + alias.name + if is_relative_import(name): + # resolve model path. + file_path = resolve_relative_import( + self.source_file_path, name, self.all_symbols) + elif name.startswith('modelscope'): + file_path = resolve_absolute_import(name, self.all_symbols) + if file_path is not None: + self.symbols_map[current_module_name] = file_path class AnalysisSourceFileImports(ast.NodeVisitor): @@ -98,23 +171,19 @@ class AnalysisSourceFileImports(ast.NodeVisitor): List imports of the modelscope. """ - def __init__(self, source_file_path) -> None: + def __init__(self, source_file_path, all_symbols) -> None: super().__init__() self.imports = [] self.source_file_path = source_file_path + self.all_symbols = all_symbols def visit_Import(self, node): """Processing import x,y,z or import os.path as osp""" for alias in node.names: if alias.name.startswith('modelscope'): - file_path = resolve_absolute_import(alias.name) - if file_path.startswith(site.getsitepackages()[0]): - self.imports.append( - os.path.relpath(file_path, - site.getsitepackages()[0])) - else: - self.imports.append( - os.path.relpath(file_path, os.getcwd())) + file_path = resolve_absolute_import(alias.name, + self.all_symbols) + self.imports.append(os.path.relpath(file_path, os.getcwd())) def visit_ImportFrom(self, node): # level 0 absolute import such as from os.path import join @@ -126,9 +195,10 @@ class AnalysisSourceFileImports(ast.NodeVisitor): if is_relative_import(module_name): # resolve model path. file_path = resolve_relative_import( - self.source_file_path, module_name) + self.source_file_path, module_name, self.all_symbols) elif module_name.startswith('modelscope'): - file_path = resolve_absolute_import(module_name) + file_path = resolve_absolute_import( + module_name, self.all_symbols) else: file_path = None # ignore other package. else: @@ -138,9 +208,17 @@ class AnalysisSourceFileImports(ast.NodeVisitor): if is_relative_import(name): # resolve model path. file_path = resolve_relative_import( - self.source_file_path, name) + self.source_file_path, name, self.all_symbols) + if file_path is None: + logger.warning( + 'File: %s, import %s%s not exist!' % + (self.source_file_path, module_name, alias.name)) elif name.startswith('modelscope'): - file_path = resolve_absolute_import(name) + file_path = resolve_absolute_import(name, self.all_symbols) + if file_path is None: + logger.warning( + 'File: %s, import %s%s not exist!' % + (self.source_file_path, module_name, alias.name)) else: file_path = None # ignore other package. @@ -152,6 +230,10 @@ class AnalysisSourceFileImports(ast.NodeVisitor): else: self.imports.append( os.path.relpath(file_path, os.getcwd())) + elif module_name.startswith('modelscope'): + logger.warning( + 'File: %s, import %s%s not exist!' % + (self.source_file_path, module_name, alias.name)) class AnalysisSourceFileRegisterModules(ast.NodeVisitor): @@ -216,14 +298,14 @@ class AnalysisSourceFileRegisterModules(ast.NodeVisitor): node.name)) # PIPELINES, task, module, class_name -def get_imported_files(file_path): +def get_imported_files(file_path, all_symbols): """Get file dependencies. """ if os.path.isabs(file_path): file_path = os.path.relpath(file_path, os.getcwd()) with open(file_path, 'rb') as f: src = f.read() - analyzer = AnalysisSourceFileImports(file_path) + analyzer = AnalysisSourceFileImports(file_path, all_symbols) analyzer.visit(ast.parse(src, filename=file_path)) return list(set(analyzer.imports)) @@ -236,7 +318,6 @@ def path_to_module_name(file_path): def get_file_register_modules(file_path): - logger.info('Get file: %s register_module' % file_path) with open(file_path, 'rb') as f: src = f.read() analyzer = AnalysisSourceFileRegisterModules(file_path) @@ -244,15 +325,51 @@ def get_file_register_modules(file_path): return analyzer.register_modules +def get_file_defined_symbols(file_path): + if os.path.isabs(file_path): + file_path = os.path.relpath(file_path, os.getcwd()) + with open(file_path, 'rb') as f: + src = f.read() + analyzer = AnalysisSourceFileDefines(file_path) + analyzer.visit(ast.parse(src, filename=file_path)) + return analyzer.symbols + + +def get_indirect_symbols(file_path, symbols, file_symbols_map): + if os.path.isabs(file_path): + file_path = os.path.relpath(file_path, os.getcwd()) + with open(file_path, 'rb') as f: + src = f.read() + analyzer = IndirectDefines(file_path, symbols, file_symbols_map) + analyzer.visit(ast.parse(src, filename=file_path)) + return analyzer.symbols_map + + def get_import_map(): all_files = [ os.path.join(dp, f) for dp, dn, filenames in os.walk( os.path.join(os.getcwd(), 'modelscope')) for f in filenames if os.path.splitext(f)[1] == '.py' ] + all_symbols = {} + file_symbols_map = {} + for f in all_files: + file_path = os.path.relpath(f, os.getcwd()) + file_symbols_map[file_path] = get_file_defined_symbols(f) + for s in file_symbols_map[file_path]: + all_symbols[s] = file_path + + # get indirect(imported) symbols, refer to origin define. + for f in all_files: + for name, real_path in get_indirect_symbols(f, all_symbols, + file_symbols_map).items(): + all_symbols[name] = os.path.relpath(real_path, os.getcwd()) + + with open('symbols.json', 'w') as f: + json.dump(all_symbols, f) import_map = {} for f in all_files: - files = get_imported_files(f) + files = get_imported_files(f, all_symbols) import_map[os.path.relpath(f, os.getcwd())] = files return import_map From 6e424cdb69b6e8b87a4a8cf9b76e095f6c72596e Mon Sep 17 00:00:00 2001 From: "mulin.lyh" Date: Fri, 13 Oct 2023 19:54:48 +0800 Subject: [PATCH 04/21] force upgrade transformers when build image Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/14293744 * force upgrade transformers when build image --- .dev_scripts/build_image.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.dev_scripts/build_image.sh b/.dev_scripts/build_image.sh index 9775d72e..9ce2a4a8 100644 --- a/.dev_scripts/build_image.sh +++ b/.dev_scripts/build_image.sh @@ -150,7 +150,7 @@ echo -e "Building image with:\npython$python_version\npytorch$torch_version\nten docker_file_content=`cat docker/Dockerfile.ubuntu` if [ "$is_ci_test" != "True" ]; then echo "Building ModelScope lib, will install ModelScope lib to image" - docker_file_content="${docker_file_content} \nRUN pip install --no-cache-dir https://modelscope.oss-cn-beijing.aliyuncs.com/releases/build/modelscope-$modelscope_version-py3-none-any.whl " + docker_file_content="${docker_file_content} \nRUN pip install --no-cache-dir -U transformers && pip install --no-cache-dir https://modelscope.oss-cn-beijing.aliyuncs.com/releases/build/modelscope-$modelscope_version-py3-none-any.whl " fi echo "$is_dsw" if [ "$is_dsw" == "False" ]; then From 049bde9ddff2fa836970e80a3c0fdbfe7bb8ff01 Mon Sep 17 00:00:00 2001 From: "mulin.lyh" Date: Fri, 13 Oct 2023 22:27:51 +0800 Subject: [PATCH 05/21] =?UTF-8?q?move=20venv=20import=20from=20file=20leve?= =?UTF-8?q?l=20to=20class=20level=20to=20avoid=20import=20error=E2=80=A6?= =?UTF-8?q?=20(#575)=20Link:=20https://code.alibaba-inc.com/Ali-MaaS/MaaS-?= =?UTF-8?q?lib/codereview/14301042=20*=20move=20venv=20import=20from=20fil?= =?UTF-8?q?e=20level=20to=20class=20level=20to=20avoid=20import=20error?= =?UTF-8?q?=E2=80=A6=20(#575)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * move venv import from file level to class level to avoid import error on windows --------- authored-by: Zhicheng Zhang --- modelscope/utils/plugins.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelscope/utils/plugins.py b/modelscope/utils/plugins.py index 3d39514a..b4485830 100644 --- a/modelscope/utils/plugins.py +++ b/modelscope/utils/plugins.py @@ -9,7 +9,6 @@ import os import pkgutil import shutil import sys -import venv from contextlib import contextmanager from fnmatch import fnmatch from pathlib import Path @@ -1144,6 +1143,7 @@ class EnvsManager(object): cfg = read_config(model_dir) self.plugins = cfg.get('plugins', []) self.allow_remote = cfg.get('allow_remote', False) + import venv self.env_builder = venv.EnvBuilder( system_site_packages=True, clear=False, From e75f5b4bc41be484e2e2e4dc3dd94d9495c4746f Mon Sep 17 00:00:00 2001 From: "mulin.lyh" Date: Fri, 13 Oct 2023 22:42:08 +0800 Subject: [PATCH 06/21] version to 1.9.3 --- modelscope/version.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modelscope/version.py b/modelscope/version.py index 97e90b1d..f7f006e7 100644 --- a/modelscope/version.py +++ b/modelscope/version.py @@ -1,5 +1,5 @@ # Make sure to modify __release_datetime__ to release time when making official release. -__version__ = '1.9.2' +__version__ = '1.9.3' # default release datetime for branches under active development is set # to be a time far-far-away-into-the-future -__release_datetime__ = '2023-09-06 00:00:00' +__release_datetime__ = '2023-10-17 00:00:00' From 087cb4e463dc5dc022954796df484f3bfe65c9ba Mon Sep 17 00:00:00 2001 From: "mulin.lyh" Date: Mon, 16 Oct 2023 16:13:28 +0800 Subject: [PATCH 07/21] upgrade flash attention to 2.32.2 --- docker/scripts/install_flash_attension.sh | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/docker/scripts/install_flash_attension.sh b/docker/scripts/install_flash_attension.sh index 6a3301c2..f37e567d 100644 --- a/docker/scripts/install_flash_attension.sh +++ b/docker/scripts/install_flash_attension.sh @@ -1,6 +1,4 @@ - git clone -b v1.0.8 https://github.com/Dao-AILab/flash-attention && \ - cd flash-attention && pip install . && \ - pip install csrc/layer_norm && \ - pip install csrc/rotary && \ + git clone -b v2.3.2 https://github.com/Dao-AILab/flash-attention && \ + cd flash-attention && python setup.py install && \ cd .. && \ rm -rf flash-attention From c67b2cfc342c2fc44f861c665cacdb539d1d09db Mon Sep 17 00:00:00 2001 From: "mulin.lyh" Date: Mon, 16 Oct 2023 22:12:31 +0800 Subject: [PATCH 08/21] fix ofa new transformers compatible issue Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/14317517 * fix ofa new transformers compatible issue * fix timm.layers to timm.models.layers compatible issue --- modelscope/models/cv/shop_segmentation/head_fpn.py | 4 ++-- modelscope/models/cv/shop_segmentation/models.py | 4 ++-- modelscope/models/cv/shop_segmentation/neck_fpn.py | 4 ++-- .../models/multi_modal/ofa/tokenization_ofa.py | 12 ++++++------ 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/modelscope/models/cv/shop_segmentation/head_fpn.py b/modelscope/models/cv/shop_segmentation/head_fpn.py index dfa284d4..0d4027cb 100644 --- a/modelscope/models/cv/shop_segmentation/head_fpn.py +++ b/modelscope/models/cv/shop_segmentation/head_fpn.py @@ -9,8 +9,8 @@ import numpy as np import torch import torch.nn as nn from mmcv.cnn import ConvModule -from timm.layers.drop import drop_path -from timm.layers.weight_init import trunc_normal_ +from timm.models.layers.drop import drop_path +from timm.models.layers.weight_init import trunc_normal_ from .common import Upsample, resize diff --git a/modelscope/models/cv/shop_segmentation/models.py b/modelscope/models/cv/shop_segmentation/models.py index 1b07a08c..a206e9f1 100644 --- a/modelscope/models/cv/shop_segmentation/models.py +++ b/modelscope/models/cv/shop_segmentation/models.py @@ -11,8 +11,8 @@ from collections import OrderedDict import torch import torch.nn.functional as F import torch.utils.checkpoint as checkpoint -from timm.layers.drop import drop_path -from timm.layers.weight_init import trunc_normal_ +from timm.models.layers.drop import drop_path +from timm.models.layers.weight_init import trunc_normal_ from torch import nn diff --git a/modelscope/models/cv/shop_segmentation/neck_fpn.py b/modelscope/models/cv/shop_segmentation/neck_fpn.py index 12c11d76..d344de71 100644 --- a/modelscope/models/cv/shop_segmentation/neck_fpn.py +++ b/modelscope/models/cv/shop_segmentation/neck_fpn.py @@ -8,8 +8,8 @@ import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import ConvModule -from timm.layers.drop import drop_path -from timm.layers.weight_init import trunc_normal_ +from timm.models.layers.drop import drop_path +from timm.models.layers.weight_init import trunc_normal_ from .common import resize diff --git a/modelscope/models/multi_modal/ofa/tokenization_ofa.py b/modelscope/models/multi_modal/ofa/tokenization_ofa.py index 77de7a1d..ea79a327 100644 --- a/modelscope/models/multi_modal/ofa/tokenization_ofa.py +++ b/modelscope/models/multi_modal/ofa/tokenization_ofa.py @@ -183,6 +183,12 @@ class OFATokenizerZH(PreTrainedTokenizer): tokenize_chinese_chars=True, strip_accents=None, **kwargs): + if not os.path.isfile(vocab_file): + raise ValueError( + f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained " + 'model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`' + ) + self.vocab = load_vocab(vocab_file) super().__init__( do_lower_case=do_lower_case, do_basic_tokenize=do_basic_tokenize, @@ -199,12 +205,6 @@ class OFATokenizerZH(PreTrainedTokenizer): **kwargs, ) - if not os.path.isfile(vocab_file): - raise ValueError( - f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained " - 'model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`' - ) - self.vocab = load_vocab(vocab_file) self.ids_to_tokens = collections.OrderedDict([ (ids, tok) for tok, ids in self.vocab.items() ]) From 19e7c1c80700ca5f9544fe826b40930d01923e69 Mon Sep 17 00:00:00 2001 From: "mulin.lyh" Date: Mon, 16 Oct 2023 22:12:31 +0800 Subject: [PATCH 09/21] fix ofa new transformers compatible issue Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/14317517 * fix ofa new transformers compatible issue * fix timm.layers to timm.models.layers compatible issue --- modelscope/models/cv/shop_segmentation/head_fpn.py | 4 ++-- modelscope/models/cv/shop_segmentation/models.py | 4 ++-- modelscope/models/cv/shop_segmentation/neck_fpn.py | 4 ++-- .../models/multi_modal/ofa/tokenization_ofa.py | 12 ++++++------ 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/modelscope/models/cv/shop_segmentation/head_fpn.py b/modelscope/models/cv/shop_segmentation/head_fpn.py index dfa284d4..0d4027cb 100644 --- a/modelscope/models/cv/shop_segmentation/head_fpn.py +++ b/modelscope/models/cv/shop_segmentation/head_fpn.py @@ -9,8 +9,8 @@ import numpy as np import torch import torch.nn as nn from mmcv.cnn import ConvModule -from timm.layers.drop import drop_path -from timm.layers.weight_init import trunc_normal_ +from timm.models.layers.drop import drop_path +from timm.models.layers.weight_init import trunc_normal_ from .common import Upsample, resize diff --git a/modelscope/models/cv/shop_segmentation/models.py b/modelscope/models/cv/shop_segmentation/models.py index 1b07a08c..a206e9f1 100644 --- a/modelscope/models/cv/shop_segmentation/models.py +++ b/modelscope/models/cv/shop_segmentation/models.py @@ -11,8 +11,8 @@ from collections import OrderedDict import torch import torch.nn.functional as F import torch.utils.checkpoint as checkpoint -from timm.layers.drop import drop_path -from timm.layers.weight_init import trunc_normal_ +from timm.models.layers.drop import drop_path +from timm.models.layers.weight_init import trunc_normal_ from torch import nn diff --git a/modelscope/models/cv/shop_segmentation/neck_fpn.py b/modelscope/models/cv/shop_segmentation/neck_fpn.py index 12c11d76..d344de71 100644 --- a/modelscope/models/cv/shop_segmentation/neck_fpn.py +++ b/modelscope/models/cv/shop_segmentation/neck_fpn.py @@ -8,8 +8,8 @@ import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import ConvModule -from timm.layers.drop import drop_path -from timm.layers.weight_init import trunc_normal_ +from timm.models.layers.drop import drop_path +from timm.models.layers.weight_init import trunc_normal_ from .common import resize diff --git a/modelscope/models/multi_modal/ofa/tokenization_ofa.py b/modelscope/models/multi_modal/ofa/tokenization_ofa.py index 77de7a1d..ea79a327 100644 --- a/modelscope/models/multi_modal/ofa/tokenization_ofa.py +++ b/modelscope/models/multi_modal/ofa/tokenization_ofa.py @@ -183,6 +183,12 @@ class OFATokenizerZH(PreTrainedTokenizer): tokenize_chinese_chars=True, strip_accents=None, **kwargs): + if not os.path.isfile(vocab_file): + raise ValueError( + f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained " + 'model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`' + ) + self.vocab = load_vocab(vocab_file) super().__init__( do_lower_case=do_lower_case, do_basic_tokenize=do_basic_tokenize, @@ -199,12 +205,6 @@ class OFATokenizerZH(PreTrainedTokenizer): **kwargs, ) - if not os.path.isfile(vocab_file): - raise ValueError( - f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained " - 'model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`' - ) - self.vocab = load_vocab(vocab_file) self.ids_to_tokens = collections.OrderedDict([ (ids, tok) for tok, ids in self.vocab.items() ]) From c01141f97e1cf28c5f67a75e988513b4e01deffe Mon Sep 17 00:00:00 2001 From: Jintao Date: Thu, 12 Oct 2023 10:27:31 +0800 Subject: [PATCH 10/21] fix merge error (#582) --- modelscope/models/base/base_model.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/modelscope/models/base/base_model.py b/modelscope/models/base/base_model.py index a3b65812..8e6d4ae6 100644 --- a/modelscope/models/base/base_model.py +++ b/modelscope/models/base/base_model.py @@ -9,7 +9,7 @@ from modelscope.metainfo import Tasks from modelscope.models.builder import build_backbone, build_model from modelscope.utils.automodel_utils import (can_load_by_ms, try_to_load_hf_model) -from modelscope.utils.config import Config +from modelscope.utils.config import Config, ConfigDict from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke, ModelFile from modelscope.utils.device import verify_device from modelscope.utils.logger import get_logger @@ -142,15 +142,10 @@ class Model(ABC): task_name = cfg.task if 'task' in kwargs: task_name = kwargs.pop('task') - try: - model_cfg = cfg.model - if hasattr(model_cfg, - 'model_type') and not hasattr(model_cfg, 'type'): - model_cfg.type = model_cfg.model_type - model_type = model_cfg.type - except Exception: - model_cfg = {} - model_type = '' + model_cfg = getattr(cfg, 'model', ConfigDict()) + if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'): + model_cfg.type = model_cfg.model_type + model_type = getattr(model_cfg, 'type', None) if isinstance(device, str) and device.startswith('gpu'): device = 'cuda' + device[3:] use_hf = kwargs.pop('use_hf', None) @@ -162,7 +157,7 @@ class Model(ABC): model = try_to_load_hf_model(local_model_dir, task_name, use_hf, **kwargs) if model is not None: - device_map = kwargs.get('device_map', None) + device_map = kwargs.pop('device_map', None) if device_map is None and device is not None: model = model.to(device) return model From f5b83ebd83bf421e13295172213c06b589c8863f Mon Sep 17 00:00:00 2001 From: "mulin.lyh" Date: Tue, 17 Oct 2023 22:15:54 +0800 Subject: [PATCH 11/21] fix chatglm2 can't find tokenizer issue Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/14335080 * fix chatglm2 can't find tokenizer issue --- modelscope/models/nlp/chatglm2/tokenization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelscope/models/nlp/chatglm2/tokenization.py b/modelscope/models/nlp/chatglm2/tokenization.py index 7014dc9c..4523dcdd 100644 --- a/modelscope/models/nlp/chatglm2/tokenization.py +++ b/modelscope/models/nlp/chatglm2/tokenization.py @@ -72,7 +72,6 @@ class ChatGLM2Tokenizer(PreTrainedTokenizer): model_input_names = ['input_ids', 'attention_mask', 'position_ids'] def __init__(self, vocab_file, padding_side='left', **kwargs): - super().__init__(padding_side=padding_side, **kwargs) self.name = 'GLMTokenizer' self.vocab_file = vocab_file @@ -82,6 +81,7 @@ class ChatGLM2Tokenizer(PreTrainedTokenizer): '': self.tokenizer.eos_id, '': self.tokenizer.pad_id } + super().__init__(padding_side=padding_side, **kwargs) def get_command(self, token): if token in self.special_tokens: From 66430171ae3618dc8e86fc39910aa75da54f781f Mon Sep 17 00:00:00 2001 From: "mulin.lyh" Date: Tue, 17 Oct 2023 22:15:54 +0800 Subject: [PATCH 12/21] fix chatglm2 can't find tokenizer issue Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/14335080 * fix chatglm2 can't find tokenizer issue --- modelscope/models/nlp/chatglm2/tokenization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelscope/models/nlp/chatglm2/tokenization.py b/modelscope/models/nlp/chatglm2/tokenization.py index 7014dc9c..4523dcdd 100644 --- a/modelscope/models/nlp/chatglm2/tokenization.py +++ b/modelscope/models/nlp/chatglm2/tokenization.py @@ -72,7 +72,6 @@ class ChatGLM2Tokenizer(PreTrainedTokenizer): model_input_names = ['input_ids', 'attention_mask', 'position_ids'] def __init__(self, vocab_file, padding_side='left', **kwargs): - super().__init__(padding_side=padding_side, **kwargs) self.name = 'GLMTokenizer' self.vocab_file = vocab_file @@ -82,6 +81,7 @@ class ChatGLM2Tokenizer(PreTrainedTokenizer): '': self.tokenizer.eos_id, '': self.tokenizer.pad_id } + super().__init__(padding_side=padding_side, **kwargs) def get_command(self, token): if token in self.special_tokens: From 0908e20da2756ad9434d019550d43e8f7e8e1608 Mon Sep 17 00:00:00 2001 From: Jintao Date: Thu, 12 Oct 2023 10:27:31 +0800 Subject: [PATCH 13/21] fix merge error (#582) --- modelscope/models/base/base_model.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/modelscope/models/base/base_model.py b/modelscope/models/base/base_model.py index a3b65812..8e6d4ae6 100644 --- a/modelscope/models/base/base_model.py +++ b/modelscope/models/base/base_model.py @@ -9,7 +9,7 @@ from modelscope.metainfo import Tasks from modelscope.models.builder import build_backbone, build_model from modelscope.utils.automodel_utils import (can_load_by_ms, try_to_load_hf_model) -from modelscope.utils.config import Config +from modelscope.utils.config import Config, ConfigDict from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke, ModelFile from modelscope.utils.device import verify_device from modelscope.utils.logger import get_logger @@ -142,15 +142,10 @@ class Model(ABC): task_name = cfg.task if 'task' in kwargs: task_name = kwargs.pop('task') - try: - model_cfg = cfg.model - if hasattr(model_cfg, - 'model_type') and not hasattr(model_cfg, 'type'): - model_cfg.type = model_cfg.model_type - model_type = model_cfg.type - except Exception: - model_cfg = {} - model_type = '' + model_cfg = getattr(cfg, 'model', ConfigDict()) + if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'): + model_cfg.type = model_cfg.model_type + model_type = getattr(model_cfg, 'type', None) if isinstance(device, str) and device.startswith('gpu'): device = 'cuda' + device[3:] use_hf = kwargs.pop('use_hf', None) @@ -162,7 +157,7 @@ class Model(ABC): model = try_to_load_hf_model(local_model_dir, task_name, use_hf, **kwargs) if model is not None: - device_map = kwargs.get('device_map', None) + device_map = kwargs.pop('device_map', None) if device_map is None and device is not None: model = model.to(device) return model From f568454bbef187a9e8316f300bd061a340f5bc35 Mon Sep 17 00:00:00 2001 From: "suluyan.sly" Date: Wed, 18 Oct 2023 16:29:13 +0800 Subject: [PATCH 14/21] [swingdeploy] oss examples Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/14342415 * oss_examples --- modelscope/utils/pipeline_inputs.json | 66 +++++++++++++-------------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/modelscope/utils/pipeline_inputs.json b/modelscope/utils/pipeline_inputs.json index 0cb9c1b1..03a00636 100644 --- a/modelscope/utils/pipeline_inputs.json +++ b/modelscope/utils/pipeline_inputs.json @@ -1,17 +1,17 @@ { "action-detection":{ "input":{ - "video":"data/test/videos/action_detection_test_video.mp4" + "video":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/videos/action_detection_test_video.mp4" } }, "action-recognition":{ "input":{ - "video":"data/test/videos/action_recognition_test_video.mp4" + "video":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/videos/action_recognition_test_video.mp4" } }, "animal-recognition":{ "input":{ - "image":"data/test/images/dogs.jpg" + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/dogs.jpg" } }, "chat":{ @@ -34,52 +34,52 @@ }, "domain-specific-object-detection":{ "input":{ - "image":"data/test/images/image_traffic_sign.jpg" + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/image_traffic_sign.jpg" } }, "face-2d-keypoints":{ "input":{ - "image":"data/test/images/face_detection.png" + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/face_detection.png" } }, "face-attribute-recognition":{ "input":{ - "image":"data/test/images/face_recognition_1.png" + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/face_recognition_1.png" } }, "facial-expression-recognition":{ "input":{ - "image":"data/test/images/facial_expression_recognition.jpg" + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/facial_expression_recognition.jpg" } }, "general-recognition":{ "input":{ - "image":"data/test/images/dogs.jpg" + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/dogs.jpg" } }, "human-detection":{ "input":{ - "image":"data/test/images/image_detection.jpg" + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/image_detection.jpg" } }, "image-captioning":{ "input":{ - "image":"data/test/images/image_captioning.png" + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/image_captioning.png" } }, "image-classification":{ "input":{ - "image":"data/test/images/content_check.jpg" + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/content_check.jpg" } }, "image-demoireing":{ "input":{ - "image":"data/test/images/shop_segmentation.jpg" + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/shop_segmentation.jpg" } }, "image-object-detection":{ "input":{ - "image":"data/test/images/image_detection.jpg" + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/image_detection.jpg" } }, "image-portrait-stylization":{ @@ -89,7 +89,7 @@ }, "image-segmentation":{ "input":{ - "image":"data/test/images/image_semantic_segmentation.jpg" + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/image_semantic_segmentation.jpg" }, "parameters":{ @@ -97,18 +97,18 @@ }, "image-text-retrieval":{ "input":{ - "image":"data/test/images/image_mplug_vqa.jpg", + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/image_mplug_vqa.jpg", "text":"What is the woman doing?" } }, "indoor-layout-estimation":{ "input":{ - "image":"data/test/images/image_traffic_sign.jpg" + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/image_traffic_sign.jpg" } }, "live-category":{ "input":{ - "video":"data/test/videos/live_category_test_video.mp4" + "video":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/videos/live_category_test_video.mp4" } }, "motion-generation":{ @@ -132,22 +132,22 @@ }, "ocr-recognition":{ "input":{ - "image":"data/test/images/image_ocr_recognition.jpg" + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/image_ocr_recognition.jpg" } }, "panorama-depth-estimation":{ "input":{ - "image":"data/test/images/panorama_depth_estimation.jpg" + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/panorama_depth_estimation.jpg" } }, "semantic-segmentation":{ "input":{ - "image":"data/test/images/image_salient_detection.jpg" + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/image_salient_detection.jpg" } }, "shop-segmentation":{ "input":{ - "image":"data/test/images/shop_segmentation.jpg" + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/shop_segmentation.jpg" } }, "text-classification":{ @@ -160,7 +160,7 @@ }, "text-driven-segmentation":{ "input":{ - "image":"data/test/images/text_driven_segmentation.jpg", + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/text_driven_segmentation.jpg", "text":"bear" } }, @@ -201,60 +201,60 @@ }, "video-captioning":{ "input":{ - "video":"data/test/videos/video_caption_and_qa_test.mp4" + "video":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/videos/video_caption_and_qa_test.mp4" } }, "video-category":{ "input":{ - "video":"data/test/videos/video_category_test_video.mp4" + "video":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/videos/video_category_test_video.mp4" } }, "video-depth-estimation":{ "input":{ - "video":"data/test/videos/video_depth_estimation.mp4" + "video":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/videos/video_depth_estimation.mp4" } }, "video-embedding":{ "input":{ - "video":"data/test/videos/action_recognition_test_video.mp4" + "video":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/videos/action_recognition_test_video.mp4" } }, "video-multi-object-tracking":{ "input":{ - "video":"data/test/videos/MOT17-03-partial.mp4" + "video":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/videos/MOT17-03-partial.mp4" } }, "video-panoptic-segmentation":{ "input":{ - "video":"data/test/videos/kitti-step_testing_image_02_0000.mp4" + "video":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/videos/kitti-step_testing_image_02_0000.mp4" } }, "video-question-answering":{ "input":{ - "video":"data/test/videos/video_caption_and_qa_test.mp4", + "video":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/videos/video_caption_and_qa_test.mp4", "text":"How many people are there?" } }, "video-summarization":{ "input":{ - "text":"data/test/videos/video_category_test_video.mp4" + "text":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/videos/video_category_test_video.mp4" } }, "visual-entailment":{ "input":{ - "image":"data/test/images/dogs.jpg", + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/dogs.jpg", "text":"there are two birds." } }, "visual-grounding":{ "input":{ - "image":"data/test/images/visual_grounding.png", + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/visual_grounding.png", "text":"a blue turtle-like pokemon with round head" } }, "visual-question-answering":{ "input":{ - "image":"data/test/images/image_mplug_vqa.jpg", + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/image_mplug_vqa.jpg", "text":"What is the woman doing?" } }, From 6201d1bfbc29cce4e57a74f675fcb013a4020ddf Mon Sep 17 00:00:00 2001 From: "suluyan.sly" Date: Wed, 18 Oct 2023 16:29:13 +0800 Subject: [PATCH 15/21] [swingdeploy] oss examples Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/14342415 * oss_examples --- modelscope/utils/pipeline_inputs.json | 66 +++++++++++++-------------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/modelscope/utils/pipeline_inputs.json b/modelscope/utils/pipeline_inputs.json index 0cb9c1b1..03a00636 100644 --- a/modelscope/utils/pipeline_inputs.json +++ b/modelscope/utils/pipeline_inputs.json @@ -1,17 +1,17 @@ { "action-detection":{ "input":{ - "video":"data/test/videos/action_detection_test_video.mp4" + "video":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/videos/action_detection_test_video.mp4" } }, "action-recognition":{ "input":{ - "video":"data/test/videos/action_recognition_test_video.mp4" + "video":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/videos/action_recognition_test_video.mp4" } }, "animal-recognition":{ "input":{ - "image":"data/test/images/dogs.jpg" + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/dogs.jpg" } }, "chat":{ @@ -34,52 +34,52 @@ }, "domain-specific-object-detection":{ "input":{ - "image":"data/test/images/image_traffic_sign.jpg" + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/image_traffic_sign.jpg" } }, "face-2d-keypoints":{ "input":{ - "image":"data/test/images/face_detection.png" + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/face_detection.png" } }, "face-attribute-recognition":{ "input":{ - "image":"data/test/images/face_recognition_1.png" + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/face_recognition_1.png" } }, "facial-expression-recognition":{ "input":{ - "image":"data/test/images/facial_expression_recognition.jpg" + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/facial_expression_recognition.jpg" } }, "general-recognition":{ "input":{ - "image":"data/test/images/dogs.jpg" + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/dogs.jpg" } }, "human-detection":{ "input":{ - "image":"data/test/images/image_detection.jpg" + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/image_detection.jpg" } }, "image-captioning":{ "input":{ - "image":"data/test/images/image_captioning.png" + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/image_captioning.png" } }, "image-classification":{ "input":{ - "image":"data/test/images/content_check.jpg" + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/content_check.jpg" } }, "image-demoireing":{ "input":{ - "image":"data/test/images/shop_segmentation.jpg" + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/shop_segmentation.jpg" } }, "image-object-detection":{ "input":{ - "image":"data/test/images/image_detection.jpg" + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/image_detection.jpg" } }, "image-portrait-stylization":{ @@ -89,7 +89,7 @@ }, "image-segmentation":{ "input":{ - "image":"data/test/images/image_semantic_segmentation.jpg" + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/image_semantic_segmentation.jpg" }, "parameters":{ @@ -97,18 +97,18 @@ }, "image-text-retrieval":{ "input":{ - "image":"data/test/images/image_mplug_vqa.jpg", + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/image_mplug_vqa.jpg", "text":"What is the woman doing?" } }, "indoor-layout-estimation":{ "input":{ - "image":"data/test/images/image_traffic_sign.jpg" + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/image_traffic_sign.jpg" } }, "live-category":{ "input":{ - "video":"data/test/videos/live_category_test_video.mp4" + "video":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/videos/live_category_test_video.mp4" } }, "motion-generation":{ @@ -132,22 +132,22 @@ }, "ocr-recognition":{ "input":{ - "image":"data/test/images/image_ocr_recognition.jpg" + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/image_ocr_recognition.jpg" } }, "panorama-depth-estimation":{ "input":{ - "image":"data/test/images/panorama_depth_estimation.jpg" + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/panorama_depth_estimation.jpg" } }, "semantic-segmentation":{ "input":{ - "image":"data/test/images/image_salient_detection.jpg" + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/image_salient_detection.jpg" } }, "shop-segmentation":{ "input":{ - "image":"data/test/images/shop_segmentation.jpg" + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/shop_segmentation.jpg" } }, "text-classification":{ @@ -160,7 +160,7 @@ }, "text-driven-segmentation":{ "input":{ - "image":"data/test/images/text_driven_segmentation.jpg", + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/text_driven_segmentation.jpg", "text":"bear" } }, @@ -201,60 +201,60 @@ }, "video-captioning":{ "input":{ - "video":"data/test/videos/video_caption_and_qa_test.mp4" + "video":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/videos/video_caption_and_qa_test.mp4" } }, "video-category":{ "input":{ - "video":"data/test/videos/video_category_test_video.mp4" + "video":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/videos/video_category_test_video.mp4" } }, "video-depth-estimation":{ "input":{ - "video":"data/test/videos/video_depth_estimation.mp4" + "video":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/videos/video_depth_estimation.mp4" } }, "video-embedding":{ "input":{ - "video":"data/test/videos/action_recognition_test_video.mp4" + "video":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/videos/action_recognition_test_video.mp4" } }, "video-multi-object-tracking":{ "input":{ - "video":"data/test/videos/MOT17-03-partial.mp4" + "video":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/videos/MOT17-03-partial.mp4" } }, "video-panoptic-segmentation":{ "input":{ - "video":"data/test/videos/kitti-step_testing_image_02_0000.mp4" + "video":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/videos/kitti-step_testing_image_02_0000.mp4" } }, "video-question-answering":{ "input":{ - "video":"data/test/videos/video_caption_and_qa_test.mp4", + "video":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/videos/video_caption_and_qa_test.mp4", "text":"How many people are there?" } }, "video-summarization":{ "input":{ - "text":"data/test/videos/video_category_test_video.mp4" + "text":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/videos/video_category_test_video.mp4" } }, "visual-entailment":{ "input":{ - "image":"data/test/images/dogs.jpg", + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/dogs.jpg", "text":"there are two birds." } }, "visual-grounding":{ "input":{ - "image":"data/test/images/visual_grounding.png", + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/visual_grounding.png", "text":"a blue turtle-like pokemon with round head" } }, "visual-question-answering":{ "input":{ - "image":"data/test/images/image_mplug_vqa.jpg", + "image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/image_mplug_vqa.jpg", "text":"What is the woman doing?" } }, From 2c3bf9629d76d401cfe962106f72685771f1e416 Mon Sep 17 00:00:00 2001 From: "wenmeng.zwm" Date: Wed, 18 Oct 2023 20:24:42 +0800 Subject: [PATCH 16/21] fix chatglm sp_tokenizer error Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/14352495 --- modelscope/models/nlp/chatglm/tokenization.py | 7 ++++--- modelscope/models/nlp/llama/text_generation.py | 2 ++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/modelscope/models/nlp/chatglm/tokenization.py b/modelscope/models/nlp/chatglm/tokenization.py index f5f8cd0c..6ce1b90d 100644 --- a/modelscope/models/nlp/chatglm/tokenization.py +++ b/modelscope/models/nlp/chatglm/tokenization.py @@ -199,6 +199,10 @@ class ChatGLMTokenizer(PreTrainedTokenizer): padding_side='left', num_image_tokens=20000, **kwargs) -> None: + + self.sp_tokenizer = SPTokenizer( + vocab_file, num_image_tokens=num_image_tokens) + super().__init__( do_lower_case=do_lower_case, remove_space=remove_space, @@ -220,9 +224,6 @@ class ChatGLMTokenizer(PreTrainedTokenizer): self.end_token = end_token self.mask_token = mask_token self.gmask_token = gmask_token - - self.sp_tokenizer = SPTokenizer( - vocab_file, num_image_tokens=num_image_tokens) """ Initialisation """ @property diff --git a/modelscope/models/nlp/llama/text_generation.py b/modelscope/models/nlp/llama/text_generation.py index b9cc8032..d95cae34 100644 --- a/modelscope/models/nlp/llama/text_generation.py +++ b/modelscope/models/nlp/llama/text_generation.py @@ -71,6 +71,8 @@ def get_chat_prompt(system: str, text: str, history: List[Tuple[str, str]], # This file is mainly copied from the llama code of transformers +@MODELS.register_module(Tasks.chat, module_name=Models.llama2) +@MODELS.register_module(Tasks.chat, module_name=Models.llama) @MODELS.register_module(Tasks.text_generation, module_name=Models.llama2) @MODELS.register_module(Tasks.chat, module_name=Models.llama2) @MODELS.register_module(Tasks.text_generation, module_name=Models.llama) From e10237074e4129dbd17457e08bf21d59e496f785 Mon Sep 17 00:00:00 2001 From: "wenmeng.zwm" Date: Wed, 18 Oct 2023 20:24:42 +0800 Subject: [PATCH 17/21] fix chatglm sp_tokenizer error Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/14352495 --- modelscope/models/nlp/chatglm/tokenization.py | 7 ++++--- modelscope/models/nlp/llama/text_generation.py | 2 ++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/modelscope/models/nlp/chatglm/tokenization.py b/modelscope/models/nlp/chatglm/tokenization.py index f5f8cd0c..6ce1b90d 100644 --- a/modelscope/models/nlp/chatglm/tokenization.py +++ b/modelscope/models/nlp/chatglm/tokenization.py @@ -199,6 +199,10 @@ class ChatGLMTokenizer(PreTrainedTokenizer): padding_side='left', num_image_tokens=20000, **kwargs) -> None: + + self.sp_tokenizer = SPTokenizer( + vocab_file, num_image_tokens=num_image_tokens) + super().__init__( do_lower_case=do_lower_case, remove_space=remove_space, @@ -220,9 +224,6 @@ class ChatGLMTokenizer(PreTrainedTokenizer): self.end_token = end_token self.mask_token = mask_token self.gmask_token = gmask_token - - self.sp_tokenizer = SPTokenizer( - vocab_file, num_image_tokens=num_image_tokens) """ Initialisation """ @property diff --git a/modelscope/models/nlp/llama/text_generation.py b/modelscope/models/nlp/llama/text_generation.py index b9cc8032..d95cae34 100644 --- a/modelscope/models/nlp/llama/text_generation.py +++ b/modelscope/models/nlp/llama/text_generation.py @@ -71,6 +71,8 @@ def get_chat_prompt(system: str, text: str, history: List[Tuple[str, str]], # This file is mainly copied from the llama code of transformers +@MODELS.register_module(Tasks.chat, module_name=Models.llama2) +@MODELS.register_module(Tasks.chat, module_name=Models.llama) @MODELS.register_module(Tasks.text_generation, module_name=Models.llama2) @MODELS.register_module(Tasks.chat, module_name=Models.llama2) @MODELS.register_module(Tasks.text_generation, module_name=Models.llama) From f493ed007ce66ae61262a5ac7db5b0f8346e0fc1 Mon Sep 17 00:00:00 2001 From: "mulin.lyh" Date: Wed, 18 Oct 2023 21:17:51 +0800 Subject: [PATCH 18/21] force rebuid image --- .dev_scripts/build_image.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.dev_scripts/build_image.sh b/.dev_scripts/build_image.sh index 9ce2a4a8..386e5aad 100644 --- a/.dev_scripts/build_image.sh +++ b/.dev_scripts/build_image.sh @@ -150,7 +150,7 @@ echo -e "Building image with:\npython$python_version\npytorch$torch_version\nten docker_file_content=`cat docker/Dockerfile.ubuntu` if [ "$is_ci_test" != "True" ]; then echo "Building ModelScope lib, will install ModelScope lib to image" - docker_file_content="${docker_file_content} \nRUN pip install --no-cache-dir -U transformers && pip install --no-cache-dir https://modelscope.oss-cn-beijing.aliyuncs.com/releases/build/modelscope-$modelscope_version-py3-none-any.whl " + docker_file_content="${docker_file_content} \nRUN pip install --no-cache-dir https://modelscope.oss-cn-beijing.aliyuncs.com/releases/build/modelscope-$modelscope_version-py3-none-any.whl && pip install --no-cache-dir -U transformers" fi echo "$is_dsw" if [ "$is_dsw" == "False" ]; then From 0390bdea676a1804ddd7099b6d8aef57a695fda2 Mon Sep 17 00:00:00 2001 From: "mulin.lyh" Date: Thu, 19 Oct 2023 12:12:35 +0800 Subject: [PATCH 19/21] remove llama2 dup in chat task Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/14359280 * remove llama2 dup in chat task --- modelscope/models/nlp/llama/text_generation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modelscope/models/nlp/llama/text_generation.py b/modelscope/models/nlp/llama/text_generation.py index d95cae34..45b9d5f0 100644 --- a/modelscope/models/nlp/llama/text_generation.py +++ b/modelscope/models/nlp/llama/text_generation.py @@ -74,7 +74,6 @@ def get_chat_prompt(system: str, text: str, history: List[Tuple[str, str]], @MODELS.register_module(Tasks.chat, module_name=Models.llama2) @MODELS.register_module(Tasks.chat, module_name=Models.llama) @MODELS.register_module(Tasks.text_generation, module_name=Models.llama2) -@MODELS.register_module(Tasks.chat, module_name=Models.llama2) @MODELS.register_module(Tasks.text_generation, module_name=Models.llama) class LlamaForTextGeneration(MsModelMixin, LlamaForCausalLM, TorchModel): From 8e187bdb962b7671a3d9384f80100cfc0cf4094c Mon Sep 17 00:00:00 2001 From: "mulin.lyh" Date: Thu, 19 Oct 2023 12:12:35 +0800 Subject: [PATCH 20/21] remove llama2 dup in chat task Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/14359280 * remove llama2 dup in chat task --- modelscope/models/nlp/llama/text_generation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modelscope/models/nlp/llama/text_generation.py b/modelscope/models/nlp/llama/text_generation.py index d95cae34..45b9d5f0 100644 --- a/modelscope/models/nlp/llama/text_generation.py +++ b/modelscope/models/nlp/llama/text_generation.py @@ -74,7 +74,6 @@ def get_chat_prompt(system: str, text: str, history: List[Tuple[str, str]], @MODELS.register_module(Tasks.chat, module_name=Models.llama2) @MODELS.register_module(Tasks.chat, module_name=Models.llama) @MODELS.register_module(Tasks.text_generation, module_name=Models.llama2) -@MODELS.register_module(Tasks.chat, module_name=Models.llama2) @MODELS.register_module(Tasks.text_generation, module_name=Models.llama) class LlamaForTextGeneration(MsModelMixin, LlamaForCausalLM, TorchModel): From b14f3464e59fce10e1eec665bfafd4cbaa252df8 Mon Sep 17 00:00:00 2001 From: "mulin.lyh" Date: Thu, 19 Oct 2023 12:30:10 +0800 Subject: [PATCH 21/21] force rebuild image --- .dev_scripts/build_image.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.dev_scripts/build_image.sh b/.dev_scripts/build_image.sh index 386e5aad..c1e61890 100644 --- a/.dev_scripts/build_image.sh +++ b/.dev_scripts/build_image.sh @@ -150,7 +150,7 @@ echo -e "Building image with:\npython$python_version\npytorch$torch_version\nten docker_file_content=`cat docker/Dockerfile.ubuntu` if [ "$is_ci_test" != "True" ]; then echo "Building ModelScope lib, will install ModelScope lib to image" - docker_file_content="${docker_file_content} \nRUN pip install --no-cache-dir https://modelscope.oss-cn-beijing.aliyuncs.com/releases/build/modelscope-$modelscope_version-py3-none-any.whl && pip install --no-cache-dir -U transformers" + docker_file_content="${docker_file_content} \nRUN pip install --no-cache-dir numpy https://modelscope.oss-cn-beijing.aliyuncs.com/releases/build/modelscope-$modelscope_version-py3-none-any.whl && pip install --no-cache-dir -U transformers" fi echo "$is_dsw" if [ "$is_dsw" == "False" ]; then