[to #41402703] add basic modules

* add constant
 * add logger module
 * add registry and builder module
 * add fileio module
 * add requirements and setup.cfg
 * add config module and tests
 * add citest script

Link: https://code.aone.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8718998
This commit is contained in:
wenmeng.zwm
2022-05-17 10:15:00 +08:00
parent 02492e2bff
commit 0a756f6a0d
33 changed files with 1894 additions and 0 deletions

13
.dev_scripts/citest.sh Normal file
View File

@@ -0,0 +1,13 @@
pip install -r requirements/runtime.txt
pip install -r requirements/tests.txt
# linter test
# use internal project for pre-commit due to the network problem
pre-commit run --all-files
if [ $? -ne 0 ]; then
echo "linter test failed, please run 'pre-commit run --all-files' to check"
exit -1
fi
PYTHONPATH=. python tests/run.py

126
.gitignore vendored Normal file
View File

@@ -0,0 +1,126 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
/package
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
data
.vscode
.idea
# custom
*.pkl
*.pkl.json
*.log.json
*.whl
*.tar.gz
*.swp
*.log
*.tar.gz
source.sh
tensorboard.sh
.DS_Store
replace.sh
# Pytorch
*.pth

1
configs/README.md Normal file
View File

@@ -0,0 +1 @@
This folder will host example configs for each model supported by maas_lib.

View File

@@ -0,0 +1,7 @@
{
"a": 1,
"b" : {
"c": [1,2,3],
"d" : "dd"
}
}

View File

@@ -0,0 +1,2 @@
a = 1
b = dict(c=[1,2,3], d='dd')

View File

@@ -0,0 +1,4 @@
a: 1
b:
c: [1,2,3]
d: dd

View File

@@ -0,0 +1,5 @@
model_dir: path/to/model
lr: 0.01
optimizer: Adam
weight_decay: 1e-6
save_checkpoint_epochs: 20

View File

@@ -0,0 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .version import __version__
__all__ = ['__version__']

View File

@@ -0,0 +1 @@
from .io import dump, dumps, load

325
maas_lib/fileio/file.py Normal file
View File

@@ -0,0 +1,325 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import contextlib
import os
import tempfile
from abc import ABCMeta, abstractmethod
from pathlib import Path
from typing import Generator, Union
import requests
class Storage(metaclass=ABCMeta):
"""Abstract class of storage.
All backends need to implement two apis: ``read()`` and ``read_text()``.
``read()`` reads the file as a byte stream and ``read_text()`` reads
the file as texts.
"""
@abstractmethod
def read(self, filepath: str):
pass
@abstractmethod
def read_text(self, filepath: str):
pass
@abstractmethod
def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
pass
@abstractmethod
def write_text(self,
obj: str,
filepath: Union[str, Path],
encoding: str = 'utf-8') -> None:
pass
class LocalStorage(Storage):
"""Local hard disk storage"""
def read(self, filepath: Union[str, Path]) -> bytes:
"""Read data from a given ``filepath`` with 'rb' mode.
Args:
filepath (str or Path): Path to read data.
Returns:
bytes: Expected bytes object.
"""
with open(filepath, 'rb') as f:
content = f.read()
return content
def read_text(self,
filepath: Union[str, Path],
encoding: str = 'utf-8') -> str:
"""Read data from a given ``filepath`` with 'r' mode.
Args:
filepath (str or Path): Path to read data.
encoding (str): The encoding format used to open the ``filepath``.
Default: 'utf-8'.
Returns:
str: Expected text reading from ``filepath``.
"""
with open(filepath, 'r', encoding=encoding) as f:
value_buf = f.read()
return value_buf
def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
"""Write data to a given ``filepath`` with 'wb' mode.
Note:
``put`` will create a directory if the directory of ``filepath``
does not exist.
Args:
obj (bytes): Data to be written.
filepath (str or Path): Path to write data.
"""
dirname = os.path.dirname(filepath)
if dirname and not os.path.exists(dirname):
os.makedirs(dirname)
with open(filepath, 'wb') as f:
f.write(obj)
def write_text(self,
obj: str,
filepath: Union[str, Path],
encoding: str = 'utf-8') -> None:
"""Write data to a given ``filepath`` with 'w' mode.
Note:
``put_text`` will create a directory if the directory of
``filepath`` does not exist.
Args:
obj (str): Data to be written.
filepath (str or Path): Path to write data.
encoding (str): The encoding format used to open the ``filepath``.
Default: 'utf-8'.
"""
dirname = os.path.dirname(filepath)
if dirname and not os.path.exists(dirname):
os.makedirs(dirname)
with open(filepath, 'w', encoding=encoding) as f:
f.write(obj)
@contextlib.contextmanager
def as_local_path(
self,
filepath: Union[str,
Path]) -> Generator[Union[str, Path], None, None]:
"""Only for unified API and do nothing."""
yield filepath
class HTTPStorage(Storage):
"""HTTP and HTTPS storage."""
def read(self, url):
r = requests.get(url)
r.raise_for_status()
return r.content
def read_text(self, url):
r = requests.get(url)
r.raise_for_status()
return r.text
@contextlib.contextmanager
def as_local_path(
self, filepath: str) -> Generator[Union[str, Path], None, None]:
"""Download a file from ``filepath``.
``as_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
can be called with ``with`` statement, and when exists from the
``with`` statement, the temporary path will be released.
Args:
filepath (str): Download a file from ``filepath``.
Examples:
>>> storage = HTTPStorage()
>>> # After existing from the ``with`` clause,
>>> # the path will be removed
>>> with storage.get_local_path('http://path/to/file') as path:
... # do something here
"""
try:
f = tempfile.NamedTemporaryFile(delete=False)
f.write(self.read(filepath))
f.close()
yield f.name
finally:
os.remove(f.name)
def write(self, obj: bytes, url: Union[str, Path]) -> None:
raise NotImplementedError('write is not supported by HTTP Storage')
def write_text(self,
obj: str,
url: Union[str, Path],
encoding: str = 'utf-8') -> None:
raise NotImplementedError(
'write_text is not supported by HTTP Storage')
class OSSStorage(Storage):
"""OSS storage."""
def __init__(self, oss_config_file=None):
# read from config file or env var
raise NotImplementedError(
'OSSStorage.__init__ to be implemented in the future')
def read(self, filepath):
raise NotImplementedError(
'OSSStorage.read to be implemented in the future')
def read_text(self, filepath, encoding='utf-8'):
raise NotImplementedError(
'OSSStorage.read_text to be implemented in the future')
@contextlib.contextmanager
def as_local_path(
self, filepath: str) -> Generator[Union[str, Path], None, None]:
"""Download a file from ``filepath``.
``as_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
can be called with ``with`` statement, and when exists from the
``with`` statement, the temporary path will be released.
Args:
filepath (str): Download a file from ``filepath``.
Examples:
>>> storage = OSSStorage()
>>> # After existing from the ``with`` clause,
>>> # the path will be removed
>>> with storage.get_local_path('http://path/to/file') as path:
... # do something here
"""
try:
f = tempfile.NamedTemporaryFile(delete=False)
f.write(self.read(filepath))
f.close()
yield f.name
finally:
os.remove(f.name)
def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
raise NotImplementedError(
'OSSStorage.write to be implemented in the future')
def write_text(self,
obj: str,
filepath: Union[str, Path],
encoding: str = 'utf-8') -> None:
raise NotImplementedError(
'OSSStorage.write_text to be implemented in the future')
G_STORAGES = {}
class File(object):
_prefix_to_storage: dict = {
'oss': OSSStorage,
'http': HTTPStorage,
'https': HTTPStorage,
'local': LocalStorage,
}
@staticmethod
def _get_storage(uri):
assert isinstance(uri,
str), f'uri should be str type, buf got {type(uri)}'
if '://' not in uri:
# local path
storage_type = 'local'
else:
prefix, _ = uri.split('://')
storage_type = prefix
assert storage_type in File._prefix_to_storage, \
f'Unsupported uri {uri}, valid prefixs: '\
f'{list(File._prefix_to_storage.keys())}'
if storage_type not in G_STORAGES:
G_STORAGES[storage_type] = File._prefix_to_storage[storage_type]()
return G_STORAGES[storage_type]
@staticmethod
def read(uri: str) -> bytes:
"""Read data from a given ``filepath`` with 'rb' mode.
Args:
filepath (str or Path): Path to read data.
Returns:
bytes: Expected bytes object.
"""
storage = File._get_storage(uri)
return storage.read(uri)
@staticmethod
def read_text(uri: Union[str, Path], encoding: str = 'utf-8') -> str:
"""Read data from a given ``filepath`` with 'r' mode.
Args:
filepath (str or Path): Path to read data.
encoding (str): The encoding format used to open the ``filepath``.
Default: 'utf-8'.
Returns:
str: Expected text reading from ``filepath``.
"""
storage = File._get_storage(uri)
return storage.read_text(uri)
@staticmethod
def write(obj: bytes, uri: Union[str, Path]) -> None:
"""Write data to a given ``filepath`` with 'wb' mode.
Note:
``put`` will create a directory if the directory of ``filepath``
does not exist.
Args:
obj (bytes): Data to be written.
filepath (str or Path): Path to write data.
"""
storage = File._get_storage(uri)
return storage.write(obj, uri)
@staticmethod
def write_text(obj: str, uri: str, encoding: str = 'utf-8') -> None:
"""Write data to a given ``filepath`` with 'w' mode.
Note:
``put_text`` will create a directory if the directory of
``filepath`` does not exist.
Args:
obj (str): Data to be written.
filepath (str or Path): Path to write data.
encoding (str): The encoding format used to open the ``filepath``.
Default: 'utf-8'.
"""
storage = File._get_storage(uri)
return storage.write_text(obj, uri)
@contextlib.contextmanager
def as_local_path(uri: str) -> Generator[Union[str, Path], None, None]:
"""Only for unified API and do nothing."""
storage = File._get_storage(uri)
with storage.as_local_path(uri) as local_path:
yield local_path

View File

@@ -0,0 +1,3 @@
from .base import FormatHandler
from .json import JsonHandler
from .yaml import YamlHandler

View File

@@ -0,0 +1,20 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from abc import ABCMeta, abstractmethod
class FormatHandler(metaclass=ABCMeta):
# if `text_format` is True, file
# should use text mode otherwise binary mode
text_mode = True
@abstractmethod
def load(self, file, **kwargs):
pass
@abstractmethod
def dump(self, obj, file, **kwargs):
pass
@abstractmethod
def dumps(self, obj, **kwargs):
pass

View File

@@ -0,0 +1,35 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import json
import numpy as np
from .base import FormatHandler
def set_default(obj):
"""Set default json values for non-serializable values.
It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list.
It also converts ``np.generic`` (including ``np.int32``, ``np.float32``,
etc.) into plain numbers of plain python built-in types.
"""
if isinstance(obj, (set, range)):
return list(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, np.generic):
return obj.item()
raise TypeError(f'{type(obj)} is unsupported for json dump')
class JsonHandler(FormatHandler):
def load(self, file):
return json.load(file)
def dump(self, obj, file, **kwargs):
kwargs.setdefault('default', set_default)
json.dump(obj, file, **kwargs)
def dumps(self, obj, **kwargs):
kwargs.setdefault('default', set_default)
return json.dumps(obj, **kwargs)

View File

@@ -0,0 +1,25 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import yaml
try:
from yaml import CDumper as Dumper
from yaml import CLoader as Loader
except ImportError:
from yaml import Loader, Dumper # type: ignore
from .base import FormatHandler # isort:skip
class YamlHandler(FormatHandler):
def load(self, file, **kwargs):
kwargs.setdefault('Loader', Loader)
return yaml.load(file, **kwargs)
def dump(self, obj, file, **kwargs):
kwargs.setdefault('Dumper', Dumper)
yaml.dump(obj, file, **kwargs)
def dumps(self, obj, **kwargs):
kwargs.setdefault('Dumper', Dumper)
return yaml.dump(obj, **kwargs)

127
maas_lib/fileio/io.py Normal file
View File

@@ -0,0 +1,127 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Copyright (c) OpenMMLab. All rights reserved.
from io import BytesIO, StringIO
from pathlib import Path
from .file import File
from .format import JsonHandler, YamlHandler
format_handlers = {
'json': JsonHandler(),
'yaml': YamlHandler(),
'yml': YamlHandler(),
}
def load(file, file_format=None, **kwargs):
"""Load data from json/yaml/pickle files.
This method provides a unified api for loading data from serialized files.
Args:
file (str or :obj:`Path` or file-like object): Filename or a file-like
object.
file_format (str, optional): If not specified, the file format will be
inferred from the file extension, otherwise use the specified one.
Currently supported formats include "json", "yaml/yml".
Examples:
>>> load('/path/of/your/file') # file is storaged in disk
>>> load('https://path/of/your/file') # file is storaged in Internet
>>> load('oss://path/of/your/file') # file is storaged in petrel
Returns:
The content from the file.
"""
if isinstance(file, Path):
file = str(file)
if file_format is None and isinstance(file, str):
file_format = file.split('.')[-1]
if file_format not in format_handlers:
raise TypeError(f'Unsupported format: {file_format}')
handler = format_handlers[file_format]
if isinstance(file, str):
if handler.text_mode:
with StringIO(File.read_text(file)) as f:
obj = handler.load(f, **kwargs)
else:
with BytesIO(File.read(file)) as f:
obj = handler.load(f, **kwargs)
elif hasattr(file, 'read'):
obj = handler.load(file, **kwargs)
else:
raise TypeError('"file" must be a filepath str or a file-object')
return obj
def dump(obj, file=None, file_format=None, **kwargs):
"""Dump data to json/yaml strings or files.
This method provides a unified api for dumping data as strings or to files.
Args:
obj (any): The python object to be dumped.
file (str or :obj:`Path` or file-like object, optional): If not
specified, then the object is dumped to a str, otherwise to a file
specified by the filename or file-like object.
file_format (str, optional): Same as :func:`load`.
Examples:
>>> dump('hello world', '/path/of/your/file') # disk
>>> dump('hello world', 'oss://path/of/your/file') # oss
Returns:
bool: True for success, False otherwise.
"""
if isinstance(file, Path):
file = str(file)
if file_format is None:
if isinstance(file, str):
file_format = file.split('.')[-1]
elif file is None:
raise ValueError(
'file_format must be specified since file is None')
if file_format not in format_handlers:
raise TypeError(f'Unsupported format: {file_format}')
handler = format_handlers[file_format]
if file is None:
return handler.dump_to_str(obj, **kwargs)
elif isinstance(file, str):
if handler.text_mode:
with StringIO() as f:
handler.dump(obj, f, **kwargs)
File.write_text(f.getvalue(), file)
else:
with BytesIO() as f:
handler.dump(obj, f, **kwargs)
File.write(f.getvalue(), file)
elif hasattr(file, 'write'):
handler.dump(obj, file, **kwargs)
else:
raise TypeError('"file" must be a filename str or a file-object')
def dumps(obj, format, **kwargs):
"""Dump data to json/yaml strings or files.
This method provides a unified api for dumping data as strings or to files.
Args:
obj (any): The python object to be dumped.
format (str, optional): Same as file_format :func:`load`.
Examples:
>>> dumps('hello world', 'json') # disk
>>> dumps('hello world', 'yaml') # oss
Returns:
bool: True for success, False otherwise.
"""
if format not in format_handlers:
raise TypeError(f'Unsupported format: {format}')
handler = format_handlers[format]
return handler.dumps(obj, **kwargs)

472
maas_lib/utils/config.py Normal file
View File

@@ -0,0 +1,472 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import ast
import copy
import os
import os.path as osp
import platform
import shutil
import sys
import tempfile
import types
import uuid
from importlib import import_module
from pathlib import Path
from typing import Dict
import addict
from yapf.yapflib.yapf_api import FormatCode
from maas_lib.utils.logger import get_logger
from maas_lib.utils.pymod import (import_modules, import_modules_from_file,
validate_py_syntax)
if platform.system() == 'Windows':
import regex as re # type: ignore
else:
import re # type: ignore
logger = get_logger()
BASE_KEY = '_base_'
DELETE_KEY = '_delete_'
DEPRECATION_KEY = '_deprecation_'
RESERVED_KEYS = ['filename', 'text', 'pretty_text']
class ConfigDict(addict.Dict):
""" Dict which support get value through getattr
Examples:
>>> cdict = ConfigDict({'a':1232})
>>> print(cdict.a)
1232
"""
def __missing__(self, name):
raise KeyError(name)
def __getattr__(self, name):
try:
value = super(ConfigDict, self).__getattr__(name)
except KeyError:
ex = AttributeError(f"'{self.__class__.__name__}' object has no "
f"attribute '{name}'")
except Exception as e:
ex = e
else:
return value
raise ex
class Config:
"""A facility for config and config files.
It supports common file formats as configs: python/json/yaml. The interface
is the same as a dict object and also allows access config values as
attributes.
Example:
>>> cfg = Config(dict(a=1, b=dict(c=[1,2,3], d='dd')))
>>> cfg.a
1
>>> cfg.b
{'c': [1, 2, 3], 'd': 'dd'}
>>> cfg.b.d
'dd'
>>> cfg = Config.from_file('configs/examples/config.json')
>>> cfg.filename
'configs/examples/config.json'
>>> cfg.b
{'c': [1, 2, 3], 'd': 'dd'}
>>> cfg = Config.from_file('configs/examples/config.py')
>>> cfg.filename
"configs/examples/config.py"
>>> cfg = Config.from_file('configs/examples/config.yaml')
>>> cfg.filename
"configs/examples/config.yaml"
"""
@staticmethod
def _file2dict(filename):
filename = osp.abspath(osp.expanduser(filename))
if not osp.exists(filename):
raise ValueError(f'File does not exists {filename}')
fileExtname = osp.splitext(filename)[1]
if fileExtname not in ['.py', '.json', '.yaml', '.yml']:
raise IOError('Only py/yml/yaml/json type are supported now!')
with tempfile.TemporaryDirectory() as tmp_cfg_dir:
tmp_cfg_file = tempfile.NamedTemporaryFile(
dir=tmp_cfg_dir, suffix=fileExtname)
if platform.system() == 'Windows':
tmp_cfg_file.close()
tmp_cfg_name = osp.basename(tmp_cfg_file.name)
shutil.copyfile(filename, tmp_cfg_file.name)
if filename.endswith('.py'):
module_nanme, mod = import_modules_from_file(
osp.join(tmp_cfg_dir, tmp_cfg_name))
cfg_dict = {}
for name, value in mod.__dict__.items():
if not name.startswith('__') and \
not isinstance(value, types.ModuleType) and \
not isinstance(value, types.FunctionType):
cfg_dict[name] = value
# delete imported module
del sys.modules[module_nanme]
elif filename.endswith(('.yml', '.yaml', '.json')):
from maas_lib.fileio import load
cfg_dict = load(tmp_cfg_file.name)
# close temp file
tmp_cfg_file.close()
cfg_text = filename + '\n'
with open(filename, 'r', encoding='utf-8') as f:
# Setting encoding explicitly to resolve coding issue on windows
cfg_text += f.read()
return cfg_dict, cfg_text
@staticmethod
def from_file(filename):
if isinstance(filename, Path):
filename = str(filename)
cfg_dict, cfg_text = Config._file2dict(filename)
return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
@staticmethod
def from_string(cfg_str, file_format):
"""Generate config from config str.
Args:
cfg_str (str): Config str.
file_format (str): Config file format corresponding to the
config str. Only py/yml/yaml/json type are supported now!
Returns:
:obj:`Config`: Config obj.
"""
if file_format not in ['.py', '.json', '.yaml', '.yml']:
raise IOError('Only py/yml/yaml/json type are supported now!')
if file_format != '.py' and 'dict(' in cfg_str:
# check if users specify a wrong suffix for python
logger.warning(
'Please check "file_format", the file format may be .py')
with tempfile.NamedTemporaryFile(
'w', encoding='utf-8', suffix=file_format,
delete=False) as temp_file:
temp_file.write(cfg_str)
# on windows, previous implementation cause error
# see PR 1077 for details
cfg = Config.from_file(temp_file.name)
os.remove(temp_file.name)
return cfg
def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
if cfg_dict is None:
cfg_dict = dict()
elif not isinstance(cfg_dict, dict):
raise TypeError('cfg_dict must be a dict, but '
f'got {type(cfg_dict)}')
for key in cfg_dict:
if key in RESERVED_KEYS:
raise KeyError(f'{key} is reserved for config file')
if isinstance(filename, Path):
filename = str(filename)
super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
super(Config, self).__setattr__('_filename', filename)
if cfg_text:
text = cfg_text
elif filename:
with open(filename, 'r') as f:
text = f.read()
else:
text = ''
super(Config, self).__setattr__('_text', text)
@property
def filename(self):
return self._filename
@property
def text(self):
return self._text
@property
def pretty_text(self):
indent = 4
def _indent(s_, num_spaces):
s = s_.split('\n')
if len(s) == 1:
return s_
first = s.pop(0)
s = [(num_spaces * ' ') + line for line in s]
s = '\n'.join(s)
s = first + '\n' + s
return s
def _format_basic_types(k, v, use_mapping=False):
if isinstance(v, str):
v_str = f"'{v}'"
else:
v_str = str(v)
if use_mapping:
k_str = f"'{k}'" if isinstance(k, str) else str(k)
attr_str = f'{k_str}: {v_str}'
else:
attr_str = f'{str(k)}={v_str}'
attr_str = _indent(attr_str, indent)
return attr_str
def _format_list(k, v, use_mapping=False):
# check if all items in the list are dict
if all(isinstance(_, dict) for _ in v):
v_str = '[\n'
v_str += '\n'.join(
f'dict({_indent(_format_dict(v_), indent)}),'
for v_ in v).rstrip(',')
if use_mapping:
k_str = f"'{k}'" if isinstance(k, str) else str(k)
attr_str = f'{k_str}: {v_str}'
else:
attr_str = f'{str(k)}={v_str}'
attr_str = _indent(attr_str, indent) + ']'
else:
attr_str = _format_basic_types(k, v, use_mapping)
return attr_str
def _contain_invalid_identifier(dict_str):
contain_invalid_identifier = False
for key_name in dict_str:
contain_invalid_identifier |= \
(not str(key_name).isidentifier())
return contain_invalid_identifier
def _format_dict(input_dict, outest_level=False):
r = ''
s = []
use_mapping = _contain_invalid_identifier(input_dict)
if use_mapping:
r += '{'
for idx, (k, v) in enumerate(input_dict.items()):
is_last = idx >= len(input_dict) - 1
end = '' if outest_level or is_last else ','
if isinstance(v, dict):
v_str = '\n' + _format_dict(v)
if use_mapping:
k_str = f"'{k}'" if isinstance(k, str) else str(k)
attr_str = f'{k_str}: dict({v_str}'
else:
attr_str = f'{str(k)}=dict({v_str}'
attr_str = _indent(attr_str, indent) + ')' + end
elif isinstance(v, list):
attr_str = _format_list(k, v, use_mapping) + end
else:
attr_str = _format_basic_types(k, v, use_mapping) + end
s.append(attr_str)
r += '\n'.join(s)
if use_mapping:
r += '}'
return r
cfg_dict = self._cfg_dict.to_dict()
text = _format_dict(cfg_dict, outest_level=True)
# copied from setup.cfg
yapf_style = dict(
based_on_style='pep8',
blank_line_before_nested_class_or_def=True,
split_before_expression_after_opening_paren=True)
text, _ = FormatCode(text, style_config=yapf_style, verify=True)
return text
def __repr__(self):
return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}'
def __len__(self):
return len(self._cfg_dict)
def __getattr__(self, name):
return getattr(self._cfg_dict, name)
def __getitem__(self, name):
return self._cfg_dict.__getitem__(name)
def __setattr__(self, name, value):
if isinstance(value, dict):
value = ConfigDict(value)
self._cfg_dict.__setattr__(name, value)
def __setitem__(self, name, value):
if isinstance(value, dict):
value = ConfigDict(value)
self._cfg_dict.__setitem__(name, value)
def __iter__(self):
return iter(self._cfg_dict)
def __getstate__(self):
return (self._cfg_dict, self._filename, self._text)
def __copy__(self):
cls = self.__class__
other = cls.__new__(cls)
other.__dict__.update(self.__dict__)
return other
def __deepcopy__(self, memo):
cls = self.__class__
other = cls.__new__(cls)
memo[id(self)] = other
for key, value in self.__dict__.items():
super(Config, other).__setattr__(key, copy.deepcopy(value, memo))
return other
def __setstate__(self, state):
_cfg_dict, _filename, _text = state
super(Config, self).__setattr__('_cfg_dict', _cfg_dict)
super(Config, self).__setattr__('_filename', _filename)
super(Config, self).__setattr__('_text', _text)
def dump(self, file: str = None):
"""Dumps config into a file or returns a string representation of the
config.
If a file argument is given, saves the config to that file using the
format defined by the file argument extension.
Otherwise, returns a string representing the config. The formatting of
this returned string is defined by the extension of `self.filename`. If
`self.filename` is not defined, returns a string representation of a
dict (lowercased and using ' for strings).
Examples:
>>> cfg_dict = dict(item1=[1, 2], item2=dict(a=0),
... item3=True, item4='test')
>>> cfg = Config(cfg_dict=cfg_dict)
>>> dump_file = "a.py"
>>> cfg.dump(dump_file)
Args:
file (str, optional): Path of the output file where the config
will be dumped. Defaults to None.
"""
from maas_lib.fileio import dump
cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict()
if file is None:
if self.filename is None or self.filename.endswith('.py'):
return self.pretty_text
else:
file_format = self.filename.split('.')[-1]
return dump(cfg_dict, file_format=file_format)
elif file.endswith('.py'):
with open(file, 'w', encoding='utf-8') as f:
f.write(self.pretty_text)
else:
file_format = file.split('.')[-1]
return dump(cfg_dict, file=file, file_format=file_format)
def merge_from_dict(self, options, allow_list_keys=True):
"""Merge list into cfg_dict.
Merge the dict parsed by MultipleKVAction into this cfg.
Examples:
>>> options = {'model.backbone.depth': 50,
... 'model.backbone.with_cp':True}
>>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet'))))
>>> cfg.merge_from_dict(options)
>>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
>>> assert cfg_dict == dict(
... model=dict(backbone=dict(depth=50, with_cp=True)))
>>> # Merge list element
>>> cfg = Config(dict(pipeline=[
... dict(type='Resize'), dict(type='RandomDistortion')]))
>>> options = dict(pipeline={'0': dict(type='MyResize')})
>>> cfg.merge_from_dict(options, allow_list_keys=True)
>>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
>>> assert cfg_dict == dict(pipeline=[
... dict(type='MyResize'), dict(type='RandomDistortion')])
Args:
options (dict): dict of configs to merge from.
allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
are allowed in ``options`` and will replace the element of the
corresponding index in the config if the config is a list.
Default: True.
"""
option_cfg_dict = {}
for full_key, v in options.items():
d = option_cfg_dict
key_list = full_key.split('.')
for subkey in key_list[:-1]:
d.setdefault(subkey, ConfigDict())
d = d[subkey]
subkey = key_list[-1]
d[subkey] = v
cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
super(Config, self).__setattr__(
'_cfg_dict',
Config._merge_a_into_b(
option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys))
def to_dict(self) -> Dict:
""" Convert Config object to python dict
"""
return self._cfg_dict.to_dict()
def to_args(self, parse_fn, use_hyphen=True):
""" Convert config obj to args using parse_fn
Args:
parse_fn: a function object, which takes args as input,
such as ['--foo', 'FOO'] and return parsed args, an
example is given as follows
including literal blocks::
def parse_fn(args):
parser = argparse.ArgumentParser(prog='PROG')
parser.add_argument('-x')
parser.add_argument('--foo')
return parser.parse_args(args)
use_hyphen (bool, optional): if set true, hyphen in keyname
will be converted to underscore
Return:
args: arg object parsed by argparse.ArgumentParser
"""
args = []
for k, v in self._cfg_dict.items():
arg_name = f'--{k}'
if use_hyphen:
arg_name = arg_name.replace('_', '-')
if isinstance(v, bool) and v:
args.append(arg_name)
elif isinstance(v, (int, str, float)):
args.append(arg_name)
args.append(str(v))
elif isinstance(v, list):
args.append(arg_name)
assert isinstance(v, (int, str, float, bool)), 'Element type in list ' \
f'is expected to be either int,str,float, but got type {v[0]}'
args.append(str(v))
else:
raise ValueError(
'type in config file which supported to be '
'converted to args should be either bool, '
f'int, str, float or list of them but got type {v}')
return parse_fn(args)

