mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
Merge pull request #591 from modelscope/master-merge-internal20231019
Master merge internal20231019
This commit is contained in:
@@ -150,7 +150,7 @@ echo -e "Building image with:\npython$python_version\npytorch$torch_version\nten
|
||||
docker_file_content=`cat docker/Dockerfile.ubuntu`
|
||||
if [ "$is_ci_test" != "True" ]; then
|
||||
echo "Building ModelScope lib, will install ModelScope lib to image"
|
||||
docker_file_content="${docker_file_content} \nRUN pip install --no-cache-dir https://modelscope.oss-cn-beijing.aliyuncs.com/releases/build/modelscope-$modelscope_version-py3-none-any.whl "
|
||||
docker_file_content="${docker_file_content} \nRUN pip install --no-cache-dir numpy https://modelscope.oss-cn-beijing.aliyuncs.com/releases/build/modelscope-$modelscope_version-py3-none-any.whl && pip install --no-cache-dir -U transformers"
|
||||
fi
|
||||
echo "$is_dsw"
|
||||
if [ "$is_dsw" == "False" ]; then
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
git clone -b v1.0.8 https://github.com/Dao-AILab/flash-attention && \
|
||||
cd flash-attention && pip install . && \
|
||||
pip install csrc/layer_norm && \
|
||||
pip install csrc/rotary && \
|
||||
git clone -b v2.3.2 https://github.com/Dao-AILab/flash-attention && \
|
||||
cd flash-attention && python setup.py install && \
|
||||
cd .. && \
|
||||
rm -rf flash-attention
|
||||
|
||||
@@ -514,6 +514,7 @@ class Pipelines(object):
|
||||
document_grounded_dialog_generate = 'document-grounded-dialog-generate'
|
||||
language_identification = 'language_identification'
|
||||
machine_reading_comprehension_for_ner = 'machine-reading-comprehension-for-ner'
|
||||
llm = 'llm'
|
||||
|
||||
# audio tasks
|
||||
sambert_hifigan_tts = 'sambert-hifigan-tts'
|
||||
|
||||
@@ -117,6 +117,7 @@ class Model(ABC):
|
||||
else:
|
||||
invoked_by = Invoke.PRETRAINED
|
||||
|
||||
ignore_file_pattern = kwargs.pop('ignore_file_pattern', None)
|
||||
if osp.exists(model_name_or_path):
|
||||
local_model_dir = model_name_or_path
|
||||
else:
|
||||
@@ -126,7 +127,6 @@ class Model(ABC):
|
||||
)
|
||||
|
||||
invoked_by = '%s/%s' % (Invoke.KEY, invoked_by)
|
||||
ignore_file_pattern = kwargs.pop('ignore_file_pattern', None)
|
||||
local_model_dir = snapshot_download(
|
||||
model_name_or_path,
|
||||
revision,
|
||||
|
||||
@@ -15,10 +15,10 @@ from . import (action_recognition, animal_recognition, bad_image_detecting,
|
||||
image_quality_assessment_man, image_quality_assessment_mos,
|
||||
image_reid_person, image_restoration,
|
||||
image_semantic_segmentation, image_super_resolution_pasd,
|
||||
image_to_image_generation, image_to_image_translation,
|
||||
language_guided_video_summarization, movie_scene_segmentation,
|
||||
object_detection, panorama_depth_estimation,
|
||||
pedestrian_attribute_recognition,
|
||||
image_super_resolution_pasd_v2, image_to_image_generation,
|
||||
image_to_image_translation, language_guided_video_summarization,
|
||||
movie_scene_segmentation, object_detection,
|
||||
panorama_depth_estimation, pedestrian_attribute_recognition,
|
||||
pointcloud_sceneflow_estimation, product_retrieval_embedding,
|
||||
referring_video_object_segmentation,
|
||||
robust_image_classification, salient_detection,
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
from .controlnet import ControlNetModel
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'unet_2d_condition': ['UNet2DConditionModel'],
|
||||
'controlnet': ['ControlNetModel']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
@@ -0,0 +1,980 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.models.attention_processor import (AttentionProcessor,
|
||||
AttnProcessor)
|
||||
from diffusers.models.embeddings import (TextImageProjection,
|
||||
TextImageTimeEmbedding,
|
||||
TextTimeEmbedding, TimestepEmbedding,
|
||||
Timesteps)
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.utils import BaseOutput, logging
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from modelscope.models.cv.super_resolution.rrdbnet_arch import RRDB
|
||||
from .unet_2d_blocks import (CrossAttnDownBlock2D, DownBlock2D,
|
||||
UNetMidBlock2DCrossAttn, get_down_block)
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class ControlNetOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`ControlNetModel`].
|
||||
|
||||
Args:
|
||||
down_block_res_samples (`tuple[torch.Tensor]`):
|
||||
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
|
||||
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
||||
used to condition the original UNet's downsampling activations.
|
||||
mid_down_block_re_sample (`torch.Tensor`):
|
||||
The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
|
||||
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
|
||||
Output can be used to condition the original UNet's middle block activation.
|
||||
"""
|
||||
|
||||
down_block_res_samples: Tuple[torch.Tensor]
|
||||
mid_block_res_sample: torch.Tensor
|
||||
|
||||
|
||||
class ControlNetConditioningEmbedding(nn.Module):
|
||||
"""
|
||||
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
|
||||
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
|
||||
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
|
||||
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
|
||||
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
|
||||
model) to encode image-space conditions ... into feature maps ..."
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conditioning_embedding_channels: int,
|
||||
conditioning_channels: int = 3,
|
||||
block_out_channels: Tuple[int] = (16, 32, 96, 256),
|
||||
return_rgbs: bool = True,
|
||||
use_rrdb: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.return_rgbs = return_rgbs
|
||||
self.use_rrdb = use_rrdb
|
||||
|
||||
self.conv_in = nn.Conv2d(
|
||||
conditioning_channels,
|
||||
block_out_channels[0],
|
||||
kernel_size=3,
|
||||
padding=1)
|
||||
|
||||
if self.use_rrdb:
|
||||
num_rrdb_block = 2
|
||||
layers = (
|
||||
RRDB(block_out_channels[0], block_out_channels[0])
|
||||
for i in range(num_rrdb_block))
|
||||
self.preprocesser = nn.Sequential(*layers)
|
||||
|
||||
self.blocks = nn.ModuleList([])
|
||||
self.to_rgbs = nn.ModuleList([])
|
||||
|
||||
for i in range(len(block_out_channels) - 1):
|
||||
channel_in = block_out_channels[i]
|
||||
channel_out = block_out_channels[i + 1]
|
||||
self.blocks.append(
|
||||
nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
|
||||
self.blocks.append(
|
||||
nn.Conv2d(
|
||||
channel_in,
|
||||
channel_out,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
stride=2))
|
||||
|
||||
if return_rgbs:
|
||||
self.to_rgbs.append(
|
||||
nn.Conv2d(channel_out, 3, kernel_size=3,
|
||||
padding=1)) # channel_in
|
||||
|
||||
self.conv_out = zero_module(
|
||||
nn.Conv2d(
|
||||
block_out_channels[-1],
|
||||
conditioning_embedding_channels,
|
||||
kernel_size=3,
|
||||
padding=1))
|
||||
|
||||
def forward(self, conditioning):
|
||||
embedding = self.conv_in(conditioning)
|
||||
embedding = F.silu(embedding)
|
||||
|
||||
if self.use_rrdb:
|
||||
embedding = self.preprocesser(embedding)
|
||||
|
||||
out_rgbs = []
|
||||
for i, block in enumerate(self.blocks):
|
||||
embedding = block(embedding)
|
||||
embedding = F.silu(embedding)
|
||||
|
||||
if i % 2 and self.return_rgbs: # 0
|
||||
out_rgbs.append(self.to_rgbs[i // 2](embedding))
|
||||
|
||||
embedding = self.conv_out(embedding)
|
||||
|
||||
return [embedding, out_rgbs] if self.return_rgbs else embedding
|
||||
|
||||
|
||||
class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
A ControlNet model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to 4):
|
||||
The number of channels in the input sample.
|
||||
flip_sin_to_cos (`bool`, defaults to `True`):
|
||||
Whether to flip the sin to cos in the time embedding.
|
||||
freq_shift (`int`, defaults to 0):
|
||||
The frequency shift to apply to the time embedding.
|
||||
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||
The tuple of downsample blocks to use.
|
||||
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
||||
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
||||
The tuple of output channels for each block.
|
||||
layers_per_block (`int`, defaults to 2):
|
||||
The number of layers per block.
|
||||
downsample_padding (`int`, defaults to 1):
|
||||
The padding to use for the downsampling convolution.
|
||||
mid_block_scale_factor (`float`, defaults to 1):
|
||||
The scale factor to use for the mid block.
|
||||
act_fn (`str`, defaults to "silu"):
|
||||
The activation function to use.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
|
||||
in post-processing.
|
||||
norm_eps (`float`, defaults to 1e-5):
|
||||
The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int`, defaults to 1280):
|
||||
The dimension of the cross attention features.
|
||||
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
||||
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
||||
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||
encoder_hid_dim (`int`, *optional*, defaults to None):
|
||||
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
||||
dimension to `cross_attention_dim`.
|
||||
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
||||
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
||||
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
||||
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
||||
The dimension of the attention heads.
|
||||
use_linear_projection (`bool`, defaults to `False`):
|
||||
class_embed_type (`str`, *optional*, defaults to `None`):
|
||||
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
|
||||
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
||||
addition_embed_type (`str`, *optional*, defaults to `None`):
|
||||
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
||||
"text". "text" will use the `TextTimeEmbedding` layer.
|
||||
num_class_embeds (`int`, *optional*, defaults to 0):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||
class conditioning with `class_embed_type` equal to `None`.
|
||||
upcast_attention (`bool`, defaults to `False`):
|
||||
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
||||
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
||||
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
|
||||
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
|
||||
`class_embed_type="projection"`.
|
||||
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
|
||||
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
||||
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
|
||||
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
||||
global_pool_conditions (`bool`, defaults to `False`):
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 4,
|
||||
conditioning_channels: int = 3,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str] = (
|
||||
'CrossAttnDownBlock2D',
|
||||
'CrossAttnDownBlock2D',
|
||||
'CrossAttnDownBlock2D',
|
||||
'DownBlock2D',
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
act_fn: str = 'silu',
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
encoder_hid_dim_type: Optional[str] = None,
|
||||
attention_head_dim: Union[int, Tuple[int]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
addition_time_embed_dim: Optional[int] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = 'default',
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
controlnet_conditioning_channel_order: str = 'rgb',
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32,
|
||||
96, 256),
|
||||
global_pool_conditions: bool = False,
|
||||
addition_embed_type_num_heads=64,
|
||||
return_rgbs: bool = True,
|
||||
use_rrdb: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# If `num_attention_heads` is not defined (which is the case for most models)
|
||||
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
||||
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
||||
# when this library was created. The incorrect naming was only discovered much later in
|
||||
# https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
||||
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
||||
# which is why we correct for the naming here.
|
||||
num_attention_heads = num_attention_heads or attention_head_dim
|
||||
|
||||
# Check inputs
|
||||
if len(block_out_channels) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f'Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: \
|
||||
{block_out_channels}. `down_block_types`: {down_block_types}.'
|
||||
)
|
||||
|
||||
if not isinstance(
|
||||
only_cross_attention,
|
||||
bool) and len(only_cross_attention) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f'Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: \
|
||||
{only_cross_attention}. `down_block_types`: {down_block_types}.'
|
||||
)
|
||||
|
||||
if not isinstance(
|
||||
num_attention_heads,
|
||||
int) and len(num_attention_heads) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f'Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: \
|
||||
{num_attention_heads}. `down_block_types`: {down_block_types}.'
|
||||
)
|
||||
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block
|
||||
] * len(down_block_types)
|
||||
|
||||
# input
|
||||
self.return_rgbs = return_rgbs
|
||||
conv_in_kernel = 3
|
||||
conv_in_padding = (conv_in_kernel - 1) // 2
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels,
|
||||
block_out_channels[0],
|
||||
kernel_size=conv_in_kernel,
|
||||
padding=conv_in_padding)
|
||||
|
||||
# time
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos,
|
||||
freq_shift)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
self.time_embedding = TimestepEmbedding(
|
||||
timestep_input_dim,
|
||||
time_embed_dim,
|
||||
act_fn=act_fn,
|
||||
)
|
||||
|
||||
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
||||
encoder_hid_dim_type = 'text_proj'
|
||||
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
||||
logger.info(
|
||||
"encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined."
|
||||
)
|
||||
|
||||
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
||||
raise ValueError(
|
||||
f'`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}.'
|
||||
)
|
||||
|
||||
if encoder_hid_dim_type == 'text_proj':
|
||||
self.encoder_hid_proj = nn.Linear(encoder_hid_dim,
|
||||
cross_attention_dim)
|
||||
elif encoder_hid_dim_type == 'text_image_proj':
|
||||
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
||||
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the
|
||||
# currently only use case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
||||
self.encoder_hid_proj = TextImageProjection(
|
||||
text_embed_dim=encoder_hid_dim,
|
||||
image_embed_dim=cross_attention_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
|
||||
elif encoder_hid_dim_type is not None:
|
||||
raise ValueError(
|
||||
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
||||
)
|
||||
else:
|
||||
self.encoder_hid_proj = None
|
||||
|
||||
# class embedding
|
||||
if class_embed_type is None and num_class_embeds is not None:
|
||||
self.class_embedding = nn.Embedding(num_class_embeds,
|
||||
time_embed_dim)
|
||||
elif class_embed_type == 'timestep':
|
||||
self.class_embedding = TimestepEmbedding(timestep_input_dim,
|
||||
time_embed_dim)
|
||||
elif class_embed_type == 'identity':
|
||||
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
||||
elif class_embed_type == 'projection':
|
||||
if projection_class_embeddings_input_dim is None:
|
||||
raise ValueError(
|
||||
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
||||
)
|
||||
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
||||
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
||||
# 2. it projects from an arbitrary input dimension.
|
||||
#
|
||||
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
||||
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
||||
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
||||
self.class_embedding = TimestepEmbedding(
|
||||
projection_class_embeddings_input_dim, time_embed_dim)
|
||||
else:
|
||||
self.class_embedding = None
|
||||
|
||||
if addition_embed_type == 'text':
|
||||
if encoder_hid_dim is not None:
|
||||
text_time_embedding_from_dim = encoder_hid_dim
|
||||
else:
|
||||
text_time_embedding_from_dim = cross_attention_dim
|
||||
|
||||
self.add_embedding = TextTimeEmbedding(
|
||||
text_time_embedding_from_dim,
|
||||
time_embed_dim,
|
||||
num_heads=addition_embed_type_num_heads)
|
||||
elif addition_embed_type == 'text_image':
|
||||
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter
|
||||
# the __init__ too much they are set to `cross_attention_dim` here as this is exactly the
|
||||
# required dimension for the currently only use
|
||||
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
||||
self.add_embedding = TextImageTimeEmbedding(
|
||||
text_embed_dim=cross_attention_dim,
|
||||
image_embed_dim=cross_attention_dim,
|
||||
time_embed_dim=time_embed_dim)
|
||||
elif addition_embed_type == 'text_time':
|
||||
self.add_time_proj = Timesteps(addition_time_embed_dim,
|
||||
flip_sin_to_cos, freq_shift)
|
||||
self.add_embedding = TimestepEmbedding(
|
||||
projection_class_embeddings_input_dim, time_embed_dim)
|
||||
|
||||
elif addition_embed_type is not None:
|
||||
raise ValueError(
|
||||
f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'."
|
||||
)
|
||||
|
||||
# control net conditioning embedding
|
||||
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
|
||||
conditioning_embedding_channels=block_out_channels[0],
|
||||
block_out_channels=conditioning_embedding_out_channels,
|
||||
conditioning_channels=conditioning_channels,
|
||||
return_rgbs=return_rgbs,
|
||||
use_rrdb=use_rrdb,
|
||||
)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.controlnet_down_blocks = nn.ModuleList([])
|
||||
|
||||
if isinstance(only_cross_attention, bool):
|
||||
only_cross_attention = [only_cross_attention
|
||||
] * len(down_block_types)
|
||||
|
||||
if isinstance(attention_head_dim, int):
|
||||
attention_head_dim = (attention_head_dim, ) * len(down_block_types)
|
||||
|
||||
if isinstance(num_attention_heads, int):
|
||||
num_attention_heads = (
|
||||
num_attention_heads, ) * len(down_block_types)
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
|
||||
controlnet_block = nn.Conv2d(
|
||||
output_channel, output_channel, kernel_size=1)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_down_blocks.append(controlnet_block)
|
||||
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block,
|
||||
transformer_layers_per_block=transformer_layers_per_block[i],
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_downsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads[i],
|
||||
attention_head_dim=attention_head_dim[i]
|
||||
if attention_head_dim[i] is not None else output_channel,
|
||||
downsample_padding=downsample_padding,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
for _ in range(layers_per_block):
|
||||
controlnet_block = nn.Conv2d(
|
||||
output_channel, output_channel, kernel_size=1)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_down_blocks.append(controlnet_block)
|
||||
|
||||
if not is_final_block:
|
||||
controlnet_block = nn.Conv2d(
|
||||
output_channel, output_channel, kernel_size=1)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_down_blocks.append(controlnet_block)
|
||||
|
||||
# mid
|
||||
mid_block_channel = block_out_channels[-1]
|
||||
|
||||
controlnet_block = nn.Conv2d(
|
||||
mid_block_channel, mid_block_channel, kernel_size=1)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_mid_block = controlnet_block
|
||||
|
||||
self.mid_block = UNetMidBlock2DCrossAttn(
|
||||
transformer_layers_per_block=transformer_layers_per_block[-1],
|
||||
in_channels=mid_block_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_unet(
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
controlnet_conditioning_channel_order: str = 'rgb',
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32,
|
||||
96, 256),
|
||||
load_weights_from_unet: bool = True,
|
||||
):
|
||||
r"""
|
||||
Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
|
||||
|
||||
Parameters:
|
||||
unet (`UNet2DConditionModel`):
|
||||
The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
|
||||
where applicable.
|
||||
"""
|
||||
transformer_layers_per_block = (
|
||||
unet.config.transformer_layers_per_block
|
||||
if 'transformer_layers_per_block' in unet.config else 1)
|
||||
encoder_hid_dim = unet.config.encoder_hid_dim if 'encoder_hid_dim' in unet.config else None
|
||||
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if 'encoder_hid_dim_type' in unet.config else None
|
||||
addition_embed_type = unet.config.addition_embed_type if 'addition_embed_type' in unet.config else None
|
||||
addition_time_embed_dim = (
|
||||
unet.config.addition_time_embed_dim
|
||||
if 'addition_time_embed_dim' in unet.config else None)
|
||||
|
||||
controlnet = cls(
|
||||
encoder_hid_dim=encoder_hid_dim,
|
||||
encoder_hid_dim_type=encoder_hid_dim_type,
|
||||
addition_embed_type=addition_embed_type,
|
||||
addition_time_embed_dim=addition_time_embed_dim,
|
||||
transformer_layers_per_block=transformer_layers_per_block,
|
||||
in_channels=unet.config.in_channels,
|
||||
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
||||
freq_shift=unet.config.freq_shift,
|
||||
down_block_types=unet.config.down_block_types,
|
||||
only_cross_attention=unet.config.only_cross_attention,
|
||||
block_out_channels=unet.config.block_out_channels,
|
||||
layers_per_block=unet.config.layers_per_block,
|
||||
downsample_padding=unet.config.downsample_padding,
|
||||
mid_block_scale_factor=unet.config.mid_block_scale_factor,
|
||||
act_fn=unet.config.act_fn,
|
||||
norm_num_groups=unet.config.norm_num_groups,
|
||||
norm_eps=unet.config.norm_eps,
|
||||
cross_attention_dim=unet.config.cross_attention_dim,
|
||||
attention_head_dim=unet.config.attention_head_dim,
|
||||
num_attention_heads=unet.config.num_attention_heads,
|
||||
use_linear_projection=unet.config.use_linear_projection,
|
||||
class_embed_type=unet.config.class_embed_type,
|
||||
num_class_embeds=unet.config.num_class_embeds,
|
||||
upcast_attention=unet.config.upcast_attention,
|
||||
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
|
||||
projection_class_embeddings_input_dim=unet.config.
|
||||
projection_class_embeddings_input_dim,
|
||||
controlnet_conditioning_channel_order=
|
||||
controlnet_conditioning_channel_order,
|
||||
conditioning_embedding_out_channels=
|
||||
conditioning_embedding_out_channels,
|
||||
)
|
||||
|
||||
if load_weights_from_unet:
|
||||
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
|
||||
controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
|
||||
controlnet.time_embedding.load_state_dict(
|
||||
unet.time_embedding.state_dict())
|
||||
|
||||
if controlnet.class_embedding:
|
||||
controlnet.class_embedding.load_state_dict(
|
||||
unet.class_embedding.state_dict())
|
||||
|
||||
controlnet.down_blocks.load_state_dict(
|
||||
unet.down_blocks.state_dict())
|
||||
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
|
||||
|
||||
return controlnet
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module,
|
||||
processors: Dict[str,
|
||||
AttentionProcessor]):
|
||||
if hasattr(module, 'get_processor'):
|
||||
processors[f'{name}.processor'] = module.get_processor(
|
||||
return_deprecated_lora=True)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f'{name}.{sub_name}', child,
|
||||
processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor,
|
||||
Dict[str,
|
||||
AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f'A dict of processors was passed, but the number of processors {len(processor)} does not match the'
|
||||
f' number of attention layers: {count}. Please make sure to pass {count} processor classes.'
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module,
|
||||
processor):
|
||||
if hasattr(module, 'set_processor'):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f'{name}.processor'))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f'{name}.{sub_name}', child,
|
||||
processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
"""
|
||||
self.set_attn_processor(AttnProcessor())
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
def set_attention_slice(self, slice_size):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
||||
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
||||
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
||||
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
||||
must be a multiple of `slice_size`.
|
||||
"""
|
||||
sliceable_head_dims = []
|
||||
|
||||
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
||||
if hasattr(module, 'set_attention_slice'):
|
||||
sliceable_head_dims.append(module.sliceable_head_dim)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_retrieve_sliceable_dims(child)
|
||||
|
||||
# retrieve number of attention layers
|
||||
for module in self.children():
|
||||
fn_recursive_retrieve_sliceable_dims(module)
|
||||
|
||||
num_sliceable_layers = len(sliceable_head_dims)
|
||||
|
||||
if slice_size == 'auto':
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
||||
elif slice_size == 'max':
|
||||
# make smallest slice possible
|
||||
slice_size = num_sliceable_layers * [1]
|
||||
|
||||
slice_size = num_sliceable_layers * [slice_size] if not isinstance(
|
||||
slice_size, list) else slice_size
|
||||
|
||||
if len(slice_size) != len(sliceable_head_dims):
|
||||
raise ValueError(
|
||||
f'You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different'
|
||||
f' attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}.'
|
||||
)
|
||||
|
||||
for i in range(len(slice_size)):
|
||||
size = slice_size[i]
|
||||
dim = sliceable_head_dims[i]
|
||||
if size is not None and size > dim:
|
||||
raise ValueError(
|
||||
f'size {size} has to be smaller or equal to {dim}.')
|
||||
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_attention_slice method
|
||||
# gets the message
|
||||
def fn_recursive_set_attention_slice(module: torch.nn.Module,
|
||||
slice_size: List[int]):
|
||||
if hasattr(module, 'set_attention_slice'):
|
||||
module.set_attention_slice(slice_size.pop())
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_set_attention_slice(child, slice_size)
|
||||
|
||||
reversed_slice_size = list(reversed(slice_size))
|
||||
for module in self.children():
|
||||
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
controlnet_cond: torch.FloatTensor,
|
||||
fg_mask: Optional[torch.FloatTensor] = None,
|
||||
conditioning_scale_fg: float = 1.0,
|
||||
conditioning_scale_bg: float = 1.0,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[ControlNetOutput, Tuple]:
|
||||
"""
|
||||
The [`ControlNetModel`] forward method.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
The noisy input tensor.
|
||||
timestep (`Union[torch.Tensor, float, int]`):
|
||||
The number of timesteps to denoise an input.
|
||||
encoder_hidden_states (`torch.Tensor`):
|
||||
The encoder hidden states.
|
||||
controlnet_cond (`torch.FloatTensor`):
|
||||
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
conditioning_scale (`float`, defaults to `1.0`):
|
||||
The scale factor for ControlNet outputs.
|
||||
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
added_cond_kwargs (`dict`):
|
||||
Additional conditions for the Stable Diffusion XL UNet.
|
||||
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
||||
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
||||
guess_mode (`bool`, defaults to `False`):
|
||||
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
|
||||
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
|
||||
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
|
||||
returned where the first element is the sample tensor.
|
||||
"""
|
||||
# check channel order
|
||||
channel_order = self.config.controlnet_conditioning_channel_order
|
||||
|
||||
if channel_order == 'rgb':
|
||||
# in rgb order by default
|
||||
...
|
||||
elif channel_order == 'bgr':
|
||||
controlnet_cond = torch.flip(controlnet_cond, dims=[1])
|
||||
else:
|
||||
raise ValueError(
|
||||
f'unknown `controlnet_conditioning_channel_order`: {channel_order}'
|
||||
)
|
||||
|
||||
# prepare attention_mask
|
||||
if attention_mask is not None:
|
||||
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = sample.device.type == 'mps'
|
||||
if isinstance(timestep, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
timesteps = torch.tensor([timesteps],
|
||||
dtype=dtype,
|
||||
device=sample.device)
|
||||
elif len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=sample.dtype)
|
||||
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
aug_emb = None
|
||||
|
||||
if self.class_embedding is not None:
|
||||
if class_labels is None:
|
||||
raise ValueError(
|
||||
'class_labels should be provided when num_class_embeds > 0'
|
||||
)
|
||||
|
||||
if self.config.class_embed_type == 'timestep':
|
||||
class_labels = self.time_proj(class_labels)
|
||||
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
emb = emb + class_emb
|
||||
|
||||
if self.config.addition_embed_type is not None:
|
||||
if self.config.addition_embed_type == 'text':
|
||||
aug_emb = self.add_embedding(encoder_hidden_states)
|
||||
|
||||
elif self.config.addition_embed_type == 'text_time':
|
||||
if 'text_embeds' not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' \
|
||||
which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
text_embeds = added_cond_kwargs.get('text_embeds')
|
||||
if 'time_ids' not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' \
|
||||
which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
time_ids = added_cond_kwargs.get('time_ids')
|
||||
time_embeds = self.add_time_proj(time_ids.flatten())
|
||||
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
||||
|
||||
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
||||
add_embeds = add_embeds.to(emb.dtype)
|
||||
aug_emb = self.add_embedding(add_embeds)
|
||||
|
||||
emb = emb + aug_emb if aug_emb is not None else emb
|
||||
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
controlnet_cond_mid = None
|
||||
if self.return_rgbs:
|
||||
controlnet_cond, controlnet_cond_mid = self.controlnet_cond_embedding(
|
||||
controlnet_cond)
|
||||
else:
|
||||
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
||||
|
||||
sample = sample + controlnet_cond
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample, )
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, 'has_cross_attention'
|
||||
) and downsample_block.has_cross_attention:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample, temb=emb)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. mid
|
||||
if self.mid_block is not None:
|
||||
sample = self.mid_block(
|
||||
sample,
|
||||
emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
|
||||
# 5. Control net blocks
|
||||
|
||||
controlnet_down_block_res_samples = ()
|
||||
|
||||
for down_block_res_sample, controlnet_block in zip(
|
||||
down_block_res_samples, self.controlnet_down_blocks):
|
||||
down_block_res_sample = controlnet_block(down_block_res_sample)
|
||||
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (
|
||||
down_block_res_sample, )
|
||||
|
||||
down_block_res_samples = controlnet_down_block_res_samples
|
||||
|
||||
mid_block_res_sample = self.controlnet_mid_block(sample)
|
||||
|
||||
# 6. scaling
|
||||
if guess_mode and not self.config.global_pool_conditions:
|
||||
scales = torch.logspace(
|
||||
-1, 0, len(down_block_res_samples) + 1,
|
||||
device=sample.device) # 0.1 to 1.0
|
||||
|
||||
scales = scales * conditioning_scale_fg
|
||||
down_block_res_samples = [
|
||||
sample * scale
|
||||
for sample, scale in zip(down_block_res_samples, scales)
|
||||
]
|
||||
mid_block_res_sample = mid_block_res_sample * scales[
|
||||
-1] # last one
|
||||
else:
|
||||
if fg_mask is None:
|
||||
down_block_res_samples = [
|
||||
sample * conditioning_scale_fg
|
||||
for sample in down_block_res_samples
|
||||
]
|
||||
mid_block_res_sample = mid_block_res_sample * conditioning_scale_fg
|
||||
else:
|
||||
down_block_masks = [
|
||||
torch.zeros_like(sample) + conditioning_scale_bg
|
||||
for i, sample in enumerate(down_block_res_samples)
|
||||
]
|
||||
mid_block_mask = torch.zeros_like(
|
||||
mid_block_res_sample) + conditioning_scale_bg
|
||||
|
||||
for i, sample in enumerate(down_block_masks):
|
||||
tmp_mask = F.interpolate(
|
||||
fg_mask,
|
||||
size=sample.shape[-2:]).repeat(sample.shape[0],
|
||||
sample.shape[1], 1,
|
||||
1).bool()
|
||||
down_block_masks[i] = sample.masked_fill(
|
||||
tmp_mask, conditioning_scale_fg)
|
||||
|
||||
tmp_mask = F.interpolate(
|
||||
fg_mask, size=mid_block_mask.shape[-2:]).repeat(
|
||||
mid_block_mask.shape[0], mid_block_mask.shape[1], 1,
|
||||
1).bool()
|
||||
mid_block_mask = mid_block_mask.masked_fill(
|
||||
tmp_mask, conditioning_scale_fg)
|
||||
|
||||
down_block_res_samples = [
|
||||
sample * down_block_mask for sample, down_block_mask in
|
||||
zip(down_block_res_samples, down_block_masks)
|
||||
]
|
||||
mid_block_res_sample = mid_block_res_sample * mid_block_mask
|
||||
|
||||
if self.config.global_pool_conditions:
|
||||
down_block_res_samples = [
|
||||
torch.mean(sample, dim=(2, 3), keepdim=True)
|
||||
for sample in down_block_res_samples
|
||||
]
|
||||
mid_block_res_sample = torch.mean(
|
||||
mid_block_res_sample, dim=(2, 3), keepdim=True)
|
||||
|
||||
if not return_dict:
|
||||
return (controlnet_cond_mid, down_block_res_samples,
|
||||
mid_block_res_sample)
|
||||
|
||||
return ControlNetOutput(
|
||||
controlnet_cond_mid=controlnet_cond_mid,
|
||||
down_block_res_samples=down_block_res_samples,
|
||||
mid_block_res_sample=mid_block_res_sample)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
for p in module.parameters():
|
||||
nn.init.zeros_(p)
|
||||
return module
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -9,8 +9,8 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from timm.layers.drop import drop_path
|
||||
from timm.layers.weight_init import trunc_normal_
|
||||
from timm.models.layers.drop import drop_path
|
||||
from timm.models.layers.weight_init import trunc_normal_
|
||||
|
||||
from .common import Upsample, resize
|
||||
|
||||
|
||||
@@ -11,8 +11,8 @@ from collections import OrderedDict
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from timm.layers.drop import drop_path
|
||||
from timm.layers.weight_init import trunc_normal_
|
||||
from timm.models.layers.drop import drop_path
|
||||
from timm.models.layers.weight_init import trunc_normal_
|
||||
from torch import nn
|
||||
|
||||
|
||||
|
||||
@@ -8,8 +8,8 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
from timm.layers.drop import drop_path
|
||||
from timm.layers.weight_init import trunc_normal_
|
||||
from timm.models.layers.drop import drop_path
|
||||
from timm.models.layers.weight_init import trunc_normal_
|
||||
|
||||
from .common import resize
|
||||
|
||||
|
||||
@@ -183,6 +183,12 @@ class OFATokenizerZH(PreTrainedTokenizer):
|
||||
tokenize_chinese_chars=True,
|
||||
strip_accents=None,
|
||||
**kwargs):
|
||||
if not os.path.isfile(vocab_file):
|
||||
raise ValueError(
|
||||
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained "
|
||||
'model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`'
|
||||
)
|
||||
self.vocab = load_vocab(vocab_file)
|
||||
super().__init__(
|
||||
do_lower_case=do_lower_case,
|
||||
do_basic_tokenize=do_basic_tokenize,
|
||||
@@ -199,12 +205,6 @@ class OFATokenizerZH(PreTrainedTokenizer):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not os.path.isfile(vocab_file):
|
||||
raise ValueError(
|
||||
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained "
|
||||
'model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`'
|
||||
)
|
||||
self.vocab = load_vocab(vocab_file)
|
||||
self.ids_to_tokens = collections.OrderedDict([
|
||||
(ids, tok) for tok, ids in self.vocab.items()
|
||||
])
|
||||
|
||||
@@ -199,6 +199,10 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||
padding_side='left',
|
||||
num_image_tokens=20000,
|
||||
**kwargs) -> None:
|
||||
|
||||
self.sp_tokenizer = SPTokenizer(
|
||||
vocab_file, num_image_tokens=num_image_tokens)
|
||||
|
||||
super().__init__(
|
||||
do_lower_case=do_lower_case,
|
||||
remove_space=remove_space,
|
||||
@@ -220,9 +224,6 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||
self.end_token = end_token
|
||||
self.mask_token = mask_token
|
||||
self.gmask_token = gmask_token
|
||||
|
||||
self.sp_tokenizer = SPTokenizer(
|
||||
vocab_file, num_image_tokens=num_image_tokens)
|
||||
""" Initialisation """
|
||||
|
||||
@property
|
||||
|
||||
@@ -72,7 +72,6 @@ class ChatGLM2Tokenizer(PreTrainedTokenizer):
|
||||
model_input_names = ['input_ids', 'attention_mask', 'position_ids']
|
||||
|
||||
def __init__(self, vocab_file, padding_side='left', **kwargs):
|
||||
super().__init__(padding_side=padding_side, **kwargs)
|
||||
self.name = 'GLMTokenizer'
|
||||
|
||||
self.vocab_file = vocab_file
|
||||
@@ -82,6 +81,7 @@ class ChatGLM2Tokenizer(PreTrainedTokenizer):
|
||||
'<eos>': self.tokenizer.eos_id,
|
||||
'<pad>': self.tokenizer.pad_id
|
||||
}
|
||||
super().__init__(padding_side=padding_side, **kwargs)
|
||||
|
||||
def get_command(self, token):
|
||||
if token in self.special_tokens:
|
||||
|
||||
@@ -71,8 +71,9 @@ def get_chat_prompt(system: str, text: str, history: List[Tuple[str, str]],
|
||||
|
||||
|
||||
# This file is mainly copied from the llama code of transformers
|
||||
@MODELS.register_module(Tasks.text_generation, module_name=Models.llama2)
|
||||
@MODELS.register_module(Tasks.chat, module_name=Models.llama2)
|
||||
@MODELS.register_module(Tasks.chat, module_name=Models.llama)
|
||||
@MODELS.register_module(Tasks.text_generation, module_name=Models.llama2)
|
||||
@MODELS.register_module(Tasks.text_generation, module_name=Models.llama)
|
||||
class LlamaForTextGeneration(MsModelMixin, LlamaForCausalLM, TorchModel):
|
||||
|
||||
|
||||
@@ -325,13 +325,9 @@ TASK_INPUTS = {
|
||||
},
|
||||
|
||||
# ============ nlp tasks ===================
|
||||
Tasks.chat: [
|
||||
InputType.TEXT,
|
||||
{
|
||||
'text': InputType.TEXT,
|
||||
'history': InputType.LIST,
|
||||
}
|
||||
],
|
||||
Tasks.chat: {
|
||||
'messages': InputType.LIST
|
||||
},
|
||||
Tasks.text_classification: [
|
||||
InputType.TEXT,
|
||||
(InputType.TEXT, InputType.TEXT),
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from modelscope.hub.file_download import model_file_download
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.metainfo import DEFAULT_MODEL_FOR_PIPELINE, Pipelines
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.utils.config import ConfigDict, check_config
|
||||
from modelscope.utils.config import Config, ConfigDict, check_config
|
||||
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, Invoke,
|
||||
ThirdParty)
|
||||
ModelFile, ThirdParty)
|
||||
from modelscope.utils.hub import read_config
|
||||
from modelscope.utils.plugins import (register_modelhub_repo,
|
||||
register_plugins_repo)
|
||||
@@ -117,6 +119,8 @@ def pipeline(task: str = None,
|
||||
model_revision,
|
||||
third_party=third_party,
|
||||
ignore_file_pattern=ignore_file_pattern)
|
||||
if pipeline_name is None and kwargs.get('llm_first'):
|
||||
pipeline_name = llm_first_checker(model, model_revision)
|
||||
pipeline_props = {'type': pipeline_name}
|
||||
if pipeline_name is None:
|
||||
# get default pipeline for this task
|
||||
@@ -196,3 +200,39 @@ def get_default_pipeline_info(task):
|
||||
else:
|
||||
pipeline_name, default_model = DEFAULT_MODEL_FOR_PIPELINE[task]
|
||||
return pipeline_name, default_model
|
||||
|
||||
|
||||
def llm_first_checker(model: Union[str, List[str], Model, List[Model]],
|
||||
revision: Optional[str]) -> Optional[str]:
|
||||
from modelscope.pipelines.nlp.llm_pipeline import LLM_FORMAT_MAP
|
||||
|
||||
def get_file_name(model: str, cfg_name: str,
|
||||
revision: Optional[str]) -> Optional[str]:
|
||||
if osp.exists(model):
|
||||
return osp.join(model, cfg_name)
|
||||
try:
|
||||
return model_file_download(model, cfg_name, revision=revision)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def parse_model_type(file: Optional[str], pattern: str) -> Optional[str]:
|
||||
if file is None or not osp.exists(file):
|
||||
return None
|
||||
return Config.from_file(file).safe_get(pattern)
|
||||
|
||||
def get_model_type(model: str, revision: Optional[str]) -> Optional[str]:
|
||||
cfg_file = get_file_name(model, ModelFile.CONFIGURATION, revision)
|
||||
hf_cfg_file = get_file_name(model, ModelFile.CONFIG, revision)
|
||||
cfg_model_type = parse_model_type(cfg_file, 'model.type')
|
||||
hf_cfg_model_type = parse_model_type(hf_cfg_file, 'model_type')
|
||||
return cfg_model_type or hf_cfg_model_type
|
||||
|
||||
if isinstance(model, list):
|
||||
model = model[0]
|
||||
if not isinstance(model, str):
|
||||
model = model.model_dir
|
||||
model_type = get_model_type(model, revision)
|
||||
if model_type is not None:
|
||||
model_type = model_type.lower().split('-')[0]
|
||||
if model_type in LLM_FORMAT_MAP:
|
||||
return 'llm'
|
||||
|
||||
@@ -8,14 +8,13 @@ import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from diffusers import AutoencoderKL, UniPCMultistepScheduler
|
||||
from torchvision import transforms
|
||||
from torchvision.models import ResNet50_Weights, resnet50
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.cv.image_portrait_enhancement.retinaface import \
|
||||
detection
|
||||
from modelscope.models.cv.image_super_resolution_pasd import (
|
||||
ControlNetModel, UNet2DConditionModel)
|
||||
from modelscope.models.cv.image_super_resolution_pasd.misc import (
|
||||
load_dreambooth_lora, wavelet_color_fix)
|
||||
from modelscope.outputs import OutputKeys
|
||||
@@ -48,8 +47,8 @@ class ImageSuperResolutionPASDPipeline(Pipeline):
|
||||
>>> 'image': input_location,
|
||||
>>> 'upscale': 2,
|
||||
>>> 'prompt': prompt,
|
||||
>>> 'fidelity_scale_fg': 1.5,
|
||||
>>> 'fidelity_scale_bg': 0.7
|
||||
>>> 'fidelity_scale_fg': 1.0,
|
||||
>>> 'fidelity_scale_bg': 1.0
|
||||
>>> }
|
||||
>>> pasd = pipeline(Tasks.image_super_resolution_pasd, model='damo/PASD_image_super_resolutions')
|
||||
>>> output = pasd(input)[OutputKeys.OUTPUT_IMG]
|
||||
@@ -69,6 +68,13 @@ class ImageSuperResolutionPASDPipeline(Pipeline):
|
||||
self.device = create_device(device_name)
|
||||
self.config = Config.from_file(
|
||||
os.path.join(model, ModelFile.CONFIGURATION))
|
||||
version = self.config.pipeline.get('version', 'pasd_v2')
|
||||
if version == 'pasd':
|
||||
from modelscope.models.cv.image_super_resolution_pasd import (
|
||||
ControlNetModel, UNet2DConditionModel)
|
||||
else:
|
||||
from modelscope.models.cv.image_super_resolution_pasd_v2 import (
|
||||
ControlNetModel, UNet2DConditionModel)
|
||||
cfg = self.config.model_cfg
|
||||
dreambooth_lora_ckpt = cfg['dreambooth_lora_ckpt']
|
||||
tiled_size = cfg['tiled_size']
|
||||
@@ -123,6 +129,12 @@ class ImageSuperResolutionPASDPipeline(Pipeline):
|
||||
self.face_detector = detection.RetinaFaceDetection(
|
||||
detector_model_path, self.device)
|
||||
|
||||
self.resize_preproc = transforms.Compose([
|
||||
transforms.Resize(
|
||||
self.process_size,
|
||||
interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
])
|
||||
|
||||
def preprocess(self, input: Input):
|
||||
return input
|
||||
|
||||
@@ -145,8 +157,8 @@ class ImageSuperResolutionPASDPipeline(Pipeline):
|
||||
eta = inputs.get('eta', 0.0)
|
||||
prompt = inputs.get('prompt', '')
|
||||
upscale = inputs.get('upscale', 2)
|
||||
fidelity_scale_fg = inputs.get('fidelity_scale_fg', 1.5)
|
||||
fidelity_scale_bg = inputs.get('fidelity_scale_bg', 0.7)
|
||||
fidelity_scale_fg = inputs.get('fidelity_scale_fg', 1.0)
|
||||
fidelity_scale_bg = inputs.get('fidelity_scale_bg', 1.0)
|
||||
|
||||
input_image = load_image(inputs['image']).convert('RGB')
|
||||
|
||||
@@ -164,19 +176,15 @@ class ImageSuperResolutionPASDPipeline(Pipeline):
|
||||
prompt = added_prompt if prompt == '' else f'{prompt}, {added_prompt}'
|
||||
|
||||
ori_width, ori_height = input_image.size
|
||||
resize_flag = False
|
||||
resize_flag = True
|
||||
rscale = upscale
|
||||
if ori_width < self.process_size // rscale or ori_height < self.process_size // rscale:
|
||||
scale = (self.process_size // rscale) / min(
|
||||
ori_width, ori_height)
|
||||
tmp_image = input_image.resize(
|
||||
(int(scale * ori_width), int(scale * ori_height)))
|
||||
|
||||
input_image = tmp_image
|
||||
resize_flag = True
|
||||
|
||||
input_image = input_image.resize(
|
||||
(input_image.size[0] * rscale, input_image.size[1] * rscale))
|
||||
|
||||
if min(input_image.size) < self.process_size:
|
||||
input_image = self.resize_preproc(input_image)
|
||||
|
||||
input_image = input_image.resize(
|
||||
(input_image.size[0] // 8 * 8, input_image.size[1] // 8 * 8))
|
||||
width, height = input_image.size
|
||||
|
||||
@@ -16,6 +16,8 @@ from . import utils
|
||||
if tf.__version__ >= '2.0':
|
||||
tf = tf.compat.v1
|
||||
|
||||
# test
|
||||
|
||||
# skip parse sys.argv in tf, so fix bug:
|
||||
# absl.flags._exceptions.UnrecognizedFlagError:
|
||||
# Unknown command line flag 'OCRDetectionPipeline: Unknown command line flag
|
||||
|
||||
@@ -1031,46 +1031,213 @@ class PixelAwareStableDiffusionPipeline(DiffusionPipeline,
|
||||
controlnet_latent_model_input = latent_model_input
|
||||
controlnet_prompt_embeds = prompt_embeds
|
||||
|
||||
if image is not None:
|
||||
rgbs, down_block_res_samples, mid_block_res_sample = self.controlnet(
|
||||
controlnet_latent_model_input,
|
||||
_, _, h, w = latent_model_input.size()
|
||||
tile_size, tile_overlap = 120, 32
|
||||
if h < tile_size and w < tile_size:
|
||||
if image is not None:
|
||||
rgbs, down_block_res_samples, mid_block_res_sample = self.controlnet(
|
||||
controlnet_latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=controlnet_prompt_embeds,
|
||||
controlnet_cond=image,
|
||||
fg_mask=fg_mask,
|
||||
conditioning_scale_fg=conditioning_scale_fg,
|
||||
conditioning_scale_bg=conditioning_scale_bg,
|
||||
guess_mode=guess_mode,
|
||||
return_dict=False,
|
||||
)
|
||||
else:
|
||||
down_block_res_samples, mid_block_res_sample = [
|
||||
None
|
||||
] * 10, [None] * 10
|
||||
|
||||
if guess_mode and do_classifier_free_guidance:
|
||||
# Infered ControlNet only for the conditional batch.
|
||||
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
||||
# add 0 to the unconditional batch to keep it unchanged.
|
||||
down_block_res_samples = [
|
||||
torch.cat([torch.zeros_like(d), d])
|
||||
for d in down_block_res_samples
|
||||
]
|
||||
mid_block_res_sample = torch.cat([
|
||||
torch.zeros_like(mid_block_res_sample),
|
||||
mid_block_res_sample
|
||||
])
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=controlnet_prompt_embeds,
|
||||
controlnet_cond=image,
|
||||
fg_mask=fg_mask,
|
||||
conditioning_scale_fg=conditioning_scale_fg,
|
||||
conditioning_scale_bg=conditioning_scale_bg,
|
||||
guess_mode=guess_mode,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
down_block_additional_residuals=down_block_res_samples,
|
||||
mid_block_additional_residual=mid_block_res_sample,
|
||||
return_dict=False,
|
||||
)
|
||||
)[0]
|
||||
else:
|
||||
down_block_res_samples, mid_block_res_sample = [
|
||||
None
|
||||
] * 10, [None] * 10
|
||||
tile_size = min(tile_size, min(h, w))
|
||||
tile_weights = self._gaussian_weights(
|
||||
tile_size, tile_size, 1).to(latent_model_input.device)
|
||||
|
||||
if guess_mode and do_classifier_free_guidance:
|
||||
# Infered ControlNet only for the conditional batch.
|
||||
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
||||
# add 0 to the unconditional batch to keep it unchanged.
|
||||
down_block_res_samples = [
|
||||
torch.cat([torch.zeros_like(d), d])
|
||||
for d in down_block_res_samples
|
||||
]
|
||||
mid_block_res_sample = torch.cat([
|
||||
torch.zeros_like(mid_block_res_sample),
|
||||
mid_block_res_sample
|
||||
])
|
||||
grid_rows = 0
|
||||
cur_x = 0
|
||||
while cur_x < latent_model_input.size(-1):
|
||||
cur_x = max(
|
||||
grid_rows * tile_size - tile_overlap * grid_rows,
|
||||
0) + tile_size
|
||||
grid_rows += 1
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
down_block_additional_residuals=down_block_res_samples,
|
||||
mid_block_additional_residual=mid_block_res_sample,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
grid_cols = 0
|
||||
cur_y = 0
|
||||
while cur_y < latent_model_input.size(-2):
|
||||
cur_y = max(
|
||||
grid_cols * tile_size - tile_overlap * grid_cols,
|
||||
0) + tile_size
|
||||
grid_cols += 1
|
||||
|
||||
input_list = []
|
||||
cond_list = []
|
||||
img_list = []
|
||||
fg_mask_list = []
|
||||
noise_preds = []
|
||||
for row in range(grid_rows):
|
||||
for col in range(grid_cols):
|
||||
if col < grid_cols - 1 or row < grid_rows - 1:
|
||||
# extract tile from input image
|
||||
ofs_x = max(
|
||||
row * tile_size - tile_overlap * row, 0)
|
||||
ofs_y = max(
|
||||
col * tile_size - tile_overlap * col, 0)
|
||||
# input tile area on total image
|
||||
if row == grid_rows - 1:
|
||||
ofs_x = w - tile_size
|
||||
if col == grid_cols - 1:
|
||||
ofs_y = h - tile_size
|
||||
|
||||
input_start_x = ofs_x
|
||||
input_end_x = ofs_x + tile_size
|
||||
input_start_y = ofs_y
|
||||
input_end_y = ofs_y + tile_size
|
||||
|
||||
# input tile dimensions
|
||||
input_tile = latent_model_input[:, :,
|
||||
input_start_y:
|
||||
input_end_y,
|
||||
input_start_x:
|
||||
input_end_x]
|
||||
input_list.append(input_tile)
|
||||
cond_tile = controlnet_latent_model_input[:, :,
|
||||
input_start_y:
|
||||
input_end_y,
|
||||
input_start_x:
|
||||
input_end_x]
|
||||
cond_list.append(cond_tile)
|
||||
img_tile = image[:, :,
|
||||
input_start_y * 8:input_end_y * 8,
|
||||
input_start_x * 8:input_end_x * 8]
|
||||
img_list.append(img_tile)
|
||||
if fg_mask is not None:
|
||||
fg_mask_tile = fg_mask[:, :, input_start_y
|
||||
* 8:input_end_y * 8,
|
||||
input_start_x
|
||||
* 8:input_end_x * 8]
|
||||
fg_mask_list.append(fg_mask_tile)
|
||||
|
||||
if len(input_list
|
||||
) == batch_size or col == grid_cols - 1:
|
||||
input_list_t = torch.cat(input_list, dim=0)
|
||||
cond_list_t = torch.cat(cond_list, dim=0)
|
||||
img_list_t = torch.cat(img_list, dim=0)
|
||||
if fg_mask is not None:
|
||||
fg_mask_list_t = torch.cat(
|
||||
fg_mask_list, dim=0)
|
||||
else:
|
||||
fg_mask_list_t = None
|
||||
|
||||
_, down_block_res_samples, mid_block_res_sample = self.controlnet(
|
||||
cond_list_t,
|
||||
t,
|
||||
encoder_hidden_states=
|
||||
controlnet_prompt_embeds,
|
||||
controlnet_cond=img_list_t,
|
||||
fg_mask=fg_mask_list_t,
|
||||
conditioning_scale_fg=conditioning_scale_fg,
|
||||
conditioning_scale_bg=conditioning_scale_bg,
|
||||
guess_mode=guess_mode,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
if guess_mode and do_classifier_free_guidance:
|
||||
# Infered ControlNet only for the conditional batch.
|
||||
# To apply the output of ControlNet to the unconditional/conditional batches,
|
||||
# add 0 to the unconditional batch to keep it unchanged.
|
||||
down_block_res_samples = [
|
||||
torch.cat([torch.zeros_like(d), d])
|
||||
for d in down_block_res_samples
|
||||
]
|
||||
mid_block_res_sample = torch.cat([
|
||||
torch.zeros_like(mid_block_res_sample),
|
||||
mid_block_res_sample
|
||||
])
|
||||
|
||||
# predict the noise residual
|
||||
model_out = self.unet(
|
||||
input_list_t,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=
|
||||
cross_attention_kwargs,
|
||||
down_block_additional_residuals=
|
||||
down_block_res_samples,
|
||||
mid_block_additional_residual=
|
||||
mid_block_res_sample,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
input_list = []
|
||||
cond_list = []
|
||||
img_list = []
|
||||
fg_mask_list = []
|
||||
|
||||
noise_preds.append(model_out)
|
||||
|
||||
# Stitch noise predictions for all tiles
|
||||
noise_pred = torch.zeros(
|
||||
latent_model_input.shape,
|
||||
device=latent_model_input.device)
|
||||
contributors = torch.zeros(
|
||||
latent_model_input.shape,
|
||||
device=latent_model_input.device)
|
||||
# Add each tile contribution to overall latents
|
||||
for row in range(grid_rows):
|
||||
for col in range(grid_cols):
|
||||
if col < grid_cols - 1 or row < grid_rows - 1:
|
||||
# extract tile from input image
|
||||
ofs_x = max(
|
||||
row * tile_size - tile_overlap * row, 0)
|
||||
ofs_y = max(
|
||||
col * tile_size - tile_overlap * col, 0)
|
||||
# input tile area on total image
|
||||
if row == grid_rows - 1:
|
||||
ofs_x = w - tile_size
|
||||
if col == grid_cols - 1:
|
||||
ofs_y = h - tile_size
|
||||
|
||||
input_start_x = ofs_x
|
||||
input_end_x = ofs_x + tile_size
|
||||
input_start_y = ofs_y
|
||||
input_end_y = ofs_y + tile_size
|
||||
|
||||
noise_pred[:, :, input_start_y:input_end_y,
|
||||
input_start_x:
|
||||
input_end_x] += noise_preds[
|
||||
row * grid_cols
|
||||
+ col] * tile_weights
|
||||
contributors[:, :, input_start_y:input_end_y,
|
||||
input_start_x:
|
||||
input_end_x] += tile_weights
|
||||
# Average overlapping areas with more than 1 contributor
|
||||
noise_pred /= contributors
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
|
||||
@@ -48,6 +48,7 @@ if TYPE_CHECKING:
|
||||
from .document_grounded_dialog_rerank_pipeline import DocumentGroundedDialogRerankPipeline
|
||||
from .language_identification_pipline import LanguageIdentificationPipeline
|
||||
from .machine_reading_comprehension_pipeline import MachineReadingComprehensionForNERPipeline
|
||||
from .llm_pipeline import LLMPipeline
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
@@ -119,6 +120,7 @@ else:
|
||||
'machine_reading_comprehension_pipeline': [
|
||||
'MachineReadingComprehensionForNERPipeline'
|
||||
],
|
||||
'llm_pipeline': ['LLMPipeline'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
@@ -11,6 +11,7 @@ from modelscope import (AutoModelForCausalLM, AutoTokenizer, Pipeline,
|
||||
snapshot_download)
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.models.nlp import ChatGLM2Tokenizer, Llama2Tokenizer
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.pipelines.util import is_model, is_official_hub_path
|
||||
from modelscope.utils.constant import Invoke, ModelFile, Tasks
|
||||
@@ -19,7 +20,8 @@ from modelscope.utils.logger import get_logger
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(Tasks.chat, module_name='llm-pipeline')
|
||||
@PIPELINES.register_module(Tasks.chat, module_name='llm')
|
||||
@PIPELINES.register_module(Tasks.text_generation, module_name='llm')
|
||||
class LLMPipeline(Pipeline):
|
||||
|
||||
def initiate_single_model(self, model):
|
||||
@@ -55,6 +57,9 @@ class LLMPipeline(Pipeline):
|
||||
*args,
|
||||
**kwargs):
|
||||
self.device_map = kwargs.pop('device_map', None)
|
||||
# TODO: qwen-int4 need 'cuda'/'auto' device_map.
|
||||
if not self.device_map and 'qwen' in kwargs['model'].lower():
|
||||
self.device_map = 'cuda'
|
||||
self.torch_dtype = kwargs.pop('torch_dtype', None)
|
||||
self.ignore_file_pattern = kwargs.pop('ignore_file_pattern', None)
|
||||
with self._temp_configuration_file(kwargs):
|
||||
@@ -138,6 +143,8 @@ class LLMPipeline(Pipeline):
|
||||
outputs, skip_special_tokens=True, **kwargs)
|
||||
if is_messages:
|
||||
response = self.format_output(response, **kwargs)
|
||||
else:
|
||||
response = {OutputKeys.TEXT: response}
|
||||
|
||||
return response
|
||||
|
||||
@@ -260,7 +267,7 @@ def chatglm2_format_output(response, **kwargs):
|
||||
response = response.replace('[[训练时间]]', '2023年')
|
||||
messages = {'role': 'assistant', 'content': response}
|
||||
outputs = {
|
||||
'messages': messages,
|
||||
'message': messages,
|
||||
}
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -773,6 +773,7 @@ def pipeline_output_to_service_base64_output(task_name, pipeline_output):
|
||||
pipeline_output = pipeline_output[0]
|
||||
for key, value in pipeline_output.items():
|
||||
if key not in task_outputs:
|
||||
json_serializable_output[key] = value
|
||||
continue # skip the output not defined.
|
||||
if key in [
|
||||
OutputKeys.OUTPUT_IMG, OutputKeys.OUTPUT_IMGS,
|
||||
|
||||
@@ -1,78 +1,85 @@
|
||||
{
|
||||
"action-detection":{
|
||||
"input":{
|
||||
"video":"data/test/videos/action_detection_test_video.mp4"
|
||||
"video":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/videos/action_detection_test_video.mp4"
|
||||
}
|
||||
},
|
||||
"action-recognition":{
|
||||
"input":{
|
||||
"video":"data/test/videos/action_recognition_test_video.mp4"
|
||||
"video":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/videos/action_recognition_test_video.mp4"
|
||||
}
|
||||
},
|
||||
"animal-recognition":{
|
||||
"input":{
|
||||
"image":"data/test/images/dogs.jpg"
|
||||
"image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/dogs.jpg"
|
||||
}
|
||||
},
|
||||
"chat":{
|
||||
"input":{
|
||||
"text":"你有什么推荐吗?",
|
||||
"history":[
|
||||
[
|
||||
"今天天气真好,",
|
||||
"今天天气真好,出去走走怎么样?"
|
||||
]
|
||||
]
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": "Hello! 你是谁?"
|
||||
}, {
|
||||
"role": "assistant",
|
||||
"content": "我是你的助手。"
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "你叫什么名字?"
|
||||
}]
|
||||
},
|
||||
"parameters": {
|
||||
"do_sample": true,
|
||||
"max_length": 512
|
||||
}
|
||||
},
|
||||
"domain-specific-object-detection":{
|
||||
"input":{
|
||||
"image":"data/test/images/image_traffic_sign.jpg"
|
||||
"image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/image_traffic_sign.jpg"
|
||||
}
|
||||
},
|
||||
"face-2d-keypoints":{
|
||||
"input":{
|
||||
"image":"data/test/images/face_detection.png"
|
||||
"image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/face_detection.png"
|
||||
}
|
||||
},
|
||||
"face-attribute-recognition":{
|
||||
"input":{
|
||||
"image":"data/test/images/face_recognition_1.png"
|
||||
"image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/face_recognition_1.png"
|
||||
}
|
||||
},
|
||||
"facial-expression-recognition":{
|
||||
"input":{
|
||||
"image":"data/test/images/facial_expression_recognition.jpg"
|
||||
"image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/facial_expression_recognition.jpg"
|
||||
}
|
||||
},
|
||||
"general-recognition":{
|
||||
"input":{
|
||||
"image":"data/test/images/dogs.jpg"
|
||||
"image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/dogs.jpg"
|
||||
}
|
||||
},
|
||||
"human-detection":{
|
||||
"input":{
|
||||
"image":"data/test/images/image_detection.jpg"
|
||||
"image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/image_detection.jpg"
|
||||
}
|
||||
},
|
||||
"image-captioning":{
|
||||
"input":{
|
||||
"image":"data/test/images/image_captioning.png"
|
||||
"image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/image_captioning.png"
|
||||
}
|
||||
},
|
||||
"image-classification":{
|
||||
"input":{
|
||||
"image":"data/test/images/content_check.jpg"
|
||||
"image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/content_check.jpg"
|
||||
}
|
||||
},
|
||||
"image-demoireing":{
|
||||
"input":{
|
||||
"image":"data/test/images/shop_segmentation.jpg"
|
||||
"image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/shop_segmentation.jpg"
|
||||
}
|
||||
},
|
||||
"image-object-detection":{
|
||||
"input":{
|
||||
"image":"data/test/images/image_detection.jpg"
|
||||
"image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/image_detection.jpg"
|
||||
}
|
||||
},
|
||||
"image-portrait-stylization":{
|
||||
@@ -82,7 +89,7 @@
|
||||
},
|
||||
"image-segmentation":{
|
||||
"input":{
|
||||
"image":"data/test/images/image_semantic_segmentation.jpg"
|
||||
"image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/image_semantic_segmentation.jpg"
|
||||
},
|
||||
"parameters":{
|
||||
|
||||
@@ -90,18 +97,18 @@
|
||||
},
|
||||
"image-text-retrieval":{
|
||||
"input":{
|
||||
"image":"data/test/images/image_mplug_vqa.jpg",
|
||||
"image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/image_mplug_vqa.jpg",
|
||||
"text":"What is the woman doing?"
|
||||
}
|
||||
},
|
||||
"indoor-layout-estimation":{
|
||||
"input":{
|
||||
"image":"data/test/images/image_traffic_sign.jpg"
|
||||
"image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/image_traffic_sign.jpg"
|
||||
}
|
||||
},
|
||||
"live-category":{
|
||||
"input":{
|
||||
"video":"data/test/videos/live_category_test_video.mp4"
|
||||
"video":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/videos/live_category_test_video.mp4"
|
||||
}
|
||||
},
|
||||
"motion-generation":{
|
||||
@@ -125,22 +132,22 @@
|
||||
},
|
||||
"ocr-recognition":{
|
||||
"input":{
|
||||
"image":"data/test/images/image_ocr_recognition.jpg"
|
||||
"image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/image_ocr_recognition.jpg"
|
||||
}
|
||||
},
|
||||
"panorama-depth-estimation":{
|
||||
"input":{
|
||||
"image":"data/test/images/panorama_depth_estimation.jpg"
|
||||
"image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/panorama_depth_estimation.jpg"
|
||||
}
|
||||
},
|
||||
"semantic-segmentation":{
|
||||
"input":{
|
||||
"image":"data/test/images/image_salient_detection.jpg"
|
||||
"image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/image_salient_detection.jpg"
|
||||
}
|
||||
},
|
||||
"shop-segmentation":{
|
||||
"input":{
|
||||
"image":"data/test/images/shop_segmentation.jpg"
|
||||
"image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/shop_segmentation.jpg"
|
||||
}
|
||||
},
|
||||
"text-classification":{
|
||||
@@ -153,7 +160,7 @@
|
||||
},
|
||||
"text-driven-segmentation":{
|
||||
"input":{
|
||||
"image":"data/test/images/text_driven_segmentation.jpg",
|
||||
"image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/text_driven_segmentation.jpg",
|
||||
"text":"bear"
|
||||
}
|
||||
},
|
||||
@@ -194,60 +201,60 @@
|
||||
},
|
||||
"video-captioning":{
|
||||
"input":{
|
||||
"video":"data/test/videos/video_caption_and_qa_test.mp4"
|
||||
"video":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/videos/video_caption_and_qa_test.mp4"
|
||||
}
|
||||
},
|
||||
"video-category":{
|
||||
"input":{
|
||||
"video":"data/test/videos/video_category_test_video.mp4"
|
||||
"video":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/videos/video_category_test_video.mp4"
|
||||
}
|
||||
},
|
||||
"video-depth-estimation":{
|
||||
"input":{
|
||||
"video":"data/test/videos/video_depth_estimation.mp4"
|
||||
"video":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/videos/video_depth_estimation.mp4"
|
||||
}
|
||||
},
|
||||
"video-embedding":{
|
||||
"input":{
|
||||
"video":"data/test/videos/action_recognition_test_video.mp4"
|
||||
"video":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/videos/action_recognition_test_video.mp4"
|
||||
}
|
||||
},
|
||||
"video-multi-object-tracking":{
|
||||
"input":{
|
||||
"video":"data/test/videos/MOT17-03-partial.mp4"
|
||||
"video":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/videos/MOT17-03-partial.mp4"
|
||||
}
|
||||
},
|
||||
"video-panoptic-segmentation":{
|
||||
"input":{
|
||||
"video":"data/test/videos/kitti-step_testing_image_02_0000.mp4"
|
||||
"video":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/videos/kitti-step_testing_image_02_0000.mp4"
|
||||
}
|
||||
},
|
||||
"video-question-answering":{
|
||||
"input":{
|
||||
"video":"data/test/videos/video_caption_and_qa_test.mp4",
|
||||
"video":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/videos/video_caption_and_qa_test.mp4",
|
||||
"text":"How many people are there?"
|
||||
}
|
||||
},
|
||||
"video-summarization":{
|
||||
"input":{
|
||||
"text":"data/test/videos/video_category_test_video.mp4"
|
||||
"text":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/videos/video_category_test_video.mp4"
|
||||
}
|
||||
},
|
||||
"visual-entailment":{
|
||||
"input":{
|
||||
"image":"data/test/images/dogs.jpg",
|
||||
"image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/dogs.jpg",
|
||||
"text":"there are two birds."
|
||||
}
|
||||
},
|
||||
"visual-grounding":{
|
||||
"input":{
|
||||
"image":"data/test/images/visual_grounding.png",
|
||||
"image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/visual_grounding.png",
|
||||
"text":"a blue turtle-like pokemon with round head"
|
||||
}
|
||||
},
|
||||
"visual-question-answering":{
|
||||
"input":{
|
||||
"image":"data/test/images/image_mplug_vqa.jpg",
|
||||
"image":"http://modelscope.oss-cn-beijing.aliyuncs.com/demo/images/image_mplug_vqa.jpg",
|
||||
"text":"What is the woman doing?"
|
||||
}
|
||||
},
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Make sure to modify __release_datetime__ to release time when making official release.
|
||||
__version__ = '1.9.2'
|
||||
__version__ = '1.9.3'
|
||||
# default release datetime for branches under active development is set
|
||||
# to be a time far-far-away-into-the-future
|
||||
__release_datetime__ = '2099-09-06 00:00:00'
|
||||
|
||||
@@ -37,12 +37,15 @@ class ModelJsonTest:
|
||||
|
||||
# init pipeline
|
||||
ppl = pipeline(
|
||||
task=task, model=model_id, model_revision=model_revision)
|
||||
task=task,
|
||||
model=model_id,
|
||||
model_revision=model_revision,
|
||||
llm_first=True)
|
||||
pipeline_info = get_pipeline_information_by_pipeline(ppl)
|
||||
|
||||
# call pipeline
|
||||
data = get_task_input_examples(task)
|
||||
print(task, data)
|
||||
|
||||
infer_result = call_pipeline_with_json(pipeline_info, ppl, data)
|
||||
result = pipeline_output_to_service_base64_output(task, infer_result)
|
||||
return result
|
||||
@@ -50,27 +53,20 @@ class ModelJsonTest:
|
||||
|
||||
if __name__ == '__main__':
|
||||
model_list = [
|
||||
'damo/nlp_structbert_nli_chinese-base',
|
||||
'damo/nlp_structbert_word-segmentation_chinese-base',
|
||||
'damo/nlp_structbert_zero-shot-classification_chinese-base',
|
||||
'damo/cv_unet_person-image-cartoon_compound-models',
|
||||
'damo/nlp_structbert_sentiment-classification_chinese-tiny',
|
||||
'damo/nlp_csanmt_translation_zh2en',
|
||||
'damo/nlp_rom_passage-ranking_chinese-base',
|
||||
'damo/ofa_image-caption_muge_base_zh',
|
||||
'damo/nlp_raner_named-entity-recognition_chinese-base-ecom-50cls',
|
||||
'damo/nlp_structbert_sentiment-classification_chinese-ecommerce-base',
|
||||
'damo/text-to-video-synthesis',
|
||||
'qwen/Qwen-7B',
|
||||
'qwen/Qwen-7B-Chat',
|
||||
'ZhipuAI/ChatGLM-6B',
|
||||
'qwen/Qwen-7B-Chat-Int4',
|
||||
'qwen/Qwen-14B-Chat-Int4',
|
||||
'baichuan-inc/Baichuan2-7B-Chat-4bits',
|
||||
'baichuan-inc/Baichuan2-13B-Chat-4bits',
|
||||
'ZhipuAI/chatglm2-6b-int4',
|
||||
]
|
||||
tester = ModelJsonTest()
|
||||
for model in model_list:
|
||||
try:
|
||||
res = tester.test_single(model)
|
||||
print(f'\nmodel_id {model} call_pipeline_with_json run ok.\n')
|
||||
print(
|
||||
f'\nmodel_id {model} call_pipeline_with_json run ok. {res}\n\n\n\n'
|
||||
)
|
||||
except BaseException as e:
|
||||
print(
|
||||
f'\nmodel_id {model} call_pipeline_with_json run failed: {e}.\n'
|
||||
f'\nmodel_id {model} call_pipeline_with_json run failed: {e}.\n\n\n\n'
|
||||
)
|
||||
|
||||
@@ -15,13 +15,14 @@ class ImageSuperResolutionPASDTest(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.model_id = 'damo/PASD_image_super_resolutions'
|
||||
self.model_v2_id = 'damo/PASD_v2_image_super_resolutions'
|
||||
self.img = 'data/test/images/dogs.jpg'
|
||||
self.input = {
|
||||
'image': self.img,
|
||||
'prompt': '',
|
||||
'upscale': 1,
|
||||
'fidelity_scale_fg': 1.5,
|
||||
'fidelity_scale_bg': 0.7
|
||||
'fidelity_scale_fg': 1.0,
|
||||
'fidelity_scale_bg': 1.0
|
||||
}
|
||||
self.task = Tasks.image_super_resolution_pasd
|
||||
|
||||
@@ -38,6 +39,13 @@ class ImageSuperResolutionPASDTest(unittest.TestCase):
|
||||
|
||||
self.pipeline_inference(super_resolution, self.input)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_modelhub_v2(self):
|
||||
super_resolution = pipeline(
|
||||
Tasks.image_super_resolution_pasd, model=self.model_v2_id)
|
||||
|
||||
self.pipeline_inference(super_resolution, self.input)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -3,6 +3,7 @@ import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from modelscope import pipeline
|
||||
from modelscope.pipelines.nlp.llm_pipeline import LLMPipeline
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
@@ -132,143 +133,172 @@ class LLMPipelineTest(unittest.TestCase):
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_chatglm2(self):
|
||||
pipe = LLMPipeline(model='ZhipuAI/chatglm2-6b', device_map='auto')
|
||||
pipe = pipeline(
|
||||
task='chat', model='ZhipuAI/chatglm2-6b', llm_first=True)
|
||||
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_chatglm2int4(self):
|
||||
pipe = LLMPipeline(model='ZhipuAI/chatglm2-6b-int4')
|
||||
pipe = pipeline(
|
||||
task='chat', model='ZhipuAI/chatglm2-6b-int4', llm_first=True)
|
||||
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_chatglm232k(self):
|
||||
pipe = LLMPipeline(model='ZhipuAI/chatglm2-6b-32k', device_map='auto')
|
||||
pipe = pipeline(
|
||||
task='chat', model='ZhipuAI/chatglm2-6b-32k', llm_first=True)
|
||||
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_llama2(self):
|
||||
pipe = LLMPipeline(
|
||||
pipe = pipeline(
|
||||
task='chat',
|
||||
model='modelscope/Llama-2-7b-ms',
|
||||
torch_dtype=torch.float16,
|
||||
device_map='auto',
|
||||
ignore_file_pattern=[r'.+\.bin$'])
|
||||
ignore_file_pattern=[r'.+\.bin$'],
|
||||
llm_first=True)
|
||||
print('messages: ', pipe(self.messages_en, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_en, **self.gen_cfg))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_llama2chat(self):
|
||||
pipe = LLMPipeline(
|
||||
pipe = pipeline(
|
||||
task='chat',
|
||||
model='modelscope/Llama-2-7b-chat-ms',
|
||||
revision='v1.0.2',
|
||||
torch_dtype=torch.float16,
|
||||
device_map='auto',
|
||||
ignore_file_pattern=[r'.+\.bin$'])
|
||||
ignore_file_pattern=[r'.+\.bin$'],
|
||||
llm_first=True)
|
||||
print('messages: ', pipe(self.messages_en, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_en, **self.gen_cfg))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_codellama(self):
|
||||
pipe = LLMPipeline(
|
||||
pipe = pipeline(
|
||||
task='chat',
|
||||
model='AI-ModelScope/CodeLlama-7b-Instruct-hf',
|
||||
torch_dtype=torch.float16,
|
||||
device_map='auto',
|
||||
ignore_file_pattern=[r'.+\.bin$'])
|
||||
ignore_file_pattern=[r'.+\.bin$'],
|
||||
llm_first=True)
|
||||
print('messages: ', pipe(self.messages_code, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_code, **self.gen_cfg))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_baichuan_7b(self):
|
||||
pipe = LLMPipeline(
|
||||
pipe = pipeline(
|
||||
task='chat',
|
||||
model='baichuan-inc/baichuan-7B',
|
||||
device_map='auto',
|
||||
torch_dtype=torch.float16)
|
||||
torch_dtype=torch.float16,
|
||||
llm_first=True)
|
||||
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_baichuan_13b(self):
|
||||
pipe = LLMPipeline(
|
||||
pipe = pipeline(
|
||||
task='chat',
|
||||
model='baichuan-inc/Baichuan-13B-Base',
|
||||
device_map='auto',
|
||||
torch_dtype=torch.float16)
|
||||
torch_dtype=torch.float16,
|
||||
llm_first=True)
|
||||
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_baichuan_13bchat(self):
|
||||
pipe = LLMPipeline(
|
||||
pipe = pipeline(
|
||||
task='chat',
|
||||
model='baichuan-inc/Baichuan-13B-Chat',
|
||||
device_map='auto',
|
||||
torch_dtype=torch.float16)
|
||||
torch_dtype=torch.float16,
|
||||
llm_first=True)
|
||||
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_baichuan2_7b(self):
|
||||
pipe = LLMPipeline(
|
||||
pipe = pipeline(
|
||||
task='chat',
|
||||
model='baichuan-inc/Baichuan2-7B-Base',
|
||||
device_map='auto',
|
||||
torch_dtype=torch.float16)
|
||||
torch_dtype=torch.float16,
|
||||
llm_first=True)
|
||||
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_baichuan2_7bchat(self):
|
||||
pipe = LLMPipeline(
|
||||
pipe = pipeline(
|
||||
task='chat',
|
||||
model='baichuan-inc/Baichuan2-7B-Chat',
|
||||
device_map='auto',
|
||||
torch_dtype=torch.float16)
|
||||
torch_dtype=torch.float16,
|
||||
llm_first=True)
|
||||
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||
|
||||
@unittest.skip('Need bitsandbytes')
|
||||
def test_baichuan2_7bchat_int4(self):
|
||||
pipe = LLMPipeline(
|
||||
pipe = pipeline(
|
||||
task='chat',
|
||||
model='baichuan-inc/Baichuan2-7B-Chat-4bits',
|
||||
device_map='auto',
|
||||
torch_dtype=torch.float16)
|
||||
torch_dtype=torch.float16,
|
||||
llm_first=True)
|
||||
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||
|
||||
@unittest.skip('Need bitsandbytes')
|
||||
def test_baichuan2_13bchat_int4(self):
|
||||
pipe = LLMPipeline(
|
||||
pipe = pipeline(
|
||||
task='chat',
|
||||
model='baichuan-inc/Baichuan2-13B-Chat-4bits',
|
||||
device_map='auto',
|
||||
torch_dtype=torch.float16)
|
||||
torch_dtype=torch.float16,
|
||||
llm_first=True)
|
||||
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_wizardlm_13b(self):
|
||||
pipe = LLMPipeline(
|
||||
pipe = pipeline(
|
||||
task='chat',
|
||||
model='AI-ModelScope/WizardLM-13B-V1.2',
|
||||
device_map='auto',
|
||||
torch_dtype=torch.float16,
|
||||
format_messages='wizardlm')
|
||||
format_messages='wizardlm',
|
||||
llm_first=True)
|
||||
print('messages: ', pipe(self.messages_en, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_en, **self.gen_cfg))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_wizardmath(self):
|
||||
pipe = LLMPipeline(
|
||||
pipe = pipeline(
|
||||
task='chat',
|
||||
model='AI-ModelScope/WizardMath-7B-V1.0',
|
||||
device_map='auto',
|
||||
torch_dtype=torch.float16,
|
||||
format_messages='wizardcode')
|
||||
format_messages='wizardcode',
|
||||
llm_first=True)
|
||||
print('messages: ', pipe(self.message_wizard_math, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_wizard_math, **self.gen_cfg))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_wizardcode_13b(self):
|
||||
pipe = LLMPipeline(
|
||||
pipe = pipeline(
|
||||
task='chat',
|
||||
model='AI-ModelScope/WizardCoder-Python-13B-V1.0',
|
||||
device_map='auto',
|
||||
torch_dtype=torch.float16,
|
||||
format_messages='wizardcode')
|
||||
format_messages='wizardcode',
|
||||
llm_first=True)
|
||||
print('messages: ', pipe(self.message_wizard_code, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_wizard_code, **self.gen_cfg))
|
||||
|
||||
@@ -284,19 +314,20 @@ class LLMPipelineTest(unittest.TestCase):
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_qwen(self):
|
||||
pipe = LLMPipeline(model='qwen/Qwen-7B-Chat', device_map='auto')
|
||||
pipe = pipeline(task='chat', model='qwen/Qwen-7B-Chat', llm_first=True)
|
||||
print('messages: ', pipe(self.messages_zh_with_system, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||
|
||||
@unittest.skip('Need optimum and auto-gptq')
|
||||
def test_qwen_int4(self):
|
||||
pipe = LLMPipeline(model='qwen/Qwen-7B-Chat-Int4', device_map='auto')
|
||||
pipe = pipeline(
|
||||
task='chat', model='qwen/Qwen-7B-Chat-Int4', llm_first=True)
|
||||
print('messages: ', pipe(self.messages_zh_with_system, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_qwen_vl(self):
|
||||
pipe = LLMPipeline(model='qwen/Qwen-VL-Chat', device_map='auto')
|
||||
pipe = pipeline(task='chat', model='qwen/Qwen-VL-Chat', llm_first=True)
|
||||
print('messages: ', pipe(self.messages_mm, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||
|
||||
|
||||
@@ -231,8 +231,8 @@ def get_test_suites_to_run():
|
||||
affected_pipeline_cases.extend(
|
||||
task_pipeline_test_suite_map[affected_register_module[1]])
|
||||
else:
|
||||
logger.warn('Pipeline task: %s has no test case!'
|
||||
% affected_register_module[1])
|
||||
logger.warning('Pipeline task: %s has no test case!'
|
||||
% affected_register_module[1])
|
||||
elif affected_register_module[0] == 'MODELS':
|
||||
# ["MODELS", "keyword_spotting", "kws_kwsbp", "GenericKeyWordSpotting"],
|
||||
# ["MODELS", task, model_name, model_class_name]
|
||||
@@ -240,8 +240,8 @@ def get_test_suites_to_run():
|
||||
affected_pipeline_cases.extend(
|
||||
task_pipeline_test_suite_map[affected_register_module[1]])
|
||||
else:
|
||||
logger.warn('Pipeline task: %s has no test case!'
|
||||
% affected_register_module[1])
|
||||
logger.warning('Pipeline task: %s has no test case!'
|
||||
% affected_register_module[1])
|
||||
elif affected_register_module[0] == 'TRAINERS':
|
||||
# ["TRAINERS", "", "nlp_base_trainer", "NlpEpochBasedTrainer"],
|
||||
# ["TRAINERS", "", trainer_name, trainer_class_name]
|
||||
@@ -298,13 +298,27 @@ def get_test_suites_to_run():
|
||||
return test_suites_to_run
|
||||
|
||||
|
||||
def get_files_related_modules(files):
|
||||
def get_files_related_modules(files, reverse_import_map):
|
||||
register_modules = []
|
||||
for single_file in files:
|
||||
if single_file.startswith('./modelscope') or \
|
||||
single_file.startswith('modelscope'):
|
||||
register_modules.extend(get_file_register_modules(single_file))
|
||||
|
||||
while len(register_modules) == 0:
|
||||
logger.warn('There is no affected register module')
|
||||
deeper_imported_by = []
|
||||
has_deeper_affected_files = False
|
||||
for source_file in files:
|
||||
if len(source_file.split('/')) > 4 and source_file.startswith(
|
||||
'modelscope'):
|
||||
deeper_imported_by.extend(reverse_import_map[source_file])
|
||||
has_deeper_affected_files = True
|
||||
if not has_deeper_affected_files:
|
||||
break
|
||||
for file in deeper_imported_by:
|
||||
register_modules = get_file_register_modules(file)
|
||||
files = deeper_imported_by
|
||||
return register_modules
|
||||
|
||||
|
||||
@@ -354,8 +368,8 @@ def get_all_file_test_info():
|
||||
file_test_info = {}
|
||||
file_test_info['imports'] = import_map[f]
|
||||
file_test_info['imported_by'] = reverse_depend_map[f]
|
||||
register_modules = get_files_related_modules([f]
|
||||
+ reverse_depend_map[f])
|
||||
register_modules = get_files_related_modules(
|
||||
[f] + reverse_depend_map[f], reverse_depend_map)
|
||||
file_test_info['relate_modules'] = register_modules
|
||||
affected_pipeline_cases, affected_trainer_cases = get_modules_related_cases(
|
||||
register_modules, task_pipeline_test_suite_map,
|
||||
|
||||
@@ -90,7 +90,8 @@ class AnalysisTestClass(ast.NodeVisitor):
|
||||
if isinstance(item, ast.Name):
|
||||
res.append(self.get_variables(item.id))
|
||||
elif isinstance(item, ast.Attribute):
|
||||
res.append(self.get_variables(item.value.id))
|
||||
if hasattr(item.value, 'id'):
|
||||
res.append(self.get_variables(item.value.id))
|
||||
elif isinstance(item, ast.Str):
|
||||
res.append(self.get_variables(item.s))
|
||||
elif isinstance(item, ast.Dict):
|
||||
|
||||
@@ -8,24 +8,62 @@ import pkgutil
|
||||
import site
|
||||
import sys
|
||||
|
||||
import json
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class AnalysisSourceFileDefines(ast.NodeVisitor):
|
||||
"""Analysis source file function, class, global variable defines.
|
||||
"""
|
||||
|
||||
def __init__(self, source_file_path) -> None:
|
||||
super().__init__()
|
||||
self.global_variables = []
|
||||
self.functions = []
|
||||
self.classes = []
|
||||
self.async_functions = []
|
||||
self.symbols = []
|
||||
|
||||
self.source_file_path = source_file_path
|
||||
rel_file_path = source_file_path
|
||||
if os.path.isabs(source_file_path):
|
||||
rel_file_path = os.path.relpath(source_file_path, os.getcwd())
|
||||
|
||||
if rel_file_path.endswith('__init__.py'): # processing package
|
||||
self.base_module_name = os.path.dirname(rel_file_path).replace(
|
||||
'/', '.')
|
||||
else: # import x.y.z z is the filename
|
||||
self.base_module_name = rel_file_path.replace('/', '.').replace(
|
||||
'.py', '')
|
||||
self.symbols.append(self.base_module_name)
|
||||
|
||||
def visit_ClassDef(self, node: ast.ClassDef):
|
||||
self.symbols.append(self.base_module_name + '.' + node.name)
|
||||
self.classes.append(node.name)
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef):
|
||||
self.symbols.append(self.base_module_name + '.' + node.name)
|
||||
self.functions.append(node.name)
|
||||
|
||||
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
|
||||
self.symbols.append(self.base_module_name + '.' + node.name)
|
||||
self.async_functions.append(node.name)
|
||||
|
||||
def visit_Assign(self, node: ast.Assign):
|
||||
for tg in node.targets:
|
||||
if isinstance(tg, ast.Name):
|
||||
self.symbols.append(self.base_module_name + '.' + tg.id)
|
||||
self.global_variables.append(tg.id)
|
||||
|
||||
|
||||
def is_relative_import(path):
|
||||
# from .x import y or from ..x import y
|
||||
return path.startswith('.')
|
||||
|
||||
|
||||
def resolve_import(module_name):
|
||||
try:
|
||||
spec = importlib.util.find_spec(module_name)
|
||||
return spec and spec.origin
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def convert_to_path(name):
|
||||
if name.startswith('.'):
|
||||
remainder = name.lstrip('.')
|
||||
@@ -39,58 +77,93 @@ def convert_to_path(name):
|
||||
return filename
|
||||
|
||||
|
||||
def resolve_relative_import(source_file_path, module_name):
|
||||
def resolve_relative_import(source_file_path, module_name, all_symbols):
|
||||
current_package = os.path.dirname(source_file_path).replace('/', '.')
|
||||
absolute_name = importlib.util.resolve_name(module_name,
|
||||
current_package) # get
|
||||
return resolve_absolute_import(absolute_name)
|
||||
return resolve_absolute_import(absolute_name, all_symbols)
|
||||
|
||||
|
||||
def onerror(name):
|
||||
logger.error('Importing module %s error!' % name)
|
||||
def resolve_absolute_import(module_name, all_symbols):
|
||||
# direct imports
|
||||
if module_name in all_symbols:
|
||||
return all_symbols[module_name]
|
||||
|
||||
# some symble import by package __init__.py, we need find the real file which define the symbel.
|
||||
parent, sub = module_name.rsplit('.', 1)
|
||||
|
||||
# case module_name is a python Definition
|
||||
for symbol, symbol_path in all_symbols.items():
|
||||
if symbol.startswith(parent) and symbol.endswith(sub):
|
||||
return all_symbols[symbol]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def resolve_absolute_import(module_name):
|
||||
module_file_path = resolve_import(module_name)
|
||||
if module_file_path is None:
|
||||
# find from base module.
|
||||
parent_module, sub_module = module_name.rsplit('.', 1)
|
||||
if parent_module in sys.modules:
|
||||
if hasattr(sys.modules[parent_module], '_import_structure'):
|
||||
import_structure = sys.modules[parent_module]._import_structure
|
||||
for k, v in import_structure.items():
|
||||
if sub_module in v:
|
||||
parent_module = parent_module + '.' + k
|
||||
break
|
||||
module_file_path = resolve_absolute_import(parent_module)
|
||||
# the parent_module is a package, we need find the module_name's file
|
||||
if os.path.basename(module_file_path) == '__init__.py' and \
|
||||
(os.path.relpath(module_file_path, site.getsitepackages()[0]) != 'modelscope/__init__.py'
|
||||
or os.path.relpath(module_file_path, os.getcwd()) != 'modelscope/__init__.py'):
|
||||
for _, sub_module_name, _ in pkgutil.walk_packages(
|
||||
[os.path.dirname(module_file_path)],
|
||||
parent_module + '.',
|
||||
onerror=onerror):
|
||||
try:
|
||||
module_ = importlib.import_module(sub_module_name)
|
||||
for k, v in module_.__dict__.items():
|
||||
if k == sub_module and v.__module__ == module_.__name__:
|
||||
module_file_path = module_.__file__
|
||||
break
|
||||
except ModuleNotFoundError as e:
|
||||
logger.warn(
|
||||
'Import error in %s, ModuleNotFoundError: %s' %
|
||||
(sub_module_name, e))
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warn('Import error in %s, Exception: %s' %
|
||||
(sub_module_name, e))
|
||||
continue
|
||||
class IndirectDefines(ast.NodeVisitor):
|
||||
"""Analysis source file function, class, global variable defines.
|
||||
"""
|
||||
|
||||
def __init__(self, source_file_path, all_symbols,
|
||||
file_symbols_map) -> None:
|
||||
super().__init__()
|
||||
self.symbols_map = {
|
||||
} # key symbol name in current file, value the real file path.
|
||||
self.all_symbols = all_symbols
|
||||
self.file_symbols_map = file_symbols_map
|
||||
self.source_file_path = source_file_path
|
||||
|
||||
rel_file_path = source_file_path
|
||||
if os.path.isabs(source_file_path):
|
||||
rel_file_path = os.path.relpath(source_file_path, os.getcwd())
|
||||
|
||||
if rel_file_path.endswith('__init__.py'): # processing package
|
||||
self.base_module_name = os.path.dirname(rel_file_path).replace(
|
||||
'/', '.')
|
||||
else: # import x.y.z z is the filename
|
||||
self.base_module_name = rel_file_path.replace('/', '.').replace(
|
||||
'.py', '')
|
||||
|
||||
# import from will get the symbol in current file.
|
||||
# from a import b, will get b in current file.
|
||||
def visit_ImportFrom(self, node):
|
||||
# level 0 absolute import such as from os.path import join
|
||||
# level 1 from .x import y
|
||||
# level 2 from ..x import y
|
||||
module_name = '.' * node.level + (node.module or '')
|
||||
for alias in node.names:
|
||||
file_path = None
|
||||
if alias.name == '*': # from x import *
|
||||
if is_relative_import(module_name):
|
||||
# resolve model path.
|
||||
file_path = resolve_relative_import(
|
||||
self.source_file_path, module_name, self.all_symbols)
|
||||
elif module_name.startswith('modelscope'):
|
||||
file_path = resolve_absolute_import(
|
||||
module_name, self.all_symbols)
|
||||
else:
|
||||
file_path = None # ignore other package.
|
||||
if file_path is not None:
|
||||
for symbol in self.file_symbols_map[file_path][1:]:
|
||||
symbol_name = symbol.split('.')[-1]
|
||||
self.symbols_map[self.base_module_name
|
||||
+ symbol_name] = file_path
|
||||
else:
|
||||
return module_file_path
|
||||
else:
|
||||
module_file_path = resolve_absolute_import(parent_module)
|
||||
return module_file_path
|
||||
if not module_name.endswith('.'):
|
||||
module_name = module_name + '.'
|
||||
name = module_name + alias.name
|
||||
if alias.asname is not None:
|
||||
current_module_name = self.base_module_name + '.' + alias.asname
|
||||
else:
|
||||
current_module_name = self.base_module_name + '.' + alias.name
|
||||
if is_relative_import(name):
|
||||
# resolve model path.
|
||||
file_path = resolve_relative_import(
|
||||
self.source_file_path, name, self.all_symbols)
|
||||
elif name.startswith('modelscope'):
|
||||
file_path = resolve_absolute_import(name, self.all_symbols)
|
||||
if file_path is not None:
|
||||
self.symbols_map[current_module_name] = file_path
|
||||
|
||||
|
||||
class AnalysisSourceFileImports(ast.NodeVisitor):
|
||||
@@ -98,23 +171,19 @@ class AnalysisSourceFileImports(ast.NodeVisitor):
|
||||
List imports of the modelscope.
|
||||
"""
|
||||
|
||||
def __init__(self, source_file_path) -> None:
|
||||
def __init__(self, source_file_path, all_symbols) -> None:
|
||||
super().__init__()
|
||||
self.imports = []
|
||||
self.source_file_path = source_file_path
|
||||
self.all_symbols = all_symbols
|
||||
|
||||
def visit_Import(self, node):
|
||||
"""Processing import x,y,z or import os.path as osp"""
|
||||
for alias in node.names:
|
||||
if alias.name.startswith('modelscope'):
|
||||
file_path = resolve_absolute_import(alias.name)
|
||||
if file_path.startswith(site.getsitepackages()[0]):
|
||||
self.imports.append(
|
||||
os.path.relpath(file_path,
|
||||
site.getsitepackages()[0]))
|
||||
else:
|
||||
self.imports.append(
|
||||
os.path.relpath(file_path, os.getcwd()))
|
||||
file_path = resolve_absolute_import(alias.name,
|
||||
self.all_symbols)
|
||||
self.imports.append(os.path.relpath(file_path, os.getcwd()))
|
||||
|
||||
def visit_ImportFrom(self, node):
|
||||
# level 0 absolute import such as from os.path import join
|
||||
@@ -126,9 +195,10 @@ class AnalysisSourceFileImports(ast.NodeVisitor):
|
||||
if is_relative_import(module_name):
|
||||
# resolve model path.
|
||||
file_path = resolve_relative_import(
|
||||
self.source_file_path, module_name)
|
||||
self.source_file_path, module_name, self.all_symbols)
|
||||
elif module_name.startswith('modelscope'):
|
||||
file_path = resolve_absolute_import(module_name)
|
||||
file_path = resolve_absolute_import(
|
||||
module_name, self.all_symbols)
|
||||
else:
|
||||
file_path = None # ignore other package.
|
||||
else:
|
||||
@@ -138,9 +208,17 @@ class AnalysisSourceFileImports(ast.NodeVisitor):
|
||||
if is_relative_import(name):
|
||||
# resolve model path.
|
||||
file_path = resolve_relative_import(
|
||||
self.source_file_path, name)
|
||||
self.source_file_path, name, self.all_symbols)
|
||||
if file_path is None:
|
||||
logger.warning(
|
||||
'File: %s, import %s%s not exist!' %
|
||||
(self.source_file_path, module_name, alias.name))
|
||||
elif name.startswith('modelscope'):
|
||||
file_path = resolve_absolute_import(name)
|
||||
file_path = resolve_absolute_import(name, self.all_symbols)
|
||||
if file_path is None:
|
||||
logger.warning(
|
||||
'File: %s, import %s%s not exist!' %
|
||||
(self.source_file_path, module_name, alias.name))
|
||||
else:
|
||||
file_path = None # ignore other package.
|
||||
|
||||
@@ -152,6 +230,10 @@ class AnalysisSourceFileImports(ast.NodeVisitor):
|
||||
else:
|
||||
self.imports.append(
|
||||
os.path.relpath(file_path, os.getcwd()))
|
||||
elif module_name.startswith('modelscope'):
|
||||
logger.warning(
|
||||
'File: %s, import %s%s not exist!' %
|
||||
(self.source_file_path, module_name, alias.name))
|
||||
|
||||
|
||||
class AnalysisSourceFileRegisterModules(ast.NodeVisitor):
|
||||
@@ -216,14 +298,14 @@ class AnalysisSourceFileRegisterModules(ast.NodeVisitor):
|
||||
node.name)) # PIPELINES, task, module, class_name
|
||||
|
||||
|
||||
def get_imported_files(file_path):
|
||||
def get_imported_files(file_path, all_symbols):
|
||||
"""Get file dependencies.
|
||||
"""
|
||||
if os.path.isabs(file_path):
|
||||
file_path = os.path.relpath(file_path, os.getcwd())
|
||||
with open(file_path, 'rb') as f:
|
||||
src = f.read()
|
||||
analyzer = AnalysisSourceFileImports(file_path)
|
||||
analyzer = AnalysisSourceFileImports(file_path, all_symbols)
|
||||
analyzer.visit(ast.parse(src, filename=file_path))
|
||||
return list(set(analyzer.imports))
|
||||
|
||||
@@ -236,7 +318,6 @@ def path_to_module_name(file_path):
|
||||
|
||||
|
||||
def get_file_register_modules(file_path):
|
||||
logger.info('Get file: %s register_module' % file_path)
|
||||
with open(file_path, 'rb') as f:
|
||||
src = f.read()
|
||||
analyzer = AnalysisSourceFileRegisterModules(file_path)
|
||||
@@ -244,15 +325,51 @@ def get_file_register_modules(file_path):
|
||||
return analyzer.register_modules
|
||||
|
||||
|
||||
def get_file_defined_symbols(file_path):
|
||||
if os.path.isabs(file_path):
|
||||
file_path = os.path.relpath(file_path, os.getcwd())
|
||||
with open(file_path, 'rb') as f:
|
||||
src = f.read()
|
||||
analyzer = AnalysisSourceFileDefines(file_path)
|
||||
analyzer.visit(ast.parse(src, filename=file_path))
|
||||
return analyzer.symbols
|
||||
|
||||
|
||||
def get_indirect_symbols(file_path, symbols, file_symbols_map):
|
||||
if os.path.isabs(file_path):
|
||||
file_path = os.path.relpath(file_path, os.getcwd())
|
||||
with open(file_path, 'rb') as f:
|
||||
src = f.read()
|
||||
analyzer = IndirectDefines(file_path, symbols, file_symbols_map)
|
||||
analyzer.visit(ast.parse(src, filename=file_path))
|
||||
return analyzer.symbols_map
|
||||
|
||||
|
||||
def get_import_map():
|
||||
all_files = [
|
||||
os.path.join(dp, f) for dp, dn, filenames in os.walk(
|
||||
os.path.join(os.getcwd(), 'modelscope')) for f in filenames
|
||||
if os.path.splitext(f)[1] == '.py'
|
||||
]
|
||||
all_symbols = {}
|
||||
file_symbols_map = {}
|
||||
for f in all_files:
|
||||
file_path = os.path.relpath(f, os.getcwd())
|
||||
file_symbols_map[file_path] = get_file_defined_symbols(f)
|
||||
for s in file_symbols_map[file_path]:
|
||||
all_symbols[s] = file_path
|
||||
|
||||
# get indirect(imported) symbols, refer to origin define.
|
||||
for f in all_files:
|
||||
for name, real_path in get_indirect_symbols(f, all_symbols,
|
||||
file_symbols_map).items():
|
||||
all_symbols[name] = os.path.relpath(real_path, os.getcwd())
|
||||
|
||||
with open('symbols.json', 'w') as f:
|
||||
json.dump(all_symbols, f)
|
||||
import_map = {}
|
||||
for f in all_files:
|
||||
files = get_imported_files(f)
|
||||
files = get_imported_files(f, all_symbols)
|
||||
import_map[os.path.relpath(f, os.getcwd())] = files
|
||||
|
||||
return import_map
|
||||
|
||||
Reference in New Issue
Block a user