add eres2netv2ep4

This commit is contained in:
chenyafeng.cyf
2024-06-27 14:46:03 +08:00
parent 5d00b7c304
commit da7d90bd60
2 changed files with 39 additions and 21 deletions

View File

@@ -37,15 +37,16 @@ class ReLU(nn.Hardtanh):
class BasicBlockERes2NetV2(nn.Module):
expansion = 2
def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2):
def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
super(BasicBlockERes2NetV2, self).__init__()
width = int(math.floor(planes * (baseWidth / 64.0)))
self.width = width
self.conv1 = nn.Conv2d(
in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
self.bn1 = nn.BatchNorm2d(width * scale)
self.nums = scale
self.expansion = expansion
convs = []
bns = []
@@ -69,9 +70,7 @@ class BasicBlockERes2NetV2(nn.Module):
kernel_size=1,
stride=stride,
bias=False), nn.BatchNorm2d(self.expansion * planes))
self.stride = stride
self.width = width
self.scale = scale
def forward(self, x):
residual = x
@@ -102,16 +101,17 @@ class BasicBlockERes2NetV2(nn.Module):
return out
class BasicBlockERes2NetV2_AFF(nn.Module):
expansion = 2
class BasicBlockERes2NetV2AFF(nn.Module):
def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2):
super(BasicBlockERes2NetV2_AFF, self).__init__()
def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
super(BasicBlockERes2NetV2AFF, self).__init__()
width = int(math.floor(planes * (baseWidth / 64.0)))
self.width = width
self.conv1 = nn.Conv2d(
in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
self.bn1 = nn.BatchNorm2d(width * scale)
self.nums = scale
self.expansion = expansion
convs = []
fuse_models = []
@@ -140,9 +140,6 @@ class BasicBlockERes2NetV2_AFF(nn.Module):
kernel_size=1,
stride=stride,
bias=False), nn.BatchNorm2d(self.expansion * planes))
self.stride = stride
self.width = width
self.scale = scale
def forward(self, x):
residual = x
@@ -178,11 +175,14 @@ class ERes2NetV2(nn.Module):
def __init__(self,
block=BasicBlockERes2NetV2,
block_fuse=BasicBlockERes2NetV2_AFF,
block_fuse=BasicBlockERes2NetV2AFF,
num_blocks=[3, 4, 6, 3],
m_channels=64,
feat_dim=80,
embed_dim=192,
baseWidth=26,
scale=2,
expansion=2,
pooling_func='TSTP',
two_emb_layer=False):
super(ERes2NetV2, self).__init__()
@@ -191,6 +191,9 @@ class ERes2NetV2(nn.Module):
self.embed_dim = embed_dim
self.stats_dim = int(feat_dim / 8) * m_channels * 8
self.two_emb_layer = two_emb_layer
self.baseWidth = baseWidth
self.scale = scale
self.expansion = expansion
self.conv1 = nn.Conv2d(
1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
@@ -206,20 +209,20 @@ class ERes2NetV2(nn.Module):
# Downsampling module
self.layer3_ds = nn.Conv2d(
m_channels * 8,
m_channels * 16,
m_channels * 4 * self.expansion,
m_channels * 8 * self.expansion,
kernel_size=3,
padding=1,
stride=2,
bias=False)
# Bottom-up fusion module
self.fuse34 = AFF(channels=m_channels * 16, r=4)
self.fuse34 = AFF(channels=m_channels * 8 * self.expansion, r=4)
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == 'TSDP' else 2
self.pool = getattr(pooling_layers, pooling_func)(
in_dim=self.stats_dim * block.expansion)
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
in_dim=self.stats_dim * self.expansion)
self.seg_1 = nn.Linear(self.stats_dim * self.expansion * self.n_stats,
embed_dim)
if self.two_emb_layer:
self.seg_bn_1 = nn.BatchNorm1d(embed_dim, affine=False)
@@ -232,8 +235,8 @@ class ERes2NetV2(nn.Module):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
layers.append(block(self.in_planes, planes, stride, baseWidth=self.baseWidth, scale=self.scale, expansion = self.expansion))
self.in_planes = planes * self.expansion
return nn.Sequential(*layers)
def forward(self, x):
@@ -275,11 +278,14 @@ class SpeakerVerificationERes2NetV2(TorchModel):
super().__init__(model_dir, model_config, *args, **kwargs)
self.model_config = model_config
self.embed_dim = self.model_config['embed_dim']
self.baseWidth = self.model_config['baseWidth']
self.scale = self.model_config['scale']
self.expansion = self.model_config['expansion']
self.other_config = kwargs
self.feature_dim = 80
self.device = create_device(self.other_config['device'])
self.embedding_model = ERes2NetV2(embed_dim=self.embed_dim)
self.embedding_model = ERes2NetV2(embed_dim=self.embed_dim, baseWidth=self.baseWidth, scale=self.scale, expansion=self.expansion)
pretrained_model_name = kwargs['pretrained_model']
self.__load_check_point(pretrained_model_name)

View File

@@ -34,6 +34,7 @@ class SpeakerVerificationTest(unittest.TestCase):
lre_eres2net_large_en_cn_16k_model_id = 'damo/speech_eres2net_large_lre_en-cn_16k'
eres2net_aug_zh_cn_16k_common_model_id = 'damo/speech_eres2net_sv_zh-cn_16k-common'
eres2netv2_zh_cn_16k_common_model_id = 'iic/speech_eres2netv2_sv_zh-cn_16k-common'
eres2netv2ep4_zh_cn_16k_common_model_id = 'iic/speech_eres2netv2w24s4ep4_sv_zh-cn_16k-common'
rdino_3dspeaker_16k_model_id = 'damo/speech_rdino_ecapa_tdnn_sv_zh-cn_3dspeaker_16k'
eres2net_base_3dspeaker_16k_model_id = 'damo/speech_eres2net_base_sv_zh-cn_3dspeaker_16k'
eres2net_large_3dspeaker_16k_model_id = 'damo/speech_eres2net_large_sv_zh-cn_3dspeaker_16k'
@@ -207,6 +208,17 @@ class SpeakerVerificationTest(unittest.TestCase):
result = self.run_pipeline(
model_id=self.eres2netv2_zh_cn_16k_common_model_id,
audios=[SPEAKER1_A_EN_16K_WAV, SPEAKER1_B_EN_16K_WAV],
model_revision='v1.0.2')
print(result)
self.assertTrue(OutputKeys.SCORE in result)
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_speaker_verification_eres2netv2ep4w24s4_zh_cn_common_16k(self):
logger.info(
'Run speaker verification for eres2netv2ep4_zh_cn_common_16k model')
result = self.run_pipeline(
model_id=self.eres2netv2ep4_zh_cn_16k_common_model_id,
audios=[SPEAKER1_A_EN_16K_WAV, SPEAKER1_B_EN_16K_WAV],
model_revision='v1.0.1')
print(result)
self.assertTrue(OutputKeys.SCORE in result)