From 379867739548f394d0fa349ba07afe04adf4c8b6 Mon Sep 17 00:00:00 2001 From: "bin.xue" Date: Wed, 16 Nov 2022 18:39:40 +0800 Subject: [PATCH] [to #42322933] bug fix: deadlock when setting the thread number up to 90 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10743508 * fix: load model directly from .pth --- modelscope/models/audio/kws/farfield/model.py | 3 ++- .../audio/kws_farfield_dataset.py | 18 ++++++++++-------- .../trainers/audio/kws_farfield_trainer.py | 7 ++++--- modelscope/utils/audio/audio_utils.py | 5 ++++- 4 files changed, 20 insertions(+), 13 deletions(-) diff --git a/modelscope/models/audio/kws/farfield/model.py b/modelscope/models/audio/kws/farfield/model.py index af1c0a27..ee0301f9 100644 --- a/modelscope/models/audio/kws/farfield/model.py +++ b/modelscope/models/audio/kws/farfield/model.py @@ -54,7 +54,8 @@ class FSMNSeleNetV2Decorator(TorchModel): ) def __del__(self): - self.tmp_dir.cleanup() + if hasattr(self, 'tmp_dir'): + self.tmp_dir.cleanup() def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: return self.model.forward(input) diff --git a/modelscope/msdatasets/task_datasets/audio/kws_farfield_dataset.py b/modelscope/msdatasets/task_datasets/audio/kws_farfield_dataset.py index 8c518ec9..d4866204 100644 --- a/modelscope/msdatasets/task_datasets/audio/kws_farfield_dataset.py +++ b/modelscope/msdatasets/task_datasets/audio/kws_farfield_dataset.py @@ -188,11 +188,13 @@ class Worker(threading.Thread): class KWSDataLoader: - """ - dataset: the dataset reference - batchsize: data batch size - numworkers: no. of workers - prefetch: prefetch factor + """ Load and organize audio data with multiple threads + + Args: + dataset: the dataset reference + batchsize: data batch size + numworkers: no. of workers + prefetch: prefetch factor """ def __init__(self, dataset, batchsize, numworkers, prefetch=2): @@ -202,7 +204,7 @@ class KWSDataLoader: self.isrun = True # data queue - self.pool = queue.Queue(batchsize * prefetch) + self.pool = queue.Queue(numworkers * prefetch) # initialize workers self.workerlist = [] @@ -270,11 +272,11 @@ class KWSDataLoader: w.stopWorker() while not self.pool.empty(): - self.pool.get(block=True, timeout=0.001) + self.pool.get(block=True, timeout=0.01) # wait workers terminated for w in self.workerlist: while not self.pool.empty(): - self.pool.get(block=True, timeout=0.001) + self.pool.get(block=True, timeout=0.01) w.join() logger.info('KWSDataLoader: All worker stopped.') diff --git a/modelscope/trainers/audio/kws_farfield_trainer.py b/modelscope/trainers/audio/kws_farfield_trainer.py index 85c1a496..9d6013e9 100644 --- a/modelscope/trainers/audio/kws_farfield_trainer.py +++ b/modelscope/trainers/audio/kws_farfield_trainer.py @@ -117,8 +117,7 @@ class KWSFarfieldTrainer(BaseTrainer): self._batch_size = dataloader_config.batch_size_per_gpu if 'model_bin' in kwargs: model_bin_file = os.path.join(self.model_dir, kwargs['model_bin']) - checkpoint = torch.load(model_bin_file) - self.model.load_state_dict(checkpoint) + self.model = torch.load(model_bin_file) # build corresponding optimizer and loss function lr = self.cfg.train.optimizer.lr self.optimizer = optim.Adam(self.model.parameters(), lr) @@ -219,7 +218,9 @@ class KWSFarfieldTrainer(BaseTrainer): # check point ckpt_name = 'checkpoint_{:04d}_loss_train_{:.4f}_loss_val_{:.4f}.pth'.format( self._current_epoch, loss_train_epoch, loss_val_epoch) - torch.save(self.model, os.path.join(self.work_dir, ckpt_name)) + save_path = os.path.join(self.work_dir, ckpt_name) + logger.info(f'Save model to {save_path}') + torch.save(self.model, save_path) # time spent per epoch epochtime = datetime.datetime.now() - epochtime logger.info('Epoch {:04d} time spent: {:.2f} hours'.format( diff --git a/modelscope/utils/audio/audio_utils.py b/modelscope/utils/audio/audio_utils.py index 1ae5c8d2..c56359bd 100644 --- a/modelscope/utils/audio/audio_utils.py +++ b/modelscope/utils/audio/audio_utils.py @@ -43,7 +43,10 @@ def update_conf(origin_config_file, new_config_file, conf_item: [str, str]): def repl(matched): key = matched.group(1) if key in conf_item: - return conf_item[key] + value = conf_item[key] + if not isinstance(value, str): + value = str(value) + return value else: return None