Files
modelscope/tests/utils/test_config.py

83 lines
2.8 KiB
Python
Raw Normal View History

# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import tempfile
import unittest
from modelscope.utils.config import Config
obj = {'a': 1, 'b': {'c': [1, 2, 3], 'd': 'dd'}}
class ConfigTest(unittest.TestCase):
def test_json(self):
config_file = 'configs/examples/configuration.json'
cfg = Config.from_file(config_file)
self.assertEqual(cfg.a, 1)
self.assertEqual(cfg.b, obj['b'])
def test_yaml(self):
config_file = 'configs/examples/configuration.yaml'
cfg = Config.from_file(config_file)
self.assertEqual(cfg.a, 1)
self.assertEqual(cfg.b, obj['b'])
def test_py(self):
config_file = 'configs/examples/configuration.py'
cfg = Config.from_file(config_file)
self.assertEqual(cfg.a, 1)
self.assertEqual(cfg.b, obj['b'])
def test_dump(self):
config_file = 'configs/examples/configuration.py'
cfg = Config.from_file(config_file)
self.assertEqual(cfg.a, 1)
self.assertEqual(cfg.b, obj['b'])
pretty_text = 'a = 1\n'
pretty_text += "b = dict(c=[1, 2, 3], d='dd')\n"
json_str = '{"a": 1, "b": {"c": [1, 2, 3], "d": "dd"}}'
yaml_str = 'a: 1\nb:\n c:\n - 1\n - 2\n - 3\n d: dd\n'
with tempfile.NamedTemporaryFile(suffix='.json') as ofile:
self.assertEqual(pretty_text, cfg.dump())
cfg.dump(ofile.name)
with open(ofile.name, 'r') as infile:
self.assertEqual(json_str, infile.read())
with tempfile.NamedTemporaryFile(suffix='.yaml') as ofile:
cfg.dump(ofile.name)
with open(ofile.name, 'r') as infile:
self.assertEqual(yaml_str, infile.read())
def test_to_dict(self):
config_file = 'configs/examples/configuration.json'
cfg = Config.from_file(config_file)
d = cfg.to_dict()
print(d)
self.assertTrue(isinstance(d, dict))
def test_to_args(self):
def parse_fn(args):
parser = argparse.ArgumentParser(prog='PROG')
parser.add_argument('--model-dir', default='')
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--optimizer', default='')
parser.add_argument('--weight-decay', type=float, default=1e-7)
parser.add_argument(
'--save-checkpoint-epochs', type=int, default=30)
return parser.parse_args(args)
cfg = Config.from_file('configs/examples/plain_args.yaml')
args = cfg.to_args(parse_fn)
self.assertEqual(args.model_dir, 'path/to/model')
self.assertAlmostEqual(args.lr, 0.01)
self.assertAlmostEqual(args.weight_decay, 1e-6)
self.assertEqual(args.optimizer, 'Adam')
self.assertEqual(args.save_checkpoint_epochs, 20)
if __name__ == '__main__':
unittest.main()