mirror of
https://github.com/modelscope/modelscope.git
synced 2026-05-18 13:15:06 +02:00
custom diffusion
This commit is contained in:
@@ -0,0 +1,162 @@
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import cv2
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.trainers import EpochBasedTrainer, build_trainer
|
||||
from modelscope.trainers.training_args import TrainingArgs
|
||||
from modelscope.utils.constant import DownloadMode, Tasks
|
||||
|
||||
|
||||
# Load configuration file and dataset
|
||||
@dataclass(init=False)
|
||||
class StableDiffusionCustomArguments(TrainingArgs):
|
||||
class_prompt: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help':
|
||||
'The prompt to specify images in the same class as provided instance images.',
|
||||
})
|
||||
|
||||
instance_prompt: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help': 'The prompt with identifier specifying the instance.',
|
||||
})
|
||||
|
||||
modifier_token: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help': 'A token to use as a modifier for the concept.',
|
||||
})
|
||||
|
||||
num_class_images: int = field(
|
||||
default=200,
|
||||
metadata={
|
||||
'help':
|
||||
'Minimal class images for prior preservation loss. If there are not enough images already present in class_data_dir, additional images will be sampled with class_prompt.',
|
||||
})
|
||||
|
||||
train_batch_size: int = field(
|
||||
default=4,
|
||||
metadata={
|
||||
'help': 'Batch size (per device) for the training dataloader.',
|
||||
})
|
||||
|
||||
sample_batch_size: int = field(
|
||||
default=4,
|
||||
metadata={
|
||||
'help': 'Batch size (per device) for sampling images.',
|
||||
})
|
||||
|
||||
initializer_token: str = field(
|
||||
default='ktn+pll+ucd',
|
||||
metadata={
|
||||
'help': 'A token to use as initializer word.',
|
||||
})
|
||||
|
||||
class_data_dir: str = field(
|
||||
default='/tmp/class_data',
|
||||
metadata={
|
||||
'help': 'A folder containing the training data of class images.',
|
||||
})
|
||||
|
||||
resolution: int = field(
|
||||
default=512,
|
||||
metadata={
|
||||
'help':
|
||||
'The resolution for input images, all the images in the train/validation dataset will be resized to this',
|
||||
})
|
||||
|
||||
prior_loss_weight: float = field(
|
||||
default=1.0,
|
||||
metadata={
|
||||
'help': 'The weight of prior preservation loss.',
|
||||
})
|
||||
|
||||
freeze_model: str = field(
|
||||
default='crossattn_kv',
|
||||
metadata={
|
||||
'help':
|
||||
'crossattn to enable fine-tuning of all params in the cross attention.',
|
||||
})
|
||||
|
||||
instance_data_name: str = field(
|
||||
default='buptwq/custom-stable-diffusion-cat',
|
||||
metadata={
|
||||
'help': 'The instance data local dir or online ID.',
|
||||
})
|
||||
|
||||
|
||||
training_args = StableDiffusionCustomArguments(
|
||||
task='text-to-image-synthesis').parse_cli()
|
||||
config, args = training_args.to_config()
|
||||
|
||||
if os.path.exists(args.train_dataset_name):
|
||||
# Load local dataset
|
||||
train_dataset = MsDataset.load(args.train_dataset_name)
|
||||
validation_dataset = MsDataset.load(args.train_dataset_name)
|
||||
else:
|
||||
# Load online dataset
|
||||
train_dataset = MsDataset.load(
|
||||
args.train_dataset_name,
|
||||
split='train',
|
||||
download_mode=DownloadMode.FORCE_REDOWNLOAD)
|
||||
validation_dataset = MsDataset.load(
|
||||
args.train_dataset_name,
|
||||
split='validation',
|
||||
download_mode=DownloadMode.FORCE_REDOWNLOAD)
|
||||
|
||||
|
||||
def cfg_modify_fn(cfg):
|
||||
if args.use_model_config:
|
||||
cfg.merge_from_dict(config)
|
||||
else:
|
||||
cfg = config
|
||||
cfg.train.lr_scheduler = {
|
||||
'type': 'LambdaLR',
|
||||
'lr_lambda': lambda _: 1,
|
||||
'last_epoch': -1
|
||||
}
|
||||
return cfg
|
||||
|
||||
|
||||
kwargs = dict(
|
||||
model=training_args.model,
|
||||
model_revision=args.model_revision,
|
||||
class_prompt=args.class_prompt,
|
||||
instance_prompt=args.instance_prompt,
|
||||
modifier_token=args.modifier_token,
|
||||
num_class_images=args.num_class_images,
|
||||
train_batch_size=args.train_batch_size,
|
||||
sample_batch_size=args.sample_batch_size,
|
||||
initializer_token=args.initializer_token,
|
||||
class_data_dir=args.class_data_dir,
|
||||
resolution=args.resolution,
|
||||
prior_loss_weight=args.prior_loss_weight,
|
||||
freeze_model=args.freeze_model,
|
||||
instance_data_name=args.instance_data_name,
|
||||
work_dir=training_args.work_dir,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=validation_dataset,
|
||||
cfg_modify_fn=cfg_modify_fn)
|
||||
|
||||
# build trainer and training
|
||||
trainer = build_trainer(name=Trainers.custom_diffusion, default_args=kwargs)
|
||||
trainer.train()
|
||||
|
||||
# pipeline after training and save result
|
||||
pipe = pipeline(
|
||||
task=Tasks.text_to_image_synthesis,
|
||||
model=training_args.model,
|
||||
custom_dir=training_args.work_dir + '/output',
|
||||
modifier_token='<new1>',
|
||||
model_revision=args.model_revision)
|
||||
|
||||
output = pipe({'text': args.prompt})
|
||||
# visualize the result on ipynb and save it
|
||||
output
|
||||
cv2.imwrite('./custom_result.png', output['output_imgs'][0])
|
||||
17
examples/pytorch/stable_diffusion/custom/run_train_custom.sh
Normal file
17
examples/pytorch/stable_diffusion/custom/run_train_custom.sh
Normal file
@@ -0,0 +1,17 @@
|
||||
PYTHONPATH=. torchrun examples/pytorch/stable_diffusion/custom/finetune_stable_diffusion_custom.py \
|
||||
--model 'AI-ModelScope/stable-diffusion-v1-5' \
|
||||
--model_revision 'v1.0.9' \
|
||||
--class_prompt "dog" \
|
||||
--instance_prompt="photo of a <new1> dog" \
|
||||
--work_dir './tmp/custom_diffusion' \
|
||||
--class_data_dir './tmp/class_data' \
|
||||
--train_dataset_name 'buptwq/lora-stable-diffusion-finetune-dog' \
|
||||
--max_epochs 2 \
|
||||
--modifier_token "<new1>" \
|
||||
--num_class_images=200 \
|
||||
--save_ckpt_strategy 'by_epoch' \
|
||||
--logging_interval 1 \
|
||||
--train.dataloader.workers_per_gpu 0 \
|
||||
--evaluation.dataloader.workers_per_gpu 0 \
|
||||
--train.optimizer.lr 1e-5 \
|
||||
--use_model_config true
|
||||
Reference in New Issue
Block a user