mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
Merge branch merge_master_github_0628 into master
Title: Merge branch 'master-github' into merge_master_github_0628 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/13104791
This commit is contained in:
@@ -0,0 +1,116 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import cv2
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.trainers import EpochBasedTrainer, build_trainer
|
||||
from modelscope.trainers.training_args import TrainingArgs
|
||||
from modelscope.utils.constant import DownloadMode, Tasks
|
||||
|
||||
|
||||
# Load configuration file and dataset
|
||||
@dataclass(init=False)
|
||||
class StableDiffusionDreamboothArguments(TrainingArgs):
|
||||
with_prior_preservation: bool = field(
|
||||
default=False, metadata={
|
||||
'help': 'Whether to enable prior loss.',
|
||||
})
|
||||
|
||||
instance_prompt: str = field(
|
||||
default='a photo of sks dog',
|
||||
metadata={
|
||||
'help': 'The instance prompt for dreambooth.',
|
||||
})
|
||||
|
||||
class_prompt: str = field(
|
||||
default='a photo of dog',
|
||||
metadata={
|
||||
'help': 'The class prompt for dreambooth.',
|
||||
})
|
||||
|
||||
class_data_dir: str = field(
|
||||
default='./tmp/class_data',
|
||||
metadata={
|
||||
'help': 'Save class prompt images path.',
|
||||
})
|
||||
|
||||
num_class_images: int = field(
|
||||
default=200,
|
||||
metadata={
|
||||
'help': 'The numbers of saving class images.',
|
||||
})
|
||||
|
||||
resolution: int = field(
|
||||
default=512, metadata={
|
||||
'help': 'The class images resolution.',
|
||||
})
|
||||
|
||||
prior_loss_weight: float = field(
|
||||
default=1.0,
|
||||
metadata={
|
||||
'help': 'The weight of instance and prior loss.',
|
||||
})
|
||||
|
||||
prompt: str = field(
|
||||
default='dog', metadata={
|
||||
'help': 'The pipeline prompt.',
|
||||
})
|
||||
|
||||
|
||||
training_args = StableDiffusionDreamboothArguments(
|
||||
task='text-to-image-synthesis').parse_cli()
|
||||
config, args = training_args.to_config()
|
||||
|
||||
train_dataset = MsDataset.load(
|
||||
args.train_dataset_name,
|
||||
split='train',
|
||||
download_mode=DownloadMode.FORCE_REDOWNLOAD)
|
||||
validation_dataset = MsDataset.load(
|
||||
args.train_dataset_name,
|
||||
split='validation',
|
||||
download_mode=DownloadMode.FORCE_REDOWNLOAD)
|
||||
|
||||
|
||||
def cfg_modify_fn(cfg):
|
||||
if args.use_model_config:
|
||||
cfg.merge_from_dict(config)
|
||||
else:
|
||||
cfg = config
|
||||
cfg.train.lr_scheduler = {
|
||||
'type': 'LambdaLR',
|
||||
'lr_lambda': lambda _: 1,
|
||||
'last_epoch': -1
|
||||
}
|
||||
return cfg
|
||||
|
||||
|
||||
kwargs = dict(
|
||||
model=training_args.model,
|
||||
model_revision=args.model_revision,
|
||||
work_dir=training_args.work_dir,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=validation_dataset,
|
||||
instance_prompt=args.instance_prompt,
|
||||
class_prompt=args.class_prompt,
|
||||
class_data_dir=args.class_data_dir,
|
||||
num_class_images=args.num_class_images,
|
||||
resolution=args.resolution,
|
||||
prior_loss_weight=args.prior_loss_weight,
|
||||
prompt=args.prompt,
|
||||
cfg_modify_fn=cfg_modify_fn)
|
||||
|
||||
# build trainer and training
|
||||
trainer = build_trainer(
|
||||
name=Trainers.dreambooth_diffusion, default_args=kwargs)
|
||||
trainer.train()
|
||||
|
||||
# pipeline after training and save result
|
||||
pipe = pipeline(
|
||||
task=Tasks.text_to_image_synthesis,
|
||||
model=training_args.work_dir + '/output',
|
||||
model_revision=args.model_revision)
|
||||
|
||||
output = pipe({'text': args.prompt})
|
||||
cv2.imwrite('./dreambooth_result.png', output['output_imgs'][0])
|
||||
@@ -0,0 +1,20 @@
|
||||
PYTHONPATH=. torchrun examples/pytorch/stable_diffusion/dreambooth/finetune_stable_diffusion_dreambooth.py \
|
||||
--model 'AI-ModelScope/stable-diffusion-v1-5' \
|
||||
--model_revision 'v1.0.8' \
|
||||
--work_dir './tmp/dreambooth_diffusion' \
|
||||
--train_dataset_name 'buptwq/lora-stable-diffusion-finetune' \
|
||||
--with_prior_preservation false \
|
||||
--instance_prompt "a photo of sks dog" \
|
||||
--class_prompt "a photo of dog" \
|
||||
--class_data_dir "./tmp/class_data" \
|
||||
--num_class_images 200 \
|
||||
--resolution 512 \
|
||||
--prior_loss_weight 1.0 \
|
||||
--prompt "dog" \
|
||||
--max_epochs 150 \
|
||||
--save_ckpt_strategy 'by_epoch' \
|
||||
--logging_interval 1 \
|
||||
--train.dataloader.workers_per_gpu 0 \
|
||||
--evaluation.dataloader.workers_per_gpu 0 \
|
||||
--train.optimizer.lr 5e-6 \
|
||||
--use_model_config true
|
||||
@@ -1,12 +1,27 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import cv2
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.trainers import EpochBasedTrainer, build_trainer
|
||||
from modelscope.trainers.training_args import TrainingArgs
|
||||
from modelscope.utils.constant import DownloadMode
|
||||
from modelscope.utils.constant import DownloadMode, Tasks
|
||||
|
||||
training_args = TrainingArgs(task='text-to-image-synthesis').parse_cli()
|
||||
|
||||
# Load configuration file and dataset
|
||||
@dataclass(init=False)
|
||||
class StableDiffusionLoraArguments(TrainingArgs):
|
||||
prompt: str = field(
|
||||
default='dog', metadata={
|
||||
'help': 'The pipeline prompt.',
|
||||
})
|
||||
|
||||
|
||||
training_args = StableDiffusionLoraArguments(
|
||||
task='text-to-image-synthesis').parse_cli()
|
||||
config, args = training_args.to_config()
|
||||
print(args)
|
||||
|
||||
train_dataset = MsDataset.load(
|
||||
args.train_dataset_name,
|
||||
@@ -28,17 +43,27 @@ def cfg_modify_fn(cfg):
|
||||
'lr_lambda': lambda _: 1,
|
||||
'last_epoch': -1
|
||||
}
|
||||
cfg.train.optimizer.lr = 1e-4
|
||||
return cfg
|
||||
|
||||
|
||||
kwargs = dict(
|
||||
model=training_args.model,
|
||||
model_revision='v1.0.6',
|
||||
model_revision=args.model_revision,
|
||||
work_dir=training_args.work_dir,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=validation_dataset,
|
||||
cfg_modify_fn=cfg_modify_fn)
|
||||
|
||||
# build trainer and training
|
||||
trainer = build_trainer(name=Trainers.lora_diffusion, default_args=kwargs)
|
||||
trainer.train()
|
||||
|
||||
# pipeline after training and save result
|
||||
pipe = pipeline(
|
||||
task=Tasks.text_to_image_synthesis,
|
||||
model=training_args.model,
|
||||
lora_dir=training_args.work_dir + '/output',
|
||||
model_revision=args.model_revision)
|
||||
|
||||
output = pipe({'text': args.prompt})
|
||||
cv2.imwrite('./lora_result.png', output['output_imgs'][0])
|
||||
@@ -1,11 +1,12 @@
|
||||
PYTHONPATH=. torchrun examples/pytorch/stable_diffusion/finetune_stable_diffusion.py \
|
||||
PYTHONPATH=. torchrun examples/pytorch/stable_diffusion/lora/finetune_stable_diffusion_lora.py \
|
||||
--model 'AI-ModelScope/stable-diffusion-v1-5' \
|
||||
--model_revision 'v1.0.6' \
|
||||
--model_revision 'v1.0.9' \
|
||||
--prompt "a dog" \
|
||||
--work_dir './tmp/lora_diffusion' \
|
||||
--train_dataset_name 'buptwq/lora-stable-diffusion-finetune' \
|
||||
--max_epochs 100 \
|
||||
--save_ckpt_strategy 'by_epoch' \
|
||||
--logging_interval 100 \
|
||||
--logging_interval 1 \
|
||||
--train.dataloader.workers_per_gpu 0 \
|
||||
--evaluation.dataloader.workers_per_gpu 0 \
|
||||
--train.optimizer.lr 1e-4 \
|
||||
@@ -13,11 +13,13 @@ if TYPE_CHECKING:
|
||||
from .nlp import SbertForSequenceClassificationExporter, SbertForZeroShotClassificationExporter
|
||||
from .torch_model_exporter import TorchModelExporter
|
||||
from .cv import FaceDetectionSCRFDExporter
|
||||
from .multi_modal import StableDiffuisonExporter
|
||||
else:
|
||||
_import_structure = {
|
||||
'base': ['Exporter'],
|
||||
'builder': ['build_exporter'],
|
||||
'cv': ['CartoonTranslationExporter', 'FaceDetectionSCRFDExporter'],
|
||||
'multi_modal': ['StableDiffuisonExporter'],
|
||||
'nlp': [
|
||||
'CsanmtForTranslationExporter',
|
||||
'SbertForSequenceClassificationExporter',
|
||||
|
||||
22
modelscope/exporters/multi_modal/__init__.py
Normal file
22
modelscope/exporters/multi_modal/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .stable_diffusion_export import StableDiffuisonExporter
|
||||
else:
|
||||
_import_structure = {
|
||||
'stable_diffusion_export': ['StableDiffuisonExporter'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
303
modelscope/exporters/multi_modal/stable_diffusion_exporter.py
Normal file
303
modelscope/exporters/multi_modal/stable_diffusion_exporter.py
Normal file
@@ -0,0 +1,303 @@
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Mapping, Tuple
|
||||
|
||||
import onnx
|
||||
import torch
|
||||
from diffusers import (OnnxRuntimeModel, OnnxStableDiffusionPipeline,
|
||||
StableDiffusionPipeline)
|
||||
from packaging import version
|
||||
from torch.onnx import export
|
||||
from torch.utils.data.dataloader import default_collate
|
||||
|
||||
from modelscope.exporters.builder import EXPORTERS
|
||||
from modelscope.exporters.torch_model_exporter import TorchModelExporter
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.preprocessors import Preprocessor
|
||||
from modelscope.utils.constant import ModeKeys, Tasks
|
||||
from modelscope.utils.hub import snapshot_download
|
||||
|
||||
|
||||
@EXPORTERS.register_module(
|
||||
Tasks.text_to_image_synthesis, module_name=Models.stable_diffusion)
|
||||
class StableDiffuisonExporter(TorchModelExporter):
|
||||
|
||||
@torch.no_grad()
|
||||
def export_onnx(self,
|
||||
output_path: str,
|
||||
opset: int = 14,
|
||||
fp16: bool = False):
|
||||
"""Export the model as onnx format files.
|
||||
|
||||
Args:
|
||||
model_path: The model id or local path.
|
||||
output_dir: The output dir.
|
||||
opset: The version of the ONNX operator set to use.
|
||||
fp16: Whether to use float16.
|
||||
"""
|
||||
output_path = Path(output_path)
|
||||
|
||||
# Conversion weight accuracy and device.
|
||||
dtype = torch.float16 if fp16 else torch.float32
|
||||
if fp16 and torch.cuda.is_available():
|
||||
device = 'cuda'
|
||||
elif fp16 and not torch.cuda.is_available():
|
||||
raise ValueError(
|
||||
'`float16` model export is only supported on GPUs with CUDA')
|
||||
else:
|
||||
device = 'cpu'
|
||||
self.model = self.model.to(device)
|
||||
|
||||
# Text encoder
|
||||
num_tokens = self.model.text_encoder.config.max_position_embeddings
|
||||
text_hidden_size = self.model.text_encoder.config.hidden_size
|
||||
text_input = self.model.tokenizer(
|
||||
'A sample prompt',
|
||||
padding='max_length',
|
||||
max_length=self.model.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors='pt',
|
||||
)
|
||||
self.export_help(
|
||||
self.model.text_encoder,
|
||||
model_args=(text_input.input_ids.to(
|
||||
device=device, dtype=torch.int32)),
|
||||
output_path=output_path / 'text_encoder' / 'model.onnx',
|
||||
ordered_input_names=['input_ids'],
|
||||
output_names=['last_hidden_state', 'pooler_output'],
|
||||
dynamic_axes={
|
||||
'input_ids': {
|
||||
0: 'batch',
|
||||
1: 'sequence'
|
||||
},
|
||||
},
|
||||
opset=opset,
|
||||
)
|
||||
del self.model.text_encoder
|
||||
|
||||
# UNET
|
||||
unet_in_channels = self.model.unet.config.in_channels
|
||||
unet_sample_size = self.model.unet.config.sample_size
|
||||
unet_path = output_path / 'unet' / 'model.onnx'
|
||||
self.export_help(
|
||||
self.model.unet,
|
||||
model_args=(
|
||||
torch.randn(2, unet_in_channels, unet_sample_size,
|
||||
unet_sample_size).to(device=device, dtype=dtype),
|
||||
torch.randn(2).to(device=device, dtype=dtype),
|
||||
torch.randn(2, num_tokens,
|
||||
text_hidden_size).to(device=device, dtype=dtype),
|
||||
False,
|
||||
),
|
||||
output_path=unet_path,
|
||||
ordered_input_names=[
|
||||
'sample', 'timestep', 'encoder_hidden_states', 'return_dict'
|
||||
],
|
||||
output_names=[
|
||||
'out_sample'
|
||||
], # has to be different from "sample" for correct tracing
|
||||
dynamic_axes={
|
||||
'sample': {
|
||||
0: 'batch',
|
||||
1: 'channels',
|
||||
2: 'height',
|
||||
3: 'width'
|
||||
},
|
||||
'timestep': {
|
||||
0: 'batch'
|
||||
},
|
||||
'encoder_hidden_states': {
|
||||
0: 'batch',
|
||||
1: 'sequence'
|
||||
},
|
||||
},
|
||||
opset=opset,
|
||||
use_external_data_format=
|
||||
True, # UNet is > 2GB, so the weights need to be split
|
||||
)
|
||||
unet_model_path = str(unet_path.absolute().as_posix())
|
||||
unet_dir = os.path.dirname(unet_model_path)
|
||||
unet = onnx.load(unet_model_path)
|
||||
# clean up existing tensor files
|
||||
shutil.rmtree(unet_dir)
|
||||
os.mkdir(unet_dir)
|
||||
# collate external tensor files into one
|
||||
onnx.save_model(
|
||||
unet,
|
||||
unet_model_path,
|
||||
save_as_external_data=True,
|
||||
all_tensors_to_one_file=True,
|
||||
location='weights.pb',
|
||||
convert_attribute=False,
|
||||
)
|
||||
del self.model.unet
|
||||
|
||||
# VAE ENCODER
|
||||
vae_encoder = self.model.vae
|
||||
vae_in_channels = vae_encoder.config.in_channels
|
||||
vae_sample_size = vae_encoder.config.sample_size
|
||||
# need to get the raw tensor output (sample) from the encoder
|
||||
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(
|
||||
sample, return_dict)[0].sample()
|
||||
self.export_help(
|
||||
vae_encoder,
|
||||
model_args=(
|
||||
torch.randn(1, vae_in_channels, vae_sample_size,
|
||||
vae_sample_size).to(device=device, dtype=dtype),
|
||||
False,
|
||||
),
|
||||
output_path=output_path / 'vae_encoder' / 'model.onnx',
|
||||
ordered_input_names=['sample', 'return_dict'],
|
||||
output_names=['latent_sample'],
|
||||
dynamic_axes={
|
||||
'sample': {
|
||||
0: 'batch',
|
||||
1: 'channels',
|
||||
2: 'height',
|
||||
3: 'width'
|
||||
},
|
||||
},
|
||||
opset=opset,
|
||||
)
|
||||
|
||||
# VAE DECODER
|
||||
vae_decoder = self.model.vae
|
||||
vae_latent_channels = vae_decoder.config.latent_channels
|
||||
vae_out_channels = vae_decoder.config.out_channels
|
||||
# forward only through the decoder part
|
||||
vae_decoder.forward = vae_encoder.decode
|
||||
self.export_help(
|
||||
vae_decoder,
|
||||
model_args=(
|
||||
torch.randn(1, vae_latent_channels, unet_sample_size,
|
||||
unet_sample_size).to(device=device, dtype=dtype),
|
||||
False,
|
||||
),
|
||||
output_path=output_path / 'vae_decoder' / 'model.onnx',
|
||||
ordered_input_names=['latent_sample', 'return_dict'],
|
||||
output_names=['sample'],
|
||||
dynamic_axes={
|
||||
'latent_sample': {
|
||||
0: 'batch',
|
||||
1: 'channels',
|
||||
2: 'height',
|
||||
3: 'width'
|
||||
},
|
||||
},
|
||||
opset=opset,
|
||||
)
|
||||
del self.model.vae
|
||||
|
||||
# SAFETY CHECKER
|
||||
if self.model.safety_checker is not None:
|
||||
safety_checker = self.model.safety_checker
|
||||
clip_num_channels = safety_checker.config.vision_config.num_channels
|
||||
clip_image_size = safety_checker.config.vision_config.image_size
|
||||
safety_checker.forward = safety_checker.forward_onnx
|
||||
self.export_help(
|
||||
self.model.safety_checker,
|
||||
model_args=(
|
||||
torch.randn(
|
||||
1,
|
||||
clip_num_channels,
|
||||
clip_image_size,
|
||||
clip_image_size,
|
||||
).to(device=device, dtype=dtype),
|
||||
torch.randn(1, vae_sample_size, vae_sample_size,
|
||||
vae_out_channels).to(
|
||||
device=device, dtype=dtype),
|
||||
),
|
||||
output_path=output_path / 'safety_checker' / 'model.onnx',
|
||||
ordered_input_names=['clip_input', 'images'],
|
||||
output_names=['out_images', 'has_nsfw_concepts'],
|
||||
dynamic_axes={
|
||||
'clip_input': {
|
||||
0: 'batch',
|
||||
1: 'channels',
|
||||
2: 'height',
|
||||
3: 'width'
|
||||
},
|
||||
'images': {
|
||||
0: 'batch',
|
||||
1: 'height',
|
||||
2: 'width',
|
||||
3: 'channels'
|
||||
},
|
||||
},
|
||||
opset=opset,
|
||||
)
|
||||
del self.model.safety_checker
|
||||
safety_checker = OnnxRuntimeModel.from_pretrained(
|
||||
output_path / 'safety_checker')
|
||||
feature_extractor = self.model.feature_extractor
|
||||
else:
|
||||
safety_checker = None
|
||||
feature_extractor = None
|
||||
|
||||
onnx_pipeline = OnnxStableDiffusionPipeline(
|
||||
vae_encoder=OnnxRuntimeModel.from_pretrained(output_path
|
||||
/ 'vae_encoder'),
|
||||
vae_decoder=OnnxRuntimeModel.from_pretrained(output_path
|
||||
/ 'vae_decoder'),
|
||||
text_encoder=OnnxRuntimeModel.from_pretrained(output_path
|
||||
/ 'text_encoder'),
|
||||
tokenizer=self.model.tokenizer,
|
||||
unet=OnnxRuntimeModel.from_pretrained(output_path / 'unet'),
|
||||
scheduler=self.model.noise_scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
requires_safety_checker=safety_checker is not None,
|
||||
)
|
||||
|
||||
onnx_pipeline.save_pretrained(output_path)
|
||||
print('ONNX pipeline model saved to', output_path)
|
||||
|
||||
del self.model
|
||||
del onnx_pipeline
|
||||
_ = OnnxStableDiffusionPipeline.from_pretrained(
|
||||
output_path, provider='CPUExecutionProvider')
|
||||
print('ONNX pipeline model is loadable')
|
||||
|
||||
def export_help(
|
||||
self,
|
||||
model,
|
||||
model_args: tuple,
|
||||
output_path: Path,
|
||||
ordered_input_names,
|
||||
output_names,
|
||||
dynamic_axes,
|
||||
opset,
|
||||
use_external_data_format=False,
|
||||
):
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
is_torch_less_than_1_11 = version.parse(
|
||||
version.parse(
|
||||
torch.__version__).base_version) < version.parse('1.11')
|
||||
if is_torch_less_than_1_11:
|
||||
export(
|
||||
model,
|
||||
model_args,
|
||||
f=output_path.as_posix(),
|
||||
input_names=ordered_input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
do_constant_folding=True,
|
||||
use_external_data_format=use_external_data_format,
|
||||
enable_onnx_checker=True,
|
||||
opset_version=opset,
|
||||
)
|
||||
else:
|
||||
export(
|
||||
model,
|
||||
model_args,
|
||||
f=output_path.as_posix(),
|
||||
input_names=ordered_input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
do_constant_folding=True,
|
||||
opset_version=opset,
|
||||
)
|
||||
@@ -900,6 +900,7 @@ class MultiModalTrainers(object):
|
||||
efficient_diffusion_tuning = 'efficient-diffusion-tuning'
|
||||
stable_diffusion = 'stable-diffusion'
|
||||
lora_diffusion = 'lora-diffusion'
|
||||
dreambooth_diffusion = 'dreambooth-diffusion'
|
||||
|
||||
|
||||
class AudioTrainers(object):
|
||||
|
||||
@@ -30,13 +30,12 @@ class StableDiffusion(TorchModel):
|
||||
""" Initialize a vision efficient diffusion tuning model.
|
||||
|
||||
Args:
|
||||
model_dir: model id or path, where model_dir/pytorch_model.bin
|
||||
model_dir: model id or path
|
||||
"""
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
pretrained_model_name_or_path = kwargs.pop(
|
||||
'pretrained_model_name_or_path', 'runwayml/stable-diffusion-v1-5')
|
||||
revision = kwargs.pop('revision', None)
|
||||
self.lora_tune = kwargs.pop('lora_tune', True)
|
||||
self.lora_tune = kwargs.pop('lora_tune', False)
|
||||
self.dreambooth_tune = kwargs.pop('dreambooth_tune', False)
|
||||
|
||||
self.weight_dtype = torch.float32
|
||||
self.device = torch.device(
|
||||
@@ -44,19 +43,16 @@ class StableDiffusion(TorchModel):
|
||||
|
||||
# Load scheduler, tokenizer and models
|
||||
self.noise_scheduler = DDPMScheduler.from_pretrained(
|
||||
pretrained_model_name_or_path, subfolder='scheduler')
|
||||
model_dir, subfolder='scheduler')
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder='tokenizer',
|
||||
revision=revision)
|
||||
model_dir, subfolder='tokenizer', revision=revision)
|
||||
self.text_encoder = CLIPTextModel.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder='text_encoder',
|
||||
revision=revision)
|
||||
model_dir, subfolder='text_encoder', revision=revision)
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
pretrained_model_name_or_path, subfolder='vae', revision=revision)
|
||||
model_dir, subfolder='vae', revision=revision)
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
pretrained_model_name_or_path, subfolder='unet', revision=revision)
|
||||
model_dir, subfolder='unet', revision=revision)
|
||||
self.safety_checker = None
|
||||
|
||||
# Freeze gradient calculation and move to device
|
||||
if self.vae is not None:
|
||||
@@ -89,6 +85,7 @@ class StableDiffusion(TorchModel):
|
||||
self.unet.train()
|
||||
self.unet = self.unet.to(self.device)
|
||||
|
||||
# Convert to latent space
|
||||
with torch.no_grad():
|
||||
latents = self.vae.encode(
|
||||
target.to(dtype=self.weight_dtype)).latent_dist.sample()
|
||||
@@ -130,6 +127,9 @@ class StableDiffusion(TorchModel):
|
||||
model_pred = self.unet(noisy_latents, timesteps,
|
||||
encoder_hidden_states).sample
|
||||
|
||||
if model_pred.shape[1] == 6:
|
||||
model_pred, _ = torch.chunk(model_pred, 2, dim=1)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction='mean')
|
||||
|
||||
output = {OutputKeys.LOSS: loss}
|
||||
@@ -143,8 +143,9 @@ class StableDiffusion(TorchModel):
|
||||
config: Optional[dict] = None,
|
||||
save_config_function: Callable = save_configuration,
|
||||
**kwargs):
|
||||
# Save only the lora model, skip saving and copying the original weights
|
||||
if self.lora_tune:
|
||||
config['pipeline']['type'] = 'diffusers-stable-diffusion'
|
||||
# Skip copying the original weights for lora and dreambooth method
|
||||
if self.lora_tune or self.dreambooth_tune:
|
||||
pass
|
||||
else:
|
||||
super().save_pretrained(target_folder, save_checkpoint_names,
|
||||
|
||||
@@ -11,6 +11,7 @@ import torch
|
||||
from PIL import Image
|
||||
from timm.data import create_transform
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import ImageFolder
|
||||
from torchvision.transforms import Compose, Normalize, Resize, ToTensor
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
@@ -55,7 +56,6 @@ class DiffusionImageGenerationPreprocessor(Preprocessor):
|
||||
transforms.Resize(
|
||||
self.preprocessor_resolution,
|
||||
interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(self.preprocessor_mean,
|
||||
self.preprocessor_std),
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
# Copyright © Alibaba, Inc. and its affiliates.
|
||||
from .dreambooth_diffusion_trainer import DreamboothDiffusionTrainer
|
||||
@@ -0,0 +1,384 @@
|
||||
# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
import hashlib
|
||||
import itertools
|
||||
import shutil
|
||||
import warnings
|
||||
from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.loaders import AttnProcsLayers
|
||||
from diffusers.models.attention_processor import LoRAAttnProcessor
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.outputs import ModelOutputBase, OutputKeys
|
||||
from modelscope.trainers.builder import TRAINERS
|
||||
from modelscope.trainers.hooks.checkpoint.checkpoint_hook import CheckpointHook
|
||||
from modelscope.trainers.hooks.checkpoint.checkpoint_processor import \
|
||||
CheckpointProcessor
|
||||
from modelscope.trainers.optimizer.builder import build_optimizer
|
||||
from modelscope.trainers.trainer import EpochBasedTrainer
|
||||
from modelscope.utils.config import ConfigDict
|
||||
from modelscope.utils.constant import ModeKeys
|
||||
from modelscope.utils.file_utils import func_receive_dict_inputs
|
||||
from modelscope.utils.torch_utils import is_dist
|
||||
|
||||
|
||||
class DreamboothCheckpointProcessor(CheckpointProcessor):
|
||||
|
||||
def __init__(self, model_dir):
|
||||
self.model_dir = model_dir
|
||||
|
||||
def save_checkpoints(self,
|
||||
trainer,
|
||||
checkpoint_path_prefix,
|
||||
output_dir,
|
||||
meta=None):
|
||||
"""Save the state dict for dreambooth model.
|
||||
"""
|
||||
pipeline_args = {}
|
||||
if trainer.model.text_encoder is not None:
|
||||
pipeline_args['text_encoder'] = trainer.model.text_encoder
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
self.model_dir,
|
||||
unet=trainer.model.unet,
|
||||
**pipeline_args,
|
||||
)
|
||||
scheduler_args = {}
|
||||
pipeline.scheduler = pipeline.scheduler.from_config(
|
||||
pipeline.scheduler.config, **scheduler_args)
|
||||
pipeline.save_pretrained(output_dir)
|
||||
|
||||
|
||||
class ClassDataset(Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
class_data_root=None,
|
||||
class_prompt=None,
|
||||
class_num_images=None,
|
||||
size=512,
|
||||
center_crop=False,
|
||||
):
|
||||
"""A dataset to prepare class images with the prompts for fine-tuning the model.
|
||||
It pre-processes the images and the tokenizes prompts.
|
||||
|
||||
Args:
|
||||
tokenizer: The tokenizer to use for tokenization.
|
||||
class_data_root: The saved class data path.
|
||||
class_prompt: The prompt to use for class images.
|
||||
class_num_images: The number of class images to use.
|
||||
size: The size to resize the images.
|
||||
center_crop: Whether to do center crop or random crop.
|
||||
|
||||
"""
|
||||
self.size = size
|
||||
self.center_crop = center_crop
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
if class_data_root is not None:
|
||||
self.class_data_root = Path(class_data_root)
|
||||
self.class_data_root.mkdir(parents=True, exist_ok=True)
|
||||
self.class_images_path = list(self.class_data_root.iterdir())
|
||||
if class_num_images is not None:
|
||||
self.num_class_images = min(
|
||||
len(self.class_images_path), class_num_images)
|
||||
else:
|
||||
self.num_class_images = len(self.class_images_path)
|
||||
self.class_prompt = class_prompt
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Class {self.class_data_root} class data root doesn't exists."
|
||||
)
|
||||
|
||||
self.image_transforms = transforms.Compose([
|
||||
transforms.Resize(
|
||||
size, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(size)
|
||||
if center_crop else transforms.RandomCrop(size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
])
|
||||
|
||||
def __len__(self):
|
||||
return self.num_class_images
|
||||
|
||||
def __getitem__(self, index):
|
||||
example = {}
|
||||
|
||||
if self.class_data_root:
|
||||
class_image = Image.open(
|
||||
self.class_images_path[index % self.num_class_images])
|
||||
class_image = exif_transpose(class_image)
|
||||
|
||||
if not class_image.mode == 'RGB':
|
||||
class_image = class_image.convert('RGB')
|
||||
example['pixel_values'] = self.image_transforms(class_image)
|
||||
|
||||
class_text_inputs = self.tokenizer(
|
||||
self.class_prompt,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
padding='max_length',
|
||||
return_tensors='pt')
|
||||
input_ids = torch.squeeze(class_text_inputs.input_ids)
|
||||
example['input_ids'] = input_ids
|
||||
|
||||
return example
|
||||
|
||||
|
||||
class PromptDataset(Dataset):
|
||||
|
||||
def __init__(self, prompt, num_samples):
|
||||
"""Dataset to prepare the prompts to generate class images.
|
||||
|
||||
Args:
|
||||
prompt: Class prompt.
|
||||
num_samples: The number sample for class images.
|
||||
|
||||
"""
|
||||
self.prompt = prompt
|
||||
self.num_samples = num_samples
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, index):
|
||||
example = {}
|
||||
example['prompt'] = self.prompt
|
||||
example['index'] = index
|
||||
return example
|
||||
|
||||
|
||||
@TRAINERS.register_module(module_name=Trainers.dreambooth_diffusion)
|
||||
class DreamboothDiffusionTrainer(EpochBasedTrainer):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
"""Dreambooth trainers for fine-tuning stable diffusion
|
||||
|
||||
Args:
|
||||
with_prior_preservation: a boolean indicating whether to enable prior loss.
|
||||
instance_prompt: a string specifying the instance prompt.
|
||||
class_prompt: a string specifying the class prompt.
|
||||
class_data_dir: the path to the class data directory.
|
||||
num_class_images: the number of class images to generate.
|
||||
prior_loss_weight: the weight of the prior loss.
|
||||
|
||||
"""
|
||||
self.with_prior_preservation = kwargs.pop('with_prior_preservation',
|
||||
False)
|
||||
self.instance_prompt = kwargs.pop('instance_prompt',
|
||||
'a photo of sks dog')
|
||||
self.class_prompt = kwargs.pop('class_prompt', 'a photo of dog')
|
||||
self.class_data_dir = kwargs.pop('class_data_dir', '/tmp/class_data')
|
||||
self.num_class_images = kwargs.pop('num_class_images', 200)
|
||||
self.resolution = kwargs.pop('resolution', 512)
|
||||
self.prior_loss_weight = kwargs.pop('prior_loss_weight', 1.0)
|
||||
|
||||
# Save checkpoint and configurate files.
|
||||
ckpt_hook = list(
|
||||
filter(lambda hook: isinstance(hook, CheckpointHook),
|
||||
self.hooks))[0]
|
||||
ckpt_hook.set_processor(DreamboothCheckpointProcessor(self.model_dir))
|
||||
|
||||
# Check for conflicts and conflicts
|
||||
if self.with_prior_preservation:
|
||||
if self.class_data_dir is None:
|
||||
raise ValueError(
|
||||
'You must specify a data directory for class images.')
|
||||
if self.class_prompt is None:
|
||||
raise ValueError('You must specify prompt for class images.')
|
||||
else:
|
||||
if self.class_data_dir is not None:
|
||||
warnings.warn(
|
||||
'You need not use --class_data_dir without --with_prior_preservation.'
|
||||
)
|
||||
if self.class_prompt is not None:
|
||||
warnings.warn(
|
||||
'You need not use --class_prompt without --with_prior_preservation.'
|
||||
)
|
||||
|
||||
# Generate class images if prior preservation is enabled.
|
||||
if self.with_prior_preservation:
|
||||
class_images_dir = Path(self.class_data_dir)
|
||||
if not class_images_dir.exists():
|
||||
class_images_dir.mkdir(parents=True)
|
||||
cur_class_images = len(list(class_images_dir.iterdir()))
|
||||
|
||||
if cur_class_images < self.num_class_images:
|
||||
if torch.cuda.device_count() > 1:
|
||||
warnings.warn('Multiple GPU inference not yet supported.')
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
self.model_dir,
|
||||
torch_dtype=torch.float32,
|
||||
safety_checker=None,
|
||||
revision=None,
|
||||
)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
num_new_images = self.num_class_images - cur_class_images
|
||||
sample_dataset = PromptDataset(self.instance_prompt,
|
||||
num_new_images)
|
||||
sample_dataloader = torch.utils.data.DataLoader(sample_dataset)
|
||||
|
||||
pipeline.to(self.device)
|
||||
|
||||
for example in tqdm(
|
||||
sample_dataloader, desc='Generating class images'):
|
||||
images = pipeline(example['prompt']).images
|
||||
for i, image in enumerate(images):
|
||||
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
|
||||
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
|
||||
image.save(image_filename)
|
||||
|
||||
del pipeline
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Class Dataset and DataLoaders creation
|
||||
class_dataset = ClassDataset(
|
||||
class_data_root=self.class_data_dir
|
||||
if self.with_prior_preservation else None,
|
||||
class_prompt=self.class_prompt,
|
||||
class_num_images=self.num_class_images,
|
||||
tokenizer=self.model.tokenizer,
|
||||
size=self.resolution,
|
||||
center_crop=False,
|
||||
)
|
||||
class_dataloader = torch.utils.data.DataLoader(
|
||||
class_dataset,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
)
|
||||
self.iter_class_dataloader = itertools.cycle(class_dataloader)
|
||||
|
||||
def build_optimizer(self, cfg: ConfigDict, default_args: dict = None):
|
||||
try:
|
||||
return build_optimizer(
|
||||
self.model.unet.parameters(),
|
||||
cfg=cfg,
|
||||
default_args=default_args)
|
||||
except KeyError as e:
|
||||
self.logger.error(
|
||||
f'Build optimizer error, the optimizer {cfg} is a torch native component, '
|
||||
f'please check if your torch with version: {torch.__version__} matches the config.'
|
||||
)
|
||||
raise e
|
||||
|
||||
def train_step(self, model, inputs):
|
||||
""" Perform a training step on a batch of inputs.
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
|
||||
Args:
|
||||
model (`TorchModel`): The model to train.
|
||||
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
|
||||
The inputs and targets of the model.
|
||||
|
||||
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
||||
argument `labels`. Check your model's documentation for all accepted arguments.
|
||||
|
||||
Return:
|
||||
`torch.Tensor`: The tensor with training loss on this batch.
|
||||
"""
|
||||
model.train()
|
||||
self._mode = ModeKeys.TRAIN
|
||||
# call model forward but not __call__ to skip postprocess
|
||||
|
||||
receive_dict_inputs = func_receive_dict_inputs(
|
||||
self.unwrap_module(self.model).forward)
|
||||
|
||||
if isinstance(inputs, Mapping) and not receive_dict_inputs:
|
||||
train_outputs = model.forward(**inputs)
|
||||
else:
|
||||
train_outputs = model.forward(inputs)
|
||||
|
||||
if self.with_prior_preservation:
|
||||
# Convert to latent space
|
||||
batch = next(self.iter_class_dataloader)
|
||||
target_prior = batch['pixel_values'].to(self.device)
|
||||
input_ids = batch['input_ids'].to(self.device)
|
||||
with torch.no_grad():
|
||||
latents = self.model.vae.encode(
|
||||
target_prior.to(dtype=torch.float32)).latent_dist.sample()
|
||||
latents = latents * self.model.vae.config.scaling_factor
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
0,
|
||||
self.model.noise_scheduler.num_train_timesteps, (bsz, ),
|
||||
device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = self.model.noise_scheduler.add_noise(
|
||||
latents, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
with torch.no_grad():
|
||||
encoder_hidden_states = self.model.text_encoder(input_ids)[0]
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if self.model.noise_scheduler.config.prediction_type == 'epsilon':
|
||||
target_prior = noise
|
||||
elif self.model.noise_scheduler.config.prediction_type == 'v_prediction':
|
||||
target_prior = self.model.noise_scheduler.get_velocity(
|
||||
latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Unknown prediction type {self.model.noise_scheduler.config.prediction_type}'
|
||||
)
|
||||
|
||||
# Predict the noise residual and compute loss
|
||||
model_pred_prior = self.model.unet(noisy_latents, timesteps,
|
||||
encoder_hidden_states).sample
|
||||
|
||||
# Compute prior loss
|
||||
prior_loss = F.mse_loss(
|
||||
model_pred_prior.float(),
|
||||
target_prior.float(),
|
||||
reduction='mean')
|
||||
# Add the prior loss to the instance loss.
|
||||
train_outputs[
|
||||
OutputKeys.LOSS] += self.prior_loss_weight * prior_loss
|
||||
|
||||
if isinstance(train_outputs, ModelOutputBase):
|
||||
train_outputs = train_outputs.to_dict()
|
||||
if not isinstance(train_outputs, dict):
|
||||
raise TypeError('"model.forward()" must return a dict')
|
||||
|
||||
# add model output info to log
|
||||
if 'log_vars' not in train_outputs:
|
||||
default_keys_pattern = ['loss']
|
||||
match_keys = set([])
|
||||
for key_p in default_keys_pattern:
|
||||
match_keys.update(
|
||||
[key for key in train_outputs.keys() if key_p in key])
|
||||
|
||||
log_vars = {}
|
||||
for key in match_keys:
|
||||
value = train_outputs.get(key, None)
|
||||
if value is not None:
|
||||
if is_dist():
|
||||
value = value.data.clone().to('cuda')
|
||||
dist.all_reduce(value.div_(dist.get_world_size()))
|
||||
log_vars.update({key: value.item()})
|
||||
self.log_buffer.update(log_vars)
|
||||
else:
|
||||
self.log_buffer.update(train_outputs['log_vars'])
|
||||
|
||||
self.train_outputs = train_outputs
|
||||
@@ -1 +1,2 @@
|
||||
# Copyright © Alibaba, Inc. and its affiliates.
|
||||
from .lora_diffusion_trainer import LoraDiffusionTrainer
|
||||
|
||||
31
tests/export/test_export_stable_diffusion.py
Normal file
31
tests/export/test_export_stable_diffusion.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from collections import OrderedDict
|
||||
|
||||
from modelscope.exporters import Exporter
|
||||
from modelscope.models import Model
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class TestExportStableDiffusion(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
self.tmp_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(self.tmp_dir):
|
||||
os.makedirs(self.tmp_dir)
|
||||
self.model_id = 'AI-ModelScope/stable-diffusion-v1-5'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_export_stable_diffusion(self):
|
||||
model = Model.from_pretrained(self.model_id)
|
||||
Exporter.from_model(model).export_onnx(
|
||||
output_path=self.tmp_dir, opset=14)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -21,7 +21,7 @@ class DiffusersStableDiffusionTest(unittest.TestCase):
|
||||
def test_run(self):
|
||||
diffusers_pipeline = pipeline(task=self.task, model=self.model_id)
|
||||
output = diffusers_pipeline({
|
||||
'prompt': self.test_input,
|
||||
'text': self.test_input,
|
||||
'height': 512,
|
||||
'width': 512
|
||||
})
|
||||
|
||||
98
tests/trainers/test_dreambooth_diffusion_trainer.py
Normal file
98
tests/trainers/test_dreambooth_diffusion_trainer.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import cv2
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.utils.constant import DownloadMode
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class TestDreamboothDiffusionTrainer(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
|
||||
self.train_dataset = MsDataset.load(
|
||||
'buptwq/lora-stable-diffusion-finetune',
|
||||
split='train',
|
||||
download_mode=DownloadMode.FORCE_REDOWNLOAD)
|
||||
self.eval_dataset = MsDataset.load(
|
||||
'buptwq/lora-stable-diffusion-finetune',
|
||||
split='validation',
|
||||
download_mode=DownloadMode.FORCE_REDOWNLOAD)
|
||||
|
||||
self.max_epochs = 5
|
||||
|
||||
self.tmp_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(self.tmp_dir):
|
||||
os.makedirs(self.tmp_dir)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmp_dir)
|
||||
super().tearDown()
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_dreambooth_diffusion_train(self):
|
||||
model_id = 'AI-ModelScope/stable-diffusion-v1-5'
|
||||
model_revision = 'v1.0.8'
|
||||
prompt = 'a dog.'
|
||||
|
||||
def cfg_modify_fn(cfg):
|
||||
cfg.train.max_epochs = self.max_epochs
|
||||
cfg.train.lr_scheduler = {
|
||||
'type': 'LambdaLR',
|
||||
'lr_lambda': lambda _: 1,
|
||||
'last_epoch': -1
|
||||
}
|
||||
cfg.train.optimizer.lr = 5e-6
|
||||
return cfg
|
||||
|
||||
kwargs = dict(
|
||||
model=model_id,
|
||||
model_revision=model_revision,
|
||||
work_dir=self.tmp_dir,
|
||||
train_dataset=self.train_dataset,
|
||||
eval_dataset=self.eval_dataset,
|
||||
cfg_modify_fn=cfg_modify_fn)
|
||||
|
||||
trainer = build_trainer(
|
||||
name=Trainers.dreambooth_diffusion, default_args=kwargs)
|
||||
trainer.train()
|
||||
result = trainer.evaluate()
|
||||
print(f'Dreambooth-diffusion train output: {result}.')
|
||||
|
||||
results_files = os.listdir(self.tmp_dir)
|
||||
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
|
||||
|
||||
pipe = pipeline(
|
||||
task=Tasks.text_to_image_synthesis, model=f'{self.tmp_dir}/output')
|
||||
output = pipe({'text': prompt})
|
||||
cv2.imwrite('./dreambooth_result.png', output['output_imgs'][0])
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_dreambooth_diffusion_eval(self):
|
||||
model_id = 'AI-ModelScope/stable-diffusion-v1-5'
|
||||
model_revision = 'v1.0.8'
|
||||
|
||||
kwargs = dict(
|
||||
model=model_id,
|
||||
model_revision=model_revision,
|
||||
work_dir=self.tmp_dir,
|
||||
train_dataset=None,
|
||||
eval_dataset=self.eval_dataset)
|
||||
|
||||
trainer = build_trainer(
|
||||
name=Trainers.dreambooth_diffusion, default_args=kwargs)
|
||||
result = trainer.evaluate()
|
||||
print(f'Dreambooth-diffusion eval output: {result}.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -38,7 +38,7 @@ class TestLoraDiffusionTrainer(unittest.TestCase):
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_lora_diffusion_train(self):
|
||||
model_id = 'AI-ModelScope/stable-diffusion-v1-5'
|
||||
model_revision = 'v1.0.6'
|
||||
model_revision = 'v1.0.9'
|
||||
|
||||
def cfg_modify_fn(cfg):
|
||||
cfg.train.max_epochs = self.max_epochs
|
||||
@@ -70,7 +70,7 @@ class TestLoraDiffusionTrainer(unittest.TestCase):
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_lora_diffusion_eval(self):
|
||||
model_id = 'AI-ModelScope/stable-diffusion-v1-5'
|
||||
model_revision = 'v1.0.6'
|
||||
model_revision = 'v1.0.9'
|
||||
|
||||
kwargs = dict(
|
||||
model=model_id,
|
||||
|
||||
Reference in New Issue
Block a user