pre-commit

This commit is contained in:
gg
2023-07-12 09:43:24 +08:00
parent 574b4568ff
commit 49c6d8bcf6
2 changed files with 19 additions and 15 deletions

View File

@@ -1,16 +1,19 @@
import os
from modelscope.msdatasets.dataset_cls.custom_datasets import ASRDataset
from modelscope.utils.constant import DownloadMode
from modelscope.trainers import build_trainer
from modelscope.metainfo import Trainers
from modelscope.msdatasets.dataset_cls.custom_datasets import ASRDataset
from modelscope.trainers import build_trainer
from modelscope.utils.constant import DownloadMode
def modelscope_finetune(params):
if not os.path.exists(params.output_dir):
os.makedirs(params.output_dir, exist_ok=True)
# dataset split ["train", "validation"]
ds_dict = ASRDataset.load(params.data_path, namespace='speech_asr', download_mode=params.download_mode)
ds_dict = ASRDataset.load(
params.data_path,
namespace='speech_asr',
download_mode=params.download_mode)
kwargs = dict(
model=params.model,
data_dir=ds_dict,
@@ -27,7 +30,8 @@ if __name__ == '__main__':
from funasr.utils.modelscope_param import modelscope_args
params = modelscope_args(
model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'
model=
'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'
)
params.output_dir = './checkpoint' # 模型保存路径
params.data_path = 'speech_asr_aishell1_trainsets' # 数据路径可以为modelscope中已上传数据也可以是本地数据

View File

@@ -15,8 +15,8 @@ class ASRDataset(MsDataset):
@classmethod
def load_core(cls, data_dir, data_set):
wav_file = os.path.join(data_dir, data_set, "wav.scp")
text_file = os.path.join(data_dir, data_set, "text")
wav_file = os.path.join(data_dir, data_set, 'wav.scp')
text_file = os.path.join(data_dir, data_set, 'text')
with open(wav_file) as f:
wav_lines = f.readlines()
with open(text_file) as f:
@@ -24,8 +24,8 @@ class ASRDataset(MsDataset):
data_list = []
for wav_line, text_line in zip(wav_lines, text_lines):
item = {}
item["Audio:FILE"] = wav_line.strip().split()[-1]
item["Text:LABEL"] = " ".join(text_line.strip().split()[1:])
item['Audio:FILE'] = wav_line.strip().split()[-1]
item['Text:LABEL'] = ' '.join(text_line.strip().split()[1:])
data_list.append(item)
return data_list
@@ -33,17 +33,17 @@ class ASRDataset(MsDataset):
def load(
cls,
dataset_name,
namespace="speech_asr",
train_set="train",
dev_set="validation",
namespace='speech_asr',
train_set='train',
dev_set='validation',
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS,
):
if os.path.exists(dataset_name):
data_dir = dataset_name
ds_dict = {}
ds_dict["train"] = cls.load_core(data_dir, train_set)
ds_dict["validation"] = cls.load_core(data_dir, dev_set)
ds_dict["raw_data_dir"] = data_dir
ds_dict['train'] = cls.load_core(data_dir, train_set)
ds_dict['validation'] = cls.load_core(data_dir, dev_set)
ds_dict['raw_data_dir'] = data_dir
return ds_dict
else:
from modelscope.msdatasets import MsDataset