From 6942144ad71bb21a0bdb5e2e5763ad9f61b6a393 Mon Sep 17 00:00:00 2001 From: Wang Qiang <37444407+XDUWQ@users.noreply.github.com> Date: Wed, 28 Jun 2023 13:26:19 +0800 Subject: [PATCH 1/3] Stable Diffusion model checkpoint export to onnx. (#340) * stable diffusion export onnx * fix pre commit bugs * fix bugs * safety checker support * test export stable diffusion --- modelscope/exporters/__init__.py | 2 + modelscope/exporters/multi_modal/__init__.py | 22 ++ .../multi_modal/stable_diffusion_exporter.py | 303 ++++++++++++++++++ .../stable_diffusion/stable_diffusion.py | 1 + tests/export/test_export_stable_diffusion.py | 31 ++ 5 files changed, 359 insertions(+) create mode 100644 modelscope/exporters/multi_modal/__init__.py create mode 100644 modelscope/exporters/multi_modal/stable_diffusion_exporter.py create mode 100644 tests/export/test_export_stable_diffusion.py diff --git a/modelscope/exporters/__init__.py b/modelscope/exporters/__init__.py index 8b627816..e5a10a0d 100644 --- a/modelscope/exporters/__init__.py +++ b/modelscope/exporters/__init__.py @@ -13,11 +13,13 @@ if TYPE_CHECKING: from .nlp import SbertForSequenceClassificationExporter, SbertForZeroShotClassificationExporter from .torch_model_exporter import TorchModelExporter from .cv import FaceDetectionSCRFDExporter + from .multi_modal import StableDiffuisonExporter else: _import_structure = { 'base': ['Exporter'], 'builder': ['build_exporter'], 'cv': ['CartoonTranslationExporter', 'FaceDetectionSCRFDExporter'], + 'multi_modal': ['StableDiffuisonExporter'], 'nlp': [ 'CsanmtForTranslationExporter', 'SbertForSequenceClassificationExporter', diff --git a/modelscope/exporters/multi_modal/__init__.py b/modelscope/exporters/multi_modal/__init__.py new file mode 100644 index 00000000..ab565d1c --- /dev/null +++ b/modelscope/exporters/multi_modal/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .stable_diffusion_export import StableDiffuisonExporter +else: + _import_structure = { + 'stable_diffusion_export': ['StableDiffuisonExporter'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/exporters/multi_modal/stable_diffusion_exporter.py b/modelscope/exporters/multi_modal/stable_diffusion_exporter.py new file mode 100644 index 00000000..62ab0ce5 --- /dev/null +++ b/modelscope/exporters/multi_modal/stable_diffusion_exporter.py @@ -0,0 +1,303 @@ +import argparse +import os +import shutil +from collections import OrderedDict +from pathlib import Path +from typing import Any, Dict, Mapping, Tuple + +import onnx +import torch +from diffusers import (OnnxRuntimeModel, OnnxStableDiffusionPipeline, + StableDiffusionPipeline) +from packaging import version +from torch.onnx import export +from torch.utils.data.dataloader import default_collate + +from modelscope.exporters.builder import EXPORTERS +from modelscope.exporters.torch_model_exporter import TorchModelExporter +from modelscope.metainfo import Models +from modelscope.preprocessors import Preprocessor +from modelscope.utils.constant import ModeKeys, Tasks +from modelscope.utils.hub import snapshot_download + + +@EXPORTERS.register_module( + Tasks.text_to_image_synthesis, module_name=Models.stable_diffusion) +class StableDiffuisonExporter(TorchModelExporter): + + @torch.no_grad() + def export_onnx(self, + output_path: str, + opset: int = 14, + fp16: bool = False): + """Export the model as onnx format files. + + Args: + model_path: The model id or local path. + output_dir: The output dir. + opset: The version of the ONNX operator set to use. + fp16: Whether to use float16. + """ + output_path = Path(output_path) + + # Conversion weight accuracy and device. + dtype = torch.float16 if fp16 else torch.float32 + if fp16 and torch.cuda.is_available(): + device = 'cuda' + elif fp16 and not torch.cuda.is_available(): + raise ValueError( + '`float16` model export is only supported on GPUs with CUDA') + else: + device = 'cpu' + self.model = self.model.to(device) + + # Text encoder + num_tokens = self.model.text_encoder.config.max_position_embeddings + text_hidden_size = self.model.text_encoder.config.hidden_size + text_input = self.model.tokenizer( + 'A sample prompt', + padding='max_length', + max_length=self.model.tokenizer.model_max_length, + truncation=True, + return_tensors='pt', + ) + self.export_help( + self.model.text_encoder, + model_args=(text_input.input_ids.to( + device=device, dtype=torch.int32)), + output_path=output_path / 'text_encoder' / 'model.onnx', + ordered_input_names=['input_ids'], + output_names=['last_hidden_state', 'pooler_output'], + dynamic_axes={ + 'input_ids': { + 0: 'batch', + 1: 'sequence' + }, + }, + opset=opset, + ) + del self.model.text_encoder + + # UNET + unet_in_channels = self.model.unet.config.in_channels + unet_sample_size = self.model.unet.config.sample_size + unet_path = output_path / 'unet' / 'model.onnx' + self.export_help( + self.model.unet, + model_args=( + torch.randn(2, unet_in_channels, unet_sample_size, + unet_sample_size).to(device=device, dtype=dtype), + torch.randn(2).to(device=device, dtype=dtype), + torch.randn(2, num_tokens, + text_hidden_size).to(device=device, dtype=dtype), + False, + ), + output_path=unet_path, + ordered_input_names=[ + 'sample', 'timestep', 'encoder_hidden_states', 'return_dict' + ], + output_names=[ + 'out_sample' + ], # has to be different from "sample" for correct tracing + dynamic_axes={ + 'sample': { + 0: 'batch', + 1: 'channels', + 2: 'height', + 3: 'width' + }, + 'timestep': { + 0: 'batch' + }, + 'encoder_hidden_states': { + 0: 'batch', + 1: 'sequence' + }, + }, + opset=opset, + use_external_data_format= + True, # UNet is > 2GB, so the weights need to be split + ) + unet_model_path = str(unet_path.absolute().as_posix()) + unet_dir = os.path.dirname(unet_model_path) + unet = onnx.load(unet_model_path) + # clean up existing tensor files + shutil.rmtree(unet_dir) + os.mkdir(unet_dir) + # collate external tensor files into one + onnx.save_model( + unet, + unet_model_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location='weights.pb', + convert_attribute=False, + ) + del self.model.unet + + # VAE ENCODER + vae_encoder = self.model.vae + vae_in_channels = vae_encoder.config.in_channels + vae_sample_size = vae_encoder.config.sample_size + # need to get the raw tensor output (sample) from the encoder + vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode( + sample, return_dict)[0].sample() + self.export_help( + vae_encoder, + model_args=( + torch.randn(1, vae_in_channels, vae_sample_size, + vae_sample_size).to(device=device, dtype=dtype), + False, + ), + output_path=output_path / 'vae_encoder' / 'model.onnx', + ordered_input_names=['sample', 'return_dict'], + output_names=['latent_sample'], + dynamic_axes={ + 'sample': { + 0: 'batch', + 1: 'channels', + 2: 'height', + 3: 'width' + }, + }, + opset=opset, + ) + + # VAE DECODER + vae_decoder = self.model.vae + vae_latent_channels = vae_decoder.config.latent_channels + vae_out_channels = vae_decoder.config.out_channels + # forward only through the decoder part + vae_decoder.forward = vae_encoder.decode + self.export_help( + vae_decoder, + model_args=( + torch.randn(1, vae_latent_channels, unet_sample_size, + unet_sample_size).to(device=device, dtype=dtype), + False, + ), + output_path=output_path / 'vae_decoder' / 'model.onnx', + ordered_input_names=['latent_sample', 'return_dict'], + output_names=['sample'], + dynamic_axes={ + 'latent_sample': { + 0: 'batch', + 1: 'channels', + 2: 'height', + 3: 'width' + }, + }, + opset=opset, + ) + del self.model.vae + + # SAFETY CHECKER + if self.model.safety_checker is not None: + safety_checker = self.model.safety_checker + clip_num_channels = safety_checker.config.vision_config.num_channels + clip_image_size = safety_checker.config.vision_config.image_size + safety_checker.forward = safety_checker.forward_onnx + self.export_help( + self.model.safety_checker, + model_args=( + torch.randn( + 1, + clip_num_channels, + clip_image_size, + clip_image_size, + ).to(device=device, dtype=dtype), + torch.randn(1, vae_sample_size, vae_sample_size, + vae_out_channels).to( + device=device, dtype=dtype), + ), + output_path=output_path / 'safety_checker' / 'model.onnx', + ordered_input_names=['clip_input', 'images'], + output_names=['out_images', 'has_nsfw_concepts'], + dynamic_axes={ + 'clip_input': { + 0: 'batch', + 1: 'channels', + 2: 'height', + 3: 'width' + }, + 'images': { + 0: 'batch', + 1: 'height', + 2: 'width', + 3: 'channels' + }, + }, + opset=opset, + ) + del self.model.safety_checker + safety_checker = OnnxRuntimeModel.from_pretrained( + output_path / 'safety_checker') + feature_extractor = self.model.feature_extractor + else: + safety_checker = None + feature_extractor = None + + onnx_pipeline = OnnxStableDiffusionPipeline( + vae_encoder=OnnxRuntimeModel.from_pretrained(output_path + / 'vae_encoder'), + vae_decoder=OnnxRuntimeModel.from_pretrained(output_path + / 'vae_decoder'), + text_encoder=OnnxRuntimeModel.from_pretrained(output_path + / 'text_encoder'), + tokenizer=self.model.tokenizer, + unet=OnnxRuntimeModel.from_pretrained(output_path / 'unet'), + scheduler=self.model.noise_scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + requires_safety_checker=safety_checker is not None, + ) + + onnx_pipeline.save_pretrained(output_path) + print('ONNX pipeline model saved to', output_path) + + del self.model + del onnx_pipeline + _ = OnnxStableDiffusionPipeline.from_pretrained( + output_path, provider='CPUExecutionProvider') + print('ONNX pipeline model is loadable') + + def export_help( + self, + model, + model_args: tuple, + output_path: Path, + ordered_input_names, + output_names, + dynamic_axes, + opset, + use_external_data_format=False, + ): + output_path.parent.mkdir(parents=True, exist_ok=True) + + is_torch_less_than_1_11 = version.parse( + version.parse( + torch.__version__).base_version) < version.parse('1.11') + if is_torch_less_than_1_11: + export( + model, + model_args, + f=output_path.as_posix(), + input_names=ordered_input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + do_constant_folding=True, + use_external_data_format=use_external_data_format, + enable_onnx_checker=True, + opset_version=opset, + ) + else: + export( + model, + model_args, + f=output_path.as_posix(), + input_names=ordered_input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + do_constant_folding=True, + opset_version=opset, + ) diff --git a/modelscope/models/multi_modal/stable_diffusion/stable_diffusion.py b/modelscope/models/multi_modal/stable_diffusion/stable_diffusion.py index 8ec2149d..72f29b56 100644 --- a/modelscope/models/multi_modal/stable_diffusion/stable_diffusion.py +++ b/modelscope/models/multi_modal/stable_diffusion/stable_diffusion.py @@ -57,6 +57,7 @@ class StableDiffusion(TorchModel): pretrained_model_name_or_path, subfolder='vae', revision=revision) self.unet = UNet2DConditionModel.from_pretrained( pretrained_model_name_or_path, subfolder='unet', revision=revision) + self.safety_checker = None # Freeze gradient calculation and move to device if self.vae is not None: diff --git a/tests/export/test_export_stable_diffusion.py b/tests/export/test_export_stable_diffusion.py new file mode 100644 index 00000000..a2e20198 --- /dev/null +++ b/tests/export/test_export_stable_diffusion.py @@ -0,0 +1,31 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest +from collections import OrderedDict + +from modelscope.exporters import Exporter +from modelscope.models import Model +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class TestExportStableDiffusion(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + self.model_id = 'AI-ModelScope/stable-diffusion-v1-5' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_export_stable_diffusion(self): + model = Model.from_pretrained(self.model_id) + Exporter.from_model(model).export_onnx( + output_path=self.tmp_dir, opset=14) + + +if __name__ == '__main__': + unittest.main() From a018cd6107a7661c12c87fc643903a00c14747b1 Mon Sep 17 00:00:00 2001 From: Wang Qiang <37444407+XDUWQ@users.noreply.github.com> Date: Wed, 28 Jun 2023 20:10:28 +0800 Subject: [PATCH 2/3] Dreambooth method for finetuning stable diffusions (#339) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Copyright * dreambooth * dreambooth test trainer * fix bugs * pre-commit --------- Co-authored-by: 翊靖 --- modelscope/metainfo.py | 1 + .../stable_diffusion/stable_diffusion.py | 30 +- modelscope/preprocessors/multi_modal.py | 2 +- .../dreambooth_diffusion/__init__.py | 2 + .../dreambooth_diffusion_trainer.py | 384 ++++++++++++++++++ .../multi_modal/lora_diffusion/__init__.py | 1 + .../test_diffusers_stable_diffusion.py | 2 +- .../test_dreambooth_diffusion_trainer.py | 98 +++++ tests/trainers/test_lora_diffusion_trainer.py | 4 +- 9 files changed, 505 insertions(+), 19 deletions(-) create mode 100644 modelscope/trainers/multi_modal/dreambooth_diffusion/__init__.py create mode 100644 modelscope/trainers/multi_modal/dreambooth_diffusion/dreambooth_diffusion_trainer.py create mode 100644 tests/trainers/test_dreambooth_diffusion_trainer.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index f2529be2..d3365b7c 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -895,6 +895,7 @@ class MultiModalTrainers(object): efficient_diffusion_tuning = 'efficient-diffusion-tuning' stable_diffusion = 'stable-diffusion' lora_diffusion = 'lora-diffusion' + dreambooth_diffusion = 'dreambooth-diffusion' class AudioTrainers(object): diff --git a/modelscope/models/multi_modal/stable_diffusion/stable_diffusion.py b/modelscope/models/multi_modal/stable_diffusion/stable_diffusion.py index 72f29b56..88cb4969 100644 --- a/modelscope/models/multi_modal/stable_diffusion/stable_diffusion.py +++ b/modelscope/models/multi_modal/stable_diffusion/stable_diffusion.py @@ -30,13 +30,12 @@ class StableDiffusion(TorchModel): """ Initialize a vision efficient diffusion tuning model. Args: - model_dir: model id or path, where model_dir/pytorch_model.bin + model_dir: model id or path """ super().__init__(model_dir, *args, **kwargs) - pretrained_model_name_or_path = kwargs.pop( - 'pretrained_model_name_or_path', 'runwayml/stable-diffusion-v1-5') revision = kwargs.pop('revision', None) - self.lora_tune = kwargs.pop('lora_tune', True) + self.lora_tune = kwargs.pop('lora_tune', False) + self.dreambooth_tune = kwargs.pop('dreambooth_tune', False) self.weight_dtype = torch.float32 self.device = torch.device( @@ -44,19 +43,15 @@ class StableDiffusion(TorchModel): # Load scheduler, tokenizer and models self.noise_scheduler = DDPMScheduler.from_pretrained( - pretrained_model_name_or_path, subfolder='scheduler') + model_dir, subfolder='scheduler') self.tokenizer = CLIPTokenizer.from_pretrained( - pretrained_model_name_or_path, - subfolder='tokenizer', - revision=revision) + model_dir, subfolder='tokenizer', revision=revision) self.text_encoder = CLIPTextModel.from_pretrained( - pretrained_model_name_or_path, - subfolder='text_encoder', - revision=revision) + model_dir, subfolder='text_encoder', revision=revision) self.vae = AutoencoderKL.from_pretrained( - pretrained_model_name_or_path, subfolder='vae', revision=revision) + model_dir, subfolder='vae', revision=revision) self.unet = UNet2DConditionModel.from_pretrained( - pretrained_model_name_or_path, subfolder='unet', revision=revision) + model_dir, subfolder='unet', revision=revision) self.safety_checker = None # Freeze gradient calculation and move to device @@ -90,6 +85,7 @@ class StableDiffusion(TorchModel): self.unet.train() self.unet = self.unet.to(self.device) + # Convert to latent space with torch.no_grad(): latents = self.vae.encode( target.to(dtype=self.weight_dtype)).latent_dist.sample() @@ -131,6 +127,9 @@ class StableDiffusion(TorchModel): model_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample + if model_pred.shape[1] == 6: + model_pred, _ = torch.chunk(model_pred, 2, dim=1) + loss = F.mse_loss(model_pred.float(), target.float(), reduction='mean') output = {OutputKeys.LOSS: loss} @@ -144,8 +143,9 @@ class StableDiffusion(TorchModel): config: Optional[dict] = None, save_config_function: Callable = save_configuration, **kwargs): - # Save only the lora model, skip saving and copying the original weights - if self.lora_tune: + config['pipeline']['type'] = 'diffusers-stable-diffusion' + # Skip copying the original weights for lora and dreambooth method + if self.lora_tune or self.dreambooth_tune: pass else: super().save_pretrained(target_folder, save_checkpoint_names, diff --git a/modelscope/preprocessors/multi_modal.py b/modelscope/preprocessors/multi_modal.py index 82d44da8..a8867aef 100644 --- a/modelscope/preprocessors/multi_modal.py +++ b/modelscope/preprocessors/multi_modal.py @@ -11,6 +11,7 @@ import torch from PIL import Image from timm.data import create_transform from torchvision import transforms +from torchvision.datasets import ImageFolder from torchvision.transforms import Compose, Normalize, Resize, ToTensor from modelscope.hub.snapshot_download import snapshot_download @@ -55,7 +56,6 @@ class DiffusionImageGenerationPreprocessor(Preprocessor): transforms.Resize( self.preprocessor_resolution, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(self.preprocessor_mean, self.preprocessor_std), diff --git a/modelscope/trainers/multi_modal/dreambooth_diffusion/__init__.py b/modelscope/trainers/multi_modal/dreambooth_diffusion/__init__.py new file mode 100644 index 00000000..430d3c9e --- /dev/null +++ b/modelscope/trainers/multi_modal/dreambooth_diffusion/__init__.py @@ -0,0 +1,2 @@ +# Copyright © Alibaba, Inc. and its affiliates. +from .dreambooth_diffusion_trainer import DreamboothDiffusionTrainer diff --git a/modelscope/trainers/multi_modal/dreambooth_diffusion/dreambooth_diffusion_trainer.py b/modelscope/trainers/multi_modal/dreambooth_diffusion/dreambooth_diffusion_trainer.py new file mode 100644 index 00000000..65623ed8 --- /dev/null +++ b/modelscope/trainers/multi_modal/dreambooth_diffusion/dreambooth_diffusion_trainer.py @@ -0,0 +1,384 @@ +# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import hashlib +import itertools +import shutil +import warnings +from collections.abc import Mapping +from pathlib import Path +from typing import Union + +import torch +import torch.nn.functional as F +from diffusers import DiffusionPipeline +from diffusers.loaders import AttnProcsLayers +from diffusers.models.attention_processor import LoRAAttnProcessor +from PIL import Image +from PIL.ImageOps import exif_transpose +from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm + +from modelscope.metainfo import Trainers +from modelscope.outputs import ModelOutputBase, OutputKeys +from modelscope.trainers.builder import TRAINERS +from modelscope.trainers.hooks.checkpoint.checkpoint_hook import CheckpointHook +from modelscope.trainers.hooks.checkpoint.checkpoint_processor import \ + CheckpointProcessor +from modelscope.trainers.optimizer.builder import build_optimizer +from modelscope.trainers.trainer import EpochBasedTrainer +from modelscope.utils.config import ConfigDict +from modelscope.utils.constant import ModeKeys +from modelscope.utils.file_utils import func_receive_dict_inputs +from modelscope.utils.torch_utils import is_dist + + +class DreamboothCheckpointProcessor(CheckpointProcessor): + + def __init__(self, model_dir): + self.model_dir = model_dir + + def save_checkpoints(self, + trainer, + checkpoint_path_prefix, + output_dir, + meta=None): + """Save the state dict for dreambooth model. + """ + pipeline_args = {} + if trainer.model.text_encoder is not None: + pipeline_args['text_encoder'] = trainer.model.text_encoder + pipeline = DiffusionPipeline.from_pretrained( + self.model_dir, + unet=trainer.model.unet, + **pipeline_args, + ) + scheduler_args = {} + pipeline.scheduler = pipeline.scheduler.from_config( + pipeline.scheduler.config, **scheduler_args) + pipeline.save_pretrained(output_dir) + + +class ClassDataset(Dataset): + + def __init__( + self, + tokenizer, + class_data_root=None, + class_prompt=None, + class_num_images=None, + size=512, + center_crop=False, + ): + """A dataset to prepare class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + + Args: + tokenizer: The tokenizer to use for tokenization. + class_data_root: The saved class data path. + class_prompt: The prompt to use for class images. + class_num_images: The number of class images to use. + size: The size to resize the images. + center_crop: Whether to do center crop or random crop. + + """ + self.size = size + self.center_crop = center_crop + self.tokenizer = tokenizer + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + if class_num_images is not None: + self.num_class_images = min( + len(self.class_images_path), class_num_images) + else: + self.num_class_images = len(self.class_images_path) + self.class_prompt = class_prompt + else: + raise ValueError( + f"Class {self.class_data_root} class data root doesn't exists." + ) + + self.image_transforms = transforms.Compose([ + transforms.Resize( + size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) + if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ]) + + def __len__(self): + return self.num_class_images + + def __getitem__(self, index): + example = {} + + if self.class_data_root: + class_image = Image.open( + self.class_images_path[index % self.num_class_images]) + class_image = exif_transpose(class_image) + + if not class_image.mode == 'RGB': + class_image = class_image.convert('RGB') + example['pixel_values'] = self.image_transforms(class_image) + + class_text_inputs = self.tokenizer( + self.class_prompt, + max_length=self.tokenizer.model_max_length, + truncation=True, + padding='max_length', + return_tensors='pt') + input_ids = torch.squeeze(class_text_inputs.input_ids) + example['input_ids'] = input_ids + + return example + + +class PromptDataset(Dataset): + + def __init__(self, prompt, num_samples): + """Dataset to prepare the prompts to generate class images. + + Args: + prompt: Class prompt. + num_samples: The number sample for class images. + + """ + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example['prompt'] = self.prompt + example['index'] = index + return example + + +@TRAINERS.register_module(module_name=Trainers.dreambooth_diffusion) +class DreamboothDiffusionTrainer(EpochBasedTrainer): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + """Dreambooth trainers for fine-tuning stable diffusion + + Args: + with_prior_preservation: a boolean indicating whether to enable prior loss. + instance_prompt: a string specifying the instance prompt. + class_prompt: a string specifying the class prompt. + class_data_dir: the path to the class data directory. + num_class_images: the number of class images to generate. + prior_loss_weight: the weight of the prior loss. + + """ + self.with_prior_preservation = kwargs.pop('with_prior_preservation', + False) + self.instance_prompt = kwargs.pop('instance_prompt', + 'a photo of sks dog') + self.class_prompt = kwargs.pop('class_prompt', 'a photo of dog') + self.class_data_dir = kwargs.pop('class_data_dir', '/tmp/class_data') + self.num_class_images = kwargs.pop('num_class_images', 200) + self.resolution = kwargs.pop('resolution', 512) + self.prior_loss_weight = kwargs.pop('prior_loss_weight', 1.0) + + # Save checkpoint and configurate files. + ckpt_hook = list( + filter(lambda hook: isinstance(hook, CheckpointHook), + self.hooks))[0] + ckpt_hook.set_processor(DreamboothCheckpointProcessor(self.model_dir)) + + # Check for conflicts and conflicts + if self.with_prior_preservation: + if self.class_data_dir is None: + raise ValueError( + 'You must specify a data directory for class images.') + if self.class_prompt is None: + raise ValueError('You must specify prompt for class images.') + else: + if self.class_data_dir is not None: + warnings.warn( + 'You need not use --class_data_dir without --with_prior_preservation.' + ) + if self.class_prompt is not None: + warnings.warn( + 'You need not use --class_prompt without --with_prior_preservation.' + ) + + # Generate class images if prior preservation is enabled. + if self.with_prior_preservation: + class_images_dir = Path(self.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < self.num_class_images: + if torch.cuda.device_count() > 1: + warnings.warn('Multiple GPU inference not yet supported.') + pipeline = DiffusionPipeline.from_pretrained( + self.model_dir, + torch_dtype=torch.float32, + safety_checker=None, + revision=None, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = self.num_class_images - cur_class_images + sample_dataset = PromptDataset(self.instance_prompt, + num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset) + + pipeline.to(self.device) + + for example in tqdm( + sample_dataloader, desc='Generating class images'): + images = pipeline(example['prompt']).images + for i, image in enumerate(images): + hash_image = hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Class Dataset and DataLoaders creation + class_dataset = ClassDataset( + class_data_root=self.class_data_dir + if self.with_prior_preservation else None, + class_prompt=self.class_prompt, + class_num_images=self.num_class_images, + tokenizer=self.model.tokenizer, + size=self.resolution, + center_crop=False, + ) + class_dataloader = torch.utils.data.DataLoader( + class_dataset, + batch_size=1, + shuffle=True, + ) + self.iter_class_dataloader = itertools.cycle(class_dataloader) + + def build_optimizer(self, cfg: ConfigDict, default_args: dict = None): + try: + return build_optimizer( + self.model.unet.parameters(), + cfg=cfg, + default_args=default_args) + except KeyError as e: + self.logger.error( + f'Build optimizer error, the optimizer {cfg} is a torch native component, ' + f'please check if your torch with version: {torch.__version__} matches the config.' + ) + raise e + + def train_step(self, model, inputs): + """ Perform a training step on a batch of inputs. + + Subclass and override to inject custom behavior. + + Args: + model (`TorchModel`): The model to train. + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + + Return: + `torch.Tensor`: The tensor with training loss on this batch. + """ + model.train() + self._mode = ModeKeys.TRAIN + # call model forward but not __call__ to skip postprocess + + receive_dict_inputs = func_receive_dict_inputs( + self.unwrap_module(self.model).forward) + + if isinstance(inputs, Mapping) and not receive_dict_inputs: + train_outputs = model.forward(**inputs) + else: + train_outputs = model.forward(inputs) + + if self.with_prior_preservation: + # Convert to latent space + batch = next(self.iter_class_dataloader) + target_prior = batch['pixel_values'].to(self.device) + input_ids = batch['input_ids'].to(self.device) + with torch.no_grad(): + latents = self.model.vae.encode( + target_prior.to(dtype=torch.float32)).latent_dist.sample() + latents = latents * self.model.vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint( + 0, + self.model.noise_scheduler.num_train_timesteps, (bsz, ), + device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = self.model.noise_scheduler.add_noise( + latents, noise, timesteps) + + # Get the text embedding for conditioning + with torch.no_grad(): + encoder_hidden_states = self.model.text_encoder(input_ids)[0] + + # Get the target for loss depending on the prediction type + if self.model.noise_scheduler.config.prediction_type == 'epsilon': + target_prior = noise + elif self.model.noise_scheduler.config.prediction_type == 'v_prediction': + target_prior = self.model.noise_scheduler.get_velocity( + latents, noise, timesteps) + else: + raise ValueError( + f'Unknown prediction type {self.model.noise_scheduler.config.prediction_type}' + ) + + # Predict the noise residual and compute loss + model_pred_prior = self.model.unet(noisy_latents, timesteps, + encoder_hidden_states).sample + + # Compute prior loss + prior_loss = F.mse_loss( + model_pred_prior.float(), + target_prior.float(), + reduction='mean') + # Add the prior loss to the instance loss. + train_outputs[ + OutputKeys.LOSS] += self.prior_loss_weight * prior_loss + + if isinstance(train_outputs, ModelOutputBase): + train_outputs = train_outputs.to_dict() + if not isinstance(train_outputs, dict): + raise TypeError('"model.forward()" must return a dict') + + # add model output info to log + if 'log_vars' not in train_outputs: + default_keys_pattern = ['loss'] + match_keys = set([]) + for key_p in default_keys_pattern: + match_keys.update( + [key for key in train_outputs.keys() if key_p in key]) + + log_vars = {} + for key in match_keys: + value = train_outputs.get(key, None) + if value is not None: + if is_dist(): + value = value.data.clone().to('cuda') + dist.all_reduce(value.div_(dist.get_world_size())) + log_vars.update({key: value.item()}) + self.log_buffer.update(log_vars) + else: + self.log_buffer.update(train_outputs['log_vars']) + + self.train_outputs = train_outputs diff --git a/modelscope/trainers/multi_modal/lora_diffusion/__init__.py b/modelscope/trainers/multi_modal/lora_diffusion/__init__.py index 311d2789..ebddd00b 100644 --- a/modelscope/trainers/multi_modal/lora_diffusion/__init__.py +++ b/modelscope/trainers/multi_modal/lora_diffusion/__init__.py @@ -1 +1,2 @@ +# Copyright © Alibaba, Inc. and its affiliates. from .lora_diffusion_trainer import LoraDiffusionTrainer diff --git a/tests/pipelines/test_diffusers_stable_diffusion.py b/tests/pipelines/test_diffusers_stable_diffusion.py index 57eae4a3..432706c7 100644 --- a/tests/pipelines/test_diffusers_stable_diffusion.py +++ b/tests/pipelines/test_diffusers_stable_diffusion.py @@ -21,7 +21,7 @@ class DiffusersStableDiffusionTest(unittest.TestCase): def test_run(self): diffusers_pipeline = pipeline(task=self.task, model=self.model_id) output = diffusers_pipeline({ - 'prompt': self.test_input, + 'text': self.test_input, 'height': 512, 'width': 512 }) diff --git a/tests/trainers/test_dreambooth_diffusion_trainer.py b/tests/trainers/test_dreambooth_diffusion_trainer.py new file mode 100644 index 00000000..b57d8a90 --- /dev/null +++ b/tests/trainers/test_dreambooth_diffusion_trainer.py @@ -0,0 +1,98 @@ +# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import os +import shutil +import tempfile +import unittest + +import cv2 + +from modelscope.metainfo import Trainers +from modelscope.msdatasets import MsDataset +from modelscope.pipelines import pipeline +from modelscope.trainers import build_trainer +from modelscope.utils.constant import DownloadMode +from modelscope.utils.test_utils import test_level + + +class TestDreamboothDiffusionTrainer(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + + self.train_dataset = MsDataset.load( + 'buptwq/lora-stable-diffusion-finetune', + split='train', + download_mode=DownloadMode.FORCE_REDOWNLOAD) + self.eval_dataset = MsDataset.load( + 'buptwq/lora-stable-diffusion-finetune', + split='validation', + download_mode=DownloadMode.FORCE_REDOWNLOAD) + + self.max_epochs = 5 + + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_dreambooth_diffusion_train(self): + model_id = 'AI-ModelScope/stable-diffusion-v1-5' + model_revision = 'v1.0.8' + prompt = 'a dog.' + + def cfg_modify_fn(cfg): + cfg.train.max_epochs = self.max_epochs + cfg.train.lr_scheduler = { + 'type': 'LambdaLR', + 'lr_lambda': lambda _: 1, + 'last_epoch': -1 + } + cfg.train.optimizer.lr = 5e-6 + return cfg + + kwargs = dict( + model=model_id, + model_revision=model_revision, + work_dir=self.tmp_dir, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + cfg_modify_fn=cfg_modify_fn) + + trainer = build_trainer( + name=Trainers.dreambooth_diffusion, default_args=kwargs) + trainer.train() + result = trainer.evaluate() + print(f'Dreambooth-diffusion train output: {result}.') + + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + + pipe = pipeline( + task=Tasks.text_to_image_synthesis, model=f'{self.tmp_dir}/output') + output = pipe({'text': prompt}) + cv2.imwrite('./dreambooth_result.png', output['output_imgs'][0]) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_dreambooth_diffusion_eval(self): + model_id = 'AI-ModelScope/stable-diffusion-v1-5' + model_revision = 'v1.0.8' + + kwargs = dict( + model=model_id, + model_revision=model_revision, + work_dir=self.tmp_dir, + train_dataset=None, + eval_dataset=self.eval_dataset) + + trainer = build_trainer( + name=Trainers.dreambooth_diffusion, default_args=kwargs) + result = trainer.evaluate() + print(f'Dreambooth-diffusion eval output: {result}.') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_lora_diffusion_trainer.py b/tests/trainers/test_lora_diffusion_trainer.py index 2ba89665..a9b9e299 100644 --- a/tests/trainers/test_lora_diffusion_trainer.py +++ b/tests/trainers/test_lora_diffusion_trainer.py @@ -38,7 +38,7 @@ class TestLoraDiffusionTrainer(unittest.TestCase): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_lora_diffusion_train(self): model_id = 'AI-ModelScope/stable-diffusion-v1-5' - model_revision = 'v1.0.6' + model_revision = 'v1.0.9' def cfg_modify_fn(cfg): cfg.train.max_epochs = self.max_epochs @@ -70,7 +70,7 @@ class TestLoraDiffusionTrainer(unittest.TestCase): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_lora_diffusion_eval(self): model_id = 'AI-ModelScope/stable-diffusion-v1-5' - model_revision = 'v1.0.6' + model_revision = 'v1.0.9' kwargs = dict( model=model_id, From d8381bf9fd62ce1ee3a847615a4186ef99a1e52a Mon Sep 17 00:00:00 2001 From: Wang Qiang <37444407+XDUWQ@users.noreply.github.com> Date: Wed, 28 Jun 2023 20:13:28 +0800 Subject: [PATCH 3/3] Stable diffusion examples of lora and dreambooth. (#341) * stable diffusion examples of lora and dreambooth * pre-commit --- .../finetune_stable_diffusion_dreambooth.py | 116 ++++++++++++++++++ .../dreambooth/run_train_dreambooth.sh | 20 +++ .../finetune_stable_diffusion_lora.py} | 35 +++++- .../{run_train.sh => lora/run_train_lora.sh} | 7 +- 4 files changed, 170 insertions(+), 8 deletions(-) create mode 100644 examples/pytorch/stable_diffusion/dreambooth/finetune_stable_diffusion_dreambooth.py create mode 100644 examples/pytorch/stable_diffusion/dreambooth/run_train_dreambooth.sh rename examples/pytorch/stable_diffusion/{finetune_stable_diffusion.py => lora/finetune_stable_diffusion_lora.py} (55%) rename examples/pytorch/stable_diffusion/{run_train.sh => lora/run_train_lora.sh} (67%) diff --git a/examples/pytorch/stable_diffusion/dreambooth/finetune_stable_diffusion_dreambooth.py b/examples/pytorch/stable_diffusion/dreambooth/finetune_stable_diffusion_dreambooth.py new file mode 100644 index 00000000..1f38cff7 --- /dev/null +++ b/examples/pytorch/stable_diffusion/dreambooth/finetune_stable_diffusion_dreambooth.py @@ -0,0 +1,116 @@ +from dataclasses import dataclass, field + +import cv2 + +from modelscope.metainfo import Trainers +from modelscope.msdatasets import MsDataset +from modelscope.pipelines import pipeline +from modelscope.trainers import EpochBasedTrainer, build_trainer +from modelscope.trainers.training_args import TrainingArgs +from modelscope.utils.constant import DownloadMode, Tasks + + +# Load configuration file and dataset +@dataclass(init=False) +class StableDiffusionDreamboothArguments(TrainingArgs): + with_prior_preservation: bool = field( + default=False, metadata={ + 'help': 'Whether to enable prior loss.', + }) + + instance_prompt: str = field( + default='a photo of sks dog', + metadata={ + 'help': 'The instance prompt for dreambooth.', + }) + + class_prompt: str = field( + default='a photo of dog', + metadata={ + 'help': 'The class prompt for dreambooth.', + }) + + class_data_dir: str = field( + default='./tmp/class_data', + metadata={ + 'help': 'Save class prompt images path.', + }) + + num_class_images: int = field( + default=200, + metadata={ + 'help': 'The numbers of saving class images.', + }) + + resolution: int = field( + default=512, metadata={ + 'help': 'The class images resolution.', + }) + + prior_loss_weight: float = field( + default=1.0, + metadata={ + 'help': 'The weight of instance and prior loss.', + }) + + prompt: str = field( + default='dog', metadata={ + 'help': 'The pipeline prompt.', + }) + + +training_args = StableDiffusionDreamboothArguments( + task='text-to-image-synthesis').parse_cli() +config, args = training_args.to_config() + +train_dataset = MsDataset.load( + args.train_dataset_name, + split='train', + download_mode=DownloadMode.FORCE_REDOWNLOAD) +validation_dataset = MsDataset.load( + args.train_dataset_name, + split='validation', + download_mode=DownloadMode.FORCE_REDOWNLOAD) + + +def cfg_modify_fn(cfg): + if args.use_model_config: + cfg.merge_from_dict(config) + else: + cfg = config + cfg.train.lr_scheduler = { + 'type': 'LambdaLR', + 'lr_lambda': lambda _: 1, + 'last_epoch': -1 + } + return cfg + + +kwargs = dict( + model=training_args.model, + model_revision=args.model_revision, + work_dir=training_args.work_dir, + train_dataset=train_dataset, + eval_dataset=validation_dataset, + instance_prompt=args.instance_prompt, + class_prompt=args.class_prompt, + class_data_dir=args.class_data_dir, + num_class_images=args.num_class_images, + resolution=args.resolution, + prior_loss_weight=args.prior_loss_weight, + prompt=args.prompt, + cfg_modify_fn=cfg_modify_fn) + +# build trainer and training +trainer = build_trainer( + name=Trainers.dreambooth_diffusion, default_args=kwargs) +trainer.train() + +# pipeline after training and save result +pipe = pipeline( + task=Tasks.text_to_image_synthesis, + model=training_args.work_dir + '/output', + model_revision=args.model_revision) + +output = pipe({'text': args.prompt}) +cv2.imwrite('./dreambooth_result.png', output['output_imgs'][0]) diff --git a/examples/pytorch/stable_diffusion/dreambooth/run_train_dreambooth.sh b/examples/pytorch/stable_diffusion/dreambooth/run_train_dreambooth.sh new file mode 100644 index 00000000..dcb3b2e6 --- /dev/null +++ b/examples/pytorch/stable_diffusion/dreambooth/run_train_dreambooth.sh @@ -0,0 +1,20 @@ +PYTHONPATH=. torchrun examples/pytorch/stable_diffusion/dreambooth/finetune_stable_diffusion_dreambooth.py \ + --model 'AI-ModelScope/stable-diffusion-v1-5' \ + --model_revision 'v1.0.8' \ + --work_dir './tmp/dreambooth_diffusion' \ + --train_dataset_name 'buptwq/lora-stable-diffusion-finetune' \ + --with_prior_preservation false \ + --instance_prompt "a photo of sks dog" \ + --class_prompt "a photo of dog" \ + --class_data_dir "./tmp/class_data" \ + --num_class_images 200 \ + --resolution 512 \ + --prior_loss_weight 1.0 \ + --prompt "dog" \ + --max_epochs 150 \ + --save_ckpt_strategy 'by_epoch' \ + --logging_interval 1 \ + --train.dataloader.workers_per_gpu 0 \ + --evaluation.dataloader.workers_per_gpu 0 \ + --train.optimizer.lr 5e-6 \ + --use_model_config true diff --git a/examples/pytorch/stable_diffusion/finetune_stable_diffusion.py b/examples/pytorch/stable_diffusion/lora/finetune_stable_diffusion_lora.py similarity index 55% rename from examples/pytorch/stable_diffusion/finetune_stable_diffusion.py rename to examples/pytorch/stable_diffusion/lora/finetune_stable_diffusion_lora.py index ac16ed2c..183e817d 100644 --- a/examples/pytorch/stable_diffusion/finetune_stable_diffusion.py +++ b/examples/pytorch/stable_diffusion/lora/finetune_stable_diffusion_lora.py @@ -1,12 +1,27 @@ +from dataclasses import dataclass, field + +import cv2 + from modelscope.metainfo import Trainers from modelscope.msdatasets import MsDataset +from modelscope.pipelines import pipeline from modelscope.trainers import EpochBasedTrainer, build_trainer from modelscope.trainers.training_args import TrainingArgs -from modelscope.utils.constant import DownloadMode +from modelscope.utils.constant import DownloadMode, Tasks -training_args = TrainingArgs(task='text-to-image-synthesis').parse_cli() + +# Load configuration file and dataset +@dataclass(init=False) +class StableDiffusionLoraArguments(TrainingArgs): + prompt: str = field( + default='dog', metadata={ + 'help': 'The pipeline prompt.', + }) + + +training_args = StableDiffusionLoraArguments( + task='text-to-image-synthesis').parse_cli() config, args = training_args.to_config() -print(args) train_dataset = MsDataset.load( args.train_dataset_name, @@ -28,17 +43,27 @@ def cfg_modify_fn(cfg): 'lr_lambda': lambda _: 1, 'last_epoch': -1 } - cfg.train.optimizer.lr = 1e-4 return cfg kwargs = dict( model=training_args.model, - model_revision='v1.0.6', + model_revision=args.model_revision, work_dir=training_args.work_dir, train_dataset=train_dataset, eval_dataset=validation_dataset, cfg_modify_fn=cfg_modify_fn) +# build trainer and training trainer = build_trainer(name=Trainers.lora_diffusion, default_args=kwargs) trainer.train() + +# pipeline after training and save result +pipe = pipeline( + task=Tasks.text_to_image_synthesis, + model=training_args.model, + lora_dir=training_args.work_dir + '/output', + model_revision=args.model_revision) + +output = pipe({'text': args.prompt}) +cv2.imwrite('./lora_result.png', output['output_imgs'][0]) diff --git a/examples/pytorch/stable_diffusion/run_train.sh b/examples/pytorch/stable_diffusion/lora/run_train_lora.sh similarity index 67% rename from examples/pytorch/stable_diffusion/run_train.sh rename to examples/pytorch/stable_diffusion/lora/run_train_lora.sh index 8e45ae88..876a2475 100644 --- a/examples/pytorch/stable_diffusion/run_train.sh +++ b/examples/pytorch/stable_diffusion/lora/run_train_lora.sh @@ -1,11 +1,12 @@ -PYTHONPATH=. torchrun examples/pytorch/stable_diffusion/finetune_stable_diffusion.py \ +PYTHONPATH=. torchrun examples/pytorch/stable_diffusion/lora/finetune_stable_diffusion_lora.py \ --model 'AI-ModelScope/stable-diffusion-v1-5' \ - --model_revision 'v1.0.6' \ + --model_revision 'v1.0.9' \ + --prompt "a dog" \ --work_dir './tmp/lora_diffusion' \ --train_dataset_name 'buptwq/lora-stable-diffusion-finetune' \ --max_epochs 100 \ --save_ckpt_strategy 'by_epoch' \ - --logging_interval 100 \ + --logging_interval 1 \ --train.dataloader.workers_per_gpu 0 \ --evaluation.dataloader.workers_per_gpu 0 \ --train.optimizer.lr 1e-4 \