modify ofatrainer

This commit is contained in:
翎航
2022-10-20 22:32:41 +08:00
parent 63a62c3151
commit 9b8cfc4ece
2 changed files with 40 additions and 10 deletions

View File

@@ -24,12 +24,13 @@ from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion,
@TRAINERS.register_module(module_name=Trainers.ofa_tasks)
class OFATrainer(EpochBasedTrainer):
def __init__(self, model: str, *args, **kwargs):
def __init__(self, model: str, cfg_file, work_dir, train_dataset,
eval_dataset, *args, **kwargs):
model = Model.from_pretrained(model)
model_dir = model.model_dir
cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION)
# cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION)
cfg = Config.from_file(cfg_file)
dataset = self._build_dataset_with_config(cfg)
# dataset = self._build_dataset_with_config(cfg)
preprocessor = {
ConfigKeys.train:
OfaPreprocessor(
@@ -41,7 +42,7 @@ class OFATrainer(EpochBasedTrainer):
# use torchrun launch
world_size = int(os.environ.get('WORLD_SIZE', 1))
epoch_steps = math.ceil(
len(dataset['train']) / # noqa
len(train_dataset) / # noqa
(cfg.train.dataloader.batch_size_per_gpu * world_size)) # noqa
cfg.train.lr_scheduler.num_train_steps = epoch_steps * cfg.train.max_epochs
cfg.train.criterion.tokenizer = model.tokenizer
@@ -68,11 +69,11 @@ class OFATrainer(EpochBasedTrainer):
cfg_file=cfg_file,
model=model,
data_collator=collator,
train_dataset=dataset['train'],
eval_dataset=dataset['valid'],
train_dataset=train_dataset,
eval_dataset=eval_dataset,
preprocessor=preprocessor,
optimizers=(optimizer, lr_scheduler),
work_dir=cfg.train.work_dir,
work_dir=work_dir,
*args,
**kwargs,
)

View File

@@ -3,22 +3,51 @@ import glob
import os
import os.path as osp
import shutil
import tempfile
import unittest
from modelscope.metainfo import Trainers
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
from modelscope.utils.constant import DownloadMode
from modelscope.utils.test_utils import test_level
class TestOfaTrainer(unittest.TestCase):
def setUp(self):
column_map = {'premise': 'text', 'hypothesis': 'text2'}
data_train = MsDataset.load(
dataset_name='glue',
subset_name='mnli',
namespace='modelscope',
split='train[:100]',
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)
self.train_dataset = MsDataset.from_hf_dataset(
data_train._hf_ds.rename_columns(column_map))
data_eval = MsDataset.load(
dataset_name='glue',
subset_name='mnli',
namespace='modelscope',
split='validation_matched[:8]',
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)
self.test_dataset = MsDataset.from_hf_dataset(
data_eval._hf_ds.rename_columns(column_map))
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_trainer(self):
os.environ['LOCAL_RANK'] = '0'
model_id = 'damo/ofa_text-classification_mnli_large_en'
default_args = {'model': model_id}
trainer = build_trainer(
name=Trainers.ofa_tasks, default_args=default_args)
kwargs = dict(
model=model_id,
cfg_file=
'/Users/running_you/.cache/modelscope/hub/damo/ofa_text-classification_mnli_large_en//configuration.json',
train_dataset=self.train_dataset,
eval_dataset=self.test_dataset,
work_dir='/Users/running_you/.cache/modelscope/hub/work/mnli')
trainer = build_trainer(name=Trainers.ofa_tasks, default_args=kwargs)
os.makedirs(trainer.work_dir, exist_ok=True)
trainer.train()
assert len(