mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
@@ -37,15 +37,22 @@ class ReLU(nn.Hardtanh):
|
||||
|
||||
|
||||
class BasicBlockERes2NetV2(nn.Module):
|
||||
expansion = 2
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2):
|
||||
def __init__(self,
|
||||
in_planes,
|
||||
planes,
|
||||
stride=1,
|
||||
baseWidth=26,
|
||||
scale=2,
|
||||
expansion=2):
|
||||
super(BasicBlockERes2NetV2, self).__init__()
|
||||
width = int(math.floor(planes * (baseWidth / 64.0)))
|
||||
self.width = width
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width * scale)
|
||||
self.nums = scale
|
||||
self.expansion = expansion
|
||||
|
||||
convs = []
|
||||
bns = []
|
||||
@@ -69,9 +76,6 @@ class BasicBlockERes2NetV2(nn.Module):
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False), nn.BatchNorm2d(self.expansion * planes))
|
||||
self.stride = stride
|
||||
self.width = width
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
@@ -102,16 +106,23 @@ class BasicBlockERes2NetV2(nn.Module):
|
||||
return out
|
||||
|
||||
|
||||
class BasicBlockERes2NetV2_AFF(nn.Module):
|
||||
expansion = 2
|
||||
class BasicBlockERes2NetV2AFF(nn.Module):
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2):
|
||||
super(BasicBlockERes2NetV2_AFF, self).__init__()
|
||||
def __init__(self,
|
||||
in_planes,
|
||||
planes,
|
||||
stride=1,
|
||||
baseWidth=26,
|
||||
scale=2,
|
||||
expansion=2):
|
||||
super(BasicBlockERes2NetV2AFF, self).__init__()
|
||||
width = int(math.floor(planes * (baseWidth / 64.0)))
|
||||
self.width = width
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width * scale)
|
||||
self.nums = scale
|
||||
self.expansion = expansion
|
||||
|
||||
convs = []
|
||||
fuse_models = []
|
||||
@@ -140,9 +151,6 @@ class BasicBlockERes2NetV2_AFF(nn.Module):
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False), nn.BatchNorm2d(self.expansion * planes))
|
||||
self.stride = stride
|
||||
self.width = width
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
@@ -178,11 +186,14 @@ class ERes2NetV2(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
block=BasicBlockERes2NetV2,
|
||||
block_fuse=BasicBlockERes2NetV2_AFF,
|
||||
block_fuse=BasicBlockERes2NetV2AFF,
|
||||
num_blocks=[3, 4, 6, 3],
|
||||
m_channels=64,
|
||||
feat_dim=80,
|
||||
embed_dim=192,
|
||||
baseWidth=26,
|
||||
scale=2,
|
||||
expansion=2,
|
||||
pooling_func='TSTP',
|
||||
two_emb_layer=False):
|
||||
super(ERes2NetV2, self).__init__()
|
||||
@@ -191,6 +202,9 @@ class ERes2NetV2(nn.Module):
|
||||
self.embed_dim = embed_dim
|
||||
self.stats_dim = int(feat_dim / 8) * m_channels * 8
|
||||
self.two_emb_layer = two_emb_layer
|
||||
self.baseWidth = baseWidth
|
||||
self.scale = scale
|
||||
self.expansion = expansion
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
@@ -206,20 +220,20 @@ class ERes2NetV2(nn.Module):
|
||||
|
||||
# Downsampling module
|
||||
self.layer3_ds = nn.Conv2d(
|
||||
m_channels * 8,
|
||||
m_channels * 16,
|
||||
m_channels * 4 * self.expansion,
|
||||
m_channels * 8 * self.expansion,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
stride=2,
|
||||
bias=False)
|
||||
|
||||
# Bottom-up fusion module
|
||||
self.fuse34 = AFF(channels=m_channels * 16, r=4)
|
||||
self.fuse34 = AFF(channels=m_channels * 8 * self.expansion, r=4)
|
||||
|
||||
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == 'TSDP' else 2
|
||||
self.pool = getattr(pooling_layers, pooling_func)(
|
||||
in_dim=self.stats_dim * block.expansion)
|
||||
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
|
||||
in_dim=self.stats_dim * self.expansion)
|
||||
self.seg_1 = nn.Linear(self.stats_dim * self.expansion * self.n_stats,
|
||||
embed_dim)
|
||||
if self.two_emb_layer:
|
||||
self.seg_bn_1 = nn.BatchNorm1d(embed_dim, affine=False)
|
||||
@@ -232,8 +246,15 @@ class ERes2NetV2(nn.Module):
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, stride))
|
||||
self.in_planes = planes * block.expansion
|
||||
layers.append(
|
||||
block(
|
||||
self.in_planes,
|
||||
planes,
|
||||
stride,
|
||||
baseWidth=self.baseWidth,
|
||||
scale=self.scale,
|
||||
expansion=self.expansion))
|
||||
self.in_planes = planes * self.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
@@ -275,11 +296,18 @@ class SpeakerVerificationERes2NetV2(TorchModel):
|
||||
super().__init__(model_dir, model_config, *args, **kwargs)
|
||||
self.model_config = model_config
|
||||
self.embed_dim = self.model_config['embed_dim']
|
||||
self.baseWidth = self.model_config['baseWidth']
|
||||
self.scale = self.model_config['scale']
|
||||
self.expansion = self.model_config['expansion']
|
||||
self.other_config = kwargs
|
||||
self.feature_dim = 80
|
||||
self.device = create_device(self.other_config['device'])
|
||||
|
||||
self.embedding_model = ERes2NetV2(embed_dim=self.embed_dim)
|
||||
self.embedding_model = ERes2NetV2(
|
||||
embed_dim=self.embed_dim,
|
||||
baseWidth=self.baseWidth,
|
||||
scale=self.scale,
|
||||
expansion=self.expansion)
|
||||
|
||||
pretrained_model_name = kwargs['pretrained_model']
|
||||
self.__load_check_point(pretrained_model_name)
|
||||
|
||||
@@ -34,6 +34,7 @@ class SpeakerVerificationTest(unittest.TestCase):
|
||||
lre_eres2net_large_en_cn_16k_model_id = 'damo/speech_eres2net_large_lre_en-cn_16k'
|
||||
eres2net_aug_zh_cn_16k_common_model_id = 'damo/speech_eres2net_sv_zh-cn_16k-common'
|
||||
eres2netv2_zh_cn_16k_common_model_id = 'iic/speech_eres2netv2_sv_zh-cn_16k-common'
|
||||
eres2netv2ep4_zh_cn_16k_common_model_id = 'iic/speech_eres2netv2w24s4ep4_sv_zh-cn_16k-common'
|
||||
rdino_3dspeaker_16k_model_id = 'damo/speech_rdino_ecapa_tdnn_sv_zh-cn_3dspeaker_16k'
|
||||
eres2net_base_3dspeaker_16k_model_id = 'damo/speech_eres2net_base_sv_zh-cn_3dspeaker_16k'
|
||||
eres2net_large_3dspeaker_16k_model_id = 'damo/speech_eres2net_large_sv_zh-cn_3dspeaker_16k'
|
||||
@@ -207,6 +208,19 @@ class SpeakerVerificationTest(unittest.TestCase):
|
||||
result = self.run_pipeline(
|
||||
model_id=self.eres2netv2_zh_cn_16k_common_model_id,
|
||||
audios=[SPEAKER1_A_EN_16K_WAV, SPEAKER1_B_EN_16K_WAV],
|
||||
model_revision='v1.0.2')
|
||||
print(result)
|
||||
self.assertTrue(OutputKeys.SCORE in result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_speaker_verification_eres2netv2ep4w24s4_zh_cn_common_16k(
|
||||
self):
|
||||
logger.info(
|
||||
'Run speaker verification for eres2netv2ep4_zh_cn_common_16k model'
|
||||
)
|
||||
result = self.run_pipeline(
|
||||
model_id=self.eres2netv2ep4_zh_cn_16k_common_model_id,
|
||||
audios=[SPEAKER1_A_EN_16K_WAV, SPEAKER1_B_EN_16K_WAV],
|
||||
model_revision='v1.0.1')
|
||||
print(result)
|
||||
self.assertTrue(OutputKeys.SCORE in result)
|
||||
|
||||
Reference in New Issue
Block a user