mirror of
https://github.com/AIGC-Audio/AudioGPT.git
synced 2025-12-16 11:57:58 +01:00
14 lines
313 B
Python
14 lines
313 B
Python
import torch
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def clip_bce(output_dict, target_dict):
|
|
"""Binary crossentropy loss.
|
|
"""
|
|
return F.binary_cross_entropy(
|
|
output_dict['clipwise_output'], target_dict['target'])
|
|
|
|
|
|
def get_loss_func(loss_type):
|
|
if loss_type == 'clip_bce':
|
|
return clip_bce |