mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
pre-commit
This commit is contained in:
@@ -1,16 +1,19 @@
|
||||
import os
|
||||
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets import ASRDataset
|
||||
from modelscope.utils.constant import DownloadMode
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets import ASRDataset
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.utils.constant import DownloadMode
|
||||
|
||||
|
||||
def modelscope_finetune(params):
|
||||
if not os.path.exists(params.output_dir):
|
||||
os.makedirs(params.output_dir, exist_ok=True)
|
||||
# dataset split ["train", "validation"]
|
||||
ds_dict = ASRDataset.load(params.data_path, namespace='speech_asr', download_mode=params.download_mode)
|
||||
ds_dict = ASRDataset.load(
|
||||
params.data_path,
|
||||
namespace='speech_asr',
|
||||
download_mode=params.download_mode)
|
||||
kwargs = dict(
|
||||
model=params.model,
|
||||
data_dir=ds_dict,
|
||||
@@ -27,7 +30,8 @@ if __name__ == '__main__':
|
||||
from funasr.utils.modelscope_param import modelscope_args
|
||||
|
||||
params = modelscope_args(
|
||||
model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'
|
||||
model=
|
||||
'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'
|
||||
)
|
||||
params.output_dir = './checkpoint' # 模型保存路径
|
||||
params.data_path = 'speech_asr_aishell1_trainsets' # 数据路径,可以为modelscope中已上传数据,也可以是本地数据
|
||||
|
||||
@@ -15,8 +15,8 @@ class ASRDataset(MsDataset):
|
||||
|
||||
@classmethod
|
||||
def load_core(cls, data_dir, data_set):
|
||||
wav_file = os.path.join(data_dir, data_set, "wav.scp")
|
||||
text_file = os.path.join(data_dir, data_set, "text")
|
||||
wav_file = os.path.join(data_dir, data_set, 'wav.scp')
|
||||
text_file = os.path.join(data_dir, data_set, 'text')
|
||||
with open(wav_file) as f:
|
||||
wav_lines = f.readlines()
|
||||
with open(text_file) as f:
|
||||
@@ -24,8 +24,8 @@ class ASRDataset(MsDataset):
|
||||
data_list = []
|
||||
for wav_line, text_line in zip(wav_lines, text_lines):
|
||||
item = {}
|
||||
item["Audio:FILE"] = wav_line.strip().split()[-1]
|
||||
item["Text:LABEL"] = " ".join(text_line.strip().split()[1:])
|
||||
item['Audio:FILE'] = wav_line.strip().split()[-1]
|
||||
item['Text:LABEL'] = ' '.join(text_line.strip().split()[1:])
|
||||
data_list.append(item)
|
||||
return data_list
|
||||
|
||||
@@ -33,17 +33,17 @@ class ASRDataset(MsDataset):
|
||||
def load(
|
||||
cls,
|
||||
dataset_name,
|
||||
namespace="speech_asr",
|
||||
train_set="train",
|
||||
dev_set="validation",
|
||||
namespace='speech_asr',
|
||||
train_set='train',
|
||||
dev_set='validation',
|
||||
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS,
|
||||
):
|
||||
if os.path.exists(dataset_name):
|
||||
data_dir = dataset_name
|
||||
ds_dict = {}
|
||||
ds_dict["train"] = cls.load_core(data_dir, train_set)
|
||||
ds_dict["validation"] = cls.load_core(data_dir, dev_set)
|
||||
ds_dict["raw_data_dir"] = data_dir
|
||||
ds_dict['train'] = cls.load_core(data_dir, train_set)
|
||||
ds_dict['validation'] = cls.load_core(data_dir, dev_set)
|
||||
ds_dict['raw_data_dir'] = data_dir
|
||||
return ds_dict
|
||||
else:
|
||||
from modelscope.msdatasets import MsDataset
|
||||
|
||||
Reference in New Issue
Block a user