mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
Fix speech kws nearfield training with multi-gpu
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/12117620
This commit is contained in:
@@ -20,7 +20,8 @@ from torch.utils.data import IterableDataset
|
||||
|
||||
import modelscope.msdatasets.dataset_cls.custom_datasets.audio.kws_nearfield_processor as processor
|
||||
from modelscope.trainers.audio.kws_utils.file_utils import (make_pair,
|
||||
read_lists)
|
||||
read_lists,
|
||||
tokenize)
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
@@ -119,12 +120,41 @@ class DataList(IterableDataset):
|
||||
data.update(sampler_info)
|
||||
yield data
|
||||
|
||||
def dump(self, dump_file):
|
||||
with open(dump_file, 'w', encoding='utf8') as fout:
|
||||
for obj in self.lists:
|
||||
if hasattr(obj, 'get') and obj.get('tokens', None) is not None:
|
||||
assert 'key' in obj
|
||||
assert 'wav' in obj
|
||||
assert 'txt' in obj
|
||||
assert len(obj['tokens']) == len(obj['txt'])
|
||||
dump_line = obj['key'] + ':\n'
|
||||
dump_line += '\t' + obj['wav'] + '\n'
|
||||
dump_line += '\t'
|
||||
for token, idx in zip(obj['tokens'], obj['txt']):
|
||||
dump_line += '%s(%d) ' % (token, idx)
|
||||
dump_line += '\n\n'
|
||||
fout.write(dump_line)
|
||||
else:
|
||||
infos = json.loads(obj)
|
||||
assert 'key' in infos
|
||||
assert 'wav' in infos
|
||||
assert 'txt' in infos
|
||||
dump_line = infos['key'] + ':\n'
|
||||
dump_line += '\t' + infos['wav'] + '\n'
|
||||
dump_line += '\t'
|
||||
dump_line += '%d' % infos['txt']
|
||||
dump_line += '\n\n'
|
||||
fout.write(dump_line)
|
||||
|
||||
|
||||
def kws_nearfield_dataset(data_file,
|
||||
trans_file,
|
||||
conf,
|
||||
symbol_table,
|
||||
lexicon_table,
|
||||
need_dump=False,
|
||||
dump_file='',
|
||||
partition=True):
|
||||
""" Construct dataset from arguments
|
||||
|
||||
@@ -137,6 +167,8 @@ def kws_nearfield_dataset(data_file,
|
||||
trans_file (str): transcription list with kaldi style
|
||||
symbol_table (Dict): token list, [token_str, token_id]
|
||||
lexicon_table (Dict): words list defined with basic tokens
|
||||
need_dump (bool): whether to dump data with mapping tokens or not
|
||||
dump_file (str): dumping file path
|
||||
partition (bool): whether to do data partition in terms of rank
|
||||
"""
|
||||
|
||||
@@ -146,14 +178,14 @@ def kws_nearfield_dataset(data_file,
|
||||
wav_lists = read_lists(data_file)
|
||||
trans_lists = read_lists(trans_file)
|
||||
lists = make_pair(wav_lists, trans_lists)
|
||||
lists = tokenize(lists, symbol_table, lexicon_table)
|
||||
|
||||
shuffle = conf.get('shuffle', True)
|
||||
dataset = DataList(lists, shuffle=shuffle, partition=partition)
|
||||
if need_dump:
|
||||
dataset.dump(dump_file)
|
||||
|
||||
dataset = Processor(dataset, processor.parse_wav)
|
||||
dataset = Processor(dataset, processor.tokenize, symbol_table,
|
||||
lexicon_table, conf.get('split_with_space', False))
|
||||
|
||||
dataset = Processor(dataset, processor.filter, **filter_conf)
|
||||
|
||||
feature_extraction_conf = conf.get('feature_extraction_conf', {})
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import random
|
||||
|
||||
import json
|
||||
@@ -24,6 +23,9 @@ import torchaudio.compliance.kaldi as kaldi
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
# torch.set_printoptions(profile="full")
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def parse_wav(data):
|
||||
@@ -53,54 +55,7 @@ def parse_wav(data):
|
||||
key=key, label=txt, wav=waveform, sample_rate=sample_rate)
|
||||
yield example
|
||||
except Exception:
|
||||
logging.warning('Failed to read {}'.format(wav_file))
|
||||
|
||||
|
||||
def tokenize(data, token_table, lexicon_table, split_with_space=False):
|
||||
""" Decode text to chars
|
||||
Inplace operation
|
||||
|
||||
Args:
|
||||
data: Iterable[{key, wav, txt, sample_rate}]
|
||||
token_table (Dict): token list, [token_str, token_id]
|
||||
lexicon_table (Dict): words list defined with basic tokens
|
||||
split_with_space (bool): if transciption split with space or not
|
||||
|
||||
Returns:
|
||||
Iterable[{key, wav, txt, tokens, label, sample_rate}]
|
||||
"""
|
||||
for sample in data:
|
||||
assert 'label' in sample
|
||||
txt = sample['label'].strip()
|
||||
|
||||
if token_table is None or lexicon_table is None:
|
||||
# to compatible with hard token map for max-pooling loss
|
||||
label = int(txt)
|
||||
else:
|
||||
parts = [txt]
|
||||
tokens = []
|
||||
for part in parts:
|
||||
if split_with_space:
|
||||
part = part.split(' ')
|
||||
for ch in part:
|
||||
if ch == ' ':
|
||||
ch = '▁'
|
||||
tokens.append(ch)
|
||||
|
||||
label = []
|
||||
for ch in tokens:
|
||||
if ch in lexicon_table:
|
||||
for sub_ch in lexicon_table[ch]:
|
||||
if sub_ch in token_table:
|
||||
label.append(token_table[sub_ch])
|
||||
else:
|
||||
label.append(token_table['<blk>'])
|
||||
else:
|
||||
label.append(token_table['<blk>'])
|
||||
|
||||
sample['tokens'] = tokens
|
||||
sample['label'] = label
|
||||
yield sample
|
||||
logger.warning('Failed to read {}'.format(wav_file))
|
||||
|
||||
|
||||
def filter(data, max_length=10240, min_length=10):
|
||||
@@ -128,11 +83,11 @@ def filter(data, max_length=10240, min_length=10):
|
||||
|
||||
# print("{} num frames is {}".format(sample['key'], num_frames))
|
||||
if num_frames < min_length:
|
||||
logging.warning('{} is discard for too short: {} frames'.format(
|
||||
logger.warning('{} is discard for too short: {} frames'.format(
|
||||
sample['key'], num_frames))
|
||||
continue
|
||||
if num_frames > max_length:
|
||||
logging.warning('{} is discard for too long: {} frames'.format(
|
||||
logger.warning('{} is discard for too long: {} frames'.format(
|
||||
sample['key'], num_frames))
|
||||
continue
|
||||
yield sample
|
||||
|
||||
@@ -6,6 +6,7 @@ import re
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import yaml
|
||||
from tensorboardX import SummaryWriter
|
||||
from torch import nn as nn
|
||||
@@ -23,11 +24,10 @@ from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile
|
||||
from modelscope.utils.device import create_device
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.torch_utils import (get_dist_info, get_local_rank,
|
||||
init_dist, set_random_seed)
|
||||
from modelscope.utils.torch_utils import set_random_seed
|
||||
from .kws_utils.batch_utils import executor_cv, executor_test, executor_train
|
||||
from .kws_utils.det_utils import compute_det
|
||||
from .kws_utils.file_utils import query_tokens_id, read_lexicon, read_token
|
||||
from .kws_utils.file_utils import query_token_set, read_lexicon, read_token
|
||||
from .kws_utils.model_utils import (average_model, convert_to_kaldi,
|
||||
count_parameters)
|
||||
|
||||
@@ -47,14 +47,11 @@ class KWSNearfieldTrainer(BaseTrainer):
|
||||
**kwargs):
|
||||
'''
|
||||
Args:
|
||||
work_dir (str): main directory for training
|
||||
model (str): model id in modelscope
|
||||
work_dir (str): main directory for training and evaluating
|
||||
cfg_file (str): config file for training and evaluating
|
||||
kwargs:
|
||||
checkpoint (str): basemodel checkpoint, if None, default to use base.pt in model path
|
||||
train_data (int): wave list with kaldi style for training
|
||||
cv_data (int): wave list with kaldi style for cross validation
|
||||
trans_data (str): transcription list with kaldi style, merge train and cv
|
||||
tensorboard_dir (str): path to save tensorboard results,
|
||||
create 'tensorboard_dir' in work_dir by default
|
||||
seed (int): random seed
|
||||
'''
|
||||
if isinstance(model, str):
|
||||
self.model_dir = self.get_or_download_model_dir(
|
||||
@@ -69,20 +66,12 @@ class KWSNearfieldTrainer(BaseTrainer):
|
||||
super().__init__(cfg_file, arg_parse_fn)
|
||||
configs = Config.from_file(cfg_file)
|
||||
|
||||
print(kwargs)
|
||||
self.launcher = 'pytorch'
|
||||
self.dist_backend = configs.train.get('dist_backend', 'nccl')
|
||||
self.tensorboard_dir = kwargs.get('tensorboard_dir', 'tensorboard')
|
||||
self.checkpoint = kwargs.get(
|
||||
'checkpoint', os.path.join(self.model_dir, 'train/base.pt'))
|
||||
self.avg_checkpoint = None
|
||||
|
||||
# 1. get rank info
|
||||
set_random_seed(kwargs.get('seed', 666))
|
||||
self.get_dist_info()
|
||||
logger.info('RANK {}/{}/{}, Master addr:{}, Master port:{}'.format(
|
||||
self.world_size, self.rank, self.local_rank, self.master_addr,
|
||||
self.master_port))
|
||||
self.init_dist()
|
||||
|
||||
self.work_dir = work_dir
|
||||
if self.rank == 0:
|
||||
@@ -116,6 +105,24 @@ class KWSNearfieldTrainer(BaseTrainer):
|
||||
fout.write(data)
|
||||
|
||||
def train(self, *args, **kwargs):
|
||||
'''
|
||||
Args:
|
||||
kwargs:
|
||||
train_data (int): wave list with kaldi style for training
|
||||
cv_data (int): wave list with kaldi style for cross validation
|
||||
trans_data (str): transcription list with kaldi style, merge train and cv
|
||||
checkpoint (str): basemodel checkpoint, if None, default to use base.pt in model path
|
||||
tensorboard_dir (str): path to save tensorboard results,
|
||||
create 'tensorboard_dir' in work_dir by default
|
||||
need_dump (bool): wether to dump data with mapping tokens or not
|
||||
'''
|
||||
train_checkpoint = kwargs.get('checkpoint', None)
|
||||
if train_checkpoint is not None and os.path.exists(train_checkpoint):
|
||||
self.checkpoint = train_checkpoint
|
||||
else:
|
||||
self.checkpoint = os.path.join(self.model_dir, 'train/base.pt')
|
||||
self.tensorboard_dir = kwargs.get('tensorboard_dir', 'tensorboard')
|
||||
|
||||
# 1. prepare dataset and dataloader
|
||||
assert kwargs['train_data'], 'please config train data in dict kwargs'
|
||||
assert kwargs['cv_data'], 'please config cv data in dict kwargs'
|
||||
@@ -124,32 +131,36 @@ class KWSNearfieldTrainer(BaseTrainer):
|
||||
self.train_data = kwargs['train_data']
|
||||
self.cv_data = kwargs['cv_data']
|
||||
self.trans_data = kwargs['trans_data']
|
||||
self.need_dump = kwargs.get(
|
||||
'need_dump', False) and (True if self.rank == 0 else False)
|
||||
|
||||
train_conf = self.configs['preprocessor']
|
||||
cv_conf = copy.deepcopy(train_conf)
|
||||
cv_conf['speed_perturb'] = False
|
||||
cv_conf['spec_aug'] = False
|
||||
cv_conf['shuffle'] = False
|
||||
self.train_dataset = kws_nearfield_dataset(self.train_data,
|
||||
self.trans_data, train_conf,
|
||||
self.token_table,
|
||||
self.lexicon_table, True)
|
||||
|
||||
dump_train_file = os.path.join(self.work_dir, 'dump_train.txt')
|
||||
dump_cv_file = os.path.join(self.work_dir, 'dump_cv.txt')
|
||||
self.train_dataset = kws_nearfield_dataset(
|
||||
self.train_data, self.trans_data, train_conf, self.token_table,
|
||||
self.lexicon_table, self.need_dump, dump_train_file, True)
|
||||
self.cv_dataset = kws_nearfield_dataset(self.cv_data, self.trans_data,
|
||||
cv_conf, self.token_table,
|
||||
self.lexicon_table, True)
|
||||
self.lexicon_table,
|
||||
self.need_dump, dump_cv_file,
|
||||
True)
|
||||
|
||||
self.train_dataloader = DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=None,
|
||||
pin_memory=kwargs.get('pin_memory', False),
|
||||
persistent_workers=True,
|
||||
num_workers=self.configs.train.dataloader.workers_per_gpu,
|
||||
prefetch_factor=self.configs.train.dataloader.get('prefetch', 2))
|
||||
self.cv_dataloader = DataLoader(
|
||||
self.cv_dataset,
|
||||
batch_size=None,
|
||||
pin_memory=kwargs.get('pin_memory', False),
|
||||
persistent_workers=True,
|
||||
num_workers=self.configs.evaluation.dataloader.workers_per_gpu,
|
||||
prefetch_factor=self.configs.evaluation.dataloader.get(
|
||||
'prefetch', 2))
|
||||
@@ -180,18 +191,18 @@ class KWSNearfieldTrainer(BaseTrainer):
|
||||
self.configs['train']['optimizer']['lr'] = lr_last_epoch
|
||||
|
||||
# 4. model placement
|
||||
self.device_name = kwargs.get('device', 'gpu')
|
||||
device_name = kwargs.get('device', 'gpu')
|
||||
if self.world_size > 1:
|
||||
self.device_name = f'cuda:{self.local_rank}'
|
||||
self.device = create_device(self.device_name)
|
||||
device_name = f'cuda:{self.local_rank}'
|
||||
self.train_device = create_device(device_name)
|
||||
|
||||
if self.world_size > 1:
|
||||
assert (torch.cuda.is_available())
|
||||
# cuda model is required for nn.parallel.DistributedDataParallel
|
||||
self.model.cuda()
|
||||
self.model = self.model.to(self.train_device)
|
||||
self.model = torch.nn.parallel.DistributedDataParallel(self.model)
|
||||
else:
|
||||
self.model = self.model.to(self.device)
|
||||
self.model = self.model.to(self.train_device)
|
||||
|
||||
# 5. update training config file
|
||||
if self.rank == 0:
|
||||
@@ -230,7 +241,7 @@ class KWSNearfieldTrainer(BaseTrainer):
|
||||
if self.start_epoch == 0 and self.rank == 0:
|
||||
save_model_path = os.path.join(self.work_dir, 'init.pt')
|
||||
save_checkpoint(self.model, save_model_path, None, None, None,
|
||||
False)
|
||||
False, True)
|
||||
|
||||
# Start training loop
|
||||
logger.info('Start training...')
|
||||
@@ -250,17 +261,18 @@ class KWSNearfieldTrainer(BaseTrainer):
|
||||
lr = optimizer.param_groups[0]['lr']
|
||||
logger.info('Epoch {} TRAIN info lr {}'.format(epoch, lr))
|
||||
executor_train(self.model, optimizer, self.train_dataloader,
|
||||
self.device, writer, training_config)
|
||||
cv_loss = executor_cv(self.model, self.cv_dataloader, self.device,
|
||||
training_config)
|
||||
logger.info('Epoch {} EVAL info cv_loss {:.6f}'.format(
|
||||
epoch, cv_loss))
|
||||
self.train_device, writer, training_config)
|
||||
cv_loss, cv_acc = executor_cv(self.model, self.cv_dataloader,
|
||||
self.train_device, training_config)
|
||||
logger.info(
|
||||
'Epoch {} EVAL info cv_loss {:.6f}, cv_acc {:.2f}'.format(
|
||||
epoch, cv_loss, cv_acc))
|
||||
|
||||
if self.rank == 0:
|
||||
save_model_path = os.path.join(self.work_dir,
|
||||
'{}.pt'.format(epoch))
|
||||
save_checkpoint(self.model, save_model_path, None, None, None,
|
||||
False)
|
||||
False, True)
|
||||
|
||||
info_path = re.sub('.pt$', '.yaml', save_model_path)
|
||||
info_dict = dict(
|
||||
@@ -274,6 +286,7 @@ class KWSNearfieldTrainer(BaseTrainer):
|
||||
|
||||
writer.add_scalar('epoch/cv_loss', cv_loss, epoch)
|
||||
writer.add_scalar('epoch/lr', lr, epoch)
|
||||
|
||||
final_epoch = epoch
|
||||
lr_scheduler.step(cv_loss)
|
||||
|
||||
@@ -296,17 +309,18 @@ class KWSNearfieldTrainer(BaseTrainer):
|
||||
average_num (int): the NO. to do model averaging(checkpoint_path==None)
|
||||
batch_size (int): batch size during evaluating
|
||||
keywords (str): keyword string, split with ','
|
||||
gpu (int): evaluating with cpu/gpu: -1 for cpu; >=0 for gpu,
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] will be setted
|
||||
gpu (int): evaluating with cpu/gpu: -1 for cpu; >=0 for gpu
|
||||
'''
|
||||
|
||||
# 1. get checkpoint
|
||||
self.avg_checkpoint = None
|
||||
if checkpoint_path is not None and os.path.exists(checkpoint_path):
|
||||
logger.warning(
|
||||
f'evaluating with specific model: {checkpoint_path}')
|
||||
eval_checkpoint = checkpoint_path
|
||||
else:
|
||||
if self.avg_checkpoint is None:
|
||||
avg_num = kwargs.get('average_num', 5)
|
||||
avg_num = kwargs.get('average_num', 10)
|
||||
self.avg_checkpoint = os.path.join(self.work_dir,
|
||||
f'avg_{avg_num}.pt')
|
||||
logger.warning(
|
||||
@@ -353,13 +367,13 @@ class KWSNearfieldTrainer(BaseTrainer):
|
||||
test_conf['speed_perturb'] = False
|
||||
test_conf['spec_aug'] = False
|
||||
test_conf['shuffle'] = False
|
||||
test_conf['feature_extraction_conf']['dither'] = 0.0
|
||||
if kwargs.get('batch_size', None) is not None:
|
||||
test_conf['batch_conf']['batch_size'] = kwargs['batch_size']
|
||||
|
||||
test_dataset = kws_nearfield_dataset(test_data, trans_data, test_conf,
|
||||
self.token_table,
|
||||
self.lexicon_table, False)
|
||||
self.lexicon_table, False, '',
|
||||
False)
|
||||
test_dataloader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=None,
|
||||
@@ -375,33 +389,39 @@ class KWSNearfieldTrainer(BaseTrainer):
|
||||
keywords_str = kwargs['keywords']
|
||||
keywords_list = keywords_str.strip().replace(' ', '').split(',')
|
||||
keywords_token = {}
|
||||
keywords_tokenset = {0}
|
||||
keywords_idxset = {0}
|
||||
keywords_strset = {'<blk>'}
|
||||
keywords_tokenmap = {'<blk>': 0}
|
||||
for keyword in keywords_list:
|
||||
ids = query_tokens_id(keyword, self.token_table,
|
||||
self.lexicon_table)
|
||||
strs, indexes = query_token_set(keyword, self.token_table,
|
||||
self.lexicon_table)
|
||||
keywords_token[keyword] = {}
|
||||
keywords_token[keyword]['token_id'] = ids
|
||||
keywords_token[keyword]['token_id'] = indexes
|
||||
keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i)
|
||||
for i in ids)
|
||||
[keywords_tokenset.add(i) for i in ids]
|
||||
logger.warning(f'Token set is: {keywords_tokenset}')
|
||||
for i in indexes)
|
||||
[keywords_strset.add(i) for i in strs]
|
||||
[keywords_idxset.add(i) for i in indexes]
|
||||
for txt, idx in zip(strs, indexes):
|
||||
if keywords_tokenmap.get(txt, None) is None:
|
||||
keywords_tokenmap[txt] = idx
|
||||
|
||||
token_print = ''
|
||||
for txt, idx in keywords_tokenmap.items():
|
||||
token_print += f'{txt}({idx}) '
|
||||
logger.warning(f'Token set is: {token_print}')
|
||||
|
||||
# 5. build model and load checkpoint
|
||||
# support assign specific gpu device
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(kwargs.get('gpu', -1))
|
||||
# Init kws model from configs
|
||||
use_cuda = kwargs.get('gpu', -1) >= 0 and torch.cuda.is_available()
|
||||
device_name = kwargs.get('device', 'cpu')
|
||||
if self.world_size > 1 and use_cuda:
|
||||
device_name = f'cuda:{self.local_rank}'
|
||||
self.test_device = create_device(device_name)
|
||||
|
||||
if kwargs.get('jit_model', None):
|
||||
model = torch.jit.load(eval_checkpoint)
|
||||
# For script model, only cpu is supported.
|
||||
device = torch.device('cpu')
|
||||
else:
|
||||
# Init kws model from configs
|
||||
model = self.build_model(self.configs)
|
||||
load_checkpoint(eval_checkpoint, model)
|
||||
device = torch.device('cuda' if use_cuda else 'cpu')
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
self.test_model = self.build_model(self.configs)
|
||||
load_checkpoint(eval_checkpoint, self.test_model)
|
||||
self.test_model = self.test_model.to(self.test_device)
|
||||
|
||||
testing_config = {}
|
||||
if kwargs.get('test_dir', None) is not None:
|
||||
@@ -417,9 +437,9 @@ class KWSNearfieldTrainer(BaseTrainer):
|
||||
# 6. executing evaluation and get score file
|
||||
logger.info('Start evaluating...')
|
||||
totaltime = datetime.datetime.now()
|
||||
score_file = executor_test(model, test_dataloader, device,
|
||||
keywords_token, keywords_tokenset,
|
||||
testing_config)
|
||||
score_file = executor_test(self.test_model, test_dataloader,
|
||||
self.test_device, keywords_token,
|
||||
keywords_idxset, testing_config)
|
||||
totaltime = datetime.datetime.now() - totaltime
|
||||
logger.info('Total time spent: {:.2f} hours'.format(
|
||||
totaltime.total_seconds() / 3600.0))
|
||||
@@ -448,7 +468,7 @@ class KWSNearfieldTrainer(BaseTrainer):
|
||||
elif isinstance(model, nn.Module):
|
||||
return model
|
||||
|
||||
def get_dist_info(self):
|
||||
def init_dist(self, train_nodes=1):
|
||||
if os.getenv('RANK', None) is None:
|
||||
os.environ['RANK'] = '0'
|
||||
if os.getenv('LOCAL_RANK', None) is None:
|
||||
@@ -466,6 +486,28 @@ class KWSNearfieldTrainer(BaseTrainer):
|
||||
self.master_addr = os.environ['MASTER_ADDR']
|
||||
self.master_port = os.environ['MASTER_PORT']
|
||||
|
||||
init_dist(self.launcher, self.dist_backend)
|
||||
self.rank, self.world_size = get_dist_info()
|
||||
self.local_rank = get_local_rank()
|
||||
if train_nodes == 1:
|
||||
if self.world_size > 1:
|
||||
logger.info('init dist on multiple gpus, this gpu {}'.format(
|
||||
self.local_rank))
|
||||
dist.init_process_group(
|
||||
backend=self.dist_backend, init_method='env://')
|
||||
elif train_nodes > 1:
|
||||
dist.init_process_group(
|
||||
backend=self.dist_backend, init_method='env://')
|
||||
dist.barrier()
|
||||
|
||||
logger.info('RANK {}/{}/{}, Master addr:{}, Master port:{}'.format(
|
||||
self.world_size, self.rank, self.local_rank, self.master_addr,
|
||||
self.master_port))
|
||||
|
||||
def uninit_dist(self, train_nodes=1):
|
||||
if train_nodes == 1:
|
||||
if self.world_size > 1:
|
||||
logger.info(
|
||||
'destory dist on multiple gpus, this gpu {}'.format(
|
||||
self.local_rank))
|
||||
dist.destroy_process_group()
|
||||
elif train_nodes > 1:
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
|
||||
@@ -65,7 +65,8 @@ def executor_train(model, optimizer, data_loader, device, writer, args):
|
||||
if num_utts == 0:
|
||||
continue
|
||||
logits, _ = model(feats)
|
||||
loss = ctc_loss(logits, target, feats_lengths, target_lengths)
|
||||
loss, acc = ctc_loss(logits, target, feats_lengths, target_lengths)
|
||||
loss = loss / num_utts
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
grad_norm = clip_grad_norm_(model.parameters(), clip)
|
||||
@@ -90,11 +91,12 @@ def executor_cv(model, data_loader, device, args):
|
||||
epoch = args.get('epoch', 0)
|
||||
# in order to avoid division by 0
|
||||
num_seen_utts = 1
|
||||
num_seen_tokens = 1
|
||||
total_loss = 0.0
|
||||
# [For distributed] Because iteration counts are not always equals between
|
||||
# processes, send stop-flag to the other processes if iterator is finished
|
||||
iterator_stop = torch.tensor(0).to(device)
|
||||
counter = torch.zeros((3, ), device=device)
|
||||
counter = torch.zeros((4, ), device=device)
|
||||
|
||||
rank = args.get('rank', 0)
|
||||
local_rank = args.get('local_rank', 0)
|
||||
@@ -117,18 +119,25 @@ def executor_cv(model, data_loader, device, args):
|
||||
if num_utts == 0:
|
||||
continue
|
||||
logits, _ = model(feats)
|
||||
loss = ctc_loss(logits, target, feats_lengths, target_lengths)
|
||||
|
||||
loss, acc = ctc_loss(logits, target, feats_lengths, target_lengths,
|
||||
True)
|
||||
if torch.isfinite(loss):
|
||||
num_seen_utts += num_utts
|
||||
total_loss += loss.item() * num_utts
|
||||
counter[0] += loss.item() * num_utts
|
||||
counter[1] += num_utts
|
||||
num_seen_tokens += target_lengths.sum()
|
||||
total_loss += loss.item()
|
||||
counter[0] += loss.item()
|
||||
counter[1] += acc * target_lengths.sum()
|
||||
counter[2] += num_utts
|
||||
counter[3] += target_lengths.sum()
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
logger.info(
|
||||
'RANK {}/{}/{} CV Batch {}/{} size {} loss {:.6f} history loss {:.6f}'
|
||||
'RANK {}/{}/{} CV Batch {}/{} size {} loss {:.6f} acc {:.2f} history loss {:.6f}'
|
||||
.format(world_size, rank, local_rank, epoch, batch_idx,
|
||||
num_utts, loss.item(), total_loss / num_seen_utts))
|
||||
num_utts,
|
||||
loss.item() / num_utts, acc,
|
||||
total_loss / num_seen_utts))
|
||||
else:
|
||||
iterator_stop.fill_(1)
|
||||
if world_size > 1:
|
||||
@@ -136,14 +145,15 @@ def executor_cv(model, data_loader, device, args):
|
||||
|
||||
if world_size > 1:
|
||||
dist.all_reduce(counter, ReduceOp.SUM)
|
||||
logger.info('Total utts number is {}'.format(counter[1]))
|
||||
logger.info('Total utts number is {}'.format(counter[2]))
|
||||
counter = counter.to('cpu')
|
||||
|
||||
return counter[0].item() / counter[1].item()
|
||||
return counter[0].item() / counter[2].item(), counter[1].item(
|
||||
) / counter[2].item()
|
||||
|
||||
|
||||
def executor_test(model, data_loader, device, keywords_token,
|
||||
keywords_tokenset, args):
|
||||
def executor_test(model, data_loader, device, keywords_token, keywords_idxset,
|
||||
args):
|
||||
''' Test model with decoder
|
||||
'''
|
||||
assert args.get('test_dir', None) is not None, \
|
||||
@@ -151,6 +161,7 @@ def executor_test(model, data_loader, device, keywords_token,
|
||||
score_abs_path = os.path.join(args['test_dir'], 'score.txt')
|
||||
log_interval = args.get('log_interval', 10)
|
||||
|
||||
model.eval()
|
||||
infer_seconds = 0.0
|
||||
decode_seconds = 0.0
|
||||
with torch.no_grad(), open(score_abs_path, 'w', encoding='utf8') as fout:
|
||||
@@ -175,8 +186,7 @@ def executor_test(model, data_loader, device, keywords_token,
|
||||
key = keys[i]
|
||||
score = logits[i][:feats_lengths[i]]
|
||||
hyps = ctc_prefix_beam_search(score, feats_lengths[i],
|
||||
keywords_tokenset)
|
||||
|
||||
keywords_idxset)
|
||||
hit_keyword = None
|
||||
hit_score = 1.0
|
||||
# start = 0; end = 0
|
||||
@@ -243,8 +253,11 @@ def is_sublist(main_list, check_list):
|
||||
return -1
|
||||
|
||||
|
||||
def ctc_loss(logits: torch.Tensor, target: torch.Tensor,
|
||||
logits_lengths: torch.Tensor, target_lengths: torch.Tensor):
|
||||
def ctc_loss(logits: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
logits_lengths: torch.Tensor,
|
||||
target_lengths: torch.Tensor,
|
||||
need_acc: bool = False):
|
||||
""" CTC Loss
|
||||
Args:
|
||||
logits: (B, D), D is the number of keywords plus 1 (non-keyword)
|
||||
@@ -255,14 +268,51 @@ def ctc_loss(logits: torch.Tensor, target: torch.Tensor,
|
||||
(float): loss of current batch
|
||||
"""
|
||||
|
||||
acc = 0.0
|
||||
if need_acc:
|
||||
acc = acc_utterance(logits, target, logits_lengths, target_lengths)
|
||||
|
||||
# logits: (B, L, D) -> (L, B, D)
|
||||
logits = logits.transpose(0, 1)
|
||||
logits = logits.log_softmax(2)
|
||||
loss = F.ctc_loss(
|
||||
logits, target, logits_lengths, target_lengths, reduction='sum')
|
||||
loss = loss / logits.size(1)
|
||||
# loss = loss / logits.size(1)
|
||||
|
||||
return loss
|
||||
return loss, acc
|
||||
|
||||
|
||||
def acc_utterance(logits: torch.Tensor, target: torch.Tensor,
|
||||
logits_length: torch.Tensor, target_length: torch.Tensor):
|
||||
if logits is None:
|
||||
return 0
|
||||
|
||||
logits = logits.softmax(2) # (1, maxlen, vocab_size)
|
||||
logits = logits.cpu()
|
||||
target = target.cpu()
|
||||
|
||||
total_word = 0
|
||||
total_ins = 0
|
||||
total_sub = 0
|
||||
total_del = 0
|
||||
calculator = Calculator()
|
||||
for i in range(logits.size(0)):
|
||||
score = logits[i][:logits_length[i]]
|
||||
hyps = ctc_prefix_beam_search(score, logits_length[i], None, 3, 5)
|
||||
lab = [str(item) for item in target[i][:target_length[i]].tolist()]
|
||||
rec = []
|
||||
if len(hyps) > 0:
|
||||
rec = [str(item) for item in hyps[0][0]]
|
||||
result = calculator.calculate(lab, rec)
|
||||
# print(f'result:{result}')
|
||||
if result['all'] != 0:
|
||||
total_word += result['all']
|
||||
total_ins += result['ins']
|
||||
total_sub += result['sub']
|
||||
total_del += result['del']
|
||||
|
||||
return float(total_word - total_ins - total_sub
|
||||
- total_del) * 100.0 / total_word
|
||||
|
||||
|
||||
def ctc_prefix_beam_search(
|
||||
@@ -304,9 +354,14 @@ def ctc_prefix_beam_search(
|
||||
filter_probs = []
|
||||
filter_index = []
|
||||
for prob, idx in zip(top_k_probs.tolist(), top_k_index.tolist()):
|
||||
if prob > 0.05 and idx in keywords_tokenset:
|
||||
filter_probs.append(prob)
|
||||
filter_index.append(idx)
|
||||
if keywords_tokenset is not None:
|
||||
if prob > 0.05 and idx in keywords_tokenset:
|
||||
filter_probs.append(prob)
|
||||
filter_index.append(idx)
|
||||
else:
|
||||
if prob > 0.05:
|
||||
filter_probs.append(prob)
|
||||
filter_index.append(idx)
|
||||
|
||||
if len(filter_index) == 0:
|
||||
continue
|
||||
@@ -363,3 +418,161 @@ def ctc_prefix_beam_search(
|
||||
|
||||
hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in cur_hyps]
|
||||
return hyps
|
||||
|
||||
|
||||
class Calculator:
|
||||
|
||||
def __init__(self):
|
||||
self.data = {}
|
||||
self.space = []
|
||||
self.cost = {}
|
||||
self.cost['cor'] = 0
|
||||
self.cost['sub'] = 1
|
||||
self.cost['del'] = 1
|
||||
self.cost['ins'] = 1
|
||||
|
||||
def calculate(self, lab, rec):
|
||||
# Initialization
|
||||
lab.insert(0, '')
|
||||
rec.insert(0, '')
|
||||
while len(self.space) < len(lab):
|
||||
self.space.append([])
|
||||
for row in self.space:
|
||||
for element in row:
|
||||
element['dist'] = 0
|
||||
element['error'] = 'non'
|
||||
while len(row) < len(rec):
|
||||
row.append({'dist': 0, 'error': 'non'})
|
||||
for i in range(len(lab)):
|
||||
self.space[i][0]['dist'] = i
|
||||
self.space[i][0]['error'] = 'del'
|
||||
for j in range(len(rec)):
|
||||
self.space[0][j]['dist'] = j
|
||||
self.space[0][j]['error'] = 'ins'
|
||||
self.space[0][0]['error'] = 'non'
|
||||
for token in lab:
|
||||
if token not in self.data and len(token) > 0:
|
||||
self.data[token] = {
|
||||
'all': 0,
|
||||
'cor': 0,
|
||||
'sub': 0,
|
||||
'ins': 0,
|
||||
'del': 0
|
||||
}
|
||||
for token in rec:
|
||||
if token not in self.data and len(token) > 0:
|
||||
self.data[token] = {
|
||||
'all': 0,
|
||||
'cor': 0,
|
||||
'sub': 0,
|
||||
'ins': 0,
|
||||
'del': 0
|
||||
}
|
||||
# Computing edit distance
|
||||
for i, lab_token in enumerate(lab):
|
||||
for j, rec_token in enumerate(rec):
|
||||
if i == 0 or j == 0:
|
||||
continue
|
||||
min_dist = sys.maxsize
|
||||
min_error = 'none'
|
||||
dist = self.space[i - 1][j]['dist'] + self.cost['del']
|
||||
error = 'del'
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
min_error = error
|
||||
dist = self.space[i][j - 1]['dist'] + self.cost['ins']
|
||||
error = 'ins'
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
min_error = error
|
||||
if lab_token == rec_token:
|
||||
dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor']
|
||||
error = 'cor'
|
||||
else:
|
||||
dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub']
|
||||
error = 'sub'
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
min_error = error
|
||||
self.space[i][j]['dist'] = min_dist
|
||||
self.space[i][j]['error'] = min_error
|
||||
# Tracing back
|
||||
result = {
|
||||
'lab': [],
|
||||
'rec': [],
|
||||
'all': 0,
|
||||
'cor': 0,
|
||||
'sub': 0,
|
||||
'ins': 0,
|
||||
'del': 0
|
||||
}
|
||||
i = len(lab) - 1
|
||||
j = len(rec) - 1
|
||||
while True:
|
||||
if self.space[i][j]['error'] == 'cor': # correct
|
||||
if len(lab[i]) > 0:
|
||||
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
|
||||
self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1
|
||||
result['all'] = result['all'] + 1
|
||||
result['cor'] = result['cor'] + 1
|
||||
result['lab'].insert(0, lab[i])
|
||||
result['rec'].insert(0, rec[j])
|
||||
i = i - 1
|
||||
j = j - 1
|
||||
elif self.space[i][j]['error'] == 'sub': # substitution
|
||||
if len(lab[i]) > 0:
|
||||
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
|
||||
self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1
|
||||
result['all'] = result['all'] + 1
|
||||
result['sub'] = result['sub'] + 1
|
||||
result['lab'].insert(0, lab[i])
|
||||
result['rec'].insert(0, rec[j])
|
||||
i = i - 1
|
||||
j = j - 1
|
||||
elif self.space[i][j]['error'] == 'del': # deletion
|
||||
if len(lab[i]) > 0:
|
||||
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
|
||||
self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1
|
||||
result['all'] = result['all'] + 1
|
||||
result['del'] = result['del'] + 1
|
||||
result['lab'].insert(0, lab[i])
|
||||
result['rec'].insert(0, '')
|
||||
i = i - 1
|
||||
elif self.space[i][j]['error'] == 'ins': # insertion
|
||||
if len(rec[j]) > 0:
|
||||
self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1
|
||||
result['ins'] = result['ins'] + 1
|
||||
result['lab'].insert(0, '')
|
||||
result['rec'].insert(0, rec[j])
|
||||
j = j - 1
|
||||
elif self.space[i][j]['error'] == 'non': # starting point
|
||||
break
|
||||
else: # shouldn't reach here
|
||||
print(
|
||||
'this should not happen , i = {i} , j = {j} , error = {error}'
|
||||
.format(i=i, j=j, error=self.space[i][j]['error']))
|
||||
return result
|
||||
|
||||
def overall(self):
|
||||
result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
|
||||
for token in self.data:
|
||||
result['all'] = result['all'] + self.data[token]['all']
|
||||
result['cor'] = result['cor'] + self.data[token]['cor']
|
||||
result['sub'] = result['sub'] + self.data[token]['sub']
|
||||
result['ins'] = result['ins'] + self.data[token]['ins']
|
||||
result['del'] = result['del'] + self.data[token]['del']
|
||||
return result
|
||||
|
||||
def cluster(self, data):
|
||||
result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
|
||||
for token in data:
|
||||
if token in self.data:
|
||||
result['all'] = result['all'] + self.data[token]['all']
|
||||
result['cor'] = result['cor'] + self.data[token]['cor']
|
||||
result['sub'] = result['sub'] + self.data[token]['sub']
|
||||
result['ins'] = result['ins'] + self.data[token]['ins']
|
||||
result['del'] = result['del'] + self.data[token]['del']
|
||||
return result
|
||||
|
||||
def keys(self):
|
||||
return list(self.data.keys())
|
||||
|
||||
@@ -25,7 +25,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .file_utils import make_pair, read_lists
|
||||
from .file_utils import make_pair, read_lists, space_mixed_label
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
@@ -68,7 +68,7 @@ def count_duration(tid, data_lists):
|
||||
frames = len(waveform[0])
|
||||
duration = frames / float(rate)
|
||||
except Exception:
|
||||
logging.info(f'load file failed: {wav_file}')
|
||||
logger.info(f'load file failed: {wav_file}')
|
||||
duration = 0.0
|
||||
|
||||
obj['duration'] = duration
|
||||
@@ -88,11 +88,12 @@ def load_data_and_score(keywords_list, data_file, trans_file, score_file):
|
||||
is_detected = arr[1]
|
||||
if is_detected == 'detected':
|
||||
if key not in score_table:
|
||||
score_table.update(
|
||||
{key: {
|
||||
'kw': arr[2],
|
||||
score_table.update({
|
||||
key: {
|
||||
'kw': space_mixed_label(arr[2]),
|
||||
'confi': float(arr[3])
|
||||
}})
|
||||
}
|
||||
})
|
||||
else:
|
||||
if key not in score_table:
|
||||
score_table.update({key: {'kw': 'unknown', 'confi': -1.0}})
|
||||
@@ -100,13 +101,14 @@ def load_data_and_score(keywords_list, data_file, trans_file, score_file):
|
||||
wav_lists = read_lists(data_file)
|
||||
trans_lists = read_lists(trans_file)
|
||||
data_lists = make_pair(wav_lists, trans_lists)
|
||||
logger.info(f'origin list samples: {len(data_lists)}')
|
||||
|
||||
# count duration for each wave use multi-thread
|
||||
num_workers = 8
|
||||
start = 0
|
||||
step = int(len(data_lists) / num_workers)
|
||||
tasks = []
|
||||
for idx in range(8):
|
||||
for idx in range(num_workers):
|
||||
if idx != num_workers - 1:
|
||||
task = thread_wrapper(count_duration,
|
||||
(idx, data_lists[start:start + step]))
|
||||
@@ -120,10 +122,12 @@ def load_data_and_score(keywords_list, data_file, trans_file, score_file):
|
||||
for task in tasks:
|
||||
task.join()
|
||||
duration_lists += task.get_result()
|
||||
logger.info(f'after list samples: {len(duration_lists)}')
|
||||
|
||||
# build empty structure for keyword-filler infos
|
||||
keyword_filler_table = {}
|
||||
for keyword in keywords_list:
|
||||
keyword = space_mixed_label(keyword)
|
||||
keyword_filler_table[keyword] = {}
|
||||
keyword_filler_table[keyword]['keyword_table'] = {}
|
||||
keyword_filler_table[keyword]['keyword_duration'] = 0.0
|
||||
@@ -139,11 +143,15 @@ def load_data_and_score(keywords_list, data_file, trans_file, score_file):
|
||||
key = obj['key']
|
||||
# wav_file = obj['wav']
|
||||
txt = obj['txt']
|
||||
txt = space_mixed_label(txt)
|
||||
txt_regstr_lrblk = ' ' + txt + ' '
|
||||
duration = obj['duration']
|
||||
assert key in score_table
|
||||
|
||||
for keyword in keywords_list:
|
||||
if txt.find(keyword) != -1:
|
||||
keyword = space_mixed_label(keyword)
|
||||
keyword_regstr_lrblk = ' ' + keyword + ' '
|
||||
if txt_regstr_lrblk.find(keyword_regstr_lrblk) != -1:
|
||||
if keyword == score_table[key]['kw']:
|
||||
keyword_filler_table[keyword]['keyword_table'].update(
|
||||
{key: score_table[key]['confi']})
|
||||
@@ -203,12 +211,13 @@ def compute_det(**kwargs):
|
||||
|
||||
score_step = kwargs.get('score_step', 0.001)
|
||||
|
||||
keywords_list = keywords.replace(' ', '').strip().split(',')
|
||||
keywords_list = keywords.strip().split(',')
|
||||
keyword_filler_table = load_data_and_score(keywords_list, test_data,
|
||||
trans_data, score_file)
|
||||
|
||||
stats_files = {}
|
||||
for keyword in keywords_list:
|
||||
keyword = space_mixed_label(keyword)
|
||||
keyword_dur = keyword_filler_table[keyword]['keyword_duration']
|
||||
keyword_num = len(keyword_filler_table[keyword]['keyword_table'])
|
||||
filler_dur = keyword_filler_table[keyword]['filler_duration']
|
||||
@@ -221,7 +230,8 @@ def compute_det(**kwargs):
|
||||
keyword_dur / 3600.0, keyword_num))
|
||||
logger.info(' Filler duration: {} Hours'.format(filler_dur / 3600.0))
|
||||
|
||||
stats_file = os.path.join(stats_dir, 'stats_' + keyword + '.txt')
|
||||
stats_file = os.path.join(
|
||||
stats_dir, 'stats.' + keyword.replace(' ', '_') + '.txt')
|
||||
with open(stats_file, 'w', encoding='utf8') as fout:
|
||||
threshold = 0.0
|
||||
while threshold <= 1.0:
|
||||
|
||||
@@ -18,14 +18,35 @@ from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
remove_str = ['!sil', '(noise)', '(noise', 'noise)', '·', '’']
|
||||
symbol_str = '[’!"#$%&\'()*+,-./:;<>=?@,。?★、…【】《》?“”‘’![\\]^_`{|}~]+'
|
||||
|
||||
|
||||
def split_mixed_label(input_str):
|
||||
tokens = []
|
||||
s = input_str.lower()
|
||||
while len(s) > 0:
|
||||
match = re.match(r'[A-Za-z!?,<>()\']+', s)
|
||||
if match is not None:
|
||||
word = match.group(0)
|
||||
else:
|
||||
word = s[0:1]
|
||||
tokens.append(word)
|
||||
s = s.replace(word, '', 1).strip(' ')
|
||||
return tokens
|
||||
|
||||
|
||||
def space_mixed_label(input_str):
|
||||
splits = split_mixed_label(input_str)
|
||||
space_str = ''.join(f'{sub} ' for sub in splits)
|
||||
return space_str.strip()
|
||||
|
||||
|
||||
def read_lists(list_file):
|
||||
lists = []
|
||||
with open(list_file, 'r', encoding='utf8') as fin:
|
||||
for line in fin:
|
||||
lists.append(line.strip())
|
||||
if line.strip() != '':
|
||||
lists.append(line.strip())
|
||||
return lists
|
||||
|
||||
|
||||
@@ -37,14 +58,7 @@ def make_pair(wav_lists, trans_lists):
|
||||
logger.debug('invalid line in trans file: {}'.format(line.strip()))
|
||||
continue
|
||||
|
||||
trans_table[arr[0]] = line.replace(arr[0], '')\
|
||||
.replace(' ', '')\
|
||||
.replace('(noise)', '')\
|
||||
.replace('noise)', '')\
|
||||
.replace('(noise', '')\
|
||||
.replace('!sil', '')\
|
||||
.replace('·', '')\
|
||||
.replace('’', '').strip()
|
||||
trans_table[arr[0]] = line.replace(arr[0], '').strip()
|
||||
|
||||
lists = []
|
||||
for line in wav_lists:
|
||||
@@ -86,27 +100,110 @@ def read_lexicon(lexicon_file):
|
||||
return lexicon_table
|
||||
|
||||
|
||||
def query_tokens_id(txt, symbol_table, lexicon_table):
|
||||
label = tuple()
|
||||
tokens = []
|
||||
def query_token_set(txt, symbol_table, lexicon_table):
|
||||
tokens_str = tuple()
|
||||
tokens_idx = tuple()
|
||||
|
||||
parts = [txt.replace(' ', '').strip()]
|
||||
parts = split_mixed_label(txt)
|
||||
for part in parts:
|
||||
for ch in part:
|
||||
if ch == ' ':
|
||||
ch = '▁'
|
||||
tokens.append(ch)
|
||||
|
||||
for ch in tokens:
|
||||
if ch in symbol_table:
|
||||
label = label + (symbol_table[ch], )
|
||||
elif ch in lexicon_table:
|
||||
for sub_ch in lexicon_table[ch]:
|
||||
if sub_ch in symbol_table:
|
||||
label = label + (symbol_table[sub_ch], )
|
||||
else:
|
||||
label = label + (symbol_table['<blk>'], )
|
||||
if part == '!sil' or part == '(sil)' or part == '<sil>':
|
||||
tokens_str = tokens_str + ('!sil', )
|
||||
elif part == '<blk>' or part == '<blank>':
|
||||
tokens_str = tokens_str + ('<blk>', )
|
||||
elif part == '(noise)' or part == 'noise)' or part == '(noise' or part == '<noise>':
|
||||
tokens_str = tokens_str + ('<GBG>', )
|
||||
elif part in symbol_table:
|
||||
tokens_str = tokens_str + (part, )
|
||||
elif part in lexicon_table:
|
||||
for ch in lexicon_table[part]:
|
||||
tokens_str = tokens_str + (ch, )
|
||||
else:
|
||||
label = label + (symbol_table['<blk>'], )
|
||||
# case with symbols or meaningless english letter combination
|
||||
part = re.sub(symbol_str, '', part)
|
||||
for ch in part:
|
||||
tokens_str = tokens_str + (ch, )
|
||||
|
||||
return label
|
||||
for ch in tokens_str:
|
||||
if ch in symbol_table:
|
||||
tokens_idx = tokens_idx + (symbol_table[ch], )
|
||||
elif ch == '!sil':
|
||||
if 'sil' in symbol_table:
|
||||
tokens_idx = tokens_idx + (symbol_table['sil'], )
|
||||
else:
|
||||
tokens_idx = tokens_idx + (symbol_table['<blk>'], )
|
||||
elif ch == '<GBG>':
|
||||
if '<GBG>' in symbol_table:
|
||||
tokens_idx = tokens_idx + (symbol_table['<GBG>'], )
|
||||
else:
|
||||
tokens_idx = tokens_idx + (symbol_table['<blk>'], )
|
||||
else:
|
||||
if '<GBG>' in symbol_table:
|
||||
tokens_idx = tokens_idx + (symbol_table['<GBG>'], )
|
||||
logger.info(
|
||||
f'\'{ch}\' is not in token set, replace with <GBG>')
|
||||
else:
|
||||
tokens_idx = tokens_idx + (symbol_table['<blk>'], )
|
||||
logger.info(
|
||||
f'\'{ch}\' is not in token set, replace with <blk>')
|
||||
|
||||
return tokens_str, tokens_idx
|
||||
|
||||
|
||||
def query_token_list(txt, symbol_table, lexicon_table):
|
||||
tokens_str = []
|
||||
tokens_idx = []
|
||||
|
||||
parts = split_mixed_label(txt)
|
||||
for part in parts:
|
||||
if part == '!sil' or part == '(sil)' or part == '<sil>':
|
||||
tokens_str.append('!sil')
|
||||
elif part == '<blk>' or part == '<blank>':
|
||||
tokens_str.append('<blk>')
|
||||
elif part == '(noise)' or part == 'noise)' or part == '(noise' or part == '<noise>':
|
||||
tokens_str.append('<GBG>')
|
||||
elif part in symbol_table:
|
||||
tokens_str.append(part)
|
||||
elif part in lexicon_table:
|
||||
for ch in lexicon_table[part]:
|
||||
tokens_str.append(ch)
|
||||
else:
|
||||
# case with symbols or meaningless english letter combination
|
||||
part = re.sub(symbol_str, '', part)
|
||||
for ch in part:
|
||||
tokens_str.append(ch)
|
||||
|
||||
for ch in tokens_str:
|
||||
if ch in symbol_table:
|
||||
tokens_idx.append(symbol_table[ch])
|
||||
elif ch == '!sil':
|
||||
if 'sil' in symbol_table:
|
||||
tokens_idx.append(symbol_table['sil'])
|
||||
else:
|
||||
tokens_idx.append(symbol_table['<blk>'])
|
||||
elif ch == '<GBG>':
|
||||
if '<GBG>' in symbol_table:
|
||||
tokens_idx.append(symbol_table['<GBG>'])
|
||||
else:
|
||||
tokens_idx.append(symbol_table['<blk>'])
|
||||
else:
|
||||
if '<GBG>' in symbol_table:
|
||||
tokens_idx.append(symbol_table['<GBG>'])
|
||||
logger.info(
|
||||
f'\'{ch}\' is not in token set, replace with <GBG>')
|
||||
else:
|
||||
tokens_idx.append(symbol_table['<blk>'])
|
||||
logger.info(
|
||||
f'\'{ch}\' is not in token set, replace with <blk>')
|
||||
|
||||
return tokens_str, tokens_idx
|
||||
|
||||
|
||||
def tokenize(data_list, symbol_table, lexicon_table):
|
||||
for sample in data_list:
|
||||
assert 'txt' in sample
|
||||
txt = sample['txt'].strip()
|
||||
strs, indexs = query_token_list(txt, symbol_table, lexicon_table)
|
||||
sample['tokens'] = strs
|
||||
sample['txt'] = indexs
|
||||
|
||||
return data_list
|
||||
|
||||
@@ -100,6 +100,9 @@ def save_checkpoint(model: torch.nn.Module,
|
||||
checkpoint['lr_scheduler'] = lr_scheduler.state_dict()
|
||||
|
||||
if with_model:
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
model = model.module
|
||||
|
||||
_weights = weights_to_cpu(model.state_dict())
|
||||
if not with_meta:
|
||||
checkpoint = _weights
|
||||
|
||||
Reference in New Issue
Block a user