diff --git a/modelscope/models/audio/sv/ERes2NetV2.py b/modelscope/models/audio/sv/ERes2NetV2.py index 85a44716..d842a094 100644 --- a/modelscope/models/audio/sv/ERes2NetV2.py +++ b/modelscope/models/audio/sv/ERes2NetV2.py @@ -38,7 +38,13 @@ class ReLU(nn.Hardtanh): class BasicBlockERes2NetV2(nn.Module): - def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2): + def __init__(self, + in_planes, + planes, + stride=1, + baseWidth=26, + scale=2, + expansion=2): super(BasicBlockERes2NetV2, self).__init__() width = int(math.floor(planes * (baseWidth / 64.0))) self.width = width @@ -70,7 +76,6 @@ class BasicBlockERes2NetV2(nn.Module): kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(self.expansion * planes)) - def forward(self, x): residual = x @@ -103,7 +108,13 @@ class BasicBlockERes2NetV2(nn.Module): class BasicBlockERes2NetV2AFF(nn.Module): - def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2): + def __init__(self, + in_planes, + planes, + stride=1, + baseWidth=26, + scale=2, + expansion=2): super(BasicBlockERes2NetV2AFF, self).__init__() width = int(math.floor(planes * (baseWidth / 64.0))) self.width = width @@ -235,7 +246,14 @@ class ERes2NetV2(nn.Module): strides = [stride] + [1] * (num_blocks - 1) layers = [] for stride in strides: - layers.append(block(self.in_planes, planes, stride, baseWidth=self.baseWidth, scale=self.scale, expansion = self.expansion)) + layers.append( + block( + self.in_planes, + planes, + stride, + baseWidth=self.baseWidth, + scale=self.scale, + expansion=self.expansion)) self.in_planes = planes * self.expansion return nn.Sequential(*layers) @@ -285,7 +303,11 @@ class SpeakerVerificationERes2NetV2(TorchModel): self.feature_dim = 80 self.device = create_device(self.other_config['device']) - self.embedding_model = ERes2NetV2(embed_dim=self.embed_dim, baseWidth=self.baseWidth, scale=self.scale, expansion=self.expansion) + self.embedding_model = ERes2NetV2( + embed_dim=self.embed_dim, + baseWidth=self.baseWidth, + scale=self.scale, + expansion=self.expansion) pretrained_model_name = kwargs['pretrained_model'] self.__load_check_point(pretrained_model_name) diff --git a/tests/pipelines/test_speaker_verification.py b/tests/pipelines/test_speaker_verification.py index 0803403e..42eeb139 100644 --- a/tests/pipelines/test_speaker_verification.py +++ b/tests/pipelines/test_speaker_verification.py @@ -213,9 +213,11 @@ class SpeakerVerificationTest(unittest.TestCase): self.assertTrue(OutputKeys.SCORE in result) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') - def test_run_with_speaker_verification_eres2netv2ep4w24s4_zh_cn_common_16k(self): + def test_run_with_speaker_verification_eres2netv2ep4w24s4_zh_cn_common_16k( + self): logger.info( - 'Run speaker verification for eres2netv2ep4_zh_cn_common_16k model') + 'Run speaker verification for eres2netv2ep4_zh_cn_common_16k model' + ) result = self.run_pipeline( model_id=self.eres2netv2ep4_zh_cn_16k_common_model_id, audios=[SPEAKER1_A_EN_16K_WAV, SPEAKER1_B_EN_16K_WAV],