diff --git a/modelscope/trainers/multi_modal/custom_diffusion/custom_diffusion_trainer.py b/modelscope/trainers/multi_modal/custom_diffusion/custom_diffusion_trainer.py index 334e54e5..38184d7d 100644 --- a/modelscope/trainers/multi_modal/custom_diffusion/custom_diffusion_trainer.py +++ b/modelscope/trainers/multi_modal/custom_diffusion/custom_diffusion_trainer.py @@ -706,8 +706,8 @@ class CustomDiffusionTrainer(EpochBasedTrainer): model_pred.float(), target.float(), reduction='none') loss = ((loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean() - # train_outputs = {} - self.train_outputs[OutputKeys.LOSS] = loss + train_outputs = {} + train_outputs[OutputKeys.LOSS] = loss # Zero out the gradients for all token embeddings except the newly added # embeddings for the concept, as we only want to optimize the concept embeddings @@ -725,24 +725,24 @@ class CustomDiffusionTrainer(EpochBasedTrainer): index_grads_to_zero, :] = grads_text_encoder.data[ index_grads_to_zero, :].fill_(0) - # # add model output info to log - # if 'log_vars' not in train_outputs: - # default_keys_pattern = ['loss'] - # match_keys = set([]) - # for key_p in default_keys_pattern: - # match_keys.update( - # [key for key in train_outputs.keys() if key_p in key]) + # add model output info to log + if 'log_vars' not in train_outputs: + default_keys_pattern = ['loss'] + match_keys = set([]) + for key_p in default_keys_pattern: + match_keys.update( + [key for key in train_outputs.keys() if key_p in key]) - # log_vars = {} - # for key in match_keys: - # value = train_outputs.get(key, None) - # if value is not None: - # if is_dist(): - # value = value.data.clone().to('cuda') - # dist.all_reduce(value.div_(dist.get_world_size())) - # log_vars.update({key: value.item()}) - # self.log_buffer.update(log_vars) - # else: - # self.log_buffer.update(train_outputs['log_vars']) + log_vars = {} + for key in match_keys: + value = train_outputs.get(key, None) + if value is not None: + if is_dist(): + value = value.data.clone().to('cuda') + dist.all_reduce(value.div_(dist.get_world_size())) + log_vars.update({key: value.item()}) + self.log_buffer.update(log_vars) + else: + self.log_buffer.update(train_outputs['log_vars']) - # self.train_outputs = train_outputs + self.train_outputs = train_outputs