mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 04:01:10 +01:00
[to #42322933] add far field KWS trainer
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10275823
This commit is contained in:
3
data/test/audios/noise_2ch.wav
Normal file
3
data/test/audios/noise_2ch.wav
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e8d653a9a1ee49789c3df38e8da96af7118e0d8336d6ed12cd6458efa015071d
|
||||
size 2327764
|
||||
3
data/test/audios/wake_word_with_label_xyxy.wav
Normal file
3
data/test/audios/wake_word_with_label_xyxy.wav
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c589d77404ea17d4d24daeb8624dce7e1ac919dc75e6bed44ea9d116f0514150
|
||||
size 68524
|
||||
@@ -285,6 +285,7 @@ class Trainers(object):
|
||||
|
||||
# audio trainers
|
||||
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
|
||||
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'
|
||||
|
||||
|
||||
class Preprocessors(object):
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from typing import Dict, Optional
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models import TorchModel
|
||||
from modelscope.models.base import Tensor
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.audio.audio_utils import update_conf
|
||||
from modelscope.utils.constant import Tasks
|
||||
from .fsmn_sele_v2 import FSMNSeleNetV2
|
||||
|
||||
|
||||
@@ -20,48 +19,38 @@ class FSMNSeleNetV2Decorator(TorchModel):
|
||||
|
||||
MODEL_TXT = 'model.txt'
|
||||
SC_CONFIG = 'sound_connect.conf'
|
||||
SC_CONF_ITEM_KWS_MODEL = '${kws_model}'
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
def __init__(self,
|
||||
model_dir: str,
|
||||
training: Optional[bool] = False,
|
||||
*args,
|
||||
**kwargs):
|
||||
"""initialize the dfsmn model from the `model_dir` path.
|
||||
|
||||
Args:
|
||||
model_dir (str): the model path.
|
||||
"""
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
sc_config_file = os.path.join(model_dir, self.SC_CONFIG)
|
||||
model_txt_file = os.path.join(model_dir, self.MODEL_TXT)
|
||||
model_bin_file = os.path.join(model_dir,
|
||||
ModelFile.TORCH_MODEL_BIN_FILE)
|
||||
self._model = None
|
||||
if os.path.exists(model_bin_file):
|
||||
kwargs.pop('device')
|
||||
self._model = FSMNSeleNetV2(*args, **kwargs)
|
||||
checkpoint = torch.load(model_bin_file)
|
||||
self._model.load_state_dict(checkpoint, strict=False)
|
||||
|
||||
self._sc = None
|
||||
if os.path.exists(model_txt_file):
|
||||
with open(sc_config_file) as f:
|
||||
lines = f.readlines()
|
||||
with open(sc_config_file, 'w') as f:
|
||||
for line in lines:
|
||||
if self.SC_CONF_ITEM_KWS_MODEL in line:
|
||||
line = line.replace(self.SC_CONF_ITEM_KWS_MODEL,
|
||||
model_txt_file)
|
||||
f.write(line)
|
||||
import py_sound_connect
|
||||
self._sc = py_sound_connect.SoundConnect(sc_config_file)
|
||||
self.size_in = self._sc.bytesPerBlockIn()
|
||||
self.size_out = self._sc.bytesPerBlockOut()
|
||||
|
||||
if self._model is None and self._sc is None:
|
||||
raise Exception(
|
||||
f'Invalid model directory! Neither {model_txt_file} nor {model_bin_file} exists.'
|
||||
)
|
||||
if training:
|
||||
self.model = FSMNSeleNetV2(*args, **kwargs)
|
||||
else:
|
||||
sc_config_file = os.path.join(model_dir, self.SC_CONFIG)
|
||||
model_txt_file = os.path.join(model_dir, self.MODEL_TXT)
|
||||
self._sc = None
|
||||
if os.path.exists(model_txt_file):
|
||||
conf_dict = dict(mode=56542, kws_model=model_txt_file)
|
||||
update_conf(sc_config_file, sc_config_file, conf_dict)
|
||||
import py_sound_connect
|
||||
self._sc = py_sound_connect.SoundConnect(sc_config_file)
|
||||
self.size_in = self._sc.bytesPerBlockIn()
|
||||
self.size_out = self._sc.bytesPerBlockOut()
|
||||
else:
|
||||
raise Exception(
|
||||
f'Invalid model directory! Failed to load model file: {model_txt_file}.'
|
||||
)
|
||||
|
||||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
...
|
||||
return self.model.forward(input)
|
||||
|
||||
def forward_decode(self, data: bytes):
|
||||
result = {'pcm': self._sc.process(data, self.size_out)}
|
||||
|
||||
21
modelscope/msdatasets/task_datasets/audio/__init__.py
Normal file
21
modelscope/msdatasets/task_datasets/audio/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .kws_farfield_dataset import KWSDataset, KWSDataLoader
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'kws_farfield_dataset': ['KWSDataset', 'KWSDataLoader'],
|
||||
}
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
@@ -0,0 +1,280 @@
|
||||
"""
|
||||
Used to prepare simulated data.
|
||||
"""
|
||||
import math
|
||||
import os.path
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
BLOCK_DEC = 2
|
||||
BLOCK_CAT = 3
|
||||
FBANK_SIZE = 40
|
||||
LABEL_SIZE = 1
|
||||
LABEL_GAIN = 100.0
|
||||
|
||||
|
||||
class KWSDataset:
|
||||
"""
|
||||
dataset for keyword spotting and vad
|
||||
conf_basetrain: basetrain configure file path
|
||||
conf_finetune: finetune configure file path, null allowed
|
||||
numworkers: no. of workers
|
||||
basetrainratio: basetrain workers ratio
|
||||
numclasses: no. of nn output classes, 2 classes to generate vad label
|
||||
blockdec: block decimation
|
||||
blockcat: block concatenation
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
conf_basetrain,
|
||||
conf_finetune,
|
||||
numworkers,
|
||||
basetrainratio,
|
||||
numclasses,
|
||||
blockdec=BLOCK_CAT,
|
||||
blockcat=BLOCK_CAT):
|
||||
super().__init__()
|
||||
self.numclasses = numclasses
|
||||
self.blockdec = blockdec
|
||||
self.blockcat = blockcat
|
||||
self.sims_base = []
|
||||
self.sims_senior = []
|
||||
self.setup_sims(conf_basetrain, conf_finetune, numworkers,
|
||||
basetrainratio)
|
||||
|
||||
def release(self):
|
||||
for sim in self.sims_base:
|
||||
del sim
|
||||
for sim in self.sims_senior:
|
||||
del sim
|
||||
del self.base_conf
|
||||
del self.senior_conf
|
||||
logger.info('KWSDataset: Released.')
|
||||
|
||||
def setup_sims(self, conf_basetrain, conf_finetune, numworkers,
|
||||
basetrainratio):
|
||||
if not os.path.exists(conf_basetrain):
|
||||
raise ValueError(f'{conf_basetrain} does not exist!')
|
||||
if not os.path.exists(conf_finetune):
|
||||
raise ValueError(f'{conf_finetune} does not exist!')
|
||||
import py_sound_connect
|
||||
logger.info('KWSDataset init SoundConnect...')
|
||||
num_base = math.ceil(numworkers * basetrainratio)
|
||||
num_senior = numworkers - num_base
|
||||
# hold by fields to avoid python releasing conf object
|
||||
self.base_conf = py_sound_connect.ConfigFile(conf_basetrain)
|
||||
self.senior_conf = py_sound_connect.ConfigFile(conf_finetune)
|
||||
for i in range(num_base):
|
||||
fs = py_sound_connect.FeatSimuKWS(self.base_conf.params)
|
||||
self.sims_base.append(fs)
|
||||
for i in range(num_senior):
|
||||
self.sims_senior.append(
|
||||
py_sound_connect.FeatSimuKWS(self.senior_conf.params))
|
||||
logger.info('KWSDataset init SoundConnect finished.')
|
||||
|
||||
def getBatch(self, id):
|
||||
"""
|
||||
Generate a data batch
|
||||
|
||||
Args:
|
||||
id: worker id
|
||||
|
||||
Return: time x channel x feature, label
|
||||
"""
|
||||
fs = self.get_sim(id)
|
||||
fs.processBatch()
|
||||
# get multi-channel feature vector size
|
||||
featsize = fs.featSize()
|
||||
# get label vector size
|
||||
labelsize = fs.labelSize()
|
||||
# get minibatch size (time dimension)
|
||||
# batchsize = fs.featBatchSize()
|
||||
# no. of fe output channels
|
||||
numchs = featsize // FBANK_SIZE
|
||||
# get raw data
|
||||
fs_feat = fs.feat()
|
||||
data = np.frombuffer(fs_feat, dtype='float32')
|
||||
data = data.reshape((-1, featsize + labelsize))
|
||||
|
||||
# convert float label to int
|
||||
label = data[:, FBANK_SIZE * numchs:]
|
||||
|
||||
if self.numclasses == 2:
|
||||
# generate vad label
|
||||
label[label > 0.0] = 1.0
|
||||
else:
|
||||
# generate kws label
|
||||
label = np.round(label * LABEL_GAIN)
|
||||
label[label > self.numclasses - 1] = 0.0
|
||||
|
||||
# decimated size
|
||||
size1 = int(np.ceil(
|
||||
label.shape[0] / self.blockdec)) - self.blockcat + 1
|
||||
|
||||
# label decimation
|
||||
label1 = np.zeros((size1, LABEL_SIZE), dtype='float32')
|
||||
for tau in range(size1):
|
||||
label1[tau, :] = label[(tau + self.blockcat // 2)
|
||||
* self.blockdec, :]
|
||||
|
||||
# feature decimation and concatenation
|
||||
# time x channel x feature
|
||||
featall = np.zeros((size1, numchs, FBANK_SIZE * self.blockcat),
|
||||
dtype='float32')
|
||||
for n in range(numchs):
|
||||
feat = data[:, FBANK_SIZE * n:FBANK_SIZE * (n + 1)]
|
||||
|
||||
for tau in range(size1):
|
||||
for i in range(self.blockcat):
|
||||
featall[tau, n, FBANK_SIZE * i:FBANK_SIZE * (i + 1)] = \
|
||||
feat[(tau + i) * self.blockdec, :]
|
||||
|
||||
return torch.from_numpy(featall), torch.from_numpy(label1).long()
|
||||
|
||||
def get_sim(self, id):
|
||||
num_base = len(self.sims_base)
|
||||
if id < num_base:
|
||||
fs = self.sims_base[id]
|
||||
else:
|
||||
fs = self.sims_senior[id - num_base]
|
||||
return fs
|
||||
|
||||
|
||||
class Worker(threading.Thread):
|
||||
"""
|
||||
id: worker id
|
||||
dataset: the dataset
|
||||
pool: queue as the global data buffer
|
||||
"""
|
||||
|
||||
def __init__(self, id, dataset, pool):
|
||||
threading.Thread.__init__(self)
|
||||
|
||||
self.id = id
|
||||
self.dataset = dataset
|
||||
self.pool = pool
|
||||
self.isrun = True
|
||||
self.nn = 0
|
||||
|
||||
def run(self):
|
||||
while self.isrun:
|
||||
self.nn += 1
|
||||
logger.debug(f'Worker {self.id:02d} running {self.nn:05d}:1')
|
||||
# get simulated minibatch
|
||||
if self.isrun:
|
||||
data = self.dataset.getBatch(self.id)
|
||||
logger.debug(f'Worker {self.id:02d} running {self.nn:05d}:2')
|
||||
|
||||
# put data into buffer
|
||||
if self.isrun:
|
||||
self.pool.put(data)
|
||||
logger.debug(f'Worker {self.id:02d} running {self.nn:05d}:3')
|
||||
|
||||
logger.info('KWSDataLoader: Worker {:02d} stopped.'.format(self.id))
|
||||
|
||||
def stopWorker(self):
|
||||
"""
|
||||
stop the worker thread
|
||||
"""
|
||||
self.isrun = False
|
||||
|
||||
|
||||
class KWSDataLoader:
|
||||
"""
|
||||
dataset: the dataset reference
|
||||
batchsize: data batch size
|
||||
numworkers: no. of workers
|
||||
prefetch: prefetch factor
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, batchsize, numworkers, prefetch=2):
|
||||
self.dataset = dataset
|
||||
self.batchsize = batchsize
|
||||
self.datamap = {}
|
||||
self.isrun = True
|
||||
|
||||
# data queue
|
||||
self.pool = queue.Queue(batchsize * prefetch)
|
||||
|
||||
# initialize workers
|
||||
self.workerlist = []
|
||||
for id in range(numworkers):
|
||||
w = Worker(id, dataset, self.pool)
|
||||
self.workerlist.append(w)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
while self.isrun:
|
||||
# get data from common data pool
|
||||
data = self.pool.get()
|
||||
self.pool.task_done()
|
||||
|
||||
# group minibatches with the same shape
|
||||
key = str(data[0].shape)
|
||||
|
||||
batchl = self.datamap.get(key)
|
||||
if batchl is None:
|
||||
batchl = []
|
||||
self.datamap.update({key: batchl})
|
||||
|
||||
batchl.append(data)
|
||||
|
||||
# a full data batch collected
|
||||
if len(batchl) >= self.batchsize:
|
||||
featbatch = []
|
||||
labelbatch = []
|
||||
|
||||
for feat, label in batchl:
|
||||
featbatch.append(feat)
|
||||
labelbatch.append(label)
|
||||
|
||||
batchl.clear()
|
||||
|
||||
feattensor = torch.stack(featbatch, dim=0)
|
||||
labeltensor = torch.stack(labelbatch, dim=0)
|
||||
|
||||
if feattensor.shape[-2] == 1:
|
||||
logger.debug('KWSDataLoader: Basetrain batch.')
|
||||
else:
|
||||
logger.debug('KWSDataLoader: Finetune batch.')
|
||||
|
||||
return feattensor, labeltensor
|
||||
|
||||
return None, None
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
start multi-thread data loader
|
||||
"""
|
||||
for w in self.workerlist:
|
||||
w.start()
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
stop data loader
|
||||
"""
|
||||
logger.info('KWSDataLoader: Stopping...')
|
||||
self.isrun = False
|
||||
|
||||
for w in self.workerlist:
|
||||
w.stopWorker()
|
||||
|
||||
while not self.pool.empty():
|
||||
self.pool.get(block=True, timeout=0.001)
|
||||
|
||||
# wait workers terminated
|
||||
for w in self.workerlist:
|
||||
while not self.pool.empty():
|
||||
self.pool.get(block=True, timeout=0.001)
|
||||
w.join()
|
||||
logger.info('KWSDataLoader: All worker stopped.')
|
||||
279
modelscope/trainers/audio/kws_farfield_trainer.py
Normal file
279
modelscope/trainers/audio/kws_farfield_trainer.py
Normal file
@@ -0,0 +1,279 @@
|
||||
import datetime
|
||||
import math
|
||||
import os
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch import optim as optim
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.models import Model, TorchModel
|
||||
from modelscope.msdatasets.task_datasets.audio import KWSDataLoader, KWSDataset
|
||||
from modelscope.trainers.base import BaseTrainer
|
||||
from modelscope.trainers.builder import TRAINERS
|
||||
from modelscope.utils.audio.audio_utils import update_conf
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile
|
||||
from modelscope.utils.data_utils import to_device
|
||||
from modelscope.utils.device import create_device
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.torch_utils import (get_dist_info, get_local_rank,
|
||||
init_dist, is_master)
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
BASETRAIN_CONF_EASY = 'basetrain_easy'
|
||||
BASETRAIN_CONF_NORMAL = 'basetrain_normal'
|
||||
BASETRAIN_CONF_HARD = 'basetrain_hard'
|
||||
FINETUNE_CONF_EASY = 'finetune_easy'
|
||||
FINETUNE_CONF_NORMAL = 'finetune_normal'
|
||||
FINETUNE_CONF_HARD = 'finetune_hard'
|
||||
|
||||
EASY_RATIO = 0.1
|
||||
NORMAL_RATIO = 0.6
|
||||
HARD_RATIO = 0.3
|
||||
BASETRAIN_RATIO = 0.5
|
||||
|
||||
|
||||
@TRAINERS.register_module(module_name=Trainers.speech_dfsmn_kws_char_farfield)
|
||||
class KWSFarfieldTrainer(BaseTrainer):
|
||||
DEFAULT_WORK_DIR = './work_dir'
|
||||
conf_keys = (BASETRAIN_CONF_EASY, FINETUNE_CONF_EASY,
|
||||
BASETRAIN_CONF_NORMAL, FINETUNE_CONF_NORMAL,
|
||||
BASETRAIN_CONF_HARD, FINETUNE_CONF_HARD)
|
||||
|
||||
def __init__(self,
|
||||
model: str,
|
||||
work_dir: str,
|
||||
cfg_file: Optional[str] = None,
|
||||
arg_parse_fn: Optional[Callable] = None,
|
||||
model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||
custom_conf: Optional[dict] = None,
|
||||
**kwargs):
|
||||
|
||||
if isinstance(model, str):
|
||||
if os.path.exists(model):
|
||||
self.model_dir = model if os.path.isdir(
|
||||
model) else os.path.dirname(model)
|
||||
else:
|
||||
self.model_dir = snapshot_download(
|
||||
model, revision=model_revision)
|
||||
if cfg_file is None:
|
||||
cfg_file = os.path.join(self.model_dir,
|
||||
ModelFile.CONFIGURATION)
|
||||
else:
|
||||
assert cfg_file is not None, 'Config file should not be None if model is not from pretrained!'
|
||||
self.model_dir = os.path.dirname(cfg_file)
|
||||
|
||||
super().__init__(cfg_file, arg_parse_fn)
|
||||
|
||||
self.model = self.build_model()
|
||||
self.work_dir = work_dir
|
||||
# the number of model output dimension
|
||||
# should update config outside the trainer, if user need more wake word
|
||||
self._num_classes = self.cfg.model.num_syn
|
||||
|
||||
if kwargs.get('launcher', None) is not None:
|
||||
init_dist(kwargs['launcher'])
|
||||
|
||||
_, world_size = get_dist_info()
|
||||
self._dist = world_size > 1
|
||||
|
||||
device_name = kwargs.get('device', 'gpu')
|
||||
if self._dist:
|
||||
local_rank = get_local_rank()
|
||||
device_name = f'cuda:{local_rank}'
|
||||
|
||||
self.device = create_device(device_name)
|
||||
# model placement
|
||||
if self.device.type == 'cuda':
|
||||
self.model.to(self.device)
|
||||
|
||||
if 'max_epochs' not in kwargs:
|
||||
assert hasattr(
|
||||
self.cfg.train, 'max_epochs'
|
||||
), 'max_epochs is missing from the configuration file'
|
||||
self._max_epochs = self.cfg.train.max_epochs
|
||||
else:
|
||||
self._max_epochs = kwargs['max_epochs']
|
||||
self._train_iters = kwargs.get('train_iters_per_epoch', None)
|
||||
self._val_iters = kwargs.get('val_iters_per_epoch', None)
|
||||
if self._train_iters is None:
|
||||
self._train_iters = self.cfg.train.train_iters_per_epoch
|
||||
if self._val_iters is None:
|
||||
self._val_iters = self.cfg.evaluation.val_iters_per_epoch
|
||||
dataloader_config = self.cfg.train.dataloader
|
||||
self._threads = kwargs.get('workers', None)
|
||||
if self._threads is None:
|
||||
self._threads = dataloader_config.workers_per_gpu
|
||||
self._single_rate = BASETRAIN_RATIO
|
||||
if 'single_rate' in kwargs:
|
||||
self._single_rate = kwargs['single_rate']
|
||||
self._batch_size = dataloader_config.batch_size_per_gpu
|
||||
if 'model_bin' in kwargs:
|
||||
model_bin_file = os.path.join(self.model_dir, kwargs['model_bin'])
|
||||
checkpoint = torch.load(model_bin_file)
|
||||
self.model.load_state_dict(checkpoint)
|
||||
# build corresponding optimizer and loss function
|
||||
lr = self.cfg.train.optimizer.lr
|
||||
self.optimizer = optim.Adam(self.model.parameters(), lr)
|
||||
self.loss_fn = nn.CrossEntropyLoss()
|
||||
self.data_val = None
|
||||
self.json_log_path = os.path.join(self.work_dir,
|
||||
'{}.log.json'.format(self.timestamp))
|
||||
self.conf_files = []
|
||||
for conf_key in self.conf_keys:
|
||||
template_file = os.path.join(self.model_dir, conf_key)
|
||||
conf_file = os.path.join(self.model_dir, f'{conf_key}.conf')
|
||||
update_conf(template_file, conf_file, custom_conf[conf_key])
|
||||
self.conf_files.append(conf_file)
|
||||
self._current_epoch = 0
|
||||
self.stages = (math.floor(self._max_epochs * EASY_RATIO),
|
||||
math.floor(self._max_epochs * NORMAL_RATIO),
|
||||
math.floor(self._max_epochs * HARD_RATIO))
|
||||
|
||||
def build_model(self) -> nn.Module:
|
||||
""" Instantiate a pytorch model and return.
|
||||
|
||||
By default, we will create a model using config from configuration file. You can
|
||||
override this method in a subclass.
|
||||
|
||||
"""
|
||||
model = Model.from_pretrained(
|
||||
self.model_dir, cfg_dict=self.cfg, training=True)
|
||||
if isinstance(model, TorchModel) and hasattr(model, 'model'):
|
||||
return model.model
|
||||
elif isinstance(model, nn.Module):
|
||||
return model
|
||||
|
||||
def train(self, *args, **kwargs):
|
||||
if not self.data_val:
|
||||
self.gen_val()
|
||||
logger.info('Start training...')
|
||||
totaltime = datetime.datetime.now()
|
||||
|
||||
for stage, num_epoch in enumerate(self.stages):
|
||||
self.run_stage(stage, num_epoch)
|
||||
|
||||
# total time spent
|
||||
totaltime = datetime.datetime.now() - totaltime
|
||||
logger.info('Total time spent: {:.2f} hours\n'.format(
|
||||
totaltime.total_seconds() / 3600.0))
|
||||
|
||||
def run_stage(self, stage, num_epoch):
|
||||
"""
|
||||
Run training stages with correspond data
|
||||
|
||||
Args:
|
||||
stage: id of stage
|
||||
num_epoch: the number of epoch to run in this stage
|
||||
"""
|
||||
if num_epoch <= 0:
|
||||
logger.warning(f'Invalid epoch number, stage {stage} exit!')
|
||||
return
|
||||
logger.info(f'Starting stage {stage}...')
|
||||
dataset, dataloader = self.create_dataloader(
|
||||
self.conf_files[stage * 2], self.conf_files[stage * 2 + 1])
|
||||
it = iter(dataloader)
|
||||
for _ in range(num_epoch):
|
||||
self._current_epoch += 1
|
||||
epochtime = datetime.datetime.now()
|
||||
logger.info('Start epoch %d...', self._current_epoch)
|
||||
loss_train_epoch = 0.0
|
||||
validbatchs = 0
|
||||
for bi in range(self._train_iters):
|
||||
# prepare data
|
||||
feat, label = next(it)
|
||||
label = torch.reshape(label, (-1, ))
|
||||
feat = to_device(feat, self.device)
|
||||
label = to_device(label, self.device)
|
||||
# apply model
|
||||
self.optimizer.zero_grad()
|
||||
predict = self.model(feat)
|
||||
# calculate loss
|
||||
loss = self.loss_fn(
|
||||
torch.reshape(predict, (-1, self._num_classes)), label)
|
||||
if not np.isnan(loss.item()):
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
loss_train_epoch += loss.item()
|
||||
validbatchs += 1
|
||||
train_result = 'Epoch: {:04d}/{:04d}, batch: {:04d}/{:04d}, loss: {:.4f}'.format(
|
||||
self._current_epoch, self._max_epochs, bi + 1,
|
||||
self._train_iters, loss.item())
|
||||
logger.info(train_result)
|
||||
self._dump_log(train_result)
|
||||
|
||||
# average training loss in one epoch
|
||||
loss_train_epoch /= validbatchs
|
||||
loss_val_epoch = self.evaluate('')
|
||||
val_result = 'Evaluate epoch: {:04d}, loss_train: {:.4f}, loss_val: {:.4f}'.format(
|
||||
self._current_epoch, loss_train_epoch, loss_val_epoch)
|
||||
logger.info(val_result)
|
||||
self._dump_log(val_result)
|
||||
# check point
|
||||
ckpt_name = 'checkpoint_{:04d}_loss_train_{:.4f}_loss_val_{:.4f}.pth'.format(
|
||||
self._current_epoch, loss_train_epoch, loss_val_epoch)
|
||||
torch.save(self.model, os.path.join(self.work_dir, ckpt_name))
|
||||
# time spent per epoch
|
||||
epochtime = datetime.datetime.now() - epochtime
|
||||
logger.info('Epoch {:04d} time spent: {:.2f} hours'.format(
|
||||
self._current_epoch,
|
||||
epochtime.total_seconds() / 3600.0))
|
||||
dataloader.stop()
|
||||
dataset.release()
|
||||
logger.info(f'Stage {stage} is finished.')
|
||||
|
||||
def gen_val(self):
|
||||
"""
|
||||
generate validation set
|
||||
"""
|
||||
logger.info('Start generating validation set...')
|
||||
dataset, dataloader = self.create_dataloader(self.conf_files[2],
|
||||
self.conf_files[3])
|
||||
it = iter(dataloader)
|
||||
|
||||
self.data_val = []
|
||||
for bi in range(self._val_iters):
|
||||
logger.info('Iterating validation data %d', bi)
|
||||
feat, label = next(it)
|
||||
label = torch.reshape(label, (-1, ))
|
||||
self.data_val.append([feat, label])
|
||||
|
||||
dataloader.stop()
|
||||
dataset.release()
|
||||
logger.info('Finish generating validation set!')
|
||||
|
||||
def create_dataloader(self, base_path, finetune_path):
|
||||
dataset = KWSDataset(base_path, finetune_path, self._threads,
|
||||
self._single_rate, self._num_classes)
|
||||
dataloader = KWSDataLoader(
|
||||
dataset, batchsize=self._batch_size, numworkers=self._threads)
|
||||
dataloader.start()
|
||||
return dataset, dataloader
|
||||
|
||||
def evaluate(self, checkpoint_path: str, *args,
|
||||
**kwargs) -> Dict[str, float]:
|
||||
logger.info('Start validation...')
|
||||
loss_val_epoch = 0.0
|
||||
|
||||
with torch.no_grad():
|
||||
for feat, label in self.data_val:
|
||||
feat = to_device(feat, self.device)
|
||||
label = to_device(label, self.device)
|
||||
# apply model
|
||||
predict = self.model(feat)
|
||||
# calculate loss
|
||||
loss = self.loss_fn(
|
||||
torch.reshape(predict, (-1, self._num_classes)), label)
|
||||
loss_val_epoch += loss.item()
|
||||
logger.info('Finish validation.')
|
||||
return loss_val_epoch / self._val_iters
|
||||
|
||||
def _dump_log(self, msg):
|
||||
if is_master():
|
||||
with open(self.json_log_path, 'a+') as f:
|
||||
f.write(msg)
|
||||
f.write('\n')
|
||||
@@ -1,4 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import re
|
||||
import struct
|
||||
from typing import Union
|
||||
from urllib.parse import urlparse
|
||||
@@ -37,6 +38,23 @@ def audio_norm(x):
|
||||
return x
|
||||
|
||||
|
||||
def update_conf(origin_config_file, new_config_file, conf_item: [str, str]):
|
||||
|
||||
def repl(matched):
|
||||
key = matched.group(1)
|
||||
if key in conf_item:
|
||||
return conf_item[key]
|
||||
else:
|
||||
return None
|
||||
|
||||
with open(origin_config_file) as f:
|
||||
lines = f.readlines()
|
||||
with open(new_config_file, 'w') as f:
|
||||
for line in lines:
|
||||
line = re.sub(r'\$\{(.*)\}', repl, line)
|
||||
f.write(line)
|
||||
|
||||
|
||||
def extract_pcm_from_wav(wav: bytes) -> bytes:
|
||||
data = wav
|
||||
if len(data) > 44:
|
||||
|
||||
@@ -14,7 +14,11 @@ nltk
|
||||
numpy<=1.18
|
||||
# protobuf version beyond 3.20.0 is not compatible with TensorFlow 1.x, therefore is discouraged.
|
||||
protobuf>3,<3.21.0
|
||||
py_sound_connect
|
||||
ptflops
|
||||
py_sound_connect>=0.1
|
||||
pytorch_wavelets
|
||||
PyWavelets>=1.0.0
|
||||
scikit-learn
|
||||
SoundFile>0.10
|
||||
sox
|
||||
torchaudio
|
||||
|
||||
85
tests/trainers/audio/test_kws_farfield_trainer.py
Normal file
85
tests/trainers/audio/test_kws_farfield_trainer.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
POS_FILE = 'data/test/audios/wake_word_with_label_xyxy.wav'
|
||||
NEG_FILE = 'data/test/audios/speech_with_noise.wav'
|
||||
NOISE_FILE = 'data/test/audios/speech_with_noise.wav'
|
||||
INTERF_FILE = 'data/test/audios/speech_with_noise.wav'
|
||||
REF_FILE = 'data/test/audios/farend_speech.wav'
|
||||
NOISE_2CH_FILE = 'data/test/audios/noise_2ch.wav'
|
||||
|
||||
|
||||
class TestKwsFarfieldTrainer(unittest.TestCase):
|
||||
REVISION = 'beta'
|
||||
|
||||
def setUp(self):
|
||||
self.tmp_dir = tempfile.TemporaryDirectory().name
|
||||
print(f'tmp dir: {self.tmp_dir}')
|
||||
if not os.path.exists(self.tmp_dir):
|
||||
os.makedirs(self.tmp_dir)
|
||||
self.model_id = 'damo/speech_dfsmn_kws_char_farfield_16k_nihaomiya'
|
||||
|
||||
train_pos_list = self.create_list('pos.list', POS_FILE)
|
||||
train_neg_list = self.create_list('neg.list', NEG_FILE)
|
||||
train_noise1_list = self.create_list('noise.list', NOISE_FILE)
|
||||
train_noise2_list = self.create_list('noise_2ch.list', NOISE_2CH_FILE)
|
||||
train_interf_list = self.create_list('interf.list', INTERF_FILE)
|
||||
train_ref_list = self.create_list('ref.list', REF_FILE)
|
||||
|
||||
base_dict = dict(
|
||||
train_pos_list=train_pos_list,
|
||||
train_neg_list=train_neg_list,
|
||||
train_noise1_list=train_noise1_list)
|
||||
fintune_dict = dict(
|
||||
train_pos_list=train_pos_list,
|
||||
train_neg_list=train_neg_list,
|
||||
train_noise1_list=train_noise1_list,
|
||||
train_noise2_type='1',
|
||||
train_noise1_ratio='0.2',
|
||||
train_noise2_list=train_noise2_list,
|
||||
train_interf_list=train_interf_list,
|
||||
train_ref_list=train_ref_list)
|
||||
self.custom_conf = dict(
|
||||
basetrain_easy=base_dict,
|
||||
basetrain_normal=base_dict,
|
||||
basetrain_hard=base_dict,
|
||||
finetune_easy=fintune_dict,
|
||||
finetune_normal=fintune_dict,
|
||||
finetune_hard=fintune_dict)
|
||||
|
||||
def create_list(self, list_name, audio_file):
|
||||
pos_list_file = os.path.join(self.tmp_dir, list_name)
|
||||
with open(pos_list_file, 'w') as f:
|
||||
for i in range(10):
|
||||
f.write(f'{os.path.join(os.getcwd(), audio_file)}\n')
|
||||
train_pos_list = f'{pos_list_file}, 1.0'
|
||||
return train_pos_list
|
||||
|
||||
def tearDown(self) -> None:
|
||||
shutil.rmtree(self.tmp_dir, ignore_errors=True)
|
||||
super().tearDown()
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_normal(self):
|
||||
kwargs = dict(
|
||||
model=self.model_id,
|
||||
work_dir=self.tmp_dir,
|
||||
model_revision=self.REVISION,
|
||||
workers=2,
|
||||
max_epochs=2,
|
||||
train_iters_per_epoch=2,
|
||||
val_iters_per_epoch=1,
|
||||
custom_conf=self.custom_conf)
|
||||
|
||||
trainer = build_trainer(
|
||||
Trainers.speech_dfsmn_kws_char_farfield, default_args=kwargs)
|
||||
trainer.train()
|
||||
results_files = os.listdir(self.tmp_dir)
|
||||
self.assertIn(f'{trainer.timestamp}.log.json', results_files,
|
||||
f'work_dir:{self.tmp_dir}')
|
||||
Reference in New Issue
Block a user