Files
modelscope/tests/utils/test_config.py
wenmeng.zwm 8a030ead72 [to #42362853] feat: rename config to configuration and remove repeated task fileds
1. rename maas_config to configuration
2. remove task field image and video, using cv instead

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9010802
2022-06-13 19:44:34 +08:00

86 lines
2.9 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import os.path as osp
import tempfile
import unittest
from pathlib import Path
from modelscope.fileio import dump, load
from modelscope.utils.config import Config
obj = {'a': 1, 'b': {'c': [1, 2, 3], 'd': 'dd'}}
class ConfigTest(unittest.TestCase):
def test_json(self):
config_file = 'configs/examples/configuration.json'
cfg = Config.from_file(config_file)
self.assertEqual(cfg.a, 1)
self.assertEqual(cfg.b, obj['b'])
def test_yaml(self):
config_file = 'configs/examples/configuration.yaml'
cfg = Config.from_file(config_file)
self.assertEqual(cfg.a, 1)
self.assertEqual(cfg.b, obj['b'])
def test_py(self):
config_file = 'configs/examples/configuration.py'
cfg = Config.from_file(config_file)
self.assertEqual(cfg.a, 1)
self.assertEqual(cfg.b, obj['b'])
def test_dump(self):
config_file = 'configs/examples/configuration.py'
cfg = Config.from_file(config_file)
self.assertEqual(cfg.a, 1)
self.assertEqual(cfg.b, obj['b'])
pretty_text = 'a = 1\n'
pretty_text += "b = dict(c=[1, 2, 3], d='dd')\n"
json_str = '{"a": 1, "b": {"c": [1, 2, 3], "d": "dd"}}'
yaml_str = 'a: 1\nb:\n c:\n - 1\n - 2\n - 3\n d: dd\n'
with tempfile.NamedTemporaryFile(suffix='.json') as ofile:
self.assertEqual(pretty_text, cfg.dump())
cfg.dump(ofile.name)
with open(ofile.name, 'r') as infile:
self.assertEqual(json_str, infile.read())
with tempfile.NamedTemporaryFile(suffix='.yaml') as ofile:
cfg.dump(ofile.name)
with open(ofile.name, 'r') as infile:
self.assertEqual(yaml_str, infile.read())
def test_to_dict(self):
config_file = 'configs/examples/configuration.json'
cfg = Config.from_file(config_file)
d = cfg.to_dict()
print(d)
self.assertTrue(isinstance(d, dict))
def test_to_args(self):
def parse_fn(args):
parser = argparse.ArgumentParser(prog='PROG')
parser.add_argument('--model-dir', default='')
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--optimizer', default='')
parser.add_argument('--weight-decay', type=float, default=1e-7)
parser.add_argument(
'--save-checkpoint-epochs', type=int, default=30)
return parser.parse_args(args)
cfg = Config.from_file('configs/examples/plain_args.yaml')
args = cfg.to_args(parse_fn)
self.assertEqual(args.model_dir, 'path/to/model')
self.assertAlmostEqual(args.lr, 0.01)
self.assertAlmostEqual(args.weight_decay, 1e-6)
self.assertEqual(args.optimizer, 'Adam')
self.assertEqual(args.save_checkpoint_epochs, 20)
if __name__ == '__main__':
unittest.main()