custom diffusion

This commit is contained in:
XDUWQ
2023-07-12 10:14:00 +08:00
parent 2493a7d0f4
commit 2fb3665c67

View File

@@ -706,8 +706,8 @@ class CustomDiffusionTrainer(EpochBasedTrainer):
model_pred.float(), target.float(), reduction='none')
loss = ((loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean()
# train_outputs = {}
self.train_outputs[OutputKeys.LOSS] = loss
train_outputs = {}
train_outputs[OutputKeys.LOSS] = loss
# Zero out the gradients for all token embeddings except the newly added
# embeddings for the concept, as we only want to optimize the concept embeddings
@@ -725,24 +725,24 @@ class CustomDiffusionTrainer(EpochBasedTrainer):
index_grads_to_zero, :] = grads_text_encoder.data[
index_grads_to_zero, :].fill_(0)
# # add model output info to log
# if 'log_vars' not in train_outputs:
# default_keys_pattern = ['loss']
# match_keys = set([])
# for key_p in default_keys_pattern:
# match_keys.update(
# [key for key in train_outputs.keys() if key_p in key])
# add model output info to log
if 'log_vars' not in train_outputs:
default_keys_pattern = ['loss']
match_keys = set([])
for key_p in default_keys_pattern:
match_keys.update(
[key for key in train_outputs.keys() if key_p in key])
# log_vars = {}
# for key in match_keys:
# value = train_outputs.get(key, None)
# if value is not None:
# if is_dist():
# value = value.data.clone().to('cuda')
# dist.all_reduce(value.div_(dist.get_world_size()))
# log_vars.update({key: value.item()})
# self.log_buffer.update(log_vars)
# else:
# self.log_buffer.update(train_outputs['log_vars'])
log_vars = {}
for key in match_keys:
value = train_outputs.get(key, None)
if value is not None:
if is_dist():
value = value.data.clone().to('cuda')
dist.all_reduce(value.div_(dist.get_world_size()))
log_vars.update({key: value.item()})
self.log_buffer.update(log_vars)
else:
self.log_buffer.update(train_outputs['log_vars'])
# self.train_outputs = train_outputs
self.train_outputs = train_outputs