[to #42322933] fix some feedback issues

1. Fix the conflict between local path and remote dataset name in the form of dataset_name='namespace/dataset_name' in MsDataset.load() function.
2. Modify the obj_key.startswith value in get_split_objects_map function to adapt to dir name 'xxx/' format.

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11820290

* fix the conflict between local path and namespace/dataset_name of the dataset_name

* fix function: get_split_objects_map

* add UT for loading local csv file

* add new test case for test_load_local_csv function
This commit is contained in:
xingjun.wxj
2023-03-01 20:13:31 +08:00
committed by wenmeng.zwm
parent eec5702409
commit 8e050f5876
3 changed files with 51 additions and 3 deletions

View File

@@ -214,7 +214,9 @@ class MsDataset:
return MsDataset.to_ms_dataset(dataset_inst, target=target)
dataset_name = os.path.expanduser(dataset_name)
if is_relative_path(dataset_name) and dataset_name.count('/') == 1:
is_local_path = os.path.exists(dataset_name)
if is_relative_path(dataset_name) and dataset_name.count(
'/') == 1 and not is_local_path:
dataset_name_split = dataset_name.split('/')
namespace = dataset_name_split[0].strip()
dataset_name = dataset_name_split[1].strip()

View File

@@ -160,7 +160,7 @@ def get_split_objects_map(file_map, objects):
for obj_key in objects:
for k, v in file_map.items():
if obj_key.startswith(v + '/'):
if obj_key.startswith(v.rstrip('/') + '/'):
res[k].append(obj_key)
return res

View File

@@ -1,5 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import hashlib
import os
import unittest
from modelscope.models import Model
@@ -32,6 +33,34 @@ class ImgPreprocessor(Preprocessor):
}
class GenLocalFile:
@staticmethod
def gen_mock_data() -> (str, str):
mock_data_list = [
'Title,Content,Label', 'mock title1,mock content1,mock label1',
'mock title2,mock content2,mock label2',
'mock title3,mock content3,mock label3'
]
mock_file_name = 'mock_file.csv'
md = hashlib.md5()
md.update('GenLocalFile.gen_mock_data.out_file_path'.encode('utf-8'))
mock_dir = os.path.join(os.getcwd(), md.hexdigest())
os.makedirs(mock_dir, exist_ok=True)
mock_relative_path = os.path.join(md.hexdigest(), mock_file_name)
with open(mock_relative_path, 'w') as f:
for line in mock_data_list:
f.write(line + '\n')
return mock_relative_path, md.hexdigest()
@staticmethod
def clear_mock_dir(mock_dir) -> None:
import shutil
shutil.rmtree(mock_dir)
class MsDatasetTest(unittest.TestCase):
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@@ -57,6 +86,23 @@ class MsDatasetTest(unittest.TestCase):
split='train').to_hf_dataset().select(range(5))
print(next(iter(ms_ds_train)))
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_load_local_csv(self):
mock_relative_path, mock_dir_name = GenLocalFile.gen_mock_data()
# To test dataset_name in the form of `xxx/xxx.csv`
ds_from_single_file = MsDataset.load(mock_relative_path)
# To test dataset_name in the form of `xxx/`
ds_from_dir = MsDataset.load(mock_dir_name + '/')
GenLocalFile.clear_mock_dir(mock_dir_name)
ds_from_single_file_sample = next(iter(ds_from_single_file))
ds_from_dir_sample = next(iter(ds_from_dir))
print(ds_from_single_file_sample)
print(ds_from_dir_sample)
assert ds_from_single_file_sample
assert ds_from_dir_sample
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_ds_basic(self):
ms_ds_full = MsDataset.load(