diff --git a/modelscope/trainers/audio/kws_nearfield_trainer.py b/modelscope/trainers/audio/kws_nearfield_trainer.py index ba3f5f5f..bf00c435 100644 --- a/modelscope/trainers/audio/kws_nearfield_trainer.py +++ b/modelscope/trainers/audio/kws_nearfield_trainer.py @@ -102,7 +102,7 @@ class KWSNearfieldTrainer(BaseTrainer): os.makedirs(self.work_dir) logger.info(f'Current working dir is {work_dir}') - # 2. prepare dataset and dataloader + # 2. prepare preset files token_file = os.path.join(self.model_dir, 'train/tokens.txt') assert os.path.exists(token_file), f'{token_file} is missing' self.token_table = read_token(token_file) @@ -111,6 +111,24 @@ class KWSNearfieldTrainer(BaseTrainer): assert os.path.exists(lexicon_file), f'{lexicon_file} is missing' self.lexicon_table = read_lexicon(lexicon_file) + feature_transform_file = os.path.join( + self.model_dir, 'train/feature_transform.txt.80dim-l2r2') + assert os.path.exists(feature_transform_file), \ + f'{feature_transform_file} is missing' + configs.model['cmvn_file'] = feature_transform_file + + # 3. write config.yaml for inference + self.configs = configs + if self.rank == 0: + if not os.path.exists(self.work_dir): + os.makedirs(self.work_dir) + saved_config_path = os.path.join(self.work_dir, 'config.yaml') + with open(saved_config_path, 'w') as fout: + data = yaml.dump(configs.to_dict()) + fout.write(data) + + def train(self, *args, **kwargs): + # 1. prepare dataset and dataloader assert kwargs['train_data'], 'please config train data in dict kwargs' assert kwargs['cv_data'], 'please config cv data in dict kwargs' assert kwargs[ @@ -119,7 +137,7 @@ class KWSNearfieldTrainer(BaseTrainer): self.cv_data = kwargs['cv_data'] self.trans_data = kwargs['trans_data'] - train_conf = configs['preprocessor'] + train_conf = self.configs['preprocessor'] cv_conf = copy.deepcopy(train_conf) cv_conf['speed_perturb'] = False cv_conf['spec_aug'] = False @@ -137,31 +155,25 @@ class KWSNearfieldTrainer(BaseTrainer): batch_size=None, pin_memory=kwargs.get('pin_memory', False), persistent_workers=True, - num_workers=configs.train.dataloader.workers_per_gpu, - prefetch_factor=configs.train.dataloader.get('prefetch', 2)) + num_workers=self.configs.train.dataloader.workers_per_gpu, + prefetch_factor=self.configs.train.dataloader.get('prefetch', 2)) self.cv_dataloader = DataLoader( self.cv_dataset, batch_size=None, pin_memory=kwargs.get('pin_memory', False), persistent_workers=True, - num_workers=configs.evaluation.dataloader.workers_per_gpu, - prefetch_factor=configs.evaluation.dataloader.get('prefetch', 2)) + num_workers=self.configs.evaluation.dataloader.workers_per_gpu, + prefetch_factor=self.configs.evaluation.dataloader.get( + 'prefetch', 2)) - # 3. build model, and load checkpoint - feature_transform_file = os.path.join( - self.model_dir, 'train/feature_transform.txt.80dim-l2r2') - assert os.path.exists(feature_transform_file), \ - f'{feature_transform_file} is missing' - configs.model['cmvn_file'] = feature_transform_file - - # 3.1 Init kws model from configs - self.model = self.build_model(configs) + # 2. Init kws model from configs + self.model = self.build_model(self.configs) num_params = count_parameters(self.model) if self.rank == 0: # print(model) logger.warning('the number of model params: {}'.format(num_params)) - # 3.2 if specify checkpoint, load infos and params + # 3. if specify checkpoint, load infos and params if self.checkpoint is not None and os.path.exists(self.checkpoint): load_checkpoint(self.checkpoint, self.model) info_path = re.sub('.pt$', '.yaml', self.checkpoint) @@ -173,12 +185,13 @@ class KWSNearfieldTrainer(BaseTrainer): logger.warning('Training with random initialized params') infos = {} self.start_epoch = infos.get('epoch', -1) + 1 - configs['train']['start_epoch'] = self.start_epoch + self.configs['train']['start_epoch'] = self.start_epoch - lr_last_epoch = infos.get('lr', configs['train']['optimizer']['lr']) - configs['train']['optimizer']['lr'] = lr_last_epoch + lr_last_epoch = infos.get('lr', + self.configs['train']['optimizer']['lr']) + self.configs['train']['optimizer']['lr'] = lr_last_epoch - # 3.3 model placement + # 4. model placement self.device_name = kwargs.get('device', 'gpu') if self.world_size > 1: self.device_name = f'cuda:{self.local_rank}' @@ -192,17 +205,15 @@ class KWSNearfieldTrainer(BaseTrainer): else: self.model = self.model.to(self.device) - # 4. write config.yaml for inference and export - self.configs = configs + # 5. update training config file if self.rank == 0: if not os.path.exists(self.work_dir): os.makedirs(self.work_dir) saved_config_path = os.path.join(self.work_dir, 'config.yaml') with open(saved_config_path, 'w') as fout: - data = yaml.dump(configs.to_dict()) + data = yaml.dump(self.configs.to_dict()) fout.write(data) - def train(self, *args, **kwargs): logger.info('Start training...') writer = None @@ -301,7 +312,7 @@ class KWSNearfieldTrainer(BaseTrainer): os.environ['CUDA_VISIBLE_DEVICES'] will be setted ''' # 1. get checkpoint - if checkpoint_path is not None and checkpoint_path != '': + if checkpoint_path is not None and os.path.exists(checkpoint_path): logger.warning( f'evaluating with specific model: {checkpoint_path}') eval_checkpoint = checkpoint_path @@ -326,7 +337,8 @@ class KWSNearfieldTrainer(BaseTrainer): self.avg_checkpoint, self.work_dir, ) - logger.warning(f'average convert to kaldi: {kaldi_cvt}') + logger.warning( + f'average model convert to kaldi network: {kaldi_cvt}') eval_checkpoint = self.avg_checkpoint logger.warning( diff --git a/modelscope/trainers/audio/kws_utils/det_utils.py b/modelscope/trainers/audio/kws_utils/det_utils.py index 97b0c2de..ee6710f7 100644 --- a/modelscope/trainers/audio/kws_utils/det_utils.py +++ b/modelscope/trainers/audio/kws_utils/det_utils.py @@ -15,12 +15,14 @@ import glob import os +import threading import json +import kaldiio import matplotlib.font_manager as fm import matplotlib.pyplot as plt import numpy as np -import torchaudio +import torch from modelscope.utils.logger import get_logger from .file_utils import make_pair, read_lists @@ -30,6 +32,51 @@ logger = get_logger() font = fm.FontProperties(size=15) +class thread_wrapper(threading.Thread): + + def __init__(self, func, args=()): + super(thread_wrapper, self).__init__() + self.func = func + self.args = args + self.result = [] + + def run(self): + self.result = self.func(*self.args) + + def get_result(self): + try: + return self.result + except Exception: + return None + + +def count_duration(tid, data_lists): + results = [] + + for obj in data_lists: + assert 'key' in obj + assert 'wav' in obj + assert 'txt' in obj + # key = obj['key'] + wav_file = obj['wav'] + # txt = obj['txt'] + + try: + rate, waveform = kaldiio.load_mat(wav_file) + waveform = torch.tensor(waveform, dtype=torch.float32) + waveform = waveform.unsqueeze(0) + frames = len(waveform[0]) + duration = frames / float(rate) + except Exception: + logging.info(f'load file failed: {wav_file}') + duration = 0.0 + + obj['duration'] = duration + results.append(obj) + + return results + + def load_data_and_score(keywords_list, data_file, trans_file, score_file): # score_table: {uttid: [keywordlist]} score_table = {} @@ -54,6 +101,26 @@ def load_data_and_score(keywords_list, data_file, trans_file, score_file): trans_lists = read_lists(trans_file) data_lists = make_pair(wav_lists, trans_lists) + # count duration for each wave use multi-thread + num_workers = 8 + start = 0 + step = int(len(data_lists) / num_workers) + tasks = [] + for idx in range(8): + if idx != num_workers - 1: + task = thread_wrapper(count_duration, + (idx, data_lists[start:start + step])) + else: + task = thread_wrapper(count_duration, (idx, data_lists[start:])) + task.start() + tasks.append(task) + start += step + + duration_lists = [] + for task in tasks: + task.join() + duration_lists += task.get_result() + # build empty structure for keyword-filler infos keyword_filler_table = {} for keyword in keywords_list: @@ -63,35 +130,36 @@ def load_data_and_score(keywords_list, data_file, trans_file, score_file): keyword_filler_table[keyword]['filler_table'] = {} keyword_filler_table[keyword]['filler_duration'] = 0.0 - for obj in data_lists: + for obj in duration_lists: assert 'key' in obj assert 'wav' in obj assert 'txt' in obj - key = obj['key'] - wav_file = obj['wav'] - txt = obj['txt'] - assert key in score_table + assert 'duration' in obj - waveform, rate = torchaudio.load(wav_file) - frames = len(waveform[0]) - duration = frames / float(rate) + key = obj['key'] + # wav_file = obj['wav'] + txt = obj['txt'] + duration = obj['duration'] + assert key in score_table for keyword in keywords_list: if txt.find(keyword) != -1: if keyword == score_table[key]['kw']: keyword_filler_table[keyword]['keyword_table'].update( {key: score_table[key]['confi']}) - keyword_filler_table[keyword][ - 'keyword_duration'] += duration else: # uttrance detected but not match this keyword keyword_filler_table[keyword]['keyword_table'].update( {key: -1.0}) - keyword_filler_table[keyword][ - 'keyword_duration'] += duration + keyword_filler_table[keyword]['keyword_duration'] += duration else: - keyword_filler_table[keyword]['filler_table'].update( - {key: score_table[key]['confi']}) + if keyword == score_table[key]['kw']: + keyword_filler_table[keyword]['filler_table'].update( + {key: score_table[key]['confi']}) + else: + # uttrance if detected, which is not FA for this keyword + keyword_filler_table[keyword]['filler_table'].update( + {key: -1.0}) keyword_filler_table[keyword]['filler_duration'] += duration return keyword_filler_table diff --git a/tests/trainers/audio/test_kws_nearfield_trainer.py b/tests/trainers/audio/test_kws_nearfield_trainer.py index a61f70bf..af434048 100644 --- a/tests/trainers/audio/test_kws_nearfield_trainer.py +++ b/tests/trainers/audio/test_kws_nearfield_trainer.py @@ -84,14 +84,16 @@ class TestKwsNearfieldTrainer(unittest.TestCase): kwargs = dict( model=self.model_id, work_dir=self.tmp_dir, - cfg_file=self.config_file, - train_data=self.train_scp, - cv_data=self.cv_scp, - trans_data=self.trans_file) + cfg_file=self.config_file) trainer = build_trainer( Trainers.speech_kws_fsmn_char_ctc_nearfield, default_args=kwargs) - trainer.train() + + kwargs = dict( + train_data=self.train_scp, + cv_data=self.cv_scp, + trans_data=self.trans_file) + trainer.train(**kwargs) rank, _ = get_dist_info() if rank == 0: