mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 03:59:23 +01:00
Split training and evaluating code for nearfield kws trainer
* fix judgement of fa case for certain keywords in det * split code so that train and evaluate can be single used * fix pre-commit errors Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11453810
This commit is contained in:
committed by
wenmeng.zwm
parent
af62b3e9ad
commit
e502e89c61
@@ -102,7 +102,7 @@ class KWSNearfieldTrainer(BaseTrainer):
|
||||
os.makedirs(self.work_dir)
|
||||
logger.info(f'Current working dir is {work_dir}')
|
||||
|
||||
# 2. prepare dataset and dataloader
|
||||
# 2. prepare preset files
|
||||
token_file = os.path.join(self.model_dir, 'train/tokens.txt')
|
||||
assert os.path.exists(token_file), f'{token_file} is missing'
|
||||
self.token_table = read_token(token_file)
|
||||
@@ -111,6 +111,24 @@ class KWSNearfieldTrainer(BaseTrainer):
|
||||
assert os.path.exists(lexicon_file), f'{lexicon_file} is missing'
|
||||
self.lexicon_table = read_lexicon(lexicon_file)
|
||||
|
||||
feature_transform_file = os.path.join(
|
||||
self.model_dir, 'train/feature_transform.txt.80dim-l2r2')
|
||||
assert os.path.exists(feature_transform_file), \
|
||||
f'{feature_transform_file} is missing'
|
||||
configs.model['cmvn_file'] = feature_transform_file
|
||||
|
||||
# 3. write config.yaml for inference
|
||||
self.configs = configs
|
||||
if self.rank == 0:
|
||||
if not os.path.exists(self.work_dir):
|
||||
os.makedirs(self.work_dir)
|
||||
saved_config_path = os.path.join(self.work_dir, 'config.yaml')
|
||||
with open(saved_config_path, 'w') as fout:
|
||||
data = yaml.dump(configs.to_dict())
|
||||
fout.write(data)
|
||||
|
||||
def train(self, *args, **kwargs):
|
||||
# 1. prepare dataset and dataloader
|
||||
assert kwargs['train_data'], 'please config train data in dict kwargs'
|
||||
assert kwargs['cv_data'], 'please config cv data in dict kwargs'
|
||||
assert kwargs[
|
||||
@@ -119,7 +137,7 @@ class KWSNearfieldTrainer(BaseTrainer):
|
||||
self.cv_data = kwargs['cv_data']
|
||||
self.trans_data = kwargs['trans_data']
|
||||
|
||||
train_conf = configs['preprocessor']
|
||||
train_conf = self.configs['preprocessor']
|
||||
cv_conf = copy.deepcopy(train_conf)
|
||||
cv_conf['speed_perturb'] = False
|
||||
cv_conf['spec_aug'] = False
|
||||
@@ -137,31 +155,25 @@ class KWSNearfieldTrainer(BaseTrainer):
|
||||
batch_size=None,
|
||||
pin_memory=kwargs.get('pin_memory', False),
|
||||
persistent_workers=True,
|
||||
num_workers=configs.train.dataloader.workers_per_gpu,
|
||||
prefetch_factor=configs.train.dataloader.get('prefetch', 2))
|
||||
num_workers=self.configs.train.dataloader.workers_per_gpu,
|
||||
prefetch_factor=self.configs.train.dataloader.get('prefetch', 2))
|
||||
self.cv_dataloader = DataLoader(
|
||||
self.cv_dataset,
|
||||
batch_size=None,
|
||||
pin_memory=kwargs.get('pin_memory', False),
|
||||
persistent_workers=True,
|
||||
num_workers=configs.evaluation.dataloader.workers_per_gpu,
|
||||
prefetch_factor=configs.evaluation.dataloader.get('prefetch', 2))
|
||||
num_workers=self.configs.evaluation.dataloader.workers_per_gpu,
|
||||
prefetch_factor=self.configs.evaluation.dataloader.get(
|
||||
'prefetch', 2))
|
||||
|
||||
# 3. build model, and load checkpoint
|
||||
feature_transform_file = os.path.join(
|
||||
self.model_dir, 'train/feature_transform.txt.80dim-l2r2')
|
||||
assert os.path.exists(feature_transform_file), \
|
||||
f'{feature_transform_file} is missing'
|
||||
configs.model['cmvn_file'] = feature_transform_file
|
||||
|
||||
# 3.1 Init kws model from configs
|
||||
self.model = self.build_model(configs)
|
||||
# 2. Init kws model from configs
|
||||
self.model = self.build_model(self.configs)
|
||||
num_params = count_parameters(self.model)
|
||||
if self.rank == 0:
|
||||
# print(model)
|
||||
logger.warning('the number of model params: {}'.format(num_params))
|
||||
|
||||
# 3.2 if specify checkpoint, load infos and params
|
||||
# 3. if specify checkpoint, load infos and params
|
||||
if self.checkpoint is not None and os.path.exists(self.checkpoint):
|
||||
load_checkpoint(self.checkpoint, self.model)
|
||||
info_path = re.sub('.pt$', '.yaml', self.checkpoint)
|
||||
@@ -173,12 +185,13 @@ class KWSNearfieldTrainer(BaseTrainer):
|
||||
logger.warning('Training with random initialized params')
|
||||
infos = {}
|
||||
self.start_epoch = infos.get('epoch', -1) + 1
|
||||
configs['train']['start_epoch'] = self.start_epoch
|
||||
self.configs['train']['start_epoch'] = self.start_epoch
|
||||
|
||||
lr_last_epoch = infos.get('lr', configs['train']['optimizer']['lr'])
|
||||
configs['train']['optimizer']['lr'] = lr_last_epoch
|
||||
lr_last_epoch = infos.get('lr',
|
||||
self.configs['train']['optimizer']['lr'])
|
||||
self.configs['train']['optimizer']['lr'] = lr_last_epoch
|
||||
|
||||
# 3.3 model placement
|
||||
# 4. model placement
|
||||
self.device_name = kwargs.get('device', 'gpu')
|
||||
if self.world_size > 1:
|
||||
self.device_name = f'cuda:{self.local_rank}'
|
||||
@@ -192,17 +205,15 @@ class KWSNearfieldTrainer(BaseTrainer):
|
||||
else:
|
||||
self.model = self.model.to(self.device)
|
||||
|
||||
# 4. write config.yaml for inference and export
|
||||
self.configs = configs
|
||||
# 5. update training config file
|
||||
if self.rank == 0:
|
||||
if not os.path.exists(self.work_dir):
|
||||
os.makedirs(self.work_dir)
|
||||
saved_config_path = os.path.join(self.work_dir, 'config.yaml')
|
||||
with open(saved_config_path, 'w') as fout:
|
||||
data = yaml.dump(configs.to_dict())
|
||||
data = yaml.dump(self.configs.to_dict())
|
||||
fout.write(data)
|
||||
|
||||
def train(self, *args, **kwargs):
|
||||
logger.info('Start training...')
|
||||
|
||||
writer = None
|
||||
@@ -301,7 +312,7 @@ class KWSNearfieldTrainer(BaseTrainer):
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] will be setted
|
||||
'''
|
||||
# 1. get checkpoint
|
||||
if checkpoint_path is not None and checkpoint_path != '':
|
||||
if checkpoint_path is not None and os.path.exists(checkpoint_path):
|
||||
logger.warning(
|
||||
f'evaluating with specific model: {checkpoint_path}')
|
||||
eval_checkpoint = checkpoint_path
|
||||
@@ -326,7 +337,8 @@ class KWSNearfieldTrainer(BaseTrainer):
|
||||
self.avg_checkpoint,
|
||||
self.work_dir,
|
||||
)
|
||||
logger.warning(f'average convert to kaldi: {kaldi_cvt}')
|
||||
logger.warning(
|
||||
f'average model convert to kaldi network: {kaldi_cvt}')
|
||||
|
||||
eval_checkpoint = self.avg_checkpoint
|
||||
logger.warning(
|
||||
|
||||
@@ -15,12 +15,14 @@
|
||||
|
||||
import glob
|
||||
import os
|
||||
import threading
|
||||
|
||||
import json
|
||||
import kaldiio
|
||||
import matplotlib.font_manager as fm
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torchaudio
|
||||
import torch
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .file_utils import make_pair, read_lists
|
||||
@@ -30,6 +32,51 @@ logger = get_logger()
|
||||
font = fm.FontProperties(size=15)
|
||||
|
||||
|
||||
class thread_wrapper(threading.Thread):
|
||||
|
||||
def __init__(self, func, args=()):
|
||||
super(thread_wrapper, self).__init__()
|
||||
self.func = func
|
||||
self.args = args
|
||||
self.result = []
|
||||
|
||||
def run(self):
|
||||
self.result = self.func(*self.args)
|
||||
|
||||
def get_result(self):
|
||||
try:
|
||||
return self.result
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def count_duration(tid, data_lists):
|
||||
results = []
|
||||
|
||||
for obj in data_lists:
|
||||
assert 'key' in obj
|
||||
assert 'wav' in obj
|
||||
assert 'txt' in obj
|
||||
# key = obj['key']
|
||||
wav_file = obj['wav']
|
||||
# txt = obj['txt']
|
||||
|
||||
try:
|
||||
rate, waveform = kaldiio.load_mat(wav_file)
|
||||
waveform = torch.tensor(waveform, dtype=torch.float32)
|
||||
waveform = waveform.unsqueeze(0)
|
||||
frames = len(waveform[0])
|
||||
duration = frames / float(rate)
|
||||
except Exception:
|
||||
logging.info(f'load file failed: {wav_file}')
|
||||
duration = 0.0
|
||||
|
||||
obj['duration'] = duration
|
||||
results.append(obj)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def load_data_and_score(keywords_list, data_file, trans_file, score_file):
|
||||
# score_table: {uttid: [keywordlist]}
|
||||
score_table = {}
|
||||
@@ -54,6 +101,26 @@ def load_data_and_score(keywords_list, data_file, trans_file, score_file):
|
||||
trans_lists = read_lists(trans_file)
|
||||
data_lists = make_pair(wav_lists, trans_lists)
|
||||
|
||||
# count duration for each wave use multi-thread
|
||||
num_workers = 8
|
||||
start = 0
|
||||
step = int(len(data_lists) / num_workers)
|
||||
tasks = []
|
||||
for idx in range(8):
|
||||
if idx != num_workers - 1:
|
||||
task = thread_wrapper(count_duration,
|
||||
(idx, data_lists[start:start + step]))
|
||||
else:
|
||||
task = thread_wrapper(count_duration, (idx, data_lists[start:]))
|
||||
task.start()
|
||||
tasks.append(task)
|
||||
start += step
|
||||
|
||||
duration_lists = []
|
||||
for task in tasks:
|
||||
task.join()
|
||||
duration_lists += task.get_result()
|
||||
|
||||
# build empty structure for keyword-filler infos
|
||||
keyword_filler_table = {}
|
||||
for keyword in keywords_list:
|
||||
@@ -63,35 +130,36 @@ def load_data_and_score(keywords_list, data_file, trans_file, score_file):
|
||||
keyword_filler_table[keyword]['filler_table'] = {}
|
||||
keyword_filler_table[keyword]['filler_duration'] = 0.0
|
||||
|
||||
for obj in data_lists:
|
||||
for obj in duration_lists:
|
||||
assert 'key' in obj
|
||||
assert 'wav' in obj
|
||||
assert 'txt' in obj
|
||||
key = obj['key']
|
||||
wav_file = obj['wav']
|
||||
txt = obj['txt']
|
||||
assert key in score_table
|
||||
assert 'duration' in obj
|
||||
|
||||
waveform, rate = torchaudio.load(wav_file)
|
||||
frames = len(waveform[0])
|
||||
duration = frames / float(rate)
|
||||
key = obj['key']
|
||||
# wav_file = obj['wav']
|
||||
txt = obj['txt']
|
||||
duration = obj['duration']
|
||||
assert key in score_table
|
||||
|
||||
for keyword in keywords_list:
|
||||
if txt.find(keyword) != -1:
|
||||
if keyword == score_table[key]['kw']:
|
||||
keyword_filler_table[keyword]['keyword_table'].update(
|
||||
{key: score_table[key]['confi']})
|
||||
keyword_filler_table[keyword][
|
||||
'keyword_duration'] += duration
|
||||
else:
|
||||
# uttrance detected but not match this keyword
|
||||
keyword_filler_table[keyword]['keyword_table'].update(
|
||||
{key: -1.0})
|
||||
keyword_filler_table[keyword][
|
||||
'keyword_duration'] += duration
|
||||
keyword_filler_table[keyword]['keyword_duration'] += duration
|
||||
else:
|
||||
keyword_filler_table[keyword]['filler_table'].update(
|
||||
{key: score_table[key]['confi']})
|
||||
if keyword == score_table[key]['kw']:
|
||||
keyword_filler_table[keyword]['filler_table'].update(
|
||||
{key: score_table[key]['confi']})
|
||||
else:
|
||||
# uttrance if detected, which is not FA for this keyword
|
||||
keyword_filler_table[keyword]['filler_table'].update(
|
||||
{key: -1.0})
|
||||
keyword_filler_table[keyword]['filler_duration'] += duration
|
||||
|
||||
return keyword_filler_table
|
||||
|
||||
@@ -84,14 +84,16 @@ class TestKwsNearfieldTrainer(unittest.TestCase):
|
||||
kwargs = dict(
|
||||
model=self.model_id,
|
||||
work_dir=self.tmp_dir,
|
||||
cfg_file=self.config_file,
|
||||
train_data=self.train_scp,
|
||||
cv_data=self.cv_scp,
|
||||
trans_data=self.trans_file)
|
||||
cfg_file=self.config_file)
|
||||
|
||||
trainer = build_trainer(
|
||||
Trainers.speech_kws_fsmn_char_ctc_nearfield, default_args=kwargs)
|
||||
trainer.train()
|
||||
|
||||
kwargs = dict(
|
||||
train_data=self.train_scp,
|
||||
cv_data=self.cv_scp,
|
||||
trans_data=self.trans_file)
|
||||
trainer.train(**kwargs)
|
||||
|
||||
rank, _ = get_dist_info()
|
||||
if rank == 0:
|
||||
|
||||
Reference in New Issue
Block a user