View File

@@ -0,0 +1,34 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
class Fields(object):
""" Names for different application fields
"""
image = 'image'
video = 'video'
nlp = 'nlp'
audio = 'audio'
multi_modal = 'multi_modal'
class Tasks(object):
""" Names for tasks supported by maas lib.
Holds the standard task name to use for identifying different tasks.
This should be used to register models, pipelines, trainers.
"""
# vision tasks
image_classfication = 'image-classification'
object_detection = 'object-detection'
# nlp tasks
sentiment_analysis = 'sentiment-analysis'
fill_mask = 'fill-mask'
class InputFields(object):
""" Names for input data fileds in the input data for pipelines
"""
img = 'img'
text = 'text'
audio = 'audio'

45
maas_lib/utils/logger.py Normal file
View File

@@ -0,0 +1,45 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
from typing import Optional
init_loggers = {}
def get_logger(log_file: Optional[str] = None,
log_level: int = logging.INFO,
file_mode: str = 'w'):
""" Get logging logger
Args:
log_file: Log filename, if specified, file handler will be added to
logger
log_level: Logging level.
file_mode: Specifies the mode to open the file, if filename is
specified (if filemode is unspecified, it defaults to 'w').
"""
logger_name = __name__.split('.')[0]
logger = logging.getLogger(logger_name)
if logger_name in init_loggers:
return logger
stream_handler = logging.StreamHandler()
handlers = [stream_handler]
# TODO @wenmeng.zwm add logger setting for distributed environment
if log_file is not None:
file_handler = logging.FileHandler(log_file, file_mode)
handlers.append(file_handler)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
for handler in handlers:
handler.setFormatter(formatter)
handler.setLevel(log_level)
logger.addHandler(handler)
logger.setLevel(log_level)
init_loggers[logger_name] = True
return logger

