diff --git a/examples/pytorch/stable_diffusion/custom/finetune_stable_diffusion_custom.py b/examples/pytorch/stable_diffusion/custom/finetune_stable_diffusion_custom.py index 76a050c4..83914127 100644 --- a/examples/pytorch/stable_diffusion/custom/finetune_stable_diffusion_custom.py +++ b/examples/pytorch/stable_diffusion/custom/finetune_stable_diffusion_custom.py @@ -4,7 +4,9 @@ from dataclasses import dataclass, field import cv2 import torch +from modelscope import snapshot_download from modelscope.metainfo import Trainers +from modelscope.models import Model from modelscope.msdatasets import MsDataset from modelscope.pipelines import pipeline from modelscope.trainers import EpochBasedTrainer, build_trainer @@ -136,9 +138,18 @@ def cfg_modify_fn(cfg): return cfg +# build model +model_dir = snapshot_download(training_args.model) +model = Model.from_pretrained( + training_args.model, + revision=args.model_revision, + torch_type=torch.float16 + if args.torch_type == 'float16' else torch.float32) + +# build trainer and training kwargs = dict( - model=training_args.model, - model_revision=args.model_revision, + model=model, + cfg_file=os.path.join(model_dir, 'configuration.json'), class_prompt=args.class_prompt, instance_prompt=args.instance_prompt, modifier_token=args.modifier_token, @@ -159,7 +170,6 @@ kwargs = dict( if args.torch_type == 'float16' else torch.float32, cfg_modify_fn=cfg_modify_fn) -# build trainer and training trainer = build_trainer(name=Trainers.custom_diffusion, default_args=kwargs) trainer.train() diff --git a/examples/pytorch/stable_diffusion/dreambooth/finetune_stable_diffusion_dreambooth.py b/examples/pytorch/stable_diffusion/dreambooth/finetune_stable_diffusion_dreambooth.py index a6bf10d4..2b741ede 100644 --- a/examples/pytorch/stable_diffusion/dreambooth/finetune_stable_diffusion_dreambooth.py +++ b/examples/pytorch/stable_diffusion/dreambooth/finetune_stable_diffusion_dreambooth.py @@ -4,7 +4,9 @@ from dataclasses import dataclass, field import cv2 import torch +from modelscope import snapshot_download from modelscope.metainfo import Trainers +from modelscope.models import Model from modelscope.msdatasets import MsDataset from modelscope.pipelines import pipeline from modelscope.trainers import build_trainer @@ -100,9 +102,18 @@ def cfg_modify_fn(cfg): return cfg +# build model +model_dir = snapshot_download(training_args.model) +model = Model.from_pretrained( + training_args.model, + revision=args.model_revision, + torch_type=torch.float16 + if args.torch_type == 'float16' else torch.float32) + +# build trainer and training kwargs = dict( - model=training_args.model, - model_revision=args.model_revision, + model=model, + cfg_file=os.path.join(model_dir, 'configuration.json'), work_dir=training_args.work_dir, train_dataset=train_dataset, eval_dataset=validation_dataset, @@ -117,7 +128,6 @@ kwargs = dict( if args.torch_type == 'float16' else torch.float32, cfg_modify_fn=cfg_modify_fn) -# build trainer and training trainer = build_trainer( name=Trainers.dreambooth_diffusion, default_args=kwargs) trainer.train() diff --git a/examples/pytorch/stable_diffusion/lora/finetune_stable_diffusion_lora.py b/examples/pytorch/stable_diffusion/lora/finetune_stable_diffusion_lora.py index 97bfad02..b6f9e57a 100644 --- a/examples/pytorch/stable_diffusion/lora/finetune_stable_diffusion_lora.py +++ b/examples/pytorch/stable_diffusion/lora/finetune_stable_diffusion_lora.py @@ -4,7 +4,9 @@ from dataclasses import dataclass, field import cv2 import torch +from modelscope import snapshot_download from modelscope.metainfo import Trainers +from modelscope.models import Model from modelscope.msdatasets import MsDataset from modelscope.pipelines import pipeline from modelscope.trainers import build_trainer @@ -66,9 +68,18 @@ def cfg_modify_fn(cfg): return cfg +# build model +model_dir = snapshot_download(training_args.model) +model = Model.from_pretrained( + training_args.model, + revision=args.model_revision, + torch_type=torch.float16 + if args.torch_type == 'float16' else torch.float32) + +# build trainer and training kwargs = dict( - model=training_args.model, - model_revision=args.model_revision, + model=model, + cfg_file=os.path.join(model_dir, 'configuration.json'), work_dir=training_args.work_dir, train_dataset=train_dataset, eval_dataset=validation_dataset, @@ -77,7 +88,6 @@ kwargs = dict( if args.torch_type == 'float16' else torch.float32, cfg_modify_fn=cfg_modify_fn) -# build trainer and training trainer = build_trainer(name=Trainers.lora_diffusion, default_args=kwargs) trainer.train()