Support sdxl finetune by lora method (#468)

* support sdxl finetune by lora

* remove useless imports

* support sdxl finetune

* upgrade diffusers to 0.19.0

* sdxl finetune

* fix bugs

* pre commit

* diffusers>=0.19.0
This commit is contained in:
Wang Qiang
2023-08-23 11:23:34 +08:00
committed by GitHub
parent df53a6a89f
commit de33f4dc87
12 changed files with 574 additions and 5 deletions

View File

@@ -7,7 +7,7 @@ import torch
from modelscope.metainfo import Trainers
from modelscope.msdatasets import MsDataset
from modelscope.pipelines import pipeline
from modelscope.trainers import EpochBasedTrainer, build_trainer
from modelscope.trainers import build_trainer
from modelscope.trainers.training_args import TrainingArgs
from modelscope.utils.constant import DownloadMode, Tasks

View File

@@ -7,7 +7,7 @@ import torch
from modelscope.metainfo import Trainers
from modelscope.msdatasets import MsDataset
from modelscope.pipelines import pipeline
from modelscope.trainers import EpochBasedTrainer, build_trainer
from modelscope.trainers import build_trainer
from modelscope.trainers.training_args import TrainingArgs
from modelscope.utils.constant import DownloadMode, Tasks

View File

@@ -0,0 +1,85 @@
import os
from dataclasses import dataclass, field
import cv2
from modelscope.metainfo import Trainers
from modelscope.msdatasets import MsDataset
from modelscope.pipelines import pipeline
from modelscope.trainers import build_trainer
from modelscope.trainers.training_args import TrainingArgs
from modelscope.utils.constant import DownloadMode, Tasks
# Load configuration file and dataset
@dataclass(init=False)
class StableDiffusionXLLoraArguments(TrainingArgs):
prompt: str = field(
default='dog', metadata={
'help': 'The pipeline prompt.',
})
lora_rank: int = field(
default=16,
metadata={
'help': 'The rank size of lora intermediate linear.',
})
training_args = StableDiffusionXLLoraArguments(
task='text-to-image-synthesis').parse_cli()
config, args = training_args.to_config()
if os.path.exists(args.train_dataset_name):
# Load local dataset
train_dataset = MsDataset.load(args.train_dataset_name)
validation_dataset = MsDataset.load(args.train_dataset_name)
else:
# Load online dataset
train_dataset = MsDataset.load(
args.train_dataset_name,
split='train',
download_mode=DownloadMode.FORCE_REDOWNLOAD)
validation_dataset = MsDataset.load(
args.train_dataset_name,
split='validation',
download_mode=DownloadMode.FORCE_REDOWNLOAD)
def cfg_modify_fn(cfg):
if args.use_model_config:
cfg.merge_from_dict(config)
else:
cfg = config
cfg.train.lr_scheduler = {
'type': 'LambdaLR',
'lr_lambda': lambda _: 1,
'last_epoch': -1
}
return cfg
kwargs = dict(
model=training_args.model,
model_revision=args.model_revision,
work_dir=training_args.work_dir,
train_dataset=train_dataset,
eval_dataset=validation_dataset,
lora_rank=args.lora_rank,
cfg_modify_fn=cfg_modify_fn)
# build trainer and training
trainer = build_trainer(name=Trainers.lora_diffusion_xl, default_args=kwargs)
trainer.train()
# pipeline after training and save result
pipe = pipeline(
task=Tasks.text_to_image_synthesis,
model=training_args.model,
lora_dir=training_args.work_dir + '/output',
model_revision=args.model_revision)
output = pipe({'text': args.prompt})
# visualize the result on ipynb and save it
output
cv2.imwrite('./lora_xl_result.png', output['output_imgs'][0])

View File

@@ -0,0 +1,14 @@
PYTHONPATH=. torchrun examples/pytorch/stable_diffusion/lora_xl/finetune_stable_diffusion_xl_lora.py \
--model 'AI-ModelScope/stable-diffusion-xl-base-1.0' \
--model_revision 'v1.0.2' \
--prompt "a dog" \
--work_dir './tmp/lora_diffusion_xl' \
--train_dataset_name 'buptwq/lora-stable-diffusion-finetune' \
--max_epochs 100 \
--lora_rank 16 \
--save_ckpt_strategy 'by_epoch' \
--logging_interval 1 \
--train.dataloader.workers_per_gpu 0 \
--evaluation.dataloader.workers_per_gpu 0 \
--train.optimizer.lr 1e-4 \
--use_model_config true

View File

@@ -218,6 +218,7 @@ class Models(object):
mplug_owl = 'mplug-owl'
clip_interrogator = 'clip-interrogator'
stable_diffusion = 'stable-diffusion'
stable_diffusion_xl = 'stable-diffusion-xl'
videocomposer = 'videocomposer'
text_to_360panorama_image = 'text-to-360panorama-image'
image_to_video_model = 'image-to-video-model'
@@ -938,6 +939,7 @@ class MultiModalTrainers(object):
efficient_diffusion_tuning = 'efficient-diffusion-tuning'
stable_diffusion = 'stable-diffusion'
lora_diffusion = 'lora-diffusion'
lora_diffusion_xl = 'lora-diffusion-xl'
dreambooth_diffusion = 'dreambooth-diffusion'
custom_diffusion = 'custom-diffusion'

View File

@@ -1,2 +1,3 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .stable_diffusion import StableDiffusion
from .stable_diffusion_xl import StableDiffusionXL

View File

@@ -20,7 +20,7 @@ from modelscope.utils.constant import Tasks
@MODELS.register_module(
Tasks.text_to_image_synthesis, module_name=Models.stable_diffusion)
class StableDiffusion(TorchModel):
""" The implementation of efficient diffusion tuning model based on TorchModel.
""" The implementation of stable diffusion model based on TorchModel.
This model is constructed with the implementation of stable diffusion model. If you want to
finetune lightweight parameters on your own dataset, you can define you own tuner module
@@ -28,7 +28,7 @@ class StableDiffusion(TorchModel):
"""
def __init__(self, model_dir, *args, **kwargs):
""" Initialize a vision efficient diffusion tuning model.
""" Initialize a vision stable diffusion model.
Args:
model_dir: model id or path

View File

@@ -0,0 +1,254 @@
# Copyright 2023-2024 The Alibaba Fundamental Vision Team Authors. All rights reserved.
import os
import random
from functools import partial
from typing import Callable, List, Optional, Union
import torch
import torch.nn.functional as F
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
from packaging import version
from torchvision import transforms
from torchvision.transforms.functional import crop
from transformers import (AutoTokenizer, CLIPTextModel,
CLIPTextModelWithProjection)
from modelscope.metainfo import Models
from modelscope.models import TorchModel
from modelscope.models.builder import MODELS
from modelscope.outputs import OutputKeys
from modelscope.utils.checkpoint import save_checkpoint, save_configuration
from modelscope.utils.constant import Tasks
@MODELS.register_module(
Tasks.text_to_image_synthesis, module_name=Models.stable_diffusion_xl)
class StableDiffusionXL(TorchModel):
""" The implementation of stable diffusion xl model based on TorchModel.
This model is constructed with the implementation of stable diffusion xl model. If you want to
finetune lightweight parameters on your own dataset, you can define you own tuner module
and load in this cls.
"""
def __init__(self, model_dir, *args, **kwargs):
""" Initialize a vision stable diffusion xl model.
Args:
model_dir: model id or path
"""
super().__init__(model_dir, *args, **kwargs)
revision = kwargs.pop('revision', None)
xformers_enable = kwargs.pop('xformers_enable', False)
self.lora_tune = kwargs.pop('lora_tune', False)
self.resolution = kwargs.pop('resolution', 1024)
self.random_flip = kwargs.pop('random_flip', True)
self.weight_dtype = torch.float32
self.device = torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')
# Load scheduler, tokenizer and models
self.noise_scheduler = DDPMScheduler.from_pretrained(
model_dir, subfolder='scheduler')
self.tokenizer_one = AutoTokenizer.from_pretrained(
model_dir,
subfolder='tokenizer',
revision=revision,
use_fast=False)
self.tokenizer_two = AutoTokenizer.from_pretrained(
model_dir,
subfolder='tokenizer_2',
revision=revision,
use_fast=False)
self.text_encoder_one = CLIPTextModel.from_pretrained(
model_dir, subfolder='text_encoder', revision=revision)
self.text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
model_dir, subfolder='text_encoder_2', revision=revision)
self.vae = AutoencoderKL.from_pretrained(
model_dir, subfolder='vae', revision=revision)
self.unet = UNet2DConditionModel.from_pretrained(
model_dir, subfolder='unet', revision=revision)
self.safety_checker = None
# Freeze gradient calculation and move to device
if self.vae is not None:
self.vae.requires_grad_(False)
self.vae = self.vae.to(self.device)
if self.text_encoder_one is not None:
self.text_encoder_one.requires_grad_(False)
self.text_encoder_one = self.text_encoder_one.to(self.device)
if self.text_encoder_two is not None:
self.text_encoder_two.requires_grad_(False)
self.text_encoder_two = self.text_encoder_two.to(self.device)
if self.unet is not None:
if self.lora_tune:
self.unet.requires_grad_(False)
self.unet = self.unet.to(self.device)
# xformers accelerate memory efficient attention
if xformers_enable:
import xformers
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse('0.0.16'):
logger.warn(
'xFormers 0.0.16 cannot be used for training in some GPUs. '
'If you observe problems during training, please update xFormers to at least 0.0.17.'
)
self.unet.enable_xformers_memory_efficient_attention()
def tokenize_caption(self, tokenizer, captions):
""" Convert caption text to token data.
Args:
tokenizer: the tokenizer one or two.
captions: a batch of texts.
Returns: token's data as tensor.
"""
inputs = tokenizer(
captions,
max_length=tokenizer.model_max_length,
padding='max_length',
truncation=True,
return_tensors='pt')
return inputs.input_ids
def compute_time_ids(self, original_size, crops_coords_top_left):
target_size = (self.resolution, self.resolution)
add_time_ids = list(original_size + crops_coords_top_left
+ target_size)
add_time_ids = torch.tensor([add_time_ids])
add_time_ids = add_time_ids.to(self.device, dtype=self.weight_dtype)
return add_time_ids
def encode_prompt(self,
text_encoders,
tokenizers,
prompt,
text_input_ids_list=None):
prompt_embeds_list = []
for i, text_encoder in enumerate(text_encoders):
if tokenizers is not None:
tokenizer = tokenizers[i]
text_input_ids = tokenize_prompt(tokenizer, prompt)
else:
assert text_input_ids_list is not None
text_input_ids = text_input_ids_list[i]
prompt_embeds = text_encoder(
text_input_ids.to(text_encoder.device),
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
return prompt_embeds, pooled_prompt_embeds
def preprocessing_data(self, text, target):
train_crop = transforms.RandomCrop(self.resolution)
train_resize = transforms.Resize(
self.resolution,
interpolation=transforms.InterpolationMode.BILINEAR)
train_flip = transforms.RandomHorizontalFlip(p=1.0)
image = target
original_size = (image.size()[-1], image.size()[-2])
image = train_resize(image)
y1, x1, h, w = train_crop.get_params(
image, (self.resolution, self.resolution))
image = crop(image, y1, x1, h, w)
if self.random_flip and random.random() < 0.5:
# flip
x1 = image.size()[-2] - x1
image = train_flip(image)
crop_top_left = (y1, x1)
input_ids_one = self.tokenize_caption(self.tokenizer_one, text)
input_ids_two = self.tokenize_caption(self.tokenizer_two, text)
return original_size, crop_top_left, image, input_ids_one, input_ids_two
def forward(self, text='', target=None):
self.unet.train()
self.unet = self.unet.to(self.device)
# processing data
original_size, crop_top_left, image, input_ids_one, input_ids_two = self.preprocessing_data(
text, target)
# Convert to latent space
with torch.no_grad():
latents = self.vae.encode(
target.to(dtype=self.weight_dtype)).latent_dist.sample()
latents = latents * self.vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0,
self.noise_scheduler.num_train_timesteps, (bsz, ),
device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = self.noise_scheduler.add_noise(latents, noise,
timesteps)
add_time_ids = self.compute_time_ids(original_size, crop_top_left)
# Predict the noise residual
unet_added_conditions = {'time_ids': add_time_ids}
prompt_embeds, pooled_prompt_embeds = self.encode_prompt(
text_encoders=[self.text_encoder_one, self.text_encoder_two],
tokenizers=None,
prompt=None,
text_input_ids_list=[input_ids_one, input_ids_two])
unet_added_conditions.update({'text_embeds': pooled_prompt_embeds})
# Predict the noise residual and compute loss
model_pred = self.unet(
noisy_latents,
timesteps,
prompt_embeds,
added_cond_kwargs=unet_added_conditions).sample
# Get the target for loss depending on the prediction type
if self.noise_scheduler.config.prediction_type == 'epsilon':
target = noise
elif self.noise_scheduler.config.prediction_type == 'v_prediction':
target = self.noise_scheduler.get_velocity(model_input, noise,
timesteps)
else:
raise ValueError(
f'Unknown prediction type {self.noise_scheduler.config.prediction_type}'
)
loss = F.mse_loss(model_pred.float(), target.float(), reduction='mean')
output = {OutputKeys.LOSS: loss}
return output
def save_pretrained(self,
target_folder: Union[str, os.PathLike],
save_checkpoint_names: Union[str, List[str]] = None,
save_function: Callable = partial(
save_checkpoint, with_meta=False),
config: Optional[dict] = None,
save_config_function: Callable = save_configuration,
**kwargs):
config['pipeline']['type'] = 'diffusers-stable-diffusion-xl'
# Skip copying the original weights for lora and dreambooth method
if self.lora_tune or self.dreambooth_tune:
pass
else:
super().save_pretrained(target_folder, save_checkpoint_names,
save_function, config,
save_config_function, **kwargs)

View File

@@ -0,0 +1,2 @@
# Copyright © Alibaba, Inc. and its affiliates.
from .lora_diffusion_xl_trainer import LoraDiffusionXLTrainer

View File

@@ -0,0 +1,122 @@
# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
from typing import Dict, Union
import torch
import torch.nn.functional as F
from diffusers import StableDiffusionXLPipeline
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import (LoRAAttnProcessor,
LoRAAttnProcessor2_0)
from modelscope.metainfo import Trainers
from modelscope.trainers.builder import TRAINERS
from modelscope.trainers.hooks.checkpoint.checkpoint_hook import CheckpointHook
from modelscope.trainers.hooks.checkpoint.checkpoint_processor import \
CheckpointProcessor
from modelscope.trainers.optimizer.builder import build_optimizer
from modelscope.trainers.trainer import EpochBasedTrainer
from modelscope.utils.config import ConfigDict
def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]:
"""
Returns:
a state dict containing just the attention processor parameters.
"""
attn_processors = unet.attn_processors
attn_processors_state_dict = {}
for attn_processor_key, attn_processor in attn_processors.items():
for parameter_key, parameter in attn_processor.state_dict().items():
attn_processors_state_dict[
f'{attn_processor_key}.{parameter_key}'] = parameter
return attn_processors_state_dict
class LoraDiffusionXLCheckpointProcessor(CheckpointProcessor):
def save_checkpoints(self,
trainer,
checkpoint_path_prefix,
output_dir,
meta=None,
save_optimizers=True):
"""Save the state dict for lora tune stable diffusion xl model.
"""
attn_processors = trainer.model.unet.attn_processors
unet_lora_layers_to_save = {}
for attn_processor_key, attn_processor in attn_processors.items():
for parameter_key, parameter in attn_processor.state_dict().items(
):
unet_lora_layers_to_save[
f'{attn_processor_key}.{parameter_key}'] = parameter
StableDiffusionXLPipeline.save_lora_weights(
output_dir, unet_lora_layers=unet_lora_layers_to_save)
@TRAINERS.register_module(module_name=Trainers.lora_diffusion_xl)
class LoraDiffusionXLTrainer(EpochBasedTrainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
"""Lora trainers for fine-tuning stable diffusion xl.
Args:
lora_rank: The rank size of lora intermediate linear.
"""
lora_rank = kwargs.pop('lora_rank', 16)
# set lora save checkpoint processor
ckpt_hook = list(
filter(lambda hook: isinstance(hook, CheckpointHook),
self.hooks))[0]
ckpt_hook.set_processor(LoraDiffusionXLCheckpointProcessor())
# Add lora weights to attention layers and set correct lora layers
unet_lora_attn_procs = {}
unet_lora_parameters = []
for name, attn_processor in self.model.unet.attn_processors.items():
cross_attention_dim = None if name.endswith(
'attn1.processor'
) else self.model.unet.config.cross_attention_dim
if name.startswith('mid_block'):
hidden_size = self.model.unet.config.block_out_channels[-1]
elif name.startswith('up_blocks'):
block_id = int(name[len('up_blocks.')])
hidden_size = list(
reversed(
self.model.unet.config.block_out_channels))[block_id]
elif name.startswith('down_blocks'):
block_id = int(name[len('down_blocks.')])
hidden_size = self.model.unet.config.block_out_channels[
block_id]
lora_attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(
F, 'scaled_dot_product_attention') else LoRAAttnProcessor)
module = lora_attn_processor_class(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
rank=lora_rank)
unet_lora_attn_procs[name] = module
unet_lora_parameters.extend(module.parameters())
self.model.unet.set_attn_processor(unet_lora_attn_procs)
def build_optimizer(self, cfg: ConfigDict, default_args: dict = None):
try:
return build_optimizer(
self.model.unet.parameters(),
cfg=cfg,
default_args=default_args)
except KeyError as e:
self.logger.error(
f'Build optimizer error, the optimizer {cfg} is a torch native component, '
f'please check if your torch with version: {torch.__version__} matches the config.'
)
raise e

View File

@@ -1,7 +1,7 @@
accelerate
cloudpickle
decord>=0.6.0
diffusers==0.18.0
diffusers>=0.19.0
fairseq
ftfy>=6.0.3
librosa==0.9.2

View File

@@ -0,0 +1,89 @@
# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
import os
import shutil
import tempfile
import unittest
from modelscope.metainfo import Trainers
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
from modelscope.utils.constant import DownloadMode
from modelscope.utils.test_utils import test_level
class TestLoraDiffusionXLTrainer(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
self.train_dataset = MsDataset.load(
'buptwq/lora-stable-diffusion-finetune',
split='train',
download_mode=DownloadMode.FORCE_REDOWNLOAD)
self.eval_dataset = MsDataset.load(
'buptwq/lora-stable-diffusion-finetune',
split='validation',
download_mode=DownloadMode.FORCE_REDOWNLOAD)
self.max_epochs = 5
self.tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)
def tearDown(self):
shutil.rmtree(self.tmp_dir)
super().tearDown()
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_lora_diffusion_xl_train(self):
model_id = 'AI-ModelScope/stable-diffusion-xl-base-1.0'
model_revision = 'v1.0.2'
def cfg_modify_fn(cfg):
cfg.train.max_epochs = self.max_epochs
cfg.train.lr_scheduler = {
'type': 'LambdaLR',
'lr_lambda': lambda _: 1,
'last_epoch': -1
}
cfg.train.optimizer.lr = 1e-4
return cfg
kwargs = dict(
model=model_id,
model_revision=model_revision,
work_dir=self.tmp_dir,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
cfg_modify_fn=cfg_modify_fn)
trainer = build_trainer(
name=Trainers.lora_diffusion, default_args=kwargs)
trainer.train()
result = trainer.evaluate()
print(f'Lora-diffusion-xl train output: {result}.')
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_lora_diffusion_xl_eval(self):
model_id = 'AI-ModelScope/stable-diffusion-xl-base-1.0'
model_revision = 'v1.0.2'
kwargs = dict(
model=model_id,
model_revision=model_revision,
work_dir=self.tmp_dir,
train_dataset=None,
eval_dataset=self.eval_dataset)
trainer = build_trainer(
name=Trainers.lora_diffusion_xl, default_args=kwargs)
result = trainer.evaluate()
print(f'Lora-diffusion-xl eval output: {result}.')
if __name__ == '__main__':
unittest.main()