90
maas_lib/utils/pymod.py Normal file
View File

@@ -0,0 +1,90 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import ast
import os
import os.path as osp
import sys
import types
from importlib import import_module
from maas_lib.utils.logger import get_logger
logger = get_logger()
def import_modules_from_file(py_file: str):
""" Import module from a certrain file
Args:
py_file: path to a python file to be imported
Return:
"""
dirname, basefile = os.path.split(py_file)
if dirname == '':
dirname == './'
module_name = osp.splitext(basefile)[0]
sys.path.insert(0, dirname)
validate_py_syntax(py_file)
mod = import_module(module_name)
sys.path.pop(0)
return module_name, mod
def import_modules(imports, allow_failed_imports=False):
"""Import modules from the given list of strings.
Args:
imports (list | str | None): The given module names to be imported.
allow_failed_imports (bool): If True, the failed imports will return
None. Otherwise, an ImportError is raise. Default: False.
Returns:
list[module] | module | None: The imported modules.
Examples:
>>> osp, sys = import_modules(
... ['os.path', 'sys'])
>>> import os.path as osp_
>>> import sys as sys_
>>> assert osp == osp_
>>> assert sys == sys_
"""
if not imports:
return
single_import = False
if isinstance(imports, str):
single_import = True
imports = [imports]
if not isinstance(imports, list):
raise TypeError(
f'custom_imports must be a list but got type {type(imports)}')
imported = []
for imp in imports:
if not isinstance(imp, str):
raise TypeError(
f'{imp} is of type {type(imp)} and cannot be imported.')
try:
imported_tmp = import_module(imp)
except ImportError:
if allow_failed_imports:
logger.warning(f'{imp} failed to import and is ignored.')
imported_tmp = None
else:
raise ImportError
imported.append(imported_tmp)
if single_import:
imported = imported[0]
return imported
def validate_py_syntax(filename):
with open(filename, 'r', encoding='utf-8') as f:
# Setting encoding explicitly to resolve coding issue on windows
content = f.read()
try:
ast.parse(content)
except SyntaxError as e:
raise SyntaxError('There are syntax errors in config '
f'file {filename}: {e}')

