mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
Compatible with Swift on SD Tuner (#554)
Co-authored-by: zeyinzi.jzyz <zeyinzi.jzyz@alibaba-inc.com>
This commit is contained in:
@@ -1233,6 +1233,7 @@ class Hooks(object):
|
||||
DeepspeedHook = 'DeepspeedHook'
|
||||
MegatronHook = 'MegatronHook'
|
||||
DDPHook = 'DDPHook'
|
||||
SwiftHook = 'SwiftHook'
|
||||
|
||||
|
||||
class LR_Schedulers(object):
|
||||
|
||||
@@ -86,6 +86,8 @@ class EfficientStableDiffusion(TorchModel):
|
||||
self.pipe.scheduler.config)
|
||||
self.pipe = self.pipe.to(self.device)
|
||||
self.unet = self.pipe.unet
|
||||
self.text_encoder = self.pipe.text_encoder
|
||||
self.vae = self.pipe.vae
|
||||
else:
|
||||
# Load scheduler, tokenizer and models.
|
||||
self.noise_scheduler = DDPMScheduler.from_pretrained(
|
||||
@@ -132,12 +134,19 @@ class EfficientStableDiffusion(TorchModel):
|
||||
)
|
||||
adapter_length = tuner_config[
|
||||
'adapter_length'] if tuner_config and 'adapter_length' in tuner_config else 10
|
||||
adapter_config = AdapterConfig(
|
||||
dim=-1,
|
||||
hidden_pos=0,
|
||||
target_modules=r'.*ff\.net\.2$',
|
||||
adapter_length=adapter_length)
|
||||
self.unet = Swift.prepare_model(self.unet, adapter_config)
|
||||
adapter_config_dict = {}
|
||||
dim_list = [320, 640, 1280]
|
||||
target_modules_list = [r"(down_blocks.0.*ff\.net\.2$)|(up_blocks.3.*ff\.net\.2$)",
|
||||
r"(down_blocks.1.*ff\.net\.2$)|(up_blocks.2.*ff\.net\.2$)",
|
||||
r"(down_blocks.2.*ff\.net\.2$)|(up_blocks.1.*ff\.net\.2$)|(mid_block.*ff\.net\.2$)"]
|
||||
for dim, target_modules in zip(dim_list, target_modules_list):
|
||||
adapter_config = AdapterConfig(
|
||||
dim=dim,
|
||||
hidden_pos=0,
|
||||
target_modules=target_modules,
|
||||
adapter_length=adapter_length)
|
||||
adapter_config_dict[f"adapter_{dim}"] = adapter_config
|
||||
self.unet = Swift.prepare_model(self.unet, adapter_config_dict)
|
||||
elif tuner_name == 'swift-prompt':
|
||||
if not is_swift_available():
|
||||
raise ValueError(
|
||||
@@ -154,7 +163,8 @@ class EfficientStableDiffusion(TorchModel):
|
||||
r'.*[down_blocks|up_blocks|mid_block]\.\d+\.attentions\.\d+\.transformer_blocks\.\d+$',
|
||||
embedding_pos=0,
|
||||
prompt_length=prompt_length,
|
||||
attach_front=False)
|
||||
attach_front=False,
|
||||
extract_embedding=True)
|
||||
self.unet = Swift.prepare_model(self.unet, prompt_config)
|
||||
elif tuner_name in ('lora', 'control_lora'):
|
||||
# if not set the config of control-tuner, we add the lora tuner directly to the original framework,
|
||||
@@ -181,13 +191,13 @@ class EfficientStableDiffusion(TorchModel):
|
||||
else:
|
||||
super().load_state_dict(state_dict=state_dict, strict=strict)
|
||||
|
||||
def state_dict(self):
|
||||
def state_dict(self, *arg, **kwargs):
|
||||
if hasattr(self, 'tuner'):
|
||||
return self.tuner.state_dict()
|
||||
elif self.tuner_name.startswith('swift'):
|
||||
return self.unet.state_dict()
|
||||
return self.tuner.state_dict(*arg, **kwargs)
|
||||
elif self.tuner_name.startswith('swift-'):
|
||||
return self.unet.state_dict(*arg, **kwargs)
|
||||
else:
|
||||
return super().state_dict()
|
||||
return super().state_dict(*arg, **kwargs)
|
||||
|
||||
def tokenize_caption(self, captions):
|
||||
""" Convert caption text to token data.
|
||||
@@ -204,7 +214,7 @@ class EfficientStableDiffusion(TorchModel):
|
||||
return_tensors='pt')
|
||||
return inputs.input_ids
|
||||
|
||||
def forward(self, prompt='', cond=None, target=None, **args):
|
||||
def forward(self, prompt, cond=None, target=None, **args):
|
||||
if self.inference:
|
||||
if 'generator_seed' in args and isinstance(args['generator_seed'],
|
||||
int):
|
||||
@@ -213,11 +223,13 @@ class EfficientStableDiffusion(TorchModel):
|
||||
else:
|
||||
generator = None
|
||||
num_inference_steps = args.get('num_inference_steps', 30)
|
||||
guidance_scale = args.get('guidance_scale', 7.5)
|
||||
if self.is_control:
|
||||
_ = self.tuner(cond.to(self.device)).control_states
|
||||
images = self.pipe(
|
||||
prompt,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
generator=generator).images
|
||||
return images
|
||||
else:
|
||||
@@ -243,8 +255,8 @@ class EfficientStableDiffusion(TorchModel):
|
||||
input_ids = self.tokenize_caption(prompt).to(self.device)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
with torch.no_grad():
|
||||
encoder_hidden_states = self.text_encoder(input_ids)[0]
|
||||
# with torch.no_grad():
|
||||
encoder_hidden_states = self.text_encoder(input_ids)[0]
|
||||
|
||||
# Inject control states to unet
|
||||
if self.is_control:
|
||||
|
||||
@@ -53,10 +53,13 @@ class DiffusionImageGenerationPreprocessor(Preprocessor):
|
||||
self.preprocessor_mean = kwargs.pop('mean', [0.5])
|
||||
self.preprocessor_std = kwargs.pop('std', [0.5])
|
||||
self.preprocessor_image_keys = set(kwargs.pop('image_keys', []))
|
||||
self.center_crop = kwargs.pop('center_crop', True)
|
||||
|
||||
self.transform_input = transforms.Compose([
|
||||
transforms.Resize(
|
||||
self.preprocessor_resolution,
|
||||
interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(self.preprocessor_resolution) if self.center_crop else transforms.RandomCrop(self.preprocessor_resolution),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(self.preprocessor_mean,
|
||||
self.preprocessor_std),
|
||||
|
||||
@@ -19,6 +19,7 @@ if TYPE_CHECKING:
|
||||
from .distributed.ddp_hook import DDPHook
|
||||
from .distributed.deepspeed_hook import DeepspeedHook
|
||||
from .distributed.megatron_hook import MegatronHook
|
||||
from .swift.swift_hook import SwiftHook
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
@@ -40,6 +41,7 @@ else:
|
||||
'distributed.ddp_hook': ['DDPHook'],
|
||||
'distributed.deepspeed_hook': ['DeepspeedHook'],
|
||||
'distributed.megatron_hook': ['MegatronHook'],
|
||||
'swift.swift_hook': ['SwiftHook'],
|
||||
'priority': ['Priority', 'get_priority']
|
||||
}
|
||||
|
||||
|
||||
1
modelscope/trainers/hooks/swift/__init__.py
Normal file
1
modelscope/trainers/hooks/swift/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .swift_hook import SwiftHook
|
||||
131
modelscope/trainers/hooks/swift/swift_hook.py
Normal file
131
modelscope/trainers/hooks/swift/swift_hook.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from modelscope.metainfo import Hooks
|
||||
from modelscope.trainers import EpochBasedTrainer
|
||||
from modelscope.trainers.hooks.builder import HOOKS
|
||||
from modelscope.trainers.hooks.checkpoint.checkpoint_hook import (
|
||||
BestCkptSaverHook, CheckpointHook, CheckpointProcessor)
|
||||
from modelscope.trainers.hooks.checkpoint.load_checkpoint_hook import \
|
||||
LoadCheckpointHook
|
||||
from modelscope.trainers.hooks.hook import Hook
|
||||
from modelscope.utils.import_utils import is_swift_available
|
||||
from modelscope.utils.checkpoint import save_configuration
|
||||
|
||||
|
||||
class SwiftCheckpointProcessor(CheckpointProcessor):
|
||||
|
||||
_BIN_FILE_DIR = 'model'
|
||||
SWIFT_SAVE_SUFFIX = '_swift'
|
||||
|
||||
@staticmethod
|
||||
def copy_files_and_dump_config(trainer, output_dir, config, bin_file):
|
||||
"""Copy useful files to target output folder and dumps the target configuration.json.
|
||||
"""
|
||||
model = trainer.unwrap_module(trainer.model)
|
||||
|
||||
class SaveConfig:
|
||||
|
||||
def __init__(self, output_dir, config):
|
||||
self.output_dir = output_dir
|
||||
self.config = config
|
||||
|
||||
def __call__(self, _output_dir, _config):
|
||||
self.config = _config
|
||||
|
||||
def save_config(self):
|
||||
save_configuration(self.output_dir, self.config)
|
||||
|
||||
for pop_key in [
|
||||
'push_to_hub', 'hub_repo_id', 'hub_token', 'private_hub'
|
||||
]:
|
||||
if config.safe_get('train.checkpoint.period.'
|
||||
+ pop_key) is not None:
|
||||
config.safe_get('train.checkpoint.period').pop(pop_key)
|
||||
if config.safe_get('train.checkpoint.best.' + pop_key) is not None:
|
||||
config.safe_get('train.checkpoint.best').pop(pop_key)
|
||||
|
||||
save_config_fn = SaveConfig(output_dir, config)
|
||||
|
||||
if hasattr(model, 'save_pretrained'):
|
||||
if not is_swift_available():
|
||||
raise ValueError(
|
||||
'Please install swift by `pip install ms-swift` to use SwiftHook.'
|
||||
)
|
||||
from swift import SwiftModel
|
||||
if isinstance(model, SwiftModel):
|
||||
_swift_output_dir = output_dir + SwiftCheckpointProcessor.SWIFT_SAVE_SUFFIX
|
||||
model.save_pretrained(
|
||||
save_directory=_swift_output_dir,
|
||||
safe_serialization=config.safe_get('train.checkpoint.safe_serialization', False),
|
||||
adapter_name=config.safe_get('train.checkpoint.adapter_name', 'default')
|
||||
)
|
||||
else:
|
||||
model.save_pretrained(
|
||||
output_dir,
|
||||
bin_file,
|
||||
save_function=lambda *args, **kwargs: None,
|
||||
config=save_config_fn.config,
|
||||
save_config_function=save_config_fn)
|
||||
|
||||
if trainer.train_preprocessor is not None:
|
||||
trainer.train_preprocessor.save_pretrained(
|
||||
output_dir,
|
||||
save_config_fn.config,
|
||||
save_config_function=save_config_fn)
|
||||
if trainer.eval_preprocessor is not None:
|
||||
trainer.eval_preprocessor.save_pretrained(
|
||||
output_dir,
|
||||
save_config_fn.config,
|
||||
save_config_function=save_config_fn)
|
||||
save_config_fn.save_config()
|
||||
|
||||
def link_dir(self, source_dir, output_dir):
|
||||
if os.path.exists(output_dir):
|
||||
shutil.rmtree(output_dir)
|
||||
shutil.copytree(source_dir, output_dir)
|
||||
|
||||
def save_swift_model_state(self, model, filename):
|
||||
model.save_pretrained(filename)
|
||||
|
||||
def save_checkpoints(self,
|
||||
trainer,
|
||||
checkpoint_path_prefix,
|
||||
output_dir,
|
||||
meta=None,
|
||||
save_optimizers=True):
|
||||
model = trainer.unwrap_module(trainer.model)
|
||||
_model_file, _train_state_file = self._get_state_file_name(
|
||||
checkpoint_path_prefix)
|
||||
_swift_save_dir = checkpoint_path_prefix + SwiftCheckpointProcessor.SWIFT_SAVE_SUFFIX
|
||||
_swift_output_dir = output_dir + SwiftCheckpointProcessor.SWIFT_SAVE_SUFFIX
|
||||
self.save_trainer_state(trainer, model, _train_state_file, meta,
|
||||
save_optimizers)
|
||||
self.save_model_state(model, _model_file)
|
||||
self.link(model, _model_file, output_dir)
|
||||
self.save_swift_model_state(model, _swift_save_dir)
|
||||
self.link_dir(_swift_save_dir, _swift_output_dir)
|
||||
|
||||
|
||||
@HOOKS.register_module(module_name=Hooks.SwiftHook)
|
||||
class SwiftHook(Hook):
|
||||
|
||||
_BIN_FILE_DIR = 'model'
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def register_processor(self, trainer: EpochBasedTrainer):
|
||||
processor = SwiftCheckpointProcessor()
|
||||
ckpt_hook = trainer.get_hook(CheckpointHook)
|
||||
if len(ckpt_hook) > 0 and not isinstance(ckpt_hook[0].processor,
|
||||
SwiftCheckpointProcessor):
|
||||
ckpt_hook[0].set_processor(processor)
|
||||
best_ckpt_hook = trainer.get_hook(BestCkptSaverHook)
|
||||
if len(best_ckpt_hook) > 0 and not isinstance(
|
||||
best_ckpt_hook[0].processor, SwiftCheckpointProcessor):
|
||||
best_ckpt_hook[0].set_processor(processor)
|
||||
load_ckpt_hook = trainer.get_hook(LoadCheckpointHook)
|
||||
if len(load_ckpt_hook) > 0 and not isinstance(
|
||||
load_ckpt_hook[0].processor, SwiftCheckpointProcessor):
|
||||
load_ckpt_hook[0].set_processor(processor)
|
||||
@@ -22,7 +22,7 @@ class TestEfficientDiffusionTuningTrainerSwift(unittest.TestCase):
|
||||
split='train',
|
||||
subset_name='Anime').remap_columns({'Image:FILE': 'target:FILE'})
|
||||
|
||||
self.max_epochs = 30
|
||||
self.max_epochs = 1
|
||||
self.lr = 0.0001
|
||||
|
||||
self.tmp_dir = tempfile.TemporaryDirectory().name
|
||||
|
||||
Reference in New Issue
Block a user