diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml index 84c0f07f..6d20383d 100644 --- a/.github/workflows/docker-image.yml +++ b/.github/workflows/docker-image.yml @@ -1,24 +1,30 @@ -name: Build Docker Images +name: Build Docker Image on: workflow_dispatch: inputs: + workflow_name: + description: 'The specific name of this build' + required: true + default: 'build' modelscope_branch: - description: 'ModelScope branch to build from' + description: 'ModelScope branch to build from(release/x.xx)' required: true image_type: - description: 'The image type to build' + description: 'The image type to build(cpu/gpu/llm)' required: true modelscope_version: - description: 'ModelScope version to use' + description: 'ModelScope version to use(x.xx.x)' required: true swift_branch: - description: 'SWIFT branch to use' + description: 'SWIFT branch to use(release/x.xx)' required: true other_params: description: 'Other params in --xxx xxx' required: false +run-name: Docker-${{ inputs.modelscope_branch }}-${{ inputs.image_type }}-${{ inputs.workflow_name }}-by-@${{ github.actor }} + jobs: build: runs-on: [modelscope-self-hosted-us] diff --git a/docker/Dockerfile.ubuntu b/docker/Dockerfile.ubuntu index f4767d51..0d242e44 100644 --- a/docker/Dockerfile.ubuntu +++ b/docker/Dockerfile.ubuntu @@ -12,10 +12,6 @@ RUN apt-get update && \ {extra_content} -RUN pip config set global.index-url https://mirrors.aliyun.com/pypi/simple && \ - pip config set install.trusted-host mirrors.aliyun.com && \ - cp /tmp/resources/ubuntu2204.aliyun /etc/apt/sources.list - COPY {meta_file} /tmp/install.sh RUN sh /tmp/install.sh {version_args} @@ -28,6 +24,10 @@ RUN cd /tmp && GIT_LFS_SKIP_SMUDGE=1 git clone -b {modelscope_branch} --single RUN cd /tmp && GIT_LFS_SKIP_SMUDGE=1 git clone -b {swift_branch} --single-branch https://github.com/modelscope/ms-swift.git && cd ms-swift && pip install .[all] && cd / && rm -fr /tmp/ms-swift && pip cache purge; +RUN pip config set global.index-url https://mirrors.aliyun.com/pypi/simple && \ + pip config set install.trusted-host mirrors.aliyun.com && \ + cp /tmp/resources/ubuntu2204.aliyun /etc/apt/sources.list + ENV SETUPTOOLS_USE_DISTUTILS=stdlib ENV VLLM_USE_MODELSCOPE=True ENV LMDEPLOY_USE_MODELSCOPE=True diff --git a/docker/build_image.py b/docker/build_image.py index be99af0e..76b50688 100644 --- a/docker/build_image.py +++ b/docker/build_image.py @@ -1,9 +1,12 @@ import argparse import os +from datetime import datetime from typing import Any docker_registry = os.environ['DOCKER_REGISTRY'] assert docker_registry, 'You must pass a valid DOCKER_REGISTRY' +timestamp = datetime.now() +formatted_time = timestamp.strftime('%Y%m%d%H%M%S') class Builder: @@ -85,12 +88,16 @@ class BaseCPUImageBuilder(Builder): return content def build(self): - image_tag = f'{docker_registry}:ubuntu{self.args.ubuntu_version}-torch{self.args.torch_version}-base' + image_tag = ( + f'{docker_registry}:ubuntu{self.args.ubuntu_version}-{self.args.python_tag}-' + f'torch{self.args.torch_version}-base') return os.system( f'DOCKER_BUILDKIT=0 docker build -t {image_tag} -f Dockerfile .') def push(self): - image_tag = f'{docker_registry}:ubuntu{self.args.ubuntu_version}-torch{self.args.torch_version}-base' + image_tag = ( + f'{docker_registry}:ubuntu{self.args.ubuntu_version}-{self.args.python_tag}-' + f'torch{self.args.torch_version}-base') return os.system(f'docker push {image_tag}') @@ -110,14 +117,14 @@ class BaseGPUImageBuilder(Builder): def build(self) -> int: image_tag = ( - f'{docker_registry}:ubuntu{self.args.ubuntu_version}-cuda{self.args.cuda_version}-' + f'{docker_registry}:ubuntu{self.args.ubuntu_version}-cuda{self.args.cuda_version}-{self.args.python_tag}-' f'torch{self.args.torch_version}-tf{self.args.tf_version}-base') return os.system( f'DOCKER_BUILDKIT=0 docker build -t {image_tag} -f Dockerfile .') def push(self): image_tag = ( - f'{docker_registry}:ubuntu{self.args.ubuntu_version}-cuda{self.args.cuda_version}-' + f'{docker_registry}:ubuntu{self.args.ubuntu_version}-cuda{self.args.cuda_version}-{self.args.python_tag}-' f'torch{self.args.torch_version}-tf{self.args.tf_version}-base') return os.system(f'docker push {image_tag}') @@ -129,7 +136,9 @@ class CPUImageBuilder(Builder): version_args = ( f'{self.args.torch_version} {self.args.torchvision_version} ' f'{self.args.torchaudio_version}') - base_image = f'{docker_registry}:ubuntu{self.args.ubuntu_version}-torch{self.args.torch_version}-base' + base_image = ( + f'{docker_registry}:ubuntu{self.args.ubuntu_version}-{self.args.python_tag}' + f'-torch{self.args.torch_version}-base') extra_content = """\nRUN pip install adaseq\nRUN pip install pai-easycv""" with open('docker/Dockerfile.ubuntu', 'r') as f: @@ -157,7 +166,17 @@ class CPUImageBuilder(Builder): f'{docker_registry}:ubuntu{self.args.ubuntu_version}-{self.args.python_tag}-' f'torch{self.args.torch_version}-{self.args.modelscope_version}-test' ) - return os.system(f'docker push {image_tag}') + ret = os.system(f'docker push {image_tag}') + if ret != 0: + return ret + image_tag2 = ( + f'{docker_registry}:ubuntu{self.args.ubuntu_version}-{self.args.python_tag}-' + f'torch{self.args.torch_version}-{self.args.modelscope_version}-{formatted_time}-test' + ) + ret = os.system(f'docker tag {image_tag} {image_tag2}') + if ret != 0: + return ret + return os.system(f'docker push {image_tag2}') class GPUImageBuilder(Builder): @@ -170,7 +189,7 @@ class GPUImageBuilder(Builder): f'{self.args.vllm_version} {self.args.lmdeploy_version} {self.args.autogptq_version}' ) base_image = ( - f'{docker_registry}:ubuntu{self.args.ubuntu_version}-cuda{self.args.cuda_version}-' + f'{docker_registry}:ubuntu{self.args.ubuntu_version}-cuda{self.args.cuda_version}-{self.args.python_tag}-' f'torch{self.args.torch_version}-tf{self.args.tf_version}-base') with open('docker/Dockerfile.ubuntu', 'r') as f: content = f.read() @@ -196,7 +215,17 @@ class GPUImageBuilder(Builder): f'{docker_registry}:ubuntu{self.args.ubuntu_version}-cuda{self.args.cuda_version}-' f'{self.args.python_tag}-torch{self.args.torch_version}-tf{self.args.tf_version}-' f'{self.args.modelscope_version}-test') - return os.system(f'docker push {image_tag}') + ret = os.system(f'docker push {image_tag}') + if ret != 0: + return ret + image_tag2 = ( + f'{docker_registry}:ubuntu{self.args.ubuntu_version}-cuda{self.args.cuda_version}-' + f'{self.args.python_tag}-torch{self.args.torch_version}-tf{self.args.tf_version}-' + f'{self.args.modelscope_version}-{formatted_time}-test') + ret = os.system(f'docker tag {image_tag} {image_tag2}') + if ret != 0: + return ret + return os.system(f'docker push {image_tag2}') class LLMImageBuilder(Builder): @@ -253,7 +282,17 @@ class LLMImageBuilder(Builder): f'{docker_registry}:ubuntu{self.args.ubuntu_version}-cuda{self.args.cuda_version}-' f'{self.args.python_tag}-torch{self.args.torch_version}-{self.args.modelscope_version}-LLM-test' ) - return os.system(f'docker push {image_tag}') + ret = os.system(f'docker push {image_tag}') + if ret != 0: + return ret + image_tag2 = ( + f'{docker_registry}:ubuntu{self.args.ubuntu_version}-cuda{self.args.cuda_version}-' + f'{self.args.python_tag}-torch{self.args.torch_version}-' + f'{self.args.modelscope_version}-LLM-{formatted_time}-test') + ret = os.system(f'docker tag {image_tag} {image_tag2}') + if ret != 0: + return ret + return os.system(f'docker push {image_tag2}') parser = argparse.ArgumentParser() diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 2c2128d8..8166e004 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -193,6 +193,7 @@ class Models(object): # audio models sambert_hifigan = 'sambert-hifigan' speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' + speech_zipenhancer_ans_multiloss_16k_base = 'speech_zipenhancer_ans_multiloss_16k_base' speech_dfsmn_ans = 'speech_dfsmn_ans' speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' speech_dfsmn_kws_char_farfield_iot = 'speech_dfsmn_kws_char_farfield_iot' @@ -551,6 +552,7 @@ class Pipelines(object): sambert_hifigan_tts = 'sambert-hifigan-tts' speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k' speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' + speech_zipenhancer_ans_multiloss_16k_base = 'speech_zipenhancer_ans_multiloss_16k_base' speech_dfsmn_ans_psm_48k_causal = 'speech_dfsmn_ans_psm_48k_causal' speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' speech_separation = 'speech-separation' diff --git a/modelscope/models/audio/ans/zipenhancer.py b/modelscope/models/audio/ans/zipenhancer.py new file mode 100644 index 00000000..544d9dc7 --- /dev/null +++ b/modelscope/models/audio/ans/zipenhancer.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import random +from typing import Dict + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.metainfo import Models +from modelscope.models import TorchModel +from modelscope.models.base import Tensor +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks +from .zipenhancer_layers.generator import (DenseEncoder, MappingDecoder, + PhaseDecoder) +from .zipenhancer_layers.scaling import ScheduledFloat +from .zipenhancer_layers.zipenhancer_layer import Zipformer2DualPathEncoder + + +@MODELS.register_module( + Tasks.acoustic_noise_suppression, + module_name=Models.speech_zipenhancer_ans_multiloss_16k_base) +class ZipenhancerDecorator(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + super().__init__(model_dir, *args, **kwargs) + + h = dict( + num_tsconformers=kwargs['num_tsconformers'], + dense_channel=kwargs['dense_channel'], + former_conf=kwargs['former_conf'], + batch_first=kwargs['batch_first'], + model_num_spks=kwargs['model_num_spks'], + ) + # num_tsconformers, dense_channel, former_name, former_conf, batch_first, model_num_spks + + h = AttrDict(h) + self.model = ZipEnhancer(h) + model_bin_file = os.path.join(model_dir, + ModelFile.TORCH_MODEL_BIN_FILE) + if os.path.exists(model_bin_file): + checkpoint = torch.load( + model_bin_file, map_location=torch.device('cpu')) + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + # the new trained model by user is based on ZipenhancerDecorator + self.load_state_dict(checkpoint['state_dict']) + else: + # The released model on Modelscope is based on Zipenhancer + # self.model.load_state_dict(checkpoint, strict=False) + self.model.load_state_dict(checkpoint['generator']) + # print(checkpoint['generator'].keys()) + + def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: + n_fft = 400 + hop_size = 100 + win_size = 400 + noisy_wav = inputs['noisy'] + norm_factor = torch.sqrt(noisy_wav.shape[1] + / torch.sum(noisy_wav**2.0)) + noisy_audio = (noisy_wav * norm_factor) + + mag, pha, com = mag_pha_stft( + noisy_audio, + n_fft, + hop_size, + win_size, + compress_factor=0.3, + center=True) + amp_g, pha_g, com_g, _, others = self.model.forward(mag, pha) + wav = mag_pha_istft( + amp_g, + pha_g, + n_fft, + hop_size, + win_size, + compress_factor=0.3, + center=True) + + wav = wav / norm_factor + + output = { + 'wav_l2': wav, + } + + return output + + +class ZipEnhancer(nn.Module): + + def __init__(self, h): + """ + Initialize the ZipEnhancer module. + + Args: + h (object): Configuration object containing various hyperparameters and settings. + having num_tsconformers, former_name, former_conf, mask_decoder_type, ... + """ + super(ZipEnhancer, self).__init__() + self.h = h + + num_tsconformers = h.num_tsconformers + self.num_tscblocks = num_tsconformers + self.dense_encoder = DenseEncoder(h, in_channel=2) + + self.TSConformer = Zipformer2DualPathEncoder( + output_downsampling_factor=1, + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + **h.former_conf) + + self.mask_decoder = MappingDecoder(h, out_channel=h.model_num_spks) + self.phase_decoder = PhaseDecoder(h, out_channel=h.model_num_spks) + + def forward(self, noisy_mag, noisy_pha): # [B, F, T] + """ + Forward pass of the ZipEnhancer module. + + Args: + noisy_mag (Tensor): Noisy magnitude input tensor of shape [B, F, T]. + noisy_pha (Tensor): Noisy phase input tensor of shape [B, F, T]. + + Returns: + Tuple: denoised magnitude, denoised phase, denoised complex representation, + (optional) predicted noise components, and other auxiliary information. + """ + others = dict() + + noisy_mag = noisy_mag.unsqueeze(-1).permute(0, 3, 2, 1) # [B, 1, T, F] + noisy_pha = noisy_pha.unsqueeze(-1).permute(0, 3, 2, 1) # [B, 1, T, F] + x = torch.cat((noisy_mag, noisy_pha), dim=1) # [B, 2, T, F] + x = self.dense_encoder(x) + + # [B, C, T, F] + x = self.TSConformer(x) + + pred_mag = self.mask_decoder(x) + pred_pha = self.phase_decoder(x) + # b, c, t, f -> b, 1, t, f -> b, f, t, 1 -> b, f, t + denoised_mag = pred_mag[:, 0, :, :].unsqueeze(1).permute(0, 3, 2, + 1).squeeze(-1) + + # b, t, f + denoised_pha = pred_pha[:, 0, :, :].unsqueeze(1).permute(0, 3, 2, + 1).squeeze(-1) + # b, t, f + denoised_com = torch.stack((denoised_mag * torch.cos(denoised_pha), + denoised_mag * torch.sin(denoised_pha)), + dim=-1) + + return denoised_mag, denoised_pha, denoised_com, None, others + + +class AttrDict(dict): + + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def mag_pha_stft(y, + n_fft, + hop_size, + win_size, + compress_factor=1.0, + center=True): + hann_window = torch.hann_window(win_size, device=y.device) + stft_spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window, + center=center, + pad_mode='reflect', + normalized=False, + return_complex=True) + stft_spec = torch.view_as_real(stft_spec) + mag = torch.sqrt(stft_spec.pow(2).sum(-1) + (1e-9)) + pha = torch.atan2(stft_spec[:, :, :, 1], stft_spec[:, :, :, 0] + (1e-5)) + # Magnitude Compression + mag = torch.pow(mag, compress_factor) + com = torch.stack((mag * torch.cos(pha), mag * torch.sin(pha)), dim=-1) + + return mag, pha, com + + +def mag_pha_istft(mag, + pha, + n_fft, + hop_size, + win_size, + compress_factor=1.0, + center=True): + # Magnitude Decompression + mag = torch.pow(mag, (1.0 / compress_factor)) + com = torch.complex(mag * torch.cos(pha), mag * torch.sin(pha)) + hann_window = torch.hann_window(win_size, device=com.device) + + wav = torch.istft( + com, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window, + center=center) + return wav diff --git a/modelscope/models/audio/ans/zipenhancer_layers/generator.py b/modelscope/models/audio/ans/zipenhancer_layers/generator.py new file mode 100644 index 00000000..8332ba4d --- /dev/null +++ b/modelscope/models/audio/ans/zipenhancer_layers/generator.py @@ -0,0 +1,220 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Alibaba, Inc. and its affiliates. +# Part of the implementation is borrowed and modified from MP-SENet, +# public available at https://github.com/yxlu-0102/MP-SENet + +import random + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SubPixelConvTranspose2d(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size=(1, 3), + stride=(1, 2), + padding=(0, 1)): + super(SubPixelConvTranspose2d, self).__init__() + self.upscale_width_factor = stride[1] + self.conv1 = nn.Conv2d( + in_channels, + out_channels * self.upscale_width_factor, + kernel_size=kernel_size, + padding=padding) # only change the width + + def forward(self, x): + + b, c, t, f = x.size() + # Use conv1 for upsampling, followed by expansion only in the width dimension. + x = self.conv1(x) + # print(x.size()) + # Note: Here we do not directly use PixelShuffle because we only intend to expand in the width dimension, + # whereas PixelShuffle operates simultaneously on both height and width, hence we manually adjust accordingly. + # b, 2c, t, f + # print(x.size()) + x = x.view(b, c, self.upscale_width_factor, t, + f).permute(0, 1, 3, 4, 2).contiguous() + # b, c, 2, t, f -> b, c, t, f, 2 + x = x.view(b, c, t, f * self.upscale_width_factor) + # b, c, t, 2f = 202 + # x = nn.functional.pad(x, (0, 1)) + # b, c, t, 2f = 202 + + return x + + +class DenseBlockV2(nn.Module): + """ + A denseblock for ZipEnhancer + """ + + def __init__(self, h, kernel_size=(2, 3), depth=4): + super(DenseBlockV2, self).__init__() + self.h = h + self.depth = depth + self.dense_block = nn.ModuleList([]) + for i in range(depth): + dil = 2**i + pad_length = kernel_size[0] + (dil - 1) * (kernel_size[0] - 1) - 1 + dense_conv = nn.Sequential( + nn.ConstantPad2d((1, 1, pad_length, 0), value=0.), + nn.Conv2d( + h.dense_channel * (i + 1), + h.dense_channel, + kernel_size, + dilation=(dil, 1)), + # nn.Conv2d(h.dense_channel * (i + 1), h.dense_channel, kernel_size, dilation=(dil, 1), + # padding=get_padding_2d(kernel_size, (dil, 1))), + nn.InstanceNorm2d(h.dense_channel, affine=True), + nn.PReLU(h.dense_channel)) + self.dense_block.append(dense_conv) + + def forward(self, x): + skip = x + # b, c, t, f + for i in range(self.depth): + _x = skip + x = self.dense_block[i](_x) + # print(x.size()) + skip = torch.cat([x, skip], dim=1) + return x + + +class DenseEncoder(nn.Module): + + def __init__(self, h, in_channel): + """ + Initialize the DenseEncoder module. + + Args: + h (object): Configuration object containing various hyperparameters and settings. + in_channel (int): Number of input channels. Example: mag + phase: 2 channels + """ + super(DenseEncoder, self).__init__() + self.h = h + self.dense_conv_1 = nn.Sequential( + nn.Conv2d(in_channel, h.dense_channel, (1, 1)), + nn.InstanceNorm2d(h.dense_channel, affine=True), + nn.PReLU(h.dense_channel)) + + self.dense_block = DenseBlockV2(h, depth=4) + + encoder_pad_kersize = (0, 1) + # Here pad was originally (0, 0),now change to (0, 1) + self.dense_conv_2 = nn.Sequential( + nn.Conv2d( + h.dense_channel, + h.dense_channel, (1, 3), (1, 2), + padding=encoder_pad_kersize), + nn.InstanceNorm2d(h.dense_channel, affine=True), + nn.PReLU(h.dense_channel)) + + def forward(self, x): + """ + Forward pass of the DenseEncoder module. + + Args: + x (Tensor): Input tensor of shape [B, C=in_channel, T, F]. + + Returns: + Tensor: Output tensor after passing through the dense encoder. Maybe: [B, C=dense_channel, T, F // 2]. + """ + # print("x: {}".format(x.size())) + x = self.dense_conv_1(x) # [b, 64, T, F] + if self.dense_block is not None: + x = self.dense_block(x) # [b, 64, T, F] + x = self.dense_conv_2(x) # [b, 64, T, F//2] + return x + + +class BaseDecoder(nn.Module): + + def __init__(self, h): + """ + Initialize the BaseDecoder module. + + Args: + h (object): Configuration object containing various hyperparameters and settings. + including upsample_type, dense_block_type. + """ + super(BaseDecoder, self).__init__() + + self.upsample_module_class = SubPixelConvTranspose2d + + # for both mag and phase decoder + self.dense_block = DenseBlockV2(h, depth=4) + + +class MappingDecoder(BaseDecoder): + + def __init__(self, h, out_channel=1): + """ + Initialize the MappingDecoderV3 module. + + Args: + h (object): Configuration object containing various hyperparameters and settings. + out_channel (int): Number of output channels. Default is 1. The number of output spearkers. + """ + super(MappingDecoder, self).__init__(h) + decoder_final_kersize = (1, 2) + + self.mask_conv = nn.Sequential( + self.upsample_module_class(h.dense_channel, h.dense_channel, + (1, 3), (1, 2)), + # nn.Conv2d(h.dense_channel, out_channel, (1, 1)), + nn.InstanceNorm2d(h.dense_channel, affine=True), + nn.PReLU(h.dense_channel), + nn.Conv2d(h.dense_channel, out_channel, decoder_final_kersize)) + # Upsample at F dimension + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """ + Forward pass of the MappingDecoderV3 module. + + Args: + x (Tensor): Input tensor. [B, C, T, F] + + Returns: + Tensor: Output tensor after passing through the decoder. [B, Num_Spks, T, F] + """ + if self.dense_block is not None: + x = self.dense_block(x) + x = self.mask_conv(x) + x = self.relu(x) + # b, c=1, t, f + return x + + +class PhaseDecoder(BaseDecoder): + + def __init__(self, h, out_channel=1): + super(PhaseDecoder, self).__init__(h) + + # now change to (1, 2), previous (1, 1) + decoder_final_kersize = (1, 2) + + self.phase_conv = nn.Sequential( + self.upsample_module_class(h.dense_channel, h.dense_channel, + (1, 3), (1, 2)), + nn.InstanceNorm2d(h.dense_channel, affine=True), + nn.PReLU(h.dense_channel)) + self.phase_conv_r = nn.Conv2d(h.dense_channel, out_channel, + decoder_final_kersize) + self.phase_conv_i = nn.Conv2d(h.dense_channel, out_channel, + decoder_final_kersize) + + def forward(self, x): + if self.dense_block is not None: + x = self.dense_block(x) + x = self.phase_conv(x) + x_r = self.phase_conv_r(x) + x_i = self.phase_conv_i(x) + x = torch.atan2(x_i, x_r) + return x diff --git a/modelscope/models/audio/ans/zipenhancer_layers/scaling.py b/modelscope/models/audio/ans/zipenhancer_layers/scaling.py new file mode 100644 index 00000000..b30eec49 --- /dev/null +++ b/modelscope/models/audio/ans/zipenhancer_layers/scaling.py @@ -0,0 +1,1055 @@ +# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey) +# Copyright (c) 2024 Alibaba, Inc. and its affiliates. +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import math +import random +from typing import Optional, Tuple, Union + +# import k2 +import torch +import torch.nn as nn +from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd + + +def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: + max_value = torch.max(x, y) + diff = torch.abs(x - y) + return max_value + torch.log1p(torch.exp(-diff)) + + +# RuntimeError: Exporting the operator logaddexp to ONNX opset version +# 14 is not supported. Please feel free to request support or submit +# a pull request on PyTorch GitHub. +# +# The following function is to solve the above error when exporting +# models to ONNX via torch.jit.trace() +def logaddexp(x: Tensor, y: Tensor) -> Tensor: + # Caution(fangjun): Put torch.jit.is_scripting() before + # torch.onnx.is_in_onnx_export(); + # otherwise, it will cause errors for torch.jit.script(). + # + # torch.logaddexp() works for both torch.jit.script() and + # torch.jit.trace() but it causes errors for ONNX export. + # + if torch.jit.is_scripting(): + # Note: We cannot use torch.jit.is_tracing() here as it also + # matches torch.onnx.export(). + return torch.logaddexp(x, y) + elif torch.onnx.is_in_onnx_export(): + return logaddexp_onnx(x, y) + else: + # for torch.jit.trace() + return torch.logaddexp(x, y) + + +class PiecewiseLinear(object): + """ + Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with + the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y] + respectively. + """ + + def __init__(self, *args): + assert len(args) >= 1, len(args) + if len(args) == 1 and isinstance(args[0], PiecewiseLinear): + self.pairs = list(args[0].pairs) + else: + self.pairs = [(float(x), float(y)) for x, y in args] + for x, y in self.pairs: + assert isinstance(x, (float, int)), type(x) + assert isinstance(y, (float, int)), type(y) + + for i in range(len(self.pairs) - 1): + assert self.pairs[i + 1][0] > self.pairs[i][0], ( + i, + self.pairs[i], + self.pairs[i + 1], + ) + + def __str__(self): + # e.g. 'PiecewiseLinear((0., 10.), (100., 0.))' + return f'PiecewiseLinear({str(self.pairs)[1:-1]})' + + def __call__(self, x): + if x <= self.pairs[0][0]: + return self.pairs[0][1] + elif x >= self.pairs[-1][0]: + return self.pairs[-1][1] + else: + cur_x, cur_y = self.pairs[0] + for i in range(1, len(self.pairs)): + next_x, next_y = self.pairs[i] + if x >= cur_x and x <= next_x: + return cur_y + (next_y - cur_y) * (x - cur_x) / ( + next_x - cur_x) + cur_x, cur_y = next_x, next_y + assert False + + def __mul__(self, alpha): + return PiecewiseLinear(*[(x, y * alpha) for x, y in self.pairs]) + + def __add__(self, x): + if isinstance(x, (float, int)): + return PiecewiseLinear(*[(p[0], p[1] + x) for p in self.pairs]) + s, x = self.get_common_basis(x) + return PiecewiseLinear(*[(sp[0], sp[1] + xp[1]) + for sp, xp in zip(s.pairs, x.pairs)]) + + def max(self, x): + if isinstance(x, (float, int)): + x = PiecewiseLinear((0, x)) + s, x = self.get_common_basis(x, include_crossings=True) + return PiecewiseLinear(*[(sp[0], max(sp[1], xp[1])) + for sp, xp in zip(s.pairs, x.pairs)]) + + def min(self, x): + if isinstance(x, float) or isinstance(x, int): + x = PiecewiseLinear((0, x)) + s, x = self.get_common_basis(x, include_crossings=True) + return PiecewiseLinear(*[(sp[0], min(sp[1], xp[1])) + for sp, xp in zip(s.pairs, x.pairs)]) + + def __eq__(self, other): + return self.pairs == other.pairs + + def get_common_basis(self, + p: 'PiecewiseLinear', + include_crossings: bool = False): + """ + Returns (self_mod, p_mod) which are equivalent piecewise linear + functions to self and p, but with the same x values. + + p: the other piecewise linear function + include_crossings: if true, include in the x values positions + where the functions indicate by this and p cross. + """ + assert isinstance(p, PiecewiseLinear), type(p) + + # get sorted x-values without repetition. + x_vals = sorted( + set([x for x, _ in self.pairs] + [x for x, _ in p.pairs])) + y_vals1 = [self(x) for x in x_vals] + y_vals2 = [p(x) for x in x_vals] + + if include_crossings: + extra_x_vals = [] + for i in range(len(x_vals) - 1): + _compare_results1 = (y_vals1[i] > y_vals2[i]) + _compare_results2 = (y_vals1[i + 1] > y_vals2[i + 1]) + if _compare_results1 != _compare_results2: + # if ((y_vals1[i] > y_vals2[i]) != + # (y_vals1[i + 1] > y_vals2[i + 1])): + # if the two lines in this subsegment potentially cross each other. + diff_cur = abs(y_vals1[i] - y_vals2[i]) + diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1]) + # `pos`, between 0 and 1, gives the relative x position, + # with 0 being x_vals[i] and 1 being x_vals[i+1]. + pos = diff_cur / (diff_cur + diff_next) + extra_x_val = x_vals[i] + pos * (x_vals[i + 1] - x_vals[i]) + extra_x_vals.append(extra_x_val) + if len(extra_x_vals) > 0: + x_vals = sorted(set(x_vals + extra_x_vals)) + y_vals1 = [self(x) for x in x_vals] + y_vals2 = [p(x) for x in x_vals] + return ( + PiecewiseLinear(*zip(x_vals, y_vals1)), + PiecewiseLinear(*zip(x_vals, y_vals2)), + ) + + +class ScheduledFloat(torch.nn.Module): + """ + This object is a torch.nn.Module only because we want it to show up in [top_level module].modules(); + it does not have a working forward() function. You are supposed to cast it to float, as + in, float(parent_module.whatever), and use it as something like a dropout prob. + + It is a floating point value whose value changes depending on the batch count of the + training loop. It is a piecewise linear function where you specify the (x,y) pairs + in sorted order on x; x corresponds to the batch index. For batch-index values before the + first x or after the last x, we just use the first or last y value. + + Example: + self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0) + + `default` is used when self.batch_count is not set or not in training mode or in + torch.jit scripting mode. + """ + + def __init__(self, *args, default: float = 0.0): + super().__init__() + # self.batch_count and self.name will be written to in the training loop. + self.batch_count = None + self.name = None + self.default = default + self.schedule = PiecewiseLinear(*args) + + def extra_repr(self) -> str: + return ( + f'batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}' + ) + + def __float__(self): + batch_count = self.batch_count + if (batch_count is None or not self.training + or torch.jit.is_scripting() or torch.jit.is_tracing()): + return float(self.default) + else: + ans = self.schedule(self.batch_count) + if random.random() < 0.0002: + logging.info( + f'ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}' + ) + return ans + + def __add__(self, x): + if isinstance(x, float) or isinstance(x, int): + return ScheduledFloat(self.schedule + x, default=self.default) + else: + return ScheduledFloat( + self.schedule + x.schedule, default=self.default + x.default) + + def max(self, x): + if isinstance(x, float) or isinstance(x, int): + return ScheduledFloat(self.schedule.max(x), default=self.default) + else: + return ScheduledFloat( + self.schedule.max(x.schedule), + default=max(self.default, x.default)) + + +FloatLike = Union[float, ScheduledFloat] + + +class SoftmaxFunction(torch.autograd.Function): + """ + Tries to handle half-precision derivatives in a randomized way that should + be more accurate for training than the default behavior. + """ + + @staticmethod + def forward(ctx, x: Tensor, dim: int): + ans = x.softmax(dim=dim) + # if x dtype is float16, x.softmax() returns a float32 because + # (presumably) that op does not support float16, and autocast + # is enabled. + if torch.is_autocast_enabled(): + ans = ans.to(torch.float16) + ctx.save_for_backward(ans) + ctx.x_dtype = x.dtype + ctx.dim = dim + return ans + + @staticmethod + def backward(ctx, ans_grad: Tensor): + (ans, ) = ctx.saved_tensors + with torch.cuda.amp.autocast(enabled=False): + ans_grad = ans_grad.to(torch.float32) + ans = ans.to(torch.float32) + x_grad = ans_grad * ans + x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) + return x_grad, None + + +def inplace_softmax(tensor, dim): + # Subtract the maximum value from each Tensor to prevent overflow. + max_vals, _ = tensor.max(dim=dim, keepdim=True) + tensor.sub_(max_vals) + + # # calculate logsumexp + # log_sum_exp = torch.logsumexp(tensor, dim=dim, keepdim=True) + # + # # minus logsumexp + # tensor.sub_(log_sum_exp) + # + # # Compute the exponential of each element, and store the results in-place. + # tensor.exp_() + + # Compute the exponential of each element, and store the results in-place. + tensor.exp_() + + # Compute the sum along the specified dimension, and store the result in-place. + sum_exp = tensor.sum(dim=dim, keepdim=True) + + # Divide each element by the sum along that dimension, and store the result in-place. + tensor.div_(sum_exp) + # tensor.add_(1e-8) + + return tensor + + +def softmax(x: Tensor, dim: int): + if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing( + ): + return x.softmax(dim=dim) + # inplace operator + # return inplace_softmax(x, dim) + + return SoftmaxFunction.apply(x, dim) + + +class BiasNormFunction(torch.autograd.Function): + # This computes: + # scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp() + # return x * scales + # (after unsqueezing the bias), but it does it in a memory-efficient way so that + # it can just store the returned value (chances are, this will also be needed for + # some other reason, related to the next operation, so we can save memory). + @staticmethod + def forward( + ctx, + x: Tensor, + bias: Tensor, + log_scale: Tensor, + channel_dim: int, + store_output_for_backprop: bool, + ) -> Tensor: + assert bias.ndim == 1 + if channel_dim < 0: + channel_dim = channel_dim + x.ndim + ctx.store_output_for_backprop = store_output_for_backprop + ctx.channel_dim = channel_dim + for _ in range(channel_dim + 1, x.ndim): + bias = bias.unsqueeze(-1) + _x_bias_square = torch.mean( + (x - bias)**2, dim=channel_dim, keepdim=True) + scales = (_x_bias_square**-0.5) * log_scale.exp() + ans = x * scales + ctx.save_for_backward( + ans.detach() if store_output_for_backprop else x, + scales.detach(), + bias.detach(), + log_scale.detach(), + ) + return ans + + @staticmethod + def backward(ctx, ans_grad: Tensor) -> Tensor: + ans_or_x, scales, bias, log_scale = ctx.saved_tensors + if ctx.store_output_for_backprop: + x = ans_or_x / scales + else: + x = ans_or_x + x = x.detach() + x.requires_grad = True + bias.requires_grad = True + log_scale.requires_grad = True + with torch.enable_grad(): + # recompute scales from x, bias and log_scale. + _x_bias_square = torch.mean( + (x - bias)**2, dim=ctx.channel_dim, keepdim=True) + scales = (_x_bias_square**-0.5) * log_scale.exp() + ans = x * scales + ans.backward(gradient=ans_grad) + return x.grad, bias.grad.flatten(), log_scale.grad, None, None + + +class BiasNorm(torch.nn.Module): + """ + This is intended to be a simpler, and hopefully cheaper, replacement for + LayerNorm. The observation this is based on, is that Transformer-type + networks, especially with pre-norm, sometimes seem to set one of the + feature dimensions to a large constant value (e.g. 50), which "defeats" + the LayerNorm because the output magnitude is then not strongly dependent + on the other (useful) features. Presumably the weight and bias of the + LayerNorm are required to allow it to do this. + + Instead, we give the BiasNorm a trainable bias that it can use when + computing the scale for normalization. We also give it a (scalar) + trainable scale on the output. + + + Args: + num_channels: the number of channels, e.g. 512. + channel_dim: the axis/dimension corresponding to the channel, + interpreted as an offset from the input's ndim if negative. + This is NOT the num_channels; it should typically be one of + {-2, -1, 0, 1, 2, 3}. + log_scale: the initial log-scale that we multiply the output by; this + is learnable. + log_scale_min: FloatLike, minimum allowed value of log_scale + log_scale_max: FloatLike, maximum allowed value of log_scale + store_output_for_backprop: only possibly affects memory use; recommend + to set to True if you think the output of this module is more likely + than the input of this module to be required to be stored for the + backprop. + """ + + def __init__( + self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + log_scale: float = 1.0, + log_scale_min: float = -1.5, + log_scale_max: float = 1.5, + store_output_for_backprop: bool = False, + ) -> None: + super(BiasNorm, self).__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + self.log_scale = nn.Parameter(torch.tensor(log_scale)) + self.bias = nn.Parameter( + torch.empty(num_channels).normal_(mean=0, std=1e-4)) + + self.log_scale_min = log_scale_min + self.log_scale_max = log_scale_max + + self.store_output_for_backprop = store_output_for_backprop + + def forward(self, x: Tensor) -> Tensor: + assert x.shape[self.channel_dim] == self.num_channels + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + channel_dim = self.channel_dim + if channel_dim < 0: + channel_dim += x.ndim + bias = self.bias + for _ in range(channel_dim + 1, x.ndim): + bias = bias.unsqueeze(-1) + _x_bias_square = torch.mean( + (x - bias)**2, dim=channel_dim, keepdim=True) + scales = (_x_bias_square**-0.5) * self.log_scale.exp() + return x * scales + + log_scale = limit_param_value( + self.log_scale, + min=float(self.log_scale_min), + max=float(self.log_scale_max), + training=self.training, + ) + + return BiasNormFunction.apply(x, self.bias, log_scale, + self.channel_dim, + self.store_output_for_backprop) + + +def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: + """ + Behaves like a constructor of a modified version of nn.Linear + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Linear(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, + 0.1 * initial_scale) + return ans + + +class ChunkCausalDepthwiseConv1d(torch.nn.Module): + """ + Behaves like a depthwise 1d convolution, except that it is causal in + a chunkwise way, as if we had a block-triangular attention mask. + The chunk size is provided at test time (it should probably be + kept in sync with the attention mask). + + This has a little more than twice the parameters of a conventional + depthwise conv1d module: we implement it by having one + depthwise convolution, of half the width, that is causal (via + right-padding); and one depthwise convolution that is applied only + within chunks, that we multiply by a scaling factor which depends + on the position within the chunk. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + + def __init__( + self, + channels: int, + kernel_size: int, + initial_scale: float = 1.0, + bias: bool = True, + ): + super().__init__() + assert kernel_size % 2 == 1 + + half_kernel_size = (kernel_size + 1) // 2 + # will pad manually, on one side. + self.causal_conv = nn.Conv1d( + in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=half_kernel_size, + padding=0, + bias=True, + ) + + self.chunkwise_conv = nn.Conv1d( + in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + bias=bias, + ) + + # first row is correction factors added to the scale near the left edge of the chunk, + # second row is correction factors added to the scale near the right edge of the chunk, + # both of these are added to a default scale of 1.0. + self.chunkwise_conv_scale = nn.Parameter( + torch.zeros(2, channels, kernel_size)) + self.kernel_size = kernel_size + + with torch.no_grad(): + self.causal_conv.weight[:] *= initial_scale + self.chunkwise_conv.weight[:] *= initial_scale + if bias: + torch.nn.init.uniform_(self.causal_conv.bias, + -0.1 * initial_scale, + 0.1 * initial_scale) + + def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor: + """Forward function. + + Args: + x: a Tensor of shape (batch_size, channels, seq_len) + chunk_size: the chunk size, in frames; does not have to divide seq_len exactly. + """ + (batch_size, num_channels, seq_len) = x.shape + + # half_kernel_size = self.kernel_size + 1 // 2 + # left_pad is half_kernel_size - 1 where half_kernel_size is the size used + # in the causal conv. It's the amount by which we must pad on the left, + # to make the convolution causal. + left_pad = self.kernel_size // 2 + + if chunk_size < 0 or chunk_size > seq_len: + chunk_size = seq_len + right_pad = -seq_len % chunk_size + + x = torch.nn.functional.pad(x, (left_pad, right_pad)) + + x_causal = self.causal_conv(x[..., :left_pad + seq_len]) + assert x_causal.shape == (batch_size, num_channels, seq_len) + + x_chunk = x[..., left_pad:] + num_chunks = x_chunk.shape[2] // chunk_size + x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks, + chunk_size) + x_chunk = x_chunk.permute(0, 2, 1, 3).reshape(batch_size * num_chunks, + num_channels, chunk_size) + x_chunk = self.chunkwise_conv(x_chunk) # does not change shape + + chunk_scale = self._get_chunk_scale(chunk_size) + + x_chunk = x_chunk * chunk_scale + x_chunk = x_chunk.reshape(batch_size, num_chunks, num_channels, + chunk_size).permute(0, 2, 1, 3) + x_chunk = x_chunk.reshape(batch_size, num_channels, + num_chunks * chunk_size)[..., :seq_len] + + return x_chunk + x_causal + + def _get_chunk_scale(self, chunk_size: int): + """Returns tensor of shape (num_channels, chunk_size) that will be used to + scale the output of self.chunkwise_conv.""" + left_edge = self.chunkwise_conv_scale[0] + right_edge = self.chunkwise_conv_scale[1] + if chunk_size < self.kernel_size: + left_edge = left_edge[:, :chunk_size] + right_edge = right_edge[:, -chunk_size:] + else: + t = chunk_size - self.kernel_size + channels = left_edge.shape[0] + pad = torch.zeros( + channels, t, device=left_edge.device, dtype=left_edge.dtype) + left_edge = torch.cat((left_edge, pad), dim=-1) + right_edge = torch.cat((pad, right_edge), dim=-1) + return 1.0 + (left_edge + right_edge) + + def streaming_forward( + self, + x: Tensor, + cache: Tensor, + ) -> Tuple[Tensor, Tensor]: + """Streaming Forward function. + + Args: + x: a Tensor of shape (batch_size, channels, seq_len) + cache: cached left context of shape (batch_size, channels, left_pad) + """ + (batch_size, num_channels, seq_len) = x.shape + + # left_pad is half_kernel_size - 1 where half_kernel_size is the size used + # in the causal conv. It's the amount by which we must pad on the left, + # to make the convolution causal. + left_pad = self.kernel_size // 2 + + # Pad cache + assert cache.shape[-1] == left_pad, (cache.shape[-1], left_pad) + x = torch.cat([cache, x], dim=2) + # Update cache + cache = x[..., -left_pad:] + + x_causal = self.causal_conv(x) + assert x_causal.shape == (batch_size, num_channels, seq_len) + + x_chunk = x[..., left_pad:] + x_chunk = self.chunkwise_conv(x_chunk) # does not change shape + + chunk_scale = self._get_chunk_scale(chunk_size=seq_len) + x_chunk = x_chunk * chunk_scale + + return x_chunk + x_causal, cache + + +def penalize_abs_values_gt(x: Tensor, + limit: float, + penalty: float, + name: str = None) -> Tensor: + """ + Returns x unmodified, but in backprop will put a penalty for the excess of + the absolute values of elements of x over the limit "limit". E.g. if + limit == 10.0, then if x has any values over 10 it will get a penalty. + + Caution: the value of this penalty will be affected by grad scaling used + in automatic mixed precision training. For this reasons we use this, + it shouldn't really matter, or may even be helpful; we just use this + to disallow really implausible values of scores to be given to softmax. + + The name is for randomly printed debug info. + """ + x_sign = x.sign() + over_limit = (x.abs() - limit) > 0 + # The following is a memory efficient way to penalize the absolute values of + # x that's over the limit. (The memory efficiency comes when you think + # about which items torch needs to cache for the autograd, and which ones it + # can throw away). The numerical value of aux_loss as computed here will + # actually be larger than it should be, by limit * over_limit.sum(), but it + # has the same derivative as the real aux_loss which is penalty * (x.abs() - + # limit).relu(). + aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x) + # note: we don't do sum() here on aux)_loss, but it's as if we had done + # sum() due to how with_loss() works. + x = with_loss(x, aux_loss, name) + # you must use x for something, or this will be ineffective. + return x + + +class WithLoss(torch.autograd.Function): + + @staticmethod + def forward(ctx, x: Tensor, y: Tensor, name: str): + ctx.y_shape = y.shape + if random.random() < 0.002 and name is not None: + loss_sum = y.sum().item() + logging.info(f'WithLoss: name={name}, loss-sum={loss_sum:.3e}') + return x + + @staticmethod + def backward(ctx, ans_grad: Tensor): + return ( + ans_grad, + torch.ones( + ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device), + None, + ) + + +def with_loss(x, y, name): + # returns x but adds y.sum() to the loss function. + return WithLoss.apply(x, y, name) + + +class LimitParamValue(torch.autograd.Function): + + @staticmethod + def forward(ctx, x: Tensor, min: float, max: float): + ctx.save_for_backward(x) + assert max >= min + ctx.min = min + ctx.max = max + return x + + @staticmethod + def backward(ctx, x_grad: Tensor): + (x, ) = ctx.saved_tensors + # where x < ctx.min, ensure all grads are negative (this will tend to make + # x more positive). + x_grad = x_grad * torch.where( + torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0) + # where x > ctx.max, ensure all grads are positive (this will tend to make + # x more negative). + x_grad *= torch.where( + torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0) + return x_grad, None, None + + +def limit_param_value(x: Tensor, + min: float, + max: float, + prob: float = 0.6, + training: bool = True): + # You apply this to (typically) an nn.Parameter during training to ensure that its + # (elements mostly) stays within a supplied range. This is done by modifying the + # gradients in backprop. + # It's not necessary to do this on every batch: do it only some of the time, + # to save a little time. + if training and random.random() < prob: + return LimitParamValue.apply(x, min, max) + else: + return x + + +def _no_op(x: Tensor) -> Tensor: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return x + else: + # a no-op function that will have a node in the autograd graph, + # to avoid certain bugs relating to backward hooks + return x.chunk(1, dim=-1)[0] + + +class Identity(torch.nn.Module): + + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x): + return _no_op(x) + + +# Dropout2 is just like normal dropout, except it supports schedules on the dropout rates. +class Dropout2(nn.Module): + + def __init__(self, p: FloatLike): + super().__init__() + self.p = p + + def forward(self, x: Tensor) -> Tensor: + return torch.nn.functional.dropout( + x, p=float(self.p), training=self.training) + + +class SwooshLFunction(torch.autograd.Function): + """ + swoosh_l(x) = log(1 + exp(x-4)) - 0.08*x - 0.035 + """ + + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + requires_grad = x.requires_grad + if x.dtype == torch.float16: + x = x.to(torch.float32) + + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + + coeff = -0.08 + + with (torch.cuda.amp.autocast(enabled=False)): + with torch.enable_grad(): + x = x.detach() + x.requires_grad = True + y = torch.logaddexp(zero, x - 4.0) + coeff * x - 0.035 + + if not requires_grad: + return y + + y.backward(gradient=torch.ones_like(y)) + + grad = x.grad + floor = coeff + ceil = 1.0 + coeff + 0.005 + _diff = (grad - floor) * (255.0 / (ceil - floor)) + d_scaled = _diff + torch.rand_like(grad) + if __name__ == '__main__': + # for self-testing only. + assert d_scaled.min() >= 0.0 + assert d_scaled.max() < 256.0 + + d_int = d_scaled.to(torch.uint8) + ctx.save_for_backward(d_int) + if x.dtype == torch.float16 or torch.is_autocast_enabled(): + y = y.to(torch.float16) + return y + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + (d, ) = ctx.saved_tensors + # the same constants as used in forward pass. + + coeff = -0.08 + floor = coeff + ceil = 1.0 + coeff + 0.005 + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d + + +class SwooshL(torch.nn.Module): + + def forward(self, x: Tensor) -> Tensor: + """Return Swoosh-L activation.""" + if torch.jit.is_scripting() or torch.jit.is_tracing(): + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 + # if not x.requires_grad: + # return k2.swoosh_l_forward(x) + # else: + # return k2.swoosh_l(x) + return SwooshLFunction.apply(x) + + +class SwooshLOnnx(torch.nn.Module): + + def forward(self, x: Tensor) -> Tensor: + """Return Swoosh-L activation.""" + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + return logaddexp_onnx(zero, x - 4.0) - 0.08 * x - 0.035 + + +class SwooshRFunction(torch.autograd.Function): + """ + swoosh_r(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 + + derivatives are between -0.08 and 0.92. + """ + + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + requires_grad = x.requires_grad + + if x.dtype == torch.float16: + x = x.to(torch.float32) + + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + + with torch.cuda.amp.autocast(enabled=False): + with torch.enable_grad(): + x = x.detach() + x.requires_grad = True + y = torch.logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 + + if not requires_grad: + return y + y.backward(gradient=torch.ones_like(y)) + + grad = x.grad + floor = -0.08 + ceil = 0.925 + + _diff = (grad - floor) * (255.0 / (ceil - floor)) + d_scaled = _diff + torch.rand_like(grad) + if __name__ == '__main__': + # for self-testing only. + assert d_scaled.min() >= 0.0 + assert d_scaled.max() < 256.0 + + d_int = d_scaled.to(torch.uint8) + ctx.save_for_backward(d_int) + if x.dtype == torch.float16 or torch.is_autocast_enabled(): + y = y.to(torch.float16) + return y + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + (d, ) = ctx.saved_tensors + # the same constants as used in forward pass. + floor = -0.08 + ceil = 0.925 + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d + + +class SwooshR(torch.nn.Module): + + def forward(self, x: Tensor) -> Tensor: + """Return Swoosh-R activation.""" + if torch.jit.is_scripting() or torch.jit.is_tracing(): + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 + # if not x.requires_grad: + # return k2.swoosh_r_forward(x) + # else: + # return k2.swoosh_r(x) + return SwooshRFunction.apply(x) + + +class SwooshROnnx(torch.nn.Module): + + def forward(self, x: Tensor) -> Tensor: + """Return Swoosh-R activation.""" + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + return logaddexp_onnx(zero, x - 1.0) - 0.08 * x - 0.313261687 + + +# simple version of SwooshL that does not redefine the backprop, used in +# ActivationDropoutAndLinearFunction. +def SwooshLForward(x: Tensor): + x_offset = x - 4.0 + log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) + log_sum = torch.where(log_sum == float('inf'), x_offset, log_sum) + return log_sum - 0.08 * x - 0.035 + + +def SwooshLForwardAndDeriv(x: Tensor): + """ + https://k2-fsa.github.io/k2/python_api/api.html#swoosh-l-forward-and-deriv + :param x: + :return: + """ + x_offset = x - 4.0 + log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) + log_sum = torch.where(log_sum == float('inf'), x_offset, log_sum) + + deriv = 0.92 - 1 / (1 + x_offset.exp()) + + return log_sum - 0.08 * x - 0.035, deriv + + +# simple version of SwooshR that does not redefine the backprop, used in +# ActivationDropoutAndLinearFunction. +def SwooshRForward(x: Tensor): + x_offset = x - 1.0 + log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) + log_sum = torch.where(log_sum == float('inf'), x_offset, log_sum) + return log_sum - 0.08 * x - 0.313261687 + + +def SwooshRForwardAndDeriv(x: Tensor): + """ + https://k2-fsa.github.io/k2/python_api/api.html#swoosh-r-forward-and-deriv + :param x: + :return: + """ + x_offset = x - 1.0 + log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) + log_sum = torch.where(log_sum == float('inf'), x_offset, log_sum) + + deriv = 0.92 - 1 / (1 + x_offset.exp()) + + return log_sum - 0.08 * x - 0.313261687, deriv + + +class ActivationDropoutAndLinear(torch.nn.Module): + """ + This merges an activation function followed by dropout and then a nn.Linear module; + it does so in a memory efficient way so that it only stores the input to the whole + module. If activation == SwooshL and dropout_shared_dim != None, this will be + equivalent to: + nn.Sequential(SwooshL(), + Dropout3(dropout_p, shared_dim=dropout_shared_dim), + ScaledLinear(in_channels, out_channels, bias=bias, + initial_scale=initial_scale)) + If dropout_shared_dim is None, the dropout would be equivalent to + Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout + mask is smaller. + + Args: + in_channels: number of input channels, e.g. 256 + out_channels: number of output channels, e.g. 256 + bias: if true, have a bias + activation: the activation function, for now just support SwooshL. + dropout_p: the dropout probability or schedule (happens after nonlinearity). + dropout_shared_dim: the dimension, if any, across which the dropout mask is + shared (e.g. the time dimension). If None, this may be less memory + efficient if there are modules before this one that cache the input + for their backprop (e.g. Balancer or Whiten). + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + bias: bool = True, + activation: str = 'SwooshL', + dropout_p: FloatLike = 0.0, + dropout_shared_dim: Optional[int] = -1, + initial_scale: float = 1.0, + ): + super().__init__() + # create a temporary module of nn.Linear that we'll steal the + # weights and bias from + linear_module = ScaledLinear( + in_channels, out_channels, bias=bias, initial_scale=initial_scale) + + self.weight = linear_module.weight + # register_parameter properly handles making it a parameter when l.bias + # is None. I think there is some reason for doing it this way rather + # than just setting it to None but I don't know what it is, maybe + # something to do with exporting the module.. + self.register_parameter('bias', linear_module.bias) + + self.activation = activation + self.dropout_p = dropout_p + self.dropout_shared_dim = dropout_shared_dim + + def forward(self, x: Tensor): + # if torch.jit.is_scripting() or torch.jit.is_tracing(): + if torch.jit.is_scripting() or torch.jit.is_tracing() or ( + not self.training): + if self.activation == 'SwooshL': + x = SwooshLForward(x) + # x = k2.swoosh_l_forward(x) + elif self.activation == 'SwooshR': + x = SwooshRForward(x) + # x = k2.swoosh_r_forward(x) + else: + assert False, self.activation + return torch.nn.functional.linear(x, self.weight, self.bias) + + # print(f"dropout_p:{float(self.dropout_p)}") + # print(f"dropout_shared_dim:{self.dropout_shared_dim}") + # return ActivationDropoutAndLinearFunction.apply( + # x, + # self.weight, + # self.bias, + # self.activation, + # float(self.dropout_p), + # self.dropout_shared_dim, + # ) + + +def convert_num_channels(x: Tensor, num_channels: int) -> Tensor: + """ + + :param x: (b, c, t, f) + :param num_channels: + :return: x: (b, num_channels, t, f) + """ + if num_channels <= x.shape[1]: + return x[:, :num_channels, :, :] + else: + shape = list(x.shape) + shape[1] = num_channels - shape[1] + zeros = torch.zeros(shape, dtype=x.dtype, device=x.device) + return torch.cat((x, zeros), dim=1) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) diff --git a/modelscope/models/audio/ans/zipenhancer_layers/zipenhancer_layer.py b/modelscope/models/audio/ans/zipenhancer_layers/zipenhancer_layer.py new file mode 100644 index 00000000..32bf7cb4 --- /dev/null +++ b/modelscope/models/audio/ans/zipenhancer_layers/zipenhancer_layer.py @@ -0,0 +1,501 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Alibaba, Inc. and its affiliates. + +import copy +from typing import List, Optional, Tuple, Union + +import torch +from torch import Tensor, nn + +from .scaling import FloatLike, ScheduledFloat, convert_num_channels +from .zipformer import (BypassModule, CompactRelPositionalEncoding, + SimpleDownsample, SimpleUpsample, + Zipformer2EncoderLayer) + + +class DualPathZipformer2Encoder(nn.Module): + r"""DualPathZipformer2Encoder is a stack of N encoder layers + it has two kinds of EncoderLayer including F_Zipformer2EncoderLayer and T_Zipformer2EncoderLayer + the features are modeling with the shape of + [B, C, T, F] -> [F, T * B, C] -> -> [B, C, T, F] -> [T, F * B, C] -> [B, C, T, F] + + Args: + encoder_layer: an instance of the Zipformer2EncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + pos_dim: the dimension for the relative positional encoding + + Examples:: + >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) + >>> dualpath_zipformer_encoder = DualPathZipformer2Encoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 512, 161, 101) + >>> out = dualpath_zipformer_encoder(src) + """ + + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + pos_dim: int, + dropout: float, + warmup_begin: float, + warmup_end: float, + initial_layerdrop_rate: float = 0.5, + final_layerdrop_rate: float = 0.05, + bypass_layer=None, + ) -> None: + """ + Initialize the DualPathZipformer2Encoder module with the specified + encoder layer, number of layers, positional dimension, dropout rate, warmup period, and layer drop rates. + """ + super().__init__() + self.encoder_pos = CompactRelPositionalEncoding( + pos_dim, dropout_rate=0.15, length_factor=1.0) + + self.f_layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)]) + self.t_layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)]) + self.bypass_layers = nn.ModuleList( + [bypass_layer for i in range(num_layers * 2)]) + self.num_layers = num_layers + + assert 0 <= warmup_begin <= warmup_end, (warmup_begin, warmup_end) + + delta = (1.0 / num_layers) * (warmup_end - warmup_begin) + cur_begin = warmup_begin # interpreted as a training batch index + for i in range(num_layers): + cur_end = cur_begin + delta + self.f_layers[i].bypass.skip_rate = ScheduledFloat( + (cur_begin, initial_layerdrop_rate), + (cur_end, final_layerdrop_rate), + default=0.0, + ) + self.t_layers[i].bypass.skip_rate = ScheduledFloat( + (cur_begin, initial_layerdrop_rate), + (cur_end, final_layerdrop_rate), + default=0.0, + ) + cur_begin = cur_end + + def forward( + self, + src: Tensor, + chunk_size: int = -1, + feature_mask: Union[Tensor, float] = 1.0, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in a dual-path manner, processing both temporal and frequency dimensions. + + Args: + src: the dual-path sequence to the encoder (required): + shape (batch_size, embedding_dim, seq_len, frequency_len). + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. No used. + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: a Tensor with the same shape as src. + """ + + # src: (b, c, t, f) + b, c, t, f = src.size() + src_f = src.permute(3, 0, 2, 1).contiguous().view(f, b * t, c) + src_t = src.permute(2, 0, 3, 1).contiguous().view(t, b * f, c) + pos_emb_f = self.encoder_pos(src_f) + pos_emb_t = self.encoder_pos(src_t) + + output = src + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + output = output * feature_mask + + for i in range(len(self.f_layers)): + # output_org = output + # (b, c, t, f) + output_f_org = output.permute(3, 2, 0, + 1).contiguous() # (f, t, b, c) + output_f = output_f_org.view(f, t * b, c) + # (f, t * b, c) + output_f = self.f_layers[i]( + output_f, + pos_emb_f, + # chunk_size=chunk_size, + # attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + ) + output_f = output_f.view(f, t, b, c) + output_f = self.bypass_layers[i * 2](output_f_org, output_f) + + # (f, t, b, c) + output = output_f.permute(2, 3, 1, 0).contiguous() + # (b, c, t, f) + # output = self.bypass_layers[i * 2](output_org, output) + + # output_org = output + + output_t_org = output.permute(2, 3, 0, + 1).contiguous() # (t, f, b, c) + output_t = output_t_org.view(t, f * b, c) + output_t = self.t_layers[i]( + output_t, + pos_emb_t, + # chunk_size=chunk_size, + # attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + ) + output_t = output_t.view(t, f, b, c) + output_t = self.bypass_layers[i * 2 + 1](output_t_org, output_t) + # (t, f, b, c) + + output = output_t.permute(2, 3, 0, 1).contiguous() + # (b, c, t, f) + # output = self.bypass_layers[i * 2 + 1](output_org, output) + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + output = output * feature_mask + + return output + + +class DualPathDownsampledZipformer2Encoder(nn.Module): + r""" + DualPathDownsampledZipformer2Encoder is a dual-path zipformer encoder evaluated at a reduced frame rate, + after convolutional downsampling, and then upsampled again at the output, and combined + with the origin input, so that the output has the same shape as the input. + The features are downsampled-upsampled at the time and frequency domain. + + """ + + def __init__(self, encoder: nn.Module, dim: int, t_downsample: int, + f_downsample: int, dropout: FloatLike): + """ + Initialize the DualPathDownsampledZipformer2Encoder module with the specified + encoder, dimension, temporal and frequency downsampling factors r, and dropout rate. + """ + super(DualPathDownsampledZipformer2Encoder, self).__init__() + self.downsample_factor = t_downsample + self.t_downsample_factor = t_downsample + self.f_downsample_factor = f_downsample + + if self.t_downsample_factor != 1: + self.downsample_t = SimpleDownsample(dim, t_downsample, dropout) + self.upsample_t = SimpleUpsample(dim, t_downsample) + if self.f_downsample_factor != 1: + self.downsample_f = SimpleDownsample(dim, f_downsample, dropout) + self.upsample_f = SimpleUpsample(dim, f_downsample) + + # self.num_layers = encoder.num_layers + self.encoder = encoder + + self.out_combiner = BypassModule(dim, straight_through_rate=0) + + def forward( + self, + src: Tensor, + chunk_size: int = -1, + feature_mask: Union[Tensor, float] = 1.0, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Downsample the input, process through the encoder, and then upsample back to the original shape. + + Args: + src: the sequence to the encoder (required): shape (batch_size, embedding_dim, seq_len, frequency_len). + feature_mask: 1.0 + attn_mask: None + src_key_padding_mask: None. + + Returns: a Tensor with the same shape as src. (batch_size, embedding_dim, seq_len, frequency_len) + """ + # src: (b, c, t, f) + b, c, t, f = src.size() + # print(src.size()) + + src_orig = src.permute(2, 3, 0, 1) # (t, f, b, c) + + # (b, c, t, f) + src = src.permute(2, 0, 3, 1).contiguous().view(t, b * f, c) + # -> (t, b * f, c) + if self.t_downsample_factor != 1: + src = self.downsample_t(src) + # (t//ds + 1, b * f, c) + downsample_t = src.size(0) + src = src.view(downsample_t, b, f, + c).permute(2, 1, 0, + 3).contiguous().view(f, b * downsample_t, c) + # src = self.upsample_f(src) + if self.f_downsample_factor != 1: + src = self.downsample_f(src) + # (f//ds + 1, b * downsample_t, c) + downsample_f = src.size(0) + src = src.view(downsample_f, b, downsample_t, c).permute(1, 3, 2, 0) + # (b, c, downsample_t, downsample_f) + # print(src.size()) + + # ds = self.downsample_factor + # if attn_mask is not None: + # attn_mask = attn_mask[::ds, ::ds] + + src = self.encoder( + src, + chunk_size=chunk_size, + feature_mask=feature_mask, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + ) + + # (b, c, downsample_t, downsample_f) + src = src.permute(3, 0, 2, + 1).contiguous().view(downsample_f, b * downsample_t, + c) + if self.f_downsample_factor != 1: + src = self.upsample_f(src) + # (f, b * downsample_t, c) + src = src[:f].view(f, b, downsample_t, + c).permute(2, 1, 0, 3).contiguous().view( + downsample_t, b * f, c) + # (downsample_t, b * f, c) + if self.t_downsample_factor != 1: + src = self.upsample_t(src) + # (t, b * f, c) + src = src[:t].view(t, b, f, c).permute(0, 2, 1, 3).contiguous() + # (t, f, b, c) + out = self.out_combiner(src_orig, src) + # (t, f, b, c) + + out = out.permute(2, 3, 0, 1).contiguous() + # (b, c, t, f) + # print(out.size()) + + # remove any extra frames that are not a multiple of downsample_factor + # src = src[: src_orig.shape[0]] # slice here + + return out + + +class Zipformer2DualPathEncoder(nn.Module): + + def __init__( + self, + output_downsampling_factor: int = 2, + downsampling_factor: Tuple[int] = (2, 4), + f_downsampling_factor: Tuple[int] = None, + encoder_dim: Union[int, Tuple[int]] = 384, + num_encoder_layers: Union[int, Tuple[int]] = 4, + encoder_unmasked_dim: Union[int, Tuple[int]] = 256, + query_head_dim: Union[int, Tuple[int]] = 24, + pos_head_dim: Union[int, Tuple[int]] = 4, + value_head_dim: Union[int, Tuple[int]] = 12, + num_heads: Union[int, Tuple[int]] = 8, + feedforward_dim: Union[int, Tuple[int]] = 1536, + cnn_module_kernel: Union[int, Tuple[int]] = 31, + pos_dim: int = 192, + dropout: FloatLike = None, # see code below for default + warmup_batches: float = 4000.0, + causal: bool = False, + chunk_size: Tuple[int] = [-1], + left_context_frames: Tuple[int] = [-1], + ): + """ + Initialize the Zipformer2DualPathEncoder module. + Zipformer2DualPathEncoder processes the hidden features of the noisy speech using dual-path modeling. + It has two kinds of blocks: DualPathZipformer2Encoder and DualPathDownsampledZipformer2Encoder. + DualPathZipformer2Encoder processes the 4D features with the shape of [B, C, T, F]. + DualPathDownsampledZipformer2Encoder first downsamples the hidden features + and processes features using dual-path modeling like DualPathZipformer2Encoder. + + Args: + Various hyperparameters and settings for the encoder. + """ + super(Zipformer2DualPathEncoder, self).__init__() + + if dropout is None: + dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1)) + + def _to_tuple(x): + """Converts a single int or a 1-tuple of an int to a tuple with the same length + as downsampling_factor""" + if isinstance(x, int): + x = (x, ) + if len(x) == 1: + x = x * len(downsampling_factor) + else: + assert len(x) == len(downsampling_factor) and isinstance( + x[0], int) + return x + + self.output_downsampling_factor = output_downsampling_factor # int + self.downsampling_factor = downsampling_factor # tuple + + if f_downsampling_factor is None: + f_downsampling_factor = downsampling_factor + self.f_downsampling_factor = _to_tuple(f_downsampling_factor) + + self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple + self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple( + encoder_unmasked_dim) # tuple + num_encoder_layers = _to_tuple(num_encoder_layers) + self.num_encoder_layers = num_encoder_layers + self.query_head_dim = query_head_dim = _to_tuple(query_head_dim) + self.value_head_dim = value_head_dim = _to_tuple(value_head_dim) + pos_head_dim = _to_tuple(pos_head_dim) + self.num_heads = num_heads = _to_tuple(num_heads) + feedforward_dim = _to_tuple(feedforward_dim) + self.cnn_module_kernel = cnn_module_kernel = _to_tuple( + cnn_module_kernel) + + self.causal = causal + self.chunk_size = chunk_size + self.left_context_frames = left_context_frames + + for u, d in zip(encoder_unmasked_dim, encoder_dim): + assert u <= d + + # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder + encoders = [] + + num_encoders = len(downsampling_factor) + # "1,2,4,8,4,2", + + for i in range(num_encoders): + encoder_layer = Zipformer2EncoderLayer( + embed_dim=encoder_dim[i], + pos_dim=pos_dim, + num_heads=num_heads[i], + query_head_dim=query_head_dim[i], + pos_head_dim=pos_head_dim[i], + value_head_dim=value_head_dim[i], + feedforward_dim=feedforward_dim[i], + dropout=dropout, + cnn_module_kernel=cnn_module_kernel[i], + causal=causal, + ) + + # For the segment of the warmup period, we let the Conv2dSubsampling + # layer learn something. Then we start to warm up the other encoders. + encoder = DualPathZipformer2Encoder( + encoder_layer, + num_encoder_layers[i], + pos_dim=pos_dim, + dropout=dropout, + warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), + warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), + final_layerdrop_rate=0.035 * (downsampling_factor[i]**0.5), + bypass_layer=BypassModule( + encoder_dim[i], straight_through_rate=0), + ) + + if downsampling_factor[i] != 1 or f_downsampling_factor[i] != 1: + encoder = DualPathDownsampledZipformer2Encoder( + encoder, + dim=encoder_dim[i], + t_downsample=downsampling_factor[i], + f_downsample=f_downsampling_factor[i], + dropout=dropout, + ) + + encoders.append(encoder) + + self.encoders = nn.ModuleList(encoders) + + self.downsample_output = SimpleDownsample( + max(encoder_dim), + downsample=output_downsampling_factor, + dropout=dropout) + + def forward(self, x): + """ + Forward pass of the Zipformer2DualPathEncoder module. + + Args: + x (Tensor): Input tensor of shape [B, C, T, F]. + + Returns: + Tensor: Output tensor after passing through the encoder. + """ + outputs = [] + + # if torch.jit.is_scripting() or torch.jit.is_tracing(): + # feature_masks = [1.0] * len(self.encoder_dim) + # else: + # feature_masks = self.get_feature_masks(x) + feature_masks = [1.0] * len(self.encoder_dim) + attn_mask = None + + chunk_size = -1 + # left_context_chunks = -1 + + for i, module in enumerate(self.encoders): + + x = convert_num_channels(x, self.encoder_dim[i]) + + x = module( + x, + chunk_size=chunk_size, + feature_mask=feature_masks[i], + src_key_padding_mask=None, + attn_mask=attn_mask, + ) + outputs.append(x) + + # (b, c, t, f) + return x + + +if __name__ == '__main__': + + # {2,2,2,2,2,2} {192,256,256,256,256,256} {512,768,768,768,768,768} + downsampling_factor = (1, 2, 4, 3) # + encoder_dim = (16, 32, 64, 64) + pos_dim = 48 # zipformer base设置 + num_heads = (4, 4, 4, 4) # "4,4,4,8,4,4" + query_head_dim = (16, ) * len(downsampling_factor) # 32 + pos_head_dim = (4, ) * len(downsampling_factor) # 4 + value_head_dim = (12, ) * len(downsampling_factor) # 12 + feedforward_dim = (32, 64, 128, 128) # + dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1)) + cnn_module_kernel = (15, ) * len(downsampling_factor) # 31,31,15,15,15,31 + causal = False + encoder_unmasked_dim = (16, ) * len(downsampling_factor) + + num_encoder_layers = (1, 1, 1, 1) + warmup_batches = 4000.0 + + net = Zipformer2DualPathEncoder( + output_downsampling_factor=1, + downsampling_factor=downsampling_factor, + num_encoder_layers=num_encoder_layers, + encoder_dim=encoder_dim, + encoder_unmasked_dim=encoder_unmasked_dim, + query_head_dim=query_head_dim, + pos_head_dim=pos_head_dim, + value_head_dim=value_head_dim, + pos_dim=pos_dim, + num_heads=num_heads, + feedforward_dim=feedforward_dim, + cnn_module_kernel=cnn_module_kernel, + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=warmup_batches, + causal=causal, + ) + + # net = DownsampledZipformer2Encoder( + # None, 128, 2, 0. + # ) + # x = torch.randn((101, 2, 128)) + b = 4 + t = 321 + f = 101 + c = 64 + + # x = torch.randn((101, 2, 128)) + x = torch.randn((b, c, t, f)) + + x = net(x) + print(x.size()) diff --git a/modelscope/models/audio/ans/zipenhancer_layers/zipformer.py b/modelscope/models/audio/ans/zipenhancer_layers/zipformer.py new file mode 100644 index 00000000..2ac5eb0d --- /dev/null +++ b/modelscope/models/audio/ans/zipenhancer_layers/zipformer.py @@ -0,0 +1,1084 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey, +# Zengwei Yao) +# Copyright (c) 2024 Alibaba, Inc. and its affiliates. +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import logging +import math +import random +import warnings +from typing import List, Optional, Tuple, Union + +import torch +from torch import Tensor, nn + +from .scaling import \ + Identity # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. +from .scaling import \ + ScaledLinear # not as in other dirs.. just scales down initial parameter values. +from .scaling import (ActivationDropoutAndLinear, BiasNorm, + ChunkCausalDepthwiseConv1d, Dropout2, FloatLike, + ScheduledFloat, limit_param_value, + penalize_abs_values_gt, softmax) + + +class Zipformer2EncoderLayer(nn.Module): + """ + Args: + embed_dim: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + feedforward_dim: the dimension of the feedforward network model (required). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module (default=31). + + Examples:: + >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + pos_dim: int, + num_heads: int, + query_head_dim: int, + pos_head_dim: int, + value_head_dim: int, + feedforward_dim: int, + dropout: FloatLike = 0.1, + cnn_module_kernel: int = 31, + causal: bool = False, + attention_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), + conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), + (16000, 0.0), + default=0), + const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), + (4000.0, 0.025), + default=0), + ff2_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), + (50000.0, 0.0)), + ff3_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), + (50000.0, 0.0)), + bypass_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), + (4000.0, 0.02), + default=0), + ) -> None: + super(Zipformer2EncoderLayer, self).__init__() + self.embed_dim = embed_dim + + # self.bypass implements layer skipping as well as bypass; see its default values. + self.bypass = BypassModule( + embed_dim, skip_rate=bypass_skip_rate, straight_through_rate=0) + # bypass_mid is bypass used in the middle of the layer. + self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0) + + # skip probability for dynamic modules (meaning: anything but feedforward). + self.attention_skip_rate = copy.deepcopy(attention_skip_rate) + # an additional skip probability that applies to ConvModule to stop it from + # contributing too much early on. + self.conv_skip_rate = copy.deepcopy(conv_skip_rate) + + # ff2_skip_rate is to prevent the ff2 module from having output that's too big + # compared to its residual. + self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate) + self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate) + + self.const_attention_rate = copy.deepcopy(const_attention_rate) + + self.self_attn_weights = RelPositionMultiheadAttentionWeights( + embed_dim, + pos_dim=pos_dim, + num_heads=num_heads, + query_head_dim=query_head_dim, + pos_head_dim=pos_head_dim, + dropout=0.0, + ) + + self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim) + + self.self_attn2 = SelfAttention(embed_dim, num_heads, value_head_dim) + + self.feed_forward1 = FeedforwardModule(embed_dim, + (feedforward_dim * 3) // 4, + dropout) + + self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, + dropout) + + self.feed_forward3 = FeedforwardModule(embed_dim, + (feedforward_dim * 5) // 4, + dropout) + + self.nonlin_attention = NonlinAttention( + embed_dim, hidden_channels=3 * embed_dim // 4) + + self.conv_module1 = ConvolutionModule( + embed_dim, cnn_module_kernel, causal=causal) + + self.conv_module2 = ConvolutionModule( + embed_dim, cnn_module_kernel, causal=causal) + + # TODO: remove it + self.bypass_scale = nn.Parameter(torch.full((embed_dim, ), 0.5)) + + self.norm = BiasNorm(embed_dim) + + self.balancer1 = Identity() + self.balancer_na = Identity() + self.balancer_ff2 = Identity() + self.balancer_ff3 = Identity() + self.whiten = Identity() + self.balancer2 = Identity() + + def get_sequence_dropout_mask(self, x: Tensor, + dropout_rate: float) -> Optional[Tensor]: + if (dropout_rate == 0.0 or not self.training + or torch.jit.is_scripting() or torch.jit.is_tracing()): + return None + batch_size = x.shape[1] + mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to( + x.dtype) + return mask + + def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor: + """ + Apply sequence-level dropout to x. + x shape: (seq_len, batch_size, embed_dim) + """ + dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate) + if dropout_mask is None: + return x + else: + return x * dropout_mask + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + chunk_size: int = -1, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the encoder layer. + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: + A tensor which has the same shape as src + """ + src_orig = src + + # dropout rate for non-feedforward submodules + if torch.jit.is_scripting() or torch.jit.is_tracing(): + attention_skip_rate = 0.0 + else: + attention_skip_rate = ( + float(self.attention_skip_rate) if self.training else 0.0) + + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + attn_weights = self.self_attn_weights( + src, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + ) + + src = src + self.feed_forward1(src) + + self_attn_dropout_mask = self.get_sequence_dropout_mask( + src, attention_skip_rate) + + selected_attn_weights = attn_weights[0:1] + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif self.training and random.random() < float( + self.const_attention_rate): + # Make attention weights constant. The intention is to + # encourage these modules to do something similar to an + # averaging-over-time operation. + # only need the mask, can just use the 1st one and expand later + selected_attn_weights = selected_attn_weights[0:1] + selected_attn_weights = (selected_attn_weights > 0.0).to( + selected_attn_weights.dtype) + selected_attn_weights = selected_attn_weights * ( + 1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)) + + na = self.balancer_na( + self.nonlin_attention(src, selected_attn_weights)) + + src = src + ( + na if self_attn_dropout_mask is None else na + * self_attn_dropout_mask) + + self_attn = self.self_attn1(src, attn_weights) + + src = src + ( + self_attn if self_attn_dropout_mask is None else self_attn + * self_attn_dropout_mask) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + conv_skip_rate = 0.0 + else: + conv_skip_rate = float( + self.conv_skip_rate) if self.training else 0.0 + + src = src + self.sequence_dropout( + self.conv_module1( + src, + chunk_size=chunk_size, + src_key_padding_mask=src_key_padding_mask), + conv_skip_rate, + ) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + ff2_skip_rate = 0.0 + else: + ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout( + self.balancer_ff2(self.feed_forward2(src)), ff2_skip_rate) + + # bypass in the middle of the layer. + src = self.bypass_mid(src_orig, src) + + self_attn = self.self_attn2(src, attn_weights) + + src = src + ( + self_attn if self_attn_dropout_mask is None else self_attn + * self_attn_dropout_mask) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + conv_skip_rate = 0.0 + else: + conv_skip_rate = float( + self.conv_skip_rate) if self.training else 0.0 + + src = src + self.sequence_dropout( + self.conv_module2( + src, + chunk_size=chunk_size, + src_key_padding_mask=src_key_padding_mask), + conv_skip_rate, + ) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + ff3_skip_rate = 0.0 + else: + ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout( + self.balancer_ff3(self.feed_forward3(src)), ff3_skip_rate) + + src = self.balancer1(src) + src = self.norm(src) + + src = self.bypass(src_orig, src) + + src = self.balancer2(src) + src = self.whiten(src) + + return src + + +class BypassModule(nn.Module): + """ + An nn.Module that implements a learnable bypass scale, and also randomized per-sequence + layer-skipping. The bypass is limited during early stages of training to be close to + "straight-through", i.e. to not do the bypass operation much initially, in order to + force all the modules to learn something. + """ + + def __init__( + self, + embed_dim: int, + skip_rate: FloatLike = 0.0, + straight_through_rate: FloatLike = 0.0, + scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), + default=0), + scale_max: FloatLike = 1.0, + ): + super().__init__() + self.bypass_scale = nn.Parameter(torch.full((embed_dim, ), 0.5)) + self.skip_rate = copy.deepcopy(skip_rate) + self.straight_through_rate = copy.deepcopy(straight_through_rate) + self.scale_min = copy.deepcopy(scale_min) + self.scale_max = copy.deepcopy(scale_max) + + def _get_bypass_scale(self, batch_size: int): + # returns bypass-scale of shape (num_channels,), + # or (batch_size, num_channels,). This is actually the + # scale on the non-residual term, so 0 corresponds to bypassing + # this module. + if torch.jit.is_scripting() or torch.jit.is_tracing( + ) or not self.training: + return self.bypass_scale + else: + ans = limit_param_value( + self.bypass_scale, + min=float(self.scale_min), + max=float(self.scale_max)) + skip_rate = float(self.skip_rate) + if skip_rate != 0.0: + mask = torch.rand( + (batch_size, 1), device=ans.device) > skip_rate + ans = ans * mask + # now ans is of shape (batch_size, num_channels), and is zero for sequences + # on which we have randomly chosen to do layer-skipping. + straight_through_rate = float(self.straight_through_rate) + if straight_through_rate != 0.0: + _rand_tensor = torch.rand((batch_size, 1), device=ans.device) + mask = (_rand_tensor < straight_through_rate) + ans = torch.maximum(ans, mask.to(ans.dtype)) + return ans + + def forward(self, src_orig: Tensor, src: Tensor): + """ + Args: src_orig and src are both of shape (seq_len, batch_size, num_channels) + Returns: something with the same shape as src and src_orig + """ + # bypass_scale = self._get_bypass_scale(src.shape[1]) + bypass_scale = self._get_bypass_scale(src.shape[-2]) + return src_orig + (src - src_orig) * bypass_scale + + +class SimpleDownsample(torch.nn.Module): + """ + Does downsampling with attention, by weighted sum, and a projection.. + """ + + def __init__(self, channels: int, downsample: int, dropout: FloatLike): + super(SimpleDownsample, self).__init__() + + self.bias = nn.Parameter(torch.zeros(downsample)) + + self.name = None # will be set from training code + self.dropout = copy.deepcopy(dropout) + + self.downsample = downsample + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, in_channels) + Returns a tensor of shape + ( (seq_len+downsample-1)//downsample, batch_size, channels) + """ + (seq_len, batch_size, in_channels) = src.shape + ds = self.downsample + d_seq_len = (seq_len + ds - 1) // ds + + # Pad to an exact multiple of self.downsample + # right-pad src, repeating the last element. + pad = d_seq_len * ds - seq_len + src_extra = src[src.shape[0] - 1:].expand(pad, src.shape[1], + src.shape[2]) + src = torch.cat((src, src_extra), dim=0) + assert src.shape[0] == d_seq_len * ds + + src = src.reshape(d_seq_len, ds, batch_size, in_channels) + + weights = self.bias.softmax(dim=0) + # weights: (downsample, 1, 1) + weights = weights.unsqueeze(-1).unsqueeze(-1) + + # ans1 is the first `in_channels` channels of the output + ans = (src * weights).sum(dim=1) + + return ans + + +class SimpleUpsample(torch.nn.Module): + """ + A very simple form of upsampling that mostly just repeats the input, but + also adds a position-specific bias. + """ + + def __init__(self, num_channels: int, upsample: int): + super(SimpleUpsample, self).__init__() + self.upsample = upsample + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, num_channels) + Returns a tensor of shape + ( (seq_len*upsample), batch_size, num_channels) + """ + upsample = self.upsample + (seq_len, batch_size, num_channels) = src.shape + src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, + num_channels) + src = src.reshape(seq_len * upsample, batch_size, num_channels) + return src + + +class CompactRelPositionalEncoding(torch.nn.Module): + """ + Relative positional encoding module. This version is "compact" meaning it is able to encode + the important information about the relative position in a relatively small number of dimensions. + The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001) + make very little difference to the embedding. Such differences were potentially important + when encoding absolute position, but not important when encoding relative position because there + is now no need to compare two large offsets with each other. + + Our embedding works by projecting the interval [-infinity,infinity] to a finite interval + using the atan() function, before doing the Fourier transform of that fixed interval. The + atan() function would compress the "long tails" too small, + making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic + function to compress large offsets to a smaller range before applying atan(). + Scalings are chosen in such a way that the embedding can clearly distinguish individual offsets as long + as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim) + + + Args: + embed_dim: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length: just a heuristic for initialization. + length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives + less weight to small differences of offset near the origin. + """ + + def __init__( + self, + embed_dim: int, + dropout_rate: FloatLike, + max_len: int = 1000, + length_factor: float = 1.0, + ) -> None: + """Construct a CompactRelPositionalEncoding object.""" + super(CompactRelPositionalEncoding, self).__init__() + self.embed_dim = embed_dim + assert embed_dim % 2 == 0, embed_dim + self.dropout = Dropout2(dropout_rate) + self.pe = None + assert length_factor >= 1.0, length_factor + self.length_factor = length_factor + self.extend_pe(torch.tensor(0.0).expand(max_len)) + + def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: + """Reset the positional encodings.""" + T = x.size(0) + left_context_len + + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(0) >= T * 2 - 1: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + + # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] + x = torch.arange( + -(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1) + + freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device) + + # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution + # for small time offsets but less resolution for large time offsets. + compression_length = self.embed_dim**0.5 + # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity; + # but it does so more slowly than T for large absolute values of T. + # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which + # is important. + _tmp_tensor = ((x.abs() + compression_length).log() + - math.log(compression_length)) + x_compressed = (compression_length * x.sign() * _tmp_tensor) + + # if self.length_factor == 1.0, then length_scale is chosen so that the + # FFT can exactly separate points close to the origin (T == 0). So this + # part of the formulation is not really heuristic. + # But empirically, for ASR at least, length_factor > 1.0 seems to work better. + length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi) + + # note for machine implementations: if atan is not available, we can use: + # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2) + # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x)) + x_atan = (x_compressed + / length_scale).atan() # results between -pi and pi + + cosines = (x_atan * freqs).cos() + sines = (x_atan * freqs).sin() + + pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device) + pe[:, 0::2] = cosines + pe[:, 1::2] = sines + pe[:, -1] = 1.0 # for bias. + + self.pe = pe.to(dtype=x.dtype) + + def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor: + """Create positional encoding. + + Args: + x (Tensor): Input tensor (time, batch, `*`). + left_context_len: (int): Length of cached left context. + + Returns: + positional embedding, of shape (batch, left_context_len + 2*time-1, `*`). + """ + self.extend_pe(x, left_context_len) + x_size_left = x.size(0) + left_context_len + # length of positive side: x.size(0) + left_context_len + # length of negative side: x.size(0) + pos_emb = self.pe[self.pe.size(0) // 2 - x_size_left + + 1:self.pe.size(0) // 2 # noqa E203 + + x.size(0), :, ] + pos_emb = pos_emb.unsqueeze(0) + return self.dropout(pos_emb) + + +class RelPositionMultiheadAttentionWeights(nn.Module): + r"""Module that computes multi-head attention weights with relative position encoding. + Various other modules consume the resulting attention weights: see, for example, the + SimpleAttention module which allows you to compute conventional attention. + + This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", + we have to write up the differences. + + + Args: + embed_dim: number of channels at the input to this module, e.g. 256 + pos_dim: dimension of the positional encoding vectors, e.g. 128. + num_heads: number of heads to compute weights for, e.g. 8 + query_head_dim: dimension of the query (and key), per head. e.g. 24. + pos_head_dim: dimension of the projected positional encoding per head, e.g. 4. + dropout: dropout probability for attn_output_weights. Default: 0.0. + pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on + any given call to forward(), in training time. + """ + + def __init__( + self, + embed_dim: int, + pos_dim: int, + num_heads: int, + query_head_dim: int, + pos_head_dim: int, + dropout: float = 0.0, + pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), + (4000.0, 0.0)), + ) -> None: + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.query_head_dim = query_head_dim + self.pos_head_dim = pos_head_dim + self.dropout = dropout + self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) + self.name = None # will be overwritten in training code; for diagnostics. + + key_head_dim = query_head_dim + in_proj_dim = (query_head_dim + key_head_dim + + pos_head_dim) * num_heads + + # the initial_scale is supposed to take over the "scaling" factor of + # head_dim ** -0.5 that has been used in previous forms of attention, + # dividing it between the query and key. Note: this module is intended + # to be used with the ScaledAdam optimizer; with most other optimizers, + # it would be necessary to apply the scaling factor in the forward function. + self.in_proj = ScaledLinear( + embed_dim, + in_proj_dim, + bias=True, + initial_scale=query_head_dim**-0.25) + + self.whiten_keys = Identity() + self.balance_keys = Identity() + + # linear transformation for positional encoding. + self.linear_pos = ScaledLinear( + pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05) + + # the following are for diagnostics only, see --print-diagnostics option + self.copy_pos_query = Identity() + self.copy_query = Identity() + + def forward( + self, + x: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + ) -> Tensor: + r""" + Args: + x: input of shape (seq_len, batch_size, embed_dim) + pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim) + key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), + interpreted as ([batch_size,] tgt_seq_len, src_seq_len) + saying which positions are allowed to attend to which other positions. + Returns: + a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) + interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). + """ + x = self.in_proj(x) + query_head_dim = self.query_head_dim + pos_head_dim = self.pos_head_dim + num_heads = self.num_heads + + seq_len, batch_size, _ = x.shape + + query_dim = query_head_dim * num_heads + + # self-attention + q = x[..., 0:query_dim] + k = x[..., query_dim:2 * query_dim] + # p is the position-encoding query + p = x[..., 2 * query_dim:] + assert p.shape[-1] == num_heads * pos_head_dim, (p.shape[-1], + num_heads, + pos_head_dim) + + q = self.copy_query(q) # for diagnostics only, does nothing. + k = self.whiten_keys( + self.balance_keys(k)) # does nothing in the forward pass. + p = self.copy_pos_query(p) # for diagnostics only, does nothing. + + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) + k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) + + # time1 refers to target, time2 refers to source. + q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) + p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + + # print(f"MHSAW {q.shape} {k.shape}") + attn_scores = torch.matmul(q, k) + + use_pos_scores = False + if torch.jit.is_scripting() or torch.jit.is_tracing(): + # We can't put random.random() in the same line + use_pos_scores = True + elif not self.training or random.random() >= float( + self.pos_emb_skip_rate): + use_pos_scores = True + + if use_pos_scores: + pos_emb = self.linear_pos(pos_emb) + seq_len2 = 2 * seq_len - 1 + pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, + pos_head_dim).permute(2, 0, 3, 1) + # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) + + # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) + # [where seq_len2 represents relative position.] + # print(f"MHSAW pos {p.shape} {pos_emb.shape}") + pos_scores = torch.matmul(p, pos_emb) + # the following .as_strided() expression converts the last axis of pos_scores from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + if torch.jit.is_tracing(): + (num_heads, batch_size, time1, n) = pos_scores.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(seq_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_scores = pos_scores.reshape(-1, n) + pos_scores = torch.gather(pos_scores, dim=1, index=indexes) + pos_scores = pos_scores.reshape(num_heads, batch_size, time1, + seq_len) + else: + pos_scores = pos_scores.as_strided( + (num_heads, batch_size, seq_len, seq_len), + ( + pos_scores.stride(0), + pos_scores.stride(1), + pos_scores.stride(2) - pos_scores.stride(3), + pos_scores.stride(3), + ), + storage_offset=pos_scores.stride(3) * (seq_len - 1), + ) + # print(attn_scores.shape, pos_scores.shape) + if self.training: + attn_scores = attn_scores + pos_scores + else: + # inplace operator important + # attn_scores.add_(pos_scores) + attn_scores = attn_scores + pos_scores + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif self.training and random.random() < 0.1: + # This is a harder way of limiting the attention scores to not be + # too large. It incurs a penalty if any of them has an absolute + # value greater than 50.0. this should be outside the normal range + # of the attention scores. We use this mechanism instead of, say, + # something added to the loss function involving the entropy, + # because once the entropy gets very small gradients through the + # softmax can become very small, and we'd get zero derivatives. The + # choices of 1.0e-04 as the scale on the penalty makes this + # mechanism vulnerable to the absolute scale of the loss function, + # but we view this as a failsafe to avoid "implausible" parameter + # values rather than a regularization method that should be active + # under normal circumstances. + attn_scores = penalize_abs_values_gt( + attn_scores, limit=25.0, penalty=1.0e-04, name=self.name) + + assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) + + if attn_mask is not None: + assert attn_mask.dtype == torch.bool + # use -1000 to avoid nan's where attn_mask and key_padding_mask make + # all scores zero. It's important that this be large enough that exp(-1000) + # is exactly zero, for reasons related to const_attention_rate, it + # compares the final weights with zero. + attn_scores = attn_scores.masked_fill(attn_mask, -1000) + + if key_padding_mask is not None: + assert key_padding_mask.shape == ( + batch_size, + seq_len, + ), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + # We use our own version of softmax, defined in scaling.py, which should + # save a little of the memory used in backprop by, if we are in + # automatic mixed precision mode (amp / autocast), by only storing the + # half-precision output for backprop purposes. + attn_weights = softmax(attn_scores, dim=-1) + + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training) + + return attn_weights + + +class SelfAttention(nn.Module): + """ + The simplest possible attention module. This one works with already-computed attention + weights, e.g. as computed by RelPositionMultiheadAttentionWeights. + + Args: + embed_dim: the input and output embedding dimension + num_heads: the number of attention heads + value_head_dim: the value dimension per head + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + value_head_dim: int, + ) -> None: + super().__init__() + self.in_proj = nn.Linear( + embed_dim, num_heads * value_head_dim, bias=True) + + self.out_proj = ScaledLinear( + num_heads * value_head_dim, + embed_dim, + bias=True, + initial_scale=0.05) + + self.whiten = Identity() + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), + with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect + attn_weights.sum(dim=-1) == 1. + Returns: + a tensor with the same shape as x. + """ + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + x = self.in_proj( + x) # (seq_len, batch_size, num_heads * value_head_dim) + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, value_head_dim) + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + # print(f"SelfAttetion pos {attn_weights.shape} {x.shape}") + x = torch.matmul(attn_weights, x) + # v: (num_heads, batch_size, seq_len, value_head_dim) + + x = ( + x.permute(2, 1, 0, + 3).contiguous().view(seq_len, batch_size, + num_heads * value_head_dim)) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + x = self.whiten(x) + + return x + + +class FeedforwardModule(nn.Module): + """Feedforward module in Zipformer2 model.""" + + def __init__(self, embed_dim: int, feedforward_dim: int, + dropout: FloatLike): + super(FeedforwardModule, self).__init__() + self.in_proj = nn.Linear(embed_dim, feedforward_dim) + + self.hidden_balancer = Identity() + + # shared_dim=0 means we share the dropout mask along the time axis + self.out_proj = ActivationDropoutAndLinear( + feedforward_dim, + embed_dim, + activation='SwooshL', + dropout_p=dropout, + dropout_shared_dim=0, + bias=True, + initial_scale=0.1, + ) + + self.out_whiten = Identity() + + def forward(self, x: Tensor): + x = self.in_proj(x) + x = self.hidden_balancer(x) + # out_proj contains SwooshL activation, then dropout, then linear. + x = self.out_proj(x) + x = self.out_whiten(x) + return x + + +class NonlinAttention(nn.Module): + """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed + from the attention module) in place of actual convolution. We also took out the second nonlinearity, the + one after the attention mechanism. + + Args: + channels (int): The number of channels of conv layers. + """ + + def __init__( + self, + channels: int, + hidden_channels: int, + ) -> None: + super().__init__() + + self.hidden_channels = hidden_channels + + self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) + + self.balancer = Identity() + self.tanh = nn.Tanh() + + self.identity1 = Identity() # for diagnostics. + self.identity2 = Identity() # for diagnostics. + self.identity3 = Identity() # for diagnostics. + + self.out_proj = ScaledLinear( + hidden_channels, channels, bias=True, initial_scale=0.05) + + self.whiten1 = Identity() + self.whiten2 = Identity() + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """. + Args: + x: a Tensor of shape (seq_len, batch_size, num_channels) + attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) + Returns: + a Tensor with the same shape as x + """ + x = self.in_proj(x) + + (seq_len, batch_size, _) = x.shape + hidden_channels = self.hidden_channels + + s, x, y = x.chunk(3, dim=2) + + # s will go through tanh. + + s = self.balancer(s) + s = self.tanh(s) + + s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) + x = self.whiten1(x) + x = x * s + x = self.identity1(x) # diagnostics only, it's the identity. + + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, head_dim) + # print(f"nonlinattion {attn_weights.shape} {x.shape}") + x = torch.matmul(attn_weights, x) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + + y = self.identity2(y) + x = x * y + x = self.identity3(x) + + x = self.out_proj(x) + x = self.whiten2(x) + return x + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Zipformer2 model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__( + self, + channels: int, + kernel_size: int, + causal: bool, + ) -> None: + """Construct a ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + bottleneck_dim = channels + self.causal = causal + + self.in_proj = nn.Linear( + channels, + 2 * bottleneck_dim, + ) + # the gradients on in_proj are a little noisy, likely to do with the + # sigmoid in glu. + + self.balancer1 = Identity() + + self.activation1 = Identity() # for diagnostics + + self.sigmoid = nn.Sigmoid() + + self.activation2 = Identity() # for diagnostics + + assert kernel_size % 2 == 1 + + self.depthwise_conv = ( + ChunkCausalDepthwiseConv1d( + channels=bottleneck_dim, kernel_size=kernel_size) + if causal else nn.Conv1d( + in_channels=bottleneck_dim, + out_channels=bottleneck_dim, + groups=bottleneck_dim, + kernel_size=kernel_size, + padding=kernel_size // 2, + )) + + self.balancer2 = Identity() + + self.whiten = Identity() + + self.out_proj = ActivationDropoutAndLinear( + bottleneck_dim, + channels, + activation='SwooshR', + dropout_p=0.0, + initial_scale=0.05, + ) + + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + chunk_size: int = -1, + ) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains True in masked positions. + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + + x = self.in_proj(x) # (time, batch, 2*channels) + + x, s = x.chunk(2, dim=2) + s = self.balancer1(s) + s = self.sigmoid(s) + x = self.activation1(x) # identity. + x = x * s + x = self.activation2(x) # identity + + # (time, batch, channels) + + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + if src_key_padding_mask is not None: + x = x.masked_fill( + src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + if (not torch.jit.is_scripting() and not torch.jit.is_tracing() + and chunk_size >= 0): + # Not support exporting a model for simulated streaming decoding + assert ( + self.causal + ), 'Must initialize model with causal=True if you use chunk_size' + x = self.depthwise_conv(x, chunk_size=chunk_size) + else: + # with record_function("depthwise_conv"): + x = self.depthwise_conv(x) + # pass + + x = self.balancer2(x) + x = x.permute(2, 0, 1) # (time, batch, channels) + + x = self.whiten(x) # (time, batch, channels) + x = self.out_proj(x) # (time, batch, channels) + + return x + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) diff --git a/modelscope/pipelines/audio/ans_pipeline.py b/modelscope/pipelines/audio/ans_pipeline.py index 3719689c..0b03beca 100644 --- a/modelscope/pipelines/audio/ans_pipeline.py +++ b/modelscope/pipelines/audio/ans_pipeline.py @@ -122,3 +122,127 @@ class ANSPipeline(Pipeline): np.frombuffer(inputs[OutputKeys.OUTPUT_PCM], dtype=np.int16), self.SAMPLE_RATE) return inputs + + +@PIPELINES.register_module( + Tasks.acoustic_noise_suppression, + module_name=Pipelines.speech_zipenhancer_ans_multiloss_16k_base) +class ANSZipEnhancerPipeline(Pipeline): + r"""ANS (Acoustic Noise Suppression) Inference Pipeline . + + When invoke the class with pipeline.__call__(), it accept only one parameter: + inputs(str): the path of wav file + """ + SAMPLE_RATE = 16000 + + def __init__(self, model, **kwargs): + """ + use `model` and `preprocessor` to create a kws pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + self.model.eval() + self.stream_mode = kwargs.get('stream_mode', False) + + def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: + if self.stream_mode: + raise TypeError('This model does not support stream mode!') + if isinstance(inputs, bytes): + data1, fs = sf.read(io.BytesIO(inputs)) + elif isinstance(inputs, str): + # file_bytes = File.read(inputs) + # data1, fs = sf.read(io.BytesIO(file_bytes)) + data1, fs = sf.read(inputs) + else: + raise TypeError(f'Unsupported type {type(inputs)}.') + if len(data1.shape) > 1: + data1 = data1[:, 0] + if fs != self.SAMPLE_RATE: + data1 = librosa.resample( + data1, orig_sr=fs, target_sr=self.SAMPLE_RATE) + data1 = audio_norm(data1) + data = data1.astype(np.float32) + inputs = np.reshape(data, [1, data.shape[0]]) + return {'ndarray': inputs, 'nsamples': data.shape[0]} + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + ndarray = inputs['ndarray'] + if isinstance(ndarray, torch.Tensor): + ndarray = ndarray.cpu().numpy() + nsamples = inputs['nsamples'] + decode_do_segement = False + window = 16000 * 2 # 2s + stride = int(window * 0.75) + print('inputs:{}'.format(ndarray.shape)) + b, t = ndarray.shape # size() + if t > window * 5: # 10s + decode_do_segement = True + print('decode_do_segement') + + if t < window: + ndarray = np.concatenate( + [ndarray, np.zeros((ndarray.shape[0], window - t))], 1) + elif decode_do_segement: + if t < window + stride: + padding = window + stride - t + print('padding: {}'.format(padding)) + ndarray = np.concatenate( + [ndarray, np.zeros((ndarray.shape[0], padding))], 1) + else: + if (t - window) % stride != 0: + # padding = t - (t - window) // stride * stride + padding = ( + (t - window) // stride + 1) * stride + window - t + print('padding: {}'.format(padding)) + ndarray = np.concatenate( + [ndarray, + np.zeros((ndarray.shape[0], padding))], 1) + # else: + # if (t - window) % stride != 0: + # padding = t - (t - window) // stride * stride + # print('padding: {}'.format(padding)) + # ndarray = np.concatenate( + # [ndarray, np.zeros((ndarray.shape[0], padding))], 1) + print('inputs after padding:{}'.format(ndarray.shape)) + with torch.no_grad(): + ndarray = torch.from_numpy(np.float32(ndarray)).to(self.device) + b, t = ndarray.shape + if decode_do_segement: + outputs = np.zeros(t) + give_up_length = (window - stride) // 2 + current_idx = 0 + while current_idx + window <= t: + # print('current_idx: {}'.format(current_idx)) + print( + '\rcurrent_idx: {} {:.2f}%'.format( + current_idx, current_idx * 100 / t), + end='') + tmp_input = dict(noisy=ndarray[:, current_idx:current_idx + + window]) + tmp_output = self.model( + tmp_input, )['wav_l2'][0].cpu().numpy() + end_index = current_idx + window - give_up_length + if current_idx == 0: + outputs[current_idx: + end_index] = tmp_output[:-give_up_length] + else: + outputs[current_idx + + give_up_length:end_index] = tmp_output[ + give_up_length:-give_up_length] + current_idx += stride + print('\rcurrent_idx: {} {:.2f}%'.format(current_idx, 100)) + else: + outputs = self.model( + dict(noisy=ndarray))['wav_l2'][0].cpu().numpy() + outputs = (outputs[:nsamples] * 32768).astype(np.int16).tobytes() + return {OutputKeys.OUTPUT_PCM: outputs} + + def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]: + if 'output_path' in kwargs.keys(): + sf.write( + kwargs['output_path'], + np.frombuffer(inputs[OutputKeys.OUTPUT_PCM], dtype=np.int16), + self.SAMPLE_RATE) + return inputs diff --git a/requirements/datasets.txt b/requirements/datasets.txt index b290664e..d20154e1 100644 --- a/requirements/datasets.txt +++ b/requirements/datasets.txt @@ -1,6 +1,6 @@ addict attrs -datasets>=3.0.0 +datasets>=3.0.0,<=3.0.1 einops oss2 Pillow diff --git a/requirements/framework.txt b/requirements/framework.txt index dabab41f..9aa4c045 100644 --- a/requirements/framework.txt +++ b/requirements/framework.txt @@ -1,6 +1,6 @@ addict attrs -datasets>=3.0.0 +datasets>=3.0.0,<=3.0.1 einops oss2 Pillow diff --git a/tests/pipelines/test_speech_signal_process.py b/tests/pipelines/test_speech_signal_process.py index 6130ea31..b853f419 100644 --- a/tests/pipelines/test_speech_signal_process.py +++ b/tests/pipelines/test_speech_signal_process.py @@ -150,6 +150,36 @@ class SpeechSignalProcessTest(unittest.TestCase): w.write(pcm) audio = f.read(block_size) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_zipenhancer_ans(self): + model_id = 'damo/speech_zipenhancer_ans_multiloss_16k_base' + ans = pipeline(Tasks.acoustic_noise_suppression, model=model_id) + output_path = os.path.abspath('output.wav') + ans(os.path.join(os.getcwd(), NOISE_SPEECH_FILE), + output_path=output_path) + print(f'Processed audio saved to {output_path}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_zipenhancer_ans_url(self): + model_id = 'damo/speech_zipenhancer_ans_multiloss_16k_base' + ans = pipeline(Tasks.acoustic_noise_suppression, model=model_id) + output_path = os.path.abspath('output.wav') + ans(NOISE_SPEECH_URL, output_path=output_path) + print(f'Processed audio saved to {output_path}') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_zipenhancer_ans_bytes(self): + model_id = 'damo/speech_zipenhancer_ans_multiloss_16k_base' + ans = pipeline( + Tasks.acoustic_noise_suppression, + model=model_id, + pipeline_name=Pipelines.speech_zipenhancer_ans_multiloss_16k_base) + output_path = os.path.abspath('output.wav') + with open(os.path.join(os.getcwd(), NOISE_SPEECH_FILE), 'rb') as f: + data = f.read() + ans(data, output_path=output_path) + print(f'Processed audio saved to {output_path}') + if __name__ == '__main__': unittest.main()