183
maas_lib/utils/registry.py Normal file
View File

@@ -0,0 +1,183 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import inspect
from maas_lib.utils.logger import get_logger
default_group = 'default'
logger = get_logger()
class Registry(object):
""" Registry which support registering modules and group them by a keyname
If group name is not provided, modules will be registered to default group.
"""
def __init__(self, name: str):
self._name = name
self._modules = dict()
def __repr__(self):
format_str = self.__class__.__name__ + f'({self._name})\n'
for group_name, group in self._modules.items():
format_str += f'group_name={group_name}, '\
f'modules={list(group.keys())}\n'
return format_str
@property
def name(self):
return self._name
@property
def modules(self):
return self._modules
def list(self):
""" logging the list of module in current registry
"""
for group_name, group in self._modules.items():
logger.info(f'group_name={group_name}')
for m in group.keys():
logger.info(f'\t{m}')
logger.info('')
def get(self, module_key, group_key=default_group):
if group_key not in self._modules:
return None
else:
return self._modules[group_key].get(module_key, None)
def _register_module(self,
group_key=default_group,
module_name=None,
module_cls=None):
assert isinstance(group_key,
str), 'group_key is required and must be str'
if group_key not in self._modules:
self._modules[group_key] = dict()
if not inspect.isclass(module_cls):
raise TypeError(f'module is not a class type: {type(module_cls)}')
if module_name is None:
module_name = module_cls.__name__
if module_name in self._modules[group_key]:
raise KeyError(f'{module_name} is already registered in'
f'{self._name}[{group_key}]')
self._modules[group_key][module_name] = module_cls
def register_module(self,
group_key: str = default_group,
module_name: str = None,
module_cls: type = None):
""" Register module
Example:
>>> models = Registry('models')
>>> @models.register_module('image-classification', 'SwinT')
>>> class SwinTransformer:
>>> pass
>>> @models.register_module('SwinDefault')
>>> class SwinTransformerDefaultGroup:
>>> pass
Args:
group_key: Group name of which module will be registered,
default group name is 'default'
module_name: Module name
module_cls: Module class object
"""
if not (module_name is None or isinstance(module_name, str)):
raise TypeError(f'module_name must be either of None, str,'
f'got {type(module_name)}')
if module_cls is not None:
self._register_module(
group_key=group_key,
module_name=module_name,
module_cls=module_cls)
return module_cls
# if module_cls is None, should return a dectorator function
def _register(module_cls):
self._register_module(
group_key=group_key,
module_name=module_name,
module_cls=module_cls)
return module_cls
return _register
def build_from_cfg(cfg,
registry: Registry,
group_key: str = default_group,
default_args: dict = None) -> object:
"""Build a module from config dict when it is a class configuration, or
call a function from config dict when it is a function configuration.
Example:
>>> models = Registry('models')
>>> @models.register_module('image-classification', 'SwinT')
>>> class SwinTransformer:
>>> pass
>>> swint = build_from_cfg(dict(type='SwinT'), MODELS,
>>> 'image-classification')
>>> # Returns an instantiated object
>>>
>>> @MODELS.register_module()
>>> def swin_transformer():
>>> pass
>>> = build_from_cfg(dict(type='swin_transformer'), MODELS)
>>> # Return a result of the calling function
Args:
cfg (dict): Config dict. It should at least contain the key "type".
registry (:obj:`Registry`): The registry to search the type from.
group_key (str, optional): The name of registry group from which
module should be searched.
default_args (dict, optional): Default initialization arguments.
Returns:
object: The constructed object.
"""
if not isinstance(cfg, dict):
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
if 'type' not in cfg:
if default_args is None or 'type' not in default_args:
raise KeyError(
'`cfg` or `default_args` must contain the key "type", '
f'but got {cfg}\n{default_args}')
if not isinstance(registry, Registry):
raise TypeError('registry must be an maas_lib.Registry object, '
f'but got {type(registry)}')
if not (isinstance(default_args, dict) or default_args is None):
raise TypeError('default_args must be a dict or None, '
f'but got {type(default_args)}')
args = cfg.copy()
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
obj_type = args.pop('type')
if isinstance(obj_type, str):
obj_cls = registry.get(obj_type, group_key=group_key)
if obj_cls is None:
raise KeyError(f'{obj_type} is not in the {registry.name}'
f'registry group {group_key}')
elif inspect.isclass(obj_type) or inspect.isfunction(obj_type):
obj_cls = obj_type
else:
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')
try:
return obj_cls(**args)
except Exception as e:
# Normal TypeError does not print class name.
raise type(e)(f'{obj_cls.__name__}: {e}')

