Add Eres2netv2ep4 (#892)

* add eres2netv2ep4
This commit is contained in:
yfchenmodelscope
2024-06-27 21:09:51 +08:00
committed by GitHub
parent 5d00b7c304
commit f65f45959d
2 changed files with 63 additions and 21 deletions

View File

@@ -37,15 +37,22 @@ class ReLU(nn.Hardtanh):
class BasicBlockERes2NetV2(nn.Module):
expansion = 2
def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2):
def __init__(self,
in_planes,
planes,
stride=1,
baseWidth=26,
scale=2,
expansion=2):
super(BasicBlockERes2NetV2, self).__init__()
width = int(math.floor(planes * (baseWidth / 64.0)))
self.width = width
self.conv1 = nn.Conv2d(
in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
self.bn1 = nn.BatchNorm2d(width * scale)
self.nums = scale
self.expansion = expansion
convs = []
bns = []
@@ -69,9 +76,6 @@ class BasicBlockERes2NetV2(nn.Module):
kernel_size=1,
stride=stride,
bias=False), nn.BatchNorm2d(self.expansion * planes))
self.stride = stride
self.width = width
self.scale = scale
def forward(self, x):
residual = x
@@ -102,16 +106,23 @@ class BasicBlockERes2NetV2(nn.Module):
return out
class BasicBlockERes2NetV2_AFF(nn.Module):
expansion = 2
class BasicBlockERes2NetV2AFF(nn.Module):
def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2):
super(BasicBlockERes2NetV2_AFF, self).__init__()
def __init__(self,
in_planes,
planes,
stride=1,
baseWidth=26,
scale=2,
expansion=2):
super(BasicBlockERes2NetV2AFF, self).__init__()
width = int(math.floor(planes * (baseWidth / 64.0)))
self.width = width
self.conv1 = nn.Conv2d(
in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
self.bn1 = nn.BatchNorm2d(width * scale)
self.nums = scale
self.expansion = expansion
convs = []
fuse_models = []
@@ -140,9 +151,6 @@ class BasicBlockERes2NetV2_AFF(nn.Module):
kernel_size=1,
stride=stride,
bias=False), nn.BatchNorm2d(self.expansion * planes))
self.stride = stride
self.width = width
self.scale = scale
def forward(self, x):
residual = x
@@ -178,11 +186,14 @@ class ERes2NetV2(nn.Module):
def __init__(self,
block=BasicBlockERes2NetV2,
block_fuse=BasicBlockERes2NetV2_AFF,
block_fuse=BasicBlockERes2NetV2AFF,
num_blocks=[3, 4, 6, 3],
m_channels=64,
feat_dim=80,
embed_dim=192,
baseWidth=26,
scale=2,
expansion=2,
pooling_func='TSTP',
two_emb_layer=False):
super(ERes2NetV2, self).__init__()
@@ -191,6 +202,9 @@ class ERes2NetV2(nn.Module):
self.embed_dim = embed_dim
self.stats_dim = int(feat_dim / 8) * m_channels * 8
self.two_emb_layer = two_emb_layer
self.baseWidth = baseWidth
self.scale = scale
self.expansion = expansion
self.conv1 = nn.Conv2d(
1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
@@ -206,20 +220,20 @@ class ERes2NetV2(nn.Module):
# Downsampling module
self.layer3_ds = nn.Conv2d(
m_channels * 8,
m_channels * 16,
m_channels * 4 * self.expansion,
m_channels * 8 * self.expansion,
kernel_size=3,
padding=1,
stride=2,
bias=False)
# Bottom-up fusion module
self.fuse34 = AFF(channels=m_channels * 16, r=4)
self.fuse34 = AFF(channels=m_channels * 8 * self.expansion, r=4)
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == 'TSDP' else 2
self.pool = getattr(pooling_layers, pooling_func)(
in_dim=self.stats_dim * block.expansion)
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
in_dim=self.stats_dim * self.expansion)
self.seg_1 = nn.Linear(self.stats_dim * self.expansion * self.n_stats,
embed_dim)
if self.two_emb_layer:
self.seg_bn_1 = nn.BatchNorm1d(embed_dim, affine=False)
@@ -232,8 +246,15 @@ class ERes2NetV2(nn.Module):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
layers.append(
block(
self.in_planes,
planes,
stride,
baseWidth=self.baseWidth,
scale=self.scale,
expansion=self.expansion))
self.in_planes = planes * self.expansion
return nn.Sequential(*layers)
def forward(self, x):
@@ -275,11 +296,18 @@ class SpeakerVerificationERes2NetV2(TorchModel):
super().__init__(model_dir, model_config, *args, **kwargs)
self.model_config = model_config
self.embed_dim = self.model_config['embed_dim']
self.baseWidth = self.model_config['baseWidth']
self.scale = self.model_config['scale']
self.expansion = self.model_config['expansion']
self.other_config = kwargs
self.feature_dim = 80
self.device = create_device(self.other_config['device'])
self.embedding_model = ERes2NetV2(embed_dim=self.embed_dim)
self.embedding_model = ERes2NetV2(
embed_dim=self.embed_dim,
baseWidth=self.baseWidth,
scale=self.scale,
expansion=self.expansion)
pretrained_model_name = kwargs['pretrained_model']
self.__load_check_point(pretrained_model_name)

View File

@@ -34,6 +34,7 @@ class SpeakerVerificationTest(unittest.TestCase):
lre_eres2net_large_en_cn_16k_model_id = 'damo/speech_eres2net_large_lre_en-cn_16k'
eres2net_aug_zh_cn_16k_common_model_id = 'damo/speech_eres2net_sv_zh-cn_16k-common'
eres2netv2_zh_cn_16k_common_model_id = 'iic/speech_eres2netv2_sv_zh-cn_16k-common'
eres2netv2ep4_zh_cn_16k_common_model_id = 'iic/speech_eres2netv2w24s4ep4_sv_zh-cn_16k-common'
rdino_3dspeaker_16k_model_id = 'damo/speech_rdino_ecapa_tdnn_sv_zh-cn_3dspeaker_16k'
eres2net_base_3dspeaker_16k_model_id = 'damo/speech_eres2net_base_sv_zh-cn_3dspeaker_16k'
eres2net_large_3dspeaker_16k_model_id = 'damo/speech_eres2net_large_sv_zh-cn_3dspeaker_16k'
@@ -207,6 +208,19 @@ class SpeakerVerificationTest(unittest.TestCase):
result = self.run_pipeline(
model_id=self.eres2netv2_zh_cn_16k_common_model_id,
audios=[SPEAKER1_A_EN_16K_WAV, SPEAKER1_B_EN_16K_WAV],
model_revision='v1.0.2')
print(result)
self.assertTrue(OutputKeys.SCORE in result)
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_speaker_verification_eres2netv2ep4w24s4_zh_cn_common_16k(
self):
logger.info(
'Run speaker verification for eres2netv2ep4_zh_cn_common_16k model'
)
result = self.run_pipeline(
model_id=self.eres2netv2ep4_zh_cn_16k_common_model_id,
audios=[SPEAKER1_A_EN_16K_WAV, SPEAKER1_B_EN_16K_WAV],
model_revision='v1.0.1')
print(result)
self.assertTrue(OutputKeys.SCORE in result)