support swift trainer and pipeline (#547)

* support swift trainer and pipeline

* support swift lora pipeline

* stable diffusion xl trainer

* tests sdxl

* fix diffusers attention

* swift support

* support swift sd

---------

Co-authored-by: 翊靖 <yijing.wq@alibaba-inc.com>
This commit is contained in:
Wang Qiang
2023-09-25 19:24:54 +08:00
committed by GitHub
parent 70fe158d13
commit ef97e3b0fe
5 changed files with 50 additions and 6 deletions

View File

@@ -158,9 +158,9 @@ class StableDiffusion(TorchModel):
config: Optional[dict] = None,
save_config_function: Callable = save_configuration,
**kwargs):
config['pipeline']['type'] = 'diffusers-stable-diffusion'
# Skip copying the original weights for lora and dreambooth method
if self.lora_tune or self.dreambooth_tune:
config['pipeline']['type'] = 'diffusers-stable-diffusion'
pass
else:
super().save_pretrained(target_folder, save_checkpoint_names,

View File

@@ -244,9 +244,9 @@ class StableDiffusionXL(TorchModel):
config: Optional[dict] = None,
save_config_function: Callable = save_configuration,
**kwargs):
config['pipeline']['type'] = 'diffusers-stable-diffusion-xl'
# Skip copying the original weights for lora and dreambooth method
if self.lora_tune or self.dreambooth_tune:
if self.lora_tune:
config['pipeline']['type'] = 'diffusers-stable-diffusion-xl'
pass
else:
super().save_pretrained(target_folder, save_checkpoint_names,

View File

@@ -12,7 +12,7 @@ import numpy as np
import torch
import torch.nn.functional as F
from diffusers import LMSDiscreteScheduler, StableDiffusionPipeline
from diffusers.models.cross_attention import CrossAttention
from diffusers.models.attention_processor import Attention
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import \
StableDiffusionPipelineOutput
from PIL import Image
@@ -245,7 +245,7 @@ class Cones2AttnProcessor:
super().__init__()
def __call__(self,
attn: CrossAttention,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None):

View File

@@ -17,6 +17,7 @@ from modelscope.pipelines.builder import PIPELINES
from modelscope.pipelines.multi_modal.diffusers_wrapped.diffusers_pipeline import \
DiffusersPipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.import_utils import is_swift_available
@PIPELINES.register_module(
@@ -38,9 +39,11 @@ class StableDiffusionPipeline(DiffusersPipeline):
custom_dir: custom diffusion weight dir for unet.
modifier_token: token to use as a modifier for the concept of custom diffusion.
use_safetensors: load safetensors weights.
use_swift: Whether to use swift lora dir for unet.
"""
use_safetensors = kwargs.pop('use_safetensors', False)
torch_type = kwargs.pop('torch_type', torch.float32)
use_swift = kwargs.pop('use_swift', False)
# check custom diffusion input value
if custom_dir is None and modifier_token is not None:
raise ValueError(
@@ -58,7 +61,17 @@ class StableDiffusionPipeline(DiffusersPipeline):
# load lora moudle to unet
if lora_dir is not None:
assert os.path.exists(lora_dir), f"{lora_dir} isn't exist"
self.pipeline.unet.load_attn_procs(lora_dir)
if use_swift:
if not is_swift_available():
raise ValueError(
'Please install swift by `pip install ms-swift` to use efficient_tuners.'
)
from swift import Swift
self.pipeline.unet = Swift.from_pretrained(
self.pipeline.unet, lora_dir)
else:
self.pipeline.unet.load_attn_procs(lora_dir)
# load custom diffusion to unet
if custom_dir is not None:
assert os.path.exists(custom_dir), f"{custom_dir} isn't exist"

View File

@@ -1,4 +1,5 @@
# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
import os
from typing import Union
import torch
@@ -7,16 +8,46 @@ from torch import nn
from modelscope.metainfo import Trainers
from modelscope.models.base import Model, TorchModel
from modelscope.trainers.builder import TRAINERS
from modelscope.trainers.hooks.checkpoint.checkpoint_hook import CheckpointHook
from modelscope.trainers.hooks.checkpoint.checkpoint_processor import \
CheckpointProcessor
from modelscope.trainers.optimizer.builder import build_optimizer
from modelscope.trainers.trainer import EpochBasedTrainer
from modelscope.utils.config import ConfigDict
class SwiftDiffusionCheckpointProcessor(CheckpointProcessor):
def save_checkpoints(self,
trainer,
checkpoint_path_prefix,
output_dir,
meta=None,
save_optimizers=True):
"""Save the state dict for swift lora tune model.
"""
trainer.model.unet.save_pretrained(os.path.join(output_dir))
@TRAINERS.register_module(module_name=Trainers.stable_diffusion)
class StableDiffusionTrainer(EpochBasedTrainer):
def __init__(self, *args, **kwargs):
"""Stable Diffusion trainers for fine-tuning.
Args:
use_swift: Whether to use swift.
"""
super().__init__(*args, **kwargs)
use_swift = kwargs.pop('use_swift', False)
# set swift lora save checkpoint processor
if use_swift:
ckpt_hook = list(
filter(lambda hook: isinstance(hook, CheckpointHook),
self.hooks))[0]
ckpt_hook.set_processor(SwiftDiffusionCheckpointProcessor())
def build_optimizer(self, cfg: ConfigDict, default_args: dict = None):
try: