mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 12:10:09 +01:00
Support sdxl finetune by lora method (#468)
* support sdxl finetune by lora * remove useless imports * support sdxl finetune * upgrade diffusers to 0.19.0 * sdxl finetune * fix bugs * pre commit * diffusers>=0.19.0
This commit is contained in:
@@ -7,7 +7,7 @@ import torch
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.trainers import EpochBasedTrainer, build_trainer
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.trainers.training_args import TrainingArgs
|
||||
from modelscope.utils.constant import DownloadMode, Tasks
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import torch
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.trainers import EpochBasedTrainer, build_trainer
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.trainers.training_args import TrainingArgs
|
||||
from modelscope.utils.constant import DownloadMode, Tasks
|
||||
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import cv2
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.trainers.training_args import TrainingArgs
|
||||
from modelscope.utils.constant import DownloadMode, Tasks
|
||||
|
||||
|
||||
# Load configuration file and dataset
|
||||
@dataclass(init=False)
|
||||
class StableDiffusionXLLoraArguments(TrainingArgs):
|
||||
prompt: str = field(
|
||||
default='dog', metadata={
|
||||
'help': 'The pipeline prompt.',
|
||||
})
|
||||
|
||||
lora_rank: int = field(
|
||||
default=16,
|
||||
metadata={
|
||||
'help': 'The rank size of lora intermediate linear.',
|
||||
})
|
||||
|
||||
|
||||
training_args = StableDiffusionXLLoraArguments(
|
||||
task='text-to-image-synthesis').parse_cli()
|
||||
config, args = training_args.to_config()
|
||||
|
||||
if os.path.exists(args.train_dataset_name):
|
||||
# Load local dataset
|
||||
train_dataset = MsDataset.load(args.train_dataset_name)
|
||||
validation_dataset = MsDataset.load(args.train_dataset_name)
|
||||
else:
|
||||
# Load online dataset
|
||||
train_dataset = MsDataset.load(
|
||||
args.train_dataset_name,
|
||||
split='train',
|
||||
download_mode=DownloadMode.FORCE_REDOWNLOAD)
|
||||
validation_dataset = MsDataset.load(
|
||||
args.train_dataset_name,
|
||||
split='validation',
|
||||
download_mode=DownloadMode.FORCE_REDOWNLOAD)
|
||||
|
||||
|
||||
def cfg_modify_fn(cfg):
|
||||
if args.use_model_config:
|
||||
cfg.merge_from_dict(config)
|
||||
else:
|
||||
cfg = config
|
||||
cfg.train.lr_scheduler = {
|
||||
'type': 'LambdaLR',
|
||||
'lr_lambda': lambda _: 1,
|
||||
'last_epoch': -1
|
||||
}
|
||||
return cfg
|
||||
|
||||
|
||||
kwargs = dict(
|
||||
model=training_args.model,
|
||||
model_revision=args.model_revision,
|
||||
work_dir=training_args.work_dir,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=validation_dataset,
|
||||
lora_rank=args.lora_rank,
|
||||
cfg_modify_fn=cfg_modify_fn)
|
||||
|
||||
# build trainer and training
|
||||
trainer = build_trainer(name=Trainers.lora_diffusion_xl, default_args=kwargs)
|
||||
trainer.train()
|
||||
|
||||
# pipeline after training and save result
|
||||
pipe = pipeline(
|
||||
task=Tasks.text_to_image_synthesis,
|
||||
model=training_args.model,
|
||||
lora_dir=training_args.work_dir + '/output',
|
||||
model_revision=args.model_revision)
|
||||
|
||||
output = pipe({'text': args.prompt})
|
||||
# visualize the result on ipynb and save it
|
||||
output
|
||||
cv2.imwrite('./lora_xl_result.png', output['output_imgs'][0])
|
||||
@@ -0,0 +1,14 @@
|
||||
PYTHONPATH=. torchrun examples/pytorch/stable_diffusion/lora_xl/finetune_stable_diffusion_xl_lora.py \
|
||||
--model 'AI-ModelScope/stable-diffusion-xl-base-1.0' \
|
||||
--model_revision 'v1.0.2' \
|
||||
--prompt "a dog" \
|
||||
--work_dir './tmp/lora_diffusion_xl' \
|
||||
--train_dataset_name 'buptwq/lora-stable-diffusion-finetune' \
|
||||
--max_epochs 100 \
|
||||
--lora_rank 16 \
|
||||
--save_ckpt_strategy 'by_epoch' \
|
||||
--logging_interval 1 \
|
||||
--train.dataloader.workers_per_gpu 0 \
|
||||
--evaluation.dataloader.workers_per_gpu 0 \
|
||||
--train.optimizer.lr 1e-4 \
|
||||
--use_model_config true
|
||||
@@ -218,6 +218,7 @@ class Models(object):
|
||||
mplug_owl = 'mplug-owl'
|
||||
clip_interrogator = 'clip-interrogator'
|
||||
stable_diffusion = 'stable-diffusion'
|
||||
stable_diffusion_xl = 'stable-diffusion-xl'
|
||||
videocomposer = 'videocomposer'
|
||||
text_to_360panorama_image = 'text-to-360panorama-image'
|
||||
image_to_video_model = 'image-to-video-model'
|
||||
@@ -938,6 +939,7 @@ class MultiModalTrainers(object):
|
||||
efficient_diffusion_tuning = 'efficient-diffusion-tuning'
|
||||
stable_diffusion = 'stable-diffusion'
|
||||
lora_diffusion = 'lora-diffusion'
|
||||
lora_diffusion_xl = 'lora-diffusion-xl'
|
||||
dreambooth_diffusion = 'dreambooth-diffusion'
|
||||
custom_diffusion = 'custom-diffusion'
|
||||
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from .stable_diffusion import StableDiffusion
|
||||
from .stable_diffusion_xl import StableDiffusionXL
|
||||
|
||||
@@ -20,7 +20,7 @@ from modelscope.utils.constant import Tasks
|
||||
@MODELS.register_module(
|
||||
Tasks.text_to_image_synthesis, module_name=Models.stable_diffusion)
|
||||
class StableDiffusion(TorchModel):
|
||||
""" The implementation of efficient diffusion tuning model based on TorchModel.
|
||||
""" The implementation of stable diffusion model based on TorchModel.
|
||||
|
||||
This model is constructed with the implementation of stable diffusion model. If you want to
|
||||
finetune lightweight parameters on your own dataset, you can define you own tuner module
|
||||
@@ -28,7 +28,7 @@ class StableDiffusion(TorchModel):
|
||||
"""
|
||||
|
||||
def __init__(self, model_dir, *args, **kwargs):
|
||||
""" Initialize a vision efficient diffusion tuning model.
|
||||
""" Initialize a vision stable diffusion model.
|
||||
|
||||
Args:
|
||||
model_dir: model id or path
|
||||
|
||||
@@ -0,0 +1,254 @@
|
||||
# Copyright 2023-2024 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
import os
|
||||
import random
|
||||
from functools import partial
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
|
||||
from packaging import version
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import crop
|
||||
from transformers import (AutoTokenizer, CLIPTextModel,
|
||||
CLIPTextModelWithProjection)
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.utils.checkpoint import save_checkpoint, save_configuration
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.text_to_image_synthesis, module_name=Models.stable_diffusion_xl)
|
||||
class StableDiffusionXL(TorchModel):
|
||||
""" The implementation of stable diffusion xl model based on TorchModel.
|
||||
|
||||
This model is constructed with the implementation of stable diffusion xl model. If you want to
|
||||
finetune lightweight parameters on your own dataset, you can define you own tuner module
|
||||
and load in this cls.
|
||||
"""
|
||||
|
||||
def __init__(self, model_dir, *args, **kwargs):
|
||||
""" Initialize a vision stable diffusion xl model.
|
||||
|
||||
Args:
|
||||
model_dir: model id or path
|
||||
"""
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
revision = kwargs.pop('revision', None)
|
||||
xformers_enable = kwargs.pop('xformers_enable', False)
|
||||
self.lora_tune = kwargs.pop('lora_tune', False)
|
||||
self.resolution = kwargs.pop('resolution', 1024)
|
||||
self.random_flip = kwargs.pop('random_flip', True)
|
||||
|
||||
self.weight_dtype = torch.float32
|
||||
self.device = torch.device(
|
||||
'cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# Load scheduler, tokenizer and models
|
||||
self.noise_scheduler = DDPMScheduler.from_pretrained(
|
||||
model_dir, subfolder='scheduler')
|
||||
self.tokenizer_one = AutoTokenizer.from_pretrained(
|
||||
model_dir,
|
||||
subfolder='tokenizer',
|
||||
revision=revision,
|
||||
use_fast=False)
|
||||
self.tokenizer_two = AutoTokenizer.from_pretrained(
|
||||
model_dir,
|
||||
subfolder='tokenizer_2',
|
||||
revision=revision,
|
||||
use_fast=False)
|
||||
self.text_encoder_one = CLIPTextModel.from_pretrained(
|
||||
model_dir, subfolder='text_encoder', revision=revision)
|
||||
self.text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
|
||||
model_dir, subfolder='text_encoder_2', revision=revision)
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_dir, subfolder='vae', revision=revision)
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
model_dir, subfolder='unet', revision=revision)
|
||||
self.safety_checker = None
|
||||
|
||||
# Freeze gradient calculation and move to device
|
||||
if self.vae is not None:
|
||||
self.vae.requires_grad_(False)
|
||||
self.vae = self.vae.to(self.device)
|
||||
if self.text_encoder_one is not None:
|
||||
self.text_encoder_one.requires_grad_(False)
|
||||
self.text_encoder_one = self.text_encoder_one.to(self.device)
|
||||
if self.text_encoder_two is not None:
|
||||
self.text_encoder_two.requires_grad_(False)
|
||||
self.text_encoder_two = self.text_encoder_two.to(self.device)
|
||||
if self.unet is not None:
|
||||
if self.lora_tune:
|
||||
self.unet.requires_grad_(False)
|
||||
self.unet = self.unet.to(self.device)
|
||||
|
||||
# xformers accelerate memory efficient attention
|
||||
if xformers_enable:
|
||||
import xformers
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse('0.0.16'):
|
||||
logger.warn(
|
||||
'xFormers 0.0.16 cannot be used for training in some GPUs. '
|
||||
'If you observe problems during training, please update xFormers to at least 0.0.17.'
|
||||
)
|
||||
self.unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
def tokenize_caption(self, tokenizer, captions):
|
||||
""" Convert caption text to token data.
|
||||
|
||||
Args:
|
||||
tokenizer: the tokenizer one or two.
|
||||
captions: a batch of texts.
|
||||
Returns: token's data as tensor.
|
||||
"""
|
||||
inputs = tokenizer(
|
||||
captions,
|
||||
max_length=tokenizer.model_max_length,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
return_tensors='pt')
|
||||
return inputs.input_ids
|
||||
|
||||
def compute_time_ids(self, original_size, crops_coords_top_left):
|
||||
target_size = (self.resolution, self.resolution)
|
||||
add_time_ids = list(original_size + crops_coords_top_left
|
||||
+ target_size)
|
||||
add_time_ids = torch.tensor([add_time_ids])
|
||||
add_time_ids = add_time_ids.to(self.device, dtype=self.weight_dtype)
|
||||
return add_time_ids
|
||||
|
||||
def encode_prompt(self,
|
||||
text_encoders,
|
||||
tokenizers,
|
||||
prompt,
|
||||
text_input_ids_list=None):
|
||||
prompt_embeds_list = []
|
||||
|
||||
for i, text_encoder in enumerate(text_encoders):
|
||||
if tokenizers is not None:
|
||||
tokenizer = tokenizers[i]
|
||||
text_input_ids = tokenize_prompt(tokenizer, prompt)
|
||||
else:
|
||||
assert text_input_ids_list is not None
|
||||
text_input_ids = text_input_ids_list[i]
|
||||
|
||||
prompt_embeds = text_encoder(
|
||||
text_input_ids.to(text_encoder.device),
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|
||||
prompt_embeds_list.append(prompt_embeds)
|
||||
|
||||
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
def preprocessing_data(self, text, target):
|
||||
train_crop = transforms.RandomCrop(self.resolution)
|
||||
train_resize = transforms.Resize(
|
||||
self.resolution,
|
||||
interpolation=transforms.InterpolationMode.BILINEAR)
|
||||
train_flip = transforms.RandomHorizontalFlip(p=1.0)
|
||||
image = target
|
||||
original_size = (image.size()[-1], image.size()[-2])
|
||||
image = train_resize(image)
|
||||
y1, x1, h, w = train_crop.get_params(
|
||||
image, (self.resolution, self.resolution))
|
||||
image = crop(image, y1, x1, h, w)
|
||||
if self.random_flip and random.random() < 0.5:
|
||||
# flip
|
||||
x1 = image.size()[-2] - x1
|
||||
image = train_flip(image)
|
||||
crop_top_left = (y1, x1)
|
||||
input_ids_one = self.tokenize_caption(self.tokenizer_one, text)
|
||||
input_ids_two = self.tokenize_caption(self.tokenizer_two, text)
|
||||
|
||||
return original_size, crop_top_left, image, input_ids_one, input_ids_two
|
||||
|
||||
def forward(self, text='', target=None):
|
||||
self.unet.train()
|
||||
self.unet = self.unet.to(self.device)
|
||||
|
||||
# processing data
|
||||
original_size, crop_top_left, image, input_ids_one, input_ids_two = self.preprocessing_data(
|
||||
text, target)
|
||||
# Convert to latent space
|
||||
with torch.no_grad():
|
||||
latents = self.vae.encode(
|
||||
target.to(dtype=self.weight_dtype)).latent_dist.sample()
|
||||
latents = latents * self.vae.config.scaling_factor
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
0,
|
||||
self.noise_scheduler.num_train_timesteps, (bsz, ),
|
||||
device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = self.noise_scheduler.add_noise(latents, noise,
|
||||
timesteps)
|
||||
|
||||
add_time_ids = self.compute_time_ids(original_size, crop_top_left)
|
||||
|
||||
# Predict the noise residual
|
||||
unet_added_conditions = {'time_ids': add_time_ids}
|
||||
prompt_embeds, pooled_prompt_embeds = self.encode_prompt(
|
||||
text_encoders=[self.text_encoder_one, self.text_encoder_two],
|
||||
tokenizers=None,
|
||||
prompt=None,
|
||||
text_input_ids_list=[input_ids_one, input_ids_two])
|
||||
unet_added_conditions.update({'text_embeds': pooled_prompt_embeds})
|
||||
# Predict the noise residual and compute loss
|
||||
model_pred = self.unet(
|
||||
noisy_latents,
|
||||
timesteps,
|
||||
prompt_embeds,
|
||||
added_cond_kwargs=unet_added_conditions).sample
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if self.noise_scheduler.config.prediction_type == 'epsilon':
|
||||
target = noise
|
||||
elif self.noise_scheduler.config.prediction_type == 'v_prediction':
|
||||
target = self.noise_scheduler.get_velocity(model_input, noise,
|
||||
timesteps)
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Unknown prediction type {self.noise_scheduler.config.prediction_type}'
|
||||
)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction='mean')
|
||||
|
||||
output = {OutputKeys.LOSS: loss}
|
||||
return output
|
||||
|
||||
def save_pretrained(self,
|
||||
target_folder: Union[str, os.PathLike],
|
||||
save_checkpoint_names: Union[str, List[str]] = None,
|
||||
save_function: Callable = partial(
|
||||
save_checkpoint, with_meta=False),
|
||||
config: Optional[dict] = None,
|
||||
save_config_function: Callable = save_configuration,
|
||||
**kwargs):
|
||||
config['pipeline']['type'] = 'diffusers-stable-diffusion-xl'
|
||||
# Skip copying the original weights for lora and dreambooth method
|
||||
if self.lora_tune or self.dreambooth_tune:
|
||||
pass
|
||||
else:
|
||||
super().save_pretrained(target_folder, save_checkpoint_names,
|
||||
save_function, config,
|
||||
save_config_function, **kwargs)
|
||||
@@ -0,0 +1,2 @@
|
||||
# Copyright © Alibaba, Inc. and its affiliates.
|
||||
from .lora_diffusion_xl_trainer import LoraDiffusionXLTrainer
|
||||
@@ -0,0 +1,122 @@
|
||||
# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
from diffusers.loaders import AttnProcsLayers
|
||||
from diffusers.models.attention_processor import (LoRAAttnProcessor,
|
||||
LoRAAttnProcessor2_0)
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.trainers.builder import TRAINERS
|
||||
from modelscope.trainers.hooks.checkpoint.checkpoint_hook import CheckpointHook
|
||||
from modelscope.trainers.hooks.checkpoint.checkpoint_processor import \
|
||||
CheckpointProcessor
|
||||
from modelscope.trainers.optimizer.builder import build_optimizer
|
||||
from modelscope.trainers.trainer import EpochBasedTrainer
|
||||
from modelscope.utils.config import ConfigDict
|
||||
|
||||
|
||||
def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]:
|
||||
"""
|
||||
Returns:
|
||||
a state dict containing just the attention processor parameters.
|
||||
"""
|
||||
attn_processors = unet.attn_processors
|
||||
|
||||
attn_processors_state_dict = {}
|
||||
|
||||
for attn_processor_key, attn_processor in attn_processors.items():
|
||||
for parameter_key, parameter in attn_processor.state_dict().items():
|
||||
attn_processors_state_dict[
|
||||
f'{attn_processor_key}.{parameter_key}'] = parameter
|
||||
|
||||
return attn_processors_state_dict
|
||||
|
||||
|
||||
class LoraDiffusionXLCheckpointProcessor(CheckpointProcessor):
|
||||
|
||||
def save_checkpoints(self,
|
||||
trainer,
|
||||
checkpoint_path_prefix,
|
||||
output_dir,
|
||||
meta=None,
|
||||
save_optimizers=True):
|
||||
"""Save the state dict for lora tune stable diffusion xl model.
|
||||
"""
|
||||
attn_processors = trainer.model.unet.attn_processors
|
||||
unet_lora_layers_to_save = {}
|
||||
|
||||
for attn_processor_key, attn_processor in attn_processors.items():
|
||||
for parameter_key, parameter in attn_processor.state_dict().items(
|
||||
):
|
||||
unet_lora_layers_to_save[
|
||||
f'{attn_processor_key}.{parameter_key}'] = parameter
|
||||
|
||||
StableDiffusionXLPipeline.save_lora_weights(
|
||||
output_dir, unet_lora_layers=unet_lora_layers_to_save)
|
||||
|
||||
|
||||
@TRAINERS.register_module(module_name=Trainers.lora_diffusion_xl)
|
||||
class LoraDiffusionXLTrainer(EpochBasedTrainer):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
"""Lora trainers for fine-tuning stable diffusion xl.
|
||||
|
||||
Args:
|
||||
lora_rank: The rank size of lora intermediate linear.
|
||||
|
||||
"""
|
||||
lora_rank = kwargs.pop('lora_rank', 16)
|
||||
|
||||
# set lora save checkpoint processor
|
||||
ckpt_hook = list(
|
||||
filter(lambda hook: isinstance(hook, CheckpointHook),
|
||||
self.hooks))[0]
|
||||
ckpt_hook.set_processor(LoraDiffusionXLCheckpointProcessor())
|
||||
|
||||
# Add lora weights to attention layers and set correct lora layers
|
||||
unet_lora_attn_procs = {}
|
||||
unet_lora_parameters = []
|
||||
for name, attn_processor in self.model.unet.attn_processors.items():
|
||||
cross_attention_dim = None if name.endswith(
|
||||
'attn1.processor'
|
||||
) else self.model.unet.config.cross_attention_dim
|
||||
if name.startswith('mid_block'):
|
||||
hidden_size = self.model.unet.config.block_out_channels[-1]
|
||||
elif name.startswith('up_blocks'):
|
||||
block_id = int(name[len('up_blocks.')])
|
||||
hidden_size = list(
|
||||
reversed(
|
||||
self.model.unet.config.block_out_channels))[block_id]
|
||||
elif name.startswith('down_blocks'):
|
||||
block_id = int(name[len('down_blocks.')])
|
||||
hidden_size = self.model.unet.config.block_out_channels[
|
||||
block_id]
|
||||
|
||||
lora_attn_processor_class = (
|
||||
LoRAAttnProcessor2_0 if hasattr(
|
||||
F, 'scaled_dot_product_attention') else LoRAAttnProcessor)
|
||||
module = lora_attn_processor_class(
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
rank=lora_rank)
|
||||
unet_lora_attn_procs[name] = module
|
||||
unet_lora_parameters.extend(module.parameters())
|
||||
|
||||
self.model.unet.set_attn_processor(unet_lora_attn_procs)
|
||||
|
||||
def build_optimizer(self, cfg: ConfigDict, default_args: dict = None):
|
||||
try:
|
||||
return build_optimizer(
|
||||
self.model.unet.parameters(),
|
||||
cfg=cfg,
|
||||
default_args=default_args)
|
||||
except KeyError as e:
|
||||
self.logger.error(
|
||||
f'Build optimizer error, the optimizer {cfg} is a torch native component, '
|
||||
f'please check if your torch with version: {torch.__version__} matches the config.'
|
||||
)
|
||||
raise e
|
||||
@@ -1,7 +1,7 @@
|
||||
accelerate
|
||||
cloudpickle
|
||||
decord>=0.6.0
|
||||
diffusers==0.18.0
|
||||
diffusers>=0.19.0
|
||||
fairseq
|
||||
ftfy>=6.0.3
|
||||
librosa==0.9.2
|
||||
|
||||
89
tests/trainers/test_lora_diffusion_xl_trainer.py
Normal file
89
tests/trainers/test_lora_diffusion_xl_trainer.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.utils.constant import DownloadMode
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class TestLoraDiffusionXLTrainer(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
|
||||
self.train_dataset = MsDataset.load(
|
||||
'buptwq/lora-stable-diffusion-finetune',
|
||||
split='train',
|
||||
download_mode=DownloadMode.FORCE_REDOWNLOAD)
|
||||
self.eval_dataset = MsDataset.load(
|
||||
'buptwq/lora-stable-diffusion-finetune',
|
||||
split='validation',
|
||||
download_mode=DownloadMode.FORCE_REDOWNLOAD)
|
||||
|
||||
self.max_epochs = 5
|
||||
|
||||
self.tmp_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(self.tmp_dir):
|
||||
os.makedirs(self.tmp_dir)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmp_dir)
|
||||
super().tearDown()
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_lora_diffusion_xl_train(self):
|
||||
model_id = 'AI-ModelScope/stable-diffusion-xl-base-1.0'
|
||||
model_revision = 'v1.0.2'
|
||||
|
||||
def cfg_modify_fn(cfg):
|
||||
cfg.train.max_epochs = self.max_epochs
|
||||
cfg.train.lr_scheduler = {
|
||||
'type': 'LambdaLR',
|
||||
'lr_lambda': lambda _: 1,
|
||||
'last_epoch': -1
|
||||
}
|
||||
cfg.train.optimizer.lr = 1e-4
|
||||
return cfg
|
||||
|
||||
kwargs = dict(
|
||||
model=model_id,
|
||||
model_revision=model_revision,
|
||||
work_dir=self.tmp_dir,
|
||||
train_dataset=self.train_dataset,
|
||||
eval_dataset=self.eval_dataset,
|
||||
cfg_modify_fn=cfg_modify_fn)
|
||||
|
||||
trainer = build_trainer(
|
||||
name=Trainers.lora_diffusion, default_args=kwargs)
|
||||
trainer.train()
|
||||
result = trainer.evaluate()
|
||||
print(f'Lora-diffusion-xl train output: {result}.')
|
||||
|
||||
results_files = os.listdir(self.tmp_dir)
|
||||
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_lora_diffusion_xl_eval(self):
|
||||
model_id = 'AI-ModelScope/stable-diffusion-xl-base-1.0'
|
||||
model_revision = 'v1.0.2'
|
||||
|
||||
kwargs = dict(
|
||||
model=model_id,
|
||||
model_revision=model_revision,
|
||||
work_dir=self.tmp_dir,
|
||||
train_dataset=None,
|
||||
eval_dataset=self.eval_dataset)
|
||||
|
||||
trainer = build_trainer(
|
||||
name=Trainers.lora_diffusion_xl, default_args=kwargs)
|
||||
result = trainer.evaluate()
|
||||
print(f'Lora-diffusion-xl eval output: {result}.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user