mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 03:59:23 +01:00
modify ofatrainer
This commit is contained in:
@@ -24,12 +24,13 @@ from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion,
|
||||
@TRAINERS.register_module(module_name=Trainers.ofa_tasks)
|
||||
class OFATrainer(EpochBasedTrainer):
|
||||
|
||||
def __init__(self, model: str, *args, **kwargs):
|
||||
def __init__(self, model: str, cfg_file, work_dir, train_dataset,
|
||||
eval_dataset, *args, **kwargs):
|
||||
model = Model.from_pretrained(model)
|
||||
model_dir = model.model_dir
|
||||
cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION)
|
||||
# cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION)
|
||||
cfg = Config.from_file(cfg_file)
|
||||
dataset = self._build_dataset_with_config(cfg)
|
||||
# dataset = self._build_dataset_with_config(cfg)
|
||||
preprocessor = {
|
||||
ConfigKeys.train:
|
||||
OfaPreprocessor(
|
||||
@@ -41,7 +42,7 @@ class OFATrainer(EpochBasedTrainer):
|
||||
# use torchrun launch
|
||||
world_size = int(os.environ.get('WORLD_SIZE', 1))
|
||||
epoch_steps = math.ceil(
|
||||
len(dataset['train']) / # noqa
|
||||
len(train_dataset) / # noqa
|
||||
(cfg.train.dataloader.batch_size_per_gpu * world_size)) # noqa
|
||||
cfg.train.lr_scheduler.num_train_steps = epoch_steps * cfg.train.max_epochs
|
||||
cfg.train.criterion.tokenizer = model.tokenizer
|
||||
@@ -68,11 +69,11 @@ class OFATrainer(EpochBasedTrainer):
|
||||
cfg_file=cfg_file,
|
||||
model=model,
|
||||
data_collator=collator,
|
||||
train_dataset=dataset['train'],
|
||||
eval_dataset=dataset['valid'],
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
preprocessor=preprocessor,
|
||||
optimizers=(optimizer, lr_scheduler),
|
||||
work_dir=cfg.train.work_dir,
|
||||
work_dir=work_dir,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -3,22 +3,51 @@ import glob
|
||||
import os
|
||||
import os.path as osp
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.utils.constant import DownloadMode
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class TestOfaTrainer(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
column_map = {'premise': 'text', 'hypothesis': 'text2'}
|
||||
data_train = MsDataset.load(
|
||||
dataset_name='glue',
|
||||
subset_name='mnli',
|
||||
namespace='modelscope',
|
||||
split='train[:100]',
|
||||
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)
|
||||
self.train_dataset = MsDataset.from_hf_dataset(
|
||||
data_train._hf_ds.rename_columns(column_map))
|
||||
data_eval = MsDataset.load(
|
||||
dataset_name='glue',
|
||||
subset_name='mnli',
|
||||
namespace='modelscope',
|
||||
split='validation_matched[:8]',
|
||||
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)
|
||||
self.test_dataset = MsDataset.from_hf_dataset(
|
||||
data_eval._hf_ds.rename_columns(column_map))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_trainer(self):
|
||||
os.environ['LOCAL_RANK'] = '0'
|
||||
model_id = 'damo/ofa_text-classification_mnli_large_en'
|
||||
default_args = {'model': model_id}
|
||||
trainer = build_trainer(
|
||||
name=Trainers.ofa_tasks, default_args=default_args)
|
||||
|
||||
kwargs = dict(
|
||||
model=model_id,
|
||||
cfg_file=
|
||||
'/Users/running_you/.cache/modelscope/hub/damo/ofa_text-classification_mnli_large_en//configuration.json',
|
||||
train_dataset=self.train_dataset,
|
||||
eval_dataset=self.test_dataset,
|
||||
work_dir='/Users/running_you/.cache/modelscope/hub/work/mnli')
|
||||
|
||||
trainer = build_trainer(name=Trainers.ofa_tasks, default_args=kwargs)
|
||||
os.makedirs(trainer.work_dir, exist_ok=True)
|
||||
trainer.train()
|
||||
assert len(
|
||||
|
||||
Reference in New Issue
Block a user