add eres2netv2ep4

This commit is contained in:
chenyafeng.cyf
2024-06-27 15:02:17 +08:00
parent da7d90bd60
commit 033d3d7f5a
2 changed files with 31 additions and 7 deletions

View File

@@ -38,7 +38,13 @@ class ReLU(nn.Hardtanh):
class BasicBlockERes2NetV2(nn.Module):
def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
def __init__(self,
in_planes,
planes,
stride=1,
baseWidth=26,
scale=2,
expansion=2):
super(BasicBlockERes2NetV2, self).__init__()
width = int(math.floor(planes * (baseWidth / 64.0)))
self.width = width
@@ -70,7 +76,6 @@ class BasicBlockERes2NetV2(nn.Module):
kernel_size=1,
stride=stride,
bias=False), nn.BatchNorm2d(self.expansion * planes))
def forward(self, x):
residual = x
@@ -103,7 +108,13 @@ class BasicBlockERes2NetV2(nn.Module):
class BasicBlockERes2NetV2AFF(nn.Module):
def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
def __init__(self,
in_planes,
planes,
stride=1,
baseWidth=26,
scale=2,
expansion=2):
super(BasicBlockERes2NetV2AFF, self).__init__()
width = int(math.floor(planes * (baseWidth / 64.0)))
self.width = width
@@ -235,7 +246,14 @@ class ERes2NetV2(nn.Module):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride, baseWidth=self.baseWidth, scale=self.scale, expansion = self.expansion))
layers.append(
block(
self.in_planes,
planes,
stride,
baseWidth=self.baseWidth,
scale=self.scale,
expansion=self.expansion))
self.in_planes = planes * self.expansion
return nn.Sequential(*layers)
@@ -285,7 +303,11 @@ class SpeakerVerificationERes2NetV2(TorchModel):
self.feature_dim = 80
self.device = create_device(self.other_config['device'])
self.embedding_model = ERes2NetV2(embed_dim=self.embed_dim, baseWidth=self.baseWidth, scale=self.scale, expansion=self.expansion)
self.embedding_model = ERes2NetV2(
embed_dim=self.embed_dim,
baseWidth=self.baseWidth,
scale=self.scale,
expansion=self.expansion)
pretrained_model_name = kwargs['pretrained_model']
self.__load_check_point(pretrained_model_name)

View File

@@ -213,9 +213,11 @@ class SpeakerVerificationTest(unittest.TestCase):
self.assertTrue(OutputKeys.SCORE in result)
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_speaker_verification_eres2netv2ep4w24s4_zh_cn_common_16k(self):
def test_run_with_speaker_verification_eres2netv2ep4w24s4_zh_cn_common_16k(
self):
logger.info(
'Run speaker verification for eres2netv2ep4_zh_cn_common_16k model')
'Run speaker verification for eres2netv2ep4_zh_cn_common_16k model'
)
result = self.run_pipeline(
model_id=self.eres2netv2ep4_zh_cn_16k_common_model_id,
audios=[SPEAKER1_A_EN_16K_WAV, SPEAKER1_B_EN_16K_WAV],