1
requirements.txt Normal file
View File

@@ -0,0 +1 @@
-r requirements/runtime.txt

6
requirements/docs.txt Normal file
View File

@@ -0,0 +1,6 @@
docutils==0.16.0
recommonmark
sphinx==4.0.2
sphinx-copybutton
sphinx_markdown_tables
sphinx_rtd_theme==0.5.2

5
requirements/runtime.txt Normal file
View File

@@ -0,0 +1,5 @@
addict
numpy
pyyaml
requests
yapf

5
requirements/tests.txt Normal file
View File

@@ -0,0 +1,5 @@
expecttest
flake8
isort==4.3.21
pre-commit
yapf==0.30.0

24
setup.cfg Normal file
View File

@@ -0,0 +1,24 @@
[isort]
line_length = 79
multi_line_output = 0
known_standard_library = setuptools
known_first_party = maas_lib
known_third_party = json,yaml
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
[yapf]
BASED_ON_STYLE = pep8
BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true
SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true
[codespell]
skip = *.ipynb
quiet-level = 3
ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids
[flake8]
select = B,C,E,F,P,T4,W,B9
max-line-length = 120
ignore = F401
exclude = docs/src,*.pyi,.git

0
tests/__init__.py Normal file
View File

0
tests/fileio/__init__.py Normal file
View File

