chenyafeng.cyf
2023-11-29 10:03:52 +08:00
committed by hemu.zp
parent a19fe73afb
commit ae42543389
2 changed files with 8 additions and 3 deletions

View File

@@ -19,6 +19,7 @@ from modelscope.metainfo import Models
from modelscope.models import MODELS, TorchModel
from modelscope.models.audio.sv.fusion import AFF
from modelscope.utils.constant import Tasks
from modelscope.utils.device import create_device
class ReLU(nn.Hardtanh):
@@ -314,6 +315,7 @@ class SpeakerVerificationERes2Net(TorchModel):
self.m_channels = self.model_config['channels']
self.other_config = kwargs
self.feature_dim = 80
self.device = create_device(self.other_config['device'])
self.embedding_model = ERes2Net(
embed_dim=self.embed_dim, m_channels=self.m_channels)
@@ -321,6 +323,7 @@ class SpeakerVerificationERes2Net(TorchModel):
pretrained_model_name = kwargs['pretrained_model']
self.__load_check_point(pretrained_model_name)
self.embedding_model.to(self.device)
self.embedding_model.eval()
def forward(self, audio):
@@ -333,7 +336,7 @@ class SpeakerVerificationERes2Net(TorchModel):
) == 2, 'modelscope error: the shape of input audio to model needs to be [N, T]'
# audio shape: [N, T]
feature = self.__extract_feature(audio)
embedding = self.embedding_model(feature)
embedding = self.embedding_model(feature.to(self.device))
return embedding.detach().cpu()

View File

@@ -19,6 +19,7 @@ from modelscope.metainfo import Models
from modelscope.models import MODELS, TorchModel
from modelscope.models.audio.sv.fusion import AFF
from modelscope.utils.constant import Tasks
from modelscope.utils.device import create_device
class ReLU(nn.Hardtanh):
@@ -308,12 +309,13 @@ class SpeakerVerificationERes2Net(TorchModel):
self.model_config = model_config
self.other_config = kwargs
self.feature_dim = 80
self.device = create_device(self.other_config['device'])
self.embedding_model = ERes2Net_aug()
pretrained_model_name = kwargs['pretrained_model']
self.__load_check_point(pretrained_model_name)
self.embedding_model.to(self.device)
self.embedding_model.eval()
def forward(self, audio):
@@ -326,7 +328,7 @@ class SpeakerVerificationERes2Net(TorchModel):
) == 2, 'modelscope error: the shape of input audio to model needs to be [N, T]'
# audio shape: [N, T]
feature = self.__extract_feature(audio)
embedding = self.embedding_model(feature)
embedding = self.embedding_model(feature.to(self.device))
return embedding.detach().cpu()