mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
[to #42322933]add face 2d keypoints finetune test case
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10421808 * add face 2d keypoints & human wholebody keypoint finrtune test case
This commit is contained in:
@@ -452,9 +452,9 @@ class Datasets(object):
|
||||
""" Names for different datasets.
|
||||
"""
|
||||
ClsDataset = 'ClsDataset'
|
||||
Face2dKeypointsDataset = 'Face2dKeypointsDataset'
|
||||
Face2dKeypointsDataset = 'FaceKeypointDataset'
|
||||
HandCocoWholeBodyDataset = 'HandCocoWholeBodyDataset'
|
||||
HumanWholeBodyKeypointDataset = 'HumanWholeBodyKeypointDataset'
|
||||
HumanWholeBodyKeypointDataset = 'WholeBodyCocoTopDownDataset'
|
||||
SegDataset = 'SegDataset'
|
||||
DetDataset = 'DetDataset'
|
||||
DetImagesMixDataset = 'DetImagesMixDataset'
|
||||
|
||||
@@ -26,11 +26,16 @@ class EasyCVBaseDataset(object):
|
||||
if self.split_config is not None:
|
||||
self._update_data_source(kwargs['data_source'])
|
||||
|
||||
def _update_data_root(self, input_dict, data_root):
|
||||
for k, v in input_dict.items():
|
||||
if isinstance(v, str) and self.DATA_ROOT_PATTERN in v:
|
||||
input_dict.update(
|
||||
{k: v.replace(self.DATA_ROOT_PATTERN, data_root)})
|
||||
elif isinstance(v, dict):
|
||||
self._update_data_root(v, data_root)
|
||||
|
||||
def _update_data_source(self, data_source):
|
||||
data_root = next(iter(self.split_config.values()))
|
||||
data_root = data_root.rstrip(osp.sep)
|
||||
|
||||
for k, v in data_source.items():
|
||||
if isinstance(v, str) and self.DATA_ROOT_PATTERN in v:
|
||||
data_source.update(
|
||||
{k: v.replace(self.DATA_ROOT_PATTERN, data_root)})
|
||||
self._update_data_root(data_source, data_root)
|
||||
|
||||
@@ -19,7 +19,7 @@ moviepy>=1.0.3
|
||||
networkx>=2.5
|
||||
numba
|
||||
onnxruntime>=1.10
|
||||
pai-easycv>=0.6.3.7
|
||||
pai-easycv>=0.6.3.9
|
||||
pandas
|
||||
psutil
|
||||
regex
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.utils.constant import DownloadMode, LogKeys, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest')
|
||||
class EasyCVTrainerTestFace2DKeypoints(unittest.TestCase):
|
||||
model_id = 'damo/cv_mobilenet_face-2d-keypoints_alignment'
|
||||
|
||||
def setUp(self):
|
||||
self.logger = get_logger()
|
||||
self.logger.info(('Testing %s.%s' %
|
||||
(type(self).__name__, self._testMethodName)))
|
||||
|
||||
def _train(self, tmp_dir):
|
||||
cfg_options = {'train.max_epochs': 2}
|
||||
|
||||
trainer_name = Trainers.easycv
|
||||
|
||||
train_dataset = MsDataset.load(
|
||||
dataset_name='face_2d_keypoints_dataset',
|
||||
namespace='modelscope',
|
||||
split='train',
|
||||
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)
|
||||
eval_dataset = MsDataset.load(
|
||||
dataset_name='face_2d_keypoints_dataset',
|
||||
namespace='modelscope',
|
||||
split='train',
|
||||
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)
|
||||
|
||||
kwargs = dict(
|
||||
model=self.model_id,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
work_dir=tmp_dir,
|
||||
cfg_options=cfg_options)
|
||||
|
||||
trainer = build_trainer(trainer_name, kwargs)
|
||||
trainer.train()
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_trainer_single_gpu(self):
|
||||
temp_file_dir = tempfile.TemporaryDirectory()
|
||||
tmp_dir = temp_file_dir.name
|
||||
if not os.path.exists(tmp_dir):
|
||||
os.makedirs(tmp_dir)
|
||||
|
||||
self._train(tmp_dir)
|
||||
|
||||
results_files = os.listdir(tmp_dir)
|
||||
json_files = glob.glob(os.path.join(tmp_dir, '*.log.json'))
|
||||
self.assertEqual(len(json_files), 1)
|
||||
self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files)
|
||||
|
||||
temp_file_dir.cleanup()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user