This commit is contained in:
xingjun.wang
2023-05-22 10:53:18 +08:00
parent 52aea36c12
commit 48c0d2a9af
468 changed files with 12942 additions and 7176 deletions

View File

@@ -4,30 +4,32 @@ from modelscope.msdatasets import MsDataset
from modelscope.trainers import EpochBasedTrainer, build_trainer
from modelscope.trainers.training_args import TrainingArgs
@dataclass
class StableDiffusionArguments(TrainingArgs):
def __call__(self, config):
config = super().__call__(config)
config.train.lr_scheduler.T_max = self.max_epochs
config.model.inference = False
return config
args = StableDiffusionArguments.from_cli(task='efficient-diffusion-tuning')
training_args = TrainingArgs(task='efficient-diffusion-tuning').parse_cli()
config, args = training_args.to_config()
print(args)
dataset = MsDataset.load(args.dataset_name, namespace=args.namespace)
dataset = MsDataset.load(
args.train_dataset_name, namespace=args.train_dataset_namespace)
train_dataset = dataset['train']
validation_dataset = dataset['validation']
def cfg_modify_fn(cfg):
if args.use_model_config:
cfg.merge_from_dict(config)
else:
cfg = config
cfg.train.lr_scheduler.T_max = training_args.max_epochs
cfg.model.inference = False
return cfg
kwargs = dict(
model=args.model,
work_dir=args.work_dir,
model=training_args.model,
work_dir=training_args.work_dir,
train_dataset=train_dataset,
eval_dataset=validation_dataset,
cfg_modify_fn=args)
cfg_modify_fn=cfg_modify_fn)
trainer: EpochBasedTrainer = build_trainer(name='trainer', default_args=kwargs)
trainer.train()