This commit is contained in:
yangdongchao
2023-04-06 22:18:47 +08:00
parent 322ed8cbb2
commit e3a7194d59

View File

@@ -1113,9 +1113,9 @@ class RaDur_fusion(nn.Module):
self.detection = CDur_CNN_mul_scale_fusion(inputdim, outputdim, time_resolution)
self.softmax = nn.Softmax(dim=2)
#self.temperature = 5
if model_config['pre_train']:
self.encoder.load_state_dict(torch.load(model_config['encoder_path'])['model'])
self.detection.load_state_dict(torch.load(model_config['CDur_path']))
# if model_config['pre_train']:
# self.encoder.load_state_dict(torch.load(model_config['encoder_path'])['model'])
# self.detection.load_state_dict(torch.load(model_config['CDur_path']))
self.q = nn.Linear(128,128)
self.k = nn.Linear(128,128)