diff --git a/modelscope/msdatasets/dataset_cls/custom_datasets/audio/asr_dataset.py b/modelscope/msdatasets/dataset_cls/custom_datasets/audio/asr_dataset.py index c0696615..326508e6 100644 --- a/modelscope/msdatasets/dataset_cls/custom_datasets/audio/asr_dataset.py +++ b/modelscope/msdatasets/dataset_cls/custom_datasets/audio/asr_dataset.py @@ -3,6 +3,8 @@ import os from modelscope.msdatasets.ms_dataset import MsDataset +from modelscope.utils.constant import DownloadMode +from typing import Optional class ASRDataset(MsDataset): @@ -33,16 +35,21 @@ class ASRDataset(MsDataset): dataset_name, namespace='speech_asr', train_set='train', - dev_set='validation'): - if os.path.exists(dataset_name): - data_dir = dataset_name - ds_dict = {} - ds_dict['train'] = cls.load_core(data_dir, train_set) - ds_dict['validation'] = cls.load_core(data_dir, dev_set) - ds_dict['raw_data_dir'] = data_dir + dev_set='validation', + download_mode: Optional[DownloadMode] = None): + if download_mode is not None: + ds_dict = MsDataset.load( + dataset_name=dataset_name, namespace=namespace, download_mode=download_mode) return ds_dict else: - from modelscope.msdatasets import MsDataset - ds_dict = MsDataset.load( - dataset_name=dataset_name, namespace=namespace) - return ds_dict + if os.path.exists(dataset_name): + data_dir = dataset_name + ds_dict = {} + ds_dict['train'] = cls.load_core(data_dir, train_set) + ds_dict['validation'] = cls.load_core(data_dir, dev_set) + ds_dict['raw_data_dir'] = data_dir + return ds_dict + else: + ds_dict = MsDataset.load( + dataset_name=dataset_name, namespace=namespace) + return ds_dict