[to #42322933] add asr inference with pytorch(espnet framework)

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9273537
This commit is contained in:
shichen.fsc
2022-07-11 16:48:47 +08:00
parent 59495d375f
commit d7c780069f
27 changed files with 8537 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:87bde7feb3b40d75dec27e5824dd1077911f867e3f125c4bf603ec0af954d4db
size 77864

View File

@@ -23,6 +23,7 @@ class Models(object):
sambert_hifigan = 'sambert-hifigan'
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
kws_kwsbp = 'kws-kwsbp'
generic_asr = 'generic-asr'
# multi-modal models
ofa = 'ofa'
@@ -68,6 +69,7 @@ class Pipelines(object):
speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k'
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
kws_kwsbp = 'kws-kwsbp'
asr_inference = 'asr-inference'
# multi-modal tasks
image_caption = 'image-captioning'
@@ -120,6 +122,7 @@ class Preprocessors(object):
linear_aec_fbank = 'linear-aec-fbank'
text_to_tacotron_symbols = 'text-to-tacotron-symbols'
wav_to_lists = 'wav-to-lists'
wav_to_scp = 'wav-to-scp'
# multi-modal
ofa_image_caption = 'ofa-image-caption'

View File

@@ -5,6 +5,7 @@ from .base import Model
from .builder import MODELS, build_model
try:
from .audio.asr import GenericAutomaticSpeechRecognition
from .audio.tts import SambertHifigan
from .audio.kws import GenericKeyWordSpotting
from .audio.ans.frcrn import FRCRNModel

View File

@@ -0,0 +1 @@
from .generic_automatic_speech_recognition import * # noqa F403

View File

@@ -0,0 +1,39 @@
import os
from typing import Any, Dict
from modelscope.metainfo import Models
from modelscope.models.base import Model
from modelscope.models.builder import MODELS
from modelscope.utils.constant import Tasks
__all__ = ['GenericAutomaticSpeechRecognition']
@MODELS.register_module(
Tasks.auto_speech_recognition, module_name=Models.generic_asr)
class GenericAutomaticSpeechRecognition(Model):
def __init__(self, model_dir: str, am_model_name: str,
model_config: Dict[str, Any], *args, **kwargs):
"""initialize the info of model.
Args:
model_dir (str): the model path.
am_model_name (str): the am model name from configuration.json
"""
self.model_cfg = {
# the recognition model dir path
'model_workspace': model_dir,
# the am model name
'am_model': am_model_name,
# the am model file path
'am_model_path': os.path.join(model_dir, am_model_name),
# the recognition model config dict
'model_config': model_config
}
def forward(self) -> Dict[str, Any]:
"""return the info of the model
"""
return self.model_cfg

View File

@@ -3,6 +3,7 @@
from modelscope.utils.error import TENSORFLOW_IMPORT_ERROR
try:
from .asr.asr_inference_pipeline import AutomaticSpeechRecognitionPipeline
from .kws_kwsbp_pipeline import * # noqa F403
from .linear_aec_pipeline import LinearAECPipeline
except ModuleNotFoundError as e:

View File

@@ -0,0 +1,12 @@
import nltk
try:
nltk.data.find('taggers/averaged_perceptron_tagger')
except LookupError:
nltk.download(
'averaged_perceptron_tagger', halt_on_error=False, raise_on_error=True)
try:
nltk.data.find('corpora/cmudict')
except LookupError:
nltk.download('cmudict', halt_on_error=False, raise_on_error=True)

View File

@@ -0,0 +1,690 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from espnet/espnet.
import argparse
import logging
import sys
import time
from pathlib import Path
from typing import Any, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
from espnet2.asr.frontend.default import DefaultFrontend
from espnet2.asr.transducer.beam_search_transducer import BeamSearchTransducer
from espnet2.asr.transducer.beam_search_transducer import \
ExtendedHypothesis as ExtTransHypothesis # noqa: H301
from espnet2.asr.transducer.beam_search_transducer import \
Hypothesis as TransHypothesis
from espnet2.fileio.datadir_writer import DatadirWriter
from espnet2.tasks.lm import LMTask
from espnet2.text.build_tokenizer import build_tokenizer
from espnet2.text.token_id_converter import TokenIDConverter
from espnet2.torch_utils.device_funcs import to_device
from espnet2.torch_utils.set_all_random_seed import set_all_random_seed
from espnet2.utils import config_argparse
from espnet2.utils.types import str2bool, str2triple_str, str_or_none
from espnet.nets.batch_beam_search import BatchBeamSearch
from espnet.nets.batch_beam_search_online_sim import BatchBeamSearchOnlineSim
from espnet.nets.beam_search import BeamSearch, Hypothesis
from espnet.nets.pytorch_backend.transformer.subsampling import \
TooShortUttError
from espnet.nets.scorer_interface import BatchScorerInterface
from espnet.nets.scorers.ctc import CTCPrefixScorer
from espnet.nets.scorers.length_bonus import LengthBonus
from espnet.utils.cli_utils import get_commandline_args
from typeguard import check_argument_types, check_return_type
from .espnet.tasks.asr import ASRTaskNAR as ASRTask
class Speech2Text:
def __init__(self,
asr_train_config: Union[Path, str] = None,
asr_model_file: Union[Path, str] = None,
transducer_conf: dict = None,
lm_train_config: Union[Path, str] = None,
lm_file: Union[Path, str] = None,
ngram_scorer: str = 'full',
ngram_file: Union[Path, str] = None,
token_type: str = None,
bpemodel: str = None,
device: str = 'cpu',
maxlenratio: float = 0.0,
minlenratio: float = 0.0,
batch_size: int = 1,
dtype: str = 'float32',
beam_size: int = 20,
ctc_weight: float = 0.5,
lm_weight: float = 1.0,
ngram_weight: float = 0.9,
penalty: float = 0.0,
nbest: int = 1,
streaming: bool = False,
frontend_conf: dict = None):
assert check_argument_types()
# 1. Build ASR model
scorers = {}
asr_model, asr_train_args = ASRTask.build_model_from_file(
asr_train_config, asr_model_file, device)
if asr_model.frontend is None and frontend_conf is not None:
frontend = DefaultFrontend(**frontend_conf)
asr_model.frontend = frontend
asr_model.to(dtype=getattr(torch, dtype)).eval()
decoder = asr_model.decoder
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
token_list = asr_model.token_list
scorers.update(
decoder=decoder,
ctc=ctc,
length_bonus=LengthBonus(len(token_list)),
)
# 2. Build Language model
if lm_train_config is not None:
lm, lm_train_args = LMTask.build_model_from_file(
lm_train_config, lm_file, device)
scorers['lm'] = lm.lm
# 3. Build ngram model
if ngram_file is not None:
if ngram_scorer == 'full':
from espnet.nets.scorers.ngram import NgramFullScorer
ngram = NgramFullScorer(ngram_file, token_list)
else:
from espnet.nets.scorers.ngram import NgramPartScorer
ngram = NgramPartScorer(ngram_file, token_list)
else:
ngram = None
scorers['ngram'] = ngram
# 4. Build BeamSearch object
if asr_model.use_transducer_decoder:
beam_search_transducer = BeamSearchTransducer(
decoder=asr_model.decoder,
joint_network=asr_model.joint_network,
beam_size=beam_size,
lm=scorers['lm'] if 'lm' in scorers else None,
lm_weight=lm_weight,
**transducer_conf,
)
beam_search = None
else:
beam_search_transducer = None
weights = dict(
decoder=1.0 - ctc_weight,
ctc=ctc_weight,
lm=lm_weight,
ngram=ngram_weight,
length_bonus=penalty,
)
beam_search = BeamSearch(
beam_size=beam_size,
weights=weights,
scorers=scorers,
sos=asr_model.sos,
eos=asr_model.eos,
vocab_size=len(token_list),
token_list=token_list,
pre_beam_score_key=None if ctc_weight == 1.0 else 'full',
)
# TODO(karita): make all scorers batchfied
if batch_size == 1:
non_batch = [
k for k, v in beam_search.full_scorers.items()
if not isinstance(v, BatchScorerInterface)
]
if len(non_batch) == 0:
if streaming:
beam_search.__class__ = BatchBeamSearchOnlineSim
beam_search.set_streaming_config(asr_train_config)
logging.info(
'BatchBeamSearchOnlineSim implementation is selected.'
)
else:
beam_search.__class__ = BatchBeamSearch
else:
logging.warning(
f'As non-batch scorers {non_batch} are found, '
f'fall back to non-batch implementation.')
beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
for scorer in scorers.values():
if isinstance(scorer, torch.nn.Module):
scorer.to(
device=device, dtype=getattr(torch, dtype)).eval()
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
if token_type is None:
token_type = asr_train_args.token_type
if bpemodel is None:
bpemodel = asr_train_args.bpemodel
if token_type is None:
tokenizer = None
elif token_type == 'bpe':
if bpemodel is not None:
tokenizer = build_tokenizer(
token_type=token_type, bpemodel=bpemodel)
else:
tokenizer = None
else:
tokenizer = build_tokenizer(token_type=token_type)
converter = TokenIDConverter(token_list=token_list)
self.asr_model = asr_model
self.asr_train_args = asr_train_args
self.converter = converter
self.tokenizer = tokenizer
self.beam_search = beam_search
self.beam_search_transducer = beam_search_transducer
self.maxlenratio = maxlenratio
self.minlenratio = minlenratio
self.device = device
self.dtype = dtype
self.nbest = nbest
@torch.no_grad()
def __call__(self, speech: Union[torch.Tensor, np.ndarray]):
"""Inference
Args:
data: Input speech data
Returns:
text, token, token_int, hyp
"""
assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
# data: (Nsamples,) -> (1, Nsamples)
speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
# lengths: (1,)
lengths = speech.new_full([1],
dtype=torch.long,
fill_value=speech.size(1))
batch = {'speech': speech, 'speech_lengths': lengths}
# a. To device
batch = to_device(batch, device=self.device)
# b. Forward Encoder
enc, enc_len = self.asr_model.encode(**batch)
if isinstance(enc, tuple):
enc = enc[0]
assert len(enc) == 1, len(enc)
predictor_outs = self.asr_model.calc_predictor(enc, enc_len)
pre_acoustic_embeds, pre_token_length = predictor_outs[
0], predictor_outs[1]
pre_token_length = torch.tensor([pre_acoustic_embeds.size(1)],
device=pre_acoustic_embeds.device)
decoder_outs = self.asr_model.cal_decoder_with_predictor(
enc, enc_len, pre_acoustic_embeds, pre_token_length)
decoder_out = decoder_outs[0]
yseq = decoder_out.argmax(dim=-1)
score = decoder_out.max(dim=-1)[0]
score = torch.sum(score, dim=-1)
# pad with mask tokens to ensure compatibility with sos/eos tokens
yseq = torch.tensor(
[self.asr_model.sos] + yseq.tolist()[0] + [self.asr_model.eos],
device=yseq.device)
nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
results = []
for hyp in nbest_hyps:
assert isinstance(hyp, (Hypothesis, TransHypothesis)), type(hyp)
# remove sos/eos and get results
last_pos = None if self.asr_model.use_transducer_decoder else -1
if isinstance(hyp.yseq, list):
token_int = hyp.yseq[1:last_pos]
else:
token_int = hyp.yseq[1:last_pos].tolist()
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x != 0, token_int))
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
if self.tokenizer is not None:
text = self.tokenizer.tokens2text(token)
else:
text = None
results.append((text, token, token_int, hyp, speech.size(1)))
return results
@staticmethod
def from_pretrained(
model_tag: Optional[str] = None,
**kwargs: Optional[Any],
):
"""Build Speech2Text instance from the pretrained model.
Args:
model_tag (Optional[str]): Model tag of the pretrained models.
Currently, the tags of espnet_model_zoo are supported.
Returns:
Speech2Text: Speech2Text instance.
"""
if model_tag is not None:
try:
from espnet_model_zoo.downloader import ModelDownloader
except ImportError:
logging.error(
'`espnet_model_zoo` is not installed. '
'Please install via `pip install -U espnet_model_zoo`.')
raise
d = ModelDownloader()
kwargs.update(**d.download_and_unpack(model_tag))
return Speech2Text(**kwargs)
def inference(
output_dir: str,
maxlenratio: float,
minlenratio: float,
batch_size: int,
dtype: str,
beam_size: int,
ngpu: int,
seed: int,
ctc_weight: float,
lm_weight: float,
ngram_weight: float,
penalty: float,
nbest: int,
num_workers: int,
log_level: Union[int, str],
data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
key_file: Optional[str],
asr_train_config: Optional[str],
asr_model_file: Optional[str],
lm_train_config: Optional[str],
lm_file: Optional[str],
word_lm_train_config: Optional[str],
word_lm_file: Optional[str],
ngram_file: Optional[str],
model_tag: Optional[str],
token_type: Optional[str],
bpemodel: Optional[str],
allow_variable_data_keys: bool,
transducer_conf: Optional[dict],
streaming: bool,
frontend_conf: dict = None,
):
assert check_argument_types()
if batch_size > 1:
raise NotImplementedError('batch decoding is not implemented')
if word_lm_train_config is not None:
raise NotImplementedError('Word LM is not implemented')
if ngpu > 1:
raise NotImplementedError('only single GPU decoding is supported')
logging.basicConfig(
level=log_level,
format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
)
if ngpu >= 1:
device = 'cuda'
else:
device = 'cpu'
# 1. Set random-seed
set_all_random_seed(seed)
# 2. Build speech2text
speech2text_kwargs = dict(
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
transducer_conf=transducer_conf,
lm_train_config=lm_train_config,
lm_file=lm_file,
ngram_file=ngram_file,
token_type=token_type,
bpemodel=bpemodel,
device=device,
maxlenratio=maxlenratio,
minlenratio=minlenratio,
dtype=dtype,
beam_size=beam_size,
ctc_weight=ctc_weight,
lm_weight=lm_weight,
ngram_weight=ngram_weight,
penalty=penalty,
nbest=nbest,
streaming=streaming,
frontend_conf=frontend_conf,
)
speech2text = Speech2Text.from_pretrained(
model_tag=model_tag,
**speech2text_kwargs,
)
# 3. Build data-iterator
loader = ASRTask.build_streaming_iterator(
data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args,
False),
collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
forward_time_total = 0.0
length_total = 0.0
# 7 .Start for-loop
# FIXME(kamo): The output format should be discussed about
with DatadirWriter(output_dir) as writer:
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f'{len(keys)} != {_bs}'
batch = {
k: v[0]
for k, v in batch.items() if not k.endswith('_lengths')
}
# N-best list of (text, token, token_int, hyp_object)
try:
time_beg = time.time()
results = speech2text(**batch)
time_end = time.time()
forward_time = time_end - time_beg
length = results[0][-1]
results = [results[0][:-1]]
forward_time_total += forward_time
length_total += length
except TooShortUttError as e:
logging.warning(f'Utterance {keys} {e}')
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
results = [[' ', ['<space>'], [2], hyp]] * nbest
# Only supporting batch_size==1
key = keys[0]
for n, (text, token, token_int,
hyp) in zip(range(1, nbest + 1), results):
# Create a directory: outdir/{n}best_recog
ibest_writer = writer[f'{n}best_recog']
# Write the result to each file
ibest_writer['token'][key] = ' '.join(token)
ibest_writer['token_int'][key] = ' '.join(map(str, token_int))
ibest_writer['score'][key] = str(hyp.score)
if text is not None:
ibest_writer['text'][key] = text
logging.info(
'decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}'
.format(length_total, forward_time_total,
100 * forward_time_total / length_total))
def get_parser():
parser = config_argparse.ArgumentParser(
description='ASR Decoding',
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
'--log_level',
type=lambda x: x.upper(),
default='INFO',
choices=('CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG', 'NOTSET'),
help='The verbose level of logging',
)
parser.add_argument('--output_dir', type=str, required=True)
parser.add_argument(
'--ngpu',
type=int,
default=0,
help='The number of gpus. 0 indicates CPU mode',
)
parser.add_argument('--seed', type=int, default=0, help='Random seed')
parser.add_argument(
'--dtype',
default='float32',
choices=['float16', 'float32', 'float64'],
help='Data type',
)
parser.add_argument(
'--num_workers',
type=int,
default=1,
help='The number of workers used for DataLoader',
)
group = parser.add_argument_group('Input data related')
group.add_argument(
'--data_path_and_name_and_type',
type=str2triple_str,
required=True,
action='append',
)
group.add_argument('--key_file', type=str_or_none)
group.add_argument(
'--allow_variable_data_keys', type=str2bool, default=False)
group = parser.add_argument_group('The model configuration related')
group.add_argument(
'--asr_train_config',
type=str,
help='ASR training configuration',
)
group.add_argument(
'--asr_model_file',
type=str,
help='ASR model parameter file',
)
group.add_argument(
'--lm_train_config',
type=str,
help='LM training configuration',
)
group.add_argument(
'--lm_file',
type=str,
help='LM parameter file',
)
group.add_argument(
'--word_lm_train_config',
type=str,
help='Word LM training configuration',
)
group.add_argument(
'--word_lm_file',
type=str,
help='Word LM parameter file',
)
group.add_argument(
'--ngram_file',
type=str,
help='N-gram parameter file',
)
group.add_argument(
'--model_tag',
type=str,
help='Pretrained model tag. If specify this option, *_train_config and '
'*_file will be overwritten',
)
group = parser.add_argument_group('Beam-search related')
group.add_argument(
'--batch_size',
type=int,
default=1,
help='The batch size for inference',
)
group.add_argument(
'--nbest', type=int, default=1, help='Output N-best hypotheses')
group.add_argument('--beam_size', type=int, default=20, help='Beam size')
group.add_argument(
'--penalty', type=float, default=0.0, help='Insertion penalty')
group.add_argument(
'--maxlenratio',
type=float,
default=0.0,
help='Input length ratio to obtain max output length. '
'If maxlenratio=0.0 (default), it uses a end-detect '
'function '
'to automatically find maximum hypothesis lengths.'
'If maxlenratio<0.0, its absolute value is interpreted'
'as a constant max output length',
)
group.add_argument(
'--minlenratio',
type=float,
default=0.0,
help='Input length ratio to obtain min output length',
)
group.add_argument(
'--ctc_weight',
type=float,
default=0.5,
help='CTC weight in joint decoding',
)
group.add_argument(
'--lm_weight', type=float, default=1.0, help='RNNLM weight')
group.add_argument(
'--ngram_weight', type=float, default=0.9, help='ngram weight')
group.add_argument('--streaming', type=str2bool, default=False)
group.add_argument(
'--frontend_conf',
default=None,
help='',
)
group = parser.add_argument_group('Text converter related')
group.add_argument(
'--token_type',
type=str_or_none,
default=None,
choices=['char', 'bpe', None],
help='The token type for ASR model. '
'If not given, refers from the training args',
)
group.add_argument(
'--bpemodel',
type=str_or_none,
default=None,
help='The model path of sentencepiece. '
'If not given, refers from the training args',
)
group.add_argument(
'--transducer_conf',
default=None,
help='The keyword arguments for transducer beam search.',
)
return parser
def asr_inference(
output_dir: str,
maxlenratio: float,
minlenratio: float,
beam_size: int,
ngpu: int,
ctc_weight: float,
lm_weight: float,
penalty: float,
data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
asr_train_config: Optional[str],
asr_model_file: Optional[str],
nbest: int = 1,
num_workers: int = 1,
log_level: Union[int, str] = 'INFO',
batch_size: int = 1,
dtype: str = 'float32',
seed: int = 0,
key_file: Optional[str] = None,
lm_train_config: Optional[str] = None,
lm_file: Optional[str] = None,
word_lm_train_config: Optional[str] = None,
word_lm_file: Optional[str] = None,
ngram_file: Optional[str] = None,
ngram_weight: float = 0.9,
model_tag: Optional[str] = None,
token_type: Optional[str] = None,
bpemodel: Optional[str] = None,
allow_variable_data_keys: bool = False,
transducer_conf: Optional[dict] = None,
streaming: bool = False,
frontend_conf: dict = None,
):
inference(
output_dir=output_dir,
maxlenratio=maxlenratio,
minlenratio=minlenratio,
batch_size=batch_size,
dtype=dtype,
beam_size=beam_size,
ngpu=ngpu,
seed=seed,
ctc_weight=ctc_weight,
lm_weight=lm_weight,
ngram_weight=ngram_weight,
penalty=penalty,
nbest=nbest,
num_workers=num_workers,
log_level=log_level,
data_path_and_name_and_type=data_path_and_name_and_type,
key_file=key_file,
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
lm_train_config=lm_train_config,
lm_file=lm_file,
word_lm_train_config=word_lm_train_config,
word_lm_file=word_lm_file,
ngram_file=ngram_file,
model_tag=model_tag,
token_type=token_type,
bpemodel=bpemodel,
allow_variable_data_keys=allow_variable_data_keys,
transducer_conf=transducer_conf,
streaming=streaming,
frontend_conf=frontend_conf)
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
kwargs.pop('config', None)
inference(**kwargs)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,193 @@
import os
from typing import Any, Dict, List
import numpy as np
def type_checking(wav_path: str,
recog_type: str = None,
audio_format: str = None,
workspace: str = None):
assert os.path.exists(wav_path), f'wav_path:{wav_path} does not exist'
r_recog_type = recog_type
r_audio_format = audio_format
r_workspace = workspace
r_wav_path = wav_path
if r_workspace is None or len(r_workspace) == 0:
r_workspace = os.path.join(os.getcwd(), '.tmp')
if r_recog_type is None:
if os.path.isfile(wav_path):
if wav_path.endswith('.wav') or wav_path.endswith('.WAV'):
r_recog_type = 'wav'
r_audio_format = 'wav'
elif os.path.isdir(wav_path):
dir_name = os.path.basename(wav_path)
if 'test' in dir_name:
r_recog_type = 'test'
elif 'dev' in dir_name:
r_recog_type = 'dev'
elif 'train' in dir_name:
r_recog_type = 'train'
if r_audio_format is None:
if find_file_by_ends(wav_path, '.ark'):
r_audio_format = 'kaldi_ark'
elif find_file_by_ends(wav_path, '.wav') or find_file_by_ends(
wav_path, '.WAV'):
r_audio_format = 'wav'
if r_audio_format == 'kaldi_ark' and r_recog_type != 'wav':
# datasets with kaldi_ark file
r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../'))
elif r_audio_format == 'wav' and r_recog_type != 'wav':
# datasets with waveform files
r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../../'))
return r_recog_type, r_audio_format, r_workspace, r_wav_path
def find_file_by_ends(dir_path: str, ends: str):
dir_files = os.listdir(dir_path)
for file in dir_files:
file_path = os.path.join(dir_path, file)
if os.path.isfile(file_path):
if file_path.endswith(ends):
return True
elif os.path.isdir(file_path):
if find_file_by_ends(file_path, ends):
return True
return False
def compute_wer(hyp_text_path: str, ref_text_path: str) -> Dict[str, Any]:
assert os.path.exists(hyp_text_path), 'hyp_text does not exist'
assert os.path.exists(ref_text_path), 'ref_text does not exist'
rst = {
'Wrd': 0,
'Corr': 0,
'Ins': 0,
'Del': 0,
'Sub': 0,
'Snt': 0,
'Err': 0.0,
'S.Err': 0.0,
'wrong_words': 0,
'wrong_sentences': 0
}
with open(ref_text_path, 'r', encoding='utf-8') as r:
r_lines = r.readlines()
with open(hyp_text_path, 'r', encoding='utf-8') as h:
h_lines = h.readlines()
for r_line in r_lines:
r_line_item = r_line.split()
r_key = r_line_item[0]
r_sentence = r_line_item[1]
for h_line in h_lines:
# find sentence from hyp text
if r_key in h_line:
h_line_item = h_line.split()
h_sentence = h_line_item[1]
out_item = compute_wer_by_line(h_sentence, r_sentence)
rst['Wrd'] += out_item['nwords']
rst['Corr'] += out_item['cor']
rst['wrong_words'] += out_item['wrong']
rst['Ins'] += out_item['ins']
rst['Del'] += out_item['del']
rst['Sub'] += out_item['sub']
rst['Snt'] += 1
if out_item['wrong'] > 0:
rst['wrong_sentences'] += 1
break
if rst['Wrd'] > 0:
rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
if rst['Snt'] > 0:
rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2)
return rst
def compute_wer_by_line(hyp: list, ref: list) -> Dict[str, Any]:
len_hyp = len(hyp)
len_ref = len(ref)
cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16)
ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8)
for i in range(len_hyp + 1):
cost_matrix[i][0] = i
for j in range(len_ref + 1):
cost_matrix[0][j] = j
for i in range(1, len_hyp + 1):
for j in range(1, len_ref + 1):
if hyp[i - 1] == ref[j - 1]:
cost_matrix[i][j] = cost_matrix[i - 1][j - 1]
else:
substitution = cost_matrix[i - 1][j - 1] + 1
insertion = cost_matrix[i - 1][j] + 1
deletion = cost_matrix[i][j - 1] + 1
compare_val = [substitution, insertion, deletion]
min_val = min(compare_val)
operation_idx = compare_val.index(min_val) + 1
cost_matrix[i][j] = min_val
ops_matrix[i][j] = operation_idx
match_idx = []
i = len_hyp
j = len_ref
rst = {
'nwords': len_hyp,
'cor': 0,
'wrong': 0,
'ins': 0,
'del': 0,
'sub': 0
}
while i >= 0 or j >= 0:
i_idx = max(0, i)
j_idx = max(0, j)
if ops_matrix[i_idx][j_idx] == 0: # correct
if i - 1 >= 0 and j - 1 >= 0:
match_idx.append((j - 1, i - 1))
rst['cor'] += 1
i -= 1
j -= 1
elif ops_matrix[i_idx][j_idx] == 2: # insert
i -= 1
rst['ins'] += 1
elif ops_matrix[i_idx][j_idx] == 3: # delete
j -= 1
rst['del'] += 1
elif ops_matrix[i_idx][j_idx] == 1: # substitute
i -= 1
j -= 1
rst['sub'] += 1
if i < 0 and j >= 0:
rst['del'] += 1
elif j < 0 and i >= 0:
rst['ins'] += 1
match_idx.reverse()
wrong_cnt = cost_matrix[len_hyp][len_ref]
rst['wrong'] = wrong_cnt
return rst

