Split training and evaluating code for nearfield kws trainer

* fix judgement of fa case for certain keywords in det
 * split code so that train and evaluate can be single used
 * fix pre-commit errors

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11453810
This commit is contained in:
pengteng.spt
2023-01-31 09:43:19 +00:00
committed by wenmeng.zwm
parent af62b3e9ad
commit e502e89c61
3 changed files with 128 additions and 46 deletions

View File

@@ -102,7 +102,7 @@ class KWSNearfieldTrainer(BaseTrainer):
os.makedirs(self.work_dir)
logger.info(f'Current working dir is {work_dir}')
# 2. prepare dataset and dataloader
# 2. prepare preset files
token_file = os.path.join(self.model_dir, 'train/tokens.txt')
assert os.path.exists(token_file), f'{token_file} is missing'
self.token_table = read_token(token_file)
@@ -111,6 +111,24 @@ class KWSNearfieldTrainer(BaseTrainer):
assert os.path.exists(lexicon_file), f'{lexicon_file} is missing'
self.lexicon_table = read_lexicon(lexicon_file)
feature_transform_file = os.path.join(
self.model_dir, 'train/feature_transform.txt.80dim-l2r2')
assert os.path.exists(feature_transform_file), \
f'{feature_transform_file} is missing'
configs.model['cmvn_file'] = feature_transform_file
# 3. write config.yaml for inference
self.configs = configs
if self.rank == 0:
if not os.path.exists(self.work_dir):
os.makedirs(self.work_dir)
saved_config_path = os.path.join(self.work_dir, 'config.yaml')
with open(saved_config_path, 'w') as fout:
data = yaml.dump(configs.to_dict())
fout.write(data)
def train(self, *args, **kwargs):
# 1. prepare dataset and dataloader
assert kwargs['train_data'], 'please config train data in dict kwargs'
assert kwargs['cv_data'], 'please config cv data in dict kwargs'
assert kwargs[
@@ -119,7 +137,7 @@ class KWSNearfieldTrainer(BaseTrainer):
self.cv_data = kwargs['cv_data']
self.trans_data = kwargs['trans_data']
train_conf = configs['preprocessor']
train_conf = self.configs['preprocessor']
cv_conf = copy.deepcopy(train_conf)
cv_conf['speed_perturb'] = False
cv_conf['spec_aug'] = False
@@ -137,31 +155,25 @@ class KWSNearfieldTrainer(BaseTrainer):
batch_size=None,
pin_memory=kwargs.get('pin_memory', False),
persistent_workers=True,
num_workers=configs.train.dataloader.workers_per_gpu,
prefetch_factor=configs.train.dataloader.get('prefetch', 2))
num_workers=self.configs.train.dataloader.workers_per_gpu,
prefetch_factor=self.configs.train.dataloader.get('prefetch', 2))
self.cv_dataloader = DataLoader(
self.cv_dataset,
batch_size=None,
pin_memory=kwargs.get('pin_memory', False),
persistent_workers=True,
num_workers=configs.evaluation.dataloader.workers_per_gpu,
prefetch_factor=configs.evaluation.dataloader.get('prefetch', 2))
num_workers=self.configs.evaluation.dataloader.workers_per_gpu,
prefetch_factor=self.configs.evaluation.dataloader.get(
'prefetch', 2))
# 3. build model, and load checkpoint
feature_transform_file = os.path.join(
self.model_dir, 'train/feature_transform.txt.80dim-l2r2')
assert os.path.exists(feature_transform_file), \
f'{feature_transform_file} is missing'
configs.model['cmvn_file'] = feature_transform_file
# 3.1 Init kws model from configs
self.model = self.build_model(configs)
# 2. Init kws model from configs
self.model = self.build_model(self.configs)
num_params = count_parameters(self.model)
if self.rank == 0:
# print(model)
logger.warning('the number of model params: {}'.format(num_params))
# 3.2 if specify checkpoint, load infos and params
# 3. if specify checkpoint, load infos and params
if self.checkpoint is not None and os.path.exists(self.checkpoint):
load_checkpoint(self.checkpoint, self.model)
info_path = re.sub('.pt$', '.yaml', self.checkpoint)
@@ -173,12 +185,13 @@ class KWSNearfieldTrainer(BaseTrainer):
logger.warning('Training with random initialized params')
infos = {}
self.start_epoch = infos.get('epoch', -1) + 1
configs['train']['start_epoch'] = self.start_epoch
self.configs['train']['start_epoch'] = self.start_epoch
lr_last_epoch = infos.get('lr', configs['train']['optimizer']['lr'])
configs['train']['optimizer']['lr'] = lr_last_epoch
lr_last_epoch = infos.get('lr',
self.configs['train']['optimizer']['lr'])
self.configs['train']['optimizer']['lr'] = lr_last_epoch
# 3.3 model placement
# 4. model placement
self.device_name = kwargs.get('device', 'gpu')
if self.world_size > 1:
self.device_name = f'cuda:{self.local_rank}'
@@ -192,17 +205,15 @@ class KWSNearfieldTrainer(BaseTrainer):
else:
self.model = self.model.to(self.device)
# 4. write config.yaml for inference and export
self.configs = configs
# 5. update training config file
if self.rank == 0:
if not os.path.exists(self.work_dir):
os.makedirs(self.work_dir)
saved_config_path = os.path.join(self.work_dir, 'config.yaml')
with open(saved_config_path, 'w') as fout:
data = yaml.dump(configs.to_dict())
data = yaml.dump(self.configs.to_dict())
fout.write(data)
def train(self, *args, **kwargs):
logger.info('Start training...')
writer = None
@@ -301,7 +312,7 @@ class KWSNearfieldTrainer(BaseTrainer):
os.environ['CUDA_VISIBLE_DEVICES'] will be setted
'''
# 1. get checkpoint
if checkpoint_path is not None and checkpoint_path != '':
if checkpoint_path is not None and os.path.exists(checkpoint_path):
logger.warning(
f'evaluating with specific model: {checkpoint_path}')
eval_checkpoint = checkpoint_path
@@ -326,7 +337,8 @@ class KWSNearfieldTrainer(BaseTrainer):
self.avg_checkpoint,
self.work_dir,
)
logger.warning(f'average convert to kaldi: {kaldi_cvt}')
logger.warning(
f'average model convert to kaldi network: {kaldi_cvt}')
eval_checkpoint = self.avg_checkpoint
logger.warning(

View File

@@ -15,12 +15,14 @@
import glob
import os
import threading
import json
import kaldiio
import matplotlib.font_manager as fm
import matplotlib.pyplot as plt
import numpy as np
import torchaudio
import torch
from modelscope.utils.logger import get_logger
from .file_utils import make_pair, read_lists
@@ -30,6 +32,51 @@ logger = get_logger()
font = fm.FontProperties(size=15)
class thread_wrapper(threading.Thread):
def __init__(self, func, args=()):
super(thread_wrapper, self).__init__()
self.func = func
self.args = args
self.result = []
def run(self):
self.result = self.func(*self.args)
def get_result(self):
try:
return self.result
except Exception:
return None
def count_duration(tid, data_lists):
results = []
for obj in data_lists:
assert 'key' in obj
assert 'wav' in obj
assert 'txt' in obj
# key = obj['key']
wav_file = obj['wav']
# txt = obj['txt']
try:
rate, waveform = kaldiio.load_mat(wav_file)
waveform = torch.tensor(waveform, dtype=torch.float32)
waveform = waveform.unsqueeze(0)
frames = len(waveform[0])
duration = frames / float(rate)
except Exception:
logging.info(f'load file failed: {wav_file}')
duration = 0.0
obj['duration'] = duration
results.append(obj)
return results
def load_data_and_score(keywords_list, data_file, trans_file, score_file):
# score_table: {uttid: [keywordlist]}
score_table = {}
@@ -54,6 +101,26 @@ def load_data_and_score(keywords_list, data_file, trans_file, score_file):
trans_lists = read_lists(trans_file)
data_lists = make_pair(wav_lists, trans_lists)
# count duration for each wave use multi-thread
num_workers = 8
start = 0
step = int(len(data_lists) / num_workers)
tasks = []
for idx in range(8):
if idx != num_workers - 1:
task = thread_wrapper(count_duration,
(idx, data_lists[start:start + step]))
else:
task = thread_wrapper(count_duration, (idx, data_lists[start:]))
task.start()
tasks.append(task)
start += step
duration_lists = []
for task in tasks:
task.join()
duration_lists += task.get_result()
# build empty structure for keyword-filler infos
keyword_filler_table = {}
for keyword in keywords_list:
@@ -63,35 +130,36 @@ def load_data_and_score(keywords_list, data_file, trans_file, score_file):
keyword_filler_table[keyword]['filler_table'] = {}
keyword_filler_table[keyword]['filler_duration'] = 0.0
for obj in data_lists:
for obj in duration_lists:
assert 'key' in obj
assert 'wav' in obj
assert 'txt' in obj
key = obj['key']
wav_file = obj['wav']
txt = obj['txt']
assert key in score_table
assert 'duration' in obj
waveform, rate = torchaudio.load(wav_file)
frames = len(waveform[0])
duration = frames / float(rate)
key = obj['key']
# wav_file = obj['wav']
txt = obj['txt']
duration = obj['duration']
assert key in score_table
for keyword in keywords_list:
if txt.find(keyword) != -1:
if keyword == score_table[key]['kw']:
keyword_filler_table[keyword]['keyword_table'].update(
{key: score_table[key]['confi']})
keyword_filler_table[keyword][
'keyword_duration'] += duration
else:
# uttrance detected but not match this keyword
keyword_filler_table[keyword]['keyword_table'].update(
{key: -1.0})
keyword_filler_table[keyword][
'keyword_duration'] += duration
keyword_filler_table[keyword]['keyword_duration'] += duration
else:
keyword_filler_table[keyword]['filler_table'].update(
{key: score_table[key]['confi']})
if keyword == score_table[key]['kw']:
keyword_filler_table[keyword]['filler_table'].update(
{key: score_table[key]['confi']})
else:
# uttrance if detected, which is not FA for this keyword
keyword_filler_table[keyword]['filler_table'].update(
{key: -1.0})
keyword_filler_table[keyword]['filler_duration'] += duration
return keyword_filler_table

View File

@@ -84,14 +84,16 @@ class TestKwsNearfieldTrainer(unittest.TestCase):
kwargs = dict(
model=self.model_id,
work_dir=self.tmp_dir,
cfg_file=self.config_file,
train_data=self.train_scp,
cv_data=self.cv_scp,
trans_data=self.trans_file)
cfg_file=self.config_file)
trainer = build_trainer(
Trainers.speech_kws_fsmn_char_ctc_nearfield, default_args=kwargs)
trainer.train()
kwargs = dict(
train_data=self.train_scp,
cv_data=self.cv_scp,
trans_data=self.trans_file)
trainer.train(**kwargs)
rank, _ = get_dist_info()
if rank == 0: