[to #42322933] update ast_index logic

This commit is contained in:
zhangzhicheng.zzc
2023-01-11 10:43:56 +08:00
parent 346af6773f
commit 42898badf7
6 changed files with 127 additions and 31 deletions

View File

@@ -3,17 +3,31 @@ repos:
rev: 4.0.0
hooks:
- id: flake8
exclude: thirdparty/|examples/
exclude: |
(?x)^(
thirdparty/|
examples/|
modelscope/utils/ast_index_file.py
)$
- repo: https://github.com/PyCQA/isort.git
rev: 4.3.21
hooks:
- id: isort
exclude: examples
exclude: |
(?x)^(
examples/|
modelscope/utils/ast_index_file.py
)$
- repo: https://github.com/pre-commit/mirrors-yapf.git
rev: v0.30.0
hooks:
- id: yapf
exclude: thirdparty/|examples/
exclude: |
(?x)^(
thirdparty/|
examples/|
modelscope/utils/ast_index_file.py
)$
- repo: https://github.com/pre-commit/pre-commit-hooks.git
rev: v3.1.0
hooks:

View File

@@ -3,17 +3,31 @@ repos:
rev: 4.0.0
hooks:
- id: flake8
exclude: thirdparty/|examples/
exclude: |
(?x)^(
thirdparty/|
examples/|
modelscope/utils/ast_index_file.py
)$
- repo: /home/admin/pre-commit/isort
rev: 4.3.21
hooks:
- id: isort
exclude: examples
exclude: |
(?x)^(
examples/|
modelscope/utils/ast_index_file.py
)$
- repo: /home/admin/pre-commit/mirrors-yapf
rev: v0.30.0
hooks:
- id: yapf
exclude: thirdparty/|examples/
exclude: |
(?x)^(
thirdparty/|
examples/|
modelscope/utils/ast_index_file.py
)$
- repo: /home/admin/pre-commit/pre-commit-hooks
rev: v3.1.0
hooks:

View File

@@ -1,29 +1,48 @@
import os
from modelscope.metainfo import Trainers
from modelscope.msdatasets.ms_dataset import MsDataset
from modelscope.trainers.builder import build_trainer
from modelscope.trainers.training_args import ArgAttr, CliArgumentParser, training_args
from modelscope.trainers.training_args import (ArgAttr, CliArgumentParser,
training_args)
def define_parser():
training_args.num_classes = ArgAttr(cfg_node_name=['model.mm_model.head.num_classes',
'model.mm_model.train_cfg.augments.0.num_classes',
'model.mm_model.train_cfg.augments.1.num_classes'],
type=int, help='number of classes')
training_args.num_classes = ArgAttr(
cfg_node_name=[
'model.mm_model.head.num_classes',
'model.mm_model.train_cfg.augments.0.num_classes',
'model.mm_model.train_cfg.augments.1.num_classes'
],
type=int,
help='number of classes')
training_args.train_batch_size.default = 16
training_args.train_data_worker.default = 1
training_args.max_epochs.default = 1
training_args.optimizer.default = 'AdamW'
training_args.lr.default = 1e-4
training_args.warmup_iters = ArgAttr('train.lr_config.warmup_iters', type=int, default=1, help='number of warmup epochs')
training_args.topk = ArgAttr(cfg_node_name=['train.evaluation.metric_options.topk',
'evaluation.metric_options.topk'],
default=(1,), help='evaluation using topk, tuple format, eg (1,), (1,5)')
training_args.warmup_iters = ArgAttr(
'train.lr_config.warmup_iters',
type=int,
default=1,
help='number of warmup epochs')
training_args.topk = ArgAttr(
cfg_node_name=[
'train.evaluation.metric_options.topk',
'evaluation.metric_options.topk'
],
default=(1, ),
help='evaluation using topk, tuple format, eg (1,), (1,5)')
training_args.train_data = ArgAttr(type=str, default='tany0699/cats_and_dogs', help='train dataset')
training_args.validation_data = ArgAttr(type=str, default='tany0699/cats_and_dogs', help='validation dataset')
training_args.model_id = ArgAttr(type=str, default='damo/cv_vit-base_image-classification_ImageNet-labels', help='model name')
training_args.train_data = ArgAttr(
type=str, default='tany0699/cats_and_dogs', help='train dataset')
training_args.validation_data = ArgAttr(
type=str, default='tany0699/cats_and_dogs', help='validation dataset')
training_args.model_id = ArgAttr(
type=str,
default='damo/cv_vit-base_image-classification_ImageNet-labels',
help='model name')
parser = CliArgumentParser(training_args)
return parser
@@ -31,9 +50,8 @@ def define_parser():
def create_dataset(name, split):
namespace, dataset_name = name.split('/')
return MsDataset.load(dataset_name, namespace=namespace,
subset_name='default',
split=split)
return MsDataset.load(
dataset_name, namespace=namespace, subset_name='default', split=split)
def train(parser):
@@ -47,17 +65,18 @@ def train(parser):
return cfg
kwargs = dict(
model=args.model_id, # model id
model=args.model_id, # model id
train_dataset=train_dataset, # training dataset
eval_dataset=val_dataset, # validation dataset
cfg_modify_fn=cfg_modify_fn # callback to modify configuration
eval_dataset=val_dataset, # validation dataset
cfg_modify_fn=cfg_modify_fn # callback to modify configuration
)
# in distributed training, specify pytorch launcher
if 'MASTER_ADDR' in os.environ:
kwargs['launcher'] = 'pytorch'
trainer = build_trainer(name=Trainers.image_classification, default_args=kwargs)
trainer = build_trainer(
name=Trainers.image_classification, default_args=kwargs)
# start to train
trainer.train()

View File

@@ -54,6 +54,8 @@ CLASS_NAME = 'class_name'
GROUP_KEY = 'group_key'
MODULE_NAME = 'module_name'
MODULE_CLS = 'module_cls'
TEMPLATE_PATH = 'TEMPLATE_PATH'
TEMPLATE_FILE = 'ast_index_file.py'
class AstScaning(object):
@@ -611,7 +613,7 @@ class FilesAstScaning(object):
file_scanner = FilesAstScaning()
def _save_index(index, file_path, file_list=None):
def _save_index(index, file_path, file_list=None, with_template=False):
# convert tuple key to str key
index[INDEX_KEY] = {str(k): v for k, v in index[INDEX_KEY].items()}
index[VERSION_KEY] = __version__
@@ -619,6 +621,9 @@ def _save_index(index, file_path, file_list=None):
file_list=file_list)
index[MODELSCOPE_PATH_KEY] = MODELSCOPE_PATH.as_posix()
json_index = json.dumps(index)
if with_template:
json_index = json_index.replace(MODELSCOPE_PATH.as_posix(),
TEMPLATE_PATH)
storage.write(json_index.encode(), file_path)
index[INDEX_KEY] = {
ast.literal_eval(k): v
@@ -626,8 +631,11 @@ def _save_index(index, file_path, file_list=None):
}
def _load_index(file_path):
def _load_index(file_path, with_template=False):
bytes_index = storage.read(file_path)
if with_template:
bytes_index = bytes_index.decode().replace(TEMPLATE_PATH,
MODELSCOPE_PATH.as_posix())
wrapped_index = json.loads(bytes_index)
# convert str key to tuple key
wrapped_index[INDEX_KEY] = {
@@ -733,14 +741,21 @@ def load_index(
if full_index_flag:
if force_rebuild:
logger.info('Force rebuilding ast index')
logger.info('Force rebuilding ast index from scanning every file!')
index = file_scanner.get_files_scan_results(file_list)
else:
logger.info(
f'No valid ast index found from {file_path}, rebuilding ast index!'
f'No valid ast index found from {file_path}, generating ast index from prebuilt!'
)
index = file_scanner.get_files_scan_results(file_list)
index = load_from_prebuilt()
if index is None:
index = file_scanner.get_files_scan_results(file_list)
_save_index(index, file_path, file_list)
elif local_changed and not full_index_flag:
logger.info(
'Updating the files for the changes of local files, '
'first time updating will take longer time! Please wait till updating done!'
)
_update_index(index, files_mtime)
_save_index(index, file_path, file_list)
@@ -760,5 +775,28 @@ def check_import_module_avaliable(module_dicts: dict) -> list:
return missed_module
def load_from_prebuilt(file_path=None):
if file_path is None:
local_path = p.resolve().parents[0]
file_path = os.path.join(local_path, TEMPLATE_FILE)
if os.path.exists(file_path):
index = _load_index(file_path, with_template=True)
else:
index = None
return index
def generate_ast_template(file_path=None, force_rebuild=True):
index = load_index(force_rebuild=force_rebuild)
if file_path is None:
local_path = p.resolve().parents[0]
file_path = os.path.join(local_path, TEMPLATE_FILE)
_save_index(index, file_path, with_template=True)
if not os.path.exists(file_path):
raise Exception(
'The index file is not create correctly, please double check')
return index
if __name__ == '__main__':
index = load_index()

View File

@@ -5,6 +5,7 @@ import shutil
import subprocess
from setuptools import find_packages, setup
from modelscope.utils.ast_utils import generate_ast_template
from modelscope.utils.constant import Fields
@@ -168,6 +169,7 @@ def pack_resource():
if __name__ == '__main__':
# write_version_py()
generate_ast_template()
pack_resource()
os.chdir('package')
install_requires, deps_link = parse_requirements('requirements.txt')

View File

@@ -10,7 +10,8 @@ from pathlib import Path
from modelscope.utils.ast_utils import (FILES_MTIME_KEY, INDEX_KEY, MD5_KEY,
MODELSCOPE_PATH_KEY, REQUIREMENT_KEY,
VERSION_KEY, AstScaning,
FilesAstScaning, load_index)
FilesAstScaning, generate_ast_template,
load_from_prebuilt, load_index)
p = Path(__file__)
@@ -134,6 +135,14 @@ class AstScaningTest(unittest.TestCase):
self.assertIsInstance(output[VERSION_KEY], str)
self.assertIsInstance(output[FILES_MTIME_KEY], dict)
# generate ast_template
file_path = os.path.join(self.tmp_dir, 'index_file.py')
index = generate_ast_template(file_path=file_path, force_rebuild=False)
self.assertTrue(os.path.exists(file_path))
self.assertEqual(output, index)
index_from_prebuilt = load_from_prebuilt(file_path)
self.assertEqual(index, index_from_prebuilt)
def test_update_load_index_method(self):
file_number = 20
file_list = []