View File

@@ -0,0 +1,757 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from espnet/espnet.
"""Decoder definition."""
from typing import Any, List, Sequence, Tuple
import torch
from espnet2.asr.decoder.abs_decoder import AbsDecoder
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask
from espnet.nets.pytorch_backend.transformer.attention import \
MultiHeadedAttention
from espnet.nets.pytorch_backend.transformer.decoder_layer import DecoderLayer
from espnet.nets.pytorch_backend.transformer.dynamic_conv import \
DynamicConvolution
from espnet.nets.pytorch_backend.transformer.dynamic_conv2d import \
DynamicConvolution2D
from espnet.nets.pytorch_backend.transformer.embedding import \
PositionalEncoding
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
from espnet.nets.pytorch_backend.transformer.lightconv import \
LightweightConvolution
from espnet.nets.pytorch_backend.transformer.lightconv2d import \
LightweightConvolution2D
from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import \
PositionwiseFeedForward # noqa: H301
from espnet.nets.pytorch_backend.transformer.repeat import repeat
from espnet.nets.scorer_interface import BatchScorerInterface
from typeguard import check_argument_types
class BaseTransformerDecoder(AbsDecoder, BatchScorerInterface):
"""Base class of Transfomer decoder module.
Args:
vocab_size: output dim
encoder_output_size: dimension of attention
attention_heads: the number of heads of multi head attention
linear_units: the number of units of position-wise feed forward
num_blocks: the number of decoder blocks
dropout_rate: dropout rate
self_attention_dropout_rate: dropout rate for attention
input_layer: input layer type
use_output_layer: whether to use output layer
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
normalize_before: whether to use layer_norm before the first block
concat_after: whether to concat attention layer's input and output
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied.
i.e. x -> x + att(x)
"""
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
input_layer: str = 'embed',
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
):
assert check_argument_types()
super().__init__()
attention_dim = encoder_output_size
if input_layer == 'embed':
self.embed = torch.nn.Sequential(
torch.nn.Embedding(vocab_size, attention_dim),
pos_enc_class(attention_dim, positional_dropout_rate),
)
elif input_layer == 'linear':
self.embed = torch.nn.Sequential(
torch.nn.Linear(vocab_size, attention_dim),
torch.nn.LayerNorm(attention_dim),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(attention_dim, positional_dropout_rate),
)
else:
raise ValueError(
f"only 'embed' or 'linear' is supported: {input_layer}")
self.normalize_before = normalize_before
if self.normalize_before:
self.after_norm = LayerNorm(attention_dim)
if use_output_layer:
self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
else:
self.output_layer = None
# Must set by the inheritance
self.decoders = None
def forward(
self,
hs_pad: torch.Tensor,
hlens: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward decoder.
Args:
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
hlens: (batch)
ys_in_pad:
input token ids, int64 (batch, maxlen_out)
if input_layer == "embed"
input tensor (batch, maxlen_out, #mels) in the other cases
ys_in_lens: (batch)
Returns:
(tuple): tuple containing:
x: decoded token score before softmax (batch, maxlen_out, token)
if use_output_layer is True,
olens: (batch, )
"""
tgt = ys_in_pad
# tgt_mask: (B, 1, L)
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
# m: (1, L, L)
m = subsequent_mask(
tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
# tgt_mask: (B, L, L)
tgt_mask = tgt_mask & m
memory = hs_pad
memory_mask = (
~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
memory.device)
# Padding for Longformer
if memory_mask.shape[-1] != memory.shape[1]:
padlen = memory.shape[1] - memory_mask.shape[-1]
memory_mask = torch.nn.functional.pad(memory_mask, (0, padlen),
'constant', False)
x = self.embed(tgt)
x, tgt_mask, memory, memory_mask = self.decoders(
x, tgt_mask, memory, memory_mask)
if self.normalize_before:
x = self.after_norm(x)
if self.output_layer is not None:
x = self.output_layer(x)
olens = tgt_mask.sum(1)
return x, olens
def forward_one_step(
self,
tgt: torch.Tensor,
tgt_mask: torch.Tensor,
memory: torch.Tensor,
cache: List[torch.Tensor] = None,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""Forward one step.
Args:
tgt: input token ids, int64 (batch, maxlen_out)
tgt_mask: input token mask, (batch, maxlen_out)
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
memory: encoded memory, float32 (batch, maxlen_in, feat)
cache: cached output list of (batch, max_time_out-1, size)
Returns:
y, cache: NN output value and cache per `self.decoders`.
y.shape` is (batch, maxlen_out, token)
"""
x = self.embed(tgt)
if cache is None:
cache = [None] * len(self.decoders)
new_cache = []
for c, decoder in zip(cache, self.decoders):
x, tgt_mask, memory, memory_mask = decoder(
x, tgt_mask, memory, None, cache=c)
new_cache.append(x)
if self.normalize_before:
y = self.after_norm(x[:, -1])
else:
y = x[:, -1]
if self.output_layer is not None:
y = torch.log_softmax(self.output_layer(y), dim=-1)
return y, new_cache
def score(self, ys, state, x):
"""Score."""
ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
logp, state = self.forward_one_step(
ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state)
return logp.squeeze(0), state
def batch_score(self, ys: torch.Tensor, states: List[Any],
xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
"""Score new token batch.
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
# merge states
n_batch = len(ys)
n_layers = len(self.decoders)
if states[0] is None:
batch_state = None
else:
# transpose state of [batch, layer] into [layer, batch]
batch_state = [
torch.stack([states[b][i] for b in range(n_batch)])
for i in range(n_layers)
]
# batch decoding
ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0)
logp, states = self.forward_one_step(
ys, ys_mask, xs, cache=batch_state)
# transpose state of [layer, batch] into [batch, layer]
state_list = [[states[i][b] for i in range(n_layers)]
for b in range(n_batch)]
return logp, state_list
class TransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = 'embed',
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
):
assert check_argument_types()
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_output_layer=use_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
attention_dim = encoder_output_size
self.decoders = repeat(
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
MultiHeadedAttention(attention_heads, attention_dim,
self_attention_dropout_rate),
MultiHeadedAttention(attention_heads, attention_dim,
src_attention_dropout_rate),
PositionwiseFeedForward(attention_dim, linear_units,
dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
class ParaformerDecoder(TransformerDecoder):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = 'embed',
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
):
assert check_argument_types()
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_output_layer=use_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
attention_dim = encoder_output_size
self.decoders = repeat(
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
MultiHeadedAttention(attention_heads, attention_dim,
self_attention_dropout_rate),
MultiHeadedAttention(attention_heads, attention_dim,
src_attention_dropout_rate),
PositionwiseFeedForward(attention_dim, linear_units,
dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
def forward(
self,
hs_pad: torch.Tensor,
hlens: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward decoder.
Args:
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
hlens: (batch)
ys_in_pad:
input token ids, int64 (batch, maxlen_out)
if input_layer == "embed"
input tensor (batch, maxlen_out, #mels) in the other cases
ys_in_lens: (batch)
Returns:
(tuple): tuple containing:
x: decoded token score before softmax (batch, maxlen_out, token)
if use_output_layer is True,
olens: (batch, )
"""
tgt = ys_in_pad
# tgt_mask: (B, 1, L)
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
# m: (1, L, L)
# m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
# tgt_mask: (B, L, L)
# tgt_mask = tgt_mask & m
memory = hs_pad
memory_mask = (
~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
memory.device)
# Padding for Longformer
if memory_mask.shape[-1] != memory.shape[1]:
padlen = memory.shape[1] - memory_mask.shape[-1]
memory_mask = torch.nn.functional.pad(memory_mask, (0, padlen),
'constant', False)
# x = self.embed(tgt)
x = tgt
x, tgt_mask, memory, memory_mask = self.decoders(
x, tgt_mask, memory, memory_mask)
if self.normalize_before:
x = self.after_norm(x)
if self.output_layer is not None:
x = self.output_layer(x)
olens = tgt_mask.sum(1)
return x, olens
class ParaformerDecoderBertEmbed(TransformerDecoder):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = 'embed',
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
embeds_id: int = 2,
):
assert check_argument_types()
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_output_layer=use_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
attention_dim = encoder_output_size
self.decoders = repeat(
embeds_id,
lambda lnum: DecoderLayer(
attention_dim,
MultiHeadedAttention(attention_heads, attention_dim,
self_attention_dropout_rate),
MultiHeadedAttention(attention_heads, attention_dim,
src_attention_dropout_rate),
PositionwiseFeedForward(attention_dim, linear_units,
dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
if embeds_id == num_blocks:
self.decoders2 = None
else:
self.decoders2 = repeat(
num_blocks - embeds_id,
lambda lnum: DecoderLayer(
attention_dim,
MultiHeadedAttention(attention_heads, attention_dim,
self_attention_dropout_rate),
MultiHeadedAttention(attention_heads, attention_dim,
src_attention_dropout_rate),
PositionwiseFeedForward(attention_dim, linear_units,
dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
def forward(
self,
hs_pad: torch.Tensor,
hlens: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward decoder.
Args:
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
hlens: (batch)
ys_in_pad:
input token ids, int64 (batch, maxlen_out)
if input_layer == "embed"
input tensor (batch, maxlen_out, #mels) in the other cases
ys_in_lens: (batch)
Returns:
(tuple): tuple containing:
x: decoded token score before softmax (batch, maxlen_out, token)
if use_output_layer is True,
olens: (batch, )
"""
tgt = ys_in_pad
# tgt_mask: (B, 1, L)
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
# m: (1, L, L)
# m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
# tgt_mask: (B, L, L)
# tgt_mask = tgt_mask & m
memory = hs_pad
memory_mask = (
~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
memory.device)
# Padding for Longformer
if memory_mask.shape[-1] != memory.shape[1]:
padlen = memory.shape[1] - memory_mask.shape[-1]
memory_mask = torch.nn.functional.pad(memory_mask, (0, padlen),
'constant', False)
# x = self.embed(tgt)
x = tgt
x, tgt_mask, memory, memory_mask = self.decoders(
x, tgt_mask, memory, memory_mask)
embeds_outputs = x
if self.decoders2 is not None:
x, tgt_mask, memory, memory_mask = self.decoders2(
x, tgt_mask, memory, memory_mask)
if self.normalize_before:
x = self.after_norm(x)
if self.output_layer is not None:
x = self.output_layer(x)
olens = tgt_mask.sum(1)
return x, olens, embeds_outputs
class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = 'embed',
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
conv_wshare: int = 4,
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
conv_usebias: int = False,
):
assert check_argument_types()
if len(conv_kernel_length) != num_blocks:
raise ValueError(
'conv_kernel_length must have equal number of values to num_blocks: '
f'{len(conv_kernel_length)} != {num_blocks}')
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_output_layer=use_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
attention_dim = encoder_output_size
self.decoders = repeat(
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
LightweightConvolution(
wshare=conv_wshare,
n_feat=attention_dim,
dropout_rate=self_attention_dropout_rate,
kernel_size=conv_kernel_length[lnum],
use_kernel_mask=True,
use_bias=conv_usebias,
),
MultiHeadedAttention(attention_heads, attention_dim,
src_attention_dropout_rate),
PositionwiseFeedForward(attention_dim, linear_units,
dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = 'embed',
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
conv_wshare: int = 4,
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
conv_usebias: int = False,
):
assert check_argument_types()
if len(conv_kernel_length) != num_blocks:
raise ValueError(
'conv_kernel_length must have equal number of values to num_blocks: '
f'{len(conv_kernel_length)} != {num_blocks}')
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_output_layer=use_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
attention_dim = encoder_output_size
self.decoders = repeat(
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
LightweightConvolution2D(
wshare=conv_wshare,
n_feat=attention_dim,
dropout_rate=self_attention_dropout_rate,
kernel_size=conv_kernel_length[lnum],
use_kernel_mask=True,
use_bias=conv_usebias,
),
MultiHeadedAttention(attention_heads, attention_dim,
src_attention_dropout_rate),
PositionwiseFeedForward(attention_dim, linear_units,
dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = 'embed',
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
conv_wshare: int = 4,
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
conv_usebias: int = False,
):
assert check_argument_types()
if len(conv_kernel_length) != num_blocks:
raise ValueError(
'conv_kernel_length must have equal number of values to num_blocks: '
f'{len(conv_kernel_length)} != {num_blocks}')
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_output_layer=use_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
attention_dim = encoder_output_size
self.decoders = repeat(
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
DynamicConvolution(
wshare=conv_wshare,
n_feat=attention_dim,
dropout_rate=self_attention_dropout_rate,
kernel_size=conv_kernel_length[lnum],
use_kernel_mask=True,
use_bias=conv_usebias,
),
MultiHeadedAttention(attention_heads, attention_dim,
src_attention_dropout_rate),
PositionwiseFeedForward(attention_dim, linear_units,
dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder):
def __init__(
self,
vocab_size: int,
encoder_output_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
self_attention_dropout_rate: float = 0.0,
src_attention_dropout_rate: float = 0.0,
input_layer: str = 'embed',
use_output_layer: bool = True,
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
conv_wshare: int = 4,
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
conv_usebias: int = False,
):
assert check_argument_types()
if len(conv_kernel_length) != num_blocks:
raise ValueError(
'conv_kernel_length must have equal number of values to num_blocks: '
f'{len(conv_kernel_length)} != {num_blocks}')
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
input_layer=input_layer,
use_output_layer=use_output_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
attention_dim = encoder_output_size
self.decoders = repeat(
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
DynamicConvolution2D(
wshare=conv_wshare,
n_feat=attention_dim,
dropout_rate=self_attention_dropout_rate,
kernel_size=conv_kernel_length[lnum],
use_kernel_mask=True,
use_bias=conv_usebias,
),
MultiHeadedAttention(attention_heads, attention_dim,
src_attention_dropout_rate),
PositionwiseFeedForward(attention_dim, linear_units,
dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)

View File

@@ -0,0 +1,710 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from espnet/espnet.
"""Conformer encoder definition."""
import logging
from typing import List, Optional, Tuple, Union
import torch
from espnet2.asr.ctc import CTC
from espnet2.asr.encoder.abs_encoder import AbsEncoder
from espnet.nets.pytorch_backend.conformer.convolution import ConvolutionModule
from espnet.nets.pytorch_backend.conformer.encoder_layer import EncoderLayer
from espnet.nets.pytorch_backend.nets_utils import (get_activation,
make_pad_mask)
from espnet.nets.pytorch_backend.transformer.embedding import \
LegacyRelPositionalEncoding # noqa: H301
from espnet.nets.pytorch_backend.transformer.embedding import \
PositionalEncoding # noqa: H301
from espnet.nets.pytorch_backend.transformer.embedding import \
RelPositionalEncoding # noqa: H301
from espnet.nets.pytorch_backend.transformer.embedding import \
ScaledPositionalEncoding # noqa: H301
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
from espnet.nets.pytorch_backend.transformer.multi_layer_conv import (
Conv1dLinear, MultiLayeredConv1d)
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import \
PositionwiseFeedForward # noqa: H301
from espnet.nets.pytorch_backend.transformer.repeat import repeat
from espnet.nets.pytorch_backend.transformer.subsampling import (
Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6,
Conv2dSubsampling8, TooShortUttError, check_short_utt)
from typeguard import check_argument_types
from ...nets.pytorch_backend.transformer.attention import \
LegacyRelPositionMultiHeadedAttention # noqa: H301
from ...nets.pytorch_backend.transformer.attention import \
MultiHeadedAttention # noqa: H301
from ...nets.pytorch_backend.transformer.attention import \
RelPositionMultiHeadedAttention # noqa: H301
from ...nets.pytorch_backend.transformer.attention import (
LegacyRelPositionMultiHeadedAttentionSANM,
RelPositionMultiHeadedAttentionSANM)
class ConformerEncoder(AbsEncoder):
"""Conformer encoder module.
Args:
input_size (int): Input dimension.
output_size (int): Dimension of attention.
attention_heads (int): The number of heads of multi head attention.
linear_units (int): The number of units of position-wise feed forward.
num_blocks (int): The number of decoder blocks.
dropout_rate (float): Dropout rate.
attention_dropout_rate (float): Dropout rate in attention.
positional_dropout_rate (float): Dropout rate after adding positional encoding.
input_layer (Union[str, torch.nn.Module]): Input layer type.
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
If True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
If False, no additional linear will be applied. i.e. x -> x + att(x)
positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
rel_pos_type (str): Whether to use the latest relative positional encoding or
the legacy one. The legacy relative positional encoding will be deprecated
in the future. More Details can be found in
https://github.com/espnet/espnet/pull/2816.
encoder_pos_enc_layer_type (str): Encoder positional encoding layer type.
encoder_attn_layer_type (str): Encoder attention layer type.
activation_type (str): Encoder activation function type.
macaron_style (bool): Whether to use macaron style for positionwise layer.
use_cnn_module (bool): Whether to use convolution module.
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
cnn_module_kernel (int): Kernerl size of convolution module.
padding_idx (int): Padding idx for input_layer=embed.
"""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: str = 'conv2d',
normalize_before: bool = True,
concat_after: bool = False,
positionwise_layer_type: str = 'linear',
positionwise_conv_kernel_size: int = 3,
macaron_style: bool = False,
rel_pos_type: str = 'legacy',
pos_enc_layer_type: str = 'rel_pos',
selfattention_layer_type: str = 'rel_selfattn',
activation_type: str = 'swish',
use_cnn_module: bool = True,
zero_triu: bool = False,
cnn_module_kernel: int = 31,
padding_idx: int = -1,
interctc_layer_idx: List[int] = [],
interctc_use_conditioning: bool = False,
stochastic_depth_rate: Union[float, List[float]] = 0.0,
):
assert check_argument_types()
super().__init__()
self._output_size = output_size
if rel_pos_type == 'legacy':
if pos_enc_layer_type == 'rel_pos':
pos_enc_layer_type = 'legacy_rel_pos'
if selfattention_layer_type == 'rel_selfattn':
selfattention_layer_type = 'legacy_rel_selfattn'
elif rel_pos_type == 'latest':
assert selfattention_layer_type != 'legacy_rel_selfattn'
assert pos_enc_layer_type != 'legacy_rel_pos'
else:
raise ValueError('unknown rel_pos_type: ' + rel_pos_type)
activation = get_activation(activation_type)
if pos_enc_layer_type == 'abs_pos':
pos_enc_class = PositionalEncoding
elif pos_enc_layer_type == 'scaled_abs_pos':
pos_enc_class = ScaledPositionalEncoding
elif pos_enc_layer_type == 'rel_pos':
assert selfattention_layer_type == 'rel_selfattn'
pos_enc_class = RelPositionalEncoding
elif pos_enc_layer_type == 'legacy_rel_pos':
assert selfattention_layer_type == 'legacy_rel_selfattn'
pos_enc_class = LegacyRelPositionalEncoding
else:
raise ValueError('unknown pos_enc_layer: ' + pos_enc_layer_type)
if input_layer == 'linear':
self.embed = torch.nn.Sequential(
torch.nn.Linear(input_size, output_size),
torch.nn.LayerNorm(output_size),
torch.nn.Dropout(dropout_rate),
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == 'conv2d':
self.embed = Conv2dSubsampling(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == 'conv2d2':
self.embed = Conv2dSubsampling2(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == 'conv2d6':
self.embed = Conv2dSubsampling6(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == 'conv2d8':
self.embed = Conv2dSubsampling8(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == 'embed':
self.embed = torch.nn.Sequential(
torch.nn.Embedding(
input_size, output_size, padding_idx=padding_idx),
pos_enc_class(output_size, positional_dropout_rate),
)
elif isinstance(input_layer, torch.nn.Module):
self.embed = torch.nn.Sequential(
input_layer,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer is None:
self.embed = torch.nn.Sequential(
pos_enc_class(output_size, positional_dropout_rate))
else:
raise ValueError('unknown input_layer: ' + input_layer)
self.normalize_before = normalize_before
if positionwise_layer_type == 'linear':
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
activation,
)
elif positionwise_layer_type == 'conv1d':
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
elif positionwise_layer_type == 'conv1d-linear':
positionwise_layer = Conv1dLinear
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
else:
raise NotImplementedError('Support only linear or conv1d.')
if selfattention_layer_type == 'selfattn':
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
)
elif selfattention_layer_type == 'legacy_rel_selfattn':
assert pos_enc_layer_type == 'legacy_rel_pos'
encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
)
elif selfattention_layer_type == 'rel_selfattn':
assert pos_enc_layer_type == 'rel_pos'
encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
zero_triu,
)
else:
raise ValueError('unknown encoder_attn_layer: '
+ selfattention_layer_type)
convolution_layer = ConvolutionModule
convolution_layer_args = (output_size, cnn_module_kernel, activation)
if isinstance(stochastic_depth_rate, float):
stochastic_depth_rate = [stochastic_depth_rate] * num_blocks
if len(stochastic_depth_rate) != num_blocks:
raise ValueError(
f'Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) '
f'should be equal to num_blocks ({num_blocks})')
self.encoders = repeat(
num_blocks,
lambda lnum: EncoderLayer(
output_size,
encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
positionwise_layer(*positionwise_layer_args)
if macaron_style else None,
convolution_layer(*convolution_layer_args)
if use_cnn_module else None,
dropout_rate,
normalize_before,
concat_after,
stochastic_depth_rate[lnum],
),
)
if self.normalize_before:
self.after_norm = LayerNorm(output_size)
self.interctc_layer_idx = interctc_layer_idx
if len(interctc_layer_idx) > 0:
assert 0 < min(interctc_layer_idx) and max(
interctc_layer_idx) < num_blocks
self.interctc_use_conditioning = interctc_use_conditioning
self.conditioning_layer = None
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
ctc: CTC = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Calculate forward propagation.
Args:
xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
ilens (torch.Tensor): Input length (#batch).
prev_states (torch.Tensor): Not to be used now.
Returns:
torch.Tensor: Output tensor (#batch, L, output_size).
torch.Tensor: Output length (#batch).
torch.Tensor: Not to be used now.
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
if (isinstance(self.embed, Conv2dSubsampling)
or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6)
or isinstance(self.embed, Conv2dSubsampling8)):
short_status, limit_size = check_short_utt(self.embed,
xs_pad.size(1))
if short_status:
raise TooShortUttError(
f'has {xs_pad.size(1)} frames and is too short for subsampling '
+ # noqa: *
f'(it needs more than {limit_size} frames), return empty results', # noqa: *
xs_pad.size(1),
limit_size) # noqa: *
xs_pad, masks = self.embed(xs_pad, masks)
else:
xs_pad = self.embed(xs_pad)
intermediate_outs = []
if len(self.interctc_layer_idx) == 0:
xs_pad, masks = self.encoders(xs_pad, masks)
else:
for layer_idx, encoder_layer in enumerate(self.encoders):
xs_pad, masks = encoder_layer(xs_pad, masks)
if layer_idx + 1 in self.interctc_layer_idx:
encoder_out = xs_pad
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
# intermediate outputs are also normalized
if self.normalize_before:
encoder_out = self.after_norm(encoder_out)
intermediate_outs.append((layer_idx + 1, encoder_out))
if self.interctc_use_conditioning:
ctc_out = ctc.softmax(encoder_out)
if isinstance(xs_pad, tuple):
x, pos_emb = xs_pad
x = x + self.conditioning_layer(ctc_out)
xs_pad = (x, pos_emb)
else:
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
if isinstance(xs_pad, tuple):
xs_pad = xs_pad[0]
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
olens = masks.squeeze(1).sum(1)
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None
class SANMEncoder_v2(AbsEncoder):
"""Conformer encoder module.
Args:
input_size (int): Input dimension.
output_size (int): Dimension of attention.
attention_heads (int): The number of heads of multi head attention.
linear_units (int): The number of units of position-wise feed forward.
num_blocks (int): The number of decoder blocks.
dropout_rate (float): Dropout rate.
attention_dropout_rate (float): Dropout rate in attention.
positional_dropout_rate (float): Dropout rate after adding positional encoding.
input_layer (Union[str, torch.nn.Module]): Input layer type.
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
If True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
If False, no additional linear will be applied. i.e. x -> x + att(x)
positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
rel_pos_type (str): Whether to use the latest relative positional encoding or
the legacy one. The legacy relative positional encoding will be deprecated
in the future. More Details can be found in
https://github.com/espnet/espnet/pull/2816.
encoder_pos_enc_layer_type (str): Encoder positional encoding layer type.
encoder_attn_layer_type (str): Encoder attention layer type.
activation_type (str): Encoder activation function type.
macaron_style (bool): Whether to use macaron style for positionwise layer.
use_cnn_module (bool): Whether to use convolution module.
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
cnn_module_kernel (int): Kernerl size of convolution module.
padding_idx (int): Padding idx for input_layer=embed.
"""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: str = 'conv2d',
normalize_before: bool = True,
concat_after: bool = False,
positionwise_layer_type: str = 'linear',
positionwise_conv_kernel_size: int = 3,
macaron_style: bool = False,
rel_pos_type: str = 'legacy',
pos_enc_layer_type: str = 'rel_pos',
selfattention_layer_type: str = 'rel_selfattn',
activation_type: str = 'swish',
use_cnn_module: bool = False,
sanm_shfit: int = 0,
zero_triu: bool = False,
cnn_module_kernel: int = 31,
padding_idx: int = -1,
interctc_layer_idx: List[int] = [],
interctc_use_conditioning: bool = False,
stochastic_depth_rate: Union[float, List[float]] = 0.0,
):
assert check_argument_types()
super().__init__()
self._output_size = output_size
if rel_pos_type == 'legacy':
if pos_enc_layer_type == 'rel_pos':
pos_enc_layer_type = 'legacy_rel_pos'
if selfattention_layer_type == 'rel_selfattn':
selfattention_layer_type = 'legacy_rel_selfattn'
if selfattention_layer_type == 'rel_selfattnsanm':
selfattention_layer_type = 'legacy_rel_selfattnsanm'
elif rel_pos_type == 'latest':
assert selfattention_layer_type != 'legacy_rel_selfattn'
assert pos_enc_layer_type != 'legacy_rel_pos'
else:
raise ValueError('unknown rel_pos_type: ' + rel_pos_type)
activation = get_activation(activation_type)
if pos_enc_layer_type == 'abs_pos':
pos_enc_class = PositionalEncoding
elif pos_enc_layer_type == 'scaled_abs_pos':
pos_enc_class = ScaledPositionalEncoding
elif pos_enc_layer_type == 'rel_pos':
# assert selfattention_layer_type == "rel_selfattn"
pos_enc_class = RelPositionalEncoding
elif pos_enc_layer_type == 'legacy_rel_pos':
# assert selfattention_layer_type == "legacy_rel_selfattn"
pos_enc_class = LegacyRelPositionalEncoding
logging.warning(
'Using legacy_rel_pos and it will be deprecated in the future.'
)
else:
raise ValueError('unknown pos_enc_layer: ' + pos_enc_layer_type)
if input_layer == 'linear':
self.embed = torch.nn.Sequential(
torch.nn.Linear(input_size, output_size),
torch.nn.LayerNorm(output_size),
torch.nn.Dropout(dropout_rate),
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == 'conv2d':
self.embed = Conv2dSubsampling(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == 'conv2d2':
self.embed = Conv2dSubsampling2(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == 'conv2d6':
self.embed = Conv2dSubsampling6(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == 'conv2d8':
self.embed = Conv2dSubsampling8(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == 'embed':
self.embed = torch.nn.Sequential(
torch.nn.Embedding(
input_size, output_size, padding_idx=padding_idx),
pos_enc_class(output_size, positional_dropout_rate),
)
elif isinstance(input_layer, torch.nn.Module):
self.embed = torch.nn.Sequential(
input_layer,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer is None:
self.embed = torch.nn.Sequential(
pos_enc_class(output_size, positional_dropout_rate))
else:
raise ValueError('unknown input_layer: ' + input_layer)
self.normalize_before = normalize_before
if positionwise_layer_type == 'linear':
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
activation,
)
elif positionwise_layer_type == 'conv1d':
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
elif positionwise_layer_type == 'conv1d-linear':
positionwise_layer = Conv1dLinear
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
else:
raise NotImplementedError('Support only linear or conv1d.')
if selfattention_layer_type == 'selfattn':
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
)
elif selfattention_layer_type == 'legacy_rel_selfattn':
assert pos_enc_layer_type == 'legacy_rel_pos'
encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
)
logging.warning(
'Using legacy_rel_selfattn and it will be deprecated in the future.'
)
elif selfattention_layer_type == 'legacy_rel_selfattnsanm':
assert pos_enc_layer_type == 'legacy_rel_pos'
encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttentionSANM
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
)
logging.warning(
'Using legacy_rel_selfattn and it will be deprecated in the future.'
)
elif selfattention_layer_type == 'rel_selfattn':
assert pos_enc_layer_type == 'rel_pos'
encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
zero_triu,
)
elif selfattention_layer_type == 'rel_selfattnsanm':
assert pos_enc_layer_type == 'rel_pos'
encoder_selfattn_layer = RelPositionMultiHeadedAttentionSANM
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
zero_triu,
cnn_module_kernel,
sanm_shfit,
)
else:
raise ValueError('unknown encoder_attn_layer: '
+ selfattention_layer_type)
convolution_layer = ConvolutionModule
convolution_layer_args = (output_size, cnn_module_kernel, activation)
if isinstance(stochastic_depth_rate, float):
stochastic_depth_rate = [stochastic_depth_rate] * num_blocks
if len(stochastic_depth_rate) != num_blocks:
raise ValueError(
f'Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) '
f'should be equal to num_blocks ({num_blocks})')
self.encoders = repeat(
num_blocks,
lambda lnum: EncoderLayer(
output_size,
encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
positionwise_layer(*positionwise_layer_args)
if macaron_style else None,
convolution_layer(*convolution_layer_args)
if use_cnn_module else None,
dropout_rate,
normalize_before,
concat_after,
stochastic_depth_rate[lnum],
),
)
if self.normalize_before:
self.after_norm = LayerNorm(output_size)
self.interctc_layer_idx = interctc_layer_idx
if len(interctc_layer_idx) > 0:
assert 0 < min(interctc_layer_idx) and max(
interctc_layer_idx) < num_blocks
self.interctc_use_conditioning = interctc_use_conditioning
self.conditioning_layer = None
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
ctc: CTC = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Calculate forward propagation.
Args:
xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
ilens (torch.Tensor): Input length (#batch).
prev_states (torch.Tensor): Not to be used now.
Returns:
torch.Tensor: Output tensor (#batch, L, output_size).
torch.Tensor: Output length (#batch).
torch.Tensor: Not to be used now.
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
if (isinstance(self.embed, Conv2dSubsampling)
or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6)
or isinstance(self.embed, Conv2dSubsampling8)):
short_status, limit_size = check_short_utt(self.embed,
xs_pad.size(1))
if short_status:
raise TooShortUttError(
f'has {xs_pad.size(1)} frames and is too short for subsampling '
+ # noqa: *
f'(it needs more than {limit_size} frames), return empty results',
xs_pad.size(1),
limit_size) # noqa: *
xs_pad, masks = self.embed(xs_pad, masks)
else:
xs_pad = self.embed(xs_pad)
intermediate_outs = []
if len(self.interctc_layer_idx) == 0:
xs_pad, masks = self.encoders(xs_pad, masks)
else:
for layer_idx, encoder_layer in enumerate(self.encoders):
xs_pad, masks = encoder_layer(xs_pad, masks)
if layer_idx + 1 in self.interctc_layer_idx:
encoder_out = xs_pad
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
# intermediate outputs are also normalized
if self.normalize_before:
encoder_out = self.after_norm(encoder_out)
intermediate_outs.append((layer_idx + 1, encoder_out))
if self.interctc_use_conditioning:
ctc_out = ctc.softmax(encoder_out)
if isinstance(xs_pad, tuple):
x, pos_emb = xs_pad
x = x + self.conditioning_layer(ctc_out)
xs_pad = (x, pos_emb)
else:
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
if isinstance(xs_pad, tuple):
xs_pad = xs_pad[0]
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
olens = masks.squeeze(1).sum(1)
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None

View File

@@ -0,0 +1,500 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from espnet/espnet.
"""Transformer encoder definition."""
import logging
from typing import List, Optional, Sequence, Tuple, Union
import torch
from espnet2.asr.ctc import CTC
from espnet2.asr.encoder.abs_encoder import AbsEncoder
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask
from espnet.nets.pytorch_backend.transformer.embedding import \
PositionalEncoding
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
from espnet.nets.pytorch_backend.transformer.multi_layer_conv import (
Conv1dLinear, MultiLayeredConv1d)
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import \
PositionwiseFeedForward # noqa: H301
from espnet.nets.pytorch_backend.transformer.repeat import repeat
from espnet.nets.pytorch_backend.transformer.subsampling import (
Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6,
Conv2dSubsampling8, TooShortUttError, check_short_utt)
from typeguard import check_argument_types
from ...asr.streaming_utilis.chunk_utilis import overlap_chunk
from ...nets.pytorch_backend.transformer.attention import (
MultiHeadedAttention, MultiHeadedAttentionSANM)
from ...nets.pytorch_backend.transformer.encoder_layer import (
EncoderLayer, EncoderLayerChunk)
class SANMEncoder(AbsEncoder):
"""Transformer encoder module.
Args:
input_size: input dim
output_size: dimension of attention
attention_heads: the number of heads of multi head attention
linear_units: the number of units of position-wise feed forward
num_blocks: the number of decoder blocks
dropout_rate: dropout rate
attention_dropout_rate: dropout rate in attention
positional_dropout_rate: dropout rate after adding positional encoding
input_layer: input layer type
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
normalize_before: whether to use layer_norm before the first block
concat_after: whether to concat attention layer's input and output
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied.
i.e. x -> x + att(x)
positionwise_layer_type: linear of conv1d
positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
padding_idx: padding_idx for input_layer=embed
"""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: Optional[str] = 'conv2d',
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
positionwise_layer_type: str = 'linear',
positionwise_conv_kernel_size: int = 1,
padding_idx: int = -1,
interctc_layer_idx: List[int] = [],
interctc_use_conditioning: bool = False,
kernel_size: int = 11,
sanm_shfit: int = 0,
selfattention_layer_type: str = 'sanm',
):
assert check_argument_types()
super().__init__()
self._output_size = output_size
if input_layer == 'linear':
self.embed = torch.nn.Sequential(
torch.nn.Linear(input_size, output_size),
torch.nn.LayerNorm(output_size),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == 'conv2d':
self.embed = Conv2dSubsampling(input_size, output_size,
dropout_rate)
elif input_layer == 'conv2d2':
self.embed = Conv2dSubsampling2(input_size, output_size,
dropout_rate)
elif input_layer == 'conv2d6':
self.embed = Conv2dSubsampling6(input_size, output_size,
dropout_rate)
elif input_layer == 'conv2d8':
self.embed = Conv2dSubsampling8(input_size, output_size,
dropout_rate)
elif input_layer == 'embed':
self.embed = torch.nn.Sequential(
torch.nn.Embedding(
input_size, output_size, padding_idx=padding_idx),
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer is None:
if input_size == output_size:
self.embed = None
else:
self.embed = torch.nn.Linear(input_size, output_size)
else:
raise ValueError('unknown input_layer: ' + input_layer)
self.normalize_before = normalize_before
if positionwise_layer_type == 'linear':
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
)
elif positionwise_layer_type == 'conv1d':
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
elif positionwise_layer_type == 'conv1d-linear':
positionwise_layer = Conv1dLinear
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
else:
raise NotImplementedError('Support only linear or conv1d.')
if selfattention_layer_type == 'selfattn':
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
)
elif selfattention_layer_type == 'sanm':
encoder_selfattn_layer = MultiHeadedAttentionSANM
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
kernel_size,
sanm_shfit,
)
self.encoders = repeat(
num_blocks,
lambda lnum: EncoderLayer(
output_size,
encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
)
if self.normalize_before:
self.after_norm = LayerNorm(output_size)
self.interctc_layer_idx = interctc_layer_idx
if len(interctc_layer_idx) > 0:
assert 0 < min(interctc_layer_idx) and max(
interctc_layer_idx) < num_blocks
self.interctc_use_conditioning = interctc_use_conditioning
self.conditioning_layer = None
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
ctc: CTC = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Embed positions in tensor.
Args:
xs_pad: input tensor (B, L, D)
ilens: input length (B)
prev_states: Not to be used now.
Returns:
position embedded tensor and mask
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
if self.embed is None:
xs_pad = xs_pad
elif (isinstance(self.embed, Conv2dSubsampling)
or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6)
or isinstance(self.embed, Conv2dSubsampling8)):
short_status, limit_size = check_short_utt(self.embed,
xs_pad.size(1))
if short_status:
raise TooShortUttError(
f'has {xs_pad.size(1)} frames and is too short for subsampling '
+ # noqa: *
f'(it needs more than {limit_size} frames), return empty results',
xs_pad.size(1),
limit_size,
)
xs_pad, masks = self.embed(xs_pad, masks)
else:
xs_pad = self.embed(xs_pad)
intermediate_outs = []
if len(self.interctc_layer_idx) == 0:
xs_pad, masks = self.encoders(xs_pad, masks)
else:
for layer_idx, encoder_layer in enumerate(self.encoders):
xs_pad, masks = encoder_layer(xs_pad, masks)
if layer_idx + 1 in self.interctc_layer_idx:
encoder_out = xs_pad
# intermediate outputs are also normalized
if self.normalize_before:
encoder_out = self.after_norm(encoder_out)
intermediate_outs.append((layer_idx + 1, encoder_out))
if self.interctc_use_conditioning:
ctc_out = ctc.softmax(encoder_out)
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
olens = masks.squeeze(1).sum(1)
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None
class SANMEncoderChunk(AbsEncoder):
"""Transformer encoder module.
Args:
input_size: input dim
output_size: dimension of attention
attention_heads: the number of heads of multi head attention
linear_units: the number of units of position-wise feed forward
num_blocks: the number of decoder blocks
dropout_rate: dropout rate
attention_dropout_rate: dropout rate in attention
positional_dropout_rate: dropout rate after adding positional encoding
input_layer: input layer type
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
normalize_before: whether to use layer_norm before the first block
concat_after: whether to concat attention layer's input and output
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied.
i.e. x -> x + att(x)
positionwise_layer_type: linear of conv1d
positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
padding_idx: padding_idx for input_layer=embed
"""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: Optional[str] = 'conv2d',
pos_enc_class=PositionalEncoding,
normalize_before: bool = True,
concat_after: bool = False,
positionwise_layer_type: str = 'linear',
positionwise_conv_kernel_size: int = 1,
padding_idx: int = -1,
interctc_layer_idx: List[int] = [],
interctc_use_conditioning: bool = False,
kernel_size: int = 11,
sanm_shfit: int = 0,
selfattention_layer_type: str = 'sanm',
chunk_size: Union[int, Sequence[int]] = (16, ),
stride: Union[int, Sequence[int]] = (10, ),
pad_left: Union[int, Sequence[int]] = (0, ),
encoder_att_look_back_factor: Union[int, Sequence[int]] = (1, ),
):
assert check_argument_types()
super().__init__()
self._output_size = output_size
if input_layer == 'linear':
self.embed = torch.nn.Sequential(
torch.nn.Linear(input_size, output_size),
torch.nn.LayerNorm(output_size),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == 'conv2d':
self.embed = Conv2dSubsampling(input_size, output_size,
dropout_rate)
elif input_layer == 'conv2d2':
self.embed = Conv2dSubsampling2(input_size, output_size,
dropout_rate)
elif input_layer == 'conv2d6':
self.embed = Conv2dSubsampling6(input_size, output_size,
dropout_rate)
elif input_layer == 'conv2d8':
self.embed = Conv2dSubsampling8(input_size, output_size,
dropout_rate)
elif input_layer == 'embed':
self.embed = torch.nn.Sequential(
torch.nn.Embedding(
input_size, output_size, padding_idx=padding_idx),
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer is None:
if input_size == output_size:
self.embed = None
else:
self.embed = torch.nn.Linear(input_size, output_size)
else:
raise ValueError('unknown input_layer: ' + input_layer)
self.normalize_before = normalize_before
if positionwise_layer_type == 'linear':
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
)
elif positionwise_layer_type == 'conv1d':
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
elif positionwise_layer_type == 'conv1d-linear':
positionwise_layer = Conv1dLinear
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
else:
raise NotImplementedError('Support only linear or conv1d.')
if selfattention_layer_type == 'selfattn':
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
)
elif selfattention_layer_type == 'sanm':
encoder_selfattn_layer = MultiHeadedAttentionSANM
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
kernel_size,
sanm_shfit,
)
self.encoders = repeat(
num_blocks,
lambda lnum: EncoderLayerChunk(
output_size,
encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
)
if self.normalize_before:
self.after_norm = LayerNorm(output_size)
self.interctc_layer_idx = interctc_layer_idx
if len(interctc_layer_idx) > 0:
assert 0 < min(interctc_layer_idx) and max(
interctc_layer_idx) < num_blocks
self.interctc_use_conditioning = interctc_use_conditioning
self.conditioning_layer = None
shfit_fsmn = (kernel_size - 1) // 2
self.overlap_chunk_cls = overlap_chunk(
chunk_size=chunk_size,
stride=stride,
pad_left=pad_left,
shfit_fsmn=shfit_fsmn,
encoder_att_look_back_factor=encoder_att_look_back_factor,
)
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
ctc: CTC = None,
ind: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Embed positions in tensor.
Args:
xs_pad: input tensor (B, L, D)
ilens: input length (B)
prev_states: Not to be used now.
Returns:
position embedded tensor and mask
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
if self.embed is None:
xs_pad = xs_pad
elif (isinstance(self.embed, Conv2dSubsampling)
or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6)
or isinstance(self.embed, Conv2dSubsampling8)):
short_status, limit_size = check_short_utt(self.embed,
xs_pad.size(1))
if short_status:
raise TooShortUttError(
f'has {xs_pad.size(1)} frames and is too short for subsampling '
+ # noqa: *
f'(it needs more than {limit_size} frames), return empty results',
xs_pad.size(1),
limit_size,
)
xs_pad, masks = self.embed(xs_pad, masks)
else:
xs_pad = self.embed(xs_pad)
mask_shfit_chunk, mask_att_chunk_encoder = None, None
if self.overlap_chunk_cls is not None:
ilens = masks.squeeze(1).sum(1)
chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind)
xs_pad, ilens = self.overlap_chunk_cls.split_chunk(
xs_pad, ilens, chunk_outs=chunk_outs)
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(
chunk_outs, xs_pad.device, xs_pad.size(0), dtype=xs_pad.dtype)
mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(
chunk_outs, xs_pad.device, xs_pad.size(0), dtype=xs_pad.dtype)
intermediate_outs = []
if len(self.interctc_layer_idx) == 0:
xs_pad, masks, _, _, _ = self.encoders(xs_pad, masks, None,
mask_shfit_chunk,
mask_att_chunk_encoder)
else:
for layer_idx, encoder_layer in enumerate(self.encoders):
xs_pad, masks, _, _, _ = encoder_layer(xs_pad, masks, None,
mask_shfit_chunk,
mask_att_chunk_encoder)
if layer_idx + 1 in self.interctc_layer_idx:
encoder_out = xs_pad
# intermediate outputs are also normalized
if self.normalize_before:
encoder_out = self.after_norm(encoder_out)
intermediate_outs.append((layer_idx + 1, encoder_out))
if self.interctc_use_conditioning:
ctc_out = ctc.softmax(encoder_out)
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
if self.overlap_chunk_cls is not None:
xs_pad, olens = self.overlap_chunk_cls.remove_chunk(
xs_pad, ilens, chunk_outs)
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,321 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from espnet/espnet.
import logging
import math
import numpy as np
import torch
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask
from ...nets.pytorch_backend.cif_utils.cif import \
cif_predictor as cif_predictor
np.set_printoptions(threshold=np.inf)
torch.set_printoptions(profile='full', precision=100000, linewidth=None)
def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device='cpu'):
if maxlen is None:
maxlen = lengths.max()
row_vector = torch.arange(0, maxlen, 1)
matrix = torch.unsqueeze(lengths, dim=-1)
mask = row_vector < matrix
return mask.type(dtype).to(device)
class overlap_chunk():
def __init__(
self,
chunk_size: tuple = (16, ),
stride: tuple = (10, ),
pad_left: tuple = (0, ),
encoder_att_look_back_factor: tuple = (1, ),
shfit_fsmn: int = 0,
):
self.chunk_size, self.stride, self.pad_left, self.encoder_att_look_back_factor \
= chunk_size, stride, pad_left, encoder_att_look_back_factor
self.shfit_fsmn = shfit_fsmn
self.x_add_mask = None
self.x_rm_mask = None
self.x_len = None
self.mask_shfit_chunk = None
self.mask_chunk_predictor = None
self.mask_att_chunk_encoder = None
self.mask_shift_att_chunk_decoder = None
self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur \
= None, None, None, None
def get_chunk_size(self, ind: int = 0):
# with torch.no_grad:
chunk_size, stride, pad_left, encoder_att_look_back_factor = self.chunk_size[
ind], self.stride[ind], self.pad_left[
ind], self.encoder_att_look_back_factor[ind]
self.chunk_size_cur, self.stride_cur, self.pad_left_cur,
self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur \
= chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size + self.shfit_fsmn
return self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur
def gen_chunk_mask(self, x_len, ind=0, num_units=1, num_units_predictor=1):
with torch.no_grad():
x_len = x_len.cpu().numpy()
x_len_max = x_len.max()
chunk_size, stride, pad_left, encoder_att_look_back_factor = self.get_chunk_size(
ind)
shfit_fsmn = self.shfit_fsmn
chunk_size_pad_shift = chunk_size + shfit_fsmn
chunk_num_batch = np.ceil(x_len / stride).astype(np.int32)
x_len_chunk = (
chunk_num_batch - 1
) * chunk_size_pad_shift + shfit_fsmn + pad_left + 0 + x_len - (
chunk_num_batch - 1) * stride
x_len_chunk = x_len_chunk.astype(x_len.dtype)
x_len_chunk_max = x_len_chunk.max()
chunk_num = int(math.ceil(x_len_max / stride))
dtype = np.int32
max_len_for_x_mask_tmp = max(chunk_size, x_len_max)
x_add_mask = np.zeros([0, max_len_for_x_mask_tmp], dtype=dtype)
x_rm_mask = np.zeros([max_len_for_x_mask_tmp, 0], dtype=dtype)
mask_shfit_chunk = np.zeros([0, num_units], dtype=dtype)
mask_chunk_predictor = np.zeros([0, num_units_predictor],
dtype=dtype)
mask_shift_att_chunk_decoder = np.zeros([0, 1], dtype=dtype)
mask_att_chunk_encoder = np.zeros(
[0, chunk_num * chunk_size_pad_shift], dtype=dtype)
for chunk_ids in range(chunk_num):
# x_mask add
fsmn_padding = np.zeros((shfit_fsmn, max_len_for_x_mask_tmp),
dtype=dtype)
x_mask_cur = np.diag(np.ones(chunk_size, dtype=np.float32))
x_mask_pad_left = np.zeros((chunk_size, chunk_ids * stride),
dtype=dtype)
x_mask_pad_right = np.zeros(
(chunk_size, max_len_for_x_mask_tmp), dtype=dtype)
x_cur_pad = np.concatenate(
[x_mask_pad_left, x_mask_cur, x_mask_pad_right], axis=1)
x_cur_pad = x_cur_pad[:chunk_size, :max_len_for_x_mask_tmp]
x_add_mask_fsmn = np.concatenate([fsmn_padding, x_cur_pad],
axis=0)
x_add_mask = np.concatenate([x_add_mask, x_add_mask_fsmn],
axis=0)
# x_mask rm
fsmn_padding = np.zeros((max_len_for_x_mask_tmp, shfit_fsmn),
dtype=dtype)
x_mask_cur = np.diag(np.ones(stride, dtype=dtype))
x_mask_right = np.zeros((stride, chunk_size - stride),
dtype=dtype)
x_mask_cur = np.concatenate([x_mask_cur, x_mask_right], axis=1)
x_mask_cur_pad_top = np.zeros((chunk_ids * stride, chunk_size),
dtype=dtype)
x_mask_cur_pad_bottom = np.zeros(
(max_len_for_x_mask_tmp, chunk_size), dtype=dtype)
x_rm_mask_cur = np.concatenate(
[x_mask_cur_pad_top, x_mask_cur, x_mask_cur_pad_bottom],
axis=0)
x_rm_mask_cur = x_rm_mask_cur[:max_len_for_x_mask_tmp, :
chunk_size]
x_rm_mask_cur_fsmn = np.concatenate(
[fsmn_padding, x_rm_mask_cur], axis=1)
x_rm_mask = np.concatenate([x_rm_mask, x_rm_mask_cur_fsmn],
axis=1)
# fsmn_padding_mask
pad_shfit_mask = np.zeros([shfit_fsmn, num_units], dtype=dtype)
ones_1 = np.ones([chunk_size, num_units], dtype=dtype)
mask_shfit_chunk_cur = np.concatenate([pad_shfit_mask, ones_1],
axis=0)
mask_shfit_chunk = np.concatenate(
[mask_shfit_chunk, mask_shfit_chunk_cur], axis=0)
# predictor mask
zeros_1 = np.zeros(
[shfit_fsmn + pad_left, num_units_predictor], dtype=dtype)
ones_2 = np.ones([stride, num_units_predictor], dtype=dtype)
zeros_3 = np.zeros(
[chunk_size - stride - pad_left, num_units_predictor],
dtype=dtype)
ones_zeros = np.concatenate([ones_2, zeros_3], axis=0)
mask_chunk_predictor_cur = np.concatenate(
[zeros_1, ones_zeros], axis=0)
mask_chunk_predictor = np.concatenate(
[mask_chunk_predictor, mask_chunk_predictor_cur], axis=0)
# encoder att mask
zeros_1_top = np.zeros(
[shfit_fsmn, chunk_num * chunk_size_pad_shift],
dtype=dtype)
zeros_2_num = max(chunk_ids - encoder_att_look_back_factor, 0)
zeros_2 = np.zeros(
[chunk_size, zeros_2_num * chunk_size_pad_shift],
dtype=dtype)
encoder_att_look_back_num = max(chunk_ids - zeros_2_num, 0)
zeros_2_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
ones_2_mid = np.ones([stride, stride], dtype=dtype)
zeros_2_bottom = np.zeros([chunk_size - stride, stride],
dtype=dtype)
zeros_2_right = np.zeros([chunk_size, chunk_size - stride],
dtype=dtype)
ones_2 = np.concatenate([ones_2_mid, zeros_2_bottom], axis=0)
ones_2 = np.concatenate([zeros_2_left, ones_2, zeros_2_right],
axis=1)
ones_2 = np.tile(ones_2, [1, encoder_att_look_back_num])
zeros_3_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
ones_3_right = np.ones([chunk_size, chunk_size], dtype=dtype)
ones_3 = np.concatenate([zeros_3_left, ones_3_right], axis=1)
zeros_remain_num = max(chunk_num - 1 - chunk_ids, 0)
zeros_remain = np.zeros(
[chunk_size, zeros_remain_num * chunk_size_pad_shift],
dtype=dtype)
ones2_bottom = np.concatenate(
[zeros_2, ones_2, ones_3, zeros_remain], axis=1)
mask_att_chunk_encoder_cur = np.concatenate(
[zeros_1_top, ones2_bottom], axis=0)
mask_att_chunk_encoder = np.concatenate(
[mask_att_chunk_encoder, mask_att_chunk_encoder_cur],
axis=0)
# decoder fsmn_shift_att_mask
zeros_1 = np.zeros([shfit_fsmn, 1])
ones_1 = np.ones([chunk_size, 1])
mask_shift_att_chunk_decoder_cur = np.concatenate(
[zeros_1, ones_1], axis=0)
mask_shift_att_chunk_decoder = np.concatenate(
[
mask_shift_att_chunk_decoder,
mask_shift_att_chunk_decoder_cur
],
vaxis=0) # noqa: *
self.x_add_mask = x_add_mask[:x_len_chunk_max, :x_len_max]
self.x_len_chunk = x_len_chunk
self.x_rm_mask = x_rm_mask[:x_len_max, :x_len_chunk_max]
self.x_len = x_len
self.mask_shfit_chunk = mask_shfit_chunk[:x_len_chunk_max, :]
self.mask_chunk_predictor = mask_chunk_predictor[:
x_len_chunk_max, :]
self.mask_att_chunk_encoder = mask_att_chunk_encoder[:
x_len_chunk_max, :
x_len_chunk_max]
self.mask_shift_att_chunk_decoder = mask_shift_att_chunk_decoder[:
x_len_chunk_max, :]
return (self.x_add_mask, self.x_len_chunk, self.x_rm_mask, self.x_len,
self.mask_shfit_chunk, self.mask_chunk_predictor,
self.mask_att_chunk_encoder, self.mask_shift_att_chunk_decoder)
def split_chunk(self, x, x_len, chunk_outs):
"""
:param x: (b, t, d)
:param x_length: (b)
:param ind: int
:return:
"""
x = x[:, :x_len.max(), :]
b, t, d = x.size()
x_len_mask = (~make_pad_mask(x_len, maxlen=t)).to(x.device)
x *= x_len_mask[:, :, None]
x_add_mask = self.get_x_add_mask(chunk_outs, x.device, dtype=x.dtype)
x_len_chunk = self.get_x_len_chunk(
chunk_outs, x_len.device, dtype=x_len.dtype)
x = torch.transpose(x, 1, 0)
x = torch.reshape(x, [t, -1])
x_chunk = torch.mm(x_add_mask, x)
x_chunk = torch.reshape(x_chunk, [-1, b, d]).transpose(1, 0)
return x_chunk, x_len_chunk
def remove_chunk(self, x_chunk, x_len_chunk, chunk_outs):
x_chunk = x_chunk[:, :x_len_chunk.max(), :]
b, t, d = x_chunk.size()
x_len_chunk_mask = (~make_pad_mask(x_len_chunk, maxlen=t)).to(
x_chunk.device)
x_chunk *= x_len_chunk_mask[:, :, None]
x_rm_mask = self.get_x_rm_mask(
chunk_outs, x_chunk.device, dtype=x_chunk.dtype)
x_len = self.get_x_len(
chunk_outs, x_len_chunk.device, dtype=x_len_chunk.dtype)
x_chunk = torch.transpose(x_chunk, 1, 0)
x_chunk = torch.reshape(x_chunk, [t, -1])
x = torch.mm(x_rm_mask, x_chunk)
x = torch.reshape(x, [-1, b, d]).transpose(1, 0)
return x, x_len
def get_x_add_mask(self, chunk_outs, device, idx=0, dtype=torch.float32):
x = chunk_outs[idx]
x = torch.from_numpy(x).type(dtype).to(device)
return x.detach()
def get_x_len_chunk(self, chunk_outs, device, idx=1, dtype=torch.float32):
x = chunk_outs[idx]
x = torch.from_numpy(x).type(dtype).to(device)
return x.detach()
def get_x_rm_mask(self, chunk_outs, device, idx=2, dtype=torch.float32):
x = chunk_outs[idx]
x = torch.from_numpy(x).type(dtype).to(device)
return x.detach()
def get_x_len(self, chunk_outs, device, idx=3, dtype=torch.float32):
x = chunk_outs[idx]
x = torch.from_numpy(x).type(dtype).to(device)
return x.detach()
def get_mask_shfit_chunk(self,
chunk_outs,
device,
batch_size=1,
num_units=1,
idx=4,
dtype=torch.float32):
x = chunk_outs[idx]
x = np.tile(x[None, :, :, ], [batch_size, 1, num_units])
x = torch.from_numpy(x).type(dtype).to(device)
return x.detach()
def get_mask_chunk_predictor(self,
chunk_outs,
device,
batch_size=1,
num_units=1,
idx=5,
dtype=torch.float32):
x = chunk_outs[idx]
x = np.tile(x[None, :, :, ], [batch_size, 1, num_units])
x = torch.from_numpy(x).type(dtype).to(device)
return x.detach()
def get_mask_att_chunk_encoder(self,
chunk_outs,
device,
batch_size=1,
idx=6,
dtype=torch.float32):
x = chunk_outs[idx]
x = np.tile(x[None, :, :, ], [batch_size, 1, 1])
x = torch.from_numpy(x).type(dtype).to(device)
return x.detach()
def get_mask_shift_att_chunk_decoder(self,
chunk_outs,
device,
batch_size=1,
idx=7,
dtype=torch.float32):
x = chunk_outs[idx]
x = np.tile(x[None, None, :, 0], [batch_size, 1, 1])
x = torch.from_numpy(x).type(dtype).to(device)
return x.detach()

View File

@@ -0,0 +1,250 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from espnet/espnet.
import logging
import numpy as np
import torch
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask
from torch import nn
class CIF_Model(nn.Module):
def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1):
super(CIF_Model, self).__init__()
self.pad = nn.ConstantPad1d((l_order, r_order), 0)
self.cif_conv1d = nn.Conv1d(
idim, idim, l_order + r_order + 1, groups=idim)
self.cif_output = nn.Linear(idim, 1)
self.dropout = torch.nn.Dropout(p=dropout)
self.threshold = threshold
def forward(self, hidden, target_label=None, mask=None, ignore_id=-1):
h = hidden
context = h.transpose(1, 2)
queries = self.pad(context)
memory = self.cif_conv1d(queries)
output = memory + context
output = self.dropout(output)
output = output.transpose(1, 2)
output = torch.relu(output)
output = self.cif_output(output)
alphas = torch.sigmoid(output)
if mask is not None:
alphas = alphas * mask.transpose(-1, -2).float()
alphas = alphas.squeeze(-1)
if target_label is not None:
target_length = (target_label != ignore_id).float().sum(-1)
else:
target_length = None
cif_length = alphas.sum(-1)
if target_label is not None:
alphas *= (target_length / cif_length)[:, None].repeat(
1, alphas.size(1))
cif_output, cif_peak = cif(hidden, alphas, self.threshold)
return cif_output, cif_length, target_length, cif_peak
def gen_frame_alignments(self,
alphas: torch.Tensor = None,
memory_sequence_length: torch.Tensor = None,
is_training: bool = True,
dtype: torch.dtype = torch.float32):
batch_size, maximum_length = alphas.size()
int_type = torch.int32
token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
max_token_num = torch.max(token_num).item()
alphas_cumsum = torch.cumsum(alphas, dim=1)
alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
alphas_cumsum = torch.tile(alphas_cumsum[:, None, :],
[1, max_token_num, 1])
index = torch.ones([batch_size, max_token_num], dtype=int_type)
index = torch.cumsum(index, dim=1)
index = torch.tile(index[:, :, None], [1, 1, maximum_length])
index_div = torch.floor(torch.divide(alphas_cumsum,
index)).type(int_type)
index_div_bool_zeros = index_div.eq(0)
index_div_bool_zeros_count = torch.sum(
index_div_bool_zeros, dim=-1) + 1
index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count, 0,
memory_sequence_length.max())
token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(
token_num.device)
index_div_bool_zeros_count *= token_num_mask
index_div_bool_zeros_count_tile = torch.tile(
index_div_bool_zeros_count[:, :, None], [1, 1, maximum_length])
ones = torch.ones_like(index_div_bool_zeros_count_tile)
zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
ones = torch.cumsum(ones, dim=2)
cond = index_div_bool_zeros_count_tile == ones
index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(
torch.bool)
index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(
int_type)
index_div_bool_zeros_count_tile_out = torch.sum(
index_div_bool_zeros_count_tile, dim=1)
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(
int_type)
predictor_mask = (~make_pad_mask(
memory_sequence_length,
maxlen=memory_sequence_length.max())).type(int_type).to(
memory_sequence_length.device) # noqa: *
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask
return index_div_bool_zeros_count_tile_out.detach(
), index_div_bool_zeros_count.detach()
class cif_predictor(nn.Module):
def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1):
super(cif_predictor, self).__init__()
self.pad = nn.ConstantPad1d((l_order, r_order), 0)
self.cif_conv1d = nn.Conv1d(
idim, idim, l_order + r_order + 1, groups=idim)
self.cif_output = nn.Linear(idim, 1)
self.dropout = torch.nn.Dropout(p=dropout)
self.threshold = threshold
def forward(self,
hidden,
target_label=None,
mask=None,
ignore_id=-1,
mask_chunk_predictor=None,
target_label_length=None):
h = hidden
context = h.transpose(1, 2)
queries = self.pad(context)
memory = self.cif_conv1d(queries)
output = memory + context
output = self.dropout(output)
output = output.transpose(1, 2)
output = torch.relu(output)
output = self.cif_output(output)
alphas = torch.sigmoid(output)
if mask is not None:
alphas = alphas * mask.transpose(-1, -2).float()
if mask_chunk_predictor is not None:
alphas = alphas * mask_chunk_predictor
alphas = alphas.squeeze(-1)
if target_label_length is not None:
target_length = target_label_length
elif target_label is not None:
target_length = (target_label != ignore_id).float().sum(-1)
else:
target_length = None
token_num = alphas.sum(-1)
if target_length is not None:
alphas *= (target_length / token_num)[:, None].repeat(
1, alphas.size(1))
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
return acoustic_embeds, token_num, alphas, cif_peak
def gen_frame_alignments(self,
alphas: torch.Tensor = None,
memory_sequence_length: torch.Tensor = None,
is_training: bool = True,
dtype: torch.dtype = torch.float32):
batch_size, maximum_length = alphas.size()
int_type = torch.int32
token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
max_token_num = torch.max(token_num).item()
alphas_cumsum = torch.cumsum(alphas, dim=1)
alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
alphas_cumsum = torch.tile(alphas_cumsum[:, None, :],
[1, max_token_num, 1])
index = torch.ones([batch_size, max_token_num], dtype=int_type)
index = torch.cumsum(index, dim=1)
index = torch.tile(index[:, :, None], [1, 1, maximum_length])
index_div = torch.floor(torch.divide(alphas_cumsum,
index)).type(int_type)
index_div_bool_zeros = index_div.eq(0)
index_div_bool_zeros_count = torch.sum(
index_div_bool_zeros, dim=-1) + 1
index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count, 0,
memory_sequence_length.max())
token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(
token_num.device)
index_div_bool_zeros_count *= token_num_mask
index_div_bool_zeros_count_tile = torch.tile(
index_div_bool_zeros_count[:, :, None], [1, 1, maximum_length])
ones = torch.ones_like(index_div_bool_zeros_count_tile)
zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
ones = torch.cumsum(ones, dim=2)
cond = index_div_bool_zeros_count_tile == ones
index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(
torch.bool)
index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(
int_type)
index_div_bool_zeros_count_tile_out = torch.sum(
index_div_bool_zeros_count_tile, dim=1)
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(
int_type)
predictor_mask = (~make_pad_mask(
memory_sequence_length,
maxlen=memory_sequence_length.max())).type(int_type).to(
memory_sequence_length.device) # noqa: *
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask
return index_div_bool_zeros_count_tile_out.detach(
), index_div_bool_zeros_count.detach()
def cif(hidden, alphas, threshold):
batch_size, len_time, hidden_size = hidden.size()
# loop varss
integrate = torch.zeros([batch_size], device=hidden.device)
frame = torch.zeros([batch_size, hidden_size], device=hidden.device)
# intermediate vars along time
list_fires = []
list_frames = []
for t in range(len_time):
alpha = alphas[:, t]
distribution_completion = torch.ones([batch_size],
device=hidden.device) - integrate
integrate += alpha
list_fires.append(integrate)
fire_place = integrate >= threshold
integrate = torch.where(
fire_place,
integrate - torch.ones([batch_size], device=hidden.device),
integrate)
cur = torch.where(fire_place, distribution_completion, alpha)
remainds = alpha - cur
frame += cur[:, None] * hidden[:, t, :]
list_frames.append(frame)
frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
remainds[:, None] * hidden[:, t, :], frame)
fires = torch.stack(list_fires, 1)
frames = torch.stack(list_frames, 1)
list_ls = []
len_labels = torch.round(alphas.sum(-1)).int()
max_label_len = len_labels.max()
for b in range(batch_size):
fire = fires[b, :]
ls = torch.index_select(frames[b, :, :], 0,
torch.nonzero(fire >= threshold).squeeze())
pad_l = torch.zeros([max_label_len - ls.size(0), hidden_size],
device=hidden.device)
list_ls.append(torch.cat([ls, pad_l], 0))
return torch.stack(list_ls, 0), fires

View File

@@ -0,0 +1,680 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from espnet/espnet.
"""Multi-Head Attention layer definition."""
import logging
import math
import numpy
import torch
from torch import nn
torch.set_printoptions(profile='full', precision=1)
class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""
def __init__(self, n_head, n_feat, dropout_rate):
"""Construct an MultiHeadedAttention object."""
super(MultiHeadedAttention, self).__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
self.linear_q = nn.Linear(n_feat, n_feat)
self.linear_k = nn.Linear(n_feat, n_feat)
self.linear_v = nn.Linear(n_feat, n_feat)
self.linear_out = nn.Linear(n_feat, n_feat)
self.attn = None
self.dropout = nn.Dropout(p=dropout_rate)
def forward_qkv(self, query, key, value):
"""Transform query, key and value.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
Returns:
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
"""
n_batch = query.size(0)
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
q = q.transpose(1, 2) # (batch, head, time1, d_k)
k = k.transpose(1, 2) # (batch, head, time2, d_k)
v = v.transpose(1, 2) # (batch, head, time2, d_k)
return q, k, v
def forward_attention(self, value, scores, mask):
"""Compute attention context vector.
Args:
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
Returns:
torch.Tensor: Transformed value (#batch, time1, d_model)
weighted by the attention score (#batch, time1, time2).
"""
n_batch = value.size(0)
if mask is not None:
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
min_value = float(
numpy.finfo(torch.tensor(
0, dtype=scores.dtype).numpy().dtype).min)
scores = scores.masked_fill(mask, min_value)
self.attn = torch.softmax(
scores, dim=-1).masked_fill(mask,
0.0) # (batch, head, time1, time2)
else:
self.attn = torch.softmax(
scores, dim=-1) # (batch, head, time1, time2)
p_attn = self.dropout(self.attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
self.h * self.d_k)
) # (batch, time1, d_model)
return self.linear_out(x) # (batch, time1, d_model)
def forward(self, query, key, value, mask):
"""Compute scaled dot product attention.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2).
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
"""
q, k, v = self.forward_qkv(query, key, value)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask)
class MultiHeadedAttentionSANM(nn.Module):
"""Multi-Head Attention layer.
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""
def __init__(self,
n_head,
n_feat,
dropout_rate,
kernel_size,
sanm_shfit=0):
"""Construct an MultiHeadedAttention object."""
super(MultiHeadedAttentionSANM, self).__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
self.linear_q = nn.Linear(n_feat, n_feat)
self.linear_k = nn.Linear(n_feat, n_feat)
self.linear_v = nn.Linear(n_feat, n_feat)
self.linear_out = nn.Linear(n_feat, n_feat)
self.attn = None
self.dropout = nn.Dropout(p=dropout_rate)
self.fsmn_block = nn.Conv1d(
n_feat,
n_feat,
kernel_size,
stride=1,
padding=0,
groups=n_feat,
bias=False)
# padding
left_padding = (kernel_size - 1) // 2
if sanm_shfit > 0:
left_padding = left_padding + sanm_shfit
right_padding = kernel_size - 1 - left_padding
self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None):
'''
:param x: (#batch, time1, size).
:param mask: Mask tensor (#batch, 1, time)
:return:
'''
# b, t, d = inputs.size()
mask = mask[:, 0, :, None]
if mask_shfit_chunk is not None:
mask = mask * mask_shfit_chunk
inputs *= mask
x = inputs.transpose(1, 2)
x = self.pad_fn(x)
x = self.fsmn_block(x)
x = x.transpose(1, 2)
x += inputs
x = self.dropout(x)
return x * mask
def forward_qkv(self, query, key, value):
"""Transform query, key and value.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
Returns:
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
"""
n_batch = query.size(0)
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
q = q.transpose(1, 2) # (batch, head, time1, d_k)
k = k.transpose(1, 2) # (batch, head, time2, d_k)
v = v.transpose(1, 2) # (batch, head, time2, d_k)
return q, k, v
def forward_attention(self,
value,
scores,
mask,
mask_att_chunk_encoder=None):
"""Compute attention context vector.
Args:
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
Returns:
torch.Tensor: Transformed value (#batch, time1, d_model)
weighted by the attention score (#batch, time1, time2).
"""
n_batch = value.size(0)
if mask is not None:
if mask_att_chunk_encoder is not None:
mask = mask * mask_att_chunk_encoder
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
min_value = float(
numpy.finfo(torch.tensor(
0, dtype=scores.dtype).numpy().dtype).min)
scores = scores.masked_fill(mask, min_value)
self.attn = torch.softmax(
scores, dim=-1).masked_fill(mask,
0.0) # (batch, head, time1, time2)
else:
self.attn = torch.softmax(
scores, dim=-1) # (batch, head, time1, time2)
p_attn = self.dropout(self.attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
self.h * self.d_k)
) # (batch, time1, d_model)
return self.linear_out(x) # (batch, time1, d_model)
def forward(self,
query,
key,
value,
mask,
mask_shfit_chunk=None,
mask_att_chunk_encoder=None):
"""Compute scaled dot product attention.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2).
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
"""
fsmn_memory = self.forward_fsmn(value, mask, mask_shfit_chunk)
q, k, v = self.forward_qkv(query, key, value)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
att_outs = self.forward_attention(v, scores, mask,
mask_att_chunk_encoder)
return att_outs + fsmn_memory
class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
"""
def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate)
self.zero_triu = zero_triu
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
torch.nn.init.xavier_uniform_(self.pos_bias_u)
torch.nn.init.xavier_uniform_(self.pos_bias_v)
def rel_shift(self, x):
"""Compute relative positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, head, time1, time2).
Returns:
torch.Tensor: Output tensor.
"""
zero_pad = torch.zeros((*x.size()[:3], 1),
device=x.device,
dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
x = x_padded[:, :, 1:].view_as(x)
if self.zero_triu:
ones = torch.ones((x.size(2), x.size(3)))
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
return x
def forward(self, query, key, value, pos_emb, mask):
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
pos_emb (torch.Tensor): Positional embedding tensor (#batch, time1, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2).
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
"""
q, k, v = self.forward_qkv(query, key, value)
q = q.transpose(1, 2) # (batch, time1, head, d_k)
n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, time1, d_k)
# (batch, head, time1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
# (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
# compute matrix b and matrix d
# (batch, head, time1, time1)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
matrix_bd = self.rel_shift(matrix_bd)
scores = (matrix_ac + matrix_bd) / math.sqrt(
self.d_k) # (batch, head, time1, time2)
return self.forward_attention(v, scores, mask)
class LegacyRelPositionMultiHeadedAttentionSANM(MultiHeadedAttentionSANM):
"""Multi-Head Attention layer with relative position encoding (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
"""
def __init__(self,
n_head,
n_feat,
dropout_rate,
zero_triu=False,
kernel_size=15,
sanm_shfit=0):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate, kernel_size, sanm_shfit)
self.zero_triu = zero_triu
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
torch.nn.init.xavier_uniform_(self.pos_bias_u)
torch.nn.init.xavier_uniform_(self.pos_bias_v)
def rel_shift(self, x):
"""Compute relative positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, head, time1, time2).
Returns:
torch.Tensor: Output tensor.
"""
zero_pad = torch.zeros((*x.size()[:3], 1),
device=x.device,
dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
x = x_padded[:, :, 1:].view_as(x)
if self.zero_triu:
ones = torch.ones((x.size(2), x.size(3)))
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
return x
def forward(self, query, key, value, pos_emb, mask):
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
pos_emb (torch.Tensor): Positional embedding tensor (#batch, time1, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2).
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
"""
fsmn_memory = self.forward_fsmn(value, mask)
q, k, v = self.forward_qkv(query, key, value)
q = q.transpose(1, 2) # (batch, time1, head, d_k)
n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, time1, d_k)
# (batch, head, time1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
# (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
# compute matrix b and matrix d
# (batch, head, time1, time1)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
matrix_bd = self.rel_shift(matrix_bd)
scores = (matrix_ac + matrix_bd) / math.sqrt(
self.d_k) # (batch, head, time1, time2)
att_outs = self.forward_attention(v, scores, mask)
return att_outs + fsmn_memory
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding (new implementation).
Details can be found in https://github.com/espnet/espnet/pull/2816.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
"""
def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate)
self.zero_triu = zero_triu
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
torch.nn.init.xavier_uniform_(self.pos_bias_u)
torch.nn.init.xavier_uniform_(self.pos_bias_v)
def rel_shift(self, x):
"""Compute relative positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
time1 means the length of query vector.
Returns:
torch.Tensor: Output tensor.
"""
zero_pad = torch.zeros((*x.size()[:3], 1),
device=x.device,
dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
x = x_padded[:, :, 1:].view_as(
x)[:, :, :, :x.size(-1) // 2
+ 1] # only keep the positions from 0 to time2
if self.zero_triu:
ones = torch.ones((x.size(2), x.size(3)), device=x.device)
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
return x
def forward(self, query, key, value, pos_emb, mask):
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
pos_emb (torch.Tensor): Positional embedding tensor
(#batch, 2*time1-1, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2).
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
"""
q, k, v = self.forward_qkv(query, key, value)
q = q.transpose(1, 2) # (batch, time1, head, d_k)
n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
# (batch, head, time1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
# (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
# compute matrix b and matrix d
# (batch, head, time1, 2*time1-1)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
matrix_bd = self.rel_shift(matrix_bd)
scores = (matrix_ac + matrix_bd) / math.sqrt(
self.d_k) # (batch, head, time1, time2)
return self.forward_attention(v, scores, mask)
class RelPositionMultiHeadedAttentionSANM(MultiHeadedAttentionSANM):
"""Multi-Head Attention layer with relative position encoding (new implementation).
Details can be found in https://github.com/espnet/espnet/pull/2816.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
"""
def __init__(self,
n_head,
n_feat,
dropout_rate,
zero_triu=False,
kernel_size=15,
sanm_shfit=0):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate, kernel_size, sanm_shfit)
self.zero_triu = zero_triu
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
torch.nn.init.xavier_uniform_(self.pos_bias_u)
torch.nn.init.xavier_uniform_(self.pos_bias_v)
def rel_shift(self, x):
"""Compute relative positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
time1 means the length of query vector.
Returns:
torch.Tensor: Output tensor.
"""
zero_pad = torch.zeros((*x.size()[:3], 1),
device=x.device,
dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
x = x_padded[:, :, 1:].view_as(
x)[:, :, :, :x.size(-1) // 2
+ 1] # only keep the positions from 0 to time2
if self.zero_triu:
ones = torch.ones((x.size(2), x.size(3)), device=x.device)
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
return x
def forward(self, query, key, value, pos_emb, mask):
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
pos_emb (torch.Tensor): Positional embedding tensor
(#batch, 2*time1-1, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2).
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
"""
fsmn_memory = self.forward_fsmn(value, mask)
q, k, v = self.forward_qkv(query, key, value)
q = q.transpose(1, 2) # (batch, time1, head, d_k)
n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
# (batch, head, time1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
# (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
# compute matrix b and matrix d
# (batch, head, time1, 2*time1-1)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
matrix_bd = self.rel_shift(matrix_bd)
scores = (matrix_ac + matrix_bd) / math.sqrt(
self.d_k) # (batch, head, time1, time2)
att_outs = self.forward_attention(v, scores, mask)
return att_outs + fsmn_memory

View File

@@ -0,0 +1,239 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from espnet/espnet.
"""Encoder self-attention layer definition."""
import torch
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
from torch import nn
class EncoderLayer(nn.Module):
"""Encoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
stochastic_depth_rate (float): Proability to skip this layer.
During training, the layer may skip residual computation and return input
as-is with given probability.
"""
def __init__(
self,
size,
self_attn,
feed_forward,
dropout_rate,
normalize_before=True,
concat_after=False,
stochastic_depth_rate=0.0,
):
"""Construct an EncoderLayer object."""
super(EncoderLayer, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(size)
self.norm2 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear = nn.Linear(size + size, size)
self.stochastic_depth_rate = stochastic_depth_rate
def forward(self, x, mask, cache=None):
"""Compute encoded features.
Args:
x_input (torch.Tensor): Input tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, time).
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time).
"""
skip_layer = False
# with stochastic depth, residual connection `x + f(x)` becomes
# `x <- x + 1 / (1 - p) * f(x)` at training time.
stoch_layer_coeff = 1.0
if self.training and self.stochastic_depth_rate > 0:
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
if skip_layer:
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, mask
residual = x
if self.normalize_before:
x = self.norm1(x)
if cache is None:
x_q = x
else:
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
x_q = x[:, -1:, :]
residual = residual[:, -1:, :]
mask = None if mask is None else mask[:, -1:, :]
if self.concat_after:
x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1)
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
else:
x = residual + stoch_layer_coeff * self.dropout(
self.self_attn(x_q, x, x, mask))
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm2(x)
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, mask
class EncoderLayerChunk(nn.Module):
"""Encoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
stochastic_depth_rate (float): Proability to skip this layer.
During training, the layer may skip residual computation and return input
as-is with given probability.
"""
def __init__(
self,
size,
self_attn,
feed_forward,
dropout_rate,
normalize_before=True,
concat_after=False,
stochastic_depth_rate=0.0,
):
"""Construct an EncoderLayer object."""
super(EncoderLayerChunk, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(size)
self.norm2 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear = nn.Linear(size + size, size)
self.stochastic_depth_rate = stochastic_depth_rate
def forward(self,
x,
mask,
cache=None,
mask_shfit_chunk=None,
mask_att_chunk_encoder=None):
"""Compute encoded features.
Args:
x_input (torch.Tensor): Input tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, time).
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time).
"""
skip_layer = False
# with stochastic depth, residual connection `x + f(x)` becomes
# `x <- x + 1 / (1 - p) * f(x)` at training time.
stoch_layer_coeff = 1.0
if self.training and self.stochastic_depth_rate > 0:
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
if skip_layer:
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, mask
residual = x
if self.normalize_before:
x = self.norm1(x)
if cache is None:
x_q = x
else:
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
x_q = x[:, -1:, :]
residual = residual[:, -1:, :]
mask = None if mask is None else mask[:, -1:, :]
if self.concat_after:
x_concat = torch.cat(
(x,
self.self_attn(
x_q,
x,
x,
mask,
mask_shfit_chunk=mask_shfit_chunk,
mask_att_chunk_encoder=mask_att_chunk_encoder)),
dim=-1)
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
else:
x = residual + stoch_layer_coeff * self.dropout(
self.self_attn(
x_q,
x,
x,
mask,
mask_shfit_chunk=mask_shfit_chunk,
mask_att_chunk_encoder=mask_att_chunk_encoder))
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm2(x)
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, mask, None, mask_shfit_chunk, mask_att_chunk_encoder

View File

@@ -0,0 +1,890 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from espnet/espnet.
import argparse
import logging
import os
from pathlib import Path
from typing import Callable, Collection, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import yaml
from espnet2.asr.ctc import CTC
from espnet2.asr.decoder.abs_decoder import AbsDecoder
from espnet2.asr.decoder.mlm_decoder import MLMDecoder
from espnet2.asr.decoder.rnn_decoder import RNNDecoder
from espnet2.asr.decoder.transformer_decoder import \
DynamicConvolution2DTransformerDecoder # noqa: H301
from espnet2.asr.decoder.transformer_decoder import \
LightweightConvolution2DTransformerDecoder # noqa: H301
from espnet2.asr.decoder.transformer_decoder import \
LightweightConvolutionTransformerDecoder # noqa: H301
from espnet2.asr.decoder.transformer_decoder import (
DynamicConvolutionTransformerDecoder, TransformerDecoder)
from espnet2.asr.encoder.abs_encoder import AbsEncoder
from espnet2.asr.encoder.contextual_block_conformer_encoder import \
ContextualBlockConformerEncoder # noqa: H301
from espnet2.asr.encoder.contextual_block_transformer_encoder import \
ContextualBlockTransformerEncoder # noqa: H301
from espnet2.asr.encoder.hubert_encoder import (FairseqHubertEncoder,
FairseqHubertPretrainEncoder)
from espnet2.asr.encoder.longformer_encoder import LongformerEncoder
from espnet2.asr.encoder.rnn_encoder import RNNEncoder
from espnet2.asr.encoder.transformer_encoder import TransformerEncoder
from espnet2.asr.encoder.vgg_rnn_encoder import VGGRNNEncoder
from espnet2.asr.encoder.wav2vec2_encoder import FairSeqWav2Vec2Encoder
from espnet2.asr.espnet_model import ESPnetASRModel
from espnet2.asr.frontend.abs_frontend import AbsFrontend
from espnet2.asr.frontend.default import DefaultFrontend
from espnet2.asr.frontend.fused import FusedFrontends
from espnet2.asr.frontend.s3prl import S3prlFrontend
from espnet2.asr.frontend.windowing import SlidingWindow
from espnet2.asr.maskctc_model import MaskCTCModel
from espnet2.asr.postencoder.abs_postencoder import AbsPostEncoder
from espnet2.asr.postencoder.hugging_face_transformers_postencoder import \
HuggingFaceTransformersPostEncoder # noqa: H301
from espnet2.asr.preencoder.abs_preencoder import AbsPreEncoder
from espnet2.asr.preencoder.linear import LinearProjection
from espnet2.asr.preencoder.sinc import LightweightSincConvs
from espnet2.asr.specaug.abs_specaug import AbsSpecAug
from espnet2.asr.specaug.specaug import SpecAug
from espnet2.asr.transducer.joint_network import JointNetwork
from espnet2.asr.transducer.transducer_decoder import TransducerDecoder
from espnet2.layers.abs_normalize import AbsNormalize
from espnet2.layers.global_mvn import GlobalMVN
from espnet2.layers.utterance_mvn import UtteranceMVN
from espnet2.tasks.abs_task import AbsTask
from espnet2.text.phoneme_tokenizer import g2p_choices
from espnet2.torch_utils.initialize import initialize
from espnet2.train.abs_espnet_model import AbsESPnetModel
from espnet2.train.class_choices import ClassChoices
from espnet2.train.collate_fn import CommonCollateFn
from espnet2.train.preprocessor import CommonPreprocessor
from espnet2.train.trainer import Trainer
from espnet2.utils.get_default_kwargs import get_default_kwargs
from espnet2.utils.nested_dict_action import NestedDictAction
from espnet2.utils.types import (float_or_none, int_or_none, str2bool,
str_or_none)
from typeguard import check_argument_types, check_return_type
from ..asr.decoder.transformer_decoder import (ParaformerDecoder,
ParaformerDecoderBertEmbed)
from ..asr.encoder.conformer_encoder import ConformerEncoder, SANMEncoder_v2
from ..asr.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunk
from ..asr.espnet_model import AEDStreaming
from ..asr.espnet_model_paraformer import Paraformer, ParaformerBertEmbed
from ..nets.pytorch_backend.cif_utils.cif import cif_predictor
# FIXME(wjm): suggested by fairseq, We need to setup root logger before importing any fairseq libraries.
logging.basicConfig(
level='INFO',
format=f"[{os.uname()[1].split('.')[0]}]"
f' %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
)
# FIXME(wjm): create logger to set level, unset __name__ for different files to share the same logger
logger = logging.getLogger()
frontend_choices = ClassChoices(
name='frontend',
classes=dict(
default=DefaultFrontend,
sliding_window=SlidingWindow,
s3prl=S3prlFrontend,
fused=FusedFrontends,
),
type_check=AbsFrontend,
default='default',
)
specaug_choices = ClassChoices(
name='specaug',
classes=dict(specaug=SpecAug, ),
type_check=AbsSpecAug,
default=None,
optional=True,
)
normalize_choices = ClassChoices(
'normalize',
classes=dict(
global_mvn=GlobalMVN,
utterance_mvn=UtteranceMVN,
),
type_check=AbsNormalize,
default='utterance_mvn',
optional=True,
)
model_choices = ClassChoices(
'model',
classes=dict(
espnet=ESPnetASRModel,
maskctc=MaskCTCModel,
paraformer=Paraformer,
paraformer_bert_embed=ParaformerBertEmbed,
aedstreaming=AEDStreaming,
),
type_check=AbsESPnetModel,
default='espnet',
)
preencoder_choices = ClassChoices(
name='preencoder',
classes=dict(
sinc=LightweightSincConvs,
linear=LinearProjection,
),
type_check=AbsPreEncoder,
default=None,
optional=True,
)
encoder_choices = ClassChoices(
'encoder',
classes=dict(
conformer=ConformerEncoder,
transformer=TransformerEncoder,
contextual_block_transformer=ContextualBlockTransformerEncoder,
contextual_block_conformer=ContextualBlockConformerEncoder,
vgg_rnn=VGGRNNEncoder,
rnn=RNNEncoder,
wav2vec2=FairSeqWav2Vec2Encoder,
hubert=FairseqHubertEncoder,
hubert_pretrain=FairseqHubertPretrainEncoder,
longformer=LongformerEncoder,
sanm=SANMEncoder,
sanm_v2=SANMEncoder_v2,
sanm_chunk=SANMEncoderChunk,
),
type_check=AbsEncoder,
default='rnn',
)
postencoder_choices = ClassChoices(
name='postencoder',
classes=dict(
hugging_face_transformers=HuggingFaceTransformersPostEncoder, ),
type_check=AbsPostEncoder,
default=None,
optional=True,
)
decoder_choices = ClassChoices(
'decoder',
classes=dict(
transformer=TransformerDecoder,
lightweight_conv=LightweightConvolutionTransformerDecoder,
lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
dynamic_conv=DynamicConvolutionTransformerDecoder,
dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
rnn=RNNDecoder,
transducer=TransducerDecoder,
mlm=MLMDecoder,
paraformer_decoder=ParaformerDecoder,
paraformer_decoder_bert_embed=ParaformerDecoderBertEmbed,
),
type_check=AbsDecoder,
default='rnn',
)
predictor_choices = ClassChoices(
name='predictor',
classes=dict(
cif_predictor=cif_predictor,
ctc_predictor=None,
),
type_check=None,
default='cif_predictor',
optional=True,
)
class ASRTask(AbsTask):
# If you need more than one optimizers, change this value
num_optimizers: int = 1
# Add variable objects configurations
class_choices_list = [
# --frontend and --frontend_conf
frontend_choices,
# --specaug and --specaug_conf
specaug_choices,
# --normalize and --normalize_conf
normalize_choices,
# --model and --model_conf
model_choices,
# --preencoder and --preencoder_conf
preencoder_choices,
# --encoder and --encoder_conf
encoder_choices,
# --postencoder and --postencoder_conf
postencoder_choices,
# --decoder and --decoder_conf
decoder_choices,
]
# If you need to modify train() or eval() procedures, change Trainer class here
trainer = Trainer
@classmethod
def add_task_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(description='Task related')
# NOTE(kamo): add_arguments(..., required=True) can't be used
# to provide --print_config mode. Instead of it, do as
required = parser.get_default('required')
required += ['token_list']
group.add_argument(
'--token_list',
type=str_or_none,
default=None,
help='A text mapping int-id to token',
)
group.add_argument(
'--init',
type=lambda x: str_or_none(x.lower()),
default=None,
help='The initialization method',
choices=[
'chainer',
'xavier_uniform',
'xavier_normal',
'kaiming_uniform',
'kaiming_normal',
None,
],
)
group.add_argument(
'--input_size',
type=int_or_none,
default=None,
help='The number of input dimension of the feature',
)
group.add_argument(
'--ctc_conf',
action=NestedDictAction,
default=get_default_kwargs(CTC),
help='The keyword arguments for CTC class.',
)
group.add_argument(
'--joint_net_conf',
action=NestedDictAction,
default=None,
help='The keyword arguments for joint network class.',
)
group = parser.add_argument_group(description='Preprocess related')
group.add_argument(
'--use_preprocessor',
type=str2bool,
default=True,
help='Apply preprocessing to data or not',
)
group.add_argument(
'--token_type',
type=str,
default='bpe',
choices=['bpe', 'char', 'word', 'phn'],
help='The text will be tokenized '
'in the specified level token',
)
group.add_argument(
'--bpemodel',
type=str_or_none,
default=None,
help='The model file of sentencepiece',
)
parser.add_argument(
'--non_linguistic_symbols',
type=str_or_none,
help='non_linguistic_symbols file path',
)
parser.add_argument(
'--cleaner',
type=str_or_none,
choices=[None, 'tacotron', 'jaconv', 'vietnamese'],
default=None,
help='Apply text cleaning',
)
parser.add_argument(
'--g2p',
type=str_or_none,
choices=g2p_choices,
default=None,
help='Specify g2p method if --token_type=phn',
)
parser.add_argument(
'--speech_volume_normalize',
type=float_or_none,
default=None,
help='Scale the maximum amplitude to the given value.',
)
parser.add_argument(
'--rir_scp',
type=str_or_none,
default=None,
help='The file path of rir scp file.',
)
parser.add_argument(
'--rir_apply_prob',
type=float,
default=1.0,
help='THe probability for applying RIR convolution.',
)
parser.add_argument(
'--noise_scp',
type=str_or_none,
default=None,
help='The file path of noise scp file.',
)
parser.add_argument(
'--noise_apply_prob',
type=float,
default=1.0,
help='The probability applying Noise adding.',
)
parser.add_argument(
'--noise_db_range',
type=str,
default='13_15',
help='The range of noise decibel level.',
)
for class_choices in cls.class_choices_list:
# Append --<name> and --<name>_conf.
# e.g. --encoder and --encoder_conf
class_choices.add_arguments(group)
@classmethod
def build_collate_fn(
cls, args: argparse.Namespace, train: bool
) -> Callable[[Collection[Tuple[str, Dict[str, np.ndarray]]]], Tuple[
List[str], Dict[str, torch.Tensor]], ]:
assert check_argument_types()
# NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
@classmethod
def build_preprocess_fn(
cls, args: argparse.Namespace, train: bool
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
assert check_argument_types()
if args.use_preprocessor:
retval = CommonPreprocessor(
train=train,
token_type=args.token_type,
token_list=args.token_list,
bpemodel=args.bpemodel,
non_linguistic_symbols=args.non_linguistic_symbols,
text_cleaner=args.cleaner,
g2p_type=args.g2p,
# NOTE(kamo): Check attribute existence for backward compatibility
rir_scp=args.rir_scp if hasattr(args, 'rir_scp') else None,
rir_apply_prob=args.rir_apply_prob if hasattr(
args, 'rir_apply_prob') else 1.0,
noise_scp=args.noise_scp
if hasattr(args, 'noise_scp') else None,
noise_apply_prob=args.noise_apply_prob if hasattr(
args, 'noise_apply_prob') else 1.0,
noise_db_range=args.noise_db_range if hasattr(
args, 'noise_db_range') else '13_15',
speech_volume_normalize=args.speech_volume_normalize
if hasattr(args, 'rir_scp') else None,
)
else:
retval = None
assert check_return_type(retval)
return retval
@classmethod
def required_data_names(cls,
train: bool = True,
inference: bool = False) -> Tuple[str, ...]:
if not inference:
retval = ('speech', 'text')
else:
# Recognition mode
retval = ('speech', )
return retval
@classmethod
def optional_data_names(cls,
train: bool = True,
inference: bool = False) -> Tuple[str, ...]:
retval = ()
assert check_return_type(retval)
return retval
@classmethod
def build_model(cls, args: argparse.Namespace) -> ESPnetASRModel:
assert check_argument_types()
if isinstance(args.token_list, str):
with open(args.token_list, encoding='utf-8') as f:
token_list = [line.rstrip() for line in f]
# Overwriting token_list to keep it as "portable".
args.token_list = list(token_list)
elif isinstance(args.token_list, (tuple, list)):
token_list = list(args.token_list)
else:
raise RuntimeError('token_list must be str or list')
vocab_size = len(token_list)
logger.info(f'Vocabulary size: {vocab_size }')
# 1. frontend
if args.input_size is None:
# Extract features in the model
frontend_class = frontend_choices.get_class(args.frontend)
frontend = frontend_class(**args.frontend_conf)
input_size = frontend.output_size()
else:
# Give features from data-loader
args.frontend = None
args.frontend_conf = {}
frontend = None
input_size = args.input_size
# 2. Data augmentation for spectrogram
if args.specaug is not None:
specaug_class = specaug_choices.get_class(args.specaug)
specaug = specaug_class(**args.specaug_conf)
else:
specaug = None
# 3. Normalization layer
if args.normalize is not None:
normalize_class = normalize_choices.get_class(args.normalize)
normalize = normalize_class(**args.normalize_conf)
else:
normalize = None
# 4. Pre-encoder input block
# NOTE(kan-bayashi): Use getattr to keep the compatibility
if getattr(args, 'preencoder', None) is not None:
preencoder_class = preencoder_choices.get_class(args.preencoder)
preencoder = preencoder_class(**args.preencoder_conf)
input_size = preencoder.output_size()
else:
preencoder = None
# 4. Encoder
encoder_class = encoder_choices.get_class(args.encoder)
encoder = encoder_class(input_size=input_size, **args.encoder_conf)
# 5. Post-encoder block
# NOTE(kan-bayashi): Use getattr to keep the compatibility
encoder_output_size = encoder.output_size()
if getattr(args, 'postencoder', None) is not None:
postencoder_class = postencoder_choices.get_class(args.postencoder)
postencoder = postencoder_class(
input_size=encoder_output_size, **args.postencoder_conf)
encoder_output_size = postencoder.output_size()
else:
postencoder = None
# 5. Decoder
decoder_class = decoder_choices.get_class(args.decoder)
if args.decoder == 'transducer':
decoder = decoder_class(
vocab_size,
embed_pad=0,
**args.decoder_conf,
)
joint_network = JointNetwork(
vocab_size,
encoder.output_size(),
decoder.dunits,
**args.joint_net_conf,
)
else:
decoder = decoder_class(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
**args.decoder_conf,
)
joint_network = None
# 6. CTC
ctc = CTC(
odim=vocab_size,
encoder_output_size=encoder_output_size,
**args.ctc_conf)
# 7. Build model
try:
model_class = model_choices.get_class(args.model)
except AttributeError:
model_class = model_choices.get_class('espnet')
model = model_class(
vocab_size=vocab_size,
frontend=frontend,
specaug=specaug,
normalize=normalize,
preencoder=preencoder,
encoder=encoder,
postencoder=postencoder,
decoder=decoder,
ctc=ctc,
joint_network=joint_network,
token_list=token_list,
**args.model_conf,
)
# FIXME(kamo): Should be done in model?
# 8. Initialize
if args.init is not None:
initialize(model, args.init)
assert check_return_type(model)
return model
class ASRTaskNAR(AbsTask):
# If you need more than one optimizers, change this value
num_optimizers: int = 1
# Add variable objects configurations
class_choices_list = [
# --frontend and --frontend_conf
frontend_choices,
# --specaug and --specaug_conf
specaug_choices,
# --normalize and --normalize_conf
normalize_choices,
# --model and --model_conf
model_choices,
# --preencoder and --preencoder_conf
preencoder_choices,
# --encoder and --encoder_conf
encoder_choices,
# --postencoder and --postencoder_conf
postencoder_choices,
# --decoder and --decoder_conf
decoder_choices,
# --predictor and --predictor_conf
predictor_choices,
]
# If you need to modify train() or eval() procedures, change Trainer class here
trainer = Trainer
@classmethod
def add_task_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(description='Task related')
# NOTE(kamo): add_arguments(..., required=True) can't be used
# to provide --print_config mode. Instead of it, do as
required = parser.get_default('required')
required += ['token_list']
group.add_argument(
'--token_list',
type=str_or_none,
default=None,
help='A text mapping int-id to token',
)
group.add_argument(
'--init',
type=lambda x: str_or_none(x.lower()),
default=None,
help='The initialization method',
choices=[
'chainer',
'xavier_uniform',
'xavier_normal',
'kaiming_uniform',
'kaiming_normal',
None,
],
)
group.add_argument(
'--input_size',
type=int_or_none,
default=None,
help='The number of input dimension of the feature',
)
group.add_argument(
'--ctc_conf',
action=NestedDictAction,
default=get_default_kwargs(CTC),
help='The keyword arguments for CTC class.',
)
group.add_argument(
'--joint_net_conf',
action=NestedDictAction,
default=None,
help='The keyword arguments for joint network class.',
)
group = parser.add_argument_group(description='Preprocess related')
group.add_argument(
'--use_preprocessor',
type=str2bool,
default=True,
help='Apply preprocessing to data or not',
)
group.add_argument(
'--token_type',
type=str,
default='bpe',
choices=['bpe', 'char', 'word', 'phn'],
help='The text will be tokenized '
'in the specified level token',
)
group.add_argument(
'--bpemodel',
type=str_or_none,
default=None,
help='The model file of sentencepiece',
)
parser.add_argument(
'--non_linguistic_symbols',
type=str_or_none,
help='non_linguistic_symbols file path',
)
parser.add_argument(
'--cleaner',
type=str_or_none,
choices=[None, 'tacotron', 'jaconv', 'vietnamese'],
default=None,
help='Apply text cleaning',
)
parser.add_argument(
'--g2p',
type=str_or_none,
choices=g2p_choices,
default=None,
help='Specify g2p method if --token_type=phn',
)
parser.add_argument(
'--speech_volume_normalize',
type=float_or_none,
default=None,
help='Scale the maximum amplitude to the given value.',
)
parser.add_argument(
'--rir_scp',
type=str_or_none,
default=None,
help='The file path of rir scp file.',
)
parser.add_argument(
'--rir_apply_prob',
type=float,
default=1.0,
help='THe probability for applying RIR convolution.',
)
parser.add_argument(
'--noise_scp',
type=str_or_none,
default=None,
help='The file path of noise scp file.',
)
parser.add_argument(
'--noise_apply_prob',
type=float,
default=1.0,
help='The probability applying Noise adding.',
)
parser.add_argument(
'--noise_db_range',
type=str,
default='13_15',
help='The range of noise decibel level.',
)
for class_choices in cls.class_choices_list:
# Append --<name> and --<name>_conf.
# e.g. --encoder and --encoder_conf
class_choices.add_arguments(group)
@classmethod
def build_collate_fn(
cls, args: argparse.Namespace, train: bool
) -> Callable[[Collection[Tuple[str, Dict[str, np.ndarray]]]], Tuple[
List[str], Dict[str, torch.Tensor]], ]:
assert check_argument_types()
# NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
@classmethod
def build_preprocess_fn(
cls, args: argparse.Namespace, train: bool
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
assert check_argument_types()
if args.use_preprocessor:
retval = CommonPreprocessor(
train=train,
token_type=args.token_type,
token_list=args.token_list,
bpemodel=args.bpemodel,
non_linguistic_symbols=args.non_linguistic_symbols,
text_cleaner=args.cleaner,
g2p_type=args.g2p,
# NOTE(kamo): Check attribute existence for backward compatibility
rir_scp=args.rir_scp if hasattr(args, 'rir_scp') else None,
rir_apply_prob=args.rir_apply_prob if hasattr(
args, 'rir_apply_prob') else 1.0,
noise_scp=args.noise_scp
if hasattr(args, 'noise_scp') else None,
noise_apply_prob=args.noise_apply_prob if hasattr(
args, 'noise_apply_prob') else 1.0,
noise_db_range=args.noise_db_range if hasattr(
args, 'noise_db_range') else '13_15',
speech_volume_normalize=args.speech_volume_normalize
if hasattr(args, 'rir_scp') else None,
)
else:
retval = None
assert check_return_type(retval)
return retval
@classmethod
def required_data_names(cls,
train: bool = True,
inference: bool = False) -> Tuple[str, ...]:
if not inference:
retval = ('speech', 'text')
else:
# Recognition mode
retval = ('speech', )
return retval
@classmethod
def optional_data_names(cls,
train: bool = True,
inference: bool = False) -> Tuple[str, ...]:
retval = ()
assert check_return_type(retval)
return retval
@classmethod
def build_model(cls, args: argparse.Namespace):
assert check_argument_types()
if isinstance(args.token_list, str):
with open(args.token_list, encoding='utf-8') as f:
token_list = [line.rstrip() for line in f]
# Overwriting token_list to keep it as "portable".
args.token_list = list(token_list)
elif isinstance(args.token_list, (tuple, list)):
token_list = list(args.token_list)
else:
raise RuntimeError('token_list must be str or list')
vocab_size = len(token_list)
# logger.info(f'Vocabulary size: {vocab_size }')
# 1. frontend
if args.input_size is None:
# Extract features in the model
frontend_class = frontend_choices.get_class(args.frontend)
frontend = frontend_class(**args.frontend_conf)
input_size = frontend.output_size()
else:
# Give features from data-loader
args.frontend = None
args.frontend_conf = {}
frontend = None
input_size = args.input_size
# 2. Data augmentation for spectrogram
if args.specaug is not None:
specaug_class = specaug_choices.get_class(args.specaug)
specaug = specaug_class(**args.specaug_conf)
else:
specaug = None
# 3. Normalization layer
if args.normalize is not None:
normalize_class = normalize_choices.get_class(args.normalize)
normalize = normalize_class(**args.normalize_conf)
else:
normalize = None
# 4. Pre-encoder input block
# NOTE(kan-bayashi): Use getattr to keep the compatibility
if getattr(args, 'preencoder', None) is not None:
preencoder_class = preencoder_choices.get_class(args.preencoder)
preencoder = preencoder_class(**args.preencoder_conf)
input_size = preencoder.output_size()
else:
preencoder = None
# 4. Encoder
encoder_class = encoder_choices.get_class(args.encoder)
encoder = encoder_class(input_size=input_size, **args.encoder_conf)
# 5. Post-encoder block
# NOTE(kan-bayashi): Use getattr to keep the compatibility
encoder_output_size = encoder.output_size()
if getattr(args, 'postencoder', None) is not None:
postencoder_class = postencoder_choices.get_class(args.postencoder)
postencoder = postencoder_class(
input_size=encoder_output_size, **args.postencoder_conf)
encoder_output_size = postencoder.output_size()
else:
postencoder = None
# 5. Decoder
decoder_class = decoder_choices.get_class(args.decoder)
if args.decoder == 'transducer':
decoder = decoder_class(
vocab_size,
embed_pad=0,
**args.decoder_conf,
)
joint_network = JointNetwork(
vocab_size,
encoder.output_size(),
decoder.dunits,
**args.joint_net_conf,
)
else:
decoder = decoder_class(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
**args.decoder_conf,
)
joint_network = None
# 6. CTC
ctc = CTC(
odim=vocab_size,
encoder_output_size=encoder_output_size,
**args.ctc_conf)
predictor_class = predictor_choices.get_class(args.predictor)
predictor = predictor_class(**args.predictor_conf)
# 7. Build model
try:
model_class = model_choices.get_class(args.model)
except AttributeError:
model_class = model_choices.get_class('espnet')
model = model_class(
vocab_size=vocab_size,
frontend=frontend,
specaug=specaug,
normalize=normalize,
preencoder=preencoder,
encoder=encoder,
postencoder=postencoder,
decoder=decoder,
ctc=ctc,
joint_network=joint_network,
token_list=token_list,
predictor=predictor,
**args.model_conf,
)
# FIXME(kamo): Should be done in model?
# 8. Initialize
if args.init is not None:
initialize(model, args.init)
assert check_return_type(model)
return model

View File

@@ -0,0 +1,217 @@
import io
import os
import shutil
import threading
from typing import Any, Dict, List, Union
import yaml
from modelscope.metainfo import Pipelines
from modelscope.models import Model
from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import WavToScp
from modelscope.utils.constant import Tasks
from .asr_engine import asr_env_checking, asr_inference_paraformer_espnet
from .asr_engine.common import asr_utils
__all__ = ['AutomaticSpeechRecognitionPipeline']
@PIPELINES.register_module(
Tasks.auto_speech_recognition, module_name=Pipelines.asr_inference)
class AutomaticSpeechRecognitionPipeline(Pipeline):
"""ASR Pipeline
"""
def __init__(self,
model: Union[List[Model], List[str]] = None,
preprocessor: WavToScp = None,
**kwargs):
"""use `model` and `preprocessor` to create an asr pipeline for prediction
"""
assert model is not None, 'asr model should be provided'
model_list: List = []
if isinstance(model[0], Model):
model_list = model
else:
model_list.append(Model.from_pretrained(model[0]))
if len(model) == 2 and model[1] is not None:
model_list.append(Model.from_pretrained(model[1]))
super().__init__(model=model_list, preprocessor=preprocessor, **kwargs)
self._preprocessor = preprocessor
self._am_model = model_list[0]
if len(model_list) == 2 and model_list[1] is not None:
self._lm_model = model_list[1]
def __call__(self,
wav_path: str,
recog_type: str = None,
audio_format: str = None,
workspace: str = None) -> Dict[str, Any]:
assert len(wav_path) > 0, 'wav_path should be provided'
self._recog_type = recog_type
self._audio_format = audio_format
self._workspace = workspace
self._wav_path = wav_path
if recog_type is None or audio_format is None or workspace is None:
self._recog_type, self._audio_format, self._workspace, self._wav_path = asr_utils.type_checking(
wav_path, recog_type, audio_format, workspace)
if self._preprocessor is None:
self._preprocessor = WavToScp(workspace=self._workspace)
output = self._preprocessor.forward(self._am_model.forward(),
self._recog_type,
self._audio_format, self._wav_path)
output = self.forward(output)
rst = self.postprocess(output)
return rst
def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Decoding
"""
j: int = 0
process = []
while j < inputs['thread_count']:
data_cmd: Sequence[Tuple[str, str, str]]
if inputs['audio_format'] == 'wav':
data_cmd = [(os.path.join(inputs['workspace'],
'data.' + str(j) + '.scp'), 'speech',
'sound')]
elif inputs['audio_format'] == 'kaldi_ark':
data_cmd = [(os.path.join(inputs['workspace'],
'data.' + str(j) + '.scp'), 'speech',
'kaldi_ark')]
output_dir: str = os.path.join(inputs['output'],
'output.' + str(j))
if not os.path.exists(output_dir):
os.mkdir(output_dir)
config_file = open(inputs['asr_model_config'])
root = yaml.full_load(config_file)
config_file.close()
frontend_conf = None
if 'frontend_conf' in root:
frontend_conf = root['frontend_conf']
cmd = {
'model_type': inputs['model_type'],
'beam_size': root['beam_size'],
'penalty': root['penalty'],
'maxlenratio': root['maxlenratio'],
'minlenratio': root['minlenratio'],
'ctc_weight': root['ctc_weight'],
'lm_weight': root['lm_weight'],
'output_dir': output_dir,
'ngpu': 0,
'log_level': 'ERROR',
'data_path_and_name_and_type': data_cmd,
'asr_train_config': inputs['am_model_config'],
'asr_model_file': inputs['am_model_path'],
'batch_size': inputs['model_config']['batch_size'],
'frontend_conf': frontend_conf
}
thread = AsrInferenceThread(j, cmd)
thread.start()
j += 1
process.append(thread)
for p in process:
p.join()
return inputs
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""process the asr results
"""
rst = {'rec_result': 'None'}
# single wav task
if inputs['recog_type'] == 'wav' and inputs['audio_format'] == 'wav':
text_file: str = os.path.join(inputs['output'], 'output.0',
'1best_recog', 'text')
if os.path.exists(text_file):
f = open(text_file, 'r')
result_str: str = f.readline()
f.close()
if len(result_str) > 0:
result_list = result_str.split()
if len(result_list) >= 2:
rst['rec_result'] = result_list[1]
# run with datasets, and audio format is waveform or kaldi_ark
elif inputs['recog_type'] != 'wav':
inputs['reference_text'] = self._ref_text_tidy(inputs)
inputs['datasets_result'] = asr_utils.compute_wer(
inputs['hypothesis_text'], inputs['reference_text'])
else:
raise ValueError('recog_type and audio_format are mismatching')
if 'datasets_result' in inputs:
rst['datasets_result'] = inputs['datasets_result']
# remove workspace dir (.tmp)
if os.path.exists(self._workspace):
shutil.rmtree(self._workspace)
return rst
def _ref_text_tidy(self, inputs: Dict[str, Any]) -> str:
ref_text: str = os.path.join(inputs['output'], 'text.ref')
k: int = 0
while k < inputs['thread_count']:
output_text = os.path.join(inputs['output'], 'output.' + str(k),
'1best_recog', 'text')
if os.path.exists(output_text):
with open(output_text, 'r', encoding='utf-8') as i:
lines = i.readlines()
with open(ref_text, 'a', encoding='utf-8') as o:
for line in lines:
o.write(line)
k += 1
return ref_text
class AsrInferenceThread(threading.Thread):
def __init__(self, threadID, cmd):
threading.Thread.__init__(self)
self._threadID = threadID
self._cmd = cmd
def run(self):
if self._cmd['model_type'] == 'pytorch':
asr_inference_paraformer_espnet.asr_inference(
batch_size=self._cmd['batch_size'],
output_dir=self._cmd['output_dir'],
maxlenratio=self._cmd['maxlenratio'],
minlenratio=self._cmd['minlenratio'],
beam_size=self._cmd['beam_size'],
ngpu=self._cmd['ngpu'],
ctc_weight=self._cmd['ctc_weight'],
lm_weight=self._cmd['lm_weight'],
penalty=self._cmd['penalty'],
log_level=self._cmd['log_level'],
data_path_and_name_and_type=self.
_cmd['data_path_and_name_and_type'],
asr_train_config=self._cmd['asr_train_config'],
asr_model_file=self._cmd['asr_model_file'],
frontend_conf=self._cmd['frontend_conf'])

View File

@@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from modelscope.utils.error import AUDIO_IMPORT_ERROR, TENSORFLOW_IMPORT_ERROR
from .asr import WavToScp
from .base import Preprocessor
from .builder import PREPROCESSORS, build_preprocessor
from .common import Compose

View File

@@ -0,0 +1,254 @@
import io
import os
import shutil
from pathlib import Path
from typing import Any, Dict, List
import yaml
from modelscope.metainfo import Preprocessors
from modelscope.models.base import Model
from modelscope.utils.constant import Fields
from .base import Preprocessor
from .builder import PREPROCESSORS
__all__ = ['WavToScp']
@PREPROCESSORS.register_module(
Fields.audio, module_name=Preprocessors.wav_to_scp)
class WavToScp(Preprocessor):
"""generate audio scp from wave or ark
Args:
workspace (str):
"""
def __init__(self, workspace: str = None):
# the workspace path
if workspace is None or len(workspace) == 0:
self._workspace = os.path.join(os.getcwd(), '.tmp')
else:
self._workspace = workspace
if not os.path.exists(self._workspace):
os.mkdir(self._workspace)
def __call__(self,
model: List[Model] = None,
recog_type: str = None,
audio_format: str = None,
wav_path: str = None) -> Dict[str, Any]:
assert len(model) > 0, 'preprocess model is invalid'
assert len(recog_type) > 0, 'preprocess recog_type is empty'
assert len(audio_format) > 0, 'preprocess audio_format is empty'
assert len(wav_path) > 0, 'preprocess wav_path is empty'
self._am_model = model[0]
if len(model) == 2 and model[1] is not None:
self._lm_model = model[1]
out = self.forward(self._am_model.forward(), recog_type, audio_format,
wav_path)
return out
def forward(self, model: Dict[str, Any], recog_type: str,
audio_format: str, wav_path: str) -> Dict[str, Any]:
assert len(recog_type) > 0, 'preprocess recog_type is empty'
assert len(audio_format) > 0, 'preprocess audio_format is empty'
assert len(wav_path) > 0, 'preprocess wav_path is empty'
assert os.path.exists(wav_path), 'preprocess wav_path does not exist'
assert len(
model['am_model']) > 0, 'preprocess model[am_model] is empty'
assert len(model['am_model_path']
) > 0, 'preprocess model[am_model_path] is empty'
assert os.path.exists(
model['am_model_path']), 'preprocess am_model_path does not exist'
assert len(model['model_workspace']
) > 0, 'preprocess model[model_workspace] is empty'
assert os.path.exists(model['model_workspace']
), 'preprocess model_workspace does not exist'
assert len(model['model_config']
) > 0, 'preprocess model[model_config] is empty'
# the am model name
am_model: str = model['am_model']
# the am model file path
am_model_path: str = model['am_model_path']
# the recognition model dir path
model_workspace: str = model['model_workspace']
# the recognition model config dict
global_model_config_dict: str = model['model_config']
rst = {
'workspace': os.path.join(self._workspace, recog_type),
'am_model': am_model,
'am_model_path': am_model_path,
'model_workspace': model_workspace,
# the asr type setting, eg: test dev train wav
'recog_type': recog_type,
# the asr audio format setting, eg: wav, kaldi_ark
'audio_format': audio_format,
# the test wav file path or the dataset path
'wav_path': wav_path,
'model_config': global_model_config_dict
}
out = self._config_checking(rst)
out = self._env_setting(out)
if audio_format == 'wav':
out = self._scp_generation_from_wav(out)
elif audio_format == 'kaldi_ark':
out = self._scp_generation_from_ark(out)
return out
def _config_checking(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""config checking
"""
assert inputs['model_config'].__contains__(
'type'), 'model type does not exist'
assert inputs['model_config'].__contains__(
'batch_size'), 'batch_size does not exist'
assert inputs['model_config'].__contains__(
'am_model_config'), 'am_model_config does not exist'
assert inputs['model_config'].__contains__(
'asr_model_config'), 'asr_model_config does not exist'
assert inputs['model_config'].__contains__(
'asr_model_wav_config'), 'asr_model_wav_config does not exist'
am_model_config: str = os.path.join(
inputs['model_workspace'],
inputs['model_config']['am_model_config'])
assert os.path.exists(
am_model_config), 'am_model_config does not exist'
inputs['am_model_config'] = am_model_config
asr_model_config: str = os.path.join(
inputs['model_workspace'],
inputs['model_config']['asr_model_config'])
assert os.path.exists(
asr_model_config), 'asr_model_config does not exist'
asr_model_wav_config: str = os.path.join(
inputs['model_workspace'],
inputs['model_config']['asr_model_wav_config'])
assert os.path.exists(
asr_model_wav_config), 'asr_model_wav_config does not exist'
inputs['model_type'] = inputs['model_config']['type']
if inputs['audio_format'] == 'wav':
inputs['asr_model_config'] = asr_model_wav_config
else:
inputs['asr_model_config'] = asr_model_config
return inputs
def _env_setting(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
if not os.path.exists(inputs['workspace']):
os.mkdir(inputs['workspace'])
inputs['output'] = os.path.join(inputs['workspace'], 'logdir')
if not os.path.exists(inputs['output']):
os.mkdir(inputs['output'])
# run with datasets, should set datasets_path and text_path
if inputs['recog_type'] != 'wav':
inputs['datasets_path'] = inputs['wav_path']
# run with datasets, and audio format is waveform
if inputs['audio_format'] == 'wav':
inputs['wav_path'] = os.path.join(inputs['datasets_path'],
'wav', inputs['recog_type'])
inputs['hypothesis_text'] = os.path.join(
inputs['datasets_path'], 'transcript', 'data.text')
assert os.path.exists(inputs['hypothesis_text']
), 'hypothesis text does not exist'
elif inputs['audio_format'] == 'kaldi_ark':
inputs['wav_path'] = os.path.join(inputs['datasets_path'],
inputs['recog_type'])
inputs['hypothesis_text'] = os.path.join(
inputs['wav_path'], 'data.text')
assert os.path.exists(inputs['hypothesis_text']
), 'hypothesis text does not exist'
return inputs
def _scp_generation_from_wav(self, inputs: Dict[str,
Any]) -> Dict[str, Any]:
"""scp generation from waveform files
"""
# find all waveform files
wav_list = []
if inputs['recog_type'] == 'wav':
file_path = inputs['wav_path']
if os.path.isfile(file_path):
if file_path.endswith('.wav') or file_path.endswith('.WAV'):
wav_list.append(file_path)
else:
wav_dir: str = inputs['wav_path']
wav_list = self._recursion_dir_all_wave(wav_list, wav_dir)
list_count: int = len(wav_list)
inputs['wav_count'] = list_count
# store all wav into data.0.scp
inputs['thread_count'] = 1
j: int = 0
wav_list_path = os.path.join(inputs['workspace'], 'data.0.scp')
with open(wav_list_path, 'a') as f:
while j < list_count:
wav_file = wav_list[j]
wave_scp_content: str = os.path.splitext(
os.path.basename(wav_file))[0]
wave_scp_content += ' ' + wav_file + '\n'
f.write(wave_scp_content)
j += 1
return inputs
def _scp_generation_from_ark(self, inputs: Dict[str,
Any]) -> Dict[str, Any]:
"""scp generation from kaldi ark file
"""
inputs['thread_count'] = 1
ark_scp_path = os.path.join(inputs['wav_path'], 'data.scp')
ark_file_path = os.path.join(inputs['wav_path'], 'data.ark')
assert os.path.exists(ark_scp_path), 'data.scp does not exist'
assert os.path.exists(ark_file_path), 'data.ark does not exist'
new_ark_scp_path = os.path.join(inputs['workspace'], 'data.0.scp')
with open(ark_scp_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
with open(new_ark_scp_path, 'w', encoding='utf-8') as n:
for line in lines:
outs = line.strip().split(' ')
if len(outs) == 2:
key = outs[0]
sub = outs[1].split(':')
if len(sub) == 2:
nums = sub[1]
content = key + ' ' + ark_file_path + ':' + nums + '\n'
n.write(content)
return inputs
def _recursion_dir_all_wave(self, wav_list,
dir_path: str) -> Dict[str, Any]:
dir_files = os.listdir(dir_path)
for file in dir_files:
file_path = os.path.join(dir_path, file)
if os.path.isfile(file_path):
if file_path.endswith('.wav') or file_path.endswith('.WAV'):
wav_list.append(file_path)
elif os.path.isdir(file_path):
self._recursion_dir_all_wave(wav_list, file_path)
return wav_list

View File

@@ -1,3 +1,4 @@
espnet==202204
#tts
h5py
inflect

View File

@@ -0,0 +1,199 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tarfile
import unittest
import requests
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level
WAV_FILE = 'data/test/audios/asr_example.wav'
LITTLE_TESTSETS_FILE = 'data_aishell.tar.gz'
LITTLE_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/data_aishell.tar.gz'
AISHELL1_TESTSETS_FILE = 'aishell1.tar.gz'
AISHELL1_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/aishell1.tar.gz'
def un_tar_gz(fname, dirs):
t = tarfile.open(fname)
t.extractall(path=dirs)
class AutomaticSpeechRecognitionTest(unittest.TestCase):
def setUp(self) -> None:
self._am_model_id = 'damo/speech_paraformer_asr_nat-aishell1-pytorch'
# this temporary workspace dir will store waveform files
self._workspace = os.path.join(os.getcwd(), '.tmp')
if not os.path.exists(self._workspace):
os.mkdir(self._workspace)
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_wav(self):
'''run with single waveform file
'''
wav_file_path = os.path.join(os.getcwd(), WAV_FILE)
inference_16k_pipline = pipeline(
task=Tasks.auto_speech_recognition, model=[self._am_model_id])
self.assertTrue(inference_16k_pipline is not None)
rec_result = inference_16k_pipline(wav_file_path)
self.assertTrue(len(rec_result['rec_result']) > 0)
self.assertTrue(rec_result['rec_result'] != 'None')
'''
result structure:
{
'rec_result': '每一天都要快乐喔'
}
or
{
'rec_result': 'None'
}
'''
print('test_run_with_wav rec result: ' + rec_result['rec_result'])
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_wav_dataset(self):
'''run with datasets, and audio format is waveform
datasets directory:
<dataset_path>
wav
test # testsets
xx.wav
...
dev # devsets
yy.wav
...
train # trainsets
zz.wav
...
transcript
data.text # hypothesis text
'''
# downloading pos_testsets file
testsets_file_path = os.path.join(self._workspace,
LITTLE_TESTSETS_FILE)
if not os.path.exists(testsets_file_path):
r = requests.get(LITTLE_TESTSETS_URL)
with open(testsets_file_path, 'wb') as f:
f.write(r.content)
testsets_dir_name = os.path.splitext(
os.path.basename(
os.path.splitext(
os.path.basename(LITTLE_TESTSETS_FILE))[0]))[0]
# dataset_path = <cwd>/.tmp/data_aishell/wav/test
dataset_path = os.path.join(self._workspace, testsets_dir_name, 'wav',
'test')
# untar the dataset_path file
if not os.path.exists(dataset_path):
un_tar_gz(testsets_file_path, self._workspace)
inference_16k_pipline = pipeline(
task=Tasks.auto_speech_recognition, model=[self._am_model_id])
self.assertTrue(inference_16k_pipline is not None)
rec_result = inference_16k_pipline(wav_path=dataset_path)
self.assertTrue(len(rec_result['datasets_result']) > 0)
self.assertTrue(rec_result['datasets_result']['Wrd'] > 0)
'''
result structure:
{
'rec_result': 'None',
'datasets_result':
{
'Wrd': 1654, # the number of words
'Snt': 128, # the number of sentences
'Corr': 1573, # the number of correct words
'Ins': 1, # the number of insert words
'Del': 1, # the number of delete words
'Sub': 80, # the number of substitution words
'wrong_words': 82, # the number of wrong words
'wrong_sentences': 47, # the number of wrong sentences
'Err': 4.96, # WER/CER
'S.Err': 36.72 # SER
}
}
'''
print('test_run_with_wav_dataset datasets result: ')
print(rec_result['datasets_result'])
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_ark_dataset(self):
'''run with datasets, and audio format is kaldi_ark
datasets directory:
<dataset_path>
test # testsets
data.ark
data.scp
data.text
dev # devsets
data.ark
data.scp
data.text
train # trainsets
data.ark
data.scp
data.text
'''
# downloading pos_testsets file
testsets_file_path = os.path.join(self._workspace,
AISHELL1_TESTSETS_FILE)
if not os.path.exists(testsets_file_path):
r = requests.get(AISHELL1_TESTSETS_URL)
with open(testsets_file_path, 'wb') as f:
f.write(r.content)
testsets_dir_name = os.path.splitext(
os.path.basename(
os.path.splitext(
os.path.basename(AISHELL1_TESTSETS_FILE))[0]))[0]
# dataset_path = <cwd>/.tmp/aishell1/test
dataset_path = os.path.join(self._workspace, testsets_dir_name, 'test')
# untar the dataset_path file
if not os.path.exists(dataset_path):
un_tar_gz(testsets_file_path, self._workspace)
inference_16k_pipline = pipeline(
task=Tasks.auto_speech_recognition, model=[self._am_model_id])
self.assertTrue(inference_16k_pipline is not None)
rec_result = inference_16k_pipline(wav_path=dataset_path)
self.assertTrue(len(rec_result['datasets_result']) > 0)
self.assertTrue(rec_result['datasets_result']['Wrd'] > 0)
'''
result structure:
{
'rec_result': 'None',
'datasets_result':
{
'Wrd': 104816, # the number of words
'Snt': 7176, # the number of sentences
'Corr': 99327, # the number of correct words
'Ins': 104, # the number of insert words
'Del': 155, # the number of delete words
'Sub': 5334, # the number of substitution words
'wrong_words': 5593, # the number of wrong words
'wrong_sentences': 2898, # the number of wrong sentences
'Err': 5.34, # WER/CER
'S.Err': 40.38 # SER
}
}
'''
print('test_run_with_ark_dataset datasets result: ')
print(rec_result['datasets_result'])
if __name__ == '__main__':
unittest.main()