mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-25 04:30:48 +01:00
1. Merge(add) daily regression from github PR (daily_regression.yaml) 2. Add lora stable diffusion from github PR Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/13010802 * fix: device arg not work, rename device to ngpu (#272) * Correcting the lora stable diffusion example script (#300) * add vad model and punc model in README.md add vad model and punc model * Merge pull request #302 from modelscope/langgz-patch-1 add vad model and punc model in README.md * add 1.6 * modify ignore * Merge pull request #307 from modelscope/dev_rs_16 Merge release 1.6 * undo datetime to 2099 * Merge pull request #311 from modelscope/fix_master_version undo datetime to 2099 * add daily regression workflow * modify workflow name * fix cron format issue * lora trainer * Merge pull request #315 from liuyhwangyh/add_regression_workflow add daily regression workflow
45 lines
1.2 KiB
Python
45 lines
1.2 KiB
Python
from modelscope.metainfo import Trainers
|
|
from modelscope.msdatasets import MsDataset
|
|
from modelscope.trainers import EpochBasedTrainer, build_trainer
|
|
from modelscope.trainers.training_args import TrainingArgs
|
|
from modelscope.utils.constant import DownloadMode
|
|
|
|
training_args = TrainingArgs(task='text-to-image-synthesis').parse_cli()
|
|
config, args = training_args.to_config()
|
|
print(args)
|
|
|
|
train_dataset = MsDataset.load(
|
|
args.train_dataset_name,
|
|
split='train',
|
|
download_mode=DownloadMode.FORCE_REDOWNLOAD)
|
|
validation_dataset = MsDataset.load(
|
|
args.train_dataset_name,
|
|
split='validation',
|
|
download_mode=DownloadMode.FORCE_REDOWNLOAD)
|
|
|
|
|
|
def cfg_modify_fn(cfg):
|
|
if args.use_model_config:
|
|
cfg.merge_from_dict(config)
|
|
else:
|
|
cfg = config
|
|
cfg.train.lr_scheduler = {
|
|
'type': 'LambdaLR',
|
|
'lr_lambda': lambda _: 1,
|
|
'last_epoch': -1
|
|
}
|
|
cfg.train.optimizer.lr = 1e-4
|
|
return cfg
|
|
|
|
|
|
kwargs = dict(
|
|
model=training_args.model,
|
|
model_revision='v1.0.6',
|
|
work_dir=training_args.work_dir,
|
|
train_dataset=train_dataset,
|
|
eval_dataset=validation_dataset,
|
|
cfg_modify_fn=cfg_modify_fn)
|
|
|
|
trainer = build_trainer(name=Trainers.lora_diffusion, default_args=kwargs)
|
|
trainer.train()
|