mirror of
https://github.com/modelscope/modelscope.git
synced 2026-05-18 13:15:06 +02:00
fix bugs of stable diffusion fp16 (#499)
This commit is contained in:
@@ -4,7 +4,9 @@ from dataclasses import dataclass, field
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
from modelscope import snapshot_download
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.models import Model
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.trainers import EpochBasedTrainer, build_trainer
|
||||
@@ -136,9 +138,18 @@ def cfg_modify_fn(cfg):
|
||||
return cfg
|
||||
|
||||
|
||||
# build model
|
||||
model_dir = snapshot_download(training_args.model)
|
||||
model = Model.from_pretrained(
|
||||
training_args.model,
|
||||
revision=args.model_revision,
|
||||
torch_type=torch.float16
|
||||
if args.torch_type == 'float16' else torch.float32)
|
||||
|
||||
# build trainer and training
|
||||
kwargs = dict(
|
||||
model=training_args.model,
|
||||
model_revision=args.model_revision,
|
||||
model=model,
|
||||
cfg_file=os.path.join(model_dir, 'configuration.json'),
|
||||
class_prompt=args.class_prompt,
|
||||
instance_prompt=args.instance_prompt,
|
||||
modifier_token=args.modifier_token,
|
||||
@@ -159,7 +170,6 @@ kwargs = dict(
|
||||
if args.torch_type == 'float16' else torch.float32,
|
||||
cfg_modify_fn=cfg_modify_fn)
|
||||
|
||||
# build trainer and training
|
||||
trainer = build_trainer(name=Trainers.custom_diffusion, default_args=kwargs)
|
||||
trainer.train()
|
||||
|
||||
|
||||
@@ -4,7 +4,9 @@ from dataclasses import dataclass, field
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
from modelscope import snapshot_download
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.models import Model
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.trainers import build_trainer
|
||||
@@ -100,9 +102,18 @@ def cfg_modify_fn(cfg):
|
||||
return cfg
|
||||
|
||||
|
||||
# build model
|
||||
model_dir = snapshot_download(training_args.model)
|
||||
model = Model.from_pretrained(
|
||||
training_args.model,
|
||||
revision=args.model_revision,
|
||||
torch_type=torch.float16
|
||||
if args.torch_type == 'float16' else torch.float32)
|
||||
|
||||
# build trainer and training
|
||||
kwargs = dict(
|
||||
model=training_args.model,
|
||||
model_revision=args.model_revision,
|
||||
model=model,
|
||||
cfg_file=os.path.join(model_dir, 'configuration.json'),
|
||||
work_dir=training_args.work_dir,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=validation_dataset,
|
||||
@@ -117,7 +128,6 @@ kwargs = dict(
|
||||
if args.torch_type == 'float16' else torch.float32,
|
||||
cfg_modify_fn=cfg_modify_fn)
|
||||
|
||||
# build trainer and training
|
||||
trainer = build_trainer(
|
||||
name=Trainers.dreambooth_diffusion, default_args=kwargs)
|
||||
trainer.train()
|
||||
|
||||
@@ -4,7 +4,9 @@ from dataclasses import dataclass, field
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
from modelscope import snapshot_download
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.models import Model
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.trainers import build_trainer
|
||||
@@ -66,9 +68,18 @@ def cfg_modify_fn(cfg):
|
||||
return cfg
|
||||
|
||||
|
||||
# build model
|
||||
model_dir = snapshot_download(training_args.model)
|
||||
model = Model.from_pretrained(
|
||||
training_args.model,
|
||||
revision=args.model_revision,
|
||||
torch_type=torch.float16
|
||||
if args.torch_type == 'float16' else torch.float32)
|
||||
|
||||
# build trainer and training
|
||||
kwargs = dict(
|
||||
model=training_args.model,
|
||||
model_revision=args.model_revision,
|
||||
model=model,
|
||||
cfg_file=os.path.join(model_dir, 'configuration.json'),
|
||||
work_dir=training_args.work_dir,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=validation_dataset,
|
||||
@@ -77,7 +88,6 @@ kwargs = dict(
|
||||
if args.torch_type == 'float16' else torch.float32,
|
||||
cfg_modify_fn=cfg_modify_fn)
|
||||
|
||||
# build trainer and training
|
||||
trainer = build_trainer(name=Trainers.lora_diffusion, default_args=kwargs)
|
||||
trainer.train()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user