mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
add text-to-360pano-image pipeline, mod cv requirements
7月份计划上线的360全景图生成模型,自研 模型权重文件地址https://www.modelscope.cn/models/damo/cv_diffusion_text-to-360panorama-image_generation/summary #### 依赖项说明 ##### 由于要使用xformers,torch版本最好使用1.13.1 ``` pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116 ``` ##### 对应的diffusers和xformers版本如下 ``` pip install -U diffusers==0.18.0 pip install xformers==0.0.16 pip install triton, accelerate, transformers ``` ##### ModelScope Library 需要使用cv ``` pip install modelscope pip install "modelscope[cv]" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html ``` ##### 此外,还需要安装第三方的一个库,Real-ESRGAN, 安装方法如下 ``` # Install basicsr - https://github.com/xinntao/BasicSR # We use BasicSR for both training and inference pip install basicsr # facexlib and gfpgan are for face enhancement pip install facexlib pip install gfpgan pip install Pillow pip install tqdm pip install realesrgan==0.3.0 ``` Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/13346430 * add text-to-360pano-image pipeline * add text-to-360pano-image pipeline, mod cv requirements * rm redundant files and cv requirements; add standard input and output definations * fix diffusers==0.18.0 and run test * fix diffusers==0.18.0 in multi-modal and run test again * add model_revision='v1.0.0' * fix yapf * add trycatch for enabling xformers * fix key error * add install xformers in test/setup * skip highres.fix in ci * feat: Fix conflict, auto commit by WebIDE
This commit is contained in:
committed by
wenmeng.zwm
parent
a780935317
commit
18f998a85c
@@ -214,6 +214,7 @@ class Models(object):
|
|||||||
mplug_owl = 'mplug-owl'
|
mplug_owl = 'mplug-owl'
|
||||||
clip_interrogator = 'clip-interrogator'
|
clip_interrogator = 'clip-interrogator'
|
||||||
stable_diffusion = 'stable-diffusion'
|
stable_diffusion = 'stable-diffusion'
|
||||||
|
text_to_360panorama_image = 'text-to-360panorama-image'
|
||||||
|
|
||||||
# science models
|
# science models
|
||||||
unifold = 'unifold'
|
unifold = 'unifold'
|
||||||
@@ -417,6 +418,7 @@ class Pipelines(object):
|
|||||||
vision_efficient_tuning = 'vision-efficient-tuning'
|
vision_efficient_tuning = 'vision-efficient-tuning'
|
||||||
image_bts_depth_estimation = 'image-bts-depth-estimation'
|
image_bts_depth_estimation = 'image-bts-depth-estimation'
|
||||||
pedestrian_attribute_recognition = 'resnet50_pedestrian-attribute-recognition_image'
|
pedestrian_attribute_recognition = 'resnet50_pedestrian-attribute-recognition_image'
|
||||||
|
text_to_360panorama_image = 'text-to-360panorama-image'
|
||||||
image_try_on = 'image-try-on'
|
image_try_on = 'image-try-on'
|
||||||
|
|
||||||
# nlp tasks
|
# nlp tasks
|
||||||
@@ -857,6 +859,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
|||||||
Tasks.pedestrian_attribute_recognition: (
|
Tasks.pedestrian_attribute_recognition: (
|
||||||
Pipelines.pedestrian_attribute_recognition,
|
Pipelines.pedestrian_attribute_recognition,
|
||||||
'damo/cv_resnet50_pedestrian-attribute-recognition_image'),
|
'damo/cv_resnet50_pedestrian-attribute-recognition_image'),
|
||||||
|
Tasks.text_to_360panorama_image: (
|
||||||
|
Pipelines.text_to_360panorama_image,
|
||||||
|
'damo/cv_diffusion_text-to-360panorama-image_generation'),
|
||||||
Tasks.image_try_on: (Pipelines.image_try_on,
|
Tasks.image_try_on: (Pipelines.image_try_on,
|
||||||
'damo/cv_SAL-VTON_virtual-try-on')
|
'damo/cv_SAL-VTON_virtual-try-on')
|
||||||
}
|
}
|
||||||
|
|||||||
24
modelscope/models/cv/text_to_360panorama_image/__init__.py
Normal file
24
modelscope/models/cv/text_to_360panorama_image/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from modelscope.utils.import_utils import LazyImportModule
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .pipeline_base import StableDiffusionBlendExtendPipeline
|
||||||
|
from .pipeline_sr import StableDiffusionControlNetImg2ImgPanoPipeline
|
||||||
|
|
||||||
|
else:
|
||||||
|
_import_structure = {
|
||||||
|
'pipeline_base': ['StableDiffusionBlendExtendPipeline'],
|
||||||
|
'pipeline_sr': ['StableDiffusionControlNetImg2ImgPanoPipeline'],
|
||||||
|
}
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.modules[__name__] = LazyImportModule(
|
||||||
|
__name__,
|
||||||
|
globals()['__file__'],
|
||||||
|
_import_structure,
|
||||||
|
module_spec=__spec__,
|
||||||
|
extra_objects={},
|
||||||
|
)
|
||||||
855
modelscope/models/cv/text_to_360panorama_image/pipeline_base.py
Normal file
855
modelscope/models/cv/text_to_360panorama_image/pipeline_base.py
Normal file
@@ -0,0 +1,855 @@
|
|||||||
|
# Copyright © Alibaba, Inc. and its affiliates.
|
||||||
|
# The implementation here is modifed based on diffusers.StableDiffusionPipeline,
|
||||||
|
# originally Apache 2.0 License and public available at
|
||||||
|
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import re
|
||||||
|
import warnings
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers import (AutoencoderKL, DiffusionPipeline,
|
||||||
|
StableDiffusionPipeline, UNet2DConditionModel)
|
||||||
|
from diffusers.configuration_utils import FrozenDict
|
||||||
|
from diffusers.image_processor import VaeImageProcessor
|
||||||
|
from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||||
|
from diffusers.models.vae import DecoderOutput
|
||||||
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
|
from diffusers.pipelines.stable_diffusion.safety_checker import \
|
||||||
|
StableDiffusionSafetyChecker
|
||||||
|
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||||
|
from diffusers.utils import (deprecate, is_accelerate_available,
|
||||||
|
is_accelerate_version, logging, randn_tensor,
|
||||||
|
replace_example_docstring)
|
||||||
|
from packaging import version
|
||||||
|
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
EXAMPLE_DOC_STRING = """
|
||||||
|
Examples:
|
||||||
|
```py
|
||||||
|
>>> import torch
|
||||||
|
>>> from diffusers import EulerAncestralDiscreteScheduler
|
||||||
|
>>> modelscope.models.cv.text_to_360panorama_image import StableDiffusionBlendExtendPipeline
|
||||||
|
>>> model_id = "damo/cv_diffusion_text-to-360panorama-image_generation/sd-base"
|
||||||
|
>>> pipe = StableDiffusionBlendExtendPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
|
||||||
|
>>> pipe = pipe.to("cuda")
|
||||||
|
>>> pipe.vae.enable_tiling()
|
||||||
|
>>> pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
|
||||||
|
>>> # remove following line if xformers is not installed
|
||||||
|
>>> pipe.enable_xformers_memory_efficient_attention()
|
||||||
|
>>> pipe.enable_model_cpu_offload()
|
||||||
|
>>> prompt = "a living room"
|
||||||
|
>>> image = pipe(prompt).images[0]
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
re_attention = re.compile(
|
||||||
|
r"""
|
||||||
|
\\\(|
|
||||||
|
\\\)|
|
||||||
|
\\\[|
|
||||||
|
\\]|
|
||||||
|
\\\\|
|
||||||
|
\\|
|
||||||
|
\(|
|
||||||
|
\[|
|
||||||
|
:([+-]?[.\d]+)\)|
|
||||||
|
\)|
|
||||||
|
]|
|
||||||
|
[^\\()\[\]:]+|
|
||||||
|
:
|
||||||
|
""",
|
||||||
|
re.X,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_prompt_attention(text):
|
||||||
|
"""
|
||||||
|
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
||||||
|
Accepted tokens are:
|
||||||
|
(abc) - increases attention to abc by a multiplier of 1.1
|
||||||
|
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
||||||
|
[abc] - decreases attention to abc by a multiplier of 1.1
|
||||||
|
"""
|
||||||
|
|
||||||
|
res = []
|
||||||
|
round_brackets = []
|
||||||
|
square_brackets = []
|
||||||
|
|
||||||
|
round_bracket_multiplier = 1.1
|
||||||
|
square_bracket_multiplier = 1 / 1.1
|
||||||
|
|
||||||
|
def multiply_range(start_position, multiplier):
|
||||||
|
for p in range(start_position, len(res)):
|
||||||
|
res[p][1] *= multiplier
|
||||||
|
|
||||||
|
for m in re_attention.finditer(text):
|
||||||
|
text = m.group(0)
|
||||||
|
weight = m.group(1)
|
||||||
|
|
||||||
|
if text.startswith('\\'):
|
||||||
|
res.append([text[1:], 1.0])
|
||||||
|
elif text == '(':
|
||||||
|
round_brackets.append(len(res))
|
||||||
|
elif text == '[':
|
||||||
|
square_brackets.append(len(res))
|
||||||
|
elif weight is not None and len(round_brackets) > 0:
|
||||||
|
multiply_range(round_brackets.pop(), float(weight))
|
||||||
|
elif text == ')' and len(round_brackets) > 0:
|
||||||
|
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
||||||
|
elif text == ']' and len(square_brackets) > 0:
|
||||||
|
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
||||||
|
else:
|
||||||
|
res.append([text, 1.0])
|
||||||
|
|
||||||
|
for pos in round_brackets:
|
||||||
|
multiply_range(pos, round_bracket_multiplier)
|
||||||
|
|
||||||
|
for pos in square_brackets:
|
||||||
|
multiply_range(pos, square_bracket_multiplier)
|
||||||
|
|
||||||
|
if len(res) == 0:
|
||||||
|
res = [['', 1.0]]
|
||||||
|
|
||||||
|
# merge runs of identical weights
|
||||||
|
i = 0
|
||||||
|
while i + 1 < len(res):
|
||||||
|
if res[i][1] == res[i + 1][1]:
|
||||||
|
res[i][0] += res[i + 1][0]
|
||||||
|
res.pop(i + 1)
|
||||||
|
else:
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str],
|
||||||
|
max_length: int):
|
||||||
|
r"""
|
||||||
|
Tokenize a list of prompts and return its tokens with weights of each token.
|
||||||
|
|
||||||
|
No padding, starting or ending token is included.
|
||||||
|
"""
|
||||||
|
tokens = []
|
||||||
|
weights = []
|
||||||
|
truncated = False
|
||||||
|
for text in prompt:
|
||||||
|
texts_and_weights = parse_prompt_attention(text)
|
||||||
|
text_token = []
|
||||||
|
text_weight = []
|
||||||
|
for word, weight in texts_and_weights:
|
||||||
|
# tokenize and discard the starting and the ending token
|
||||||
|
token = pipe.tokenizer(word).input_ids[1:-1]
|
||||||
|
text_token += token
|
||||||
|
# copy the weight by length of token
|
||||||
|
text_weight += [weight] * len(token)
|
||||||
|
# stop if the text is too long (longer than truncation limit)
|
||||||
|
if len(text_token) > max_length:
|
||||||
|
truncated = True
|
||||||
|
break
|
||||||
|
# truncate
|
||||||
|
if len(text_token) > max_length:
|
||||||
|
truncated = True
|
||||||
|
text_token = text_token[:max_length]
|
||||||
|
text_weight = text_weight[:max_length]
|
||||||
|
tokens.append(text_token)
|
||||||
|
weights.append(text_weight)
|
||||||
|
if truncated:
|
||||||
|
logger.warning(
|
||||||
|
'Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples'
|
||||||
|
)
|
||||||
|
return tokens, weights
|
||||||
|
|
||||||
|
|
||||||
|
def pad_tokens_and_weights(tokens,
|
||||||
|
weights,
|
||||||
|
max_length,
|
||||||
|
bos,
|
||||||
|
eos,
|
||||||
|
pad,
|
||||||
|
no_boseos_middle=True,
|
||||||
|
chunk_length=77):
|
||||||
|
r"""
|
||||||
|
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
||||||
|
"""
|
||||||
|
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
|
||||||
|
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
|
||||||
|
for i in range(len(tokens)):
|
||||||
|
tokens[i] = [
|
||||||
|
bos
|
||||||
|
] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos]
|
||||||
|
if no_boseos_middle:
|
||||||
|
weights[i] = [1.0] + weights[i] + [1.0] * (
|
||||||
|
max_length - 1 - len(weights[i]))
|
||||||
|
else:
|
||||||
|
w = []
|
||||||
|
if len(weights[i]) == 0:
|
||||||
|
w = [1.0] * weights_length
|
||||||
|
else:
|
||||||
|
for j in range(max_embeddings_multiples):
|
||||||
|
w.append(1.0) # weight for starting token in this chunk
|
||||||
|
w += weights[i][j * (chunk_length - 2):min(
|
||||||
|
len(weights[i]), (j + 1) * (chunk_length - 2))]
|
||||||
|
w.append(1.0) # weight for ending token in this chunk
|
||||||
|
w += [1.0] * (weights_length - len(w))
|
||||||
|
weights[i] = w[:]
|
||||||
|
|
||||||
|
return tokens, weights
|
||||||
|
|
||||||
|
|
||||||
|
def get_unweighted_text_embeddings(
|
||||||
|
pipe: DiffusionPipeline,
|
||||||
|
text_input: torch.Tensor,
|
||||||
|
chunk_length: int,
|
||||||
|
no_boseos_middle: Optional[bool] = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
When the length of tokens is a multiple of the capacity of the text encoder,
|
||||||
|
it should be split into chunks and sent to the text encoder individually.
|
||||||
|
"""
|
||||||
|
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
|
||||||
|
if max_embeddings_multiples > 1:
|
||||||
|
text_embeddings = []
|
||||||
|
for i in range(max_embeddings_multiples):
|
||||||
|
# extract the i-th chunk
|
||||||
|
text_input_chunk = text_input[:, i * (chunk_length - 2):(i + 1)
|
||||||
|
* (chunk_length - 2) + 2].clone()
|
||||||
|
|
||||||
|
# cover the head and the tail by the starting and the ending tokens
|
||||||
|
text_input_chunk[:, 0] = text_input[0, 0]
|
||||||
|
text_input_chunk[:, -1] = text_input[0, -1]
|
||||||
|
text_embedding = pipe.text_encoder(text_input_chunk)[0]
|
||||||
|
|
||||||
|
if no_boseos_middle:
|
||||||
|
if i == 0:
|
||||||
|
# discard the ending token
|
||||||
|
text_embedding = text_embedding[:, :-1]
|
||||||
|
elif i == max_embeddings_multiples - 1:
|
||||||
|
# discard the starting token
|
||||||
|
text_embedding = text_embedding[:, 1:]
|
||||||
|
else:
|
||||||
|
# discard both starting and ending tokens
|
||||||
|
text_embedding = text_embedding[:, 1:-1]
|
||||||
|
|
||||||
|
text_embeddings.append(text_embedding)
|
||||||
|
text_embeddings = torch.concat(text_embeddings, axis=1)
|
||||||
|
else:
|
||||||
|
text_embeddings = pipe.text_encoder(text_input)[0]
|
||||||
|
return text_embeddings
|
||||||
|
|
||||||
|
|
||||||
|
def get_weighted_text_embeddings(
|
||||||
|
pipe: DiffusionPipeline,
|
||||||
|
prompt: Union[str, List[str]],
|
||||||
|
uncond_prompt: Optional[Union[str, List[str]]] = None,
|
||||||
|
max_embeddings_multiples: Optional[int] = 3,
|
||||||
|
no_boseos_middle: Optional[bool] = False,
|
||||||
|
skip_parsing: Optional[bool] = False,
|
||||||
|
skip_weighting: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Prompts can be assigned with local weights using brackets. For example,
|
||||||
|
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
|
||||||
|
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
|
||||||
|
|
||||||
|
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pipe (`DiffusionPipeline`):
|
||||||
|
Pipe to provide access to the tokenizer and the text encoder.
|
||||||
|
prompt (`str` or `List[str]`):
|
||||||
|
The prompt or prompts to guide the image generation.
|
||||||
|
uncond_prompt (`str` or `List[str]`):
|
||||||
|
The unconditional prompt or prompts for guide the image generation. If unconditional prompt
|
||||||
|
is provided, the embeddings of prompt and uncond_prompt are concatenated.
|
||||||
|
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
||||||
|
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
||||||
|
no_boseos_middle (`bool`, *optional*, defaults to `False`):
|
||||||
|
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
|
||||||
|
ending token in each of the chunk in the middle.
|
||||||
|
skip_parsing (`bool`, *optional*, defaults to `False`):
|
||||||
|
Skip the parsing of brackets.
|
||||||
|
skip_weighting (`bool`, *optional*, defaults to `False`):
|
||||||
|
Skip the weighting. When the parsing is skipped, it is forced True.
|
||||||
|
"""
|
||||||
|
max_length = (pipe.tokenizer.model_max_length
|
||||||
|
- 2) * max_embeddings_multiples + 2
|
||||||
|
if isinstance(prompt, str):
|
||||||
|
prompt = [prompt]
|
||||||
|
|
||||||
|
if not skip_parsing:
|
||||||
|
prompt_tokens, prompt_weights = get_prompts_with_weights(
|
||||||
|
pipe, prompt, max_length - 2)
|
||||||
|
if uncond_prompt is not None:
|
||||||
|
if isinstance(uncond_prompt, str):
|
||||||
|
uncond_prompt = [uncond_prompt]
|
||||||
|
uncond_tokens, uncond_weights = get_prompts_with_weights(
|
||||||
|
pipe, uncond_prompt, max_length - 2)
|
||||||
|
else:
|
||||||
|
prompt_tokens = [
|
||||||
|
token[1:-1] for token in pipe.tokenizer(
|
||||||
|
prompt, max_length=max_length, truncation=True).input_ids
|
||||||
|
]
|
||||||
|
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
|
||||||
|
if uncond_prompt is not None:
|
||||||
|
if isinstance(uncond_prompt, str):
|
||||||
|
uncond_prompt = [uncond_prompt]
|
||||||
|
uncond_tokens = [
|
||||||
|
token[1:-1] for token in pipe.tokenizer(
|
||||||
|
uncond_prompt, max_length=max_length,
|
||||||
|
truncation=True).input_ids
|
||||||
|
]
|
||||||
|
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
|
||||||
|
|
||||||
|
# round up the longest length of tokens to a multiple of (model_max_length - 2)
|
||||||
|
max_length = max([len(token) for token in prompt_tokens])
|
||||||
|
if uncond_prompt is not None:
|
||||||
|
max_length = max(max_length,
|
||||||
|
max([len(token) for token in uncond_tokens]))
|
||||||
|
|
||||||
|
max_embeddings_multiples = min(
|
||||||
|
max_embeddings_multiples,
|
||||||
|
(max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
|
||||||
|
)
|
||||||
|
max_embeddings_multiples = max(1, max_embeddings_multiples)
|
||||||
|
max_length = (pipe.tokenizer.model_max_length
|
||||||
|
- 2) * max_embeddings_multiples + 2
|
||||||
|
|
||||||
|
# pad the length of tokens and weights
|
||||||
|
bos = pipe.tokenizer.bos_token_id
|
||||||
|
eos = pipe.tokenizer.eos_token_id
|
||||||
|
pad = getattr(pipe.tokenizer, 'pad_token_id', eos)
|
||||||
|
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
||||||
|
prompt_tokens,
|
||||||
|
prompt_weights,
|
||||||
|
max_length,
|
||||||
|
bos,
|
||||||
|
eos,
|
||||||
|
pad,
|
||||||
|
no_boseos_middle=no_boseos_middle,
|
||||||
|
chunk_length=pipe.tokenizer.model_max_length,
|
||||||
|
)
|
||||||
|
prompt_tokens = torch.tensor(
|
||||||
|
prompt_tokens, dtype=torch.long, device=pipe.device)
|
||||||
|
if uncond_prompt is not None:
|
||||||
|
uncond_tokens, uncond_weights = pad_tokens_and_weights(
|
||||||
|
uncond_tokens,
|
||||||
|
uncond_weights,
|
||||||
|
max_length,
|
||||||
|
bos,
|
||||||
|
eos,
|
||||||
|
pad,
|
||||||
|
no_boseos_middle=no_boseos_middle,
|
||||||
|
chunk_length=pipe.tokenizer.model_max_length,
|
||||||
|
)
|
||||||
|
uncond_tokens = torch.tensor(
|
||||||
|
uncond_tokens, dtype=torch.long, device=pipe.device)
|
||||||
|
|
||||||
|
# get the embeddings
|
||||||
|
text_embeddings = get_unweighted_text_embeddings(
|
||||||
|
pipe,
|
||||||
|
prompt_tokens,
|
||||||
|
pipe.tokenizer.model_max_length,
|
||||||
|
no_boseos_middle=no_boseos_middle,
|
||||||
|
)
|
||||||
|
prompt_weights = torch.tensor(
|
||||||
|
prompt_weights,
|
||||||
|
dtype=text_embeddings.dtype,
|
||||||
|
device=text_embeddings.device)
|
||||||
|
if uncond_prompt is not None:
|
||||||
|
uncond_embeddings = get_unweighted_text_embeddings(
|
||||||
|
pipe,
|
||||||
|
uncond_tokens,
|
||||||
|
pipe.tokenizer.model_max_length,
|
||||||
|
no_boseos_middle=no_boseos_middle,
|
||||||
|
)
|
||||||
|
uncond_weights = torch.tensor(
|
||||||
|
uncond_weights,
|
||||||
|
dtype=uncond_embeddings.dtype,
|
||||||
|
device=uncond_embeddings.device)
|
||||||
|
|
||||||
|
# assign weights to the prompts and normalize in the sense of mean
|
||||||
|
# TODO: should we normalize by chunk or in a whole (current implementation)?
|
||||||
|
if (not skip_parsing) and (not skip_weighting):
|
||||||
|
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(
|
||||||
|
text_embeddings.dtype)
|
||||||
|
text_embeddings *= prompt_weights.unsqueeze(-1)
|
||||||
|
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(
|
||||||
|
text_embeddings.dtype)
|
||||||
|
text_embeddings *= (previous_mean
|
||||||
|
/ current_mean).unsqueeze(-1).unsqueeze(-1)
|
||||||
|
if uncond_prompt is not None:
|
||||||
|
previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(
|
||||||
|
uncond_embeddings.dtype)
|
||||||
|
uncond_embeddings *= uncond_weights.unsqueeze(-1)
|
||||||
|
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(
|
||||||
|
uncond_embeddings.dtype)
|
||||||
|
uncond_embeddings *= (previous_mean
|
||||||
|
/ current_mean).unsqueeze(-1).unsqueeze(-1)
|
||||||
|
|
||||||
|
if uncond_prompt is not None:
|
||||||
|
return text_embeddings, uncond_embeddings
|
||||||
|
return text_embeddings, None
|
||||||
|
|
||||||
|
|
||||||
|
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||||
|
"""
|
||||||
|
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
||||||
|
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
||||||
|
"""
|
||||||
|
std_text = noise_pred_text.std(
|
||||||
|
dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
||||||
|
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
||||||
|
# rescale the results from guidance (fixes overexposure)
|
||||||
|
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
||||||
|
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
||||||
|
noise_cfg = guidance_rescale * noise_pred_rescaled + (
|
||||||
|
1 - guidance_rescale) * noise_cfg
|
||||||
|
return noise_cfg
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionBlendExtendPipeline(StableDiffusionPipeline):
|
||||||
|
r"""
|
||||||
|
Pipeline for text-to-image generation using Stable Diffusion.
|
||||||
|
|
||||||
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||||
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||||
|
|
||||||
|
In addition the pipeline inherits the following loading methods:
|
||||||
|
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||||
|
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||||
|
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
||||||
|
|
||||||
|
as well as the following saving methods:
|
||||||
|
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vae ([`AutoencoderKL`]):
|
||||||
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||||
|
text_encoder ([`CLIPTextModel`]):
|
||||||
|
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||||
|
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||||
|
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||||
|
tokenizer (`CLIPTokenizer`):
|
||||||
|
Tokenizer of class
|
||||||
|
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/
|
||||||
|
en/model_doc/clip#transformers.CLIPTokenizer).
|
||||||
|
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||||
|
scheduler ([`SchedulerMixin`]):
|
||||||
|
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||||
|
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||||
|
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||||
|
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||||
|
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||||
|
feature_extractor ([`CLIPImageProcessor`]):
|
||||||
|
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||||
|
"""
|
||||||
|
_optional_components = ['safety_checker', 'feature_extractor']
|
||||||
|
|
||||||
|
def _encode_prompt(
|
||||||
|
self,
|
||||||
|
prompt,
|
||||||
|
device,
|
||||||
|
num_images_per_prompt,
|
||||||
|
do_classifier_free_guidance,
|
||||||
|
negative_prompt=None,
|
||||||
|
max_embeddings_multiples=3,
|
||||||
|
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
lora_scale: Optional[float] = None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Encodes the prompt into text encoder hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (`str` or `list(int)`):
|
||||||
|
prompt to be encoded
|
||||||
|
device: (`torch.device`):
|
||||||
|
torch device
|
||||||
|
num_images_per_prompt (`int`):
|
||||||
|
number of images that should be generated per prompt
|
||||||
|
do_classifier_free_guidance (`bool`):
|
||||||
|
whether to use classifier free guidance or not
|
||||||
|
negative_prompt (`str` or `List[str]`):
|
||||||
|
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||||
|
if `guidance_scale` is less than `1`).
|
||||||
|
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
||||||
|
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
||||||
|
"""
|
||||||
|
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
||||||
|
self._lora_scale = lora_scale
|
||||||
|
|
||||||
|
if prompt is not None and isinstance(prompt, str):
|
||||||
|
batch_size = 1
|
||||||
|
elif prompt is not None and isinstance(prompt, list):
|
||||||
|
batch_size = len(prompt)
|
||||||
|
else:
|
||||||
|
batch_size = prompt_embeds.shape[0]
|
||||||
|
|
||||||
|
if negative_prompt_embeds is None:
|
||||||
|
if negative_prompt is None:
|
||||||
|
negative_prompt = [''] * batch_size
|
||||||
|
elif isinstance(negative_prompt, str):
|
||||||
|
negative_prompt = [negative_prompt] * batch_size
|
||||||
|
if batch_size != len(negative_prompt):
|
||||||
|
raise ValueError(
|
||||||
|
f'`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:'
|
||||||
|
f' {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches'
|
||||||
|
' the batch size of `prompt`.')
|
||||||
|
if prompt_embeds is None or negative_prompt_embeds is None:
|
||||||
|
if isinstance(self, TextualInversionLoaderMixin):
|
||||||
|
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||||
|
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||||
|
negative_prompt = self.maybe_convert_prompt(
|
||||||
|
negative_prompt, self.tokenizer)
|
||||||
|
|
||||||
|
prompt_embeds1, negative_prompt_embeds1 = get_weighted_text_embeddings(
|
||||||
|
pipe=self,
|
||||||
|
prompt=prompt,
|
||||||
|
uncond_prompt=negative_prompt
|
||||||
|
if do_classifier_free_guidance else None,
|
||||||
|
max_embeddings_multiples=max_embeddings_multiples,
|
||||||
|
)
|
||||||
|
if prompt_embeds is None:
|
||||||
|
prompt_embeds = prompt_embeds1
|
||||||
|
if negative_prompt_embeds is None:
|
||||||
|
negative_prompt_embeds = negative_prompt_embeds1
|
||||||
|
|
||||||
|
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||||
|
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||||
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||||
|
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt,
|
||||||
|
seq_len, -1)
|
||||||
|
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
bs_embed, seq_len, _ = negative_prompt_embeds.shape
|
||||||
|
negative_prompt_embeds = negative_prompt_embeds.repeat(
|
||||||
|
1, num_images_per_prompt, 1)
|
||||||
|
negative_prompt_embeds = negative_prompt_embeds.view(
|
||||||
|
bs_embed * num_images_per_prompt, seq_len, -1)
|
||||||
|
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||||
|
|
||||||
|
return prompt_embeds
|
||||||
|
|
||||||
|
def blend_v(self, a, b, blend_extent):
|
||||||
|
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
|
||||||
|
for y in range(blend_extent):
|
||||||
|
b[:, :,
|
||||||
|
y, :] = a[:, :, -blend_extent
|
||||||
|
+ y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (
|
||||||
|
y / blend_extent)
|
||||||
|
return b
|
||||||
|
|
||||||
|
def blend_h(self, a, b, blend_extent):
|
||||||
|
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
||||||
|
for x in range(blend_extent):
|
||||||
|
b[:, :, :, x] = a[:, :, :, -blend_extent
|
||||||
|
+ x] * (1 - x / blend_extent) + b[:, :, :, x] * (
|
||||||
|
x / blend_extent)
|
||||||
|
return b
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, List[str]] = None,
|
||||||
|
height: Optional[int] = None,
|
||||||
|
width: Optional[int] = None,
|
||||||
|
num_inference_steps: int = 50,
|
||||||
|
guidance_scale: float = 7.5,
|
||||||
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||||
|
num_images_per_prompt: Optional[int] = 1,
|
||||||
|
eta: float = 0.0,
|
||||||
|
generator: Optional[Union[torch.Generator,
|
||||||
|
List[torch.Generator]]] = None,
|
||||||
|
latents: Optional[torch.FloatTensor] = None,
|
||||||
|
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
output_type: Optional[str] = 'pil',
|
||||||
|
return_dict: bool = True,
|
||||||
|
callback: Optional[Callable[[int, int, torch.FloatTensor],
|
||||||
|
None]] = None,
|
||||||
|
callback_steps: int = 1,
|
||||||
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
guidance_rescale: float = 0.0,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Function invoked when calling the pipeline for generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (`str` or `List[str]`, *optional*):
|
||||||
|
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||||
|
instead.
|
||||||
|
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||||
|
The height in pixels of the generated image.
|
||||||
|
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||||
|
The width in pixels of the generated image.
|
||||||
|
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||||
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||||
|
expense of slower inference.
|
||||||
|
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||||
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||||
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||||
|
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||||
|
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||||
|
usually at the expense of lower image quality.
|
||||||
|
negative_prompt (`str` or `List[str]`, *optional*):
|
||||||
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||||
|
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||||
|
less than `1`).
|
||||||
|
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||||
|
The number of images to generate per prompt.
|
||||||
|
eta (`float`, *optional*, defaults to 0.0):
|
||||||
|
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||||
|
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||||
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||||
|
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||||
|
to make generation deterministic.
|
||||||
|
latents (`torch.FloatTensor`, *optional*):
|
||||||
|
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||||
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||||
|
tensor will ge generated by sampling using the supplied random `generator`.
|
||||||
|
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||||
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||||
|
provided, text embeddings will be generated from `prompt` input argument.
|
||||||
|
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||||
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||||
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||||
|
argument.
|
||||||
|
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||||
|
The output format of the generate image. Choose between
|
||||||
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||||
|
plain tuple.
|
||||||
|
callback (`Callable`, *optional*):
|
||||||
|
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||||
|
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||||
|
callback_steps (`int`, *optional*, defaults to 1):
|
||||||
|
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||||
|
called at every step.
|
||||||
|
cross_attention_kwargs (`dict`, *optional*):
|
||||||
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||||
|
`self.processor` in
|
||||||
|
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
||||||
|
guidance_rescale (`float`, *optional*, defaults to 0.7):
|
||||||
|
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
||||||
|
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
|
||||||
|
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||||
|
Guidance rescale factor should fix overexposure when using zero terminal SNR.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||||
|
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||||
|
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||||
|
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||||
|
(nsfw) content, according to the `safety_checker`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def tiled_decode(
|
||||||
|
self,
|
||||||
|
z: torch.FloatTensor,
|
||||||
|
return_dict: bool = True
|
||||||
|
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||||
|
r"""Decode a batch of images using a tiled decoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several
|
||||||
|
steps. This is useful to keep memory use constant regardless of image size.
|
||||||
|
The end result of tiled decoding is: different from non-tiled decoding due to each tile using a different
|
||||||
|
decoder. To avoid tiling artifacts, the tiles overlap and are blended together to form a smooth output.
|
||||||
|
You may still see tile-sized changes in the look of the output, but they should be much less noticeable.
|
||||||
|
z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to
|
||||||
|
`True`):
|
||||||
|
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
||||||
|
"""
|
||||||
|
_tile_overlap_factor = 1 - self.tile_overlap_factor
|
||||||
|
overlap_size = int(self.tile_latent_min_size
|
||||||
|
* _tile_overlap_factor)
|
||||||
|
blend_extent = int(self.tile_sample_min_size
|
||||||
|
* self.tile_overlap_factor)
|
||||||
|
row_limit = self.tile_sample_min_size - blend_extent
|
||||||
|
w = z.shape[3]
|
||||||
|
z = torch.cat([z, z[:, :, :, :w // 4]], dim=-1)
|
||||||
|
# Split z into overlapping 64x64 tiles and decode them separately.
|
||||||
|
# The tiles have an overlap to avoid seams between tiles.
|
||||||
|
|
||||||
|
rows = []
|
||||||
|
for i in range(0, z.shape[2], overlap_size):
|
||||||
|
row = []
|
||||||
|
tile = z[:, :, i:i + self.tile_latent_min_size, :]
|
||||||
|
tile = self.post_quant_conv(tile)
|
||||||
|
decoded = self.decoder(tile)
|
||||||
|
vae_scale_factor = decoded.shape[-1] // tile.shape[-1]
|
||||||
|
row.append(decoded)
|
||||||
|
rows.append(row)
|
||||||
|
result_rows = []
|
||||||
|
for i, row in enumerate(rows):
|
||||||
|
result_row = []
|
||||||
|
for j, tile in enumerate(row):
|
||||||
|
# blend the above tile and the left tile
|
||||||
|
# to the current tile and add the current tile to the result row
|
||||||
|
if i > 0:
|
||||||
|
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||||
|
if j > 0:
|
||||||
|
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||||
|
result_row.append(
|
||||||
|
self.blend_h(
|
||||||
|
tile[:, :, :row_limit, w * vae_scale_factor:],
|
||||||
|
tile[:, :, :row_limit, :w * vae_scale_factor],
|
||||||
|
tile.shape[-1] - w * vae_scale_factor))
|
||||||
|
result_rows.append(torch.cat(result_row, dim=3))
|
||||||
|
|
||||||
|
dec = torch.cat(result_rows, dim=2)
|
||||||
|
if not return_dict:
|
||||||
|
return (dec, )
|
||||||
|
|
||||||
|
return DecoderOutput(sample=dec)
|
||||||
|
|
||||||
|
self.vae.tiled_decode = tiled_decode.__get__(self.vae, AutoencoderKL)
|
||||||
|
|
||||||
|
# 0. Default height and width to unet
|
||||||
|
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||||
|
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||||
|
|
||||||
|
# 1. Check inputs. Raise error if not correct
|
||||||
|
self.check_inputs(prompt, height, width, callback_steps,
|
||||||
|
negative_prompt, prompt_embeds,
|
||||||
|
negative_prompt_embeds)
|
||||||
|
self.blend_extend = width // self.vae_scale_factor // 32
|
||||||
|
|
||||||
|
# 2. Define call parameters
|
||||||
|
if prompt is not None and isinstance(prompt, str):
|
||||||
|
batch_size = 1
|
||||||
|
elif prompt is not None and isinstance(prompt, list):
|
||||||
|
batch_size = len(prompt)
|
||||||
|
else:
|
||||||
|
batch_size = prompt_embeds.shape[0]
|
||||||
|
|
||||||
|
device = self._execution_device
|
||||||
|
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||||
|
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||||
|
# corresponds to doing no classifier free guidance.
|
||||||
|
do_classifier_free_guidance = guidance_scale > 1.0
|
||||||
|
|
||||||
|
# 3. Encode input prompt
|
||||||
|
text_encoder_lora_scale = (
|
||||||
|
cross_attention_kwargs.get('scale', None)
|
||||||
|
if cross_attention_kwargs is not None else None)
|
||||||
|
prompt_embeds = self._encode_prompt(
|
||||||
|
prompt,
|
||||||
|
device,
|
||||||
|
num_images_per_prompt,
|
||||||
|
do_classifier_free_guidance,
|
||||||
|
negative_prompt,
|
||||||
|
prompt_embeds=prompt_embeds,
|
||||||
|
negative_prompt_embeds=negative_prompt_embeds,
|
||||||
|
lora_scale=text_encoder_lora_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Prepare timesteps
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||||
|
timesteps = self.scheduler.timesteps
|
||||||
|
|
||||||
|
# 5. Prepare latent variables
|
||||||
|
num_channels_latents = self.unet.config.in_channels
|
||||||
|
latents = self.prepare_latents(
|
||||||
|
batch_size * num_images_per_prompt,
|
||||||
|
num_channels_latents,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
prompt_embeds.dtype,
|
||||||
|
device,
|
||||||
|
generator,
|
||||||
|
latents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||||
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||||
|
|
||||||
|
# 7. Denoising loop
|
||||||
|
num_warmup_steps = len(
|
||||||
|
timesteps) - num_inference_steps * self.scheduler.order
|
||||||
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||||
|
for i, t in enumerate(timesteps):
|
||||||
|
# expand the latents if we are doing classifier free guidance
|
||||||
|
latent_model_input = torch.cat(
|
||||||
|
[latents] * 2) if do_classifier_free_guidance else latents
|
||||||
|
latent_model_input = self.scheduler.scale_model_input(
|
||||||
|
latent_model_input, t)
|
||||||
|
|
||||||
|
# predict the noise residual
|
||||||
|
noise_pred = self.unet(
|
||||||
|
latent_model_input,
|
||||||
|
t,
|
||||||
|
encoder_hidden_states=prompt_embeds,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
# perform guidance
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||||
|
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||||
|
noise_pred_text - noise_pred_uncond)
|
||||||
|
|
||||||
|
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||||
|
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||||
|
noise_pred = rescale_noise_cfg(
|
||||||
|
noise_pred,
|
||||||
|
noise_pred_text,
|
||||||
|
guidance_rescale=guidance_rescale)
|
||||||
|
|
||||||
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
latents = self.scheduler.step(
|
||||||
|
noise_pred,
|
||||||
|
t,
|
||||||
|
latents,
|
||||||
|
**extra_step_kwargs,
|
||||||
|
return_dict=False)[0]
|
||||||
|
|
||||||
|
# call the callback, if provided
|
||||||
|
condition_i = i == len(timesteps) - 1
|
||||||
|
condition_warm = (i + 1) > num_warmup_steps and (
|
||||||
|
i + 1) % self.scheduler.order == 0
|
||||||
|
if condition_i or condition_warm:
|
||||||
|
progress_bar.update()
|
||||||
|
if callback is not None and i % callback_steps == 0:
|
||||||
|
callback(i, t, latents)
|
||||||
|
latents = self.blend_h(latents, latents, self.blend_extend)
|
||||||
|
latents = self.blend_h(latents, latents, self.blend_extend)
|
||||||
|
latents = latents[:, :, :, :width // self.vae_scale_factor]
|
||||||
|
|
||||||
|
if not output_type == 'latent':
|
||||||
|
image = self.vae.decode(
|
||||||
|
latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||||
|
image, has_nsfw_concept = self.run_safety_checker(
|
||||||
|
image, device, prompt_embeds.dtype)
|
||||||
|
else:
|
||||||
|
image = latents
|
||||||
|
has_nsfw_concept = None
|
||||||
|
|
||||||
|
if has_nsfw_concept is None:
|
||||||
|
do_denormalize = [True] * image.shape[0]
|
||||||
|
else:
|
||||||
|
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
||||||
|
|
||||||
|
image = self.image_processor.postprocess(
|
||||||
|
image, output_type=output_type, do_denormalize=do_denormalize)
|
||||||
|
|
||||||
|
# Offload last model to CPU
|
||||||
|
if hasattr(
|
||||||
|
self,
|
||||||
|
'final_offload_hook') and self.final_offload_hook is not None:
|
||||||
|
self.final_offload_hook.offload()
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (image, has_nsfw_concept)
|
||||||
|
|
||||||
|
return StableDiffusionPipelineOutput(
|
||||||
|
images=image, nsfw_content_detected=has_nsfw_concept)
|
||||||
1208
modelscope/models/cv/text_to_360panorama_image/pipeline_sr.py
Normal file
1208
modelscope/models/cv/text_to_360panorama_image/pipeline_sr.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1504,6 +1504,7 @@ TASK_OUTPUTS = {
|
|||||||
Tasks.document_grounded_dialog_retrieval: [OutputKeys.OUTPUT],
|
Tasks.document_grounded_dialog_retrieval: [OutputKeys.OUTPUT],
|
||||||
Tasks.video_temporal_grounding: [OutputKeys.SCORES, OutputKeys.TBOUNDS],
|
Tasks.video_temporal_grounding: [OutputKeys.SCORES, OutputKeys.TBOUNDS],
|
||||||
Tasks.text_to_video_synthesis: [OutputKeys.OUTPUT_VIDEO],
|
Tasks.text_to_video_synthesis: [OutputKeys.OUTPUT_VIDEO],
|
||||||
|
Tasks.text_to_360panorama_image: [OutputKeys.OUTPUT_IMG],
|
||||||
|
|
||||||
# Tasks.image_try_on result for a single sample
|
# Tasks.image_try_on result for a single sample
|
||||||
# {
|
# {
|
||||||
|
|||||||
@@ -398,4 +398,7 @@ TASK_INPUTS = {
|
|||||||
'text': InputType.TEXT
|
'text': InputType.TEXT
|
||||||
},
|
},
|
||||||
Tasks.video_summarization: InputType.TEXT,
|
Tasks.video_summarization: InputType.TEXT,
|
||||||
|
Tasks.text_to_360panorama_image: {
|
||||||
|
'prompt': InputType.TEXT,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -107,6 +107,7 @@ if TYPE_CHECKING:
|
|||||||
from .image_bts_depth_estimation_pipeline import ImageBTSDepthEstimationPipeline
|
from .image_bts_depth_estimation_pipeline import ImageBTSDepthEstimationPipeline
|
||||||
from .pedestrian_attribute_recognition_pipeline import PedestrainAttributeRecognitionPipeline
|
from .pedestrian_attribute_recognition_pipeline import PedestrainAttributeRecognitionPipeline
|
||||||
from .image_panoptic_segmentation_pipeline import ImagePanopticSegmentationPipeline
|
from .image_panoptic_segmentation_pipeline import ImagePanopticSegmentationPipeline
|
||||||
|
from .text_to_360panorama_image_pipeline import Text2360PanoramaImagePipeline
|
||||||
else:
|
else:
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
'action_recognition_pipeline': ['ActionRecognitionPipeline'],
|
'action_recognition_pipeline': ['ActionRecognitionPipeline'],
|
||||||
@@ -265,6 +266,9 @@ else:
|
|||||||
'image_panoptic_segmentation_pipeline': [
|
'image_panoptic_segmentation_pipeline': [
|
||||||
'ImagePanopticSegmentationPipeline',
|
'ImagePanopticSegmentationPipeline',
|
||||||
],
|
],
|
||||||
|
'text_to_360panorama_image_pipeline': [
|
||||||
|
'Text2360PanoramaImagePipeline'
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|||||||
224
modelscope/pipelines/cv/text_to_360panorama_image_pipeline.py
Normal file
224
modelscope/pipelines/cv/text_to_360panorama_image_pipeline.py
Normal file
@@ -0,0 +1,224 @@
|
|||||||
|
# Copyright © Alibaba, Inc. and its affiliates.
|
||||||
|
import random
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
|
from diffusers import (ControlNetModel, DiffusionPipeline,
|
||||||
|
EulerAncestralDiscreteScheduler,
|
||||||
|
UniPCMultistepScheduler)
|
||||||
|
from PIL import Image
|
||||||
|
from realesrgan import RealESRGANer
|
||||||
|
|
||||||
|
from modelscope.metainfo import Pipelines
|
||||||
|
from modelscope.models.cv.text_to_360panorama_image import (
|
||||||
|
StableDiffusionBlendExtendPipeline,
|
||||||
|
StableDiffusionControlNetImg2ImgPanoPipeline)
|
||||||
|
from modelscope.pipelines.builder import PIPELINES
|
||||||
|
from modelscope.utils.constant import Tasks
|
||||||
|
|
||||||
|
|
||||||
|
@PIPELINES.register_module(
|
||||||
|
Tasks.text_to_360panorama_image,
|
||||||
|
module_name=Pipelines.text_to_360panorama_image)
|
||||||
|
class Text2360PanoramaImagePipeline(Pipelines):
|
||||||
|
""" Stable Diffusion for 360 Panorama Image Generation Pipeline.
|
||||||
|
Example:
|
||||||
|
>>> import cv2
|
||||||
|
>>> from modelscope.outputs import OutputKeys
|
||||||
|
>>> from modelscope.pipelines import pipeline
|
||||||
|
>>> from modelscope.utils.constant import Tasks
|
||||||
|
>>> prompt = 'The mountains'
|
||||||
|
>>> input = {'prompt': prompt, 'upscale': True}
|
||||||
|
>>> model_id = 'damo/cv_diffusion_text-to-360panorama-image_generation'
|
||||||
|
>>> txt2panoimg = pipeline(Tasks.text_to_360panorama_image, model=model_id, model_revision='v1.0.0')
|
||||||
|
>>> output = txt2panoimg(input)[OutputKeys.OUTPUT_IMG]
|
||||||
|
>>> cv2.imwrite('result.png', output)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model: str, device: str = 'cuda', **kwargs):
|
||||||
|
"""
|
||||||
|
Use `model` to create a stable diffusion pipeline for 360 panorama image generation.
|
||||||
|
Args:
|
||||||
|
model: model id on modelscope hub.
|
||||||
|
device: str = 'cuda'
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
) if device is None else device
|
||||||
|
if device == 'gpu':
|
||||||
|
device = 'cuda'
|
||||||
|
|
||||||
|
torch_dtype = kwargs.get('torch_dtype', torch.float16)
|
||||||
|
enable_xformers_memory_efficient_attention = kwargs.get(
|
||||||
|
'enable_xformers_memory_efficient_attention', True)
|
||||||
|
|
||||||
|
model_id = model + '/sd-base/'
|
||||||
|
|
||||||
|
# init base model
|
||||||
|
self.pipe = StableDiffusionBlendExtendPipeline.from_pretrained(
|
||||||
|
model_id, torch_dtype=torch_dtype).to(device)
|
||||||
|
self.pipe.vae.enable_tiling()
|
||||||
|
self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
||||||
|
self.pipe.scheduler.config)
|
||||||
|
# remove following line if xformers is not installed
|
||||||
|
try:
|
||||||
|
if enable_xformers_memory_efficient_attention:
|
||||||
|
self.pipe.enable_xformers_memory_efficient_attention()
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
self.pipe.enable_model_cpu_offload()
|
||||||
|
|
||||||
|
# init controlnet-sr model
|
||||||
|
base_model_path = model + '/sr-base'
|
||||||
|
controlnet_path = model + '/sr-control'
|
||||||
|
controlnet = ControlNetModel.from_pretrained(
|
||||||
|
controlnet_path, torch_dtype=torch_dtype)
|
||||||
|
self.pipe_sr = StableDiffusionControlNetImg2ImgPanoPipeline.from_pretrained(
|
||||||
|
base_model_path, controlnet=controlnet,
|
||||||
|
torch_dtype=torch_dtype).to(device)
|
||||||
|
self.pipe_sr.scheduler = UniPCMultistepScheduler.from_config(
|
||||||
|
self.pipe.scheduler.config)
|
||||||
|
self.pipe_sr.vae.enable_tiling()
|
||||||
|
# remove following line if xformers is not installed
|
||||||
|
try:
|
||||||
|
if enable_xformers_memory_efficient_attention:
|
||||||
|
self.pipe_sr.enable_xformers_memory_efficient_attention()
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
self.pipe_sr.enable_model_cpu_offload()
|
||||||
|
|
||||||
|
# init realesrgan model
|
||||||
|
sr_model = RRDBNet(
|
||||||
|
num_in_ch=3,
|
||||||
|
num_out_ch=3,
|
||||||
|
num_feat=64,
|
||||||
|
num_block=23,
|
||||||
|
num_grow_ch=32,
|
||||||
|
scale=2)
|
||||||
|
netscale = 2
|
||||||
|
|
||||||
|
model_path = model + '/RealESRGAN_x2plus.pth'
|
||||||
|
|
||||||
|
dni_weight = None
|
||||||
|
self.upsampler = RealESRGANer(
|
||||||
|
scale=netscale,
|
||||||
|
model_path=model_path,
|
||||||
|
dni_weight=dni_weight,
|
||||||
|
model=sr_model,
|
||||||
|
tile=384,
|
||||||
|
tile_pad=20,
|
||||||
|
pre_pad=20,
|
||||||
|
half=False,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def blend_h(a, b, blend_extent):
|
||||||
|
blend_extent = min(a.shape[1], b.shape[1], blend_extent)
|
||||||
|
for x in range(blend_extent):
|
||||||
|
b[:, x, :] = a[:, -blend_extent
|
||||||
|
+ x, :] * (1 - x / blend_extent) + b[:, x, :] * (
|
||||||
|
x / blend_extent)
|
||||||
|
return b
|
||||||
|
|
||||||
|
def __call__(self, inputs: Dict[str, Any],
|
||||||
|
**forward_params) -> Dict[str, Any]:
|
||||||
|
if not isinstance(inputs, dict):
|
||||||
|
raise ValueError(
|
||||||
|
f'Expected the input to be a dictionary, but got {type(input)}'
|
||||||
|
)
|
||||||
|
num_inference_steps = inputs.get('num_inference_steps', 20)
|
||||||
|
guidance_scale = inputs.get('guidance_scale', 7.5)
|
||||||
|
preset_a_prompt = 'photorealistic, trend on artstation, ((best quality)), ((ultra high res))'
|
||||||
|
add_prompt = inputs.get('add_prompt', preset_a_prompt)
|
||||||
|
preset_n_prompt = 'persons, complex texture, small objects, sheltered, blur, worst quality, '\
|
||||||
|
'low quality, zombie, logo, text, watermark, username, monochrome, '\
|
||||||
|
'complex lighting'
|
||||||
|
negative_prompt = inputs.get('negative_prompt', preset_n_prompt)
|
||||||
|
seed = inputs.get('seed', -1)
|
||||||
|
upscale = inputs.get('upscale', True)
|
||||||
|
refinement = inputs.get('refinement', True)
|
||||||
|
|
||||||
|
if 'prompt' in inputs.keys():
|
||||||
|
prompt = inputs['prompt']
|
||||||
|
else:
|
||||||
|
# for demo_service
|
||||||
|
prompt = forward_params.get('prompt', 'the living room')
|
||||||
|
|
||||||
|
print(f'Test with prompt: {prompt}')
|
||||||
|
|
||||||
|
if seed == -1:
|
||||||
|
seed = random.randint(0, 65535)
|
||||||
|
print(f'global seed: {seed}')
|
||||||
|
|
||||||
|
generator = torch.manual_seed(seed)
|
||||||
|
|
||||||
|
prompt = '<360panorama>, ' + prompt + ', ' + add_prompt
|
||||||
|
output_img = self.pipe(
|
||||||
|
prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
height=512,
|
||||||
|
width=1024,
|
||||||
|
guidance_scale=guidance_scale,
|
||||||
|
generator=generator).images[0]
|
||||||
|
|
||||||
|
if not upscale:
|
||||||
|
print('finished')
|
||||||
|
else:
|
||||||
|
print('inputs: upscale=True, running upscaler.')
|
||||||
|
print('running upscaler step1. Initial super-resolution')
|
||||||
|
sr_scale = 2.0
|
||||||
|
output_img = self.pipe_sr(
|
||||||
|
prompt.replace('<360panorama>, ', ''),
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
image=output_img.resize(
|
||||||
|
(int(1536 * sr_scale), int(768 * sr_scale))),
|
||||||
|
num_inference_steps=7,
|
||||||
|
generator=generator,
|
||||||
|
control_image=output_img.resize(
|
||||||
|
(int(1536 * sr_scale), int(768 * sr_scale))),
|
||||||
|
strength=0.8,
|
||||||
|
controlnet_conditioning_scale=1.0,
|
||||||
|
guidance_scale=15,
|
||||||
|
).images[0]
|
||||||
|
|
||||||
|
print('running upscaler step2. Super-resolution with Real-ESRGAN')
|
||||||
|
output_img = output_img.resize((1536 * 2, 768 * 2))
|
||||||
|
w = output_img.size[0]
|
||||||
|
blend_extend = 10
|
||||||
|
outscale = 2
|
||||||
|
output_img = np.array(output_img)
|
||||||
|
output_img = np.concatenate(
|
||||||
|
[output_img, output_img[:, :blend_extend, :]], axis=1)
|
||||||
|
output_img, _ = self.upsampler.enhance(
|
||||||
|
output_img, outscale=outscale)
|
||||||
|
output_img = self.blend_h(output_img, output_img,
|
||||||
|
blend_extend * outscale)
|
||||||
|
output_img = Image.fromarray(output_img[:, :w * outscale, :])
|
||||||
|
|
||||||
|
if refinement:
|
||||||
|
print(
|
||||||
|
'inputs: refinement=True, running refinement. This is a bit time-consuming.'
|
||||||
|
)
|
||||||
|
sr_scale = 4
|
||||||
|
output_img = self.pipe_sr(
|
||||||
|
prompt.replace('<360panorama>, ', ''),
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
image=output_img.resize(
|
||||||
|
(int(1536 * sr_scale), int(768 * sr_scale))),
|
||||||
|
num_inference_steps=7,
|
||||||
|
generator=generator,
|
||||||
|
control_image=output_img.resize(
|
||||||
|
(int(1536 * sr_scale), int(768 * sr_scale))),
|
||||||
|
strength=0.8,
|
||||||
|
controlnet_conditioning_scale=1.0,
|
||||||
|
guidance_scale=17,
|
||||||
|
).images[0]
|
||||||
|
print('finished')
|
||||||
|
|
||||||
|
output_img = np.array(output_img)
|
||||||
|
return {'output_img': output_img[:, :, ::-1]}
|
||||||
@@ -96,6 +96,7 @@ class CVTasks(object):
|
|||||||
image_face_fusion = 'image-face-fusion'
|
image_face_fusion = 'image-face-fusion'
|
||||||
product_retrieval_embedding = 'product-retrieval-embedding'
|
product_retrieval_embedding = 'product-retrieval-embedding'
|
||||||
controllable_image_generation = 'controllable-image-generation'
|
controllable_image_generation = 'controllable-image-generation'
|
||||||
|
text_to_360panorama_image = 'text-to-360panorama-image'
|
||||||
image_try_on = 'image-try-on'
|
image_try_on = 'image-try-on'
|
||||||
|
|
||||||
# video recognition
|
# video recognition
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
accelerate
|
accelerate
|
||||||
albumentations>=1.0.3
|
albumentations>=1.0.3
|
||||||
av>=9.2.0
|
av>=9.2.0
|
||||||
|
basicsr>=1.4.2
|
||||||
bmt_clipit>=1.0
|
bmt_clipit>=1.0
|
||||||
chumpy
|
chumpy
|
||||||
clip>=1.0
|
clip>=1.0
|
||||||
control_ldm
|
control_ldm
|
||||||
ddpm_guided_diffusion
|
ddpm_guided_diffusion
|
||||||
diffusers>=0.13.1,<=0.15.0
|
diffusers==0.18.0
|
||||||
easydict
|
easydict
|
||||||
easyrobust
|
easyrobust
|
||||||
edit_distance
|
edit_distance
|
||||||
@@ -49,6 +50,7 @@ psutil
|
|||||||
pyclipper
|
pyclipper
|
||||||
PyMCubes
|
PyMCubes
|
||||||
pytorch-lightning
|
pytorch-lightning
|
||||||
|
realesrgan==0.3.0
|
||||||
regex
|
regex
|
||||||
# <0.20.0 for compatible python3.7 python3.8
|
# <0.20.0 for compatible python3.7 python3.8
|
||||||
scikit-image>=0.19.3,<0.20.0
|
scikit-image>=0.19.3,<0.20.0
|
||||||
@@ -63,6 +65,7 @@ timm>=0.4.9
|
|||||||
torchmetrics>=0.6.2
|
torchmetrics>=0.6.2
|
||||||
torchsummary>=1.5.1
|
torchsummary>=1.5.1
|
||||||
torchvision
|
torchvision
|
||||||
|
tqdm
|
||||||
transformers>=4.26.0
|
transformers>=4.26.0
|
||||||
trimesh
|
trimesh
|
||||||
ujson
|
ujson
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
accelerate
|
accelerate
|
||||||
cloudpickle
|
cloudpickle
|
||||||
decord>=0.6.0
|
decord>=0.6.0
|
||||||
diffusers==0.15.0
|
diffusers==0.18.0
|
||||||
fairseq
|
fairseq
|
||||||
ftfy>=6.0.3
|
ftfy>=6.0.3
|
||||||
librosa==0.9.2
|
librosa==0.9.2
|
||||||
|
|||||||
69
tests/pipelines/test_text_to_360panorama_image.py
Normal file
69
tests/pipelines/test_text_to_360panorama_image.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
from modelscope.hub.snapshot_download import snapshot_download
|
||||||
|
from modelscope.outputs import OutputKeys
|
||||||
|
from modelscope.pipelines import pipeline
|
||||||
|
from modelscope.pipelines.cv import Text2360PanoramaImagePipeline
|
||||||
|
from modelscope.utils.constant import Tasks
|
||||||
|
from modelscope.utils.logger import get_logger
|
||||||
|
from modelscope.utils.test_utils import test_level
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class Text2360PanoramaImageTest(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
logger.info('start install xformers')
|
||||||
|
cmd = [
|
||||||
|
sys.executable, '-m', 'pip', 'install', 'xformers', '-f',
|
||||||
|
'https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html'
|
||||||
|
]
|
||||||
|
subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||||
|
|
||||||
|
logger.info('install xformers finished')
|
||||||
|
|
||||||
|
self.task = Tasks.text_to_360panorama_image
|
||||||
|
self.model_id = 'damo/cv_diffusion_text-to-360panorama-image_generation'
|
||||||
|
self.prompt = 'The living room'
|
||||||
|
self.upscale = False
|
||||||
|
self.refinement = False
|
||||||
|
|
||||||
|
self.input = {
|
||||||
|
'prompt': self.prompt,
|
||||||
|
'upscale': self.upscale,
|
||||||
|
'refinement': self.refinement,
|
||||||
|
}
|
||||||
|
|
||||||
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||||
|
def test_run_by_direct_model_download(self):
|
||||||
|
output_image_path = tempfile.NamedTemporaryFile(suffix='.png').name
|
||||||
|
cache_path = snapshot_download(self.model_id)
|
||||||
|
pipeline = Text2360PanoramaImagePipeline(cache_path)
|
||||||
|
pipeline.group_key = self.task
|
||||||
|
output = pipeline(inputs=self.input)[OutputKeys.OUTPUT_IMG]
|
||||||
|
cv2.imwrite(output_image_path, output)
|
||||||
|
print(
|
||||||
|
'pipeline: the output image path is {}'.format(output_image_path))
|
||||||
|
|
||||||
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||||
|
def test_run_with_model_from_modelhub(self):
|
||||||
|
output_image_path = tempfile.NamedTemporaryFile(suffix='.png').name
|
||||||
|
pipeline_ins = pipeline(
|
||||||
|
task=Tasks.text_to_360panorama_image,
|
||||||
|
model=self.model_id,
|
||||||
|
model_revision='v1.0.0')
|
||||||
|
output = pipeline_ins(inputs=self.input)[OutputKeys.OUTPUT_IMG]
|
||||||
|
cv2.imwrite(output_image_path, output)
|
||||||
|
print(
|
||||||
|
'pipeline: the output image path is {}'.format(output_image_path))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
@@ -62,6 +62,8 @@ isolated: # test cases that may require excessive anmount of GPU memory or run
|
|||||||
- test_bad_image_detecting.py
|
- test_bad_image_detecting.py
|
||||||
- test_controllable_image_generation.py
|
- test_controllable_image_generation.py
|
||||||
- test_image_colorization_trainer.py
|
- test_image_colorization_trainer.py
|
||||||
|
- test_text_to_360panorama_image.py
|
||||||
|
|
||||||
|
|
||||||
envs:
|
envs:
|
||||||
default: # default env, case not in other env will in default, pytorch.
|
default: # default env, case not in other env will in default, pytorch.
|
||||||
|
|||||||
Reference in New Issue
Block a user