speech kws nearfield training add gradient accumulation config

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/12204960
This commit is contained in:
pengteng.spt
2023-04-07 14:38:14 +08:00
committed by wenmeng.zwm
parent 63d5493962
commit bb879063e9
2 changed files with 13 additions and 5 deletions

View File

@@ -247,6 +247,7 @@ class KWSNearfieldTrainer(BaseTrainer):
logger.info('Start training...')
training_config = {}
training_config['grad_clip'] = optim_conf['grad_clip']
training_config['grad_accum'] = optim_conf.get('grad_accum', 1)
training_config['log_interval'] = log_interval
training_config['world_size'] = self.world_size
training_config['rank'] = self.rank

View File

@@ -44,6 +44,7 @@ def executor_train(model, optimizer, data_loader, device, writer, args):
rank = args.get('rank', 0)
local_rank = args.get('local_rank', 0)
world_size = args.get('world_size', 1)
accum_batchs = args.get('grad_accum', 1)
# [For distributed] Because iteration counts are not always equals between
# processes, send stop-flag to the other processes if iterator is finished
@@ -67,11 +68,16 @@ def executor_train(model, optimizer, data_loader, device, writer, args):
logits, _ = model(feats)
loss, acc = ctc_loss(logits, target, feats_lengths, target_lengths)
loss = loss / num_utts
optimizer.zero_grad()
# normlize loss to account for batch accumulation
loss = loss / accum_batchs
loss.backward()
grad_norm = clip_grad_norm_(model.parameters(), clip)
if torch.isfinite(grad_norm):
optimizer.step()
if (batch_idx + 1) % accum_batchs == 0:
grad_norm = clip_grad_norm_(model.parameters(), clip)
if torch.isfinite(grad_norm):
optimizer.step()
optimizer.zero_grad()
if batch_idx % log_interval == 0:
logger.info(
'RANK {}/{}/{} TRAIN Batch {}/{} size {} loss {:.6f}'.format(
@@ -127,7 +133,8 @@ def executor_cv(model, data_loader, device, args):
num_seen_tokens += target_lengths.sum()
total_loss += loss.item()
counter[0] += loss.item()
counter[1] += acc * target_lengths.sum()
counter[1] += acc * num_utts
# counter[1] += acc * target_lengths.sum()
counter[2] += num_utts
counter[3] += target_lengths.sum()