mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-23 03:29:27 +01:00
add 1.6
This commit is contained in:
@@ -4,30 +4,32 @@ from modelscope.msdatasets import MsDataset
|
||||
from modelscope.trainers import EpochBasedTrainer, build_trainer
|
||||
from modelscope.trainers.training_args import TrainingArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class StableDiffusionArguments(TrainingArgs):
|
||||
|
||||
def __call__(self, config):
|
||||
config = super().__call__(config)
|
||||
config.train.lr_scheduler.T_max = self.max_epochs
|
||||
config.model.inference = False
|
||||
return config
|
||||
|
||||
|
||||
args = StableDiffusionArguments.from_cli(task='efficient-diffusion-tuning')
|
||||
training_args = TrainingArgs(task='efficient-diffusion-tuning').parse_cli()
|
||||
config, args = training_args.to_config()
|
||||
print(args)
|
||||
|
||||
dataset = MsDataset.load(args.dataset_name, namespace=args.namespace)
|
||||
dataset = MsDataset.load(
|
||||
args.train_dataset_name, namespace=args.train_dataset_namespace)
|
||||
train_dataset = dataset['train']
|
||||
validation_dataset = dataset['validation']
|
||||
|
||||
|
||||
def cfg_modify_fn(cfg):
|
||||
if args.use_model_config:
|
||||
cfg.merge_from_dict(config)
|
||||
else:
|
||||
cfg = config
|
||||
cfg.train.lr_scheduler.T_max = training_args.max_epochs
|
||||
cfg.model.inference = False
|
||||
return cfg
|
||||
|
||||
|
||||
kwargs = dict(
|
||||
model=args.model,
|
||||
work_dir=args.work_dir,
|
||||
model=training_args.model,
|
||||
work_dir=training_args.work_dir,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=validation_dataset,
|
||||
cfg_modify_fn=args)
|
||||
cfg_modify_fn=cfg_modify_fn)
|
||||
|
||||
trainer: EpochBasedTrainer = build_trainer(name='trainer', default_args=kwargs)
|
||||
trainer.train()
|
||||
|
||||
Reference in New Issue
Block a user