mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 20:49:37 +01:00
update config
This commit is contained in:
@@ -29,128 +29,6 @@ class DeepSpeedConfig(HfTrainerDeepSpeedConfig):
|
||||
same lifespan as the latter.
|
||||
"""
|
||||
|
||||
def __init__(self, config_file_or_dict):
|
||||
super().__init__(config_file_or_dict)
|
||||
self._dtype = None
|
||||
self.mismatches = []
|
||||
|
||||
def dtype(self):
|
||||
if self._dtype is None:
|
||||
raise ValueError("trainer_config_process() wasn't called yet to tell dtype")
|
||||
return self._dtype
|
||||
|
||||
def is_auto(self, ds_key_long):
|
||||
val = self.get_value(ds_key_long)
|
||||
if val is None:
|
||||
return False
|
||||
else:
|
||||
return val == "auto"
|
||||
|
||||
def fill_match(self, ds_key_long, hf_val, hf_key=None, must_match=True):
|
||||
"""
|
||||
A utility method that massages the config file and can optionally verify that the values match.
|
||||
|
||||
1. Replace "auto" values with `TrainingArguments` value.
|
||||
|
||||
2. If it wasn't "auto" and `must_match` is true, then check that DS config matches Trainer
|
||||
config values and if mismatched add the entry to `self.mismatched` - will assert during
|
||||
`trainer_config_finalize` for one or more mismatches.
|
||||
|
||||
"""
|
||||
config, ds_key = self.find_config_node(ds_key_long)
|
||||
if config is None:
|
||||
return
|
||||
|
||||
if config.get(ds_key) == "auto":
|
||||
config[ds_key] = hf_val
|
||||
return
|
||||
|
||||
if not must_match:
|
||||
return
|
||||
|
||||
ds_val = config.get(ds_key)
|
||||
if ds_val is not None and ds_val != hf_val:
|
||||
self.mismatches.append(f"- ds {ds_key_long}={ds_val} vs hf {hf_key}={hf_val}")
|
||||
|
||||
fill_only = partialmethod(fill_match, must_match=False)
|
||||
|
||||
def trainer_config_process(self, args):
|
||||
"""
|
||||
Adjust the config with `TrainingArguments` values. This stage is run during `TrainingArguments` object
|
||||
creation.
|
||||
"""
|
||||
batch_size_per_gpu = args.train.dataloader.get("batch_size_per_gpu", 4)
|
||||
gradient_accumulation_steps = args.train.get("gradient_accumulation_steps", 8)
|
||||
workers_per_gpu = args.train.dataloader.get("workers_per_gpu", 2)
|
||||
clip_grad = args.train.get("clip_grad", 1.0)
|
||||
lr = args.train.optimizer.get("lr", 2e-5)
|
||||
adam_beta1 = args.train.optimizer.get("adam_beta1", 0.9)
|
||||
adam_beta2 = args.train.optimizer.get("adam_beta2", 0.999)
|
||||
adam_epsilon = args.train.optimizer.get("adam_epsilon", 1e-8)
|
||||
weight_decay = args.train.optimizer.get("weight_decay", 0.0)
|
||||
|
||||
# DeepSpeed does:
|
||||
# train_batch_size = world_size * train_micro_batch_size_per_gpu * gradient_accumulation_steps
|
||||
train_batch_size = args.world_size * batch_size_per_gpu * gradient_accumulation_steps
|
||||
|
||||
self.fill_match(
|
||||
"train_micro_batch_size_per_gpu", batch_size_per_gpu)
|
||||
self.fill_match("gradient_accumulation_steps", gradient_accumulation_steps)
|
||||
self.fill_match("train_batch_size", train_batch_size)
|
||||
self.fill_match("gradient_clipping", clip_grad)
|
||||
|
||||
self.fill_match("optimizer.params.lr", lr)
|
||||
self.fill_match("optimizer.params.betas", [adam_beta1, adam_beta2])
|
||||
self.fill_match("optimizer.params.eps", adam_epsilon)
|
||||
self.fill_match("optimizer.params.weight_decay", weight_decay)
|
||||
|
||||
self.fill_only("scheduler.params.warmup_min_lr", 0) # not a trainer arg
|
||||
self.fill_match("scheduler.params.warmup_max_lr", lr)
|
||||
# total_num_steps - will get set in trainer_config_finalize
|
||||
|
||||
args.fp16 = args.train.get("use_fp16", False)
|
||||
args.fp16_full_eval = args.train.get("use_fp16", False)
|
||||
args.fp16_backend = args.train.get("fp16_backend", "amp")
|
||||
# fp16
|
||||
if args.fp16 or args.fp16_full_eval:
|
||||
fp16_backend = "apex" if args.fp16_backend == "apex" else "amp"
|
||||
else:
|
||||
fp16_backend = None
|
||||
|
||||
args.save_on_each_node = args.train.get("save_on_each_node", False)
|
||||
if args.save_on_each_node:
|
||||
# deepspeed uses shared storage by default. Let's override this setting if save_on_each_node == True
|
||||
self.config["checkpoint"] = self.config.get("checkpoint", {})
|
||||
self.config["checkpoint"]["use_node_local_storage"] = args.save_on_each_node
|
||||
|
||||
# amp: similar to the pytorch native amp - it has a bunch of optional params but we won't set
|
||||
# any here unless the user did the work
|
||||
self.fill_match(
|
||||
"fp16.enabled",
|
||||
((args.fp16 or args.fp16_full_eval) and fp16_backend == "amp"),
|
||||
"fp16|fp16_full_eval+fp16_backend(amp)",
|
||||
)
|
||||
|
||||
args.fp16_opt_level = args.train.get("fp16_opt_level", None)
|
||||
args.fp16_opt_level = next((item["opt_level"] for item in args.train.hooks if item["type"] == "ApexAMPOptimizerHook"), args.fp16_opt_level)
|
||||
if not args.fp16_opt_level:
|
||||
args.fp16_opt_level = "O1"
|
||||
# apex: delegates amp work to apex (which needs to be available), but it cannot be used with any
|
||||
# ZeRO features
|
||||
self.fill_match("amp.enabled", fp16_backend == "apex", "fp16+fp16_backend(apex)")
|
||||
self.fill_match("amp.opt_level", args.fp16_opt_level, "fp16_opt_level")
|
||||
|
||||
args.bf16 = args.train.get("bf16", False)
|
||||
self.fill_match("bf16.enabled", (args.bf16 or args.bf16_full_eval), "bf16|bf16_full_eval")
|
||||
|
||||
# deepspeed's default mode is fp16 unless there is a config that says differently
|
||||
if self.is_true("bf16.enabled"):
|
||||
self._dtype = torch.bfloat16
|
||||
elif self.is_false("fp16.enabled"):
|
||||
self._dtype = torch.float32
|
||||
else:
|
||||
self._dtype = torch.float16
|
||||
|
||||
def trainer_config_finalize(self, args, model, num_training_steps):
|
||||
"""
|
||||
This stage is run after we have the model and know num_training_steps.
|
||||
@@ -267,10 +145,26 @@ class DeepspeedHook(Hook):
|
||||
name='CheckpointHook.should_save_on_rank',
|
||||
function=self.should_save_on_rank)
|
||||
|
||||
def should_save_on_rank(self, trainer):
|
||||
# TODO
|
||||
return (not torch.distributed.is_initialized()
|
||||
) or mpu.get_data_parallel_rank() == 0
|
||||
|
||||
def wrap_module(self, trainer):
|
||||
# deepspeed initializes its own ddp
|
||||
self.wrapped = True
|
||||
|
||||
def rank_name(self):
|
||||
# TODO
|
||||
try:
|
||||
tp_world_size = mpu.get_tensor_model_parallel_world_size()
|
||||
if tp_world_size == 1:
|
||||
return ''
|
||||
mp_rank = mpu.get_tensor_model_parallel_rank()
|
||||
return '_mp_rank_{:02d}'.format(mp_rank)
|
||||
except (ImportError, AssertionError):
|
||||
return ''
|
||||
|
||||
def backward(self, trainer, loss_keys, cumulative_iters, grad_clip):
|
||||
# assert cumulative_iters == 1, 'DeepSpeed only support cumulative_iters=1'
|
||||
# The `trainer.model` here is actually a deepspeed engine object.
|
||||
@@ -385,9 +279,30 @@ class DeepspeedHook(Hook):
|
||||
trainer.device = f'cuda:{device_id}'
|
||||
#trainer.parallel_groups[DistributedParallelType.DP] = None
|
||||
|
||||
def prepare_for_init(self, trainer):
|
||||
def prepare_args(self, args):
|
||||
args.per_device_train_batch_size = args.train.dataloader.get("batch_size_per_gpu", 4)
|
||||
args.gradient_accumulation_steps = args.train.get("gradient_accumulation_steps", 1)
|
||||
args.max_grad_norm = args.train.get("clip_grad", 1.0)
|
||||
args.learning_rate = args.train.optimizer.get("lr", 2e-5)
|
||||
args.adam_beta1 = args.train.optimizer.get("adam_beta1", 0.9)
|
||||
args.adam_beta2 = args.train.optimizer.get("adam_beta2", 0.999)
|
||||
args.adam_epsilon = args.train.optimizer.get("adam_epsilon", 1e-8)
|
||||
args.weight_decay = args.train.optimizer.get("weight_decay", 0.0)
|
||||
args.fp16 = args.train.get("use_fp16", False)
|
||||
args.fp16_full_eval = args.train.get("use_fp16", False)
|
||||
args.fp16_backend = args.train.get("fp16_backend", "amp")
|
||||
args.save_on_each_node = args.train.get("save_on_each_node", False)
|
||||
args.fp16_opt_level = args.train.get("fp16_opt_level", None)
|
||||
args.fp16_opt_level = next((item["opt_level"] for item in args.train.hooks if item["type"] == "ApexAMPOptimizerHook"), args.fp16_opt_level)
|
||||
if not args.fp16_opt_level:
|
||||
args.fp16_opt_level = "O1"
|
||||
args.bf16 = args.train.get("bf16", False)
|
||||
|
||||
|
||||
def get_deepspeed_config(self, trainer, max_steps):
|
||||
args = trainer.cfg
|
||||
_, args.world_size = get_dist_info()
|
||||
self.prepare_args(args)
|
||||
if os.path.exists(self.deepspeed_config):
|
||||
deepspeed_config = self.deepspeed_config
|
||||
else:
|
||||
@@ -395,17 +310,23 @@ class DeepspeedHook(Hook):
|
||||
self.deepspeed_config)
|
||||
self.logger.info(f"Loading deepspeed config from {deepspeed_config}")
|
||||
|
||||
gradient_accumulation_steps = args.train.get("gradient_accumulation_steps", 8)
|
||||
num_update_steps_per_epoch = trainer.iters_per_epoch // gradient_accumulation_steps
|
||||
max_steps = math.ceil(trainer._max_epochs * num_update_steps_per_epoch)
|
||||
|
||||
ds_config = DeepSpeedConfig(deepspeed_config)
|
||||
ds_config.trainer_config_process(args)
|
||||
|
||||
ds_config.trainer_config_finalize(args, trainer.model, max_steps)
|
||||
optimizer, lr_scheduler = deepspeed_optim_sched(trainer, ds_config, max_steps)
|
||||
config = ds_config.config
|
||||
return config, optimizer, lr_scheduler
|
||||
return ds_config
|
||||
|
||||
# def prepare_for_init(self, trainer):
|
||||
|
||||
# gradient_accumulation_steps = trainer.cfg.train.get("gradient_accumulation_steps", 1)
|
||||
# num_update_steps_per_epoch = trainer.iters_per_epoch // gradient_accumulation_steps
|
||||
# max_steps = math.ceil(trainer._max_epochs * num_update_steps_per_epoch)
|
||||
|
||||
# ds_config = self.get_deepspeed_config(trainer, max_steps)
|
||||
|
||||
# optimizer, lr_scheduler = deepspeed_optim_sched(trainer, ds_config, max_steps)
|
||||
# config = ds_config.config
|
||||
# return config, optimizer, lr_scheduler
|
||||
|
||||
def before_run(self, trainer):
|
||||
if not hasattr(trainer, 'logger'):
|
||||
@@ -414,8 +335,15 @@ class DeepspeedHook(Hook):
|
||||
self.logger = trainer.logger
|
||||
|
||||
# deepspeed init
|
||||
gradient_accumulation_steps = trainer.cfg.train.get("gradient_accumulation_steps", 1)
|
||||
num_update_steps_per_epoch = trainer.iters_per_epoch // gradient_accumulation_steps
|
||||
max_steps = math.ceil(trainer._max_epochs * num_update_steps_per_epoch)
|
||||
|
||||
config, optimizer, lr_scheduler = self.prepare_for_init(trainer)
|
||||
ds_config = self.get_deepspeed_config(trainer, max_steps)
|
||||
|
||||
optimizer, lr_scheduler = deepspeed_optim_sched(trainer, ds_config, max_steps)
|
||||
config = ds_config.config
|
||||
|
||||
# TODO: 判断是否需要dist_init 和 mpu 而非写死;
|
||||
trainer.model, trainer.optimizer, _, trainer.lr_scheduler = deepspeed.initialize(
|
||||
model=trainer.model,
|
||||
|
||||
@@ -48,7 +48,7 @@ if __name__ == '__main__':
|
||||
def cfg_modify_fn(cfg):
|
||||
cfg.train.lr_scheduler = {
|
||||
'type': 'CosineAnnealingLR',
|
||||
'T_max': 3,
|
||||
'T_max': 1,
|
||||
'options': {
|
||||
'by_epoch': False
|
||||
}
|
||||
@@ -119,7 +119,7 @@ if __name__ == '__main__':
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=None,
|
||||
data_collator=data_collator,
|
||||
max_epochs=3,
|
||||
max_epochs=1,
|
||||
launcher='pytorch',
|
||||
work_dir="/run/model/ms_out",
|
||||
cfg_modify_fn=cfg_modify_fn)
|
||||
|
||||
Reference in New Issue
Block a user