mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 08:17:45 +01:00
[to #42322933] update ast_index logic
This commit is contained in:
@@ -3,17 +3,31 @@ repos:
|
||||
rev: 4.0.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
exclude: thirdparty/|examples/
|
||||
exclude: |
|
||||
(?x)^(
|
||||
thirdparty/|
|
||||
examples/|
|
||||
modelscope/utils/ast_index_file.py
|
||||
)$
|
||||
- repo: https://github.com/PyCQA/isort.git
|
||||
rev: 4.3.21
|
||||
hooks:
|
||||
- id: isort
|
||||
exclude: examples
|
||||
exclude: |
|
||||
(?x)^(
|
||||
examples/|
|
||||
modelscope/utils/ast_index_file.py
|
||||
)$
|
||||
- repo: https://github.com/pre-commit/mirrors-yapf.git
|
||||
rev: v0.30.0
|
||||
hooks:
|
||||
- id: yapf
|
||||
exclude: thirdparty/|examples/
|
||||
exclude: |
|
||||
(?x)^(
|
||||
thirdparty/|
|
||||
examples/|
|
||||
modelscope/utils/ast_index_file.py
|
||||
)$
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks.git
|
||||
rev: v3.1.0
|
||||
hooks:
|
||||
|
||||
@@ -3,17 +3,31 @@ repos:
|
||||
rev: 4.0.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
exclude: thirdparty/|examples/
|
||||
exclude: |
|
||||
(?x)^(
|
||||
thirdparty/|
|
||||
examples/|
|
||||
modelscope/utils/ast_index_file.py
|
||||
)$
|
||||
- repo: /home/admin/pre-commit/isort
|
||||
rev: 4.3.21
|
||||
hooks:
|
||||
- id: isort
|
||||
exclude: examples
|
||||
exclude: |
|
||||
(?x)^(
|
||||
examples/|
|
||||
modelscope/utils/ast_index_file.py
|
||||
)$
|
||||
- repo: /home/admin/pre-commit/mirrors-yapf
|
||||
rev: v0.30.0
|
||||
hooks:
|
||||
- id: yapf
|
||||
exclude: thirdparty/|examples/
|
||||
exclude: |
|
||||
(?x)^(
|
||||
thirdparty/|
|
||||
examples/|
|
||||
modelscope/utils/ast_index_file.py
|
||||
)$
|
||||
- repo: /home/admin/pre-commit/pre-commit-hooks
|
||||
rev: v3.1.0
|
||||
hooks:
|
||||
|
||||
@@ -1,29 +1,48 @@
|
||||
import os
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.msdatasets.ms_dataset import MsDataset
|
||||
from modelscope.trainers.builder import build_trainer
|
||||
from modelscope.trainers.training_args import ArgAttr, CliArgumentParser, training_args
|
||||
from modelscope.trainers.training_args import (ArgAttr, CliArgumentParser,
|
||||
training_args)
|
||||
|
||||
|
||||
def define_parser():
|
||||
training_args.num_classes = ArgAttr(cfg_node_name=['model.mm_model.head.num_classes',
|
||||
'model.mm_model.train_cfg.augments.0.num_classes',
|
||||
'model.mm_model.train_cfg.augments.1.num_classes'],
|
||||
type=int, help='number of classes')
|
||||
training_args.num_classes = ArgAttr(
|
||||
cfg_node_name=[
|
||||
'model.mm_model.head.num_classes',
|
||||
'model.mm_model.train_cfg.augments.0.num_classes',
|
||||
'model.mm_model.train_cfg.augments.1.num_classes'
|
||||
],
|
||||
type=int,
|
||||
help='number of classes')
|
||||
|
||||
training_args.train_batch_size.default = 16
|
||||
training_args.train_data_worker.default = 1
|
||||
training_args.max_epochs.default = 1
|
||||
training_args.optimizer.default = 'AdamW'
|
||||
training_args.lr.default = 1e-4
|
||||
training_args.warmup_iters = ArgAttr('train.lr_config.warmup_iters', type=int, default=1, help='number of warmup epochs')
|
||||
training_args.topk = ArgAttr(cfg_node_name=['train.evaluation.metric_options.topk',
|
||||
'evaluation.metric_options.topk'],
|
||||
default=(1,), help='evaluation using topk, tuple format, eg (1,), (1,5)')
|
||||
training_args.warmup_iters = ArgAttr(
|
||||
'train.lr_config.warmup_iters',
|
||||
type=int,
|
||||
default=1,
|
||||
help='number of warmup epochs')
|
||||
training_args.topk = ArgAttr(
|
||||
cfg_node_name=[
|
||||
'train.evaluation.metric_options.topk',
|
||||
'evaluation.metric_options.topk'
|
||||
],
|
||||
default=(1, ),
|
||||
help='evaluation using topk, tuple format, eg (1,), (1,5)')
|
||||
|
||||
training_args.train_data = ArgAttr(type=str, default='tany0699/cats_and_dogs', help='train dataset')
|
||||
training_args.validation_data = ArgAttr(type=str, default='tany0699/cats_and_dogs', help='validation dataset')
|
||||
training_args.model_id = ArgAttr(type=str, default='damo/cv_vit-base_image-classification_ImageNet-labels', help='model name')
|
||||
training_args.train_data = ArgAttr(
|
||||
type=str, default='tany0699/cats_and_dogs', help='train dataset')
|
||||
training_args.validation_data = ArgAttr(
|
||||
type=str, default='tany0699/cats_and_dogs', help='validation dataset')
|
||||
training_args.model_id = ArgAttr(
|
||||
type=str,
|
||||
default='damo/cv_vit-base_image-classification_ImageNet-labels',
|
||||
help='model name')
|
||||
|
||||
parser = CliArgumentParser(training_args)
|
||||
return parser
|
||||
@@ -31,9 +50,8 @@ def define_parser():
|
||||
|
||||
def create_dataset(name, split):
|
||||
namespace, dataset_name = name.split('/')
|
||||
return MsDataset.load(dataset_name, namespace=namespace,
|
||||
subset_name='default',
|
||||
split=split)
|
||||
return MsDataset.load(
|
||||
dataset_name, namespace=namespace, subset_name='default', split=split)
|
||||
|
||||
|
||||
def train(parser):
|
||||
@@ -47,17 +65,18 @@ def train(parser):
|
||||
return cfg
|
||||
|
||||
kwargs = dict(
|
||||
model=args.model_id, # model id
|
||||
model=args.model_id, # model id
|
||||
train_dataset=train_dataset, # training dataset
|
||||
eval_dataset=val_dataset, # validation dataset
|
||||
cfg_modify_fn=cfg_modify_fn # callback to modify configuration
|
||||
eval_dataset=val_dataset, # validation dataset
|
||||
cfg_modify_fn=cfg_modify_fn # callback to modify configuration
|
||||
)
|
||||
|
||||
# in distributed training, specify pytorch launcher
|
||||
if 'MASTER_ADDR' in os.environ:
|
||||
kwargs['launcher'] = 'pytorch'
|
||||
|
||||
trainer = build_trainer(name=Trainers.image_classification, default_args=kwargs)
|
||||
trainer = build_trainer(
|
||||
name=Trainers.image_classification, default_args=kwargs)
|
||||
# start to train
|
||||
trainer.train()
|
||||
|
||||
|
||||
@@ -54,6 +54,8 @@ CLASS_NAME = 'class_name'
|
||||
GROUP_KEY = 'group_key'
|
||||
MODULE_NAME = 'module_name'
|
||||
MODULE_CLS = 'module_cls'
|
||||
TEMPLATE_PATH = 'TEMPLATE_PATH'
|
||||
TEMPLATE_FILE = 'ast_index_file.py'
|
||||
|
||||
|
||||
class AstScaning(object):
|
||||
@@ -611,7 +613,7 @@ class FilesAstScaning(object):
|
||||
file_scanner = FilesAstScaning()
|
||||
|
||||
|
||||
def _save_index(index, file_path, file_list=None):
|
||||
def _save_index(index, file_path, file_list=None, with_template=False):
|
||||
# convert tuple key to str key
|
||||
index[INDEX_KEY] = {str(k): v for k, v in index[INDEX_KEY].items()}
|
||||
index[VERSION_KEY] = __version__
|
||||
@@ -619,6 +621,9 @@ def _save_index(index, file_path, file_list=None):
|
||||
file_list=file_list)
|
||||
index[MODELSCOPE_PATH_KEY] = MODELSCOPE_PATH.as_posix()
|
||||
json_index = json.dumps(index)
|
||||
if with_template:
|
||||
json_index = json_index.replace(MODELSCOPE_PATH.as_posix(),
|
||||
TEMPLATE_PATH)
|
||||
storage.write(json_index.encode(), file_path)
|
||||
index[INDEX_KEY] = {
|
||||
ast.literal_eval(k): v
|
||||
@@ -626,8 +631,11 @@ def _save_index(index, file_path, file_list=None):
|
||||
}
|
||||
|
||||
|
||||
def _load_index(file_path):
|
||||
def _load_index(file_path, with_template=False):
|
||||
bytes_index = storage.read(file_path)
|
||||
if with_template:
|
||||
bytes_index = bytes_index.decode().replace(TEMPLATE_PATH,
|
||||
MODELSCOPE_PATH.as_posix())
|
||||
wrapped_index = json.loads(bytes_index)
|
||||
# convert str key to tuple key
|
||||
wrapped_index[INDEX_KEY] = {
|
||||
@@ -733,14 +741,21 @@ def load_index(
|
||||
|
||||
if full_index_flag:
|
||||
if force_rebuild:
|
||||
logger.info('Force rebuilding ast index')
|
||||
logger.info('Force rebuilding ast index from scanning every file!')
|
||||
index = file_scanner.get_files_scan_results(file_list)
|
||||
else:
|
||||
logger.info(
|
||||
f'No valid ast index found from {file_path}, rebuilding ast index!'
|
||||
f'No valid ast index found from {file_path}, generating ast index from prebuilt!'
|
||||
)
|
||||
index = file_scanner.get_files_scan_results(file_list)
|
||||
index = load_from_prebuilt()
|
||||
if index is None:
|
||||
index = file_scanner.get_files_scan_results(file_list)
|
||||
_save_index(index, file_path, file_list)
|
||||
elif local_changed and not full_index_flag:
|
||||
logger.info(
|
||||
'Updating the files for the changes of local files, '
|
||||
'first time updating will take longer time! Please wait till updating done!'
|
||||
)
|
||||
_update_index(index, files_mtime)
|
||||
_save_index(index, file_path, file_list)
|
||||
|
||||
@@ -760,5 +775,28 @@ def check_import_module_avaliable(module_dicts: dict) -> list:
|
||||
return missed_module
|
||||
|
||||
|
||||
def load_from_prebuilt(file_path=None):
|
||||
if file_path is None:
|
||||
local_path = p.resolve().parents[0]
|
||||
file_path = os.path.join(local_path, TEMPLATE_FILE)
|
||||
if os.path.exists(file_path):
|
||||
index = _load_index(file_path, with_template=True)
|
||||
else:
|
||||
index = None
|
||||
return index
|
||||
|
||||
|
||||
def generate_ast_template(file_path=None, force_rebuild=True):
|
||||
index = load_index(force_rebuild=force_rebuild)
|
||||
if file_path is None:
|
||||
local_path = p.resolve().parents[0]
|
||||
file_path = os.path.join(local_path, TEMPLATE_FILE)
|
||||
_save_index(index, file_path, with_template=True)
|
||||
if not os.path.exists(file_path):
|
||||
raise Exception(
|
||||
'The index file is not create correctly, please double check')
|
||||
return index
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
index = load_index()
|
||||
|
||||
2
setup.py
2
setup.py
@@ -5,6 +5,7 @@ import shutil
|
||||
import subprocess
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
from modelscope.utils.ast_utils import generate_ast_template
|
||||
from modelscope.utils.constant import Fields
|
||||
|
||||
|
||||
@@ -168,6 +169,7 @@ def pack_resource():
|
||||
|
||||
if __name__ == '__main__':
|
||||
# write_version_py()
|
||||
generate_ast_template()
|
||||
pack_resource()
|
||||
os.chdir('package')
|
||||
install_requires, deps_link = parse_requirements('requirements.txt')
|
||||
|
||||
@@ -10,7 +10,8 @@ from pathlib import Path
|
||||
from modelscope.utils.ast_utils import (FILES_MTIME_KEY, INDEX_KEY, MD5_KEY,
|
||||
MODELSCOPE_PATH_KEY, REQUIREMENT_KEY,
|
||||
VERSION_KEY, AstScaning,
|
||||
FilesAstScaning, load_index)
|
||||
FilesAstScaning, generate_ast_template,
|
||||
load_from_prebuilt, load_index)
|
||||
|
||||
p = Path(__file__)
|
||||
|
||||
@@ -134,6 +135,14 @@ class AstScaningTest(unittest.TestCase):
|
||||
self.assertIsInstance(output[VERSION_KEY], str)
|
||||
self.assertIsInstance(output[FILES_MTIME_KEY], dict)
|
||||
|
||||
# generate ast_template
|
||||
file_path = os.path.join(self.tmp_dir, 'index_file.py')
|
||||
index = generate_ast_template(file_path=file_path, force_rebuild=False)
|
||||
self.assertTrue(os.path.exists(file_path))
|
||||
self.assertEqual(output, index)
|
||||
index_from_prebuilt = load_from_prebuilt(file_path)
|
||||
self.assertEqual(index, index_from_prebuilt)
|
||||
|
||||
def test_update_load_index_method(self):
|
||||
file_number = 20
|
||||
file_list = []
|
||||
|
||||
Reference in New Issue
Block a user