Compatible with Swift on SD Tuner (#554)

Co-authored-by: zeyinzi.jzyz <zeyinzi.jzyz@alibaba-inc.com>
This commit is contained in:
jiangzeyinzi
2023-09-21 16:02:31 +08:00
committed by GitHub
parent cd976a366a
commit 3e6acb7998
7 changed files with 166 additions and 16 deletions

View File

@@ -1233,6 +1233,7 @@ class Hooks(object):
DeepspeedHook = 'DeepspeedHook'
MegatronHook = 'MegatronHook'
DDPHook = 'DDPHook'
SwiftHook = 'SwiftHook'
class LR_Schedulers(object):

View File

@@ -86,6 +86,8 @@ class EfficientStableDiffusion(TorchModel):
self.pipe.scheduler.config)
self.pipe = self.pipe.to(self.device)
self.unet = self.pipe.unet
self.text_encoder = self.pipe.text_encoder
self.vae = self.pipe.vae
else:
# Load scheduler, tokenizer and models.
self.noise_scheduler = DDPMScheduler.from_pretrained(
@@ -132,12 +134,19 @@ class EfficientStableDiffusion(TorchModel):
)
adapter_length = tuner_config[
'adapter_length'] if tuner_config and 'adapter_length' in tuner_config else 10
adapter_config = AdapterConfig(
dim=-1,
hidden_pos=0,
target_modules=r'.*ff\.net\.2$',
adapter_length=adapter_length)
self.unet = Swift.prepare_model(self.unet, adapter_config)
adapter_config_dict = {}
dim_list = [320, 640, 1280]
target_modules_list = [r"(down_blocks.0.*ff\.net\.2$)|(up_blocks.3.*ff\.net\.2$)",
r"(down_blocks.1.*ff\.net\.2$)|(up_blocks.2.*ff\.net\.2$)",
r"(down_blocks.2.*ff\.net\.2$)|(up_blocks.1.*ff\.net\.2$)|(mid_block.*ff\.net\.2$)"]
for dim, target_modules in zip(dim_list, target_modules_list):
adapter_config = AdapterConfig(
dim=dim,
hidden_pos=0,
target_modules=target_modules,
adapter_length=adapter_length)
adapter_config_dict[f"adapter_{dim}"] = adapter_config
self.unet = Swift.prepare_model(self.unet, adapter_config_dict)
elif tuner_name == 'swift-prompt':
if not is_swift_available():
raise ValueError(
@@ -154,7 +163,8 @@ class EfficientStableDiffusion(TorchModel):
r'.*[down_blocks|up_blocks|mid_block]\.\d+\.attentions\.\d+\.transformer_blocks\.\d+$',
embedding_pos=0,
prompt_length=prompt_length,
attach_front=False)
attach_front=False,
extract_embedding=True)
self.unet = Swift.prepare_model(self.unet, prompt_config)
elif tuner_name in ('lora', 'control_lora'):
# if not set the config of control-tuner, we add the lora tuner directly to the original framework,
@@ -181,13 +191,13 @@ class EfficientStableDiffusion(TorchModel):
else:
super().load_state_dict(state_dict=state_dict, strict=strict)
def state_dict(self):
def state_dict(self, *arg, **kwargs):
if hasattr(self, 'tuner'):
return self.tuner.state_dict()
elif self.tuner_name.startswith('swift'):
return self.unet.state_dict()
return self.tuner.state_dict(*arg, **kwargs)
elif self.tuner_name.startswith('swift-'):
return self.unet.state_dict(*arg, **kwargs)
else:
return super().state_dict()
return super().state_dict(*arg, **kwargs)
def tokenize_caption(self, captions):
""" Convert caption text to token data.
@@ -204,7 +214,7 @@ class EfficientStableDiffusion(TorchModel):
return_tensors='pt')
return inputs.input_ids
def forward(self, prompt='', cond=None, target=None, **args):
def forward(self, prompt, cond=None, target=None, **args):
if self.inference:
if 'generator_seed' in args and isinstance(args['generator_seed'],
int):
@@ -213,11 +223,13 @@ class EfficientStableDiffusion(TorchModel):
else:
generator = None
num_inference_steps = args.get('num_inference_steps', 30)
guidance_scale = args.get('guidance_scale', 7.5)
if self.is_control:
_ = self.tuner(cond.to(self.device)).control_states
images = self.pipe(
prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator).images
return images
else:
@@ -243,8 +255,8 @@ class EfficientStableDiffusion(TorchModel):
input_ids = self.tokenize_caption(prompt).to(self.device)
# Get the text embedding for conditioning
with torch.no_grad():
encoder_hidden_states = self.text_encoder(input_ids)[0]
# with torch.no_grad():
encoder_hidden_states = self.text_encoder(input_ids)[0]
# Inject control states to unet
if self.is_control:

View File

@@ -53,10 +53,13 @@ class DiffusionImageGenerationPreprocessor(Preprocessor):
self.preprocessor_mean = kwargs.pop('mean', [0.5])
self.preprocessor_std = kwargs.pop('std', [0.5])
self.preprocessor_image_keys = set(kwargs.pop('image_keys', []))
self.center_crop = kwargs.pop('center_crop', True)
self.transform_input = transforms.Compose([
transforms.Resize(
self.preprocessor_resolution,
interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(self.preprocessor_resolution) if self.center_crop else transforms.RandomCrop(self.preprocessor_resolution),
transforms.ToTensor(),
transforms.Normalize(self.preprocessor_mean,
self.preprocessor_std),

View File

@@ -19,6 +19,7 @@ if TYPE_CHECKING:
from .distributed.ddp_hook import DDPHook
from .distributed.deepspeed_hook import DeepspeedHook
from .distributed.megatron_hook import MegatronHook
from .swift.swift_hook import SwiftHook
else:
_import_structure = {
@@ -40,6 +41,7 @@ else:
'distributed.ddp_hook': ['DDPHook'],
'distributed.deepspeed_hook': ['DeepspeedHook'],
'distributed.megatron_hook': ['MegatronHook'],
'swift.swift_hook': ['SwiftHook'],
'priority': ['Priority', 'get_priority']
}

View File

@@ -0,0 +1 @@
from .swift_hook import SwiftHook

View File

@@ -0,0 +1,131 @@
import os
import shutil
from modelscope.metainfo import Hooks
from modelscope.trainers import EpochBasedTrainer
from modelscope.trainers.hooks.builder import HOOKS
from modelscope.trainers.hooks.checkpoint.checkpoint_hook import (
BestCkptSaverHook, CheckpointHook, CheckpointProcessor)
from modelscope.trainers.hooks.checkpoint.load_checkpoint_hook import \
LoadCheckpointHook
from modelscope.trainers.hooks.hook import Hook
from modelscope.utils.import_utils import is_swift_available
from modelscope.utils.checkpoint import save_configuration
class SwiftCheckpointProcessor(CheckpointProcessor):
_BIN_FILE_DIR = 'model'
SWIFT_SAVE_SUFFIX = '_swift'
@staticmethod
def copy_files_and_dump_config(trainer, output_dir, config, bin_file):
"""Copy useful files to target output folder and dumps the target configuration.json.
"""
model = trainer.unwrap_module(trainer.model)
class SaveConfig:
def __init__(self, output_dir, config):
self.output_dir = output_dir
self.config = config
def __call__(self, _output_dir, _config):
self.config = _config
def save_config(self):
save_configuration(self.output_dir, self.config)
for pop_key in [
'push_to_hub', 'hub_repo_id', 'hub_token', 'private_hub'
]:
if config.safe_get('train.checkpoint.period.'
+ pop_key) is not None:
config.safe_get('train.checkpoint.period').pop(pop_key)
if config.safe_get('train.checkpoint.best.' + pop_key) is not None:
config.safe_get('train.checkpoint.best').pop(pop_key)
save_config_fn = SaveConfig(output_dir, config)
if hasattr(model, 'save_pretrained'):
if not is_swift_available():
raise ValueError(
'Please install swift by `pip install ms-swift` to use SwiftHook.'
)
from swift import SwiftModel
if isinstance(model, SwiftModel):
_swift_output_dir = output_dir + SwiftCheckpointProcessor.SWIFT_SAVE_SUFFIX
model.save_pretrained(
save_directory=_swift_output_dir,
safe_serialization=config.safe_get('train.checkpoint.safe_serialization', False),
adapter_name=config.safe_get('train.checkpoint.adapter_name', 'default')
)
else:
model.save_pretrained(
output_dir,
bin_file,
save_function=lambda *args, **kwargs: None,
config=save_config_fn.config,
save_config_function=save_config_fn)
if trainer.train_preprocessor is not None:
trainer.train_preprocessor.save_pretrained(
output_dir,
save_config_fn.config,
save_config_function=save_config_fn)
if trainer.eval_preprocessor is not None:
trainer.eval_preprocessor.save_pretrained(
output_dir,
save_config_fn.config,
save_config_function=save_config_fn)
save_config_fn.save_config()
def link_dir(self, source_dir, output_dir):
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
shutil.copytree(source_dir, output_dir)
def save_swift_model_state(self, model, filename):
model.save_pretrained(filename)
def save_checkpoints(self,
trainer,
checkpoint_path_prefix,
output_dir,
meta=None,
save_optimizers=True):
model = trainer.unwrap_module(trainer.model)
_model_file, _train_state_file = self._get_state_file_name(
checkpoint_path_prefix)
_swift_save_dir = checkpoint_path_prefix + SwiftCheckpointProcessor.SWIFT_SAVE_SUFFIX
_swift_output_dir = output_dir + SwiftCheckpointProcessor.SWIFT_SAVE_SUFFIX
self.save_trainer_state(trainer, model, _train_state_file, meta,
save_optimizers)
self.save_model_state(model, _model_file)
self.link(model, _model_file, output_dir)
self.save_swift_model_state(model, _swift_save_dir)
self.link_dir(_swift_save_dir, _swift_output_dir)
@HOOKS.register_module(module_name=Hooks.SwiftHook)
class SwiftHook(Hook):
_BIN_FILE_DIR = 'model'
def __init__(self):
pass
def register_processor(self, trainer: EpochBasedTrainer):
processor = SwiftCheckpointProcessor()
ckpt_hook = trainer.get_hook(CheckpointHook)
if len(ckpt_hook) > 0 and not isinstance(ckpt_hook[0].processor,
SwiftCheckpointProcessor):
ckpt_hook[0].set_processor(processor)
best_ckpt_hook = trainer.get_hook(BestCkptSaverHook)
if len(best_ckpt_hook) > 0 and not isinstance(
best_ckpt_hook[0].processor, SwiftCheckpointProcessor):
best_ckpt_hook[0].set_processor(processor)
load_ckpt_hook = trainer.get_hook(LoadCheckpointHook)
if len(load_ckpt_hook) > 0 and not isinstance(
load_ckpt_hook[0].processor, SwiftCheckpointProcessor):
load_ckpt_hook[0].set_processor(processor)

View File

@@ -22,7 +22,7 @@ class TestEfficientDiffusionTuningTrainerSwift(unittest.TestCase):
split='train',
subset_name='Anime').remap_columns({'Image:FILE': 'target:FILE'})
self.max_epochs = 30
self.max_epochs = 1
self.lr = 0.0001
self.tmp_dir = tempfile.TemporaryDirectory().name