Files
modelscope/tests/run_analysis.py
liuyhwangyh 82ee20f447 fix issue #845 (#861)
* fix #845

Co-authored-by: mulin.lyh <mulin.lyh@taobao.com>
2024-05-23 20:34:52 +08:00

400 lines
17 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import subprocess
from fnmatch import fnmatch
from trainers.model_trainer_map import model_trainer_map
from utils.case_file_analyzer import get_pipelines_trainers_test_info
from utils.source_file_analyzer import (get_all_register_modules,
get_file_register_modules,
get_import_map)
from modelscope.hub.api import HubApi
from modelscope.hub.file_download import model_file_download
from modelscope.hub.utils.utils import model_id_to_group_owner_name
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile
from modelscope.utils.file_utils import get_model_cache_dir
from modelscope.utils.logger import get_logger
logger = get_logger()
def get_models_info(groups: list) -> dict:
models = []
api = HubApi()
for group in groups:
page = 1
total_count = 0
while True:
query_result = api.list_models(group, page, 100)
if query_result['Models'] is not None:
models.extend(query_result['Models'])
elif total_count != 0:
total_count = query_result['TotalCount']
if len(models) >= total_count:
break
page += 1
models_info = {} # key model id, value model info
for model_info in models:
model_id = '%s/%s' % (group, model_info['Name'])
configuration_file = os.path.join(
get_model_cache_dir(model_id), ModelFile.CONFIGURATION)
if not os.path.exists(configuration_file):
try:
model_revisions = api.list_model_revisions(model_id=model_id)
if len(model_revisions) == 0:
print('Model: %s has no revision' % model_id)
continue
# get latest revision
configuration_file = model_file_download(
model_id=model_id,
file_path=ModelFile.CONFIGURATION,
revision=model_revisions[0])
except Exception as e:
print('Download model: %s configuration file exception'
% model_id)
print('Exception: %s' % e)
continue
try:
cfg = Config.from_file(configuration_file)
except Exception as e:
print('Resolve model: %s configuration file failed!' % model_id)
print(('Exception: %s' % e))
model_info = {}
model_info['framework'] = cfg.safe_get('framework')
model_info['task'] = cfg.safe_get('task')
model_info['model_type'] = cfg.safe_get('model.type')
model_info['pipeline_type'] = cfg.safe_get('pipeline.type')
model_info['preprocessor_type'] = cfg.safe_get('preprocessor.type')
train_hooks_type = []
train_hooks = cfg.safe_get('train.hooks')
if train_hooks is not None:
for train_hook in train_hooks:
train_hooks_type.append(train_hook.type)
model_info['train_hooks_type'] = train_hooks_type
model_info['datasets'] = cfg.safe_get('dataset')
model_info['evaluation_metics'] = cfg.safe_get('evaluation.metrics',
[]) # metrics name list
"""
print('framework: %s, task: %s, model_type: %s, pipeline_type: %s, \
preprocessor_type: %s, train_hooks_type: %s, \
dataset: %s, evaluation_metics: %s'%(
framework, task, model_type, pipeline_type,
preprocessor_type, ','.join(train_hooks_type),
datasets, evaluation_metics))
"""
models_info[model_id] = model_info
return models_info
def gather_test_suites_files(test_dir='./tests',
pattern='test_*.py',
is_full_path=True):
case_file_list = []
for dirpath, dirnames, filenames in os.walk(test_dir):
for file in filenames:
if fnmatch(file, pattern):
if is_full_path:
case_file_list.append(os.path.join(dirpath, file))
else:
case_file_list.append(file)
return case_file_list
def run_command_get_output(cmd):
response = subprocess.run(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
try:
response.check_returncode()
output = response.stdout.decode('utf8')
return output
except subprocess.CalledProcessError as error:
print('stdout: %s, stderr: %s' %
(response.stdout.decode('utf8'), error.stderr.decode('utf8')))
return None
def get_current_branch():
cmd = ['git', 'rev-parse', '--abbrev-ref', 'HEAD']
branch = run_command_get_output(cmd).strip()
logger.info('Testing branch: %s' % branch)
return branch
def get_modified_files():
if 'PR_CHANGED_FILES' in os.environ and os.environ[
'PR_CHANGED_FILES'].strip() != '':
logger.info('Getting PR modified files.')
# get modify file from environment
diff_files = os.environ['PR_CHANGED_FILES'].replace('#', '\n')
else:
logger.info('Getting diff of branch.')
cmd = ['git', 'diff', '--name-only', 'origin/master...']
diff_files = run_command_get_output(cmd)
logger.info('Diff files: ')
logger.info(diff_files)
modified_files = []
# remove the deleted file.
for diff_file in diff_files.splitlines():
if os.path.exists(diff_file.strip()):
modified_files.append(diff_file.strip())
return modified_files
def analysis_diff():
"""Get modified files and their imported files modified modules
"""
# ignore diff for constant define files, these files import by all pipeline, trainer
ignore_files = [
'modelscope/utils/constant.py', 'modelscope/metainfo.py',
'modelscope/pipeline_inputs.py', 'modelscope/outputs/outputs.py'
]
modified_register_modules = []
modified_cases = []
modified_files_imported_by = []
modified_files = get_modified_files()
logger.info('Modified files:\n %s' % '\n'.join(modified_files))
logger.info('Starting get import map')
import_map = get_import_map()
logger.info('Finished get import map')
for modified_file in modified_files:
if ((modified_file.startswith('./modelscope')
or modified_file.startswith('modelscope'))
and modified_file not in ignore_files): # is source file
for k, v in import_map.items():
if modified_file in v and modified_file != k:
modified_files_imported_by.append(k)
logger.info('There are affected files: %s'
% len(modified_files_imported_by))
for f in modified_files_imported_by:
logger.info(f)
modified_files.extend(modified_files_imported_by) # add imported by file
for modified_file in modified_files:
if modified_file.startswith('./modelscope') or \
modified_file.startswith('modelscope'):
modified_register_modules.extend(
get_file_register_modules(modified_file))
elif ((modified_file.startswith('./tests')
or modified_file.startswith('tests'))
and os.path.basename(modified_file).startswith('test_')):
modified_cases.append(modified_file)
return modified_register_modules, modified_cases
def split_test_suites():
test_suite_full_paths = gather_test_suites_files()
pipeline_test_suites = []
trainer_test_suites = []
other_test_suites = []
for test_suite in test_suite_full_paths:
if test_suite.find('tests/trainers') != -1:
trainer_test_suites.append(test_suite)
elif test_suite.find('tests/pipelines') != -1:
pipeline_test_suites.append(test_suite)
else:
other_test_suites.append(test_suite)
return pipeline_test_suites, trainer_test_suites, other_test_suites
def get_test_suites_to_run():
branch = get_current_branch()
if branch == 'master':
# when run with master, run all the cases
return gather_test_suites_files(is_full_path=False)
affected_register_modules, modified_cases = analysis_diff()
# affected_register_modules list of modified file and dependent file's register_module.
# ("MODULES|PIPELINES|TRAINERS|...", '', '', model_class_name)
# modified_cases, modified case file.
all_register_modules = get_all_register_modules()
_, _, other_test_suites = split_test_suites()
task_pipeline_test_suite_map, trainer_test_suite_map = get_pipelines_trainers_test_info(
all_register_modules)
# task_pipeline_test_suite_map key: pipeline task, value: case file path
# trainer_test_suite_map key: trainer_name, value: case file path
iic_models_info = get_models_info(['iic'])
models_info = {}
# compatible model info
for model_id, model_info in iic_models_info.items():
_, model_name = model_id_to_group_owner_name(model_id)
models_info['damo/%s' % model_name] = models_info
# model_info key: model_id, value: model info such as framework task etc.
affected_pipeline_cases = []
affected_trainer_cases = []
for affected_register_module in affected_register_modules:
# affected_register_module PIPELINE structure
# ["PIPELINES", "acoustic_noise_suppression", "speech_frcrn_ans_cirm_16k", "ANSPipeline"]
# ["PIPELINES", task, pipeline_name, pipeline_class_name]
if affected_register_module[0] == 'PIPELINES':
if affected_register_module[1] in task_pipeline_test_suite_map:
affected_pipeline_cases.extend(
task_pipeline_test_suite_map[affected_register_module[1]])
else:
logger.warning('Pipeline task: %s has no test case!'
% affected_register_module[1])
elif affected_register_module[0] == 'MODELS':
# ["MODELS", "keyword_spotting", "kws_kwsbp", "GenericKeyWordSpotting"],
# ["MODELS", task, model_name, model_class_name]
if affected_register_module[1] in task_pipeline_test_suite_map:
affected_pipeline_cases.extend(
task_pipeline_test_suite_map[affected_register_module[1]])
else:
logger.warning('Pipeline task: %s has no test case!'
% affected_register_module[1])
elif affected_register_module[0] == 'TRAINERS':
# ["TRAINERS", "", "nlp_base_trainer", "NlpEpochBasedTrainer"],
# ["TRAINERS", "", trainer_name, trainer_class_name]
if affected_register_module[2] in trainer_test_suite_map:
affected_trainer_cases.extend(
trainer_test_suite_map[affected_register_module[2]])
else:
logger.warn('Trainer %s his no case' %
(affected_register_module[2]))
elif affected_register_module[0] == 'PREPROCESSORS':
# ["PREPROCESSORS", "cv", "object_detection_scrfd", "SCRFDPreprocessor"]
# ["PREPROCESSORS", domain, preprocessor_name, class_name]
for model_id, model_info in models_info.items():
if ('preprocessor_type' in model_info
and model_info['preprocessor_type'] is not None
and model_info['preprocessor_type']
== affected_register_module[2]):
task = model_info['task']
if task in task_pipeline_test_suite_map:
affected_pipeline_cases.extend(
task_pipeline_test_suite_map[task])
if model_id in model_trainer_map:
affected_trainer_cases.extend(
model_trainer_map[model_id])
elif (affected_register_module[0] == 'HOOKS'
or affected_register_module[0] == 'CUSTOM_DATASETS'):
# ["HOOKS", "", "CheckpointHook", "CheckpointHook"]
# ["HOOKS", "", hook_name, class_name]
# HOOKS, DATASETS modify run all trainer cases
for _, cases in trainer_test_suite_map.items():
affected_trainer_cases.extend(cases)
elif affected_register_module[0] == 'METRICS':
# ["METRICS", "default_group", "accuracy", "AccuracyMetric"]
# ["METRICS", group, metric_name, class_name]
for model_id, model_info in models_info.items():
if affected_register_module[2] in model_info[
'evaluation_metics']:
if model_id in model_trainer_map:
affected_trainer_cases.extend(
model_trainer_map[model_id])
# deduplication
affected_pipeline_cases = list(set(affected_pipeline_cases))
affected_trainer_cases = list(set(affected_trainer_cases))
test_suites_to_run = []
for test_suite in other_test_suites:
test_suites_to_run.append(os.path.basename(test_suite))
for test_suite in affected_pipeline_cases:
test_suites_to_run.append(os.path.basename(test_suite))
for test_suite in affected_trainer_cases:
test_suites_to_run.append(os.path.basename(test_suite))
for modified_case in modified_cases:
if modified_case not in test_suites_to_run:
test_suites_to_run.append(os.path.basename(modified_case))
return test_suites_to_run
def get_files_related_modules(files, reverse_import_map):
register_modules = []
for single_file in files:
if single_file.startswith('./modelscope') or \
single_file.startswith('modelscope'):
register_modules.extend(get_file_register_modules(single_file))
while len(register_modules) == 0:
logger.warn('There is no affected register module')
deeper_imported_by = []
has_deeper_affected_files = False
for source_file in files:
if len(source_file.split('/')) > 4 and source_file.startswith(
'modelscope'):
deeper_imported_by.extend(reverse_import_map[source_file])
has_deeper_affected_files = True
if not has_deeper_affected_files:
break
for file in deeper_imported_by:
register_modules = get_file_register_modules(file)
files = deeper_imported_by
return register_modules
def get_modules_related_cases(register_modules, task_pipeline_test_suite_map,
trainer_test_suite_map):
affected_pipeline_cases = []
affected_trainer_cases = []
for register_module in register_modules:
if register_module[0] == 'PIPELINES' or \
register_module[0] == 'MODELS':
if register_module[1] in task_pipeline_test_suite_map:
affected_pipeline_cases.extend(
task_pipeline_test_suite_map[register_module[1]])
else:
logger.warn('Pipeline task: %s has no test case!'
% register_module[1])
elif register_module[0] == 'TRAINERS':
if register_module[2] in trainer_test_suite_map:
affected_trainer_cases.extend(
trainer_test_suite_map[register_module[2]])
else:
logger.warn('Trainer %s his no case' % (register_module[2]))
return affected_pipeline_cases, affected_trainer_cases
def get_all_file_test_info():
all_files = [
os.path.relpath(os.path.join(dp, f), os.getcwd())
for dp, dn, filenames in os.walk(
os.path.join(os.getcwd(), 'modelscope')) for f in filenames
if os.path.splitext(f)[1] == '.py'
]
import_map = get_import_map()
all_register_modules = get_all_register_modules()
task_pipeline_test_suite_map, trainer_test_suite_map = get_pipelines_trainers_test_info(
all_register_modules)
reverse_depend_map = {}
for f in all_files:
depend_by = []
for k, v in import_map.items():
if f in v and f != k:
depend_by.append(k)
reverse_depend_map[f] = depend_by
# get cases.
test_info = {}
for f in all_files:
file_test_info = {}
file_test_info['imports'] = import_map[f]
file_test_info['imported_by'] = reverse_depend_map[f]
register_modules = get_files_related_modules(
[f] + reverse_depend_map[f], reverse_depend_map)
file_test_info['relate_modules'] = register_modules
affected_pipeline_cases, affected_trainer_cases = get_modules_related_cases(
register_modules, task_pipeline_test_suite_map,
trainer_test_suite_map)
file_test_info['pipeline_cases'] = affected_pipeline_cases
file_test_info['trainer_cases'] = affected_trainer_cases
file_relative_path = os.path.relpath(f, os.getcwd())
test_info[file_relative_path] = file_test_info
with open('./test_relate_info.json', 'w') as f:
import json
json.dump(test_info, f)
if __name__ == '__main__':
test_suites_to_run = get_test_suites_to_run()
msg = ','.join(test_suites_to_run)
print('Selected cases: %s' % msg)