mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 12:10:09 +01:00
[to #42322933] fix some feedback issues
1. Fix the conflict between local path and remote dataset name in the form of dataset_name='namespace/dataset_name' in MsDataset.load() function. 2. Modify the obj_key.startswith value in get_split_objects_map function to adapt to dir name 'xxx/' format. Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11820290 * fix the conflict between local path and namespace/dataset_name of the dataset_name * fix function: get_split_objects_map * add UT for loading local csv file * add new test case for test_load_local_csv function
This commit is contained in:
@@ -214,7 +214,9 @@ class MsDataset:
|
||||
return MsDataset.to_ms_dataset(dataset_inst, target=target)
|
||||
|
||||
dataset_name = os.path.expanduser(dataset_name)
|
||||
if is_relative_path(dataset_name) and dataset_name.count('/') == 1:
|
||||
is_local_path = os.path.exists(dataset_name)
|
||||
if is_relative_path(dataset_name) and dataset_name.count(
|
||||
'/') == 1 and not is_local_path:
|
||||
dataset_name_split = dataset_name.split('/')
|
||||
namespace = dataset_name_split[0].strip()
|
||||
dataset_name = dataset_name_split[1].strip()
|
||||
|
||||
@@ -160,7 +160,7 @@ def get_split_objects_map(file_map, objects):
|
||||
|
||||
for obj_key in objects:
|
||||
for k, v in file_map.items():
|
||||
if obj_key.startswith(v + '/'):
|
||||
if obj_key.startswith(v.rstrip('/') + '/'):
|
||||
res[k].append(obj_key)
|
||||
|
||||
return res
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from modelscope.models import Model
|
||||
@@ -32,6 +33,34 @@ class ImgPreprocessor(Preprocessor):
|
||||
}
|
||||
|
||||
|
||||
class GenLocalFile:
|
||||
|
||||
@staticmethod
|
||||
def gen_mock_data() -> (str, str):
|
||||
mock_data_list = [
|
||||
'Title,Content,Label', 'mock title1,mock content1,mock label1',
|
||||
'mock title2,mock content2,mock label2',
|
||||
'mock title3,mock content3,mock label3'
|
||||
]
|
||||
|
||||
mock_file_name = 'mock_file.csv'
|
||||
md = hashlib.md5()
|
||||
md.update('GenLocalFile.gen_mock_data.out_file_path'.encode('utf-8'))
|
||||
mock_dir = os.path.join(os.getcwd(), md.hexdigest())
|
||||
os.makedirs(mock_dir, exist_ok=True)
|
||||
mock_relative_path = os.path.join(md.hexdigest(), mock_file_name)
|
||||
with open(mock_relative_path, 'w') as f:
|
||||
for line in mock_data_list:
|
||||
f.write(line + '\n')
|
||||
|
||||
return mock_relative_path, md.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def clear_mock_dir(mock_dir) -> None:
|
||||
import shutil
|
||||
shutil.rmtree(mock_dir)
|
||||
|
||||
|
||||
class MsDatasetTest(unittest.TestCase):
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
@@ -57,6 +86,23 @@ class MsDatasetTest(unittest.TestCase):
|
||||
split='train').to_hf_dataset().select(range(5))
|
||||
print(next(iter(ms_ds_train)))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_load_local_csv(self):
|
||||
mock_relative_path, mock_dir_name = GenLocalFile.gen_mock_data()
|
||||
# To test dataset_name in the form of `xxx/xxx.csv`
|
||||
ds_from_single_file = MsDataset.load(mock_relative_path)
|
||||
# To test dataset_name in the form of `xxx/`
|
||||
ds_from_dir = MsDataset.load(mock_dir_name + '/')
|
||||
|
||||
GenLocalFile.clear_mock_dir(mock_dir_name)
|
||||
ds_from_single_file_sample = next(iter(ds_from_single_file))
|
||||
ds_from_dir_sample = next(iter(ds_from_dir))
|
||||
|
||||
print(ds_from_single_file_sample)
|
||||
print(ds_from_dir_sample)
|
||||
assert ds_from_single_file_sample
|
||||
assert ds_from_dir_sample
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_ds_basic(self):
|
||||
ms_ds_full = MsDataset.load(
|
||||
|
||||
Reference in New Issue
Block a user