[to #42322933]支持从dataset json文件中获取参数

* dataset json file add args
This commit is contained in:
feiwu.yfw
2022-08-30 15:15:15 +08:00
parent 745bd5a9e0
commit 2b64cf2bb6
6 changed files with 29 additions and 37 deletions

View File

@@ -15,7 +15,7 @@ from modelscope.msdatasets.task_datasets import \
ImageInstanceSegmentationCocoDataset
from modelscope.trainers import build_trainer
from modelscope.utils.config import Config, ConfigDict
from modelscope.utils.constant import ModelFile
from modelscope.utils.constant import DownloadMode, ModelFile
from modelscope.utils.test_utils import test_level
@@ -41,38 +41,26 @@ class TestImageInstanceSegmentationTrainer(unittest.TestCase):
if train_data_cfg is None:
# use default toy data
train_data_cfg = ConfigDict(
name='pets_small',
split='train',
classes=('Cat', 'Dog'),
folder_name='Pets',
test_mode=False)
name='pets_small', split='train', test_mode=False)
if val_data_cfg is None:
val_data_cfg = ConfigDict(
name='pets_small',
split='validation',
classes=('Cat', 'Dog'),
folder_name='Pets',
test_mode=True)
name='pets_small', split='validation', test_mode=True)
self.train_dataset = MsDataset.load(
dataset_name=train_data_cfg.name,
split=train_data_cfg.split,
classes=train_data_cfg.classes,
folder_name=train_data_cfg.folder_name,
test_mode=train_data_cfg.test_mode)
assert self.train_dataset.config_kwargs[
'classes'] == train_data_cfg.classes
test_mode=train_data_cfg.test_mode,
download_mode=DownloadMode.FORCE_REDOWNLOAD)
assert self.train_dataset.config_kwargs['classes']
assert next(
iter(self.train_dataset.config_kwargs['split_config'].values()))
self.eval_dataset = MsDataset.load(
dataset_name=val_data_cfg.name,
split=val_data_cfg.split,
classes=val_data_cfg.classes,
folder_name=val_data_cfg.folder_name,
test_mode=val_data_cfg.test_mode)
assert self.eval_dataset.config_kwargs[
'classes'] == val_data_cfg.classes
test_mode=val_data_cfg.test_mode,
download_mode=DownloadMode.FORCE_REDOWNLOAD)
assert self.eval_dataset.config_kwargs['classes']
assert next(
iter(self.eval_dataset.config_kwargs['split_config'].values()))