mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
[to #42322933] add asr inference with pytorch(espnet framework)
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9273537
This commit is contained in:
3
data/test/audios/asr_example.wav
Normal file
3
data/test/audios/asr_example.wav
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:87bde7feb3b40d75dec27e5824dd1077911f867e3f125c4bf603ec0af954d4db
|
||||
size 77864
|
||||
@@ -23,6 +23,7 @@ class Models(object):
|
||||
sambert_hifigan = 'sambert-hifigan'
|
||||
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
|
||||
kws_kwsbp = 'kws-kwsbp'
|
||||
generic_asr = 'generic-asr'
|
||||
|
||||
# multi-modal models
|
||||
ofa = 'ofa'
|
||||
@@ -68,6 +69,7 @@ class Pipelines(object):
|
||||
speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k'
|
||||
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
|
||||
kws_kwsbp = 'kws-kwsbp'
|
||||
asr_inference = 'asr-inference'
|
||||
|
||||
# multi-modal tasks
|
||||
image_caption = 'image-captioning'
|
||||
@@ -120,6 +122,7 @@ class Preprocessors(object):
|
||||
linear_aec_fbank = 'linear-aec-fbank'
|
||||
text_to_tacotron_symbols = 'text-to-tacotron-symbols'
|
||||
wav_to_lists = 'wav-to-lists'
|
||||
wav_to_scp = 'wav-to-scp'
|
||||
|
||||
# multi-modal
|
||||
ofa_image_caption = 'ofa-image-caption'
|
||||
|
||||
@@ -5,6 +5,7 @@ from .base import Model
|
||||
from .builder import MODELS, build_model
|
||||
|
||||
try:
|
||||
from .audio.asr import GenericAutomaticSpeechRecognition
|
||||
from .audio.tts import SambertHifigan
|
||||
from .audio.kws import GenericKeyWordSpotting
|
||||
from .audio.ans.frcrn import FRCRNModel
|
||||
|
||||
1
modelscope/models/audio/asr/__init__.py
Normal file
1
modelscope/models/audio/asr/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .generic_automatic_speech_recognition import * # noqa F403
|
||||
@@ -0,0 +1,39 @@
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
__all__ = ['GenericAutomaticSpeechRecognition']
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.auto_speech_recognition, module_name=Models.generic_asr)
|
||||
class GenericAutomaticSpeechRecognition(Model):
|
||||
|
||||
def __init__(self, model_dir: str, am_model_name: str,
|
||||
model_config: Dict[str, Any], *args, **kwargs):
|
||||
"""initialize the info of model.
|
||||
|
||||
Args:
|
||||
model_dir (str): the model path.
|
||||
am_model_name (str): the am model name from configuration.json
|
||||
"""
|
||||
|
||||
self.model_cfg = {
|
||||
# the recognition model dir path
|
||||
'model_workspace': model_dir,
|
||||
# the am model name
|
||||
'am_model': am_model_name,
|
||||
# the am model file path
|
||||
'am_model_path': os.path.join(model_dir, am_model_name),
|
||||
# the recognition model config dict
|
||||
'model_config': model_config
|
||||
}
|
||||
|
||||
def forward(self) -> Dict[str, Any]:
|
||||
"""return the info of the model
|
||||
"""
|
||||
return self.model_cfg
|
||||
@@ -3,6 +3,7 @@
|
||||
from modelscope.utils.error import TENSORFLOW_IMPORT_ERROR
|
||||
|
||||
try:
|
||||
from .asr.asr_inference_pipeline import AutomaticSpeechRecognitionPipeline
|
||||
from .kws_kwsbp_pipeline import * # noqa F403
|
||||
from .linear_aec_pipeline import LinearAECPipeline
|
||||
except ModuleNotFoundError as e:
|
||||
|
||||
0
modelscope/pipelines/audio/asr/__init__.py
Normal file
0
modelscope/pipelines/audio/asr/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import nltk
|
||||
|
||||
try:
|
||||
nltk.data.find('taggers/averaged_perceptron_tagger')
|
||||
except LookupError:
|
||||
nltk.download(
|
||||
'averaged_perceptron_tagger', halt_on_error=False, raise_on_error=True)
|
||||
|
||||
try:
|
||||
nltk.data.find('corpora/cmudict')
|
||||
except LookupError:
|
||||
nltk.download('cmudict', halt_on_error=False, raise_on_error=True)
|
||||
690
modelscope/pipelines/audio/asr/asr_engine/asr_inference_paraformer_espnet.py
Executable file
690
modelscope/pipelines/audio/asr/asr_engine/asr_inference_paraformer_espnet.py
Executable file
@@ -0,0 +1,690 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# Part of the implementation is borrowed from espnet/espnet.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from espnet2.asr.frontend.default import DefaultFrontend
|
||||
from espnet2.asr.transducer.beam_search_transducer import BeamSearchTransducer
|
||||
from espnet2.asr.transducer.beam_search_transducer import \
|
||||
ExtendedHypothesis as ExtTransHypothesis # noqa: H301
|
||||
from espnet2.asr.transducer.beam_search_transducer import \
|
||||
Hypothesis as TransHypothesis
|
||||
from espnet2.fileio.datadir_writer import DatadirWriter
|
||||
from espnet2.tasks.lm import LMTask
|
||||
from espnet2.text.build_tokenizer import build_tokenizer
|
||||
from espnet2.text.token_id_converter import TokenIDConverter
|
||||
from espnet2.torch_utils.device_funcs import to_device
|
||||
from espnet2.torch_utils.set_all_random_seed import set_all_random_seed
|
||||
from espnet2.utils import config_argparse
|
||||
from espnet2.utils.types import str2bool, str2triple_str, str_or_none
|
||||
from espnet.nets.batch_beam_search import BatchBeamSearch
|
||||
from espnet.nets.batch_beam_search_online_sim import BatchBeamSearchOnlineSim
|
||||
from espnet.nets.beam_search import BeamSearch, Hypothesis
|
||||
from espnet.nets.pytorch_backend.transformer.subsampling import \
|
||||
TooShortUttError
|
||||
from espnet.nets.scorer_interface import BatchScorerInterface
|
||||
from espnet.nets.scorers.ctc import CTCPrefixScorer
|
||||
from espnet.nets.scorers.length_bonus import LengthBonus
|
||||
from espnet.utils.cli_utils import get_commandline_args
|
||||
from typeguard import check_argument_types, check_return_type
|
||||
|
||||
from .espnet.tasks.asr import ASRTaskNAR as ASRTask
|
||||
|
||||
|
||||
class Speech2Text:
|
||||
|
||||
def __init__(self,
|
||||
asr_train_config: Union[Path, str] = None,
|
||||
asr_model_file: Union[Path, str] = None,
|
||||
transducer_conf: dict = None,
|
||||
lm_train_config: Union[Path, str] = None,
|
||||
lm_file: Union[Path, str] = None,
|
||||
ngram_scorer: str = 'full',
|
||||
ngram_file: Union[Path, str] = None,
|
||||
token_type: str = None,
|
||||
bpemodel: str = None,
|
||||
device: str = 'cpu',
|
||||
maxlenratio: float = 0.0,
|
||||
minlenratio: float = 0.0,
|
||||
batch_size: int = 1,
|
||||
dtype: str = 'float32',
|
||||
beam_size: int = 20,
|
||||
ctc_weight: float = 0.5,
|
||||
lm_weight: float = 1.0,
|
||||
ngram_weight: float = 0.9,
|
||||
penalty: float = 0.0,
|
||||
nbest: int = 1,
|
||||
streaming: bool = False,
|
||||
frontend_conf: dict = None):
|
||||
assert check_argument_types()
|
||||
|
||||
# 1. Build ASR model
|
||||
scorers = {}
|
||||
asr_model, asr_train_args = ASRTask.build_model_from_file(
|
||||
asr_train_config, asr_model_file, device)
|
||||
if asr_model.frontend is None and frontend_conf is not None:
|
||||
frontend = DefaultFrontend(**frontend_conf)
|
||||
asr_model.frontend = frontend
|
||||
asr_model.to(dtype=getattr(torch, dtype)).eval()
|
||||
|
||||
decoder = asr_model.decoder
|
||||
|
||||
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
|
||||
token_list = asr_model.token_list
|
||||
scorers.update(
|
||||
decoder=decoder,
|
||||
ctc=ctc,
|
||||
length_bonus=LengthBonus(len(token_list)),
|
||||
)
|
||||
|
||||
# 2. Build Language model
|
||||
if lm_train_config is not None:
|
||||
lm, lm_train_args = LMTask.build_model_from_file(
|
||||
lm_train_config, lm_file, device)
|
||||
scorers['lm'] = lm.lm
|
||||
|
||||
# 3. Build ngram model
|
||||
if ngram_file is not None:
|
||||
if ngram_scorer == 'full':
|
||||
from espnet.nets.scorers.ngram import NgramFullScorer
|
||||
|
||||
ngram = NgramFullScorer(ngram_file, token_list)
|
||||
else:
|
||||
from espnet.nets.scorers.ngram import NgramPartScorer
|
||||
|
||||
ngram = NgramPartScorer(ngram_file, token_list)
|
||||
else:
|
||||
ngram = None
|
||||
scorers['ngram'] = ngram
|
||||
|
||||
# 4. Build BeamSearch object
|
||||
if asr_model.use_transducer_decoder:
|
||||
beam_search_transducer = BeamSearchTransducer(
|
||||
decoder=asr_model.decoder,
|
||||
joint_network=asr_model.joint_network,
|
||||
beam_size=beam_size,
|
||||
lm=scorers['lm'] if 'lm' in scorers else None,
|
||||
lm_weight=lm_weight,
|
||||
**transducer_conf,
|
||||
)
|
||||
beam_search = None
|
||||
else:
|
||||
beam_search_transducer = None
|
||||
|
||||
weights = dict(
|
||||
decoder=1.0 - ctc_weight,
|
||||
ctc=ctc_weight,
|
||||
lm=lm_weight,
|
||||
ngram=ngram_weight,
|
||||
length_bonus=penalty,
|
||||
)
|
||||
beam_search = BeamSearch(
|
||||
beam_size=beam_size,
|
||||
weights=weights,
|
||||
scorers=scorers,
|
||||
sos=asr_model.sos,
|
||||
eos=asr_model.eos,
|
||||
vocab_size=len(token_list),
|
||||
token_list=token_list,
|
||||
pre_beam_score_key=None if ctc_weight == 1.0 else 'full',
|
||||
)
|
||||
|
||||
# TODO(karita): make all scorers batchfied
|
||||
if batch_size == 1:
|
||||
non_batch = [
|
||||
k for k, v in beam_search.full_scorers.items()
|
||||
if not isinstance(v, BatchScorerInterface)
|
||||
]
|
||||
if len(non_batch) == 0:
|
||||
if streaming:
|
||||
beam_search.__class__ = BatchBeamSearchOnlineSim
|
||||
beam_search.set_streaming_config(asr_train_config)
|
||||
logging.info(
|
||||
'BatchBeamSearchOnlineSim implementation is selected.'
|
||||
)
|
||||
else:
|
||||
beam_search.__class__ = BatchBeamSearch
|
||||
else:
|
||||
logging.warning(
|
||||
f'As non-batch scorers {non_batch} are found, '
|
||||
f'fall back to non-batch implementation.')
|
||||
|
||||
beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
|
||||
for scorer in scorers.values():
|
||||
if isinstance(scorer, torch.nn.Module):
|
||||
scorer.to(
|
||||
device=device, dtype=getattr(torch, dtype)).eval()
|
||||
|
||||
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
|
||||
if token_type is None:
|
||||
token_type = asr_train_args.token_type
|
||||
if bpemodel is None:
|
||||
bpemodel = asr_train_args.bpemodel
|
||||
|
||||
if token_type is None:
|
||||
tokenizer = None
|
||||
elif token_type == 'bpe':
|
||||
if bpemodel is not None:
|
||||
tokenizer = build_tokenizer(
|
||||
token_type=token_type, bpemodel=bpemodel)
|
||||
else:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = build_tokenizer(token_type=token_type)
|
||||
converter = TokenIDConverter(token_list=token_list)
|
||||
|
||||
self.asr_model = asr_model
|
||||
self.asr_train_args = asr_train_args
|
||||
self.converter = converter
|
||||
self.tokenizer = tokenizer
|
||||
self.beam_search = beam_search
|
||||
self.beam_search_transducer = beam_search_transducer
|
||||
self.maxlenratio = maxlenratio
|
||||
self.minlenratio = minlenratio
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.nbest = nbest
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, speech: Union[torch.Tensor, np.ndarray]):
|
||||
"""Inference
|
||||
|
||||
Args:
|
||||
data: Input speech data
|
||||
Returns:
|
||||
text, token, token_int, hyp
|
||||
"""
|
||||
|
||||
assert check_argument_types()
|
||||
|
||||
# Input as audio signal
|
||||
if isinstance(speech, np.ndarray):
|
||||
speech = torch.tensor(speech)
|
||||
|
||||
# data: (Nsamples,) -> (1, Nsamples)
|
||||
speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
|
||||
# lengths: (1,)
|
||||
lengths = speech.new_full([1],
|
||||
dtype=torch.long,
|
||||
fill_value=speech.size(1))
|
||||
batch = {'speech': speech, 'speech_lengths': lengths}
|
||||
|
||||
# a. To device
|
||||
batch = to_device(batch, device=self.device)
|
||||
|
||||
# b. Forward Encoder
|
||||
enc, enc_len = self.asr_model.encode(**batch)
|
||||
if isinstance(enc, tuple):
|
||||
enc = enc[0]
|
||||
assert len(enc) == 1, len(enc)
|
||||
|
||||
predictor_outs = self.asr_model.calc_predictor(enc, enc_len)
|
||||
pre_acoustic_embeds, pre_token_length = predictor_outs[
|
||||
0], predictor_outs[1]
|
||||
pre_token_length = torch.tensor([pre_acoustic_embeds.size(1)],
|
||||
device=pre_acoustic_embeds.device)
|
||||
decoder_outs = self.asr_model.cal_decoder_with_predictor(
|
||||
enc, enc_len, pre_acoustic_embeds, pre_token_length)
|
||||
decoder_out = decoder_outs[0]
|
||||
|
||||
yseq = decoder_out.argmax(dim=-1)
|
||||
score = decoder_out.max(dim=-1)[0]
|
||||
score = torch.sum(score, dim=-1)
|
||||
# pad with mask tokens to ensure compatibility with sos/eos tokens
|
||||
yseq = torch.tensor(
|
||||
[self.asr_model.sos] + yseq.tolist()[0] + [self.asr_model.eos],
|
||||
device=yseq.device)
|
||||
nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
|
||||
|
||||
results = []
|
||||
for hyp in nbest_hyps:
|
||||
assert isinstance(hyp, (Hypothesis, TransHypothesis)), type(hyp)
|
||||
|
||||
# remove sos/eos and get results
|
||||
last_pos = None if self.asr_model.use_transducer_decoder else -1
|
||||
if isinstance(hyp.yseq, list):
|
||||
token_int = hyp.yseq[1:last_pos]
|
||||
else:
|
||||
token_int = hyp.yseq[1:last_pos].tolist()
|
||||
|
||||
# remove blank symbol id, which is assumed to be 0
|
||||
token_int = list(filter(lambda x: x != 0, token_int))
|
||||
|
||||
# Change integer-ids to tokens
|
||||
token = self.converter.ids2tokens(token_int)
|
||||
|
||||
if self.tokenizer is not None:
|
||||
text = self.tokenizer.tokens2text(token)
|
||||
else:
|
||||
text = None
|
||||
|
||||
results.append((text, token, token_int, hyp, speech.size(1)))
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
model_tag: Optional[str] = None,
|
||||
**kwargs: Optional[Any],
|
||||
):
|
||||
"""Build Speech2Text instance from the pretrained model.
|
||||
|
||||
Args:
|
||||
model_tag (Optional[str]): Model tag of the pretrained models.
|
||||
Currently, the tags of espnet_model_zoo are supported.
|
||||
|
||||
Returns:
|
||||
Speech2Text: Speech2Text instance.
|
||||
|
||||
"""
|
||||
if model_tag is not None:
|
||||
try:
|
||||
from espnet_model_zoo.downloader import ModelDownloader
|
||||
|
||||
except ImportError:
|
||||
logging.error(
|
||||
'`espnet_model_zoo` is not installed. '
|
||||
'Please install via `pip install -U espnet_model_zoo`.')
|
||||
raise
|
||||
d = ModelDownloader()
|
||||
kwargs.update(**d.download_and_unpack(model_tag))
|
||||
|
||||
return Speech2Text(**kwargs)
|
||||
|
||||
|
||||
def inference(
|
||||
output_dir: str,
|
||||
maxlenratio: float,
|
||||
minlenratio: float,
|
||||
batch_size: int,
|
||||
dtype: str,
|
||||
beam_size: int,
|
||||
ngpu: int,
|
||||
seed: int,
|
||||
ctc_weight: float,
|
||||
lm_weight: float,
|
||||
ngram_weight: float,
|
||||
penalty: float,
|
||||
nbest: int,
|
||||
num_workers: int,
|
||||
log_level: Union[int, str],
|
||||
data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
|
||||
key_file: Optional[str],
|
||||
asr_train_config: Optional[str],
|
||||
asr_model_file: Optional[str],
|
||||
lm_train_config: Optional[str],
|
||||
lm_file: Optional[str],
|
||||
word_lm_train_config: Optional[str],
|
||||
word_lm_file: Optional[str],
|
||||
ngram_file: Optional[str],
|
||||
model_tag: Optional[str],
|
||||
token_type: Optional[str],
|
||||
bpemodel: Optional[str],
|
||||
allow_variable_data_keys: bool,
|
||||
transducer_conf: Optional[dict],
|
||||
streaming: bool,
|
||||
frontend_conf: dict = None,
|
||||
):
|
||||
assert check_argument_types()
|
||||
if batch_size > 1:
|
||||
raise NotImplementedError('batch decoding is not implemented')
|
||||
if word_lm_train_config is not None:
|
||||
raise NotImplementedError('Word LM is not implemented')
|
||||
if ngpu > 1:
|
||||
raise NotImplementedError('only single GPU decoding is supported')
|
||||
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
|
||||
)
|
||||
|
||||
if ngpu >= 1:
|
||||
device = 'cuda'
|
||||
else:
|
||||
device = 'cpu'
|
||||
|
||||
# 1. Set random-seed
|
||||
set_all_random_seed(seed)
|
||||
|
||||
# 2. Build speech2text
|
||||
speech2text_kwargs = dict(
|
||||
asr_train_config=asr_train_config,
|
||||
asr_model_file=asr_model_file,
|
||||
transducer_conf=transducer_conf,
|
||||
lm_train_config=lm_train_config,
|
||||
lm_file=lm_file,
|
||||
ngram_file=ngram_file,
|
||||
token_type=token_type,
|
||||
bpemodel=bpemodel,
|
||||
device=device,
|
||||
maxlenratio=maxlenratio,
|
||||
minlenratio=minlenratio,
|
||||
dtype=dtype,
|
||||
beam_size=beam_size,
|
||||
ctc_weight=ctc_weight,
|
||||
lm_weight=lm_weight,
|
||||
ngram_weight=ngram_weight,
|
||||
penalty=penalty,
|
||||
nbest=nbest,
|
||||
streaming=streaming,
|
||||
frontend_conf=frontend_conf,
|
||||
)
|
||||
speech2text = Speech2Text.from_pretrained(
|
||||
model_tag=model_tag,
|
||||
**speech2text_kwargs,
|
||||
)
|
||||
|
||||
# 3. Build data-iterator
|
||||
loader = ASRTask.build_streaming_iterator(
|
||||
data_path_and_name_and_type,
|
||||
dtype=dtype,
|
||||
batch_size=batch_size,
|
||||
key_file=key_file,
|
||||
num_workers=num_workers,
|
||||
preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args,
|
||||
False),
|
||||
collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
|
||||
allow_variable_data_keys=allow_variable_data_keys,
|
||||
inference=True,
|
||||
)
|
||||
|
||||
forward_time_total = 0.0
|
||||
length_total = 0.0
|
||||
# 7 .Start for-loop
|
||||
# FIXME(kamo): The output format should be discussed about
|
||||
with DatadirWriter(output_dir) as writer:
|
||||
for keys, batch in loader:
|
||||
assert isinstance(batch, dict), type(batch)
|
||||
assert all(isinstance(s, str) for s in keys), keys
|
||||
_bs = len(next(iter(batch.values())))
|
||||
assert len(keys) == _bs, f'{len(keys)} != {_bs}'
|
||||
batch = {
|
||||
k: v[0]
|
||||
for k, v in batch.items() if not k.endswith('_lengths')
|
||||
}
|
||||
|
||||
# N-best list of (text, token, token_int, hyp_object)
|
||||
|
||||
try:
|
||||
time_beg = time.time()
|
||||
results = speech2text(**batch)
|
||||
time_end = time.time()
|
||||
forward_time = time_end - time_beg
|
||||
length = results[0][-1]
|
||||
results = [results[0][:-1]]
|
||||
forward_time_total += forward_time
|
||||
length_total += length
|
||||
except TooShortUttError as e:
|
||||
logging.warning(f'Utterance {keys} {e}')
|
||||
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
|
||||
results = [[' ', ['<space>'], [2], hyp]] * nbest
|
||||
|
||||
# Only supporting batch_size==1
|
||||
key = keys[0]
|
||||
for n, (text, token, token_int,
|
||||
hyp) in zip(range(1, nbest + 1), results):
|
||||
# Create a directory: outdir/{n}best_recog
|
||||
ibest_writer = writer[f'{n}best_recog']
|
||||
|
||||
# Write the result to each file
|
||||
ibest_writer['token'][key] = ' '.join(token)
|
||||
ibest_writer['token_int'][key] = ' '.join(map(str, token_int))
|
||||
ibest_writer['score'][key] = str(hyp.score)
|
||||
|
||||
if text is not None:
|
||||
ibest_writer['text'][key] = text
|
||||
|
||||
logging.info(
|
||||
'decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}'
|
||||
.format(length_total, forward_time_total,
|
||||
100 * forward_time_total / length_total))
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = config_argparse.ArgumentParser(
|
||||
description='ASR Decoding',
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
|
||||
# Note(kamo): Use '_' instead of '-' as separator.
|
||||
# '-' is confusing if written in yaml.
|
||||
parser.add_argument(
|
||||
'--log_level',
|
||||
type=lambda x: x.upper(),
|
||||
default='INFO',
|
||||
choices=('CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG', 'NOTSET'),
|
||||
help='The verbose level of logging',
|
||||
)
|
||||
|
||||
parser.add_argument('--output_dir', type=str, required=True)
|
||||
parser.add_argument(
|
||||
'--ngpu',
|
||||
type=int,
|
||||
default=0,
|
||||
help='The number of gpus. 0 indicates CPU mode',
|
||||
)
|
||||
parser.add_argument('--seed', type=int, default=0, help='Random seed')
|
||||
parser.add_argument(
|
||||
'--dtype',
|
||||
default='float32',
|
||||
choices=['float16', 'float32', 'float64'],
|
||||
help='Data type',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--num_workers',
|
||||
type=int,
|
||||
default=1,
|
||||
help='The number of workers used for DataLoader',
|
||||
)
|
||||
|
||||
group = parser.add_argument_group('Input data related')
|
||||
group.add_argument(
|
||||
'--data_path_and_name_and_type',
|
||||
type=str2triple_str,
|
||||
required=True,
|
||||
action='append',
|
||||
)
|
||||
group.add_argument('--key_file', type=str_or_none)
|
||||
group.add_argument(
|
||||
'--allow_variable_data_keys', type=str2bool, default=False)
|
||||
|
||||
group = parser.add_argument_group('The model configuration related')
|
||||
group.add_argument(
|
||||
'--asr_train_config',
|
||||
type=str,
|
||||
help='ASR training configuration',
|
||||
)
|
||||
group.add_argument(
|
||||
'--asr_model_file',
|
||||
type=str,
|
||||
help='ASR model parameter file',
|
||||
)
|
||||
group.add_argument(
|
||||
'--lm_train_config',
|
||||
type=str,
|
||||
help='LM training configuration',
|
||||
)
|
||||
group.add_argument(
|
||||
'--lm_file',
|
||||
type=str,
|
||||
help='LM parameter file',
|
||||
)
|
||||
group.add_argument(
|
||||
'--word_lm_train_config',
|
||||
type=str,
|
||||
help='Word LM training configuration',
|
||||
)
|
||||
group.add_argument(
|
||||
'--word_lm_file',
|
||||
type=str,
|
||||
help='Word LM parameter file',
|
||||
)
|
||||
group.add_argument(
|
||||
'--ngram_file',
|
||||
type=str,
|
||||
help='N-gram parameter file',
|
||||
)
|
||||
group.add_argument(
|
||||
'--model_tag',
|
||||
type=str,
|
||||
help='Pretrained model tag. If specify this option, *_train_config and '
|
||||
'*_file will be overwritten',
|
||||
)
|
||||
|
||||
group = parser.add_argument_group('Beam-search related')
|
||||
group.add_argument(
|
||||
'--batch_size',
|
||||
type=int,
|
||||
default=1,
|
||||
help='The batch size for inference',
|
||||
)
|
||||
group.add_argument(
|
||||
'--nbest', type=int, default=1, help='Output N-best hypotheses')
|
||||
group.add_argument('--beam_size', type=int, default=20, help='Beam size')
|
||||
group.add_argument(
|
||||
'--penalty', type=float, default=0.0, help='Insertion penalty')
|
||||
group.add_argument(
|
||||
'--maxlenratio',
|
||||
type=float,
|
||||
default=0.0,
|
||||
help='Input length ratio to obtain max output length. '
|
||||
'If maxlenratio=0.0 (default), it uses a end-detect '
|
||||
'function '
|
||||
'to automatically find maximum hypothesis lengths.'
|
||||
'If maxlenratio<0.0, its absolute value is interpreted'
|
||||
'as a constant max output length',
|
||||
)
|
||||
group.add_argument(
|
||||
'--minlenratio',
|
||||
type=float,
|
||||
default=0.0,
|
||||
help='Input length ratio to obtain min output length',
|
||||
)
|
||||
group.add_argument(
|
||||
'--ctc_weight',
|
||||
type=float,
|
||||
default=0.5,
|
||||
help='CTC weight in joint decoding',
|
||||
)
|
||||
group.add_argument(
|
||||
'--lm_weight', type=float, default=1.0, help='RNNLM weight')
|
||||
group.add_argument(
|
||||
'--ngram_weight', type=float, default=0.9, help='ngram weight')
|
||||
group.add_argument('--streaming', type=str2bool, default=False)
|
||||
|
||||
group.add_argument(
|
||||
'--frontend_conf',
|
||||
default=None,
|
||||
help='',
|
||||
)
|
||||
|
||||
group = parser.add_argument_group('Text converter related')
|
||||
group.add_argument(
|
||||
'--token_type',
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
choices=['char', 'bpe', None],
|
||||
help='The token type for ASR model. '
|
||||
'If not given, refers from the training args',
|
||||
)
|
||||
group.add_argument(
|
||||
'--bpemodel',
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help='The model path of sentencepiece. '
|
||||
'If not given, refers from the training args',
|
||||
)
|
||||
group.add_argument(
|
||||
'--transducer_conf',
|
||||
default=None,
|
||||
help='The keyword arguments for transducer beam search.',
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def asr_inference(
|
||||
output_dir: str,
|
||||
maxlenratio: float,
|
||||
minlenratio: float,
|
||||
beam_size: int,
|
||||
ngpu: int,
|
||||
ctc_weight: float,
|
||||
lm_weight: float,
|
||||
penalty: float,
|
||||
data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
|
||||
asr_train_config: Optional[str],
|
||||
asr_model_file: Optional[str],
|
||||
nbest: int = 1,
|
||||
num_workers: int = 1,
|
||||
log_level: Union[int, str] = 'INFO',
|
||||
batch_size: int = 1,
|
||||
dtype: str = 'float32',
|
||||
seed: int = 0,
|
||||
key_file: Optional[str] = None,
|
||||
lm_train_config: Optional[str] = None,
|
||||
lm_file: Optional[str] = None,
|
||||
word_lm_train_config: Optional[str] = None,
|
||||
word_lm_file: Optional[str] = None,
|
||||
ngram_file: Optional[str] = None,
|
||||
ngram_weight: float = 0.9,
|
||||
model_tag: Optional[str] = None,
|
||||
token_type: Optional[str] = None,
|
||||
bpemodel: Optional[str] = None,
|
||||
allow_variable_data_keys: bool = False,
|
||||
transducer_conf: Optional[dict] = None,
|
||||
streaming: bool = False,
|
||||
frontend_conf: dict = None,
|
||||
):
|
||||
inference(
|
||||
output_dir=output_dir,
|
||||
maxlenratio=maxlenratio,
|
||||
minlenratio=minlenratio,
|
||||
batch_size=batch_size,
|
||||
dtype=dtype,
|
||||
beam_size=beam_size,
|
||||
ngpu=ngpu,
|
||||
seed=seed,
|
||||
ctc_weight=ctc_weight,
|
||||
lm_weight=lm_weight,
|
||||
ngram_weight=ngram_weight,
|
||||
penalty=penalty,
|
||||
nbest=nbest,
|
||||
num_workers=num_workers,
|
||||
log_level=log_level,
|
||||
data_path_and_name_and_type=data_path_and_name_and_type,
|
||||
key_file=key_file,
|
||||
asr_train_config=asr_train_config,
|
||||
asr_model_file=asr_model_file,
|
||||
lm_train_config=lm_train_config,
|
||||
lm_file=lm_file,
|
||||
word_lm_train_config=word_lm_train_config,
|
||||
word_lm_file=word_lm_file,
|
||||
ngram_file=ngram_file,
|
||||
model_tag=model_tag,
|
||||
token_type=token_type,
|
||||
bpemodel=bpemodel,
|
||||
allow_variable_data_keys=allow_variable_data_keys,
|
||||
transducer_conf=transducer_conf,
|
||||
streaming=streaming,
|
||||
frontend_conf=frontend_conf)
|
||||
|
||||
|
||||
def main(cmd=None):
|
||||
print(get_commandline_args(), file=sys.stderr)
|
||||
parser = get_parser()
|
||||
args = parser.parse_args(cmd)
|
||||
kwargs = vars(args)
|
||||
kwargs.pop('config', None)
|
||||
inference(**kwargs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
193
modelscope/pipelines/audio/asr/asr_engine/common/asr_utils.py
Normal file
193
modelscope/pipelines/audio/asr/asr_engine/common/asr_utils.py
Normal file
@@ -0,0 +1,193 @@
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def type_checking(wav_path: str,
|
||||
recog_type: str = None,
|
||||
audio_format: str = None,
|
||||
workspace: str = None):
|
||||
assert os.path.exists(wav_path), f'wav_path:{wav_path} does not exist'
|
||||
|
||||
r_recog_type = recog_type
|
||||
r_audio_format = audio_format
|
||||
r_workspace = workspace
|
||||
r_wav_path = wav_path
|
||||
|
||||
if r_workspace is None or len(r_workspace) == 0:
|
||||
r_workspace = os.path.join(os.getcwd(), '.tmp')
|
||||
|
||||
if r_recog_type is None:
|
||||
if os.path.isfile(wav_path):
|
||||
if wav_path.endswith('.wav') or wav_path.endswith('.WAV'):
|
||||
r_recog_type = 'wav'
|
||||
r_audio_format = 'wav'
|
||||
|
||||
elif os.path.isdir(wav_path):
|
||||
dir_name = os.path.basename(wav_path)
|
||||
if 'test' in dir_name:
|
||||
r_recog_type = 'test'
|
||||
elif 'dev' in dir_name:
|
||||
r_recog_type = 'dev'
|
||||
elif 'train' in dir_name:
|
||||
r_recog_type = 'train'
|
||||
|
||||
if r_audio_format is None:
|
||||
if find_file_by_ends(wav_path, '.ark'):
|
||||
r_audio_format = 'kaldi_ark'
|
||||
elif find_file_by_ends(wav_path, '.wav') or find_file_by_ends(
|
||||
wav_path, '.WAV'):
|
||||
r_audio_format = 'wav'
|
||||
|
||||
if r_audio_format == 'kaldi_ark' and r_recog_type != 'wav':
|
||||
# datasets with kaldi_ark file
|
||||
r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../'))
|
||||
elif r_audio_format == 'wav' and r_recog_type != 'wav':
|
||||
# datasets with waveform files
|
||||
r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../../'))
|
||||
|
||||
return r_recog_type, r_audio_format, r_workspace, r_wav_path
|
||||
|
||||
|
||||
def find_file_by_ends(dir_path: str, ends: str):
|
||||
dir_files = os.listdir(dir_path)
|
||||
for file in dir_files:
|
||||
file_path = os.path.join(dir_path, file)
|
||||
if os.path.isfile(file_path):
|
||||
if file_path.endswith(ends):
|
||||
return True
|
||||
elif os.path.isdir(file_path):
|
||||
if find_file_by_ends(file_path, ends):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def compute_wer(hyp_text_path: str, ref_text_path: str) -> Dict[str, Any]:
|
||||
assert os.path.exists(hyp_text_path), 'hyp_text does not exist'
|
||||
assert os.path.exists(ref_text_path), 'ref_text does not exist'
|
||||
|
||||
rst = {
|
||||
'Wrd': 0,
|
||||
'Corr': 0,
|
||||
'Ins': 0,
|
||||
'Del': 0,
|
||||
'Sub': 0,
|
||||
'Snt': 0,
|
||||
'Err': 0.0,
|
||||
'S.Err': 0.0,
|
||||
'wrong_words': 0,
|
||||
'wrong_sentences': 0
|
||||
}
|
||||
|
||||
with open(ref_text_path, 'r', encoding='utf-8') as r:
|
||||
r_lines = r.readlines()
|
||||
|
||||
with open(hyp_text_path, 'r', encoding='utf-8') as h:
|
||||
h_lines = h.readlines()
|
||||
|
||||
for r_line in r_lines:
|
||||
r_line_item = r_line.split()
|
||||
r_key = r_line_item[0]
|
||||
r_sentence = r_line_item[1]
|
||||
for h_line in h_lines:
|
||||
# find sentence from hyp text
|
||||
if r_key in h_line:
|
||||
h_line_item = h_line.split()
|
||||
h_sentence = h_line_item[1]
|
||||
out_item = compute_wer_by_line(h_sentence, r_sentence)
|
||||
rst['Wrd'] += out_item['nwords']
|
||||
rst['Corr'] += out_item['cor']
|
||||
rst['wrong_words'] += out_item['wrong']
|
||||
rst['Ins'] += out_item['ins']
|
||||
rst['Del'] += out_item['del']
|
||||
rst['Sub'] += out_item['sub']
|
||||
rst['Snt'] += 1
|
||||
if out_item['wrong'] > 0:
|
||||
rst['wrong_sentences'] += 1
|
||||
|
||||
break
|
||||
|
||||
if rst['Wrd'] > 0:
|
||||
rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
|
||||
if rst['Snt'] > 0:
|
||||
rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2)
|
||||
|
||||
return rst
|
||||
|
||||
|
||||
def compute_wer_by_line(hyp: list, ref: list) -> Dict[str, Any]:
|
||||
len_hyp = len(hyp)
|
||||
len_ref = len(ref)
|
||||
cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16)
|
||||
|
||||
ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8)
|
||||
|
||||
for i in range(len_hyp + 1):
|
||||
cost_matrix[i][0] = i
|
||||
for j in range(len_ref + 1):
|
||||
cost_matrix[0][j] = j
|
||||
|
||||
for i in range(1, len_hyp + 1):
|
||||
for j in range(1, len_ref + 1):
|
||||
if hyp[i - 1] == ref[j - 1]:
|
||||
cost_matrix[i][j] = cost_matrix[i - 1][j - 1]
|
||||
else:
|
||||
substitution = cost_matrix[i - 1][j - 1] + 1
|
||||
insertion = cost_matrix[i - 1][j] + 1
|
||||
deletion = cost_matrix[i][j - 1] + 1
|
||||
|
||||
compare_val = [substitution, insertion, deletion]
|
||||
|
||||
min_val = min(compare_val)
|
||||
operation_idx = compare_val.index(min_val) + 1
|
||||
cost_matrix[i][j] = min_val
|
||||
ops_matrix[i][j] = operation_idx
|
||||
|
||||
match_idx = []
|
||||
i = len_hyp
|
||||
j = len_ref
|
||||
rst = {
|
||||
'nwords': len_hyp,
|
||||
'cor': 0,
|
||||
'wrong': 0,
|
||||
'ins': 0,
|
||||
'del': 0,
|
||||
'sub': 0
|
||||
}
|
||||
while i >= 0 or j >= 0:
|
||||
i_idx = max(0, i)
|
||||
j_idx = max(0, j)
|
||||
|
||||
if ops_matrix[i_idx][j_idx] == 0: # correct
|
||||
if i - 1 >= 0 and j - 1 >= 0:
|
||||
match_idx.append((j - 1, i - 1))
|
||||
rst['cor'] += 1
|
||||
|
||||
i -= 1
|
||||
j -= 1
|
||||
|
||||
elif ops_matrix[i_idx][j_idx] == 2: # insert
|
||||
i -= 1
|
||||
rst['ins'] += 1
|
||||
|
||||
elif ops_matrix[i_idx][j_idx] == 3: # delete
|
||||
j -= 1
|
||||
rst['del'] += 1
|
||||
|
||||
elif ops_matrix[i_idx][j_idx] == 1: # substitute
|
||||
i -= 1
|
||||
j -= 1
|
||||
rst['sub'] += 1
|
||||
|
||||
if i < 0 and j >= 0:
|
||||
rst['del'] += 1
|
||||
elif j < 0 and i >= 0:
|
||||
rst['ins'] += 1
|
||||
|
||||
match_idx.reverse()
|
||||
wrong_cnt = cost_matrix[len_hyp][len_ref]
|
||||
rst['wrong'] = wrong_cnt
|
||||
|
||||
return rst
|
||||
@@ -0,0 +1,757 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# Part of the implementation is borrowed from espnet/espnet.
|
||||
"""Decoder definition."""
|
||||
from typing import Any, List, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from espnet2.asr.decoder.abs_decoder import AbsDecoder
|
||||
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask
|
||||
from espnet.nets.pytorch_backend.transformer.attention import \
|
||||
MultiHeadedAttention
|
||||
from espnet.nets.pytorch_backend.transformer.decoder_layer import DecoderLayer
|
||||
from espnet.nets.pytorch_backend.transformer.dynamic_conv import \
|
||||
DynamicConvolution
|
||||
from espnet.nets.pytorch_backend.transformer.dynamic_conv2d import \
|
||||
DynamicConvolution2D
|
||||
from espnet.nets.pytorch_backend.transformer.embedding import \
|
||||
PositionalEncoding
|
||||
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
|
||||
from espnet.nets.pytorch_backend.transformer.lightconv import \
|
||||
LightweightConvolution
|
||||
from espnet.nets.pytorch_backend.transformer.lightconv2d import \
|
||||
LightweightConvolution2D
|
||||
from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask
|
||||
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import \
|
||||
PositionwiseFeedForward # noqa: H301
|
||||
from espnet.nets.pytorch_backend.transformer.repeat import repeat
|
||||
from espnet.nets.scorer_interface import BatchScorerInterface
|
||||
from typeguard import check_argument_types
|
||||
|
||||
|
||||
class BaseTransformerDecoder(AbsDecoder, BatchScorerInterface):
|
||||
"""Base class of Transfomer decoder module.
|
||||
|
||||
Args:
|
||||
vocab_size: output dim
|
||||
encoder_output_size: dimension of attention
|
||||
attention_heads: the number of heads of multi head attention
|
||||
linear_units: the number of units of position-wise feed forward
|
||||
num_blocks: the number of decoder blocks
|
||||
dropout_rate: dropout rate
|
||||
self_attention_dropout_rate: dropout rate for attention
|
||||
input_layer: input layer type
|
||||
use_output_layer: whether to use output layer
|
||||
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
|
||||
normalize_before: whether to use layer_norm before the first block
|
||||
concat_after: whether to concat attention layer's input and output
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied.
|
||||
i.e. x -> x + att(x)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
input_layer: str = 'embed',
|
||||
use_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
attention_dim = encoder_output_size
|
||||
|
||||
if input_layer == 'embed':
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(vocab_size, attention_dim),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == 'linear':
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(vocab_size, attention_dim),
|
||||
torch.nn.LayerNorm(attention_dim),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
torch.nn.ReLU(),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"only 'embed' or 'linear' is supported: {input_layer}")
|
||||
|
||||
self.normalize_before = normalize_before
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(attention_dim)
|
||||
if use_output_layer:
|
||||
self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
|
||||
else:
|
||||
self.output_layer = None
|
||||
|
||||
# Must set by the inheritance
|
||||
self.decoders = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hs_pad: torch.Tensor,
|
||||
hlens: torch.Tensor,
|
||||
ys_in_pad: torch.Tensor,
|
||||
ys_in_lens: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward decoder.
|
||||
|
||||
Args:
|
||||
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
|
||||
hlens: (batch)
|
||||
ys_in_pad:
|
||||
input token ids, int64 (batch, maxlen_out)
|
||||
if input_layer == "embed"
|
||||
input tensor (batch, maxlen_out, #mels) in the other cases
|
||||
ys_in_lens: (batch)
|
||||
Returns:
|
||||
(tuple): tuple containing:
|
||||
|
||||
x: decoded token score before softmax (batch, maxlen_out, token)
|
||||
if use_output_layer is True,
|
||||
olens: (batch, )
|
||||
"""
|
||||
tgt = ys_in_pad
|
||||
# tgt_mask: (B, 1, L)
|
||||
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
|
||||
# m: (1, L, L)
|
||||
m = subsequent_mask(
|
||||
tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
|
||||
# tgt_mask: (B, L, L)
|
||||
tgt_mask = tgt_mask & m
|
||||
|
||||
memory = hs_pad
|
||||
memory_mask = (
|
||||
~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
|
||||
memory.device)
|
||||
# Padding for Longformer
|
||||
if memory_mask.shape[-1] != memory.shape[1]:
|
||||
padlen = memory.shape[1] - memory_mask.shape[-1]
|
||||
memory_mask = torch.nn.functional.pad(memory_mask, (0, padlen),
|
||||
'constant', False)
|
||||
|
||||
x = self.embed(tgt)
|
||||
x, tgt_mask, memory, memory_mask = self.decoders(
|
||||
x, tgt_mask, memory, memory_mask)
|
||||
if self.normalize_before:
|
||||
x = self.after_norm(x)
|
||||
if self.output_layer is not None:
|
||||
x = self.output_layer(x)
|
||||
|
||||
olens = tgt_mask.sum(1)
|
||||
return x, olens
|
||||
|
||||
def forward_one_step(
|
||||
self,
|
||||
tgt: torch.Tensor,
|
||||
tgt_mask: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
cache: List[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
"""Forward one step.
|
||||
|
||||
Args:
|
||||
tgt: input token ids, int64 (batch, maxlen_out)
|
||||
tgt_mask: input token mask, (batch, maxlen_out)
|
||||
dtype=torch.uint8 in PyTorch 1.2-
|
||||
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
|
||||
memory: encoded memory, float32 (batch, maxlen_in, feat)
|
||||
cache: cached output list of (batch, max_time_out-1, size)
|
||||
Returns:
|
||||
y, cache: NN output value and cache per `self.decoders`.
|
||||
y.shape` is (batch, maxlen_out, token)
|
||||
"""
|
||||
x = self.embed(tgt)
|
||||
if cache is None:
|
||||
cache = [None] * len(self.decoders)
|
||||
new_cache = []
|
||||
for c, decoder in zip(cache, self.decoders):
|
||||
x, tgt_mask, memory, memory_mask = decoder(
|
||||
x, tgt_mask, memory, None, cache=c)
|
||||
new_cache.append(x)
|
||||
|
||||
if self.normalize_before:
|
||||
y = self.after_norm(x[:, -1])
|
||||
else:
|
||||
y = x[:, -1]
|
||||
if self.output_layer is not None:
|
||||
y = torch.log_softmax(self.output_layer(y), dim=-1)
|
||||
|
||||
return y, new_cache
|
||||
|
||||
def score(self, ys, state, x):
|
||||
"""Score."""
|
||||
ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
|
||||
logp, state = self.forward_one_step(
|
||||
ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state)
|
||||
return logp.squeeze(0), state
|
||||
|
||||
def batch_score(self, ys: torch.Tensor, states: List[Any],
|
||||
xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
|
||||
"""Score new token batch.
|
||||
|
||||
Args:
|
||||
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
|
||||
states (List[Any]): Scorer states for prefix tokens.
|
||||
xs (torch.Tensor):
|
||||
The encoder feature that generates ys (n_batch, xlen, n_feat).
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, List[Any]]: Tuple of
|
||||
batchfied scores for next token with shape of `(n_batch, n_vocab)`
|
||||
and next state list for ys.
|
||||
|
||||
"""
|
||||
# merge states
|
||||
n_batch = len(ys)
|
||||
n_layers = len(self.decoders)
|
||||
if states[0] is None:
|
||||
batch_state = None
|
||||
else:
|
||||
# transpose state of [batch, layer] into [layer, batch]
|
||||
batch_state = [
|
||||
torch.stack([states[b][i] for b in range(n_batch)])
|
||||
for i in range(n_layers)
|
||||
]
|
||||
|
||||
# batch decoding
|
||||
ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0)
|
||||
logp, states = self.forward_one_step(
|
||||
ys, ys_mask, xs, cache=batch_state)
|
||||
|
||||
# transpose state of [layer, batch] into [batch, layer]
|
||||
state_list = [[states[i][b] for i in range(n_layers)]
|
||||
for b in range(n_batch)]
|
||||
return logp, state_list
|
||||
|
||||
|
||||
class TransformerDecoder(BaseTransformerDecoder):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
self_attention_dropout_rate: float = 0.0,
|
||||
src_attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = 'embed',
|
||||
use_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
dropout_rate=dropout_rate,
|
||||
positional_dropout_rate=positional_dropout_rate,
|
||||
input_layer=input_layer,
|
||||
use_output_layer=use_output_layer,
|
||||
pos_enc_class=pos_enc_class,
|
||||
normalize_before=normalize_before,
|
||||
)
|
||||
|
||||
attention_dim = encoder_output_size
|
||||
self.decoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: DecoderLayer(
|
||||
attention_dim,
|
||||
MultiHeadedAttention(attention_heads, attention_dim,
|
||||
self_attention_dropout_rate),
|
||||
MultiHeadedAttention(attention_heads, attention_dim,
|
||||
src_attention_dropout_rate),
|
||||
PositionwiseFeedForward(attention_dim, linear_units,
|
||||
dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ParaformerDecoder(TransformerDecoder):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
self_attention_dropout_rate: float = 0.0,
|
||||
src_attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = 'embed',
|
||||
use_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
dropout_rate=dropout_rate,
|
||||
positional_dropout_rate=positional_dropout_rate,
|
||||
input_layer=input_layer,
|
||||
use_output_layer=use_output_layer,
|
||||
pos_enc_class=pos_enc_class,
|
||||
normalize_before=normalize_before,
|
||||
)
|
||||
|
||||
attention_dim = encoder_output_size
|
||||
self.decoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: DecoderLayer(
|
||||
attention_dim,
|
||||
MultiHeadedAttention(attention_heads, attention_dim,
|
||||
self_attention_dropout_rate),
|
||||
MultiHeadedAttention(attention_heads, attention_dim,
|
||||
src_attention_dropout_rate),
|
||||
PositionwiseFeedForward(attention_dim, linear_units,
|
||||
dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hs_pad: torch.Tensor,
|
||||
hlens: torch.Tensor,
|
||||
ys_in_pad: torch.Tensor,
|
||||
ys_in_lens: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward decoder.
|
||||
|
||||
Args:
|
||||
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
|
||||
hlens: (batch)
|
||||
ys_in_pad:
|
||||
input token ids, int64 (batch, maxlen_out)
|
||||
if input_layer == "embed"
|
||||
input tensor (batch, maxlen_out, #mels) in the other cases
|
||||
ys_in_lens: (batch)
|
||||
Returns:
|
||||
(tuple): tuple containing:
|
||||
|
||||
x: decoded token score before softmax (batch, maxlen_out, token)
|
||||
if use_output_layer is True,
|
||||
olens: (batch, )
|
||||
"""
|
||||
tgt = ys_in_pad
|
||||
# tgt_mask: (B, 1, L)
|
||||
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
|
||||
# m: (1, L, L)
|
||||
# m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
|
||||
# tgt_mask: (B, L, L)
|
||||
# tgt_mask = tgt_mask & m
|
||||
|
||||
memory = hs_pad
|
||||
memory_mask = (
|
||||
~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
|
||||
memory.device)
|
||||
# Padding for Longformer
|
||||
if memory_mask.shape[-1] != memory.shape[1]:
|
||||
padlen = memory.shape[1] - memory_mask.shape[-1]
|
||||
memory_mask = torch.nn.functional.pad(memory_mask, (0, padlen),
|
||||
'constant', False)
|
||||
|
||||
# x = self.embed(tgt)
|
||||
x = tgt
|
||||
x, tgt_mask, memory, memory_mask = self.decoders(
|
||||
x, tgt_mask, memory, memory_mask)
|
||||
if self.normalize_before:
|
||||
x = self.after_norm(x)
|
||||
if self.output_layer is not None:
|
||||
x = self.output_layer(x)
|
||||
|
||||
olens = tgt_mask.sum(1)
|
||||
return x, olens
|
||||
|
||||
|
||||
class ParaformerDecoderBertEmbed(TransformerDecoder):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
self_attention_dropout_rate: float = 0.0,
|
||||
src_attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = 'embed',
|
||||
use_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
embeds_id: int = 2,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
dropout_rate=dropout_rate,
|
||||
positional_dropout_rate=positional_dropout_rate,
|
||||
input_layer=input_layer,
|
||||
use_output_layer=use_output_layer,
|
||||
pos_enc_class=pos_enc_class,
|
||||
normalize_before=normalize_before,
|
||||
)
|
||||
|
||||
attention_dim = encoder_output_size
|
||||
self.decoders = repeat(
|
||||
embeds_id,
|
||||
lambda lnum: DecoderLayer(
|
||||
attention_dim,
|
||||
MultiHeadedAttention(attention_heads, attention_dim,
|
||||
self_attention_dropout_rate),
|
||||
MultiHeadedAttention(attention_heads, attention_dim,
|
||||
src_attention_dropout_rate),
|
||||
PositionwiseFeedForward(attention_dim, linear_units,
|
||||
dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
if embeds_id == num_blocks:
|
||||
self.decoders2 = None
|
||||
else:
|
||||
self.decoders2 = repeat(
|
||||
num_blocks - embeds_id,
|
||||
lambda lnum: DecoderLayer(
|
||||
attention_dim,
|
||||
MultiHeadedAttention(attention_heads, attention_dim,
|
||||
self_attention_dropout_rate),
|
||||
MultiHeadedAttention(attention_heads, attention_dim,
|
||||
src_attention_dropout_rate),
|
||||
PositionwiseFeedForward(attention_dim, linear_units,
|
||||
dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hs_pad: torch.Tensor,
|
||||
hlens: torch.Tensor,
|
||||
ys_in_pad: torch.Tensor,
|
||||
ys_in_lens: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward decoder.
|
||||
|
||||
Args:
|
||||
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
|
||||
hlens: (batch)
|
||||
ys_in_pad:
|
||||
input token ids, int64 (batch, maxlen_out)
|
||||
if input_layer == "embed"
|
||||
input tensor (batch, maxlen_out, #mels) in the other cases
|
||||
ys_in_lens: (batch)
|
||||
Returns:
|
||||
(tuple): tuple containing:
|
||||
|
||||
x: decoded token score before softmax (batch, maxlen_out, token)
|
||||
if use_output_layer is True,
|
||||
olens: (batch, )
|
||||
"""
|
||||
tgt = ys_in_pad
|
||||
# tgt_mask: (B, 1, L)
|
||||
tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
|
||||
# m: (1, L, L)
|
||||
# m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
|
||||
# tgt_mask: (B, L, L)
|
||||
# tgt_mask = tgt_mask & m
|
||||
|
||||
memory = hs_pad
|
||||
memory_mask = (
|
||||
~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
|
||||
memory.device)
|
||||
# Padding for Longformer
|
||||
if memory_mask.shape[-1] != memory.shape[1]:
|
||||
padlen = memory.shape[1] - memory_mask.shape[-1]
|
||||
memory_mask = torch.nn.functional.pad(memory_mask, (0, padlen),
|
||||
'constant', False)
|
||||
|
||||
# x = self.embed(tgt)
|
||||
x = tgt
|
||||
x, tgt_mask, memory, memory_mask = self.decoders(
|
||||
x, tgt_mask, memory, memory_mask)
|
||||
embeds_outputs = x
|
||||
if self.decoders2 is not None:
|
||||
x, tgt_mask, memory, memory_mask = self.decoders2(
|
||||
x, tgt_mask, memory, memory_mask)
|
||||
if self.normalize_before:
|
||||
x = self.after_norm(x)
|
||||
if self.output_layer is not None:
|
||||
x = self.output_layer(x)
|
||||
|
||||
olens = tgt_mask.sum(1)
|
||||
return x, olens, embeds_outputs
|
||||
|
||||
|
||||
class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
self_attention_dropout_rate: float = 0.0,
|
||||
src_attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = 'embed',
|
||||
use_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
conv_wshare: int = 4,
|
||||
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
|
||||
conv_usebias: int = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
if len(conv_kernel_length) != num_blocks:
|
||||
raise ValueError(
|
||||
'conv_kernel_length must have equal number of values to num_blocks: '
|
||||
f'{len(conv_kernel_length)} != {num_blocks}')
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
dropout_rate=dropout_rate,
|
||||
positional_dropout_rate=positional_dropout_rate,
|
||||
input_layer=input_layer,
|
||||
use_output_layer=use_output_layer,
|
||||
pos_enc_class=pos_enc_class,
|
||||
normalize_before=normalize_before,
|
||||
)
|
||||
|
||||
attention_dim = encoder_output_size
|
||||
self.decoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: DecoderLayer(
|
||||
attention_dim,
|
||||
LightweightConvolution(
|
||||
wshare=conv_wshare,
|
||||
n_feat=attention_dim,
|
||||
dropout_rate=self_attention_dropout_rate,
|
||||
kernel_size=conv_kernel_length[lnum],
|
||||
use_kernel_mask=True,
|
||||
use_bias=conv_usebias,
|
||||
),
|
||||
MultiHeadedAttention(attention_heads, attention_dim,
|
||||
src_attention_dropout_rate),
|
||||
PositionwiseFeedForward(attention_dim, linear_units,
|
||||
dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
self_attention_dropout_rate: float = 0.0,
|
||||
src_attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = 'embed',
|
||||
use_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
conv_wshare: int = 4,
|
||||
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
|
||||
conv_usebias: int = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
if len(conv_kernel_length) != num_blocks:
|
||||
raise ValueError(
|
||||
'conv_kernel_length must have equal number of values to num_blocks: '
|
||||
f'{len(conv_kernel_length)} != {num_blocks}')
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
dropout_rate=dropout_rate,
|
||||
positional_dropout_rate=positional_dropout_rate,
|
||||
input_layer=input_layer,
|
||||
use_output_layer=use_output_layer,
|
||||
pos_enc_class=pos_enc_class,
|
||||
normalize_before=normalize_before,
|
||||
)
|
||||
|
||||
attention_dim = encoder_output_size
|
||||
self.decoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: DecoderLayer(
|
||||
attention_dim,
|
||||
LightweightConvolution2D(
|
||||
wshare=conv_wshare,
|
||||
n_feat=attention_dim,
|
||||
dropout_rate=self_attention_dropout_rate,
|
||||
kernel_size=conv_kernel_length[lnum],
|
||||
use_kernel_mask=True,
|
||||
use_bias=conv_usebias,
|
||||
),
|
||||
MultiHeadedAttention(attention_heads, attention_dim,
|
||||
src_attention_dropout_rate),
|
||||
PositionwiseFeedForward(attention_dim, linear_units,
|
||||
dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
self_attention_dropout_rate: float = 0.0,
|
||||
src_attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = 'embed',
|
||||
use_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
conv_wshare: int = 4,
|
||||
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
|
||||
conv_usebias: int = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
if len(conv_kernel_length) != num_blocks:
|
||||
raise ValueError(
|
||||
'conv_kernel_length must have equal number of values to num_blocks: '
|
||||
f'{len(conv_kernel_length)} != {num_blocks}')
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
dropout_rate=dropout_rate,
|
||||
positional_dropout_rate=positional_dropout_rate,
|
||||
input_layer=input_layer,
|
||||
use_output_layer=use_output_layer,
|
||||
pos_enc_class=pos_enc_class,
|
||||
normalize_before=normalize_before,
|
||||
)
|
||||
attention_dim = encoder_output_size
|
||||
|
||||
self.decoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: DecoderLayer(
|
||||
attention_dim,
|
||||
DynamicConvolution(
|
||||
wshare=conv_wshare,
|
||||
n_feat=attention_dim,
|
||||
dropout_rate=self_attention_dropout_rate,
|
||||
kernel_size=conv_kernel_length[lnum],
|
||||
use_kernel_mask=True,
|
||||
use_bias=conv_usebias,
|
||||
),
|
||||
MultiHeadedAttention(attention_heads, attention_dim,
|
||||
src_attention_dropout_rate),
|
||||
PositionwiseFeedForward(attention_dim, linear_units,
|
||||
dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
self_attention_dropout_rate: float = 0.0,
|
||||
src_attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = 'embed',
|
||||
use_output_layer: bool = True,
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
conv_wshare: int = 4,
|
||||
conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
|
||||
conv_usebias: int = False,
|
||||
):
|
||||
assert check_argument_types()
|
||||
if len(conv_kernel_length) != num_blocks:
|
||||
raise ValueError(
|
||||
'conv_kernel_length must have equal number of values to num_blocks: '
|
||||
f'{len(conv_kernel_length)} != {num_blocks}')
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
dropout_rate=dropout_rate,
|
||||
positional_dropout_rate=positional_dropout_rate,
|
||||
input_layer=input_layer,
|
||||
use_output_layer=use_output_layer,
|
||||
pos_enc_class=pos_enc_class,
|
||||
normalize_before=normalize_before,
|
||||
)
|
||||
attention_dim = encoder_output_size
|
||||
|
||||
self.decoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: DecoderLayer(
|
||||
attention_dim,
|
||||
DynamicConvolution2D(
|
||||
wshare=conv_wshare,
|
||||
n_feat=attention_dim,
|
||||
dropout_rate=self_attention_dropout_rate,
|
||||
kernel_size=conv_kernel_length[lnum],
|
||||
use_kernel_mask=True,
|
||||
use_bias=conv_usebias,
|
||||
),
|
||||
MultiHeadedAttention(attention_heads, attention_dim,
|
||||
src_attention_dropout_rate),
|
||||
PositionwiseFeedForward(attention_dim, linear_units,
|
||||
dropout_rate),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,710 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# Part of the implementation is borrowed from espnet/espnet.
|
||||
"""Conformer encoder definition."""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from espnet2.asr.ctc import CTC
|
||||
from espnet2.asr.encoder.abs_encoder import AbsEncoder
|
||||
from espnet.nets.pytorch_backend.conformer.convolution import ConvolutionModule
|
||||
from espnet.nets.pytorch_backend.conformer.encoder_layer import EncoderLayer
|
||||
from espnet.nets.pytorch_backend.nets_utils import (get_activation,
|
||||
make_pad_mask)
|
||||
from espnet.nets.pytorch_backend.transformer.embedding import \
|
||||
LegacyRelPositionalEncoding # noqa: H301
|
||||
from espnet.nets.pytorch_backend.transformer.embedding import \
|
||||
PositionalEncoding # noqa: H301
|
||||
from espnet.nets.pytorch_backend.transformer.embedding import \
|
||||
RelPositionalEncoding # noqa: H301
|
||||
from espnet.nets.pytorch_backend.transformer.embedding import \
|
||||
ScaledPositionalEncoding # noqa: H301
|
||||
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
|
||||
from espnet.nets.pytorch_backend.transformer.multi_layer_conv import (
|
||||
Conv1dLinear, MultiLayeredConv1d)
|
||||
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import \
|
||||
PositionwiseFeedForward # noqa: H301
|
||||
from espnet.nets.pytorch_backend.transformer.repeat import repeat
|
||||
from espnet.nets.pytorch_backend.transformer.subsampling import (
|
||||
Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6,
|
||||
Conv2dSubsampling8, TooShortUttError, check_short_utt)
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from ...nets.pytorch_backend.transformer.attention import \
|
||||
LegacyRelPositionMultiHeadedAttention # noqa: H301
|
||||
from ...nets.pytorch_backend.transformer.attention import \
|
||||
MultiHeadedAttention # noqa: H301
|
||||
from ...nets.pytorch_backend.transformer.attention import \
|
||||
RelPositionMultiHeadedAttention # noqa: H301
|
||||
from ...nets.pytorch_backend.transformer.attention import (
|
||||
LegacyRelPositionMultiHeadedAttentionSANM,
|
||||
RelPositionMultiHeadedAttentionSANM)
|
||||
|
||||
|
||||
class ConformerEncoder(AbsEncoder):
|
||||
"""Conformer encoder module.
|
||||
|
||||
Args:
|
||||
input_size (int): Input dimension.
|
||||
output_size (int): Dimension of attention.
|
||||
attention_heads (int): The number of heads of multi head attention.
|
||||
linear_units (int): The number of units of position-wise feed forward.
|
||||
num_blocks (int): The number of decoder blocks.
|
||||
dropout_rate (float): Dropout rate.
|
||||
attention_dropout_rate (float): Dropout rate in attention.
|
||||
positional_dropout_rate (float): Dropout rate after adding positional encoding.
|
||||
input_layer (Union[str, torch.nn.Module]): Input layer type.
|
||||
normalize_before (bool): Whether to use layer_norm before the first block.
|
||||
concat_after (bool): Whether to concat attention layer's input and output.
|
||||
If True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
If False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||
positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
|
||||
positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
|
||||
rel_pos_type (str): Whether to use the latest relative positional encoding or
|
||||
the legacy one. The legacy relative positional encoding will be deprecated
|
||||
in the future. More Details can be found in
|
||||
https://github.com/espnet/espnet/pull/2816.
|
||||
encoder_pos_enc_layer_type (str): Encoder positional encoding layer type.
|
||||
encoder_attn_layer_type (str): Encoder attention layer type.
|
||||
activation_type (str): Encoder activation function type.
|
||||
macaron_style (bool): Whether to use macaron style for positionwise layer.
|
||||
use_cnn_module (bool): Whether to use convolution module.
|
||||
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
|
||||
cnn_module_kernel (int): Kernerl size of convolution module.
|
||||
padding_idx (int): Padding idx for input_layer=embed.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = 'conv2d',
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
positionwise_layer_type: str = 'linear',
|
||||
positionwise_conv_kernel_size: int = 3,
|
||||
macaron_style: bool = False,
|
||||
rel_pos_type: str = 'legacy',
|
||||
pos_enc_layer_type: str = 'rel_pos',
|
||||
selfattention_layer_type: str = 'rel_selfattn',
|
||||
activation_type: str = 'swish',
|
||||
use_cnn_module: bool = True,
|
||||
zero_triu: bool = False,
|
||||
cnn_module_kernel: int = 31,
|
||||
padding_idx: int = -1,
|
||||
interctc_layer_idx: List[int] = [],
|
||||
interctc_use_conditioning: bool = False,
|
||||
stochastic_depth_rate: Union[float, List[float]] = 0.0,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self._output_size = output_size
|
||||
|
||||
if rel_pos_type == 'legacy':
|
||||
if pos_enc_layer_type == 'rel_pos':
|
||||
pos_enc_layer_type = 'legacy_rel_pos'
|
||||
if selfattention_layer_type == 'rel_selfattn':
|
||||
selfattention_layer_type = 'legacy_rel_selfattn'
|
||||
elif rel_pos_type == 'latest':
|
||||
assert selfattention_layer_type != 'legacy_rel_selfattn'
|
||||
assert pos_enc_layer_type != 'legacy_rel_pos'
|
||||
else:
|
||||
raise ValueError('unknown rel_pos_type: ' + rel_pos_type)
|
||||
|
||||
activation = get_activation(activation_type)
|
||||
if pos_enc_layer_type == 'abs_pos':
|
||||
pos_enc_class = PositionalEncoding
|
||||
elif pos_enc_layer_type == 'scaled_abs_pos':
|
||||
pos_enc_class = ScaledPositionalEncoding
|
||||
elif pos_enc_layer_type == 'rel_pos':
|
||||
assert selfattention_layer_type == 'rel_selfattn'
|
||||
pos_enc_class = RelPositionalEncoding
|
||||
elif pos_enc_layer_type == 'legacy_rel_pos':
|
||||
assert selfattention_layer_type == 'legacy_rel_selfattn'
|
||||
pos_enc_class = LegacyRelPositionalEncoding
|
||||
else:
|
||||
raise ValueError('unknown pos_enc_layer: ' + pos_enc_layer_type)
|
||||
|
||||
if input_layer == 'linear':
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(input_size, output_size),
|
||||
torch.nn.LayerNorm(output_size),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == 'conv2d':
|
||||
self.embed = Conv2dSubsampling(
|
||||
input_size,
|
||||
output_size,
|
||||
dropout_rate,
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == 'conv2d2':
|
||||
self.embed = Conv2dSubsampling2(
|
||||
input_size,
|
||||
output_size,
|
||||
dropout_rate,
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == 'conv2d6':
|
||||
self.embed = Conv2dSubsampling6(
|
||||
input_size,
|
||||
output_size,
|
||||
dropout_rate,
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == 'conv2d8':
|
||||
self.embed = Conv2dSubsampling8(
|
||||
input_size,
|
||||
output_size,
|
||||
dropout_rate,
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == 'embed':
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(
|
||||
input_size, output_size, padding_idx=padding_idx),
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif isinstance(input_layer, torch.nn.Module):
|
||||
self.embed = torch.nn.Sequential(
|
||||
input_layer,
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer is None:
|
||||
self.embed = torch.nn.Sequential(
|
||||
pos_enc_class(output_size, positional_dropout_rate))
|
||||
else:
|
||||
raise ValueError('unknown input_layer: ' + input_layer)
|
||||
self.normalize_before = normalize_before
|
||||
if positionwise_layer_type == 'linear':
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
activation,
|
||||
)
|
||||
elif positionwise_layer_type == 'conv1d':
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == 'conv1d-linear':
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError('Support only linear or conv1d.')
|
||||
|
||||
if selfattention_layer_type == 'selfattn':
|
||||
encoder_selfattn_layer = MultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
elif selfattention_layer_type == 'legacy_rel_selfattn':
|
||||
assert pos_enc_layer_type == 'legacy_rel_pos'
|
||||
encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
elif selfattention_layer_type == 'rel_selfattn':
|
||||
assert pos_enc_layer_type == 'rel_pos'
|
||||
encoder_selfattn_layer = RelPositionMultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
zero_triu,
|
||||
)
|
||||
else:
|
||||
raise ValueError('unknown encoder_attn_layer: '
|
||||
+ selfattention_layer_type)
|
||||
|
||||
convolution_layer = ConvolutionModule
|
||||
convolution_layer_args = (output_size, cnn_module_kernel, activation)
|
||||
|
||||
if isinstance(stochastic_depth_rate, float):
|
||||
stochastic_depth_rate = [stochastic_depth_rate] * num_blocks
|
||||
|
||||
if len(stochastic_depth_rate) != num_blocks:
|
||||
raise ValueError(
|
||||
f'Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) '
|
||||
f'should be equal to num_blocks ({num_blocks})')
|
||||
|
||||
self.encoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: EncoderLayer(
|
||||
output_size,
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args)
|
||||
if macaron_style else None,
|
||||
convolution_layer(*convolution_layer_args)
|
||||
if use_cnn_module else None,
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
stochastic_depth_rate[lnum],
|
||||
),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(output_size)
|
||||
|
||||
self.interctc_layer_idx = interctc_layer_idx
|
||||
if len(interctc_layer_idx) > 0:
|
||||
assert 0 < min(interctc_layer_idx) and max(
|
||||
interctc_layer_idx) < num_blocks
|
||||
self.interctc_use_conditioning = interctc_use_conditioning
|
||||
self.conditioning_layer = None
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self._output_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
ctc: CTC = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
|
||||
ilens (torch.Tensor): Input length (#batch).
|
||||
prev_states (torch.Tensor): Not to be used now.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, L, output_size).
|
||||
torch.Tensor: Output length (#batch).
|
||||
torch.Tensor: Not to be used now.
|
||||
|
||||
"""
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
|
||||
if (isinstance(self.embed, Conv2dSubsampling)
|
||||
or isinstance(self.embed, Conv2dSubsampling2)
|
||||
or isinstance(self.embed, Conv2dSubsampling6)
|
||||
or isinstance(self.embed, Conv2dSubsampling8)):
|
||||
short_status, limit_size = check_short_utt(self.embed,
|
||||
xs_pad.size(1))
|
||||
if short_status:
|
||||
raise TooShortUttError(
|
||||
f'has {xs_pad.size(1)} frames and is too short for subsampling '
|
||||
+ # noqa: *
|
||||
f'(it needs more than {limit_size} frames), return empty results', # noqa: *
|
||||
xs_pad.size(1),
|
||||
limit_size) # noqa: *
|
||||
xs_pad, masks = self.embed(xs_pad, masks)
|
||||
else:
|
||||
xs_pad = self.embed(xs_pad)
|
||||
|
||||
intermediate_outs = []
|
||||
if len(self.interctc_layer_idx) == 0:
|
||||
xs_pad, masks = self.encoders(xs_pad, masks)
|
||||
else:
|
||||
for layer_idx, encoder_layer in enumerate(self.encoders):
|
||||
xs_pad, masks = encoder_layer(xs_pad, masks)
|
||||
|
||||
if layer_idx + 1 in self.interctc_layer_idx:
|
||||
encoder_out = xs_pad
|
||||
if isinstance(encoder_out, tuple):
|
||||
encoder_out = encoder_out[0]
|
||||
|
||||
# intermediate outputs are also normalized
|
||||
if self.normalize_before:
|
||||
encoder_out = self.after_norm(encoder_out)
|
||||
|
||||
intermediate_outs.append((layer_idx + 1, encoder_out))
|
||||
|
||||
if self.interctc_use_conditioning:
|
||||
ctc_out = ctc.softmax(encoder_out)
|
||||
|
||||
if isinstance(xs_pad, tuple):
|
||||
x, pos_emb = xs_pad
|
||||
x = x + self.conditioning_layer(ctc_out)
|
||||
xs_pad = (x, pos_emb)
|
||||
else:
|
||||
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
|
||||
|
||||
if isinstance(xs_pad, tuple):
|
||||
xs_pad = xs_pad[0]
|
||||
if self.normalize_before:
|
||||
xs_pad = self.after_norm(xs_pad)
|
||||
|
||||
olens = masks.squeeze(1).sum(1)
|
||||
if len(intermediate_outs) > 0:
|
||||
return (xs_pad, intermediate_outs), olens, None
|
||||
return xs_pad, olens, None
|
||||
|
||||
|
||||
class SANMEncoder_v2(AbsEncoder):
|
||||
"""Conformer encoder module.
|
||||
|
||||
Args:
|
||||
input_size (int): Input dimension.
|
||||
output_size (int): Dimension of attention.
|
||||
attention_heads (int): The number of heads of multi head attention.
|
||||
linear_units (int): The number of units of position-wise feed forward.
|
||||
num_blocks (int): The number of decoder blocks.
|
||||
dropout_rate (float): Dropout rate.
|
||||
attention_dropout_rate (float): Dropout rate in attention.
|
||||
positional_dropout_rate (float): Dropout rate after adding positional encoding.
|
||||
input_layer (Union[str, torch.nn.Module]): Input layer type.
|
||||
normalize_before (bool): Whether to use layer_norm before the first block.
|
||||
concat_after (bool): Whether to concat attention layer's input and output.
|
||||
If True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
If False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||
positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
|
||||
positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
|
||||
rel_pos_type (str): Whether to use the latest relative positional encoding or
|
||||
the legacy one. The legacy relative positional encoding will be deprecated
|
||||
in the future. More Details can be found in
|
||||
https://github.com/espnet/espnet/pull/2816.
|
||||
encoder_pos_enc_layer_type (str): Encoder positional encoding layer type.
|
||||
encoder_attn_layer_type (str): Encoder attention layer type.
|
||||
activation_type (str): Encoder activation function type.
|
||||
macaron_style (bool): Whether to use macaron style for positionwise layer.
|
||||
use_cnn_module (bool): Whether to use convolution module.
|
||||
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
|
||||
cnn_module_kernel (int): Kernerl size of convolution module.
|
||||
padding_idx (int): Padding idx for input_layer=embed.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = 'conv2d',
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
positionwise_layer_type: str = 'linear',
|
||||
positionwise_conv_kernel_size: int = 3,
|
||||
macaron_style: bool = False,
|
||||
rel_pos_type: str = 'legacy',
|
||||
pos_enc_layer_type: str = 'rel_pos',
|
||||
selfattention_layer_type: str = 'rel_selfattn',
|
||||
activation_type: str = 'swish',
|
||||
use_cnn_module: bool = False,
|
||||
sanm_shfit: int = 0,
|
||||
zero_triu: bool = False,
|
||||
cnn_module_kernel: int = 31,
|
||||
padding_idx: int = -1,
|
||||
interctc_layer_idx: List[int] = [],
|
||||
interctc_use_conditioning: bool = False,
|
||||
stochastic_depth_rate: Union[float, List[float]] = 0.0,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self._output_size = output_size
|
||||
|
||||
if rel_pos_type == 'legacy':
|
||||
if pos_enc_layer_type == 'rel_pos':
|
||||
pos_enc_layer_type = 'legacy_rel_pos'
|
||||
if selfattention_layer_type == 'rel_selfattn':
|
||||
selfattention_layer_type = 'legacy_rel_selfattn'
|
||||
if selfattention_layer_type == 'rel_selfattnsanm':
|
||||
selfattention_layer_type = 'legacy_rel_selfattnsanm'
|
||||
|
||||
elif rel_pos_type == 'latest':
|
||||
assert selfattention_layer_type != 'legacy_rel_selfattn'
|
||||
assert pos_enc_layer_type != 'legacy_rel_pos'
|
||||
else:
|
||||
raise ValueError('unknown rel_pos_type: ' + rel_pos_type)
|
||||
|
||||
activation = get_activation(activation_type)
|
||||
if pos_enc_layer_type == 'abs_pos':
|
||||
pos_enc_class = PositionalEncoding
|
||||
elif pos_enc_layer_type == 'scaled_abs_pos':
|
||||
pos_enc_class = ScaledPositionalEncoding
|
||||
elif pos_enc_layer_type == 'rel_pos':
|
||||
# assert selfattention_layer_type == "rel_selfattn"
|
||||
pos_enc_class = RelPositionalEncoding
|
||||
elif pos_enc_layer_type == 'legacy_rel_pos':
|
||||
# assert selfattention_layer_type == "legacy_rel_selfattn"
|
||||
pos_enc_class = LegacyRelPositionalEncoding
|
||||
logging.warning(
|
||||
'Using legacy_rel_pos and it will be deprecated in the future.'
|
||||
)
|
||||
else:
|
||||
raise ValueError('unknown pos_enc_layer: ' + pos_enc_layer_type)
|
||||
|
||||
if input_layer == 'linear':
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(input_size, output_size),
|
||||
torch.nn.LayerNorm(output_size),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == 'conv2d':
|
||||
self.embed = Conv2dSubsampling(
|
||||
input_size,
|
||||
output_size,
|
||||
dropout_rate,
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == 'conv2d2':
|
||||
self.embed = Conv2dSubsampling2(
|
||||
input_size,
|
||||
output_size,
|
||||
dropout_rate,
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == 'conv2d6':
|
||||
self.embed = Conv2dSubsampling6(
|
||||
input_size,
|
||||
output_size,
|
||||
dropout_rate,
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == 'conv2d8':
|
||||
self.embed = Conv2dSubsampling8(
|
||||
input_size,
|
||||
output_size,
|
||||
dropout_rate,
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == 'embed':
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(
|
||||
input_size, output_size, padding_idx=padding_idx),
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif isinstance(input_layer, torch.nn.Module):
|
||||
self.embed = torch.nn.Sequential(
|
||||
input_layer,
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer is None:
|
||||
self.embed = torch.nn.Sequential(
|
||||
pos_enc_class(output_size, positional_dropout_rate))
|
||||
else:
|
||||
raise ValueError('unknown input_layer: ' + input_layer)
|
||||
self.normalize_before = normalize_before
|
||||
if positionwise_layer_type == 'linear':
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
activation,
|
||||
)
|
||||
elif positionwise_layer_type == 'conv1d':
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == 'conv1d-linear':
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError('Support only linear or conv1d.')
|
||||
|
||||
if selfattention_layer_type == 'selfattn':
|
||||
encoder_selfattn_layer = MultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
elif selfattention_layer_type == 'legacy_rel_selfattn':
|
||||
assert pos_enc_layer_type == 'legacy_rel_pos'
|
||||
encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
logging.warning(
|
||||
'Using legacy_rel_selfattn and it will be deprecated in the future.'
|
||||
)
|
||||
|
||||
elif selfattention_layer_type == 'legacy_rel_selfattnsanm':
|
||||
assert pos_enc_layer_type == 'legacy_rel_pos'
|
||||
encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttentionSANM
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
logging.warning(
|
||||
'Using legacy_rel_selfattn and it will be deprecated in the future.'
|
||||
)
|
||||
|
||||
elif selfattention_layer_type == 'rel_selfattn':
|
||||
assert pos_enc_layer_type == 'rel_pos'
|
||||
encoder_selfattn_layer = RelPositionMultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
zero_triu,
|
||||
)
|
||||
elif selfattention_layer_type == 'rel_selfattnsanm':
|
||||
assert pos_enc_layer_type == 'rel_pos'
|
||||
encoder_selfattn_layer = RelPositionMultiHeadedAttentionSANM
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
zero_triu,
|
||||
cnn_module_kernel,
|
||||
sanm_shfit,
|
||||
)
|
||||
else:
|
||||
raise ValueError('unknown encoder_attn_layer: '
|
||||
+ selfattention_layer_type)
|
||||
|
||||
convolution_layer = ConvolutionModule
|
||||
convolution_layer_args = (output_size, cnn_module_kernel, activation)
|
||||
|
||||
if isinstance(stochastic_depth_rate, float):
|
||||
stochastic_depth_rate = [stochastic_depth_rate] * num_blocks
|
||||
|
||||
if len(stochastic_depth_rate) != num_blocks:
|
||||
raise ValueError(
|
||||
f'Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) '
|
||||
f'should be equal to num_blocks ({num_blocks})')
|
||||
|
||||
self.encoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: EncoderLayer(
|
||||
output_size,
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args)
|
||||
if macaron_style else None,
|
||||
convolution_layer(*convolution_layer_args)
|
||||
if use_cnn_module else None,
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
stochastic_depth_rate[lnum],
|
||||
),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(output_size)
|
||||
|
||||
self.interctc_layer_idx = interctc_layer_idx
|
||||
if len(interctc_layer_idx) > 0:
|
||||
assert 0 < min(interctc_layer_idx) and max(
|
||||
interctc_layer_idx) < num_blocks
|
||||
self.interctc_use_conditioning = interctc_use_conditioning
|
||||
self.conditioning_layer = None
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self._output_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
ctc: CTC = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
|
||||
ilens (torch.Tensor): Input length (#batch).
|
||||
prev_states (torch.Tensor): Not to be used now.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, L, output_size).
|
||||
torch.Tensor: Output length (#batch).
|
||||
torch.Tensor: Not to be used now.
|
||||
|
||||
"""
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
|
||||
if (isinstance(self.embed, Conv2dSubsampling)
|
||||
or isinstance(self.embed, Conv2dSubsampling2)
|
||||
or isinstance(self.embed, Conv2dSubsampling6)
|
||||
or isinstance(self.embed, Conv2dSubsampling8)):
|
||||
short_status, limit_size = check_short_utt(self.embed,
|
||||
xs_pad.size(1))
|
||||
if short_status:
|
||||
raise TooShortUttError(
|
||||
f'has {xs_pad.size(1)} frames and is too short for subsampling '
|
||||
+ # noqa: *
|
||||
f'(it needs more than {limit_size} frames), return empty results',
|
||||
xs_pad.size(1),
|
||||
limit_size) # noqa: *
|
||||
xs_pad, masks = self.embed(xs_pad, masks)
|
||||
else:
|
||||
xs_pad = self.embed(xs_pad)
|
||||
|
||||
intermediate_outs = []
|
||||
if len(self.interctc_layer_idx) == 0:
|
||||
xs_pad, masks = self.encoders(xs_pad, masks)
|
||||
else:
|
||||
for layer_idx, encoder_layer in enumerate(self.encoders):
|
||||
xs_pad, masks = encoder_layer(xs_pad, masks)
|
||||
|
||||
if layer_idx + 1 in self.interctc_layer_idx:
|
||||
encoder_out = xs_pad
|
||||
if isinstance(encoder_out, tuple):
|
||||
encoder_out = encoder_out[0]
|
||||
|
||||
# intermediate outputs are also normalized
|
||||
if self.normalize_before:
|
||||
encoder_out = self.after_norm(encoder_out)
|
||||
|
||||
intermediate_outs.append((layer_idx + 1, encoder_out))
|
||||
|
||||
if self.interctc_use_conditioning:
|
||||
ctc_out = ctc.softmax(encoder_out)
|
||||
|
||||
if isinstance(xs_pad, tuple):
|
||||
x, pos_emb = xs_pad
|
||||
x = x + self.conditioning_layer(ctc_out)
|
||||
xs_pad = (x, pos_emb)
|
||||
else:
|
||||
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
|
||||
|
||||
if isinstance(xs_pad, tuple):
|
||||
xs_pad = xs_pad[0]
|
||||
if self.normalize_before:
|
||||
xs_pad = self.after_norm(xs_pad)
|
||||
|
||||
olens = masks.squeeze(1).sum(1)
|
||||
if len(intermediate_outs) > 0:
|
||||
return (xs_pad, intermediate_outs), olens, None
|
||||
return xs_pad, olens, None
|
||||
@@ -0,0 +1,500 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# Part of the implementation is borrowed from espnet/espnet.
|
||||
"""Transformer encoder definition."""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
from espnet2.asr.ctc import CTC
|
||||
from espnet2.asr.encoder.abs_encoder import AbsEncoder
|
||||
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask
|
||||
from espnet.nets.pytorch_backend.transformer.embedding import \
|
||||
PositionalEncoding
|
||||
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
|
||||
from espnet.nets.pytorch_backend.transformer.multi_layer_conv import (
|
||||
Conv1dLinear, MultiLayeredConv1d)
|
||||
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import \
|
||||
PositionwiseFeedForward # noqa: H301
|
||||
from espnet.nets.pytorch_backend.transformer.repeat import repeat
|
||||
from espnet.nets.pytorch_backend.transformer.subsampling import (
|
||||
Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6,
|
||||
Conv2dSubsampling8, TooShortUttError, check_short_utt)
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from ...asr.streaming_utilis.chunk_utilis import overlap_chunk
|
||||
from ...nets.pytorch_backend.transformer.attention import (
|
||||
MultiHeadedAttention, MultiHeadedAttentionSANM)
|
||||
from ...nets.pytorch_backend.transformer.encoder_layer import (
|
||||
EncoderLayer, EncoderLayerChunk)
|
||||
|
||||
|
||||
class SANMEncoder(AbsEncoder):
|
||||
"""Transformer encoder module.
|
||||
|
||||
Args:
|
||||
input_size: input dim
|
||||
output_size: dimension of attention
|
||||
attention_heads: the number of heads of multi head attention
|
||||
linear_units: the number of units of position-wise feed forward
|
||||
num_blocks: the number of decoder blocks
|
||||
dropout_rate: dropout rate
|
||||
attention_dropout_rate: dropout rate in attention
|
||||
positional_dropout_rate: dropout rate after adding positional encoding
|
||||
input_layer: input layer type
|
||||
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
|
||||
normalize_before: whether to use layer_norm before the first block
|
||||
concat_after: whether to concat attention layer's input and output
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied.
|
||||
i.e. x -> x + att(x)
|
||||
positionwise_layer_type: linear of conv1d
|
||||
positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
|
||||
padding_idx: padding_idx for input_layer=embed
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
attention_dropout_rate: float = 0.0,
|
||||
input_layer: Optional[str] = 'conv2d',
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
positionwise_layer_type: str = 'linear',
|
||||
positionwise_conv_kernel_size: int = 1,
|
||||
padding_idx: int = -1,
|
||||
interctc_layer_idx: List[int] = [],
|
||||
interctc_use_conditioning: bool = False,
|
||||
kernel_size: int = 11,
|
||||
sanm_shfit: int = 0,
|
||||
selfattention_layer_type: str = 'sanm',
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self._output_size = output_size
|
||||
|
||||
if input_layer == 'linear':
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(input_size, output_size),
|
||||
torch.nn.LayerNorm(output_size),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
torch.nn.ReLU(),
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == 'conv2d':
|
||||
self.embed = Conv2dSubsampling(input_size, output_size,
|
||||
dropout_rate)
|
||||
elif input_layer == 'conv2d2':
|
||||
self.embed = Conv2dSubsampling2(input_size, output_size,
|
||||
dropout_rate)
|
||||
elif input_layer == 'conv2d6':
|
||||
self.embed = Conv2dSubsampling6(input_size, output_size,
|
||||
dropout_rate)
|
||||
elif input_layer == 'conv2d8':
|
||||
self.embed = Conv2dSubsampling8(input_size, output_size,
|
||||
dropout_rate)
|
||||
elif input_layer == 'embed':
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(
|
||||
input_size, output_size, padding_idx=padding_idx),
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer is None:
|
||||
if input_size == output_size:
|
||||
self.embed = None
|
||||
else:
|
||||
self.embed = torch.nn.Linear(input_size, output_size)
|
||||
else:
|
||||
raise ValueError('unknown input_layer: ' + input_layer)
|
||||
self.normalize_before = normalize_before
|
||||
if positionwise_layer_type == 'linear':
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == 'conv1d':
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == 'conv1d-linear':
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError('Support only linear or conv1d.')
|
||||
|
||||
if selfattention_layer_type == 'selfattn':
|
||||
encoder_selfattn_layer = MultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
elif selfattention_layer_type == 'sanm':
|
||||
encoder_selfattn_layer = MultiHeadedAttentionSANM
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
kernel_size,
|
||||
sanm_shfit,
|
||||
)
|
||||
|
||||
self.encoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: EncoderLayer(
|
||||
output_size,
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(output_size)
|
||||
|
||||
self.interctc_layer_idx = interctc_layer_idx
|
||||
if len(interctc_layer_idx) > 0:
|
||||
assert 0 < min(interctc_layer_idx) and max(
|
||||
interctc_layer_idx) < num_blocks
|
||||
self.interctc_use_conditioning = interctc_use_conditioning
|
||||
self.conditioning_layer = None
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self._output_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
ctc: CTC = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Embed positions in tensor.
|
||||
|
||||
Args:
|
||||
xs_pad: input tensor (B, L, D)
|
||||
ilens: input length (B)
|
||||
prev_states: Not to be used now.
|
||||
Returns:
|
||||
position embedded tensor and mask
|
||||
"""
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
|
||||
if self.embed is None:
|
||||
xs_pad = xs_pad
|
||||
elif (isinstance(self.embed, Conv2dSubsampling)
|
||||
or isinstance(self.embed, Conv2dSubsampling2)
|
||||
or isinstance(self.embed, Conv2dSubsampling6)
|
||||
or isinstance(self.embed, Conv2dSubsampling8)):
|
||||
short_status, limit_size = check_short_utt(self.embed,
|
||||
xs_pad.size(1))
|
||||
if short_status:
|
||||
raise TooShortUttError(
|
||||
f'has {xs_pad.size(1)} frames and is too short for subsampling '
|
||||
+ # noqa: *
|
||||
f'(it needs more than {limit_size} frames), return empty results',
|
||||
xs_pad.size(1),
|
||||
limit_size,
|
||||
)
|
||||
xs_pad, masks = self.embed(xs_pad, masks)
|
||||
else:
|
||||
xs_pad = self.embed(xs_pad)
|
||||
|
||||
intermediate_outs = []
|
||||
if len(self.interctc_layer_idx) == 0:
|
||||
xs_pad, masks = self.encoders(xs_pad, masks)
|
||||
else:
|
||||
for layer_idx, encoder_layer in enumerate(self.encoders):
|
||||
xs_pad, masks = encoder_layer(xs_pad, masks)
|
||||
|
||||
if layer_idx + 1 in self.interctc_layer_idx:
|
||||
encoder_out = xs_pad
|
||||
|
||||
# intermediate outputs are also normalized
|
||||
if self.normalize_before:
|
||||
encoder_out = self.after_norm(encoder_out)
|
||||
|
||||
intermediate_outs.append((layer_idx + 1, encoder_out))
|
||||
|
||||
if self.interctc_use_conditioning:
|
||||
ctc_out = ctc.softmax(encoder_out)
|
||||
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
|
||||
|
||||
if self.normalize_before:
|
||||
xs_pad = self.after_norm(xs_pad)
|
||||
|
||||
olens = masks.squeeze(1).sum(1)
|
||||
if len(intermediate_outs) > 0:
|
||||
return (xs_pad, intermediate_outs), olens, None
|
||||
return xs_pad, olens, None
|
||||
|
||||
|
||||
class SANMEncoderChunk(AbsEncoder):
|
||||
"""Transformer encoder module.
|
||||
|
||||
Args:
|
||||
input_size: input dim
|
||||
output_size: dimension of attention
|
||||
attention_heads: the number of heads of multi head attention
|
||||
linear_units: the number of units of position-wise feed forward
|
||||
num_blocks: the number of decoder blocks
|
||||
dropout_rate: dropout rate
|
||||
attention_dropout_rate: dropout rate in attention
|
||||
positional_dropout_rate: dropout rate after adding positional encoding
|
||||
input_layer: input layer type
|
||||
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
|
||||
normalize_before: whether to use layer_norm before the first block
|
||||
concat_after: whether to concat attention layer's input and output
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied.
|
||||
i.e. x -> x + att(x)
|
||||
positionwise_layer_type: linear of conv1d
|
||||
positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
|
||||
padding_idx: padding_idx for input_layer=embed
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
attention_dropout_rate: float = 0.0,
|
||||
input_layer: Optional[str] = 'conv2d',
|
||||
pos_enc_class=PositionalEncoding,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
positionwise_layer_type: str = 'linear',
|
||||
positionwise_conv_kernel_size: int = 1,
|
||||
padding_idx: int = -1,
|
||||
interctc_layer_idx: List[int] = [],
|
||||
interctc_use_conditioning: bool = False,
|
||||
kernel_size: int = 11,
|
||||
sanm_shfit: int = 0,
|
||||
selfattention_layer_type: str = 'sanm',
|
||||
chunk_size: Union[int, Sequence[int]] = (16, ),
|
||||
stride: Union[int, Sequence[int]] = (10, ),
|
||||
pad_left: Union[int, Sequence[int]] = (0, ),
|
||||
encoder_att_look_back_factor: Union[int, Sequence[int]] = (1, ),
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self._output_size = output_size
|
||||
|
||||
if input_layer == 'linear':
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(input_size, output_size),
|
||||
torch.nn.LayerNorm(output_size),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
torch.nn.ReLU(),
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == 'conv2d':
|
||||
self.embed = Conv2dSubsampling(input_size, output_size,
|
||||
dropout_rate)
|
||||
elif input_layer == 'conv2d2':
|
||||
self.embed = Conv2dSubsampling2(input_size, output_size,
|
||||
dropout_rate)
|
||||
elif input_layer == 'conv2d6':
|
||||
self.embed = Conv2dSubsampling6(input_size, output_size,
|
||||
dropout_rate)
|
||||
elif input_layer == 'conv2d8':
|
||||
self.embed = Conv2dSubsampling8(input_size, output_size,
|
||||
dropout_rate)
|
||||
elif input_layer == 'embed':
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(
|
||||
input_size, output_size, padding_idx=padding_idx),
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer is None:
|
||||
if input_size == output_size:
|
||||
self.embed = None
|
||||
else:
|
||||
self.embed = torch.nn.Linear(input_size, output_size)
|
||||
else:
|
||||
raise ValueError('unknown input_layer: ' + input_layer)
|
||||
self.normalize_before = normalize_before
|
||||
if positionwise_layer_type == 'linear':
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == 'conv1d':
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == 'conv1d-linear':
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError('Support only linear or conv1d.')
|
||||
|
||||
if selfattention_layer_type == 'selfattn':
|
||||
encoder_selfattn_layer = MultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
elif selfattention_layer_type == 'sanm':
|
||||
encoder_selfattn_layer = MultiHeadedAttentionSANM
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
kernel_size,
|
||||
sanm_shfit,
|
||||
)
|
||||
|
||||
self.encoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: EncoderLayerChunk(
|
||||
output_size,
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(output_size)
|
||||
|
||||
self.interctc_layer_idx = interctc_layer_idx
|
||||
if len(interctc_layer_idx) > 0:
|
||||
assert 0 < min(interctc_layer_idx) and max(
|
||||
interctc_layer_idx) < num_blocks
|
||||
self.interctc_use_conditioning = interctc_use_conditioning
|
||||
self.conditioning_layer = None
|
||||
shfit_fsmn = (kernel_size - 1) // 2
|
||||
self.overlap_chunk_cls = overlap_chunk(
|
||||
chunk_size=chunk_size,
|
||||
stride=stride,
|
||||
pad_left=pad_left,
|
||||
shfit_fsmn=shfit_fsmn,
|
||||
encoder_att_look_back_factor=encoder_att_look_back_factor,
|
||||
)
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self._output_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
ctc: CTC = None,
|
||||
ind: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Embed positions in tensor.
|
||||
|
||||
Args:
|
||||
xs_pad: input tensor (B, L, D)
|
||||
ilens: input length (B)
|
||||
prev_states: Not to be used now.
|
||||
Returns:
|
||||
position embedded tensor and mask
|
||||
"""
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
|
||||
if self.embed is None:
|
||||
xs_pad = xs_pad
|
||||
elif (isinstance(self.embed, Conv2dSubsampling)
|
||||
or isinstance(self.embed, Conv2dSubsampling2)
|
||||
or isinstance(self.embed, Conv2dSubsampling6)
|
||||
or isinstance(self.embed, Conv2dSubsampling8)):
|
||||
short_status, limit_size = check_short_utt(self.embed,
|
||||
xs_pad.size(1))
|
||||
if short_status:
|
||||
raise TooShortUttError(
|
||||
f'has {xs_pad.size(1)} frames and is too short for subsampling '
|
||||
+ # noqa: *
|
||||
f'(it needs more than {limit_size} frames), return empty results',
|
||||
xs_pad.size(1),
|
||||
limit_size,
|
||||
)
|
||||
xs_pad, masks = self.embed(xs_pad, masks)
|
||||
else:
|
||||
xs_pad = self.embed(xs_pad)
|
||||
|
||||
mask_shfit_chunk, mask_att_chunk_encoder = None, None
|
||||
if self.overlap_chunk_cls is not None:
|
||||
ilens = masks.squeeze(1).sum(1)
|
||||
chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind)
|
||||
xs_pad, ilens = self.overlap_chunk_cls.split_chunk(
|
||||
xs_pad, ilens, chunk_outs=chunk_outs)
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(
|
||||
chunk_outs, xs_pad.device, xs_pad.size(0), dtype=xs_pad.dtype)
|
||||
mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(
|
||||
chunk_outs, xs_pad.device, xs_pad.size(0), dtype=xs_pad.dtype)
|
||||
|
||||
intermediate_outs = []
|
||||
if len(self.interctc_layer_idx) == 0:
|
||||
xs_pad, masks, _, _, _ = self.encoders(xs_pad, masks, None,
|
||||
mask_shfit_chunk,
|
||||
mask_att_chunk_encoder)
|
||||
else:
|
||||
for layer_idx, encoder_layer in enumerate(self.encoders):
|
||||
xs_pad, masks, _, _, _ = encoder_layer(xs_pad, masks, None,
|
||||
mask_shfit_chunk,
|
||||
mask_att_chunk_encoder)
|
||||
|
||||
if layer_idx + 1 in self.interctc_layer_idx:
|
||||
encoder_out = xs_pad
|
||||
|
||||
# intermediate outputs are also normalized
|
||||
if self.normalize_before:
|
||||
encoder_out = self.after_norm(encoder_out)
|
||||
|
||||
intermediate_outs.append((layer_idx + 1, encoder_out))
|
||||
|
||||
if self.interctc_use_conditioning:
|
||||
ctc_out = ctc.softmax(encoder_out)
|
||||
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
|
||||
|
||||
if self.normalize_before:
|
||||
xs_pad = self.after_norm(xs_pad)
|
||||
|
||||
if self.overlap_chunk_cls is not None:
|
||||
xs_pad, olens = self.overlap_chunk_cls.remove_chunk(
|
||||
xs_pad, ilens, chunk_outs)
|
||||
if len(intermediate_outs) > 0:
|
||||
return (xs_pad, intermediate_outs), olens, None
|
||||
return xs_pad, olens, None
|
||||
1131
modelscope/pipelines/audio/asr/asr_engine/espnet/asr/espnet_model.py
Normal file
1131
modelscope/pipelines/audio/asr/asr_engine/espnet/asr/espnet_model.py
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,321 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# Part of the implementation is borrowed from espnet/espnet.
|
||||
import logging
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask
|
||||
|
||||
from ...nets.pytorch_backend.cif_utils.cif import \
|
||||
cif_predictor as cif_predictor
|
||||
|
||||
np.set_printoptions(threshold=np.inf)
|
||||
torch.set_printoptions(profile='full', precision=100000, linewidth=None)
|
||||
|
||||
|
||||
def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device='cpu'):
|
||||
if maxlen is None:
|
||||
maxlen = lengths.max()
|
||||
row_vector = torch.arange(0, maxlen, 1)
|
||||
matrix = torch.unsqueeze(lengths, dim=-1)
|
||||
mask = row_vector < matrix
|
||||
|
||||
return mask.type(dtype).to(device)
|
||||
|
||||
|
||||
class overlap_chunk():
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: tuple = (16, ),
|
||||
stride: tuple = (10, ),
|
||||
pad_left: tuple = (0, ),
|
||||
encoder_att_look_back_factor: tuple = (1, ),
|
||||
shfit_fsmn: int = 0,
|
||||
):
|
||||
self.chunk_size, self.stride, self.pad_left, self.encoder_att_look_back_factor \
|
||||
= chunk_size, stride, pad_left, encoder_att_look_back_factor
|
||||
self.shfit_fsmn = shfit_fsmn
|
||||
self.x_add_mask = None
|
||||
self.x_rm_mask = None
|
||||
self.x_len = None
|
||||
self.mask_shfit_chunk = None
|
||||
self.mask_chunk_predictor = None
|
||||
self.mask_att_chunk_encoder = None
|
||||
self.mask_shift_att_chunk_decoder = None
|
||||
self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur \
|
||||
= None, None, None, None
|
||||
|
||||
def get_chunk_size(self, ind: int = 0):
|
||||
# with torch.no_grad:
|
||||
chunk_size, stride, pad_left, encoder_att_look_back_factor = self.chunk_size[
|
||||
ind], self.stride[ind], self.pad_left[
|
||||
ind], self.encoder_att_look_back_factor[ind]
|
||||
self.chunk_size_cur, self.stride_cur, self.pad_left_cur,
|
||||
self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur \
|
||||
= chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size + self.shfit_fsmn
|
||||
return self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur
|
||||
|
||||
def gen_chunk_mask(self, x_len, ind=0, num_units=1, num_units_predictor=1):
|
||||
|
||||
with torch.no_grad():
|
||||
x_len = x_len.cpu().numpy()
|
||||
x_len_max = x_len.max()
|
||||
|
||||
chunk_size, stride, pad_left, encoder_att_look_back_factor = self.get_chunk_size(
|
||||
ind)
|
||||
shfit_fsmn = self.shfit_fsmn
|
||||
chunk_size_pad_shift = chunk_size + shfit_fsmn
|
||||
|
||||
chunk_num_batch = np.ceil(x_len / stride).astype(np.int32)
|
||||
x_len_chunk = (
|
||||
chunk_num_batch - 1
|
||||
) * chunk_size_pad_shift + shfit_fsmn + pad_left + 0 + x_len - (
|
||||
chunk_num_batch - 1) * stride
|
||||
x_len_chunk = x_len_chunk.astype(x_len.dtype)
|
||||
x_len_chunk_max = x_len_chunk.max()
|
||||
|
||||
chunk_num = int(math.ceil(x_len_max / stride))
|
||||
dtype = np.int32
|
||||
max_len_for_x_mask_tmp = max(chunk_size, x_len_max)
|
||||
x_add_mask = np.zeros([0, max_len_for_x_mask_tmp], dtype=dtype)
|
||||
x_rm_mask = np.zeros([max_len_for_x_mask_tmp, 0], dtype=dtype)
|
||||
mask_shfit_chunk = np.zeros([0, num_units], dtype=dtype)
|
||||
mask_chunk_predictor = np.zeros([0, num_units_predictor],
|
||||
dtype=dtype)
|
||||
mask_shift_att_chunk_decoder = np.zeros([0, 1], dtype=dtype)
|
||||
mask_att_chunk_encoder = np.zeros(
|
||||
[0, chunk_num * chunk_size_pad_shift], dtype=dtype)
|
||||
for chunk_ids in range(chunk_num):
|
||||
# x_mask add
|
||||
fsmn_padding = np.zeros((shfit_fsmn, max_len_for_x_mask_tmp),
|
||||
dtype=dtype)
|
||||
x_mask_cur = np.diag(np.ones(chunk_size, dtype=np.float32))
|
||||
x_mask_pad_left = np.zeros((chunk_size, chunk_ids * stride),
|
||||
dtype=dtype)
|
||||
x_mask_pad_right = np.zeros(
|
||||
(chunk_size, max_len_for_x_mask_tmp), dtype=dtype)
|
||||
x_cur_pad = np.concatenate(
|
||||
[x_mask_pad_left, x_mask_cur, x_mask_pad_right], axis=1)
|
||||
x_cur_pad = x_cur_pad[:chunk_size, :max_len_for_x_mask_tmp]
|
||||
x_add_mask_fsmn = np.concatenate([fsmn_padding, x_cur_pad],
|
||||
axis=0)
|
||||
x_add_mask = np.concatenate([x_add_mask, x_add_mask_fsmn],
|
||||
axis=0)
|
||||
|
||||
# x_mask rm
|
||||
fsmn_padding = np.zeros((max_len_for_x_mask_tmp, shfit_fsmn),
|
||||
dtype=dtype)
|
||||
x_mask_cur = np.diag(np.ones(stride, dtype=dtype))
|
||||
x_mask_right = np.zeros((stride, chunk_size - stride),
|
||||
dtype=dtype)
|
||||
x_mask_cur = np.concatenate([x_mask_cur, x_mask_right], axis=1)
|
||||
x_mask_cur_pad_top = np.zeros((chunk_ids * stride, chunk_size),
|
||||
dtype=dtype)
|
||||
x_mask_cur_pad_bottom = np.zeros(
|
||||
(max_len_for_x_mask_tmp, chunk_size), dtype=dtype)
|
||||
x_rm_mask_cur = np.concatenate(
|
||||
[x_mask_cur_pad_top, x_mask_cur, x_mask_cur_pad_bottom],
|
||||
axis=0)
|
||||
x_rm_mask_cur = x_rm_mask_cur[:max_len_for_x_mask_tmp, :
|
||||
chunk_size]
|
||||
x_rm_mask_cur_fsmn = np.concatenate(
|
||||
[fsmn_padding, x_rm_mask_cur], axis=1)
|
||||
x_rm_mask = np.concatenate([x_rm_mask, x_rm_mask_cur_fsmn],
|
||||
axis=1)
|
||||
|
||||
# fsmn_padding_mask
|
||||
pad_shfit_mask = np.zeros([shfit_fsmn, num_units], dtype=dtype)
|
||||
ones_1 = np.ones([chunk_size, num_units], dtype=dtype)
|
||||
mask_shfit_chunk_cur = np.concatenate([pad_shfit_mask, ones_1],
|
||||
axis=0)
|
||||
mask_shfit_chunk = np.concatenate(
|
||||
[mask_shfit_chunk, mask_shfit_chunk_cur], axis=0)
|
||||
|
||||
# predictor mask
|
||||
zeros_1 = np.zeros(
|
||||
[shfit_fsmn + pad_left, num_units_predictor], dtype=dtype)
|
||||
ones_2 = np.ones([stride, num_units_predictor], dtype=dtype)
|
||||
zeros_3 = np.zeros(
|
||||
[chunk_size - stride - pad_left, num_units_predictor],
|
||||
dtype=dtype)
|
||||
ones_zeros = np.concatenate([ones_2, zeros_3], axis=0)
|
||||
mask_chunk_predictor_cur = np.concatenate(
|
||||
[zeros_1, ones_zeros], axis=0)
|
||||
mask_chunk_predictor = np.concatenate(
|
||||
[mask_chunk_predictor, mask_chunk_predictor_cur], axis=0)
|
||||
|
||||
# encoder att mask
|
||||
zeros_1_top = np.zeros(
|
||||
[shfit_fsmn, chunk_num * chunk_size_pad_shift],
|
||||
dtype=dtype)
|
||||
|
||||
zeros_2_num = max(chunk_ids - encoder_att_look_back_factor, 0)
|
||||
zeros_2 = np.zeros(
|
||||
[chunk_size, zeros_2_num * chunk_size_pad_shift],
|
||||
dtype=dtype)
|
||||
|
||||
encoder_att_look_back_num = max(chunk_ids - zeros_2_num, 0)
|
||||
zeros_2_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
|
||||
ones_2_mid = np.ones([stride, stride], dtype=dtype)
|
||||
zeros_2_bottom = np.zeros([chunk_size - stride, stride],
|
||||
dtype=dtype)
|
||||
zeros_2_right = np.zeros([chunk_size, chunk_size - stride],
|
||||
dtype=dtype)
|
||||
ones_2 = np.concatenate([ones_2_mid, zeros_2_bottom], axis=0)
|
||||
ones_2 = np.concatenate([zeros_2_left, ones_2, zeros_2_right],
|
||||
axis=1)
|
||||
ones_2 = np.tile(ones_2, [1, encoder_att_look_back_num])
|
||||
|
||||
zeros_3_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
|
||||
ones_3_right = np.ones([chunk_size, chunk_size], dtype=dtype)
|
||||
ones_3 = np.concatenate([zeros_3_left, ones_3_right], axis=1)
|
||||
|
||||
zeros_remain_num = max(chunk_num - 1 - chunk_ids, 0)
|
||||
zeros_remain = np.zeros(
|
||||
[chunk_size, zeros_remain_num * chunk_size_pad_shift],
|
||||
dtype=dtype)
|
||||
|
||||
ones2_bottom = np.concatenate(
|
||||
[zeros_2, ones_2, ones_3, zeros_remain], axis=1)
|
||||
mask_att_chunk_encoder_cur = np.concatenate(
|
||||
[zeros_1_top, ones2_bottom], axis=0)
|
||||
mask_att_chunk_encoder = np.concatenate(
|
||||
[mask_att_chunk_encoder, mask_att_chunk_encoder_cur],
|
||||
axis=0)
|
||||
|
||||
# decoder fsmn_shift_att_mask
|
||||
zeros_1 = np.zeros([shfit_fsmn, 1])
|
||||
ones_1 = np.ones([chunk_size, 1])
|
||||
mask_shift_att_chunk_decoder_cur = np.concatenate(
|
||||
[zeros_1, ones_1], axis=0)
|
||||
mask_shift_att_chunk_decoder = np.concatenate(
|
||||
[
|
||||
mask_shift_att_chunk_decoder,
|
||||
mask_shift_att_chunk_decoder_cur
|
||||
],
|
||||
vaxis=0) # noqa: *
|
||||
|
||||
self.x_add_mask = x_add_mask[:x_len_chunk_max, :x_len_max]
|
||||
self.x_len_chunk = x_len_chunk
|
||||
self.x_rm_mask = x_rm_mask[:x_len_max, :x_len_chunk_max]
|
||||
self.x_len = x_len
|
||||
self.mask_shfit_chunk = mask_shfit_chunk[:x_len_chunk_max, :]
|
||||
self.mask_chunk_predictor = mask_chunk_predictor[:
|
||||
x_len_chunk_max, :]
|
||||
self.mask_att_chunk_encoder = mask_att_chunk_encoder[:
|
||||
x_len_chunk_max, :
|
||||
x_len_chunk_max]
|
||||
self.mask_shift_att_chunk_decoder = mask_shift_att_chunk_decoder[:
|
||||
x_len_chunk_max, :]
|
||||
|
||||
return (self.x_add_mask, self.x_len_chunk, self.x_rm_mask, self.x_len,
|
||||
self.mask_shfit_chunk, self.mask_chunk_predictor,
|
||||
self.mask_att_chunk_encoder, self.mask_shift_att_chunk_decoder)
|
||||
|
||||
def split_chunk(self, x, x_len, chunk_outs):
|
||||
"""
|
||||
:param x: (b, t, d)
|
||||
:param x_length: (b)
|
||||
:param ind: int
|
||||
:return:
|
||||
"""
|
||||
x = x[:, :x_len.max(), :]
|
||||
b, t, d = x.size()
|
||||
x_len_mask = (~make_pad_mask(x_len, maxlen=t)).to(x.device)
|
||||
x *= x_len_mask[:, :, None]
|
||||
|
||||
x_add_mask = self.get_x_add_mask(chunk_outs, x.device, dtype=x.dtype)
|
||||
x_len_chunk = self.get_x_len_chunk(
|
||||
chunk_outs, x_len.device, dtype=x_len.dtype)
|
||||
x = torch.transpose(x, 1, 0)
|
||||
x = torch.reshape(x, [t, -1])
|
||||
x_chunk = torch.mm(x_add_mask, x)
|
||||
x_chunk = torch.reshape(x_chunk, [-1, b, d]).transpose(1, 0)
|
||||
|
||||
return x_chunk, x_len_chunk
|
||||
|
||||
def remove_chunk(self, x_chunk, x_len_chunk, chunk_outs):
|
||||
x_chunk = x_chunk[:, :x_len_chunk.max(), :]
|
||||
b, t, d = x_chunk.size()
|
||||
x_len_chunk_mask = (~make_pad_mask(x_len_chunk, maxlen=t)).to(
|
||||
x_chunk.device)
|
||||
x_chunk *= x_len_chunk_mask[:, :, None]
|
||||
|
||||
x_rm_mask = self.get_x_rm_mask(
|
||||
chunk_outs, x_chunk.device, dtype=x_chunk.dtype)
|
||||
x_len = self.get_x_len(
|
||||
chunk_outs, x_len_chunk.device, dtype=x_len_chunk.dtype)
|
||||
x_chunk = torch.transpose(x_chunk, 1, 0)
|
||||
x_chunk = torch.reshape(x_chunk, [t, -1])
|
||||
x = torch.mm(x_rm_mask, x_chunk)
|
||||
x = torch.reshape(x, [-1, b, d]).transpose(1, 0)
|
||||
|
||||
return x, x_len
|
||||
|
||||
def get_x_add_mask(self, chunk_outs, device, idx=0, dtype=torch.float32):
|
||||
x = chunk_outs[idx]
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x.detach()
|
||||
|
||||
def get_x_len_chunk(self, chunk_outs, device, idx=1, dtype=torch.float32):
|
||||
x = chunk_outs[idx]
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x.detach()
|
||||
|
||||
def get_x_rm_mask(self, chunk_outs, device, idx=2, dtype=torch.float32):
|
||||
x = chunk_outs[idx]
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x.detach()
|
||||
|
||||
def get_x_len(self, chunk_outs, device, idx=3, dtype=torch.float32):
|
||||
x = chunk_outs[idx]
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x.detach()
|
||||
|
||||
def get_mask_shfit_chunk(self,
|
||||
chunk_outs,
|
||||
device,
|
||||
batch_size=1,
|
||||
num_units=1,
|
||||
idx=4,
|
||||
dtype=torch.float32):
|
||||
x = chunk_outs[idx]
|
||||
x = np.tile(x[None, :, :, ], [batch_size, 1, num_units])
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x.detach()
|
||||
|
||||
def get_mask_chunk_predictor(self,
|
||||
chunk_outs,
|
||||
device,
|
||||
batch_size=1,
|
||||
num_units=1,
|
||||
idx=5,
|
||||
dtype=torch.float32):
|
||||
x = chunk_outs[idx]
|
||||
x = np.tile(x[None, :, :, ], [batch_size, 1, num_units])
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x.detach()
|
||||
|
||||
def get_mask_att_chunk_encoder(self,
|
||||
chunk_outs,
|
||||
device,
|
||||
batch_size=1,
|
||||
idx=6,
|
||||
dtype=torch.float32):
|
||||
x = chunk_outs[idx]
|
||||
x = np.tile(x[None, :, :, ], [batch_size, 1, 1])
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x.detach()
|
||||
|
||||
def get_mask_shift_att_chunk_decoder(self,
|
||||
chunk_outs,
|
||||
device,
|
||||
batch_size=1,
|
||||
idx=7,
|
||||
dtype=torch.float32):
|
||||
x = chunk_outs[idx]
|
||||
x = np.tile(x[None, None, :, 0], [batch_size, 1, 1])
|
||||
x = torch.from_numpy(x).type(dtype).to(device)
|
||||
return x.detach()
|
||||
@@ -0,0 +1,250 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# Part of the implementation is borrowed from espnet/espnet.
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask
|
||||
from torch import nn
|
||||
|
||||
|
||||
class CIF_Model(nn.Module):
|
||||
|
||||
def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1):
|
||||
super(CIF_Model, self).__init__()
|
||||
|
||||
self.pad = nn.ConstantPad1d((l_order, r_order), 0)
|
||||
self.cif_conv1d = nn.Conv1d(
|
||||
idim, idim, l_order + r_order + 1, groups=idim)
|
||||
self.cif_output = nn.Linear(idim, 1)
|
||||
self.dropout = torch.nn.Dropout(p=dropout)
|
||||
self.threshold = threshold
|
||||
|
||||
def forward(self, hidden, target_label=None, mask=None, ignore_id=-1):
|
||||
h = hidden
|
||||
context = h.transpose(1, 2)
|
||||
queries = self.pad(context)
|
||||
memory = self.cif_conv1d(queries)
|
||||
output = memory + context
|
||||
output = self.dropout(output)
|
||||
output = output.transpose(1, 2)
|
||||
output = torch.relu(output)
|
||||
output = self.cif_output(output)
|
||||
alphas = torch.sigmoid(output)
|
||||
if mask is not None:
|
||||
alphas = alphas * mask.transpose(-1, -2).float()
|
||||
alphas = alphas.squeeze(-1)
|
||||
if target_label is not None:
|
||||
target_length = (target_label != ignore_id).float().sum(-1)
|
||||
else:
|
||||
target_length = None
|
||||
cif_length = alphas.sum(-1)
|
||||
if target_label is not None:
|
||||
alphas *= (target_length / cif_length)[:, None].repeat(
|
||||
1, alphas.size(1))
|
||||
cif_output, cif_peak = cif(hidden, alphas, self.threshold)
|
||||
return cif_output, cif_length, target_length, cif_peak
|
||||
|
||||
def gen_frame_alignments(self,
|
||||
alphas: torch.Tensor = None,
|
||||
memory_sequence_length: torch.Tensor = None,
|
||||
is_training: bool = True,
|
||||
dtype: torch.dtype = torch.float32):
|
||||
batch_size, maximum_length = alphas.size()
|
||||
int_type = torch.int32
|
||||
token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
|
||||
|
||||
max_token_num = torch.max(token_num).item()
|
||||
|
||||
alphas_cumsum = torch.cumsum(alphas, dim=1)
|
||||
alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
|
||||
alphas_cumsum = torch.tile(alphas_cumsum[:, None, :],
|
||||
[1, max_token_num, 1])
|
||||
|
||||
index = torch.ones([batch_size, max_token_num], dtype=int_type)
|
||||
index = torch.cumsum(index, dim=1)
|
||||
index = torch.tile(index[:, :, None], [1, 1, maximum_length])
|
||||
|
||||
index_div = torch.floor(torch.divide(alphas_cumsum,
|
||||
index)).type(int_type)
|
||||
index_div_bool_zeros = index_div.eq(0)
|
||||
index_div_bool_zeros_count = torch.sum(
|
||||
index_div_bool_zeros, dim=-1) + 1
|
||||
index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count, 0,
|
||||
memory_sequence_length.max())
|
||||
token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(
|
||||
token_num.device)
|
||||
index_div_bool_zeros_count *= token_num_mask
|
||||
|
||||
index_div_bool_zeros_count_tile = torch.tile(
|
||||
index_div_bool_zeros_count[:, :, None], [1, 1, maximum_length])
|
||||
ones = torch.ones_like(index_div_bool_zeros_count_tile)
|
||||
zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
|
||||
ones = torch.cumsum(ones, dim=2)
|
||||
cond = index_div_bool_zeros_count_tile == ones
|
||||
index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
|
||||
|
||||
index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(
|
||||
torch.bool)
|
||||
index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(
|
||||
int_type)
|
||||
index_div_bool_zeros_count_tile_out = torch.sum(
|
||||
index_div_bool_zeros_count_tile, dim=1)
|
||||
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(
|
||||
int_type)
|
||||
predictor_mask = (~make_pad_mask(
|
||||
memory_sequence_length,
|
||||
maxlen=memory_sequence_length.max())).type(int_type).to(
|
||||
memory_sequence_length.device) # noqa: *
|
||||
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask
|
||||
return index_div_bool_zeros_count_tile_out.detach(
|
||||
), index_div_bool_zeros_count.detach()
|
||||
|
||||
|
||||
class cif_predictor(nn.Module):
|
||||
|
||||
def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1):
|
||||
super(cif_predictor, self).__init__()
|
||||
|
||||
self.pad = nn.ConstantPad1d((l_order, r_order), 0)
|
||||
self.cif_conv1d = nn.Conv1d(
|
||||
idim, idim, l_order + r_order + 1, groups=idim)
|
||||
self.cif_output = nn.Linear(idim, 1)
|
||||
self.dropout = torch.nn.Dropout(p=dropout)
|
||||
self.threshold = threshold
|
||||
|
||||
def forward(self,
|
||||
hidden,
|
||||
target_label=None,
|
||||
mask=None,
|
||||
ignore_id=-1,
|
||||
mask_chunk_predictor=None,
|
||||
target_label_length=None):
|
||||
h = hidden
|
||||
context = h.transpose(1, 2)
|
||||
queries = self.pad(context)
|
||||
memory = self.cif_conv1d(queries)
|
||||
output = memory + context
|
||||
output = self.dropout(output)
|
||||
output = output.transpose(1, 2)
|
||||
output = torch.relu(output)
|
||||
output = self.cif_output(output)
|
||||
alphas = torch.sigmoid(output)
|
||||
if mask is not None:
|
||||
alphas = alphas * mask.transpose(-1, -2).float()
|
||||
if mask_chunk_predictor is not None:
|
||||
alphas = alphas * mask_chunk_predictor
|
||||
alphas = alphas.squeeze(-1)
|
||||
if target_label_length is not None:
|
||||
target_length = target_label_length
|
||||
elif target_label is not None:
|
||||
target_length = (target_label != ignore_id).float().sum(-1)
|
||||
else:
|
||||
target_length = None
|
||||
token_num = alphas.sum(-1)
|
||||
if target_length is not None:
|
||||
alphas *= (target_length / token_num)[:, None].repeat(
|
||||
1, alphas.size(1))
|
||||
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
|
||||
return acoustic_embeds, token_num, alphas, cif_peak
|
||||
|
||||
def gen_frame_alignments(self,
|
||||
alphas: torch.Tensor = None,
|
||||
memory_sequence_length: torch.Tensor = None,
|
||||
is_training: bool = True,
|
||||
dtype: torch.dtype = torch.float32):
|
||||
batch_size, maximum_length = alphas.size()
|
||||
int_type = torch.int32
|
||||
token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
|
||||
|
||||
max_token_num = torch.max(token_num).item()
|
||||
|
||||
alphas_cumsum = torch.cumsum(alphas, dim=1)
|
||||
alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
|
||||
alphas_cumsum = torch.tile(alphas_cumsum[:, None, :],
|
||||
[1, max_token_num, 1])
|
||||
|
||||
index = torch.ones([batch_size, max_token_num], dtype=int_type)
|
||||
index = torch.cumsum(index, dim=1)
|
||||
index = torch.tile(index[:, :, None], [1, 1, maximum_length])
|
||||
|
||||
index_div = torch.floor(torch.divide(alphas_cumsum,
|
||||
index)).type(int_type)
|
||||
index_div_bool_zeros = index_div.eq(0)
|
||||
index_div_bool_zeros_count = torch.sum(
|
||||
index_div_bool_zeros, dim=-1) + 1
|
||||
index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count, 0,
|
||||
memory_sequence_length.max())
|
||||
token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(
|
||||
token_num.device)
|
||||
index_div_bool_zeros_count *= token_num_mask
|
||||
|
||||
index_div_bool_zeros_count_tile = torch.tile(
|
||||
index_div_bool_zeros_count[:, :, None], [1, 1, maximum_length])
|
||||
ones = torch.ones_like(index_div_bool_zeros_count_tile)
|
||||
zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
|
||||
ones = torch.cumsum(ones, dim=2)
|
||||
cond = index_div_bool_zeros_count_tile == ones
|
||||
index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
|
||||
|
||||
index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(
|
||||
torch.bool)
|
||||
index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(
|
||||
int_type)
|
||||
index_div_bool_zeros_count_tile_out = torch.sum(
|
||||
index_div_bool_zeros_count_tile, dim=1)
|
||||
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(
|
||||
int_type)
|
||||
predictor_mask = (~make_pad_mask(
|
||||
memory_sequence_length,
|
||||
maxlen=memory_sequence_length.max())).type(int_type).to(
|
||||
memory_sequence_length.device) # noqa: *
|
||||
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask
|
||||
return index_div_bool_zeros_count_tile_out.detach(
|
||||
), index_div_bool_zeros_count.detach()
|
||||
|
||||
|
||||
def cif(hidden, alphas, threshold):
|
||||
batch_size, len_time, hidden_size = hidden.size()
|
||||
|
||||
# loop varss
|
||||
integrate = torch.zeros([batch_size], device=hidden.device)
|
||||
frame = torch.zeros([batch_size, hidden_size], device=hidden.device)
|
||||
# intermediate vars along time
|
||||
list_fires = []
|
||||
list_frames = []
|
||||
|
||||
for t in range(len_time):
|
||||
alpha = alphas[:, t]
|
||||
distribution_completion = torch.ones([batch_size],
|
||||
device=hidden.device) - integrate
|
||||
|
||||
integrate += alpha
|
||||
list_fires.append(integrate)
|
||||
|
||||
fire_place = integrate >= threshold
|
||||
integrate = torch.where(
|
||||
fire_place,
|
||||
integrate - torch.ones([batch_size], device=hidden.device),
|
||||
integrate)
|
||||
cur = torch.where(fire_place, distribution_completion, alpha)
|
||||
remainds = alpha - cur
|
||||
|
||||
frame += cur[:, None] * hidden[:, t, :]
|
||||
list_frames.append(frame)
|
||||
frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
|
||||
remainds[:, None] * hidden[:, t, :], frame)
|
||||
|
||||
fires = torch.stack(list_fires, 1)
|
||||
frames = torch.stack(list_frames, 1)
|
||||
list_ls = []
|
||||
len_labels = torch.round(alphas.sum(-1)).int()
|
||||
max_label_len = len_labels.max()
|
||||
for b in range(batch_size):
|
||||
fire = fires[b, :]
|
||||
ls = torch.index_select(frames[b, :, :], 0,
|
||||
torch.nonzero(fire >= threshold).squeeze())
|
||||
pad_l = torch.zeros([max_label_len - ls.size(0), hidden_size],
|
||||
device=hidden.device)
|
||||
list_ls.append(torch.cat([ls, pad_l], 0))
|
||||
return torch.stack(list_ls, 0), fires
|
||||
@@ -0,0 +1,680 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# Part of the implementation is borrowed from espnet/espnet.
|
||||
"""Multi-Head Attention layer definition."""
|
||||
|
||||
import logging
|
||||
import math
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
torch.set_printoptions(profile='full', precision=1)
|
||||
|
||||
|
||||
class MultiHeadedAttention(nn.Module):
|
||||
"""Multi-Head Attention layer.
|
||||
|
||||
Args:
|
||||
n_head (int): The number of heads.
|
||||
n_feat (int): The number of features.
|
||||
dropout_rate (float): Dropout rate.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, n_head, n_feat, dropout_rate):
|
||||
"""Construct an MultiHeadedAttention object."""
|
||||
super(MultiHeadedAttention, self).__init__()
|
||||
assert n_feat % n_head == 0
|
||||
# We assume d_v always equals d_k
|
||||
self.d_k = n_feat // n_head
|
||||
self.h = n_head
|
||||
self.linear_q = nn.Linear(n_feat, n_feat)
|
||||
self.linear_k = nn.Linear(n_feat, n_feat)
|
||||
self.linear_v = nn.Linear(n_feat, n_feat)
|
||||
self.linear_out = nn.Linear(n_feat, n_feat)
|
||||
self.attn = None
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
|
||||
def forward_qkv(self, query, key, value):
|
||||
"""Transform query, key and value.
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): Query tensor (#batch, time1, size).
|
||||
key (torch.Tensor): Key tensor (#batch, time2, size).
|
||||
value (torch.Tensor): Value tensor (#batch, time2, size).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
|
||||
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
|
||||
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
|
||||
|
||||
"""
|
||||
n_batch = query.size(0)
|
||||
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
||||
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
||||
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
||||
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
||||
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
||||
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
||||
|
||||
return q, k, v
|
||||
|
||||
def forward_attention(self, value, scores, mask):
|
||||
"""Compute attention context vector.
|
||||
|
||||
Args:
|
||||
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
|
||||
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
|
||||
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed value (#batch, time1, d_model)
|
||||
weighted by the attention score (#batch, time1, time2).
|
||||
|
||||
"""
|
||||
n_batch = value.size(0)
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
||||
min_value = float(
|
||||
numpy.finfo(torch.tensor(
|
||||
0, dtype=scores.dtype).numpy().dtype).min)
|
||||
scores = scores.masked_fill(mask, min_value)
|
||||
self.attn = torch.softmax(
|
||||
scores, dim=-1).masked_fill(mask,
|
||||
0.0) # (batch, head, time1, time2)
|
||||
else:
|
||||
self.attn = torch.softmax(
|
||||
scores, dim=-1) # (batch, head, time1, time2)
|
||||
|
||||
p_attn = self.dropout(self.attn)
|
||||
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
||||
x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
|
||||
self.h * self.d_k)
|
||||
) # (batch, time1, d_model)
|
||||
|
||||
return self.linear_out(x) # (batch, time1, d_model)
|
||||
|
||||
def forward(self, query, key, value, mask):
|
||||
"""Compute scaled dot product attention.
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): Query tensor (#batch, time1, size).
|
||||
key (torch.Tensor): Key tensor (#batch, time2, size).
|
||||
value (torch.Tensor): Value tensor (#batch, time2, size).
|
||||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
||||
(#batch, time1, time2).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time1, d_model).
|
||||
|
||||
"""
|
||||
q, k, v = self.forward_qkv(query, key, value)
|
||||
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
||||
return self.forward_attention(v, scores, mask)
|
||||
|
||||
|
||||
class MultiHeadedAttentionSANM(nn.Module):
|
||||
"""Multi-Head Attention layer.
|
||||
|
||||
Args:
|
||||
n_head (int): The number of heads.
|
||||
n_feat (int): The number of features.
|
||||
dropout_rate (float): Dropout rate.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
n_head,
|
||||
n_feat,
|
||||
dropout_rate,
|
||||
kernel_size,
|
||||
sanm_shfit=0):
|
||||
"""Construct an MultiHeadedAttention object."""
|
||||
super(MultiHeadedAttentionSANM, self).__init__()
|
||||
assert n_feat % n_head == 0
|
||||
# We assume d_v always equals d_k
|
||||
self.d_k = n_feat // n_head
|
||||
self.h = n_head
|
||||
self.linear_q = nn.Linear(n_feat, n_feat)
|
||||
self.linear_k = nn.Linear(n_feat, n_feat)
|
||||
self.linear_v = nn.Linear(n_feat, n_feat)
|
||||
self.linear_out = nn.Linear(n_feat, n_feat)
|
||||
self.attn = None
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
|
||||
self.fsmn_block = nn.Conv1d(
|
||||
n_feat,
|
||||
n_feat,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
groups=n_feat,
|
||||
bias=False)
|
||||
# padding
|
||||
left_padding = (kernel_size - 1) // 2
|
||||
if sanm_shfit > 0:
|
||||
left_padding = left_padding + sanm_shfit
|
||||
right_padding = kernel_size - 1 - left_padding
|
||||
self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
|
||||
|
||||
def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None):
|
||||
'''
|
||||
:param x: (#batch, time1, size).
|
||||
:param mask: Mask tensor (#batch, 1, time)
|
||||
:return:
|
||||
'''
|
||||
# b, t, d = inputs.size()
|
||||
mask = mask[:, 0, :, None]
|
||||
if mask_shfit_chunk is not None:
|
||||
mask = mask * mask_shfit_chunk
|
||||
inputs *= mask
|
||||
x = inputs.transpose(1, 2)
|
||||
x = self.pad_fn(x)
|
||||
x = self.fsmn_block(x)
|
||||
x = x.transpose(1, 2)
|
||||
x += inputs
|
||||
x = self.dropout(x)
|
||||
return x * mask
|
||||
|
||||
def forward_qkv(self, query, key, value):
|
||||
"""Transform query, key and value.
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): Query tensor (#batch, time1, size).
|
||||
key (torch.Tensor): Key tensor (#batch, time2, size).
|
||||
value (torch.Tensor): Value tensor (#batch, time2, size).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
|
||||
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
|
||||
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
|
||||
|
||||
"""
|
||||
n_batch = query.size(0)
|
||||
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
||||
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
||||
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
||||
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
||||
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
||||
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
||||
|
||||
return q, k, v
|
||||
|
||||
def forward_attention(self,
|
||||
value,
|
||||
scores,
|
||||
mask,
|
||||
mask_att_chunk_encoder=None):
|
||||
"""Compute attention context vector.
|
||||
|
||||
Args:
|
||||
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
|
||||
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
|
||||
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed value (#batch, time1, d_model)
|
||||
weighted by the attention score (#batch, time1, time2).
|
||||
|
||||
"""
|
||||
n_batch = value.size(0)
|
||||
if mask is not None:
|
||||
if mask_att_chunk_encoder is not None:
|
||||
mask = mask * mask_att_chunk_encoder
|
||||
|
||||
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
||||
|
||||
min_value = float(
|
||||
numpy.finfo(torch.tensor(
|
||||
0, dtype=scores.dtype).numpy().dtype).min)
|
||||
scores = scores.masked_fill(mask, min_value)
|
||||
self.attn = torch.softmax(
|
||||
scores, dim=-1).masked_fill(mask,
|
||||
0.0) # (batch, head, time1, time2)
|
||||
else:
|
||||
self.attn = torch.softmax(
|
||||
scores, dim=-1) # (batch, head, time1, time2)
|
||||
|
||||
p_attn = self.dropout(self.attn)
|
||||
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
||||
x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
|
||||
self.h * self.d_k)
|
||||
) # (batch, time1, d_model)
|
||||
|
||||
return self.linear_out(x) # (batch, time1, d_model)
|
||||
|
||||
def forward(self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
mask,
|
||||
mask_shfit_chunk=None,
|
||||
mask_att_chunk_encoder=None):
|
||||
"""Compute scaled dot product attention.
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): Query tensor (#batch, time1, size).
|
||||
key (torch.Tensor): Key tensor (#batch, time2, size).
|
||||
value (torch.Tensor): Value tensor (#batch, time2, size).
|
||||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
||||
(#batch, time1, time2).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time1, d_model).
|
||||
|
||||
"""
|
||||
fsmn_memory = self.forward_fsmn(value, mask, mask_shfit_chunk)
|
||||
q, k, v = self.forward_qkv(query, key, value)
|
||||
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
||||
att_outs = self.forward_attention(v, scores, mask,
|
||||
mask_att_chunk_encoder)
|
||||
return att_outs + fsmn_memory
|
||||
|
||||
|
||||
class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
|
||||
"""Multi-Head Attention layer with relative position encoding (old version).
|
||||
|
||||
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
||||
|
||||
Paper: https://arxiv.org/abs/1901.02860
|
||||
|
||||
Args:
|
||||
n_head (int): The number of heads.
|
||||
n_feat (int): The number of features.
|
||||
dropout_rate (float): Dropout rate.
|
||||
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
|
||||
"""Construct an RelPositionMultiHeadedAttention object."""
|
||||
super().__init__(n_head, n_feat, dropout_rate)
|
||||
self.zero_triu = zero_triu
|
||||
# linear transformation for positional encoding
|
||||
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
||||
# these two learnable bias are used in matrix c and matrix d
|
||||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
||||
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
||||
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
||||
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
||||
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
||||
|
||||
def rel_shift(self, x):
|
||||
"""Compute relative positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, head, time1, time2).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor.
|
||||
|
||||
"""
|
||||
zero_pad = torch.zeros((*x.size()[:3], 1),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
x_padded = torch.cat([zero_pad, x], dim=-1)
|
||||
|
||||
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
|
||||
x = x_padded[:, :, 1:].view_as(x)
|
||||
|
||||
if self.zero_triu:
|
||||
ones = torch.ones((x.size(2), x.size(3)))
|
||||
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, query, key, value, pos_emb, mask):
|
||||
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): Query tensor (#batch, time1, size).
|
||||
key (torch.Tensor): Key tensor (#batch, time2, size).
|
||||
value (torch.Tensor): Value tensor (#batch, time2, size).
|
||||
pos_emb (torch.Tensor): Positional embedding tensor (#batch, time1, size).
|
||||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
||||
(#batch, time1, time2).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time1, d_model).
|
||||
|
||||
"""
|
||||
q, k, v = self.forward_qkv(query, key, value)
|
||||
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
||||
|
||||
n_batch_pos = pos_emb.size(0)
|
||||
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
||||
p = p.transpose(1, 2) # (batch, head, time1, d_k)
|
||||
|
||||
# (batch, head, time1, d_k)
|
||||
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
||||
# (batch, head, time1, d_k)
|
||||
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
||||
|
||||
# compute attention score
|
||||
# first compute matrix a and matrix c
|
||||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
||||
# (batch, head, time1, time2)
|
||||
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
||||
|
||||
# compute matrix b and matrix d
|
||||
# (batch, head, time1, time1)
|
||||
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
||||
matrix_bd = self.rel_shift(matrix_bd)
|
||||
|
||||
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
||||
self.d_k) # (batch, head, time1, time2)
|
||||
|
||||
return self.forward_attention(v, scores, mask)
|
||||
|
||||
|
||||
class LegacyRelPositionMultiHeadedAttentionSANM(MultiHeadedAttentionSANM):
|
||||
"""Multi-Head Attention layer with relative position encoding (old version).
|
||||
|
||||
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
||||
|
||||
Paper: https://arxiv.org/abs/1901.02860
|
||||
|
||||
Args:
|
||||
n_head (int): The number of heads.
|
||||
n_feat (int): The number of features.
|
||||
dropout_rate (float): Dropout rate.
|
||||
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
n_head,
|
||||
n_feat,
|
||||
dropout_rate,
|
||||
zero_triu=False,
|
||||
kernel_size=15,
|
||||
sanm_shfit=0):
|
||||
"""Construct an RelPositionMultiHeadedAttention object."""
|
||||
super().__init__(n_head, n_feat, dropout_rate, kernel_size, sanm_shfit)
|
||||
self.zero_triu = zero_triu
|
||||
# linear transformation for positional encoding
|
||||
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
||||
# these two learnable bias are used in matrix c and matrix d
|
||||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
||||
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
||||
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
||||
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
||||
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
||||
|
||||
def rel_shift(self, x):
|
||||
"""Compute relative positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, head, time1, time2).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor.
|
||||
|
||||
"""
|
||||
zero_pad = torch.zeros((*x.size()[:3], 1),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
x_padded = torch.cat([zero_pad, x], dim=-1)
|
||||
|
||||
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
|
||||
x = x_padded[:, :, 1:].view_as(x)
|
||||
|
||||
if self.zero_triu:
|
||||
ones = torch.ones((x.size(2), x.size(3)))
|
||||
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, query, key, value, pos_emb, mask):
|
||||
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): Query tensor (#batch, time1, size).
|
||||
key (torch.Tensor): Key tensor (#batch, time2, size).
|
||||
value (torch.Tensor): Value tensor (#batch, time2, size).
|
||||
pos_emb (torch.Tensor): Positional embedding tensor (#batch, time1, size).
|
||||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
||||
(#batch, time1, time2).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time1, d_model).
|
||||
|
||||
"""
|
||||
fsmn_memory = self.forward_fsmn(value, mask)
|
||||
q, k, v = self.forward_qkv(query, key, value)
|
||||
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
||||
|
||||
n_batch_pos = pos_emb.size(0)
|
||||
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
||||
p = p.transpose(1, 2) # (batch, head, time1, d_k)
|
||||
|
||||
# (batch, head, time1, d_k)
|
||||
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
||||
# (batch, head, time1, d_k)
|
||||
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
||||
|
||||
# compute attention score
|
||||
# first compute matrix a and matrix c
|
||||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
||||
# (batch, head, time1, time2)
|
||||
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
||||
|
||||
# compute matrix b and matrix d
|
||||
# (batch, head, time1, time1)
|
||||
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
||||
matrix_bd = self.rel_shift(matrix_bd)
|
||||
|
||||
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
||||
self.d_k) # (batch, head, time1, time2)
|
||||
|
||||
att_outs = self.forward_attention(v, scores, mask)
|
||||
return att_outs + fsmn_memory
|
||||
|
||||
|
||||
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
||||
"""Multi-Head Attention layer with relative position encoding (new implementation).
|
||||
|
||||
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
||||
|
||||
Paper: https://arxiv.org/abs/1901.02860
|
||||
|
||||
Args:
|
||||
n_head (int): The number of heads.
|
||||
n_feat (int): The number of features.
|
||||
dropout_rate (float): Dropout rate.
|
||||
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
|
||||
"""Construct an RelPositionMultiHeadedAttention object."""
|
||||
super().__init__(n_head, n_feat, dropout_rate)
|
||||
self.zero_triu = zero_triu
|
||||
# linear transformation for positional encoding
|
||||
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
||||
# these two learnable bias are used in matrix c and matrix d
|
||||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
||||
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
||||
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
||||
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
||||
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
||||
|
||||
def rel_shift(self, x):
|
||||
"""Compute relative positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
|
||||
time1 means the length of query vector.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor.
|
||||
|
||||
"""
|
||||
zero_pad = torch.zeros((*x.size()[:3], 1),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
x_padded = torch.cat([zero_pad, x], dim=-1)
|
||||
|
||||
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
|
||||
x = x_padded[:, :, 1:].view_as(
|
||||
x)[:, :, :, :x.size(-1) // 2
|
||||
+ 1] # only keep the positions from 0 to time2
|
||||
|
||||
if self.zero_triu:
|
||||
ones = torch.ones((x.size(2), x.size(3)), device=x.device)
|
||||
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, query, key, value, pos_emb, mask):
|
||||
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): Query tensor (#batch, time1, size).
|
||||
key (torch.Tensor): Key tensor (#batch, time2, size).
|
||||
value (torch.Tensor): Value tensor (#batch, time2, size).
|
||||
pos_emb (torch.Tensor): Positional embedding tensor
|
||||
(#batch, 2*time1-1, size).
|
||||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
||||
(#batch, time1, time2).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time1, d_model).
|
||||
|
||||
"""
|
||||
q, k, v = self.forward_qkv(query, key, value)
|
||||
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
||||
|
||||
n_batch_pos = pos_emb.size(0)
|
||||
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
||||
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
|
||||
|
||||
# (batch, head, time1, d_k)
|
||||
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
||||
# (batch, head, time1, d_k)
|
||||
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
||||
|
||||
# compute attention score
|
||||
# first compute matrix a and matrix c
|
||||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
||||
# (batch, head, time1, time2)
|
||||
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
||||
|
||||
# compute matrix b and matrix d
|
||||
# (batch, head, time1, 2*time1-1)
|
||||
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
||||
matrix_bd = self.rel_shift(matrix_bd)
|
||||
|
||||
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
||||
self.d_k) # (batch, head, time1, time2)
|
||||
|
||||
return self.forward_attention(v, scores, mask)
|
||||
|
||||
|
||||
class RelPositionMultiHeadedAttentionSANM(MultiHeadedAttentionSANM):
|
||||
"""Multi-Head Attention layer with relative position encoding (new implementation).
|
||||
|
||||
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
||||
|
||||
Paper: https://arxiv.org/abs/1901.02860
|
||||
|
||||
Args:
|
||||
n_head (int): The number of heads.
|
||||
n_feat (int): The number of features.
|
||||
dropout_rate (float): Dropout rate.
|
||||
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
n_head,
|
||||
n_feat,
|
||||
dropout_rate,
|
||||
zero_triu=False,
|
||||
kernel_size=15,
|
||||
sanm_shfit=0):
|
||||
"""Construct an RelPositionMultiHeadedAttention object."""
|
||||
super().__init__(n_head, n_feat, dropout_rate, kernel_size, sanm_shfit)
|
||||
self.zero_triu = zero_triu
|
||||
# linear transformation for positional encoding
|
||||
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
||||
# these two learnable bias are used in matrix c and matrix d
|
||||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
||||
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
||||
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
||||
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
||||
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
||||
|
||||
def rel_shift(self, x):
|
||||
"""Compute relative positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
|
||||
time1 means the length of query vector.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor.
|
||||
|
||||
"""
|
||||
zero_pad = torch.zeros((*x.size()[:3], 1),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
x_padded = torch.cat([zero_pad, x], dim=-1)
|
||||
|
||||
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
|
||||
x = x_padded[:, :, 1:].view_as(
|
||||
x)[:, :, :, :x.size(-1) // 2
|
||||
+ 1] # only keep the positions from 0 to time2
|
||||
|
||||
if self.zero_triu:
|
||||
ones = torch.ones((x.size(2), x.size(3)), device=x.device)
|
||||
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, query, key, value, pos_emb, mask):
|
||||
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): Query tensor (#batch, time1, size).
|
||||
key (torch.Tensor): Key tensor (#batch, time2, size).
|
||||
value (torch.Tensor): Value tensor (#batch, time2, size).
|
||||
pos_emb (torch.Tensor): Positional embedding tensor
|
||||
(#batch, 2*time1-1, size).
|
||||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
||||
(#batch, time1, time2).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time1, d_model).
|
||||
|
||||
"""
|
||||
fsmn_memory = self.forward_fsmn(value, mask)
|
||||
q, k, v = self.forward_qkv(query, key, value)
|
||||
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
||||
|
||||
n_batch_pos = pos_emb.size(0)
|
||||
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
||||
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
|
||||
|
||||
# (batch, head, time1, d_k)
|
||||
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
||||
# (batch, head, time1, d_k)
|
||||
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
||||
|
||||
# compute attention score
|
||||
# first compute matrix a and matrix c
|
||||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
||||
# (batch, head, time1, time2)
|
||||
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
||||
|
||||
# compute matrix b and matrix d
|
||||
# (batch, head, time1, 2*time1-1)
|
||||
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
||||
matrix_bd = self.rel_shift(matrix_bd)
|
||||
|
||||
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
||||
self.d_k) # (batch, head, time1, time2)
|
||||
|
||||
att_outs = self.forward_attention(v, scores, mask)
|
||||
return att_outs + fsmn_memory
|
||||
@@ -0,0 +1,239 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# Part of the implementation is borrowed from espnet/espnet.
|
||||
"""Encoder self-attention layer definition."""
|
||||
|
||||
import torch
|
||||
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
|
||||
from torch import nn
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
"""Encoder layer module.
|
||||
|
||||
Args:
|
||||
size (int): Input dimension.
|
||||
self_attn (torch.nn.Module): Self-attention module instance.
|
||||
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
|
||||
can be used as the argument.
|
||||
feed_forward (torch.nn.Module): Feed-forward module instance.
|
||||
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
|
||||
can be used as the argument.
|
||||
dropout_rate (float): Dropout rate.
|
||||
normalize_before (bool): Whether to use layer_norm before the first block.
|
||||
concat_after (bool): Whether to concat attention layer's input and output.
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||
stochastic_depth_rate (float): Proability to skip this layer.
|
||||
During training, the layer may skip residual computation and return input
|
||||
as-is with given probability.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
self_attn,
|
||||
feed_forward,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
stochastic_depth_rate=0.0,
|
||||
):
|
||||
"""Construct an EncoderLayer object."""
|
||||
super(EncoderLayer, self).__init__()
|
||||
self.self_attn = self_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.norm1 = LayerNorm(size)
|
||||
self.norm2 = LayerNorm(size)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.size = size
|
||||
self.normalize_before = normalize_before
|
||||
self.concat_after = concat_after
|
||||
if self.concat_after:
|
||||
self.concat_linear = nn.Linear(size + size, size)
|
||||
self.stochastic_depth_rate = stochastic_depth_rate
|
||||
|
||||
def forward(self, x, mask, cache=None):
|
||||
"""Compute encoded features.
|
||||
|
||||
Args:
|
||||
x_input (torch.Tensor): Input tensor (#batch, time, size).
|
||||
mask (torch.Tensor): Mask tensor for the input (#batch, time).
|
||||
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time, size).
|
||||
torch.Tensor: Mask tensor (#batch, time).
|
||||
|
||||
"""
|
||||
skip_layer = False
|
||||
# with stochastic depth, residual connection `x + f(x)` becomes
|
||||
# `x <- x + 1 / (1 - p) * f(x)` at training time.
|
||||
stoch_layer_coeff = 1.0
|
||||
if self.training and self.stochastic_depth_rate > 0:
|
||||
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
|
||||
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
|
||||
|
||||
if skip_layer:
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
return x, mask
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
if cache is None:
|
||||
x_q = x
|
||||
else:
|
||||
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
|
||||
x_q = x[:, -1:, :]
|
||||
residual = residual[:, -1:, :]
|
||||
mask = None if mask is None else mask[:, -1:, :]
|
||||
|
||||
if self.concat_after:
|
||||
x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1)
|
||||
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
|
||||
else:
|
||||
x = residual + stoch_layer_coeff * self.dropout(
|
||||
self.self_attn(x_q, x, x, mask))
|
||||
if not self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
|
||||
return x, mask
|
||||
|
||||
|
||||
class EncoderLayerChunk(nn.Module):
|
||||
"""Encoder layer module.
|
||||
|
||||
Args:
|
||||
size (int): Input dimension.
|
||||
self_attn (torch.nn.Module): Self-attention module instance.
|
||||
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
|
||||
can be used as the argument.
|
||||
feed_forward (torch.nn.Module): Feed-forward module instance.
|
||||
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
|
||||
can be used as the argument.
|
||||
dropout_rate (float): Dropout rate.
|
||||
normalize_before (bool): Whether to use layer_norm before the first block.
|
||||
concat_after (bool): Whether to concat attention layer's input and output.
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||
stochastic_depth_rate (float): Proability to skip this layer.
|
||||
During training, the layer may skip residual computation and return input
|
||||
as-is with given probability.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
self_attn,
|
||||
feed_forward,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
stochastic_depth_rate=0.0,
|
||||
):
|
||||
"""Construct an EncoderLayer object."""
|
||||
super(EncoderLayerChunk, self).__init__()
|
||||
self.self_attn = self_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.norm1 = LayerNorm(size)
|
||||
self.norm2 = LayerNorm(size)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.size = size
|
||||
self.normalize_before = normalize_before
|
||||
self.concat_after = concat_after
|
||||
if self.concat_after:
|
||||
self.concat_linear = nn.Linear(size + size, size)
|
||||
self.stochastic_depth_rate = stochastic_depth_rate
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
mask,
|
||||
cache=None,
|
||||
mask_shfit_chunk=None,
|
||||
mask_att_chunk_encoder=None):
|
||||
"""Compute encoded features.
|
||||
|
||||
Args:
|
||||
x_input (torch.Tensor): Input tensor (#batch, time, size).
|
||||
mask (torch.Tensor): Mask tensor for the input (#batch, time).
|
||||
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time, size).
|
||||
torch.Tensor: Mask tensor (#batch, time).
|
||||
|
||||
"""
|
||||
skip_layer = False
|
||||
# with stochastic depth, residual connection `x + f(x)` becomes
|
||||
# `x <- x + 1 / (1 - p) * f(x)` at training time.
|
||||
stoch_layer_coeff = 1.0
|
||||
if self.training and self.stochastic_depth_rate > 0:
|
||||
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
|
||||
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
|
||||
|
||||
if skip_layer:
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
return x, mask
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
if cache is None:
|
||||
x_q = x
|
||||
else:
|
||||
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
|
||||
x_q = x[:, -1:, :]
|
||||
residual = residual[:, -1:, :]
|
||||
mask = None if mask is None else mask[:, -1:, :]
|
||||
|
||||
if self.concat_after:
|
||||
x_concat = torch.cat(
|
||||
(x,
|
||||
self.self_attn(
|
||||
x_q,
|
||||
x,
|
||||
x,
|
||||
mask,
|
||||
mask_shfit_chunk=mask_shfit_chunk,
|
||||
mask_att_chunk_encoder=mask_att_chunk_encoder)),
|
||||
dim=-1)
|
||||
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
|
||||
else:
|
||||
x = residual + stoch_layer_coeff * self.dropout(
|
||||
self.self_attn(
|
||||
x_q,
|
||||
x,
|
||||
x,
|
||||
mask,
|
||||
mask_shfit_chunk=mask_shfit_chunk,
|
||||
mask_att_chunk_encoder=mask_att_chunk_encoder))
|
||||
if not self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
|
||||
return x, mask, None, mask_shfit_chunk, mask_att_chunk_encoder
|
||||
890
modelscope/pipelines/audio/asr/asr_engine/espnet/tasks/asr.py
Normal file
890
modelscope/pipelines/audio/asr/asr_engine/espnet/tasks/asr.py
Normal file
@@ -0,0 +1,890 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# Part of the implementation is borrowed from espnet/espnet.
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable, Collection, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import yaml
|
||||
from espnet2.asr.ctc import CTC
|
||||
from espnet2.asr.decoder.abs_decoder import AbsDecoder
|
||||
from espnet2.asr.decoder.mlm_decoder import MLMDecoder
|
||||
from espnet2.asr.decoder.rnn_decoder import RNNDecoder
|
||||
from espnet2.asr.decoder.transformer_decoder import \
|
||||
DynamicConvolution2DTransformerDecoder # noqa: H301
|
||||
from espnet2.asr.decoder.transformer_decoder import \
|
||||
LightweightConvolution2DTransformerDecoder # noqa: H301
|
||||
from espnet2.asr.decoder.transformer_decoder import \
|
||||
LightweightConvolutionTransformerDecoder # noqa: H301
|
||||
from espnet2.asr.decoder.transformer_decoder import (
|
||||
DynamicConvolutionTransformerDecoder, TransformerDecoder)
|
||||
from espnet2.asr.encoder.abs_encoder import AbsEncoder
|
||||
from espnet2.asr.encoder.contextual_block_conformer_encoder import \
|
||||
ContextualBlockConformerEncoder # noqa: H301
|
||||
from espnet2.asr.encoder.contextual_block_transformer_encoder import \
|
||||
ContextualBlockTransformerEncoder # noqa: H301
|
||||
from espnet2.asr.encoder.hubert_encoder import (FairseqHubertEncoder,
|
||||
FairseqHubertPretrainEncoder)
|
||||
from espnet2.asr.encoder.longformer_encoder import LongformerEncoder
|
||||
from espnet2.asr.encoder.rnn_encoder import RNNEncoder
|
||||
from espnet2.asr.encoder.transformer_encoder import TransformerEncoder
|
||||
from espnet2.asr.encoder.vgg_rnn_encoder import VGGRNNEncoder
|
||||
from espnet2.asr.encoder.wav2vec2_encoder import FairSeqWav2Vec2Encoder
|
||||
from espnet2.asr.espnet_model import ESPnetASRModel
|
||||
from espnet2.asr.frontend.abs_frontend import AbsFrontend
|
||||
from espnet2.asr.frontend.default import DefaultFrontend
|
||||
from espnet2.asr.frontend.fused import FusedFrontends
|
||||
from espnet2.asr.frontend.s3prl import S3prlFrontend
|
||||
from espnet2.asr.frontend.windowing import SlidingWindow
|
||||
from espnet2.asr.maskctc_model import MaskCTCModel
|
||||
from espnet2.asr.postencoder.abs_postencoder import AbsPostEncoder
|
||||
from espnet2.asr.postencoder.hugging_face_transformers_postencoder import \
|
||||
HuggingFaceTransformersPostEncoder # noqa: H301
|
||||
from espnet2.asr.preencoder.abs_preencoder import AbsPreEncoder
|
||||
from espnet2.asr.preencoder.linear import LinearProjection
|
||||
from espnet2.asr.preencoder.sinc import LightweightSincConvs
|
||||
from espnet2.asr.specaug.abs_specaug import AbsSpecAug
|
||||
from espnet2.asr.specaug.specaug import SpecAug
|
||||
from espnet2.asr.transducer.joint_network import JointNetwork
|
||||
from espnet2.asr.transducer.transducer_decoder import TransducerDecoder
|
||||
from espnet2.layers.abs_normalize import AbsNormalize
|
||||
from espnet2.layers.global_mvn import GlobalMVN
|
||||
from espnet2.layers.utterance_mvn import UtteranceMVN
|
||||
from espnet2.tasks.abs_task import AbsTask
|
||||
from espnet2.text.phoneme_tokenizer import g2p_choices
|
||||
from espnet2.torch_utils.initialize import initialize
|
||||
from espnet2.train.abs_espnet_model import AbsESPnetModel
|
||||
from espnet2.train.class_choices import ClassChoices
|
||||
from espnet2.train.collate_fn import CommonCollateFn
|
||||
from espnet2.train.preprocessor import CommonPreprocessor
|
||||
from espnet2.train.trainer import Trainer
|
||||
from espnet2.utils.get_default_kwargs import get_default_kwargs
|
||||
from espnet2.utils.nested_dict_action import NestedDictAction
|
||||
from espnet2.utils.types import (float_or_none, int_or_none, str2bool,
|
||||
str_or_none)
|
||||
from typeguard import check_argument_types, check_return_type
|
||||
|
||||
from ..asr.decoder.transformer_decoder import (ParaformerDecoder,
|
||||
ParaformerDecoderBertEmbed)
|
||||
from ..asr.encoder.conformer_encoder import ConformerEncoder, SANMEncoder_v2
|
||||
from ..asr.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunk
|
||||
from ..asr.espnet_model import AEDStreaming
|
||||
from ..asr.espnet_model_paraformer import Paraformer, ParaformerBertEmbed
|
||||
from ..nets.pytorch_backend.cif_utils.cif import cif_predictor
|
||||
|
||||
# FIXME(wjm): suggested by fairseq, We need to setup root logger before importing any fairseq libraries.
|
||||
logging.basicConfig(
|
||||
level='INFO',
|
||||
format=f"[{os.uname()[1].split('.')[0]}]"
|
||||
f' %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
|
||||
)
|
||||
# FIXME(wjm): create logger to set level, unset __name__ for different files to share the same logger
|
||||
logger = logging.getLogger()
|
||||
|
||||
frontend_choices = ClassChoices(
|
||||
name='frontend',
|
||||
classes=dict(
|
||||
default=DefaultFrontend,
|
||||
sliding_window=SlidingWindow,
|
||||
s3prl=S3prlFrontend,
|
||||
fused=FusedFrontends,
|
||||
),
|
||||
type_check=AbsFrontend,
|
||||
default='default',
|
||||
)
|
||||
specaug_choices = ClassChoices(
|
||||
name='specaug',
|
||||
classes=dict(specaug=SpecAug, ),
|
||||
type_check=AbsSpecAug,
|
||||
default=None,
|
||||
optional=True,
|
||||
)
|
||||
normalize_choices = ClassChoices(
|
||||
'normalize',
|
||||
classes=dict(
|
||||
global_mvn=GlobalMVN,
|
||||
utterance_mvn=UtteranceMVN,
|
||||
),
|
||||
type_check=AbsNormalize,
|
||||
default='utterance_mvn',
|
||||
optional=True,
|
||||
)
|
||||
model_choices = ClassChoices(
|
||||
'model',
|
||||
classes=dict(
|
||||
espnet=ESPnetASRModel,
|
||||
maskctc=MaskCTCModel,
|
||||
paraformer=Paraformer,
|
||||
paraformer_bert_embed=ParaformerBertEmbed,
|
||||
aedstreaming=AEDStreaming,
|
||||
),
|
||||
type_check=AbsESPnetModel,
|
||||
default='espnet',
|
||||
)
|
||||
preencoder_choices = ClassChoices(
|
||||
name='preencoder',
|
||||
classes=dict(
|
||||
sinc=LightweightSincConvs,
|
||||
linear=LinearProjection,
|
||||
),
|
||||
type_check=AbsPreEncoder,
|
||||
default=None,
|
||||
optional=True,
|
||||
)
|
||||
encoder_choices = ClassChoices(
|
||||
'encoder',
|
||||
classes=dict(
|
||||
conformer=ConformerEncoder,
|
||||
transformer=TransformerEncoder,
|
||||
contextual_block_transformer=ContextualBlockTransformerEncoder,
|
||||
contextual_block_conformer=ContextualBlockConformerEncoder,
|
||||
vgg_rnn=VGGRNNEncoder,
|
||||
rnn=RNNEncoder,
|
||||
wav2vec2=FairSeqWav2Vec2Encoder,
|
||||
hubert=FairseqHubertEncoder,
|
||||
hubert_pretrain=FairseqHubertPretrainEncoder,
|
||||
longformer=LongformerEncoder,
|
||||
sanm=SANMEncoder,
|
||||
sanm_v2=SANMEncoder_v2,
|
||||
sanm_chunk=SANMEncoderChunk,
|
||||
),
|
||||
type_check=AbsEncoder,
|
||||
default='rnn',
|
||||
)
|
||||
postencoder_choices = ClassChoices(
|
||||
name='postencoder',
|
||||
classes=dict(
|
||||
hugging_face_transformers=HuggingFaceTransformersPostEncoder, ),
|
||||
type_check=AbsPostEncoder,
|
||||
default=None,
|
||||
optional=True,
|
||||
)
|
||||
decoder_choices = ClassChoices(
|
||||
'decoder',
|
||||
classes=dict(
|
||||
transformer=TransformerDecoder,
|
||||
lightweight_conv=LightweightConvolutionTransformerDecoder,
|
||||
lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
|
||||
dynamic_conv=DynamicConvolutionTransformerDecoder,
|
||||
dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
|
||||
rnn=RNNDecoder,
|
||||
transducer=TransducerDecoder,
|
||||
mlm=MLMDecoder,
|
||||
paraformer_decoder=ParaformerDecoder,
|
||||
paraformer_decoder_bert_embed=ParaformerDecoderBertEmbed,
|
||||
),
|
||||
type_check=AbsDecoder,
|
||||
default='rnn',
|
||||
)
|
||||
|
||||
predictor_choices = ClassChoices(
|
||||
name='predictor',
|
||||
classes=dict(
|
||||
cif_predictor=cif_predictor,
|
||||
ctc_predictor=None,
|
||||
),
|
||||
type_check=None,
|
||||
default='cif_predictor',
|
||||
optional=True,
|
||||
)
|
||||
|
||||
|
||||
class ASRTask(AbsTask):
|
||||
# If you need more than one optimizers, change this value
|
||||
num_optimizers: int = 1
|
||||
|
||||
# Add variable objects configurations
|
||||
class_choices_list = [
|
||||
# --frontend and --frontend_conf
|
||||
frontend_choices,
|
||||
# --specaug and --specaug_conf
|
||||
specaug_choices,
|
||||
# --normalize and --normalize_conf
|
||||
normalize_choices,
|
||||
# --model and --model_conf
|
||||
model_choices,
|
||||
# --preencoder and --preencoder_conf
|
||||
preencoder_choices,
|
||||
# --encoder and --encoder_conf
|
||||
encoder_choices,
|
||||
# --postencoder and --postencoder_conf
|
||||
postencoder_choices,
|
||||
# --decoder and --decoder_conf
|
||||
decoder_choices,
|
||||
]
|
||||
|
||||
# If you need to modify train() or eval() procedures, change Trainer class here
|
||||
trainer = Trainer
|
||||
|
||||
@classmethod
|
||||
def add_task_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(description='Task related')
|
||||
|
||||
# NOTE(kamo): add_arguments(..., required=True) can't be used
|
||||
# to provide --print_config mode. Instead of it, do as
|
||||
required = parser.get_default('required')
|
||||
required += ['token_list']
|
||||
|
||||
group.add_argument(
|
||||
'--token_list',
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help='A text mapping int-id to token',
|
||||
)
|
||||
group.add_argument(
|
||||
'--init',
|
||||
type=lambda x: str_or_none(x.lower()),
|
||||
default=None,
|
||||
help='The initialization method',
|
||||
choices=[
|
||||
'chainer',
|
||||
'xavier_uniform',
|
||||
'xavier_normal',
|
||||
'kaiming_uniform',
|
||||
'kaiming_normal',
|
||||
None,
|
||||
],
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
'--input_size',
|
||||
type=int_or_none,
|
||||
default=None,
|
||||
help='The number of input dimension of the feature',
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
'--ctc_conf',
|
||||
action=NestedDictAction,
|
||||
default=get_default_kwargs(CTC),
|
||||
help='The keyword arguments for CTC class.',
|
||||
)
|
||||
group.add_argument(
|
||||
'--joint_net_conf',
|
||||
action=NestedDictAction,
|
||||
default=None,
|
||||
help='The keyword arguments for joint network class.',
|
||||
)
|
||||
|
||||
group = parser.add_argument_group(description='Preprocess related')
|
||||
group.add_argument(
|
||||
'--use_preprocessor',
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help='Apply preprocessing to data or not',
|
||||
)
|
||||
group.add_argument(
|
||||
'--token_type',
|
||||
type=str,
|
||||
default='bpe',
|
||||
choices=['bpe', 'char', 'word', 'phn'],
|
||||
help='The text will be tokenized '
|
||||
'in the specified level token',
|
||||
)
|
||||
group.add_argument(
|
||||
'--bpemodel',
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help='The model file of sentencepiece',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--non_linguistic_symbols',
|
||||
type=str_or_none,
|
||||
help='non_linguistic_symbols file path',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--cleaner',
|
||||
type=str_or_none,
|
||||
choices=[None, 'tacotron', 'jaconv', 'vietnamese'],
|
||||
default=None,
|
||||
help='Apply text cleaning',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--g2p',
|
||||
type=str_or_none,
|
||||
choices=g2p_choices,
|
||||
default=None,
|
||||
help='Specify g2p method if --token_type=phn',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--speech_volume_normalize',
|
||||
type=float_or_none,
|
||||
default=None,
|
||||
help='Scale the maximum amplitude to the given value.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--rir_scp',
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help='The file path of rir scp file.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--rir_apply_prob',
|
||||
type=float,
|
||||
default=1.0,
|
||||
help='THe probability for applying RIR convolution.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--noise_scp',
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help='The file path of noise scp file.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--noise_apply_prob',
|
||||
type=float,
|
||||
default=1.0,
|
||||
help='The probability applying Noise adding.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--noise_db_range',
|
||||
type=str,
|
||||
default='13_15',
|
||||
help='The range of noise decibel level.',
|
||||
)
|
||||
|
||||
for class_choices in cls.class_choices_list:
|
||||
# Append --<name> and --<name>_conf.
|
||||
# e.g. --encoder and --encoder_conf
|
||||
class_choices.add_arguments(group)
|
||||
|
||||
@classmethod
|
||||
def build_collate_fn(
|
||||
cls, args: argparse.Namespace, train: bool
|
||||
) -> Callable[[Collection[Tuple[str, Dict[str, np.ndarray]]]], Tuple[
|
||||
List[str], Dict[str, torch.Tensor]], ]:
|
||||
assert check_argument_types()
|
||||
# NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
|
||||
return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
|
||||
|
||||
@classmethod
|
||||
def build_preprocess_fn(
|
||||
cls, args: argparse.Namespace, train: bool
|
||||
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
|
||||
assert check_argument_types()
|
||||
if args.use_preprocessor:
|
||||
retval = CommonPreprocessor(
|
||||
train=train,
|
||||
token_type=args.token_type,
|
||||
token_list=args.token_list,
|
||||
bpemodel=args.bpemodel,
|
||||
non_linguistic_symbols=args.non_linguistic_symbols,
|
||||
text_cleaner=args.cleaner,
|
||||
g2p_type=args.g2p,
|
||||
# NOTE(kamo): Check attribute existence for backward compatibility
|
||||
rir_scp=args.rir_scp if hasattr(args, 'rir_scp') else None,
|
||||
rir_apply_prob=args.rir_apply_prob if hasattr(
|
||||
args, 'rir_apply_prob') else 1.0,
|
||||
noise_scp=args.noise_scp
|
||||
if hasattr(args, 'noise_scp') else None,
|
||||
noise_apply_prob=args.noise_apply_prob if hasattr(
|
||||
args, 'noise_apply_prob') else 1.0,
|
||||
noise_db_range=args.noise_db_range if hasattr(
|
||||
args, 'noise_db_range') else '13_15',
|
||||
speech_volume_normalize=args.speech_volume_normalize
|
||||
if hasattr(args, 'rir_scp') else None,
|
||||
)
|
||||
else:
|
||||
retval = None
|
||||
assert check_return_type(retval)
|
||||
return retval
|
||||
|
||||
@classmethod
|
||||
def required_data_names(cls,
|
||||
train: bool = True,
|
||||
inference: bool = False) -> Tuple[str, ...]:
|
||||
if not inference:
|
||||
retval = ('speech', 'text')
|
||||
else:
|
||||
# Recognition mode
|
||||
retval = ('speech', )
|
||||
return retval
|
||||
|
||||
@classmethod
|
||||
def optional_data_names(cls,
|
||||
train: bool = True,
|
||||
inference: bool = False) -> Tuple[str, ...]:
|
||||
retval = ()
|
||||
assert check_return_type(retval)
|
||||
return retval
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args: argparse.Namespace) -> ESPnetASRModel:
|
||||
assert check_argument_types()
|
||||
if isinstance(args.token_list, str):
|
||||
with open(args.token_list, encoding='utf-8') as f:
|
||||
token_list = [line.rstrip() for line in f]
|
||||
|
||||
# Overwriting token_list to keep it as "portable".
|
||||
args.token_list = list(token_list)
|
||||
elif isinstance(args.token_list, (tuple, list)):
|
||||
token_list = list(args.token_list)
|
||||
else:
|
||||
raise RuntimeError('token_list must be str or list')
|
||||
vocab_size = len(token_list)
|
||||
logger.info(f'Vocabulary size: {vocab_size }')
|
||||
|
||||
# 1. frontend
|
||||
if args.input_size is None:
|
||||
# Extract features in the model
|
||||
frontend_class = frontend_choices.get_class(args.frontend)
|
||||
frontend = frontend_class(**args.frontend_conf)
|
||||
input_size = frontend.output_size()
|
||||
else:
|
||||
# Give features from data-loader
|
||||
args.frontend = None
|
||||
args.frontend_conf = {}
|
||||
frontend = None
|
||||
input_size = args.input_size
|
||||
|
||||
# 2. Data augmentation for spectrogram
|
||||
if args.specaug is not None:
|
||||
specaug_class = specaug_choices.get_class(args.specaug)
|
||||
specaug = specaug_class(**args.specaug_conf)
|
||||
else:
|
||||
specaug = None
|
||||
|
||||
# 3. Normalization layer
|
||||
if args.normalize is not None:
|
||||
normalize_class = normalize_choices.get_class(args.normalize)
|
||||
normalize = normalize_class(**args.normalize_conf)
|
||||
else:
|
||||
normalize = None
|
||||
|
||||
# 4. Pre-encoder input block
|
||||
# NOTE(kan-bayashi): Use getattr to keep the compatibility
|
||||
if getattr(args, 'preencoder', None) is not None:
|
||||
preencoder_class = preencoder_choices.get_class(args.preencoder)
|
||||
preencoder = preencoder_class(**args.preencoder_conf)
|
||||
input_size = preencoder.output_size()
|
||||
else:
|
||||
preencoder = None
|
||||
|
||||
# 4. Encoder
|
||||
encoder_class = encoder_choices.get_class(args.encoder)
|
||||
encoder = encoder_class(input_size=input_size, **args.encoder_conf)
|
||||
|
||||
# 5. Post-encoder block
|
||||
# NOTE(kan-bayashi): Use getattr to keep the compatibility
|
||||
encoder_output_size = encoder.output_size()
|
||||
if getattr(args, 'postencoder', None) is not None:
|
||||
postencoder_class = postencoder_choices.get_class(args.postencoder)
|
||||
postencoder = postencoder_class(
|
||||
input_size=encoder_output_size, **args.postencoder_conf)
|
||||
encoder_output_size = postencoder.output_size()
|
||||
else:
|
||||
postencoder = None
|
||||
|
||||
# 5. Decoder
|
||||
decoder_class = decoder_choices.get_class(args.decoder)
|
||||
|
||||
if args.decoder == 'transducer':
|
||||
decoder = decoder_class(
|
||||
vocab_size,
|
||||
embed_pad=0,
|
||||
**args.decoder_conf,
|
||||
)
|
||||
|
||||
joint_network = JointNetwork(
|
||||
vocab_size,
|
||||
encoder.output_size(),
|
||||
decoder.dunits,
|
||||
**args.joint_net_conf,
|
||||
)
|
||||
else:
|
||||
decoder = decoder_class(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
**args.decoder_conf,
|
||||
)
|
||||
|
||||
joint_network = None
|
||||
|
||||
# 6. CTC
|
||||
ctc = CTC(
|
||||
odim=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
**args.ctc_conf)
|
||||
|
||||
# 7. Build model
|
||||
try:
|
||||
model_class = model_choices.get_class(args.model)
|
||||
except AttributeError:
|
||||
model_class = model_choices.get_class('espnet')
|
||||
model = model_class(
|
||||
vocab_size=vocab_size,
|
||||
frontend=frontend,
|
||||
specaug=specaug,
|
||||
normalize=normalize,
|
||||
preencoder=preencoder,
|
||||
encoder=encoder,
|
||||
postencoder=postencoder,
|
||||
decoder=decoder,
|
||||
ctc=ctc,
|
||||
joint_network=joint_network,
|
||||
token_list=token_list,
|
||||
**args.model_conf,
|
||||
)
|
||||
|
||||
# FIXME(kamo): Should be done in model?
|
||||
# 8. Initialize
|
||||
if args.init is not None:
|
||||
initialize(model, args.init)
|
||||
|
||||
assert check_return_type(model)
|
||||
return model
|
||||
|
||||
|
||||
class ASRTaskNAR(AbsTask):
|
||||
# If you need more than one optimizers, change this value
|
||||
num_optimizers: int = 1
|
||||
|
||||
# Add variable objects configurations
|
||||
class_choices_list = [
|
||||
# --frontend and --frontend_conf
|
||||
frontend_choices,
|
||||
# --specaug and --specaug_conf
|
||||
specaug_choices,
|
||||
# --normalize and --normalize_conf
|
||||
normalize_choices,
|
||||
# --model and --model_conf
|
||||
model_choices,
|
||||
# --preencoder and --preencoder_conf
|
||||
preencoder_choices,
|
||||
# --encoder and --encoder_conf
|
||||
encoder_choices,
|
||||
# --postencoder and --postencoder_conf
|
||||
postencoder_choices,
|
||||
# --decoder and --decoder_conf
|
||||
decoder_choices,
|
||||
# --predictor and --predictor_conf
|
||||
predictor_choices,
|
||||
]
|
||||
|
||||
# If you need to modify train() or eval() procedures, change Trainer class here
|
||||
trainer = Trainer
|
||||
|
||||
@classmethod
|
||||
def add_task_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(description='Task related')
|
||||
|
||||
# NOTE(kamo): add_arguments(..., required=True) can't be used
|
||||
# to provide --print_config mode. Instead of it, do as
|
||||
required = parser.get_default('required')
|
||||
required += ['token_list']
|
||||
|
||||
group.add_argument(
|
||||
'--token_list',
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help='A text mapping int-id to token',
|
||||
)
|
||||
group.add_argument(
|
||||
'--init',
|
||||
type=lambda x: str_or_none(x.lower()),
|
||||
default=None,
|
||||
help='The initialization method',
|
||||
choices=[
|
||||
'chainer',
|
||||
'xavier_uniform',
|
||||
'xavier_normal',
|
||||
'kaiming_uniform',
|
||||
'kaiming_normal',
|
||||
None,
|
||||
],
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
'--input_size',
|
||||
type=int_or_none,
|
||||
default=None,
|
||||
help='The number of input dimension of the feature',
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
'--ctc_conf',
|
||||
action=NestedDictAction,
|
||||
default=get_default_kwargs(CTC),
|
||||
help='The keyword arguments for CTC class.',
|
||||
)
|
||||
group.add_argument(
|
||||
'--joint_net_conf',
|
||||
action=NestedDictAction,
|
||||
default=None,
|
||||
help='The keyword arguments for joint network class.',
|
||||
)
|
||||
|
||||
group = parser.add_argument_group(description='Preprocess related')
|
||||
group.add_argument(
|
||||
'--use_preprocessor',
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help='Apply preprocessing to data or not',
|
||||
)
|
||||
group.add_argument(
|
||||
'--token_type',
|
||||
type=str,
|
||||
default='bpe',
|
||||
choices=['bpe', 'char', 'word', 'phn'],
|
||||
help='The text will be tokenized '
|
||||
'in the specified level token',
|
||||
)
|
||||
group.add_argument(
|
||||
'--bpemodel',
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help='The model file of sentencepiece',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--non_linguistic_symbols',
|
||||
type=str_or_none,
|
||||
help='non_linguistic_symbols file path',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--cleaner',
|
||||
type=str_or_none,
|
||||
choices=[None, 'tacotron', 'jaconv', 'vietnamese'],
|
||||
default=None,
|
||||
help='Apply text cleaning',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--g2p',
|
||||
type=str_or_none,
|
||||
choices=g2p_choices,
|
||||
default=None,
|
||||
help='Specify g2p method if --token_type=phn',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--speech_volume_normalize',
|
||||
type=float_or_none,
|
||||
default=None,
|
||||
help='Scale the maximum amplitude to the given value.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--rir_scp',
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help='The file path of rir scp file.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--rir_apply_prob',
|
||||
type=float,
|
||||
default=1.0,
|
||||
help='THe probability for applying RIR convolution.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--noise_scp',
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help='The file path of noise scp file.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--noise_apply_prob',
|
||||
type=float,
|
||||
default=1.0,
|
||||
help='The probability applying Noise adding.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--noise_db_range',
|
||||
type=str,
|
||||
default='13_15',
|
||||
help='The range of noise decibel level.',
|
||||
)
|
||||
|
||||
for class_choices in cls.class_choices_list:
|
||||
# Append --<name> and --<name>_conf.
|
||||
# e.g. --encoder and --encoder_conf
|
||||
class_choices.add_arguments(group)
|
||||
|
||||
@classmethod
|
||||
def build_collate_fn(
|
||||
cls, args: argparse.Namespace, train: bool
|
||||
) -> Callable[[Collection[Tuple[str, Dict[str, np.ndarray]]]], Tuple[
|
||||
List[str], Dict[str, torch.Tensor]], ]:
|
||||
assert check_argument_types()
|
||||
# NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
|
||||
return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
|
||||
|
||||
@classmethod
|
||||
def build_preprocess_fn(
|
||||
cls, args: argparse.Namespace, train: bool
|
||||
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
|
||||
assert check_argument_types()
|
||||
if args.use_preprocessor:
|
||||
retval = CommonPreprocessor(
|
||||
train=train,
|
||||
token_type=args.token_type,
|
||||
token_list=args.token_list,
|
||||
bpemodel=args.bpemodel,
|
||||
non_linguistic_symbols=args.non_linguistic_symbols,
|
||||
text_cleaner=args.cleaner,
|
||||
g2p_type=args.g2p,
|
||||
# NOTE(kamo): Check attribute existence for backward compatibility
|
||||
rir_scp=args.rir_scp if hasattr(args, 'rir_scp') else None,
|
||||
rir_apply_prob=args.rir_apply_prob if hasattr(
|
||||
args, 'rir_apply_prob') else 1.0,
|
||||
noise_scp=args.noise_scp
|
||||
if hasattr(args, 'noise_scp') else None,
|
||||
noise_apply_prob=args.noise_apply_prob if hasattr(
|
||||
args, 'noise_apply_prob') else 1.0,
|
||||
noise_db_range=args.noise_db_range if hasattr(
|
||||
args, 'noise_db_range') else '13_15',
|
||||
speech_volume_normalize=args.speech_volume_normalize
|
||||
if hasattr(args, 'rir_scp') else None,
|
||||
)
|
||||
else:
|
||||
retval = None
|
||||
assert check_return_type(retval)
|
||||
return retval
|
||||
|
||||
@classmethod
|
||||
def required_data_names(cls,
|
||||
train: bool = True,
|
||||
inference: bool = False) -> Tuple[str, ...]:
|
||||
if not inference:
|
||||
retval = ('speech', 'text')
|
||||
else:
|
||||
# Recognition mode
|
||||
retval = ('speech', )
|
||||
return retval
|
||||
|
||||
@classmethod
|
||||
def optional_data_names(cls,
|
||||
train: bool = True,
|
||||
inference: bool = False) -> Tuple[str, ...]:
|
||||
retval = ()
|
||||
assert check_return_type(retval)
|
||||
return retval
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args: argparse.Namespace):
|
||||
assert check_argument_types()
|
||||
if isinstance(args.token_list, str):
|
||||
with open(args.token_list, encoding='utf-8') as f:
|
||||
token_list = [line.rstrip() for line in f]
|
||||
|
||||
# Overwriting token_list to keep it as "portable".
|
||||
args.token_list = list(token_list)
|
||||
elif isinstance(args.token_list, (tuple, list)):
|
||||
token_list = list(args.token_list)
|
||||
else:
|
||||
raise RuntimeError('token_list must be str or list')
|
||||
vocab_size = len(token_list)
|
||||
# logger.info(f'Vocabulary size: {vocab_size }')
|
||||
|
||||
# 1. frontend
|
||||
if args.input_size is None:
|
||||
# Extract features in the model
|
||||
frontend_class = frontend_choices.get_class(args.frontend)
|
||||
frontend = frontend_class(**args.frontend_conf)
|
||||
input_size = frontend.output_size()
|
||||
else:
|
||||
# Give features from data-loader
|
||||
args.frontend = None
|
||||
args.frontend_conf = {}
|
||||
frontend = None
|
||||
input_size = args.input_size
|
||||
|
||||
# 2. Data augmentation for spectrogram
|
||||
if args.specaug is not None:
|
||||
specaug_class = specaug_choices.get_class(args.specaug)
|
||||
specaug = specaug_class(**args.specaug_conf)
|
||||
else:
|
||||
specaug = None
|
||||
|
||||
# 3. Normalization layer
|
||||
if args.normalize is not None:
|
||||
normalize_class = normalize_choices.get_class(args.normalize)
|
||||
normalize = normalize_class(**args.normalize_conf)
|
||||
else:
|
||||
normalize = None
|
||||
|
||||
# 4. Pre-encoder input block
|
||||
# NOTE(kan-bayashi): Use getattr to keep the compatibility
|
||||
if getattr(args, 'preencoder', None) is not None:
|
||||
preencoder_class = preencoder_choices.get_class(args.preencoder)
|
||||
preencoder = preencoder_class(**args.preencoder_conf)
|
||||
input_size = preencoder.output_size()
|
||||
else:
|
||||
preencoder = None
|
||||
|
||||
# 4. Encoder
|
||||
encoder_class = encoder_choices.get_class(args.encoder)
|
||||
encoder = encoder_class(input_size=input_size, **args.encoder_conf)
|
||||
|
||||
# 5. Post-encoder block
|
||||
# NOTE(kan-bayashi): Use getattr to keep the compatibility
|
||||
encoder_output_size = encoder.output_size()
|
||||
if getattr(args, 'postencoder', None) is not None:
|
||||
postencoder_class = postencoder_choices.get_class(args.postencoder)
|
||||
postencoder = postencoder_class(
|
||||
input_size=encoder_output_size, **args.postencoder_conf)
|
||||
encoder_output_size = postencoder.output_size()
|
||||
else:
|
||||
postencoder = None
|
||||
|
||||
# 5. Decoder
|
||||
decoder_class = decoder_choices.get_class(args.decoder)
|
||||
|
||||
if args.decoder == 'transducer':
|
||||
decoder = decoder_class(
|
||||
vocab_size,
|
||||
embed_pad=0,
|
||||
**args.decoder_conf,
|
||||
)
|
||||
|
||||
joint_network = JointNetwork(
|
||||
vocab_size,
|
||||
encoder.output_size(),
|
||||
decoder.dunits,
|
||||
**args.joint_net_conf,
|
||||
)
|
||||
else:
|
||||
decoder = decoder_class(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
**args.decoder_conf,
|
||||
)
|
||||
|
||||
joint_network = None
|
||||
|
||||
# 6. CTC
|
||||
ctc = CTC(
|
||||
odim=vocab_size,
|
||||
encoder_output_size=encoder_output_size,
|
||||
**args.ctc_conf)
|
||||
|
||||
predictor_class = predictor_choices.get_class(args.predictor)
|
||||
predictor = predictor_class(**args.predictor_conf)
|
||||
|
||||
# 7. Build model
|
||||
try:
|
||||
model_class = model_choices.get_class(args.model)
|
||||
except AttributeError:
|
||||
model_class = model_choices.get_class('espnet')
|
||||
model = model_class(
|
||||
vocab_size=vocab_size,
|
||||
frontend=frontend,
|
||||
specaug=specaug,
|
||||
normalize=normalize,
|
||||
preencoder=preencoder,
|
||||
encoder=encoder,
|
||||
postencoder=postencoder,
|
||||
decoder=decoder,
|
||||
ctc=ctc,
|
||||
joint_network=joint_network,
|
||||
token_list=token_list,
|
||||
predictor=predictor,
|
||||
**args.model_conf,
|
||||
)
|
||||
|
||||
# FIXME(kamo): Should be done in model?
|
||||
# 8. Initialize
|
||||
if args.init is not None:
|
||||
initialize(model, args.init)
|
||||
|
||||
assert check_return_type(model)
|
||||
return model
|
||||
217
modelscope/pipelines/audio/asr/asr_inference_pipeline.py
Normal file
217
modelscope/pipelines/audio/asr/asr_inference_pipeline.py
Normal file
@@ -0,0 +1,217 @@
|
||||
import io
|
||||
import os
|
||||
import shutil
|
||||
import threading
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import yaml
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models import Model
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import WavToScp
|
||||
from modelscope.utils.constant import Tasks
|
||||
from .asr_engine import asr_env_checking, asr_inference_paraformer_espnet
|
||||
from .asr_engine.common import asr_utils
|
||||
|
||||
__all__ = ['AutomaticSpeechRecognitionPipeline']
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.auto_speech_recognition, module_name=Pipelines.asr_inference)
|
||||
class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
"""ASR Pipeline
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: Union[List[Model], List[str]] = None,
|
||||
preprocessor: WavToScp = None,
|
||||
**kwargs):
|
||||
"""use `model` and `preprocessor` to create an asr pipeline for prediction
|
||||
"""
|
||||
|
||||
assert model is not None, 'asr model should be provided'
|
||||
|
||||
model_list: List = []
|
||||
if isinstance(model[0], Model):
|
||||
model_list = model
|
||||
else:
|
||||
model_list.append(Model.from_pretrained(model[0]))
|
||||
if len(model) == 2 and model[1] is not None:
|
||||
model_list.append(Model.from_pretrained(model[1]))
|
||||
|
||||
super().__init__(model=model_list, preprocessor=preprocessor, **kwargs)
|
||||
|
||||
self._preprocessor = preprocessor
|
||||
self._am_model = model_list[0]
|
||||
if len(model_list) == 2 and model_list[1] is not None:
|
||||
self._lm_model = model_list[1]
|
||||
|
||||
def __call__(self,
|
||||
wav_path: str,
|
||||
recog_type: str = None,
|
||||
audio_format: str = None,
|
||||
workspace: str = None) -> Dict[str, Any]:
|
||||
assert len(wav_path) > 0, 'wav_path should be provided'
|
||||
|
||||
self._recog_type = recog_type
|
||||
self._audio_format = audio_format
|
||||
self._workspace = workspace
|
||||
self._wav_path = wav_path
|
||||
|
||||
if recog_type is None or audio_format is None or workspace is None:
|
||||
self._recog_type, self._audio_format, self._workspace, self._wav_path = asr_utils.type_checking(
|
||||
wav_path, recog_type, audio_format, workspace)
|
||||
|
||||
if self._preprocessor is None:
|
||||
self._preprocessor = WavToScp(workspace=self._workspace)
|
||||
|
||||
output = self._preprocessor.forward(self._am_model.forward(),
|
||||
self._recog_type,
|
||||
self._audio_format, self._wav_path)
|
||||
output = self.forward(output)
|
||||
rst = self.postprocess(output)
|
||||
return rst
|
||||
|
||||
def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Decoding
|
||||
"""
|
||||
|
||||
j: int = 0
|
||||
process = []
|
||||
|
||||
while j < inputs['thread_count']:
|
||||
data_cmd: Sequence[Tuple[str, str, str]]
|
||||
if inputs['audio_format'] == 'wav':
|
||||
data_cmd = [(os.path.join(inputs['workspace'],
|
||||
'data.' + str(j) + '.scp'), 'speech',
|
||||
'sound')]
|
||||
elif inputs['audio_format'] == 'kaldi_ark':
|
||||
data_cmd = [(os.path.join(inputs['workspace'],
|
||||
'data.' + str(j) + '.scp'), 'speech',
|
||||
'kaldi_ark')]
|
||||
|
||||
output_dir: str = os.path.join(inputs['output'],
|
||||
'output.' + str(j))
|
||||
if not os.path.exists(output_dir):
|
||||
os.mkdir(output_dir)
|
||||
|
||||
config_file = open(inputs['asr_model_config'])
|
||||
root = yaml.full_load(config_file)
|
||||
config_file.close()
|
||||
frontend_conf = None
|
||||
if 'frontend_conf' in root:
|
||||
frontend_conf = root['frontend_conf']
|
||||
|
||||
cmd = {
|
||||
'model_type': inputs['model_type'],
|
||||
'beam_size': root['beam_size'],
|
||||
'penalty': root['penalty'],
|
||||
'maxlenratio': root['maxlenratio'],
|
||||
'minlenratio': root['minlenratio'],
|
||||
'ctc_weight': root['ctc_weight'],
|
||||
'lm_weight': root['lm_weight'],
|
||||
'output_dir': output_dir,
|
||||
'ngpu': 0,
|
||||
'log_level': 'ERROR',
|
||||
'data_path_and_name_and_type': data_cmd,
|
||||
'asr_train_config': inputs['am_model_config'],
|
||||
'asr_model_file': inputs['am_model_path'],
|
||||
'batch_size': inputs['model_config']['batch_size'],
|
||||
'frontend_conf': frontend_conf
|
||||
}
|
||||
|
||||
thread = AsrInferenceThread(j, cmd)
|
||||
thread.start()
|
||||
j += 1
|
||||
process.append(thread)
|
||||
|
||||
for p in process:
|
||||
p.join()
|
||||
|
||||
return inputs
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""process the asr results
|
||||
"""
|
||||
|
||||
rst = {'rec_result': 'None'}
|
||||
|
||||
# single wav task
|
||||
if inputs['recog_type'] == 'wav' and inputs['audio_format'] == 'wav':
|
||||
text_file: str = os.path.join(inputs['output'], 'output.0',
|
||||
'1best_recog', 'text')
|
||||
|
||||
if os.path.exists(text_file):
|
||||
f = open(text_file, 'r')
|
||||
result_str: str = f.readline()
|
||||
f.close()
|
||||
if len(result_str) > 0:
|
||||
result_list = result_str.split()
|
||||
if len(result_list) >= 2:
|
||||
rst['rec_result'] = result_list[1]
|
||||
|
||||
# run with datasets, and audio format is waveform or kaldi_ark
|
||||
elif inputs['recog_type'] != 'wav':
|
||||
inputs['reference_text'] = self._ref_text_tidy(inputs)
|
||||
inputs['datasets_result'] = asr_utils.compute_wer(
|
||||
inputs['hypothesis_text'], inputs['reference_text'])
|
||||
|
||||
else:
|
||||
raise ValueError('recog_type and audio_format are mismatching')
|
||||
|
||||
if 'datasets_result' in inputs:
|
||||
rst['datasets_result'] = inputs['datasets_result']
|
||||
|
||||
# remove workspace dir (.tmp)
|
||||
if os.path.exists(self._workspace):
|
||||
shutil.rmtree(self._workspace)
|
||||
|
||||
return rst
|
||||
|
||||
def _ref_text_tidy(self, inputs: Dict[str, Any]) -> str:
|
||||
ref_text: str = os.path.join(inputs['output'], 'text.ref')
|
||||
k: int = 0
|
||||
|
||||
while k < inputs['thread_count']:
|
||||
output_text = os.path.join(inputs['output'], 'output.' + str(k),
|
||||
'1best_recog', 'text')
|
||||
if os.path.exists(output_text):
|
||||
with open(output_text, 'r', encoding='utf-8') as i:
|
||||
lines = i.readlines()
|
||||
|
||||
with open(ref_text, 'a', encoding='utf-8') as o:
|
||||
for line in lines:
|
||||
o.write(line)
|
||||
|
||||
k += 1
|
||||
|
||||
return ref_text
|
||||
|
||||
|
||||
class AsrInferenceThread(threading.Thread):
|
||||
|
||||
def __init__(self, threadID, cmd):
|
||||
threading.Thread.__init__(self)
|
||||
self._threadID = threadID
|
||||
self._cmd = cmd
|
||||
|
||||
def run(self):
|
||||
if self._cmd['model_type'] == 'pytorch':
|
||||
asr_inference_paraformer_espnet.asr_inference(
|
||||
batch_size=self._cmd['batch_size'],
|
||||
output_dir=self._cmd['output_dir'],
|
||||
maxlenratio=self._cmd['maxlenratio'],
|
||||
minlenratio=self._cmd['minlenratio'],
|
||||
beam_size=self._cmd['beam_size'],
|
||||
ngpu=self._cmd['ngpu'],
|
||||
ctc_weight=self._cmd['ctc_weight'],
|
||||
lm_weight=self._cmd['lm_weight'],
|
||||
penalty=self._cmd['penalty'],
|
||||
log_level=self._cmd['log_level'],
|
||||
data_path_and_name_and_type=self.
|
||||
_cmd['data_path_and_name_and_type'],
|
||||
asr_train_config=self._cmd['asr_train_config'],
|
||||
asr_model_file=self._cmd['asr_model_file'],
|
||||
frontend_conf=self._cmd['frontend_conf'])
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from modelscope.utils.error import AUDIO_IMPORT_ERROR, TENSORFLOW_IMPORT_ERROR
|
||||
from .asr import WavToScp
|
||||
from .base import Preprocessor
|
||||
from .builder import PREPROCESSORS, build_preprocessor
|
||||
from .common import Compose
|
||||
|
||||
254
modelscope/preprocessors/asr.py
Normal file
254
modelscope/preprocessors/asr.py
Normal file
@@ -0,0 +1,254 @@
|
||||
import io
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import yaml
|
||||
|
||||
from modelscope.metainfo import Preprocessors
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.utils.constant import Fields
|
||||
from .base import Preprocessor
|
||||
from .builder import PREPROCESSORS
|
||||
|
||||
__all__ = ['WavToScp']
|
||||
|
||||
|
||||
@PREPROCESSORS.register_module(
|
||||
Fields.audio, module_name=Preprocessors.wav_to_scp)
|
||||
class WavToScp(Preprocessor):
|
||||
"""generate audio scp from wave or ark
|
||||
|
||||
Args:
|
||||
workspace (str):
|
||||
"""
|
||||
|
||||
def __init__(self, workspace: str = None):
|
||||
# the workspace path
|
||||
if workspace is None or len(workspace) == 0:
|
||||
self._workspace = os.path.join(os.getcwd(), '.tmp')
|
||||
else:
|
||||
self._workspace = workspace
|
||||
|
||||
if not os.path.exists(self._workspace):
|
||||
os.mkdir(self._workspace)
|
||||
|
||||
def __call__(self,
|
||||
model: List[Model] = None,
|
||||
recog_type: str = None,
|
||||
audio_format: str = None,
|
||||
wav_path: str = None) -> Dict[str, Any]:
|
||||
assert len(model) > 0, 'preprocess model is invalid'
|
||||
assert len(recog_type) > 0, 'preprocess recog_type is empty'
|
||||
assert len(audio_format) > 0, 'preprocess audio_format is empty'
|
||||
assert len(wav_path) > 0, 'preprocess wav_path is empty'
|
||||
|
||||
self._am_model = model[0]
|
||||
if len(model) == 2 and model[1] is not None:
|
||||
self._lm_model = model[1]
|
||||
out = self.forward(self._am_model.forward(), recog_type, audio_format,
|
||||
wav_path)
|
||||
return out
|
||||
|
||||
def forward(self, model: Dict[str, Any], recog_type: str,
|
||||
audio_format: str, wav_path: str) -> Dict[str, Any]:
|
||||
assert len(recog_type) > 0, 'preprocess recog_type is empty'
|
||||
assert len(audio_format) > 0, 'preprocess audio_format is empty'
|
||||
assert len(wav_path) > 0, 'preprocess wav_path is empty'
|
||||
assert os.path.exists(wav_path), 'preprocess wav_path does not exist'
|
||||
assert len(
|
||||
model['am_model']) > 0, 'preprocess model[am_model] is empty'
|
||||
assert len(model['am_model_path']
|
||||
) > 0, 'preprocess model[am_model_path] is empty'
|
||||
assert os.path.exists(
|
||||
model['am_model_path']), 'preprocess am_model_path does not exist'
|
||||
assert len(model['model_workspace']
|
||||
) > 0, 'preprocess model[model_workspace] is empty'
|
||||
assert os.path.exists(model['model_workspace']
|
||||
), 'preprocess model_workspace does not exist'
|
||||
assert len(model['model_config']
|
||||
) > 0, 'preprocess model[model_config] is empty'
|
||||
|
||||
# the am model name
|
||||
am_model: str = model['am_model']
|
||||
# the am model file path
|
||||
am_model_path: str = model['am_model_path']
|
||||
# the recognition model dir path
|
||||
model_workspace: str = model['model_workspace']
|
||||
# the recognition model config dict
|
||||
global_model_config_dict: str = model['model_config']
|
||||
|
||||
rst = {
|
||||
'workspace': os.path.join(self._workspace, recog_type),
|
||||
'am_model': am_model,
|
||||
'am_model_path': am_model_path,
|
||||
'model_workspace': model_workspace,
|
||||
# the asr type setting, eg: test dev train wav
|
||||
'recog_type': recog_type,
|
||||
# the asr audio format setting, eg: wav, kaldi_ark
|
||||
'audio_format': audio_format,
|
||||
# the test wav file path or the dataset path
|
||||
'wav_path': wav_path,
|
||||
'model_config': global_model_config_dict
|
||||
}
|
||||
|
||||
out = self._config_checking(rst)
|
||||
out = self._env_setting(out)
|
||||
if audio_format == 'wav':
|
||||
out = self._scp_generation_from_wav(out)
|
||||
elif audio_format == 'kaldi_ark':
|
||||
out = self._scp_generation_from_ark(out)
|
||||
|
||||
return out
|
||||
|
||||
def _config_checking(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""config checking
|
||||
"""
|
||||
|
||||
assert inputs['model_config'].__contains__(
|
||||
'type'), 'model type does not exist'
|
||||
assert inputs['model_config'].__contains__(
|
||||
'batch_size'), 'batch_size does not exist'
|
||||
assert inputs['model_config'].__contains__(
|
||||
'am_model_config'), 'am_model_config does not exist'
|
||||
assert inputs['model_config'].__contains__(
|
||||
'asr_model_config'), 'asr_model_config does not exist'
|
||||
assert inputs['model_config'].__contains__(
|
||||
'asr_model_wav_config'), 'asr_model_wav_config does not exist'
|
||||
|
||||
am_model_config: str = os.path.join(
|
||||
inputs['model_workspace'],
|
||||
inputs['model_config']['am_model_config'])
|
||||
assert os.path.exists(
|
||||
am_model_config), 'am_model_config does not exist'
|
||||
inputs['am_model_config'] = am_model_config
|
||||
|
||||
asr_model_config: str = os.path.join(
|
||||
inputs['model_workspace'],
|
||||
inputs['model_config']['asr_model_config'])
|
||||
assert os.path.exists(
|
||||
asr_model_config), 'asr_model_config does not exist'
|
||||
|
||||
asr_model_wav_config: str = os.path.join(
|
||||
inputs['model_workspace'],
|
||||
inputs['model_config']['asr_model_wav_config'])
|
||||
assert os.path.exists(
|
||||
asr_model_wav_config), 'asr_model_wav_config does not exist'
|
||||
|
||||
inputs['model_type'] = inputs['model_config']['type']
|
||||
|
||||
if inputs['audio_format'] == 'wav':
|
||||
inputs['asr_model_config'] = asr_model_wav_config
|
||||
else:
|
||||
inputs['asr_model_config'] = asr_model_config
|
||||
|
||||
return inputs
|
||||
|
||||
def _env_setting(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
if not os.path.exists(inputs['workspace']):
|
||||
os.mkdir(inputs['workspace'])
|
||||
|
||||
inputs['output'] = os.path.join(inputs['workspace'], 'logdir')
|
||||
if not os.path.exists(inputs['output']):
|
||||
os.mkdir(inputs['output'])
|
||||
|
||||
# run with datasets, should set datasets_path and text_path
|
||||
if inputs['recog_type'] != 'wav':
|
||||
inputs['datasets_path'] = inputs['wav_path']
|
||||
|
||||
# run with datasets, and audio format is waveform
|
||||
if inputs['audio_format'] == 'wav':
|
||||
inputs['wav_path'] = os.path.join(inputs['datasets_path'],
|
||||
'wav', inputs['recog_type'])
|
||||
inputs['hypothesis_text'] = os.path.join(
|
||||
inputs['datasets_path'], 'transcript', 'data.text')
|
||||
assert os.path.exists(inputs['hypothesis_text']
|
||||
), 'hypothesis text does not exist'
|
||||
|
||||
elif inputs['audio_format'] == 'kaldi_ark':
|
||||
inputs['wav_path'] = os.path.join(inputs['datasets_path'],
|
||||
inputs['recog_type'])
|
||||
inputs['hypothesis_text'] = os.path.join(
|
||||
inputs['wav_path'], 'data.text')
|
||||
assert os.path.exists(inputs['hypothesis_text']
|
||||
), 'hypothesis text does not exist'
|
||||
|
||||
return inputs
|
||||
|
||||
def _scp_generation_from_wav(self, inputs: Dict[str,
|
||||
Any]) -> Dict[str, Any]:
|
||||
"""scp generation from waveform files
|
||||
"""
|
||||
|
||||
# find all waveform files
|
||||
wav_list = []
|
||||
if inputs['recog_type'] == 'wav':
|
||||
file_path = inputs['wav_path']
|
||||
if os.path.isfile(file_path):
|
||||
if file_path.endswith('.wav') or file_path.endswith('.WAV'):
|
||||
wav_list.append(file_path)
|
||||
else:
|
||||
wav_dir: str = inputs['wav_path']
|
||||
wav_list = self._recursion_dir_all_wave(wav_list, wav_dir)
|
||||
|
||||
list_count: int = len(wav_list)
|
||||
inputs['wav_count'] = list_count
|
||||
|
||||
# store all wav into data.0.scp
|
||||
inputs['thread_count'] = 1
|
||||
j: int = 0
|
||||
wav_list_path = os.path.join(inputs['workspace'], 'data.0.scp')
|
||||
with open(wav_list_path, 'a') as f:
|
||||
while j < list_count:
|
||||
wav_file = wav_list[j]
|
||||
wave_scp_content: str = os.path.splitext(
|
||||
os.path.basename(wav_file))[0]
|
||||
wave_scp_content += ' ' + wav_file + '\n'
|
||||
f.write(wave_scp_content)
|
||||
j += 1
|
||||
|
||||
return inputs
|
||||
|
||||
def _scp_generation_from_ark(self, inputs: Dict[str,
|
||||
Any]) -> Dict[str, Any]:
|
||||
"""scp generation from kaldi ark file
|
||||
"""
|
||||
|
||||
inputs['thread_count'] = 1
|
||||
ark_scp_path = os.path.join(inputs['wav_path'], 'data.scp')
|
||||
ark_file_path = os.path.join(inputs['wav_path'], 'data.ark')
|
||||
assert os.path.exists(ark_scp_path), 'data.scp does not exist'
|
||||
assert os.path.exists(ark_file_path), 'data.ark does not exist'
|
||||
|
||||
new_ark_scp_path = os.path.join(inputs['workspace'], 'data.0.scp')
|
||||
|
||||
with open(ark_scp_path, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
with open(new_ark_scp_path, 'w', encoding='utf-8') as n:
|
||||
for line in lines:
|
||||
outs = line.strip().split(' ')
|
||||
if len(outs) == 2:
|
||||
key = outs[0]
|
||||
sub = outs[1].split(':')
|
||||
if len(sub) == 2:
|
||||
nums = sub[1]
|
||||
content = key + ' ' + ark_file_path + ':' + nums + '\n'
|
||||
n.write(content)
|
||||
|
||||
return inputs
|
||||
|
||||
def _recursion_dir_all_wave(self, wav_list,
|
||||
dir_path: str) -> Dict[str, Any]:
|
||||
dir_files = os.listdir(dir_path)
|
||||
for file in dir_files:
|
||||
file_path = os.path.join(dir_path, file)
|
||||
if os.path.isfile(file_path):
|
||||
if file_path.endswith('.wav') or file_path.endswith('.WAV'):
|
||||
wav_list.append(file_path)
|
||||
elif os.path.isdir(file_path):
|
||||
self._recursion_dir_all_wave(wav_list, file_path)
|
||||
|
||||
return wav_list
|
||||
@@ -1,3 +1,4 @@
|
||||
espnet==202204
|
||||
#tts
|
||||
h5py
|
||||
inflect
|
||||
|
||||
199
tests/pipelines/test_automatic_speech_recognition.py
Normal file
199
tests/pipelines/test_automatic_speech_recognition.py
Normal file
@@ -0,0 +1,199 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import shutil
|
||||
import tarfile
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
WAV_FILE = 'data/test/audios/asr_example.wav'
|
||||
|
||||
LITTLE_TESTSETS_FILE = 'data_aishell.tar.gz'
|
||||
LITTLE_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/data_aishell.tar.gz'
|
||||
|
||||
AISHELL1_TESTSETS_FILE = 'aishell1.tar.gz'
|
||||
AISHELL1_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/datasets/aishell1.tar.gz'
|
||||
|
||||
|
||||
def un_tar_gz(fname, dirs):
|
||||
t = tarfile.open(fname)
|
||||
t.extractall(path=dirs)
|
||||
|
||||
|
||||
class AutomaticSpeechRecognitionTest(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self._am_model_id = 'damo/speech_paraformer_asr_nat-aishell1-pytorch'
|
||||
# this temporary workspace dir will store waveform files
|
||||
self._workspace = os.path.join(os.getcwd(), '.tmp')
|
||||
if not os.path.exists(self._workspace):
|
||||
os.mkdir(self._workspace)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_wav(self):
|
||||
'''run with single waveform file
|
||||
'''
|
||||
|
||||
wav_file_path = os.path.join(os.getcwd(), WAV_FILE)
|
||||
|
||||
inference_16k_pipline = pipeline(
|
||||
task=Tasks.auto_speech_recognition, model=[self._am_model_id])
|
||||
self.assertTrue(inference_16k_pipline is not None)
|
||||
|
||||
rec_result = inference_16k_pipline(wav_file_path)
|
||||
self.assertTrue(len(rec_result['rec_result']) > 0)
|
||||
self.assertTrue(rec_result['rec_result'] != 'None')
|
||||
'''
|
||||
result structure:
|
||||
{
|
||||
'rec_result': '每一天都要快乐喔'
|
||||
}
|
||||
or
|
||||
{
|
||||
'rec_result': 'None'
|
||||
}
|
||||
'''
|
||||
print('test_run_with_wav rec result: ' + rec_result['rec_result'])
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_with_wav_dataset(self):
|
||||
'''run with datasets, and audio format is waveform
|
||||
datasets directory:
|
||||
<dataset_path>
|
||||
wav
|
||||
test # testsets
|
||||
xx.wav
|
||||
...
|
||||
dev # devsets
|
||||
yy.wav
|
||||
...
|
||||
train # trainsets
|
||||
zz.wav
|
||||
...
|
||||
transcript
|
||||
data.text # hypothesis text
|
||||
'''
|
||||
|
||||
# downloading pos_testsets file
|
||||
testsets_file_path = os.path.join(self._workspace,
|
||||
LITTLE_TESTSETS_FILE)
|
||||
if not os.path.exists(testsets_file_path):
|
||||
r = requests.get(LITTLE_TESTSETS_URL)
|
||||
with open(testsets_file_path, 'wb') as f:
|
||||
f.write(r.content)
|
||||
|
||||
testsets_dir_name = os.path.splitext(
|
||||
os.path.basename(
|
||||
os.path.splitext(
|
||||
os.path.basename(LITTLE_TESTSETS_FILE))[0]))[0]
|
||||
# dataset_path = <cwd>/.tmp/data_aishell/wav/test
|
||||
dataset_path = os.path.join(self._workspace, testsets_dir_name, 'wav',
|
||||
'test')
|
||||
|
||||
# untar the dataset_path file
|
||||
if not os.path.exists(dataset_path):
|
||||
un_tar_gz(testsets_file_path, self._workspace)
|
||||
|
||||
inference_16k_pipline = pipeline(
|
||||
task=Tasks.auto_speech_recognition, model=[self._am_model_id])
|
||||
self.assertTrue(inference_16k_pipline is not None)
|
||||
|
||||
rec_result = inference_16k_pipline(wav_path=dataset_path)
|
||||
self.assertTrue(len(rec_result['datasets_result']) > 0)
|
||||
self.assertTrue(rec_result['datasets_result']['Wrd'] > 0)
|
||||
'''
|
||||
result structure:
|
||||
{
|
||||
'rec_result': 'None',
|
||||
'datasets_result':
|
||||
{
|
||||
'Wrd': 1654, # the number of words
|
||||
'Snt': 128, # the number of sentences
|
||||
'Corr': 1573, # the number of correct words
|
||||
'Ins': 1, # the number of insert words
|
||||
'Del': 1, # the number of delete words
|
||||
'Sub': 80, # the number of substitution words
|
||||
'wrong_words': 82, # the number of wrong words
|
||||
'wrong_sentences': 47, # the number of wrong sentences
|
||||
'Err': 4.96, # WER/CER
|
||||
'S.Err': 36.72 # SER
|
||||
}
|
||||
}
|
||||
'''
|
||||
print('test_run_with_wav_dataset datasets result: ')
|
||||
print(rec_result['datasets_result'])
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_with_ark_dataset(self):
|
||||
'''run with datasets, and audio format is kaldi_ark
|
||||
datasets directory:
|
||||
<dataset_path>
|
||||
test # testsets
|
||||
data.ark
|
||||
data.scp
|
||||
data.text
|
||||
dev # devsets
|
||||
data.ark
|
||||
data.scp
|
||||
data.text
|
||||
train # trainsets
|
||||
data.ark
|
||||
data.scp
|
||||
data.text
|
||||
'''
|
||||
|
||||
# downloading pos_testsets file
|
||||
testsets_file_path = os.path.join(self._workspace,
|
||||
AISHELL1_TESTSETS_FILE)
|
||||
if not os.path.exists(testsets_file_path):
|
||||
r = requests.get(AISHELL1_TESTSETS_URL)
|
||||
with open(testsets_file_path, 'wb') as f:
|
||||
f.write(r.content)
|
||||
|
||||
testsets_dir_name = os.path.splitext(
|
||||
os.path.basename(
|
||||
os.path.splitext(
|
||||
os.path.basename(AISHELL1_TESTSETS_FILE))[0]))[0]
|
||||
# dataset_path = <cwd>/.tmp/aishell1/test
|
||||
dataset_path = os.path.join(self._workspace, testsets_dir_name, 'test')
|
||||
|
||||
# untar the dataset_path file
|
||||
if not os.path.exists(dataset_path):
|
||||
un_tar_gz(testsets_file_path, self._workspace)
|
||||
|
||||
inference_16k_pipline = pipeline(
|
||||
task=Tasks.auto_speech_recognition, model=[self._am_model_id])
|
||||
self.assertTrue(inference_16k_pipline is not None)
|
||||
|
||||
rec_result = inference_16k_pipline(wav_path=dataset_path)
|
||||
self.assertTrue(len(rec_result['datasets_result']) > 0)
|
||||
self.assertTrue(rec_result['datasets_result']['Wrd'] > 0)
|
||||
'''
|
||||
result structure:
|
||||
{
|
||||
'rec_result': 'None',
|
||||
'datasets_result':
|
||||
{
|
||||
'Wrd': 104816, # the number of words
|
||||
'Snt': 7176, # the number of sentences
|
||||
'Corr': 99327, # the number of correct words
|
||||
'Ins': 104, # the number of insert words
|
||||
'Del': 155, # the number of delete words
|
||||
'Sub': 5334, # the number of substitution words
|
||||
'wrong_words': 5593, # the number of wrong words
|
||||
'wrong_sentences': 2898, # the number of wrong sentences
|
||||
'Err': 5.34, # WER/CER
|
||||
'S.Err': 40.38 # SER
|
||||
}
|
||||
}
|
||||
'''
|
||||
print('test_run_with_ark_dataset datasets result: ')
|
||||
print(rec_result['datasets_result'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user