70
tests/fileio/test_file.py Normal file
View File

@@ -0,0 +1,70 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import tempfile
import unittest
from requests import HTTPError
from maas_lib.fileio.file import File, HTTPStorage, LocalStorage
class FileTest(unittest.TestCase):
def test_local_storage(self):
storage = LocalStorage()
temp_name = tempfile.gettempdir() + '/' + next(
tempfile._get_candidate_names())
binary_content = b'12345'
storage.write(binary_content, temp_name)
self.assertEqual(binary_content, storage.read(temp_name))
content = '12345'
storage.write_text(content, temp_name)
self.assertEqual(content, storage.read_text(temp_name))
os.remove(temp_name)
def test_http_storage(self):
storage = HTTPStorage()
url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com' \
'/data/test/data.txt'
content = 'this is test data'
self.assertEqual(content.encode('utf8'), storage.read(url))
self.assertEqual(content, storage.read_text(url))
with storage.as_local_path(url) as local_file:
with open(local_file, 'r') as infile:
self.assertEqual(content, infile.read())
with self.assertRaises(NotImplementedError):
storage.write('dfad', url)
with self.assertRaises(HTTPError):
storage.read(url + 'df')
def test_file(self):
url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com'\
'/data/test/data.txt'
content = 'this is test data'
self.assertEqual(content.encode('utf8'), File.read(url))
with File.as_local_path(url) as local_file:
with open(local_file, 'r') as infile:
self.assertEqual(content, infile.read())
with self.assertRaises(NotImplementedError):
File.write('dfad', url)
with self.assertRaises(HTTPError):
File.read(url + 'df')
temp_name = tempfile.gettempdir() + '/' + next(
tempfile._get_candidate_names())
binary_content = b'12345'
File.write(binary_content, temp_name)
self.assertEqual(binary_content, File.read(temp_name))
os.remove(temp_name)
if __name__ == '__main__':
unittest.main()

