mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 20:49:37 +01:00
245 lines
8.5 KiB
Python
245 lines
8.5 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||
|
||
import io
|
||
import os
|
||
from typing import Any, Dict, Tuple, Union
|
||
|
||
import numpy as np
|
||
import scipy.io.wavfile as wav
|
||
import torch
|
||
|
||
from modelscope.fileio import File
|
||
from modelscope.preprocessors import Preprocessor
|
||
from modelscope.preprocessors.builder import PREPROCESSORS
|
||
from modelscope.utils.constant import Fields, ModeKeys
|
||
|
||
|
||
class AudioBrainPreprocessor(Preprocessor):
|
||
"""A preprocessor takes audio file path and reads it into tensor
|
||
|
||
Args:
|
||
takes: the audio file field name
|
||
provides: the tensor field name
|
||
mode: process mode, default 'inference'
|
||
"""
|
||
|
||
def __init__(self,
|
||
takes: str,
|
||
provides: str,
|
||
mode=ModeKeys.INFERENCE,
|
||
*args,
|
||
**kwargs):
|
||
super(AudioBrainPreprocessor, self).__init__(mode, *args, **kwargs)
|
||
self.takes = takes
|
||
self.provides = provides
|
||
import speechbrain as sb
|
||
self.read_audio = sb.dataio.dataio.read_audio
|
||
|
||
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||
result = self.read_audio(data[self.takes])
|
||
data[self.provides] = result
|
||
return data
|
||
|
||
|
||
def load_kaldi_feature_transform(filename):
|
||
fp = open(filename, 'r', encoding='utf-8')
|
||
all_str = fp.read()
|
||
pos1 = all_str.find('AddShift')
|
||
pos2 = all_str.find('[', pos1)
|
||
pos3 = all_str.find(']', pos2)
|
||
mean = np.fromstring(all_str[pos2 + 1:pos3], dtype=np.float32, sep=' ')
|
||
pos1 = all_str.find('Rescale')
|
||
pos2 = all_str.find('[', pos1)
|
||
pos3 = all_str.find(']', pos2)
|
||
scale = np.fromstring(all_str[pos2 + 1:pos3], dtype=np.float32, sep=' ')
|
||
fp.close()
|
||
return mean, scale
|
||
|
||
|
||
class Feature:
|
||
r"""Extract feat from one utterance.
|
||
"""
|
||
|
||
def __init__(self,
|
||
fbank_config,
|
||
feat_type='spec',
|
||
mvn_file=None,
|
||
cuda=False):
|
||
r"""
|
||
|
||
Args:
|
||
fbank_config (dict):
|
||
feat_type (str):
|
||
raw: do nothing
|
||
fbank: use kaldi.fbank
|
||
spec: Real/Imag
|
||
logpow: log(1+|x|^2)
|
||
mvn_file (str): the path of data file for mean variance normalization
|
||
cuda:
|
||
"""
|
||
self.fbank_config = fbank_config
|
||
self.feat_type = feat_type
|
||
self.n_fft = fbank_config['frame_length'] * fbank_config[
|
||
'sample_frequency'] // 1000
|
||
self.hop_length = fbank_config['frame_shift'] * fbank_config[
|
||
'sample_frequency'] // 1000
|
||
self.window = torch.hamming_window(self.n_fft, periodic=False)
|
||
|
||
self.mvn = False
|
||
if mvn_file is not None and os.path.exists(mvn_file):
|
||
print(f'loading mvn file: {mvn_file}')
|
||
shift, scale = load_kaldi_feature_transform(mvn_file)
|
||
self.shift = torch.from_numpy(shift)
|
||
self.scale = torch.from_numpy(scale)
|
||
self.mvn = True
|
||
if cuda:
|
||
self.window = self.window.cuda()
|
||
if self.mvn:
|
||
self.shift = self.shift.cuda()
|
||
self.scale = self.scale.cuda()
|
||
|
||
def compute(self, utt):
|
||
r"""
|
||
|
||
Args:
|
||
utt: in [-32768, 32767] range
|
||
|
||
Returns:
|
||
[..., T, F]
|
||
"""
|
||
if self.feat_type == 'raw':
|
||
return utt
|
||
elif self.feat_type == 'fbank':
|
||
# have to use local import before modelscope framework supoort lazy loading
|
||
import torchaudio.compliance.kaldi as kaldi
|
||
if len(utt.shape) == 1:
|
||
utt = utt.unsqueeze(0)
|
||
feat = kaldi.fbank(utt, **self.fbank_config)
|
||
elif self.feat_type == 'spec':
|
||
spec = torch.stft(
|
||
utt / 32768,
|
||
self.n_fft,
|
||
self.hop_length,
|
||
self.n_fft,
|
||
self.window,
|
||
center=False,
|
||
return_complex=True)
|
||
feat = torch.cat([spec.real, spec.imag], dim=-2).permute(-1, -2)
|
||
elif self.feat_type == 'logpow':
|
||
spec = torch.stft(
|
||
utt,
|
||
self.n_fft,
|
||
self.hop_length,
|
||
self.n_fft,
|
||
self.window,
|
||
center=False,
|
||
return_complex=True)
|
||
abspow = torch.abs(spec)**2
|
||
feat = torch.log(1 + abspow).permute(-1, -2)
|
||
return feat
|
||
|
||
def normalize(self, feat):
|
||
if self.mvn:
|
||
feat = feat + self.shift
|
||
feat = feat * self.scale
|
||
return feat
|
||
|
||
|
||
@PREPROCESSORS.register_module(Fields.audio)
|
||
class LinearAECAndFbank(Preprocessor):
|
||
SAMPLE_RATE = 16000
|
||
|
||
def __init__(self, io_config):
|
||
import MinDAEC
|
||
self.trunc_length = 7200 * self.SAMPLE_RATE
|
||
self.linear_aec_delay = io_config['linear_aec_delay']
|
||
self.feature = Feature(io_config['fbank_config'],
|
||
io_config['feat_type'], io_config['mvn'])
|
||
self.mitaec = MinDAEC.load()
|
||
self.mask_on_mic = io_config['mask_on'] == 'nearend_mic'
|
||
|
||
def __call__(self, data: Union[Tuple, Dict[str, Any]]) -> Dict[str, Any]:
|
||
""" Linear filtering the near end mic and far end audio, then extract the feature.
|
||
|
||
Args:
|
||
data: Dict with two keys and correspond audios: "nearend_mic" and "farend_speech".
|
||
|
||
Returns:
|
||
Dict with two keys and Tensor values: "base" linear filtered audio,and "feature"
|
||
"""
|
||
if isinstance(data, tuple):
|
||
nearend_mic, fs = self.load_wav(data[0])
|
||
farend_speech, fs = self.load_wav(data[1])
|
||
nearend_speech = np.zeros_like(nearend_mic)
|
||
else:
|
||
# read files
|
||
nearend_mic, fs = self.load_wav(data['nearend_mic'])
|
||
farend_speech, fs = self.load_wav(data['farend_speech'])
|
||
if 'nearend_speech' in data:
|
||
nearend_speech, fs = self.load_wav(data['nearend_speech'])
|
||
else:
|
||
nearend_speech = np.zeros_like(nearend_mic)
|
||
|
||
out_mic, out_ref, out_linear, out_echo = self.mitaec.do_linear_aec(
|
||
nearend_mic, farend_speech)
|
||
# fix 20ms linear aec delay by delaying the target speech
|
||
extra_zeros = np.zeros([int(self.linear_aec_delay * fs)])
|
||
nearend_speech = np.concatenate([extra_zeros, nearend_speech])
|
||
# truncate files to the same length
|
||
flen = min(
|
||
len(out_mic), len(out_ref), len(out_linear), len(out_echo),
|
||
len(nearend_speech))
|
||
fstart = 0
|
||
flen = min(flen, self.trunc_length)
|
||
nearend_mic, out_ref, out_linear, out_echo, nearend_speech = (
|
||
out_mic[fstart:flen], out_ref[fstart:flen],
|
||
out_linear[fstart:flen], out_echo[fstart:flen],
|
||
nearend_speech[fstart:flen])
|
||
|
||
# extract features (frames, [mic, linear, ref, aes?])
|
||
feat = torch.FloatTensor()
|
||
|
||
nearend_mic = torch.from_numpy(np.float32(nearend_mic))
|
||
fbank_nearend_mic = self.feature.compute(nearend_mic)
|
||
feat = torch.cat([feat, fbank_nearend_mic], dim=1)
|
||
|
||
out_linear = torch.from_numpy(np.float32(out_linear))
|
||
fbank_out_linear = self.feature.compute(out_linear)
|
||
feat = torch.cat([feat, fbank_out_linear], dim=1)
|
||
|
||
out_echo = torch.from_numpy(np.float32(out_echo))
|
||
fbank_out_echo = self.feature.compute(out_echo)
|
||
feat = torch.cat([feat, fbank_out_echo], dim=1)
|
||
|
||
# feature transform
|
||
feat = self.feature.normalize(feat)
|
||
|
||
# prepare target
|
||
if nearend_speech is not None:
|
||
nearend_speech = torch.from_numpy(np.float32(nearend_speech))
|
||
|
||
if self.mask_on_mic:
|
||
base = nearend_mic
|
||
else:
|
||
base = out_linear
|
||
out_data = {'base': base, 'target': nearend_speech, 'feature': feat}
|
||
return out_data
|
||
|
||
@staticmethod
|
||
def load_wav(inputs):
|
||
import librosa
|
||
if isinstance(inputs, bytes):
|
||
inputs = io.BytesIO(inputs)
|
||
elif isinstance(inputs, str):
|
||
file_bytes = File.read(inputs)
|
||
inputs = io.BytesIO(file_bytes)
|
||
else:
|
||
raise TypeError(f'Unsupported input type: {type(inputs)}.')
|
||
sample_rate, data = wav.read(inputs)
|
||
if len(data.shape) > 1:
|
||
raise ValueError('modelscope error:The audio must be mono.')
|
||
if sample_rate != LinearAECAndFbank.SAMPLE_RATE:
|
||
data = librosa.resample(data, sample_rate,
|
||
LinearAECAndFbank.SAMPLE_RATE)
|
||
return data.astype(np.float32), LinearAECAndFbank.SAMPLE_RATE
|