[to #46993990]feat: run ci cases base on code diff to reduct ci test time

This commit is contained in:
mulin.lyh
2023-02-06 08:00:19 +00:00
parent db2f203e5d
commit e54694690f
19 changed files with 1362 additions and 93 deletions

View File

@@ -28,7 +28,7 @@ if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then
awk -F: '/^[^#]/ { print $1 }' requirements/science.txt | xargs -n 1 pip install -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
# test with install
python setup.py install
pip install .
else
echo "Running case in release image, run case directly!"
fi

View File

@@ -1,4 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from modelscope.utils.constant import Fields
class Models(object):
@@ -12,7 +13,6 @@ class Models(object):
# tinynas models
tinynas_detection = 'tinynas-detection'
tinynas_damoyolo = 'tinynas-damoyolo'
# vision models
detection = 'detection'
image_restoration = 'image-restoration'
@@ -397,26 +397,7 @@ class Pipelines(object):
protein_structure = 'unifold-protein-structure'
class Trainers(object):
""" Names for different trainer.
Holds the standard trainer name to use for identifying different trainer.
This should be used to register trainers.
For a general Trainer, you can use EpochBasedTrainer.
For a model specific Trainer, you can use ${ModelName}-${Task}-trainer.
"""
default = 'trainer'
easycv = 'easycv'
tinynas_damoyolo = 'tinynas-damoyolo'
# multi-modal trainers
clip_multi_modal_embedding = 'clip-multi-modal-embedding'
ofa = 'ofa'
mplug = 'mplug'
mgeo_ranking_trainer = 'mgeo-ranking-trainer'
class CVTrainers(object):
# cv trainers
image_instance_segmentation = 'image-instance-segmentation'
image_portrait_enhancement = 'image-portrait-enhancement'
@@ -430,6 +411,8 @@ class Trainers(object):
image_classification = 'image-classification'
image_fewshot_detection = 'image-fewshot-detection'
class NLPTrainers(object):
# nlp trainers
bert_sentiment_analysis = 'bert-sentiment-analysis'
dialog_modeling_trainer = 'dialog-modeling-trainer'
@@ -444,7 +427,15 @@ class Trainers(object):
gpt_moe_trainer = 'nlp-gpt-moe-trainer'
table_question_answering_trainer = 'table-question-answering-trainer'
# audio trainers
class MultiModalTrainers(object):
clip_multi_modal_embedding = 'clip-multi-modal-embedding'
ofa = 'ofa'
mplug = 'mplug'
mgeo_ranking_trainer = 'mgeo-ranking-trainer'
class AudioTrainers(object):
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'
speech_kws_fsmn_char_ctc_nearfield = 'speech_kws_fsmn_char_ctc_nearfield'
@@ -453,6 +444,45 @@ class Trainers(object):
speech_separation = 'speech-separation'
class Trainers(CVTrainers, NLPTrainers, MultiModalTrainers, AudioTrainers):
""" Names for different trainer.
Holds the standard trainer name to use for identifying different trainer.
This should be used to register trainers.
For a general Trainer, you can use EpochBasedTrainer.
For a model specific Trainer, you can use ${ModelName}-${Task}-trainer.
"""
default = 'trainer'
easycv = 'easycv'
tinynas_damoyolo = 'tinynas-damoyolo'
@staticmethod
def get_trainer_domain(attribute_or_value):
if attribute_or_value in vars(
CVTrainers) or attribute_or_value in vars(CVTrainers).values():
return Fields.cv
elif attribute_or_value in vars(
NLPTrainers) or attribute_or_value in vars(
NLPTrainers).values():
return Fields.nlp
elif attribute_or_value in vars(
AudioTrainers) or attribute_or_value in vars(
AudioTrainers).values():
return Fields.audio
elif attribute_or_value in vars(
MultiModalTrainers) or attribute_or_value in vars(
MultiModalTrainers).values():
return Fields.multi_modal
elif attribute_or_value == Trainers.default:
return Trainers.default
elif attribute_or_value == Trainers.easycv:
return Trainers.easycv
else:
return 'unknown'
class Preprocessors(object):
""" Names for different preprocessor.

View File

@@ -6,68 +6,75 @@ import os
import json
import nltk
nltk.download('punkt')
class NLTKSegmenter:
def __init(self):
pass
download_nltk()
@staticmethod
def segment_string(article):
return nltk.tokenize.sent_tokenize(article)
wiki_path = 'data/extracted'
output_path = 'formatted/wiki-key.txt'
segmenter = NLTKSegmenter()
with open(output_path, 'w') as output:
for dirname in glob.glob(os.path.join(wiki_path, '*'), recursive=False):
for filename in glob.glob(
os.path.join(dirname, 'wiki_*'), recursive=True):
print(filename)
article_lines = []
article_open = False
with open(
filename, mode='r', newline='\n',
encoding='utf-8') as file:
for line in file:
line = line.rstrip()
if '<doc id=' in line:
article_open = True
elif '</doc>' in line:
key_sentences, contents = [], []
key, content = None, []
for sentences in article_lines[1:]:
if len(sentences) > 1:
if key:
if len(content) > 0 or len(contents) == 0:
key_sentences.append(key)
contents.append(content)
def download_nltk():
nltk.download('punkt')
wiki_path = 'data/extracted'
output_path = 'formatted/wiki-key.txt'
segmenter = NLTKSegmenter()
with open(output_path, 'w') as output:
for dirname in glob.glob(
os.path.join(wiki_path, '*'), recursive=False):
for filename in glob.glob(
os.path.join(dirname, 'wiki_*'), recursive=True):
print(filename)
article_lines = []
article_open = False
with open(
filename, mode='r', newline='\n',
encoding='utf-8') as file:
for line in file:
line = line.rstrip()
if '<doc id=' in line:
article_open = True
elif '</doc>' in line:
key_sentences, contents = [], []
key, content = None, []
for sentences in article_lines[1:]:
if len(sentences) > 1:
if key:
if len(content) > 0 or len(
contents) == 0:
key_sentences.append(key)
contents.append(content)
else:
contents[-1].append(key)
key, content = None, []
key_sentences.append(sentences[0])
contents.append(sentences[1:])
elif len(sentences) > 0:
if key:
content.append(sentences[0])
else:
contents[-1].append(key)
key, content = None, []
key_sentences.append(sentences[0])
contents.append(sentences[1:])
elif len(sentences) > 0:
if key:
content.append(sentences[0])
key = sentences[0]
if key:
if len(content) > 0 or len(contents) == 0:
key_sentences.append(key)
contents.append(content)
else:
key = sentences[0]
if key:
if len(content) > 0 or len(contents) == 0:
key_sentences.append(key)
contents.append(content)
else:
contents[-1].append(key)
contents = [' '.join(content) for content in contents]
article = {'key': key_sentences, 'content': contents}
output.write(json.dumps(article))
output.write('\n')
article_open = False
article_lines = []
else:
if article_open and line:
sentences = segmenter.segment_string(line)
article_lines.append(sentences)
contents[-1].append(key)
contents = [
' '.join(content) for content in contents
]
article = {
'key': key_sentences,
'content': contents
}
output.write(json.dumps(article))
output.write('\n')
article_open = False
article_lines = []
else:
if article_open and line:
sentences = segmenter.segment_string(line)
article_lines.append(sentences)

View File

@@ -37,7 +37,7 @@ class SpaceModelBase(nn.Module):
return
def _create_parameters(self):
""" Create model's paramters. """
""" Create model's parameters. """
raise NotImplementedError
def _forward(self, inputs, is_training, with_label):

View File

@@ -107,7 +107,7 @@ class UnifiedTransformer(SpaceModelBase):
return
def _create_parameters(self):
""" Create model's paramters. """
""" Create model's parameters. """
sequence_mask = np.tri(
self.num_pos_embeddings,
self.num_pos_embeddings,

View File

@@ -199,7 +199,7 @@ class BaseTaskModel(TorchModel, ABC):
]
if _fast_init:
# retrieve unintialized modules and initialize
# retrieve uninitialized modules and initialize
uninitialized_modules = self.retrieve_modules_from_names(
missing_keys,
prefix=prefix,

View File

@@ -4,11 +4,11 @@ from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .easycv_detection import DetDataset, DetImagesMixDataset
from .detection_dataset import DetDataset, DetImagesMixDataset
else:
_import_structure = {
'easycv_detection': ['DetDataset', 'DetImagesMixDataset']
'detection_dataset': ['DetDataset', 'DetImagesMixDataset']
}
import sys

View File

@@ -68,7 +68,7 @@ def _nn_variable(name, shape, init_method, collection=None, **kwargs):
shape: variable shape
init_method: 'zero', 'kaiming', 'xavier', or (mean, std)
collection: if not none, add variable to this collection
kwargs: extra paramters passed to tf.get_variable
kwargs: extra parameters passed to tf.get_variable
RETURN
var: a new or existing variable
"""

View File

@@ -55,7 +55,7 @@ class BaseTrainer(ABC):
""" Train (and evaluate) process
Train process should be implemented for specific task or
model, releated paramters have been intialized in
model, related parameters have been initialized in
``BaseTrainer.__init__`` and should be used in this function
"""
pass
@@ -66,7 +66,7 @@ class BaseTrainer(ABC):
""" Evaluation process
Evaluation process should be implemented for specific task or
model, releated paramters have been intialized in
model, related parameters have been initialized in
``BaseTrainer.__init__`` and should be used in this function
"""
pass
@@ -87,7 +87,7 @@ class DummyTrainer(BaseTrainer):
""" Train (and evaluate) process
Train process should be implemented for specific task or
model, releated paramters have been intialized in
model, related parameters have been initialized in
``BaseTrainer.__init__`` and should be used in this function
"""
cfg = self.cfg.train
@@ -100,7 +100,7 @@ class DummyTrainer(BaseTrainer):
""" Evaluation process
Evaluation process should be implemented for specific task or
model, releated paramters have been intialized in
model, related parameters have been initialized in
``BaseTrainer.__init__`` and should be used in this function
"""
cfg = self.cfg.evaluation

View File

@@ -38,7 +38,7 @@ class TableQuestionAnsweringTrainer(BaseTrainer):
num_training_steps,
last_epoch=-1):
"""
set scheduler
set scheduler.
"""
def lr_lambda(current_step: int):

View File

@@ -107,7 +107,7 @@ class CliArgumentParser(ArgumentParser):
Args:
arg_dict (dict of `ArgAttr` or list of them): dict or list of dict which defines different
paramters for training.
parameters for training.
"""
def __init__(self, arg_dict: Union[Dict[str, ArgAttr],

View File

@@ -2,7 +2,7 @@
# docstyle-ignore
AUDIO_IMPORT_ERROR = """
Audio model import failed: {0}, if you want to use audio releated function, please execute
Audio model import failed: {0}, if you want to use audio related function, please execute
`pip install modelscope[audio] -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html`
"""

View File

@@ -9,7 +9,7 @@ numpy
oss2
Pillow>=6.2.0
# pyarrow 9.0.0 introduced event_loop core dump
pyarrow>=6.0.0,<9.0.0
pyarrow>=6.0.0,!=9.0.0
pyyaml
requests
scipy

View File

@@ -49,7 +49,7 @@ class HubOperationTest(unittest.TestCase):
repo.tag_and_push(self.revision, 'Test revision')
def test_model_repo_creation(self):
# change to proper model names before use
# change to proper model names before use.
try:
info = self.api.get_model(model_id=self.model_id)
assert info['Name'] == self.model_name

View File

@@ -3,6 +3,7 @@
import argparse
import datetime
import importlib
import math
import multiprocessing
import os
@@ -362,10 +363,43 @@ def run_non_parallelizable_test_suites(suites, result_dir):
run_command_with_popen(cmd)
# Selected cases:
def get_selected_cases():
cmd = ['python', '-u', 'tests/run_analysis.py']
selected_cases = []
with subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
bufsize=1,
encoding='utf8') as sub_process:
for line in iter(sub_process.stdout.readline, ''):
sys.stdout.write(line)
if line.startswith('Selected cases:'):
line = line.replace('Selected cases:', '').strip()
selected_cases = line.split(',')
return selected_cases
def run_in_subprocess(args):
# only case args.isolated_cases run in subporcess, all other run in a subprocess
test_suite_files = gather_test_suites_files(
os.path.abspath(args.test_dir), args.pattern)
if not args.no_diff: # run based on git diff
try:
test_suite_files = get_selected_cases()
logger.info('Tests suite to run: ')
for f in test_suite_files:
logger.info(f)
except Exception:
logger.error('Get test suite based diff exception!')
test_suite_files = gather_test_suites_files(
os.path.abspath(args.test_dir), args.pattern)
if len(test_suite_files) == 0:
logger.error('Get no test suite based on diff, run all the cases.')
test_suite_files = gather_test_suites_files(
os.path.abspath(args.test_dir), args.pattern)
else:
test_suite_files = gather_test_suites_files(
os.path.abspath(args.test_dir), args.pattern)
non_parallelizable_suites = [
'test_download_dataset.py',
@@ -579,11 +613,18 @@ if __name__ == '__main__':
type=int,
help='Set case parallels, default single process, set with gpu number.'
)
parser.add_argument(
'--no-diff',
action='store_true',
help=
'Default running case based on git diff(with master), disable with --no-diff)'
)
parser.add_argument(
'--suites',
nargs='*',
help='Run specified test suites(test suite files list split by space)')
args = parser.parse_args()
print(args)
set_test_level(args.level)
os.environ['REGRESSION_BASELINE'] = '1'
logger.info(f'TEST LEVEL: {test_level()}')

337
tests/run_analysis.py Normal file
View File

@@ -0,0 +1,337 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import subprocess
import sys
from fnmatch import fnmatch
from trainers.model_trainer_map import model_trainer_map
from utils.case_file_analyzer import get_pipelines_trainers_test_info
from utils.source_file_analyzer import (get_all_register_modules,
get_file_register_modules,
get_import_map)
from modelscope.hub.api import HubApi
from modelscope.hub.errors import NotExistError
from modelscope.hub.file_download import model_file_download
from modelscope.hub.utils.utils import get_cache_dir
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile
from modelscope.utils.logger import get_logger
logger = get_logger()
def get_models_info(groups: list) -> dict:
models = []
api = HubApi()
for group in groups:
page = 1
while True:
query_result = api.list_models(group, page, 100)
models.extend(query_result['Models'])
if len(models) >= query_result['TotalCount']:
break
page += 1
cache_root = get_cache_dir()
models_info = {} # key model id, value model info
for model_info in models:
model_id = '%s/%s' % (group, model_info['Name'])
configuration_file = os.path.join(cache_root, model_id,
ModelFile.CONFIGURATION)
if not os.path.exists(configuration_file):
model_revisions = api.list_model_revisions(model_id=model_id)
if len(model_revisions) == 0:
logger.warn('Model: %s has no revision' % model_id)
continue
# get latest revision
try:
configuration_file = model_file_download(
model_id=model_id,
file_path=ModelFile.CONFIGURATION,
revision=model_revisions[0])
except NotExistError:
logger.warn('Model: %s has no configuration file %s' %
(model_id, ModelFile.CONFIGURATION))
continue
cfg = Config.from_file(configuration_file)
model_info = {}
model_info['framework'] = cfg.safe_get('framework')
model_info['task'] = cfg.safe_get('task')
model_info['model_type'] = cfg.safe_get('model.type')
model_info['pipeline_type'] = cfg.safe_get('pipeline.type')
model_info['preprocessor_type'] = cfg.safe_get('preprocessor.type')
train_hooks_type = []
train_hooks = cfg.safe_get('train.hooks')
if train_hooks is not None:
for train_hook in train_hooks:
train_hooks_type.append(train_hook.type)
model_info['train_hooks_type'] = train_hooks_type
model_info['datasets'] = cfg.safe_get('dataset')
model_info['evaluation_metics'] = cfg.safe_get('evaluation.metrics',
[]) # metrics name list
"""
print('framework: %s, task: %s, model_type: %s, pipeline_type: %s, \
preprocessor_type: %s, train_hooks_type: %s, \
dataset: %s, evaluation_metics: %s'%(
framework, task, model_type, pipeline_type,
preprocessor_type, ','.join(train_hooks_type),
datasets, evaluation_metics))
"""
models_info[model_id] = model_info
return models_info
def gather_test_suites_files_full_path(test_dir='./tests',
pattern='test_*.py'):
case_file_list = []
for dirpath, dirnames, filenames in os.walk(test_dir):
for file in filenames:
if fnmatch(file, pattern):
case_file_list.append(os.path.join(dirpath, file))
return case_file_list
def run_command_get_output(cmd):
response = subprocess.run(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
try:
response.check_returncode()
output = response.stdout.decode('utf8')
return output
except subprocess.CalledProcessError as error:
logger.error(
'stdout: %s, stderr: %s' %
(response.stdout.decode('utf8'), error.stderr.decode('utf8')))
return None
def get_modified_files():
cmd = ['git', 'diff', '--name-only', 'origin/master...']
cmd_output = run_command_get_output(cmd)
logger.info('Modified files: ')
logger.info(cmd_output)
return cmd_output.splitlines()
def analysis_diff():
"""Get modified files and their imported files modified modules
"""
modified_register_modules = []
modified_cases = []
modified_files_imported_by = []
modified_files = get_modified_files()
logger.info('Modified files:\n %s' % '\n'.join(modified_files))
logger.info('Starting get import map')
import_map = get_import_map()
logger.info('Finished get import map')
for modified_file in modified_files:
if modified_file.startswith('./modelscope') or \
modified_file.startswith('modelscope'): # is source file
for k, v in import_map.items():
if modified_file in v and modified_file != k:
modified_files_imported_by.append(k)
logger.info('There are affected files: %s'
% len(modified_files_imported_by))
for f in modified_files_imported_by:
logger.info(f)
modified_files.extend(modified_files_imported_by) # add imported by file
for modified_file in modified_files:
if modified_file.startswith('./modelscope') or \
modified_file.startswith('modelscope'):
modified_register_modules.extend(
get_file_register_modules(modified_file))
elif ((modified_file.startswith('./tests')
or modified_file.startswith('tests'))
and os.path.basename(modified_file).startswith('test_')):
modified_cases.append(modified_file)
return modified_register_modules, modified_cases
def split_test_suites():
test_suite_full_paths = gather_test_suites_files_full_path()
pipeline_test_suites = []
trainer_test_suites = []
other_test_suites = []
for test_suite in test_suite_full_paths:
if test_suite.find('tests/trainers') != -1:
trainer_test_suites.append(test_suite)
elif test_suite.find('tests/pipelines') != -1:
pipeline_test_suites.append(test_suite)
else:
other_test_suites.append(test_suite)
return pipeline_test_suites, trainer_test_suites, other_test_suites
def get_test_suites_to_run():
affected_register_modules, modified_cases = analysis_diff()
# affected_register_modules list of modified file and dependent file's register_module.
# ("MODULES|PIPELINES|TRAINERS|...", '', '', model_class_name)
# modified_cases, modified case file.
all_register_modules = get_all_register_modules()
_, _, other_test_suites = split_test_suites()
task_pipeline_test_suite_map, trainer_test_suite_map = get_pipelines_trainers_test_info(
all_register_modules)
# task_pipeline_test_suite_map key: pipeline task, value: case file path
# trainer_test_suite_map key: trainer_name, value: case file path
models_info = get_models_info(['damo'])
# model_info key: model_id, value: model info such as framework task etc.
affected_pipeline_cases = []
affected_trainer_cases = []
for affected_register_module in affected_register_modules:
# affected_register_module PIPELINE structure
# ["PIPELINES", "acoustic_noise_suppression", "speech_frcrn_ans_cirm_16k", "ANSPipeline"]
# ["PIPELINES", task, pipeline_name, pipeline_class_name]
if affected_register_module[0] == 'PIPELINES':
if affected_register_module[1] in task_pipeline_test_suite_map:
affected_pipeline_cases.extend(
task_pipeline_test_suite_map[affected_register_module[1]])
else:
logger.warn('Pipeline task: %s has no test case!'
% affected_register_module[1])
elif affected_register_module[0] == 'MODELS':
# ["MODELS", "keyword_spotting", "kws_kwsbp", "GenericKeyWordSpotting"],
# ["MODELS", task, model_name, model_class_name]
if affected_register_module[1] in task_pipeline_test_suite_map:
affected_pipeline_cases.extend(
task_pipeline_test_suite_map[affected_register_module[1]])
else:
logger.warn('Pipeline task: %s has no test case!'
% affected_register_module[1])
elif affected_register_module[0] == 'TRAINERS':
# ["TRAINERS", "", "nlp_base_trainer", "NlpEpochBasedTrainer"],
# ["TRAINERS", "", trainer_name, trainer_class_name]
if affected_register_module[2] in trainer_test_suite_map:
affected_trainer_cases.extend(
trainer_test_suite_map[affected_register_module[2]])
else:
logger.warn('Trainer %s his no case' %
(affected_register_module[2]))
elif affected_register_module[0] == 'PREPROCESSORS':
# ["PREPROCESSORS", "cv", "object_detection_scrfd", "SCRFDPreprocessor"]
# ["PREPROCESSORS", domain, preprocessor_name, class_name]
task = model_info['task']
for model_id, model_info in models_info.items():
if model_info['preprocessor_type'] is not None and model_info[
'preprocessor_type'] == affected_register_module[2]:
if task in task_pipeline_test_suite_map:
affected_pipeline_cases.extend(
task_pipeline_test_suite_map[task])
if model_id in model_trainer_map:
affected_trainer_cases.extend(
model_trainer_map[model_id])
elif (affected_register_module[0] == 'HOOKS'
or affected_register_module[0] == 'TASK_DATASETS'):
# ["HOOKS", "", "CheckpointHook", "CheckpointHook"]
# ["HOOKS", "", hook_name, class_name]
# HOOKS, DATASETS modify run all trainer cases
for _, cases in trainer_test_suite_map.items():
affected_trainer_cases.extend(cases)
elif affected_register_module[0] == 'METRICS':
# ["METRICS", "default_group", "accuracy", "AccuracyMetric"]
# ["METRICS", group, metric_name, class_name]
for model_id, model_info in models_info.items():
if affected_register_module[2] in model_info[
'evaluation_metics']:
if model_id in model_trainer_map:
affected_trainer_cases.extend(
model_trainer_map[model_id])
# deduplication
affected_pipeline_cases = list(set(affected_pipeline_cases))
affected_trainer_cases = list(set(affected_trainer_cases))
test_suites_to_run = []
for test_suite in other_test_suites:
test_suites_to_run.append(os.path.basename(test_suite))
for test_suite in affected_pipeline_cases:
test_suites_to_run.append(os.path.basename(test_suite))
for test_suite in affected_trainer_cases:
test_suites_to_run.append(os.path.basename(test_suite))
for modified_case in modified_cases:
if modified_case not in test_suites_to_run:
test_suites_to_run.append(os.path.basename(modified_case))
return test_suites_to_run
def get_files_related_modules(files):
register_modules = []
for single_file in files:
if single_file.startswith('./modelscope') or \
single_file.startswith('modelscope'):
register_modules.extend(get_file_register_modules(single_file))
return register_modules
def get_modules_related_cases(register_modules, task_pipeline_test_suite_map,
trainer_test_suite_map):
affected_pipeline_cases = []
affected_trainer_cases = []
for register_module in register_modules:
if register_module[0] == 'PIPELINES' or \
register_module[0] == 'MODELS':
if register_module[1] in task_pipeline_test_suite_map:
affected_pipeline_cases.extend(
task_pipeline_test_suite_map[register_module[1]])
else:
logger.warn('Pipeline task: %s has no test case!'
% register_module[1])
elif register_module[0] == 'TRAINERS':
if register_module[2] in trainer_test_suite_map:
affected_trainer_cases.extend(
trainer_test_suite_map[register_module[2]])
else:
logger.warn('Trainer %s his no case' % (register_module[2]))
return affected_pipeline_cases, affected_trainer_cases
def get_all_file_test_info():
all_files = [
os.path.relpath(os.path.join(dp, f), os.getcwd())
for dp, dn, filenames in os.walk(
os.path.join(os.getcwd(), 'modelscope')) for f in filenames
if os.path.splitext(f)[1] == '.py'
]
import_map = get_import_map()
all_register_modules = get_all_register_modules()
task_pipeline_test_suite_map, trainer_test_suite_map = get_pipelines_trainers_test_info(
all_register_modules)
reverse_depend_map = {}
for f in all_files:
depend_by = []
for k, v in import_map.items():
if f in v and f != k:
depend_by.append(k)
reverse_depend_map[f] = depend_by
# get cases.
test_info = {}
for f in all_files:
file_test_info = {}
file_test_info['imports'] = import_map[f]
file_test_info['imported_by'] = reverse_depend_map[f]
register_modules = get_files_related_modules([f]
+ reverse_depend_map[f])
file_test_info['relate_modules'] = register_modules
affected_pipeline_cases, affected_trainer_cases = get_modules_related_cases(
register_modules, task_pipeline_test_suite_map,
trainer_test_suite_map)
file_test_info['pipeline_cases'] = affected_pipeline_cases
file_test_info['trainer_cases'] = affected_trainer_cases
file_relative_path = os.path.relpath(f, os.getcwd())
test_info[file_relative_path] = file_test_info
with open('./test_relate_info.json', 'w') as f:
import json
json.dump(test_info, f)
if __name__ == '__main__':
test_suites_to_run = get_test_suites_to_run()
msg = ','.join(test_suites_to_run)
print('Selected cases: %s' % msg)

View File

@@ -0,0 +1,145 @@
model_trainer_map = {
'damo/speech_frcrn_ans_cirm_16k':
['tests/trainers/audio/test_ans_trainer.py'],
'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch':
['tests/trainers/audio/test_asr_trainer.py'],
'damo/speech_dfsmn_kws_char_farfield_16k_nihaomiya':
['tests/trainers/audio/test_kws_farfield_trainer.py'],
'damo/speech_charctc_kws_phone-xiaoyun':
['tests/trainers/audio/test_kws_nearfield_trainer.py'],
'damo/speech_mossformer_separation_temporal_8k':
['tests/trainers/audio/test_separation_trainer.py'],
'speech_tts/speech_sambert-hifigan_tts_zh-cn_multisp_pretrain_16k':
['tests/trainers/audio/test_tts_trainer.py'],
'damo/cv_mobilenet_face-2d-keypoints_alignment':
['tests/trainers/easycv/test_easycv_trainer_face_2d_keypoints.py'],
'damo/cv_hrnetw18_hand-pose-keypoints_coco-wholebody':
['tests/trainers/easycv/test_easycv_trainer_hand_2d_keypoints.py'],
'damo/cv_yolox-pai_hand-detection':
['tests/trainers/easycv/test_easycv_trainer_hand_detection.py'],
'damo/cv_r50_panoptic-segmentation_cocopan':
['tests/trainers/easycv/test_easycv_trainer_panoptic_mask2former.py'],
'damo/cv_segformer-b0_image_semantic-segmentation_coco-stuff164k':
['tests/trainers/easycv/test_segformer.py'],
'damo/cv_resnet_carddetection_scrfd34gkps':
['tests/trainers/test_card_detection_scrfd_trainer.py'],
'damo/multi-modal_clip-vit-base-patch16_zh': [
'tests/trainers/test_clip_trainer.py'
],
'damo/nlp_space_pretrained-dialog-model': [
'tests/trainers/test_dialog_intent_trainer.py'
],
'damo/cv_resnet_facedetection_scrfd10gkps': [
'tests/trainers/test_face_detection_scrfd_trainer.py'
],
'damo/nlp_structbert_faq-question-answering_chinese-base': [
'tests/trainers/test_finetune_faq_question_answering.py'
],
'PAI/nlp_gpt3_text-generation_0.35B_MoE-64': [
'tests/trainers/test_finetune_gpt_moe.py'
],
'damo/nlp_gpt3_text-generation_1.3B': [
'tests/trainers/test_finetune_gpt3.py'
],
'damo/mgeo_backbone_chinese_base': [
'tests/trainers/test_finetune_mgeo.py'
],
'damo/mplug_backbone_base_en': ['tests/trainers/test_finetune_mplug.py'],
'damo/nlp_structbert_backbone_base_std': [
'tests/trainers/test_finetune_sequence_classification.py',
'tests/trainers/test_finetune_token_classification.py'
],
'damo/nlp_palm2.0_text-generation_english-base': [
'tests/trainers/test_finetune_text_generation.py'
],
'damo/nlp_gpt3_text-generation_chinese-base': [
'tests/trainers/test_finetune_text_generation.py'
],
'damo/nlp_palm2.0_text-generation_chinese-base': [
'tests/trainers/test_finetune_text_generation.py'
],
'damo/nlp_corom_passage-ranking_english-base': [
'tests/trainers/test_finetune_text_ranking.py'
],
'damo/nlp_rom_passage-ranking_chinese-base': [
'tests/trainers/test_finetune_text_ranking.py'
],
'damo/cv_nextvit-small_image-classification_Dailylife-labels': [
'tests/trainers/test_general_image_classification_trainer.py'
],
'damo/cv_convnext-base_image-classification_garbage': [
'tests/trainers/test_general_image_classification_trainer.py'
],
'damo/cv_beitv2-base_image-classification_patch16_224_pt1k_ft22k_in1k': [
'tests/trainers/test_general_image_classification_trainer.py'
],
'damo/cv_csrnet_image-color-enhance-models': [
'tests/trainers/test_image_color_enhance_trainer.py'
],
'damo/cv_nafnet_image-deblur_gopro': [
'tests/trainers/test_image_deblur_trainer.py'
],
'damo/cv_resnet101_detection_fewshot-defrcn': [
'tests/trainers/test_image_defrcn_fewshot_trainer.py'
],
'damo/cv_nafnet_image-denoise_sidd': [
'tests/trainers/test_image_denoise_trainer.py'
],
'damo/cv_fft_inpainting_lama': [
'tests/trainers/test_image_inpainting_trainer.py'
],
'damo/cv_swin-b_image-instance-segmentation_coco': [
'tests/trainers/test_image_instance_segmentation_trainer.py'
],
'damo/cv_gpen_image-portrait-enhancement': [
'tests/trainers/test_image_portrait_enhancement_trainer.py'
],
'damo/cv_clip-it_video-summarization_language-guided_en': [
'tests/trainers/test_language_guided_video_summarization_trainer.py'
],
'damo/cv_resnet50-bert_video-scene-segmentation_movienet': [
'tests/trainers/test_movie_scene_segmentation_trainer.py'
],
'damo/ofa_mmspeech_pretrain_base_zh': [
'tests/trainers/test_ofa_mmspeech_trainer.py'
],
'damo/ofa_ocr-recognition_scene_base_zh': [
'tests/trainers/test_ofa_trainer.py'
],
'damo/nlp_plug_text-generation_27B': [
'tests/trainers/test_plug_finetune_text_generation.py'
],
'damo/cv_swin-t_referring_video-object-segmentation': [
'tests/trainers/test_referring_video_object_segmentation_trainer.py'
],
'damo/nlp_convai_text2sql_pretrain_cn': [
'tests/trainers/test_table_question_answering_trainer.py'
],
'damo/multi-modal_team-vit-large-patch14_multi-modal-similarity': [
'tests/trainers/test_team_transfer_trainer.py'
],
'damo/cv_tinynas_object-detection_damoyolo': [
'tests/trainers/test_tinynas_damoyolo_trainer.py'
],
'damo/nlp_structbert_sentence-similarity_chinese-tiny': [
'tests/trainers/test_trainer_with_nlp.py'
],
'damo/nlp_structbert_sentiment-classification_chinese-base': [
'tests/trainers/test_trainer_with_nlp.py'
],
'damo/nlp_structbert_sentence-similarity_chinese-base': [
'tests/trainers/test_trainer_with_nlp.py'
],
'damo/nlp_csanmt_translation_en2zh': [
'tests/trainers/test_translation_trainer.py'
],
'damo/nlp_csanmt_translation_en2fr': [
'tests/trainers/test_translation_trainer.py'
],
'damo/nlp_csanmt_translation_en2es': [
'tests/trainers/test_translation_trainer.py'
],
'damo/cv_googlenet_pgl-video-summarization': [
'tests/trainers/test_video_summarization_trainer.py'
],
}

View File

@@ -0,0 +1,415 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from __future__ import print_function
import ast
import os
from typing import Any
from modelscope.utils.logger import get_logger
logger = get_logger()
SYSTEM_TRAINER_BUILDER_FINCTION_NAME = 'build_trainer'
SYSTEM_TRAINER_BUILDER_PARAMETER_NAME = 'name'
SYSTEM_PIPELINE_BUILDER_FUNCTION_NAME = 'pipeline'
SYSTEM_PIPELINE_BUILDER_PARAMETER_NAME = 'task'
class AnalysisTestFile(ast.NodeVisitor):
"""Analysis test suite files.
Get global function and test class
Args:
ast (NodeVisitor): The ast node.
Examples:
with open(test_suite_file, "rb") as f:
src = f.read()
analyzer = AnalysisTestFile(test_suite_file)
analyzer.visit(ast.parse(src, filename=test_suite_file))
"""
def __init__(self, test_suite_file, builder_function_name) -> None:
super().__init__()
self.test_classes = []
self.builder_function_name = builder_function_name
self.global_functions = []
self.custom_global_builders = [
] # global trainer builder method(call build_trainer)
self.custom_global_builder_calls = [] # the builder call statement
def visit_ClassDef(self, node) -> bool:
"""Check if the class is a unittest suite.
Args:
node (ast.Node): the ast node
Returns: True if is a test class.
"""
for base in node.bases:
if isinstance(base, ast.Attribute) and base.attr == 'TestCase':
self.test_classes.append(node)
elif isinstance(base, ast.Name) and 'TestCase' in base.id:
self.test_classes.append(node)
def visit_FunctionDef(self, node: ast.FunctionDef):
self.global_functions.append(node)
for statement in ast.walk(node):
if isinstance(statement, ast.Call) and \
isinstance(statement.func, ast.Name):
if statement.func.id == self.builder_function_name:
self.custom_global_builders.append(node)
self.custom_global_builder_calls.append(statement)
class AnalysisTestClass(ast.NodeVisitor):
def __init__(self, test_class_node, builder_function_name) -> None:
super().__init__()
self.test_class_node = test_class_node
self.builder_function_name = builder_function_name
self.setup_variables = {}
self.test_methods = []
self.custom_class_method_builders = [
] # class method trainer builder(call build_trainer)
self.custom_class_method_builder_calls = [
] # the builder call statement
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
if node.name.startswith('setUp'):
for statement in node.body:
if isinstance(statement, ast.Assign):
if len(statement.targets) == 1 and \
isinstance(statement.targets[0], ast.Attribute) and \
isinstance(statement.value, ast.Attribute):
self.setup_variables[str(
statement.targets[0].attr)] = str(
statement.value.attr)
elif node.name.startswith('test_'):
self.test_methods.append(node)
else:
for statement in ast.walk(node):
if isinstance(statement, ast.Call) and \
isinstance(statement.func, ast.Name):
if statement.func.id == self.builder_function_name:
self.custom_class_method_builders.append(node)
self.custom_class_method_builder_calls.append(
statement)
def get_local_arg_value(target_method, args_name):
for statement in target_method.body:
if isinstance(statement, ast.Assign):
for target in statement.targets:
if isinstance(target, ast.Name) and target.id == args_name:
if isinstance(statement.value, ast.Attribute):
return statement.value.attr
elif isinstance(statement.value, ast.Str):
return statement.value.s
return None
def get_custom_builder_parameter_name(args, keywords, builder, builder_call,
builder_arg_name):
# get build_trainer call name argument name.
arg_name = None
if len(builder_call.args) > 0:
if isinstance(builder_call.args[0], ast.Name):
# build_trainer name is a variable
arg_name = builder_call.args[0].id
elif isinstance(builder_call.args[0], ast.Attribute):
# Attribute access, such as Trainers.image_classification_team
return builder_call.args[0].attr
else:
raise Exception('Invalid argument name')
else:
use_default_name = True
for kw in builder_call.keywords:
if kw.arg == builder_arg_name:
use_default_name = False
if isinstance(kw.value, ast.Attribute):
return kw.value.attr
elif isinstance(kw.value,
ast.Name) and kw.arg == builder_arg_name:
arg_name = kw.value.id
else:
raise Exception('Invalid keyword argument')
if use_default_name:
return 'default'
if arg_name is None:
raise Exception('Invalid build_trainer call')
arg_value = get_local_arg_value(builder, arg_name)
if arg_value is not None: # trainer_name is a local variable
return arg_value
# get build_trainer name parameter, if it's passed
default_name = None
arg_idx = 100000
for idx, arg in enumerate(builder.args.args):
if arg.arg == arg_name:
arg_idx = idx
if idx >= len(builder.args.args) - len(builder.args.defaults):
default_name = builder.args.defaults[idx - (
len(builder.args.args) - len(builder.args.defaults))].attr
break
if len(builder.args.args
) > 0 and builder.args.args[0].arg == 'self': # class method
if len(args) > arg_idx - 1: # - self
if isinstance(args[arg_idx - 1], ast.Attribute):
return args[arg_idx - 1].attr
for keyword in keywords:
if keyword.arg == arg_name:
if isinstance(keyword.value, ast.Attribute):
return keyword.value.attr
return default_name
def get_system_builder_parameter_value(builder_call, test_method,
setup_attributes,
builder_parameter_name):
if len(builder_call.args) > 0:
if isinstance(builder_call.args[0], ast.Name):
return get_local_arg_value(test_method, builder_call.args[0].id)
elif isinstance(builder_call.args[0], ast.Attribute):
if builder_call.args[0].attr in setup_attributes:
return setup_attributes[builder_call.args[0].attr]
return builder_call.args[0].attr
elif isinstance(builder_call.args[0], ast.Str): # TODO check py38
return builder_call.args[0].s
for kw in builder_call.keywords:
if kw.arg == builder_parameter_name:
if isinstance(kw.value, ast.Attribute):
if kw.value.attr in setup_attributes:
return setup_attributes[kw.value.attr]
else:
return kw.value.attr
elif isinstance(kw.value,
ast.Name) and kw.arg == builder_parameter_name:
return kw.value.id
return 'default' # use build_trainer default argument.
def get_builder_parameter_value(test_method, setup_variables, builder,
builder_call, system_builder_func_name,
builder_parameter_name):
"""
get target builder parameter name, for tariner we get trainer name, for pipeline we get pipeline task
"""
for node in ast.walk(test_method):
if builder is None: # direct call build_trainer
for node in ast.walk(test_method):
if (isinstance(node, ast.Call)
and isinstance(node.func, ast.Name)
and node.func.id == system_builder_func_name):
return get_system_builder_parameter_value(
node, test_method, setup_variables,
builder_parameter_name)
elif (isinstance(node, ast.Call)
and isinstance(node.func, ast.Attribute)
and node.func.attr == builder.name):
return get_custom_builder_parameter_name(node.args, node.keywords,
builder, builder_call,
builder_parameter_name)
elif (isinstance(node, ast.Expr) and isinstance(node.value, ast.Call)
and isinstance(node.value.func, ast.Name)
and node.value.func.id == builder.name):
return get_custom_builder_parameter_name(node.value.args,
node.value.keywords,
builder, builder_call,
builder_parameter_name)
elif (isinstance(node, ast.Expr) and isinstance(node.value, ast.Call)
and isinstance(node.value.func, ast.Attribute)
and node.value.func.attr == builder.name):
# self.class_method_builder
return get_custom_builder_parameter_name(node.value.args,
node.value.keywords,
builder, builder_call,
builder_parameter_name)
elif isinstance(node, ast.Expr) and isinstance(node.value, ast.Call):
for arg in node.value.args:
if isinstance(arg, ast.Name) and arg.id == builder.name:
# self.start(train_func, num_gpus=2, **kwargs)
return get_custom_builder_parameter_name(
None, None, builder, builder_call,
builder_parameter_name)
return None
def get_class_constructor(test_method, modified_register_modules, module_name):
# module_name 'TRAINERS' | 'PIPELINES'
for node in ast.walk(test_method):
if isinstance(node, ast.Assign) and isinstance(node.value, ast.Call):
# trainer = CsanmtTranslationTrainer(model=model_id)
for modified_register_module in modified_register_modules:
if isinstance(node.value.func, ast.Name) and \
node.value.func.id == modified_register_module[3] and \
modified_register_module[0] == module_name:
if module_name == 'TRAINERS':
return modified_register_module[2]
elif module_name == 'PIPELINES':
return modified_register_module[1] # pipeline
return None
def analysis_trainer_test_suite(test_file, modified_register_modules):
tested_trainers = []
with open(test_file, 'rb') as tsf:
src = tsf.read()
# get test file global function and test class
test_suite_root = ast.parse(src, test_file)
test_suite_analyzer = AnalysisTestFile(
test_file, SYSTEM_TRAINER_BUILDER_FINCTION_NAME)
test_suite_analyzer.visit(test_suite_root)
for test_class in test_suite_analyzer.test_classes:
test_class_analyzer = AnalysisTestClass(
test_class, SYSTEM_TRAINER_BUILDER_FINCTION_NAME)
test_class_analyzer.visit(test_class)
for test_method in test_class_analyzer.test_methods:
for idx, custom_global_builder in enumerate(
test_suite_analyzer.custom_global_builders
): # custom test method is global method
trainer_name = get_builder_parameter_value(
test_method, test_class_analyzer.setup_variables,
custom_global_builder,
test_suite_analyzer.custom_global_builder_calls[idx],
SYSTEM_TRAINER_BUILDER_FINCTION_NAME,
SYSTEM_TRAINER_BUILDER_PARAMETER_NAME)
if trainer_name is not None:
tested_trainers.append(trainer_name)
for idx, custom_class_method_builder in enumerate(
test_class_analyzer.custom_class_method_builders
): # custom class method builder.
trainer_name = get_builder_parameter_value(
test_method, test_class_analyzer.setup_variables,
custom_class_method_builder,
test_class_analyzer.custom_class_method_builder_calls[idx],
SYSTEM_TRAINER_BUILDER_FINCTION_NAME,
SYSTEM_TRAINER_BUILDER_PARAMETER_NAME)
if trainer_name is not None:
tested_trainers.append(trainer_name)
trainer_name = get_builder_parameter_value(
test_method, test_class_analyzer.setup_variables, None, None,
SYSTEM_TRAINER_BUILDER_FINCTION_NAME,
SYSTEM_TRAINER_BUILDER_PARAMETER_NAME
) # direct call the build_trainer
if trainer_name is not None:
tested_trainers.append(trainer_name)
if len(tested_trainers
) == 0: # suppose no builder call is direct construct.
trainer_name = get_class_constructor(
test_method, modified_register_modules, 'TRAINERS')
if trainer_name is not None:
tested_trainers.append(trainer_name)
return tested_trainers
def analysis_pipeline_test_suite(test_file, modified_register_modules):
tested_tasks = []
with open(test_file, 'rb') as tsf:
src = tsf.read()
# get test file global function and test class
test_suite_root = ast.parse(src, test_file)
test_suite_analyzer = AnalysisTestFile(
test_file, SYSTEM_PIPELINE_BUILDER_FUNCTION_NAME)
test_suite_analyzer.visit(test_suite_root)
for test_class in test_suite_analyzer.test_classes:
test_class_analyzer = AnalysisTestClass(
test_class, SYSTEM_PIPELINE_BUILDER_FUNCTION_NAME)
test_class_analyzer.visit(test_class)
for test_method in test_class_analyzer.test_methods:
for idx, custom_global_builder in enumerate(
test_suite_analyzer.custom_global_builders
): # custom test method is global method
task_name = get_builder_parameter_value(
test_method, test_class_analyzer.setup_variables,
custom_global_builder,
test_suite_analyzer.custom_global_builder_calls[idx],
SYSTEM_PIPELINE_BUILDER_FUNCTION_NAME,
SYSTEM_PIPELINE_BUILDER_PARAMETER_NAME)
if task_name is not None:
tested_tasks.append(task_name)
for idx, custom_class_method_builder in enumerate(
test_class_analyzer.custom_class_method_builders
): # custom class method builder.
task_name = get_builder_parameter_value(
test_method, test_class_analyzer.setup_variables,
custom_class_method_builder,
test_class_analyzer.custom_class_method_builder_calls[idx],
SYSTEM_PIPELINE_BUILDER_FUNCTION_NAME,
SYSTEM_PIPELINE_BUILDER_PARAMETER_NAME)
if task_name is not None:
tested_tasks.append(task_name)
task_name = get_builder_parameter_value(
test_method, test_class_analyzer.setup_variables, None, None,
SYSTEM_PIPELINE_BUILDER_FUNCTION_NAME,
SYSTEM_PIPELINE_BUILDER_PARAMETER_NAME
) # direct call the build_trainer
if task_name is not None:
tested_tasks.append(task_name)
if len(tested_tasks
) == 0: # suppose no builder call is direct construct.
task_name = get_class_constructor(test_method,
modified_register_modules,
'PIPELINES')
if task_name is not None:
tested_tasks.append(task_name)
return tested_tasks
def get_pipelines_trainers_test_info(register_modules):
all_trainer_cases = [
os.path.join(dp, f) for dp, dn, filenames in os.walk(
os.path.join(os.getcwd(), 'tests', 'trainers')) for f in filenames
if os.path.splitext(f)[1] == '.py'
]
trainer_test_info = {}
for test_file in all_trainer_cases:
tested_trainers = analysis_trainer_test_suite(test_file,
register_modules)
if len(tested_trainers) == 0:
logger.warn('test_suite: %s has no trainer name' % test_file)
else:
tested_trainers = list(set(tested_trainers))
for trainer_name in tested_trainers:
if trainer_name not in trainer_test_info:
trainer_test_info[trainer_name] = []
trainer_test_info[trainer_name].append(test_file)
pipeline_test_info = {}
all_pipeline_cases = [
os.path.join(dp, f) for dp, dn, filenames in os.walk(
os.path.join(os.getcwd(), 'tests', 'pipelines')) for f in filenames
if os.path.splitext(f)[1] == '.py'
]
for test_file in all_pipeline_cases:
tested_pipelines = analysis_pipeline_test_suite(
test_file, register_modules)
if len(tested_pipelines) == 0:
logger.warn('test_suite: %s has no pipeline task' % test_file)
else:
tested_pipelines = list(set(tested_pipelines))
for pipeline_task in tested_pipelines:
if pipeline_task not in pipeline_test_info:
pipeline_test_info[pipeline_task] = []
pipeline_test_info[pipeline_task].append(test_file)
return pipeline_test_info, trainer_test_info
if __name__ == '__main__':
test_file = 'tests/pipelines/test_action_detection.py'
tasks = analysis_pipeline_test_suite(test_file, None)
print(tasks)

View File

@@ -0,0 +1,294 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from __future__ import print_function
import ast
import importlib.util
import os
import pkgutil
import site
import sys
from modelscope.utils.logger import get_logger
logger = get_logger()
def is_relative_import(path):
# from .x import y or from ..x import y
return path.startswith('.')
def resolve_import(module_name):
try:
spec = importlib.util.find_spec(module_name)
return spec and spec.origin
except Exception:
return None
def convert_to_path(name):
if name.startswith('.'):
remainder = name.lstrip('.')
dot_count = (len(name) - len(remainder))
prefix = '../' * (dot_count - 1)
else:
remainder = name
dot_count = 0
prefix = ''
filename = prefix + os.path.join(*remainder.split('.'))
return filename
def resolve_relative_import(source_file_path, module_name):
current_package = os.path.dirname(source_file_path).replace('/', '.')
absolute_name = importlib.util.resolve_name(module_name,
current_package) # get
return resolve_absolute_import(absolute_name)
def onerror(name):
logger.error('Importing module %s error!' % name)
def resolve_absolute_import(module_name):
module_file_path = resolve_import(module_name)
if module_file_path is None:
# find from base module.
parent_module, sub_module = module_name.rsplit('.', 1)
if parent_module in sys.modules:
if hasattr(sys.modules[parent_module], '_import_structure'):
import_structure = sys.modules[parent_module]._import_structure
for k, v in import_structure.items():
if sub_module in v:
parent_module = parent_module + '.' + k
break
module_file_path = resolve_absolute_import(parent_module)
# the parent_module is a package, we need find the module_name's file
if os.path.basename(module_file_path) == '__init__.py' and \
(os.path.relpath(module_file_path, site.getsitepackages()[0]) != 'modelscope/__init__.py'
or os.path.relpath(module_file_path, os.getcwd()) != 'modelscope/__init__.py'):
for _, sub_module_name, _ in pkgutil.walk_packages(
[os.path.dirname(module_file_path)],
parent_module + '.',
onerror=onerror):
try:
module_ = importlib.import_module(sub_module_name)
for k, v in module_.__dict__.items():
if k == sub_module and v.__module__ == module_.__name__:
module_file_path = module_.__file__
break
except ModuleNotFoundError as e:
logger.warn(
'Import error in %s, ModuleNotFoundError: %s' %
(sub_module_name, e))
continue
except Exception as e:
logger.warn('Import error in %s, Exception: %s' %
(sub_module_name, e))
continue
else:
return module_file_path
else:
module_file_path = resolve_absolute_import(parent_module)
return module_file_path
class AnalysisSourceFileImports(ast.NodeVisitor):
"""Analysis source file imports
List imports of the modelscope.
"""
def __init__(self, source_file_path) -> None:
super().__init__()
self.imports = []
self.source_file_path = source_file_path
def visit_Import(self, node):
"""Processing import x,y,z or import os.path as osp"""
for alias in node.names:
if alias.name.startswith('modelscope'):
file_path = resolve_absolute_import(alias.name)
if file_path.startswith(site.getsitepackages()[0]):
self.imports.append(
os.path.relpath(file_path,
site.getsitepackages()[0]))
else:
self.imports.append(
os.path.relpath(file_path, os.getcwd()))
def visit_ImportFrom(self, node):
# level 0 absolute import such as from os.path import join
# level 1 from .x import y
# level 2 from ..x import y
module_name = '.' * node.level + (node.module or '')
for alias in node.names:
if alias.name == '*': # from x import *
if is_relative_import(module_name):
# resolve model path.
file_path = resolve_relative_import(
self.source_file_path, module_name)
elif module_name.startswith('modelscope'):
file_path = resolve_absolute_import(module_name)
else:
file_path = None # ignore other package.
else:
if not module_name.endswith('.'):
module_name = module_name + '.'
name = module_name + alias.name
if is_relative_import(name):
# resolve model path.
file_path = resolve_relative_import(
self.source_file_path, name)
elif name.startswith('modelscope'):
file_path = resolve_absolute_import(name)
else:
file_path = None # ignore other package.
if file_path is not None:
if file_path.startswith(site.getsitepackages()[0]):
self.imports.append(
os.path.relpath(file_path,
site.getsitepackages()[0]))
else:
self.imports.append(
os.path.relpath(file_path, os.getcwd()))
class AnalysisSourceFileRegisterModules(ast.NodeVisitor):
"""Get register_module call of the python source file.
Args:
ast (NodeVisitor): The ast node.
Examples:
with open(source_file_path, "rb") as f:
src = f.read()
analyzer = AnalysisSourceFileRegisterModules(source_file_path)
analyzer.visit(ast.parse(src, filename=source_file_path))
"""
def __init__(self, source_file_path) -> None:
super().__init__()
self.source_file_path = source_file_path
self.register_modules = []
def visit_ClassDef(self, node: ast.ClassDef):
if len(node.decorator_list) > 0:
for dec in node.decorator_list:
if isinstance(dec, ast.Call):
target_name = ''
module_name_param = ''
task_param = ''
if isinstance(dec.func, ast.Attribute
) and dec.func.attr == 'register_module':
target_name = dec.func.value.id # MODELS
if len(dec.args) > 0:
if isinstance(dec.args[0], ast.Attribute):
task_param = dec.args[0].attr
elif isinstance(dec.args[0], ast.Constant):
task_param = dec.args[0].value
if len(dec.keywords) > 0:
for kw in dec.keywords:
if kw.arg == 'module_name':
if isinstance(kw.value, ast.Str):
module_name_param = kw.value.s
else:
module_name_param = kw.value.attr
elif kw.arg == 'group_key':
if isinstance(kw.value, ast.Str):
task_param = kw.value.s
elif isinstance(kw.value, ast.Name):
task_param = kw.value.id
else:
task_param = kw.value.attr
if task_param == '' and module_name_param == '':
logger.warn(
'File %s %s.register_module has no parameters'
% (self.source_file_path, target_name))
continue
if target_name == 'PIPELINES' and task_param == '':
logger.warn(
'File %s %s.register_module has no task_param'
% (self.source_file_path, target_name))
self.register_modules.append(
(target_name, task_param, module_name_param,
node.name)) # PIPELINES, task, module, class_name
def get_imported_files(file_path):
"""Get file dependencies.
"""
print('Getting %s imports' % file_path)
if os.path.isabs(file_path):
file_path = os.path.relpath(file_path, os.getcwd())
with open(file_path, 'rb') as f:
src = f.read()
analyzer = AnalysisSourceFileImports(file_path)
analyzer.visit(ast.parse(src, filename=file_path))
return list(set(analyzer.imports))
def path_to_module_name(file_path):
if os.path.isabs(file_path):
file_path = os.path.relpath(file_path, os.getcwd())
module_name = os.path.dirname(file_path).replace('/', '.')
return module_name
def get_file_register_modules(file_path):
logger.info('Get file: %s register_module' % file_path)
with open(file_path, 'rb') as f:
src = f.read()
analyzer = AnalysisSourceFileRegisterModules(file_path)
analyzer.visit(ast.parse(src, filename=file_path))
return analyzer.register_modules
def get_import_map():
all_files = [
os.path.join(dp, f) for dp, dn, filenames in os.walk(
os.path.join(os.getcwd(), 'modelscope')) for f in filenames
if os.path.splitext(f)[1] == '.py'
]
import_map = {}
for f in all_files:
files = get_imported_files(f)
import_map[os.path.relpath(f, os.getcwd())] = files
return import_map
def get_reverse_import_map():
all_files = [
os.path.join(dp, f) for dp, dn, filenames in os.walk(
os.path.join(os.getcwd(), 'modelscope')) for f in filenames
if os.path.splitext(f)[1] == '.py'
]
import_map = get_import_map()
reverse_depend_map = {}
for f in all_files:
depend_by = []
for k, v in import_map.items():
if f in v and f != k:
depend_by.append(k)
reverse_depend_map[f] = depend_by
return reverse_depend_map, import_map
def get_all_register_modules():
all_files = [
os.path.join(dp, f) for dp, dn, filenames in os.walk(
os.path.join(os.getcwd(), 'modelscope')) for f in filenames
if os.path.splitext(f)[1] == '.py'
]
all_register_modules = []
for f in all_files:
all_register_modules.extend(get_file_register_modules(f))
return all_register_modules
if __name__ == '__main__':
pass