From bb879063e952707dce950910cf5189dba6fc9038 Mon Sep 17 00:00:00 2001 From: "pengteng.spt" Date: Fri, 7 Apr 2023 14:38:14 +0800 Subject: [PATCH] speech kws nearfield training add gradient accumulation config Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/12204960 --- .../trainers/audio/kws_nearfield_trainer.py | 1 + .../trainers/audio/kws_utils/batch_utils.py | 17 ++++++++++++----- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/modelscope/trainers/audio/kws_nearfield_trainer.py b/modelscope/trainers/audio/kws_nearfield_trainer.py index 9a84ce35..d1e3fdee 100644 --- a/modelscope/trainers/audio/kws_nearfield_trainer.py +++ b/modelscope/trainers/audio/kws_nearfield_trainer.py @@ -247,6 +247,7 @@ class KWSNearfieldTrainer(BaseTrainer): logger.info('Start training...') training_config = {} training_config['grad_clip'] = optim_conf['grad_clip'] + training_config['grad_accum'] = optim_conf.get('grad_accum', 1) training_config['log_interval'] = log_interval training_config['world_size'] = self.world_size training_config['rank'] = self.rank diff --git a/modelscope/trainers/audio/kws_utils/batch_utils.py b/modelscope/trainers/audio/kws_utils/batch_utils.py index ac382b79..75cf804e 100644 --- a/modelscope/trainers/audio/kws_utils/batch_utils.py +++ b/modelscope/trainers/audio/kws_utils/batch_utils.py @@ -44,6 +44,7 @@ def executor_train(model, optimizer, data_loader, device, writer, args): rank = args.get('rank', 0) local_rank = args.get('local_rank', 0) world_size = args.get('world_size', 1) + accum_batchs = args.get('grad_accum', 1) # [For distributed] Because iteration counts are not always equals between # processes, send stop-flag to the other processes if iterator is finished @@ -67,11 +68,16 @@ def executor_train(model, optimizer, data_loader, device, writer, args): logits, _ = model(feats) loss, acc = ctc_loss(logits, target, feats_lengths, target_lengths) loss = loss / num_utts - optimizer.zero_grad() + + # normlize loss to account for batch accumulation + loss = loss / accum_batchs loss.backward() - grad_norm = clip_grad_norm_(model.parameters(), clip) - if torch.isfinite(grad_norm): - optimizer.step() + if (batch_idx + 1) % accum_batchs == 0: + grad_norm = clip_grad_norm_(model.parameters(), clip) + if torch.isfinite(grad_norm): + optimizer.step() + optimizer.zero_grad() + if batch_idx % log_interval == 0: logger.info( 'RANK {}/{}/{} TRAIN Batch {}/{} size {} loss {:.6f}'.format( @@ -127,7 +133,8 @@ def executor_cv(model, data_loader, device, args): num_seen_tokens += target_lengths.sum() total_loss += loss.item() counter[0] += loss.item() - counter[1] += acc * target_lengths.sum() + counter[1] += acc * num_utts + # counter[1] += acc * target_lengths.sum() counter[2] += num_utts counter[3] += target_lengths.sum()