mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
[to #42322933] bug fix: deadlock when setting the thread number up to 90
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10743508 * fix: load model directly from .pth
This commit is contained in:
@@ -54,7 +54,8 @@ class FSMNSeleNetV2Decorator(TorchModel):
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
self.tmp_dir.cleanup()
|
||||
if hasattr(self, 'tmp_dir'):
|
||||
self.tmp_dir.cleanup()
|
||||
|
||||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
return self.model.forward(input)
|
||||
|
||||
@@ -188,11 +188,13 @@ class Worker(threading.Thread):
|
||||
|
||||
|
||||
class KWSDataLoader:
|
||||
"""
|
||||
dataset: the dataset reference
|
||||
batchsize: data batch size
|
||||
numworkers: no. of workers
|
||||
prefetch: prefetch factor
|
||||
""" Load and organize audio data with multiple threads
|
||||
|
||||
Args:
|
||||
dataset: the dataset reference
|
||||
batchsize: data batch size
|
||||
numworkers: no. of workers
|
||||
prefetch: prefetch factor
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, batchsize, numworkers, prefetch=2):
|
||||
@@ -202,7 +204,7 @@ class KWSDataLoader:
|
||||
self.isrun = True
|
||||
|
||||
# data queue
|
||||
self.pool = queue.Queue(batchsize * prefetch)
|
||||
self.pool = queue.Queue(numworkers * prefetch)
|
||||
|
||||
# initialize workers
|
||||
self.workerlist = []
|
||||
@@ -270,11 +272,11 @@ class KWSDataLoader:
|
||||
w.stopWorker()
|
||||
|
||||
while not self.pool.empty():
|
||||
self.pool.get(block=True, timeout=0.001)
|
||||
self.pool.get(block=True, timeout=0.01)
|
||||
|
||||
# wait workers terminated
|
||||
for w in self.workerlist:
|
||||
while not self.pool.empty():
|
||||
self.pool.get(block=True, timeout=0.001)
|
||||
self.pool.get(block=True, timeout=0.01)
|
||||
w.join()
|
||||
logger.info('KWSDataLoader: All worker stopped.')
|
||||
|
||||
@@ -117,8 +117,7 @@ class KWSFarfieldTrainer(BaseTrainer):
|
||||
self._batch_size = dataloader_config.batch_size_per_gpu
|
||||
if 'model_bin' in kwargs:
|
||||
model_bin_file = os.path.join(self.model_dir, kwargs['model_bin'])
|
||||
checkpoint = torch.load(model_bin_file)
|
||||
self.model.load_state_dict(checkpoint)
|
||||
self.model = torch.load(model_bin_file)
|
||||
# build corresponding optimizer and loss function
|
||||
lr = self.cfg.train.optimizer.lr
|
||||
self.optimizer = optim.Adam(self.model.parameters(), lr)
|
||||
@@ -219,7 +218,9 @@ class KWSFarfieldTrainer(BaseTrainer):
|
||||
# check point
|
||||
ckpt_name = 'checkpoint_{:04d}_loss_train_{:.4f}_loss_val_{:.4f}.pth'.format(
|
||||
self._current_epoch, loss_train_epoch, loss_val_epoch)
|
||||
torch.save(self.model, os.path.join(self.work_dir, ckpt_name))
|
||||
save_path = os.path.join(self.work_dir, ckpt_name)
|
||||
logger.info(f'Save model to {save_path}')
|
||||
torch.save(self.model, save_path)
|
||||
# time spent per epoch
|
||||
epochtime = datetime.datetime.now() - epochtime
|
||||
logger.info('Epoch {:04d} time spent: {:.2f} hours'.format(
|
||||
|
||||
@@ -43,7 +43,10 @@ def update_conf(origin_config_file, new_config_file, conf_item: [str, str]):
|
||||
def repl(matched):
|
||||
key = matched.group(1)
|
||||
if key in conf_item:
|
||||
return conf_item[key]
|
||||
value = conf_item[key]
|
||||
if not isinstance(value, str):
|
||||
value = str(value)
|
||||
return value
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user