32
tests/fileio/test_io.py Normal file
View File

@@ -0,0 +1,32 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import tempfile
import unittest
from maas_lib.fileio.io import dump, dumps, load
class FileIOTest(unittest.TestCase):
def test_format(self, format='json'):
obj = [1, 2, 3, 'str', {'model': 'resnet'}]
result_str = dumps(obj, format)
temp_name = tempfile.gettempdir() + '/' + next(
tempfile._get_candidate_names()) + '.' + format
dump(obj, temp_name)
obj_load = load(temp_name)
self.assertEqual(obj_load, obj)
with open(temp_name, 'r') as infile:
self.assertEqual(result_str, infile.read())
with self.assertRaises(TypeError):
obj_load = load(temp_name + 's')
with self.assertRaises(TypeError):
dump(obj, temp_name + 's')
def test_yaml(self):
self.test_format('yaml')
if __name__ == '__main__':
unittest.main()

53
tests/run.py Normal file
View File

@@ -0,0 +1,53 @@
#!/usr/bin/env python
# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import os
import sys
import unittest
from fnmatch import fnmatch
def gather_test_cases(test_dir, pattern, list_tests):
case_list = []
for dirpath, dirnames, filenames in os.walk(test_dir):
for file in filenames:
if fnmatch(file, pattern):
case_list.append(file)
test_suite = unittest.TestSuite()
for case in case_list:
test_case = unittest.defaultTestLoader.discover(
start_dir=test_dir, pattern=case)
test_suite.addTest(test_case)
if hasattr(test_case, '__iter__'):
for subcase in test_case:
if list_tests:
print(subcase)
else:
if list_tests:
print(test_case)
return test_suite
def main(args):
runner = unittest.TextTestRunner()
test_suite = gather_test_cases(
os.path.abspath(args.test_dir), args.pattern, args.list_tests)
if not args.list_tests:
result = runner.run(test_suite)
if len(result.failures) > 0:
sys.exit(1)
if __name__ == '__main__':
parser = argparse.ArgumentParser('test runner')
parser.add_argument(
'--list_tests', action='store_true', help='list all tests')
parser.add_argument(
'--pattern', default='test_*.py', help='test file pattern')
parser.add_argument(
'--test_dir', default='tests', help='directory to be tested')
args = parser.parse_args()
main(args)

