Files
modelscope/examples/pytorch/stable_diffusion/finetune_stable_diffusion.py
xingjun.wxj 0db0ec5586 Merge code from github
1. Merge(add) daily regression from github PR (daily_regression.yaml)
2. Add lora stable diffusion from github PR
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/13010802
* fix: device arg not work, rename device to ngpu (#272)

* Correcting the lora stable diffusion example script (#300)

* add vad model and punc model in README.md 

add vad model and punc model

* Merge pull request #302 from modelscope/langgz-patch-1

add vad model and punc model in README.md

* add 1.6

* modify ignore

* Merge pull request #307 from modelscope/dev_rs_16

Merge release 1.6

* undo datetime to 2099

* Merge pull request #311 from modelscope/fix_master_version

undo datetime to 2099

* add daily regression workflow

* modify workflow name

* fix cron format issue

* lora trainer

* Merge pull request #315 from liuyhwangyh/add_regression_workflow

add daily regression workflow
2023-06-21 10:22:06 +08:00

45 lines
1.2 KiB
Python

from modelscope.metainfo import Trainers
from modelscope.msdatasets import MsDataset
from modelscope.trainers import EpochBasedTrainer, build_trainer
from modelscope.trainers.training_args import TrainingArgs
from modelscope.utils.constant import DownloadMode
training_args = TrainingArgs(task='text-to-image-synthesis').parse_cli()
config, args = training_args.to_config()
print(args)
train_dataset = MsDataset.load(
args.train_dataset_name,
split='train',
download_mode=DownloadMode.FORCE_REDOWNLOAD)
validation_dataset = MsDataset.load(
args.train_dataset_name,
split='validation',
download_mode=DownloadMode.FORCE_REDOWNLOAD)
def cfg_modify_fn(cfg):
if args.use_model_config:
cfg.merge_from_dict(config)
else:
cfg = config
cfg.train.lr_scheduler = {
'type': 'LambdaLR',
'lr_lambda': lambda _: 1,
'last_epoch': -1
}
cfg.train.optimizer.lr = 1e-4
return cfg
kwargs = dict(
model=training_args.model,
model_revision='v1.0.6',
work_dir=training_args.work_dir,
train_dataset=train_dataset,
eval_dataset=validation_dataset,
cfg_modify_fn=cfg_modify_fn)
trainer = build_trainer(name=Trainers.lora_diffusion, default_args=kwargs)
trainer.train()