mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
fix_gpu_bug
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/14822269
This commit is contained in:
@@ -19,6 +19,7 @@ from modelscope.metainfo import Models
|
||||
from modelscope.models import MODELS, TorchModel
|
||||
from modelscope.models.audio.sv.fusion import AFF
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.device import create_device
|
||||
|
||||
|
||||
class ReLU(nn.Hardtanh):
|
||||
@@ -314,6 +315,7 @@ class SpeakerVerificationERes2Net(TorchModel):
|
||||
self.m_channels = self.model_config['channels']
|
||||
self.other_config = kwargs
|
||||
self.feature_dim = 80
|
||||
self.device = create_device(self.other_config['device'])
|
||||
|
||||
self.embedding_model = ERes2Net(
|
||||
embed_dim=self.embed_dim, m_channels=self.m_channels)
|
||||
@@ -321,6 +323,7 @@ class SpeakerVerificationERes2Net(TorchModel):
|
||||
pretrained_model_name = kwargs['pretrained_model']
|
||||
self.__load_check_point(pretrained_model_name)
|
||||
|
||||
self.embedding_model.to(self.device)
|
||||
self.embedding_model.eval()
|
||||
|
||||
def forward(self, audio):
|
||||
@@ -333,7 +336,7 @@ class SpeakerVerificationERes2Net(TorchModel):
|
||||
) == 2, 'modelscope error: the shape of input audio to model needs to be [N, T]'
|
||||
# audio shape: [N, T]
|
||||
feature = self.__extract_feature(audio)
|
||||
embedding = self.embedding_model(feature)
|
||||
embedding = self.embedding_model(feature.to(self.device))
|
||||
|
||||
return embedding.detach().cpu()
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from modelscope.metainfo import Models
|
||||
from modelscope.models import MODELS, TorchModel
|
||||
from modelscope.models.audio.sv.fusion import AFF
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.device import create_device
|
||||
|
||||
|
||||
class ReLU(nn.Hardtanh):
|
||||
@@ -308,12 +309,13 @@ class SpeakerVerificationERes2Net(TorchModel):
|
||||
self.model_config = model_config
|
||||
self.other_config = kwargs
|
||||
self.feature_dim = 80
|
||||
|
||||
self.device = create_device(self.other_config['device'])
|
||||
self.embedding_model = ERes2Net_aug()
|
||||
|
||||
pretrained_model_name = kwargs['pretrained_model']
|
||||
self.__load_check_point(pretrained_model_name)
|
||||
|
||||
self.embedding_model.to(self.device)
|
||||
self.embedding_model.eval()
|
||||
|
||||
def forward(self, audio):
|
||||
@@ -326,7 +328,7 @@ class SpeakerVerificationERes2Net(TorchModel):
|
||||
) == 2, 'modelscope error: the shape of input audio to model needs to be [N, T]'
|
||||
# audio shape: [N, T]
|
||||
feature = self.__extract_feature(audio)
|
||||
embedding = self.embedding_model(feature)
|
||||
embedding = self.embedding_model(feature.to(self.device))
|
||||
|
||||
return embedding.detach().cpu()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user