mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 04:29:22 +01:00
custom diffusion
This commit is contained in:
@@ -585,9 +585,8 @@ class CustomDiffusionTrainer(EpochBasedTrainer):
|
||||
index_grads_to_zero = torch.arange(
|
||||
len(self.model.tokenizer)) != self.modifier_token_id[0]
|
||||
for i in range(len(self.modifier_token_id[1:])):
|
||||
index_grads_to_zero = index_grads_to_zero & (
|
||||
torch.arange(len(self.model.tokenizer)) !=
|
||||
self.modifier_token_id[i])
|
||||
modifier_flag = torch.arange(len(self.model.tokenizer)) != self.modifier_token_id[i]
|
||||
index_grads_to_zero = index_grads_to_zero & modifier_flag
|
||||
grads_data = grads_text_encoder.data[
|
||||
index_grads_to_zero, :].fill_(0)
|
||||
grads_text_encoder.data[
|
||||
|
||||
Reference in New Issue
Block a user