mirror of
https://github.com/modelscope/modelscope.git
synced 2026-05-18 13:15:06 +02:00
Refactor hooks
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11651547
This commit is contained in:
@@ -9,7 +9,6 @@ from tensorflow.python.tools import freeze_graph
|
||||
from modelscope.exporters.builder import EXPORTERS
|
||||
from modelscope.exporters.tf_model_exporter import TfModelExporter
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.pipelines.nlp.translation_pipeline import TranslationPipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.test_utils import compare_arguments_nested
|
||||
@@ -18,7 +17,6 @@ logger = get_logger(__name__)
|
||||
|
||||
if tf.__version__ >= '2.0':
|
||||
tf = tf.compat.v1
|
||||
tf.disable_eager_execution()
|
||||
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
|
||||
@@ -27,7 +25,10 @@ tf.logging.set_verbosity(tf.logging.INFO)
|
||||
class CsanmtForTranslationExporter(TfModelExporter):
|
||||
|
||||
def __init__(self, model=None):
|
||||
tf.disable_eager_execution()
|
||||
super().__init__(model)
|
||||
|
||||
from modelscope.pipelines.nlp.translation_pipeline import TranslationPipeline
|
||||
self.pipeline = TranslationPipeline(self.model)
|
||||
|
||||
def generate_dummy_inputs(self, **kwargs) -> Dict[str, Any]:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Callable, Dict, Mapping
|
||||
|
||||
import tensorflow as tf
|
||||
@@ -30,13 +29,11 @@ class TfModelExporter(Exporter):
|
||||
self._tf2_export_onnx(model, onnx_file, opset=opset, **kwargs)
|
||||
return {'model': onnx_file}
|
||||
|
||||
@abstractmethod
|
||||
def export_saved_model(self, output_dir: str, **kwargs):
|
||||
pass
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def export_frozen_graph_def(self, output_dir: str, **kwargs):
|
||||
pass
|
||||
raise NotImplementedError()
|
||||
|
||||
def _tf2_export_onnx(self,
|
||||
model,
|
||||
|
||||
@@ -1079,6 +1079,7 @@ class Hooks(object):
|
||||
# train
|
||||
EarlyStopHook = 'EarlyStopHook'
|
||||
DeepspeedHook = 'DeepspeedHook'
|
||||
MegatronHook = 'MegatronHook'
|
||||
|
||||
|
||||
class LR_Schedulers(object):
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
@@ -94,7 +95,8 @@ class TorchModel(Model, torch.nn.Module):
|
||||
def save_pretrained(self,
|
||||
target_folder: Union[str, os.PathLike],
|
||||
save_checkpoint_names: Union[str, List[str]] = None,
|
||||
save_function: Callable = save_checkpoint,
|
||||
save_function: Callable = partial(
|
||||
save_checkpoint, with_meta=False),
|
||||
config: Optional[dict] = None,
|
||||
save_config_function: Callable = save_configuration,
|
||||
**kwargs):
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, Optional, Sequence, Union
|
||||
|
||||
from modelscope.metainfo import Models, Preprocessors, TaskModels
|
||||
@@ -348,14 +347,13 @@ class Preprocessor(ABC):
|
||||
|
||||
Args:
|
||||
target_folder (Union[str, os.PathLike]):
|
||||
Directory to which to save. Will be created if it doesn't exist.
|
||||
Directory to which to save. Will be created if it doesn't exist.
|
||||
|
||||
config (Optional[dict], optional):
|
||||
The config for the configuration.json
|
||||
The config for the configuration.json
|
||||
|
||||
save_config_function (Callable): The function used to save the configuration, call this function
|
||||
after the config is updated.
|
||||
|
||||
"""
|
||||
if config is None and hasattr(self, 'cfg'):
|
||||
config = self.cfg
|
||||
|
||||
@@ -27,9 +27,9 @@ class ReferringVideoObjectSegmentationTrainer(EpochBasedTrainer):
|
||||
super().train(*args, **kwargs)
|
||||
|
||||
def evaluate(self, checkpoint_path=None):
|
||||
if checkpoint_path is not None and os.path.isfile(checkpoint_path):
|
||||
from modelscope.trainers.hooks import CheckpointHook
|
||||
CheckpointHook.load_checkpoint(checkpoint_path, self)
|
||||
if checkpoint_path is not None:
|
||||
from modelscope.trainers.hooks import LoadCheckpointHook
|
||||
LoadCheckpointHook.load_checkpoint(checkpoint_path, self)
|
||||
self.model.eval()
|
||||
self._mode = ModeKeys.EVAL
|
||||
if self.eval_dataset is None:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
@@ -155,9 +156,10 @@ class EasyCVEpochBasedTrainer(EpochBasedTrainer):
|
||||
|
||||
def to_parallel(self, model) -> Union[nn.Module, TorchModel]:
|
||||
if self.cfg.get('parallel', None) is not None:
|
||||
self.cfg.parallel.update(
|
||||
dp_cfg = deepcopy(self.cfg['parallel'])
|
||||
dp_cfg.update(
|
||||
dict(module=model, device_ids=[torch.cuda.current_device()]))
|
||||
return build_parallel(self.cfg.parallel)
|
||||
return build_parallel(dp_cfg)
|
||||
|
||||
dp_cfg = dict(
|
||||
type='MMDistributedDataParallel',
|
||||
|
||||
@@ -19,7 +19,7 @@ class AddLrLogHook(LrSchedulerHook):
|
||||
def before_run(self, trainer):
|
||||
pass
|
||||
|
||||
def before_train_iter(self, trainer):
|
||||
def after_train_iter(self, trainer):
|
||||
trainer.log_buffer.output[LogKeys.LR] = self._get_log_lr(trainer)
|
||||
|
||||
def before_train_epoch(self, trainer):
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import importlib
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from shutil import rmtree
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from modelscope import __version__
|
||||
from modelscope.metainfo import Hooks, Pipelines
|
||||
@@ -38,6 +40,10 @@ class CheckpointHook(Hook):
|
||||
|
||||
PRIORITY = Priority.LOW
|
||||
|
||||
TRAINER_STATE_SUFFIX = '_trainer_state.pth'
|
||||
|
||||
MODEL_STATE_SUFFIX = '.pth'
|
||||
|
||||
def __init__(self,
|
||||
interval=0,
|
||||
by_epoch=True,
|
||||
@@ -63,8 +69,8 @@ class CheckpointHook(Hook):
|
||||
if not self.save_dir:
|
||||
self.save_dir = trainer.work_dir
|
||||
|
||||
if not os.path.exists(self.save_dir) and is_master():
|
||||
os.makedirs(self.save_dir)
|
||||
if not os.path.exists(self.save_dir):
|
||||
os.makedirs(self.save_dir, exist_ok=True)
|
||||
|
||||
if not hasattr(trainer, 'logger'):
|
||||
self.logger = get_logger()
|
||||
@@ -72,34 +78,85 @@ class CheckpointHook(Hook):
|
||||
self.logger = trainer.logger
|
||||
|
||||
if is_master():
|
||||
output_dir = os.path.join(self.save_dir, self.output_sub_dir)
|
||||
# only global master prepares the output folder
|
||||
self.prepare_output(trainer, output_dir)
|
||||
self.logger.info(f'Checkpoints will be saved to {self.save_dir}')
|
||||
|
||||
def after_train_epoch(self, trainer):
|
||||
if not self.by_epoch:
|
||||
return
|
||||
|
||||
if self._should_save(trainer):
|
||||
if is_master() or trainer.cfg.model.get('model_parallel_size',
|
||||
1) != 1:
|
||||
if self._should_save(trainer) and self.should_save_on_rank(trainer):
|
||||
if is_master():
|
||||
self.logger.info(
|
||||
f'Saving checkpoint at {trainer.epoch + 1} epoch')
|
||||
self._save_checkpoint(trainer)
|
||||
self._save_checkpoint(trainer)
|
||||
|
||||
def after_train_iter(self, trainer):
|
||||
if self.by_epoch:
|
||||
return
|
||||
|
||||
if self._should_save(trainer) and self.should_save_on_rank(trainer):
|
||||
if is_master():
|
||||
self.logger.info(
|
||||
f'Saving checkpoint at {trainer.iter + 1} epoch')
|
||||
self._save_checkpoint(trainer)
|
||||
|
||||
def _save_checkpoint(self, trainer):
|
||||
if self.by_epoch:
|
||||
cur_save_name = os.path.join(
|
||||
self.save_dir, f'{LogKeys.EPOCH}_{trainer.epoch + 1}.pth')
|
||||
else:
|
||||
cur_save_name = os.path.join(
|
||||
self.save_dir, f'{LogKeys.ITER}_{trainer.iter + 1}.pth')
|
||||
cur_save_name = extend_save_name_for_parallel(cur_save_name)
|
||||
"""Save checkpoint files and remove obsolete ones
|
||||
"""
|
||||
|
||||
if self.by_epoch:
|
||||
checkpoint_path_prefix = os.path.join(
|
||||
self.save_dir, f'{LogKeys.EPOCH}_{trainer.epoch + 1}')
|
||||
else:
|
||||
checkpoint_path_prefix = os.path.join(
|
||||
self.save_dir, f'{LogKeys.ITER}_{trainer.iter + 1}')
|
||||
|
||||
meta = self._create_training_state(trainer)
|
||||
self.save_checkpoints(trainer, checkpoint_path_prefix,
|
||||
self.output_sub_dir, meta)
|
||||
self.history_checkpoints.append(checkpoint_path_prefix)
|
||||
self._remove_obsolete_checkpoints(trainer)
|
||||
|
||||
def _remove_obsolete_checkpoints(self, trainer):
|
||||
if self.max_checkpoint_num is not None and \
|
||||
len(self.history_checkpoints) > self.max_checkpoint_num:
|
||||
history_checkpoints = [ckpt for ckpt in self.history_checkpoints]
|
||||
self.history_checkpoints.clear()
|
||||
for i, checkpoint_path_prefix in enumerate(history_checkpoints):
|
||||
if i < len(history_checkpoints) - self.max_checkpoint_num:
|
||||
self.logger.info(
|
||||
f'deleting checkpoint: {checkpoint_path_prefix}')
|
||||
self.remove_checkpoints(
|
||||
trainer, checkpoint_path_prefix=checkpoint_path_prefix)
|
||||
else:
|
||||
self.history_checkpoints.append(checkpoint_path_prefix)
|
||||
|
||||
def _should_save(self, trainer):
|
||||
if self.by_epoch:
|
||||
check_last = self.is_last_epoch
|
||||
check_frequency = self.every_n_epochs
|
||||
else:
|
||||
check_last = self.is_last_iter
|
||||
check_frequency = self.every_n_iters
|
||||
|
||||
if check_frequency(trainer,
|
||||
self.interval) or (self.save_last
|
||||
and check_last(trainer)):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _create_training_state(self, trainer):
|
||||
self.rng_state = {
|
||||
'random': random.getstate(),
|
||||
'numpy': np.random.get_state(),
|
||||
'cpu': torch.random.get_rng_state(),
|
||||
'cuda': torch.cuda.get_rng_state_all(),
|
||||
}
|
||||
|
||||
# keep epoch/iter/inner_iter/random_state
|
||||
meta = {
|
||||
'epoch': trainer.epoch,
|
||||
'iter': trainer.iter + 1,
|
||||
@@ -107,6 +164,7 @@ class CheckpointHook(Hook):
|
||||
'rng_state': self.rng_state,
|
||||
}
|
||||
|
||||
# keep hooks state
|
||||
i = 0
|
||||
for hook in trainer.hooks:
|
||||
if hasattr(hook, 'state_dict') and getattr(hook, '_should_save',
|
||||
@@ -114,54 +172,13 @@ class CheckpointHook(Hook):
|
||||
meta[f'{hook.__class__}-{i}'] = hook.state_dict()
|
||||
i += 1
|
||||
|
||||
save_checkpoint(
|
||||
trainer.model,
|
||||
cur_save_name,
|
||||
trainer.optimizer,
|
||||
trainer.lr_scheduler,
|
||||
meta=meta)
|
||||
if (self.is_last_epoch(trainer)
|
||||
and self.by_epoch) or (self.is_last_iter(trainer)
|
||||
and not self.by_epoch):
|
||||
self._save_pretrained(trainer)
|
||||
return meta
|
||||
|
||||
self.history_checkpoints.append(cur_save_name)
|
||||
self.remove_obsolete_checkpoints()
|
||||
|
||||
def remove_obsolete_checkpoints(self):
|
||||
if self.max_checkpoint_num is not None and \
|
||||
len(self.history_checkpoints) > self.max_checkpoint_num:
|
||||
history_checkpoints = [ckpt for ckpt in self.history_checkpoints]
|
||||
self.history_checkpoints.clear()
|
||||
for i, ckpt_file in enumerate(history_checkpoints):
|
||||
if i < len(history_checkpoints) - self.max_checkpoint_num:
|
||||
if os.path.isfile(ckpt_file):
|
||||
os.remove(ckpt_file)
|
||||
else:
|
||||
self.history_checkpoints.append(ckpt_file)
|
||||
|
||||
def _save_pretrained(self, trainer):
|
||||
output_dir = os.path.join(self.save_dir, self.output_sub_dir)
|
||||
from modelscope.trainers.parallel.utils import is_parallel
|
||||
|
||||
if is_parallel(trainer.model):
|
||||
model = trainer.model.module
|
||||
else:
|
||||
model = trainer.model
|
||||
|
||||
config = trainer.cfg.to_dict()
|
||||
# override pipeline by tasks name after finetune done,
|
||||
# avoid case like fill mask pipeline with a text cls task
|
||||
if config['task'] in [
|
||||
getattr(Pipelines, attr) for attr in dir(Pipelines)
|
||||
if not attr.startswith('__')
|
||||
]:
|
||||
# TODO a temp fix to avoid pipeline_name and task mismatch
|
||||
config['pipeline'] = {'type': config['task']}
|
||||
|
||||
# remove parallel module that is not JSON serializable
|
||||
if 'parallel' in config and 'module' in config['parallel']:
|
||||
del config['parallel']['module']
|
||||
@staticmethod
|
||||
def copy_files_and_dump_config(trainer, output_dir, config, bin_file):
|
||||
"""Copy useful files to target output folder and dumps the target configuration.json.
|
||||
"""
|
||||
model = trainer.unwrap_module(trainer.model)
|
||||
|
||||
class SaveConfig:
|
||||
|
||||
@@ -178,20 +195,14 @@ class CheckpointHook(Hook):
|
||||
save_config_fn = SaveConfig(output_dir, config)
|
||||
|
||||
if hasattr(model, 'save_pretrained'):
|
||||
# Now support two binary files: pytorch_model.bin and pytorch_model.pt
|
||||
default_bin_file = ModelFile.TORCH_MODEL_BIN_FILE
|
||||
if hasattr(
|
||||
model,
|
||||
'model_dir') and ModelFile.TORCH_MODEL_FILE in os.listdir(
|
||||
model.model_dir):
|
||||
default_bin_file = ModelFile.TORCH_MODEL_FILE
|
||||
# Save pretrained of model, skip saving checkpoint
|
||||
model.save_pretrained(
|
||||
output_dir,
|
||||
default_bin_file,
|
||||
save_function=save_checkpoint,
|
||||
bin_file,
|
||||
save_function=lambda *args, **kwargs: None,
|
||||
config=save_config_fn.config,
|
||||
save_config_function=save_config_fn,
|
||||
with_meta=False)
|
||||
save_config_function=save_config_fn)
|
||||
|
||||
if trainer.train_preprocessor is not None:
|
||||
trainer.train_preprocessor.save_pretrained(
|
||||
output_dir,
|
||||
@@ -204,30 +215,141 @@ class CheckpointHook(Hook):
|
||||
save_config_function=save_config_fn)
|
||||
save_config_fn.save_config()
|
||||
|
||||
def after_train_iter(self, trainer):
|
||||
if self.by_epoch:
|
||||
return
|
||||
@staticmethod
|
||||
def _bin_file(model):
|
||||
"""Get bin file path.
|
||||
"""
|
||||
default_bin_file = ModelFile.TORCH_MODEL_BIN_FILE
|
||||
if hasattr(model,
|
||||
'model_dir') and ModelFile.TORCH_MODEL_FILE in os.listdir(
|
||||
model.model_dir):
|
||||
default_bin_file = ModelFile.TORCH_MODEL_FILE
|
||||
return default_bin_file
|
||||
|
||||
if self._should_save(trainer):
|
||||
if is_master() or trainer.cfg.model.get('model_parallel_size',
|
||||
1) != 1:
|
||||
self.logger.info(
|
||||
f'Saving checkpoint at {trainer.iter + 1} iterations')
|
||||
self._save_checkpoint(trainer)
|
||||
@Hook.overload_func(name='CheckpointHook.prepare_output')
|
||||
def prepare_output(self, trainer, output_dir):
|
||||
"""Prepares the output of target folder.
|
||||
|
||||
def _should_save(self, trainer):
|
||||
if self.by_epoch:
|
||||
check_last = self.is_last_epoch
|
||||
check_frequency = self.every_n_epochs
|
||||
else:
|
||||
check_last = self.is_last_iter
|
||||
check_frequency = self.every_n_iters
|
||||
This is a strategic function which can be registered by other hook's function.
|
||||
|
||||
if check_frequency(trainer,
|
||||
self.interval) or (self.save_last
|
||||
and check_last(trainer)):
|
||||
return True
|
||||
return False
|
||||
Args:
|
||||
trainer: The trainer instance.
|
||||
output_dir: The target folder used in inference.
|
||||
"""
|
||||
model = trainer.unwrap_module(trainer.model)
|
||||
config = trainer.cfg.to_dict()
|
||||
|
||||
# override pipeline by tasks name after finetune done,
|
||||
# avoid case like fill mask pipeline with a text cls task
|
||||
if config['task'] in [
|
||||
getattr(Pipelines, attr) for attr in dir(Pipelines)
|
||||
if not attr.startswith('__')
|
||||
]:
|
||||
# TODO a temp fix to avoid pipeline_name and task mismatch
|
||||
config['pipeline'] = {'type': config['task']}
|
||||
|
||||
self.copy_files_and_dump_config(trainer, output_dir, config,
|
||||
self._bin_file(model))
|
||||
|
||||
def link(self, model, src_file, output_dir):
|
||||
"""Links the src bin file to the output folder.
|
||||
|
||||
Args:
|
||||
model: The model instance.
|
||||
src_file: The src bin file path.
|
||||
output_dir: The target folder used in inference.
|
||||
"""
|
||||
|
||||
bin_file = self._bin_file(model)
|
||||
dest_file = os.path.join(output_dir, bin_file)
|
||||
if os.path.isfile(dest_file):
|
||||
os.unlink(dest_file)
|
||||
|
||||
os.link(src_file, dest_file)
|
||||
|
||||
def save_trainer_state(self, trainer, model, train_state_file, meta):
|
||||
"""Save the trainer state, including optimizer/lr_scheduler's state dict, random states etc.
|
||||
|
||||
Args:
|
||||
trainer: The trainer instance.
|
||||
model: The model instance.
|
||||
train_state_file: The target file name for saving trainer states.
|
||||
meta: Some extra meta info.
|
||||
"""
|
||||
save_checkpoint(
|
||||
model,
|
||||
train_state_file,
|
||||
trainer.optimizer,
|
||||
trainer.lr_scheduler,
|
||||
meta=meta,
|
||||
with_model=False)
|
||||
|
||||
def save_model_state(self, model, model_file):
|
||||
"""Save the model state.
|
||||
|
||||
Args:
|
||||
model: The model instance.
|
||||
model_file: The target file name for saving model states.
|
||||
"""
|
||||
save_checkpoint(
|
||||
model, model_file, None, None, meta=None, with_meta=False)
|
||||
|
||||
@Hook.overload_func(name='CheckpointHook.save_checkpoints')
|
||||
def save_checkpoints(self,
|
||||
trainer,
|
||||
checkpoint_path_prefix,
|
||||
output_sub_dir,
|
||||
meta=None):
|
||||
"""Save the state dict for trainer and model.
|
||||
|
||||
This is a strategic function which can be registered by other hook's function.
|
||||
|
||||
Args:
|
||||
trainer(`EpochBasedTrainer`): The trainer instance.
|
||||
checkpoint_path_prefix(`str`): The saving dir with a prefix.
|
||||
like: /tmp/test/epoch_0
|
||||
output_sub_dir(`str`): The sub-dir in the saving dir used in inference.
|
||||
meta: (`dict`): The meta info needed to be saved into files.
|
||||
"""
|
||||
model = trainer.unwrap_module(trainer.model)
|
||||
_model_file, _train_state_file = _get_state_file_name(
|
||||
checkpoint_path_prefix)
|
||||
|
||||
# Save pth file without model state_dict
|
||||
self.save_trainer_state(trainer, model, _train_state_file, meta)
|
||||
self.save_model_state(model, _model_file)
|
||||
output_dir = os.path.join(self.save_dir, output_sub_dir)
|
||||
self.link(model, _model_file, output_dir)
|
||||
|
||||
@Hook.overload_func(name='CheckpointHook.remove_checkpoints')
|
||||
def remove_checkpoints(self, trainer, checkpoint_path_prefix):
|
||||
"""Remove obsolete checkpoint files.
|
||||
|
||||
This is a strategic function which can be registered by other hook's function.
|
||||
|
||||
Args:
|
||||
trainer(`EpochBasedTrainer`): The trainer instance.
|
||||
checkpoint_path_prefix(`str`): The saving dir with a prefix.
|
||||
like: /tmp/test/epoch_0
|
||||
"""
|
||||
_model_file, _train_state_file = _get_state_file_name(
|
||||
checkpoint_path_prefix)
|
||||
if os.path.isfile(_train_state_file):
|
||||
os.remove(_train_state_file)
|
||||
|
||||
if os.path.isfile(_model_file):
|
||||
os.remove(_model_file)
|
||||
|
||||
@Hook.overload_func(name='CheckpointHook.should_save_on_rank')
|
||||
def should_save_on_rank(self, trainer):
|
||||
"""Used in ddp or other distributed training scenario, returns whether do saving in current rank.
|
||||
|
||||
This is a strategic function which can be registered by other hook's function.
|
||||
|
||||
Args:
|
||||
trainer(`EpochBasedTrainer`): The trainer instance.
|
||||
"""
|
||||
return is_master()
|
||||
|
||||
|
||||
@HOOKS.register_module(module_name=Hooks.BestCkptSaverHook)
|
||||
@@ -306,52 +428,33 @@ class BestCkptSaverHook(CheckpointHook):
|
||||
return False
|
||||
|
||||
def _save_checkpoint(self, trainer):
|
||||
cur_save_name = self.save_file_name
|
||||
if cur_save_name is None:
|
||||
checkpoint_path_prefix = self.save_file_name
|
||||
if checkpoint_path_prefix is None:
|
||||
if self.by_epoch:
|
||||
cur_save_name = os.path.join(
|
||||
checkpoint_path_prefix = os.path.join(
|
||||
self.save_dir,
|
||||
f'best_{LogKeys.EPOCH}{trainer.epoch + 1}_{self.metric_key}{self._best_metric}.pth'
|
||||
f'best_{LogKeys.EPOCH}{trainer.epoch + 1}_{self.metric_key}{self._best_metric}'
|
||||
)
|
||||
else:
|
||||
cur_save_name = os.path.join(
|
||||
checkpoint_path_prefix = os.path.join(
|
||||
self.save_dir,
|
||||
f'best_{LogKeys.ITER}{trainer.iter + 1}_{self.metric_key}{self._best_metric}.pth'
|
||||
f'best_{LogKeys.ITER}{trainer.iter + 1}_{self.metric_key}{self._best_metric}'
|
||||
)
|
||||
else:
|
||||
if '.' not in cur_save_name:
|
||||
cur_save_name = f'{cur_save_name}.pth'
|
||||
cur_save_name = os.path.join(self.save_dir, cur_save_name)
|
||||
cur_save_name = extend_save_name_for_parallel(cur_save_name)
|
||||
checkpoint_path_prefix = os.path.join(self.save_dir,
|
||||
checkpoint_path_prefix)
|
||||
|
||||
meta = {
|
||||
'epoch': trainer.epoch,
|
||||
'iter': trainer.iter + 1,
|
||||
'inner_iter': trainer.inner_iter + 1,
|
||||
'rng_state': self.rng_state,
|
||||
}
|
||||
self._best_ckpt_file = checkpoint_path_prefix
|
||||
meta = self._create_training_state(trainer)
|
||||
self.save_checkpoints(trainer, checkpoint_path_prefix,
|
||||
self.output_sub_dir, meta)
|
||||
self.history_checkpoints.add(checkpoint_path_prefix)
|
||||
self._remove_obsolete_checkpoints(trainer)
|
||||
|
||||
i = 0
|
||||
for hook in trainer.hooks:
|
||||
if hasattr(hook, 'state_dict') and getattr(hook, '_should_save',
|
||||
True):
|
||||
meta[f'{hook.__class__}-{i}'] = hook.state_dict()
|
||||
i += 1
|
||||
|
||||
if os.path.isfile(cur_save_name):
|
||||
os.remove(cur_save_name)
|
||||
save_checkpoint(trainer.model, cur_save_name, trainer.optimizer,
|
||||
trainer.lr_scheduler, meta)
|
||||
self._best_ckpt_file = cur_save_name
|
||||
self._save_pretrained(trainer)
|
||||
self.history_checkpoints.add(cur_save_name)
|
||||
self.remove_obsolete_checkpoints()
|
||||
|
||||
def remove_obsolete_checkpoints(self):
|
||||
def _remove_obsolete_checkpoints(self, trainer):
|
||||
|
||||
def extract_metric_from_filename(name1):
|
||||
metric1 = float('.'.join(
|
||||
name1.split(self.metric_key)[1].split('.')[:-1]))
|
||||
metric1 = float(name1.split(self.metric_key)[1])
|
||||
if self.rule == 'max':
|
||||
return -metric1
|
||||
else:
|
||||
@@ -362,11 +465,14 @@ class BestCkptSaverHook(CheckpointHook):
|
||||
history_checkpoints = sorted(
|
||||
self.history_checkpoints, key=extract_metric_from_filename)
|
||||
self.history_checkpoints.clear()
|
||||
for i, ckpt_file in enumerate(history_checkpoints):
|
||||
for i, checkpoint_path_prefix in enumerate(history_checkpoints):
|
||||
if i < self.max_checkpoint_num:
|
||||
self.history_checkpoints.add(ckpt_file)
|
||||
elif os.path.isfile(ckpt_file):
|
||||
os.remove(ckpt_file)
|
||||
self.history_checkpoints.add(checkpoint_path_prefix)
|
||||
else:
|
||||
self.logger.info(
|
||||
f'deleting checkpoint: {checkpoint_path_prefix}')
|
||||
self.remove_checkpoints(
|
||||
trainer, checkpoint_path_prefix=checkpoint_path_prefix)
|
||||
|
||||
def state_dict(self):
|
||||
return {
|
||||
@@ -383,9 +489,9 @@ class BestCkptSaverHook(CheckpointHook):
|
||||
|
||||
def after_run(self, trainer):
|
||||
if self.restore_best:
|
||||
if is_master():
|
||||
LoadCheckpointHook.load_checkpoint(self._best_ckpt_file,
|
||||
trainer)
|
||||
# If restore_best is True, will call the LoadCheckpointHook to load the best checkpoint
|
||||
# for later evaluation or prediction.
|
||||
LoadCheckpointHook.load_checkpoint(self._best_ckpt_file, trainer)
|
||||
|
||||
|
||||
@HOOKS.register_module(module_name=Hooks.LoadCheckpointHook)
|
||||
@@ -403,21 +509,26 @@ class LoadCheckpointHook(Hook):
|
||||
checkpoint_file (str): The checkpoint file to be loaded.
|
||||
load_all_state (bool): Load all states(optimizer, epoch, lr_scheduler, random_state, etc.) when loading old
|
||||
training state file or not. The model's state dict will only be loaded if False.
|
||||
strict (bool): If strict, any unmatched keys will cause an error.
|
||||
"""
|
||||
|
||||
PRIORITY = Priority.HIGH
|
||||
|
||||
_should_save = False
|
||||
|
||||
_TWO_PTH_FILE_VERSION = '1.3.1'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_file=None,
|
||||
load_all_state=True,
|
||||
strict=False,
|
||||
):
|
||||
self.checkpoint_file = checkpoint_file
|
||||
self.rng_state = None
|
||||
self.need_load_rng_state = False
|
||||
self.load_all_state = load_all_state
|
||||
self.strict = strict
|
||||
|
||||
def before_run(self, trainer):
|
||||
if not hasattr(trainer, 'logger'):
|
||||
@@ -425,10 +536,9 @@ class LoadCheckpointHook(Hook):
|
||||
else:
|
||||
self.logger = trainer.logger
|
||||
|
||||
if self.checkpoint_file is not None and os.path.isfile(
|
||||
self.checkpoint_file):
|
||||
if self.checkpoint_file is not None:
|
||||
meta = self.load_checkpoint(self.checkpoint_file, trainer,
|
||||
self.load_all_state)
|
||||
self.load_all_state, self.strict)
|
||||
self.rng_state = meta.get('rng_state')
|
||||
self.need_load_rng_state = self.load_all_state
|
||||
|
||||
@@ -442,69 +552,136 @@ class LoadCheckpointHook(Hook):
|
||||
torch.cuda.random.set_rng_state_all(self.rng_state['cuda'])
|
||||
self.need_load_rng_state = False
|
||||
else:
|
||||
self.logger.warning(
|
||||
self.logger.info(
|
||||
'Random state cannot be found in checkpoint file, '
|
||||
'this may cause a random data order or model initialization.'
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _restore_training_state(trainer, meta):
|
||||
trainer._epoch = meta.get('epoch', trainer._epoch)
|
||||
trainer._iter = meta.get('iter', trainer._iter)
|
||||
trainer._inner_iter = meta.get('inner_iter', trainer._inner_iter)
|
||||
|
||||
i = 0
|
||||
for hook in trainer.hooks:
|
||||
if hasattr(hook, 'load_state_dict') and getattr(
|
||||
hook, '_should_save', True):
|
||||
key = f'{hook.__class__}-{i}'
|
||||
if key in meta:
|
||||
hook.load_state_dict(meta.get(key, {}))
|
||||
else:
|
||||
trainer.logger.warning(
|
||||
f'The state_dict of hook {hook.__class__} at index {i} is not found in the checkpoint file.'
|
||||
)
|
||||
i += 1
|
||||
|
||||
@classmethod
|
||||
def load_checkpoint(cls, filename, trainer, load_all_state=True):
|
||||
from modelscope.trainers.parallel.utils import is_parallel
|
||||
if is_parallel(trainer.model):
|
||||
model = trainer.model.module
|
||||
else:
|
||||
model = trainer.model
|
||||
meta = load_checkpoint(
|
||||
filename, model,
|
||||
getattr(trainer, 'optimizer', None) if load_all_state else None,
|
||||
getattr(trainer, 'lr_scheduler', None) if load_all_state else None)
|
||||
def load_checkpoint(cls,
|
||||
filename,
|
||||
trainer,
|
||||
load_all_state=True,
|
||||
strict=False):
|
||||
"""A static method to load checkpoint files.
|
||||
|
||||
Args:
|
||||
filename(str): An absolute model bin file(pth or bin) or a dir path with a file prefix(like epoch_1).
|
||||
trainer(`EpochBasedTrainer`): The trainer instance.
|
||||
load_all_state(`bool`): Load all states including the trainer states.
|
||||
strict(`bool`): Load module state dict strictly.
|
||||
|
||||
Returns:
|
||||
A dict containing the train states saved by `_create_training_state`
|
||||
"""
|
||||
meta = cls().load_checkpoints(filename, trainer, load_all_state,
|
||||
strict)
|
||||
if load_all_state:
|
||||
trainer._epoch = meta.get('epoch', trainer._epoch)
|
||||
trainer._iter = meta.get('iter', trainer._iter)
|
||||
trainer._inner_iter = meta.get('inner_iter', trainer._inner_iter)
|
||||
cls._restore_training_state(trainer, meta)
|
||||
|
||||
i = 0
|
||||
for hook in trainer.hooks:
|
||||
if hasattr(hook, 'load_state_dict') and getattr(
|
||||
hook, '_should_save', True):
|
||||
key = f'{hook.__class__}-{i}'
|
||||
if key in meta:
|
||||
hook.load_state_dict(meta.get(key, {}))
|
||||
else:
|
||||
trainer.logger.warning(
|
||||
f'The state_dict of hook {hook.__class__} at index {i} is not found in the checkpoint file.'
|
||||
)
|
||||
i += 1
|
||||
|
||||
version = meta.get('modelscope')
|
||||
if version != __version__:
|
||||
trainer.logger.warning(
|
||||
f'The modelscope version of loaded checkpoint does not match the runtime version. '
|
||||
f'The saved version: {version}, runtime version: {__version__}'
|
||||
if meta is not None:
|
||||
_version = meta.get('modelscope')
|
||||
if _version is not None and version.parse(
|
||||
_version) < version.parse(
|
||||
LoadCheckpointHook._TWO_PTH_FILE_VERSION):
|
||||
trainer.logger.warning(
|
||||
'The unique pth file is split into a model file and '
|
||||
f'a trainer file since version {LoadCheckpointHook._TWO_PTH_FILE_VERSION},'
|
||||
'consider re-training your model or '
|
||||
'using a converting script to split the single pth file into two.'
|
||||
)
|
||||
trainer.logger.info(
|
||||
f'Checkpoint {filename} saving time: {meta.get("time")}, modelscope version: {_version}'
|
||||
)
|
||||
trainer.logger.info(
|
||||
f'Checkpoint {filename} saving time: {meta.get("time")}')
|
||||
return meta
|
||||
|
||||
@staticmethod
|
||||
def load_trainer_state(trainer, train_state_file, load_all_state):
|
||||
"""Load trainer state file.
|
||||
"""
|
||||
|
||||
optimizer = getattr(trainer, 'optimizer',
|
||||
None) if load_all_state else None
|
||||
lr_scheduler = getattr(trainer, 'lr_scheduler',
|
||||
None) if load_all_state else None
|
||||
return load_checkpoint(train_state_file, None, optimizer, lr_scheduler)
|
||||
|
||||
def load_model_state(self, trainer, model_file, strict):
|
||||
"""Load model state file.
|
||||
"""
|
||||
return load_checkpoint(model_file,
|
||||
trainer.unwrap_module(trainer.model), None,
|
||||
None)
|
||||
|
||||
@Hook.overload_func(name='LoadCheckpointHook.load_checkpoints')
|
||||
def load_checkpoints(self, checkpoint_path_prefix, trainer, load_all_state,
|
||||
strict):
|
||||
"""Load checkpoint files of trainer state and model state.
|
||||
|
||||
This is a strategic function which can be registered by other hook's function.
|
||||
|
||||
Args:
|
||||
checkpoint_path_prefix(str): The checkpoint dir with prefix or a model state file.
|
||||
Example: '/tmp/test/epoch_0' or '/tmp/test/epoch_0.pth'
|
||||
trainer(`EpochBasedTrainer`): The trainer instance.
|
||||
load_all_state(`boolean`): Load all states (else load only module states).
|
||||
strict(`boolean`): If strict, any unmatched keys will cause an error.
|
||||
|
||||
Returns:
|
||||
The meta info in json.
|
||||
"""
|
||||
_model_file, _train_state_file = _get_state_file_name(
|
||||
checkpoint_path_prefix)
|
||||
meta = {}
|
||||
if os.path.isfile(_train_state_file):
|
||||
meta = self.load_trainer_state(trainer, _train_state_file,
|
||||
load_all_state)
|
||||
else:
|
||||
print(f'No trainer state file {_train_state_file} found, skip.')
|
||||
self.load_model_state(trainer, _model_file, strict)
|
||||
return meta
|
||||
|
||||
|
||||
def extend_save_name_for_parallel(cur_save_name: str) -> str:
|
||||
"""Saving model parameters during tensor parallel training
|
||||
requires each process to save its own parameters,
|
||||
This function will try to get the local rank of the process
|
||||
and extend save name for multi-slice model.
|
||||
def _get_state_file_name(checkpoint_path_prefix):
|
||||
"""Get the default file name for state files.
|
||||
|
||||
If the input is a checkpoint dir with prefix, this function will append suffix for both checkpoint files.
|
||||
If the input is an absolute file name, this function will return it as the model file name, and append
|
||||
suffix for the trainer file name.
|
||||
|
||||
NOTE: a best checkpoint filename with float or int metric value inside
|
||||
will not be judged as having a extension file name. like: '/tmp/test/epoch_0_accuracy0.85'
|
||||
|
||||
Args:
|
||||
cur_save_name (str): Original save name.
|
||||
checkpoint_path_prefix(`str`): The checkpoint dir with prefix or a model state file with extension file name.
|
||||
like: '/tmp/test/epoch_0'
|
||||
|
||||
Returns:
|
||||
str: Extended save name.
|
||||
A tuple of model state file name and trainer state file name.
|
||||
"""
|
||||
try:
|
||||
mpu = importlib.import_module('megatron_util.mpu')
|
||||
tp_world_size = mpu.get_tensor_model_parallel_world_size()
|
||||
if tp_world_size == 1:
|
||||
return cur_save_name
|
||||
mp_rank = mpu.get_tensor_model_parallel_rank()
|
||||
return cur_save_name.replace('.', '_mp_rank_{:02d}.'.format(mp_rank))
|
||||
except (ImportError, AssertionError):
|
||||
return cur_save_name
|
||||
base, ext = os.path.splitext(checkpoint_path_prefix)
|
||||
if len(ext) == 0 or re.match(r'^\d+$', ext[1:]):
|
||||
return checkpoint_path_prefix + CheckpointHook.MODEL_STATE_SUFFIX, \
|
||||
checkpoint_path_prefix + CheckpointHook.TRAINER_STATE_SUFFIX
|
||||
else:
|
||||
return checkpoint_path_prefix, base + CheckpointHook.TRAINER_STATE_SUFFIX.split(
|
||||
'.')[0] + '.' + ext[1:]
|
||||
|
||||
@@ -1,35 +1,149 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
from types import MethodType
|
||||
import shutil
|
||||
|
||||
import deepspeed
|
||||
from megatron_util import mpu
|
||||
import torch
|
||||
from deepspeed import DeepSpeedEngine
|
||||
from megatron_util import mpu, print_rank_0
|
||||
|
||||
from modelscope.metainfo import Hooks
|
||||
from modelscope.trainers.hooks import (BestCkptSaverHook, CheckpointHook,
|
||||
LrSchedulerHook, NoneLrSchedulerHook,
|
||||
NoneOptimizerHook, OptimizerHook)
|
||||
from modelscope.trainers.lrscheduler.builder import build_lr_scheduler
|
||||
from modelscope.utils.constant import LogKeys, ModelFile
|
||||
from modelscope.utils.torch_utils import is_master
|
||||
from .builder import HOOKS
|
||||
from .hook import Hook
|
||||
from .priority import Priority
|
||||
from modelscope.trainers.hooks.builder import HOOKS
|
||||
from modelscope.trainers.hooks.hook import Hook
|
||||
from modelscope.trainers.hooks.priority import Priority
|
||||
from modelscope.utils.checkpoint import save_checkpoint
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .checkpoint_hook import CheckpointHook, LoadCheckpointHook
|
||||
from .megatron_hook import MegatronHook
|
||||
|
||||
|
||||
@HOOKS.register_module(module_name=Hooks.DeepspeedHook)
|
||||
class DeepspeedHook(Hook):
|
||||
class DeepspeedHook(MegatronHook):
|
||||
PRIORITY = Priority.VERY_HIGH
|
||||
|
||||
def __init__(self,
|
||||
deepspeed_activation_checkpointing=True,
|
||||
save_zero_checkpoint=False,
|
||||
loss_key='loss'):
|
||||
with_mpu=True):
|
||||
self.save_zero_checkpoint = save_zero_checkpoint
|
||||
self.loss_key = loss_key
|
||||
self.deepspeed_activation_checkpointing = deepspeed_activation_checkpointing
|
||||
# TODO without mpu
|
||||
self.with_mpu = with_mpu
|
||||
assert with_mpu, 'DeepspeedHook now is only for mpu models.'
|
||||
|
||||
def register_strategy(self):
|
||||
Hook.overload(name='OptimizerHook.backward', function=self.backward)
|
||||
Hook.overload(
|
||||
name='OptimizerHook.initialize_optimizer', function=self.idle)
|
||||
Hook.overload(name='LrSchedulerHook.step', function=self.idle)
|
||||
Hook.overload(
|
||||
name='CheckpointHook.save_checkpoints',
|
||||
function=self.save_checkpoints)
|
||||
Hook.overload(
|
||||
name='LoadCheckpointHook.load_checkpoints',
|
||||
function=self.load_checkpoints)
|
||||
Hook.overload(
|
||||
name='CheckpointHook.remove_checkpoints',
|
||||
function=self.remove_checkpoints)
|
||||
Hook.overload(
|
||||
name='CheckpointHook.prepare_output', function=self.prepare_output)
|
||||
if self.with_mpu:
|
||||
Hook.overload(
|
||||
name='CheckpointHook.should_save_on_rank',
|
||||
function=self.should_save_on_rank)
|
||||
|
||||
def backward(self, trainer, loss_keys, cumulative_iters, grad_clip):
|
||||
# assert cumulative_iters == 1, 'DeepSpeed only support cumulative_iters=1'
|
||||
# The `trainer.model` here is actually a deepspeed engine object.
|
||||
# backward step
|
||||
for k in loss_keys:
|
||||
loss = trainer.train_outputs[k]
|
||||
trainer.model.backward(loss)
|
||||
|
||||
# update parameters
|
||||
trainer.model.step()
|
||||
|
||||
def idle(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def save_checkpoints(self,
|
||||
trainer,
|
||||
checkpoint_path_prefix,
|
||||
output_sub_dir,
|
||||
meta=None):
|
||||
model = trainer.unwrap_module(trainer.model)
|
||||
_train_state_file = checkpoint_path_prefix + self.rank_name(
|
||||
) + CheckpointHook.TRAINER_STATE_SUFFIX
|
||||
# Save pth file without model state_dict
|
||||
save_checkpoint(
|
||||
model, _train_state_file, None, None, meta=meta, with_model=False)
|
||||
|
||||
save_dir = os.path.dirname(checkpoint_path_prefix)
|
||||
prefix = os.path.basename(checkpoint_path_prefix)
|
||||
trainer.model.save_checkpoint(save_dir, prefix)
|
||||
|
||||
bin_file = self.get_bin_file()
|
||||
src_file = os.path.join(checkpoint_path_prefix, bin_file)
|
||||
dest_file = os.path.join(save_dir, output_sub_dir, self._BIN_FILE_DIR,
|
||||
bin_file)
|
||||
if os.path.isfile(dest_file):
|
||||
os.unlink(dest_file)
|
||||
|
||||
os.link(src_file, dest_file)
|
||||
|
||||
def remove_checkpoints(self, trainer, checkpoint_path_prefix):
|
||||
_train_state_file = checkpoint_path_prefix + self.rank_name(
|
||||
) + CheckpointHook.TRAINER_STATE_SUFFIX
|
||||
if os.path.isfile(_train_state_file):
|
||||
os.remove(_train_state_file)
|
||||
|
||||
shutil.rmtree(checkpoint_path_prefix, ignore_errors=True)
|
||||
|
||||
def load_checkpoints(self, checkpoint_path_prefix, trainer, load_all_state,
|
||||
strict):
|
||||
assert os.path.isdir(checkpoint_path_prefix)
|
||||
path = os.path.dirname(checkpoint_path_prefix)
|
||||
tag = os.path.basename(checkpoint_path_prefix)
|
||||
|
||||
meta = {}
|
||||
_train_state_file = checkpoint_path_prefix + self.rank_name(
|
||||
) + CheckpointHook.TRAINER_STATE_SUFFIX
|
||||
if os.path.isfile(_train_state_file):
|
||||
meta = LoadCheckpointHook.load_trainer_state(
|
||||
trainer, _train_state_file, load_all_state)
|
||||
|
||||
if isinstance(trainer.model, DeepSpeedEngine):
|
||||
# DeepSpeedEngine is initialized
|
||||
trainer.model.load_checkpoint(
|
||||
path,
|
||||
tag,
|
||||
load_module_strict=strict,
|
||||
load_module_only=not load_all_state,
|
||||
)
|
||||
else:
|
||||
# in eval or prediction
|
||||
save_dir = checkpoint_path_prefix
|
||||
bin_file = self.get_bin_file()
|
||||
model_file = os.path.join(save_dir, bin_file)
|
||||
checkpoint = torch.load(
|
||||
model_file, map_location=lambda storage, loc: storage)
|
||||
checkpoint = checkpoint['module']
|
||||
model_dict = trainer.unwrap_module(trainer.model).state_dict()
|
||||
for key in checkpoint:
|
||||
if key not in model_dict.keys():
|
||||
print_rank_0('Skip key: ' + key)
|
||||
else:
|
||||
print_rank_0('Loading key: ' + key)
|
||||
trainer.unwrap_module(trainer.model).load_state_dict(
|
||||
checkpoint, strict=strict)
|
||||
return meta
|
||||
|
||||
def before_run(self, trainer):
|
||||
if not hasattr(trainer, 'logger'):
|
||||
self.logger = get_logger()
|
||||
else:
|
||||
self.logger = trainer.logger
|
||||
|
||||
# deepspeed init
|
||||
args = trainer.cfg.train
|
||||
args.deepspeed_config = os.path.join(trainer.model_dir,
|
||||
@@ -45,9 +159,7 @@ class DeepspeedHook(Hook):
|
||||
trainer.model.save_zero_checkpoint = self.save_zero_checkpoint
|
||||
|
||||
if self.deepspeed_activation_checkpointing:
|
||||
model = trainer.model
|
||||
while hasattr(model, 'module'):
|
||||
model = model.module
|
||||
model = trainer.unwrap_module(trainer.model)
|
||||
deepspeed.checkpointing.configure(
|
||||
mpu,
|
||||
deepspeed_config=args.deepspeed_config,
|
||||
@@ -56,61 +168,3 @@ class DeepspeedHook(Hook):
|
||||
mpu.checkpoint = deepspeed.checkpointing.checkpoint
|
||||
mpu.get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
|
||||
mpu.model_parallel_cuda_manual_seed = deepspeed.checkpointing.model_parallel_cuda_manual_seed
|
||||
|
||||
# modify hooks
|
||||
for i, hook in enumerate(trainer._hooks):
|
||||
# backward & step
|
||||
if isinstance(hook, OptimizerHook):
|
||||
trainer._hooks[i] = NoneOptimizerHook()
|
||||
if isinstance(hook, LrSchedulerHook):
|
||||
trainer._hooks[i] = NoneLrSchedulerHook()
|
||||
|
||||
# save checkpoint
|
||||
if isinstance(hook, CheckpointHook):
|
||||
|
||||
def _save_checkpoint(self, trainer):
|
||||
if self.by_epoch:
|
||||
cur_save_dir = os.path.join(
|
||||
self.save_dir,
|
||||
f'{LogKeys.EPOCH}_{trainer.epoch + 1}')
|
||||
else:
|
||||
cur_save_dir = os.path.join(
|
||||
self.save_dir,
|
||||
f'{LogKeys.ITER}_{trainer.iter + 1}')
|
||||
if (self.is_last_epoch(trainer)
|
||||
and self.by_epoch) or (self.is_last_iter(trainer)
|
||||
and not self.by_epoch):
|
||||
cur_save_dir = os.path.join(self.save_dir,
|
||||
ModelFile.TRAIN_OUTPUT_DIR)
|
||||
trainer.model.save_checkpoint(cur_save_dir)
|
||||
|
||||
trainer._hooks[i]._save_checkpoint = MethodType(
|
||||
_save_checkpoint, trainer._hooks[i])
|
||||
|
||||
if isinstance(hook, BestCkptSaverHook):
|
||||
|
||||
def _save_checkpoint(self, trainer):
|
||||
if self.by_epoch:
|
||||
cur_save_dir = os.path.join(
|
||||
self.save_dir,
|
||||
f'best_{LogKeys.EPOCH}{trainer.epoch + 1}_{self.metric_key}{self._best_metric}'
|
||||
)
|
||||
else:
|
||||
cur_save_dir = os.path.join(
|
||||
self.save_dir,
|
||||
f'best_{LogKeys.ITER}{trainer.iter + 1}_{self.metric_key}{self._best_metric}.pth'
|
||||
)
|
||||
trainer.model.save_checkpoint(cur_save_dir)
|
||||
self._best_ckpt_file = cur_save_dir
|
||||
|
||||
trainer._hooks[i]._save_checkpoint = MethodType(
|
||||
_save_checkpoint, trainer._hooks[i])
|
||||
|
||||
def after_train_iter(self, trainer):
|
||||
# The `trainer.model` here is actually a deepspeed engine object.
|
||||
# backward step
|
||||
loss = trainer.train_outputs[self.loss_key]
|
||||
trainer.model.backward(loss)
|
||||
|
||||
# update parameters
|
||||
trainer.model.step()
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from functools import wraps
|
||||
|
||||
from modelscope.utils.constant import TrainerStages
|
||||
from modelscope.utils.import_utils import is_method_overridden
|
||||
from .priority import Priority
|
||||
@@ -18,6 +20,9 @@ class Hook:
|
||||
|
||||
PRIORITY = Priority.NORMAL
|
||||
|
||||
# The strategic function dict.
|
||||
_strategies = dict()
|
||||
|
||||
def before_run(self, trainer):
|
||||
"""
|
||||
Will be called before any loop begins.
|
||||
@@ -221,3 +226,54 @@ class Hook:
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def clear_strategies():
|
||||
Hook._strategies.clear()
|
||||
|
||||
@staticmethod
|
||||
def overload(function, name=None):
|
||||
"""Register a function to a strategic function.
|
||||
|
||||
Args:
|
||||
function(`method` or `Callable`): The function instance.
|
||||
name(`str`): The name of the strategic function, which specifies by the method `consume`
|
||||
"""
|
||||
|
||||
_name = name or function.__name__
|
||||
if _name not in Hook._strategies:
|
||||
Hook._strategies[_name] = []
|
||||
|
||||
Hook._strategies[_name].append(function)
|
||||
|
||||
@staticmethod
|
||||
def overload_func(name=None):
|
||||
"""Declare a function as a strategic function, which can be replaced by some other functions.
|
||||
|
||||
This function should be used in annotations.
|
||||
|
||||
Args:
|
||||
name(str): The strategic function name.
|
||||
"""
|
||||
|
||||
def _register(function):
|
||||
|
||||
@wraps(function)
|
||||
def _call(*args, **kwargs):
|
||||
_name = name or function.__name__
|
||||
producers = Hook._strategies.get(_name, [])
|
||||
|
||||
if len(producers) == 0:
|
||||
return function(*args, **kwargs)
|
||||
else:
|
||||
if len(producers) > 1:
|
||||
raise ValueError(
|
||||
f'Multiple functions registered to {_name}, '
|
||||
f'here is the list: {producers}')
|
||||
if isinstance(args[0], Hook):
|
||||
args = args[1:]
|
||||
return producers[0](*args, **kwargs)
|
||||
|
||||
return _call
|
||||
|
||||
return _register
|
||||
|
||||
@@ -8,16 +8,13 @@ if TYPE_CHECKING:
|
||||
from .base import LoggerHook
|
||||
from .tensorboard_hook import TensorboardHook
|
||||
from .text_logger_hook import TextLoggerHook
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'base': ['LoggerHook'],
|
||||
'tensorboard_hook': ['TensorboardHook'],
|
||||
'text_logger_hook': ['TextLoggerHook']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
|
||||
@@ -9,12 +9,11 @@ import torch
|
||||
from torch import distributed as dist
|
||||
|
||||
from modelscope.metainfo import Hooks
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.trainers.hooks.builder import HOOKS
|
||||
from modelscope.trainers.hooks.logger.base import LoggerHook
|
||||
from modelscope.utils.constant import LogKeys, ModeKeys
|
||||
from modelscope.utils.json_utils import EnhancedEncoder
|
||||
from modelscope.utils.torch_utils import get_dist_info, is_master
|
||||
from modelscope.utils.torch_utils import is_master
|
||||
from .base import LoggerHook
|
||||
|
||||
|
||||
@HOOKS.register_module(module_name=Hooks.TextLoggerHook)
|
||||
@@ -22,7 +21,7 @@ class TextLoggerHook(LoggerHook):
|
||||
"""Logger hook in text, Output log to both console and local json file.
|
||||
|
||||
Args:
|
||||
by_epoch (bool, optional): Whether EpochBasedtrainer is used.
|
||||
by_epoch (bool, optional): Whether EpochBasedTrainer is used.
|
||||
Default: True.
|
||||
interval (int, optional): Logging interval (every k iterations).
|
||||
It is interval of iterations even by_epoch is true. Default: 10.
|
||||
@@ -79,9 +78,7 @@ class TextLoggerHook(LoggerHook):
|
||||
mem_mb = torch.tensor([mem / (1024 * 1024)],
|
||||
dtype=torch.int,
|
||||
device=device)
|
||||
_, world_size = get_dist_info()
|
||||
if world_size > 1 and getattr(trainer.cfg.model, 'model_parallel_size',
|
||||
1) < world_size:
|
||||
if trainer._dist:
|
||||
dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX)
|
||||
return mem_mb.item()
|
||||
|
||||
|
||||
@@ -17,21 +17,41 @@ class LrSchedulerHook(Hook):
|
||||
by_epoch (bool): Whether lr changes by epoch
|
||||
warmup (dict): warm up config
|
||||
"""
|
||||
PRIORITY = Priority.VERY_HIGH
|
||||
PRIORITY = Priority.LOW
|
||||
|
||||
def __init__(self, by_epoch=True, warmup=None) -> None:
|
||||
def __init__(self, by_epoch=True, warmup=None, **kwargs) -> None:
|
||||
super().__init__()
|
||||
self.by_epoch = by_epoch
|
||||
self.warmup = warmup
|
||||
self.warmup_lr_scheduler = None
|
||||
|
||||
def before_run(self, trainer):
|
||||
self.initialize_lr_scheduler(trainer)
|
||||
if self.warmup is not None:
|
||||
assert isinstance(self.warmup, dict) and 'type' in self.warmup
|
||||
self.warmup_lr_scheduler = build_lr_scheduler(
|
||||
cfg=self.warmup,
|
||||
default_args={'base_scheduler': trainer.lr_scheduler})
|
||||
|
||||
@Hook.overload_func(name='LrSchedulerHook.initialize_lr_scheduler')
|
||||
def initialize_lr_scheduler(self, trainer):
|
||||
"""Initialize the lr scheduler.
|
||||
|
||||
This is a strategic function which can be registered by other hook's function.
|
||||
"""
|
||||
pass
|
||||
|
||||
@Hook.overload_func(name='LrSchedulerHook.step')
|
||||
def step(self, trainer):
|
||||
"""Do lr scheduler's step.
|
||||
|
||||
This is a strategic function which can be registered by other hook's function.
|
||||
"""
|
||||
if self.warmup_lr_scheduler is not None:
|
||||
self.warmup_lr_scheduler.step()
|
||||
else:
|
||||
trainer.lr_scheduler.step()
|
||||
|
||||
def get_current_lr(self, trainer):
|
||||
import torch
|
||||
|
||||
@@ -46,13 +66,10 @@ class LrSchedulerHook(Hook):
|
||||
'lr is not applicable because optimizer does not exist.')
|
||||
return lr
|
||||
|
||||
def before_train_iter(self, trainer):
|
||||
def after_train_iter(self, trainer):
|
||||
if not self.by_epoch and trainer.iter >= getattr(
|
||||
trainer, 'cumulative_iters', 1):
|
||||
if self.warmup_lr_scheduler is not None:
|
||||
self.warmup_lr_scheduler.step()
|
||||
else:
|
||||
trainer.lr_scheduler.step()
|
||||
trainer, 'cumulative_iters', 1) - 1:
|
||||
self.step(trainer)
|
||||
trainer.log_buffer.output[LogKeys.LR] = self._get_log_lr(trainer)
|
||||
|
||||
def before_train_epoch(self, trainer):
|
||||
@@ -60,10 +77,7 @@ class LrSchedulerHook(Hook):
|
||||
|
||||
def after_train_epoch(self, trainer):
|
||||
if self.by_epoch:
|
||||
if self.warmup_lr_scheduler is not None:
|
||||
self.warmup_lr_scheduler.step()
|
||||
else:
|
||||
trainer.lr_scheduler.step()
|
||||
self.step(trainer)
|
||||
|
||||
def _get_log_lr(self, trainer):
|
||||
cur_lr = self.get_current_lr(trainer)
|
||||
@@ -81,29 +95,28 @@ class LrSchedulerHook(Hook):
|
||||
|
||||
|
||||
@HOOKS.register_module(module_name=Hooks.PlateauLrSchedulerHook)
|
||||
class PlateauLrSchedulerHook(LrSchedulerHook):
|
||||
class PlateauLrSchedulerHook(Hook):
|
||||
"""Lr scheduler hook for `ReduceLROnPlateau`.
|
||||
|
||||
Args:
|
||||
metric_key (str): Metric key returned from `trainer.metric_values`,
|
||||
get the value of metric key and pass it to `ReduceLROnPlateau.step`.
|
||||
by_epoch (bool): Whether lr changes by epoch
|
||||
warmup (dict): warm up config
|
||||
"""
|
||||
PRIORITY = Priority.LOW # should be after EvaluationHook
|
||||
|
||||
def __init__(self, metric_key, by_epoch=True, warmup=None) -> None:
|
||||
super().__init__(by_epoch=by_epoch, warmup=warmup)
|
||||
def __init__(self, metric_key, **kwargs):
|
||||
self.metric_key = metric_key
|
||||
|
||||
def register_strategy(self):
|
||||
Hook.overload(name='LrSchedulerHook.step', function=self.step)
|
||||
|
||||
def before_run(self, trainer):
|
||||
super().before_run(trainer)
|
||||
if not hasattr(trainer, 'logger'):
|
||||
self.logger = get_logger()
|
||||
else:
|
||||
self.logger = trainer.logger
|
||||
|
||||
def after_train_epoch(self, trainer):
|
||||
def step(self, trainer):
|
||||
# adapt to evaluation intervel is greater than 1
|
||||
if trainer.metric_values is None:
|
||||
if is_master():
|
||||
@@ -113,10 +126,10 @@ class PlateauLrSchedulerHook(LrSchedulerHook):
|
||||
return
|
||||
|
||||
metrics = trainer.metric_values[self.metric_key]
|
||||
|
||||
if self.by_epoch:
|
||||
if self.warmup_lr_scheduler is not None:
|
||||
self.warmup_lr_scheduler.step(metrics=metrics)
|
||||
lr_scheduler_hook = trainer.get_hook(LrSchedulerHook)[0]
|
||||
if lr_scheduler_hook.by_epoch:
|
||||
if lr_scheduler_hook.warmup_lr_scheduler is not None:
|
||||
lr_scheduler_hook.warmup_lr_scheduler.step(metrics=metrics)
|
||||
else:
|
||||
trainer.lr_scheduler.step(metrics=metrics)
|
||||
|
||||
|
||||
126
modelscope/trainers/hooks/megatron_hook.py
Normal file
126
modelscope/trainers/hooks/megatron_hook.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
from megatron_util import mpu
|
||||
|
||||
from modelscope.metainfo import Hooks
|
||||
from modelscope.trainers.hooks.builder import HOOKS
|
||||
from modelscope.trainers.hooks.hook import Hook
|
||||
from modelscope.utils.checkpoint import load_checkpoint, save_checkpoint
|
||||
from .checkpoint_hook import CheckpointHook, LoadCheckpointHook
|
||||
|
||||
|
||||
@HOOKS.register_module(module_name=Hooks.MegatronHook)
|
||||
class MegatronHook(Hook):
|
||||
|
||||
_BIN_FILE_DIR = 'model'
|
||||
|
||||
def register_strategy(self):
|
||||
Hook.overload(
|
||||
name='CheckpointHook.should_save_on_rank',
|
||||
function=self.should_save_on_rank)
|
||||
Hook.overload(
|
||||
name='CheckpointHook.save_checkpoints',
|
||||
function=self.save_checkpoints)
|
||||
Hook.overload(
|
||||
name='LoadCheckpointHook.load_checkpoints',
|
||||
function=self.load_checkpoints)
|
||||
Hook.overload(
|
||||
name='CheckpointHook.remove_checkpoints',
|
||||
function=self.remove_checkpoints)
|
||||
Hook.overload(
|
||||
name='CheckpointHook.prepare_output', function=self.prepare_output)
|
||||
|
||||
def should_save_on_rank(self, trainer):
|
||||
# TODO
|
||||
return (not torch.distributed.is_initialized()
|
||||
) or mpu.get_data_parallel_rank() == 0
|
||||
|
||||
def rank_name(self):
|
||||
# TODO
|
||||
try:
|
||||
tp_world_size = mpu.get_tensor_model_parallel_world_size()
|
||||
if tp_world_size == 1:
|
||||
return ''
|
||||
mp_rank = mpu.get_tensor_model_parallel_rank()
|
||||
return '_mp_rank_{:02d}'.format(mp_rank)
|
||||
except (ImportError, AssertionError):
|
||||
return ''
|
||||
|
||||
def get_bin_file(self):
|
||||
mp_rank = mpu.get_tensor_model_parallel_rank()
|
||||
rank = '{:02d}'.format(mp_rank)
|
||||
return f'mp_rank_{rank}_model_states.pt'
|
||||
|
||||
def save_checkpoints(self,
|
||||
trainer,
|
||||
checkpoint_path_prefix,
|
||||
output_sub_dir,
|
||||
meta=None):
|
||||
model = trainer.unwrap_module(trainer.model)
|
||||
_train_state_file = checkpoint_path_prefix + self.rank_name(
|
||||
) + CheckpointHook.TRAINER_STATE_SUFFIX
|
||||
# Save pth file without model state_dict
|
||||
save_checkpoint(
|
||||
model,
|
||||
_train_state_file,
|
||||
trainer.optimizer,
|
||||
trainer.lr_scheduler,
|
||||
meta=meta,
|
||||
with_model=False)
|
||||
|
||||
save_dir = os.path.dirname(checkpoint_path_prefix)
|
||||
prefix = os.path.basename(checkpoint_path_prefix)
|
||||
bin_file = self.get_bin_file()
|
||||
prefix_bin_file = os.path.join(save_dir, prefix + '_' + bin_file)
|
||||
save_checkpoint(model, prefix_bin_file, with_meta=False)
|
||||
|
||||
src_file = prefix_bin_file
|
||||
dest_file = os.path.join(save_dir, output_sub_dir, self._BIN_FILE_DIR,
|
||||
bin_file)
|
||||
if os.path.isfile(dest_file):
|
||||
os.unlink(dest_file)
|
||||
|
||||
os.link(src_file, dest_file)
|
||||
|
||||
def remove_checkpoints(self, trainer, checkpoint_path_prefix):
|
||||
_train_state_file = checkpoint_path_prefix + self.rank_name(
|
||||
) + CheckpointHook.TRAINER_STATE_SUFFIX
|
||||
if os.path.isfile(_train_state_file):
|
||||
os.remove(_train_state_file)
|
||||
|
||||
save_dir = os.path.dirname(checkpoint_path_prefix)
|
||||
prefix = os.path.basename(checkpoint_path_prefix)
|
||||
bin_file = self.get_bin_file()
|
||||
absolute_file = os.path.join(save_dir, prefix + '_' + bin_file)
|
||||
if os.path.isfile(absolute_file):
|
||||
os.remove(absolute_file)
|
||||
|
||||
def load_checkpoints(self, checkpoint_path_prefix, trainer, load_all_state,
|
||||
strict):
|
||||
model = trainer.unwrap_module(trainer.model)
|
||||
if os.path.isdir(checkpoint_path_prefix):
|
||||
save_dir = checkpoint_path_prefix
|
||||
bin_file = self.get_bin_file()
|
||||
model_file = os.path.join(save_dir, bin_file)
|
||||
load_checkpoint(model_file, model, None, None)
|
||||
else:
|
||||
_train_state_file = checkpoint_path_prefix + self.rank_name(
|
||||
) + CheckpointHook.TRAINER_STATE_SUFFIX
|
||||
meta = LoadCheckpointHook.load_trainer_state(
|
||||
trainer, _train_state_file, load_all_state)
|
||||
|
||||
save_dir = os.path.dirname(checkpoint_path_prefix)
|
||||
prefix = os.path.basename(checkpoint_path_prefix)
|
||||
bin_file = self.get_bin_file()
|
||||
|
||||
model_file = os.path.join(save_dir, prefix + '_' + bin_file)
|
||||
load_checkpoint(model_file, model, None, None)
|
||||
return meta
|
||||
|
||||
def prepare_output(self, trainer, output_dir):
|
||||
config = trainer.cfg.to_dict()
|
||||
CheckpointHook.copy_files_and_dump_config(trainer, output_dir, config,
|
||||
self._BIN_FILE_DIR)
|
||||
os.makedirs(
|
||||
os.path.join(output_dir, self._BIN_FILE_DIR), exist_ok=True)
|
||||
@@ -7,16 +7,13 @@ if TYPE_CHECKING:
|
||||
from .apex_optimizer_hook import ApexAMPOptimizerHook
|
||||
from .base import OptimizerHook, NoneOptimizerHook
|
||||
from .torch_optimizer_hook import TorchAMPOptimizerHook
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'apex_optimizer_hook': ['ApexAMPOptimizerHook'],
|
||||
'base': ['OptimizerHook', 'NoneOptimizerHook'],
|
||||
'torch_optimizer_hook': ['TorchAMPOptimizerHook']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
|
||||
@@ -1,39 +1,31 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from modelscope.metainfo import Hooks
|
||||
from modelscope.trainers.hooks import Hook
|
||||
from modelscope.trainers.hooks.builder import HOOKS
|
||||
from .base import OptimizerHook
|
||||
|
||||
|
||||
@HOOKS.register_module(module_name=Hooks.ApexAMPOptimizerHook)
|
||||
class ApexAMPOptimizerHook(OptimizerHook):
|
||||
class ApexAMPOptimizerHook(Hook):
|
||||
"""
|
||||
Fp16 optimizer, if torch version is less than 1.6.0,
|
||||
you must install apex (https://www.github.com/nvidia/apex) else use torch.cuda.amp by default
|
||||
|
||||
Args:
|
||||
cumulative_iters (int): interval of gradients accumulation. Default: 1
|
||||
grad_clip (dict): Default None. Containing keys:
|
||||
max_norm (float or int): max norm of the gradients
|
||||
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm.
|
||||
More details please refer to `torch.nn.utils.clip_grad.clip_grad_norm_`
|
||||
loss_keys (str | list): keys list of loss
|
||||
opt_level (str): "O0" and "O3" are not true mixed precision,
|
||||
but they are useful for establishing accuracy and speed baselines, respectively.
|
||||
"O1" and "O2" are different implementations of mixed precision.
|
||||
Try both, and see what gives the best speedup and accuracy for your model.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
cumulative_iters=1,
|
||||
grad_clip=None,
|
||||
loss_keys='loss',
|
||||
opt_level='O1'):
|
||||
PRIORITY = OptimizerHook.PRIORITY
|
||||
|
||||
super(ApexAMPOptimizerHook, self).__init__(
|
||||
grad_clip=grad_clip, loss_keys=loss_keys)
|
||||
self.cumulative_iters = cumulative_iters
|
||||
def __init__(self, opt_level='O1', **kwargs):
|
||||
self.opt_level = opt_level
|
||||
|
||||
try:
|
||||
@@ -43,35 +35,43 @@ class ApexAMPOptimizerHook(OptimizerHook):
|
||||
'apex not installed, please install apex from https://www.github.com/nvidia/apex.'
|
||||
)
|
||||
|
||||
def before_run(self, trainer):
|
||||
def register_strategy(self):
|
||||
Hook.overload(
|
||||
name='OptimizerHook.initialize_optimizer',
|
||||
function=self.initialize_optimizer)
|
||||
Hook.overload(name='OptimizerHook.backward', function=self.backward)
|
||||
|
||||
def initialize_optimizer(self, trainer):
|
||||
from apex import amp
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('1.9.0'):
|
||||
trainer.logger.warning(
|
||||
'ApexAMPOptimizerHook is only tested on torch version 1.8.x,'
|
||||
'if it works abnormally please consider downgrading your torch version to 1.8.x.'
|
||||
)
|
||||
|
||||
logging.info('open fp16')
|
||||
# TODO: fix it should initialze amp with model not wrapper by DDP or DP
|
||||
if hasattr(trainer.model, 'module'):
|
||||
trainer.model, trainer.optimizer = amp.initialize(
|
||||
trainer.model.module,
|
||||
trainer.optimizer,
|
||||
opt_level=self.opt_level)
|
||||
else:
|
||||
trainer.model, trainer.optimizer = amp.initialize(
|
||||
trainer.model, trainer.optimizer, opt_level=self.opt_level)
|
||||
model = trainer.unwrap_module(trainer.model)
|
||||
trainer.model, trainer.optimizer = amp.initialize(
|
||||
model, trainer.optimizer, opt_level=self.opt_level)
|
||||
|
||||
trainer.optimizer.zero_grad()
|
||||
|
||||
def after_train_iter(self, trainer):
|
||||
for k in self.loss_keys:
|
||||
trainer.train_outputs[k] /= self.cumulative_iters
|
||||
def backward(self, trainer, loss_keys, cumulative_iters, grad_clip):
|
||||
for k in loss_keys:
|
||||
trainer.train_outputs[k] /= cumulative_iters
|
||||
|
||||
from apex import amp
|
||||
for k in self.loss_keys:
|
||||
for k in loss_keys:
|
||||
with amp.scale_loss(trainer.train_outputs[k],
|
||||
trainer.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
|
||||
if self.every_n_iters(trainer, self.cumulative_iters):
|
||||
if self.grad_clip is not None:
|
||||
self.clip_grads(trainer.model.parameters(), **self.grad_clip)
|
||||
if self.every_n_iters(trainer, cumulative_iters):
|
||||
if grad_clip is not None:
|
||||
OptimizerHook.clip_grads(trainer.model.parameters(),
|
||||
**grad_clip)
|
||||
|
||||
trainer.optimizer.step()
|
||||
trainer.optimizer.zero_grad()
|
||||
|
||||
@@ -28,7 +28,8 @@ class OptimizerHook(Hook):
|
||||
def __init__(self,
|
||||
cumulative_iters=1,
|
||||
grad_clip=None,
|
||||
loss_keys=OutputKeys.LOSS) -> None:
|
||||
loss_keys=OutputKeys.LOSS,
|
||||
**kwargs) -> None:
|
||||
if isinstance(loss_keys, str):
|
||||
loss_keys = [loss_keys]
|
||||
assert isinstance(loss_keys, (tuple, list))
|
||||
@@ -36,28 +37,52 @@ class OptimizerHook(Hook):
|
||||
self.cumulative_iters = cumulative_iters
|
||||
self.grad_clip = grad_clip
|
||||
|
||||
def clip_grads(self, params, **clip_args):
|
||||
@staticmethod
|
||||
def clip_grads(params, **clip_args):
|
||||
params = list(
|
||||
filter(lambda p: p.requires_grad and p.grad is not None, params))
|
||||
if len(params) > 0:
|
||||
return clip_grad.clip_grad_norm_(params, **clip_args)
|
||||
|
||||
def before_run(self, trainer):
|
||||
@Hook.overload_func(name='OptimizerHook.initialize_optimizer')
|
||||
def initialize_optimizer(self, trainer):
|
||||
"""Initialize the optimizer.
|
||||
|
||||
This is a strategic function which can be registered by other hook's function.
|
||||
"""
|
||||
trainer.optimizer.zero_grad()
|
||||
|
||||
def before_run(self, trainer):
|
||||
self.initialize_optimizer(trainer)
|
||||
trainer.cumulative_iters = self.cumulative_iters
|
||||
|
||||
def after_train_iter(self, trainer):
|
||||
for k in self.loss_keys:
|
||||
trainer.train_outputs[k] /= self.cumulative_iters
|
||||
@Hook.overload_func(name='OptimizerHook.backward')
|
||||
def backward(self, trainer, loss_keys, cumulative_iters, grad_clip):
|
||||
"""Do module backward, optimizer's step and zero_grad and clip the grads.
|
||||
|
||||
This is a strategic function which can be registered by other hook's function.
|
||||
|
||||
Args:
|
||||
trainer(`EpochBasedTrainer`): The trainer instance.
|
||||
loss_keys(`list`): The list of loss keys.
|
||||
cumulative_iters(`int`): The cumulative iters for gradients.
|
||||
grad_clip(`dict`): The grad clipping options.
|
||||
"""
|
||||
for k in loss_keys:
|
||||
trainer.train_outputs[k] /= cumulative_iters
|
||||
trainer.train_outputs[k].backward()
|
||||
|
||||
if self.every_n_iters(trainer, self.cumulative_iters):
|
||||
if self.grad_clip is not None:
|
||||
self.clip_grads(trainer.model.parameters(), **self.grad_clip)
|
||||
if self.every_n_iters(trainer, cumulative_iters):
|
||||
if grad_clip is not None:
|
||||
self.clip_grads(trainer.model.parameters(), **grad_clip)
|
||||
|
||||
trainer.optimizer.step()
|
||||
trainer.optimizer.zero_grad()
|
||||
|
||||
def after_train_iter(self, trainer):
|
||||
self.backward(trainer, self.loss_keys, self.cumulative_iters,
|
||||
self.grad_clip)
|
||||
|
||||
|
||||
@HOOKS.register_module(module_name=Hooks.NoneOptimizerHook)
|
||||
class NoneOptimizerHook(OptimizerHook):
|
||||
|
||||
@@ -2,12 +2,13 @@
|
||||
import logging
|
||||
|
||||
from modelscope.metainfo import Hooks
|
||||
from modelscope.trainers.hooks import Hook
|
||||
from modelscope.trainers.hooks.builder import HOOKS
|
||||
from .base import OptimizerHook
|
||||
|
||||
|
||||
@HOOKS.register_module(module_name=Hooks.TorchAMPOptimizerHook)
|
||||
class TorchAMPOptimizerHook(OptimizerHook):
|
||||
class TorchAMPOptimizerHook(Hook):
|
||||
"""
|
||||
Fp16 optimizer, if torch version is less than 1.6.0,
|
||||
you must install apex (https://www.github.com/nvidia/apex) else use torch.cuda.amp by default
|
||||
@@ -26,15 +27,9 @@ class TorchAMPOptimizerHook(OptimizerHook):
|
||||
please refer to: https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler for the parameters.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
cumulative_iters=1,
|
||||
grad_clip=None,
|
||||
loss_keys='loss',
|
||||
loss_scale={}):
|
||||
PRIORITY = OptimizerHook.PRIORITY
|
||||
|
||||
super(TorchAMPOptimizerHook, self).__init__(
|
||||
grad_clip=grad_clip, loss_keys=loss_keys)
|
||||
self.cumulative_iters = cumulative_iters
|
||||
def __init__(self, loss_scale={}, **kwargs):
|
||||
self._scale_update_param = None
|
||||
|
||||
from torch.cuda import amp
|
||||
@@ -49,34 +44,36 @@ class TorchAMPOptimizerHook(OptimizerHook):
|
||||
'`loss_scale` type must be in [float, dict], but got {loss_scale}'
|
||||
)
|
||||
|
||||
def before_run(self, trainer):
|
||||
def register_strategy(self):
|
||||
Hook.overload(
|
||||
name='OptimizerHook.initialize_optimizer',
|
||||
function=self.initialize_optimizer)
|
||||
Hook.overload(name='OptimizerHook.backward', function=self.backward)
|
||||
|
||||
def initialize_optimizer(self, trainer):
|
||||
logging.info('open fp16')
|
||||
trainer.optimizer.zero_grad()
|
||||
|
||||
if hasattr(trainer.model, 'module'):
|
||||
self._ori_model_forward = trainer.model.module.forward
|
||||
self._model = trainer.model.module
|
||||
else:
|
||||
self._ori_model_forward = trainer.model.forward
|
||||
self._model = trainer.model
|
||||
|
||||
self.ori_model_forward = trainer.model.forward
|
||||
model = trainer.unwrap_module(trainer.model)
|
||||
self._ori_model_forward = model.forward
|
||||
self._model = model
|
||||
|
||||
def before_train_iter(self, trainer):
|
||||
from torch.cuda import amp
|
||||
setattr(self._model, 'forward', amp.autocast()(self._model.forward))
|
||||
|
||||
def after_train_iter(self, trainer):
|
||||
for k in self.loss_keys:
|
||||
trainer.train_outputs[k] /= self.cumulative_iters
|
||||
def backward(self, trainer, loss_keys, cumulative_iters, grad_clip):
|
||||
for k in loss_keys:
|
||||
trainer.train_outputs[k] /= cumulative_iters
|
||||
|
||||
for k in self.loss_keys:
|
||||
for k in loss_keys:
|
||||
self.scaler.scale(trainer.train_outputs[k]).backward()
|
||||
|
||||
if self.every_n_iters(trainer, self.cumulative_iters):
|
||||
if self.every_n_iters(trainer, cumulative_iters):
|
||||
self.scaler.unscale_(trainer.optimizer)
|
||||
if self.grad_clip is not None:
|
||||
self.clip_grads(trainer.model.parameters(), **self.grad_clip)
|
||||
if grad_clip is not None:
|
||||
OptimizerHook.clip_grads(trainer.model.parameters(),
|
||||
**grad_clip)
|
||||
|
||||
self.scaler.step(trainer.optimizer)
|
||||
self.scaler.update(self._scale_update_param)
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import os
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from deepspeed import DeepSpeedEngine
|
||||
from megatron_util import mpu
|
||||
from torch import nn
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.models.base import Model, TorchModel
|
||||
from modelscope.models.base import TorchModel
|
||||
from modelscope.models.nlp.plug import DistributedPlug
|
||||
from modelscope.models.nlp.plug.backbone import BertLayerNorm
|
||||
from modelscope.models.nlp.plug.generator import TextGenerator
|
||||
@@ -28,6 +29,7 @@ class PlugTrainer(NlpEpochBasedTrainer):
|
||||
master_ip=master_ip,
|
||||
master_port=master_port,
|
||||
**self.cfg.model)
|
||||
self.unwrap_module(model.model).model_dir = self.model_dir
|
||||
return model.model
|
||||
|
||||
def to_parallel(self, model) -> Union[nn.Module, TorchModel]:
|
||||
@@ -160,11 +162,16 @@ class PlugTrainer(NlpEpochBasedTrainer):
|
||||
|
||||
def evaluation_step(self, data):
|
||||
# wapper 1: DeepspeedEngine, wapper 2: DDP
|
||||
model = self.model.module
|
||||
# model = self.model.module
|
||||
if isinstance(self.model, DeepSpeedEngine):
|
||||
model = self.model.module
|
||||
else:
|
||||
model = self.model
|
||||
|
||||
model.eval()
|
||||
|
||||
# model: fp16 wapper; model.module : distributedPlug
|
||||
vocab_size = model.module.config.original_vocab_size
|
||||
vocab_size = self.unwrap_module(self.model).config.original_vocab_size
|
||||
batch_size = data['input_ids'].shape[0]
|
||||
beam_generator = TextGenerator(model,
|
||||
self.eval_preprocessor.nlp_tokenizer,
|
||||
|
||||
@@ -151,9 +151,9 @@ class VecoTrainer(NlpEpochBasedTrainer):
|
||||
|
||||
"""
|
||||
from modelscope.msdatasets.task_datasets import VecoDataset
|
||||
if checkpoint_path is not None and os.path.isfile(checkpoint_path):
|
||||
from modelscope.trainers.hooks import CheckpointHook
|
||||
CheckpointHook.load_checkpoint(checkpoint_path, self)
|
||||
if checkpoint_path is not None:
|
||||
from modelscope.trainers.hooks import LoadCheckpointHook
|
||||
LoadCheckpointHook.load_checkpoint(checkpoint_path, self)
|
||||
self.model.eval()
|
||||
self._mode = ModeKeys.EVAL
|
||||
metric_values = {}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import inspect
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from distutils.version import LooseVersion
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
@@ -40,9 +40,9 @@ from modelscope.utils.file_utils import func_receive_dict_inputs
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.megatron_utils import is_megatron_initialized
|
||||
from modelscope.utils.registry import build_from_cfg
|
||||
from modelscope.utils.torch_utils import (broadcast, get_dist_info,
|
||||
get_local_rank, init_dist, is_dist,
|
||||
is_master, set_random_seed)
|
||||
from modelscope.utils.torch_utils import (get_dist_info, get_local_rank,
|
||||
init_dist, is_dist, is_master,
|
||||
set_random_seed)
|
||||
from .base import BaseTrainer
|
||||
from .builder import TRAINERS
|
||||
from .default_config import merge_cfg, merge_hooks
|
||||
@@ -140,15 +140,6 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
# add default config
|
||||
merge_cfg(self.cfg)
|
||||
self.cfg = self.rebuild_config(self.cfg)
|
||||
self.logger = get_logger(log_level=self.cfg.get('log_level', 'INFO'))
|
||||
self.logger.info(
|
||||
'==========================Training Config Start=========================='
|
||||
)
|
||||
self.logger.info(
|
||||
json.dumps(self.cfg._cfg_dict, indent=4, cls=JSONIteratorEncoder))
|
||||
self.logger.info(
|
||||
'===========================Training Config End==========================='
|
||||
)
|
||||
if 'cfg_options' in kwargs:
|
||||
self.cfg.merge_from_dict(kwargs['cfg_options'])
|
||||
|
||||
@@ -177,6 +168,17 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
self.logger = get_logger(
|
||||
log_file=log_file, log_level=self.cfg.get('log_level', 'INFO'))
|
||||
|
||||
if is_master():
|
||||
self.logger.info(
|
||||
'==========================Training Config Start=========================='
|
||||
)
|
||||
self.logger.info(
|
||||
json.dumps(
|
||||
self.cfg._cfg_dict, indent=4, cls=JSONIteratorEncoder))
|
||||
self.logger.info(
|
||||
'===========================Training Config End==========================='
|
||||
)
|
||||
|
||||
self.train_dataset = self.to_task_dataset(
|
||||
train_dataset,
|
||||
mode=ModeKeys.TRAIN,
|
||||
@@ -208,6 +210,8 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
# model placement
|
||||
self.place_model()
|
||||
|
||||
Hook.clear_strategies()
|
||||
|
||||
def place_model(self):
|
||||
"""Place model to device, or to DDP
|
||||
"""
|
||||
@@ -504,22 +508,20 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
metrics = [metrics]
|
||||
return metrics
|
||||
|
||||
def set_checkpoint_file_to_hook(self, checkpoint_path, load_all_state):
|
||||
def set_checkpoint_file_to_hook(self, checkpoint_path, load_all_state,
|
||||
strict):
|
||||
if checkpoint_path is not None:
|
||||
if os.path.isfile(checkpoint_path):
|
||||
from modelscope.trainers.hooks import LoadCheckpointHook
|
||||
load_ckpt_hooks = list(
|
||||
filter(lambda hook: isinstance(hook, LoadCheckpointHook),
|
||||
self.hooks))
|
||||
if len(load_ckpt_hooks) == 0:
|
||||
load_ckpt_hook = LoadCheckpointHook()
|
||||
self.hooks.append(load_ckpt_hook)
|
||||
load_ckpt_hooks.append(load_ckpt_hook)
|
||||
load_ckpt_hooks[0].checkpoint_file = checkpoint_path
|
||||
load_ckpt_hooks[0].load_all_state = load_all_state
|
||||
else:
|
||||
self.logger.error(
|
||||
f'No {checkpoint_path} found in local file system.')
|
||||
from modelscope.trainers.hooks import LoadCheckpointHook
|
||||
load_ckpt_hooks = list(
|
||||
filter(lambda hook: isinstance(hook, LoadCheckpointHook),
|
||||
self.hooks))
|
||||
if len(load_ckpt_hooks) == 0:
|
||||
load_ckpt_hook = LoadCheckpointHook()
|
||||
self.register_hook(load_ckpt_hook)
|
||||
load_ckpt_hooks.append(load_ckpt_hook)
|
||||
load_ckpt_hooks[0].checkpoint_file = checkpoint_path
|
||||
load_ckpt_hooks[0].load_all_state = load_all_state
|
||||
load_ckpt_hooks[0].strict = strict
|
||||
|
||||
def train(self,
|
||||
checkpoint_path=None,
|
||||
@@ -534,6 +536,8 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
load_all_state(`bool`: `optional`): Load all state out of the `checkpoint_path` file, including the
|
||||
state dict of model, optimizer, lr_scheduler, the random state and epoch/iter number. If False, only
|
||||
the model's state dict will be read, and model will be trained again.
|
||||
kwargs:
|
||||
strict(`boolean`): If strict, any unmatched keys will cause an error.
|
||||
"""
|
||||
|
||||
self._mode = ModeKeys.TRAIN
|
||||
@@ -542,7 +546,10 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
self.register_optimizers_hook()
|
||||
hooks = merge_hooks(self.cfg)
|
||||
self.register_hook_from_cfg(hooks)
|
||||
self.set_checkpoint_file_to_hook(checkpoint_path, load_all_state)
|
||||
if is_master():
|
||||
self.logger.info(self.get_hook_info())
|
||||
self.set_checkpoint_file_to_hook(checkpoint_path, load_all_state,
|
||||
kwargs.get('strict', False))
|
||||
self.model.train()
|
||||
|
||||
self.train_loop(self.train_dataloader)
|
||||
@@ -550,7 +557,8 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
def predict(self,
|
||||
predict_datasets: Union[Dataset, List[Dataset]],
|
||||
saving_fn,
|
||||
checkpoint_path=None):
|
||||
checkpoint_path=None,
|
||||
strict=False):
|
||||
"""Start prediction.
|
||||
|
||||
Args:
|
||||
@@ -575,11 +583,18 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
checkpoint_path(`str`, `optional`): The previous saving checkpoint to read,
|
||||
usually it's a `some-file-name.pth` file or a pure PyTorch `some-file.bin` file
|
||||
generated by this trainer.
|
||||
"""
|
||||
|
||||
if checkpoint_path is not None and os.path.isfile(checkpoint_path):
|
||||
strict(`boolean`): If strict, any unmatched keys will cause an error.
|
||||
"""
|
||||
if not self._hooks:
|
||||
hooks = merge_hooks(self.cfg)
|
||||
self.register_hook_from_cfg(hooks)
|
||||
if is_master():
|
||||
self.logger.info(self.get_hook_info())
|
||||
if checkpoint_path is not None:
|
||||
from modelscope.trainers.hooks import LoadCheckpointHook
|
||||
LoadCheckpointHook.load_checkpoint(checkpoint_path, self)
|
||||
LoadCheckpointHook.load_checkpoint(
|
||||
checkpoint_path, self, strict=strict)
|
||||
self.model.eval()
|
||||
self._mode = ModeKeys.EVAL
|
||||
predict_dataloader = self.get_predict_data_loader(predict_datasets)
|
||||
@@ -610,10 +625,18 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
>>> with open(self.filename, 'a') as f:
|
||||
>>> for id, pred in zip(ids, predictions):
|
||||
>>> f.writelines(f'{id}, {pred}')
|
||||
kwargs:
|
||||
strict(`boolean`): If strict, any unmatched keys will cause an error.
|
||||
"""
|
||||
if checkpoint_path is not None and os.path.isfile(checkpoint_path):
|
||||
if not self._hooks:
|
||||
hooks = merge_hooks(self.cfg)
|
||||
self.register_hook_from_cfg(hooks)
|
||||
if is_master():
|
||||
self.logger.info(self.get_hook_info())
|
||||
if checkpoint_path is not None:
|
||||
from modelscope.trainers.hooks import LoadCheckpointHook
|
||||
LoadCheckpointHook.load_checkpoint(checkpoint_path, self)
|
||||
LoadCheckpointHook.load_checkpoint(
|
||||
checkpoint_path, self, strict=kwargs.get('strict', False))
|
||||
self.model.eval()
|
||||
self._mode = ModeKeys.EVAL
|
||||
self.eval_dataloader = self.get_eval_data_loader()
|
||||
@@ -650,9 +673,10 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
def to_parallel(self, model) -> Union[nn.Module, TorchModel]:
|
||||
# config format to reserve custom ddp
|
||||
if self.cfg.get('parallel', None) is not None:
|
||||
self.cfg.parallel.update(
|
||||
dp_cfg = deepcopy(self.cfg['parallel'])
|
||||
dp_cfg.update(
|
||||
dict(module=model, device_ids=[torch.cuda.current_device()]))
|
||||
return build_parallel(self.cfg.parallel)
|
||||
return build_parallel(dp_cfg)
|
||||
|
||||
dp_cfg = dict(
|
||||
type='DistributedDataParallel',
|
||||
@@ -669,6 +693,18 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
|
||||
return build_parallel(dp_cfg)
|
||||
|
||||
def unwrap_module(self, model) -> Union[nn.Module, TorchModel]:
|
||||
"""Unwrap the model until it's a naked nn.Module.
|
||||
|
||||
Args:
|
||||
model: An module.
|
||||
"""
|
||||
if hasattr(model, 'module'):
|
||||
return self.unwrap_module(model.module)
|
||||
else:
|
||||
assert isinstance(model, torch.nn.Module)
|
||||
return model
|
||||
|
||||
def train_step(self, model, inputs):
|
||||
""" Perform a training step on a batch of inputs.
|
||||
|
||||
@@ -691,11 +727,8 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
self._mode = ModeKeys.TRAIN
|
||||
# call model forward but not __call__ to skip postprocess
|
||||
|
||||
if is_parallel(model):
|
||||
receive_dict_inputs = func_receive_dict_inputs(
|
||||
model.module.forward)
|
||||
else:
|
||||
receive_dict_inputs = func_receive_dict_inputs(model.forward)
|
||||
receive_dict_inputs = func_receive_dict_inputs(
|
||||
self.unwrap_module(self.model).forward)
|
||||
|
||||
if isinstance(inputs, Mapping) and not receive_dict_inputs:
|
||||
train_outputs = model.forward(**inputs)
|
||||
@@ -730,25 +763,9 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
self.train_outputs = train_outputs
|
||||
|
||||
def prediction_step(self, model, inputs):
|
||||
""" Perform forward step by `model` using `inputs`.
|
||||
|
||||
Args:
|
||||
model (`TorchModel`): The model to evaluate.
|
||||
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
|
||||
The inputs and targets of the model.
|
||||
|
||||
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
||||
argument `labels`. Check your model's documentation for all accepted arguments.
|
||||
prediction_loss_only (`bool`):
|
||||
Whether or not to return the loss only.
|
||||
ignore_keys (`Lst[str]`, *optional*):
|
||||
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
|
||||
gathering predictions.
|
||||
|
||||
Return:
|
||||
Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
|
||||
logits and labels (each being optional).
|
||||
"""Deprecated method
|
||||
"""
|
||||
self.logger.warn('This prediction_step method is deprecated.')
|
||||
raise NotImplementedError
|
||||
|
||||
def get_train_dataloader(self):
|
||||
@@ -842,7 +859,9 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
def build_optimizer(self, cfg: ConfigDict, default_args: dict = None):
|
||||
try:
|
||||
return build_optimizer(
|
||||
self.model, cfg=cfg, default_args=default_args)
|
||||
self.unwrap_module(self.model),
|
||||
cfg=cfg,
|
||||
default_args=default_args)
|
||||
except KeyError as e:
|
||||
self.logger.error(
|
||||
f'Build optimizer error, the optimizer {cfg} is a torch native component, '
|
||||
@@ -900,12 +919,12 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
_, lr_scheduler, optim_options, lr_options = self.create_optimizer_and_scheduler(
|
||||
)
|
||||
|
||||
optim_hook = self.cfg.train.get('optimizer_hook', None)
|
||||
lr_hook = self.cfg.train.get('lr_scheduler_hook', None)
|
||||
optim_hook = self.cfg.train.get('optimizer_hook', {})
|
||||
lr_hook = self.cfg.train.get('lr_scheduler_hook', {})
|
||||
|
||||
# adapt to `ReduceLROnPlateau`
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
if isinstance(lr_scheduler, ReduceLROnPlateau) and lr_hook is None:
|
||||
if isinstance(lr_scheduler, ReduceLROnPlateau) and not lr_hook:
|
||||
plateau_cfg = {
|
||||
'train': {
|
||||
'lr_scheduler_hook': {
|
||||
@@ -921,16 +940,54 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
'Must add `lr_scheduler_hook` to configuration for `ReduceLROnPlateau` lr scheduler as follows:'
|
||||
+ '\n' + plateau_cfg)
|
||||
|
||||
if lr_hook is None:
|
||||
lr_hook = dict(type='LrSchedulerHook', **lr_options)
|
||||
if optim_hook is None:
|
||||
if self.use_fp16:
|
||||
optim_hook = dict(
|
||||
type='TorchAMPOptimizerHook', **optim_options)
|
||||
else:
|
||||
optim_hook = dict(type='OptimizerHook', **optim_options)
|
||||
def _fit_to_old_keys():
|
||||
"""This function used to fit `optimizer_hook` key and `lr_scheduler_hook` key for easycv configs.
|
||||
|
||||
self.register_hook_from_cfg([lr_hook, optim_hook])
|
||||
The logic is:
|
||||
If the optimizer_hook is provided and it's not TorchAMPOptimizerHook or ApexAMPOptimizerHook,
|
||||
(which means the hook is a complete one for optimization, which does not need the OptimizerHook),
|
||||
The OptimizerHook will not be registered, or else the OptimizerHook will be registered.
|
||||
|
||||
Same logic to the LrSchedulerHook, the only difference is the condition of lr_scheduler_hook is
|
||||
PlateauLrSchedulerHook.
|
||||
|
||||
If TorchAMPOptimizerHook or ApexAMPOptimizerHook is provided, self.use_fp16 will be set to False
|
||||
in case of the duplication of registration.
|
||||
|
||||
"""
|
||||
if lr_hook:
|
||||
self.register_hook_from_cfg([lr_hook])
|
||||
|
||||
_lr_options = None
|
||||
if not lr_hook or lr_hook.get('type') == 'PlateauLrSchedulerHook':
|
||||
lr_hook.pop('type', None)
|
||||
_lr_options = {**lr_options, **lr_hook}
|
||||
|
||||
if optim_hook:
|
||||
self.register_hook_from_cfg([optim_hook])
|
||||
|
||||
_optim_options = None
|
||||
if optim_hook.get('type') in ('TorchAMPOptimizerHook',
|
||||
'ApexAMPOptimizerHook'):
|
||||
self.use_fp16 = False
|
||||
if not optim_hook or optim_hook.get('type') in (
|
||||
'TorchAMPOptimizerHook', 'ApexAMPOptimizerHook'):
|
||||
optim_hook.pop('type', None)
|
||||
_optim_options = {**optim_options, **optim_hook}
|
||||
|
||||
return _optim_options, _lr_options
|
||||
|
||||
optim_options, lr_options = _fit_to_old_keys()
|
||||
|
||||
if optim_options is not None:
|
||||
self.register_hook_from_cfg(
|
||||
[dict(type='OptimizerHook', **optim_options)])
|
||||
if lr_options is not None:
|
||||
self.register_hook_from_cfg(
|
||||
[dict(type='LrSchedulerHook', **lr_options)])
|
||||
if self.use_fp16:
|
||||
self.register_hook_from_cfg(
|
||||
[dict(type='TorchAMPOptimizerHook', **optim_options)])
|
||||
|
||||
def _build_dataloader_with_dataset(self,
|
||||
dataset: Dataset,
|
||||
@@ -979,10 +1036,7 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
batch_size = batch_size_per_gpu
|
||||
num_workers = workers_per_gpu
|
||||
|
||||
if dist and not isinstance(
|
||||
dataset,
|
||||
torch.utils.data.IterableDataset) and self.cfg.model.get(
|
||||
'model_parallel_size', 1) == 1:
|
||||
if dist and not isinstance(dataset, torch.utils.data.IterableDataset):
|
||||
sampler = DistributedSampler(
|
||||
dataset, num_replicas=world_size, rank=rank, shuffle=shuffle)
|
||||
else:
|
||||
@@ -1054,20 +1108,16 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
Subclass and override to inject custom behavior.
|
||||
|
||||
"""
|
||||
model = self.model.module if self._dist else self.model
|
||||
model.eval()
|
||||
self.model.eval()
|
||||
|
||||
if is_parallel(model):
|
||||
receive_dict_inputs = func_receive_dict_inputs(
|
||||
model.module.forward)
|
||||
else:
|
||||
receive_dict_inputs = func_receive_dict_inputs(model.forward)
|
||||
receive_dict_inputs = func_receive_dict_inputs(
|
||||
self.unwrap_module(self.model).forward)
|
||||
|
||||
with torch.no_grad():
|
||||
if isinstance(data, Mapping) and not receive_dict_inputs:
|
||||
result = model.forward(**data)
|
||||
result = self.model.forward(**data)
|
||||
else:
|
||||
result = model.forward(data)
|
||||
result = self.model.forward(data)
|
||||
return result
|
||||
|
||||
def evaluation_loop(self, data_loader, metric_classes):
|
||||
@@ -1080,7 +1130,7 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
vis_closure = partial(
|
||||
self.visualization, dataset=self.eval_dataset, **vis_cfg)
|
||||
|
||||
if self._dist and self.cfg.model.get('model_parallel_size', 1) == 1:
|
||||
if self._dist:
|
||||
from modelscope.trainers.utils.inference import multi_gpu_test
|
||||
# list of batched result and data samples
|
||||
metric_values = multi_gpu_test(
|
||||
@@ -1151,7 +1201,7 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
if not inserted:
|
||||
self._hooks.insert(0, hook)
|
||||
|
||||
def register_hook_from_cfg(self, hook_cfg: List) -> None:
|
||||
def register_hook_from_cfg(self, hook_cfg: List) -> List:
|
||||
"""Register a hook from its cfg.
|
||||
|
||||
Args:
|
||||
@@ -1161,12 +1211,23 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
Note:
|
||||
The specific hook class to register should not use 'type' and
|
||||
'priority' arguments during initialization.
|
||||
|
||||
Returns:
|
||||
A list of instances of registered hooks.
|
||||
"""
|
||||
hook_cfg = hook_cfg.copy()
|
||||
assert isinstance(hook_cfg, list)
|
||||
hooks = []
|
||||
for cfg_i in hook_cfg:
|
||||
hook = build_from_cfg(cfg_i, HOOKS)
|
||||
if hasattr(hook, 'register_strategy'):
|
||||
hook.register_strategy()
|
||||
self.register_hook(hook)
|
||||
hooks.append(hook)
|
||||
return hooks
|
||||
|
||||
def get_hook(self, cls):
|
||||
return [h for h in self._hooks if h.__class__ == cls]
|
||||
|
||||
def invoke_hook(self, fn_name: str) -> None:
|
||||
"""Call all hooks.
|
||||
@@ -1183,9 +1244,9 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
stage_hook_map: Dict[str, list] = {stage: [] for stage in Hook.stages}
|
||||
for hook in self.hooks:
|
||||
try:
|
||||
priority = Priority(hook.priority).name # type: ignore
|
||||
except ValueError:
|
||||
priority = hook.priority # type: ignore
|
||||
priority = Priority(hook.PRIORITY).name # type: ignore
|
||||
except Exception:
|
||||
priority = Priority.NORMAL # type: ignore
|
||||
classname = hook.__class__.__name__
|
||||
hook_info = f'({priority:<12}) {classname:<35}'
|
||||
for trigger_stage in hook.get_triggered_stages():
|
||||
@@ -1195,11 +1256,19 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
for stage in Hook.stages:
|
||||
hook_infos = stage_hook_map[stage]
|
||||
if len(hook_infos) > 0:
|
||||
info = f'{stage}:\n'
|
||||
info += '\n'.join(hook_infos)
|
||||
info = f'Stage: {stage}:\n '
|
||||
info += '\n '.join(hook_infos)
|
||||
info += '\n -------------------- '
|
||||
stage_hook_infos.append(info)
|
||||
return '\n'.join(stage_hook_infos)
|
||||
stage_hook_infos = '\n'.join(stage_hook_infos)
|
||||
|
||||
strategy_info = '\n --- Hook strategies info --- \n'
|
||||
for consumer, methods in Hook._strategies.items():
|
||||
strategy_info += f'Method: {consumer} ' \
|
||||
f'replaced by: ' \
|
||||
f'{[method.__self__.__class__.__name__ + "." + method.__name__ for method in methods]}\n'
|
||||
strategy_info += '\n --- Hook strategies info end --- \n'
|
||||
return stage_hook_infos + strategy_info
|
||||
|
||||
|
||||
def worker_init_fn(worker_id, num_workers, rank, seed):
|
||||
|
||||
@@ -5,6 +5,7 @@ import os
|
||||
import re
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from shutil import copytree, ignore_patterns, rmtree
|
||||
from typing import Callable, Dict, Optional, Union
|
||||
|
||||
@@ -16,7 +17,7 @@ from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
from modelscope import __version__
|
||||
from modelscope.fileio import File, LocalStorage
|
||||
from modelscope.utils.config import JSONIteratorEncoder
|
||||
from modelscope.utils.config import Config, JSONIteratorEncoder
|
||||
from modelscope.utils.constant import ConfigFields, ModelFile
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.torch_utils import is_master
|
||||
@@ -48,7 +49,8 @@ def save_checkpoint(model: torch.nn.Module,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
lr_scheduler: Optional[_LRScheduler] = None,
|
||||
meta: Optional[dict] = None,
|
||||
with_meta: bool = True) -> None:
|
||||
with_meta: bool = True,
|
||||
with_model: bool = True) -> None:
|
||||
"""Save checkpoint to file.
|
||||
|
||||
The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
|
||||
@@ -60,26 +62,30 @@ def save_checkpoint(model: torch.nn.Module,
|
||||
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
|
||||
lr_scheduler(:obj:`_LRScheduler`, optional): LRScheduler to be saved.
|
||||
meta (dict, optional): Metadata to be saved in checkpoint.
|
||||
with_meta (bool, optional):
|
||||
with_meta (bool, optional): Save meta info.
|
||||
with_model(bool, optional): Save model states.
|
||||
"""
|
||||
if meta is None:
|
||||
meta = {}
|
||||
elif not isinstance(meta, dict):
|
||||
raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
|
||||
meta.update(modelscope=__version__, time=time.asctime())
|
||||
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
model = model.module
|
||||
|
||||
if hasattr(model, 'CLASSES') and model.CLASSES is not None:
|
||||
# save class name to the meta
|
||||
meta.update(CLASSES=model.CLASSES)
|
||||
checkpoint = {}
|
||||
if not with_meta and not with_model:
|
||||
raise ValueError(
|
||||
'Save meta by "with_meta=True" or model by "with_model=True"')
|
||||
|
||||
if with_meta:
|
||||
checkpoint = {
|
||||
'meta': meta,
|
||||
'state_dict': weights_to_cpu(model.state_dict())
|
||||
}
|
||||
if meta is None:
|
||||
meta = {}
|
||||
elif not isinstance(meta, dict):
|
||||
raise TypeError(
|
||||
f'meta must be a dict or None, but got {type(meta)}')
|
||||
meta.update(modelscope=__version__, time=time.asctime())
|
||||
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
model = model.module
|
||||
|
||||
if hasattr(model, 'CLASSES') and model.CLASSES is not None:
|
||||
# save class name to the meta
|
||||
meta.update(CLASSES=model.CLASSES)
|
||||
|
||||
checkpoint['meta'] = meta
|
||||
|
||||
# save optimizer state dict in the checkpoint
|
||||
if isinstance(optimizer, Optimizer):
|
||||
@@ -92,8 +98,13 @@ def save_checkpoint(model: torch.nn.Module,
|
||||
# save lr_scheduler state dict in the checkpoint
|
||||
if lr_scheduler is not None and hasattr(lr_scheduler, 'state_dict'):
|
||||
checkpoint['lr_scheduler'] = lr_scheduler.state_dict()
|
||||
else:
|
||||
checkpoint = weights_to_cpu(model.state_dict())
|
||||
|
||||
if with_model:
|
||||
_weights = weights_to_cpu(model.state_dict())
|
||||
if not with_meta:
|
||||
checkpoint = _weights
|
||||
else:
|
||||
checkpoint['state_dict'] = _weights
|
||||
|
||||
with io.BytesIO() as f:
|
||||
torch.save(checkpoint, f)
|
||||
@@ -134,9 +145,10 @@ def load_checkpoint(filename,
|
||||
f'The state dict of lr_scheduler cannot be found in checkpoint file: {filename}'
|
||||
)
|
||||
|
||||
state_dict = checkpoint if 'state_dict' not in checkpoint else checkpoint[
|
||||
'state_dict']
|
||||
model.load_state_dict(state_dict)
|
||||
if model is not None:
|
||||
state_dict = checkpoint if 'state_dict' not in checkpoint else checkpoint[
|
||||
'state_dict']
|
||||
model.load_state_dict(state_dict)
|
||||
return checkpoint.get('meta', {})
|
||||
|
||||
|
||||
@@ -521,7 +533,6 @@ def load_task_model_checkpoint(model_to_load,
|
||||
|
||||
|
||||
def save_configuration(target_folder, config: Dict):
|
||||
from modelscope.utils.config import Config
|
||||
if isinstance(config, Config):
|
||||
config = config.to_dict()
|
||||
if ConfigFields.pipeline not in config:
|
||||
|
||||
@@ -11,6 +11,9 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modelscope.models.base import TorchModel
|
||||
from modelscope.preprocessors import Preprocessor
|
||||
from modelscope.utils.regress_test_utils import (compare_arguments_nested,
|
||||
numpify_tensor_nested)
|
||||
|
||||
|
||||
class TorchBaseTest(unittest.TestCase):
|
||||
@@ -69,8 +72,14 @@ class TorchBaseTest(unittest.TestCase):
|
||||
self.assertTrue(np.all(out.detach().numpy() > (add_bias - 10)))
|
||||
|
||||
def test_save_pretrained(self):
|
||||
preprocessor = Preprocessor.from_pretrained(
|
||||
'damo/nlp_structbert_sentence-similarity_chinese-tiny')
|
||||
model = TorchModel.from_pretrained(
|
||||
'damo/nlp_structbert_sentence-similarity_chinese-tiny')
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
res1 = numpify_tensor_nested(
|
||||
model(**preprocessor(('test1', 'test2'))))
|
||||
save_path = os.path.join(self.tmp_dir, 'test_save_pretrained')
|
||||
model.save_pretrained(
|
||||
save_path, save_checkpoint_names='pytorch_model.bin')
|
||||
@@ -79,6 +88,12 @@ class TorchBaseTest(unittest.TestCase):
|
||||
self.assertTrue(
|
||||
os.path.isfile(os.path.join(save_path, 'configuration.json')))
|
||||
self.assertTrue(os.path.isfile(os.path.join(save_path, 'vocab.txt')))
|
||||
model = TorchModel.from_pretrained(save_path)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
res2 = numpify_tensor_nested(
|
||||
model(**preprocessor(('test1', 'test2'))))
|
||||
self.assertTrue(compare_arguments_nested('', res1, res2))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -118,7 +118,7 @@ class EasyCVTrainerTestSingleGpu(unittest.TestCase):
|
||||
LogKeys.MODE: ModeKeys.TRAIN,
|
||||
LogKeys.EPOCH: 1,
|
||||
LogKeys.ITER: 3,
|
||||
LogKeys.LR: 0.00013
|
||||
LogKeys.LR: 0.00029
|
||||
}, json.loads(lines[0]))
|
||||
self.assertDictContainsSubset(
|
||||
{
|
||||
@@ -131,7 +131,7 @@ class EasyCVTrainerTestSingleGpu(unittest.TestCase):
|
||||
LogKeys.MODE: ModeKeys.TRAIN,
|
||||
LogKeys.EPOCH: 2,
|
||||
LogKeys.ITER: 3,
|
||||
LogKeys.LR: 0.00157
|
||||
LogKeys.LR: 0.00205
|
||||
}, json.loads(lines[2]))
|
||||
self.assertDictContainsSubset(
|
||||
{
|
||||
|
||||
@@ -12,7 +12,7 @@ from torch.optim import SGD
|
||||
from torch.optim.lr_scheduler import MultiStepLR
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.models.base import TorchModel
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.utils.constant import ModelFile, TrainerStages
|
||||
from modelscope.utils.test_utils import create_dummy_test_dataset
|
||||
@@ -21,7 +21,7 @@ dummy_dataset = create_dummy_test_dataset(
|
||||
np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 10)
|
||||
|
||||
|
||||
class DummyModel(nn.Module, Model):
|
||||
class DummyModel(TorchModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@@ -9,7 +9,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.optim import SGD
|
||||
from torch.optim.lr_scheduler import MultiStepLR
|
||||
from torch.optim.lr_scheduler import LinearLR, MultiStepLR
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.metrics.builder import METRICS, MetricKeys
|
||||
@@ -128,6 +128,86 @@ class LrSchedulerHookTest(unittest.TestCase):
|
||||
self.assertListEqual(log_lrs, target_lrs)
|
||||
self.assertListEqual(optim_lrs, target_lrs)
|
||||
|
||||
def test_accumulation_step(self):
|
||||
json_cfg = {
|
||||
'task': 'image_classification',
|
||||
'train': {
|
||||
'work_dir': self.tmp_dir,
|
||||
'dataloader': {
|
||||
'batch_size_per_gpu': 2,
|
||||
'workers_per_gpu': 1
|
||||
},
|
||||
'optimizer': {
|
||||
'type': 'SGD',
|
||||
'lr': 0.01,
|
||||
'options': {
|
||||
'cumulative_iters': 4,
|
||||
}
|
||||
},
|
||||
'lr_scheduler': {
|
||||
'type': 'LinearLR',
|
||||
'start_factor': 1.0,
|
||||
'end_factor': 0.0,
|
||||
'total_iters': int(8 * len(dummy_dataset) / 2),
|
||||
'options': {
|
||||
'by_epoch': False,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION)
|
||||
with open(config_path, 'w') as f:
|
||||
json.dump(json_cfg, f)
|
||||
|
||||
model = DummyModel()
|
||||
trainer_name = Trainers.default
|
||||
kwargs = dict(
|
||||
cfg_file=config_path,
|
||||
model=model,
|
||||
train_dataset=dummy_dataset,
|
||||
max_epochs=8,
|
||||
device='cpu')
|
||||
|
||||
trainer = build_trainer(trainer_name, kwargs)
|
||||
train_dataloader = trainer._build_dataloader_with_dataset(
|
||||
trainer.train_dataset, **trainer.cfg.train.get('dataloader', {}))
|
||||
trainer.register_optimizers_hook()
|
||||
|
||||
trainer.invoke_hook(TrainerStages.before_run)
|
||||
log_lrs = []
|
||||
optim_lrs = []
|
||||
for epoch in range(trainer._epoch, trainer._max_epochs):
|
||||
trainer.invoke_hook(TrainerStages.before_train_epoch)
|
||||
for iter, data_batch in enumerate(train_dataloader):
|
||||
trainer.invoke_hook(TrainerStages.before_train_iter)
|
||||
trainer.train_step(trainer.model, data_batch)
|
||||
trainer.invoke_hook(TrainerStages.after_train_iter)
|
||||
|
||||
if (trainer.iter + 1) % 4 == 0:
|
||||
log_lrs.append(trainer.log_buffer.output[LogKeys.LR])
|
||||
optim_lrs.append(trainer.optimizer.param_groups[0]['lr'])
|
||||
|
||||
trainer._iter += 1
|
||||
|
||||
trainer.invoke_hook(TrainerStages.after_train_epoch)
|
||||
trainer._epoch += 1
|
||||
trainer.invoke_hook(TrainerStages.after_run)
|
||||
lr = 0.01
|
||||
decay = 0.01 / 40
|
||||
target_lrs = []
|
||||
for i in range(40):
|
||||
if i >= 3:
|
||||
lr -= decay
|
||||
target_lrs.append(lr)
|
||||
else:
|
||||
target_lrs.append(lr)
|
||||
target_lrs = [
|
||||
i for idx, i in enumerate(target_lrs) if (idx + 1) % 4 == 0
|
||||
]
|
||||
self.assertTrue(all(np.isclose(log_lrs, target_lrs)))
|
||||
self.assertTrue(all(np.isclose(optim_lrs, target_lrs)))
|
||||
|
||||
def test_warmup_lr_scheduler_hook(self):
|
||||
global _global_iter
|
||||
_global_iter = 0
|
||||
|
||||
@@ -7,6 +7,7 @@ import unittest
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.optim import SGD
|
||||
from torch.optim.lr_scheduler import MultiStepLR
|
||||
@@ -31,7 +32,7 @@ class DummyModel(nn.Module, Model):
|
||||
def forward(self, feat, labels):
|
||||
x = self.linear(feat)
|
||||
x = self.bn(x)
|
||||
loss = torch.sum(x)
|
||||
loss = F.cross_entropy(x, labels.to(torch.long).squeeze())
|
||||
return dict(logits=x, loss=loss)
|
||||
|
||||
|
||||
@@ -177,5 +178,78 @@ class TorchAMPOptimizerHookTest(unittest.TestCase):
|
||||
trainer.invoke_hook(TrainerStages.after_run)
|
||||
|
||||
|
||||
class TorchApexOptimizerHookTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
self.tmp_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(self.tmp_dir):
|
||||
os.makedirs(self.tmp_dir)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
shutil.rmtree(self.tmp_dir)
|
||||
|
||||
@unittest.skip('Apex works abnormally with torch 1.13')
|
||||
def test_apex_optimizer_hook(self):
|
||||
json_cfg = {
|
||||
'task': 'image_classification',
|
||||
'train': {
|
||||
'work_dir': self.tmp_dir,
|
||||
'dataloader': {
|
||||
'batch_size_per_gpu': 2,
|
||||
'workers_per_gpu': 1
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION)
|
||||
with open(config_path, 'w') as f:
|
||||
json.dump(json_cfg, f)
|
||||
|
||||
model = DummyModel().cuda()
|
||||
optimizer = SGD(model.parameters(), lr=0.01)
|
||||
lr_scheduler = MultiStepLR(optimizer, milestones=[1, 2])
|
||||
trainer_name = Trainers.default
|
||||
kwargs = dict(
|
||||
cfg_file=config_path,
|
||||
model=model,
|
||||
train_dataset=dummy_dataset,
|
||||
optimizers=(optimizer, lr_scheduler),
|
||||
max_epochs=2)
|
||||
|
||||
trainer = build_trainer(trainer_name, kwargs)
|
||||
train_dataloader = trainer._build_dataloader_with_dataset(
|
||||
trainer.train_dataset, **trainer.cfg.train.get('dataloader', {}))
|
||||
trainer.register_optimizers_hook()
|
||||
trainer.register_hook_from_cfg([{'type': 'ApexAMPOptimizerHook'}])
|
||||
trainer.invoke_hook(TrainerStages.before_run)
|
||||
|
||||
for _ in range(trainer._epoch, trainer._max_epochs):
|
||||
trainer.invoke_hook(TrainerStages.before_train_epoch)
|
||||
for _, data_batch in enumerate(train_dataloader):
|
||||
for k, v in data_batch.items():
|
||||
data_batch[k] = v.cuda()
|
||||
trainer.invoke_hook(TrainerStages.before_train_iter)
|
||||
trainer.train_step(trainer.model, data_batch)
|
||||
trainer.invoke_hook(TrainerStages.after_train_iter)
|
||||
|
||||
self.assertEqual(trainer.train_outputs['logits'].dtype,
|
||||
torch.float16)
|
||||
|
||||
# test if `after_train_iter`, whether the model is reset to fp32
|
||||
trainer.train_step(trainer.model, data_batch)
|
||||
|
||||
self.assertEqual(
|
||||
len(trainer.optimizer.param_groups[0]['params']), 4)
|
||||
for i in range(4):
|
||||
self.assertTrue(trainer.optimizer.param_groups[0]['params']
|
||||
[i].requires_grad)
|
||||
|
||||
trainer.invoke_hook(TrainerStages.after_train_epoch)
|
||||
trainer._epoch += 1
|
||||
trainer.invoke_hook(TrainerStages.after_run)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -12,7 +12,7 @@ from torch.optim import SGD
|
||||
from torch.optim.lr_scheduler import MultiStepLR
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.models.base import TorchModel
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.utils.constant import LogKeys, ModelFile, TrainerStages
|
||||
from modelscope.utils.test_utils import create_dummy_test_dataset
|
||||
@@ -21,7 +21,7 @@ dummy_dataset = create_dummy_test_dataset(
|
||||
np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 10)
|
||||
|
||||
|
||||
class DummyModel(nn.Module, Model):
|
||||
class DummyModel(TorchModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import hashlib
|
||||
import os
|
||||
import pathlib
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
@@ -121,6 +123,20 @@ class TestTrainerWithNlp(unittest.TestCase):
|
||||
checkpoint_path=os.path.join(self.tmp_dir, 'epoch_10.pth'))
|
||||
self.assertTrue(Metrics.accuracy in eval_results)
|
||||
|
||||
def saving_fn(inputs, outputs):
|
||||
with open(f'{self.tmp_dir}/predicts.txt', 'a') as f:
|
||||
labels = inputs['labels'].cpu().numpy()
|
||||
predictions = np.argmax(
|
||||
outputs['logits'].cpu().numpy(), axis=1)
|
||||
for label, pred in zip(labels, predictions):
|
||||
f.writelines(f'{label}, {pred}\n')
|
||||
|
||||
trainer.predict(
|
||||
predict_datasets=self.dataset,
|
||||
saving_fn=saving_fn,
|
||||
checkpoint_path=os.path.join(self.tmp_dir, 'epoch_10'))
|
||||
self.assertTrue(os.path.isfile(f'{self.tmp_dir}/predicts.txt'))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_trainer_save_best_ckpt(self):
|
||||
|
||||
@@ -208,6 +224,24 @@ class TestTrainerWithNlp(unittest.TestCase):
|
||||
os.path.isfile(
|
||||
os.path.join(self.tmp_dir, 'output_best',
|
||||
'pytorch_model.bin')))
|
||||
md51 = hashlib.md5(
|
||||
pathlib.Path(
|
||||
os.path.join(self.tmp_dir, 'output',
|
||||
'pytorch_model.bin')).read_bytes()).hexdigest()
|
||||
md52 = hashlib.md5(
|
||||
pathlib.Path(os.path.join(
|
||||
self.tmp_dir, 'epoch_10.pth')).read_bytes()).hexdigest()
|
||||
self.assertEqual(md51, md52)
|
||||
md51 = hashlib.md5(
|
||||
pathlib.Path(
|
||||
os.path.join(self.tmp_dir, 'output_best',
|
||||
'pytorch_model.bin')).read_bytes()).hexdigest()
|
||||
md52 = hashlib.md5(
|
||||
pathlib.Path(
|
||||
os.path.join(
|
||||
self.tmp_dir,
|
||||
'best_iter19_accuracy28.pth')).read_bytes()).hexdigest()
|
||||
self.assertEqual(md51, md52)
|
||||
|
||||
@unittest.skip('skip for now before test is re-configured')
|
||||
def test_trainer_with_configured_datasets(self):
|
||||
@@ -313,7 +347,7 @@ class TestTrainerWithNlp(unittest.TestCase):
|
||||
regress_tool = MsRegressTool(baseline=False)
|
||||
with regress_tool.monitor_ms_train(
|
||||
trainer, 'trainer_continue_train', level='strict'):
|
||||
trainer.train(os.path.join(self.tmp_dir, 'iter_3.pth'))
|
||||
trainer.train(os.path.join(self.tmp_dir, 'iter_3'))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_trainer_with_new_style_configuration(self):
|
||||
@@ -489,6 +523,49 @@ class TestTrainerWithNlp(unittest.TestCase):
|
||||
for i in range(2):
|
||||
self.assertIn(f'epoch_{i + 1}.pth', results_files)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_trainer_with_hook_register(self):
|
||||
model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny'
|
||||
|
||||
def cfg_modify_fn(cfg):
|
||||
cfg.train.hooks.append({'type': 'TorchAMPOptimizerHook'})
|
||||
return cfg
|
||||
|
||||
kwargs = dict(
|
||||
model=model_id,
|
||||
train_dataset=self.dataset,
|
||||
eval_dataset=self.dataset,
|
||||
cfg_modify_fn=cfg_modify_fn,
|
||||
work_dir=self.tmp_dir)
|
||||
|
||||
trainer = build_trainer(default_args=kwargs)
|
||||
trainer.train()
|
||||
results_files = os.listdir(self.tmp_dir)
|
||||
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
|
||||
for i in range(10):
|
||||
self.assertIn(f'epoch_{i + 1}.pth', results_files)
|
||||
|
||||
output_files = os.listdir(
|
||||
os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR))
|
||||
self.assertIn(ModelFile.CONFIGURATION, output_files)
|
||||
self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, output_files)
|
||||
copy_src_files = os.listdir(trainer.model_dir)
|
||||
|
||||
print(f'copy_src_files are {copy_src_files}')
|
||||
print(f'output_files are {output_files}')
|
||||
for item in copy_src_files:
|
||||
if not item.startswith('.'):
|
||||
self.assertIn(item, output_files)
|
||||
|
||||
def pipeline_sentence_similarity(model_dir):
|
||||
model = Model.from_pretrained(model_dir)
|
||||
pipeline_ins = pipeline(
|
||||
task=Tasks.sentence_similarity, model=model)
|
||||
print(pipeline_ins(input=(self.sentence1, self.sentence2)))
|
||||
|
||||
output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR)
|
||||
pipeline_sentence_similarity(output_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user