diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 23ffdab1..d2d8115a 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -1233,6 +1233,7 @@ class Hooks(object): DeepspeedHook = 'DeepspeedHook' MegatronHook = 'MegatronHook' DDPHook = 'DDPHook' + SwiftHook = 'SwiftHook' class LR_Schedulers(object): diff --git a/modelscope/models/multi_modal/efficient_diffusion_tuning/efficient_stable_diffusion.py b/modelscope/models/multi_modal/efficient_diffusion_tuning/efficient_stable_diffusion.py index cec87bad..3830bb52 100644 --- a/modelscope/models/multi_modal/efficient_diffusion_tuning/efficient_stable_diffusion.py +++ b/modelscope/models/multi_modal/efficient_diffusion_tuning/efficient_stable_diffusion.py @@ -86,6 +86,8 @@ class EfficientStableDiffusion(TorchModel): self.pipe.scheduler.config) self.pipe = self.pipe.to(self.device) self.unet = self.pipe.unet + self.text_encoder = self.pipe.text_encoder + self.vae = self.pipe.vae else: # Load scheduler, tokenizer and models. self.noise_scheduler = DDPMScheduler.from_pretrained( @@ -132,12 +134,19 @@ class EfficientStableDiffusion(TorchModel): ) adapter_length = tuner_config[ 'adapter_length'] if tuner_config and 'adapter_length' in tuner_config else 10 - adapter_config = AdapterConfig( - dim=-1, - hidden_pos=0, - target_modules=r'.*ff\.net\.2$', - adapter_length=adapter_length) - self.unet = Swift.prepare_model(self.unet, adapter_config) + adapter_config_dict = {} + dim_list = [320, 640, 1280] + target_modules_list = [r"(down_blocks.0.*ff\.net\.2$)|(up_blocks.3.*ff\.net\.2$)", + r"(down_blocks.1.*ff\.net\.2$)|(up_blocks.2.*ff\.net\.2$)", + r"(down_blocks.2.*ff\.net\.2$)|(up_blocks.1.*ff\.net\.2$)|(mid_block.*ff\.net\.2$)"] + for dim, target_modules in zip(dim_list, target_modules_list): + adapter_config = AdapterConfig( + dim=dim, + hidden_pos=0, + target_modules=target_modules, + adapter_length=adapter_length) + adapter_config_dict[f"adapter_{dim}"] = adapter_config + self.unet = Swift.prepare_model(self.unet, adapter_config_dict) elif tuner_name == 'swift-prompt': if not is_swift_available(): raise ValueError( @@ -154,7 +163,8 @@ class EfficientStableDiffusion(TorchModel): r'.*[down_blocks|up_blocks|mid_block]\.\d+\.attentions\.\d+\.transformer_blocks\.\d+$', embedding_pos=0, prompt_length=prompt_length, - attach_front=False) + attach_front=False, + extract_embedding=True) self.unet = Swift.prepare_model(self.unet, prompt_config) elif tuner_name in ('lora', 'control_lora'): # if not set the config of control-tuner, we add the lora tuner directly to the original framework, @@ -181,13 +191,13 @@ class EfficientStableDiffusion(TorchModel): else: super().load_state_dict(state_dict=state_dict, strict=strict) - def state_dict(self): + def state_dict(self, *arg, **kwargs): if hasattr(self, 'tuner'): - return self.tuner.state_dict() - elif self.tuner_name.startswith('swift'): - return self.unet.state_dict() + return self.tuner.state_dict(*arg, **kwargs) + elif self.tuner_name.startswith('swift-'): + return self.unet.state_dict(*arg, **kwargs) else: - return super().state_dict() + return super().state_dict(*arg, **kwargs) def tokenize_caption(self, captions): """ Convert caption text to token data. @@ -204,7 +214,7 @@ class EfficientStableDiffusion(TorchModel): return_tensors='pt') return inputs.input_ids - def forward(self, prompt='', cond=None, target=None, **args): + def forward(self, prompt, cond=None, target=None, **args): if self.inference: if 'generator_seed' in args and isinstance(args['generator_seed'], int): @@ -213,11 +223,13 @@ class EfficientStableDiffusion(TorchModel): else: generator = None num_inference_steps = args.get('num_inference_steps', 30) + guidance_scale = args.get('guidance_scale', 7.5) if self.is_control: _ = self.tuner(cond.to(self.device)).control_states images = self.pipe( prompt, num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, generator=generator).images return images else: @@ -243,8 +255,8 @@ class EfficientStableDiffusion(TorchModel): input_ids = self.tokenize_caption(prompt).to(self.device) # Get the text embedding for conditioning - with torch.no_grad(): - encoder_hidden_states = self.text_encoder(input_ids)[0] + # with torch.no_grad(): + encoder_hidden_states = self.text_encoder(input_ids)[0] # Inject control states to unet if self.is_control: diff --git a/modelscope/preprocessors/multi_modal.py b/modelscope/preprocessors/multi_modal.py index d180289b..54ad6e97 100644 --- a/modelscope/preprocessors/multi_modal.py +++ b/modelscope/preprocessors/multi_modal.py @@ -53,10 +53,13 @@ class DiffusionImageGenerationPreprocessor(Preprocessor): self.preprocessor_mean = kwargs.pop('mean', [0.5]) self.preprocessor_std = kwargs.pop('std', [0.5]) self.preprocessor_image_keys = set(kwargs.pop('image_keys', [])) + self.center_crop = kwargs.pop('center_crop', True) + self.transform_input = transforms.Compose([ transforms.Resize( self.preprocessor_resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(self.preprocessor_resolution) if self.center_crop else transforms.RandomCrop(self.preprocessor_resolution), transforms.ToTensor(), transforms.Normalize(self.preprocessor_mean, self.preprocessor_std), diff --git a/modelscope/trainers/hooks/__init__.py b/modelscope/trainers/hooks/__init__.py index 072105be..a51c50e8 100644 --- a/modelscope/trainers/hooks/__init__.py +++ b/modelscope/trainers/hooks/__init__.py @@ -19,6 +19,7 @@ if TYPE_CHECKING: from .distributed.ddp_hook import DDPHook from .distributed.deepspeed_hook import DeepspeedHook from .distributed.megatron_hook import MegatronHook + from .swift.swift_hook import SwiftHook else: _import_structure = { @@ -40,6 +41,7 @@ else: 'distributed.ddp_hook': ['DDPHook'], 'distributed.deepspeed_hook': ['DeepspeedHook'], 'distributed.megatron_hook': ['MegatronHook'], + 'swift.swift_hook': ['SwiftHook'], 'priority': ['Priority', 'get_priority'] } diff --git a/modelscope/trainers/hooks/swift/__init__.py b/modelscope/trainers/hooks/swift/__init__.py new file mode 100644 index 00000000..daf16f92 --- /dev/null +++ b/modelscope/trainers/hooks/swift/__init__.py @@ -0,0 +1 @@ +from .swift_hook import SwiftHook \ No newline at end of file diff --git a/modelscope/trainers/hooks/swift/swift_hook.py b/modelscope/trainers/hooks/swift/swift_hook.py new file mode 100644 index 00000000..262dd483 --- /dev/null +++ b/modelscope/trainers/hooks/swift/swift_hook.py @@ -0,0 +1,131 @@ +import os +import shutil + +from modelscope.metainfo import Hooks +from modelscope.trainers import EpochBasedTrainer +from modelscope.trainers.hooks.builder import HOOKS +from modelscope.trainers.hooks.checkpoint.checkpoint_hook import ( + BestCkptSaverHook, CheckpointHook, CheckpointProcessor) +from modelscope.trainers.hooks.checkpoint.load_checkpoint_hook import \ + LoadCheckpointHook +from modelscope.trainers.hooks.hook import Hook +from modelscope.utils.import_utils import is_swift_available +from modelscope.utils.checkpoint import save_configuration + + +class SwiftCheckpointProcessor(CheckpointProcessor): + + _BIN_FILE_DIR = 'model' + SWIFT_SAVE_SUFFIX = '_swift' + + @staticmethod + def copy_files_and_dump_config(trainer, output_dir, config, bin_file): + """Copy useful files to target output folder and dumps the target configuration.json. + """ + model = trainer.unwrap_module(trainer.model) + + class SaveConfig: + + def __init__(self, output_dir, config): + self.output_dir = output_dir + self.config = config + + def __call__(self, _output_dir, _config): + self.config = _config + + def save_config(self): + save_configuration(self.output_dir, self.config) + + for pop_key in [ + 'push_to_hub', 'hub_repo_id', 'hub_token', 'private_hub' + ]: + if config.safe_get('train.checkpoint.period.' + + pop_key) is not None: + config.safe_get('train.checkpoint.period').pop(pop_key) + if config.safe_get('train.checkpoint.best.' + pop_key) is not None: + config.safe_get('train.checkpoint.best').pop(pop_key) + + save_config_fn = SaveConfig(output_dir, config) + + if hasattr(model, 'save_pretrained'): + if not is_swift_available(): + raise ValueError( + 'Please install swift by `pip install ms-swift` to use SwiftHook.' + ) + from swift import SwiftModel + if isinstance(model, SwiftModel): + _swift_output_dir = output_dir + SwiftCheckpointProcessor.SWIFT_SAVE_SUFFIX + model.save_pretrained( + save_directory=_swift_output_dir, + safe_serialization=config.safe_get('train.checkpoint.safe_serialization', False), + adapter_name=config.safe_get('train.checkpoint.adapter_name', 'default') + ) + else: + model.save_pretrained( + output_dir, + bin_file, + save_function=lambda *args, **kwargs: None, + config=save_config_fn.config, + save_config_function=save_config_fn) + + if trainer.train_preprocessor is not None: + trainer.train_preprocessor.save_pretrained( + output_dir, + save_config_fn.config, + save_config_function=save_config_fn) + if trainer.eval_preprocessor is not None: + trainer.eval_preprocessor.save_pretrained( + output_dir, + save_config_fn.config, + save_config_function=save_config_fn) + save_config_fn.save_config() + + def link_dir(self, source_dir, output_dir): + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + shutil.copytree(source_dir, output_dir) + + def save_swift_model_state(self, model, filename): + model.save_pretrained(filename) + + def save_checkpoints(self, + trainer, + checkpoint_path_prefix, + output_dir, + meta=None, + save_optimizers=True): + model = trainer.unwrap_module(trainer.model) + _model_file, _train_state_file = self._get_state_file_name( + checkpoint_path_prefix) + _swift_save_dir = checkpoint_path_prefix + SwiftCheckpointProcessor.SWIFT_SAVE_SUFFIX + _swift_output_dir = output_dir + SwiftCheckpointProcessor.SWIFT_SAVE_SUFFIX + self.save_trainer_state(trainer, model, _train_state_file, meta, + save_optimizers) + self.save_model_state(model, _model_file) + self.link(model, _model_file, output_dir) + self.save_swift_model_state(model, _swift_save_dir) + self.link_dir(_swift_save_dir, _swift_output_dir) + + +@HOOKS.register_module(module_name=Hooks.SwiftHook) +class SwiftHook(Hook): + + _BIN_FILE_DIR = 'model' + + def __init__(self): + pass + + def register_processor(self, trainer: EpochBasedTrainer): + processor = SwiftCheckpointProcessor() + ckpt_hook = trainer.get_hook(CheckpointHook) + if len(ckpt_hook) > 0 and not isinstance(ckpt_hook[0].processor, + SwiftCheckpointProcessor): + ckpt_hook[0].set_processor(processor) + best_ckpt_hook = trainer.get_hook(BestCkptSaverHook) + if len(best_ckpt_hook) > 0 and not isinstance( + best_ckpt_hook[0].processor, SwiftCheckpointProcessor): + best_ckpt_hook[0].set_processor(processor) + load_ckpt_hook = trainer.get_hook(LoadCheckpointHook) + if len(load_ckpt_hook) > 0 and not isinstance( + load_ckpt_hook[0].processor, SwiftCheckpointProcessor): + load_ckpt_hook[0].set_processor(processor) diff --git a/tests/trainers/test_efficient_diffusion_tuning_trainer_swift.py b/tests/trainers/test_efficient_diffusion_tuning_trainer_swift.py index c661b8ee..c05e504c 100644 --- a/tests/trainers/test_efficient_diffusion_tuning_trainer_swift.py +++ b/tests/trainers/test_efficient_diffusion_tuning_trainer_swift.py @@ -22,7 +22,7 @@ class TestEfficientDiffusionTuningTrainerSwift(unittest.TestCase): split='train', subset_name='Anime').remap_columns({'Image:FILE': 'target:FILE'}) - self.max_epochs = 30 + self.max_epochs = 1 self.lr = 0.0001 self.tmp_dir = tempfile.TemporaryDirectory().name