diff --git a/modelscope/trainers/audio/kws_farfield_trainer.py b/modelscope/trainers/audio/kws_farfield_trainer.py index a43d20eb..508517a7 100644 --- a/modelscope/trainers/audio/kws_farfield_trainer.py +++ b/modelscope/trainers/audio/kws_farfield_trainer.py @@ -1,6 +1,8 @@ import datetime +import glob import math import os +import pickle from typing import Callable, Dict, Optional import numpy as np @@ -29,6 +31,7 @@ BASETRAIN_CONF_HARD = 'basetrain_hard' FINETUNE_CONF_EASY = 'finetune_easy' FINETUNE_CONF_NORMAL = 'finetune_normal' FINETUNE_CONF_HARD = 'finetune_hard' +CKPT_PREFIX = 'checkpoint' EASY_RATIO = 0.1 NORMAL_RATIO = 0.6 @@ -110,9 +113,27 @@ class KWSFarfieldTrainer(BaseTrainer): if 'single_rate' in kwargs: self._single_rate = kwargs['single_rate'] self._batch_size = dataloader_config.batch_size_per_gpu + next_epoch = kwargs.get('next_epoch', 1) + self._current_epoch = next_epoch - 1 if 'model_bin' in kwargs: model_bin_file = os.path.join(self.model_dir, kwargs['model_bin']) self.model = torch.load(model_bin_file) + elif self._current_epoch > 0: + # load checkpoint + ckpt_file_pattern = os.path.join( + self.work_dir, f'{CKPT_PREFIX}_{self._current_epoch:04d}*.pth') + ckpt_files = glob.glob(ckpt_file_pattern) + if len(ckpt_files) == 1: + logger.info('Loading model from checkpoint: %s', ckpt_files[0]) + self.model = torch.load(ckpt_files[0]) + elif len(ckpt_files) == 0: + raise FileNotFoundError( + f'Failed to load checkpoint file like ' + f'{ckpt_file_pattern}. File not found!') + else: + raise AssertionError(f'Expecting one but multiple checkpoint' + f' files are found: {ckpt_files}') + # build corresponding optimizer and loss function lr = self.cfg.train.optimizer.lr self.optimizer = optim.Adam(self.model.parameters(), lr) @@ -126,7 +147,6 @@ class KWSFarfieldTrainer(BaseTrainer): conf_file = os.path.join(self.work_dir, f'{conf_key}.conf') update_conf(template_file, conf_file, custom_conf[conf_key]) self.conf_files.append(conf_file) - self._current_epoch = 0 self.stages = (math.floor(self._max_epochs * EASY_RATIO), math.floor(self._max_epochs * NORMAL_RATIO), math.floor(self._max_epochs * HARD_RATIO)) @@ -151,30 +171,33 @@ class KWSFarfieldTrainer(BaseTrainer): logger.info('Start training...') totaltime = datetime.datetime.now() + next_stage_head_epoch = 0 for stage, num_epoch in enumerate(self.stages): - self.run_stage(stage, num_epoch) + next_stage_head_epoch += num_epoch + epochs_to_run = next_stage_head_epoch - self._current_epoch + self.run_stage(stage, epochs_to_run) # total time spent totaltime = datetime.datetime.now() - totaltime logger.info('Total time spent: {:.2f} hours\n'.format( totaltime.total_seconds() / 3600.0)) - def run_stage(self, stage, num_epoch): + def run_stage(self, stage, epochs_to_run): """ Run training stages with correspond data Args: stage: id of stage - num_epoch: the number of epoch to run in this stage + epochs_to_run: the number of epoch to run in this stage """ - if num_epoch <= 0: + if epochs_to_run <= 0: logger.warning(f'Invalid epoch number, stage {stage} exit!') return logger.info(f'Starting stage {stage}...') dataset, dataloader = self.create_dataloader( self.conf_files[stage * 2], self.conf_files[stage * 2 + 1]) it = iter(dataloader) - for _ in range(num_epoch): + for _ in range(epochs_to_run): self._current_epoch += 1 epochtime = datetime.datetime.now() logger.info('Start epoch %d...', self._current_epoch) @@ -211,8 +234,9 @@ class KWSFarfieldTrainer(BaseTrainer): logger.info(val_result) self._dump_log(val_result) # check point - ckpt_name = 'checkpoint_{:04d}_loss_train_{:.4f}_loss_val_{:.4f}.pth'.format( - self._current_epoch, loss_train_epoch, loss_val_epoch) + ckpt_name = '{}_{:04d}_loss_train_{:.4f}_loss_val_{:.4f}.pth'.format( + CKPT_PREFIX, self._current_epoch, loss_train_epoch, + loss_val_epoch) save_path = os.path.join(self.work_dir, ckpt_name) logger.info(f'Save model to {save_path}') torch.save(self.model, save_path) @@ -229,6 +253,14 @@ class KWSFarfieldTrainer(BaseTrainer): """ generate validation set """ + val_dump_file = os.path.join(self.work_dir, 'val_dataset.bin') + if self._current_epoch > 0: + logger.info('Start loading validation set...') + with open(val_dump_file, 'rb') as f: + self.data_val = pickle.load(f) + logger.info('Finish loading validation set!') + return + logger.info('Start generating validation set...') dataset, dataloader = self.create_dataloader(self.conf_files[2], self.conf_files[3]) @@ -243,6 +275,9 @@ class KWSFarfieldTrainer(BaseTrainer): dataloader.stop() dataset.release() + + with open(val_dump_file, 'wb') as f: + pickle.dump(self.data_val, f) logger.info('Finish generating validation set!') def create_dataloader(self, base_path, finetune_path): diff --git a/tests/trainers/audio/test_kws_farfield_trainer.py b/tests/trainers/audio/test_kws_farfield_trainer.py index 70b68a11..cc2b38f6 100644 --- a/tests/trainers/audio/test_kws_farfield_trainer.py +++ b/tests/trainers/audio/test_kws_farfield_trainer.py @@ -81,3 +81,5 @@ class TestKwsFarfieldTrainer(unittest.TestCase): results_files = os.listdir(self.tmp_dir) self.assertIn(f'{trainer.timestamp}.log.json', results_files, f'work_dir:{self.tmp_dir}') + self.assertIn('val_dataset.bin', results_files, + f'work_dir:{self.tmp_dir}')