fix bugs of stable diffusion fp16 (#499)

This commit is contained in:
Wang Qiang
2023-08-29 10:16:07 +08:00
committed by GitHub
parent 736984af32
commit 9e6fd10ca5
3 changed files with 39 additions and 9 deletions

View File

@@ -4,7 +4,9 @@ from dataclasses import dataclass, field
import cv2
import torch
from modelscope import snapshot_download
from modelscope.metainfo import Trainers
from modelscope.models import Model
from modelscope.msdatasets import MsDataset
from modelscope.pipelines import pipeline
from modelscope.trainers import EpochBasedTrainer, build_trainer
@@ -136,9 +138,18 @@ def cfg_modify_fn(cfg):
return cfg
# build model
model_dir = snapshot_download(training_args.model)
model = Model.from_pretrained(
training_args.model,
revision=args.model_revision,
torch_type=torch.float16
if args.torch_type == 'float16' else torch.float32)
# build trainer and training
kwargs = dict(
model=training_args.model,
model_revision=args.model_revision,
model=model,
cfg_file=os.path.join(model_dir, 'configuration.json'),
class_prompt=args.class_prompt,
instance_prompt=args.instance_prompt,
modifier_token=args.modifier_token,
@@ -159,7 +170,6 @@ kwargs = dict(
if args.torch_type == 'float16' else torch.float32,
cfg_modify_fn=cfg_modify_fn)
# build trainer and training
trainer = build_trainer(name=Trainers.custom_diffusion, default_args=kwargs)
trainer.train()

View File

@@ -4,7 +4,9 @@ from dataclasses import dataclass, field
import cv2
import torch
from modelscope import snapshot_download
from modelscope.metainfo import Trainers
from modelscope.models import Model
from modelscope.msdatasets import MsDataset
from modelscope.pipelines import pipeline
from modelscope.trainers import build_trainer
@@ -100,9 +102,18 @@ def cfg_modify_fn(cfg):
return cfg
# build model
model_dir = snapshot_download(training_args.model)
model = Model.from_pretrained(
training_args.model,
revision=args.model_revision,
torch_type=torch.float16
if args.torch_type == 'float16' else torch.float32)
# build trainer and training
kwargs = dict(
model=training_args.model,
model_revision=args.model_revision,
model=model,
cfg_file=os.path.join(model_dir, 'configuration.json'),
work_dir=training_args.work_dir,
train_dataset=train_dataset,
eval_dataset=validation_dataset,
@@ -117,7 +128,6 @@ kwargs = dict(
if args.torch_type == 'float16' else torch.float32,
cfg_modify_fn=cfg_modify_fn)
# build trainer and training
trainer = build_trainer(
name=Trainers.dreambooth_diffusion, default_args=kwargs)
trainer.train()

View File

@@ -4,7 +4,9 @@ from dataclasses import dataclass, field
import cv2
import torch
from modelscope import snapshot_download
from modelscope.metainfo import Trainers
from modelscope.models import Model
from modelscope.msdatasets import MsDataset
from modelscope.pipelines import pipeline
from modelscope.trainers import build_trainer
@@ -66,9 +68,18 @@ def cfg_modify_fn(cfg):
return cfg
# build model
model_dir = snapshot_download(training_args.model)
model = Model.from_pretrained(
training_args.model,
revision=args.model_revision,
torch_type=torch.float16
if args.torch_type == 'float16' else torch.float32)
# build trainer and training
kwargs = dict(
model=training_args.model,
model_revision=args.model_revision,
model=model,
cfg_file=os.path.join(model_dir, 'configuration.json'),
work_dir=training_args.work_dir,
train_dataset=train_dataset,
eval_dataset=validation_dataset,
@@ -77,7 +88,6 @@ kwargs = dict(
if args.torch_type == 'float16' else torch.float32,
cfg_modify_fn=cfg_modify_fn)
# build trainer and training
trainer = build_trainer(name=Trainers.lora_diffusion, default_args=kwargs)
trainer.train()