diff --git a/examples/pytorch/llama/finetune_llama.py b/examples/pytorch/llama/finetune_llama.py new file mode 100644 index 00000000..88975e66 --- /dev/null +++ b/examples/pytorch/llama/finetune_llama.py @@ -0,0 +1,263 @@ +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# Copyright (c) Alibaba, Inc. and its affiliates. + +import copy +import logging +import os +import shutil +import tempfile +import unittest +from dataclasses import dataclass, field + +import json +import torch +import utils + +from modelscope import TrainingArgs +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Trainers +from modelscope.models.nlp.llama import LlamaForTextGeneration, LlamaTokenizer +from modelscope.msdatasets.dataset_cls.custom_datasets.torch_custom_dataset import \ + TorchCustomDataset +from modelscope.trainers import build_trainer + +IGNORE_INDEX = -100 +DEFAULT_PAD_TOKEN = '[PAD]' +DEFAULT_EOS_TOKEN = '' +DEFAULT_BOS_TOKEN = '' +DEFAULT_UNK_TOKEN = '' +PROMPT_DICT = { + 'prompt_input': + ('Below is an instruction that describes a task, paired with an input that provides further context. ' + 'Write a response that appropriately completes the request.\n\n' + '### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:' + ), + 'prompt_no_input': + ('Below is an instruction that describes a task. ' + 'Write a response that appropriately completes the request.\n\n' + '### Instruction:\n{instruction}\n\n### Response:'), +} + + +@dataclass(init=False) +class TextGenerationArguments(TrainingArgs): + src_txt: str = field( + default=None, + metadata={ + 'help': 'The source text key of preprocessor', + 'cfg_node': 'preprocessor.src_txt' + }) + + deepspeed: str = field( + default=None, + metadata={ + 'help': 'The location of DeepSpeed json config file.', + }) + + work_dir: str = field( + default=None, metadata={ + 'help': 'The location of work dir', + }) + + +def _tokenize_fn(strings, tokenizer): + """Tokenize a list of strings.""" + tokenized_list = [ + tokenizer( + text, + return_tensors='pt', + padding='longest', + max_length=tokenizer.model_max_length, + truncation=True, + ) for text in strings + ] + input_ids = labels = [ + tokenized.input_ids[0] for tokenized in tokenized_list + ] + input_ids_lens = labels_lens = [ + tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() + for tokenized in tokenized_list + ] + return dict( + input_ids=input_ids, + labels=labels, + input_ids_lens=input_ids_lens, + labels_lens=labels_lens, + ) + + +def preprocess(sources, targets, tokenizer): + """Preprocess the data by tokenizing.""" + examples = [s + t for s, t in zip(sources, targets)] + examples_tokenized, sources_tokenized = [ + _tokenize_fn(strings, tokenizer) for strings in (examples, sources) + ] + input_ids = examples_tokenized['input_ids'] + labels = copy.deepcopy(input_ids) + for label, source_len in zip(labels, sources_tokenized['input_ids_lens']): + label[:source_len] = IGNORE_INDEX + return dict(input_ids=input_ids, labels=labels) + + +def smart_tokenizer_and_embedding_resize(special_tokens_dict, tokenizer, + model): + """Resize tokenizer and embedding. + + Note: This is the unoptimized version that may make your embedding size not be divisible by 64. + """ + num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + model.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = model.get_input_embeddings().weight.data + output_embeddings = model.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + +class SupervisedDataset(TorchCustomDataset): + """Dataset for supervised fine-tuning.""" + + def __init__(self, data_path: str, tokenizer): + logging.warning('Loading data...') + f = open(data_path, 'r') + list_data_dict = json.load(f) + f.close() + + logging.warning('Formatting inputs...') + prompt_input, prompt_no_input = PROMPT_DICT[ + 'prompt_input'], PROMPT_DICT['prompt_no_input'] + sources = [ + prompt_input.format_map(example) if example.get('input', '') != '' + else prompt_no_input.format_map(example) + for example in list_data_dict + ] + targets = [ + f"{example['output']}{tokenizer.eos_token}" + for example in list_data_dict + ] + + logging.warning('Tokenizing inputs... This may take some time...') + data_dict = preprocess(sources, targets, tokenizer) + + self.input_ids = data_dict['input_ids'] + self.labels = data_dict['labels'] + + def __len__(self): + return len(self.input_ids) + + def __getitem__(self, i): + return dict(input_ids=self.input_ids[i], labels=self.labels[i]) + + +@dataclass +class DataCollatorForSupervisedDataset(object): + """Collate examples for supervised fine-tuning.""" + + tokenizer: LlamaTokenizer + + def __call__(self, instances): + input_ids, labels = tuple([instance[key] for instance in instances] + for key in ('input_ids', 'labels')) + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id) + labels = torch.nn.utils.rnn.pad_sequence( + labels, batch_first=True, padding_value=IGNORE_INDEX) + return dict( + input_ids=input_ids, + labels=labels, + attention_mask=input_ids.ne(self.tokenizer.pad_token_id), + ) + + +config, args = TextGenerationArguments().parse_cli().to_config() + +if __name__ == '__main__': + + def cfg_modify_fn(cfg): + if args.use_model_config: + cfg.merge_from_dict(config) + else: + cfg = config + cfg.train.lr_scheduler = { + 'type': 'CosineAnnealingLR', + 'T_max': 1, + 'options': { + 'by_epoch': False + } + } + cfg.train.optimizer = { + 'type': 'AdamW', + 'lr': 2e-5, + 'weight_decay': 0.0, + 'options': { + 'cumulative_iters': 8, + 'warmup': { + 'type': 'LinearWarmup', + 'warmup_ratio': 0.03 + } + } + } + cfg.train.logging = {'interval': 8, 'by_epoch': False} + cfg.train['bf16'] = True + cfg.train.dataloader = {'batch_size_per_gpu': 4, 'workers_per_gpu': 1} + if 'hooks' not in cfg.train: + cfg.train['hooks'] = [] + cfg.train.hooks.append({ + 'type': 'DeepspeedHook', + 'config': args.deepspeed, + 'save_zero_checkpoint': True, + 'with_mpu': False, + }) + + cfg.preprocessor.sequence_length = 512 + return cfg + + model_path = args.model if os.path.exists( + args.model) else snapshot_download(args.model) + data_path = args.src_txt if args.src_txt else os.path.join( + model_path, 'alpaca_data.json') + model = LlamaForTextGeneration.from_pretrained(model_path) + + tokenizer = LlamaTokenizer.from_pretrained( + model_path, + model_max_length=512, + padding_side='right', + ) + + special_tokens_dict = dict() + special_tokens_dict['pad_token'] = DEFAULT_PAD_TOKEN + special_tokens_dict['eos_token'] = DEFAULT_EOS_TOKEN + special_tokens_dict['bos_token'] = DEFAULT_BOS_TOKEN + special_tokens_dict['unk_token'] = DEFAULT_UNK_TOKEN + + smart_tokenizer_and_embedding_resize( + special_tokens_dict=special_tokens_dict, + tokenizer=tokenizer, + model=model, + ) + + train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_path) + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) + + kwargs = dict( + model=model, + cfg_file=os.path.join(model_path, 'configuration.json'), + train_dataset=train_dataset, + data_collator=data_collator, + max_epochs=3, + work_dir=args.work_dir, + cfg_modify_fn=cfg_modify_fn) + + # Construct trainer and train + trainer = build_trainer( + name=Trainers.text_generation_trainer, default_args=kwargs) + trainer.train() diff --git a/examples/pytorch/llama/run_train_llama.sh b/examples/pytorch/llama/run_train_llama.sh new file mode 100644 index 00000000..7c860d57 --- /dev/null +++ b/examples/pytorch/llama/run_train_llama.sh @@ -0,0 +1,9 @@ +DATA_PARALLEL_SIZE=4 + + +export PYTHONPATH=$PYTHONPATH:./ +torchrun --nproc_per_node $DATA_PARALLEL_SIZE examples/pytorch/llama/finetune_llama.py \ + --work_dir './tmp' \ + --model 'skyline2006/llama-7b' \ + --deepspeed 'default_offload_opt_param.json' \ + --eval_interval 100 diff --git a/modelscope/trainers/hooks/distributed/deepspeed_hook.py b/modelscope/trainers/hooks/distributed/deepspeed_hook.py index d0a6eb9b..3626a9e3 100644 --- a/modelscope/trainers/hooks/distributed/deepspeed_hook.py +++ b/modelscope/trainers/hooks/distributed/deepspeed_hook.py @@ -1,24 +1,131 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +# Copyright 2020 The HuggingFace Team. All rights reserved. +import math import os import shutil +from functools import partialmethod import deepspeed import torch from deepspeed import DeepSpeedEngine from megatron_util import mpu, print_rank_0 +from transformers.deepspeed import HfTrainerDeepSpeedConfig from modelscope.metainfo import Hooks from modelscope.trainers.hooks import LoadCheckpointHook from modelscope.trainers.hooks.builder import HOOKS from modelscope.trainers.hooks.checkpoint.checkpoint_hook import ( BestCkptSaverHook, CheckpointHook) +from modelscope.trainers.hooks.checkpoint.checkpoint_processor import \ + CheckpointProcessor from modelscope.trainers.hooks.hook import Hook +from modelscope.trainers.hooks.lr_scheduler_hook import (LrSchedulerHook, + LrSchedulerProcessor) +from modelscope.trainers.hooks.optimizer.base import (OptimizerHook, + OptimizerProcessor) from modelscope.trainers.hooks.priority import Priority from modelscope.utils.checkpoint import save_checkpoint +from modelscope.utils.constant import DistributedParallelType +from modelscope.utils.device import create_device from modelscope.utils.logger import get_logger -from ..checkpoint.checkpoint_processor import CheckpointProcessor -from ..lr_scheduler_hook import LrSchedulerHook, LrSchedulerProcessor -from ..optimizer.base import OptimizerHook, OptimizerProcessor +from modelscope.utils.torch_utils import (get_dist_info, get_local_rank, + init_dist) + + +class DeepSpeedConfig(HfTrainerDeepSpeedConfig): + """ + The `DeepSpeedConfig` object is meant to be created during `TrainingArguments` object creation and has the + same lifespan as the latter. + """ + + def is_auto(self, ds_key_long): + val = self.get_value(ds_key_long) + if val is None: + return False + else: + return val == 'auto' + + def trainer_config_finalize(self, args, model, num_training_steps): + """ + This stage runs after we have the model and know num_training_steps. + + Now we can complete the configuration process. + """ + # zero + + # deal with config keys that use `auto` value and rely on model's hidden_size + hidden_size_based_keys = [ + 'zero_optimization.reduce_bucket_size', + 'zero_optimization.stage3_prefetch_bucket_size', + 'zero_optimization.stage3_param_persistence_threshold', + ] + hidden_size_auto_keys = [ + x for x in hidden_size_based_keys if self.is_auto(x) + ] + + if len(hidden_size_auto_keys) > 0: + if hasattr(model.config, 'hidden_size'): + hidden_size = model.config.hidden_size + elif hasattr(model.config, 'hidden_sizes'): + # if there are many hidden sizes pick the largest one + hidden_size = max(model.config.hidden_sizes) + else: + raise ValueError( + "The model's config file has neither `hidden_size` nor `hidden_sizes` entry, " + "therefore it's not possible to automatically fill out the following `auto` entries " + f'in the DeepSpeed config file: {hidden_size_auto_keys}. You can fix that by replacing ' + '`auto` values for these keys with an integer value of your choice.' + ) + + self.fill_only('zero_optimization.reduce_bucket_size', + hidden_size * hidden_size) + if self.is_zero3(): + # automatically assign the optimal config values based on model config + self.fill_only('zero_optimization.stage3_prefetch_bucket_size', + 0.9 * hidden_size * hidden_size) + self.fill_only( + 'zero_optimization.stage3_param_persistence_threshold', + 10 * hidden_size) + + # scheduler + options = args.train.optimizer.get('options', {}) + warmup = options.get('warmup', {}) + warmup_steps = warmup.get('warmup_steps', 0) + warmup_ratio = warmup.get('warmup_ratio', 0.0) + warmup_steps = warmup_steps if warmup_steps > 0 else math.ceil( + num_training_steps * warmup_ratio) + self.fill_match('scheduler.params.total_num_steps', num_training_steps) + self.fill_match('scheduler.params.warmup_num_steps', warmup_steps) + + if len(self.mismatches) > 0: + mismatches = '\n'.join(self.mismatches) + raise ValueError( + 'Please correct the following DeepSpeed config values that mismatch TrainingArguments' + f" values:\n{mismatches}\nThe easiest method is to set these DeepSpeed config values to 'auto'." + ) + + +def deepspeed_optim_sched(trainer, hf_deepspeed_config, num_training_steps): + config = hf_deepspeed_config.config + optimizer = None + if 'optimizer' not in config: + if hf_deepspeed_config.is_offload(): + logger.info( + 'Detected ZeRO Offload and non-DeepSpeed optimizers: This combination should work as long as the' + ' custom optimizer has both CPU and GPU implementation (except LAMB)' + ) + + # ds supports Adam, OneBitAdam, and Lamb optimizers and can import other optimizers from torch. + # But trainer uses AdamW by default. + optimizer = trainer.optimizer + # To use other optimizers requires voiding warranty with: `zero_allow_untested_optimizer` + config['zero_allow_untested_optimizer'] = True + + lr_scheduler = None + if 'scheduler' not in config: + lr_scheduler = trainer.scheduler + + return optimizer, lr_scheduler class DeepspeedProcessor(CheckpointProcessor, LrSchedulerProcessor, @@ -58,6 +165,8 @@ class DeepspeedProcessor(CheckpointProcessor, LrSchedulerProcessor, prefix = os.path.basename(checkpoint_path_prefix) trainer.model.save_checkpoint(save_dir, prefix) + if not self.stage3_gather_16bit_weights_on_model_save: + return bin_file = self.get_bin_file() src_file = os.path.join(checkpoint_path_prefix, bin_file) dest_file = os.path.join(output_dir, self._BIN_FILE_DIR, bin_file) @@ -129,6 +238,8 @@ class DeepspeedProcessor(CheckpointProcessor, LrSchedulerProcessor, trainer.model.backward(loss) # update parameters + # Optimizer step for deepspeed must be called on every step regardless of + # the value of gradient accumulation iters trainer.model.step() def initialize_optimizer(self, trainer): @@ -137,20 +248,36 @@ class DeepspeedProcessor(CheckpointProcessor, LrSchedulerProcessor, def step(self, trainer): pass + def should_save_on_rank(self, trainer): + return True + + def get_current_lr(self, trainer): + if isinstance(trainer.optimizer, torch.optim.Optimizer) or isinstance( + trainer.optimizer, deepspeed.DeepSpeedOptimizer): + lr = [group['lr'] for group in trainer.optimizer.param_groups] + elif isinstance(trainer.optimizer, dict): + lr = dict() + for name, optim in trainer.optimizer.items(): + lr[name] = [group['lr'] for group in optim.param_groups] + else: + raise RuntimeError( + 'lr is not applicable because optimizer does not exist.') + return lr + @HOOKS.register_module(module_name=Hooks.DeepspeedHook) class DeepspeedHook(Hook): PRIORITY = Priority.VERY_HIGH def __init__(self, + config=None, deepspeed_activation_checkpointing=True, save_zero_checkpoint=False, with_mpu=True): self.save_zero_checkpoint = save_zero_checkpoint self.deepspeed_activation_checkpointing = deepspeed_activation_checkpointing - # TODO without mpu self.with_mpu = with_mpu - assert with_mpu, 'DeepspeedHook now is only for mpu models.' + self.deepspeed_config = config def register_processor(self, trainer): processor = DeepspeedProcessor() @@ -158,23 +285,76 @@ class DeepspeedHook(Hook): if len(optimizer_hook) > 0 and not isinstance( optimizer_hook[0].processor, DeepspeedProcessor): optimizer_hook[0].set_processor(processor) - lr_schedular_hook = trainer.get_hook(LrSchedulerHook) - if len(lr_schedular_hook) > 0 and not isinstance( - lr_schedular_hook[0].processor, DeepspeedProcessor): - lr_schedular_hook[0].set_processor(processor) ckpt_hook = trainer.get_hook(CheckpointHook) if len(ckpt_hook) > 0 and not isinstance(ckpt_hook[0].processor, DeepspeedProcessor): ckpt_hook[0].set_processor(processor) + best_ckpt_hook = trainer.get_hook(BestCkptSaverHook) if len(best_ckpt_hook) > 0 and not isinstance( best_ckpt_hook[0].processor, DeepspeedProcessor): best_ckpt_hook[0].set_processor(processor) + load_ckpt_hook = trainer.get_hook(LoadCheckpointHook) if len(load_ckpt_hook) > 0 and not isinstance( load_ckpt_hook[0].processor, DeepspeedProcessor): load_ckpt_hook[0].set_processor(processor) + lr_scheduler_hook = trainer.get_hook(LrSchedulerHook) + if len(lr_scheduler_hook) > 0 and not isinstance( + lr_scheduler_hook[0].processor, DeepspeedProcessor): + lr_scheduler_hook[0].set_processor(processor) + self.processor = processor + + def prepare_args(self, args): + args.per_device_train_batch_size = args.train.dataloader.get( + 'batch_size_per_gpu', 4) + args.max_grad_norm = args.train.get('clip_grad', 1.0) + args.learning_rate = args.train.optimizer.get('lr', 2e-5) + args.adam_beta1 = args.train.optimizer.get('adam_beta1', 0.9) + args.adam_beta2 = args.train.optimizer.get('adam_beta2', 0.999) + args.adam_epsilon = args.train.optimizer.get('adam_epsilon', 1e-8) + args.weight_decay = args.train.optimizer.get('weight_decay', 0.0) + args.fp16 = args.train.get('use_fp16', False) + args.fp16_full_eval = args.train.get('use_fp16', False) + args.fp16_backend = args.train.get('fp16_backend', 'amp') + args.save_on_each_node = args.train.get('save_on_each_node', False) + args.fp16_opt_level = args.train.get('fp16_opt_level', None) + args.fp16_opt_level = next((item.get('opt_level', args.fp16_opt_level) + for item in args.train.hooks + if item['type'] == 'ApexAMPOptimizerHook'), + args.fp16_opt_level) + if not args.fp16_opt_level: + args.fp16_opt_level = 'O1' + args.bf16 = args.train.get('bf16', False) + + def get_deepspeed_config(self, trainer, args, max_steps): + _, args.world_size = get_dist_info() + self.prepare_args(args) + if os.path.exists(self.deepspeed_config): + deepspeed_config = self.deepspeed_config + else: + deepspeed_config = os.path.join(trainer.model_dir, + self.deepspeed_config) + if not os.path.exists(deepspeed_config): + raise RuntimeError( + f'No such DeepSpeed json config file: {self.deepspeed_config}.' + ) + self.logger.info(f'Loading deepspeed config from {deepspeed_config}') + + ds_config = DeepSpeedConfig(deepspeed_config) + ds_config.trainer_config_process(args) + + ds_config.trainer_config_finalize(args, trainer.model, max_steps) + return ds_config + + def after_init(self, trainer): + init_dist('pytorch') + local_rank = get_local_rank() + trainer.device = create_device(f'cuda:{local_rank}') + trainer.model.to(trainer.device) + trainer.parallel_groups[DistributedParallelType.DP] = None + def before_val(self, trainer): pass @@ -185,26 +365,23 @@ class DeepspeedHook(Hook): self.logger = trainer.logger # deepspeed init - args = trainer.cfg.train - args.deepspeed_config = os.path.join(trainer.model_dir, - args.deepspeed_config) + args = trainer.cfg + args.gradient_accumulation_steps = args.train.optimizer.get( + 'options', {}).get('cumulative_iters', 1) + num_update_steps_per_epoch = trainer.iters_per_epoch // args.gradient_accumulation_steps + max_steps = math.ceil(trainer._max_epochs * num_update_steps_per_epoch) - trainer.model, _, _, _ = deepspeed.initialize( + ds_config = self.get_deepspeed_config(trainer, args, max_steps) + + optimizer, lr_scheduler = deepspeed_optim_sched( + trainer, ds_config, max_steps) + config = ds_config.config + self.processor.stage3_gather_16bit_weights_on_model_save = config[ + 'zero_optimization'].get( + 'stage3_gather_16bit_weights_on_model_save', True) + + trainer.model, trainer.optimizer, _, trainer.lr_scheduler = deepspeed.initialize( model=trainer.model, - optimizer=trainer.optimizer, - args=args, - lr_scheduler=trainer.lr_scheduler, - mpu=mpu, - dist_init_required=False) - trainer.model.save_zero_checkpoint = self.save_zero_checkpoint - - if self.deepspeed_activation_checkpointing: - model = trainer.unwrap_module(trainer.model) - deepspeed.checkpointing.configure( - mpu, - deepspeed_config=args.deepspeed_config, - num_checkpoints=model.config.num_hidden_layers) - - mpu.checkpoint = deepspeed.checkpointing.checkpoint - mpu.get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker - mpu.model_parallel_cuda_manual_seed = deepspeed.checkpointing.model_parallel_cuda_manual_seed + optimizer=optimizer, + config=config, + lr_scheduler=lr_scheduler) diff --git a/modelscope/trainers/hooks/lr_scheduler_hook.py b/modelscope/trainers/hooks/lr_scheduler_hook.py index 51a8e858..facf5155 100644 --- a/modelscope/trainers/hooks/lr_scheduler_hook.py +++ b/modelscope/trainers/hooks/lr_scheduler_hook.py @@ -39,6 +39,20 @@ class LrSchedulerProcessor: else: trainer.lr_scheduler.step() + def get_current_lr(self, trainer): + import torch + + if isinstance(trainer.optimizer, torch.optim.Optimizer): + lr = [group['lr'] for group in trainer.optimizer.param_groups] + elif isinstance(trainer.optimizer, dict): + lr = dict() + for name, optim in trainer.optimizer.items(): + lr[name] = [group['lr'] for group in optim.param_groups] + else: + raise RuntimeError( + 'lr is not applicable because optimizer does not exist.') + return lr + class LrStrategy: by_epoch = 'by_epoch' @@ -84,20 +98,6 @@ class LrSchedulerHook(Hook): self.processor.initialize_lr_scheduler(trainer) - def get_current_lr(self, trainer): - import torch - - if isinstance(trainer.optimizer, torch.optim.Optimizer): - lr = [group['lr'] for group in trainer.optimizer.param_groups] - elif isinstance(trainer.optimizer, dict): - lr = dict() - for name, optim in trainer.optimizer.items(): - lr[name] = [group['lr'] for group in optim.param_groups] - else: - raise RuntimeError( - 'lr is not applicable because optimizer does not exist.') - return lr - def after_train_iter(self, trainer): if self.lr_strategy == LrStrategy.by_step and trainer.iter >= getattr( trainer, 'cumulative_iters', 1) - 1: @@ -112,7 +112,7 @@ class LrSchedulerHook(Hook): self.processor.step(trainer) def _get_log_lr(self, trainer): - cur_lr = self.get_current_lr(trainer) + cur_lr = self.processor.get_current_lr(trainer) # only record lr of the first param group if isinstance(cur_lr, list): lr = cur_lr[0] diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py index c980de04..fd0fafb8 100644 --- a/modelscope/trainers/trainer.py +++ b/modelscope/trainers/trainer.py @@ -1001,7 +1001,7 @@ class EpochBasedTrainer(BaseTrainer): """ optimizer, lr_scheduler = self.optimizers if optimizer is None: - optimizer_cfg = self.cfg.train.get('optimizer', None) + optimizer_cfg = deepcopy(self.cfg.train.get('optimizer', None)) else: optimizer_cfg = None @@ -1011,7 +1011,8 @@ class EpochBasedTrainer(BaseTrainer): optimizer = self.build_optimizer(cfg=optimizer_cfg) if lr_scheduler is None: - lr_scheduler_cfg = self.cfg.train.get('lr_scheduler', None) + lr_scheduler_cfg = deepcopy( + self.cfg.train.get('lr_scheduler', None)) else: lr_scheduler_cfg = None