From cfa363d433dcecfbc9ba252346d287bc36bc1b73 Mon Sep 17 00:00:00 2001 From: XDUWQ <1300964705@qq.com> Date: Wed, 12 Jul 2023 17:35:57 +0800 Subject: [PATCH] custom diffusion --- .../multi_modal/custom_diffusion/custom_diffusion_trainer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/modelscope/trainers/multi_modal/custom_diffusion/custom_diffusion_trainer.py b/modelscope/trainers/multi_modal/custom_diffusion/custom_diffusion_trainer.py index a0a1861e..99c6cb88 100644 --- a/modelscope/trainers/multi_modal/custom_diffusion/custom_diffusion_trainer.py +++ b/modelscope/trainers/multi_modal/custom_diffusion/custom_diffusion_trainer.py @@ -585,9 +585,8 @@ class CustomDiffusionTrainer(EpochBasedTrainer): index_grads_to_zero = torch.arange( len(self.model.tokenizer)) != self.modifier_token_id[0] for i in range(len(self.modifier_token_id[1:])): - index_grads_to_zero = index_grads_to_zero & ( - torch.arange(len(self.model.tokenizer)) != - self.modifier_token_id[i]) + modifier_flag = torch.arange(len(self.model.tokenizer)) != self.modifier_token_id[i] + index_grads_to_zero = index_grads_to_zero & modifier_flag grads_data = grads_text_encoder.data[ index_grads_to_zero, :].fill_(0) grads_text_encoder.data[