From ae425433895e349b977137e4a67441aa59009715 Mon Sep 17 00:00:00 2001 From: "chenyafeng.cyf" Date: Wed, 29 Nov 2023 10:03:52 +0800 Subject: [PATCH] fix_gpu_bug Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/14822269 --- modelscope/models/audio/sv/ERes2Net.py | 5 ++++- modelscope/models/audio/sv/ERes2Net_aug.py | 6 ++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/modelscope/models/audio/sv/ERes2Net.py b/modelscope/models/audio/sv/ERes2Net.py index 0119783c..3c07390b 100644 --- a/modelscope/models/audio/sv/ERes2Net.py +++ b/modelscope/models/audio/sv/ERes2Net.py @@ -19,6 +19,7 @@ from modelscope.metainfo import Models from modelscope.models import MODELS, TorchModel from modelscope.models.audio.sv.fusion import AFF from modelscope.utils.constant import Tasks +from modelscope.utils.device import create_device class ReLU(nn.Hardtanh): @@ -314,6 +315,7 @@ class SpeakerVerificationERes2Net(TorchModel): self.m_channels = self.model_config['channels'] self.other_config = kwargs self.feature_dim = 80 + self.device = create_device(self.other_config['device']) self.embedding_model = ERes2Net( embed_dim=self.embed_dim, m_channels=self.m_channels) @@ -321,6 +323,7 @@ class SpeakerVerificationERes2Net(TorchModel): pretrained_model_name = kwargs['pretrained_model'] self.__load_check_point(pretrained_model_name) + self.embedding_model.to(self.device) self.embedding_model.eval() def forward(self, audio): @@ -333,7 +336,7 @@ class SpeakerVerificationERes2Net(TorchModel): ) == 2, 'modelscope error: the shape of input audio to model needs to be [N, T]' # audio shape: [N, T] feature = self.__extract_feature(audio) - embedding = self.embedding_model(feature) + embedding = self.embedding_model(feature.to(self.device)) return embedding.detach().cpu() diff --git a/modelscope/models/audio/sv/ERes2Net_aug.py b/modelscope/models/audio/sv/ERes2Net_aug.py index d0739cad..5540ff3e 100644 --- a/modelscope/models/audio/sv/ERes2Net_aug.py +++ b/modelscope/models/audio/sv/ERes2Net_aug.py @@ -19,6 +19,7 @@ from modelscope.metainfo import Models from modelscope.models import MODELS, TorchModel from modelscope.models.audio.sv.fusion import AFF from modelscope.utils.constant import Tasks +from modelscope.utils.device import create_device class ReLU(nn.Hardtanh): @@ -308,12 +309,13 @@ class SpeakerVerificationERes2Net(TorchModel): self.model_config = model_config self.other_config = kwargs self.feature_dim = 80 - + self.device = create_device(self.other_config['device']) self.embedding_model = ERes2Net_aug() pretrained_model_name = kwargs['pretrained_model'] self.__load_check_point(pretrained_model_name) + self.embedding_model.to(self.device) self.embedding_model.eval() def forward(self, audio): @@ -326,7 +328,7 @@ class SpeakerVerificationERes2Net(TorchModel): ) == 2, 'modelscope error: the shape of input audio to model needs to be [N, T]' # audio shape: [N, T] feature = self.__extract_feature(audio) - embedding = self.embedding_model(feature) + embedding = self.embedding_model(feature.to(self.device)) return embedding.detach().cpu()