custom diffusion

This commit is contained in:
XDUWQ
2023-07-11 20:46:32 +08:00
parent 2a79a6cee7
commit 1caa45422c
6 changed files with 1080 additions and 0 deletions

View File

@@ -0,0 +1,162 @@
import os
from dataclasses import dataclass, field
import cv2
from modelscope.metainfo import Trainers
from modelscope.msdatasets import MsDataset
from modelscope.pipelines import pipeline
from modelscope.trainers import EpochBasedTrainer, build_trainer
from modelscope.trainers.training_args import TrainingArgs
from modelscope.utils.constant import DownloadMode, Tasks
# Load configuration file and dataset
@dataclass(init=False)
class StableDiffusionCustomArguments(TrainingArgs):
class_prompt: str = field(
default=None,
metadata={
'help':
'The prompt to specify images in the same class as provided instance images.',
})
instance_prompt: str = field(
default=None,
metadata={
'help': 'The prompt with identifier specifying the instance.',
})
modifier_token: str = field(
default=None,
metadata={
'help': 'A token to use as a modifier for the concept.',
})
num_class_images: int = field(
default=200,
metadata={
'help':
'Minimal class images for prior preservation loss. If there are not enough images already present in class_data_dir, additional images will be sampled with class_prompt.',
})
train_batch_size: int = field(
default=4,
metadata={
'help': 'Batch size (per device) for the training dataloader.',
})
sample_batch_size: int = field(
default=4,
metadata={
'help': 'Batch size (per device) for sampling images.',
})
initializer_token: str = field(
default='ktn+pll+ucd',
metadata={
'help': 'A token to use as initializer word.',
})
class_data_dir: str = field(
default='/tmp/class_data',
metadata={
'help': 'A folder containing the training data of class images.',
})
resolution: int = field(
default=512,
metadata={
'help':
'The resolution for input images, all the images in the train/validation dataset will be resized to this',
})
prior_loss_weight: float = field(
default=1.0,
metadata={
'help': 'The weight of prior preservation loss.',
})
freeze_model: str = field(
default='crossattn_kv',
metadata={
'help':
'crossattn to enable fine-tuning of all params in the cross attention.',
})
instance_data_name: str = field(
default='buptwq/custom-stable-diffusion-cat',
metadata={
'help': 'The instance data local dir or online ID.',
})
training_args = StableDiffusionCustomArguments(
task='text-to-image-synthesis').parse_cli()
config, args = training_args.to_config()
if os.path.exists(args.train_dataset_name):
# Load local dataset
train_dataset = MsDataset.load(args.train_dataset_name)
validation_dataset = MsDataset.load(args.train_dataset_name)
else:
# Load online dataset
train_dataset = MsDataset.load(
args.train_dataset_name,
split='train',
download_mode=DownloadMode.FORCE_REDOWNLOAD)
validation_dataset = MsDataset.load(
args.train_dataset_name,
split='validation',
download_mode=DownloadMode.FORCE_REDOWNLOAD)
def cfg_modify_fn(cfg):
if args.use_model_config:
cfg.merge_from_dict(config)
else:
cfg = config
cfg.train.lr_scheduler = {
'type': 'LambdaLR',
'lr_lambda': lambda _: 1,
'last_epoch': -1
}
return cfg
kwargs = dict(
model=training_args.model,
model_revision=args.model_revision,
class_prompt=args.class_prompt,
instance_prompt=args.instance_prompt,
modifier_token=args.modifier_token,
num_class_images=args.num_class_images,
train_batch_size=args.train_batch_size,
sample_batch_size=args.sample_batch_size,
initializer_token=args.initializer_token,
class_data_dir=args.class_data_dir,
resolution=args.resolution,
prior_loss_weight=args.prior_loss_weight,
freeze_model=args.freeze_model,
instance_data_name=args.instance_data_name,
work_dir=training_args.work_dir,
train_dataset=train_dataset,
eval_dataset=validation_dataset,
cfg_modify_fn=cfg_modify_fn)
# build trainer and training
trainer = build_trainer(name=Trainers.custom_diffusion, default_args=kwargs)
trainer.train()
# pipeline after training and save result
pipe = pipeline(
task=Tasks.text_to_image_synthesis,
model=training_args.model,
custom_dir=training_args.work_dir + '/output',
modifier_token='<new1>',
model_revision=args.model_revision)
output = pipe({'text': args.prompt})
# visualize the result on ipynb and save it
output
cv2.imwrite('./custom_result.png', output['output_imgs'][0])

View File

@@ -0,0 +1,17 @@
PYTHONPATH=. torchrun examples/pytorch/stable_diffusion/custom/finetune_stable_diffusion_custom.py \
--model 'AI-ModelScope/stable-diffusion-v1-5' \
--model_revision 'v1.0.9' \
--class_prompt "dog" \
--instance_prompt="photo of a <new1> dog" \
--work_dir './tmp/custom_diffusion' \
--class_data_dir './tmp/class_data' \
--train_dataset_name 'buptwq/lora-stable-diffusion-finetune-dog' \
--max_epochs 2 \
--modifier_token "<new1>" \
--num_class_images=200 \
--save_ckpt_strategy 'by_epoch' \
--logging_interval 1 \
--train.dataloader.workers_per_gpu 0 \
--evaluation.dataloader.workers_per_gpu 0 \
--train.optimizer.lr 1e-5 \
--use_model_config true