custom diffusion

This commit is contained in:
XDUWQ
2023-07-12 17:35:57 +08:00
parent c1aa120029
commit cfa363d433

View File

@@ -585,9 +585,8 @@ class CustomDiffusionTrainer(EpochBasedTrainer):
index_grads_to_zero = torch.arange(
len(self.model.tokenizer)) != self.modifier_token_id[0]
for i in range(len(self.modifier_token_id[1:])):
index_grads_to_zero = index_grads_to_zero & (
torch.arange(len(self.model.tokenizer)) !=
self.modifier_token_id[i])
modifier_flag = torch.arange(len(self.model.tokenizer)) != self.modifier_token_id[i]
index_grads_to_zero = index_grads_to_zero & modifier_flag
grads_data = grads_text_encoder.data[
index_grads_to_zero, :].fill_(0)
grads_text_encoder.data[