[to #42322933] feat: kws support continue training from a checkpoint

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11940446

* feat: kws support continue training from a checkpoint

* log: add loading model log
This commit is contained in:
bin.xue
2023-03-09 21:07:34 +08:00
committed by wenmeng.zwm
parent a2bed42fe1
commit 281d3c630e
2 changed files with 45 additions and 8 deletions

View File

@@ -1,6 +1,8 @@
import datetime
import glob
import math
import os
import pickle
from typing import Callable, Dict, Optional
import numpy as np
@@ -29,6 +31,7 @@ BASETRAIN_CONF_HARD = 'basetrain_hard'
FINETUNE_CONF_EASY = 'finetune_easy'
FINETUNE_CONF_NORMAL = 'finetune_normal'
FINETUNE_CONF_HARD = 'finetune_hard'
CKPT_PREFIX = 'checkpoint'
EASY_RATIO = 0.1
NORMAL_RATIO = 0.6
@@ -110,9 +113,27 @@ class KWSFarfieldTrainer(BaseTrainer):
if 'single_rate' in kwargs:
self._single_rate = kwargs['single_rate']
self._batch_size = dataloader_config.batch_size_per_gpu
next_epoch = kwargs.get('next_epoch', 1)
self._current_epoch = next_epoch - 1
if 'model_bin' in kwargs:
model_bin_file = os.path.join(self.model_dir, kwargs['model_bin'])
self.model = torch.load(model_bin_file)
elif self._current_epoch > 0:
# load checkpoint
ckpt_file_pattern = os.path.join(
self.work_dir, f'{CKPT_PREFIX}_{self._current_epoch:04d}*.pth')
ckpt_files = glob.glob(ckpt_file_pattern)
if len(ckpt_files) == 1:
logger.info('Loading model from checkpoint: %s', ckpt_files[0])
self.model = torch.load(ckpt_files[0])
elif len(ckpt_files) == 0:
raise FileNotFoundError(
f'Failed to load checkpoint file like '
f'{ckpt_file_pattern}. File not found!')
else:
raise AssertionError(f'Expecting one but multiple checkpoint'
f' files are found: {ckpt_files}')
# build corresponding optimizer and loss function
lr = self.cfg.train.optimizer.lr
self.optimizer = optim.Adam(self.model.parameters(), lr)
@@ -126,7 +147,6 @@ class KWSFarfieldTrainer(BaseTrainer):
conf_file = os.path.join(self.work_dir, f'{conf_key}.conf')
update_conf(template_file, conf_file, custom_conf[conf_key])
self.conf_files.append(conf_file)
self._current_epoch = 0
self.stages = (math.floor(self._max_epochs * EASY_RATIO),
math.floor(self._max_epochs * NORMAL_RATIO),
math.floor(self._max_epochs * HARD_RATIO))
@@ -151,30 +171,33 @@ class KWSFarfieldTrainer(BaseTrainer):
logger.info('Start training...')
totaltime = datetime.datetime.now()
next_stage_head_epoch = 0
for stage, num_epoch in enumerate(self.stages):
self.run_stage(stage, num_epoch)
next_stage_head_epoch += num_epoch
epochs_to_run = next_stage_head_epoch - self._current_epoch
self.run_stage(stage, epochs_to_run)
# total time spent
totaltime = datetime.datetime.now() - totaltime
logger.info('Total time spent: {:.2f} hours\n'.format(
totaltime.total_seconds() / 3600.0))
def run_stage(self, stage, num_epoch):
def run_stage(self, stage, epochs_to_run):
"""
Run training stages with correspond data
Args:
stage: id of stage
num_epoch: the number of epoch to run in this stage
epochs_to_run: the number of epoch to run in this stage
"""
if num_epoch <= 0:
if epochs_to_run <= 0:
logger.warning(f'Invalid epoch number, stage {stage} exit!')
return
logger.info(f'Starting stage {stage}...')
dataset, dataloader = self.create_dataloader(
self.conf_files[stage * 2], self.conf_files[stage * 2 + 1])
it = iter(dataloader)
for _ in range(num_epoch):
for _ in range(epochs_to_run):
self._current_epoch += 1
epochtime = datetime.datetime.now()
logger.info('Start epoch %d...', self._current_epoch)
@@ -211,8 +234,9 @@ class KWSFarfieldTrainer(BaseTrainer):
logger.info(val_result)
self._dump_log(val_result)
# check point
ckpt_name = 'checkpoint_{:04d}_loss_train_{:.4f}_loss_val_{:.4f}.pth'.format(
self._current_epoch, loss_train_epoch, loss_val_epoch)
ckpt_name = '{}_{:04d}_loss_train_{:.4f}_loss_val_{:.4f}.pth'.format(
CKPT_PREFIX, self._current_epoch, loss_train_epoch,
loss_val_epoch)
save_path = os.path.join(self.work_dir, ckpt_name)
logger.info(f'Save model to {save_path}')
torch.save(self.model, save_path)
@@ -229,6 +253,14 @@ class KWSFarfieldTrainer(BaseTrainer):
"""
generate validation set
"""
val_dump_file = os.path.join(self.work_dir, 'val_dataset.bin')
if self._current_epoch > 0:
logger.info('Start loading validation set...')
with open(val_dump_file, 'rb') as f:
self.data_val = pickle.load(f)
logger.info('Finish loading validation set!')
return
logger.info('Start generating validation set...')
dataset, dataloader = self.create_dataloader(self.conf_files[2],
self.conf_files[3])
@@ -243,6 +275,9 @@ class KWSFarfieldTrainer(BaseTrainer):
dataloader.stop()
dataset.release()
with open(val_dump_file, 'wb') as f:
pickle.dump(self.data_val, f)
logger.info('Finish generating validation set!')
def create_dataloader(self, base_path, finetune_path):

View File

@@ -81,3 +81,5 @@ class TestKwsFarfieldTrainer(unittest.TestCase):
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files,
f'work_dir:{self.tmp_dir}')
self.assertIn('val_dataset.bin', results_files,
f'work_dir:{self.tmp_dir}')