0
tests/utils/__init__.py Normal file
View File

View File

@@ -0,0 +1,85 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import os.path as osp
import tempfile
import unittest
from pathlib import Path
from maas_lib.fileio import dump, load
from maas_lib.utils.config import Config
obj = {'a': 1, 'b': {'c': [1, 2, 3], 'd': 'dd'}}
class ConfigTest(unittest.TestCase):
def test_json(self):
config_file = 'configs/examples/config.json'
cfg = Config.from_file(config_file)
self.assertEqual(cfg.a, 1)
self.assertEqual(cfg.b, obj['b'])
def test_yaml(self):
config_file = 'configs/examples/config.yaml'
cfg = Config.from_file(config_file)
self.assertEqual(cfg.a, 1)
self.assertEqual(cfg.b, obj['b'])
def test_py(self):
config_file = 'configs/examples/config.py'
cfg = Config.from_file(config_file)
self.assertEqual(cfg.a, 1)
self.assertEqual(cfg.b, obj['b'])
def test_dump(self):
config_file = 'configs/examples/config.py'
cfg = Config.from_file(config_file)
self.assertEqual(cfg.a, 1)
self.assertEqual(cfg.b, obj['b'])
pretty_text = 'a = 1\n'
pretty_text += "b = dict(c=[1, 2, 3], d='dd')\n"
json_str = '{"a": 1, "b": {"c": [1, 2, 3], "d": "dd"}}'
yaml_str = 'a: 1\nb:\n c:\n - 1\n - 2\n - 3\n d: dd\n'
with tempfile.NamedTemporaryFile(suffix='.json') as ofile:
self.assertEqual(pretty_text, cfg.dump())
cfg.dump(ofile.name)
with open(ofile.name, 'r') as infile:
self.assertEqual(json_str, infile.read())
with tempfile.NamedTemporaryFile(suffix='.yaml') as ofile:
cfg.dump(ofile.name)
with open(ofile.name, 'r') as infile:
self.assertEqual(yaml_str, infile.read())
def test_to_dict(self):
config_file = 'configs/examples/config.json'
cfg = Config.from_file(config_file)
d = cfg.to_dict()
print(d)
self.assertTrue(isinstance(d, dict))
def test_to_args(self):
def parse_fn(args):
parser = argparse.ArgumentParser(prog='PROG')
parser.add_argument('--model-dir', default='')
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--optimizer', default='')
parser.add_argument('--weight-decay', type=float, default=1e-7)
parser.add_argument(
'--save-checkpoint-epochs', type=int, default=30)
return parser.parse_args(args)
cfg = Config.from_file('configs/examples/plain_args.yaml')
args = cfg.to_args(parse_fn)
self.assertEqual(args.model_dir, 'path/to/model')
self.assertAlmostEqual(args.lr, 0.01)
self.assertAlmostEqual(args.weight_decay, 1e-6)
self.assertEqual(args.optimizer, 'Adam')
self.assertEqual(args.save_checkpoint_epochs, 20)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,91 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
from maas_lib.utils.constant import Tasks
from maas_lib.utils.registry import Registry, build_from_cfg, default_group
class RegistryTest(unittest.TestCase):
def test_register_class_no_task(self):
MODELS = Registry('models')
self.assertTrue(MODELS.name == 'models')
self.assertTrue(MODELS.modules == {})
self.assertEqual(len(MODELS.modules), 0)
@MODELS.register_module(module_name='cls-resnet')
class ResNetForCls(object):
pass
self.assertTrue(default_group in MODELS.modules)
self.assertTrue(MODELS.get('cls-resnet') is ResNetForCls)
def test_register_class_with_task(self):
MODELS = Registry('models')
@MODELS.register_module(Tasks.image_classfication, 'SwinT')
class SwinTForCls(object):
pass
self.assertTrue(Tasks.image_classfication in MODELS.modules)
self.assertTrue(
MODELS.get('SwinT', Tasks.image_classfication) is SwinTForCls)
@MODELS.register_module(Tasks.sentiment_analysis, 'Bert')
class BertForSentimentAnalysis(object):
pass
self.assertTrue(Tasks.sentiment_analysis in MODELS.modules)
self.assertTrue(
MODELS.get('Bert', Tasks.sentiment_analysis) is
BertForSentimentAnalysis)
@MODELS.register_module(Tasks.object_detection)
class DETR(object):
pass
self.assertTrue(Tasks.object_detection in MODELS.modules)
self.assertTrue(MODELS.get('DETR', Tasks.object_detection) is DETR)
self.assertEqual(len(MODELS.modules), 3)
def test_list(self):
MODELS = Registry('models')
@MODELS.register_module(Tasks.image_classfication, 'SwinT')
class SwinTForCls(object):
pass
@MODELS.register_module(Tasks.sentiment_analysis, 'Bert')
class BertForSentimentAnalysis(object):
pass
MODELS.list()
print(MODELS)
def test_build(self):
MODELS = Registry('models')
@MODELS.register_module(Tasks.image_classfication, 'SwinT')
class SwinTForCls(object):
pass
@MODELS.register_module(Tasks.sentiment_analysis, 'Bert')
class BertForSentimentAnalysis(object):
pass
cfg = dict(type='SwinT')
model = build_from_cfg(cfg, MODELS, Tasks.image_classfication)
self.assertTrue(isinstance(model, SwinTForCls))
cfg = dict(type='Bert')
model = build_from_cfg(cfg, MODELS, Tasks.sentiment_analysis)
self.assertTrue(isinstance(model, BertForSentimentAnalysis))
with self.assertRaises(KeyError):
cfg = dict(type='Bert')
model = build_from_cfg(cfg, MODELS, Tasks.image_classfication)
if __name__ == '__main__':
unittest.main()