mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 12:10:09 +01:00
refactor
This commit is contained in:
@@ -24,46 +24,46 @@ import deepspeed
|
||||
def add_model_config_args(parser):
|
||||
"""Model arguments"""
|
||||
|
||||
group = parser.add_argument_group('model', 'model configuration')
|
||||
# group = parser.add_argument_group('model', 'model configuration')
|
||||
|
||||
group.add_argument('--pretrained-bert', action='store_true',
|
||||
help='use a pretrained bert-large-uncased model instead'
|
||||
'of initializing from scratch. See '
|
||||
'--tokenizer-model-type to specify which pretrained '
|
||||
'BERT model to use')
|
||||
group.add_argument('--attention-dropout', type=float, default=0.1,
|
||||
help='dropout probability for attention weights')
|
||||
group.add_argument('--num-attention-heads', type=int, default=16,
|
||||
help='num of transformer attention heads')
|
||||
group.add_argument('--hidden-size', type=int, default=1024,
|
||||
help='tansformer hidden size')
|
||||
group.add_argument('--intermediate-size', type=int, default=None,
|
||||
help='transformer embedding dimension for FFN'
|
||||
'set to 4*`--hidden-size` if it is None')
|
||||
group.add_argument('--num-layers', type=int, default=24,
|
||||
help='num decoder layers')
|
||||
group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
|
||||
help='layer norm epsilon')
|
||||
group.add_argument('--hidden-dropout', type=float, default=0.1,
|
||||
help='dropout probability for hidden state transformer')
|
||||
group.add_argument('--max-position-embeddings', type=int, default=512,
|
||||
help='maximum number of position embeddings to use')
|
||||
group.add_argument('--vocab-size', type=int, default=30522,
|
||||
help='vocab size to use for non-character-level '
|
||||
'tokenization. This value will only be used when '
|
||||
'creating a tokenizer')
|
||||
group.add_argument('--deep-init', action='store_true',
|
||||
help='initialize bert model similar to gpt2 model.'
|
||||
'scales initialization of projection layers by a '
|
||||
'factor of 1/sqrt(2N). Necessary to train bert '
|
||||
'models larger than BERT-Large.')
|
||||
group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
|
||||
help='Pad the vocab size to be divisible by this value.'
|
||||
'This is added for computational efficieny reasons.')
|
||||
group.add_argument('--cpu-optimizer', action='store_true',
|
||||
help='Run optimizer on CPU')
|
||||
group.add_argument('--cpu_torch_adam', action='store_true',
|
||||
help='Use Torch Adam as optimizer on CPU.')
|
||||
# group.add_argument('--pretrained-bert', action='store_true',
|
||||
# help='use a pretrained bert-large-uncased model instead'
|
||||
# 'of initializing from scratch. See '
|
||||
# '--tokenizer-model-type to specify which pretrained '
|
||||
# 'BERT model to use')
|
||||
# group.add_argument('--attention-dropout', type=float, default=0.1,
|
||||
# help='dropout probability for attention weights')
|
||||
# group.add_argument('--num-attention-heads', type=int, default=16,
|
||||
# help='num of transformer attention heads')
|
||||
# group.add_argument('--hidden-size', type=int, default=1024,
|
||||
# help='tansformer hidden size')
|
||||
# group.add_argument('--intermediate-size', type=int, default=None,
|
||||
# help='transformer embedding dimension for FFN'
|
||||
# 'set to 4*`--hidden-size` if it is None')
|
||||
# group.add_argument('--num-layers', type=int, default=24,
|
||||
# help='num decoder layers')
|
||||
# group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
|
||||
# help='layer norm epsilon')
|
||||
# group.add_argument('--hidden-dropout', type=float, default=0.1,
|
||||
# help='dropout probability for hidden state transformer')
|
||||
# group.add_argument('--max-position-embeddings', type=int, default=512,
|
||||
# help='maximum number of position embeddings to use')
|
||||
# group.add_argument('--vocab-size', type=int, default=30522,
|
||||
# help='vocab size to use for non-character-level '
|
||||
# 'tokenization. This value will only be used when '
|
||||
# 'creating a tokenizer')
|
||||
# group.add_argument('--deep-init', action='store_true',
|
||||
# help='initialize bert model similar to gpt2 model.'
|
||||
# 'scales initialization of projection layers by a '
|
||||
# 'factor of 1/sqrt(2N). Necessary to train bert '
|
||||
# 'models larger than BERT-Large.')
|
||||
# group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
|
||||
# help='Pad the vocab size to be divisible by this value.'
|
||||
# 'This is added for computational efficieny reasons.')
|
||||
# group.add_argument('--cpu-optimizer', action='store_true',
|
||||
# help='Run optimizer on CPU')
|
||||
# group.add_argument('--cpu_torch_adam', action='store_true',
|
||||
# help='Use Torch Adam as optimizer on CPU.')
|
||||
|
||||
return parser
|
||||
|
||||
@@ -71,28 +71,28 @@ def add_model_config_args(parser):
|
||||
def add_fp16_config_args(parser):
|
||||
"""Mixed precision arguments."""
|
||||
|
||||
group = parser.add_argument_group('fp16', 'fp16 configurations')
|
||||
# group = parser.add_argument_group('fp16', 'fp16 configurations')
|
||||
|
||||
group.add_argument('--fp16', action='store_true',
|
||||
help='Run model in fp16 mode')
|
||||
group.add_argument('--fp32-embedding', action='store_true',
|
||||
help='embedding in fp32')
|
||||
group.add_argument('--fp32-layernorm', action='store_true',
|
||||
help='layer norm in fp32')
|
||||
group.add_argument('--fp32-tokentypes', action='store_true',
|
||||
help='embedding token types in fp32')
|
||||
group.add_argument('--fp32-allreduce', action='store_true',
|
||||
help='all-reduce in fp32')
|
||||
group.add_argument('--hysteresis', type=int, default=2,
|
||||
help='hysteresis for dynamic loss scaling')
|
||||
group.add_argument('--loss-scale', type=float, default=None,
|
||||
help='Static loss scaling, positive power of 2 '
|
||||
'values can improve fp16 convergence. If None, dynamic'
|
||||
'loss scaling is used.')
|
||||
group.add_argument('--loss-scale-window', type=float, default=1000,
|
||||
help='Window over which to raise/lower dynamic scale')
|
||||
group.add_argument('--min-scale', type=float, default=1,
|
||||
help='Minimum loss scale for dynamic loss scale')
|
||||
# group.add_argument('--fp16', action='store_true',
|
||||
# help='Run model in fp16 mode')
|
||||
# group.add_argument('--fp32-embedding', action='store_true',
|
||||
# help='embedding in fp32')
|
||||
# group.add_argument('--fp32-layernorm', action='store_true',
|
||||
# help='layer norm in fp32')
|
||||
# group.add_argument('--fp32-tokentypes', action='store_true',
|
||||
# help='embedding token types in fp32')
|
||||
# group.add_argument('--fp32-allreduce', action='store_true',
|
||||
# help='all-reduce in fp32')
|
||||
# group.add_argument('--hysteresis', type=int, default=2,
|
||||
# help='hysteresis for dynamic loss scaling')
|
||||
# group.add_argument('--loss-scale', type=float, default=None,
|
||||
# help='Static loss scaling, positive power of 2 '
|
||||
# 'values can improve fp16 convergence. If None, dynamic'
|
||||
# 'loss scaling is used.')
|
||||
# group.add_argument('--loss-scale-window', type=float, default=1000,
|
||||
# help='Window over which to raise/lower dynamic scale')
|
||||
# group.add_argument('--min-scale', type=float, default=1,
|
||||
# help='Minimum loss scale for dynamic loss scale')
|
||||
|
||||
return parser
|
||||
|
||||
@@ -100,91 +100,91 @@ def add_fp16_config_args(parser):
|
||||
def add_training_args(parser):
|
||||
"""Training arguments."""
|
||||
|
||||
group = parser.add_argument_group('train', 'training configurations')
|
||||
# group = parser.add_argument_group('train', 'training configurations')
|
||||
|
||||
group.add_argument('--batch-size', type=int, default=4,
|
||||
help='Data Loader batch size')
|
||||
group.add_argument('--weight-decay', type=float, default=0.01,
|
||||
help='weight decay coefficient for L2 regularization')
|
||||
group.add_argument('--checkpoint-activations', action='store_true',
|
||||
help='checkpoint activation to allow for training '
|
||||
'with larger models and sequences')
|
||||
group.add_argument('--checkpoint-num-layers', type=int, default=1,
|
||||
help='chunk size (number of layers) for checkpointing')
|
||||
group.add_argument('--deepspeed-activation-checkpointing', action='store_true',
|
||||
help='uses activation checkpointing from deepspeed')
|
||||
group.add_argument('--clip-grad', type=float, default=1.0,
|
||||
help='gradient clipping')
|
||||
group.add_argument('--train-iters', type=int, default=1000000,
|
||||
help='total number of iterations to train over all training runs')
|
||||
group.add_argument('--log-interval', type=int, default=100,
|
||||
help='report interval')
|
||||
group.add_argument('--exit-interval', type=int, default=None,
|
||||
help='Exit the program after this many new iterations.')
|
||||
|
||||
group.add_argument('--seed', type=int, default=1234,
|
||||
help='random seed')
|
||||
# Batch prodecuer arguments
|
||||
group.add_argument('--reset-position-ids', action='store_true',
|
||||
help='Reset posistion ids after end-of-document token.')
|
||||
group.add_argument('--reset-attention-mask', action='store_true',
|
||||
help='Reset self attention maske after '
|
||||
'end-of-document token.')
|
||||
|
||||
# Learning rate.
|
||||
group.add_argument('--lr-decay-iters', type=int, default=None,
|
||||
help='number of iterations to decay LR over,'
|
||||
' If None defaults to `--train-iters`*`--epochs`')
|
||||
group.add_argument('--lr-decay-style', type=str, default='linear',
|
||||
choices=['constant', 'linear', 'cosine', 'exponential'],
|
||||
help='learning rate decay function')
|
||||
group.add_argument('--lr', type=float, default=1.0e-4,
|
||||
help='initial learning rate')
|
||||
group.add_argument('--warmup', type=float, default=0.01,
|
||||
help='percentage of data to warmup on (.01 = 1% of all '
|
||||
'training iters). Default 0.01')
|
||||
group.add_argument('--batch-warmup', type=float, default=0.01,
|
||||
help='percentage of data to warmup on (.01 = 1% of all '
|
||||
'training iters). Default 0.01')
|
||||
group.add_argument('--length-warmup', type=float, default=0.01,
|
||||
help='percentage of data to warmup on (.01 = 1% of all '
|
||||
'training iters). Default 0.01')
|
||||
# model checkpointing
|
||||
group.add_argument('--save', type=str, default=None,
|
||||
help='Output directory to save checkpoints to.')
|
||||
group.add_argument('--save-interval', type=int, default=2000,
|
||||
help='number of iterations between saves')
|
||||
group.add_argument('--no-save-optim', action='store_true',
|
||||
help='Do not save current optimizer.')
|
||||
group.add_argument('--no-save-rng', action='store_true',
|
||||
help='Do not save current rng state.')
|
||||
group.add_argument('--load', type=str, default=None,
|
||||
help='Path to a directory containing a model checkpoint.')
|
||||
group.add_argument('--load-iteration', type=str, default=0,
|
||||
help='Load iteration of a model checkpoint.')
|
||||
group.add_argument('--pre-load', action='store_true',
|
||||
help='Use pre-load instead of deepspeed load.')
|
||||
group.add_argument('--no-load-optim', action='store_true',
|
||||
help='Do not load optimizer when loading checkpoint.')
|
||||
group.add_argument('--no-load-rng', action='store_true',
|
||||
help='Do not load rng state when loading checkpoint.')
|
||||
group.add_argument('--no-load-lr', action='store_true',
|
||||
help='Do not load lr schedule when loading checkpoint.')
|
||||
group.add_argument('--finetune', action='store_true',
|
||||
help='Load model for finetuning. Do not load optimizer '
|
||||
'or rng state from checkpoint and set iteration to 0. '
|
||||
'Assumed when loading a release checkpoint.')
|
||||
group.add_argument('--resume-dataloader', action='store_true',
|
||||
help='Resume the dataloader when resuming training. '
|
||||
'Does not apply to tfrecords dataloader, try resuming'
|
||||
'with a different seed in this case.')
|
||||
# distributed training args
|
||||
group.add_argument('--distributed-backend', default='nccl',
|
||||
help='which backend to use for distributed '
|
||||
'training. One of [gloo, nccl]')
|
||||
|
||||
group.add_argument('--local_rank', type=int, default=None,
|
||||
help='local rank passed from distributed launcher')
|
||||
# group.add_argument('--batch-size', type=int, default=4,
|
||||
# help='Data Loader batch size')
|
||||
# group.add_argument('--weight-decay', type=float, default=0.01,
|
||||
# help='weight decay coefficient for L2 regularization')
|
||||
# group.add_argument('--checkpoint-activations', action='store_true',
|
||||
# help='checkpoint activation to allow for training '
|
||||
# 'with larger models and sequences')
|
||||
# group.add_argument('--checkpoint-num-layers', type=int, default=1,
|
||||
# help='chunk size (number of layers) for checkpointing')
|
||||
# group.add_argument('--deepspeed-activation-checkpointing', action='store_true',
|
||||
# help='uses activation checkpointing from deepspeed')
|
||||
# group.add_argument('--clip-grad', type=float, default=1.0,
|
||||
# help='gradient clipping')
|
||||
# group.add_argument('--train-iters', type=int, default=1000000,
|
||||
# help='total number of iterations to train over all training runs')
|
||||
# group.add_argument('--log-interval', type=int, default=100,
|
||||
# help='report interval')
|
||||
# group.add_argument('--exit-interval', type=int, default=None,
|
||||
# help='Exit the program after this many new iterations.')
|
||||
#
|
||||
# group.add_argument('--seed', type=int, default=1234,
|
||||
# help='random seed')
|
||||
# # Batch prodecuer arguments
|
||||
# group.add_argument('--reset-position-ids', action='store_true',
|
||||
# help='Reset posistion ids after end-of-document token.')
|
||||
# group.add_argument('--reset-attention-mask', action='store_true',
|
||||
# help='Reset self attention maske after '
|
||||
# 'end-of-document token.')
|
||||
#
|
||||
# # Learning rate.
|
||||
# group.add_argument('--lr-decay-iters', type=int, default=None,
|
||||
# help='number of iterations to decay LR over,'
|
||||
# ' If None defaults to `--train-iters`*`--epochs`')
|
||||
# group.add_argument('--lr-decay-style', type=str, default='linear',
|
||||
# choices=['constant', 'linear', 'cosine', 'exponential'],
|
||||
# help='learning rate decay function')
|
||||
# group.add_argument('--lr', type=float, default=1.0e-4,
|
||||
# help='initial learning rate')
|
||||
# group.add_argument('--warmup', type=float, default=0.01,
|
||||
# help='percentage of data to warmup on (.01 = 1% of all '
|
||||
# 'training iters). Default 0.01')
|
||||
# group.add_argument('--batch-warmup', type=float, default=0.01,
|
||||
# help='percentage of data to warmup on (.01 = 1% of all '
|
||||
# 'training iters). Default 0.01')
|
||||
# group.add_argument('--length-warmup', type=float, default=0.01,
|
||||
# help='percentage of data to warmup on (.01 = 1% of all '
|
||||
# 'training iters). Default 0.01')
|
||||
# # model checkpointing
|
||||
# group.add_argument('--save', type=str, default=None,
|
||||
# help='Output directory to save checkpoints to.')
|
||||
# group.add_argument('--save-interval', type=int, default=2000,
|
||||
# help='number of iterations between saves')
|
||||
# group.add_argument('--no-save-optim', action='store_true',
|
||||
# help='Do not save current optimizer.')
|
||||
# group.add_argument('--no-save-rng', action='store_true',
|
||||
# help='Do not save current rng state.')
|
||||
# group.add_argument('--load', type=str, default=None,
|
||||
# help='Path to a directory containing a model checkpoint.')
|
||||
# group.add_argument('--load-iteration', type=str, default=0,
|
||||
# help='Load iteration of a model checkpoint.')
|
||||
# group.add_argument('--pre-load', action='store_true',
|
||||
# help='Use pre-load instead of deepspeed load.')
|
||||
# group.add_argument('--no-load-optim', action='store_true',
|
||||
# help='Do not load optimizer when loading checkpoint.')
|
||||
# group.add_argument('--no-load-rng', action='store_true',
|
||||
# help='Do not load rng state when loading checkpoint.')
|
||||
# group.add_argument('--no-load-lr', action='store_true',
|
||||
# help='Do not load lr schedule when loading checkpoint.')
|
||||
# group.add_argument('--finetune', action='store_true',
|
||||
# help='Load model for finetuning. Do not load optimizer '
|
||||
# 'or rng state from checkpoint and set iteration to 0. '
|
||||
# 'Assumed when loading a release checkpoint.')
|
||||
# group.add_argument('--resume-dataloader', action='store_true',
|
||||
# help='Resume the dataloader when resuming training. '
|
||||
# 'Does not apply to tfrecords dataloader, try resuming'
|
||||
# 'with a different seed in this case.')
|
||||
# # distributed training args
|
||||
# group.add_argument('--distributed-backend', default='nccl',
|
||||
# help='which backend to use for distributed '
|
||||
# 'training. One of [gloo, nccl]')
|
||||
#
|
||||
# group.add_argument('--local_rank', type=int, default=None,
|
||||
# help='local rank passed from distributed launcher')
|
||||
|
||||
return parser
|
||||
|
||||
@@ -192,164 +192,164 @@ def add_training_args(parser):
|
||||
def add_evaluation_args(parser):
|
||||
"""Evaluation arguments."""
|
||||
|
||||
group = parser.add_argument_group('validation', 'validation configurations')
|
||||
# group = parser.add_argument_group('validation', 'validation configurations')
|
||||
|
||||
group.add_argument('--eval-batch-size', type=int, default=None,
|
||||
help='Data Loader batch size for evaluation datasets.'
|
||||
'Defaults to `--batch-size`')
|
||||
group.add_argument('--eval-iters', type=int, default=100,
|
||||
help='number of iterations to run for evaluation'
|
||||
'validation/test for')
|
||||
group.add_argument('--eval-interval', type=int, default=1000,
|
||||
help='interval between running evaluation on validation set')
|
||||
group.add_argument('--eval-seq-length', type=int, default=None,
|
||||
help='Maximum sequence length to process for '
|
||||
'evaluation. Defaults to `--seq-length`')
|
||||
group.add_argument('--eval-max-preds-per-seq', type=int, default=None,
|
||||
help='Maximum number of predictions to use for '
|
||||
'evaluation. Defaults to '
|
||||
'math.ceil(`--eval-seq-length`*.15/10)*10')
|
||||
group.add_argument('--overlapping-eval', type=int, default=32,
|
||||
help='sliding window for overlapping eval ')
|
||||
group.add_argument('--cloze-eval', action='store_true',
|
||||
help='Evaluation dataset from `--valid-data` is a cloze task')
|
||||
group.add_argument('--eval-hf', action='store_true',
|
||||
help='perform evaluation with huggingface openai model.'
|
||||
'use `--load` to specify weights path to be loaded')
|
||||
group.add_argument('--load-openai', action='store_true',
|
||||
help='load openai weights into our model. Use `--load` '
|
||||
'to specify weights path to be loaded')
|
||||
# group.add_argument('--eval-batch-size', type=int, default=None,
|
||||
# help='Data Loader batch size for evaluation datasets.'
|
||||
# 'Defaults to `--batch-size`')
|
||||
# group.add_argument('--eval-iters', type=int, default=100,
|
||||
# help='number of iterations to run for evaluation'
|
||||
# 'validation/test for')
|
||||
# group.add_argument('--eval-interval', type=int, default=1000,
|
||||
# help='interval between running evaluation on validation set')
|
||||
# group.add_argument('--eval-seq-length', type=int, default=None,
|
||||
# help='Maximum sequence length to process for '
|
||||
# 'evaluation. Defaults to `--seq-length`')
|
||||
# group.add_argument('--eval-max-preds-per-seq', type=int, default=None,
|
||||
# help='Maximum number of predictions to use for '
|
||||
# 'evaluation. Defaults to '
|
||||
# 'math.ceil(`--eval-seq-length`*.15/10)*10')
|
||||
# group.add_argument('--overlapping-eval', type=int, default=32,
|
||||
# help='sliding window for overlapping eval ')
|
||||
# group.add_argument('--cloze-eval', action='store_true',
|
||||
# help='Evaluation dataset from `--valid-data` is a cloze task')
|
||||
# group.add_argument('--eval-hf', action='store_true',
|
||||
# help='perform evaluation with huggingface openai model.'
|
||||
# 'use `--load` to specify weights path to be loaded')
|
||||
# group.add_argument('--load-openai', action='store_true',
|
||||
# help='load openai weights into our model. Use `--load` '
|
||||
# 'to specify weights path to be loaded')
|
||||
|
||||
return parser
|
||||
|
||||
def add_text_generate_args(parser):
|
||||
"""Text generate arguments."""
|
||||
|
||||
group = parser.add_argument_group('Text generation', 'configurations')
|
||||
group.add_argument("--temperature", type=float, default=1.0)
|
||||
group.add_argument("--top_p", type=float, default=0.0)
|
||||
group.add_argument("--top_k", type=int, default=0)
|
||||
group.add_argument("--out-seq-length", type=int, default=256)
|
||||
# group = parser.add_argument_group('Text generation', 'configurations')
|
||||
# group.add_argument("--temperature", type=float, default=1.0)
|
||||
# group.add_argument("--top_p", type=float, default=0.0)
|
||||
# group.add_argument("--top_k", type=int, default=0)
|
||||
# group.add_argument("--out-seq-length", type=int, default=256)
|
||||
return parser
|
||||
|
||||
def add_struct_args(parser):
|
||||
group = parser.add_argument_group('struct', 'struct configurations')
|
||||
group.add_argument("--gradient-accumulation-steps", type=int, default=1,
|
||||
help='Not Imp yet.')
|
||||
group.add_argument("--num-epochs", type=int, default=1,
|
||||
help='Not Imp yet.')
|
||||
group.add_argument("--struct-bert-dataset", action='store_true', default=False,
|
||||
help='Use struct bert dataset or not.')
|
||||
# group = parser.add_argument_group('struct', 'struct configurations')
|
||||
# group.add_argument("--gradient-accumulation-steps", type=int, default=1,
|
||||
# help='Not Imp yet.')
|
||||
# group.add_argument("--num-epochs", type=int, default=1,
|
||||
# help='Not Imp yet.')
|
||||
# group.add_argument("--struct-bert-dataset", action='store_true', default=False,
|
||||
# help='Use struct bert dataset or not.')
|
||||
return parser
|
||||
|
||||
def add_palm_args(parser):
|
||||
group = parser.add_argument_group('palm', 'struct configurations')
|
||||
group.add_argument('--dec-layers', type=int, default=6,
|
||||
help='num decoder layers')
|
||||
group.add_argument('--tgt-length', type=int, default=100,
|
||||
help='num decoder layers')
|
||||
group.add_argument('--vae-size', type=int, default=8192,
|
||||
help='vae code vocab size')
|
||||
group.add_argument('--max-image-position', type=int, default=1025,
|
||||
help='max image decode position')
|
||||
group.add_argument("--palm-dataset", action='store_true', default=False,
|
||||
help='Use struct bert dataset or not.')
|
||||
group.add_argument("--image-dataset", action='store_true', default=False,
|
||||
help='Use struct bert dataset or not.')
|
||||
group.add_argument("--do-mask-lm", action='store_true', default=False,
|
||||
help='Do mask lm task or not.')
|
||||
group.add_argument('--vae-enc-model', type=str, default=None,
|
||||
help='Path to a directory containing a model checkpoint.')
|
||||
# group = parser.add_argument_group('palm', 'struct configurations')
|
||||
# group.add_argument('--dec-layers', type=int, default=6,
|
||||
# help='num decoder layers')
|
||||
# group.add_argument('--tgt-length', type=int, default=100,
|
||||
# help='num decoder layers')
|
||||
# group.add_argument('--vae-size', type=int, default=8192,
|
||||
# help='vae code vocab size')
|
||||
# group.add_argument('--max-image-position', type=int, default=1025,
|
||||
# help='max image decode position')
|
||||
# group.add_argument("--palm-dataset", action='store_true', default=False,
|
||||
# help='Use struct bert dataset or not.')
|
||||
# group.add_argument("--image-dataset", action='store_true', default=False,
|
||||
# help='Use struct bert dataset or not.')
|
||||
# group.add_argument("--do-mask-lm", action='store_true', default=False,
|
||||
# help='Do mask lm task or not.')
|
||||
# group.add_argument('--vae-enc-model', type=str, default=None,
|
||||
# help='Path to a directory containing a model checkpoint.')
|
||||
return parser
|
||||
|
||||
def add_downstream_args(parser):
|
||||
group = parser.add_argument_group('downstream', 'struct configurations')
|
||||
group.add_argument("--downstream-dataset", action='store_true', default=False,
|
||||
help='Use struct bert dataset or not.')
|
||||
group.add_argument("--task-name", default='ocnli', type=str)
|
||||
# group = parser.add_argument_group('downstream', 'struct configurations')
|
||||
# group.add_argument("--downstream-dataset", action='store_true', default=False,
|
||||
# help='Use struct bert dataset or not.')
|
||||
# group.add_argument("--task-name", default='ocnli', type=str)
|
||||
return parser
|
||||
|
||||
def add_data_args(parser):
|
||||
"""Train/valid/test data arguments."""
|
||||
|
||||
group = parser.add_argument_group('data', 'data configurations')
|
||||
|
||||
group.add_argument('--model-parallel-size', type=int, default=1,
|
||||
help='size of the model parallel.')
|
||||
group.add_argument('--shuffle', action='store_true',
|
||||
help='Shuffle data. Shuffling is deterministic '
|
||||
'based on seed and current epoch.')
|
||||
group.add_argument('--train-data', nargs='+', default=None,
|
||||
help='Whitespace separated filenames or corpora names '
|
||||
'for training.')
|
||||
|
||||
group.add_argument('--use-npy-data-loader', action='store_true',
|
||||
help='Use the numpy data loader. If set, then'
|
||||
'train-data-path, val-data-path, and test-data-path'
|
||||
'should also be provided.')
|
||||
group.add_argument('--train-data-path', type=str, default='',
|
||||
help='path to the training data')
|
||||
group.add_argument('--val-data-path', type=str, default='',
|
||||
help='path to the validation data')
|
||||
group.add_argument('--test-data-path', type=str, default='',
|
||||
help='path to the test data')
|
||||
group.add_argument('--input-data-sizes-file', type=str, default='sizes.txt',
|
||||
help='the filename containing all the shards sizes')
|
||||
|
||||
group.add_argument('--delim', default=',',
|
||||
help='delimiter used to parse csv data files')
|
||||
group.add_argument('--text-key', default='sentence',
|
||||
help='key to use to extract text from json/csv')
|
||||
group.add_argument('--eval-text-key', default=None,
|
||||
help='key to use to extract text from '
|
||||
'json/csv evaluation datasets')
|
||||
group.add_argument('--valid-data', nargs='*', default=None,
|
||||
help="""Filename for validation data.""")
|
||||
group.add_argument('--split', default='1000,1,1',
|
||||
help='comma-separated list of proportions for training,'
|
||||
' validation, and test split')
|
||||
group.add_argument('--test-data', nargs='*', default=None,
|
||||
help="""Filename for testing""")
|
||||
|
||||
group.add_argument('--lazy-loader', action='store_true',
|
||||
help='whether to lazy read the data set')
|
||||
group.add_argument('--loose-json', action='store_true',
|
||||
help='Use loose json (one json-formatted string per '
|
||||
'newline), instead of tight json (data file is one '
|
||||
'json string)')
|
||||
group.add_argument('--presplit-sentences', action='store_true',
|
||||
help='Dataset content consists of documents where '
|
||||
'each document consists of newline separated sentences')
|
||||
group.add_argument('--num-workers', type=int, default=2,
|
||||
help="""Number of workers to use for dataloading""")
|
||||
group.add_argument('--tokenizer-model-type', type=str,
|
||||
default='bert-large-uncased',
|
||||
help="Model type to use for sentencepiece tokenization \
|
||||
(one of ['bpe', 'char', 'unigram', 'word']) or \
|
||||
bert vocab to use for BertWordPieceTokenizer (one of \
|
||||
['bert-large-uncased', 'bert-large-cased', etc.])")
|
||||
group.add_argument('--tokenizer-path', type=str, default='tokenizer.model',
|
||||
help='path used to save/load sentencepiece tokenization '
|
||||
'models')
|
||||
group.add_argument('--tokenizer-type', type=str,
|
||||
default='BertWordPieceTokenizer',
|
||||
choices=['CharacterLevelTokenizer',
|
||||
'SentencePieceTokenizer',
|
||||
'BertWordPieceTokenizer',
|
||||
'GPT2BPETokenizer'],
|
||||
help='what type of tokenizer to use')
|
||||
group.add_argument("--cache-dir", default=None, type=str,
|
||||
help="Where to store pre-trained BERT downloads")
|
||||
group.add_argument('--use-tfrecords', action='store_true',
|
||||
help='load `--train-data`, `--valid-data`, '
|
||||
'`--test-data` from BERT tf records instead of '
|
||||
'normal data pipeline')
|
||||
group.add_argument('--seq-length', type=int, default=512,
|
||||
help="Maximum sequence length to process")
|
||||
group.add_argument('--max-preds-per-seq', type=int, default=None,
|
||||
help='Maximum number of predictions to use per sequence.'
|
||||
'Defaults to math.ceil(`--seq-length`*.15/10)*10.'
|
||||
'MUST BE SPECIFIED IF `--use-tfrecords` is True.')
|
||||
# group = parser.add_argument_group('data', 'data configurations')
|
||||
#
|
||||
# group.add_argument('--model-parallel-size', type=int, default=1,
|
||||
# help='size of the model parallel.')
|
||||
# group.add_argument('--shuffle', action='store_true',
|
||||
# help='Shuffle data. Shuffling is deterministic '
|
||||
# 'based on seed and current epoch.')
|
||||
# group.add_argument('--train-data', nargs='+', default=None,
|
||||
# help='Whitespace separated filenames or corpora names '
|
||||
# 'for training.')
|
||||
#
|
||||
# group.add_argument('--use-npy-data-loader', action='store_true',
|
||||
# help='Use the numpy data loader. If set, then'
|
||||
# 'train-data-path, val-data-path, and test-data-path'
|
||||
# 'should also be provided.')
|
||||
# group.add_argument('--train-data-path', type=str, default='',
|
||||
# help='path to the training data')
|
||||
# group.add_argument('--val-data-path', type=str, default='',
|
||||
# help='path to the validation data')
|
||||
# group.add_argument('--test-data-path', type=str, default='',
|
||||
# help='path to the test data')
|
||||
# group.add_argument('--input-data-sizes-file', type=str, default='sizes.txt',
|
||||
# help='the filename containing all the shards sizes')
|
||||
#
|
||||
# group.add_argument('--delim', default=',',
|
||||
# help='delimiter used to parse csv data files')
|
||||
# group.add_argument('--text-key', default='sentence',
|
||||
# help='key to use to extract text from json/csv')
|
||||
# group.add_argument('--eval-text-key', default=None,
|
||||
# help='key to use to extract text from '
|
||||
# 'json/csv evaluation datasets')
|
||||
# group.add_argument('--valid-data', nargs='*', default=None,
|
||||
# help="""Filename for validation data.""")
|
||||
# group.add_argument('--split', default='1000,1,1',
|
||||
# help='comma-separated list of proportions for training,'
|
||||
# ' validation, and test split')
|
||||
# group.add_argument('--test-data', nargs='*', default=None,
|
||||
# help="""Filename for testing""")
|
||||
#
|
||||
# group.add_argument('--lazy-loader', action='store_true',
|
||||
# help='whether to lazy read the data set')
|
||||
# group.add_argument('--loose-json', action='store_true',
|
||||
# help='Use loose json (one json-formatted string per '
|
||||
# 'newline), instead of tight json (data file is one '
|
||||
# 'json string)')
|
||||
# group.add_argument('--presplit-sentences', action='store_true',
|
||||
# help='Dataset content consists of documents where '
|
||||
# 'each document consists of newline separated sentences')
|
||||
# group.add_argument('--num-workers', type=int, default=2,
|
||||
# help="""Number of workers to use for dataloading""")
|
||||
# group.add_argument('--tokenizer-model-type', type=str,
|
||||
# default='bert-large-uncased',
|
||||
# help="Model type to use for sentencepiece tokenization \
|
||||
# (one of ['bpe', 'char', 'unigram', 'word']) or \
|
||||
# bert vocab to use for BertWordPieceTokenizer (one of \
|
||||
# ['bert-large-uncased', 'bert-large-cased', etc.])")
|
||||
# group.add_argument('--tokenizer-path', type=str, default='tokenizer.model',
|
||||
# help='path used to save/load sentencepiece tokenization '
|
||||
# 'models')
|
||||
# group.add_argument('--tokenizer-type', type=str,
|
||||
# default='BertWordPieceTokenizer',
|
||||
# choices=['CharacterLevelTokenizer',
|
||||
# 'SentencePieceTokenizer',
|
||||
# 'BertWordPieceTokenizer',
|
||||
# 'GPT2BPETokenizer'],
|
||||
# help='what type of tokenizer to use')
|
||||
# group.add_argument("--cache-dir", default=None, type=str,
|
||||
# help="Where to store pre-trained BERT downloads")
|
||||
# group.add_argument('--use-tfrecords', action='store_true',
|
||||
# help='load `--train-data`, `--valid-data`, '
|
||||
# '`--test-data` from BERT tf records instead of '
|
||||
# 'normal data pipeline')
|
||||
# group.add_argument('--seq-length', type=int, default=512,
|
||||
# help="Maximum sequence length to process")
|
||||
# group.add_argument('--max-preds-per-seq', type=int, default=None,
|
||||
# help='Maximum number of predictions to use per sequence.'
|
||||
# 'Defaults to math.ceil(`--seq-length`*.15/10)*10.'
|
||||
# 'MUST BE SPECIFIED IF `--use-tfrecords` is True.')
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
53
modelscope/models/nlp/plug/config.json
Normal file
53
modelscope/models/nlp/plug/config.json
Normal file
@@ -0,0 +1,53 @@
|
||||
{
|
||||
"hidden_size": 8192,
|
||||
"intermediate_size": 32768,
|
||||
"num_hidden_layers": 1,
|
||||
"dec_hidden_layers": 1,
|
||||
"num_attention_heads": 128,
|
||||
"hidden_dropout_prob": 0.1,
|
||||
"attention_probs_dropout_prob": 0.1,
|
||||
"hidden_act": "gelu",
|
||||
"type_vocab_size": 3,
|
||||
"vocab_size": 21504,
|
||||
"original_vocab_size": 21128,
|
||||
"max_position_embeddings": 2048,
|
||||
"lr_decay_style": "linear",
|
||||
"lr": 3e-5,
|
||||
"weight_decay": 1e-2,
|
||||
"clip_grad": 1.0,
|
||||
"warmup": 0.0333,
|
||||
"layernorm_epsilon": 1e-5,
|
||||
"layer_norm_eps": 1e-5,
|
||||
"fp32_embedding": false,
|
||||
"fp32_tokentypes": false,
|
||||
"fp32_layernorm": true,
|
||||
"fp16": true,
|
||||
"pruning_method": "pest_block",
|
||||
"pruning_initial_threshold": 0.5,
|
||||
"pruning_final_threshold": 0.01,
|
||||
"pruning_initial_warmup": 1,
|
||||
"pruning_final_warmup": 9,
|
||||
"pruning_module": "decoder",
|
||||
"pruning_decay_step": 76,
|
||||
"pruning_decay_type": "exp",
|
||||
"pruning_mask_init": "constant",
|
||||
"pruning_mask_scale": 0.0,
|
||||
"LR_weight_rank": 16,
|
||||
"LR_mask_rank": 16,
|
||||
"LR_weight_alpha": 32,
|
||||
"LR_mask_alpha": 16,
|
||||
"attn_separate": false,
|
||||
"ft_module": "0",
|
||||
"model_type": "plug_nlg_original-8192",
|
||||
"load_iteration": 28000,
|
||||
"distributed_backend": "nccl",
|
||||
"deep_init": true,
|
||||
"deepspeed": true,
|
||||
"deepspeed_config": "ds_zero-offload_10B_config.json",
|
||||
"nni": true,
|
||||
"cpu_optimizer": true,
|
||||
"no_load_optim": true,
|
||||
"pre_load": true,
|
||||
"pre_ln": true,
|
||||
"attention-dropout": 0.1
|
||||
}
|
||||
69
modelscope/models/nlp/plug/configuration.json
Normal file
69
modelscope/models/nlp/plug/configuration.json
Normal file
@@ -0,0 +1,69 @@
|
||||
{
|
||||
"framework": "pytorch",
|
||||
"task": "text-generation",
|
||||
"preprocessor": {
|
||||
"type": "text-gen-tokenizer"
|
||||
},
|
||||
"model": {
|
||||
"type": "plug",
|
||||
"world_size": 8,
|
||||
"model_parallel_size": 8,
|
||||
"pre_load": true,
|
||||
"distributed_backend": "nccl",
|
||||
"checkpoint_activations": true,
|
||||
"top_k": 20,
|
||||
"top_p": 0.0,
|
||||
"temperature": 0.9,
|
||||
"seed": 42,
|
||||
"output_sequence_length": 128
|
||||
},
|
||||
"pipeline": {
|
||||
"type": "text-generation"
|
||||
},
|
||||
"train": {
|
||||
"work_dir": "/tmp",
|
||||
"max_epochs": 3,
|
||||
"dataloader": {
|
||||
"batch_size_per_gpu": 2,
|
||||
"workers_per_gpu": 1
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "SGD",
|
||||
"lr": 0.01,
|
||||
"options": {
|
||||
"grad_clip": {
|
||||
"max_norm": 2.0
|
||||
}
|
||||
}
|
||||
},
|
||||
"lr_scheduler": {
|
||||
"type": "StepLR",
|
||||
"step_size": 2,
|
||||
"options": {
|
||||
"warmup": {
|
||||
"type": "LinearWarmup",
|
||||
"warmup_iters": 2
|
||||
}
|
||||
}
|
||||
},
|
||||
"hooks": [{
|
||||
"type": "CheckpointHook",
|
||||
"interval": 1
|
||||
}, {
|
||||
"type": "TextLoggerHook",
|
||||
"interval": 1
|
||||
}, {
|
||||
"type": "IterTimerHook"
|
||||
}, {
|
||||
"type": "EvaluationHook",
|
||||
"interval": 1
|
||||
}]
|
||||
},
|
||||
"evaluation": {
|
||||
"dataloader": {
|
||||
"batch_size_per_gpu": 2,
|
||||
"workers_per_gpu": 1,
|
||||
"shuffle": false
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -18,8 +18,6 @@ import json
|
||||
import copy
|
||||
|
||||
""" BERT model configuration """
|
||||
from collections import OrderedDict
|
||||
from typing import Mapping
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
from modelscope.utils import logger as logging
|
||||
@@ -94,11 +92,12 @@ class PlugNLUConfig(PretrainedConfig):
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
"""
|
||||
model_type="plugNLU"
|
||||
model_type = "plugNLU"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=21504,
|
||||
original_vocab_size=21128,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
@@ -141,6 +140,7 @@ class PlugNLUConfig(PretrainedConfig):
|
||||
super().__init__(layer_norm_eps=layernorm_epsilon, **kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.original_vocab_size = original_vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
@@ -1,188 +1,136 @@
|
||||
import random
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from . import PlugModel
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modelscope.models import TorchModel
|
||||
from modelscope.models.base import Tensor
|
||||
from modelscope.utils.nlp import mpu
|
||||
from modelscope.utils.nlp.utils import print_rank_0
|
||||
from modelscope.utils.nlp.fp16 import FP16_Module
|
||||
from modelscope.utils.nlp.distributed import DistributedDataParallel as DDP
|
||||
from modelscope.utils.nlp.load_checkpoint import pre_load
|
||||
from modelscope.utils.nlp.utils import print_rank_0
|
||||
from modelscope.utils.torch_utils import set_random_seed_mpu
|
||||
from modelscope.utils.logger import get_logger
|
||||
from . import PlugModel
|
||||
from .configuration_plug import PlugNLGConfig
|
||||
from modelscope.models.nlp.utils.distributed import initialize_distributed
|
||||
|
||||
import os
|
||||
from modelscope.utils.torch_utils import init_dist
|
||||
def initialize_distributed(rank):
|
||||
"""Initialize torch.distributed."""
|
||||
# Manually set the device ids.
|
||||
#torch.multiprocessing.set_start_method("spawn")
|
||||
device = rank % torch.cuda.device_count()
|
||||
torch.cuda.set_device(device)
|
||||
# Call the init process
|
||||
init_method = 'tcp://'
|
||||
master_ip = os.getenv('MASTER_ADDR', '127.0.0.1')
|
||||
master_port = os.getenv('MASTER_PORT', '12345')
|
||||
init_method += master_ip + ':' + master_port
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl",
|
||||
world_size=8, rank=rank,
|
||||
init_method=init_method)
|
||||
# Set the model-parallel communicators.
|
||||
mpu.initialize_model_parallel(8)
|
||||
|
||||
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
|
||||
# This function has been mostly taken from huggingface conversational ai code at
|
||||
# https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
|
||||
|
||||
if top_k > 0:
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||
logits[indices_to_remove] = filter_value
|
||||
|
||||
if top_p > 0.0:
|
||||
#convert to 1D
|
||||
logits=logits.view(logits.size()[1]).contiguous()
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
|
||||
# Remove tokens with cumulative probability above the threshold
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
# Shift the indices to the right to keep also the first token above the threshold
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
||||
logits[indices_to_remove] = filter_value
|
||||
#going back to 2D
|
||||
logits=logits.view(1, -1).contiguous()
|
||||
|
||||
return logits
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DistributedPlug(TorchModel):
|
||||
|
||||
class DistributedPlug:
|
||||
@classmethod
|
||||
def init(cls, rank, model_dir, model_config, args):
|
||||
#def init(cls, rank):
|
||||
#torch.backends.cudnn.enabled = False
|
||||
#
|
||||
cls.rank = rank
|
||||
cls.args = args
|
||||
cls.config = model_config
|
||||
cls.model_dir = model_dir
|
||||
initialize_distributed(rank)
|
||||
cls.set_random_seed(cls, args.seed)
|
||||
cls.setup_model(cls, path_load_tag='model')
|
||||
def __init__(self, model_dir, rank, **kwargs):
|
||||
super().__init__(model_dir, **kwargs)
|
||||
self.rank = rank
|
||||
self.model_cfg = kwargs
|
||||
self.config = PlugNLGConfig.from_pretrained(model_dir)
|
||||
initialize_distributed(rank, mpu, kwargs['world_size'], kwargs['model_parallel_size'])
|
||||
if 'seed' in kwargs:
|
||||
set_random_seed_mpu(kwargs['seed'])
|
||||
self.iteration = 0
|
||||
self.dist_model = self.initialize_model(path_load_tag='model')
|
||||
|
||||
def set_random_seed(cls, seed):
|
||||
if seed is not None and seed > 0:
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
mpu.model_parallel_cuda_manual_seed(seed)
|
||||
|
||||
def get_model(cls):
|
||||
def initialize_model(self, path_load_tag='model'):
|
||||
"""Build the model."""
|
||||
|
||||
print_rank_0('Building Plug model. It will take a few minutes ...')
|
||||
model = PlugModel(cls.config)
|
||||
model = PlugModel(self.config)
|
||||
|
||||
if mpu.get_data_parallel_rank() == 0:
|
||||
print(' > number of parameters on model parallel rank {}: {}'.format(
|
||||
logger.info(' > number of parameters on model parallel rank {}: {}'.format(
|
||||
mpu.get_model_parallel_rank(),
|
||||
sum([p.nelement() for p in model.parameters()])), flush=True)
|
||||
sum([p.nelement() for p in model.parameters()])))
|
||||
|
||||
if cls.args.deepspeed and cls.args.fp16:
|
||||
model.half()
|
||||
if self.config.deepspeed and self.args.fp16:
|
||||
model.half()
|
||||
|
||||
# GPU allocation.
|
||||
model.cuda(torch.cuda.current_device())
|
||||
|
||||
# Fp16 conversion.
|
||||
if cls.args.fp16:
|
||||
if self.config.fp16:
|
||||
model = FP16_Module(model)
|
||||
if cls.args.fp32_embedding:
|
||||
if self.config.fp32_embedding:
|
||||
model.module.model.bert.embeddings.word_embeddings.float()
|
||||
model.module.model.bert.embeddings.position_embeddings.float()
|
||||
model.module.model.bert.embeddings.token_type_embeddings.float()
|
||||
if cls.args.fp32_tokentypes:
|
||||
if self.config.fp32_tokentypes:
|
||||
model.module.model.bert.embeddings.token_type_embeddings.float()
|
||||
if cls.args.fp32_layernorm:
|
||||
if self.config.fp32_layernorm:
|
||||
for name, _module in model.named_modules():
|
||||
if 'LayerNorm' in name:
|
||||
_module.float()
|
||||
|
||||
# model = DDP(model)
|
||||
|
||||
load_model = pre_load(mpu, self.model_dir, tag=path_load_tag)
|
||||
model_dict = model.module.model.state_dict()
|
||||
for key in load_model:
|
||||
if key not in model_dict.keys():
|
||||
print_rank_0('Skip key: ' + key)
|
||||
else:
|
||||
print_rank_0('Loading key: ' + key)
|
||||
model.module.model.load_state_dict(load_model, strict=False)
|
||||
return model
|
||||
|
||||
def setup_model(cls, path_load_tag='model'):
|
||||
dist_model = cls.get_model(cls)
|
||||
if cls.model_dir is not None:
|
||||
from modelscope.utils.nlp.load_checkpoint import pre_load
|
||||
load_model = pre_load(mpu, cls.model_dir, tag=path_load_tag)
|
||||
# model_dict = dist_model.module.module.model.state_dict()
|
||||
model_dict = dist_model.module.model.state_dict()
|
||||
for key in load_model:
|
||||
if key not in model_dict.keys():
|
||||
print_rank_0('Skip key: '+key)
|
||||
else:
|
||||
print_rank_0('Loading key: '+key)
|
||||
# dist_model.module.module.model.load_state_dict(load_model, strict=False)
|
||||
dist_model.module.model.load_state_dict(load_model, strict=False)
|
||||
cls.args.iteration = 0
|
||||
cls.dist_model = dist_model
|
||||
@staticmethod
|
||||
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
|
||||
# This function has been mostly taken from huggingface conversational ai code at
|
||||
# https://medium.com/huggingface/how-to-build-a-state-of-the-art-
|
||||
# conversational-ai-with-transfer-learning-2d818ac26313
|
||||
|
||||
@classmethod
|
||||
def forward(cls, input:Dict[str, Tensor]):
|
||||
if top_k > 0:
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||
logits[indices_to_remove] = filter_value
|
||||
|
||||
if top_p > 0.0:
|
||||
# convert to 1D
|
||||
logits = logits.view(logits.size()[1]).contiguous()
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
|
||||
# Remove tokens with cumulative probability above the threshold
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
# Shift the indices to the right to keep also the first token above the threshold
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
||||
logits[indices_to_remove] = filter_value
|
||||
# going back to 2D
|
||||
logits = logits.view(1, -1).contiguous()
|
||||
return logits
|
||||
|
||||
def forward(self, input: Dict[str, Tensor]):
|
||||
device = torch.cuda.current_device()
|
||||
batch_size = input["input_ids"].shape[0]
|
||||
tokens = input["input_ids"].to(device)
|
||||
dec_input_ids = input["dec_input_ids"].to(device)
|
||||
attention_mask = input["attention_mask"].to(device)
|
||||
cls.dist_model.eval()
|
||||
seq_length = 128
|
||||
self.dist_model.eval()
|
||||
seq_length = self.model_cfg['output_sequence_length']
|
||||
with torch.no_grad():
|
||||
all_generate_tokens = []
|
||||
generate_tokens = []
|
||||
counter = 0
|
||||
sequence_output = None
|
||||
vocab_size = 21128
|
||||
#tokens, attention_mask, types, dec_input_ids = get_batch(context_tokens_tensor, device, args)
|
||||
vocab_size = self.config.original_vocab_size
|
||||
while counter < seq_length:
|
||||
# if counter % 128 == 0 and counter != 0:
|
||||
# generate_tokens.append(tokenizer.vocab[args.sep_token])
|
||||
# start = (context_tokens_tensor == 102).nonzero(as_tuple=True)[-1]
|
||||
# if start + len(generate_tokens) >= 512:
|
||||
# context_tokens_tensor = torch.cat([context_tokens_tensor[:start], torch.cuda.LongTensor(generate_tokens)], -1)[-512:]
|
||||
# else:
|
||||
# context_tokens_tensor[start:start+len(generate_tokens)] = torch.cuda.LongTensor(generate_tokens)
|
||||
# tokens, attention_mask, types, dec_input_ids = get_batch(context_tokens_tensor, device, args)
|
||||
# generate_tokens = []
|
||||
# sequence_output = None
|
||||
|
||||
position_ids = torch.full([cls.args.batch_size, 1], len(generate_tokens), dtype=torch.long, device=device)
|
||||
_, logits, sequence_output = cls.dist_model(tokens, None, attention_mask, dec_input_ids, attention_mask, position_ids, is_infer=True, sequence_output=sequence_output, parallel_output=False)
|
||||
|
||||
partition_vocab_size = logits.size()[-1]
|
||||
|
||||
position_ids = torch.full([batch_size, 1], len(generate_tokens),
|
||||
dtype=torch.long, device=device)
|
||||
_, logits, sequence_output = self.dist_model(tokens, None, attention_mask, dec_input_ids,
|
||||
attention_mask, position_ids, is_infer=True,
|
||||
sequence_output=sequence_output, parallel_output=False)
|
||||
logits = logits[:, -1, :]
|
||||
logits = logits / cls.args.temperature
|
||||
logits = top_k_logits(logits, top_k=cls.args.top_k, top_p=cls.args.top_p)
|
||||
logits = logits / self.model_cfg['temperature']
|
||||
logits = self.top_k_logits(logits, top_k=self.model_cfg['top_k'], top_p=self.model_cfg['top_p'])
|
||||
log_probs = F.softmax(logits, dim=-1)
|
||||
prev = torch.multinomial(log_probs, num_samples=1)
|
||||
prev_token = prev[0].item()
|
||||
if prev_token >= vocab_size: #or prev_token == 102:
|
||||
if prev_token >= vocab_size:
|
||||
prev_token = 100
|
||||
prev[0] = 100
|
||||
# if prev_token == 102 and len(all_generate_tokens) > int(max(1, length) * 0.8):
|
||||
if prev_token == 102:
|
||||
break
|
||||
#if prev_token == 102:
|
||||
# counter += 1
|
||||
# continue
|
||||
#if prev_token == 100:
|
||||
# counter += 1
|
||||
# continue
|
||||
dec_input_ids = torch.cat([dec_input_ids, prev], dim=1)
|
||||
generate_tokens.append(prev_token)
|
||||
all_generate_tokens.append(prev_token)
|
||||
|
||||
@@ -18,113 +18,21 @@
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import os
|
||||
import copy
|
||||
import json
|
||||
import math
|
||||
import logging
|
||||
import tarfile
|
||||
import tempfile
|
||||
import shutil
|
||||
import math
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from deepspeed.utils.timer import SynchronizedWallClockTimer
|
||||
from modelscope.models.nlp.utils.distributed import normal_init_method, scaled_init_method
|
||||
from torch import nn
|
||||
|
||||
from .configuration_plug import PlugNLUConfig, PlugNLGConfig
|
||||
from ....utils.nlp import mpu#, cached_path
|
||||
import copy
|
||||
from deepspeed.utils.timer import SynchronizedWallClockTimer
|
||||
|
||||
def normal_init_method(mean, std):
|
||||
def init_(tensor):
|
||||
return torch.nn.init.normal_(tensor, mean=mean, std=std)
|
||||
return init_
|
||||
|
||||
def scaled_init_method(mean, std, num_layers):
|
||||
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
|
||||
std = std / math.sqrt(2.0 * num_layers)
|
||||
def init_(tensor):
|
||||
return torch.nn.init.normal_(tensor, mean=mean, std=std)
|
||||
|
||||
return init_
|
||||
from ....utils.nlp import mpu # , cached_path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
|
||||
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
|
||||
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
|
||||
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz",
|
||||
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz",
|
||||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
|
||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
|
||||
}
|
||||
CONFIG_NAME = 'bert_config.json'
|
||||
WEIGHTS_NAME = 'pytorch_model.bin'
|
||||
TF_WEIGHTS_NAME = 'model.ckpt'
|
||||
|
||||
def load_tf_weights_in_bert(model, tf_checkpoint_path):
|
||||
""" Load tf checkpoints in a pytorch model
|
||||
"""
|
||||
try:
|
||||
import re
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||
raise
|
||||
tf_path = os.path.abspath(tf_checkpoint_path)
|
||||
print("Converting TensorFlow checkpoint from {}".format(tf_path))
|
||||
# Load weights from TF model
|
||||
init_vars = tf.train.list_variables(tf_path)
|
||||
names = []
|
||||
arrays = []
|
||||
for name, shape in init_vars:
|
||||
print("Loading TF weight {} with shape {}".format(name, shape))
|
||||
array = tf.train.load_variable(tf_path, name)
|
||||
names.append(name)
|
||||
arrays.append(array)
|
||||
|
||||
for name, array in zip(names, arrays):
|
||||
name = name.split('/')
|
||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||
# which are not required for using pretrained model
|
||||
if any(n in ["adam_v", "adam_m"] for n in name):
|
||||
print("Skipping {}".format("/".join(name)))
|
||||
continue
|
||||
pointer = model
|
||||
for m_name in name:
|
||||
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
|
||||
l = re.split(r'_(\d+)', m_name)
|
||||
else:
|
||||
l = [m_name]
|
||||
if l[0] == 'kernel' or l[0] == 'gamma':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif l[0] == 'output_bias' or l[0] == 'beta':
|
||||
pointer = getattr(pointer, 'bias')
|
||||
elif l[0] == 'output_weights':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
else:
|
||||
pointer = getattr(pointer, l[0])
|
||||
if len(l) >= 2:
|
||||
num = int(l[1])
|
||||
pointer = pointer[num]
|
||||
if m_name[-11:] == '_embeddings':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif m_name == 'kernel':
|
||||
array = np.transpose(array)
|
||||
try:
|
||||
assert pointer.shape == array.shape
|
||||
except AssertionError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
print("Initialize PyTorch weight {}".format(name))
|
||||
pointer.data = torch.from_numpy(array)
|
||||
return model
|
||||
|
||||
|
||||
def gelu(x):
|
||||
"""Implementation of the gelu activation function.
|
||||
@@ -140,6 +48,7 @@ def swish(x):
|
||||
|
||||
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
||||
|
||||
|
||||
class BertLayerNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-12):
|
||||
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
||||
@@ -155,9 +64,11 @@ class BertLayerNorm(nn.Module):
|
||||
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
||||
return self.weight * x + self.bias
|
||||
|
||||
|
||||
class BertEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(BertEmbeddings, self).__init__()
|
||||
self.word_embeddings = mpu.VocabParallelEmbedding(
|
||||
@@ -201,7 +112,7 @@ class BertEmbeddings(nn.Module):
|
||||
else:
|
||||
embeddings = embeddings.type(previous_type)
|
||||
else:
|
||||
embeddings = words_embeddings.float() + position_embeddings.float() + token_type_embeddings.float()
|
||||
embeddings = words_embeddings.float() + position_embeddings.float() + token_type_embeddings.float()
|
||||
if self.fp32_tokentypes and not self.fp32_layernorm:
|
||||
embeddings = embeddings.half()
|
||||
previous_type = embeddings.type()
|
||||
@@ -234,7 +145,9 @@ class BertSelfOutput(nn.Module):
|
||||
input_is_parallel=True,
|
||||
stride=1,
|
||||
init_method=init_method,
|
||||
pruning_method=config.pruning_method if config.pruning_module in ['all', 'encoder', 'encoder_self', 'encoder_selfvo', 'encoder_selfo'] else None,
|
||||
pruning_method=config.pruning_method if config.pruning_module in ['all', 'encoder', 'encoder_self',
|
||||
'encoder_selfvo',
|
||||
'encoder_selfo'] else None,
|
||||
pruning_mask_init=config.pruning_mask_init,
|
||||
pruning_mask_scale=config.pruning_mask_scale,
|
||||
LR_weight_rank=config.LR_weight_rank,
|
||||
@@ -246,8 +159,8 @@ class BertSelfOutput(nn.Module):
|
||||
self.LayerNorm = None
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor, pruning_threshold=None,):
|
||||
hidden_states = self.dense(hidden_states, pruning_threshold=pruning_threshold,)
|
||||
def forward(self, hidden_states, input_tensor, pruning_threshold=None, ):
|
||||
hidden_states = self.dense(hidden_states, pruning_threshold=pruning_threshold, )
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
ln_input = hidden_states + input_tensor
|
||||
if self.LayerNorm is not None:
|
||||
@@ -261,6 +174,7 @@ class BertSelfOutput(nn.Module):
|
||||
hidden_states = ln_input
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertAttention, self).__init__()
|
||||
@@ -285,7 +199,7 @@ class BertAttention(nn.Module):
|
||||
LR_mask_rank=config.LR_mask_rank)
|
||||
self.output = BertSelfOutput(config)
|
||||
|
||||
def forward(self, input_tensor, attention_mask, pruning_threshold=None,):
|
||||
def forward(self, input_tensor, attention_mask, pruning_threshold=None, ):
|
||||
if self.LayerNorm is not None:
|
||||
ln_input = input_tensor
|
||||
previous_type = input_tensor.type()
|
||||
@@ -294,15 +208,15 @@ class BertAttention(nn.Module):
|
||||
ln_output = self.LayerNorm(ln_input)
|
||||
if self.fp32_layernorm:
|
||||
ln_output = ln_output.type(previous_type)
|
||||
self_output = self.self(ln_output, attention_mask, pruning_threshold=pruning_threshold,)
|
||||
self_output = self.self(ln_output, attention_mask, pruning_threshold=pruning_threshold, )
|
||||
else:
|
||||
self_output = self.self(input_tensor, attention_mask, pruning_threshold=pruning_threshold,)
|
||||
# output_pruning_threshold = 1 - (1 - pruning_threshold)/0.99*0.95
|
||||
self_output = self.self(input_tensor, attention_mask, pruning_threshold=pruning_threshold, )
|
||||
output_pruning_threshold = pruning_threshold
|
||||
|
||||
attention_output = self.output(self_output, input_tensor, pruning_threshold=output_pruning_threshold,)
|
||||
attention_output = self.output(self_output, input_tensor, pruning_threshold=output_pruning_threshold, )
|
||||
return attention_output
|
||||
|
||||
|
||||
class BertIntermediate(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertIntermediate, self).__init__()
|
||||
@@ -313,7 +227,8 @@ class BertIntermediate(nn.Module):
|
||||
gather_output=False,
|
||||
stride=1,
|
||||
init_method=normal_init_method(mean=0.0, std=config.initializer_range),
|
||||
pruning_method=config.pruning_method if config.pruning_module in ['all', 'encoder', 'encoder_ffn'] else None,
|
||||
pruning_method=config.pruning_method if config.pruning_module in ['all', 'encoder',
|
||||
'encoder_ffn'] else None,
|
||||
pruning_mask_init=config.pruning_mask_init,
|
||||
pruning_mask_scale=config.pruning_mask_scale,
|
||||
LR_weight_rank=config.LR_weight_rank,
|
||||
@@ -321,15 +236,16 @@ class BertIntermediate(nn.Module):
|
||||
self.intermediate_act_fn = ACT2FN[config.hidden_act] \
|
||||
if isinstance(config.hidden_act, str) else config.hidden_act
|
||||
|
||||
def forward(self, hidden_states, pruning_threshold=None,):
|
||||
hidden_states = self.dense(hidden_states, pruning_threshold=pruning_threshold,)
|
||||
def forward(self, hidden_states, pruning_threshold=None, ):
|
||||
hidden_states = self.dense(hidden_states, pruning_threshold=pruning_threshold, )
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertOutput, self).__init__()
|
||||
if hasattr(config, 'deep_init') and config.deep_init:
|
||||
if hasattr(config, 'deep_init') and config.deep_init:
|
||||
init_method = scaled_init_method(mean=0.0,
|
||||
std=config.initializer_range,
|
||||
num_layers=config.num_hidden_layers)
|
||||
@@ -343,7 +259,8 @@ class BertOutput(nn.Module):
|
||||
input_is_parallel=True,
|
||||
stride=1,
|
||||
init_method=init_method,
|
||||
pruning_method=config.pruning_method if config.pruning_module in ['all', 'encoder', 'encoder_ffn'] else None,
|
||||
pruning_method=config.pruning_method if config.pruning_module in ['all', 'encoder',
|
||||
'encoder_ffn'] else None,
|
||||
pruning_mask_init=config.pruning_mask_init,
|
||||
pruning_mask_scale=config.pruning_mask_scale,
|
||||
LR_weight_rank=config.LR_weight_rank,
|
||||
@@ -355,11 +272,11 @@ class BertOutput(nn.Module):
|
||||
self.LayerNorm = None
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor, pruning_threshold=None,):
|
||||
hidden_states = self.dense(hidden_states, pruning_threshold=pruning_threshold,)
|
||||
def forward(self, hidden_states, input_tensor, pruning_threshold=None, ):
|
||||
hidden_states = self.dense(hidden_states, pruning_threshold=pruning_threshold, )
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
ln_input = hidden_states + input_tensor
|
||||
if self.LayerNorm is not None:
|
||||
if self.LayerNorm is not None:
|
||||
previous_type = ln_input.type()
|
||||
if self.fp32_layernorm:
|
||||
ln_input = ln_input.float()
|
||||
@@ -370,6 +287,7 @@ class BertOutput(nn.Module):
|
||||
hidden_states = ln_input
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertLayer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertLayer, self).__init__()
|
||||
@@ -382,7 +300,7 @@ class BertLayer(nn.Module):
|
||||
else:
|
||||
self.LayerNorm = None
|
||||
|
||||
def forward(self, hidden_states, attention_mask, pruning_threshold=None,):
|
||||
def forward(self, hidden_states, attention_mask, pruning_threshold=None, ):
|
||||
attention_output = self.attention(hidden_states, attention_mask, pruning_threshold=pruning_threshold)
|
||||
if self.LayerNorm is not None:
|
||||
ln_input = attention_output
|
||||
@@ -398,6 +316,7 @@ class BertLayer(nn.Module):
|
||||
layer_output = self.output(intermediate_output, attention_output, pruning_threshold=pruning_threshold)
|
||||
return layer_output
|
||||
|
||||
|
||||
class BertEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertEncoder, self).__init__()
|
||||
@@ -408,8 +327,10 @@ class BertEncoder(nn.Module):
|
||||
else:
|
||||
self.LayerNorm = None
|
||||
|
||||
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, checkpoint_activations=False, detach_index=-1, pruning_threshold=None,):
|
||||
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, checkpoint_activations=False,
|
||||
detach_index=-1, pruning_threshold=None, ):
|
||||
all_encoder_layers = []
|
||||
|
||||
def custom(start, end):
|
||||
def custom_forward(*inputs):
|
||||
layers = self.layer[start:end]
|
||||
@@ -417,20 +338,21 @@ class BertEncoder(nn.Module):
|
||||
for layer in layers:
|
||||
x_ = layer(x_, inputs[1], pruning_threshold=pruning_threshold)
|
||||
return x_
|
||||
|
||||
return custom_forward
|
||||
|
||||
if checkpoint_activations:
|
||||
l = 0
|
||||
num_layers = len(self.layer)
|
||||
chunk_length = 1 #math.ceil(math.sqrt(num_layers))
|
||||
chunk_length = 1
|
||||
while l < num_layers:
|
||||
hidden_states = mpu.checkpoint(custom(l, l+chunk_length), hidden_states, attention_mask*1)
|
||||
hidden_states = mpu.checkpoint(custom(l, l + chunk_length), hidden_states, attention_mask * 1)
|
||||
if detach_index == l:
|
||||
hidden_states.detach_()
|
||||
l += chunk_length
|
||||
# decoder layers
|
||||
else:
|
||||
for i,layer_module in enumerate(self.layer):
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
hidden_states = layer_module(hidden_states, attention_mask)
|
||||
if detach_index == i:
|
||||
hidden_states.detach_()
|
||||
@@ -455,6 +377,7 @@ class BertEncoder(nn.Module):
|
||||
all_encoder_layers.append(hidden_states)
|
||||
return all_encoder_layers
|
||||
|
||||
|
||||
class BertPooler(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertPooler, self).__init__()
|
||||
@@ -469,6 +392,7 @@ class BertPooler(nn.Module):
|
||||
pooled_output = self.activation(pooled_output)
|
||||
return pooled_output
|
||||
|
||||
|
||||
class BertPredictionHeadTransform(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertPredictionHeadTransform, self).__init__()
|
||||
@@ -489,6 +413,7 @@ class BertPredictionHeadTransform(nn.Module):
|
||||
hidden_states = hidden_states.type(previous_type)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertLMPredictionHead(nn.Module):
|
||||
def __init__(self, config, bert_model_embedding_weights):
|
||||
super(BertLMPredictionHead, self).__init__()
|
||||
@@ -496,19 +421,18 @@ class BertLMPredictionHead(nn.Module):
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
#self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
|
||||
# bert_model_embedding_weights.size(0),
|
||||
# bias=False)
|
||||
self.decoder_weight = bert_model_embedding_weights
|
||||
self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
|
||||
self.bias.model_parallel = True
|
||||
self.fp32_embedding = config.fp32_embedding
|
||||
self.fp32_layernorm = config.fp32_layernorm
|
||||
|
||||
def convert_to_type(tensor):
|
||||
if self.fp32_embedding:
|
||||
return tensor.half()
|
||||
else:
|
||||
return tensor
|
||||
|
||||
self.type_converter = convert_to_type
|
||||
self.converted = False
|
||||
self.timers = SynchronizedWallClockTimer()
|
||||
@@ -521,14 +445,12 @@ class BertLMPredictionHead(nn.Module):
|
||||
if self.fp32_layernorm:
|
||||
self.transform.LayerNorm.float()
|
||||
hidden_states = self.transform(self.type_converter(hidden_states))
|
||||
# hidden_states = self.decoder(hidden_states) + self.bias
|
||||
self.timers('final linear gather').start()
|
||||
hidden_states = mpu.copy_to_model_parallel_region(hidden_states)
|
||||
self.timers('final linear gather').stop()
|
||||
hidden_states = F.linear(self.type_converter(hidden_states),
|
||||
self.type_converter(self.decoder_weight),
|
||||
self.type_converter(self.bias))
|
||||
#self.timers.log(names=['final linear gather'])
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -547,10 +469,12 @@ class BertPreTrainingHeads(nn.Module):
|
||||
seq_relationship_score = self.seq_relationship(pooled_output)
|
||||
return prediction_scores, seq_relationship_score
|
||||
|
||||
|
||||
class PreTrainedBertModel(nn.Module):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(PreTrainedBertModel, self).__init__()
|
||||
if not isinstance(config, PlugNLUConfig) and not isinstance(config, PlugNLGConfig):
|
||||
@@ -575,119 +499,6 @@ class PreTrainedBertModel(nn.Module):
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
#@classmethod
|
||||
#def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None,
|
||||
# fp32_layernorm=False, fp32_embedding=False, layernorm_epsilon=1e-12,
|
||||
# fp32_tokentypes=False, *inputs, **kwargs):
|
||||
# """
|
||||
# Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict.
|
||||
# Download and cache the pre-trained model file if needed.
|
||||
|
||||
# Params:
|
||||
# pretrained_model_name: either:
|
||||
# - a str with the name of a pre-trained model to load selected in the list of:
|
||||
# . `bert-base-uncased`
|
||||
# . `bert-large-uncased`
|
||||
# . `bert-base-cased`
|
||||
# . `bert-large-cased`
|
||||
# . `bert-base-multilingual-uncased`
|
||||
# . `bert-base-multilingual-cased`
|
||||
# . `bert-base-chinese`
|
||||
# - a path or url to a pretrained model archive containing:
|
||||
# . `bert_config.json` a configuration file for the model
|
||||
# . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
|
||||
# cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
||||
# state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
|
||||
# *inputs, **kwargs: additional input for the specific Bert class
|
||||
# (ex: num_labels for BertForSequenceClassification)
|
||||
# """
|
||||
# if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
|
||||
# archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
|
||||
# else:
|
||||
# archive_file = pretrained_model_name
|
||||
# # redirect to the cache, if necessary
|
||||
# try:
|
||||
# resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
||||
# except FileNotFoundError:
|
||||
# logger.error(
|
||||
# "Model name '{}' was not found in model name list ({}). "
|
||||
# "We assumed '{}' was a path or url but couldn't find any file "
|
||||
# "associated to this path or url.".format(
|
||||
# pretrained_model_name,
|
||||
# ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
|
||||
# archive_file))
|
||||
# return None
|
||||
# if resolved_archive_file == archive_file:
|
||||
# logger.info("loading archive file {}".format(archive_file))
|
||||
# else:
|
||||
# logger.info("loading archive file {} from cache at {}".format(
|
||||
# archive_file, resolved_archive_file))
|
||||
# tempdir = None
|
||||
# if os.path.isdir(resolved_archive_file):
|
||||
# serialization_dir = resolved_archive_file
|
||||
# else:
|
||||
# # Extract archive to temp dir
|
||||
# tempdir = tempfile.mkdtemp()
|
||||
# logger.info("extracting archive file {} to temp dir {}".format(
|
||||
# resolved_archive_file, tempdir))
|
||||
# with tarfile.open(resolved_archive_file, 'r:gz') as archive:
|
||||
# archive.extractall(tempdir)
|
||||
# serialization_dir = tempdir
|
||||
# # Load config
|
||||
# config_file = os.path.join(serialization_dir, CONFIG_NAME)
|
||||
# config = PlugNLUConfig.from_json_file(config_file)
|
||||
# config.fp32_layernorm = fp32_layernorm
|
||||
# config.fp32_embedding = fp32_embedding
|
||||
# config.layernorm_epsilon = layernorm_epsilon
|
||||
# config.fp32_tokentypes = fp32_tokentypes
|
||||
# logger.info("Model config {}".format(config))
|
||||
# # Instantiate model.
|
||||
# model = cls(config, *inputs, **kwargs)
|
||||
# if state_dict is None:
|
||||
# weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
||||
# state_dict = torch.load(weights_path)
|
||||
|
||||
# old_keys = []
|
||||
# new_keys = []
|
||||
# for key in state_dict.keys():
|
||||
# new_key = None
|
||||
# if 'gamma' in key:
|
||||
# new_key = key.replace('gamma', 'weight')
|
||||
# if 'beta' in key:
|
||||
# new_key = key.replace('beta', 'bias')
|
||||
# if new_key:
|
||||
# old_keys.append(key)
|
||||
# new_keys.append(new_key)
|
||||
# for old_key, new_key in zip(old_keys, new_keys):
|
||||
# state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
# missing_keys = []
|
||||
# unexpected_keys = []
|
||||
# error_msgs = []
|
||||
# # copy state_dict so _load_from_state_dict can modify it
|
||||
# metadata = getattr(state_dict, '_metadata', None)
|
||||
# state_dict = state_dict.copy()
|
||||
# if metadata is not None:
|
||||
# state_dict._metadata = metadata
|
||||
|
||||
# def load(module, prefix=''):
|
||||
# local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
# module._load_from_state_dict(
|
||||
# state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
||||
# for name, child in module._modules.items():
|
||||
# if child is not None:
|
||||
# load(child, prefix + name + '.')
|
||||
# load(model, prefix='' if hasattr(model, 'bert') else 'bert.')
|
||||
# if len(missing_keys) > 0:
|
||||
# logger.info("Weights of {} not initialized from pretrained model: {}".format(
|
||||
# model.__class__.__name__, missing_keys))
|
||||
# if len(unexpected_keys) > 0:
|
||||
# logger.info("Weights from pretrained model not used in {}: {}".format(
|
||||
# model.__class__.__name__, unexpected_keys))
|
||||
# if tempdir:
|
||||
# # Clean up temp dir
|
||||
# shutil.rmtree(tempdir)
|
||||
# return model
|
||||
|
||||
class BertModel(PreTrainedBertModel):
|
||||
"""BERT model ("Bidirectional Embedding Representations from a Transformer").
|
||||
@@ -733,6 +544,7 @@ class BertModel(PreTrainedBertModel):
|
||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(BertModel, self).__init__(config)
|
||||
self.embeddings = BertEmbeddings(config)
|
||||
@@ -740,7 +552,8 @@ class BertModel(PreTrainedBertModel):
|
||||
self.pooler = BertPooler(config)
|
||||
self.apply(self.init_bert_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True, checkpoint_activations=False, detach_index=-1, pruning_threshold=None,):
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True,
|
||||
checkpoint_activations=False, detach_index=-1, pruning_threshold=None, ):
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
if token_type_ids is None:
|
||||
@@ -758,7 +571,8 @@ class BertModel(PreTrainedBertModel):
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.encoder.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = extended_attention_mask.to(
|
||||
dtype=next(self.encoder.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
embedding_output = self.embeddings(input_ids, token_type_ids)
|
||||
@@ -774,7 +588,7 @@ class BertModel(PreTrainedBertModel):
|
||||
continue
|
||||
sequence_output = sequence_output.type_as(p)
|
||||
break
|
||||
#pooled_output = self.pooler(sequence_output)
|
||||
|
||||
pooled_output = sequence_output[:, 0]
|
||||
if not output_all_encoded_layers or checkpoint_activations:
|
||||
encoded_layers = encoded_layers[-1]
|
||||
@@ -784,11 +598,11 @@ class BertModel(PreTrainedBertModel):
|
||||
class DecodeLayer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(DecodeLayer, self).__init__()
|
||||
init_method = normal_init_method(mean=0.0,std=config.initializer_range)
|
||||
init_method = normal_init_method(mean=0.0, std=config.initializer_range)
|
||||
output_layer_init_method = scaled_init_method(mean=0.0,
|
||||
std=config.initializer_range,
|
||||
num_layers=config.num_hidden_layers)
|
||||
|
||||
std=config.initializer_range,
|
||||
num_layers=config.num_hidden_layers)
|
||||
|
||||
self_pruning_method = config.pruning_method
|
||||
cross_pruning_method = config.pruning_method
|
||||
ffn_pruning_method = config.pruning_method
|
||||
@@ -808,11 +622,12 @@ class DecodeLayer(nn.Module):
|
||||
output_dropout_prob=config.hidden_dropout_prob,
|
||||
init_method=init_method,
|
||||
output_layer_init_method=output_layer_init_method,
|
||||
pruning_method=self_pruning_method if config.pruning_module in ['all', 'decoder', 'decoder_self', 'decoder_self+ffn'] else None,
|
||||
pruning_method=self_pruning_method if config.pruning_module in ['all', 'decoder', 'decoder_self',
|
||||
'decoder_self+ffn'] else None,
|
||||
pruning_mask_init=config.pruning_mask_init, pruning_mask_scale=config.pruning_mask_scale,
|
||||
LR_weight_rank=config.LR_weight_rank,
|
||||
LR_mask_rank=config.LR_mask_rank,
|
||||
)
|
||||
)
|
||||
|
||||
self.cross_attention = mpu.PalmParallelCrossAttention(
|
||||
hidden_size=config.hidden_size,
|
||||
@@ -821,55 +636,65 @@ class DecodeLayer(nn.Module):
|
||||
output_dropout_prob=config.hidden_dropout_prob,
|
||||
init_method=init_method, attn_separate=False,
|
||||
output_layer_init_method=output_layer_init_method,
|
||||
pruning_method=cross_pruning_method, pruning_mask_init=config.pruning_mask_init,
|
||||
pruning_method=cross_pruning_method, pruning_mask_init=config.pruning_mask_init,
|
||||
pruning_mask_scale=config.pruning_mask_scale, pruning_module=config.pruning_module,
|
||||
LR_weight_rank=config.LR_weight_rank,
|
||||
LR_mask_rank=config.LR_mask_rank,)
|
||||
|
||||
LR_mask_rank=config.LR_mask_rank, )
|
||||
|
||||
self.input_layernorm = BertLayerNorm(config.hidden_size, eps=config.layernorm_epsilon)
|
||||
self.post_attention_layernorm = BertLayerNorm(config.hidden_size, eps=config.layernorm_epsilon)
|
||||
self.post_cross_attention_layernorm = BertLayerNorm(config.hidden_size, eps=config.layernorm_epsilon)
|
||||
|
||||
self.intermediate = mpu.ColumnParallelLinear(config.hidden_size, config.intermediate_size, gather_output=False, init_method=init_method,
|
||||
pruning_method=ffn_pruning_method if config.pruning_module in ['all', 'decoder', 'decoder_ffn', 'decoder_self+ffn'] else None,
|
||||
pruning_mask_init=config.pruning_mask_init, pruning_mask_scale=config.pruning_mask_scale,
|
||||
LR_weight_rank=config.LR_weight_rank, LR_mask_rank=config.LR_mask_rank,)
|
||||
self.intermediate = mpu.ColumnParallelLinear(config.hidden_size, config.intermediate_size, gather_output=False,
|
||||
init_method=init_method,
|
||||
pruning_method=ffn_pruning_method if config.pruning_module in [
|
||||
'all', 'decoder', 'decoder_ffn', 'decoder_self+ffn'] else None,
|
||||
pruning_mask_init=config.pruning_mask_init,
|
||||
pruning_mask_scale=config.pruning_mask_scale,
|
||||
LR_weight_rank=config.LR_weight_rank,
|
||||
LR_mask_rank=config.LR_mask_rank, )
|
||||
self.intermediate_act_fn = ACT2FN[config.hidden_act] \
|
||||
if isinstance(config.hidden_act, str) else config.hidden_act
|
||||
self.output = mpu.RowParallelLinear(config.intermediate_size, config.hidden_size, input_is_parallel=True, init_method=output_layer_init_method,
|
||||
pruning_method=ffn_pruning_method if config.pruning_module in ['all', 'decoder', 'decoder_ffn', 'decoder_self+ffn'] else None,
|
||||
pruning_mask_init=config.pruning_mask_init, pruning_mask_scale=config.pruning_mask_scale,
|
||||
LR_weight_rank=config.LR_weight_rank, LR_mask_rank=config.LR_mask_rank,)
|
||||
|
||||
self.output = mpu.RowParallelLinear(config.intermediate_size, config.hidden_size, input_is_parallel=True,
|
||||
init_method=output_layer_init_method,
|
||||
pruning_method=ffn_pruning_method if config.pruning_module in ['all',
|
||||
'decoder',
|
||||
'decoder_ffn',
|
||||
'decoder_self+ffn'] else None,
|
||||
pruning_mask_init=config.pruning_mask_init,
|
||||
pruning_mask_scale=config.pruning_mask_scale,
|
||||
LR_weight_rank=config.LR_weight_rank, LR_mask_rank=config.LR_mask_rank, )
|
||||
|
||||
self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
|
||||
self.fp32_layernorm = config.fp32_layernorm
|
||||
|
||||
def convert_to_type(tensor):
|
||||
if self.fp32_layernorm:
|
||||
return tensor.float()
|
||||
else:
|
||||
return tensor
|
||||
self.type_converter = convert_to_type
|
||||
|
||||
|
||||
#def forward(self, hidden_states, enc_attn_mask, dec_attn_mask):
|
||||
def forward(self, hidden_states, enc_hidden_states, enc_attn_mask, dec_attn_mask, is_infer=False, pruning_threshold=None):
|
||||
self.type_converter = convert_to_type
|
||||
|
||||
# def forward(self, hidden_states, enc_attn_mask, dec_attn_mask):
|
||||
def forward(self, hidden_states, enc_hidden_states, enc_attn_mask, dec_attn_mask, is_infer=False,
|
||||
pruning_threshold=None):
|
||||
residual = hidden_states
|
||||
previous_type = hidden_states.type()
|
||||
hidden_states = self.input_layernorm(self.type_converter(hidden_states))
|
||||
if self.fp32_layernorm:
|
||||
hidden_states = hidden_states.type(previous_type)
|
||||
hidden_states = self.attention(hidden_states, dec_attn_mask, is_infer=is_infer, pruning_threshold=pruning_threshold)
|
||||
# add dropout?
|
||||
# hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.attention(hidden_states, dec_attn_mask, is_infer=is_infer,
|
||||
pruning_threshold=pruning_threshold)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(self.type_converter(hidden_states))
|
||||
if self.fp32_layernorm:
|
||||
# same to the output of BertAttention
|
||||
hidden_states = hidden_states.type(previous_type)
|
||||
hidden_states = self.cross_attention(hidden_states, enc_hidden_states, enc_attn_mask, pruning_threshold=pruning_threshold)
|
||||
# hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.cross_attention(hidden_states, enc_hidden_states, enc_attn_mask,
|
||||
pruning_threshold=pruning_threshold)
|
||||
hidden_states = residual + hidden_states
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_cross_attention_layernorm(self.type_converter(hidden_states))
|
||||
@@ -877,55 +702,61 @@ class DecodeLayer(nn.Module):
|
||||
hidden_states = hidden_states.type(previous_type)
|
||||
hidden_states = self.intermediate(hidden_states, pruning_threshold=pruning_threshold)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
# hidden_states = self.dropout(hidden_states)
|
||||
|
||||
hidden_states = self.output(hidden_states, pruning_threshold=pruning_threshold)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertDecoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertDecoder, self).__init__()
|
||||
self.layer = nn.ModuleList([DecodeLayer(config) for _ in range(config.dec_hidden_layers)])
|
||||
|
||||
|
||||
self.final_layernorm = BertLayerNorm(config.hidden_size, eps=config.layernorm_epsilon)
|
||||
self.fp32_layernorm = config.fp32_layernorm
|
||||
|
||||
def forward(self, hidden_states, enc_hidden_states, enc_attn_mask, dec_attn_mask, checkpoint_activations=False, output_all_encoded_layers=False, is_infer=False, pruning_threshold=None):
|
||||
all_encoder_layers = []
|
||||
def forward(self, hidden_states, enc_hidden_states, enc_attn_mask, dec_attn_mask, checkpoint_activations=False,
|
||||
output_all_encoded_layers=False, is_infer=False, pruning_threshold=None):
|
||||
|
||||
def custom(start, end):
|
||||
def custom_forward(*inputs):
|
||||
layers = self.layer[start:end]
|
||||
x_ = inputs[0]
|
||||
for layer in layers:
|
||||
x_ = layer(x_, inputs[1], inputs[2], dec_attn_mask*1, is_infer=is_infer, pruning_threshold=pruning_threshold)
|
||||
x_ = layer(x_, inputs[1], inputs[2], dec_attn_mask * 1, is_infer=is_infer,
|
||||
pruning_threshold=pruning_threshold)
|
||||
return x_
|
||||
|
||||
return custom_forward
|
||||
|
||||
pre_enc_hidden= enc_hidden_states.data
|
||||
pre_enc_hidden = enc_hidden_states.data
|
||||
if checkpoint_activations:
|
||||
l = 0
|
||||
num_layers = len(self.layer)
|
||||
chunk_length = 1 #math.ceil(math.sqrt(num_layers))
|
||||
chunk_length = 1
|
||||
while l < num_layers:
|
||||
hidden_states = mpu.checkpoint(custom(l, l+chunk_length), hidden_states, enc_hidden_states, enc_attn_mask*1)
|
||||
hidden_states = mpu.checkpoint(custom(l, l + chunk_length), hidden_states, enc_hidden_states,
|
||||
enc_attn_mask * 1)
|
||||
enc_hidden_states.data = pre_enc_hidden
|
||||
l += chunk_length
|
||||
else:
|
||||
for i,layer_module in enumerate(self.layer):
|
||||
hidden_states = layer_module(hidden_states, enc_hidden_states, enc_attn_mask, dec_attn_mask, is_infer=is_infer, pruning_threshold=pruning_threshold)
|
||||
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
hidden_states = layer_module(hidden_states, enc_hidden_states, enc_attn_mask, dec_attn_mask,
|
||||
is_infer=is_infer, pruning_threshold=pruning_threshold)
|
||||
|
||||
previous_type = hidden_states.type()
|
||||
if self.fp32_layernorm:
|
||||
hidden_states = hidden_states.float()
|
||||
hidden_states = self.final_layernorm(hidden_states)
|
||||
if self.fp32_layernorm:
|
||||
hidden_states = hidden_states.type(previous_type)
|
||||
|
||||
|
||||
return [hidden_states]
|
||||
|
||||
|
||||
class DecodeModel(PreTrainedBertModel):
|
||||
|
||||
def __init__(self, config):
|
||||
@@ -933,22 +764,24 @@ class DecodeModel(PreTrainedBertModel):
|
||||
self.decoder = BertDecoder(config)
|
||||
self.apply(self.init_bert_weights)
|
||||
|
||||
def forward(self, embeddings, sequence_output, decode_input_ids, position_ids=None, enc_attn_mask=None, dec_attn_mask=None, checkpoint_activations=False, is_infer=False, pruning_threshold=None):
|
||||
|
||||
def forward(self, embeddings, sequence_output, decode_input_ids, position_ids=None, enc_attn_mask=None,
|
||||
dec_attn_mask=None, checkpoint_activations=False, is_infer=False, pruning_threshold=None):
|
||||
extended_attention_mask = enc_attn_mask.unsqueeze(1).unsqueeze(2)
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.decoder.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = extended_attention_mask.to(
|
||||
dtype=next(self.decoder.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
embedding_output = embeddings(decode_input_ids)
|
||||
sequence_output = self.decoder(embedding_output,
|
||||
sequence_output,
|
||||
extended_attention_mask,
|
||||
dec_attn_mask,
|
||||
checkpoint_activations=False,
|
||||
is_infer=is_infer,
|
||||
pruning_threshold=pruning_threshold)
|
||||
sequence_output,
|
||||
extended_attention_mask,
|
||||
dec_attn_mask,
|
||||
checkpoint_activations=False,
|
||||
is_infer=is_infer,
|
||||
pruning_threshold=pruning_threshold)
|
||||
return sequence_output[-1]
|
||||
|
||||
|
||||
class PalmForPreTraining(PreTrainedBertModel):
|
||||
def __init__(self, config):
|
||||
super(PalmForPreTraining, self).__init__(config)
|
||||
@@ -957,65 +790,52 @@ class PalmForPreTraining(PreTrainedBertModel):
|
||||
self.decoder = DecodeModel(config)
|
||||
self.apply(self.init_bert_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, decode_input_ids=None, position_ids=None, decode_attention_mask=None, lm_labels=None, checkpoint_activations=False, is_infer=False, sequence_output=None, parallel_output=True, pruning_threshold=None):
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, decode_input_ids=None, position_ids=None,
|
||||
decode_attention_mask=None, lm_labels=None, checkpoint_activations=False, is_infer=False,
|
||||
sequence_output=None, parallel_output=True, pruning_threshold=None):
|
||||
if sequence_output is None:
|
||||
sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
|
||||
output_all_encoded_layers=False, checkpoint_activations=checkpoint_activations, pruning_threshold=pruning_threshold)
|
||||
output_all_encoded_layers=False,
|
||||
checkpoint_activations=checkpoint_activations,
|
||||
pruning_threshold=pruning_threshold)
|
||||
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
||||
else:
|
||||
prediction_scores = None
|
||||
seq_relationship_score = None
|
||||
sequence_output = sequence_output.to(dtype=next(self.decoder.parameters()).dtype)
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
decode_output = self.decoder(self.bert.embeddings, sequence_output, decode_input_ids, position_ids, attention_mask, decode_attention_mask, checkpoint_activations=checkpoint_activations, is_infer=is_infer, pruning_threshold=pruning_threshold)
|
||||
decode_output = self.decoder(self.bert.embeddings, sequence_output, decode_input_ids, position_ids,
|
||||
attention_mask, decode_attention_mask,
|
||||
checkpoint_activations=checkpoint_activations, is_infer=is_infer,
|
||||
pruning_threshold=pruning_threshold)
|
||||
|
||||
#prediction_scores = self.cls(decode_output)
|
||||
|
||||
transformer_output_parallel = mpu.copy_to_model_parallel_region(
|
||||
decode_output)
|
||||
|
||||
logits_parallel = F.linear(transformer_output_parallel,
|
||||
self.bert.embeddings.word_embeddings.weight)
|
||||
|
||||
|
||||
if parallel_output:
|
||||
return prediction_scores, logits_parallel
|
||||
if is_infer:
|
||||
return prediction_scores, mpu.gather_from_model_parallel_region(logits_parallel), sequence_output
|
||||
return prediction_scores, mpu.gather_from_model_parallel_region(logits_parallel)
|
||||
|
||||
|
||||
class PlugModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super(PlugModel, self).__init__()
|
||||
if config.intermediate_size is None:
|
||||
intermediate_size = 4 * config.hidden_size
|
||||
else:
|
||||
intermediate_size = config.intermediate_size
|
||||
self.config = config
|
||||
# self.config = BertConfig(
|
||||
# args.tokenizer_num_tokens,
|
||||
# hidden_size=args.hidden_size,
|
||||
# num_hidden_layers=args.num_layers,
|
||||
# num_attention_heads=args.num_attention_heads,
|
||||
# intermediate_size=intermediate_size,
|
||||
# hidden_dropout_prob=args.hidden_dropout,
|
||||
# attention_probs_dropout_prob=args.attention_dropout,
|
||||
# max_position_embeddings=args.max_position_embeddings,
|
||||
# type_vocab_size=args.tokenizer_num_type_tokens,
|
||||
# fp32_layernorm=args.fp32_layernorm,
|
||||
# fp32_embedding=args.fp32_embedding,
|
||||
# fp32_tokentypes=args.fp32_tokentypes,
|
||||
# layernorm_epsilon=args.layernorm_epsilon,
|
||||
# deep_init=args.deep_init,
|
||||
# dec_hidden_layers=args.dec_layers)
|
||||
self.model = PalmForPreTraining(self.config)
|
||||
|
||||
def forward(self, input_tokens, token_type_ids=None,
|
||||
attention_mask=None, target_tokens=None, position_ids=None, decode_attention_mask=None, checkpoint_activations=False, is_infer=False, sequence_output=None, parallel_output=True):
|
||||
attention_mask=None, target_tokens=None, position_ids=None, decode_attention_mask=None,
|
||||
checkpoint_activations=False, is_infer=False, sequence_output=None, parallel_output=True):
|
||||
return self.model(
|
||||
input_tokens, token_type_ids, attention_mask, target_tokens, position_ids,
|
||||
decode_attention_mask, checkpoint_activations=checkpoint_activations, is_infer=is_infer, sequence_output=sequence_output, parallel_output=parallel_output)
|
||||
input_tokens, token_type_ids, attention_mask, target_tokens, position_ids,
|
||||
decode_attention_mask, checkpoint_activations=checkpoint_activations, is_infer=is_infer,
|
||||
sequence_output=sequence_output, parallel_output=parallel_output)
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
||||
return self.model.state_dict(destination=destination, prefix=prefix,
|
||||
@@ -1023,5 +843,3 @@ class PlugModel(torch.nn.Module):
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
return self.model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
|
||||
|
||||
@@ -1,59 +1,35 @@
|
||||
import torch
|
||||
from typing import Dict
|
||||
from functools import partial
|
||||
from typing import Dict, Any
|
||||
|
||||
from . import DistributedPlug
|
||||
from ...base import Tensor, TorchModel
|
||||
import torch
|
||||
from ...builder import MODELS
|
||||
from ....metainfo import Models
|
||||
from ....outputs import OutputKeys
|
||||
from ....utils.constant import Tasks
|
||||
from modelscope.models.nlp.structbert import SbertTokenizer
|
||||
from modelscope.models.nlp.utils.distributed import DistributedTorchModel
|
||||
from . import DistributedPlug
|
||||
from ...base import Tensor
|
||||
|
||||
__all__ = ['PlugForTextGeneration']
|
||||
|
||||
__all__ = ['PLUGForTextGeneration']
|
||||
|
||||
@MODELS.register_module(Tasks.text_generation, module_name=Models.plug)
|
||||
class PlugForTextGeneration(TorchModel):
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
import torch
|
||||
class PlugForTextGeneration(DistributedTorchModel):
|
||||
|
||||
from transformers import BertTokenizer
|
||||
from multiprocessing import Pool
|
||||
from .arguments import get_args
|
||||
from . import PlugNLGConfig
|
||||
import torch
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
|
||||
self.tokenizer = BertTokenizer.from_pretrained(model_dir)
|
||||
model_config = PlugNLGConfig.from_pretrained(model_dir)
|
||||
|
||||
# TODO(suluyan): Arguments
|
||||
args = get_args()
|
||||
args.world_size = 8
|
||||
args.model_parallel_size = 8
|
||||
args.pre_load = True
|
||||
args.distributed_backend = 'nccl'
|
||||
args.fp16 = True
|
||||
args.fp32_layernorm = True
|
||||
args.checkpoint_activations = True
|
||||
args.batch_size = 1
|
||||
args.top_k = 20
|
||||
args.top_p = 0.0
|
||||
args.temperature = 0.9
|
||||
self.args = args
|
||||
|
||||
self.world_size = args.world_size
|
||||
ranks = list(range(self.world_size))
|
||||
self.model_pool = Pool(self.world_size)
|
||||
self.model_pool.map(partial(DistributedPlug.init, model_dir=model_dir, model_config=model_config, args=args), ranks)
|
||||
|
||||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
def _forward_one(self, input: Dict[str, Any]) -> Dict[str, Tensor]:
|
||||
return self.model(**input)
|
||||
|
||||
def generate(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
dec_input_ids = torch.full([self.args.batch_size, 1], self.tokenizer.cls_token_id, dtype=torch.long)
|
||||
batch_size = input['input_ids'].shape[0]
|
||||
dec_input_ids = torch.full([batch_size, 1], self.cls_token_id, dtype=torch.long)
|
||||
input["dec_input_ids"] = dec_input_ids
|
||||
res = self.model_pool.map(DistributedPlug.forward, [input]*self.world_size)
|
||||
return res[0]
|
||||
|
||||
def _instantiate_one(self, model_dir, rank):
|
||||
tokenizer = SbertTokenizer.from_pretrained(model_dir)
|
||||
self.cls_token_id = tokenizer.cls_token_id
|
||||
self.model = DistributedPlug.instantiate(model_dir, rank)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
0
modelscope/utils/nlp/mpu/tests/__init__.py → modelscope/models/nlp/utils/__init__.py
Executable file → Normal file
0
modelscope/utils/nlp/mpu/tests/__init__.py → modelscope/models/nlp/utils/__init__.py
Executable file → Normal file
72
modelscope/models/nlp/utils/distributed.py
Normal file
72
modelscope/models/nlp/utils/distributed.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from modelscope.models import Model
|
||||
from multiprocessing import Pool
|
||||
from functools import partial
|
||||
from typing import Dict, Any
|
||||
from modelscope.utils.hub import read_config
|
||||
from modelscope.utils.torch_utils import init_dist, _is_free_port, _find_free_port
|
||||
import torch
|
||||
import os
|
||||
import math
|
||||
|
||||
|
||||
class DistributedTorchModel(Model):
|
||||
|
||||
def __init__(self, model_dir, *args, **kwargs):
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
self.model_pool = None
|
||||
self.world_size = None
|
||||
|
||||
@classmethod
|
||||
def _instantiate(cls, model_dir):
|
||||
model = DistributedTorchModel(model_dir=model_dir)
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
cfg = read_config(model_dir)
|
||||
model.world_size = cfg.model.word_size
|
||||
ranks = list(range(model.world_size))
|
||||
model.model_pool = Pool(model.world_size)
|
||||
model.model_pool.map(partial(model._instantiate_one, model_dir=model_dir), ranks)
|
||||
return model
|
||||
|
||||
def _instantiate_one(self, model_dir, rank):
|
||||
pass
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
res = self.model_pool.map(self._forward_one, [input]*self.world_size)
|
||||
return res[0]
|
||||
|
||||
def _forward_one(self, input):
|
||||
pass
|
||||
|
||||
|
||||
def initialize_distributed(rank, mpu, world_size, model_parallel_size):
|
||||
"""Initialize torch.distributed."""
|
||||
# Manually set the device ids.
|
||||
device = rank % torch.cuda.device_count()
|
||||
torch.cuda.set_device(device)
|
||||
# Call the init process
|
||||
init_method = 'tcp://'
|
||||
master_ip = os.getenv('MASTER_ADDR', '127.0.0.1')
|
||||
master_port = os.getenv('MASTER_PORT', '29500')
|
||||
if not _is_free_port(int(master_port)):
|
||||
master_port = str(_find_free_port())
|
||||
init_method += master_ip + ':' + master_port
|
||||
init_dist('pytorch', world_size=world_size, rank=rank, init_method=init_method)
|
||||
# Set the model-parallel communicators.
|
||||
mpu.initialize_model_parallel(model_parallel_size)
|
||||
|
||||
|
||||
def normal_init_method(mean, std):
|
||||
def init_(tensor):
|
||||
return torch.nn.init.normal_(tensor, mean=mean, std=std)
|
||||
|
||||
return init_
|
||||
|
||||
|
||||
def scaled_init_method(mean, std, num_layers):
|
||||
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
|
||||
std = std / math.sqrt(2.0 * num_layers)
|
||||
|
||||
def init_(tensor):
|
||||
return torch.nn.init.normal_(tensor, mean=mean, std=std)
|
||||
|
||||
return init_
|
||||
@@ -10,6 +10,7 @@ from typing import Callable, List, Optional, Tuple, Union
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from modelscope.utils.torch_utils import set_random_seed
|
||||
from addict import Dict
|
||||
from torch import distributed as dist
|
||||
from torch import nn
|
||||
@@ -816,6 +817,4 @@ def worker_init_fn(worker_id, num_workers, rank, seed):
|
||||
# The seed of each worker equals to
|
||||
# num_worker * rank + worker_id + user_seed
|
||||
worker_seed = num_workers * rank + worker_id + seed
|
||||
np.random.seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
torch.manual_seed(worker_seed)
|
||||
set_random_seed(worker_seed)
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from sofa.utils import mpu
|
||||
from modelscope.utils.nlp import mpu
|
||||
|
||||
# item() is a recent addition, so this helps with backward compatibility.
|
||||
def to_python_float(t):
|
||||
|
||||
@@ -24,9 +24,6 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.init as init
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
|
||||
|
||||
from .initialize import get_model_parallel_rank
|
||||
from .initialize import get_model_parallel_world_size
|
||||
from .mappings import copy_to_model_parallel_region
|
||||
|
||||
@@ -1,82 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
import mpu
|
||||
|
||||
|
||||
class IdentityLayer(torch.nn.Module):
|
||||
def __init__(self, size, scale=1.0):
|
||||
super(IdentityLayer, self).__init__()
|
||||
self.weight = torch.nn.Parameter(scale * torch.randn(size))
|
||||
def forward(self):
|
||||
return self.weight
|
||||
|
||||
|
||||
def set_random_seed(seed):
|
||||
"""Set random seed for reproducability."""
|
||||
random.seed(seed)
|
||||
numpy.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
mpu.model_parallel_cuda_manual_seed(seed)
|
||||
|
||||
|
||||
def initialize_distributed(backend='nccl'):
|
||||
"""Initialize torch.distributed."""
|
||||
# Get local rank in case it is provided.
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--local_rank', type=int, default=None,
|
||||
help='local rank passed from distributed launcher')
|
||||
args = parser.parse_args()
|
||||
local_rank = args.local_rank
|
||||
|
||||
# Get rank and world size.
|
||||
rank = int(os.getenv('RANK', '0'))
|
||||
world_size = int(os.getenv("WORLD_SIZE", '1'))
|
||||
|
||||
print('> initializing torch.distributed with local rank: {}, '
|
||||
'rank: {}, world size: {}'.format(local_rank, rank, world_size))
|
||||
|
||||
# Set the device id.
|
||||
device = rank % torch.cuda.device_count()
|
||||
if local_rank is not None:
|
||||
device = local_rank
|
||||
#torch.cuda.set_device(device)
|
||||
|
||||
# Call the init process.
|
||||
init_method = 'tcp://'
|
||||
master_ip = os.getenv('MASTER_ADDR', 'localhost')
|
||||
master_port = os.getenv('MASTER_PORT', '6000')
|
||||
init_method += master_ip + ':' + master_port
|
||||
torch.distributed.init_process_group(
|
||||
backend=backend,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
init_method=init_method)
|
||||
|
||||
|
||||
def print_separator(message):
|
||||
torch.distributed.barrier()
|
||||
filler_len = (78 - len(message)) // 2
|
||||
filler = '-' * filler_len
|
||||
string = '\n' + filler + ' {} '.format(message) + filler
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print(string, flush=True)
|
||||
torch.distributed.barrier()
|
||||
@@ -1,110 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import sys
|
||||
sys.path.append("../..")
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import mpu
|
||||
from mpu.cross_entropy import vocab_parallel_cross_entropy
|
||||
|
||||
from commons import initialize_distributed
|
||||
from commons import print_separator
|
||||
from commons import IdentityLayer
|
||||
from commons import set_random_seed
|
||||
|
||||
|
||||
def torch_cross_entropy(batch_size, seq_length, vocab_size,
|
||||
logits_scale, seed):
|
||||
set_random_seed(seed)
|
||||
identity = IdentityLayer((batch_size, seq_length, vocab_size),
|
||||
scale=logits_scale).cuda()
|
||||
logits = identity()
|
||||
target = torch.cuda.LongTensor(
|
||||
size=(batch_size, seq_length)).random_(0, vocab_size)
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size()[-1]),
|
||||
target.view(-1),
|
||||
reduction='none').view_as(target).mean()
|
||||
loss.backward()
|
||||
return loss, identity.weight.grad
|
||||
|
||||
|
||||
def mpu_cross_entropy(batch_size, seq_length, vocab_size,
|
||||
logits_scale, seed):
|
||||
set_random_seed(seed)
|
||||
identity = IdentityLayer((batch_size, seq_length, vocab_size),
|
||||
scale=logits_scale).cuda()
|
||||
logits = identity()
|
||||
logits_parallel = mpu.scatter_to_model_parallel_region(logits)
|
||||
target = torch.cuda.LongTensor(
|
||||
size=(batch_size, seq_length)).random_(0, vocab_size)
|
||||
loss = vocab_parallel_cross_entropy(logits_parallel, target).mean()
|
||||
loss.backward()
|
||||
return loss, identity.weight.grad
|
||||
|
||||
|
||||
def test_cross_entropy(model_parallel_size):
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing cross entropy with model parallel size {} ...'.
|
||||
format(model_parallel_size))
|
||||
|
||||
mpu.initialize_model_parallel(model_parallel_size)
|
||||
model_parallel_size = mpu.get_model_parallel_world_size()
|
||||
|
||||
batch_size = 13
|
||||
seq_length = 17
|
||||
vocab_size_per_partition = 11
|
||||
logits_scale = 1000.0
|
||||
vocab_size = vocab_size_per_partition * model_parallel_size
|
||||
seed = 1234
|
||||
|
||||
loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length,
|
||||
vocab_size, logits_scale,
|
||||
seed)
|
||||
loss_mpu, grad_mpu = mpu_cross_entropy(batch_size, seq_length,
|
||||
vocab_size, logits_scale,
|
||||
seed)
|
||||
|
||||
error = loss_torch.sub_(loss_mpu).abs().max()
|
||||
print(' max error in loss on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-6
|
||||
|
||||
error = grad_torch.sub_(grad_mpu).abs().max()
|
||||
print(' max error in grad on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-6
|
||||
|
||||
# Reset groups
|
||||
mpu.destroy_model_parallel()
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('>> passed the test :-)')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
initialize_distributed()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
|
||||
model_parallel_size = 1
|
||||
while model_parallel_size <= world_size:
|
||||
print_separator('test cross entropy')
|
||||
test_cross_entropy(model_parallel_size)
|
||||
model_parallel_size *= 2
|
||||
@@ -1,92 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import operator
|
||||
import sys
|
||||
sys.path.append("../..")
|
||||
|
||||
import torch
|
||||
import mpu
|
||||
from mpu import data as data_utils
|
||||
|
||||
from commons import initialize_distributed
|
||||
from commons import print_separator
|
||||
|
||||
|
||||
def test_boradcast_data(model_parallel_size):
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing boradcast_data with model parallel size {} ...'.
|
||||
format(model_parallel_size))
|
||||
|
||||
mpu.initialize_model_parallel(model_parallel_size)
|
||||
torch.manual_seed(1234 + mpu.get_data_parallel_rank())
|
||||
model_parallel_size = mpu.get_model_parallel_world_size()
|
||||
|
||||
key_size_t = {'key1': [7, 11],
|
||||
'key2': [8, 2, 1],
|
||||
'key3': [13],
|
||||
'key4': [5, 1, 2],
|
||||
'key5': [5, 12]}
|
||||
keys = list(key_size_t.keys())
|
||||
|
||||
data = {}
|
||||
data_t = {}
|
||||
for key in key_size_t:
|
||||
data[key] = torch.LongTensor(size=key_size_t[key]).random_(0, 1000)
|
||||
data_t[key] = data[key].clone()
|
||||
data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000)
|
||||
data_t['keyX'] = data['keyX'].clone()
|
||||
if mpu.get_model_parallel_rank() != 0:
|
||||
data = None
|
||||
|
||||
data_utils._check_data_types(keys, data_t, torch.int64)
|
||||
key_size, key_numel, \
|
||||
total_numel = data_utils._build_key_size_numel_dictionaries(keys, data)
|
||||
for key in keys:
|
||||
assert key_size[key] == key_size_t[key]
|
||||
total_numel_t = 0
|
||||
for key in keys:
|
||||
target_size = functools.reduce(operator.mul, key_size_t[key], 1)
|
||||
assert key_numel[key] == target_size
|
||||
total_numel_t += target_size
|
||||
assert total_numel == total_numel_t
|
||||
|
||||
data_b = data_utils.broadcast_data(keys, data, torch.int64)
|
||||
for key in keys:
|
||||
tensor = data_t[key].cuda()
|
||||
assert data_b[key].sub(tensor).abs().max() == 0
|
||||
|
||||
# Reset groups
|
||||
mpu.destroy_model_parallel()
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('>> passed the test :-)')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
initialize_distributed()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
|
||||
model_parallel_size = 1
|
||||
while model_parallel_size <= world_size:
|
||||
print_separator('test test boradcast data')
|
||||
test_boradcast_data(model_parallel_size)
|
||||
model_parallel_size *= 2
|
||||
|
||||
|
||||
@@ -1,98 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
sys.path.append("../..")
|
||||
|
||||
import torch
|
||||
import mpu
|
||||
|
||||
from commons import initialize_distributed
|
||||
from commons import print_separator
|
||||
|
||||
|
||||
def test_initialize_model_parallel(model_parallel_size):
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing initialize_model_parallel with size {} ...'.format(
|
||||
model_parallel_size))
|
||||
model_parallel_size_ = min(model_parallel_size,
|
||||
torch.distributed.get_world_size())
|
||||
assert not mpu.model_parallel_is_initialized()
|
||||
mpu.initialize_model_parallel(model_parallel_size_)
|
||||
assert mpu.model_parallel_is_initialized()
|
||||
|
||||
# Checks.
|
||||
def check(group, world_size, rank):
|
||||
assert world_size == torch.distributed.get_world_size(group=group)
|
||||
assert rank == torch.distributed.get_rank(group=group)
|
||||
|
||||
# Model parallel.
|
||||
world_size = model_parallel_size_
|
||||
rank = torch.distributed.get_rank() % model_parallel_size_
|
||||
assert world_size == mpu.get_model_parallel_world_size()
|
||||
assert rank == mpu.get_model_parallel_rank()
|
||||
check(mpu.get_model_parallel_group(), world_size, rank)
|
||||
|
||||
|
||||
# Data parallel.
|
||||
world_size = torch.distributed.get_world_size() // model_parallel_size_
|
||||
rank = torch.distributed.get_rank() // model_parallel_size
|
||||
assert world_size == mpu.get_data_parallel_world_size()
|
||||
assert rank == mpu.get_data_parallel_rank()
|
||||
check(mpu.get_data_parallel_group(), world_size, rank)
|
||||
|
||||
# Reset groups
|
||||
mpu.destroy_model_parallel()
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('>> passed the test :-)')
|
||||
|
||||
|
||||
def test_get_model_parallel_src_rank(model_parallel_size_):
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing get_model_parallel_src_rank with size {} ...'.format(
|
||||
model_parallel_size_))
|
||||
model_parallel_size = min(model_parallel_size_,
|
||||
torch.distributed.get_world_size())
|
||||
assert not mpu.model_parallel_is_initialized()
|
||||
mpu.initialize_model_parallel(model_parallel_size)
|
||||
assert mpu.model_parallel_is_initialized()
|
||||
|
||||
# Checks
|
||||
src_rank = torch.distributed.get_rank() - mpu.get_model_parallel_rank()
|
||||
assert mpu.get_model_parallel_src_rank() == src_rank
|
||||
|
||||
# Reset groups
|
||||
mpu.destroy_model_parallel()
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('>> passed the test :-)')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
initialize_distributed()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
model_parallel_size = 1
|
||||
while model_parallel_size <= world_size:
|
||||
print_separator('test initialize model parallel')
|
||||
test_initialize_model_parallel(model_parallel_size)
|
||||
print_separator('test model parallel source rank')
|
||||
test_get_model_parallel_src_rank(model_parallel_size)
|
||||
model_parallel_size *= 2
|
||||
@@ -1,529 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import sys
|
||||
sys.path.append("../..")
|
||||
|
||||
import torch
|
||||
import torch.nn.init as init
|
||||
from torch.nn.parameter import Parameter
|
||||
import mpu
|
||||
|
||||
from commons import initialize_distributed
|
||||
from commons import print_separator
|
||||
from commons import set_random_seed
|
||||
from mpu import layers
|
||||
|
||||
|
||||
def test_parallel_embedding(model_parallel_size):
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing parallel embedding with model parallel size {} ...'.
|
||||
format(model_parallel_size))
|
||||
|
||||
mpu.initialize_model_parallel(model_parallel_size)
|
||||
model_parallel_size = mpu.get_model_parallel_world_size()
|
||||
|
||||
batch_size = 17
|
||||
seq_length = 23
|
||||
vocab_size = 48
|
||||
hidden_size = 16
|
||||
seed = 1236
|
||||
|
||||
set_random_seed(123)
|
||||
input_data = torch.LongTensor(
|
||||
size=(batch_size,seq_length)).random_(0, vocab_size).cuda()
|
||||
loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda()
|
||||
|
||||
set_random_seed(seed)
|
||||
embedding_original = torch.nn.Embedding(vocab_size, hidden_size).cuda()
|
||||
|
||||
output = embedding_original(input_data)
|
||||
loss_original = torch.mul(output, loss_weight).sum()
|
||||
loss_original.backward()
|
||||
|
||||
set_random_seed(seed)
|
||||
embedding_parallel = layers.ParallelEmbedding(
|
||||
vocab_size, hidden_size, init_method=init.normal_).cuda()
|
||||
output = embedding_parallel(input_data)
|
||||
loss_parallel = torch.mul(output, loss_weight).sum()
|
||||
loss_parallel.backward()
|
||||
|
||||
set_random_seed(seed)
|
||||
embedding_vocab_parallel = layers.VocabParallelEmbedding(
|
||||
vocab_size, hidden_size, init_method=init.normal_).cuda()
|
||||
output = embedding_vocab_parallel(input_data)
|
||||
loss_vocab_parallel = torch.mul(output, loss_weight).sum()
|
||||
loss_vocab_parallel.backward()
|
||||
|
||||
torch.distributed.barrier()
|
||||
error = loss_parallel.sub(loss_original).abs()
|
||||
print(' error in loss (parallel) on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-12, 'error: {}'.format(error)
|
||||
|
||||
torch.distributed.barrier()
|
||||
error = loss_vocab_parallel.sub(loss_original).abs()
|
||||
print(' error in loss (vocab parallel) on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-12, 'error: {}'.format(error)
|
||||
|
||||
weight_grad_orig = torch.split(embedding_original.weight.grad,
|
||||
hidden_size // model_parallel_size,
|
||||
1)[mpu.get_model_parallel_rank()]
|
||||
error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max()
|
||||
print(' error in grad (parallel) on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-12, 'error: {}'.format(error)
|
||||
|
||||
weight_grad_orig = torch.split(embedding_original.weight.grad,
|
||||
vocab_size // model_parallel_size,
|
||||
0)[mpu.get_model_parallel_rank()]
|
||||
error = embedding_vocab_parallel.weight.grad.sub(
|
||||
weight_grad_orig).abs().max()
|
||||
print(' error in grad (vocab parallel) on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-12, 'error: {}'.format(error)
|
||||
|
||||
# Reset groups
|
||||
mpu.destroy_model_parallel()
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('>> passed the test :-)')
|
||||
|
||||
|
||||
def test_initialize_affine_weight(model_parallel_size):
|
||||
|
||||
mpu.initialize_model_parallel(model_parallel_size)
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing initialize_affine_weight with model parallel '
|
||||
'size: {}'.format(model_parallel_size))
|
||||
model_parallel_size = mpu.get_model_parallel_world_size()
|
||||
|
||||
seed = 12345
|
||||
input_size_coeff = 13
|
||||
input_size = input_size_coeff * model_parallel_size
|
||||
output_size_coeff = 17
|
||||
output_size = output_size_coeff * model_parallel_size
|
||||
|
||||
# ---------------
|
||||
# Column parallel
|
||||
# ---------------
|
||||
weight = torch.empty(output_size_coeff, input_size)
|
||||
set_random_seed(seed)
|
||||
layers._initialize_affine_weight(weight, output_size, input_size,
|
||||
|
||||
output_size_coeff, 0,
|
||||
torch.nn.init.normal_)
|
||||
# Target.
|
||||
set_random_seed(seed)
|
||||
master_weight = torch.empty(output_size, input_size)
|
||||
torch.nn.init.normal_(master_weight)
|
||||
rank = mpu.get_model_parallel_rank()
|
||||
my_weight = torch.split(master_weight, output_size_coeff,
|
||||
dim=0)[rank].contiguous().clone()
|
||||
|
||||
# Compare.
|
||||
error = weight.sub(my_weight).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' column parallel max error (should be zero) on global rank '
|
||||
'{}: {}'.format(torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-6
|
||||
|
||||
# ------------
|
||||
# Row parallel
|
||||
# ------------
|
||||
weight = torch.empty(output_size, input_size_coeff)
|
||||
set_random_seed(seed)
|
||||
mpu.layers._initialize_affine_weight(weight, output_size, input_size,
|
||||
input_size_coeff, 1,
|
||||
torch.nn.init.normal_)
|
||||
# Target.
|
||||
set_random_seed(seed)
|
||||
master_weight = torch.empty(output_size, input_size)
|
||||
torch.nn.init.normal_(master_weight)
|
||||
rank = mpu.get_model_parallel_rank()
|
||||
my_weight = torch.split(master_weight, input_size_coeff,
|
||||
dim=1)[rank].contiguous().clone()
|
||||
|
||||
# Compare.
|
||||
error = weight.sub(my_weight).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' row parallel max error (should be zero) on global rank '
|
||||
'{}: {}'.format(torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-6
|
||||
|
||||
# Reset groups
|
||||
mpu.destroy_model_parallel()
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print(' >> passed the test :-)')
|
||||
|
||||
|
||||
class IdentityLayer2D(torch.nn.Module):
|
||||
def __init__(self, m , n):
|
||||
super(IdentityLayer2D, self).__init__()
|
||||
self.weight = Parameter(torch.Tensor(m, n))
|
||||
torch.nn.init.xavier_normal_(self.weight)
|
||||
def forward(self):
|
||||
return self.weight
|
||||
|
||||
|
||||
def test_column_parallel_linear(model_parallel_size):
|
||||
|
||||
mpu.initialize_model_parallel(model_parallel_size)
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing ColumnParallelLinear with model parallel '
|
||||
'size: {}'.format(model_parallel_size))
|
||||
model_parallel_size = mpu.get_model_parallel_world_size()
|
||||
|
||||
seed = 12345
|
||||
set_random_seed(seed)
|
||||
input_size_coeff = 13
|
||||
input_size = input_size_coeff * model_parallel_size
|
||||
output_size_coeff = 17
|
||||
output_size = output_size_coeff * model_parallel_size
|
||||
batch_size = 7
|
||||
|
||||
# Network
|
||||
identity_layer = IdentityLayer2D(batch_size, input_size).cuda()
|
||||
linear_layer = mpu.ColumnParallelLinear(
|
||||
input_size, output_size, keep_master_weight_for_test=True).cuda()
|
||||
loss_weight = torch.randn([batch_size, output_size]).cuda()
|
||||
# Forward
|
||||
input_ = identity_layer()
|
||||
output = linear_layer(input_)
|
||||
loss = torch.mul(output, loss_weight).sum()
|
||||
# Backward
|
||||
loss.backward()
|
||||
|
||||
# Values.
|
||||
dLdY = loss_weight
|
||||
X = identity_layer.weight
|
||||
A = linear_layer.master_weight.cuda()
|
||||
dLdA = torch.matmul(dLdY.t(), X)
|
||||
dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
|
||||
dLdX = torch.matmul(dLdY, A)
|
||||
|
||||
rank = mpu.get_model_parallel_rank()
|
||||
my_dLdA = torch.split(dLdA, output_size_coeff,
|
||||
dim=0)[rank].contiguous().clone()
|
||||
error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' error in dLdA on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-6
|
||||
|
||||
my_dLdb = torch.split(dLdb, output_size_coeff,
|
||||
dim=0)[rank].contiguous().clone()
|
||||
error = my_dLdb.sub(linear_layer.bias.grad).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' error in dLdb on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-6
|
||||
|
||||
error = dLdX.sub(identity_layer.weight.grad).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' error in dLdX on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-6
|
||||
|
||||
# Reset groups
|
||||
mpu.destroy_model_parallel()
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print(' >> passed the test :-)')
|
||||
|
||||
|
||||
def test_row_parallel_linear(model_parallel_size):
|
||||
|
||||
mpu.initialize_model_parallel(model_parallel_size)
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing RowParallelLinear with model parallel '
|
||||
'size: {}'.format(model_parallel_size))
|
||||
model_parallel_size = mpu.get_model_parallel_world_size()
|
||||
|
||||
seed = 12345
|
||||
set_random_seed(seed)
|
||||
input_size_coeff = 13
|
||||
input_size = input_size_coeff * model_parallel_size
|
||||
output_size_coeff = 17
|
||||
output_size = output_size_coeff * model_parallel_size
|
||||
batch_size = 7
|
||||
|
||||
# Network
|
||||
identity_layer = IdentityLayer2D(batch_size, input_size).cuda()
|
||||
linear_layer = mpu.RowParallelLinear(
|
||||
input_size, output_size, keep_master_weight_for_test=True).cuda()
|
||||
loss_weight = torch.randn([batch_size, output_size]).cuda()
|
||||
# Forward
|
||||
input_ = identity_layer()
|
||||
output = linear_layer(input_)
|
||||
loss = torch.mul(output, loss_weight).sum()
|
||||
# Backward
|
||||
loss.backward()
|
||||
|
||||
# Values.
|
||||
dLdY = loss_weight
|
||||
X = identity_layer.weight
|
||||
A = linear_layer.master_weight.cuda()
|
||||
dLdA = torch.matmul(dLdY.t(), X)
|
||||
dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
|
||||
dLdX = torch.matmul(dLdY, A)
|
||||
|
||||
rank = mpu.get_model_parallel_rank()
|
||||
my_dLdA = torch.split(dLdA, input_size_coeff,
|
||||
dim=1)[rank].contiguous().clone()
|
||||
error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' error in dLdA on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-6
|
||||
|
||||
error = dLdb.sub(linear_layer.bias.grad).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' error in dLdb on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-6
|
||||
|
||||
error = dLdX.sub(identity_layer.weight.grad).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' error in dLdX on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-6
|
||||
|
||||
# Reset groups
|
||||
mpu.destroy_model_parallel()
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print(' >> passed the test :-)')
|
||||
|
||||
|
||||
class IdentityLayer3D(torch.nn.Module):
|
||||
def __init__(self, m , n, k):
|
||||
super(IdentityLayer3D, self).__init__()
|
||||
self.weight = Parameter(torch.Tensor(m, n, k))
|
||||
torch.nn.init.xavier_normal_(self.weight)
|
||||
def forward(self):
|
||||
return self.weight
|
||||
|
||||
|
||||
def parallel_self_attention(model_parallel_size, num_att_heads_per_partition,
|
||||
hidden_size_per_att_head, dropout_prob, batch_size,
|
||||
sequence_length):
|
||||
mpu.initialize_model_parallel(model_parallel_size)
|
||||
model_parallel_size = mpu.get_model_parallel_world_size()
|
||||
|
||||
seed = 12345
|
||||
set_random_seed(seed)
|
||||
|
||||
num_att_heads = num_att_heads_per_partition * \
|
||||
torch.distributed.get_world_size()
|
||||
hidden_size = hidden_size_per_att_head * num_att_heads
|
||||
|
||||
# Network
|
||||
identity_layer = IdentityLayer3D(batch_size, sequence_length,
|
||||
hidden_size).cuda()
|
||||
attention_layer = mpu.BertParallelSelfAttention(hidden_size, num_att_heads,
|
||||
dropout_prob).cuda()
|
||||
loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda()
|
||||
attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda()
|
||||
# Forward
|
||||
input_ = identity_layer()
|
||||
output = attention_layer(input_, attention_mask)
|
||||
loss = torch.mul(output, loss_weight).sum()
|
||||
# Backward
|
||||
loss.backward()
|
||||
|
||||
rank = mpu.get_model_parallel_rank()
|
||||
mpu.destroy_model_parallel()
|
||||
return rank, hidden_size, model_parallel_size, loss, \
|
||||
attention_layer, identity_layer
|
||||
|
||||
|
||||
def test_parallel_self_attention(model_parallel_size):
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing ParallelSelfAttention with model parallel '
|
||||
'size: {}'.format(model_parallel_size))
|
||||
|
||||
num_att_heads_per_partition = 3
|
||||
hidden_size_per_att_head = 7
|
||||
dropout_prob = 0.0 # has to be zero
|
||||
batch_size = 5
|
||||
sequence_length = 13
|
||||
|
||||
rank_1, hideen_size_1, model_parallel_size_1, loss_1, \
|
||||
attention_layer_1, identity_layer_1 =parallel_self_attention(
|
||||
1, num_att_heads_per_partition,
|
||||
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
|
||||
|
||||
rank, hidden_size, model_parallel_size, loss, \
|
||||
attention_layer, identity_layer =parallel_self_attention(
|
||||
model_parallel_size, num_att_heads_per_partition,
|
||||
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
|
||||
assert hideen_size_1 == hidden_size
|
||||
|
||||
error = loss_1.sub(loss).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' loss error on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 5.0e-6
|
||||
|
||||
my_lin_grad_list = torch.split(
|
||||
attention_layer_1.query_key_value.weight.grad,
|
||||
hidden_size // model_parallel_size, 0)[rank::model_parallel_size]
|
||||
my_lin_grad = torch.cat(my_lin_grad_list, dim=0)
|
||||
error = my_lin_grad.sub(
|
||||
attention_layer.query_key_value.weight.grad).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' weight gradient error on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 5.0e-6
|
||||
|
||||
error = identity_layer_1.weight.grad.sub(
|
||||
identity_layer.weight.grad).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' input gradient error on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 5.0e-6
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print(' >> passed the test :-)')
|
||||
|
||||
def parallel_transformer(model_parallel_size, num_att_heads_per_partition,
|
||||
hidden_size_per_att_head, batch_size, sequence_length):
|
||||
|
||||
mpu.initialize_model_parallel(model_parallel_size)
|
||||
model_parallel_size = mpu.get_model_parallel_world_size()
|
||||
|
||||
seed = 12345
|
||||
set_random_seed(seed)
|
||||
|
||||
num_att_heads = num_att_heads_per_partition * \
|
||||
torch.distributed.get_world_size()
|
||||
hidden_size = hidden_size_per_att_head * num_att_heads
|
||||
intermediate_size = 4 * hidden_size
|
||||
|
||||
# Network
|
||||
identity_layer = IdentityLayer3D(batch_size, sequence_length,
|
||||
hidden_size).cuda()
|
||||
transformer_layer = mpu.BertParallelTransformerLayer(
|
||||
hidden_size, intermediate_size, num_att_heads, 0.0, 0.0,
|
||||
torch.nn.functional.relu, 1.0e-5).cuda()
|
||||
|
||||
loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda()
|
||||
attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda()
|
||||
# Forward
|
||||
input_ = identity_layer()
|
||||
output = transformer_layer(input_, attention_mask)
|
||||
loss = torch.mul(output, loss_weight).sum()
|
||||
# Backward
|
||||
loss.backward()
|
||||
|
||||
rank = mpu.get_model_parallel_rank()
|
||||
mpu.destroy_model_parallel()
|
||||
return rank, hidden_size, model_parallel_size, loss, \
|
||||
transformer_layer, identity_layer
|
||||
|
||||
|
||||
def test_parallel_transformer_layer(model_parallel_size):
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing ParallelTransformerLayer with model parallel '
|
||||
'size: {}'.format(model_parallel_size))
|
||||
|
||||
num_att_heads_per_partition = 3
|
||||
hidden_size_per_att_head = 7
|
||||
batch_size = 5
|
||||
sequence_length = 13
|
||||
|
||||
rank_1, hidden_size_1, model_parallel_size_1, loss_1, \
|
||||
transformer_layer_1, identity_layer_1 = parallel_transformer(
|
||||
1, num_att_heads_per_partition,
|
||||
hidden_size_per_att_head, batch_size, sequence_length)
|
||||
|
||||
rank, hidden_size, model_parallel_size, loss, \
|
||||
transformer_layer, identity_layer = parallel_transformer(
|
||||
model_parallel_size, num_att_heads_per_partition,
|
||||
hidden_size_per_att_head, batch_size, sequence_length)
|
||||
|
||||
error = loss_1.sub(loss).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' loss error on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 5.0e-5, 'error: {}'.format(error)
|
||||
|
||||
error = identity_layer_1.weight.grad.sub(
|
||||
identity_layer.weight.grad).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' input gradient error on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 5.0e-5, 'error: {}'.format(error)
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print(' >> passed the test :-)')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
initialize_distributed()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
|
||||
print_separator('test initialize affine weight')
|
||||
model_parallel_size = 1
|
||||
while model_parallel_size <= world_size:
|
||||
test_initialize_affine_weight(model_parallel_size)
|
||||
model_parallel_size *= 2
|
||||
|
||||
model_parallel_size = 1
|
||||
while model_parallel_size <= world_size:
|
||||
print_separator('test parallel embedding')
|
||||
test_parallel_embedding(model_parallel_size)
|
||||
model_parallel_size *= 2
|
||||
|
||||
print_separator('test column-parallel linear')
|
||||
model_parallel_size = 1
|
||||
while model_parallel_size <= world_size:
|
||||
test_column_parallel_linear(model_parallel_size)
|
||||
model_parallel_size *= 2
|
||||
|
||||
print_separator('test row-parallel linear')
|
||||
model_parallel_size = 1
|
||||
while model_parallel_size <= world_size:
|
||||
test_row_parallel_linear(model_parallel_size)
|
||||
model_parallel_size *= 2
|
||||
|
||||
print_separator('test parallel self-attention')
|
||||
model_parallel_size = 1
|
||||
while model_parallel_size <= world_size:
|
||||
test_parallel_self_attention(model_parallel_size)
|
||||
model_parallel_size *= 2
|
||||
|
||||
print_separator('test parallel transformer')
|
||||
model_parallel_size = 1
|
||||
while model_parallel_size <= world_size:
|
||||
test_parallel_transformer_layer(model_parallel_size)
|
||||
model_parallel_size *= 2
|
||||
@@ -1,207 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
sys.path.append("../..")
|
||||
|
||||
import torch
|
||||
import mpu
|
||||
|
||||
from commons import initialize_distributed
|
||||
from commons import print_separator
|
||||
|
||||
|
||||
def test_set_cuda_rng_state(model_parallel_size):
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing set_rng_state with size {} ...'.
|
||||
format(model_parallel_size))
|
||||
|
||||
mpu.initialize_model_parallel(model_parallel_size)
|
||||
model_parallel_size = mpu.get_model_parallel_world_size()
|
||||
|
||||
size = 123
|
||||
seed = 1234
|
||||
torch.cuda.manual_seed(1234)
|
||||
tensor = torch.cuda.FloatTensor(size)
|
||||
|
||||
# Get the state
|
||||
rng_state = torch.cuda.get_rng_state()
|
||||
rng_state_copy = rng_state.clone()
|
||||
|
||||
# Do some stuff.
|
||||
for _ in range(5):
|
||||
torch.randn(size, out=tensor)
|
||||
result_1 = tensor.clone()
|
||||
|
||||
assert rng_state.sub(rng_state_copy).max() == 0
|
||||
assert torch.cuda.get_rng_state().sub(rng_state_copy).max() > 0
|
||||
|
||||
# State should be different.
|
||||
new_rng_state = torch.cuda.get_rng_state()
|
||||
max_diff = new_rng_state.sub(rng_state).max()
|
||||
print(' max diff in rng state (should be non-zero) on global rank {}: {}'.
|
||||
format(torch.distributed.get_rank(), max_diff))
|
||||
assert max_diff > 0
|
||||
|
||||
# Reset the rng state and do the same stuff.
|
||||
mpu.random._set_cuda_rng_state(rng_state)
|
||||
for _ in range(5):
|
||||
torch.randn(size, out=tensor)
|
||||
mpu.random._set_cuda_rng_state(rng_state)
|
||||
for _ in range(5):
|
||||
torch.randn(size, out=tensor)
|
||||
result_2 = tensor.clone()
|
||||
|
||||
# Results should be the same
|
||||
error = result_2.sub(result_1).abs().max()
|
||||
print(' max error in generated tensors (should be zero) on '
|
||||
'global rank {}: {}'.format(torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-6
|
||||
|
||||
# Input state should have remained intact.
|
||||
error = rng_state.sub(rng_state_copy).max()
|
||||
print(' max error in rng state (should be zero) on global rank {}: {}'.
|
||||
format(torch.distributed.get_rank(), error))
|
||||
assert error == 0
|
||||
|
||||
# Reset groups
|
||||
mpu.destroy_model_parallel()
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('>> passed the test :-)')
|
||||
|
||||
|
||||
def test_cuda_rng_tracker(model_parallel_size):
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing cuda rng tracker with size {} ...'.
|
||||
format(model_parallel_size))
|
||||
|
||||
mpu.initialize_model_parallel(model_parallel_size)
|
||||
model_parallel_size = mpu.get_model_parallel_world_size()
|
||||
|
||||
seed_1 = 1234
|
||||
seed_2 = 4321
|
||||
size = [12, 21]
|
||||
tensor = torch.cuda.FloatTensor(size)
|
||||
|
||||
# Set to seed_1 and generate two tensors.
|
||||
torch.cuda.manual_seed(seed_1)
|
||||
torch.randn(size, out=tensor)
|
||||
target_11 = tensor.clone()
|
||||
torch.randn(size, out=tensor)
|
||||
target_12 = tensor.clone()
|
||||
|
||||
# Set to seed_2 and generate two tensors.
|
||||
torch.cuda.manual_seed(seed_2)
|
||||
torch.randn(size, out=tensor)
|
||||
target_21 = tensor.clone()
|
||||
torch.randn(size, out=tensor)
|
||||
target_22 = tensor.clone()
|
||||
|
||||
# Now if we interleave seed_1 and seed_2,
|
||||
# we should still get the same tensors
|
||||
torch.cuda.manual_seed(seed_1)
|
||||
mpu.get_cuda_rng_tracker().add('test', seed_2)
|
||||
|
||||
torch.randn(size, out=tensor)
|
||||
result_11 = tensor.clone()
|
||||
|
||||
with mpu.get_cuda_rng_tracker().fork('test'):
|
||||
torch.randn(size, out=tensor)
|
||||
result_21 = tensor.clone()
|
||||
|
||||
torch.randn(size, out=tensor)
|
||||
result_12 = tensor.clone()
|
||||
|
||||
with mpu.get_cuda_rng_tracker().fork('test'):
|
||||
torch.randn(size, out=tensor)
|
||||
result_22 = tensor.clone()
|
||||
|
||||
diff = result_11.sub(result_21).abs().max()
|
||||
diff = min(diff, result_12.sub(result_22).abs().max())
|
||||
print(' max diff in generated tensors (should be non-zero) on '
|
||||
'global rank {}: {}'.format(torch.distributed.get_rank(), diff))
|
||||
assert diff > 1.0e-6
|
||||
error = max(result_11.sub(target_11).abs().max(),
|
||||
result_12.sub(target_12).abs().max())
|
||||
error = max(error, result_21.sub(target_21).abs().max())
|
||||
error = max(error, result_22.sub(target_22).abs().max())
|
||||
print(' max error in generated tensors (should be zero) on '
|
||||
'global rank {}: {}'.format(torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-6
|
||||
|
||||
# Reset the tracker
|
||||
mpu.get_cuda_rng_tracker().reset()
|
||||
|
||||
# Reset groups
|
||||
mpu.destroy_model_parallel()
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('>> passed the test :-)')
|
||||
|
||||
|
||||
def test_model_parallel_cuda_manual_seed(model_parallel_size):
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing model parallel cuda manual seed with size {} ...'.
|
||||
format(model_parallel_size))
|
||||
|
||||
mpu.initialize_model_parallel(model_parallel_size)
|
||||
model_parallel_size = mpu.get_model_parallel_world_size()
|
||||
|
||||
mpu.model_parallel_cuda_manual_seed(12345)
|
||||
assert torch.cuda.initial_seed() == 12345
|
||||
with mpu.get_cuda_rng_tracker().fork():
|
||||
assert torch.cuda.initial_seed() == (12345 + 2718 +
|
||||
mpu.get_model_parallel_rank())
|
||||
|
||||
# Reset the tracker
|
||||
mpu.get_cuda_rng_tracker().reset()
|
||||
|
||||
# Reset groups
|
||||
mpu.destroy_model_parallel()
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('>> passed the test :-)')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
initialize_distributed()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
|
||||
model_parallel_size = 1
|
||||
while model_parallel_size <= world_size:
|
||||
print_separator('test set rng state')
|
||||
test_set_cuda_rng_state(model_parallel_size)
|
||||
model_parallel_size *= 2
|
||||
|
||||
model_parallel_size = 1
|
||||
while model_parallel_size <= world_size:
|
||||
print_separator('test cuda rng tracker')
|
||||
test_cuda_rng_tracker(model_parallel_size)
|
||||
model_parallel_size *= 2
|
||||
|
||||
model_parallel_size = 1
|
||||
while model_parallel_size <= world_size:
|
||||
print_separator('test model parallel cuda manual seed')
|
||||
test_model_parallel_cuda_manual_seed(model_parallel_size)
|
||||
model_parallel_size *= 2
|
||||
|
||||
@@ -1,21 +1,16 @@
|
||||
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import numpy as np
|
||||
from modelscope.utils.logger import get_logger
|
||||
import torch
|
||||
|
||||
"""Utilities for logging and serialization"""
|
||||
logger = get_logger(__name__)
|
||||
|
||||
def get_log_constant(user_log):
|
||||
return '[user log]' if user_log else ''
|
||||
|
||||
def print_rank_0(message):
|
||||
if torch.distributed.is_initialized():
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print(message, flush=True)
|
||||
logger.info(message)
|
||||
else:
|
||||
print(message, flush=True)
|
||||
logger.info(message)
|
||||
|
||||
|
||||
def print_args(args):
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
import functools
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
import numpy as np
|
||||
import socket
|
||||
import subprocess
|
||||
import tempfile
|
||||
@@ -11,8 +13,7 @@ from typing import Callable, List, Optional, Tuple
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from torch import distributed as dist
|
||||
from torch._utils import (_flatten_dense_tensors, _take_tensors,
|
||||
_unflatten_dense_tensors)
|
||||
from modelscope.utils.nlp import mpu
|
||||
|
||||
|
||||
def _find_free_port() -> str:
|
||||
@@ -106,6 +107,25 @@ def _init_dist_slurm(backend: str, port: Optional[int] = None) -> None:
|
||||
dist.init_process_group(backend=backend)
|
||||
|
||||
|
||||
def initialize_distributed(rank):
|
||||
"""Initialize torch.distributed."""
|
||||
# Manually set the device ids.
|
||||
#torch.multiprocessing.set_start_method("spawn")
|
||||
device = rank % torch.cuda.device_count()
|
||||
torch.cuda.set_device(device)
|
||||
# Call the init process
|
||||
init_method = 'tcp://'
|
||||
master_ip = os.getenv('MASTER_ADDR', '127.0.0.1')
|
||||
master_port = os.getenv('MASTER_PORT', '12345')
|
||||
init_method += master_ip + ':' + master_port
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl",
|
||||
world_size=8, rank=rank,
|
||||
init_method=init_method)
|
||||
# Set the model-parallel communicators.
|
||||
mpu.initialize_model_parallel(8)
|
||||
|
||||
|
||||
def get_dist_info() -> Tuple[int, int]:
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
rank = dist.get_rank()
|
||||
@@ -191,3 +211,17 @@ def broadcast(inputs, src):
|
||||
dist.broadcast(inputs_tensor, src)
|
||||
|
||||
return pickle.loads(inputs_tensor.cpu().numpy().tobytes())
|
||||
|
||||
|
||||
def set_random_seed(seed):
|
||||
if seed is not None and seed > 0:
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
else:
|
||||
raise ValueError(f'Random seed should be positive, current seed is {seed}')
|
||||
|
||||
|
||||
def set_random_seed_mpu(seed):
|
||||
set_random_seed(seed)
|
||||
mpu.model_parallel_cuda_manual_seed(seed)
|
||||
|
||||
@@ -8,3 +8,4 @@ seqeval
|
||||
spacy>=2.3.5
|
||||
tokenizers
|
||||
transformers>=4.12.0
|
||||
deepspeed
|
||||
|
||||
Reference in New Issue
Block a user