mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
[to #41402703] add basic modules
* add constant * add logger module * add registry and builder module * add fileio module * add requirements and setup.cfg * add config module and tests * add citest script Link: https://code.aone.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8718998
This commit is contained in:
13
.dev_scripts/citest.sh
Normal file
13
.dev_scripts/citest.sh
Normal file
@@ -0,0 +1,13 @@
|
||||
pip install -r requirements/runtime.txt
|
||||
pip install -r requirements/tests.txt
|
||||
|
||||
|
||||
# linter test
|
||||
# use internal project for pre-commit due to the network problem
|
||||
pre-commit run --all-files
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "linter test failed, please run 'pre-commit run --all-files' to check"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
PYTHONPATH=. python tests/run.py
|
||||
126
.gitignore
vendored
Normal file
126
.gitignore
vendored
Normal file
@@ -0,0 +1,126 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
/package
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# celery beat schedule file
|
||||
celerybeat-schedule
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
|
||||
data
|
||||
.vscode
|
||||
.idea
|
||||
|
||||
# custom
|
||||
*.pkl
|
||||
*.pkl.json
|
||||
*.log.json
|
||||
*.whl
|
||||
*.tar.gz
|
||||
*.swp
|
||||
*.log
|
||||
*.tar.gz
|
||||
source.sh
|
||||
tensorboard.sh
|
||||
.DS_Store
|
||||
replace.sh
|
||||
|
||||
# Pytorch
|
||||
*.pth
|
||||
1
configs/README.md
Normal file
1
configs/README.md
Normal file
@@ -0,0 +1 @@
|
||||
This folder will host example configs for each model supported by maas_lib.
|
||||
7
configs/examples/config.json
Normal file
7
configs/examples/config.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"a": 1,
|
||||
"b" : {
|
||||
"c": [1,2,3],
|
||||
"d" : "dd"
|
||||
}
|
||||
}
|
||||
2
configs/examples/config.py
Normal file
2
configs/examples/config.py
Normal file
@@ -0,0 +1,2 @@
|
||||
a = 1
|
||||
b = dict(c=[1,2,3], d='dd')
|
||||
4
configs/examples/config.yaml
Normal file
4
configs/examples/config.yaml
Normal file
@@ -0,0 +1,4 @@
|
||||
a: 1
|
||||
b:
|
||||
c: [1,2,3]
|
||||
d: dd
|
||||
5
configs/examples/plain_args.yaml
Normal file
5
configs/examples/plain_args.yaml
Normal file
@@ -0,0 +1,5 @@
|
||||
model_dir: path/to/model
|
||||
lr: 0.01
|
||||
optimizer: Adam
|
||||
weight_decay: 1e-6
|
||||
save_checkpoint_epochs: 20
|
||||
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from .version import __version__
|
||||
|
||||
__all__ = ['__version__']
|
||||
|
||||
1
maas_lib/fileio/__init__.py
Normal file
1
maas_lib/fileio/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .io import dump, dumps, load
|
||||
325
maas_lib/fileio/file.py
Normal file
325
maas_lib/fileio/file.py
Normal file
@@ -0,0 +1,325 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import tempfile
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Generator, Union
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
class Storage(metaclass=ABCMeta):
|
||||
"""Abstract class of storage.
|
||||
|
||||
All backends need to implement two apis: ``read()`` and ``read_text()``.
|
||||
``read()`` reads the file as a byte stream and ``read_text()`` reads
|
||||
the file as texts.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def read(self, filepath: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def read_text(self, filepath: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def write_text(self,
|
||||
obj: str,
|
||||
filepath: Union[str, Path],
|
||||
encoding: str = 'utf-8') -> None:
|
||||
pass
|
||||
|
||||
|
||||
class LocalStorage(Storage):
|
||||
"""Local hard disk storage"""
|
||||
|
||||
def read(self, filepath: Union[str, Path]) -> bytes:
|
||||
"""Read data from a given ``filepath`` with 'rb' mode.
|
||||
|
||||
Args:
|
||||
filepath (str or Path): Path to read data.
|
||||
|
||||
Returns:
|
||||
bytes: Expected bytes object.
|
||||
"""
|
||||
with open(filepath, 'rb') as f:
|
||||
content = f.read()
|
||||
return content
|
||||
|
||||
def read_text(self,
|
||||
filepath: Union[str, Path],
|
||||
encoding: str = 'utf-8') -> str:
|
||||
"""Read data from a given ``filepath`` with 'r' mode.
|
||||
|
||||
Args:
|
||||
filepath (str or Path): Path to read data.
|
||||
encoding (str): The encoding format used to open the ``filepath``.
|
||||
Default: 'utf-8'.
|
||||
|
||||
Returns:
|
||||
str: Expected text reading from ``filepath``.
|
||||
"""
|
||||
with open(filepath, 'r', encoding=encoding) as f:
|
||||
value_buf = f.read()
|
||||
return value_buf
|
||||
|
||||
def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
|
||||
"""Write data to a given ``filepath`` with 'wb' mode.
|
||||
|
||||
Note:
|
||||
``put`` will create a directory if the directory of ``filepath``
|
||||
does not exist.
|
||||
|
||||
Args:
|
||||
obj (bytes): Data to be written.
|
||||
filepath (str or Path): Path to write data.
|
||||
"""
|
||||
dirname = os.path.dirname(filepath)
|
||||
if dirname and not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
with open(filepath, 'wb') as f:
|
||||
f.write(obj)
|
||||
|
||||
def write_text(self,
|
||||
obj: str,
|
||||
filepath: Union[str, Path],
|
||||
encoding: str = 'utf-8') -> None:
|
||||
"""Write data to a given ``filepath`` with 'w' mode.
|
||||
|
||||
Note:
|
||||
``put_text`` will create a directory if the directory of
|
||||
``filepath`` does not exist.
|
||||
|
||||
Args:
|
||||
obj (str): Data to be written.
|
||||
filepath (str or Path): Path to write data.
|
||||
encoding (str): The encoding format used to open the ``filepath``.
|
||||
Default: 'utf-8'.
|
||||
"""
|
||||
dirname = os.path.dirname(filepath)
|
||||
if dirname and not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
with open(filepath, 'w', encoding=encoding) as f:
|
||||
f.write(obj)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def as_local_path(
|
||||
self,
|
||||
filepath: Union[str,
|
||||
Path]) -> Generator[Union[str, Path], None, None]:
|
||||
"""Only for unified API and do nothing."""
|
||||
yield filepath
|
||||
|
||||
|
||||
class HTTPStorage(Storage):
|
||||
"""HTTP and HTTPS storage."""
|
||||
|
||||
def read(self, url):
|
||||
r = requests.get(url)
|
||||
r.raise_for_status()
|
||||
return r.content
|
||||
|
||||
def read_text(self, url):
|
||||
r = requests.get(url)
|
||||
r.raise_for_status()
|
||||
return r.text
|
||||
|
||||
@contextlib.contextmanager
|
||||
def as_local_path(
|
||||
self, filepath: str) -> Generator[Union[str, Path], None, None]:
|
||||
"""Download a file from ``filepath``.
|
||||
|
||||
``as_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
|
||||
can be called with ``with`` statement, and when exists from the
|
||||
``with`` statement, the temporary path will be released.
|
||||
|
||||
Args:
|
||||
filepath (str): Download a file from ``filepath``.
|
||||
|
||||
Examples:
|
||||
>>> storage = HTTPStorage()
|
||||
>>> # After existing from the ``with`` clause,
|
||||
>>> # the path will be removed
|
||||
>>> with storage.get_local_path('http://path/to/file') as path:
|
||||
... # do something here
|
||||
"""
|
||||
try:
|
||||
f = tempfile.NamedTemporaryFile(delete=False)
|
||||
f.write(self.read(filepath))
|
||||
f.close()
|
||||
yield f.name
|
||||
finally:
|
||||
os.remove(f.name)
|
||||
|
||||
def write(self, obj: bytes, url: Union[str, Path]) -> None:
|
||||
raise NotImplementedError('write is not supported by HTTP Storage')
|
||||
|
||||
def write_text(self,
|
||||
obj: str,
|
||||
url: Union[str, Path],
|
||||
encoding: str = 'utf-8') -> None:
|
||||
raise NotImplementedError(
|
||||
'write_text is not supported by HTTP Storage')
|
||||
|
||||
|
||||
class OSSStorage(Storage):
|
||||
"""OSS storage."""
|
||||
|
||||
def __init__(self, oss_config_file=None):
|
||||
# read from config file or env var
|
||||
raise NotImplementedError(
|
||||
'OSSStorage.__init__ to be implemented in the future')
|
||||
|
||||
def read(self, filepath):
|
||||
raise NotImplementedError(
|
||||
'OSSStorage.read to be implemented in the future')
|
||||
|
||||
def read_text(self, filepath, encoding='utf-8'):
|
||||
raise NotImplementedError(
|
||||
'OSSStorage.read_text to be implemented in the future')
|
||||
|
||||
@contextlib.contextmanager
|
||||
def as_local_path(
|
||||
self, filepath: str) -> Generator[Union[str, Path], None, None]:
|
||||
"""Download a file from ``filepath``.
|
||||
|
||||
``as_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
|
||||
can be called with ``with`` statement, and when exists from the
|
||||
``with`` statement, the temporary path will be released.
|
||||
|
||||
Args:
|
||||
filepath (str): Download a file from ``filepath``.
|
||||
|
||||
Examples:
|
||||
>>> storage = OSSStorage()
|
||||
>>> # After existing from the ``with`` clause,
|
||||
>>> # the path will be removed
|
||||
>>> with storage.get_local_path('http://path/to/file') as path:
|
||||
... # do something here
|
||||
"""
|
||||
try:
|
||||
f = tempfile.NamedTemporaryFile(delete=False)
|
||||
f.write(self.read(filepath))
|
||||
f.close()
|
||||
yield f.name
|
||||
finally:
|
||||
os.remove(f.name)
|
||||
|
||||
def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
|
||||
raise NotImplementedError(
|
||||
'OSSStorage.write to be implemented in the future')
|
||||
|
||||
def write_text(self,
|
||||
obj: str,
|
||||
filepath: Union[str, Path],
|
||||
encoding: str = 'utf-8') -> None:
|
||||
raise NotImplementedError(
|
||||
'OSSStorage.write_text to be implemented in the future')
|
||||
|
||||
|
||||
G_STORAGES = {}
|
||||
|
||||
|
||||
class File(object):
|
||||
_prefix_to_storage: dict = {
|
||||
'oss': OSSStorage,
|
||||
'http': HTTPStorage,
|
||||
'https': HTTPStorage,
|
||||
'local': LocalStorage,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _get_storage(uri):
|
||||
assert isinstance(uri,
|
||||
str), f'uri should be str type, buf got {type(uri)}'
|
||||
|
||||
if '://' not in uri:
|
||||
# local path
|
||||
storage_type = 'local'
|
||||
else:
|
||||
prefix, _ = uri.split('://')
|
||||
storage_type = prefix
|
||||
|
||||
assert storage_type in File._prefix_to_storage, \
|
||||
f'Unsupported uri {uri}, valid prefixs: '\
|
||||
f'{list(File._prefix_to_storage.keys())}'
|
||||
|
||||
if storage_type not in G_STORAGES:
|
||||
G_STORAGES[storage_type] = File._prefix_to_storage[storage_type]()
|
||||
|
||||
return G_STORAGES[storage_type]
|
||||
|
||||
@staticmethod
|
||||
def read(uri: str) -> bytes:
|
||||
"""Read data from a given ``filepath`` with 'rb' mode.
|
||||
|
||||
Args:
|
||||
filepath (str or Path): Path to read data.
|
||||
|
||||
Returns:
|
||||
bytes: Expected bytes object.
|
||||
"""
|
||||
storage = File._get_storage(uri)
|
||||
return storage.read(uri)
|
||||
|
||||
@staticmethod
|
||||
def read_text(uri: Union[str, Path], encoding: str = 'utf-8') -> str:
|
||||
"""Read data from a given ``filepath`` with 'r' mode.
|
||||
|
||||
Args:
|
||||
filepath (str or Path): Path to read data.
|
||||
encoding (str): The encoding format used to open the ``filepath``.
|
||||
Default: 'utf-8'.
|
||||
|
||||
Returns:
|
||||
str: Expected text reading from ``filepath``.
|
||||
"""
|
||||
storage = File._get_storage(uri)
|
||||
return storage.read_text(uri)
|
||||
|
||||
@staticmethod
|
||||
def write(obj: bytes, uri: Union[str, Path]) -> None:
|
||||
"""Write data to a given ``filepath`` with 'wb' mode.
|
||||
|
||||
Note:
|
||||
``put`` will create a directory if the directory of ``filepath``
|
||||
does not exist.
|
||||
|
||||
Args:
|
||||
obj (bytes): Data to be written.
|
||||
filepath (str or Path): Path to write data.
|
||||
"""
|
||||
storage = File._get_storage(uri)
|
||||
return storage.write(obj, uri)
|
||||
|
||||
@staticmethod
|
||||
def write_text(obj: str, uri: str, encoding: str = 'utf-8') -> None:
|
||||
"""Write data to a given ``filepath`` with 'w' mode.
|
||||
|
||||
Note:
|
||||
``put_text`` will create a directory if the directory of
|
||||
``filepath`` does not exist.
|
||||
|
||||
Args:
|
||||
obj (str): Data to be written.
|
||||
filepath (str or Path): Path to write data.
|
||||
encoding (str): The encoding format used to open the ``filepath``.
|
||||
Default: 'utf-8'.
|
||||
"""
|
||||
storage = File._get_storage(uri)
|
||||
return storage.write_text(obj, uri)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def as_local_path(uri: str) -> Generator[Union[str, Path], None, None]:
|
||||
"""Only for unified API and do nothing."""
|
||||
storage = File._get_storage(uri)
|
||||
with storage.as_local_path(uri) as local_path:
|
||||
yield local_path
|
||||
3
maas_lib/fileio/format/__init__.py
Normal file
3
maas_lib/fileio/format/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .base import FormatHandler
|
||||
from .json import JsonHandler
|
||||
from .yaml import YamlHandler
|
||||
20
maas_lib/fileio/format/base.py
Normal file
20
maas_lib/fileio/format/base.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
|
||||
class FormatHandler(metaclass=ABCMeta):
|
||||
# if `text_format` is True, file
|
||||
# should use text mode otherwise binary mode
|
||||
text_mode = True
|
||||
|
||||
@abstractmethod
|
||||
def load(self, file, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def dump(self, obj, file, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def dumps(self, obj, **kwargs):
|
||||
pass
|
||||
35
maas_lib/fileio/format/json.py
Normal file
35
maas_lib/fileio/format/json.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
from .base import FormatHandler
|
||||
|
||||
|
||||
def set_default(obj):
|
||||
"""Set default json values for non-serializable values.
|
||||
|
||||
It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list.
|
||||
It also converts ``np.generic`` (including ``np.int32``, ``np.float32``,
|
||||
etc.) into plain numbers of plain python built-in types.
|
||||
"""
|
||||
if isinstance(obj, (set, range)):
|
||||
return list(obj)
|
||||
elif isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
elif isinstance(obj, np.generic):
|
||||
return obj.item()
|
||||
raise TypeError(f'{type(obj)} is unsupported for json dump')
|
||||
|
||||
|
||||
class JsonHandler(FormatHandler):
|
||||
|
||||
def load(self, file):
|
||||
return json.load(file)
|
||||
|
||||
def dump(self, obj, file, **kwargs):
|
||||
kwargs.setdefault('default', set_default)
|
||||
json.dump(obj, file, **kwargs)
|
||||
|
||||
def dumps(self, obj, **kwargs):
|
||||
kwargs.setdefault('default', set_default)
|
||||
return json.dumps(obj, **kwargs)
|
||||
25
maas_lib/fileio/format/yaml.py
Normal file
25
maas_lib/fileio/format/yaml.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import yaml
|
||||
|
||||
try:
|
||||
from yaml import CDumper as Dumper
|
||||
from yaml import CLoader as Loader
|
||||
except ImportError:
|
||||
from yaml import Loader, Dumper # type: ignore
|
||||
|
||||
from .base import FormatHandler # isort:skip
|
||||
|
||||
|
||||
class YamlHandler(FormatHandler):
|
||||
|
||||
def load(self, file, **kwargs):
|
||||
kwargs.setdefault('Loader', Loader)
|
||||
return yaml.load(file, **kwargs)
|
||||
|
||||
def dump(self, obj, file, **kwargs):
|
||||
kwargs.setdefault('Dumper', Dumper)
|
||||
yaml.dump(obj, file, **kwargs)
|
||||
|
||||
def dumps(self, obj, **kwargs):
|
||||
kwargs.setdefault('Dumper', Dumper)
|
||||
return yaml.dump(obj, **kwargs)
|
||||
127
maas_lib/fileio/io.py
Normal file
127
maas_lib/fileio/io.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from io import BytesIO, StringIO
|
||||
from pathlib import Path
|
||||
|
||||
from .file import File
|
||||
from .format import JsonHandler, YamlHandler
|
||||
|
||||
format_handlers = {
|
||||
'json': JsonHandler(),
|
||||
'yaml': YamlHandler(),
|
||||
'yml': YamlHandler(),
|
||||
}
|
||||
|
||||
|
||||
def load(file, file_format=None, **kwargs):
|
||||
"""Load data from json/yaml/pickle files.
|
||||
|
||||
This method provides a unified api for loading data from serialized files.
|
||||
|
||||
Args:
|
||||
file (str or :obj:`Path` or file-like object): Filename or a file-like
|
||||
object.
|
||||
file_format (str, optional): If not specified, the file format will be
|
||||
inferred from the file extension, otherwise use the specified one.
|
||||
Currently supported formats include "json", "yaml/yml".
|
||||
|
||||
Examples:
|
||||
>>> load('/path/of/your/file') # file is storaged in disk
|
||||
>>> load('https://path/of/your/file') # file is storaged in Internet
|
||||
>>> load('oss://path/of/your/file') # file is storaged in petrel
|
||||
|
||||
Returns:
|
||||
The content from the file.
|
||||
"""
|
||||
if isinstance(file, Path):
|
||||
file = str(file)
|
||||
if file_format is None and isinstance(file, str):
|
||||
file_format = file.split('.')[-1]
|
||||
if file_format not in format_handlers:
|
||||
raise TypeError(f'Unsupported format: {file_format}')
|
||||
|
||||
handler = format_handlers[file_format]
|
||||
if isinstance(file, str):
|
||||
if handler.text_mode:
|
||||
with StringIO(File.read_text(file)) as f:
|
||||
obj = handler.load(f, **kwargs)
|
||||
else:
|
||||
with BytesIO(File.read(file)) as f:
|
||||
obj = handler.load(f, **kwargs)
|
||||
elif hasattr(file, 'read'):
|
||||
obj = handler.load(file, **kwargs)
|
||||
else:
|
||||
raise TypeError('"file" must be a filepath str or a file-object')
|
||||
return obj
|
||||
|
||||
|
||||
def dump(obj, file=None, file_format=None, **kwargs):
|
||||
"""Dump data to json/yaml strings or files.
|
||||
|
||||
This method provides a unified api for dumping data as strings or to files.
|
||||
|
||||
Args:
|
||||
obj (any): The python object to be dumped.
|
||||
file (str or :obj:`Path` or file-like object, optional): If not
|
||||
specified, then the object is dumped to a str, otherwise to a file
|
||||
specified by the filename or file-like object.
|
||||
file_format (str, optional): Same as :func:`load`.
|
||||
|
||||
Examples:
|
||||
>>> dump('hello world', '/path/of/your/file') # disk
|
||||
>>> dump('hello world', 'oss://path/of/your/file') # oss
|
||||
|
||||
Returns:
|
||||
bool: True for success, False otherwise.
|
||||
"""
|
||||
if isinstance(file, Path):
|
||||
file = str(file)
|
||||
if file_format is None:
|
||||
if isinstance(file, str):
|
||||
file_format = file.split('.')[-1]
|
||||
elif file is None:
|
||||
raise ValueError(
|
||||
'file_format must be specified since file is None')
|
||||
if file_format not in format_handlers:
|
||||
raise TypeError(f'Unsupported format: {file_format}')
|
||||
|
||||
handler = format_handlers[file_format]
|
||||
if file is None:
|
||||
return handler.dump_to_str(obj, **kwargs)
|
||||
elif isinstance(file, str):
|
||||
if handler.text_mode:
|
||||
with StringIO() as f:
|
||||
handler.dump(obj, f, **kwargs)
|
||||
File.write_text(f.getvalue(), file)
|
||||
else:
|
||||
with BytesIO() as f:
|
||||
handler.dump(obj, f, **kwargs)
|
||||
File.write(f.getvalue(), file)
|
||||
elif hasattr(file, 'write'):
|
||||
handler.dump(obj, file, **kwargs)
|
||||
else:
|
||||
raise TypeError('"file" must be a filename str or a file-object')
|
||||
|
||||
|
||||
def dumps(obj, format, **kwargs):
|
||||
"""Dump data to json/yaml strings or files.
|
||||
|
||||
This method provides a unified api for dumping data as strings or to files.
|
||||
|
||||
Args:
|
||||
obj (any): The python object to be dumped.
|
||||
format (str, optional): Same as file_format :func:`load`.
|
||||
|
||||
Examples:
|
||||
>>> dumps('hello world', 'json') # disk
|
||||
>>> dumps('hello world', 'yaml') # oss
|
||||
|
||||
Returns:
|
||||
bool: True for success, False otherwise.
|
||||
"""
|
||||
if format not in format_handlers:
|
||||
raise TypeError(f'Unsupported format: {format}')
|
||||
|
||||
handler = format_handlers[format]
|
||||
return handler.dumps(obj, **kwargs)
|
||||
472
maas_lib/utils/config.py
Normal file
472
maas_lib/utils/config.py
Normal file
@@ -0,0 +1,472 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import ast
|
||||
import copy
|
||||
import os
|
||||
import os.path as osp
|
||||
import platform
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import types
|
||||
import uuid
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import addict
|
||||
from yapf.yapflib.yapf_api import FormatCode
|
||||
|
||||
from maas_lib.utils.logger import get_logger
|
||||
from maas_lib.utils.pymod import (import_modules, import_modules_from_file,
|
||||
validate_py_syntax)
|
||||
|
||||
if platform.system() == 'Windows':
|
||||
import regex as re # type: ignore
|
||||
else:
|
||||
import re # type: ignore
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
BASE_KEY = '_base_'
|
||||
DELETE_KEY = '_delete_'
|
||||
DEPRECATION_KEY = '_deprecation_'
|
||||
RESERVED_KEYS = ['filename', 'text', 'pretty_text']
|
||||
|
||||
|
||||
class ConfigDict(addict.Dict):
|
||||
""" Dict which support get value through getattr
|
||||
|
||||
Examples:
|
||||
>>> cdict = ConfigDict({'a':1232})
|
||||
>>> print(cdict.a)
|
||||
1232
|
||||
"""
|
||||
|
||||
def __missing__(self, name):
|
||||
raise KeyError(name)
|
||||
|
||||
def __getattr__(self, name):
|
||||
try:
|
||||
value = super(ConfigDict, self).__getattr__(name)
|
||||
except KeyError:
|
||||
ex = AttributeError(f"'{self.__class__.__name__}' object has no "
|
||||
f"attribute '{name}'")
|
||||
except Exception as e:
|
||||
ex = e
|
||||
else:
|
||||
return value
|
||||
raise ex
|
||||
|
||||
|
||||
class Config:
|
||||
"""A facility for config and config files.
|
||||
|
||||
It supports common file formats as configs: python/json/yaml. The interface
|
||||
is the same as a dict object and also allows access config values as
|
||||
attributes.
|
||||
|
||||
Example:
|
||||
>>> cfg = Config(dict(a=1, b=dict(c=[1,2,3], d='dd')))
|
||||
>>> cfg.a
|
||||
1
|
||||
>>> cfg.b
|
||||
{'c': [1, 2, 3], 'd': 'dd'}
|
||||
>>> cfg.b.d
|
||||
'dd'
|
||||
>>> cfg = Config.from_file('configs/examples/config.json')
|
||||
>>> cfg.filename
|
||||
'configs/examples/config.json'
|
||||
>>> cfg.b
|
||||
{'c': [1, 2, 3], 'd': 'dd'}
|
||||
>>> cfg = Config.from_file('configs/examples/config.py')
|
||||
>>> cfg.filename
|
||||
"configs/examples/config.py"
|
||||
>>> cfg = Config.from_file('configs/examples/config.yaml')
|
||||
>>> cfg.filename
|
||||
"configs/examples/config.yaml"
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _file2dict(filename):
|
||||
filename = osp.abspath(osp.expanduser(filename))
|
||||
if not osp.exists(filename):
|
||||
raise ValueError(f'File does not exists {filename}')
|
||||
fileExtname = osp.splitext(filename)[1]
|
||||
if fileExtname not in ['.py', '.json', '.yaml', '.yml']:
|
||||
raise IOError('Only py/yml/yaml/json type are supported now!')
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_cfg_dir:
|
||||
tmp_cfg_file = tempfile.NamedTemporaryFile(
|
||||
dir=tmp_cfg_dir, suffix=fileExtname)
|
||||
if platform.system() == 'Windows':
|
||||
tmp_cfg_file.close()
|
||||
tmp_cfg_name = osp.basename(tmp_cfg_file.name)
|
||||
shutil.copyfile(filename, tmp_cfg_file.name)
|
||||
|
||||
if filename.endswith('.py'):
|
||||
module_nanme, mod = import_modules_from_file(
|
||||
osp.join(tmp_cfg_dir, tmp_cfg_name))
|
||||
cfg_dict = {}
|
||||
for name, value in mod.__dict__.items():
|
||||
if not name.startswith('__') and \
|
||||
not isinstance(value, types.ModuleType) and \
|
||||
not isinstance(value, types.FunctionType):
|
||||
cfg_dict[name] = value
|
||||
|
||||
# delete imported module
|
||||
del sys.modules[module_nanme]
|
||||
elif filename.endswith(('.yml', '.yaml', '.json')):
|
||||
from maas_lib.fileio import load
|
||||
cfg_dict = load(tmp_cfg_file.name)
|
||||
# close temp file
|
||||
tmp_cfg_file.close()
|
||||
|
||||
cfg_text = filename + '\n'
|
||||
with open(filename, 'r', encoding='utf-8') as f:
|
||||
# Setting encoding explicitly to resolve coding issue on windows
|
||||
cfg_text += f.read()
|
||||
|
||||
return cfg_dict, cfg_text
|
||||
|
||||
@staticmethod
|
||||
def from_file(filename):
|
||||
if isinstance(filename, Path):
|
||||
filename = str(filename)
|
||||
cfg_dict, cfg_text = Config._file2dict(filename)
|
||||
return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
|
||||
|
||||
@staticmethod
|
||||
def from_string(cfg_str, file_format):
|
||||
"""Generate config from config str.
|
||||
|
||||
Args:
|
||||
cfg_str (str): Config str.
|
||||
file_format (str): Config file format corresponding to the
|
||||
config str. Only py/yml/yaml/json type are supported now!
|
||||
|
||||
Returns:
|
||||
:obj:`Config`: Config obj.
|
||||
"""
|
||||
if file_format not in ['.py', '.json', '.yaml', '.yml']:
|
||||
raise IOError('Only py/yml/yaml/json type are supported now!')
|
||||
if file_format != '.py' and 'dict(' in cfg_str:
|
||||
# check if users specify a wrong suffix for python
|
||||
logger.warning(
|
||||
'Please check "file_format", the file format may be .py')
|
||||
with tempfile.NamedTemporaryFile(
|
||||
'w', encoding='utf-8', suffix=file_format,
|
||||
delete=False) as temp_file:
|
||||
temp_file.write(cfg_str)
|
||||
# on windows, previous implementation cause error
|
||||
# see PR 1077 for details
|
||||
cfg = Config.from_file(temp_file.name)
|
||||
os.remove(temp_file.name)
|
||||
return cfg
|
||||
|
||||
def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
|
||||
if cfg_dict is None:
|
||||
cfg_dict = dict()
|
||||
elif not isinstance(cfg_dict, dict):
|
||||
raise TypeError('cfg_dict must be a dict, but '
|
||||
f'got {type(cfg_dict)}')
|
||||
for key in cfg_dict:
|
||||
if key in RESERVED_KEYS:
|
||||
raise KeyError(f'{key} is reserved for config file')
|
||||
|
||||
if isinstance(filename, Path):
|
||||
filename = str(filename)
|
||||
|
||||
super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
|
||||
super(Config, self).__setattr__('_filename', filename)
|
||||
if cfg_text:
|
||||
text = cfg_text
|
||||
elif filename:
|
||||
with open(filename, 'r') as f:
|
||||
text = f.read()
|
||||
else:
|
||||
text = ''
|
||||
super(Config, self).__setattr__('_text', text)
|
||||
|
||||
@property
|
||||
def filename(self):
|
||||
return self._filename
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
return self._text
|
||||
|
||||
@property
|
||||
def pretty_text(self):
|
||||
|
||||
indent = 4
|
||||
|
||||
def _indent(s_, num_spaces):
|
||||
s = s_.split('\n')
|
||||
if len(s) == 1:
|
||||
return s_
|
||||
first = s.pop(0)
|
||||
s = [(num_spaces * ' ') + line for line in s]
|
||||
s = '\n'.join(s)
|
||||
s = first + '\n' + s
|
||||
return s
|
||||
|
||||
def _format_basic_types(k, v, use_mapping=False):
|
||||
if isinstance(v, str):
|
||||
v_str = f"'{v}'"
|
||||
else:
|
||||
v_str = str(v)
|
||||
|
||||
if use_mapping:
|
||||
k_str = f"'{k}'" if isinstance(k, str) else str(k)
|
||||
attr_str = f'{k_str}: {v_str}'
|
||||
else:
|
||||
attr_str = f'{str(k)}={v_str}'
|
||||
attr_str = _indent(attr_str, indent)
|
||||
|
||||
return attr_str
|
||||
|
||||
def _format_list(k, v, use_mapping=False):
|
||||
# check if all items in the list are dict
|
||||
if all(isinstance(_, dict) for _ in v):
|
||||
v_str = '[\n'
|
||||
v_str += '\n'.join(
|
||||
f'dict({_indent(_format_dict(v_), indent)}),'
|
||||
for v_ in v).rstrip(',')
|
||||
if use_mapping:
|
||||
k_str = f"'{k}'" if isinstance(k, str) else str(k)
|
||||
attr_str = f'{k_str}: {v_str}'
|
||||
else:
|
||||
attr_str = f'{str(k)}={v_str}'
|
||||
attr_str = _indent(attr_str, indent) + ']'
|
||||
else:
|
||||
attr_str = _format_basic_types(k, v, use_mapping)
|
||||
return attr_str
|
||||
|
||||
def _contain_invalid_identifier(dict_str):
|
||||
contain_invalid_identifier = False
|
||||
for key_name in dict_str:
|
||||
contain_invalid_identifier |= \
|
||||
(not str(key_name).isidentifier())
|
||||
return contain_invalid_identifier
|
||||
|
||||
def _format_dict(input_dict, outest_level=False):
|
||||
r = ''
|
||||
s = []
|
||||
|
||||
use_mapping = _contain_invalid_identifier(input_dict)
|
||||
if use_mapping:
|
||||
r += '{'
|
||||
for idx, (k, v) in enumerate(input_dict.items()):
|
||||
is_last = idx >= len(input_dict) - 1
|
||||
end = '' if outest_level or is_last else ','
|
||||
if isinstance(v, dict):
|
||||
v_str = '\n' + _format_dict(v)
|
||||
if use_mapping:
|
||||
k_str = f"'{k}'" if isinstance(k, str) else str(k)
|
||||
attr_str = f'{k_str}: dict({v_str}'
|
||||
else:
|
||||
attr_str = f'{str(k)}=dict({v_str}'
|
||||
attr_str = _indent(attr_str, indent) + ')' + end
|
||||
elif isinstance(v, list):
|
||||
attr_str = _format_list(k, v, use_mapping) + end
|
||||
else:
|
||||
attr_str = _format_basic_types(k, v, use_mapping) + end
|
||||
|
||||
s.append(attr_str)
|
||||
r += '\n'.join(s)
|
||||
if use_mapping:
|
||||
r += '}'
|
||||
return r
|
||||
|
||||
cfg_dict = self._cfg_dict.to_dict()
|
||||
text = _format_dict(cfg_dict, outest_level=True)
|
||||
# copied from setup.cfg
|
||||
yapf_style = dict(
|
||||
based_on_style='pep8',
|
||||
blank_line_before_nested_class_or_def=True,
|
||||
split_before_expression_after_opening_paren=True)
|
||||
text, _ = FormatCode(text, style_config=yapf_style, verify=True)
|
||||
|
||||
return text
|
||||
|
||||
def __repr__(self):
|
||||
return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}'
|
||||
|
||||
def __len__(self):
|
||||
return len(self._cfg_dict)
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._cfg_dict, name)
|
||||
|
||||
def __getitem__(self, name):
|
||||
return self._cfg_dict.__getitem__(name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if isinstance(value, dict):
|
||||
value = ConfigDict(value)
|
||||
self._cfg_dict.__setattr__(name, value)
|
||||
|
||||
def __setitem__(self, name, value):
|
||||
if isinstance(value, dict):
|
||||
value = ConfigDict(value)
|
||||
self._cfg_dict.__setitem__(name, value)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._cfg_dict)
|
||||
|
||||
def __getstate__(self):
|
||||
return (self._cfg_dict, self._filename, self._text)
|
||||
|
||||
def __copy__(self):
|
||||
cls = self.__class__
|
||||
other = cls.__new__(cls)
|
||||
other.__dict__.update(self.__dict__)
|
||||
|
||||
return other
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
cls = self.__class__
|
||||
other = cls.__new__(cls)
|
||||
memo[id(self)] = other
|
||||
|
||||
for key, value in self.__dict__.items():
|
||||
super(Config, other).__setattr__(key, copy.deepcopy(value, memo))
|
||||
|
||||
return other
|
||||
|
||||
def __setstate__(self, state):
|
||||
_cfg_dict, _filename, _text = state
|
||||
super(Config, self).__setattr__('_cfg_dict', _cfg_dict)
|
||||
super(Config, self).__setattr__('_filename', _filename)
|
||||
super(Config, self).__setattr__('_text', _text)
|
||||
|
||||
def dump(self, file: str = None):
|
||||
"""Dumps config into a file or returns a string representation of the
|
||||
config.
|
||||
|
||||
If a file argument is given, saves the config to that file using the
|
||||
format defined by the file argument extension.
|
||||
|
||||
Otherwise, returns a string representing the config. The formatting of
|
||||
this returned string is defined by the extension of `self.filename`. If
|
||||
`self.filename` is not defined, returns a string representation of a
|
||||
dict (lowercased and using ' for strings).
|
||||
|
||||
Examples:
|
||||
>>> cfg_dict = dict(item1=[1, 2], item2=dict(a=0),
|
||||
... item3=True, item4='test')
|
||||
>>> cfg = Config(cfg_dict=cfg_dict)
|
||||
>>> dump_file = "a.py"
|
||||
>>> cfg.dump(dump_file)
|
||||
|
||||
Args:
|
||||
file (str, optional): Path of the output file where the config
|
||||
will be dumped. Defaults to None.
|
||||
"""
|
||||
from maas_lib.fileio import dump
|
||||
cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict()
|
||||
if file is None:
|
||||
if self.filename is None or self.filename.endswith('.py'):
|
||||
return self.pretty_text
|
||||
else:
|
||||
file_format = self.filename.split('.')[-1]
|
||||
return dump(cfg_dict, file_format=file_format)
|
||||
elif file.endswith('.py'):
|
||||
with open(file, 'w', encoding='utf-8') as f:
|
||||
f.write(self.pretty_text)
|
||||
else:
|
||||
file_format = file.split('.')[-1]
|
||||
return dump(cfg_dict, file=file, file_format=file_format)
|
||||
|
||||
def merge_from_dict(self, options, allow_list_keys=True):
|
||||
"""Merge list into cfg_dict.
|
||||
|
||||
Merge the dict parsed by MultipleKVAction into this cfg.
|
||||
|
||||
Examples:
|
||||
>>> options = {'model.backbone.depth': 50,
|
||||
... 'model.backbone.with_cp':True}
|
||||
>>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet'))))
|
||||
>>> cfg.merge_from_dict(options)
|
||||
>>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
|
||||
>>> assert cfg_dict == dict(
|
||||
... model=dict(backbone=dict(depth=50, with_cp=True)))
|
||||
|
||||
>>> # Merge list element
|
||||
>>> cfg = Config(dict(pipeline=[
|
||||
... dict(type='Resize'), dict(type='RandomDistortion')]))
|
||||
>>> options = dict(pipeline={'0': dict(type='MyResize')})
|
||||
>>> cfg.merge_from_dict(options, allow_list_keys=True)
|
||||
>>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
|
||||
>>> assert cfg_dict == dict(pipeline=[
|
||||
... dict(type='MyResize'), dict(type='RandomDistortion')])
|
||||
|
||||
Args:
|
||||
options (dict): dict of configs to merge from.
|
||||
allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
|
||||
are allowed in ``options`` and will replace the element of the
|
||||
corresponding index in the config if the config is a list.
|
||||
Default: True.
|
||||
"""
|
||||
option_cfg_dict = {}
|
||||
for full_key, v in options.items():
|
||||
d = option_cfg_dict
|
||||
key_list = full_key.split('.')
|
||||
for subkey in key_list[:-1]:
|
||||
d.setdefault(subkey, ConfigDict())
|
||||
d = d[subkey]
|
||||
subkey = key_list[-1]
|
||||
d[subkey] = v
|
||||
|
||||
cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
|
||||
super(Config, self).__setattr__(
|
||||
'_cfg_dict',
|
||||
Config._merge_a_into_b(
|
||||
option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys))
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
""" Convert Config object to python dict
|
||||
"""
|
||||
return self._cfg_dict.to_dict()
|
||||
|
||||
def to_args(self, parse_fn, use_hyphen=True):
|
||||
""" Convert config obj to args using parse_fn
|
||||
|
||||
Args:
|
||||
parse_fn: a function object, which takes args as input,
|
||||
such as ['--foo', 'FOO'] and return parsed args, an
|
||||
example is given as follows
|
||||
including literal blocks::
|
||||
def parse_fn(args):
|
||||
parser = argparse.ArgumentParser(prog='PROG')
|
||||
parser.add_argument('-x')
|
||||
parser.add_argument('--foo')
|
||||
return parser.parse_args(args)
|
||||
use_hyphen (bool, optional): if set true, hyphen in keyname
|
||||
will be converted to underscore
|
||||
Return:
|
||||
args: arg object parsed by argparse.ArgumentParser
|
||||
"""
|
||||
args = []
|
||||
for k, v in self._cfg_dict.items():
|
||||
arg_name = f'--{k}'
|
||||
if use_hyphen:
|
||||
arg_name = arg_name.replace('_', '-')
|
||||
if isinstance(v, bool) and v:
|
||||
args.append(arg_name)
|
||||
elif isinstance(v, (int, str, float)):
|
||||
args.append(arg_name)
|
||||
args.append(str(v))
|
||||
elif isinstance(v, list):
|
||||
args.append(arg_name)
|
||||
assert isinstance(v, (int, str, float, bool)), 'Element type in list ' \
|
||||
f'is expected to be either int,str,float, but got type {v[0]}'
|
||||
args.append(str(v))
|
||||
else:
|
||||
raise ValueError(
|
||||
'type in config file which supported to be '
|
||||
'converted to args should be either bool, '
|
||||
f'int, str, float or list of them but got type {v}')
|
||||
|
||||
return parse_fn(args)
|
||||
34
maas_lib/utils/constant.py
Normal file
34
maas_lib/utils/constant.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
|
||||
class Fields(object):
|
||||
""" Names for different application fields
|
||||
"""
|
||||
image = 'image'
|
||||
video = 'video'
|
||||
nlp = 'nlp'
|
||||
audio = 'audio'
|
||||
multi_modal = 'multi_modal'
|
||||
|
||||
|
||||
class Tasks(object):
|
||||
""" Names for tasks supported by maas lib.
|
||||
|
||||
Holds the standard task name to use for identifying different tasks.
|
||||
This should be used to register models, pipelines, trainers.
|
||||
"""
|
||||
# vision tasks
|
||||
image_classfication = 'image-classification'
|
||||
object_detection = 'object-detection'
|
||||
|
||||
# nlp tasks
|
||||
sentiment_analysis = 'sentiment-analysis'
|
||||
fill_mask = 'fill-mask'
|
||||
|
||||
|
||||
class InputFields(object):
|
||||
""" Names for input data fileds in the input data for pipelines
|
||||
"""
|
||||
img = 'img'
|
||||
text = 'text'
|
||||
audio = 'audio'
|
||||
45
maas_lib/utils/logger.py
Normal file
45
maas_lib/utils/logger.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
init_loggers = {}
|
||||
|
||||
|
||||
def get_logger(log_file: Optional[str] = None,
|
||||
log_level: int = logging.INFO,
|
||||
file_mode: str = 'w'):
|
||||
""" Get logging logger
|
||||
|
||||
Args:
|
||||
log_file: Log filename, if specified, file handler will be added to
|
||||
logger
|
||||
log_level: Logging level.
|
||||
file_mode: Specifies the mode to open the file, if filename is
|
||||
specified (if filemode is unspecified, it defaults to 'w').
|
||||
"""
|
||||
logger_name = __name__.split('.')[0]
|
||||
logger = logging.getLogger(logger_name)
|
||||
|
||||
if logger_name in init_loggers:
|
||||
return logger
|
||||
|
||||
stream_handler = logging.StreamHandler()
|
||||
handlers = [stream_handler]
|
||||
|
||||
# TODO @wenmeng.zwm add logger setting for distributed environment
|
||||
if log_file is not None:
|
||||
file_handler = logging.FileHandler(log_file, file_mode)
|
||||
handlers.append(file_handler)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
for handler in handlers:
|
||||
handler.setFormatter(formatter)
|
||||
handler.setLevel(log_level)
|
||||
logger.addHandler(handler)
|
||||
|
||||
logger.setLevel(log_level)
|
||||
init_loggers[logger_name] = True
|
||||
|
||||
return logger
|
||||
90
maas_lib/utils/pymod.py
Normal file
90
maas_lib/utils/pymod.py
Normal file
@@ -0,0 +1,90 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import ast
|
||||
import os
|
||||
import os.path as osp
|
||||
import sys
|
||||
import types
|
||||
from importlib import import_module
|
||||
|
||||
from maas_lib.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def import_modules_from_file(py_file: str):
|
||||
""" Import module from a certrain file
|
||||
|
||||
Args:
|
||||
py_file: path to a python file to be imported
|
||||
|
||||
Return:
|
||||
|
||||
"""
|
||||
dirname, basefile = os.path.split(py_file)
|
||||
if dirname == '':
|
||||
dirname == './'
|
||||
module_name = osp.splitext(basefile)[0]
|
||||
sys.path.insert(0, dirname)
|
||||
validate_py_syntax(py_file)
|
||||
mod = import_module(module_name)
|
||||
sys.path.pop(0)
|
||||
return module_name, mod
|
||||
|
||||
|
||||
def import_modules(imports, allow_failed_imports=False):
|
||||
"""Import modules from the given list of strings.
|
||||
|
||||
Args:
|
||||
imports (list | str | None): The given module names to be imported.
|
||||
allow_failed_imports (bool): If True, the failed imports will return
|
||||
None. Otherwise, an ImportError is raise. Default: False.
|
||||
|
||||
Returns:
|
||||
list[module] | module | None: The imported modules.
|
||||
|
||||
Examples:
|
||||
>>> osp, sys = import_modules(
|
||||
... ['os.path', 'sys'])
|
||||
>>> import os.path as osp_
|
||||
>>> import sys as sys_
|
||||
>>> assert osp == osp_
|
||||
>>> assert sys == sys_
|
||||
"""
|
||||
if not imports:
|
||||
return
|
||||
single_import = False
|
||||
if isinstance(imports, str):
|
||||
single_import = True
|
||||
imports = [imports]
|
||||
if not isinstance(imports, list):
|
||||
raise TypeError(
|
||||
f'custom_imports must be a list but got type {type(imports)}')
|
||||
imported = []
|
||||
for imp in imports:
|
||||
if not isinstance(imp, str):
|
||||
raise TypeError(
|
||||
f'{imp} is of type {type(imp)} and cannot be imported.')
|
||||
try:
|
||||
imported_tmp = import_module(imp)
|
||||
except ImportError:
|
||||
if allow_failed_imports:
|
||||
logger.warning(f'{imp} failed to import and is ignored.')
|
||||
imported_tmp = None
|
||||
else:
|
||||
raise ImportError
|
||||
imported.append(imported_tmp)
|
||||
if single_import:
|
||||
imported = imported[0]
|
||||
return imported
|
||||
|
||||
|
||||
def validate_py_syntax(filename):
|
||||
with open(filename, 'r', encoding='utf-8') as f:
|
||||
# Setting encoding explicitly to resolve coding issue on windows
|
||||
content = f.read()
|
||||
try:
|
||||
ast.parse(content)
|
||||
except SyntaxError as e:
|
||||
raise SyntaxError('There are syntax errors in config '
|
||||
f'file {filename}: {e}')
|
||||
183
maas_lib/utils/registry.py
Normal file
183
maas_lib/utils/registry.py
Normal file
@@ -0,0 +1,183 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import inspect
|
||||
|
||||
from maas_lib.utils.logger import get_logger
|
||||
|
||||
default_group = 'default'
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class Registry(object):
|
||||
""" Registry which support registering modules and group them by a keyname
|
||||
|
||||
If group name is not provided, modules will be registered to default group.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str):
|
||||
self._name = name
|
||||
self._modules = dict()
|
||||
|
||||
def __repr__(self):
|
||||
format_str = self.__class__.__name__ + f'({self._name})\n'
|
||||
for group_name, group in self._modules.items():
|
||||
format_str += f'group_name={group_name}, '\
|
||||
f'modules={list(group.keys())}\n'
|
||||
|
||||
return format_str
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def modules(self):
|
||||
return self._modules
|
||||
|
||||
def list(self):
|
||||
""" logging the list of module in current registry
|
||||
"""
|
||||
for group_name, group in self._modules.items():
|
||||
logger.info(f'group_name={group_name}')
|
||||
for m in group.keys():
|
||||
logger.info(f'\t{m}')
|
||||
logger.info('')
|
||||
|
||||
def get(self, module_key, group_key=default_group):
|
||||
if group_key not in self._modules:
|
||||
return None
|
||||
else:
|
||||
return self._modules[group_key].get(module_key, None)
|
||||
|
||||
def _register_module(self,
|
||||
group_key=default_group,
|
||||
module_name=None,
|
||||
module_cls=None):
|
||||
assert isinstance(group_key,
|
||||
str), 'group_key is required and must be str'
|
||||
if group_key not in self._modules:
|
||||
self._modules[group_key] = dict()
|
||||
|
||||
if not inspect.isclass(module_cls):
|
||||
raise TypeError(f'module is not a class type: {type(module_cls)}')
|
||||
|
||||
if module_name is None:
|
||||
module_name = module_cls.__name__
|
||||
|
||||
if module_name in self._modules[group_key]:
|
||||
raise KeyError(f'{module_name} is already registered in'
|
||||
f'{self._name}[{group_key}]')
|
||||
|
||||
self._modules[group_key][module_name] = module_cls
|
||||
|
||||
def register_module(self,
|
||||
group_key: str = default_group,
|
||||
module_name: str = None,
|
||||
module_cls: type = None):
|
||||
""" Register module
|
||||
|
||||
Example:
|
||||
>>> models = Registry('models')
|
||||
>>> @models.register_module('image-classification', 'SwinT')
|
||||
>>> class SwinTransformer:
|
||||
>>> pass
|
||||
|
||||
>>> @models.register_module('SwinDefault')
|
||||
>>> class SwinTransformerDefaultGroup:
|
||||
>>> pass
|
||||
|
||||
Args:
|
||||
group_key: Group name of which module will be registered,
|
||||
default group name is 'default'
|
||||
module_name: Module name
|
||||
module_cls: Module class object
|
||||
|
||||
"""
|
||||
if not (module_name is None or isinstance(module_name, str)):
|
||||
raise TypeError(f'module_name must be either of None, str,'
|
||||
f'got {type(module_name)}')
|
||||
|
||||
if module_cls is not None:
|
||||
self._register_module(
|
||||
group_key=group_key,
|
||||
module_name=module_name,
|
||||
module_cls=module_cls)
|
||||
return module_cls
|
||||
|
||||
# if module_cls is None, should return a dectorator function
|
||||
def _register(module_cls):
|
||||
self._register_module(
|
||||
group_key=group_key,
|
||||
module_name=module_name,
|
||||
module_cls=module_cls)
|
||||
return module_cls
|
||||
|
||||
return _register
|
||||
|
||||
|
||||
def build_from_cfg(cfg,
|
||||
registry: Registry,
|
||||
group_key: str = default_group,
|
||||
default_args: dict = None) -> object:
|
||||
"""Build a module from config dict when it is a class configuration, or
|
||||
call a function from config dict when it is a function configuration.
|
||||
|
||||
Example:
|
||||
>>> models = Registry('models')
|
||||
>>> @models.register_module('image-classification', 'SwinT')
|
||||
>>> class SwinTransformer:
|
||||
>>> pass
|
||||
>>> swint = build_from_cfg(dict(type='SwinT'), MODELS,
|
||||
>>> 'image-classification')
|
||||
>>> # Returns an instantiated object
|
||||
>>>
|
||||
>>> @MODELS.register_module()
|
||||
>>> def swin_transformer():
|
||||
>>> pass
|
||||
>>> = build_from_cfg(dict(type='swin_transformer'), MODELS)
|
||||
>>> # Return a result of the calling function
|
||||
|
||||
Args:
|
||||
cfg (dict): Config dict. It should at least contain the key "type".
|
||||
registry (:obj:`Registry`): The registry to search the type from.
|
||||
group_key (str, optional): The name of registry group from which
|
||||
module should be searched.
|
||||
default_args (dict, optional): Default initialization arguments.
|
||||
Returns:
|
||||
object: The constructed object.
|
||||
"""
|
||||
if not isinstance(cfg, dict):
|
||||
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
|
||||
if 'type' not in cfg:
|
||||
if default_args is None or 'type' not in default_args:
|
||||
raise KeyError(
|
||||
'`cfg` or `default_args` must contain the key "type", '
|
||||
f'but got {cfg}\n{default_args}')
|
||||
if not isinstance(registry, Registry):
|
||||
raise TypeError('registry must be an maas_lib.Registry object, '
|
||||
f'but got {type(registry)}')
|
||||
if not (isinstance(default_args, dict) or default_args is None):
|
||||
raise TypeError('default_args must be a dict or None, '
|
||||
f'but got {type(default_args)}')
|
||||
|
||||
args = cfg.copy()
|
||||
|
||||
if default_args is not None:
|
||||
for name, value in default_args.items():
|
||||
args.setdefault(name, value)
|
||||
|
||||
obj_type = args.pop('type')
|
||||
if isinstance(obj_type, str):
|
||||
obj_cls = registry.get(obj_type, group_key=group_key)
|
||||
if obj_cls is None:
|
||||
raise KeyError(f'{obj_type} is not in the {registry.name}'
|
||||
f'registry group {group_key}')
|
||||
elif inspect.isclass(obj_type) or inspect.isfunction(obj_type):
|
||||
obj_cls = obj_type
|
||||
else:
|
||||
raise TypeError(
|
||||
f'type must be a str or valid type, but got {type(obj_type)}')
|
||||
try:
|
||||
return obj_cls(**args)
|
||||
except Exception as e:
|
||||
# Normal TypeError does not print class name.
|
||||
raise type(e)(f'{obj_cls.__name__}: {e}')
|
||||
1
requirements.txt
Normal file
1
requirements.txt
Normal file
@@ -0,0 +1 @@
|
||||
-r requirements/runtime.txt
|
||||
6
requirements/docs.txt
Normal file
6
requirements/docs.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
docutils==0.16.0
|
||||
recommonmark
|
||||
sphinx==4.0.2
|
||||
sphinx-copybutton
|
||||
sphinx_markdown_tables
|
||||
sphinx_rtd_theme==0.5.2
|
||||
5
requirements/runtime.txt
Normal file
5
requirements/runtime.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
addict
|
||||
numpy
|
||||
pyyaml
|
||||
requests
|
||||
yapf
|
||||
5
requirements/tests.txt
Normal file
5
requirements/tests.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
expecttest
|
||||
flake8
|
||||
isort==4.3.21
|
||||
pre-commit
|
||||
yapf==0.30.0
|
||||
24
setup.cfg
Normal file
24
setup.cfg
Normal file
@@ -0,0 +1,24 @@
|
||||
[isort]
|
||||
line_length = 79
|
||||
multi_line_output = 0
|
||||
known_standard_library = setuptools
|
||||
known_first_party = maas_lib
|
||||
known_third_party = json,yaml
|
||||
no_lines_before = STDLIB,LOCALFOLDER
|
||||
default_section = THIRDPARTY
|
||||
|
||||
[yapf]
|
||||
BASED_ON_STYLE = pep8
|
||||
BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true
|
||||
SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true
|
||||
|
||||
[codespell]
|
||||
skip = *.ipynb
|
||||
quiet-level = 3
|
||||
ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids
|
||||
|
||||
[flake8]
|
||||
select = B,C,E,F,P,T4,W,B9
|
||||
max-line-length = 120
|
||||
ignore = F401
|
||||
exclude = docs/src,*.pyi,.git
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
0
tests/fileio/__init__.py
Normal file
0
tests/fileio/__init__.py
Normal file
70
tests/fileio/test_file.py
Normal file
70
tests/fileio/test_file.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from requests import HTTPError
|
||||
|
||||
from maas_lib.fileio.file import File, HTTPStorage, LocalStorage
|
||||
|
||||
|
||||
class FileTest(unittest.TestCase):
|
||||
|
||||
def test_local_storage(self):
|
||||
storage = LocalStorage()
|
||||
temp_name = tempfile.gettempdir() + '/' + next(
|
||||
tempfile._get_candidate_names())
|
||||
binary_content = b'12345'
|
||||
storage.write(binary_content, temp_name)
|
||||
self.assertEqual(binary_content, storage.read(temp_name))
|
||||
|
||||
content = '12345'
|
||||
storage.write_text(content, temp_name)
|
||||
self.assertEqual(content, storage.read_text(temp_name))
|
||||
|
||||
os.remove(temp_name)
|
||||
|
||||
def test_http_storage(self):
|
||||
storage = HTTPStorage()
|
||||
url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com' \
|
||||
'/data/test/data.txt'
|
||||
content = 'this is test data'
|
||||
self.assertEqual(content.encode('utf8'), storage.read(url))
|
||||
self.assertEqual(content, storage.read_text(url))
|
||||
|
||||
with storage.as_local_path(url) as local_file:
|
||||
with open(local_file, 'r') as infile:
|
||||
self.assertEqual(content, infile.read())
|
||||
|
||||
with self.assertRaises(NotImplementedError):
|
||||
storage.write('dfad', url)
|
||||
|
||||
with self.assertRaises(HTTPError):
|
||||
storage.read(url + 'df')
|
||||
|
||||
def test_file(self):
|
||||
url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com'\
|
||||
'/data/test/data.txt'
|
||||
content = 'this is test data'
|
||||
self.assertEqual(content.encode('utf8'), File.read(url))
|
||||
|
||||
with File.as_local_path(url) as local_file:
|
||||
with open(local_file, 'r') as infile:
|
||||
self.assertEqual(content, infile.read())
|
||||
|
||||
with self.assertRaises(NotImplementedError):
|
||||
File.write('dfad', url)
|
||||
|
||||
with self.assertRaises(HTTPError):
|
||||
File.read(url + 'df')
|
||||
|
||||
temp_name = tempfile.gettempdir() + '/' + next(
|
||||
tempfile._get_candidate_names())
|
||||
binary_content = b'12345'
|
||||
File.write(binary_content, temp_name)
|
||||
self.assertEqual(binary_content, File.read(temp_name))
|
||||
os.remove(temp_name)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
32
tests/fileio/test_io.py
Normal file
32
tests/fileio/test_io.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from maas_lib.fileio.io import dump, dumps, load
|
||||
|
||||
|
||||
class FileIOTest(unittest.TestCase):
|
||||
|
||||
def test_format(self, format='json'):
|
||||
obj = [1, 2, 3, 'str', {'model': 'resnet'}]
|
||||
result_str = dumps(obj, format)
|
||||
temp_name = tempfile.gettempdir() + '/' + next(
|
||||
tempfile._get_candidate_names()) + '.' + format
|
||||
dump(obj, temp_name)
|
||||
obj_load = load(temp_name)
|
||||
self.assertEqual(obj_load, obj)
|
||||
with open(temp_name, 'r') as infile:
|
||||
self.assertEqual(result_str, infile.read())
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
obj_load = load(temp_name + 's')
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
dump(obj, temp_name + 's')
|
||||
|
||||
def test_yaml(self):
|
||||
self.test_format('yaml')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
53
tests/run.py
Normal file
53
tests/run.py
Normal file
@@ -0,0 +1,53 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from fnmatch import fnmatch
|
||||
|
||||
|
||||
def gather_test_cases(test_dir, pattern, list_tests):
|
||||
case_list = []
|
||||
for dirpath, dirnames, filenames in os.walk(test_dir):
|
||||
for file in filenames:
|
||||
if fnmatch(file, pattern):
|
||||
case_list.append(file)
|
||||
|
||||
test_suite = unittest.TestSuite()
|
||||
|
||||
for case in case_list:
|
||||
test_case = unittest.defaultTestLoader.discover(
|
||||
start_dir=test_dir, pattern=case)
|
||||
test_suite.addTest(test_case)
|
||||
if hasattr(test_case, '__iter__'):
|
||||
for subcase in test_case:
|
||||
if list_tests:
|
||||
print(subcase)
|
||||
else:
|
||||
if list_tests:
|
||||
print(test_case)
|
||||
return test_suite
|
||||
|
||||
|
||||
def main(args):
|
||||
runner = unittest.TextTestRunner()
|
||||
test_suite = gather_test_cases(
|
||||
os.path.abspath(args.test_dir), args.pattern, args.list_tests)
|
||||
if not args.list_tests:
|
||||
result = runner.run(test_suite)
|
||||
if len(result.failures) > 0:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser('test runner')
|
||||
parser.add_argument(
|
||||
'--list_tests', action='store_true', help='list all tests')
|
||||
parser.add_argument(
|
||||
'--pattern', default='test_*.py', help='test file pattern')
|
||||
parser.add_argument(
|
||||
'--test_dir', default='tests', help='directory to be tested')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
0
tests/utils/__init__.py
Normal file
0
tests/utils/__init__.py
Normal file
85
tests/utils/test_config.py
Normal file
85
tests/utils/test_config.py
Normal file
@@ -0,0 +1,85 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from maas_lib.fileio import dump, load
|
||||
from maas_lib.utils.config import Config
|
||||
|
||||
obj = {'a': 1, 'b': {'c': [1, 2, 3], 'd': 'dd'}}
|
||||
|
||||
|
||||
class ConfigTest(unittest.TestCase):
|
||||
|
||||
def test_json(self):
|
||||
config_file = 'configs/examples/config.json'
|
||||
cfg = Config.from_file(config_file)
|
||||
self.assertEqual(cfg.a, 1)
|
||||
self.assertEqual(cfg.b, obj['b'])
|
||||
|
||||
def test_yaml(self):
|
||||
config_file = 'configs/examples/config.yaml'
|
||||
cfg = Config.from_file(config_file)
|
||||
self.assertEqual(cfg.a, 1)
|
||||
self.assertEqual(cfg.b, obj['b'])
|
||||
|
||||
def test_py(self):
|
||||
config_file = 'configs/examples/config.py'
|
||||
cfg = Config.from_file(config_file)
|
||||
self.assertEqual(cfg.a, 1)
|
||||
self.assertEqual(cfg.b, obj['b'])
|
||||
|
||||
def test_dump(self):
|
||||
config_file = 'configs/examples/config.py'
|
||||
cfg = Config.from_file(config_file)
|
||||
self.assertEqual(cfg.a, 1)
|
||||
self.assertEqual(cfg.b, obj['b'])
|
||||
pretty_text = 'a = 1\n'
|
||||
pretty_text += "b = dict(c=[1, 2, 3], d='dd')\n"
|
||||
|
||||
json_str = '{"a": 1, "b": {"c": [1, 2, 3], "d": "dd"}}'
|
||||
yaml_str = 'a: 1\nb:\n c:\n - 1\n - 2\n - 3\n d: dd\n'
|
||||
with tempfile.NamedTemporaryFile(suffix='.json') as ofile:
|
||||
self.assertEqual(pretty_text, cfg.dump())
|
||||
cfg.dump(ofile.name)
|
||||
with open(ofile.name, 'r') as infile:
|
||||
self.assertEqual(json_str, infile.read())
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix='.yaml') as ofile:
|
||||
cfg.dump(ofile.name)
|
||||
with open(ofile.name, 'r') as infile:
|
||||
self.assertEqual(yaml_str, infile.read())
|
||||
|
||||
def test_to_dict(self):
|
||||
config_file = 'configs/examples/config.json'
|
||||
cfg = Config.from_file(config_file)
|
||||
d = cfg.to_dict()
|
||||
print(d)
|
||||
self.assertTrue(isinstance(d, dict))
|
||||
|
||||
def test_to_args(self):
|
||||
|
||||
def parse_fn(args):
|
||||
parser = argparse.ArgumentParser(prog='PROG')
|
||||
parser.add_argument('--model-dir', default='')
|
||||
parser.add_argument('--lr', type=float, default=0.001)
|
||||
parser.add_argument('--optimizer', default='')
|
||||
parser.add_argument('--weight-decay', type=float, default=1e-7)
|
||||
parser.add_argument(
|
||||
'--save-checkpoint-epochs', type=int, default=30)
|
||||
return parser.parse_args(args)
|
||||
|
||||
cfg = Config.from_file('configs/examples/plain_args.yaml')
|
||||
args = cfg.to_args(parse_fn)
|
||||
|
||||
self.assertEqual(args.model_dir, 'path/to/model')
|
||||
self.assertAlmostEqual(args.lr, 0.01)
|
||||
self.assertAlmostEqual(args.weight_decay, 1e-6)
|
||||
self.assertEqual(args.optimizer, 'Adam')
|
||||
self.assertEqual(args.save_checkpoint_epochs, 20)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
91
tests/utils/test_registry.py
Normal file
91
tests/utils/test_registry.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import unittest
|
||||
|
||||
from maas_lib.utils.constant import Tasks
|
||||
from maas_lib.utils.registry import Registry, build_from_cfg, default_group
|
||||
|
||||
|
||||
class RegistryTest(unittest.TestCase):
|
||||
|
||||
def test_register_class_no_task(self):
|
||||
MODELS = Registry('models')
|
||||
self.assertTrue(MODELS.name == 'models')
|
||||
self.assertTrue(MODELS.modules == {})
|
||||
self.assertEqual(len(MODELS.modules), 0)
|
||||
|
||||
@MODELS.register_module(module_name='cls-resnet')
|
||||
class ResNetForCls(object):
|
||||
pass
|
||||
|
||||
self.assertTrue(default_group in MODELS.modules)
|
||||
self.assertTrue(MODELS.get('cls-resnet') is ResNetForCls)
|
||||
|
||||
def test_register_class_with_task(self):
|
||||
MODELS = Registry('models')
|
||||
|
||||
@MODELS.register_module(Tasks.image_classfication, 'SwinT')
|
||||
class SwinTForCls(object):
|
||||
pass
|
||||
|
||||
self.assertTrue(Tasks.image_classfication in MODELS.modules)
|
||||
self.assertTrue(
|
||||
MODELS.get('SwinT', Tasks.image_classfication) is SwinTForCls)
|
||||
|
||||
@MODELS.register_module(Tasks.sentiment_analysis, 'Bert')
|
||||
class BertForSentimentAnalysis(object):
|
||||
pass
|
||||
|
||||
self.assertTrue(Tasks.sentiment_analysis in MODELS.modules)
|
||||
self.assertTrue(
|
||||
MODELS.get('Bert', Tasks.sentiment_analysis) is
|
||||
BertForSentimentAnalysis)
|
||||
|
||||
@MODELS.register_module(Tasks.object_detection)
|
||||
class DETR(object):
|
||||
pass
|
||||
|
||||
self.assertTrue(Tasks.object_detection in MODELS.modules)
|
||||
self.assertTrue(MODELS.get('DETR', Tasks.object_detection) is DETR)
|
||||
|
||||
self.assertEqual(len(MODELS.modules), 3)
|
||||
|
||||
def test_list(self):
|
||||
MODELS = Registry('models')
|
||||
|
||||
@MODELS.register_module(Tasks.image_classfication, 'SwinT')
|
||||
class SwinTForCls(object):
|
||||
pass
|
||||
|
||||
@MODELS.register_module(Tasks.sentiment_analysis, 'Bert')
|
||||
class BertForSentimentAnalysis(object):
|
||||
pass
|
||||
|
||||
MODELS.list()
|
||||
print(MODELS)
|
||||
|
||||
def test_build(self):
|
||||
MODELS = Registry('models')
|
||||
|
||||
@MODELS.register_module(Tasks.image_classfication, 'SwinT')
|
||||
class SwinTForCls(object):
|
||||
pass
|
||||
|
||||
@MODELS.register_module(Tasks.sentiment_analysis, 'Bert')
|
||||
class BertForSentimentAnalysis(object):
|
||||
pass
|
||||
|
||||
cfg = dict(type='SwinT')
|
||||
model = build_from_cfg(cfg, MODELS, Tasks.image_classfication)
|
||||
self.assertTrue(isinstance(model, SwinTForCls))
|
||||
|
||||
cfg = dict(type='Bert')
|
||||
model = build_from_cfg(cfg, MODELS, Tasks.sentiment_analysis)
|
||||
self.assertTrue(isinstance(model, BertForSentimentAnalysis))
|
||||
|
||||
with self.assertRaises(KeyError):
|
||||
cfg = dict(type='Bert')
|
||||
model = build_from_cfg(cfg, MODELS, Tasks.image_classfication)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user