Files
AudioGPT/audio_detection/audio_infer/pytorch/losses.py
2023-03-29 21:20:32 +08:00

14 lines
313 B
Python

import torch
import torch.nn.functional as F
def clip_bce(output_dict, target_dict):
"""Binary crossentropy loss.
"""
return F.binary_cross_entropy(
output_dict['clipwise_output'], target_dict['target'])
def get_loss_func(loss_type):
if loss_type == 'clip_bce':
return clip_bce