From 574b4568ffc41d709d27f5882977ab6ff3fe85bd Mon Sep 17 00:00:00 2001 From: gg Date: Tue, 11 Jul 2023 18:36:54 +0800 Subject: [PATCH] flake8 --- .../finetune_speech_recognition.py | 3 +- .../custom_datasets/audio/asr_dataset.py | 35 +++++++++++-------- 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/examples/pytorch/auto_speech_recognition/finetune_speech_recognition.py b/examples/pytorch/auto_speech_recognition/finetune_speech_recognition.py index 2fee3a2e..1716d8a0 100644 --- a/examples/pytorch/auto_speech_recognition/finetune_speech_recognition.py +++ b/examples/pytorch/auto_speech_recognition/finetune_speech_recognition.py @@ -27,8 +27,7 @@ if __name__ == '__main__': from funasr.utils.modelscope_param import modelscope_args params = modelscope_args( - model= - 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' + model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' ) params.output_dir = './checkpoint' # 模型保存路径 params.data_path = 'speech_asr_aishell1_trainsets' # 数据路径,可以为modelscope中已上传数据,也可以是本地数据 diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/audio/asr_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/audio/asr_dataset.py index 73a40813..9e64bcb3 100644 --- a/modelscope/msdatasets/dataset_cls/custom_datasets/audio/asr_dataset.py +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/audio/asr_dataset.py @@ -15,8 +15,8 @@ class ASRDataset(MsDataset): @classmethod def load_core(cls, data_dir, data_set): - wav_file = os.path.join(data_dir, data_set, 'wav.scp') - text_file = os.path.join(data_dir, data_set, 'text') + wav_file = os.path.join(data_dir, data_set, "wav.scp") + text_file = os.path.join(data_dir, data_set, "text") with open(wav_file) as f: wav_lines = f.readlines() with open(text_file) as f: @@ -24,28 +24,33 @@ class ASRDataset(MsDataset): data_list = [] for wav_line, text_line in zip(wav_lines, text_lines): item = {} - item['Audio:FILE'] = wav_line.strip().split()[-1] - item['Text:LABEL'] = ' '.join(text_line.strip().split()[1:]) + item["Audio:FILE"] = wav_line.strip().split()[-1] + item["Text:LABEL"] = " ".join(text_line.strip().split()[1:]) data_list.append(item) return data_list @classmethod - def load(cls, - dataset_name, - namespace='speech_asr', - train_set='train', - dev_set='validation', - download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS): + def load( + cls, + dataset_name, + namespace="speech_asr", + train_set="train", + dev_set="validation", + download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS, + ): if os.path.exists(dataset_name): data_dir = dataset_name ds_dict = {} - ds_dict['train'] = cls.load_core(data_dir, train_set) - ds_dict['validation'] = cls.load_core(data_dir, dev_set) - ds_dict['raw_data_dir'] = data_dir + ds_dict["train"] = cls.load_core(data_dir, train_set) + ds_dict["validation"] = cls.load_core(data_dir, dev_set) + ds_dict["raw_data_dir"] = data_dir return ds_dict else: from modelscope.msdatasets import MsDataset + ds_dict = MsDataset.load( - dataset_name=dataset_name, namespace=namespace, download_mode=download_mode) + dataset_name=dataset_name, + namespace=namespace, + download_mode=download_mode, + ) return ds_dict -