diff --git a/modelscope/msdatasets/ms_dataset.py b/modelscope/msdatasets/ms_dataset.py index 9f34186c..e4948310 100644 --- a/modelscope/msdatasets/ms_dataset.py +++ b/modelscope/msdatasets/ms_dataset.py @@ -214,7 +214,9 @@ class MsDataset: return MsDataset.to_ms_dataset(dataset_inst, target=target) dataset_name = os.path.expanduser(dataset_name) - if is_relative_path(dataset_name) and dataset_name.count('/') == 1: + is_local_path = os.path.exists(dataset_name) + if is_relative_path(dataset_name) and dataset_name.count( + '/') == 1 and not is_local_path: dataset_name_split = dataset_name.split('/') namespace = dataset_name_split[0].strip() dataset_name = dataset_name_split[1].strip() diff --git a/modelscope/msdatasets/utils/dataset_utils.py b/modelscope/msdatasets/utils/dataset_utils.py index 785337eb..4c80af7d 100644 --- a/modelscope/msdatasets/utils/dataset_utils.py +++ b/modelscope/msdatasets/utils/dataset_utils.py @@ -160,7 +160,7 @@ def get_split_objects_map(file_map, objects): for obj_key in objects: for k, v in file_map.items(): - if obj_key.startswith(v + '/'): + if obj_key.startswith(v.rstrip('/') + '/'): res[k].append(obj_key) return res diff --git a/tests/msdatasets/test_ms_dataset.py b/tests/msdatasets/test_ms_dataset.py index 2bea0c4c..51074bca 100644 --- a/tests/msdatasets/test_ms_dataset.py +++ b/tests/msdatasets/test_ms_dataset.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. - +import hashlib +import os import unittest from modelscope.models import Model @@ -32,6 +33,34 @@ class ImgPreprocessor(Preprocessor): } +class GenLocalFile: + + @staticmethod + def gen_mock_data() -> (str, str): + mock_data_list = [ + 'Title,Content,Label', 'mock title1,mock content1,mock label1', + 'mock title2,mock content2,mock label2', + 'mock title3,mock content3,mock label3' + ] + + mock_file_name = 'mock_file.csv' + md = hashlib.md5() + md.update('GenLocalFile.gen_mock_data.out_file_path'.encode('utf-8')) + mock_dir = os.path.join(os.getcwd(), md.hexdigest()) + os.makedirs(mock_dir, exist_ok=True) + mock_relative_path = os.path.join(md.hexdigest(), mock_file_name) + with open(mock_relative_path, 'w') as f: + for line in mock_data_list: + f.write(line + '\n') + + return mock_relative_path, md.hexdigest() + + @staticmethod + def clear_mock_dir(mock_dir) -> None: + import shutil + shutil.rmtree(mock_dir) + + class MsDatasetTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @@ -57,6 +86,23 @@ class MsDatasetTest(unittest.TestCase): split='train').to_hf_dataset().select(range(5)) print(next(iter(ms_ds_train))) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_load_local_csv(self): + mock_relative_path, mock_dir_name = GenLocalFile.gen_mock_data() + # To test dataset_name in the form of `xxx/xxx.csv` + ds_from_single_file = MsDataset.load(mock_relative_path) + # To test dataset_name in the form of `xxx/` + ds_from_dir = MsDataset.load(mock_dir_name + '/') + + GenLocalFile.clear_mock_dir(mock_dir_name) + ds_from_single_file_sample = next(iter(ds_from_single_file)) + ds_from_dir_sample = next(iter(ds_from_dir)) + + print(ds_from_single_file_sample) + print(ds_from_dir_sample) + assert ds_from_single_file_sample + assert ds_from_dir_sample + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_ds_basic(self): ms_ds_full = MsDataset.load(