mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 20:19:22 +01:00
custom diffusion
This commit is contained in:
@@ -706,8 +706,8 @@ class CustomDiffusionTrainer(EpochBasedTrainer):
|
||||
model_pred.float(), target.float(), reduction='none')
|
||||
loss = ((loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean()
|
||||
|
||||
# train_outputs = {}
|
||||
self.train_outputs[OutputKeys.LOSS] = loss
|
||||
train_outputs = {}
|
||||
train_outputs[OutputKeys.LOSS] = loss
|
||||
|
||||
# Zero out the gradients for all token embeddings except the newly added
|
||||
# embeddings for the concept, as we only want to optimize the concept embeddings
|
||||
@@ -725,24 +725,24 @@ class CustomDiffusionTrainer(EpochBasedTrainer):
|
||||
index_grads_to_zero, :] = grads_text_encoder.data[
|
||||
index_grads_to_zero, :].fill_(0)
|
||||
|
||||
# # add model output info to log
|
||||
# if 'log_vars' not in train_outputs:
|
||||
# default_keys_pattern = ['loss']
|
||||
# match_keys = set([])
|
||||
# for key_p in default_keys_pattern:
|
||||
# match_keys.update(
|
||||
# [key for key in train_outputs.keys() if key_p in key])
|
||||
# add model output info to log
|
||||
if 'log_vars' not in train_outputs:
|
||||
default_keys_pattern = ['loss']
|
||||
match_keys = set([])
|
||||
for key_p in default_keys_pattern:
|
||||
match_keys.update(
|
||||
[key for key in train_outputs.keys() if key_p in key])
|
||||
|
||||
# log_vars = {}
|
||||
# for key in match_keys:
|
||||
# value = train_outputs.get(key, None)
|
||||
# if value is not None:
|
||||
# if is_dist():
|
||||
# value = value.data.clone().to('cuda')
|
||||
# dist.all_reduce(value.div_(dist.get_world_size()))
|
||||
# log_vars.update({key: value.item()})
|
||||
# self.log_buffer.update(log_vars)
|
||||
# else:
|
||||
# self.log_buffer.update(train_outputs['log_vars'])
|
||||
log_vars = {}
|
||||
for key in match_keys:
|
||||
value = train_outputs.get(key, None)
|
||||
if value is not None:
|
||||
if is_dist():
|
||||
value = value.data.clone().to('cuda')
|
||||
dist.all_reduce(value.div_(dist.get_world_size()))
|
||||
log_vars.update({key: value.item()})
|
||||
self.log_buffer.update(log_vars)
|
||||
else:
|
||||
self.log_buffer.update(train_outputs['log_vars'])
|
||||
|
||||
# self.train_outputs = train_outputs
|
||||
self.train_outputs = train_outputs
|
||||
|
||||
Reference in New Issue
Block a user