From 46850d53a8ad2600ec495c0d5ec3685dc3a805e6 Mon Sep 17 00:00:00 2001 From: "pengteng.spt" Date: Wed, 29 Mar 2023 20:01:25 +0800 Subject: [PATCH] Fix speech kws nearfield training with multi-gpu Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/12117620 --- .../audio/kws_nearfield_dataset.py | 40 ++- .../audio/kws_nearfield_processor.py | 57 +--- .../trainers/audio/kws_nearfield_trainer.py | 180 ++++++++----- .../trainers/audio/kws_utils/batch_utils.py | 255 ++++++++++++++++-- .../trainers/audio/kws_utils/det_utils.py | 30 ++- .../trainers/audio/kws_utils/file_utils.py | 157 ++++++++--- modelscope/utils/checkpoint.py | 3 + 7 files changed, 537 insertions(+), 185 deletions(-) diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/audio/kws_nearfield_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/audio/kws_nearfield_dataset.py index 1b784410..2688d7ca 100644 --- a/modelscope/msdatasets/dataset_cls/custom_datasets/audio/kws_nearfield_dataset.py +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/audio/kws_nearfield_dataset.py @@ -20,7 +20,8 @@ from torch.utils.data import IterableDataset import modelscope.msdatasets.dataset_cls.custom_datasets.audio.kws_nearfield_processor as processor from modelscope.trainers.audio.kws_utils.file_utils import (make_pair, - read_lists) + read_lists, + tokenize) from modelscope.utils.logger import get_logger logger = get_logger() @@ -119,12 +120,41 @@ class DataList(IterableDataset): data.update(sampler_info) yield data + def dump(self, dump_file): + with open(dump_file, 'w', encoding='utf8') as fout: + for obj in self.lists: + if hasattr(obj, 'get') and obj.get('tokens', None) is not None: + assert 'key' in obj + assert 'wav' in obj + assert 'txt' in obj + assert len(obj['tokens']) == len(obj['txt']) + dump_line = obj['key'] + ':\n' + dump_line += '\t' + obj['wav'] + '\n' + dump_line += '\t' + for token, idx in zip(obj['tokens'], obj['txt']): + dump_line += '%s(%d) ' % (token, idx) + dump_line += '\n\n' + fout.write(dump_line) + else: + infos = json.loads(obj) + assert 'key' in infos + assert 'wav' in infos + assert 'txt' in infos + dump_line = infos['key'] + ':\n' + dump_line += '\t' + infos['wav'] + '\n' + dump_line += '\t' + dump_line += '%d' % infos['txt'] + dump_line += '\n\n' + fout.write(dump_line) + def kws_nearfield_dataset(data_file, trans_file, conf, symbol_table, lexicon_table, + need_dump=False, + dump_file='', partition=True): """ Construct dataset from arguments @@ -137,6 +167,8 @@ def kws_nearfield_dataset(data_file, trans_file (str): transcription list with kaldi style symbol_table (Dict): token list, [token_str, token_id] lexicon_table (Dict): words list defined with basic tokens + need_dump (bool): whether to dump data with mapping tokens or not + dump_file (str): dumping file path partition (bool): whether to do data partition in terms of rank """ @@ -146,14 +178,14 @@ def kws_nearfield_dataset(data_file, wav_lists = read_lists(data_file) trans_lists = read_lists(trans_file) lists = make_pair(wav_lists, trans_lists) + lists = tokenize(lists, symbol_table, lexicon_table) shuffle = conf.get('shuffle', True) dataset = DataList(lists, shuffle=shuffle, partition=partition) + if need_dump: + dataset.dump(dump_file) dataset = Processor(dataset, processor.parse_wav) - dataset = Processor(dataset, processor.tokenize, symbol_table, - lexicon_table, conf.get('split_with_space', False)) - dataset = Processor(dataset, processor.filter, **filter_conf) feature_extraction_conf = conf.get('feature_extraction_conf', {}) diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/audio/kws_nearfield_processor.py b/modelscope/msdatasets/dataset_cls/custom_datasets/audio/kws_nearfield_processor.py index d27c9e38..0363cc82 100644 --- a/modelscope/msdatasets/dataset_cls/custom_datasets/audio/kws_nearfield_processor.py +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/audio/kws_nearfield_processor.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging import random import json @@ -24,6 +23,9 @@ import torchaudio.compliance.kaldi as kaldi from torch.nn.utils.rnn import pad_sequence # torch.set_printoptions(profile="full") +from modelscope.utils.logger import get_logger + +logger = get_logger() def parse_wav(data): @@ -53,54 +55,7 @@ def parse_wav(data): key=key, label=txt, wav=waveform, sample_rate=sample_rate) yield example except Exception: - logging.warning('Failed to read {}'.format(wav_file)) - - -def tokenize(data, token_table, lexicon_table, split_with_space=False): - """ Decode text to chars - Inplace operation - - Args: - data: Iterable[{key, wav, txt, sample_rate}] - token_table (Dict): token list, [token_str, token_id] - lexicon_table (Dict): words list defined with basic tokens - split_with_space (bool): if transciption split with space or not - - Returns: - Iterable[{key, wav, txt, tokens, label, sample_rate}] - """ - for sample in data: - assert 'label' in sample - txt = sample['label'].strip() - - if token_table is None or lexicon_table is None: - # to compatible with hard token map for max-pooling loss - label = int(txt) - else: - parts = [txt] - tokens = [] - for part in parts: - if split_with_space: - part = part.split(' ') - for ch in part: - if ch == ' ': - ch = '▁' - tokens.append(ch) - - label = [] - for ch in tokens: - if ch in lexicon_table: - for sub_ch in lexicon_table[ch]: - if sub_ch in token_table: - label.append(token_table[sub_ch]) - else: - label.append(token_table['']) - else: - label.append(token_table['']) - - sample['tokens'] = tokens - sample['label'] = label - yield sample + logger.warning('Failed to read {}'.format(wav_file)) def filter(data, max_length=10240, min_length=10): @@ -128,11 +83,11 @@ def filter(data, max_length=10240, min_length=10): # print("{} num frames is {}".format(sample['key'], num_frames)) if num_frames < min_length: - logging.warning('{} is discard for too short: {} frames'.format( + logger.warning('{} is discard for too short: {} frames'.format( sample['key'], num_frames)) continue if num_frames > max_length: - logging.warning('{} is discard for too long: {} frames'.format( + logger.warning('{} is discard for too long: {} frames'.format( sample['key'], num_frames)) continue yield sample diff --git a/modelscope/trainers/audio/kws_nearfield_trainer.py b/modelscope/trainers/audio/kws_nearfield_trainer.py index 5e63e87e..9a84ce35 100644 --- a/modelscope/trainers/audio/kws_nearfield_trainer.py +++ b/modelscope/trainers/audio/kws_nearfield_trainer.py @@ -6,6 +6,7 @@ import re from typing import Callable, Dict, Optional import torch +import torch.distributed as dist import yaml from tensorboardX import SummaryWriter from torch import nn as nn @@ -23,11 +24,10 @@ from modelscope.utils.config import Config from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile from modelscope.utils.device import create_device from modelscope.utils.logger import get_logger -from modelscope.utils.torch_utils import (get_dist_info, get_local_rank, - init_dist, set_random_seed) +from modelscope.utils.torch_utils import set_random_seed from .kws_utils.batch_utils import executor_cv, executor_test, executor_train from .kws_utils.det_utils import compute_det -from .kws_utils.file_utils import query_tokens_id, read_lexicon, read_token +from .kws_utils.file_utils import query_token_set, read_lexicon, read_token from .kws_utils.model_utils import (average_model, convert_to_kaldi, count_parameters) @@ -47,14 +47,11 @@ class KWSNearfieldTrainer(BaseTrainer): **kwargs): ''' Args: - work_dir (str): main directory for training + model (str): model id in modelscope + work_dir (str): main directory for training and evaluating + cfg_file (str): config file for training and evaluating kwargs: - checkpoint (str): basemodel checkpoint, if None, default to use base.pt in model path - train_data (int): wave list with kaldi style for training - cv_data (int): wave list with kaldi style for cross validation - trans_data (str): transcription list with kaldi style, merge train and cv - tensorboard_dir (str): path to save tensorboard results, - create 'tensorboard_dir' in work_dir by default + seed (int): random seed ''' if isinstance(model, str): self.model_dir = self.get_or_download_model_dir( @@ -69,20 +66,12 @@ class KWSNearfieldTrainer(BaseTrainer): super().__init__(cfg_file, arg_parse_fn) configs = Config.from_file(cfg_file) - print(kwargs) self.launcher = 'pytorch' self.dist_backend = configs.train.get('dist_backend', 'nccl') - self.tensorboard_dir = kwargs.get('tensorboard_dir', 'tensorboard') - self.checkpoint = kwargs.get( - 'checkpoint', os.path.join(self.model_dir, 'train/base.pt')) - self.avg_checkpoint = None # 1. get rank info set_random_seed(kwargs.get('seed', 666)) - self.get_dist_info() - logger.info('RANK {}/{}/{}, Master addr:{}, Master port:{}'.format( - self.world_size, self.rank, self.local_rank, self.master_addr, - self.master_port)) + self.init_dist() self.work_dir = work_dir if self.rank == 0: @@ -116,6 +105,24 @@ class KWSNearfieldTrainer(BaseTrainer): fout.write(data) def train(self, *args, **kwargs): + ''' + Args: + kwargs: + train_data (int): wave list with kaldi style for training + cv_data (int): wave list with kaldi style for cross validation + trans_data (str): transcription list with kaldi style, merge train and cv + checkpoint (str): basemodel checkpoint, if None, default to use base.pt in model path + tensorboard_dir (str): path to save tensorboard results, + create 'tensorboard_dir' in work_dir by default + need_dump (bool): wether to dump data with mapping tokens or not + ''' + train_checkpoint = kwargs.get('checkpoint', None) + if train_checkpoint is not None and os.path.exists(train_checkpoint): + self.checkpoint = train_checkpoint + else: + self.checkpoint = os.path.join(self.model_dir, 'train/base.pt') + self.tensorboard_dir = kwargs.get('tensorboard_dir', 'tensorboard') + # 1. prepare dataset and dataloader assert kwargs['train_data'], 'please config train data in dict kwargs' assert kwargs['cv_data'], 'please config cv data in dict kwargs' @@ -124,32 +131,36 @@ class KWSNearfieldTrainer(BaseTrainer): self.train_data = kwargs['train_data'] self.cv_data = kwargs['cv_data'] self.trans_data = kwargs['trans_data'] + self.need_dump = kwargs.get( + 'need_dump', False) and (True if self.rank == 0 else False) train_conf = self.configs['preprocessor'] cv_conf = copy.deepcopy(train_conf) cv_conf['speed_perturb'] = False cv_conf['spec_aug'] = False cv_conf['shuffle'] = False - self.train_dataset = kws_nearfield_dataset(self.train_data, - self.trans_data, train_conf, - self.token_table, - self.lexicon_table, True) + + dump_train_file = os.path.join(self.work_dir, 'dump_train.txt') + dump_cv_file = os.path.join(self.work_dir, 'dump_cv.txt') + self.train_dataset = kws_nearfield_dataset( + self.train_data, self.trans_data, train_conf, self.token_table, + self.lexicon_table, self.need_dump, dump_train_file, True) self.cv_dataset = kws_nearfield_dataset(self.cv_data, self.trans_data, cv_conf, self.token_table, - self.lexicon_table, True) + self.lexicon_table, + self.need_dump, dump_cv_file, + True) self.train_dataloader = DataLoader( self.train_dataset, batch_size=None, pin_memory=kwargs.get('pin_memory', False), - persistent_workers=True, num_workers=self.configs.train.dataloader.workers_per_gpu, prefetch_factor=self.configs.train.dataloader.get('prefetch', 2)) self.cv_dataloader = DataLoader( self.cv_dataset, batch_size=None, pin_memory=kwargs.get('pin_memory', False), - persistent_workers=True, num_workers=self.configs.evaluation.dataloader.workers_per_gpu, prefetch_factor=self.configs.evaluation.dataloader.get( 'prefetch', 2)) @@ -180,18 +191,18 @@ class KWSNearfieldTrainer(BaseTrainer): self.configs['train']['optimizer']['lr'] = lr_last_epoch # 4. model placement - self.device_name = kwargs.get('device', 'gpu') + device_name = kwargs.get('device', 'gpu') if self.world_size > 1: - self.device_name = f'cuda:{self.local_rank}' - self.device = create_device(self.device_name) + device_name = f'cuda:{self.local_rank}' + self.train_device = create_device(device_name) if self.world_size > 1: assert (torch.cuda.is_available()) # cuda model is required for nn.parallel.DistributedDataParallel - self.model.cuda() + self.model = self.model.to(self.train_device) self.model = torch.nn.parallel.DistributedDataParallel(self.model) else: - self.model = self.model.to(self.device) + self.model = self.model.to(self.train_device) # 5. update training config file if self.rank == 0: @@ -230,7 +241,7 @@ class KWSNearfieldTrainer(BaseTrainer): if self.start_epoch == 0 and self.rank == 0: save_model_path = os.path.join(self.work_dir, 'init.pt') save_checkpoint(self.model, save_model_path, None, None, None, - False) + False, True) # Start training loop logger.info('Start training...') @@ -250,17 +261,18 @@ class KWSNearfieldTrainer(BaseTrainer): lr = optimizer.param_groups[0]['lr'] logger.info('Epoch {} TRAIN info lr {}'.format(epoch, lr)) executor_train(self.model, optimizer, self.train_dataloader, - self.device, writer, training_config) - cv_loss = executor_cv(self.model, self.cv_dataloader, self.device, - training_config) - logger.info('Epoch {} EVAL info cv_loss {:.6f}'.format( - epoch, cv_loss)) + self.train_device, writer, training_config) + cv_loss, cv_acc = executor_cv(self.model, self.cv_dataloader, + self.train_device, training_config) + logger.info( + 'Epoch {} EVAL info cv_loss {:.6f}, cv_acc {:.2f}'.format( + epoch, cv_loss, cv_acc)) if self.rank == 0: save_model_path = os.path.join(self.work_dir, '{}.pt'.format(epoch)) save_checkpoint(self.model, save_model_path, None, None, None, - False) + False, True) info_path = re.sub('.pt$', '.yaml', save_model_path) info_dict = dict( @@ -274,6 +286,7 @@ class KWSNearfieldTrainer(BaseTrainer): writer.add_scalar('epoch/cv_loss', cv_loss, epoch) writer.add_scalar('epoch/lr', lr, epoch) + final_epoch = epoch lr_scheduler.step(cv_loss) @@ -296,17 +309,18 @@ class KWSNearfieldTrainer(BaseTrainer): average_num (int): the NO. to do model averaging(checkpoint_path==None) batch_size (int): batch size during evaluating keywords (str): keyword string, split with ',' - gpu (int): evaluating with cpu/gpu: -1 for cpu; >=0 for gpu, - os.environ['CUDA_VISIBLE_DEVICES'] will be setted + gpu (int): evaluating with cpu/gpu: -1 for cpu; >=0 for gpu ''' + # 1. get checkpoint + self.avg_checkpoint = None if checkpoint_path is not None and os.path.exists(checkpoint_path): logger.warning( f'evaluating with specific model: {checkpoint_path}') eval_checkpoint = checkpoint_path else: if self.avg_checkpoint is None: - avg_num = kwargs.get('average_num', 5) + avg_num = kwargs.get('average_num', 10) self.avg_checkpoint = os.path.join(self.work_dir, f'avg_{avg_num}.pt') logger.warning( @@ -353,13 +367,13 @@ class KWSNearfieldTrainer(BaseTrainer): test_conf['speed_perturb'] = False test_conf['spec_aug'] = False test_conf['shuffle'] = False - test_conf['feature_extraction_conf']['dither'] = 0.0 if kwargs.get('batch_size', None) is not None: test_conf['batch_conf']['batch_size'] = kwargs['batch_size'] test_dataset = kws_nearfield_dataset(test_data, trans_data, test_conf, self.token_table, - self.lexicon_table, False) + self.lexicon_table, False, '', + False) test_dataloader = DataLoader( test_dataset, batch_size=None, @@ -375,33 +389,39 @@ class KWSNearfieldTrainer(BaseTrainer): keywords_str = kwargs['keywords'] keywords_list = keywords_str.strip().replace(' ', '').split(',') keywords_token = {} - keywords_tokenset = {0} + keywords_idxset = {0} + keywords_strset = {''} + keywords_tokenmap = {'': 0} for keyword in keywords_list: - ids = query_tokens_id(keyword, self.token_table, - self.lexicon_table) + strs, indexes = query_token_set(keyword, self.token_table, + self.lexicon_table) keywords_token[keyword] = {} - keywords_token[keyword]['token_id'] = ids + keywords_token[keyword]['token_id'] = indexes keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i) - for i in ids) - [keywords_tokenset.add(i) for i in ids] - logger.warning(f'Token set is: {keywords_tokenset}') + for i in indexes) + [keywords_strset.add(i) for i in strs] + [keywords_idxset.add(i) for i in indexes] + for txt, idx in zip(strs, indexes): + if keywords_tokenmap.get(txt, None) is None: + keywords_tokenmap[txt] = idx + + token_print = '' + for txt, idx in keywords_tokenmap.items(): + token_print += f'{txt}({idx}) ' + logger.warning(f'Token set is: {token_print}') # 5. build model and load checkpoint # support assign specific gpu device - os.environ['CUDA_VISIBLE_DEVICES'] = str(kwargs.get('gpu', -1)) + # Init kws model from configs use_cuda = kwargs.get('gpu', -1) >= 0 and torch.cuda.is_available() + device_name = kwargs.get('device', 'cpu') + if self.world_size > 1 and use_cuda: + device_name = f'cuda:{self.local_rank}' + self.test_device = create_device(device_name) - if kwargs.get('jit_model', None): - model = torch.jit.load(eval_checkpoint) - # For script model, only cpu is supported. - device = torch.device('cpu') - else: - # Init kws model from configs - model = self.build_model(self.configs) - load_checkpoint(eval_checkpoint, model) - device = torch.device('cuda' if use_cuda else 'cpu') - model = model.to(device) - model.eval() + self.test_model = self.build_model(self.configs) + load_checkpoint(eval_checkpoint, self.test_model) + self.test_model = self.test_model.to(self.test_device) testing_config = {} if kwargs.get('test_dir', None) is not None: @@ -417,9 +437,9 @@ class KWSNearfieldTrainer(BaseTrainer): # 6. executing evaluation and get score file logger.info('Start evaluating...') totaltime = datetime.datetime.now() - score_file = executor_test(model, test_dataloader, device, - keywords_token, keywords_tokenset, - testing_config) + score_file = executor_test(self.test_model, test_dataloader, + self.test_device, keywords_token, + keywords_idxset, testing_config) totaltime = datetime.datetime.now() - totaltime logger.info('Total time spent: {:.2f} hours'.format( totaltime.total_seconds() / 3600.0)) @@ -448,7 +468,7 @@ class KWSNearfieldTrainer(BaseTrainer): elif isinstance(model, nn.Module): return model - def get_dist_info(self): + def init_dist(self, train_nodes=1): if os.getenv('RANK', None) is None: os.environ['RANK'] = '0' if os.getenv('LOCAL_RANK', None) is None: @@ -466,6 +486,28 @@ class KWSNearfieldTrainer(BaseTrainer): self.master_addr = os.environ['MASTER_ADDR'] self.master_port = os.environ['MASTER_PORT'] - init_dist(self.launcher, self.dist_backend) - self.rank, self.world_size = get_dist_info() - self.local_rank = get_local_rank() + if train_nodes == 1: + if self.world_size > 1: + logger.info('init dist on multiple gpus, this gpu {}'.format( + self.local_rank)) + dist.init_process_group( + backend=self.dist_backend, init_method='env://') + elif train_nodes > 1: + dist.init_process_group( + backend=self.dist_backend, init_method='env://') + dist.barrier() + + logger.info('RANK {}/{}/{}, Master addr:{}, Master port:{}'.format( + self.world_size, self.rank, self.local_rank, self.master_addr, + self.master_port)) + + def uninit_dist(self, train_nodes=1): + if train_nodes == 1: + if self.world_size > 1: + logger.info( + 'destory dist on multiple gpus, this gpu {}'.format( + self.local_rank)) + dist.destroy_process_group() + elif train_nodes > 1: + dist.barrier() + dist.destroy_process_group() diff --git a/modelscope/trainers/audio/kws_utils/batch_utils.py b/modelscope/trainers/audio/kws_utils/batch_utils.py index 8dc866e8..ac382b79 100644 --- a/modelscope/trainers/audio/kws_utils/batch_utils.py +++ b/modelscope/trainers/audio/kws_utils/batch_utils.py @@ -65,7 +65,8 @@ def executor_train(model, optimizer, data_loader, device, writer, args): if num_utts == 0: continue logits, _ = model(feats) - loss = ctc_loss(logits, target, feats_lengths, target_lengths) + loss, acc = ctc_loss(logits, target, feats_lengths, target_lengths) + loss = loss / num_utts optimizer.zero_grad() loss.backward() grad_norm = clip_grad_norm_(model.parameters(), clip) @@ -90,11 +91,12 @@ def executor_cv(model, data_loader, device, args): epoch = args.get('epoch', 0) # in order to avoid division by 0 num_seen_utts = 1 + num_seen_tokens = 1 total_loss = 0.0 # [For distributed] Because iteration counts are not always equals between # processes, send stop-flag to the other processes if iterator is finished iterator_stop = torch.tensor(0).to(device) - counter = torch.zeros((3, ), device=device) + counter = torch.zeros((4, ), device=device) rank = args.get('rank', 0) local_rank = args.get('local_rank', 0) @@ -117,18 +119,25 @@ def executor_cv(model, data_loader, device, args): if num_utts == 0: continue logits, _ = model(feats) - loss = ctc_loss(logits, target, feats_lengths, target_lengths) + + loss, acc = ctc_loss(logits, target, feats_lengths, target_lengths, + True) if torch.isfinite(loss): num_seen_utts += num_utts - total_loss += loss.item() * num_utts - counter[0] += loss.item() * num_utts - counter[1] += num_utts + num_seen_tokens += target_lengths.sum() + total_loss += loss.item() + counter[0] += loss.item() + counter[1] += acc * target_lengths.sum() + counter[2] += num_utts + counter[3] += target_lengths.sum() if batch_idx % log_interval == 0: logger.info( - 'RANK {}/{}/{} CV Batch {}/{} size {} loss {:.6f} history loss {:.6f}' + 'RANK {}/{}/{} CV Batch {}/{} size {} loss {:.6f} acc {:.2f} history loss {:.6f}' .format(world_size, rank, local_rank, epoch, batch_idx, - num_utts, loss.item(), total_loss / num_seen_utts)) + num_utts, + loss.item() / num_utts, acc, + total_loss / num_seen_utts)) else: iterator_stop.fill_(1) if world_size > 1: @@ -136,14 +145,15 @@ def executor_cv(model, data_loader, device, args): if world_size > 1: dist.all_reduce(counter, ReduceOp.SUM) - logger.info('Total utts number is {}'.format(counter[1])) + logger.info('Total utts number is {}'.format(counter[2])) counter = counter.to('cpu') - return counter[0].item() / counter[1].item() + return counter[0].item() / counter[2].item(), counter[1].item( + ) / counter[2].item() -def executor_test(model, data_loader, device, keywords_token, - keywords_tokenset, args): +def executor_test(model, data_loader, device, keywords_token, keywords_idxset, + args): ''' Test model with decoder ''' assert args.get('test_dir', None) is not None, \ @@ -151,6 +161,7 @@ def executor_test(model, data_loader, device, keywords_token, score_abs_path = os.path.join(args['test_dir'], 'score.txt') log_interval = args.get('log_interval', 10) + model.eval() infer_seconds = 0.0 decode_seconds = 0.0 with torch.no_grad(), open(score_abs_path, 'w', encoding='utf8') as fout: @@ -175,8 +186,7 @@ def executor_test(model, data_loader, device, keywords_token, key = keys[i] score = logits[i][:feats_lengths[i]] hyps = ctc_prefix_beam_search(score, feats_lengths[i], - keywords_tokenset) - + keywords_idxset) hit_keyword = None hit_score = 1.0 # start = 0; end = 0 @@ -243,8 +253,11 @@ def is_sublist(main_list, check_list): return -1 -def ctc_loss(logits: torch.Tensor, target: torch.Tensor, - logits_lengths: torch.Tensor, target_lengths: torch.Tensor): +def ctc_loss(logits: torch.Tensor, + target: torch.Tensor, + logits_lengths: torch.Tensor, + target_lengths: torch.Tensor, + need_acc: bool = False): """ CTC Loss Args: logits: (B, D), D is the number of keywords plus 1 (non-keyword) @@ -255,14 +268,51 @@ def ctc_loss(logits: torch.Tensor, target: torch.Tensor, (float): loss of current batch """ + acc = 0.0 + if need_acc: + acc = acc_utterance(logits, target, logits_lengths, target_lengths) + # logits: (B, L, D) -> (L, B, D) logits = logits.transpose(0, 1) logits = logits.log_softmax(2) loss = F.ctc_loss( logits, target, logits_lengths, target_lengths, reduction='sum') - loss = loss / logits.size(1) + # loss = loss / logits.size(1) - return loss + return loss, acc + + +def acc_utterance(logits: torch.Tensor, target: torch.Tensor, + logits_length: torch.Tensor, target_length: torch.Tensor): + if logits is None: + return 0 + + logits = logits.softmax(2) # (1, maxlen, vocab_size) + logits = logits.cpu() + target = target.cpu() + + total_word = 0 + total_ins = 0 + total_sub = 0 + total_del = 0 + calculator = Calculator() + for i in range(logits.size(0)): + score = logits[i][:logits_length[i]] + hyps = ctc_prefix_beam_search(score, logits_length[i], None, 3, 5) + lab = [str(item) for item in target[i][:target_length[i]].tolist()] + rec = [] + if len(hyps) > 0: + rec = [str(item) for item in hyps[0][0]] + result = calculator.calculate(lab, rec) + # print(f'result:{result}') + if result['all'] != 0: + total_word += result['all'] + total_ins += result['ins'] + total_sub += result['sub'] + total_del += result['del'] + + return float(total_word - total_ins - total_sub + - total_del) * 100.0 / total_word def ctc_prefix_beam_search( @@ -304,9 +354,14 @@ def ctc_prefix_beam_search( filter_probs = [] filter_index = [] for prob, idx in zip(top_k_probs.tolist(), top_k_index.tolist()): - if prob > 0.05 and idx in keywords_tokenset: - filter_probs.append(prob) - filter_index.append(idx) + if keywords_tokenset is not None: + if prob > 0.05 and idx in keywords_tokenset: + filter_probs.append(prob) + filter_index.append(idx) + else: + if prob > 0.05: + filter_probs.append(prob) + filter_index.append(idx) if len(filter_index) == 0: continue @@ -363,3 +418,161 @@ def ctc_prefix_beam_search( hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in cur_hyps] return hyps + + +class Calculator: + + def __init__(self): + self.data = {} + self.space = [] + self.cost = {} + self.cost['cor'] = 0 + self.cost['sub'] = 1 + self.cost['del'] = 1 + self.cost['ins'] = 1 + + def calculate(self, lab, rec): + # Initialization + lab.insert(0, '') + rec.insert(0, '') + while len(self.space) < len(lab): + self.space.append([]) + for row in self.space: + for element in row: + element['dist'] = 0 + element['error'] = 'non' + while len(row) < len(rec): + row.append({'dist': 0, 'error': 'non'}) + for i in range(len(lab)): + self.space[i][0]['dist'] = i + self.space[i][0]['error'] = 'del' + for j in range(len(rec)): + self.space[0][j]['dist'] = j + self.space[0][j]['error'] = 'ins' + self.space[0][0]['error'] = 'non' + for token in lab: + if token not in self.data and len(token) > 0: + self.data[token] = { + 'all': 0, + 'cor': 0, + 'sub': 0, + 'ins': 0, + 'del': 0 + } + for token in rec: + if token not in self.data and len(token) > 0: + self.data[token] = { + 'all': 0, + 'cor': 0, + 'sub': 0, + 'ins': 0, + 'del': 0 + } + # Computing edit distance + for i, lab_token in enumerate(lab): + for j, rec_token in enumerate(rec): + if i == 0 or j == 0: + continue + min_dist = sys.maxsize + min_error = 'none' + dist = self.space[i - 1][j]['dist'] + self.cost['del'] + error = 'del' + if dist < min_dist: + min_dist = dist + min_error = error + dist = self.space[i][j - 1]['dist'] + self.cost['ins'] + error = 'ins' + if dist < min_dist: + min_dist = dist + min_error = error + if lab_token == rec_token: + dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor'] + error = 'cor' + else: + dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub'] + error = 'sub' + if dist < min_dist: + min_dist = dist + min_error = error + self.space[i][j]['dist'] = min_dist + self.space[i][j]['error'] = min_error + # Tracing back + result = { + 'lab': [], + 'rec': [], + 'all': 0, + 'cor': 0, + 'sub': 0, + 'ins': 0, + 'del': 0 + } + i = len(lab) - 1 + j = len(rec) - 1 + while True: + if self.space[i][j]['error'] == 'cor': # correct + if len(lab[i]) > 0: + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1 + result['all'] = result['all'] + 1 + result['cor'] = result['cor'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, rec[j]) + i = i - 1 + j = j - 1 + elif self.space[i][j]['error'] == 'sub': # substitution + if len(lab[i]) > 0: + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1 + result['all'] = result['all'] + 1 + result['sub'] = result['sub'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, rec[j]) + i = i - 1 + j = j - 1 + elif self.space[i][j]['error'] == 'del': # deletion + if len(lab[i]) > 0: + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1 + result['all'] = result['all'] + 1 + result['del'] = result['del'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, '') + i = i - 1 + elif self.space[i][j]['error'] == 'ins': # insertion + if len(rec[j]) > 0: + self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1 + result['ins'] = result['ins'] + 1 + result['lab'].insert(0, '') + result['rec'].insert(0, rec[j]) + j = j - 1 + elif self.space[i][j]['error'] == 'non': # starting point + break + else: # shouldn't reach here + print( + 'this should not happen , i = {i} , j = {j} , error = {error}' + .format(i=i, j=j, error=self.space[i][j]['error'])) + return result + + def overall(self): + result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0} + for token in self.data: + result['all'] = result['all'] + self.data[token]['all'] + result['cor'] = result['cor'] + self.data[token]['cor'] + result['sub'] = result['sub'] + self.data[token]['sub'] + result['ins'] = result['ins'] + self.data[token]['ins'] + result['del'] = result['del'] + self.data[token]['del'] + return result + + def cluster(self, data): + result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0} + for token in data: + if token in self.data: + result['all'] = result['all'] + self.data[token]['all'] + result['cor'] = result['cor'] + self.data[token]['cor'] + result['sub'] = result['sub'] + self.data[token]['sub'] + result['ins'] = result['ins'] + self.data[token]['ins'] + result['del'] = result['del'] + self.data[token]['del'] + return result + + def keys(self): + return list(self.data.keys()) diff --git a/modelscope/trainers/audio/kws_utils/det_utils.py b/modelscope/trainers/audio/kws_utils/det_utils.py index ee6710f7..7241bbd3 100644 --- a/modelscope/trainers/audio/kws_utils/det_utils.py +++ b/modelscope/trainers/audio/kws_utils/det_utils.py @@ -25,7 +25,7 @@ import numpy as np import torch from modelscope.utils.logger import get_logger -from .file_utils import make_pair, read_lists +from .file_utils import make_pair, read_lists, space_mixed_label logger = get_logger() @@ -68,7 +68,7 @@ def count_duration(tid, data_lists): frames = len(waveform[0]) duration = frames / float(rate) except Exception: - logging.info(f'load file failed: {wav_file}') + logger.info(f'load file failed: {wav_file}') duration = 0.0 obj['duration'] = duration @@ -88,11 +88,12 @@ def load_data_and_score(keywords_list, data_file, trans_file, score_file): is_detected = arr[1] if is_detected == 'detected': if key not in score_table: - score_table.update( - {key: { - 'kw': arr[2], + score_table.update({ + key: { + 'kw': space_mixed_label(arr[2]), 'confi': float(arr[3]) - }}) + } + }) else: if key not in score_table: score_table.update({key: {'kw': 'unknown', 'confi': -1.0}}) @@ -100,13 +101,14 @@ def load_data_and_score(keywords_list, data_file, trans_file, score_file): wav_lists = read_lists(data_file) trans_lists = read_lists(trans_file) data_lists = make_pair(wav_lists, trans_lists) + logger.info(f'origin list samples: {len(data_lists)}') # count duration for each wave use multi-thread num_workers = 8 start = 0 step = int(len(data_lists) / num_workers) tasks = [] - for idx in range(8): + for idx in range(num_workers): if idx != num_workers - 1: task = thread_wrapper(count_duration, (idx, data_lists[start:start + step])) @@ -120,10 +122,12 @@ def load_data_and_score(keywords_list, data_file, trans_file, score_file): for task in tasks: task.join() duration_lists += task.get_result() + logger.info(f'after list samples: {len(duration_lists)}') # build empty structure for keyword-filler infos keyword_filler_table = {} for keyword in keywords_list: + keyword = space_mixed_label(keyword) keyword_filler_table[keyword] = {} keyword_filler_table[keyword]['keyword_table'] = {} keyword_filler_table[keyword]['keyword_duration'] = 0.0 @@ -139,11 +143,15 @@ def load_data_and_score(keywords_list, data_file, trans_file, score_file): key = obj['key'] # wav_file = obj['wav'] txt = obj['txt'] + txt = space_mixed_label(txt) + txt_regstr_lrblk = ' ' + txt + ' ' duration = obj['duration'] assert key in score_table for keyword in keywords_list: - if txt.find(keyword) != -1: + keyword = space_mixed_label(keyword) + keyword_regstr_lrblk = ' ' + keyword + ' ' + if txt_regstr_lrblk.find(keyword_regstr_lrblk) != -1: if keyword == score_table[key]['kw']: keyword_filler_table[keyword]['keyword_table'].update( {key: score_table[key]['confi']}) @@ -203,12 +211,13 @@ def compute_det(**kwargs): score_step = kwargs.get('score_step', 0.001) - keywords_list = keywords.replace(' ', '').strip().split(',') + keywords_list = keywords.strip().split(',') keyword_filler_table = load_data_and_score(keywords_list, test_data, trans_data, score_file) stats_files = {} for keyword in keywords_list: + keyword = space_mixed_label(keyword) keyword_dur = keyword_filler_table[keyword]['keyword_duration'] keyword_num = len(keyword_filler_table[keyword]['keyword_table']) filler_dur = keyword_filler_table[keyword]['filler_duration'] @@ -221,7 +230,8 @@ def compute_det(**kwargs): keyword_dur / 3600.0, keyword_num)) logger.info(' Filler duration: {} Hours'.format(filler_dur / 3600.0)) - stats_file = os.path.join(stats_dir, 'stats_' + keyword + '.txt') + stats_file = os.path.join( + stats_dir, 'stats.' + keyword.replace(' ', '_') + '.txt') with open(stats_file, 'w', encoding='utf8') as fout: threshold = 0.0 while threshold <= 1.0: diff --git a/modelscope/trainers/audio/kws_utils/file_utils.py b/modelscope/trainers/audio/kws_utils/file_utils.py index 95a37153..f2754ff7 100644 --- a/modelscope/trainers/audio/kws_utils/file_utils.py +++ b/modelscope/trainers/audio/kws_utils/file_utils.py @@ -18,14 +18,35 @@ from modelscope.utils.logger import get_logger logger = get_logger() -remove_str = ['!sil', '(noise)', '(noise', 'noise)', '·', '’'] +symbol_str = '[’!"#$%&\'()*+,-./:;<>=?@,。?★、…【】《》?“”‘’![\\]^_`{|}~]+' + + +def split_mixed_label(input_str): + tokens = [] + s = input_str.lower() + while len(s) > 0: + match = re.match(r'[A-Za-z!?,<>()\']+', s) + if match is not None: + word = match.group(0) + else: + word = s[0:1] + tokens.append(word) + s = s.replace(word, '', 1).strip(' ') + return tokens + + +def space_mixed_label(input_str): + splits = split_mixed_label(input_str) + space_str = ''.join(f'{sub} ' for sub in splits) + return space_str.strip() def read_lists(list_file): lists = [] with open(list_file, 'r', encoding='utf8') as fin: for line in fin: - lists.append(line.strip()) + if line.strip() != '': + lists.append(line.strip()) return lists @@ -37,14 +58,7 @@ def make_pair(wav_lists, trans_lists): logger.debug('invalid line in trans file: {}'.format(line.strip())) continue - trans_table[arr[0]] = line.replace(arr[0], '')\ - .replace(' ', '')\ - .replace('(noise)', '')\ - .replace('noise)', '')\ - .replace('(noise', '')\ - .replace('!sil', '')\ - .replace('·', '')\ - .replace('’', '').strip() + trans_table[arr[0]] = line.replace(arr[0], '').strip() lists = [] for line in wav_lists: @@ -86,27 +100,110 @@ def read_lexicon(lexicon_file): return lexicon_table -def query_tokens_id(txt, symbol_table, lexicon_table): - label = tuple() - tokens = [] +def query_token_set(txt, symbol_table, lexicon_table): + tokens_str = tuple() + tokens_idx = tuple() - parts = [txt.replace(' ', '').strip()] + parts = split_mixed_label(txt) for part in parts: - for ch in part: - if ch == ' ': - ch = '▁' - tokens.append(ch) - - for ch in tokens: - if ch in symbol_table: - label = label + (symbol_table[ch], ) - elif ch in lexicon_table: - for sub_ch in lexicon_table[ch]: - if sub_ch in symbol_table: - label = label + (symbol_table[sub_ch], ) - else: - label = label + (symbol_table[''], ) + if part == '!sil' or part == '(sil)' or part == '': + tokens_str = tokens_str + ('!sil', ) + elif part == '' or part == '': + tokens_str = tokens_str + ('', ) + elif part == '(noise)' or part == 'noise)' or part == '(noise' or part == '': + tokens_str = tokens_str + ('', ) + elif part in symbol_table: + tokens_str = tokens_str + (part, ) + elif part in lexicon_table: + for ch in lexicon_table[part]: + tokens_str = tokens_str + (ch, ) else: - label = label + (symbol_table[''], ) + # case with symbols or meaningless english letter combination + part = re.sub(symbol_str, '', part) + for ch in part: + tokens_str = tokens_str + (ch, ) - return label + for ch in tokens_str: + if ch in symbol_table: + tokens_idx = tokens_idx + (symbol_table[ch], ) + elif ch == '!sil': + if 'sil' in symbol_table: + tokens_idx = tokens_idx + (symbol_table['sil'], ) + else: + tokens_idx = tokens_idx + (symbol_table[''], ) + elif ch == '': + if '' in symbol_table: + tokens_idx = tokens_idx + (symbol_table[''], ) + else: + tokens_idx = tokens_idx + (symbol_table[''], ) + else: + if '' in symbol_table: + tokens_idx = tokens_idx + (symbol_table[''], ) + logger.info( + f'\'{ch}\' is not in token set, replace with ') + else: + tokens_idx = tokens_idx + (symbol_table[''], ) + logger.info( + f'\'{ch}\' is not in token set, replace with ') + + return tokens_str, tokens_idx + + +def query_token_list(txt, symbol_table, lexicon_table): + tokens_str = [] + tokens_idx = [] + + parts = split_mixed_label(txt) + for part in parts: + if part == '!sil' or part == '(sil)' or part == '': + tokens_str.append('!sil') + elif part == '' or part == '': + tokens_str.append('') + elif part == '(noise)' or part == 'noise)' or part == '(noise' or part == '': + tokens_str.append('') + elif part in symbol_table: + tokens_str.append(part) + elif part in lexicon_table: + for ch in lexicon_table[part]: + tokens_str.append(ch) + else: + # case with symbols or meaningless english letter combination + part = re.sub(symbol_str, '', part) + for ch in part: + tokens_str.append(ch) + + for ch in tokens_str: + if ch in symbol_table: + tokens_idx.append(symbol_table[ch]) + elif ch == '!sil': + if 'sil' in symbol_table: + tokens_idx.append(symbol_table['sil']) + else: + tokens_idx.append(symbol_table['']) + elif ch == '': + if '' in symbol_table: + tokens_idx.append(symbol_table['']) + else: + tokens_idx.append(symbol_table['']) + else: + if '' in symbol_table: + tokens_idx.append(symbol_table['']) + logger.info( + f'\'{ch}\' is not in token set, replace with ') + else: + tokens_idx.append(symbol_table['']) + logger.info( + f'\'{ch}\' is not in token set, replace with ') + + return tokens_str, tokens_idx + + +def tokenize(data_list, symbol_table, lexicon_table): + for sample in data_list: + assert 'txt' in sample + txt = sample['txt'].strip() + strs, indexs = query_token_list(txt, symbol_table, lexicon_table) + sample['tokens'] = strs + sample['txt'] = indexs + + return data_list diff --git a/modelscope/utils/checkpoint.py b/modelscope/utils/checkpoint.py index ef1fa003..a81ce68f 100644 --- a/modelscope/utils/checkpoint.py +++ b/modelscope/utils/checkpoint.py @@ -100,6 +100,9 @@ def save_checkpoint(model: torch.nn.Module, checkpoint['lr_scheduler'] = lr_scheduler.state_dict() if with_model: + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + model = model.module + _weights = weights_to_cpu(model.state_dict()) if not with_meta: checkpoint = _weights