mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 12:09:22 +01:00
[to #42322933] feat: kws support continue training from a checkpoint
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11940446 * feat: kws support continue training from a checkpoint * log: add loading model log
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import datetime
|
||||
import glob
|
||||
import math
|
||||
import os
|
||||
import pickle
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
@@ -29,6 +31,7 @@ BASETRAIN_CONF_HARD = 'basetrain_hard'
|
||||
FINETUNE_CONF_EASY = 'finetune_easy'
|
||||
FINETUNE_CONF_NORMAL = 'finetune_normal'
|
||||
FINETUNE_CONF_HARD = 'finetune_hard'
|
||||
CKPT_PREFIX = 'checkpoint'
|
||||
|
||||
EASY_RATIO = 0.1
|
||||
NORMAL_RATIO = 0.6
|
||||
@@ -110,9 +113,27 @@ class KWSFarfieldTrainer(BaseTrainer):
|
||||
if 'single_rate' in kwargs:
|
||||
self._single_rate = kwargs['single_rate']
|
||||
self._batch_size = dataloader_config.batch_size_per_gpu
|
||||
next_epoch = kwargs.get('next_epoch', 1)
|
||||
self._current_epoch = next_epoch - 1
|
||||
if 'model_bin' in kwargs:
|
||||
model_bin_file = os.path.join(self.model_dir, kwargs['model_bin'])
|
||||
self.model = torch.load(model_bin_file)
|
||||
elif self._current_epoch > 0:
|
||||
# load checkpoint
|
||||
ckpt_file_pattern = os.path.join(
|
||||
self.work_dir, f'{CKPT_PREFIX}_{self._current_epoch:04d}*.pth')
|
||||
ckpt_files = glob.glob(ckpt_file_pattern)
|
||||
if len(ckpt_files) == 1:
|
||||
logger.info('Loading model from checkpoint: %s', ckpt_files[0])
|
||||
self.model = torch.load(ckpt_files[0])
|
||||
elif len(ckpt_files) == 0:
|
||||
raise FileNotFoundError(
|
||||
f'Failed to load checkpoint file like '
|
||||
f'{ckpt_file_pattern}. File not found!')
|
||||
else:
|
||||
raise AssertionError(f'Expecting one but multiple checkpoint'
|
||||
f' files are found: {ckpt_files}')
|
||||
|
||||
# build corresponding optimizer and loss function
|
||||
lr = self.cfg.train.optimizer.lr
|
||||
self.optimizer = optim.Adam(self.model.parameters(), lr)
|
||||
@@ -126,7 +147,6 @@ class KWSFarfieldTrainer(BaseTrainer):
|
||||
conf_file = os.path.join(self.work_dir, f'{conf_key}.conf')
|
||||
update_conf(template_file, conf_file, custom_conf[conf_key])
|
||||
self.conf_files.append(conf_file)
|
||||
self._current_epoch = 0
|
||||
self.stages = (math.floor(self._max_epochs * EASY_RATIO),
|
||||
math.floor(self._max_epochs * NORMAL_RATIO),
|
||||
math.floor(self._max_epochs * HARD_RATIO))
|
||||
@@ -151,30 +171,33 @@ class KWSFarfieldTrainer(BaseTrainer):
|
||||
logger.info('Start training...')
|
||||
totaltime = datetime.datetime.now()
|
||||
|
||||
next_stage_head_epoch = 0
|
||||
for stage, num_epoch in enumerate(self.stages):
|
||||
self.run_stage(stage, num_epoch)
|
||||
next_stage_head_epoch += num_epoch
|
||||
epochs_to_run = next_stage_head_epoch - self._current_epoch
|
||||
self.run_stage(stage, epochs_to_run)
|
||||
|
||||
# total time spent
|
||||
totaltime = datetime.datetime.now() - totaltime
|
||||
logger.info('Total time spent: {:.2f} hours\n'.format(
|
||||
totaltime.total_seconds() / 3600.0))
|
||||
|
||||
def run_stage(self, stage, num_epoch):
|
||||
def run_stage(self, stage, epochs_to_run):
|
||||
"""
|
||||
Run training stages with correspond data
|
||||
|
||||
Args:
|
||||
stage: id of stage
|
||||
num_epoch: the number of epoch to run in this stage
|
||||
epochs_to_run: the number of epoch to run in this stage
|
||||
"""
|
||||
if num_epoch <= 0:
|
||||
if epochs_to_run <= 0:
|
||||
logger.warning(f'Invalid epoch number, stage {stage} exit!')
|
||||
return
|
||||
logger.info(f'Starting stage {stage}...')
|
||||
dataset, dataloader = self.create_dataloader(
|
||||
self.conf_files[stage * 2], self.conf_files[stage * 2 + 1])
|
||||
it = iter(dataloader)
|
||||
for _ in range(num_epoch):
|
||||
for _ in range(epochs_to_run):
|
||||
self._current_epoch += 1
|
||||
epochtime = datetime.datetime.now()
|
||||
logger.info('Start epoch %d...', self._current_epoch)
|
||||
@@ -211,8 +234,9 @@ class KWSFarfieldTrainer(BaseTrainer):
|
||||
logger.info(val_result)
|
||||
self._dump_log(val_result)
|
||||
# check point
|
||||
ckpt_name = 'checkpoint_{:04d}_loss_train_{:.4f}_loss_val_{:.4f}.pth'.format(
|
||||
self._current_epoch, loss_train_epoch, loss_val_epoch)
|
||||
ckpt_name = '{}_{:04d}_loss_train_{:.4f}_loss_val_{:.4f}.pth'.format(
|
||||
CKPT_PREFIX, self._current_epoch, loss_train_epoch,
|
||||
loss_val_epoch)
|
||||
save_path = os.path.join(self.work_dir, ckpt_name)
|
||||
logger.info(f'Save model to {save_path}')
|
||||
torch.save(self.model, save_path)
|
||||
@@ -229,6 +253,14 @@ class KWSFarfieldTrainer(BaseTrainer):
|
||||
"""
|
||||
generate validation set
|
||||
"""
|
||||
val_dump_file = os.path.join(self.work_dir, 'val_dataset.bin')
|
||||
if self._current_epoch > 0:
|
||||
logger.info('Start loading validation set...')
|
||||
with open(val_dump_file, 'rb') as f:
|
||||
self.data_val = pickle.load(f)
|
||||
logger.info('Finish loading validation set!')
|
||||
return
|
||||
|
||||
logger.info('Start generating validation set...')
|
||||
dataset, dataloader = self.create_dataloader(self.conf_files[2],
|
||||
self.conf_files[3])
|
||||
@@ -243,6 +275,9 @@ class KWSFarfieldTrainer(BaseTrainer):
|
||||
|
||||
dataloader.stop()
|
||||
dataset.release()
|
||||
|
||||
with open(val_dump_file, 'wb') as f:
|
||||
pickle.dump(self.data_val, f)
|
||||
logger.info('Finish generating validation set!')
|
||||
|
||||
def create_dataloader(self, base_path, finetune_path):
|
||||
|
||||
@@ -81,3 +81,5 @@ class TestKwsFarfieldTrainer(unittest.TestCase):
|
||||
results_files = os.listdir(self.tmp_dir)
|
||||
self.assertIn(f'{trainer.timestamp}.log.json', results_files,
|
||||
f'work_dir:{self.tmp_dir}')
|
||||
self.assertIn('val_dataset.bin', results_files,
|
||||
f'work_dir:{self.tmp_dir}')
|
||||
|
||||
Reference in New Issue
Block a user