mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 20:49:37 +01:00
support swift trainer and pipeline (#547)
* support swift trainer and pipeline * support swift lora pipeline * stable diffusion xl trainer * tests sdxl * fix diffusers attention * swift support * support swift sd --------- Co-authored-by: 翊靖 <yijing.wq@alibaba-inc.com>
This commit is contained in:
@@ -158,9 +158,9 @@ class StableDiffusion(TorchModel):
|
||||
config: Optional[dict] = None,
|
||||
save_config_function: Callable = save_configuration,
|
||||
**kwargs):
|
||||
config['pipeline']['type'] = 'diffusers-stable-diffusion'
|
||||
# Skip copying the original weights for lora and dreambooth method
|
||||
if self.lora_tune or self.dreambooth_tune:
|
||||
config['pipeline']['type'] = 'diffusers-stable-diffusion'
|
||||
pass
|
||||
else:
|
||||
super().save_pretrained(target_folder, save_checkpoint_names,
|
||||
|
||||
@@ -244,9 +244,9 @@ class StableDiffusionXL(TorchModel):
|
||||
config: Optional[dict] = None,
|
||||
save_config_function: Callable = save_configuration,
|
||||
**kwargs):
|
||||
config['pipeline']['type'] = 'diffusers-stable-diffusion-xl'
|
||||
# Skip copying the original weights for lora and dreambooth method
|
||||
if self.lora_tune or self.dreambooth_tune:
|
||||
if self.lora_tune:
|
||||
config['pipeline']['type'] = 'diffusers-stable-diffusion-xl'
|
||||
pass
|
||||
else:
|
||||
super().save_pretrained(target_folder, save_checkpoint_names,
|
||||
|
||||
@@ -12,7 +12,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from diffusers import LMSDiscreteScheduler, StableDiffusionPipeline
|
||||
from diffusers.models.cross_attention import CrossAttention
|
||||
from diffusers.models.attention_processor import Attention
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import \
|
||||
StableDiffusionPipelineOutput
|
||||
from PIL import Image
|
||||
@@ -245,7 +245,7 @@ class Cones2AttnProcessor:
|
||||
super().__init__()
|
||||
|
||||
def __call__(self,
|
||||
attn: CrossAttention,
|
||||
attn: Attention,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None):
|
||||
|
||||
@@ -17,6 +17,7 @@ from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.pipelines.multi_modal.diffusers_wrapped.diffusers_pipeline import \
|
||||
DiffusersPipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.import_utils import is_swift_available
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
@@ -38,9 +39,11 @@ class StableDiffusionPipeline(DiffusersPipeline):
|
||||
custom_dir: custom diffusion weight dir for unet.
|
||||
modifier_token: token to use as a modifier for the concept of custom diffusion.
|
||||
use_safetensors: load safetensors weights.
|
||||
use_swift: Whether to use swift lora dir for unet.
|
||||
"""
|
||||
use_safetensors = kwargs.pop('use_safetensors', False)
|
||||
torch_type = kwargs.pop('torch_type', torch.float32)
|
||||
use_swift = kwargs.pop('use_swift', False)
|
||||
# check custom diffusion input value
|
||||
if custom_dir is None and modifier_token is not None:
|
||||
raise ValueError(
|
||||
@@ -58,7 +61,17 @@ class StableDiffusionPipeline(DiffusersPipeline):
|
||||
# load lora moudle to unet
|
||||
if lora_dir is not None:
|
||||
assert os.path.exists(lora_dir), f"{lora_dir} isn't exist"
|
||||
self.pipeline.unet.load_attn_procs(lora_dir)
|
||||
if use_swift:
|
||||
if not is_swift_available():
|
||||
raise ValueError(
|
||||
'Please install swift by `pip install ms-swift` to use efficient_tuners.'
|
||||
)
|
||||
from swift import Swift
|
||||
self.pipeline.unet = Swift.from_pretrained(
|
||||
self.pipeline.unet, lora_dir)
|
||||
else:
|
||||
self.pipeline.unet.load_attn_procs(lora_dir)
|
||||
|
||||
# load custom diffusion to unet
|
||||
if custom_dir is not None:
|
||||
assert os.path.exists(custom_dir), f"{custom_dir} isn't exist"
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
@@ -7,16 +8,46 @@ from torch import nn
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.models.base import Model, TorchModel
|
||||
from modelscope.trainers.builder import TRAINERS
|
||||
from modelscope.trainers.hooks.checkpoint.checkpoint_hook import CheckpointHook
|
||||
from modelscope.trainers.hooks.checkpoint.checkpoint_processor import \
|
||||
CheckpointProcessor
|
||||
from modelscope.trainers.optimizer.builder import build_optimizer
|
||||
from modelscope.trainers.trainer import EpochBasedTrainer
|
||||
from modelscope.utils.config import ConfigDict
|
||||
|
||||
|
||||
class SwiftDiffusionCheckpointProcessor(CheckpointProcessor):
|
||||
|
||||
def save_checkpoints(self,
|
||||
trainer,
|
||||
checkpoint_path_prefix,
|
||||
output_dir,
|
||||
meta=None,
|
||||
save_optimizers=True):
|
||||
"""Save the state dict for swift lora tune model.
|
||||
"""
|
||||
trainer.model.unet.save_pretrained(os.path.join(output_dir))
|
||||
|
||||
|
||||
@TRAINERS.register_module(module_name=Trainers.stable_diffusion)
|
||||
class StableDiffusionTrainer(EpochBasedTrainer):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Stable Diffusion trainers for fine-tuning.
|
||||
|
||||
Args:
|
||||
use_swift: Whether to use swift.
|
||||
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
use_swift = kwargs.pop('use_swift', False)
|
||||
|
||||
# set swift lora save checkpoint processor
|
||||
if use_swift:
|
||||
ckpt_hook = list(
|
||||
filter(lambda hook: isinstance(hook, CheckpointHook),
|
||||
self.hooks))[0]
|
||||
ckpt_hook.set_processor(SwiftDiffusionCheckpointProcessor())
|
||||
|
||||
def build_optimizer(self, cfg: ConfigDict, default_args: dict = None):
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user