2022-05-17 10:15:00 +08:00
|
|
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
|
|
|
import tempfile
|
|
|
|
|
import unittest
|
|
|
|
|
|
2023-03-21 18:10:10 +08:00
|
|
|
import numpy as np
|
|
|
|
|
|
2022-06-09 20:16:26 +08:00
|
|
|
from modelscope.fileio.io import dump, dumps, load
|
2022-05-17 10:15:00 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class FileIOTest(unittest.TestCase):
|
|
|
|
|
|
|
|
|
|
def test_format(self, format='json'):
|
2023-03-21 18:10:10 +08:00
|
|
|
obj = [
|
|
|
|
|
1, 2, 3, 'str', {
|
|
|
|
|
'model': 'resnet'
|
|
|
|
|
},
|
|
|
|
|
np.array([[1, 2]], dtype=np.float16),
|
|
|
|
|
np.array([[1, 2]], dtype=np.float32),
|
|
|
|
|
np.array([[1, 2]], dtype=np.float64),
|
|
|
|
|
np.array([[1, 2]], dtype=np.int64), (1, 2)
|
|
|
|
|
]
|
2022-05-17 10:15:00 +08:00
|
|
|
result_str = dumps(obj, format)
|
|
|
|
|
temp_name = tempfile.gettempdir() + '/' + next(
|
|
|
|
|
tempfile._get_candidate_names()) + '.' + format
|
|
|
|
|
dump(obj, temp_name)
|
|
|
|
|
obj_load = load(temp_name)
|
2023-03-21 18:10:10 +08:00
|
|
|
|
|
|
|
|
self.assertEqual(len(obj), len(obj_load))
|
|
|
|
|
for i, obj_i in enumerate(obj):
|
|
|
|
|
if isinstance(obj_i, list):
|
|
|
|
|
self.assertListEqual(obj_i, obj_load[i])
|
|
|
|
|
elif isinstance(obj_i, np.ndarray):
|
|
|
|
|
self.assertListEqual(obj_i.tolist(), obj_load[i].tolist())
|
|
|
|
|
elif isinstance(obj_i, dict):
|
|
|
|
|
self.assertDictEqual(obj_i, obj_load[i])
|
|
|
|
|
else:
|
|
|
|
|
self.assertEqual(obj_i, obj_load[i])
|
|
|
|
|
|
2022-05-17 10:15:00 +08:00
|
|
|
with open(temp_name, 'r') as infile:
|
|
|
|
|
self.assertEqual(result_str, infile.read())
|
|
|
|
|
|
|
|
|
|
with self.assertRaises(TypeError):
|
|
|
|
|
obj_load = load(temp_name + 's')
|
|
|
|
|
|
|
|
|
|
with self.assertRaises(TypeError):
|
|
|
|
|
dump(obj, temp_name + 's')
|
|
|
|
|
|
|
|
|
|
def test_yaml(self):
|
|
|
|
|
self.test_format('yaml')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
unittest.main()
|