diff --git a/modelscope/models/audio/sv/ERes2NetV2.py b/modelscope/models/audio/sv/ERes2NetV2.py index ba47dcc8..d842a094 100644 --- a/modelscope/models/audio/sv/ERes2NetV2.py +++ b/modelscope/models/audio/sv/ERes2NetV2.py @@ -37,15 +37,22 @@ class ReLU(nn.Hardtanh): class BasicBlockERes2NetV2(nn.Module): - expansion = 2 - def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2): + def __init__(self, + in_planes, + planes, + stride=1, + baseWidth=26, + scale=2, + expansion=2): super(BasicBlockERes2NetV2, self).__init__() width = int(math.floor(planes * (baseWidth / 64.0))) + self.width = width self.conv1 = nn.Conv2d( in_planes, width * scale, kernel_size=1, stride=stride, bias=False) self.bn1 = nn.BatchNorm2d(width * scale) self.nums = scale + self.expansion = expansion convs = [] bns = [] @@ -69,9 +76,6 @@ class BasicBlockERes2NetV2(nn.Module): kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(self.expansion * planes)) - self.stride = stride - self.width = width - self.scale = scale def forward(self, x): residual = x @@ -102,16 +106,23 @@ class BasicBlockERes2NetV2(nn.Module): return out -class BasicBlockERes2NetV2_AFF(nn.Module): - expansion = 2 +class BasicBlockERes2NetV2AFF(nn.Module): - def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2): - super(BasicBlockERes2NetV2_AFF, self).__init__() + def __init__(self, + in_planes, + planes, + stride=1, + baseWidth=26, + scale=2, + expansion=2): + super(BasicBlockERes2NetV2AFF, self).__init__() width = int(math.floor(planes * (baseWidth / 64.0))) + self.width = width self.conv1 = nn.Conv2d( in_planes, width * scale, kernel_size=1, stride=stride, bias=False) self.bn1 = nn.BatchNorm2d(width * scale) self.nums = scale + self.expansion = expansion convs = [] fuse_models = [] @@ -140,9 +151,6 @@ class BasicBlockERes2NetV2_AFF(nn.Module): kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(self.expansion * planes)) - self.stride = stride - self.width = width - self.scale = scale def forward(self, x): residual = x @@ -178,11 +186,14 @@ class ERes2NetV2(nn.Module): def __init__(self, block=BasicBlockERes2NetV2, - block_fuse=BasicBlockERes2NetV2_AFF, + block_fuse=BasicBlockERes2NetV2AFF, num_blocks=[3, 4, 6, 3], m_channels=64, feat_dim=80, embed_dim=192, + baseWidth=26, + scale=2, + expansion=2, pooling_func='TSTP', two_emb_layer=False): super(ERes2NetV2, self).__init__() @@ -191,6 +202,9 @@ class ERes2NetV2(nn.Module): self.embed_dim = embed_dim self.stats_dim = int(feat_dim / 8) * m_channels * 8 self.two_emb_layer = two_emb_layer + self.baseWidth = baseWidth + self.scale = scale + self.expansion = expansion self.conv1 = nn.Conv2d( 1, m_channels, kernel_size=3, stride=1, padding=1, bias=False) @@ -206,20 +220,20 @@ class ERes2NetV2(nn.Module): # Downsampling module self.layer3_ds = nn.Conv2d( - m_channels * 8, - m_channels * 16, + m_channels * 4 * self.expansion, + m_channels * 8 * self.expansion, kernel_size=3, padding=1, stride=2, bias=False) # Bottom-up fusion module - self.fuse34 = AFF(channels=m_channels * 16, r=4) + self.fuse34 = AFF(channels=m_channels * 8 * self.expansion, r=4) self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == 'TSDP' else 2 self.pool = getattr(pooling_layers, pooling_func)( - in_dim=self.stats_dim * block.expansion) - self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, + in_dim=self.stats_dim * self.expansion) + self.seg_1 = nn.Linear(self.stats_dim * self.expansion * self.n_stats, embed_dim) if self.two_emb_layer: self.seg_bn_1 = nn.BatchNorm1d(embed_dim, affine=False) @@ -232,8 +246,15 @@ class ERes2NetV2(nn.Module): strides = [stride] + [1] * (num_blocks - 1) layers = [] for stride in strides: - layers.append(block(self.in_planes, planes, stride)) - self.in_planes = planes * block.expansion + layers.append( + block( + self.in_planes, + planes, + stride, + baseWidth=self.baseWidth, + scale=self.scale, + expansion=self.expansion)) + self.in_planes = planes * self.expansion return nn.Sequential(*layers) def forward(self, x): @@ -275,11 +296,18 @@ class SpeakerVerificationERes2NetV2(TorchModel): super().__init__(model_dir, model_config, *args, **kwargs) self.model_config = model_config self.embed_dim = self.model_config['embed_dim'] + self.baseWidth = self.model_config['baseWidth'] + self.scale = self.model_config['scale'] + self.expansion = self.model_config['expansion'] self.other_config = kwargs self.feature_dim = 80 self.device = create_device(self.other_config['device']) - self.embedding_model = ERes2NetV2(embed_dim=self.embed_dim) + self.embedding_model = ERes2NetV2( + embed_dim=self.embed_dim, + baseWidth=self.baseWidth, + scale=self.scale, + expansion=self.expansion) pretrained_model_name = kwargs['pretrained_model'] self.__load_check_point(pretrained_model_name) diff --git a/tests/pipelines/test_speaker_verification.py b/tests/pipelines/test_speaker_verification.py index 8beae8ec..42eeb139 100644 --- a/tests/pipelines/test_speaker_verification.py +++ b/tests/pipelines/test_speaker_verification.py @@ -34,6 +34,7 @@ class SpeakerVerificationTest(unittest.TestCase): lre_eres2net_large_en_cn_16k_model_id = 'damo/speech_eres2net_large_lre_en-cn_16k' eres2net_aug_zh_cn_16k_common_model_id = 'damo/speech_eres2net_sv_zh-cn_16k-common' eres2netv2_zh_cn_16k_common_model_id = 'iic/speech_eres2netv2_sv_zh-cn_16k-common' + eres2netv2ep4_zh_cn_16k_common_model_id = 'iic/speech_eres2netv2w24s4ep4_sv_zh-cn_16k-common' rdino_3dspeaker_16k_model_id = 'damo/speech_rdino_ecapa_tdnn_sv_zh-cn_3dspeaker_16k' eres2net_base_3dspeaker_16k_model_id = 'damo/speech_eres2net_base_sv_zh-cn_3dspeaker_16k' eres2net_large_3dspeaker_16k_model_id = 'damo/speech_eres2net_large_sv_zh-cn_3dspeaker_16k' @@ -207,6 +208,19 @@ class SpeakerVerificationTest(unittest.TestCase): result = self.run_pipeline( model_id=self.eres2netv2_zh_cn_16k_common_model_id, audios=[SPEAKER1_A_EN_16K_WAV, SPEAKER1_B_EN_16K_WAV], + model_revision='v1.0.2') + print(result) + self.assertTrue(OutputKeys.SCORE in result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_speaker_verification_eres2netv2ep4w24s4_zh_cn_common_16k( + self): + logger.info( + 'Run speaker verification for eres2netv2ep4_zh_cn_common_16k model' + ) + result = self.run_pipeline( + model_id=self.eres2netv2ep4_zh_cn_16k_common_model_id, + audios=[SPEAKER1_A_EN_16K_WAV, SPEAKER1_B_EN_16K_WAV], model_revision='v1.0.1') print(result) self.assertTrue(OutputKeys.SCORE in result)