add enh / ss

This commit is contained in:
simpleoier
2023-04-11 08:06:42 -04:00
parent e2b06d3c79
commit 181bceea24
2 changed files with 82 additions and 0 deletions

View File

@@ -821,6 +821,86 @@ class TargetSoundDetection:
#print(ans) #print(ans)
return ans return ans
class Speech_Enh_SS_SC:
"""Speech Enhancement or Separation in single-channel
Example usage:
enh_model = Speech_Enh_SS("cuda")
enh_wav = enh_model.inference("./test_chime4_audio_M05_440C0213_PED_REAL.wav")
"""
def __init__(self, device="cuda", model_name="lichenda/chime4_fasnet_dprnn_tac"):
self.model_name = model_name
self.device = device
print("Initializing ESPnet Enh to %s" % device)
self._initialize_model()
def _initialize_model(self):
from espnet_model_zoo.downloader import ModelDownloader
from espnet2.bin.enh_inference import SeparateSpeech
d = ModelDownloader()
cfg = d.download_and_unpack(self.model_name)
self.separate_speech = SeparateSpeech(
train_config=cfg["train_config"],
model_file=cfg["model_file"],
# for segment-wise process on long speech
segment_size=2.4,
hop_size=0.8,
normalize_segment_scale=False,
show_progressbar=True,
ref_channel=None,
normalize_output_wav=True,
device=self.device,
)
def inference(self, speech_path, ref_channel=0):
speech, sr = soundfile.read(speech_path)
speech = speech[:, ref_channel]
assert speech.dim() == 1
enh_speech = self.separate_speech(speech[None, ], fs=sr)
if len(enh_speech) == 1:
return enh_speech[0]
return enh_speech
class Speech_Enh_SS_MC:
"""Speech Enhancement or Separation in multi-channel"""
def __init__(self, device="cuda", model_name=None, ref_channel=4):
self.model_name = model_name
self.ref_channel = ref_channel
self.device = device
print("Initializing ESPnet Enh to %s" % device)
self._initialize_model()
def _initialize_model(self):
from espnet_model_zoo.downloader import ModelDownloader
from espnet2.bin.enh_inference import SeparateSpeech
d = ModelDownloader()
cfg = d.download_and_unpack(self.model_name)
self.separate_speech = SeparateSpeech(
train_config=cfg["train_config"],
model_file=cfg["model_file"],
# for segment-wise process on long speech
segment_size=2.4,
hop_size=0.8,
normalize_segment_scale=False,
show_progressbar=True,
ref_channel=self.ref_channel,
normalize_output_wav=True,
device=self.device,
)
def inference(self, speech_path):
speech, sr = soundfile.read(speech_path)
speech = speech.T
enh_speech = self.separate_speech(speech[None, ...], fs=sr)
if len(enh_speech) == 1:
return enh_speech[0]
return enh_speech
class ConversationBot: class ConversationBot:
def __init__(self): def __init__(self):
print("Initializing AudioGPT") print("Initializing AudioGPT")

View File

@@ -8,6 +8,8 @@ beautifulsoup4==4.10.0
Cython==0.29.24 Cython==0.29.24
diffusers diffusers
einops==0.3.0 einops==0.3.0
espnet
espnet_model_zoo
g2p-en==2.1.0 g2p-en==2.1.0
google==3.0.0 google==3.0.0
gradio gradio