mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
speech kws nearfield training add gradient accumulation config
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/12204960
This commit is contained in:
committed by
wenmeng.zwm
parent
63d5493962
commit
bb879063e9
@@ -247,6 +247,7 @@ class KWSNearfieldTrainer(BaseTrainer):
|
||||
logger.info('Start training...')
|
||||
training_config = {}
|
||||
training_config['grad_clip'] = optim_conf['grad_clip']
|
||||
training_config['grad_accum'] = optim_conf.get('grad_accum', 1)
|
||||
training_config['log_interval'] = log_interval
|
||||
training_config['world_size'] = self.world_size
|
||||
training_config['rank'] = self.rank
|
||||
|
||||
@@ -44,6 +44,7 @@ def executor_train(model, optimizer, data_loader, device, writer, args):
|
||||
rank = args.get('rank', 0)
|
||||
local_rank = args.get('local_rank', 0)
|
||||
world_size = args.get('world_size', 1)
|
||||
accum_batchs = args.get('grad_accum', 1)
|
||||
|
||||
# [For distributed] Because iteration counts are not always equals between
|
||||
# processes, send stop-flag to the other processes if iterator is finished
|
||||
@@ -67,11 +68,16 @@ def executor_train(model, optimizer, data_loader, device, writer, args):
|
||||
logits, _ = model(feats)
|
||||
loss, acc = ctc_loss(logits, target, feats_lengths, target_lengths)
|
||||
loss = loss / num_utts
|
||||
optimizer.zero_grad()
|
||||
|
||||
# normlize loss to account for batch accumulation
|
||||
loss = loss / accum_batchs
|
||||
loss.backward()
|
||||
grad_norm = clip_grad_norm_(model.parameters(), clip)
|
||||
if torch.isfinite(grad_norm):
|
||||
optimizer.step()
|
||||
if (batch_idx + 1) % accum_batchs == 0:
|
||||
grad_norm = clip_grad_norm_(model.parameters(), clip)
|
||||
if torch.isfinite(grad_norm):
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
logger.info(
|
||||
'RANK {}/{}/{} TRAIN Batch {}/{} size {} loss {:.6f}'.format(
|
||||
@@ -127,7 +133,8 @@ def executor_cv(model, data_loader, device, args):
|
||||
num_seen_tokens += target_lengths.sum()
|
||||
total_loss += loss.item()
|
||||
counter[0] += loss.item()
|
||||
counter[1] += acc * target_lengths.sum()
|
||||
counter[1] += acc * num_utts
|
||||
# counter[1] += acc * target_lengths.sum()
|
||||
counter[2] += num_utts
|
||||
counter[3] += target_lengths.sum()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user