mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 12:10:09 +01:00
add eres2netv2ep4
This commit is contained in:
@@ -38,7 +38,13 @@ class ReLU(nn.Hardtanh):
|
||||
|
||||
class BasicBlockERes2NetV2(nn.Module):
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
|
||||
def __init__(self,
|
||||
in_planes,
|
||||
planes,
|
||||
stride=1,
|
||||
baseWidth=26,
|
||||
scale=2,
|
||||
expansion=2):
|
||||
super(BasicBlockERes2NetV2, self).__init__()
|
||||
width = int(math.floor(planes * (baseWidth / 64.0)))
|
||||
self.width = width
|
||||
@@ -70,7 +76,6 @@ class BasicBlockERes2NetV2(nn.Module):
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False), nn.BatchNorm2d(self.expansion * planes))
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
@@ -103,7 +108,13 @@ class BasicBlockERes2NetV2(nn.Module):
|
||||
|
||||
class BasicBlockERes2NetV2AFF(nn.Module):
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
|
||||
def __init__(self,
|
||||
in_planes,
|
||||
planes,
|
||||
stride=1,
|
||||
baseWidth=26,
|
||||
scale=2,
|
||||
expansion=2):
|
||||
super(BasicBlockERes2NetV2AFF, self).__init__()
|
||||
width = int(math.floor(planes * (baseWidth / 64.0)))
|
||||
self.width = width
|
||||
@@ -235,7 +246,14 @@ class ERes2NetV2(nn.Module):
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, stride, baseWidth=self.baseWidth, scale=self.scale, expansion = self.expansion))
|
||||
layers.append(
|
||||
block(
|
||||
self.in_planes,
|
||||
planes,
|
||||
stride,
|
||||
baseWidth=self.baseWidth,
|
||||
scale=self.scale,
|
||||
expansion=self.expansion))
|
||||
self.in_planes = planes * self.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
@@ -285,7 +303,11 @@ class SpeakerVerificationERes2NetV2(TorchModel):
|
||||
self.feature_dim = 80
|
||||
self.device = create_device(self.other_config['device'])
|
||||
|
||||
self.embedding_model = ERes2NetV2(embed_dim=self.embed_dim, baseWidth=self.baseWidth, scale=self.scale, expansion=self.expansion)
|
||||
self.embedding_model = ERes2NetV2(
|
||||
embed_dim=self.embed_dim,
|
||||
baseWidth=self.baseWidth,
|
||||
scale=self.scale,
|
||||
expansion=self.expansion)
|
||||
|
||||
pretrained_model_name = kwargs['pretrained_model']
|
||||
self.__load_check_point(pretrained_model_name)
|
||||
|
||||
@@ -213,9 +213,11 @@ class SpeakerVerificationTest(unittest.TestCase):
|
||||
self.assertTrue(OutputKeys.SCORE in result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_speaker_verification_eres2netv2ep4w24s4_zh_cn_common_16k(self):
|
||||
def test_run_with_speaker_verification_eres2netv2ep4w24s4_zh_cn_common_16k(
|
||||
self):
|
||||
logger.info(
|
||||
'Run speaker verification for eres2netv2ep4_zh_cn_common_16k model')
|
||||
'Run speaker verification for eres2netv2ep4_zh_cn_common_16k model'
|
||||
)
|
||||
result = self.run_pipeline(
|
||||
model_id=self.eres2netv2ep4_zh_cn_16k_common_model_id,
|
||||
audios=[SPEAKER1_A_EN_16K_WAV, SPEAKER1_B_EN_16K_WAV],
|
||||
|
||||
Reference in New Issue
Block a user