diff --git a/examples/pytorch/stable_diffusion/dreambooth/finetune_stable_diffusion_dreambooth.py b/examples/pytorch/stable_diffusion/dreambooth/finetune_stable_diffusion_dreambooth.py index 5659a105..a6bf10d4 100644 --- a/examples/pytorch/stable_diffusion/dreambooth/finetune_stable_diffusion_dreambooth.py +++ b/examples/pytorch/stable_diffusion/dreambooth/finetune_stable_diffusion_dreambooth.py @@ -7,7 +7,7 @@ import torch from modelscope.metainfo import Trainers from modelscope.msdatasets import MsDataset from modelscope.pipelines import pipeline -from modelscope.trainers import EpochBasedTrainer, build_trainer +from modelscope.trainers import build_trainer from modelscope.trainers.training_args import TrainingArgs from modelscope.utils.constant import DownloadMode, Tasks diff --git a/examples/pytorch/stable_diffusion/lora/finetune_stable_diffusion_lora.py b/examples/pytorch/stable_diffusion/lora/finetune_stable_diffusion_lora.py index f0c40be7..97bfad02 100644 --- a/examples/pytorch/stable_diffusion/lora/finetune_stable_diffusion_lora.py +++ b/examples/pytorch/stable_diffusion/lora/finetune_stable_diffusion_lora.py @@ -7,7 +7,7 @@ import torch from modelscope.metainfo import Trainers from modelscope.msdatasets import MsDataset from modelscope.pipelines import pipeline -from modelscope.trainers import EpochBasedTrainer, build_trainer +from modelscope.trainers import build_trainer from modelscope.trainers.training_args import TrainingArgs from modelscope.utils.constant import DownloadMode, Tasks diff --git a/examples/pytorch/stable_diffusion/lora_xl/finetune_stable_diffusion_xl_lora.py b/examples/pytorch/stable_diffusion/lora_xl/finetune_stable_diffusion_xl_lora.py new file mode 100644 index 00000000..42facfec --- /dev/null +++ b/examples/pytorch/stable_diffusion/lora_xl/finetune_stable_diffusion_xl_lora.py @@ -0,0 +1,85 @@ +import os +from dataclasses import dataclass, field + +import cv2 + +from modelscope.metainfo import Trainers +from modelscope.msdatasets import MsDataset +from modelscope.pipelines import pipeline +from modelscope.trainers import build_trainer +from modelscope.trainers.training_args import TrainingArgs +from modelscope.utils.constant import DownloadMode, Tasks + + +# Load configuration file and dataset +@dataclass(init=False) +class StableDiffusionXLLoraArguments(TrainingArgs): + prompt: str = field( + default='dog', metadata={ + 'help': 'The pipeline prompt.', + }) + + lora_rank: int = field( + default=16, + metadata={ + 'help': 'The rank size of lora intermediate linear.', + }) + + +training_args = StableDiffusionXLLoraArguments( + task='text-to-image-synthesis').parse_cli() +config, args = training_args.to_config() + +if os.path.exists(args.train_dataset_name): + # Load local dataset + train_dataset = MsDataset.load(args.train_dataset_name) + validation_dataset = MsDataset.load(args.train_dataset_name) +else: + # Load online dataset + train_dataset = MsDataset.load( + args.train_dataset_name, + split='train', + download_mode=DownloadMode.FORCE_REDOWNLOAD) + validation_dataset = MsDataset.load( + args.train_dataset_name, + split='validation', + download_mode=DownloadMode.FORCE_REDOWNLOAD) + + +def cfg_modify_fn(cfg): + if args.use_model_config: + cfg.merge_from_dict(config) + else: + cfg = config + cfg.train.lr_scheduler = { + 'type': 'LambdaLR', + 'lr_lambda': lambda _: 1, + 'last_epoch': -1 + } + return cfg + + +kwargs = dict( + model=training_args.model, + model_revision=args.model_revision, + work_dir=training_args.work_dir, + train_dataset=train_dataset, + eval_dataset=validation_dataset, + lora_rank=args.lora_rank, + cfg_modify_fn=cfg_modify_fn) + +# build trainer and training +trainer = build_trainer(name=Trainers.lora_diffusion_xl, default_args=kwargs) +trainer.train() + +# pipeline after training and save result +pipe = pipeline( + task=Tasks.text_to_image_synthesis, + model=training_args.model, + lora_dir=training_args.work_dir + '/output', + model_revision=args.model_revision) + +output = pipe({'text': args.prompt}) +# visualize the result on ipynb and save it +output +cv2.imwrite('./lora_xl_result.png', output['output_imgs'][0]) diff --git a/examples/pytorch/stable_diffusion/lora_xl/run_train_xl_lora.sh b/examples/pytorch/stable_diffusion/lora_xl/run_train_xl_lora.sh new file mode 100644 index 00000000..fc7704a4 --- /dev/null +++ b/examples/pytorch/stable_diffusion/lora_xl/run_train_xl_lora.sh @@ -0,0 +1,14 @@ +PYTHONPATH=. torchrun examples/pytorch/stable_diffusion/lora_xl/finetune_stable_diffusion_xl_lora.py \ + --model 'AI-ModelScope/stable-diffusion-xl-base-1.0' \ + --model_revision 'v1.0.2' \ + --prompt "a dog" \ + --work_dir './tmp/lora_diffusion_xl' \ + --train_dataset_name 'buptwq/lora-stable-diffusion-finetune' \ + --max_epochs 100 \ + --lora_rank 16 \ + --save_ckpt_strategy 'by_epoch' \ + --logging_interval 1 \ + --train.dataloader.workers_per_gpu 0 \ + --evaluation.dataloader.workers_per_gpu 0 \ + --train.optimizer.lr 1e-4 \ + --use_model_config true diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 630d4aa5..298c93d4 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -218,6 +218,7 @@ class Models(object): mplug_owl = 'mplug-owl' clip_interrogator = 'clip-interrogator' stable_diffusion = 'stable-diffusion' + stable_diffusion_xl = 'stable-diffusion-xl' videocomposer = 'videocomposer' text_to_360panorama_image = 'text-to-360panorama-image' image_to_video_model = 'image-to-video-model' @@ -938,6 +939,7 @@ class MultiModalTrainers(object): efficient_diffusion_tuning = 'efficient-diffusion-tuning' stable_diffusion = 'stable-diffusion' lora_diffusion = 'lora-diffusion' + lora_diffusion_xl = 'lora-diffusion-xl' dreambooth_diffusion = 'dreambooth-diffusion' custom_diffusion = 'custom-diffusion' diff --git a/modelscope/models/multi_modal/stable_diffusion/__init__.py b/modelscope/models/multi_modal/stable_diffusion/__init__.py index 9ffff12a..7f0b184b 100644 --- a/modelscope/models/multi_modal/stable_diffusion/__init__.py +++ b/modelscope/models/multi_modal/stable_diffusion/__init__.py @@ -1,2 +1,3 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from .stable_diffusion import StableDiffusion +from .stable_diffusion_xl import StableDiffusionXL diff --git a/modelscope/models/multi_modal/stable_diffusion/stable_diffusion.py b/modelscope/models/multi_modal/stable_diffusion/stable_diffusion.py index 80d8ab28..6267fb9d 100644 --- a/modelscope/models/multi_modal/stable_diffusion/stable_diffusion.py +++ b/modelscope/models/multi_modal/stable_diffusion/stable_diffusion.py @@ -20,7 +20,7 @@ from modelscope.utils.constant import Tasks @MODELS.register_module( Tasks.text_to_image_synthesis, module_name=Models.stable_diffusion) class StableDiffusion(TorchModel): - """ The implementation of efficient diffusion tuning model based on TorchModel. + """ The implementation of stable diffusion model based on TorchModel. This model is constructed with the implementation of stable diffusion model. If you want to finetune lightweight parameters on your own dataset, you can define you own tuner module @@ -28,7 +28,7 @@ class StableDiffusion(TorchModel): """ def __init__(self, model_dir, *args, **kwargs): - """ Initialize a vision efficient diffusion tuning model. + """ Initialize a vision stable diffusion model. Args: model_dir: model id or path diff --git a/modelscope/models/multi_modal/stable_diffusion/stable_diffusion_xl.py b/modelscope/models/multi_modal/stable_diffusion/stable_diffusion_xl.py new file mode 100644 index 00000000..23ad6676 --- /dev/null +++ b/modelscope/models/multi_modal/stable_diffusion/stable_diffusion_xl.py @@ -0,0 +1,254 @@ +# Copyright 2023-2024 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import os +import random +from functools import partial +from typing import Callable, List, Optional, Union + +import torch +import torch.nn.functional as F +from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel +from packaging import version +from torchvision import transforms +from torchvision.transforms.functional import crop +from transformers import (AutoTokenizer, CLIPTextModel, + CLIPTextModelWithProjection) + +from modelscope.metainfo import Models +from modelscope.models import TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import OutputKeys +from modelscope.utils.checkpoint import save_checkpoint, save_configuration +from modelscope.utils.constant import Tasks + + +@MODELS.register_module( + Tasks.text_to_image_synthesis, module_name=Models.stable_diffusion_xl) +class StableDiffusionXL(TorchModel): + """ The implementation of stable diffusion xl model based on TorchModel. + + This model is constructed with the implementation of stable diffusion xl model. If you want to + finetune lightweight parameters on your own dataset, you can define you own tuner module + and load in this cls. + """ + + def __init__(self, model_dir, *args, **kwargs): + """ Initialize a vision stable diffusion xl model. + + Args: + model_dir: model id or path + """ + super().__init__(model_dir, *args, **kwargs) + revision = kwargs.pop('revision', None) + xformers_enable = kwargs.pop('xformers_enable', False) + self.lora_tune = kwargs.pop('lora_tune', False) + self.resolution = kwargs.pop('resolution', 1024) + self.random_flip = kwargs.pop('random_flip', True) + + self.weight_dtype = torch.float32 + self.device = torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu') + + # Load scheduler, tokenizer and models + self.noise_scheduler = DDPMScheduler.from_pretrained( + model_dir, subfolder='scheduler') + self.tokenizer_one = AutoTokenizer.from_pretrained( + model_dir, + subfolder='tokenizer', + revision=revision, + use_fast=False) + self.tokenizer_two = AutoTokenizer.from_pretrained( + model_dir, + subfolder='tokenizer_2', + revision=revision, + use_fast=False) + self.text_encoder_one = CLIPTextModel.from_pretrained( + model_dir, subfolder='text_encoder', revision=revision) + self.text_encoder_two = CLIPTextModelWithProjection.from_pretrained( + model_dir, subfolder='text_encoder_2', revision=revision) + self.vae = AutoencoderKL.from_pretrained( + model_dir, subfolder='vae', revision=revision) + self.unet = UNet2DConditionModel.from_pretrained( + model_dir, subfolder='unet', revision=revision) + self.safety_checker = None + + # Freeze gradient calculation and move to device + if self.vae is not None: + self.vae.requires_grad_(False) + self.vae = self.vae.to(self.device) + if self.text_encoder_one is not None: + self.text_encoder_one.requires_grad_(False) + self.text_encoder_one = self.text_encoder_one.to(self.device) + if self.text_encoder_two is not None: + self.text_encoder_two.requires_grad_(False) + self.text_encoder_two = self.text_encoder_two.to(self.device) + if self.unet is not None: + if self.lora_tune: + self.unet.requires_grad_(False) + self.unet = self.unet.to(self.device) + + # xformers accelerate memory efficient attention + if xformers_enable: + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse('0.0.16'): + logger.warn( + 'xFormers 0.0.16 cannot be used for training in some GPUs. ' + 'If you observe problems during training, please update xFormers to at least 0.0.17.' + ) + self.unet.enable_xformers_memory_efficient_attention() + + def tokenize_caption(self, tokenizer, captions): + """ Convert caption text to token data. + + Args: + tokenizer: the tokenizer one or two. + captions: a batch of texts. + Returns: token's data as tensor. + """ + inputs = tokenizer( + captions, + max_length=tokenizer.model_max_length, + padding='max_length', + truncation=True, + return_tensors='pt') + return inputs.input_ids + + def compute_time_ids(self, original_size, crops_coords_top_left): + target_size = (self.resolution, self.resolution) + add_time_ids = list(original_size + crops_coords_top_left + + target_size) + add_time_ids = torch.tensor([add_time_ids]) + add_time_ids = add_time_ids.to(self.device, dtype=self.weight_dtype) + return add_time_ids + + def encode_prompt(self, + text_encoders, + tokenizers, + prompt, + text_input_ids_list=None): + prompt_embeds_list = [] + + for i, text_encoder in enumerate(text_encoders): + if tokenizers is not None: + tokenizer = tokenizers[i] + text_input_ids = tokenize_prompt(tokenizer, prompt) + else: + assert text_input_ids_list is not None + text_input_ids = text_input_ids_list[i] + + prompt_embeds = text_encoder( + text_input_ids.to(text_encoder.device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) + return prompt_embeds, pooled_prompt_embeds + + def preprocessing_data(self, text, target): + train_crop = transforms.RandomCrop(self.resolution) + train_resize = transforms.Resize( + self.resolution, + interpolation=transforms.InterpolationMode.BILINEAR) + train_flip = transforms.RandomHorizontalFlip(p=1.0) + image = target + original_size = (image.size()[-1], image.size()[-2]) + image = train_resize(image) + y1, x1, h, w = train_crop.get_params( + image, (self.resolution, self.resolution)) + image = crop(image, y1, x1, h, w) + if self.random_flip and random.random() < 0.5: + # flip + x1 = image.size()[-2] - x1 + image = train_flip(image) + crop_top_left = (y1, x1) + input_ids_one = self.tokenize_caption(self.tokenizer_one, text) + input_ids_two = self.tokenize_caption(self.tokenizer_two, text) + + return original_size, crop_top_left, image, input_ids_one, input_ids_two + + def forward(self, text='', target=None): + self.unet.train() + self.unet = self.unet.to(self.device) + + # processing data + original_size, crop_top_left, image, input_ids_one, input_ids_two = self.preprocessing_data( + text, target) + # Convert to latent space + with torch.no_grad(): + latents = self.vae.encode( + target.to(dtype=self.weight_dtype)).latent_dist.sample() + latents = latents * self.vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint( + 0, + self.noise_scheduler.num_train_timesteps, (bsz, ), + device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = self.noise_scheduler.add_noise(latents, noise, + timesteps) + + add_time_ids = self.compute_time_ids(original_size, crop_top_left) + + # Predict the noise residual + unet_added_conditions = {'time_ids': add_time_ids} + prompt_embeds, pooled_prompt_embeds = self.encode_prompt( + text_encoders=[self.text_encoder_one, self.text_encoder_two], + tokenizers=None, + prompt=None, + text_input_ids_list=[input_ids_one, input_ids_two]) + unet_added_conditions.update({'text_embeds': pooled_prompt_embeds}) + # Predict the noise residual and compute loss + model_pred = self.unet( + noisy_latents, + timesteps, + prompt_embeds, + added_cond_kwargs=unet_added_conditions).sample + + # Get the target for loss depending on the prediction type + if self.noise_scheduler.config.prediction_type == 'epsilon': + target = noise + elif self.noise_scheduler.config.prediction_type == 'v_prediction': + target = self.noise_scheduler.get_velocity(model_input, noise, + timesteps) + else: + raise ValueError( + f'Unknown prediction type {self.noise_scheduler.config.prediction_type}' + ) + + loss = F.mse_loss(model_pred.float(), target.float(), reduction='mean') + + output = {OutputKeys.LOSS: loss} + return output + + def save_pretrained(self, + target_folder: Union[str, os.PathLike], + save_checkpoint_names: Union[str, List[str]] = None, + save_function: Callable = partial( + save_checkpoint, with_meta=False), + config: Optional[dict] = None, + save_config_function: Callable = save_configuration, + **kwargs): + config['pipeline']['type'] = 'diffusers-stable-diffusion-xl' + # Skip copying the original weights for lora and dreambooth method + if self.lora_tune or self.dreambooth_tune: + pass + else: + super().save_pretrained(target_folder, save_checkpoint_names, + save_function, config, + save_config_function, **kwargs) diff --git a/modelscope/trainers/multi_modal/lora_diffusion_xl/__init__.py b/modelscope/trainers/multi_modal/lora_diffusion_xl/__init__.py new file mode 100644 index 00000000..9fe65472 --- /dev/null +++ b/modelscope/trainers/multi_modal/lora_diffusion_xl/__init__.py @@ -0,0 +1,2 @@ +# Copyright © Alibaba, Inc. and its affiliates. +from .lora_diffusion_xl_trainer import LoraDiffusionXLTrainer diff --git a/modelscope/trainers/multi_modal/lora_diffusion_xl/lora_diffusion_xl_trainer.py b/modelscope/trainers/multi_modal/lora_diffusion_xl/lora_diffusion_xl_trainer.py new file mode 100644 index 00000000..470d88e8 --- /dev/null +++ b/modelscope/trainers/multi_modal/lora_diffusion_xl/lora_diffusion_xl_trainer.py @@ -0,0 +1,122 @@ +# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from typing import Dict, Union + +import torch +import torch.nn.functional as F +from diffusers import StableDiffusionXLPipeline +from diffusers.loaders import AttnProcsLayers +from diffusers.models.attention_processor import (LoRAAttnProcessor, + LoRAAttnProcessor2_0) + +from modelscope.metainfo import Trainers +from modelscope.trainers.builder import TRAINERS +from modelscope.trainers.hooks.checkpoint.checkpoint_hook import CheckpointHook +from modelscope.trainers.hooks.checkpoint.checkpoint_processor import \ + CheckpointProcessor +from modelscope.trainers.optimizer.builder import build_optimizer +from modelscope.trainers.trainer import EpochBasedTrainer +from modelscope.utils.config import ConfigDict + + +def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]: + """ + Returns: + a state dict containing just the attention processor parameters. + """ + attn_processors = unet.attn_processors + + attn_processors_state_dict = {} + + for attn_processor_key, attn_processor in attn_processors.items(): + for parameter_key, parameter in attn_processor.state_dict().items(): + attn_processors_state_dict[ + f'{attn_processor_key}.{parameter_key}'] = parameter + + return attn_processors_state_dict + + +class LoraDiffusionXLCheckpointProcessor(CheckpointProcessor): + + def save_checkpoints(self, + trainer, + checkpoint_path_prefix, + output_dir, + meta=None, + save_optimizers=True): + """Save the state dict for lora tune stable diffusion xl model. + """ + attn_processors = trainer.model.unet.attn_processors + unet_lora_layers_to_save = {} + + for attn_processor_key, attn_processor in attn_processors.items(): + for parameter_key, parameter in attn_processor.state_dict().items( + ): + unet_lora_layers_to_save[ + f'{attn_processor_key}.{parameter_key}'] = parameter + + StableDiffusionXLPipeline.save_lora_weights( + output_dir, unet_lora_layers=unet_lora_layers_to_save) + + +@TRAINERS.register_module(module_name=Trainers.lora_diffusion_xl) +class LoraDiffusionXLTrainer(EpochBasedTrainer): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + """Lora trainers for fine-tuning stable diffusion xl. + + Args: + lora_rank: The rank size of lora intermediate linear. + + """ + lora_rank = kwargs.pop('lora_rank', 16) + + # set lora save checkpoint processor + ckpt_hook = list( + filter(lambda hook: isinstance(hook, CheckpointHook), + self.hooks))[0] + ckpt_hook.set_processor(LoraDiffusionXLCheckpointProcessor()) + + # Add lora weights to attention layers and set correct lora layers + unet_lora_attn_procs = {} + unet_lora_parameters = [] + for name, attn_processor in self.model.unet.attn_processors.items(): + cross_attention_dim = None if name.endswith( + 'attn1.processor' + ) else self.model.unet.config.cross_attention_dim + if name.startswith('mid_block'): + hidden_size = self.model.unet.config.block_out_channels[-1] + elif name.startswith('up_blocks'): + block_id = int(name[len('up_blocks.')]) + hidden_size = list( + reversed( + self.model.unet.config.block_out_channels))[block_id] + elif name.startswith('down_blocks'): + block_id = int(name[len('down_blocks.')]) + hidden_size = self.model.unet.config.block_out_channels[ + block_id] + + lora_attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr( + F, 'scaled_dot_product_attention') else LoRAAttnProcessor) + module = lora_attn_processor_class( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + rank=lora_rank) + unet_lora_attn_procs[name] = module + unet_lora_parameters.extend(module.parameters()) + + self.model.unet.set_attn_processor(unet_lora_attn_procs) + + def build_optimizer(self, cfg: ConfigDict, default_args: dict = None): + try: + return build_optimizer( + self.model.unet.parameters(), + cfg=cfg, + default_args=default_args) + except KeyError as e: + self.logger.error( + f'Build optimizer error, the optimizer {cfg} is a torch native component, ' + f'please check if your torch with version: {torch.__version__} matches the config.' + ) + raise e diff --git a/requirements/multi-modal.txt b/requirements/multi-modal.txt index 4cf4d19a..59415bb0 100644 --- a/requirements/multi-modal.txt +++ b/requirements/multi-modal.txt @@ -1,7 +1,7 @@ accelerate cloudpickle decord>=0.6.0 -diffusers==0.18.0 +diffusers>=0.19.0 fairseq ftfy>=6.0.3 librosa==0.9.2 diff --git a/tests/trainers/test_lora_diffusion_xl_trainer.py b/tests/trainers/test_lora_diffusion_xl_trainer.py new file mode 100644 index 00000000..da780b5d --- /dev/null +++ b/tests/trainers/test_lora_diffusion_xl_trainer.py @@ -0,0 +1,89 @@ +# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import os +import shutil +import tempfile +import unittest + +from modelscope.metainfo import Trainers +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.constant import DownloadMode +from modelscope.utils.test_utils import test_level + + +class TestLoraDiffusionXLTrainer(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + + self.train_dataset = MsDataset.load( + 'buptwq/lora-stable-diffusion-finetune', + split='train', + download_mode=DownloadMode.FORCE_REDOWNLOAD) + self.eval_dataset = MsDataset.load( + 'buptwq/lora-stable-diffusion-finetune', + split='validation', + download_mode=DownloadMode.FORCE_REDOWNLOAD) + + self.max_epochs = 5 + + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_lora_diffusion_xl_train(self): + model_id = 'AI-ModelScope/stable-diffusion-xl-base-1.0' + model_revision = 'v1.0.2' + + def cfg_modify_fn(cfg): + cfg.train.max_epochs = self.max_epochs + cfg.train.lr_scheduler = { + 'type': 'LambdaLR', + 'lr_lambda': lambda _: 1, + 'last_epoch': -1 + } + cfg.train.optimizer.lr = 1e-4 + return cfg + + kwargs = dict( + model=model_id, + model_revision=model_revision, + work_dir=self.tmp_dir, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + cfg_modify_fn=cfg_modify_fn) + + trainer = build_trainer( + name=Trainers.lora_diffusion, default_args=kwargs) + trainer.train() + result = trainer.evaluate() + print(f'Lora-diffusion-xl train output: {result}.') + + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_lora_diffusion_xl_eval(self): + model_id = 'AI-ModelScope/stable-diffusion-xl-base-1.0' + model_revision = 'v1.0.2' + + kwargs = dict( + model=model_id, + model_revision=model_revision, + work_dir=self.tmp_dir, + train_dataset=None, + eval_dataset=self.eval_dataset) + + trainer = build_trainer( + name=Trainers.lora_diffusion_xl, default_args=kwargs) + result = trainer.evaluate() + print(f'Lora-diffusion-xl eval output: {result}.') + + +if __name__ == '__main__': + unittest.main()