mirror of
https://github.com/modelscope/modelscope.git
synced 2026-05-18 05:05:00 +02:00
[to #46993990]feat: run ci cases base on code diff to reduct ci test time
This commit is contained in:
@@ -28,7 +28,7 @@ if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then
|
||||
awk -F: '/^[^#]/ { print $1 }' requirements/science.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
|
||||
# test with install
|
||||
python setup.py install
|
||||
pip install .
|
||||
else
|
||||
echo "Running case in release image, run case directly!"
|
||||
fi
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from modelscope.utils.constant import Fields
|
||||
|
||||
|
||||
class Models(object):
|
||||
@@ -12,7 +13,6 @@ class Models(object):
|
||||
# tinynas models
|
||||
tinynas_detection = 'tinynas-detection'
|
||||
tinynas_damoyolo = 'tinynas-damoyolo'
|
||||
|
||||
# vision models
|
||||
detection = 'detection'
|
||||
image_restoration = 'image-restoration'
|
||||
@@ -397,26 +397,7 @@ class Pipelines(object):
|
||||
protein_structure = 'unifold-protein-structure'
|
||||
|
||||
|
||||
class Trainers(object):
|
||||
""" Names for different trainer.
|
||||
|
||||
Holds the standard trainer name to use for identifying different trainer.
|
||||
This should be used to register trainers.
|
||||
|
||||
For a general Trainer, you can use EpochBasedTrainer.
|
||||
For a model specific Trainer, you can use ${ModelName}-${Task}-trainer.
|
||||
"""
|
||||
|
||||
default = 'trainer'
|
||||
easycv = 'easycv'
|
||||
tinynas_damoyolo = 'tinynas-damoyolo'
|
||||
|
||||
# multi-modal trainers
|
||||
clip_multi_modal_embedding = 'clip-multi-modal-embedding'
|
||||
ofa = 'ofa'
|
||||
mplug = 'mplug'
|
||||
mgeo_ranking_trainer = 'mgeo-ranking-trainer'
|
||||
|
||||
class CVTrainers(object):
|
||||
# cv trainers
|
||||
image_instance_segmentation = 'image-instance-segmentation'
|
||||
image_portrait_enhancement = 'image-portrait-enhancement'
|
||||
@@ -430,6 +411,8 @@ class Trainers(object):
|
||||
image_classification = 'image-classification'
|
||||
image_fewshot_detection = 'image-fewshot-detection'
|
||||
|
||||
|
||||
class NLPTrainers(object):
|
||||
# nlp trainers
|
||||
bert_sentiment_analysis = 'bert-sentiment-analysis'
|
||||
dialog_modeling_trainer = 'dialog-modeling-trainer'
|
||||
@@ -444,7 +427,15 @@ class Trainers(object):
|
||||
gpt_moe_trainer = 'nlp-gpt-moe-trainer'
|
||||
table_question_answering_trainer = 'table-question-answering-trainer'
|
||||
|
||||
# audio trainers
|
||||
|
||||
class MultiModalTrainers(object):
|
||||
clip_multi_modal_embedding = 'clip-multi-modal-embedding'
|
||||
ofa = 'ofa'
|
||||
mplug = 'mplug'
|
||||
mgeo_ranking_trainer = 'mgeo-ranking-trainer'
|
||||
|
||||
|
||||
class AudioTrainers(object):
|
||||
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
|
||||
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'
|
||||
speech_kws_fsmn_char_ctc_nearfield = 'speech_kws_fsmn_char_ctc_nearfield'
|
||||
@@ -453,6 +444,45 @@ class Trainers(object):
|
||||
speech_separation = 'speech-separation'
|
||||
|
||||
|
||||
class Trainers(CVTrainers, NLPTrainers, MultiModalTrainers, AudioTrainers):
|
||||
""" Names for different trainer.
|
||||
|
||||
Holds the standard trainer name to use for identifying different trainer.
|
||||
This should be used to register trainers.
|
||||
|
||||
For a general Trainer, you can use EpochBasedTrainer.
|
||||
For a model specific Trainer, you can use ${ModelName}-${Task}-trainer.
|
||||
"""
|
||||
|
||||
default = 'trainer'
|
||||
easycv = 'easycv'
|
||||
tinynas_damoyolo = 'tinynas-damoyolo'
|
||||
|
||||
@staticmethod
|
||||
def get_trainer_domain(attribute_or_value):
|
||||
if attribute_or_value in vars(
|
||||
CVTrainers) or attribute_or_value in vars(CVTrainers).values():
|
||||
return Fields.cv
|
||||
elif attribute_or_value in vars(
|
||||
NLPTrainers) or attribute_or_value in vars(
|
||||
NLPTrainers).values():
|
||||
return Fields.nlp
|
||||
elif attribute_or_value in vars(
|
||||
AudioTrainers) or attribute_or_value in vars(
|
||||
AudioTrainers).values():
|
||||
return Fields.audio
|
||||
elif attribute_or_value in vars(
|
||||
MultiModalTrainers) or attribute_or_value in vars(
|
||||
MultiModalTrainers).values():
|
||||
return Fields.multi_modal
|
||||
elif attribute_or_value == Trainers.default:
|
||||
return Trainers.default
|
||||
elif attribute_or_value == Trainers.easycv:
|
||||
return Trainers.easycv
|
||||
else:
|
||||
return 'unknown'
|
||||
|
||||
|
||||
class Preprocessors(object):
|
||||
""" Names for different preprocessor.
|
||||
|
||||
|
||||
@@ -6,68 +6,75 @@ import os
|
||||
import json
|
||||
import nltk
|
||||
|
||||
nltk.download('punkt')
|
||||
|
||||
|
||||
class NLTKSegmenter:
|
||||
|
||||
def __init(self):
|
||||
pass
|
||||
download_nltk()
|
||||
|
||||
@staticmethod
|
||||
def segment_string(article):
|
||||
return nltk.tokenize.sent_tokenize(article)
|
||||
|
||||
|
||||
wiki_path = 'data/extracted'
|
||||
output_path = 'formatted/wiki-key.txt'
|
||||
segmenter = NLTKSegmenter()
|
||||
with open(output_path, 'w') as output:
|
||||
for dirname in glob.glob(os.path.join(wiki_path, '*'), recursive=False):
|
||||
for filename in glob.glob(
|
||||
os.path.join(dirname, 'wiki_*'), recursive=True):
|
||||
print(filename)
|
||||
article_lines = []
|
||||
article_open = False
|
||||
with open(
|
||||
filename, mode='r', newline='\n',
|
||||
encoding='utf-8') as file:
|
||||
for line in file:
|
||||
line = line.rstrip()
|
||||
if '<doc id=' in line:
|
||||
article_open = True
|
||||
elif '</doc>' in line:
|
||||
key_sentences, contents = [], []
|
||||
key, content = None, []
|
||||
for sentences in article_lines[1:]:
|
||||
if len(sentences) > 1:
|
||||
if key:
|
||||
if len(content) > 0 or len(contents) == 0:
|
||||
key_sentences.append(key)
|
||||
contents.append(content)
|
||||
def download_nltk():
|
||||
nltk.download('punkt')
|
||||
wiki_path = 'data/extracted'
|
||||
output_path = 'formatted/wiki-key.txt'
|
||||
segmenter = NLTKSegmenter()
|
||||
with open(output_path, 'w') as output:
|
||||
for dirname in glob.glob(
|
||||
os.path.join(wiki_path, '*'), recursive=False):
|
||||
for filename in glob.glob(
|
||||
os.path.join(dirname, 'wiki_*'), recursive=True):
|
||||
print(filename)
|
||||
article_lines = []
|
||||
article_open = False
|
||||
with open(
|
||||
filename, mode='r', newline='\n',
|
||||
encoding='utf-8') as file:
|
||||
for line in file:
|
||||
line = line.rstrip()
|
||||
if '<doc id=' in line:
|
||||
article_open = True
|
||||
elif '</doc>' in line:
|
||||
key_sentences, contents = [], []
|
||||
key, content = None, []
|
||||
for sentences in article_lines[1:]:
|
||||
if len(sentences) > 1:
|
||||
if key:
|
||||
if len(content) > 0 or len(
|
||||
contents) == 0:
|
||||
key_sentences.append(key)
|
||||
contents.append(content)
|
||||
else:
|
||||
contents[-1].append(key)
|
||||
key, content = None, []
|
||||
key_sentences.append(sentences[0])
|
||||
contents.append(sentences[1:])
|
||||
elif len(sentences) > 0:
|
||||
if key:
|
||||
content.append(sentences[0])
|
||||
else:
|
||||
contents[-1].append(key)
|
||||
key, content = None, []
|
||||
key_sentences.append(sentences[0])
|
||||
contents.append(sentences[1:])
|
||||
elif len(sentences) > 0:
|
||||
if key:
|
||||
content.append(sentences[0])
|
||||
key = sentences[0]
|
||||
if key:
|
||||
if len(content) > 0 or len(contents) == 0:
|
||||
key_sentences.append(key)
|
||||
contents.append(content)
|
||||
else:
|
||||
key = sentences[0]
|
||||
if key:
|
||||
if len(content) > 0 or len(contents) == 0:
|
||||
key_sentences.append(key)
|
||||
contents.append(content)
|
||||
else:
|
||||
contents[-1].append(key)
|
||||
contents = [' '.join(content) for content in contents]
|
||||
article = {'key': key_sentences, 'content': contents}
|
||||
output.write(json.dumps(article))
|
||||
output.write('\n')
|
||||
article_open = False
|
||||
article_lines = []
|
||||
else:
|
||||
if article_open and line:
|
||||
sentences = segmenter.segment_string(line)
|
||||
article_lines.append(sentences)
|
||||
contents[-1].append(key)
|
||||
contents = [
|
||||
' '.join(content) for content in contents
|
||||
]
|
||||
article = {
|
||||
'key': key_sentences,
|
||||
'content': contents
|
||||
}
|
||||
output.write(json.dumps(article))
|
||||
output.write('\n')
|
||||
article_open = False
|
||||
article_lines = []
|
||||
else:
|
||||
if article_open and line:
|
||||
sentences = segmenter.segment_string(line)
|
||||
article_lines.append(sentences)
|
||||
|
||||
@@ -37,7 +37,7 @@ class SpaceModelBase(nn.Module):
|
||||
return
|
||||
|
||||
def _create_parameters(self):
|
||||
""" Create model's paramters. """
|
||||
""" Create model's parameters. """
|
||||
raise NotImplementedError
|
||||
|
||||
def _forward(self, inputs, is_training, with_label):
|
||||
|
||||
@@ -107,7 +107,7 @@ class UnifiedTransformer(SpaceModelBase):
|
||||
return
|
||||
|
||||
def _create_parameters(self):
|
||||
""" Create model's paramters. """
|
||||
""" Create model's parameters. """
|
||||
sequence_mask = np.tri(
|
||||
self.num_pos_embeddings,
|
||||
self.num_pos_embeddings,
|
||||
|
||||
@@ -199,7 +199,7 @@ class BaseTaskModel(TorchModel, ABC):
|
||||
]
|
||||
|
||||
if _fast_init:
|
||||
# retrieve unintialized modules and initialize
|
||||
# retrieve uninitialized modules and initialize
|
||||
uninitialized_modules = self.retrieve_modules_from_names(
|
||||
missing_keys,
|
||||
prefix=prefix,
|
||||
|
||||
@@ -4,11 +4,11 @@ from typing import TYPE_CHECKING
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .easycv_detection import DetDataset, DetImagesMixDataset
|
||||
from .detection_dataset import DetDataset, DetImagesMixDataset
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'easycv_detection': ['DetDataset', 'DetImagesMixDataset']
|
||||
'detection_dataset': ['DetDataset', 'DetImagesMixDataset']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
@@ -68,7 +68,7 @@ def _nn_variable(name, shape, init_method, collection=None, **kwargs):
|
||||
shape: variable shape
|
||||
init_method: 'zero', 'kaiming', 'xavier', or (mean, std)
|
||||
collection: if not none, add variable to this collection
|
||||
kwargs: extra paramters passed to tf.get_variable
|
||||
kwargs: extra parameters passed to tf.get_variable
|
||||
RETURN
|
||||
var: a new or existing variable
|
||||
"""
|
||||
|
||||
@@ -55,7 +55,7 @@ class BaseTrainer(ABC):
|
||||
""" Train (and evaluate) process
|
||||
|
||||
Train process should be implemented for specific task or
|
||||
model, releated paramters have been intialized in
|
||||
model, related parameters have been initialized in
|
||||
``BaseTrainer.__init__`` and should be used in this function
|
||||
"""
|
||||
pass
|
||||
@@ -66,7 +66,7 @@ class BaseTrainer(ABC):
|
||||
""" Evaluation process
|
||||
|
||||
Evaluation process should be implemented for specific task or
|
||||
model, releated paramters have been intialized in
|
||||
model, related parameters have been initialized in
|
||||
``BaseTrainer.__init__`` and should be used in this function
|
||||
"""
|
||||
pass
|
||||
@@ -87,7 +87,7 @@ class DummyTrainer(BaseTrainer):
|
||||
""" Train (and evaluate) process
|
||||
|
||||
Train process should be implemented for specific task or
|
||||
model, releated paramters have been intialized in
|
||||
model, related parameters have been initialized in
|
||||
``BaseTrainer.__init__`` and should be used in this function
|
||||
"""
|
||||
cfg = self.cfg.train
|
||||
@@ -100,7 +100,7 @@ class DummyTrainer(BaseTrainer):
|
||||
""" Evaluation process
|
||||
|
||||
Evaluation process should be implemented for specific task or
|
||||
model, releated paramters have been intialized in
|
||||
model, related parameters have been initialized in
|
||||
``BaseTrainer.__init__`` and should be used in this function
|
||||
"""
|
||||
cfg = self.cfg.evaluation
|
||||
|
||||
@@ -38,7 +38,7 @@ class TableQuestionAnsweringTrainer(BaseTrainer):
|
||||
num_training_steps,
|
||||
last_epoch=-1):
|
||||
"""
|
||||
set scheduler
|
||||
set scheduler.
|
||||
"""
|
||||
|
||||
def lr_lambda(current_step: int):
|
||||
|
||||
@@ -107,7 +107,7 @@ class CliArgumentParser(ArgumentParser):
|
||||
|
||||
Args:
|
||||
arg_dict (dict of `ArgAttr` or list of them): dict or list of dict which defines different
|
||||
paramters for training.
|
||||
parameters for training.
|
||||
"""
|
||||
|
||||
def __init__(self, arg_dict: Union[Dict[str, ArgAttr],
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
# docstyle-ignore
|
||||
AUDIO_IMPORT_ERROR = """
|
||||
Audio model import failed: {0}, if you want to use audio releated function, please execute
|
||||
Audio model import failed: {0}, if you want to use audio related function, please execute
|
||||
`pip install modelscope[audio] -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html`
|
||||
"""
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ numpy
|
||||
oss2
|
||||
Pillow>=6.2.0
|
||||
# pyarrow 9.0.0 introduced event_loop core dump
|
||||
pyarrow>=6.0.0,<9.0.0
|
||||
pyarrow>=6.0.0,!=9.0.0
|
||||
pyyaml
|
||||
requests
|
||||
scipy
|
||||
|
||||
@@ -49,7 +49,7 @@ class HubOperationTest(unittest.TestCase):
|
||||
repo.tag_and_push(self.revision, 'Test revision')
|
||||
|
||||
def test_model_repo_creation(self):
|
||||
# change to proper model names before use
|
||||
# change to proper model names before use.
|
||||
try:
|
||||
info = self.api.get_model(model_id=self.model_id)
|
||||
assert info['Name'] == self.model_name
|
||||
|
||||
45
tests/run.py
45
tests/run.py
@@ -3,6 +3,7 @@
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import importlib
|
||||
import math
|
||||
import multiprocessing
|
||||
import os
|
||||
@@ -362,10 +363,43 @@ def run_non_parallelizable_test_suites(suites, result_dir):
|
||||
run_command_with_popen(cmd)
|
||||
|
||||
|
||||
# Selected cases:
|
||||
def get_selected_cases():
|
||||
cmd = ['python', '-u', 'tests/run_analysis.py']
|
||||
selected_cases = []
|
||||
with subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
bufsize=1,
|
||||
encoding='utf8') as sub_process:
|
||||
for line in iter(sub_process.stdout.readline, ''):
|
||||
sys.stdout.write(line)
|
||||
if line.startswith('Selected cases:'):
|
||||
line = line.replace('Selected cases:', '').strip()
|
||||
selected_cases = line.split(',')
|
||||
return selected_cases
|
||||
|
||||
|
||||
def run_in_subprocess(args):
|
||||
# only case args.isolated_cases run in subporcess, all other run in a subprocess
|
||||
test_suite_files = gather_test_suites_files(
|
||||
os.path.abspath(args.test_dir), args.pattern)
|
||||
if not args.no_diff: # run based on git diff
|
||||
try:
|
||||
test_suite_files = get_selected_cases()
|
||||
logger.info('Tests suite to run: ')
|
||||
for f in test_suite_files:
|
||||
logger.info(f)
|
||||
except Exception:
|
||||
logger.error('Get test suite based diff exception!')
|
||||
test_suite_files = gather_test_suites_files(
|
||||
os.path.abspath(args.test_dir), args.pattern)
|
||||
if len(test_suite_files) == 0:
|
||||
logger.error('Get no test suite based on diff, run all the cases.')
|
||||
test_suite_files = gather_test_suites_files(
|
||||
os.path.abspath(args.test_dir), args.pattern)
|
||||
else:
|
||||
test_suite_files = gather_test_suites_files(
|
||||
os.path.abspath(args.test_dir), args.pattern)
|
||||
|
||||
non_parallelizable_suites = [
|
||||
'test_download_dataset.py',
|
||||
@@ -579,11 +613,18 @@ if __name__ == '__main__':
|
||||
type=int,
|
||||
help='Set case parallels, default single process, set with gpu number.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--no-diff',
|
||||
action='store_true',
|
||||
help=
|
||||
'Default running case based on git diff(with master), disable with --no-diff)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--suites',
|
||||
nargs='*',
|
||||
help='Run specified test suites(test suite files list split by space)')
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
set_test_level(args.level)
|
||||
os.environ['REGRESSION_BASELINE'] = '1'
|
||||
logger.info(f'TEST LEVEL: {test_level()}')
|
||||
|
||||
337
tests/run_analysis.py
Normal file
337
tests/run_analysis.py
Normal file
@@ -0,0 +1,337 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from fnmatch import fnmatch
|
||||
|
||||
from trainers.model_trainer_map import model_trainer_map
|
||||
from utils.case_file_analyzer import get_pipelines_trainers_test_info
|
||||
from utils.source_file_analyzer import (get_all_register_modules,
|
||||
get_file_register_modules,
|
||||
get_import_map)
|
||||
|
||||
from modelscope.hub.api import HubApi
|
||||
from modelscope.hub.errors import NotExistError
|
||||
from modelscope.hub.file_download import model_file_download
|
||||
from modelscope.hub.utils.utils import get_cache_dir
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def get_models_info(groups: list) -> dict:
|
||||
models = []
|
||||
api = HubApi()
|
||||
for group in groups:
|
||||
page = 1
|
||||
while True:
|
||||
query_result = api.list_models(group, page, 100)
|
||||
models.extend(query_result['Models'])
|
||||
if len(models) >= query_result['TotalCount']:
|
||||
break
|
||||
page += 1
|
||||
cache_root = get_cache_dir()
|
||||
models_info = {} # key model id, value model info
|
||||
for model_info in models:
|
||||
model_id = '%s/%s' % (group, model_info['Name'])
|
||||
configuration_file = os.path.join(cache_root, model_id,
|
||||
ModelFile.CONFIGURATION)
|
||||
if not os.path.exists(configuration_file):
|
||||
model_revisions = api.list_model_revisions(model_id=model_id)
|
||||
if len(model_revisions) == 0:
|
||||
logger.warn('Model: %s has no revision' % model_id)
|
||||
continue
|
||||
# get latest revision
|
||||
try:
|
||||
configuration_file = model_file_download(
|
||||
model_id=model_id,
|
||||
file_path=ModelFile.CONFIGURATION,
|
||||
revision=model_revisions[0])
|
||||
except NotExistError:
|
||||
logger.warn('Model: %s has no configuration file %s' %
|
||||
(model_id, ModelFile.CONFIGURATION))
|
||||
continue
|
||||
cfg = Config.from_file(configuration_file)
|
||||
model_info = {}
|
||||
model_info['framework'] = cfg.safe_get('framework')
|
||||
model_info['task'] = cfg.safe_get('task')
|
||||
model_info['model_type'] = cfg.safe_get('model.type')
|
||||
model_info['pipeline_type'] = cfg.safe_get('pipeline.type')
|
||||
model_info['preprocessor_type'] = cfg.safe_get('preprocessor.type')
|
||||
train_hooks_type = []
|
||||
train_hooks = cfg.safe_get('train.hooks')
|
||||
if train_hooks is not None:
|
||||
for train_hook in train_hooks:
|
||||
train_hooks_type.append(train_hook.type)
|
||||
model_info['train_hooks_type'] = train_hooks_type
|
||||
model_info['datasets'] = cfg.safe_get('dataset')
|
||||
|
||||
model_info['evaluation_metics'] = cfg.safe_get('evaluation.metrics',
|
||||
[]) # metrics name list
|
||||
"""
|
||||
print('framework: %s, task: %s, model_type: %s, pipeline_type: %s, \
|
||||
preprocessor_type: %s, train_hooks_type: %s, \
|
||||
dataset: %s, evaluation_metics: %s'%(
|
||||
framework, task, model_type, pipeline_type,
|
||||
preprocessor_type, ','.join(train_hooks_type),
|
||||
datasets, evaluation_metics))
|
||||
"""
|
||||
models_info[model_id] = model_info
|
||||
return models_info
|
||||
|
||||
|
||||
def gather_test_suites_files_full_path(test_dir='./tests',
|
||||
pattern='test_*.py'):
|
||||
case_file_list = []
|
||||
for dirpath, dirnames, filenames in os.walk(test_dir):
|
||||
for file in filenames:
|
||||
if fnmatch(file, pattern):
|
||||
case_file_list.append(os.path.join(dirpath, file))
|
||||
|
||||
return case_file_list
|
||||
|
||||
|
||||
def run_command_get_output(cmd):
|
||||
response = subprocess.run(
|
||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
try:
|
||||
response.check_returncode()
|
||||
output = response.stdout.decode('utf8')
|
||||
return output
|
||||
except subprocess.CalledProcessError as error:
|
||||
logger.error(
|
||||
'stdout: %s, stderr: %s' %
|
||||
(response.stdout.decode('utf8'), error.stderr.decode('utf8')))
|
||||
return None
|
||||
|
||||
|
||||
def get_modified_files():
|
||||
cmd = ['git', 'diff', '--name-only', 'origin/master...']
|
||||
cmd_output = run_command_get_output(cmd)
|
||||
logger.info('Modified files: ')
|
||||
logger.info(cmd_output)
|
||||
return cmd_output.splitlines()
|
||||
|
||||
|
||||
def analysis_diff():
|
||||
"""Get modified files and their imported files modified modules
|
||||
"""
|
||||
modified_register_modules = []
|
||||
modified_cases = []
|
||||
modified_files_imported_by = []
|
||||
modified_files = get_modified_files()
|
||||
logger.info('Modified files:\n %s' % '\n'.join(modified_files))
|
||||
|
||||
logger.info('Starting get import map')
|
||||
import_map = get_import_map()
|
||||
logger.info('Finished get import map')
|
||||
for modified_file in modified_files:
|
||||
if modified_file.startswith('./modelscope') or \
|
||||
modified_file.startswith('modelscope'): # is source file
|
||||
for k, v in import_map.items():
|
||||
if modified_file in v and modified_file != k:
|
||||
modified_files_imported_by.append(k)
|
||||
logger.info('There are affected files: %s'
|
||||
% len(modified_files_imported_by))
|
||||
for f in modified_files_imported_by:
|
||||
logger.info(f)
|
||||
modified_files.extend(modified_files_imported_by) # add imported by file
|
||||
for modified_file in modified_files:
|
||||
if modified_file.startswith('./modelscope') or \
|
||||
modified_file.startswith('modelscope'):
|
||||
modified_register_modules.extend(
|
||||
get_file_register_modules(modified_file))
|
||||
elif ((modified_file.startswith('./tests')
|
||||
or modified_file.startswith('tests'))
|
||||
and os.path.basename(modified_file).startswith('test_')):
|
||||
modified_cases.append(modified_file)
|
||||
|
||||
return modified_register_modules, modified_cases
|
||||
|
||||
|
||||
def split_test_suites():
|
||||
test_suite_full_paths = gather_test_suites_files_full_path()
|
||||
pipeline_test_suites = []
|
||||
trainer_test_suites = []
|
||||
other_test_suites = []
|
||||
for test_suite in test_suite_full_paths:
|
||||
if test_suite.find('tests/trainers') != -1:
|
||||
trainer_test_suites.append(test_suite)
|
||||
elif test_suite.find('tests/pipelines') != -1:
|
||||
pipeline_test_suites.append(test_suite)
|
||||
else:
|
||||
other_test_suites.append(test_suite)
|
||||
|
||||
return pipeline_test_suites, trainer_test_suites, other_test_suites
|
||||
|
||||
|
||||
def get_test_suites_to_run():
|
||||
affected_register_modules, modified_cases = analysis_diff()
|
||||
# affected_register_modules list of modified file and dependent file's register_module.
|
||||
# ("MODULES|PIPELINES|TRAINERS|...", '', '', model_class_name)
|
||||
# modified_cases, modified case file.
|
||||
all_register_modules = get_all_register_modules()
|
||||
_, _, other_test_suites = split_test_suites()
|
||||
task_pipeline_test_suite_map, trainer_test_suite_map = get_pipelines_trainers_test_info(
|
||||
all_register_modules)
|
||||
# task_pipeline_test_suite_map key: pipeline task, value: case file path
|
||||
# trainer_test_suite_map key: trainer_name, value: case file path
|
||||
models_info = get_models_info(['damo'])
|
||||
# model_info key: model_id, value: model info such as framework task etc.
|
||||
affected_pipeline_cases = []
|
||||
affected_trainer_cases = []
|
||||
for affected_register_module in affected_register_modules:
|
||||
# affected_register_module PIPELINE structure
|
||||
# ["PIPELINES", "acoustic_noise_suppression", "speech_frcrn_ans_cirm_16k", "ANSPipeline"]
|
||||
# ["PIPELINES", task, pipeline_name, pipeline_class_name]
|
||||
if affected_register_module[0] == 'PIPELINES':
|
||||
if affected_register_module[1] in task_pipeline_test_suite_map:
|
||||
affected_pipeline_cases.extend(
|
||||
task_pipeline_test_suite_map[affected_register_module[1]])
|
||||
else:
|
||||
logger.warn('Pipeline task: %s has no test case!'
|
||||
% affected_register_module[1])
|
||||
elif affected_register_module[0] == 'MODELS':
|
||||
# ["MODELS", "keyword_spotting", "kws_kwsbp", "GenericKeyWordSpotting"],
|
||||
# ["MODELS", task, model_name, model_class_name]
|
||||
if affected_register_module[1] in task_pipeline_test_suite_map:
|
||||
affected_pipeline_cases.extend(
|
||||
task_pipeline_test_suite_map[affected_register_module[1]])
|
||||
else:
|
||||
logger.warn('Pipeline task: %s has no test case!'
|
||||
% affected_register_module[1])
|
||||
elif affected_register_module[0] == 'TRAINERS':
|
||||
# ["TRAINERS", "", "nlp_base_trainer", "NlpEpochBasedTrainer"],
|
||||
# ["TRAINERS", "", trainer_name, trainer_class_name]
|
||||
if affected_register_module[2] in trainer_test_suite_map:
|
||||
affected_trainer_cases.extend(
|
||||
trainer_test_suite_map[affected_register_module[2]])
|
||||
else:
|
||||
logger.warn('Trainer %s his no case' %
|
||||
(affected_register_module[2]))
|
||||
elif affected_register_module[0] == 'PREPROCESSORS':
|
||||
# ["PREPROCESSORS", "cv", "object_detection_scrfd", "SCRFDPreprocessor"]
|
||||
# ["PREPROCESSORS", domain, preprocessor_name, class_name]
|
||||
task = model_info['task']
|
||||
for model_id, model_info in models_info.items():
|
||||
if model_info['preprocessor_type'] is not None and model_info[
|
||||
'preprocessor_type'] == affected_register_module[2]:
|
||||
if task in task_pipeline_test_suite_map:
|
||||
affected_pipeline_cases.extend(
|
||||
task_pipeline_test_suite_map[task])
|
||||
if model_id in model_trainer_map:
|
||||
affected_trainer_cases.extend(
|
||||
model_trainer_map[model_id])
|
||||
elif (affected_register_module[0] == 'HOOKS'
|
||||
or affected_register_module[0] == 'TASK_DATASETS'):
|
||||
# ["HOOKS", "", "CheckpointHook", "CheckpointHook"]
|
||||
# ["HOOKS", "", hook_name, class_name]
|
||||
# HOOKS, DATASETS modify run all trainer cases
|
||||
for _, cases in trainer_test_suite_map.items():
|
||||
affected_trainer_cases.extend(cases)
|
||||
elif affected_register_module[0] == 'METRICS':
|
||||
# ["METRICS", "default_group", "accuracy", "AccuracyMetric"]
|
||||
# ["METRICS", group, metric_name, class_name]
|
||||
for model_id, model_info in models_info.items():
|
||||
if affected_register_module[2] in model_info[
|
||||
'evaluation_metics']:
|
||||
if model_id in model_trainer_map:
|
||||
affected_trainer_cases.extend(
|
||||
model_trainer_map[model_id])
|
||||
|
||||
# deduplication
|
||||
affected_pipeline_cases = list(set(affected_pipeline_cases))
|
||||
affected_trainer_cases = list(set(affected_trainer_cases))
|
||||
test_suites_to_run = []
|
||||
for test_suite in other_test_suites:
|
||||
test_suites_to_run.append(os.path.basename(test_suite))
|
||||
for test_suite in affected_pipeline_cases:
|
||||
test_suites_to_run.append(os.path.basename(test_suite))
|
||||
for test_suite in affected_trainer_cases:
|
||||
test_suites_to_run.append(os.path.basename(test_suite))
|
||||
|
||||
for modified_case in modified_cases:
|
||||
if modified_case not in test_suites_to_run:
|
||||
test_suites_to_run.append(os.path.basename(modified_case))
|
||||
return test_suites_to_run
|
||||
|
||||
|
||||
def get_files_related_modules(files):
|
||||
register_modules = []
|
||||
for single_file in files:
|
||||
if single_file.startswith('./modelscope') or \
|
||||
single_file.startswith('modelscope'):
|
||||
register_modules.extend(get_file_register_modules(single_file))
|
||||
|
||||
return register_modules
|
||||
|
||||
|
||||
def get_modules_related_cases(register_modules, task_pipeline_test_suite_map,
|
||||
trainer_test_suite_map):
|
||||
affected_pipeline_cases = []
|
||||
affected_trainer_cases = []
|
||||
for register_module in register_modules:
|
||||
if register_module[0] == 'PIPELINES' or \
|
||||
register_module[0] == 'MODELS':
|
||||
if register_module[1] in task_pipeline_test_suite_map:
|
||||
affected_pipeline_cases.extend(
|
||||
task_pipeline_test_suite_map[register_module[1]])
|
||||
else:
|
||||
logger.warn('Pipeline task: %s has no test case!'
|
||||
% register_module[1])
|
||||
elif register_module[0] == 'TRAINERS':
|
||||
if register_module[2] in trainer_test_suite_map:
|
||||
affected_trainer_cases.extend(
|
||||
trainer_test_suite_map[register_module[2]])
|
||||
else:
|
||||
logger.warn('Trainer %s his no case' % (register_module[2]))
|
||||
return affected_pipeline_cases, affected_trainer_cases
|
||||
|
||||
|
||||
def get_all_file_test_info():
|
||||
all_files = [
|
||||
os.path.relpath(os.path.join(dp, f), os.getcwd())
|
||||
for dp, dn, filenames in os.walk(
|
||||
os.path.join(os.getcwd(), 'modelscope')) for f in filenames
|
||||
if os.path.splitext(f)[1] == '.py'
|
||||
]
|
||||
import_map = get_import_map()
|
||||
all_register_modules = get_all_register_modules()
|
||||
task_pipeline_test_suite_map, trainer_test_suite_map = get_pipelines_trainers_test_info(
|
||||
all_register_modules)
|
||||
reverse_depend_map = {}
|
||||
for f in all_files:
|
||||
depend_by = []
|
||||
for k, v in import_map.items():
|
||||
if f in v and f != k:
|
||||
depend_by.append(k)
|
||||
reverse_depend_map[f] = depend_by
|
||||
# get cases.
|
||||
test_info = {}
|
||||
for f in all_files:
|
||||
file_test_info = {}
|
||||
file_test_info['imports'] = import_map[f]
|
||||
file_test_info['imported_by'] = reverse_depend_map[f]
|
||||
register_modules = get_files_related_modules([f]
|
||||
+ reverse_depend_map[f])
|
||||
file_test_info['relate_modules'] = register_modules
|
||||
affected_pipeline_cases, affected_trainer_cases = get_modules_related_cases(
|
||||
register_modules, task_pipeline_test_suite_map,
|
||||
trainer_test_suite_map)
|
||||
file_test_info['pipeline_cases'] = affected_pipeline_cases
|
||||
file_test_info['trainer_cases'] = affected_trainer_cases
|
||||
file_relative_path = os.path.relpath(f, os.getcwd())
|
||||
test_info[file_relative_path] = file_test_info
|
||||
|
||||
with open('./test_relate_info.json', 'w') as f:
|
||||
import json
|
||||
json.dump(test_info, f)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_suites_to_run = get_test_suites_to_run()
|
||||
msg = ','.join(test_suites_to_run)
|
||||
print('Selected cases: %s' % msg)
|
||||
145
tests/trainers/model_trainer_map.py
Normal file
145
tests/trainers/model_trainer_map.py
Normal file
@@ -0,0 +1,145 @@
|
||||
model_trainer_map = {
|
||||
'damo/speech_frcrn_ans_cirm_16k':
|
||||
['tests/trainers/audio/test_ans_trainer.py'],
|
||||
'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch':
|
||||
['tests/trainers/audio/test_asr_trainer.py'],
|
||||
'damo/speech_dfsmn_kws_char_farfield_16k_nihaomiya':
|
||||
['tests/trainers/audio/test_kws_farfield_trainer.py'],
|
||||
'damo/speech_charctc_kws_phone-xiaoyun':
|
||||
['tests/trainers/audio/test_kws_nearfield_trainer.py'],
|
||||
'damo/speech_mossformer_separation_temporal_8k':
|
||||
['tests/trainers/audio/test_separation_trainer.py'],
|
||||
'speech_tts/speech_sambert-hifigan_tts_zh-cn_multisp_pretrain_16k':
|
||||
['tests/trainers/audio/test_tts_trainer.py'],
|
||||
'damo/cv_mobilenet_face-2d-keypoints_alignment':
|
||||
['tests/trainers/easycv/test_easycv_trainer_face_2d_keypoints.py'],
|
||||
'damo/cv_hrnetw18_hand-pose-keypoints_coco-wholebody':
|
||||
['tests/trainers/easycv/test_easycv_trainer_hand_2d_keypoints.py'],
|
||||
'damo/cv_yolox-pai_hand-detection':
|
||||
['tests/trainers/easycv/test_easycv_trainer_hand_detection.py'],
|
||||
'damo/cv_r50_panoptic-segmentation_cocopan':
|
||||
['tests/trainers/easycv/test_easycv_trainer_panoptic_mask2former.py'],
|
||||
'damo/cv_segformer-b0_image_semantic-segmentation_coco-stuff164k':
|
||||
['tests/trainers/easycv/test_segformer.py'],
|
||||
'damo/cv_resnet_carddetection_scrfd34gkps':
|
||||
['tests/trainers/test_card_detection_scrfd_trainer.py'],
|
||||
'damo/multi-modal_clip-vit-base-patch16_zh': [
|
||||
'tests/trainers/test_clip_trainer.py'
|
||||
],
|
||||
'damo/nlp_space_pretrained-dialog-model': [
|
||||
'tests/trainers/test_dialog_intent_trainer.py'
|
||||
],
|
||||
'damo/cv_resnet_facedetection_scrfd10gkps': [
|
||||
'tests/trainers/test_face_detection_scrfd_trainer.py'
|
||||
],
|
||||
'damo/nlp_structbert_faq-question-answering_chinese-base': [
|
||||
'tests/trainers/test_finetune_faq_question_answering.py'
|
||||
],
|
||||
'PAI/nlp_gpt3_text-generation_0.35B_MoE-64': [
|
||||
'tests/trainers/test_finetune_gpt_moe.py'
|
||||
],
|
||||
'damo/nlp_gpt3_text-generation_1.3B': [
|
||||
'tests/trainers/test_finetune_gpt3.py'
|
||||
],
|
||||
'damo/mgeo_backbone_chinese_base': [
|
||||
'tests/trainers/test_finetune_mgeo.py'
|
||||
],
|
||||
'damo/mplug_backbone_base_en': ['tests/trainers/test_finetune_mplug.py'],
|
||||
'damo/nlp_structbert_backbone_base_std': [
|
||||
'tests/trainers/test_finetune_sequence_classification.py',
|
||||
'tests/trainers/test_finetune_token_classification.py'
|
||||
],
|
||||
'damo/nlp_palm2.0_text-generation_english-base': [
|
||||
'tests/trainers/test_finetune_text_generation.py'
|
||||
],
|
||||
'damo/nlp_gpt3_text-generation_chinese-base': [
|
||||
'tests/trainers/test_finetune_text_generation.py'
|
||||
],
|
||||
'damo/nlp_palm2.0_text-generation_chinese-base': [
|
||||
'tests/trainers/test_finetune_text_generation.py'
|
||||
],
|
||||
'damo/nlp_corom_passage-ranking_english-base': [
|
||||
'tests/trainers/test_finetune_text_ranking.py'
|
||||
],
|
||||
'damo/nlp_rom_passage-ranking_chinese-base': [
|
||||
'tests/trainers/test_finetune_text_ranking.py'
|
||||
],
|
||||
'damo/cv_nextvit-small_image-classification_Dailylife-labels': [
|
||||
'tests/trainers/test_general_image_classification_trainer.py'
|
||||
],
|
||||
'damo/cv_convnext-base_image-classification_garbage': [
|
||||
'tests/trainers/test_general_image_classification_trainer.py'
|
||||
],
|
||||
'damo/cv_beitv2-base_image-classification_patch16_224_pt1k_ft22k_in1k': [
|
||||
'tests/trainers/test_general_image_classification_trainer.py'
|
||||
],
|
||||
'damo/cv_csrnet_image-color-enhance-models': [
|
||||
'tests/trainers/test_image_color_enhance_trainer.py'
|
||||
],
|
||||
'damo/cv_nafnet_image-deblur_gopro': [
|
||||
'tests/trainers/test_image_deblur_trainer.py'
|
||||
],
|
||||
'damo/cv_resnet101_detection_fewshot-defrcn': [
|
||||
'tests/trainers/test_image_defrcn_fewshot_trainer.py'
|
||||
],
|
||||
'damo/cv_nafnet_image-denoise_sidd': [
|
||||
'tests/trainers/test_image_denoise_trainer.py'
|
||||
],
|
||||
'damo/cv_fft_inpainting_lama': [
|
||||
'tests/trainers/test_image_inpainting_trainer.py'
|
||||
],
|
||||
'damo/cv_swin-b_image-instance-segmentation_coco': [
|
||||
'tests/trainers/test_image_instance_segmentation_trainer.py'
|
||||
],
|
||||
'damo/cv_gpen_image-portrait-enhancement': [
|
||||
'tests/trainers/test_image_portrait_enhancement_trainer.py'
|
||||
],
|
||||
'damo/cv_clip-it_video-summarization_language-guided_en': [
|
||||
'tests/trainers/test_language_guided_video_summarization_trainer.py'
|
||||
],
|
||||
'damo/cv_resnet50-bert_video-scene-segmentation_movienet': [
|
||||
'tests/trainers/test_movie_scene_segmentation_trainer.py'
|
||||
],
|
||||
'damo/ofa_mmspeech_pretrain_base_zh': [
|
||||
'tests/trainers/test_ofa_mmspeech_trainer.py'
|
||||
],
|
||||
'damo/ofa_ocr-recognition_scene_base_zh': [
|
||||
'tests/trainers/test_ofa_trainer.py'
|
||||
],
|
||||
'damo/nlp_plug_text-generation_27B': [
|
||||
'tests/trainers/test_plug_finetune_text_generation.py'
|
||||
],
|
||||
'damo/cv_swin-t_referring_video-object-segmentation': [
|
||||
'tests/trainers/test_referring_video_object_segmentation_trainer.py'
|
||||
],
|
||||
'damo/nlp_convai_text2sql_pretrain_cn': [
|
||||
'tests/trainers/test_table_question_answering_trainer.py'
|
||||
],
|
||||
'damo/multi-modal_team-vit-large-patch14_multi-modal-similarity': [
|
||||
'tests/trainers/test_team_transfer_trainer.py'
|
||||
],
|
||||
'damo/cv_tinynas_object-detection_damoyolo': [
|
||||
'tests/trainers/test_tinynas_damoyolo_trainer.py'
|
||||
],
|
||||
'damo/nlp_structbert_sentence-similarity_chinese-tiny': [
|
||||
'tests/trainers/test_trainer_with_nlp.py'
|
||||
],
|
||||
'damo/nlp_structbert_sentiment-classification_chinese-base': [
|
||||
'tests/trainers/test_trainer_with_nlp.py'
|
||||
],
|
||||
'damo/nlp_structbert_sentence-similarity_chinese-base': [
|
||||
'tests/trainers/test_trainer_with_nlp.py'
|
||||
],
|
||||
'damo/nlp_csanmt_translation_en2zh': [
|
||||
'tests/trainers/test_translation_trainer.py'
|
||||
],
|
||||
'damo/nlp_csanmt_translation_en2fr': [
|
||||
'tests/trainers/test_translation_trainer.py'
|
||||
],
|
||||
'damo/nlp_csanmt_translation_en2es': [
|
||||
'tests/trainers/test_translation_trainer.py'
|
||||
],
|
||||
'damo/cv_googlenet_pgl-video-summarization': [
|
||||
'tests/trainers/test_video_summarization_trainer.py'
|
||||
],
|
||||
}
|
||||
415
tests/utils/case_file_analyzer.py
Normal file
415
tests/utils/case_file_analyzer.py
Normal file
@@ -0,0 +1,415 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from __future__ import print_function
|
||||
import ast
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
SYSTEM_TRAINER_BUILDER_FINCTION_NAME = 'build_trainer'
|
||||
SYSTEM_TRAINER_BUILDER_PARAMETER_NAME = 'name'
|
||||
SYSTEM_PIPELINE_BUILDER_FUNCTION_NAME = 'pipeline'
|
||||
SYSTEM_PIPELINE_BUILDER_PARAMETER_NAME = 'task'
|
||||
|
||||
|
||||
class AnalysisTestFile(ast.NodeVisitor):
|
||||
"""Analysis test suite files.
|
||||
Get global function and test class
|
||||
|
||||
Args:
|
||||
ast (NodeVisitor): The ast node.
|
||||
Examples:
|
||||
with open(test_suite_file, "rb") as f:
|
||||
src = f.read()
|
||||
analyzer = AnalysisTestFile(test_suite_file)
|
||||
analyzer.visit(ast.parse(src, filename=test_suite_file))
|
||||
"""
|
||||
|
||||
def __init__(self, test_suite_file, builder_function_name) -> None:
|
||||
super().__init__()
|
||||
self.test_classes = []
|
||||
self.builder_function_name = builder_function_name
|
||||
self.global_functions = []
|
||||
self.custom_global_builders = [
|
||||
] # global trainer builder method(call build_trainer)
|
||||
self.custom_global_builder_calls = [] # the builder call statement
|
||||
|
||||
def visit_ClassDef(self, node) -> bool:
|
||||
"""Check if the class is a unittest suite.
|
||||
|
||||
Args:
|
||||
node (ast.Node): the ast node
|
||||
|
||||
Returns: True if is a test class.
|
||||
"""
|
||||
for base in node.bases:
|
||||
if isinstance(base, ast.Attribute) and base.attr == 'TestCase':
|
||||
self.test_classes.append(node)
|
||||
elif isinstance(base, ast.Name) and 'TestCase' in base.id:
|
||||
self.test_classes.append(node)
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef):
|
||||
self.global_functions.append(node)
|
||||
for statement in ast.walk(node):
|
||||
if isinstance(statement, ast.Call) and \
|
||||
isinstance(statement.func, ast.Name):
|
||||
if statement.func.id == self.builder_function_name:
|
||||
self.custom_global_builders.append(node)
|
||||
self.custom_global_builder_calls.append(statement)
|
||||
|
||||
|
||||
class AnalysisTestClass(ast.NodeVisitor):
|
||||
|
||||
def __init__(self, test_class_node, builder_function_name) -> None:
|
||||
super().__init__()
|
||||
self.test_class_node = test_class_node
|
||||
self.builder_function_name = builder_function_name
|
||||
self.setup_variables = {}
|
||||
self.test_methods = []
|
||||
self.custom_class_method_builders = [
|
||||
] # class method trainer builder(call build_trainer)
|
||||
self.custom_class_method_builder_calls = [
|
||||
] # the builder call statement
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
|
||||
if node.name.startswith('setUp'):
|
||||
for statement in node.body:
|
||||
if isinstance(statement, ast.Assign):
|
||||
if len(statement.targets) == 1 and \
|
||||
isinstance(statement.targets[0], ast.Attribute) and \
|
||||
isinstance(statement.value, ast.Attribute):
|
||||
self.setup_variables[str(
|
||||
statement.targets[0].attr)] = str(
|
||||
statement.value.attr)
|
||||
elif node.name.startswith('test_'):
|
||||
self.test_methods.append(node)
|
||||
else:
|
||||
for statement in ast.walk(node):
|
||||
if isinstance(statement, ast.Call) and \
|
||||
isinstance(statement.func, ast.Name):
|
||||
if statement.func.id == self.builder_function_name:
|
||||
self.custom_class_method_builders.append(node)
|
||||
self.custom_class_method_builder_calls.append(
|
||||
statement)
|
||||
|
||||
|
||||
def get_local_arg_value(target_method, args_name):
|
||||
for statement in target_method.body:
|
||||
if isinstance(statement, ast.Assign):
|
||||
for target in statement.targets:
|
||||
if isinstance(target, ast.Name) and target.id == args_name:
|
||||
if isinstance(statement.value, ast.Attribute):
|
||||
return statement.value.attr
|
||||
elif isinstance(statement.value, ast.Str):
|
||||
return statement.value.s
|
||||
return None
|
||||
|
||||
|
||||
def get_custom_builder_parameter_name(args, keywords, builder, builder_call,
|
||||
builder_arg_name):
|
||||
# get build_trainer call name argument name.
|
||||
arg_name = None
|
||||
if len(builder_call.args) > 0:
|
||||
if isinstance(builder_call.args[0], ast.Name):
|
||||
# build_trainer name is a variable
|
||||
arg_name = builder_call.args[0].id
|
||||
elif isinstance(builder_call.args[0], ast.Attribute):
|
||||
# Attribute access, such as Trainers.image_classification_team
|
||||
return builder_call.args[0].attr
|
||||
else:
|
||||
raise Exception('Invalid argument name')
|
||||
else:
|
||||
use_default_name = True
|
||||
for kw in builder_call.keywords:
|
||||
if kw.arg == builder_arg_name:
|
||||
use_default_name = False
|
||||
if isinstance(kw.value, ast.Attribute):
|
||||
return kw.value.attr
|
||||
elif isinstance(kw.value,
|
||||
ast.Name) and kw.arg == builder_arg_name:
|
||||
arg_name = kw.value.id
|
||||
else:
|
||||
raise Exception('Invalid keyword argument')
|
||||
if use_default_name:
|
||||
return 'default'
|
||||
|
||||
if arg_name is None:
|
||||
raise Exception('Invalid build_trainer call')
|
||||
|
||||
arg_value = get_local_arg_value(builder, arg_name)
|
||||
if arg_value is not None: # trainer_name is a local variable
|
||||
return arg_value
|
||||
# get build_trainer name parameter, if it's passed
|
||||
default_name = None
|
||||
arg_idx = 100000
|
||||
for idx, arg in enumerate(builder.args.args):
|
||||
if arg.arg == arg_name:
|
||||
arg_idx = idx
|
||||
if idx >= len(builder.args.args) - len(builder.args.defaults):
|
||||
default_name = builder.args.defaults[idx - (
|
||||
len(builder.args.args) - len(builder.args.defaults))].attr
|
||||
break
|
||||
if len(builder.args.args
|
||||
) > 0 and builder.args.args[0].arg == 'self': # class method
|
||||
if len(args) > arg_idx - 1: # - self
|
||||
if isinstance(args[arg_idx - 1], ast.Attribute):
|
||||
return args[arg_idx - 1].attr
|
||||
|
||||
for keyword in keywords:
|
||||
if keyword.arg == arg_name:
|
||||
if isinstance(keyword.value, ast.Attribute):
|
||||
return keyword.value.attr
|
||||
|
||||
return default_name
|
||||
|
||||
|
||||
def get_system_builder_parameter_value(builder_call, test_method,
|
||||
setup_attributes,
|
||||
builder_parameter_name):
|
||||
if len(builder_call.args) > 0:
|
||||
if isinstance(builder_call.args[0], ast.Name):
|
||||
return get_local_arg_value(test_method, builder_call.args[0].id)
|
||||
elif isinstance(builder_call.args[0], ast.Attribute):
|
||||
if builder_call.args[0].attr in setup_attributes:
|
||||
return setup_attributes[builder_call.args[0].attr]
|
||||
return builder_call.args[0].attr
|
||||
elif isinstance(builder_call.args[0], ast.Str): # TODO check py38
|
||||
return builder_call.args[0].s
|
||||
|
||||
for kw in builder_call.keywords:
|
||||
if kw.arg == builder_parameter_name:
|
||||
if isinstance(kw.value, ast.Attribute):
|
||||
if kw.value.attr in setup_attributes:
|
||||
return setup_attributes[kw.value.attr]
|
||||
else:
|
||||
return kw.value.attr
|
||||
elif isinstance(kw.value,
|
||||
ast.Name) and kw.arg == builder_parameter_name:
|
||||
return kw.value.id
|
||||
|
||||
return 'default' # use build_trainer default argument.
|
||||
|
||||
|
||||
def get_builder_parameter_value(test_method, setup_variables, builder,
|
||||
builder_call, system_builder_func_name,
|
||||
builder_parameter_name):
|
||||
"""
|
||||
get target builder parameter name, for tariner we get trainer name, for pipeline we get pipeline task
|
||||
"""
|
||||
for node in ast.walk(test_method):
|
||||
if builder is None: # direct call build_trainer
|
||||
for node in ast.walk(test_method):
|
||||
if (isinstance(node, ast.Call)
|
||||
and isinstance(node.func, ast.Name)
|
||||
and node.func.id == system_builder_func_name):
|
||||
return get_system_builder_parameter_value(
|
||||
node, test_method, setup_variables,
|
||||
builder_parameter_name)
|
||||
elif (isinstance(node, ast.Call)
|
||||
and isinstance(node.func, ast.Attribute)
|
||||
and node.func.attr == builder.name):
|
||||
return get_custom_builder_parameter_name(node.args, node.keywords,
|
||||
builder, builder_call,
|
||||
builder_parameter_name)
|
||||
elif (isinstance(node, ast.Expr) and isinstance(node.value, ast.Call)
|
||||
and isinstance(node.value.func, ast.Name)
|
||||
and node.value.func.id == builder.name):
|
||||
return get_custom_builder_parameter_name(node.value.args,
|
||||
node.value.keywords,
|
||||
builder, builder_call,
|
||||
builder_parameter_name)
|
||||
elif (isinstance(node, ast.Expr) and isinstance(node.value, ast.Call)
|
||||
and isinstance(node.value.func, ast.Attribute)
|
||||
and node.value.func.attr == builder.name):
|
||||
# self.class_method_builder
|
||||
return get_custom_builder_parameter_name(node.value.args,
|
||||
node.value.keywords,
|
||||
builder, builder_call,
|
||||
builder_parameter_name)
|
||||
elif isinstance(node, ast.Expr) and isinstance(node.value, ast.Call):
|
||||
for arg in node.value.args:
|
||||
if isinstance(arg, ast.Name) and arg.id == builder.name:
|
||||
# self.start(train_func, num_gpus=2, **kwargs)
|
||||
return get_custom_builder_parameter_name(
|
||||
None, None, builder, builder_call,
|
||||
builder_parameter_name)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_class_constructor(test_method, modified_register_modules, module_name):
|
||||
# module_name 'TRAINERS' | 'PIPELINES'
|
||||
for node in ast.walk(test_method):
|
||||
if isinstance(node, ast.Assign) and isinstance(node.value, ast.Call):
|
||||
# trainer = CsanmtTranslationTrainer(model=model_id)
|
||||
for modified_register_module in modified_register_modules:
|
||||
if isinstance(node.value.func, ast.Name) and \
|
||||
node.value.func.id == modified_register_module[3] and \
|
||||
modified_register_module[0] == module_name:
|
||||
if module_name == 'TRAINERS':
|
||||
return modified_register_module[2]
|
||||
elif module_name == 'PIPELINES':
|
||||
return modified_register_module[1] # pipeline
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def analysis_trainer_test_suite(test_file, modified_register_modules):
|
||||
tested_trainers = []
|
||||
with open(test_file, 'rb') as tsf:
|
||||
src = tsf.read()
|
||||
# get test file global function and test class
|
||||
test_suite_root = ast.parse(src, test_file)
|
||||
test_suite_analyzer = AnalysisTestFile(
|
||||
test_file, SYSTEM_TRAINER_BUILDER_FINCTION_NAME)
|
||||
test_suite_analyzer.visit(test_suite_root)
|
||||
|
||||
for test_class in test_suite_analyzer.test_classes:
|
||||
test_class_analyzer = AnalysisTestClass(
|
||||
test_class, SYSTEM_TRAINER_BUILDER_FINCTION_NAME)
|
||||
test_class_analyzer.visit(test_class)
|
||||
for test_method in test_class_analyzer.test_methods:
|
||||
for idx, custom_global_builder in enumerate(
|
||||
test_suite_analyzer.custom_global_builders
|
||||
): # custom test method is global method
|
||||
trainer_name = get_builder_parameter_value(
|
||||
test_method, test_class_analyzer.setup_variables,
|
||||
custom_global_builder,
|
||||
test_suite_analyzer.custom_global_builder_calls[idx],
|
||||
SYSTEM_TRAINER_BUILDER_FINCTION_NAME,
|
||||
SYSTEM_TRAINER_BUILDER_PARAMETER_NAME)
|
||||
if trainer_name is not None:
|
||||
tested_trainers.append(trainer_name)
|
||||
for idx, custom_class_method_builder in enumerate(
|
||||
test_class_analyzer.custom_class_method_builders
|
||||
): # custom class method builder.
|
||||
trainer_name = get_builder_parameter_value(
|
||||
test_method, test_class_analyzer.setup_variables,
|
||||
custom_class_method_builder,
|
||||
test_class_analyzer.custom_class_method_builder_calls[idx],
|
||||
SYSTEM_TRAINER_BUILDER_FINCTION_NAME,
|
||||
SYSTEM_TRAINER_BUILDER_PARAMETER_NAME)
|
||||
if trainer_name is not None:
|
||||
tested_trainers.append(trainer_name)
|
||||
|
||||
trainer_name = get_builder_parameter_value(
|
||||
test_method, test_class_analyzer.setup_variables, None, None,
|
||||
SYSTEM_TRAINER_BUILDER_FINCTION_NAME,
|
||||
SYSTEM_TRAINER_BUILDER_PARAMETER_NAME
|
||||
) # direct call the build_trainer
|
||||
if trainer_name is not None:
|
||||
tested_trainers.append(trainer_name)
|
||||
|
||||
if len(tested_trainers
|
||||
) == 0: # suppose no builder call is direct construct.
|
||||
trainer_name = get_class_constructor(
|
||||
test_method, modified_register_modules, 'TRAINERS')
|
||||
if trainer_name is not None:
|
||||
tested_trainers.append(trainer_name)
|
||||
|
||||
return tested_trainers
|
||||
|
||||
|
||||
def analysis_pipeline_test_suite(test_file, modified_register_modules):
|
||||
tested_tasks = []
|
||||
with open(test_file, 'rb') as tsf:
|
||||
src = tsf.read()
|
||||
# get test file global function and test class
|
||||
test_suite_root = ast.parse(src, test_file)
|
||||
test_suite_analyzer = AnalysisTestFile(
|
||||
test_file, SYSTEM_PIPELINE_BUILDER_FUNCTION_NAME)
|
||||
test_suite_analyzer.visit(test_suite_root)
|
||||
|
||||
for test_class in test_suite_analyzer.test_classes:
|
||||
test_class_analyzer = AnalysisTestClass(
|
||||
test_class, SYSTEM_PIPELINE_BUILDER_FUNCTION_NAME)
|
||||
test_class_analyzer.visit(test_class)
|
||||
for test_method in test_class_analyzer.test_methods:
|
||||
for idx, custom_global_builder in enumerate(
|
||||
test_suite_analyzer.custom_global_builders
|
||||
): # custom test method is global method
|
||||
task_name = get_builder_parameter_value(
|
||||
test_method, test_class_analyzer.setup_variables,
|
||||
custom_global_builder,
|
||||
test_suite_analyzer.custom_global_builder_calls[idx],
|
||||
SYSTEM_PIPELINE_BUILDER_FUNCTION_NAME,
|
||||
SYSTEM_PIPELINE_BUILDER_PARAMETER_NAME)
|
||||
if task_name is not None:
|
||||
tested_tasks.append(task_name)
|
||||
for idx, custom_class_method_builder in enumerate(
|
||||
test_class_analyzer.custom_class_method_builders
|
||||
): # custom class method builder.
|
||||
task_name = get_builder_parameter_value(
|
||||
test_method, test_class_analyzer.setup_variables,
|
||||
custom_class_method_builder,
|
||||
test_class_analyzer.custom_class_method_builder_calls[idx],
|
||||
SYSTEM_PIPELINE_BUILDER_FUNCTION_NAME,
|
||||
SYSTEM_PIPELINE_BUILDER_PARAMETER_NAME)
|
||||
if task_name is not None:
|
||||
tested_tasks.append(task_name)
|
||||
|
||||
task_name = get_builder_parameter_value(
|
||||
test_method, test_class_analyzer.setup_variables, None, None,
|
||||
SYSTEM_PIPELINE_BUILDER_FUNCTION_NAME,
|
||||
SYSTEM_PIPELINE_BUILDER_PARAMETER_NAME
|
||||
) # direct call the build_trainer
|
||||
if task_name is not None:
|
||||
tested_tasks.append(task_name)
|
||||
|
||||
if len(tested_tasks
|
||||
) == 0: # suppose no builder call is direct construct.
|
||||
task_name = get_class_constructor(test_method,
|
||||
modified_register_modules,
|
||||
'PIPELINES')
|
||||
if task_name is not None:
|
||||
tested_tasks.append(task_name)
|
||||
|
||||
return tested_tasks
|
||||
|
||||
|
||||
def get_pipelines_trainers_test_info(register_modules):
|
||||
all_trainer_cases = [
|
||||
os.path.join(dp, f) for dp, dn, filenames in os.walk(
|
||||
os.path.join(os.getcwd(), 'tests', 'trainers')) for f in filenames
|
||||
if os.path.splitext(f)[1] == '.py'
|
||||
]
|
||||
trainer_test_info = {}
|
||||
for test_file in all_trainer_cases:
|
||||
tested_trainers = analysis_trainer_test_suite(test_file,
|
||||
register_modules)
|
||||
if len(tested_trainers) == 0:
|
||||
logger.warn('test_suite: %s has no trainer name' % test_file)
|
||||
else:
|
||||
tested_trainers = list(set(tested_trainers))
|
||||
for trainer_name in tested_trainers:
|
||||
if trainer_name not in trainer_test_info:
|
||||
trainer_test_info[trainer_name] = []
|
||||
trainer_test_info[trainer_name].append(test_file)
|
||||
|
||||
pipeline_test_info = {}
|
||||
all_pipeline_cases = [
|
||||
os.path.join(dp, f) for dp, dn, filenames in os.walk(
|
||||
os.path.join(os.getcwd(), 'tests', 'pipelines')) for f in filenames
|
||||
if os.path.splitext(f)[1] == '.py'
|
||||
]
|
||||
for test_file in all_pipeline_cases:
|
||||
tested_pipelines = analysis_pipeline_test_suite(
|
||||
test_file, register_modules)
|
||||
if len(tested_pipelines) == 0:
|
||||
logger.warn('test_suite: %s has no pipeline task' % test_file)
|
||||
else:
|
||||
tested_pipelines = list(set(tested_pipelines))
|
||||
for pipeline_task in tested_pipelines:
|
||||
if pipeline_task not in pipeline_test_info:
|
||||
pipeline_test_info[pipeline_task] = []
|
||||
pipeline_test_info[pipeline_task].append(test_file)
|
||||
return pipeline_test_info, trainer_test_info
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_file = 'tests/pipelines/test_action_detection.py'
|
||||
tasks = analysis_pipeline_test_suite(test_file, None)
|
||||
|
||||
print(tasks)
|
||||
294
tests/utils/source_file_analyzer.py
Normal file
294
tests/utils/source_file_analyzer.py
Normal file
@@ -0,0 +1,294 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from __future__ import print_function
|
||||
import ast
|
||||
import importlib.util
|
||||
import os
|
||||
import pkgutil
|
||||
import site
|
||||
import sys
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def is_relative_import(path):
|
||||
# from .x import y or from ..x import y
|
||||
return path.startswith('.')
|
||||
|
||||
|
||||
def resolve_import(module_name):
|
||||
try:
|
||||
spec = importlib.util.find_spec(module_name)
|
||||
return spec and spec.origin
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def convert_to_path(name):
|
||||
if name.startswith('.'):
|
||||
remainder = name.lstrip('.')
|
||||
dot_count = (len(name) - len(remainder))
|
||||
prefix = '../' * (dot_count - 1)
|
||||
else:
|
||||
remainder = name
|
||||
dot_count = 0
|
||||
prefix = ''
|
||||
filename = prefix + os.path.join(*remainder.split('.'))
|
||||
return filename
|
||||
|
||||
|
||||
def resolve_relative_import(source_file_path, module_name):
|
||||
current_package = os.path.dirname(source_file_path).replace('/', '.')
|
||||
absolute_name = importlib.util.resolve_name(module_name,
|
||||
current_package) # get
|
||||
return resolve_absolute_import(absolute_name)
|
||||
|
||||
|
||||
def onerror(name):
|
||||
logger.error('Importing module %s error!' % name)
|
||||
|
||||
|
||||
def resolve_absolute_import(module_name):
|
||||
module_file_path = resolve_import(module_name)
|
||||
if module_file_path is None:
|
||||
# find from base module.
|
||||
parent_module, sub_module = module_name.rsplit('.', 1)
|
||||
if parent_module in sys.modules:
|
||||
if hasattr(sys.modules[parent_module], '_import_structure'):
|
||||
import_structure = sys.modules[parent_module]._import_structure
|
||||
for k, v in import_structure.items():
|
||||
if sub_module in v:
|
||||
parent_module = parent_module + '.' + k
|
||||
break
|
||||
module_file_path = resolve_absolute_import(parent_module)
|
||||
# the parent_module is a package, we need find the module_name's file
|
||||
if os.path.basename(module_file_path) == '__init__.py' and \
|
||||
(os.path.relpath(module_file_path, site.getsitepackages()[0]) != 'modelscope/__init__.py'
|
||||
or os.path.relpath(module_file_path, os.getcwd()) != 'modelscope/__init__.py'):
|
||||
for _, sub_module_name, _ in pkgutil.walk_packages(
|
||||
[os.path.dirname(module_file_path)],
|
||||
parent_module + '.',
|
||||
onerror=onerror):
|
||||
try:
|
||||
module_ = importlib.import_module(sub_module_name)
|
||||
for k, v in module_.__dict__.items():
|
||||
if k == sub_module and v.__module__ == module_.__name__:
|
||||
module_file_path = module_.__file__
|
||||
break
|
||||
except ModuleNotFoundError as e:
|
||||
logger.warn(
|
||||
'Import error in %s, ModuleNotFoundError: %s' %
|
||||
(sub_module_name, e))
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warn('Import error in %s, Exception: %s' %
|
||||
(sub_module_name, e))
|
||||
continue
|
||||
else:
|
||||
return module_file_path
|
||||
else:
|
||||
module_file_path = resolve_absolute_import(parent_module)
|
||||
return module_file_path
|
||||
|
||||
|
||||
class AnalysisSourceFileImports(ast.NodeVisitor):
|
||||
"""Analysis source file imports
|
||||
List imports of the modelscope.
|
||||
"""
|
||||
|
||||
def __init__(self, source_file_path) -> None:
|
||||
super().__init__()
|
||||
self.imports = []
|
||||
self.source_file_path = source_file_path
|
||||
|
||||
def visit_Import(self, node):
|
||||
"""Processing import x,y,z or import os.path as osp"""
|
||||
for alias in node.names:
|
||||
if alias.name.startswith('modelscope'):
|
||||
file_path = resolve_absolute_import(alias.name)
|
||||
if file_path.startswith(site.getsitepackages()[0]):
|
||||
self.imports.append(
|
||||
os.path.relpath(file_path,
|
||||
site.getsitepackages()[0]))
|
||||
else:
|
||||
self.imports.append(
|
||||
os.path.relpath(file_path, os.getcwd()))
|
||||
|
||||
def visit_ImportFrom(self, node):
|
||||
# level 0 absolute import such as from os.path import join
|
||||
# level 1 from .x import y
|
||||
# level 2 from ..x import y
|
||||
module_name = '.' * node.level + (node.module or '')
|
||||
for alias in node.names:
|
||||
if alias.name == '*': # from x import *
|
||||
if is_relative_import(module_name):
|
||||
# resolve model path.
|
||||
file_path = resolve_relative_import(
|
||||
self.source_file_path, module_name)
|
||||
elif module_name.startswith('modelscope'):
|
||||
file_path = resolve_absolute_import(module_name)
|
||||
else:
|
||||
file_path = None # ignore other package.
|
||||
else:
|
||||
if not module_name.endswith('.'):
|
||||
module_name = module_name + '.'
|
||||
name = module_name + alias.name
|
||||
if is_relative_import(name):
|
||||
# resolve model path.
|
||||
file_path = resolve_relative_import(
|
||||
self.source_file_path, name)
|
||||
elif name.startswith('modelscope'):
|
||||
file_path = resolve_absolute_import(name)
|
||||
else:
|
||||
file_path = None # ignore other package.
|
||||
|
||||
if file_path is not None:
|
||||
if file_path.startswith(site.getsitepackages()[0]):
|
||||
self.imports.append(
|
||||
os.path.relpath(file_path,
|
||||
site.getsitepackages()[0]))
|
||||
else:
|
||||
self.imports.append(
|
||||
os.path.relpath(file_path, os.getcwd()))
|
||||
|
||||
|
||||
class AnalysisSourceFileRegisterModules(ast.NodeVisitor):
|
||||
"""Get register_module call of the python source file.
|
||||
|
||||
|
||||
Args:
|
||||
ast (NodeVisitor): The ast node.
|
||||
|
||||
Examples:
|
||||
with open(source_file_path, "rb") as f:
|
||||
src = f.read()
|
||||
analyzer = AnalysisSourceFileRegisterModules(source_file_path)
|
||||
analyzer.visit(ast.parse(src, filename=source_file_path))
|
||||
"""
|
||||
|
||||
def __init__(self, source_file_path) -> None:
|
||||
super().__init__()
|
||||
self.source_file_path = source_file_path
|
||||
self.register_modules = []
|
||||
|
||||
def visit_ClassDef(self, node: ast.ClassDef):
|
||||
if len(node.decorator_list) > 0:
|
||||
for dec in node.decorator_list:
|
||||
if isinstance(dec, ast.Call):
|
||||
target_name = ''
|
||||
module_name_param = ''
|
||||
task_param = ''
|
||||
if isinstance(dec.func, ast.Attribute
|
||||
) and dec.func.attr == 'register_module':
|
||||
target_name = dec.func.value.id # MODELS
|
||||
if len(dec.args) > 0:
|
||||
if isinstance(dec.args[0], ast.Attribute):
|
||||
task_param = dec.args[0].attr
|
||||
elif isinstance(dec.args[0], ast.Constant):
|
||||
task_param = dec.args[0].value
|
||||
if len(dec.keywords) > 0:
|
||||
for kw in dec.keywords:
|
||||
if kw.arg == 'module_name':
|
||||
if isinstance(kw.value, ast.Str):
|
||||
module_name_param = kw.value.s
|
||||
else:
|
||||
module_name_param = kw.value.attr
|
||||
elif kw.arg == 'group_key':
|
||||
if isinstance(kw.value, ast.Str):
|
||||
task_param = kw.value.s
|
||||
elif isinstance(kw.value, ast.Name):
|
||||
task_param = kw.value.id
|
||||
else:
|
||||
task_param = kw.value.attr
|
||||
if task_param == '' and module_name_param == '':
|
||||
logger.warn(
|
||||
'File %s %s.register_module has no parameters'
|
||||
% (self.source_file_path, target_name))
|
||||
continue
|
||||
if target_name == 'PIPELINES' and task_param == '':
|
||||
logger.warn(
|
||||
'File %s %s.register_module has no task_param'
|
||||
% (self.source_file_path, target_name))
|
||||
self.register_modules.append(
|
||||
(target_name, task_param, module_name_param,
|
||||
node.name)) # PIPELINES, task, module, class_name
|
||||
|
||||
|
||||
def get_imported_files(file_path):
|
||||
"""Get file dependencies.
|
||||
"""
|
||||
print('Getting %s imports' % file_path)
|
||||
if os.path.isabs(file_path):
|
||||
file_path = os.path.relpath(file_path, os.getcwd())
|
||||
with open(file_path, 'rb') as f:
|
||||
src = f.read()
|
||||
analyzer = AnalysisSourceFileImports(file_path)
|
||||
analyzer.visit(ast.parse(src, filename=file_path))
|
||||
return list(set(analyzer.imports))
|
||||
|
||||
|
||||
def path_to_module_name(file_path):
|
||||
if os.path.isabs(file_path):
|
||||
file_path = os.path.relpath(file_path, os.getcwd())
|
||||
module_name = os.path.dirname(file_path).replace('/', '.')
|
||||
return module_name
|
||||
|
||||
|
||||
def get_file_register_modules(file_path):
|
||||
logger.info('Get file: %s register_module' % file_path)
|
||||
with open(file_path, 'rb') as f:
|
||||
src = f.read()
|
||||
analyzer = AnalysisSourceFileRegisterModules(file_path)
|
||||
analyzer.visit(ast.parse(src, filename=file_path))
|
||||
return analyzer.register_modules
|
||||
|
||||
|
||||
def get_import_map():
|
||||
all_files = [
|
||||
os.path.join(dp, f) for dp, dn, filenames in os.walk(
|
||||
os.path.join(os.getcwd(), 'modelscope')) for f in filenames
|
||||
if os.path.splitext(f)[1] == '.py'
|
||||
]
|
||||
import_map = {}
|
||||
for f in all_files:
|
||||
files = get_imported_files(f)
|
||||
import_map[os.path.relpath(f, os.getcwd())] = files
|
||||
|
||||
return import_map
|
||||
|
||||
|
||||
def get_reverse_import_map():
|
||||
all_files = [
|
||||
os.path.join(dp, f) for dp, dn, filenames in os.walk(
|
||||
os.path.join(os.getcwd(), 'modelscope')) for f in filenames
|
||||
if os.path.splitext(f)[1] == '.py'
|
||||
]
|
||||
import_map = get_import_map()
|
||||
|
||||
reverse_depend_map = {}
|
||||
for f in all_files:
|
||||
depend_by = []
|
||||
for k, v in import_map.items():
|
||||
if f in v and f != k:
|
||||
depend_by.append(k)
|
||||
reverse_depend_map[f] = depend_by
|
||||
|
||||
return reverse_depend_map, import_map
|
||||
|
||||
|
||||
def get_all_register_modules():
|
||||
all_files = [
|
||||
os.path.join(dp, f) for dp, dn, filenames in os.walk(
|
||||
os.path.join(os.getcwd(), 'modelscope')) for f in filenames
|
||||
if os.path.splitext(f)[1] == '.py'
|
||||
]
|
||||
all_register_modules = []
|
||||
for f in all_files:
|
||||
all_register_modules.extend(get_file_register_modules(f))
|
||||
return all_register_modules
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
||||
Reference in New Issue
Block a user