This commit is contained in:
yzhao
2022-08-25 15:20:14 +08:00
parent dc9050325f
commit 5e3446db4b
21 changed files with 765 additions and 1921 deletions

View File

@@ -24,46 +24,46 @@ import deepspeed
def add_model_config_args(parser):
"""Model arguments"""
group = parser.add_argument_group('model', 'model configuration')
# group = parser.add_argument_group('model', 'model configuration')
group.add_argument('--pretrained-bert', action='store_true',
help='use a pretrained bert-large-uncased model instead'
'of initializing from scratch. See '
'--tokenizer-model-type to specify which pretrained '
'BERT model to use')
group.add_argument('--attention-dropout', type=float, default=0.1,
help='dropout probability for attention weights')
group.add_argument('--num-attention-heads', type=int, default=16,
help='num of transformer attention heads')
group.add_argument('--hidden-size', type=int, default=1024,
help='tansformer hidden size')
group.add_argument('--intermediate-size', type=int, default=None,
help='transformer embedding dimension for FFN'
'set to 4*`--hidden-size` if it is None')
group.add_argument('--num-layers', type=int, default=24,
help='num decoder layers')
group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
help='layer norm epsilon')
group.add_argument('--hidden-dropout', type=float, default=0.1,
help='dropout probability for hidden state transformer')
group.add_argument('--max-position-embeddings', type=int, default=512,
help='maximum number of position embeddings to use')
group.add_argument('--vocab-size', type=int, default=30522,
help='vocab size to use for non-character-level '
'tokenization. This value will only be used when '
'creating a tokenizer')
group.add_argument('--deep-init', action='store_true',
help='initialize bert model similar to gpt2 model.'
'scales initialization of projection layers by a '
'factor of 1/sqrt(2N). Necessary to train bert '
'models larger than BERT-Large.')
group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
help='Pad the vocab size to be divisible by this value.'
'This is added for computational efficieny reasons.')
group.add_argument('--cpu-optimizer', action='store_true',
help='Run optimizer on CPU')
group.add_argument('--cpu_torch_adam', action='store_true',
help='Use Torch Adam as optimizer on CPU.')
# group.add_argument('--pretrained-bert', action='store_true',
# help='use a pretrained bert-large-uncased model instead'
# 'of initializing from scratch. See '
# '--tokenizer-model-type to specify which pretrained '
# 'BERT model to use')
# group.add_argument('--attention-dropout', type=float, default=0.1,
# help='dropout probability for attention weights')
# group.add_argument('--num-attention-heads', type=int, default=16,
# help='num of transformer attention heads')
# group.add_argument('--hidden-size', type=int, default=1024,
# help='tansformer hidden size')
# group.add_argument('--intermediate-size', type=int, default=None,
# help='transformer embedding dimension for FFN'
# 'set to 4*`--hidden-size` if it is None')
# group.add_argument('--num-layers', type=int, default=24,
# help='num decoder layers')
# group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
# help='layer norm epsilon')
# group.add_argument('--hidden-dropout', type=float, default=0.1,
# help='dropout probability for hidden state transformer')
# group.add_argument('--max-position-embeddings', type=int, default=512,
# help='maximum number of position embeddings to use')
# group.add_argument('--vocab-size', type=int, default=30522,
# help='vocab size to use for non-character-level '
# 'tokenization. This value will only be used when '
# 'creating a tokenizer')
# group.add_argument('--deep-init', action='store_true',
# help='initialize bert model similar to gpt2 model.'
# 'scales initialization of projection layers by a '
# 'factor of 1/sqrt(2N). Necessary to train bert '
# 'models larger than BERT-Large.')
# group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
# help='Pad the vocab size to be divisible by this value.'
# 'This is added for computational efficieny reasons.')
# group.add_argument('--cpu-optimizer', action='store_true',
# help='Run optimizer on CPU')
# group.add_argument('--cpu_torch_adam', action='store_true',
# help='Use Torch Adam as optimizer on CPU.')
return parser
@@ -71,28 +71,28 @@ def add_model_config_args(parser):
def add_fp16_config_args(parser):
"""Mixed precision arguments."""
group = parser.add_argument_group('fp16', 'fp16 configurations')
# group = parser.add_argument_group('fp16', 'fp16 configurations')
group.add_argument('--fp16', action='store_true',
help='Run model in fp16 mode')
group.add_argument('--fp32-embedding', action='store_true',
help='embedding in fp32')
group.add_argument('--fp32-layernorm', action='store_true',
help='layer norm in fp32')
group.add_argument('--fp32-tokentypes', action='store_true',
help='embedding token types in fp32')
group.add_argument('--fp32-allreduce', action='store_true',
help='all-reduce in fp32')
group.add_argument('--hysteresis', type=int, default=2,
help='hysteresis for dynamic loss scaling')
group.add_argument('--loss-scale', type=float, default=None,
help='Static loss scaling, positive power of 2 '
'values can improve fp16 convergence. If None, dynamic'
'loss scaling is used.')
group.add_argument('--loss-scale-window', type=float, default=1000,
help='Window over which to raise/lower dynamic scale')
group.add_argument('--min-scale', type=float, default=1,
help='Minimum loss scale for dynamic loss scale')
# group.add_argument('--fp16', action='store_true',
# help='Run model in fp16 mode')
# group.add_argument('--fp32-embedding', action='store_true',
# help='embedding in fp32')
# group.add_argument('--fp32-layernorm', action='store_true',
# help='layer norm in fp32')
# group.add_argument('--fp32-tokentypes', action='store_true',
# help='embedding token types in fp32')
# group.add_argument('--fp32-allreduce', action='store_true',
# help='all-reduce in fp32')
# group.add_argument('--hysteresis', type=int, default=2,
# help='hysteresis for dynamic loss scaling')
# group.add_argument('--loss-scale', type=float, default=None,
# help='Static loss scaling, positive power of 2 '
# 'values can improve fp16 convergence. If None, dynamic'
# 'loss scaling is used.')
# group.add_argument('--loss-scale-window', type=float, default=1000,
# help='Window over which to raise/lower dynamic scale')
# group.add_argument('--min-scale', type=float, default=1,
# help='Minimum loss scale for dynamic loss scale')
return parser
@@ -100,91 +100,91 @@ def add_fp16_config_args(parser):
def add_training_args(parser):
"""Training arguments."""
group = parser.add_argument_group('train', 'training configurations')
# group = parser.add_argument_group('train', 'training configurations')
group.add_argument('--batch-size', type=int, default=4,
help='Data Loader batch size')
group.add_argument('--weight-decay', type=float, default=0.01,
help='weight decay coefficient for L2 regularization')
group.add_argument('--checkpoint-activations', action='store_true',
help='checkpoint activation to allow for training '
'with larger models and sequences')
group.add_argument('--checkpoint-num-layers', type=int, default=1,
help='chunk size (number of layers) for checkpointing')
group.add_argument('--deepspeed-activation-checkpointing', action='store_true',
help='uses activation checkpointing from deepspeed')
group.add_argument('--clip-grad', type=float, default=1.0,
help='gradient clipping')
group.add_argument('--train-iters', type=int, default=1000000,
help='total number of iterations to train over all training runs')
group.add_argument('--log-interval', type=int, default=100,
help='report interval')
group.add_argument('--exit-interval', type=int, default=None,
help='Exit the program after this many new iterations.')
group.add_argument('--seed', type=int, default=1234,
help='random seed')
# Batch prodecuer arguments
group.add_argument('--reset-position-ids', action='store_true',
help='Reset posistion ids after end-of-document token.')
group.add_argument('--reset-attention-mask', action='store_true',
help='Reset self attention maske after '
'end-of-document token.')
# Learning rate.
group.add_argument('--lr-decay-iters', type=int, default=None,
help='number of iterations to decay LR over,'
' If None defaults to `--train-iters`*`--epochs`')
group.add_argument('--lr-decay-style', type=str, default='linear',
choices=['constant', 'linear', 'cosine', 'exponential'],
help='learning rate decay function')
group.add_argument('--lr', type=float, default=1.0e-4,
help='initial learning rate')
group.add_argument('--warmup', type=float, default=0.01,
help='percentage of data to warmup on (.01 = 1% of all '
'training iters). Default 0.01')
group.add_argument('--batch-warmup', type=float, default=0.01,
help='percentage of data to warmup on (.01 = 1% of all '
'training iters). Default 0.01')
group.add_argument('--length-warmup', type=float, default=0.01,
help='percentage of data to warmup on (.01 = 1% of all '
'training iters). Default 0.01')
# model checkpointing
group.add_argument('--save', type=str, default=None,
help='Output directory to save checkpoints to.')
group.add_argument('--save-interval', type=int, default=2000,
help='number of iterations between saves')
group.add_argument('--no-save-optim', action='store_true',
help='Do not save current optimizer.')
group.add_argument('--no-save-rng', action='store_true',
help='Do not save current rng state.')
group.add_argument('--load', type=str, default=None,
help='Path to a directory containing a model checkpoint.')
group.add_argument('--load-iteration', type=str, default=0,
help='Load iteration of a model checkpoint.')
group.add_argument('--pre-load', action='store_true',
help='Use pre-load instead of deepspeed load.')
group.add_argument('--no-load-optim', action='store_true',
help='Do not load optimizer when loading checkpoint.')
group.add_argument('--no-load-rng', action='store_true',
help='Do not load rng state when loading checkpoint.')
group.add_argument('--no-load-lr', action='store_true',
help='Do not load lr schedule when loading checkpoint.')
group.add_argument('--finetune', action='store_true',
help='Load model for finetuning. Do not load optimizer '
'or rng state from checkpoint and set iteration to 0. '
'Assumed when loading a release checkpoint.')
group.add_argument('--resume-dataloader', action='store_true',
help='Resume the dataloader when resuming training. '
'Does not apply to tfrecords dataloader, try resuming'
'with a different seed in this case.')
# distributed training args
group.add_argument('--distributed-backend', default='nccl',
help='which backend to use for distributed '
'training. One of [gloo, nccl]')
group.add_argument('--local_rank', type=int, default=None,
help='local rank passed from distributed launcher')
# group.add_argument('--batch-size', type=int, default=4,
# help='Data Loader batch size')
# group.add_argument('--weight-decay', type=float, default=0.01,
# help='weight decay coefficient for L2 regularization')
# group.add_argument('--checkpoint-activations', action='store_true',
# help='checkpoint activation to allow for training '
# 'with larger models and sequences')
# group.add_argument('--checkpoint-num-layers', type=int, default=1,
# help='chunk size (number of layers) for checkpointing')
# group.add_argument('--deepspeed-activation-checkpointing', action='store_true',
# help='uses activation checkpointing from deepspeed')
# group.add_argument('--clip-grad', type=float, default=1.0,
# help='gradient clipping')
# group.add_argument('--train-iters', type=int, default=1000000,
# help='total number of iterations to train over all training runs')
# group.add_argument('--log-interval', type=int, default=100,
# help='report interval')
# group.add_argument('--exit-interval', type=int, default=None,
# help='Exit the program after this many new iterations.')
#
# group.add_argument('--seed', type=int, default=1234,
# help='random seed')
# # Batch prodecuer arguments
# group.add_argument('--reset-position-ids', action='store_true',
# help='Reset posistion ids after end-of-document token.')
# group.add_argument('--reset-attention-mask', action='store_true',
# help='Reset self attention maske after '
# 'end-of-document token.')
#
# # Learning rate.
# group.add_argument('--lr-decay-iters', type=int, default=None,
# help='number of iterations to decay LR over,'
# ' If None defaults to `--train-iters`*`--epochs`')
# group.add_argument('--lr-decay-style', type=str, default='linear',
# choices=['constant', 'linear', 'cosine', 'exponential'],
# help='learning rate decay function')
# group.add_argument('--lr', type=float, default=1.0e-4,
# help='initial learning rate')
# group.add_argument('--warmup', type=float, default=0.01,
# help='percentage of data to warmup on (.01 = 1% of all '
# 'training iters). Default 0.01')
# group.add_argument('--batch-warmup', type=float, default=0.01,
# help='percentage of data to warmup on (.01 = 1% of all '
# 'training iters). Default 0.01')
# group.add_argument('--length-warmup', type=float, default=0.01,
# help='percentage of data to warmup on (.01 = 1% of all '
# 'training iters). Default 0.01')
# # model checkpointing
# group.add_argument('--save', type=str, default=None,
# help='Output directory to save checkpoints to.')
# group.add_argument('--save-interval', type=int, default=2000,
# help='number of iterations between saves')
# group.add_argument('--no-save-optim', action='store_true',
# help='Do not save current optimizer.')
# group.add_argument('--no-save-rng', action='store_true',
# help='Do not save current rng state.')
# group.add_argument('--load', type=str, default=None,
# help='Path to a directory containing a model checkpoint.')
# group.add_argument('--load-iteration', type=str, default=0,
# help='Load iteration of a model checkpoint.')
# group.add_argument('--pre-load', action='store_true',
# help='Use pre-load instead of deepspeed load.')
# group.add_argument('--no-load-optim', action='store_true',
# help='Do not load optimizer when loading checkpoint.')
# group.add_argument('--no-load-rng', action='store_true',
# help='Do not load rng state when loading checkpoint.')
# group.add_argument('--no-load-lr', action='store_true',
# help='Do not load lr schedule when loading checkpoint.')
# group.add_argument('--finetune', action='store_true',
# help='Load model for finetuning. Do not load optimizer '
# 'or rng state from checkpoint and set iteration to 0. '
# 'Assumed when loading a release checkpoint.')
# group.add_argument('--resume-dataloader', action='store_true',
# help='Resume the dataloader when resuming training. '
# 'Does not apply to tfrecords dataloader, try resuming'
# 'with a different seed in this case.')
# # distributed training args
# group.add_argument('--distributed-backend', default='nccl',
# help='which backend to use for distributed '
# 'training. One of [gloo, nccl]')
#
# group.add_argument('--local_rank', type=int, default=None,
# help='local rank passed from distributed launcher')
return parser
@@ -192,164 +192,164 @@ def add_training_args(parser):
def add_evaluation_args(parser):
"""Evaluation arguments."""
group = parser.add_argument_group('validation', 'validation configurations')
# group = parser.add_argument_group('validation', 'validation configurations')
group.add_argument('--eval-batch-size', type=int, default=None,
help='Data Loader batch size for evaluation datasets.'
'Defaults to `--batch-size`')
group.add_argument('--eval-iters', type=int, default=100,
help='number of iterations to run for evaluation'
'validation/test for')
group.add_argument('--eval-interval', type=int, default=1000,
help='interval between running evaluation on validation set')
group.add_argument('--eval-seq-length', type=int, default=None,
help='Maximum sequence length to process for '
'evaluation. Defaults to `--seq-length`')
group.add_argument('--eval-max-preds-per-seq', type=int, default=None,
help='Maximum number of predictions to use for '
'evaluation. Defaults to '
'math.ceil(`--eval-seq-length`*.15/10)*10')
group.add_argument('--overlapping-eval', type=int, default=32,
help='sliding window for overlapping eval ')
group.add_argument('--cloze-eval', action='store_true',
help='Evaluation dataset from `--valid-data` is a cloze task')
group.add_argument('--eval-hf', action='store_true',
help='perform evaluation with huggingface openai model.'
'use `--load` to specify weights path to be loaded')
group.add_argument('--load-openai', action='store_true',
help='load openai weights into our model. Use `--load` '
'to specify weights path to be loaded')
# group.add_argument('--eval-batch-size', type=int, default=None,
# help='Data Loader batch size for evaluation datasets.'
# 'Defaults to `--batch-size`')
# group.add_argument('--eval-iters', type=int, default=100,
# help='number of iterations to run for evaluation'
# 'validation/test for')
# group.add_argument('--eval-interval', type=int, default=1000,
# help='interval between running evaluation on validation set')
# group.add_argument('--eval-seq-length', type=int, default=None,
# help='Maximum sequence length to process for '
# 'evaluation. Defaults to `--seq-length`')
# group.add_argument('--eval-max-preds-per-seq', type=int, default=None,
# help='Maximum number of predictions to use for '
# 'evaluation. Defaults to '
# 'math.ceil(`--eval-seq-length`*.15/10)*10')
# group.add_argument('--overlapping-eval', type=int, default=32,
# help='sliding window for overlapping eval ')
# group.add_argument('--cloze-eval', action='store_true',
# help='Evaluation dataset from `--valid-data` is a cloze task')
# group.add_argument('--eval-hf', action='store_true',
# help='perform evaluation with huggingface openai model.'
# 'use `--load` to specify weights path to be loaded')
# group.add_argument('--load-openai', action='store_true',
# help='load openai weights into our model. Use `--load` '
# 'to specify weights path to be loaded')
return parser
def add_text_generate_args(parser):
"""Text generate arguments."""
group = parser.add_argument_group('Text generation', 'configurations')
group.add_argument("--temperature", type=float, default=1.0)
group.add_argument("--top_p", type=float, default=0.0)
group.add_argument("--top_k", type=int, default=0)
group.add_argument("--out-seq-length", type=int, default=256)
# group = parser.add_argument_group('Text generation', 'configurations')
# group.add_argument("--temperature", type=float, default=1.0)
# group.add_argument("--top_p", type=float, default=0.0)
# group.add_argument("--top_k", type=int, default=0)
# group.add_argument("--out-seq-length", type=int, default=256)
return parser
def add_struct_args(parser):
group = parser.add_argument_group('struct', 'struct configurations')
group.add_argument("--gradient-accumulation-steps", type=int, default=1,
help='Not Imp yet.')
group.add_argument("--num-epochs", type=int, default=1,
help='Not Imp yet.')
group.add_argument("--struct-bert-dataset", action='store_true', default=False,
help='Use struct bert dataset or not.')
# group = parser.add_argument_group('struct', 'struct configurations')
# group.add_argument("--gradient-accumulation-steps", type=int, default=1,
# help='Not Imp yet.')
# group.add_argument("--num-epochs", type=int, default=1,
# help='Not Imp yet.')
# group.add_argument("--struct-bert-dataset", action='store_true', default=False,
# help='Use struct bert dataset or not.')
return parser
def add_palm_args(parser):
group = parser.add_argument_group('palm', 'struct configurations')
group.add_argument('--dec-layers', type=int, default=6,
help='num decoder layers')
group.add_argument('--tgt-length', type=int, default=100,
help='num decoder layers')
group.add_argument('--vae-size', type=int, default=8192,
help='vae code vocab size')
group.add_argument('--max-image-position', type=int, default=1025,
help='max image decode position')
group.add_argument("--palm-dataset", action='store_true', default=False,
help='Use struct bert dataset or not.')
group.add_argument("--image-dataset", action='store_true', default=False,
help='Use struct bert dataset or not.')
group.add_argument("--do-mask-lm", action='store_true', default=False,
help='Do mask lm task or not.')
group.add_argument('--vae-enc-model', type=str, default=None,
help='Path to a directory containing a model checkpoint.')
# group = parser.add_argument_group('palm', 'struct configurations')
# group.add_argument('--dec-layers', type=int, default=6,
# help='num decoder layers')
# group.add_argument('--tgt-length', type=int, default=100,
# help='num decoder layers')
# group.add_argument('--vae-size', type=int, default=8192,
# help='vae code vocab size')
# group.add_argument('--max-image-position', type=int, default=1025,
# help='max image decode position')
# group.add_argument("--palm-dataset", action='store_true', default=False,
# help='Use struct bert dataset or not.')
# group.add_argument("--image-dataset", action='store_true', default=False,
# help='Use struct bert dataset or not.')
# group.add_argument("--do-mask-lm", action='store_true', default=False,
# help='Do mask lm task or not.')
# group.add_argument('--vae-enc-model', type=str, default=None,
# help='Path to a directory containing a model checkpoint.')
return parser
def add_downstream_args(parser):
group = parser.add_argument_group('downstream', 'struct configurations')
group.add_argument("--downstream-dataset", action='store_true', default=False,
help='Use struct bert dataset or not.')
group.add_argument("--task-name", default='ocnli', type=str)
# group = parser.add_argument_group('downstream', 'struct configurations')
# group.add_argument("--downstream-dataset", action='store_true', default=False,
# help='Use struct bert dataset or not.')
# group.add_argument("--task-name", default='ocnli', type=str)
return parser
def add_data_args(parser):
"""Train/valid/test data arguments."""
group = parser.add_argument_group('data', 'data configurations')
group.add_argument('--model-parallel-size', type=int, default=1,
help='size of the model parallel.')
group.add_argument('--shuffle', action='store_true',
help='Shuffle data. Shuffling is deterministic '
'based on seed and current epoch.')
group.add_argument('--train-data', nargs='+', default=None,
help='Whitespace separated filenames or corpora names '
'for training.')
group.add_argument('--use-npy-data-loader', action='store_true',
help='Use the numpy data loader. If set, then'
'train-data-path, val-data-path, and test-data-path'
'should also be provided.')
group.add_argument('--train-data-path', type=str, default='',
help='path to the training data')
group.add_argument('--val-data-path', type=str, default='',
help='path to the validation data')
group.add_argument('--test-data-path', type=str, default='',
help='path to the test data')
group.add_argument('--input-data-sizes-file', type=str, default='sizes.txt',
help='the filename containing all the shards sizes')
group.add_argument('--delim', default=',',
help='delimiter used to parse csv data files')
group.add_argument('--text-key', default='sentence',
help='key to use to extract text from json/csv')
group.add_argument('--eval-text-key', default=None,
help='key to use to extract text from '
'json/csv evaluation datasets')
group.add_argument('--valid-data', nargs='*', default=None,
help="""Filename for validation data.""")
group.add_argument('--split', default='1000,1,1',
help='comma-separated list of proportions for training,'
' validation, and test split')
group.add_argument('--test-data', nargs='*', default=None,
help="""Filename for testing""")
group.add_argument('--lazy-loader', action='store_true',
help='whether to lazy read the data set')
group.add_argument('--loose-json', action='store_true',
help='Use loose json (one json-formatted string per '
'newline), instead of tight json (data file is one '
'json string)')
group.add_argument('--presplit-sentences', action='store_true',
help='Dataset content consists of documents where '
'each document consists of newline separated sentences')
group.add_argument('--num-workers', type=int, default=2,
help="""Number of workers to use for dataloading""")
group.add_argument('--tokenizer-model-type', type=str,
default='bert-large-uncased',
help="Model type to use for sentencepiece tokenization \
(one of ['bpe', 'char', 'unigram', 'word']) or \
bert vocab to use for BertWordPieceTokenizer (one of \
['bert-large-uncased', 'bert-large-cased', etc.])")
group.add_argument('--tokenizer-path', type=str, default='tokenizer.model',
help='path used to save/load sentencepiece tokenization '
'models')
group.add_argument('--tokenizer-type', type=str,
default='BertWordPieceTokenizer',
choices=['CharacterLevelTokenizer',
'SentencePieceTokenizer',
'BertWordPieceTokenizer',
'GPT2BPETokenizer'],
help='what type of tokenizer to use')
group.add_argument("--cache-dir", default=None, type=str,
help="Where to store pre-trained BERT downloads")
group.add_argument('--use-tfrecords', action='store_true',
help='load `--train-data`, `--valid-data`, '
'`--test-data` from BERT tf records instead of '
'normal data pipeline')
group.add_argument('--seq-length', type=int, default=512,
help="Maximum sequence length to process")
group.add_argument('--max-preds-per-seq', type=int, default=None,
help='Maximum number of predictions to use per sequence.'
'Defaults to math.ceil(`--seq-length`*.15/10)*10.'
'MUST BE SPECIFIED IF `--use-tfrecords` is True.')
# group = parser.add_argument_group('data', 'data configurations')
#
# group.add_argument('--model-parallel-size', type=int, default=1,
# help='size of the model parallel.')
# group.add_argument('--shuffle', action='store_true',
# help='Shuffle data. Shuffling is deterministic '
# 'based on seed and current epoch.')
# group.add_argument('--train-data', nargs='+', default=None,
# help='Whitespace separated filenames or corpora names '
# 'for training.')
#
# group.add_argument('--use-npy-data-loader', action='store_true',
# help='Use the numpy data loader. If set, then'
# 'train-data-path, val-data-path, and test-data-path'
# 'should also be provided.')
# group.add_argument('--train-data-path', type=str, default='',
# help='path to the training data')
# group.add_argument('--val-data-path', type=str, default='',
# help='path to the validation data')
# group.add_argument('--test-data-path', type=str, default='',
# help='path to the test data')
# group.add_argument('--input-data-sizes-file', type=str, default='sizes.txt',
# help='the filename containing all the shards sizes')
#
# group.add_argument('--delim', default=',',
# help='delimiter used to parse csv data files')
# group.add_argument('--text-key', default='sentence',
# help='key to use to extract text from json/csv')
# group.add_argument('--eval-text-key', default=None,
# help='key to use to extract text from '
# 'json/csv evaluation datasets')
# group.add_argument('--valid-data', nargs='*', default=None,
# help="""Filename for validation data.""")
# group.add_argument('--split', default='1000,1,1',
# help='comma-separated list of proportions for training,'
# ' validation, and test split')
# group.add_argument('--test-data', nargs='*', default=None,
# help="""Filename for testing""")
#
# group.add_argument('--lazy-loader', action='store_true',
# help='whether to lazy read the data set')
# group.add_argument('--loose-json', action='store_true',
# help='Use loose json (one json-formatted string per '
# 'newline), instead of tight json (data file is one '
# 'json string)')
# group.add_argument('--presplit-sentences', action='store_true',
# help='Dataset content consists of documents where '
# 'each document consists of newline separated sentences')
# group.add_argument('--num-workers', type=int, default=2,
# help="""Number of workers to use for dataloading""")
# group.add_argument('--tokenizer-model-type', type=str,
# default='bert-large-uncased',
# help="Model type to use for sentencepiece tokenization \
# (one of ['bpe', 'char', 'unigram', 'word']) or \
# bert vocab to use for BertWordPieceTokenizer (one of \
# ['bert-large-uncased', 'bert-large-cased', etc.])")
# group.add_argument('--tokenizer-path', type=str, default='tokenizer.model',
# help='path used to save/load sentencepiece tokenization '
# 'models')
# group.add_argument('--tokenizer-type', type=str,
# default='BertWordPieceTokenizer',
# choices=['CharacterLevelTokenizer',
# 'SentencePieceTokenizer',
# 'BertWordPieceTokenizer',
# 'GPT2BPETokenizer'],
# help='what type of tokenizer to use')
# group.add_argument("--cache-dir", default=None, type=str,
# help="Where to store pre-trained BERT downloads")
# group.add_argument('--use-tfrecords', action='store_true',
# help='load `--train-data`, `--valid-data`, '
# '`--test-data` from BERT tf records instead of '
# 'normal data pipeline')
# group.add_argument('--seq-length', type=int, default=512,
# help="Maximum sequence length to process")
# group.add_argument('--max-preds-per-seq', type=int, default=None,
# help='Maximum number of predictions to use per sequence.'
# 'Defaults to math.ceil(`--seq-length`*.15/10)*10.'
# 'MUST BE SPECIFIED IF `--use-tfrecords` is True.')
return parser

View File

@@ -0,0 +1,53 @@
{
"hidden_size": 8192,
"intermediate_size": 32768,
"num_hidden_layers": 1,
"dec_hidden_layers": 1,
"num_attention_heads": 128,
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"type_vocab_size": 3,
"vocab_size": 21504,
"original_vocab_size": 21128,
"max_position_embeddings": 2048,
"lr_decay_style": "linear",
"lr": 3e-5,
"weight_decay": 1e-2,
"clip_grad": 1.0,
"warmup": 0.0333,
"layernorm_epsilon": 1e-5,
"layer_norm_eps": 1e-5,
"fp32_embedding": false,
"fp32_tokentypes": false,
"fp32_layernorm": true,
"fp16": true,
"pruning_method": "pest_block",
"pruning_initial_threshold": 0.5,
"pruning_final_threshold": 0.01,
"pruning_initial_warmup": 1,
"pruning_final_warmup": 9,
"pruning_module": "decoder",
"pruning_decay_step": 76,
"pruning_decay_type": "exp",
"pruning_mask_init": "constant",
"pruning_mask_scale": 0.0,
"LR_weight_rank": 16,
"LR_mask_rank": 16,
"LR_weight_alpha": 32,
"LR_mask_alpha": 16,
"attn_separate": false,
"ft_module": "0",
"model_type": "plug_nlg_original-8192",
"load_iteration": 28000,
"distributed_backend": "nccl",
"deep_init": true,
"deepspeed": true,
"deepspeed_config": "ds_zero-offload_10B_config.json",
"nni": true,
"cpu_optimizer": true,
"no_load_optim": true,
"pre_load": true,
"pre_ln": true,
"attention-dropout": 0.1
}

View File

@@ -0,0 +1,69 @@
{
"framework": "pytorch",
"task": "text-generation",
"preprocessor": {
"type": "text-gen-tokenizer"
},
"model": {
"type": "plug",
"world_size": 8,
"model_parallel_size": 8,
"pre_load": true,
"distributed_backend": "nccl",
"checkpoint_activations": true,
"top_k": 20,
"top_p": 0.0,
"temperature": 0.9,
"seed": 42,
"output_sequence_length": 128
},
"pipeline": {
"type": "text-generation"
},
"train": {
"work_dir": "/tmp",
"max_epochs": 3,
"dataloader": {
"batch_size_per_gpu": 2,
"workers_per_gpu": 1
},
"optimizer": {
"type": "SGD",
"lr": 0.01,
"options": {
"grad_clip": {
"max_norm": 2.0
}
}
},
"lr_scheduler": {
"type": "StepLR",
"step_size": 2,
"options": {
"warmup": {
"type": "LinearWarmup",
"warmup_iters": 2
}
}
},
"hooks": [{
"type": "CheckpointHook",
"interval": 1
}, {
"type": "TextLoggerHook",
"interval": 1
}, {
"type": "IterTimerHook"
}, {
"type": "EvaluationHook",
"interval": 1
}]
},
"evaluation": {
"dataloader": {
"batch_size_per_gpu": 2,
"workers_per_gpu": 1,
"shuffle": false
}
}
}

View File

@@ -18,8 +18,6 @@ import json
import copy
""" BERT model configuration """
from collections import OrderedDict
from typing import Mapping
from transformers import PretrainedConfig
from modelscope.utils import logger as logging
@@ -94,11 +92,12 @@ class PlugNLUConfig(PretrainedConfig):
>>> # Accessing the model configuration
>>> configuration = model.config
"""
model_type="plugNLU"
model_type = "plugNLU"
def __init__(
self,
vocab_size=21504,
original_vocab_size=21128,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
@@ -141,6 +140,7 @@ class PlugNLUConfig(PretrainedConfig):
super().__init__(layer_norm_eps=layernorm_epsilon, **kwargs)
self.vocab_size = vocab_size
self.original_vocab_size = original_vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads

View File

@@ -1,188 +1,136 @@
import random
import torch
import numpy as np
import torch.nn.functional as F
import os
from typing import Dict
from . import PlugModel
import torch
import torch.nn.functional as F
from modelscope.models import TorchModel
from modelscope.models.base import Tensor
from modelscope.utils.nlp import mpu
from modelscope.utils.nlp.utils import print_rank_0
from modelscope.utils.nlp.fp16 import FP16_Module
from modelscope.utils.nlp.distributed import DistributedDataParallel as DDP
from modelscope.utils.nlp.load_checkpoint import pre_load
from modelscope.utils.nlp.utils import print_rank_0
from modelscope.utils.torch_utils import set_random_seed_mpu
from modelscope.utils.logger import get_logger
from . import PlugModel
from .configuration_plug import PlugNLGConfig
from modelscope.models.nlp.utils.distributed import initialize_distributed
import os
from modelscope.utils.torch_utils import init_dist
def initialize_distributed(rank):
"""Initialize torch.distributed."""
# Manually set the device ids.
#torch.multiprocessing.set_start_method("spawn")
device = rank % torch.cuda.device_count()
torch.cuda.set_device(device)
# Call the init process
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', '127.0.0.1')
master_port = os.getenv('MASTER_PORT', '12345')
init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(
backend="nccl",
world_size=8, rank=rank,
init_method=init_method)
# Set the model-parallel communicators.
mpu.initialize_model_parallel(8)
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
# This function has been mostly taken from huggingface conversational ai code at
# https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
#convert to 1D
logits=logits.view(logits.size()[1]).contiguous()
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = filter_value
#going back to 2D
logits=logits.view(1, -1).contiguous()
return logits
logger = get_logger(__name__)
class DistributedPlug(TorchModel):
class DistributedPlug:
@classmethod
def init(cls, rank, model_dir, model_config, args):
#def init(cls, rank):
#torch.backends.cudnn.enabled = False
#
cls.rank = rank
cls.args = args
cls.config = model_config
cls.model_dir = model_dir
initialize_distributed(rank)
cls.set_random_seed(cls, args.seed)
cls.setup_model(cls, path_load_tag='model')
def __init__(self, model_dir, rank, **kwargs):
super().__init__(model_dir, **kwargs)
self.rank = rank
self.model_cfg = kwargs
self.config = PlugNLGConfig.from_pretrained(model_dir)
initialize_distributed(rank, mpu, kwargs['world_size'], kwargs['model_parallel_size'])
if 'seed' in kwargs:
set_random_seed_mpu(kwargs['seed'])
self.iteration = 0
self.dist_model = self.initialize_model(path_load_tag='model')
def set_random_seed(cls, seed):
if seed is not None and seed > 0:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
mpu.model_parallel_cuda_manual_seed(seed)
def get_model(cls):
def initialize_model(self, path_load_tag='model'):
"""Build the model."""
print_rank_0('Building Plug model. It will take a few minutes ...')
model = PlugModel(cls.config)
model = PlugModel(self.config)
if mpu.get_data_parallel_rank() == 0:
print(' > number of parameters on model parallel rank {}: {}'.format(
logger.info(' > number of parameters on model parallel rank {}: {}'.format(
mpu.get_model_parallel_rank(),
sum([p.nelement() for p in model.parameters()])), flush=True)
sum([p.nelement() for p in model.parameters()])))
if cls.args.deepspeed and cls.args.fp16:
model.half()
if self.config.deepspeed and self.args.fp16:
model.half()
# GPU allocation.
model.cuda(torch.cuda.current_device())
# Fp16 conversion.
if cls.args.fp16:
if self.config.fp16:
model = FP16_Module(model)
if cls.args.fp32_embedding:
if self.config.fp32_embedding:
model.module.model.bert.embeddings.word_embeddings.float()
model.module.model.bert.embeddings.position_embeddings.float()
model.module.model.bert.embeddings.token_type_embeddings.float()
if cls.args.fp32_tokentypes:
if self.config.fp32_tokentypes:
model.module.model.bert.embeddings.token_type_embeddings.float()
if cls.args.fp32_layernorm:
if self.config.fp32_layernorm:
for name, _module in model.named_modules():
if 'LayerNorm' in name:
_module.float()
# model = DDP(model)
load_model = pre_load(mpu, self.model_dir, tag=path_load_tag)
model_dict = model.module.model.state_dict()
for key in load_model:
if key not in model_dict.keys():
print_rank_0('Skip key: ' + key)
else:
print_rank_0('Loading key: ' + key)
model.module.model.load_state_dict(load_model, strict=False)
return model
def setup_model(cls, path_load_tag='model'):
dist_model = cls.get_model(cls)
if cls.model_dir is not None:
from modelscope.utils.nlp.load_checkpoint import pre_load
load_model = pre_load(mpu, cls.model_dir, tag=path_load_tag)
# model_dict = dist_model.module.module.model.state_dict()
model_dict = dist_model.module.model.state_dict()
for key in load_model:
if key not in model_dict.keys():
print_rank_0('Skip key: '+key)
else:
print_rank_0('Loading key: '+key)
# dist_model.module.module.model.load_state_dict(load_model, strict=False)
dist_model.module.model.load_state_dict(load_model, strict=False)
cls.args.iteration = 0
cls.dist_model = dist_model
@staticmethod
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
# This function has been mostly taken from huggingface conversational ai code at
# https://medium.com/huggingface/how-to-build-a-state-of-the-art-
# conversational-ai-with-transfer-learning-2d818ac26313
@classmethod
def forward(cls, input:Dict[str, Tensor]):
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
# convert to 1D
logits = logits.view(logits.size()[1]).contiguous()
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = filter_value
# going back to 2D
logits = logits.view(1, -1).contiguous()
return logits
def forward(self, input: Dict[str, Tensor]):
device = torch.cuda.current_device()
batch_size = input["input_ids"].shape[0]
tokens = input["input_ids"].to(device)
dec_input_ids = input["dec_input_ids"].to(device)
attention_mask = input["attention_mask"].to(device)
cls.dist_model.eval()
seq_length = 128
self.dist_model.eval()
seq_length = self.model_cfg['output_sequence_length']
with torch.no_grad():
all_generate_tokens = []
generate_tokens = []
counter = 0
sequence_output = None
vocab_size = 21128
#tokens, attention_mask, types, dec_input_ids = get_batch(context_tokens_tensor, device, args)
vocab_size = self.config.original_vocab_size
while counter < seq_length:
# if counter % 128 == 0 and counter != 0:
# generate_tokens.append(tokenizer.vocab[args.sep_token])
# start = (context_tokens_tensor == 102).nonzero(as_tuple=True)[-1]
# if start + len(generate_tokens) >= 512:
# context_tokens_tensor = torch.cat([context_tokens_tensor[:start], torch.cuda.LongTensor(generate_tokens)], -1)[-512:]
# else:
# context_tokens_tensor[start:start+len(generate_tokens)] = torch.cuda.LongTensor(generate_tokens)
# tokens, attention_mask, types, dec_input_ids = get_batch(context_tokens_tensor, device, args)
# generate_tokens = []
# sequence_output = None
position_ids = torch.full([cls.args.batch_size, 1], len(generate_tokens), dtype=torch.long, device=device)
_, logits, sequence_output = cls.dist_model(tokens, None, attention_mask, dec_input_ids, attention_mask, position_ids, is_infer=True, sequence_output=sequence_output, parallel_output=False)
partition_vocab_size = logits.size()[-1]
position_ids = torch.full([batch_size, 1], len(generate_tokens),
dtype=torch.long, device=device)
_, logits, sequence_output = self.dist_model(tokens, None, attention_mask, dec_input_ids,
attention_mask, position_ids, is_infer=True,
sequence_output=sequence_output, parallel_output=False)
logits = logits[:, -1, :]
logits = logits / cls.args.temperature
logits = top_k_logits(logits, top_k=cls.args.top_k, top_p=cls.args.top_p)
logits = logits / self.model_cfg['temperature']
logits = self.top_k_logits(logits, top_k=self.model_cfg['top_k'], top_p=self.model_cfg['top_p'])
log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1)
prev_token = prev[0].item()
if prev_token >= vocab_size: #or prev_token == 102:
if prev_token >= vocab_size:
prev_token = 100
prev[0] = 100
# if prev_token == 102 and len(all_generate_tokens) > int(max(1, length) * 0.8):
if prev_token == 102:
break
#if prev_token == 102:
# counter += 1
# continue
#if prev_token == 100:
# counter += 1
# continue
dec_input_ids = torch.cat([dec_input_ids, prev], dim=1)
generate_tokens.append(prev_token)
all_generate_tokens.append(prev_token)

View File

@@ -18,113 +18,21 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import copy
import json
import math
import logging
import tarfile
import tempfile
import shutil
import math
import os
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from deepspeed.utils.timer import SynchronizedWallClockTimer
from modelscope.models.nlp.utils.distributed import normal_init_method, scaled_init_method
from torch import nn
from .configuration_plug import PlugNLUConfig, PlugNLGConfig
from ....utils.nlp import mpu#, cached_path
import copy
from deepspeed.utils.timer import SynchronizedWallClockTimer
def normal_init_method(mean, std):
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=mean, std=std)
return init_
def scaled_init_method(mean, std, num_layers):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = std / math.sqrt(2.0 * num_layers)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=mean, std=std)
return init_
from ....utils.nlp import mpu # , cached_path
logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
}
CONFIG_NAME = 'bert_config.json'
WEIGHTS_NAME = 'pytorch_model.bin'
TF_WEIGHTS_NAME = 'model.ckpt'
def load_tf_weights_in_bert(model, tf_checkpoint_path):
""" Load tf checkpoints in a pytorch model
"""
try:
import re
import numpy as np
import tensorflow as tf
except ImportError:
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
raise
tf_path = os.path.abspath(tf_checkpoint_path)
print("Converting TensorFlow checkpoint from {}".format(tf_path))
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
for name, shape in init_vars:
print("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
names.append(name)
arrays.append(array)
for name, array in zip(names, arrays):
name = name.split('/')
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(n in ["adam_v", "adam_m"] for n in name):
print("Skipping {}".format("/".join(name)))
continue
pointer = model
for m_name in name:
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
l = re.split(r'_(\d+)', m_name)
else:
l = [m_name]
if l[0] == 'kernel' or l[0] == 'gamma':
pointer = getattr(pointer, 'weight')
elif l[0] == 'output_bias' or l[0] == 'beta':
pointer = getattr(pointer, 'bias')
elif l[0] == 'output_weights':
pointer = getattr(pointer, 'weight')
else:
pointer = getattr(pointer, l[0])
if len(l) >= 2:
num = int(l[1])
pointer = pointer[num]
if m_name[-11:] == '_embeddings':
pointer = getattr(pointer, 'weight')
elif m_name == 'kernel':
array = np.transpose(array)
try:
assert pointer.shape == array.shape
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
print("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
return model
def gelu(x):
"""Implementation of the gelu activation function.
@@ -140,6 +48,7 @@ def swish(x):
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
class BertLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
@@ -155,9 +64,11 @@ class BertLayerNorm(nn.Module):
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
"""
def __init__(self, config):
super(BertEmbeddings, self).__init__()
self.word_embeddings = mpu.VocabParallelEmbedding(
@@ -201,7 +112,7 @@ class BertEmbeddings(nn.Module):
else:
embeddings = embeddings.type(previous_type)
else:
embeddings = words_embeddings.float() + position_embeddings.float() + token_type_embeddings.float()
embeddings = words_embeddings.float() + position_embeddings.float() + token_type_embeddings.float()
if self.fp32_tokentypes and not self.fp32_layernorm:
embeddings = embeddings.half()
previous_type = embeddings.type()
@@ -234,7 +145,9 @@ class BertSelfOutput(nn.Module):
input_is_parallel=True,
stride=1,
init_method=init_method,
pruning_method=config.pruning_method if config.pruning_module in ['all', 'encoder', 'encoder_self', 'encoder_selfvo', 'encoder_selfo'] else None,
pruning_method=config.pruning_method if config.pruning_module in ['all', 'encoder', 'encoder_self',
'encoder_selfvo',
'encoder_selfo'] else None,
pruning_mask_init=config.pruning_mask_init,
pruning_mask_scale=config.pruning_mask_scale,
LR_weight_rank=config.LR_weight_rank,
@@ -246,8 +159,8 @@ class BertSelfOutput(nn.Module):
self.LayerNorm = None
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor, pruning_threshold=None,):
hidden_states = self.dense(hidden_states, pruning_threshold=pruning_threshold,)
def forward(self, hidden_states, input_tensor, pruning_threshold=None, ):
hidden_states = self.dense(hidden_states, pruning_threshold=pruning_threshold, )
hidden_states = self.dropout(hidden_states)
ln_input = hidden_states + input_tensor
if self.LayerNorm is not None:
@@ -261,6 +174,7 @@ class BertSelfOutput(nn.Module):
hidden_states = ln_input
return hidden_states
class BertAttention(nn.Module):
def __init__(self, config):
super(BertAttention, self).__init__()
@@ -285,7 +199,7 @@ class BertAttention(nn.Module):
LR_mask_rank=config.LR_mask_rank)
self.output = BertSelfOutput(config)
def forward(self, input_tensor, attention_mask, pruning_threshold=None,):
def forward(self, input_tensor, attention_mask, pruning_threshold=None, ):
if self.LayerNorm is not None:
ln_input = input_tensor
previous_type = input_tensor.type()
@@ -294,15 +208,15 @@ class BertAttention(nn.Module):
ln_output = self.LayerNorm(ln_input)
if self.fp32_layernorm:
ln_output = ln_output.type(previous_type)
self_output = self.self(ln_output, attention_mask, pruning_threshold=pruning_threshold,)
self_output = self.self(ln_output, attention_mask, pruning_threshold=pruning_threshold, )
else:
self_output = self.self(input_tensor, attention_mask, pruning_threshold=pruning_threshold,)
# output_pruning_threshold = 1 - (1 - pruning_threshold)/0.99*0.95
self_output = self.self(input_tensor, attention_mask, pruning_threshold=pruning_threshold, )
output_pruning_threshold = pruning_threshold
attention_output = self.output(self_output, input_tensor, pruning_threshold=output_pruning_threshold,)
attention_output = self.output(self_output, input_tensor, pruning_threshold=output_pruning_threshold, )
return attention_output
class BertIntermediate(nn.Module):
def __init__(self, config):
super(BertIntermediate, self).__init__()
@@ -313,7 +227,8 @@ class BertIntermediate(nn.Module):
gather_output=False,
stride=1,
init_method=normal_init_method(mean=0.0, std=config.initializer_range),
pruning_method=config.pruning_method if config.pruning_module in ['all', 'encoder', 'encoder_ffn'] else None,
pruning_method=config.pruning_method if config.pruning_module in ['all', 'encoder',
'encoder_ffn'] else None,
pruning_mask_init=config.pruning_mask_init,
pruning_mask_scale=config.pruning_mask_scale,
LR_weight_rank=config.LR_weight_rank,
@@ -321,15 +236,16 @@ class BertIntermediate(nn.Module):
self.intermediate_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, str) else config.hidden_act
def forward(self, hidden_states, pruning_threshold=None,):
hidden_states = self.dense(hidden_states, pruning_threshold=pruning_threshold,)
def forward(self, hidden_states, pruning_threshold=None, ):
hidden_states = self.dense(hidden_states, pruning_threshold=pruning_threshold, )
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class BertOutput(nn.Module):
def __init__(self, config):
super(BertOutput, self).__init__()
if hasattr(config, 'deep_init') and config.deep_init:
if hasattr(config, 'deep_init') and config.deep_init:
init_method = scaled_init_method(mean=0.0,
std=config.initializer_range,
num_layers=config.num_hidden_layers)
@@ -343,7 +259,8 @@ class BertOutput(nn.Module):
input_is_parallel=True,
stride=1,
init_method=init_method,
pruning_method=config.pruning_method if config.pruning_module in ['all', 'encoder', 'encoder_ffn'] else None,
pruning_method=config.pruning_method if config.pruning_module in ['all', 'encoder',
'encoder_ffn'] else None,
pruning_mask_init=config.pruning_mask_init,
pruning_mask_scale=config.pruning_mask_scale,
LR_weight_rank=config.LR_weight_rank,
@@ -355,11 +272,11 @@ class BertOutput(nn.Module):
self.LayerNorm = None
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor, pruning_threshold=None,):
hidden_states = self.dense(hidden_states, pruning_threshold=pruning_threshold,)
def forward(self, hidden_states, input_tensor, pruning_threshold=None, ):
hidden_states = self.dense(hidden_states, pruning_threshold=pruning_threshold, )
hidden_states = self.dropout(hidden_states)
ln_input = hidden_states + input_tensor
if self.LayerNorm is not None:
if self.LayerNorm is not None:
previous_type = ln_input.type()
if self.fp32_layernorm:
ln_input = ln_input.float()
@@ -370,6 +287,7 @@ class BertOutput(nn.Module):
hidden_states = ln_input
return hidden_states
class BertLayer(nn.Module):
def __init__(self, config):
super(BertLayer, self).__init__()
@@ -382,7 +300,7 @@ class BertLayer(nn.Module):
else:
self.LayerNorm = None
def forward(self, hidden_states, attention_mask, pruning_threshold=None,):
def forward(self, hidden_states, attention_mask, pruning_threshold=None, ):
attention_output = self.attention(hidden_states, attention_mask, pruning_threshold=pruning_threshold)
if self.LayerNorm is not None:
ln_input = attention_output
@@ -398,6 +316,7 @@ class BertLayer(nn.Module):
layer_output = self.output(intermediate_output, attention_output, pruning_threshold=pruning_threshold)
return layer_output
class BertEncoder(nn.Module):
def __init__(self, config):
super(BertEncoder, self).__init__()
@@ -408,8 +327,10 @@ class BertEncoder(nn.Module):
else:
self.LayerNorm = None
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, checkpoint_activations=False, detach_index=-1, pruning_threshold=None,):
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, checkpoint_activations=False,
detach_index=-1, pruning_threshold=None, ):
all_encoder_layers = []
def custom(start, end):
def custom_forward(*inputs):
layers = self.layer[start:end]
@@ -417,20 +338,21 @@ class BertEncoder(nn.Module):
for layer in layers:
x_ = layer(x_, inputs[1], pruning_threshold=pruning_threshold)
return x_
return custom_forward
if checkpoint_activations:
l = 0
num_layers = len(self.layer)
chunk_length = 1 #math.ceil(math.sqrt(num_layers))
chunk_length = 1
while l < num_layers:
hidden_states = mpu.checkpoint(custom(l, l+chunk_length), hidden_states, attention_mask*1)
hidden_states = mpu.checkpoint(custom(l, l + chunk_length), hidden_states, attention_mask * 1)
if detach_index == l:
hidden_states.detach_()
l += chunk_length
# decoder layers
else:
for i,layer_module in enumerate(self.layer):
for i, layer_module in enumerate(self.layer):
hidden_states = layer_module(hidden_states, attention_mask)
if detach_index == i:
hidden_states.detach_()
@@ -455,6 +377,7 @@ class BertEncoder(nn.Module):
all_encoder_layers.append(hidden_states)
return all_encoder_layers
class BertPooler(nn.Module):
def __init__(self, config):
super(BertPooler, self).__init__()
@@ -469,6 +392,7 @@ class BertPooler(nn.Module):
pooled_output = self.activation(pooled_output)
return pooled_output
class BertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super(BertPredictionHeadTransform, self).__init__()
@@ -489,6 +413,7 @@ class BertPredictionHeadTransform(nn.Module):
hidden_states = hidden_states.type(previous_type)
return hidden_states
class BertLMPredictionHead(nn.Module):
def __init__(self, config, bert_model_embedding_weights):
super(BertLMPredictionHead, self).__init__()
@@ -496,19 +421,18 @@ class BertLMPredictionHead(nn.Module):
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
#self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
# bert_model_embedding_weights.size(0),
# bias=False)
self.decoder_weight = bert_model_embedding_weights
self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
self.bias.model_parallel = True
self.fp32_embedding = config.fp32_embedding
self.fp32_layernorm = config.fp32_layernorm
def convert_to_type(tensor):
if self.fp32_embedding:
return tensor.half()
else:
return tensor
self.type_converter = convert_to_type
self.converted = False
self.timers = SynchronizedWallClockTimer()
@@ -521,14 +445,12 @@ class BertLMPredictionHead(nn.Module):
if self.fp32_layernorm:
self.transform.LayerNorm.float()
hidden_states = self.transform(self.type_converter(hidden_states))
# hidden_states = self.decoder(hidden_states) + self.bias
self.timers('final linear gather').start()
hidden_states = mpu.copy_to_model_parallel_region(hidden_states)
self.timers('final linear gather').stop()
hidden_states = F.linear(self.type_converter(hidden_states),
self.type_converter(self.decoder_weight),
self.type_converter(self.bias))
#self.timers.log(names=['final linear gather'])
return hidden_states
@@ -547,10 +469,12 @@ class BertPreTrainingHeads(nn.Module):
seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score
class PreTrainedBertModel(nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, config, *inputs, **kwargs):
super(PreTrainedBertModel, self).__init__()
if not isinstance(config, PlugNLUConfig) and not isinstance(config, PlugNLGConfig):
@@ -575,119 +499,6 @@ class PreTrainedBertModel(nn.Module):
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
#@classmethod
#def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None,
# fp32_layernorm=False, fp32_embedding=False, layernorm_epsilon=1e-12,
# fp32_tokentypes=False, *inputs, **kwargs):
# """
# Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict.
# Download and cache the pre-trained model file if needed.
# Params:
# pretrained_model_name: either:
# - a str with the name of a pre-trained model to load selected in the list of:
# . `bert-base-uncased`
# . `bert-large-uncased`
# . `bert-base-cased`
# . `bert-large-cased`
# . `bert-base-multilingual-uncased`
# . `bert-base-multilingual-cased`
# . `bert-base-chinese`
# - a path or url to a pretrained model archive containing:
# . `bert_config.json` a configuration file for the model
# . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
# cache_dir: an optional path to a folder in which the pre-trained models will be cached.
# state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
# *inputs, **kwargs: additional input for the specific Bert class
# (ex: num_labels for BertForSequenceClassification)
# """
# if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
# archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
# else:
# archive_file = pretrained_model_name
# # redirect to the cache, if necessary
# try:
# resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
# except FileNotFoundError:
# logger.error(
# "Model name '{}' was not found in model name list ({}). "
# "We assumed '{}' was a path or url but couldn't find any file "
# "associated to this path or url.".format(
# pretrained_model_name,
# ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
# archive_file))
# return None
# if resolved_archive_file == archive_file:
# logger.info("loading archive file {}".format(archive_file))
# else:
# logger.info("loading archive file {} from cache at {}".format(
# archive_file, resolved_archive_file))
# tempdir = None
# if os.path.isdir(resolved_archive_file):
# serialization_dir = resolved_archive_file
# else:
# # Extract archive to temp dir
# tempdir = tempfile.mkdtemp()
# logger.info("extracting archive file {} to temp dir {}".format(
# resolved_archive_file, tempdir))
# with tarfile.open(resolved_archive_file, 'r:gz') as archive:
# archive.extractall(tempdir)
# serialization_dir = tempdir
# # Load config
# config_file = os.path.join(serialization_dir, CONFIG_NAME)
# config = PlugNLUConfig.from_json_file(config_file)
# config.fp32_layernorm = fp32_layernorm
# config.fp32_embedding = fp32_embedding
# config.layernorm_epsilon = layernorm_epsilon
# config.fp32_tokentypes = fp32_tokentypes
# logger.info("Model config {}".format(config))
# # Instantiate model.
# model = cls(config, *inputs, **kwargs)
# if state_dict is None:
# weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
# state_dict = torch.load(weights_path)
# old_keys = []
# new_keys = []
# for key in state_dict.keys():
# new_key = None
# if 'gamma' in key:
# new_key = key.replace('gamma', 'weight')
# if 'beta' in key:
# new_key = key.replace('beta', 'bias')
# if new_key:
# old_keys.append(key)
# new_keys.append(new_key)
# for old_key, new_key in zip(old_keys, new_keys):
# state_dict[new_key] = state_dict.pop(old_key)
# missing_keys = []
# unexpected_keys = []
# error_msgs = []
# # copy state_dict so _load_from_state_dict can modify it
# metadata = getattr(state_dict, '_metadata', None)
# state_dict = state_dict.copy()
# if metadata is not None:
# state_dict._metadata = metadata
# def load(module, prefix=''):
# local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
# module._load_from_state_dict(
# state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
# for name, child in module._modules.items():
# if child is not None:
# load(child, prefix + name + '.')
# load(model, prefix='' if hasattr(model, 'bert') else 'bert.')
# if len(missing_keys) > 0:
# logger.info("Weights of {} not initialized from pretrained model: {}".format(
# model.__class__.__name__, missing_keys))
# if len(unexpected_keys) > 0:
# logger.info("Weights from pretrained model not used in {}: {}".format(
# model.__class__.__name__, unexpected_keys))
# if tempdir:
# # Clean up temp dir
# shutil.rmtree(tempdir)
# return model
class BertModel(PreTrainedBertModel):
"""BERT model ("Bidirectional Embedding Representations from a Transformer").
@@ -733,6 +544,7 @@ class BertModel(PreTrainedBertModel):
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config):
super(BertModel, self).__init__(config)
self.embeddings = BertEmbeddings(config)
@@ -740,7 +552,8 @@ class BertModel(PreTrainedBertModel):
self.pooler = BertPooler(config)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True, checkpoint_activations=False, detach_index=-1, pruning_threshold=None,):
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True,
checkpoint_activations=False, detach_index=-1, pruning_threshold=None, ):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
@@ -758,7 +571,8 @@ class BertModel(PreTrainedBertModel):
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=next(self.encoder.parameters()).dtype) # fp16 compatibility
extended_attention_mask = extended_attention_mask.to(
dtype=next(self.encoder.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
embedding_output = self.embeddings(input_ids, token_type_ids)
@@ -774,7 +588,7 @@ class BertModel(PreTrainedBertModel):
continue
sequence_output = sequence_output.type_as(p)
break
#pooled_output = self.pooler(sequence_output)
pooled_output = sequence_output[:, 0]
if not output_all_encoded_layers or checkpoint_activations:
encoded_layers = encoded_layers[-1]
@@ -784,11 +598,11 @@ class BertModel(PreTrainedBertModel):
class DecodeLayer(nn.Module):
def __init__(self, config):
super(DecodeLayer, self).__init__()
init_method = normal_init_method(mean=0.0,std=config.initializer_range)
init_method = normal_init_method(mean=0.0, std=config.initializer_range)
output_layer_init_method = scaled_init_method(mean=0.0,
std=config.initializer_range,
num_layers=config.num_hidden_layers)
std=config.initializer_range,
num_layers=config.num_hidden_layers)
self_pruning_method = config.pruning_method
cross_pruning_method = config.pruning_method
ffn_pruning_method = config.pruning_method
@@ -808,11 +622,12 @@ class DecodeLayer(nn.Module):
output_dropout_prob=config.hidden_dropout_prob,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
pruning_method=self_pruning_method if config.pruning_module in ['all', 'decoder', 'decoder_self', 'decoder_self+ffn'] else None,
pruning_method=self_pruning_method if config.pruning_module in ['all', 'decoder', 'decoder_self',
'decoder_self+ffn'] else None,
pruning_mask_init=config.pruning_mask_init, pruning_mask_scale=config.pruning_mask_scale,
LR_weight_rank=config.LR_weight_rank,
LR_mask_rank=config.LR_mask_rank,
)
)
self.cross_attention = mpu.PalmParallelCrossAttention(
hidden_size=config.hidden_size,
@@ -821,55 +636,65 @@ class DecodeLayer(nn.Module):
output_dropout_prob=config.hidden_dropout_prob,
init_method=init_method, attn_separate=False,
output_layer_init_method=output_layer_init_method,
pruning_method=cross_pruning_method, pruning_mask_init=config.pruning_mask_init,
pruning_method=cross_pruning_method, pruning_mask_init=config.pruning_mask_init,
pruning_mask_scale=config.pruning_mask_scale, pruning_module=config.pruning_module,
LR_weight_rank=config.LR_weight_rank,
LR_mask_rank=config.LR_mask_rank,)
LR_mask_rank=config.LR_mask_rank, )
self.input_layernorm = BertLayerNorm(config.hidden_size, eps=config.layernorm_epsilon)
self.post_attention_layernorm = BertLayerNorm(config.hidden_size, eps=config.layernorm_epsilon)
self.post_cross_attention_layernorm = BertLayerNorm(config.hidden_size, eps=config.layernorm_epsilon)
self.intermediate = mpu.ColumnParallelLinear(config.hidden_size, config.intermediate_size, gather_output=False, init_method=init_method,
pruning_method=ffn_pruning_method if config.pruning_module in ['all', 'decoder', 'decoder_ffn', 'decoder_self+ffn'] else None,
pruning_mask_init=config.pruning_mask_init, pruning_mask_scale=config.pruning_mask_scale,
LR_weight_rank=config.LR_weight_rank, LR_mask_rank=config.LR_mask_rank,)
self.intermediate = mpu.ColumnParallelLinear(config.hidden_size, config.intermediate_size, gather_output=False,
init_method=init_method,
pruning_method=ffn_pruning_method if config.pruning_module in [
'all', 'decoder', 'decoder_ffn', 'decoder_self+ffn'] else None,
pruning_mask_init=config.pruning_mask_init,
pruning_mask_scale=config.pruning_mask_scale,
LR_weight_rank=config.LR_weight_rank,
LR_mask_rank=config.LR_mask_rank, )
self.intermediate_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, str) else config.hidden_act
self.output = mpu.RowParallelLinear(config.intermediate_size, config.hidden_size, input_is_parallel=True, init_method=output_layer_init_method,
pruning_method=ffn_pruning_method if config.pruning_module in ['all', 'decoder', 'decoder_ffn', 'decoder_self+ffn'] else None,
pruning_mask_init=config.pruning_mask_init, pruning_mask_scale=config.pruning_mask_scale,
LR_weight_rank=config.LR_weight_rank, LR_mask_rank=config.LR_mask_rank,)
self.output = mpu.RowParallelLinear(config.intermediate_size, config.hidden_size, input_is_parallel=True,
init_method=output_layer_init_method,
pruning_method=ffn_pruning_method if config.pruning_module in ['all',
'decoder',
'decoder_ffn',
'decoder_self+ffn'] else None,
pruning_mask_init=config.pruning_mask_init,
pruning_mask_scale=config.pruning_mask_scale,
LR_weight_rank=config.LR_weight_rank, LR_mask_rank=config.LR_mask_rank, )
self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
self.fp32_layernorm = config.fp32_layernorm
def convert_to_type(tensor):
if self.fp32_layernorm:
return tensor.float()
else:
return tensor
self.type_converter = convert_to_type
#def forward(self, hidden_states, enc_attn_mask, dec_attn_mask):
def forward(self, hidden_states, enc_hidden_states, enc_attn_mask, dec_attn_mask, is_infer=False, pruning_threshold=None):
self.type_converter = convert_to_type
# def forward(self, hidden_states, enc_attn_mask, dec_attn_mask):
def forward(self, hidden_states, enc_hidden_states, enc_attn_mask, dec_attn_mask, is_infer=False,
pruning_threshold=None):
residual = hidden_states
previous_type = hidden_states.type()
hidden_states = self.input_layernorm(self.type_converter(hidden_states))
if self.fp32_layernorm:
hidden_states = hidden_states.type(previous_type)
hidden_states = self.attention(hidden_states, dec_attn_mask, is_infer=is_infer, pruning_threshold=pruning_threshold)
# add dropout?
# hidden_states = self.dropout(hidden_states)
hidden_states = self.attention(hidden_states, dec_attn_mask, is_infer=is_infer,
pruning_threshold=pruning_threshold)
hidden_states = residual + hidden_states
residual = hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(self.type_converter(hidden_states))
if self.fp32_layernorm:
# same to the output of BertAttention
hidden_states = hidden_states.type(previous_type)
hidden_states = self.cross_attention(hidden_states, enc_hidden_states, enc_attn_mask, pruning_threshold=pruning_threshold)
# hidden_states = self.dropout(hidden_states)
hidden_states = self.cross_attention(hidden_states, enc_hidden_states, enc_attn_mask,
pruning_threshold=pruning_threshold)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_cross_attention_layernorm(self.type_converter(hidden_states))
@@ -877,55 +702,61 @@ class DecodeLayer(nn.Module):
hidden_states = hidden_states.type(previous_type)
hidden_states = self.intermediate(hidden_states, pruning_threshold=pruning_threshold)
hidden_states = self.intermediate_act_fn(hidden_states)
# hidden_states = self.dropout(hidden_states)
hidden_states = self.output(hidden_states, pruning_threshold=pruning_threshold)
hidden_states = self.dropout(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class BertDecoder(nn.Module):
def __init__(self, config):
super(BertDecoder, self).__init__()
self.layer = nn.ModuleList([DecodeLayer(config) for _ in range(config.dec_hidden_layers)])
self.final_layernorm = BertLayerNorm(config.hidden_size, eps=config.layernorm_epsilon)
self.fp32_layernorm = config.fp32_layernorm
def forward(self, hidden_states, enc_hidden_states, enc_attn_mask, dec_attn_mask, checkpoint_activations=False, output_all_encoded_layers=False, is_infer=False, pruning_threshold=None):
all_encoder_layers = []
def forward(self, hidden_states, enc_hidden_states, enc_attn_mask, dec_attn_mask, checkpoint_activations=False,
output_all_encoded_layers=False, is_infer=False, pruning_threshold=None):
def custom(start, end):
def custom_forward(*inputs):
layers = self.layer[start:end]
x_ = inputs[0]
for layer in layers:
x_ = layer(x_, inputs[1], inputs[2], dec_attn_mask*1, is_infer=is_infer, pruning_threshold=pruning_threshold)
x_ = layer(x_, inputs[1], inputs[2], dec_attn_mask * 1, is_infer=is_infer,
pruning_threshold=pruning_threshold)
return x_
return custom_forward
pre_enc_hidden= enc_hidden_states.data
pre_enc_hidden = enc_hidden_states.data
if checkpoint_activations:
l = 0
num_layers = len(self.layer)
chunk_length = 1 #math.ceil(math.sqrt(num_layers))
chunk_length = 1
while l < num_layers:
hidden_states = mpu.checkpoint(custom(l, l+chunk_length), hidden_states, enc_hidden_states, enc_attn_mask*1)
hidden_states = mpu.checkpoint(custom(l, l + chunk_length), hidden_states, enc_hidden_states,
enc_attn_mask * 1)
enc_hidden_states.data = pre_enc_hidden
l += chunk_length
else:
for i,layer_module in enumerate(self.layer):
hidden_states = layer_module(hidden_states, enc_hidden_states, enc_attn_mask, dec_attn_mask, is_infer=is_infer, pruning_threshold=pruning_threshold)
for i, layer_module in enumerate(self.layer):
hidden_states = layer_module(hidden_states, enc_hidden_states, enc_attn_mask, dec_attn_mask,
is_infer=is_infer, pruning_threshold=pruning_threshold)
previous_type = hidden_states.type()
if self.fp32_layernorm:
hidden_states = hidden_states.float()
hidden_states = self.final_layernorm(hidden_states)
if self.fp32_layernorm:
hidden_states = hidden_states.type(previous_type)
return [hidden_states]
class DecodeModel(PreTrainedBertModel):
def __init__(self, config):
@@ -933,22 +764,24 @@ class DecodeModel(PreTrainedBertModel):
self.decoder = BertDecoder(config)
self.apply(self.init_bert_weights)
def forward(self, embeddings, sequence_output, decode_input_ids, position_ids=None, enc_attn_mask=None, dec_attn_mask=None, checkpoint_activations=False, is_infer=False, pruning_threshold=None):
def forward(self, embeddings, sequence_output, decode_input_ids, position_ids=None, enc_attn_mask=None,
dec_attn_mask=None, checkpoint_activations=False, is_infer=False, pruning_threshold=None):
extended_attention_mask = enc_attn_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = extended_attention_mask.to(dtype=next(self.decoder.parameters()).dtype) # fp16 compatibility
extended_attention_mask = extended_attention_mask.to(
dtype=next(self.decoder.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
embedding_output = embeddings(decode_input_ids)
sequence_output = self.decoder(embedding_output,
sequence_output,
extended_attention_mask,
dec_attn_mask,
checkpoint_activations=False,
is_infer=is_infer,
pruning_threshold=pruning_threshold)
sequence_output,
extended_attention_mask,
dec_attn_mask,
checkpoint_activations=False,
is_infer=is_infer,
pruning_threshold=pruning_threshold)
return sequence_output[-1]
class PalmForPreTraining(PreTrainedBertModel):
def __init__(self, config):
super(PalmForPreTraining, self).__init__(config)
@@ -957,65 +790,52 @@ class PalmForPreTraining(PreTrainedBertModel):
self.decoder = DecodeModel(config)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, decode_input_ids=None, position_ids=None, decode_attention_mask=None, lm_labels=None, checkpoint_activations=False, is_infer=False, sequence_output=None, parallel_output=True, pruning_threshold=None):
def forward(self, input_ids, token_type_ids=None, attention_mask=None, decode_input_ids=None, position_ids=None,
decode_attention_mask=None, lm_labels=None, checkpoint_activations=False, is_infer=False,
sequence_output=None, parallel_output=True, pruning_threshold=None):
if sequence_output is None:
sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False, checkpoint_activations=checkpoint_activations, pruning_threshold=pruning_threshold)
output_all_encoded_layers=False,
checkpoint_activations=checkpoint_activations,
pruning_threshold=pruning_threshold)
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
else:
prediction_scores = None
seq_relationship_score = None
sequence_output = sequence_output.to(dtype=next(self.decoder.parameters()).dtype)
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
decode_output = self.decoder(self.bert.embeddings, sequence_output, decode_input_ids, position_ids, attention_mask, decode_attention_mask, checkpoint_activations=checkpoint_activations, is_infer=is_infer, pruning_threshold=pruning_threshold)
decode_output = self.decoder(self.bert.embeddings, sequence_output, decode_input_ids, position_ids,
attention_mask, decode_attention_mask,
checkpoint_activations=checkpoint_activations, is_infer=is_infer,
pruning_threshold=pruning_threshold)
#prediction_scores = self.cls(decode_output)
transformer_output_parallel = mpu.copy_to_model_parallel_region(
decode_output)
logits_parallel = F.linear(transformer_output_parallel,
self.bert.embeddings.word_embeddings.weight)
if parallel_output:
return prediction_scores, logits_parallel
if is_infer:
return prediction_scores, mpu.gather_from_model_parallel_region(logits_parallel), sequence_output
return prediction_scores, mpu.gather_from_model_parallel_region(logits_parallel)
class PlugModel(torch.nn.Module):
def __init__(self, config):
super(PlugModel, self).__init__()
if config.intermediate_size is None:
intermediate_size = 4 * config.hidden_size
else:
intermediate_size = config.intermediate_size
self.config = config
# self.config = BertConfig(
# args.tokenizer_num_tokens,
# hidden_size=args.hidden_size,
# num_hidden_layers=args.num_layers,
# num_attention_heads=args.num_attention_heads,
# intermediate_size=intermediate_size,
# hidden_dropout_prob=args.hidden_dropout,
# attention_probs_dropout_prob=args.attention_dropout,
# max_position_embeddings=args.max_position_embeddings,
# type_vocab_size=args.tokenizer_num_type_tokens,
# fp32_layernorm=args.fp32_layernorm,
# fp32_embedding=args.fp32_embedding,
# fp32_tokentypes=args.fp32_tokentypes,
# layernorm_epsilon=args.layernorm_epsilon,
# deep_init=args.deep_init,
# dec_hidden_layers=args.dec_layers)
self.model = PalmForPreTraining(self.config)
def forward(self, input_tokens, token_type_ids=None,
attention_mask=None, target_tokens=None, position_ids=None, decode_attention_mask=None, checkpoint_activations=False, is_infer=False, sequence_output=None, parallel_output=True):
attention_mask=None, target_tokens=None, position_ids=None, decode_attention_mask=None,
checkpoint_activations=False, is_infer=False, sequence_output=None, parallel_output=True):
return self.model(
input_tokens, token_type_ids, attention_mask, target_tokens, position_ids,
decode_attention_mask, checkpoint_activations=checkpoint_activations, is_infer=is_infer, sequence_output=sequence_output, parallel_output=parallel_output)
input_tokens, token_type_ids, attention_mask, target_tokens, position_ids,
decode_attention_mask, checkpoint_activations=checkpoint_activations, is_infer=is_infer,
sequence_output=sequence_output, parallel_output=parallel_output)
def state_dict(self, destination=None, prefix='', keep_vars=False):
return self.model.state_dict(destination=destination, prefix=prefix,
@@ -1023,5 +843,3 @@ class PlugModel(torch.nn.Module):
def load_state_dict(self, state_dict, strict=True):
return self.model.load_state_dict(state_dict, strict=strict)

View File

@@ -1,59 +1,35 @@
import torch
from typing import Dict
from functools import partial
from typing import Dict, Any
from . import DistributedPlug
from ...base import Tensor, TorchModel
import torch
from ...builder import MODELS
from ....metainfo import Models
from ....outputs import OutputKeys
from ....utils.constant import Tasks
from modelscope.models.nlp.structbert import SbertTokenizer
from modelscope.models.nlp.utils.distributed import DistributedTorchModel
from . import DistributedPlug
from ...base import Tensor
__all__ = ['PlugForTextGeneration']
__all__ = ['PLUGForTextGeneration']
@MODELS.register_module(Tasks.text_generation, module_name=Models.plug)
class PlugForTextGeneration(TorchModel):
def __init__(self, model_dir: str, *args, **kwargs):
super().__init__(model_dir, *args, **kwargs)
import torch
class PlugForTextGeneration(DistributedTorchModel):
from transformers import BertTokenizer
from multiprocessing import Pool
from .arguments import get_args
from . import PlugNLGConfig
import torch
torch.multiprocessing.set_start_method("spawn")
self.tokenizer = BertTokenizer.from_pretrained(model_dir)
model_config = PlugNLGConfig.from_pretrained(model_dir)
# TODO(suluyan): Arguments
args = get_args()
args.world_size = 8
args.model_parallel_size = 8
args.pre_load = True
args.distributed_backend = 'nccl'
args.fp16 = True
args.fp32_layernorm = True
args.checkpoint_activations = True
args.batch_size = 1
args.top_k = 20
args.top_p = 0.0
args.temperature = 0.9
self.args = args
self.world_size = args.world_size
ranks = list(range(self.world_size))
self.model_pool = Pool(self.world_size)
self.model_pool.map(partial(DistributedPlug.init, model_dir=model_dir, model_config=model_config, args=args), ranks)
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
def _forward_one(self, input: Dict[str, Any]) -> Dict[str, Tensor]:
return self.model(**input)
def generate(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
dec_input_ids = torch.full([self.args.batch_size, 1], self.tokenizer.cls_token_id, dtype=torch.long)
batch_size = input['input_ids'].shape[0]
dec_input_ids = torch.full([batch_size, 1], self.cls_token_id, dtype=torch.long)
input["dec_input_ids"] = dec_input_ids
res = self.model_pool.map(DistributedPlug.forward, [input]*self.world_size)
return res[0]
def _instantiate_one(self, model_dir, rank):
tokenizer = SbertTokenizer.from_pretrained(model_dir)
self.cls_token_id = tokenizer.cls_token_id
self.model = DistributedPlug.instantiate(model_dir, rank)

View File

@@ -0,0 +1,72 @@
from modelscope.models import Model
from multiprocessing import Pool
from functools import partial
from typing import Dict, Any
from modelscope.utils.hub import read_config
from modelscope.utils.torch_utils import init_dist, _is_free_port, _find_free_port
import torch
import os
import math
class DistributedTorchModel(Model):
def __init__(self, model_dir, *args, **kwargs):
super().__init__(model_dir, *args, **kwargs)
self.model_pool = None
self.world_size = None
@classmethod
def _instantiate(cls, model_dir):
model = DistributedTorchModel(model_dir=model_dir)
torch.multiprocessing.set_start_method("spawn")
cfg = read_config(model_dir)
model.world_size = cfg.model.word_size
ranks = list(range(model.world_size))
model.model_pool = Pool(model.world_size)
model.model_pool.map(partial(model._instantiate_one, model_dir=model_dir), ranks)
return model
def _instantiate_one(self, model_dir, rank):
pass
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
res = self.model_pool.map(self._forward_one, [input]*self.world_size)
return res[0]
def _forward_one(self, input):
pass
def initialize_distributed(rank, mpu, world_size, model_parallel_size):
"""Initialize torch.distributed."""
# Manually set the device ids.
device = rank % torch.cuda.device_count()
torch.cuda.set_device(device)
# Call the init process
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', '127.0.0.1')
master_port = os.getenv('MASTER_PORT', '29500')
if not _is_free_port(int(master_port)):
master_port = str(_find_free_port())
init_method += master_ip + ':' + master_port
init_dist('pytorch', world_size=world_size, rank=rank, init_method=init_method)
# Set the model-parallel communicators.
mpu.initialize_model_parallel(model_parallel_size)
def normal_init_method(mean, std):
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=mean, std=std)
return init_
def scaled_init_method(mean, std, num_layers):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = std / math.sqrt(2.0 * num_layers)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=mean, std=std)
return init_

View File

@@ -10,6 +10,7 @@ from typing import Callable, List, Optional, Tuple, Union
import json
import numpy as np
import torch
from modelscope.utils.torch_utils import set_random_seed
from addict import Dict
from torch import distributed as dist
from torch import nn
@@ -816,6 +817,4 @@ def worker_init_fn(worker_id, num_workers, rank, seed):
# The seed of each worker equals to
# num_worker * rank + worker_id + user_seed
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)
torch.manual_seed(worker_seed)
set_random_seed(worker_seed)

View File

@@ -14,7 +14,7 @@
# limitations under the License.
import torch
from sofa.utils import mpu
from modelscope.utils.nlp import mpu
# item() is a recent addition, so this helps with backward compatibility.
def to_python_float(t):

View File

@@ -24,9 +24,6 @@ import torch
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.parameter import Parameter
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
from .initialize import get_model_parallel_rank
from .initialize import get_model_parallel_world_size
from .mappings import copy_to_model_parallel_region

View File

@@ -1,82 +0,0 @@
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import random
import numpy
import torch
import mpu
class IdentityLayer(torch.nn.Module):
def __init__(self, size, scale=1.0):
super(IdentityLayer, self).__init__()
self.weight = torch.nn.Parameter(scale * torch.randn(size))
def forward(self):
return self.weight
def set_random_seed(seed):
"""Set random seed for reproducability."""
random.seed(seed)
numpy.random.seed(seed)
torch.manual_seed(seed)
mpu.model_parallel_cuda_manual_seed(seed)
def initialize_distributed(backend='nccl'):
"""Initialize torch.distributed."""
# Get local rank in case it is provided.
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=None,
help='local rank passed from distributed launcher')
args = parser.parse_args()
local_rank = args.local_rank
# Get rank and world size.
rank = int(os.getenv('RANK', '0'))
world_size = int(os.getenv("WORLD_SIZE", '1'))
print('> initializing torch.distributed with local rank: {}, '
'rank: {}, world size: {}'.format(local_rank, rank, world_size))
# Set the device id.
device = rank % torch.cuda.device_count()
if local_rank is not None:
device = local_rank
#torch.cuda.set_device(device)
# Call the init process.
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = os.getenv('MASTER_PORT', '6000')
init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(
backend=backend,
world_size=world_size,
rank=rank,
init_method=init_method)
def print_separator(message):
torch.distributed.barrier()
filler_len = (78 - len(message)) // 2
filler = '-' * filler_len
string = '\n' + filler + ' {} '.format(message) + filler
if torch.distributed.get_rank() == 0:
print(string, flush=True)
torch.distributed.barrier()

View File

@@ -1,110 +0,0 @@
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import sys
sys.path.append("../..")
import torch
import torch.nn.functional as F
import mpu
from mpu.cross_entropy import vocab_parallel_cross_entropy
from commons import initialize_distributed
from commons import print_separator
from commons import IdentityLayer
from commons import set_random_seed
def torch_cross_entropy(batch_size, seq_length, vocab_size,
logits_scale, seed):
set_random_seed(seed)
identity = IdentityLayer((batch_size, seq_length, vocab_size),
scale=logits_scale).cuda()
logits = identity()
target = torch.cuda.LongTensor(
size=(batch_size, seq_length)).random_(0, vocab_size)
loss = F.cross_entropy(logits.view(-1, logits.size()[-1]),
target.view(-1),
reduction='none').view_as(target).mean()
loss.backward()
return loss, identity.weight.grad
def mpu_cross_entropy(batch_size, seq_length, vocab_size,
logits_scale, seed):
set_random_seed(seed)
identity = IdentityLayer((batch_size, seq_length, vocab_size),
scale=logits_scale).cuda()
logits = identity()
logits_parallel = mpu.scatter_to_model_parallel_region(logits)
target = torch.cuda.LongTensor(
size=(batch_size, seq_length)).random_(0, vocab_size)
loss = vocab_parallel_cross_entropy(logits_parallel, target).mean()
loss.backward()
return loss, identity.weight.grad
def test_cross_entropy(model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing cross entropy with model parallel size {} ...'.
format(model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
batch_size = 13
seq_length = 17
vocab_size_per_partition = 11
logits_scale = 1000.0
vocab_size = vocab_size_per_partition * model_parallel_size
seed = 1234
loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length,
vocab_size, logits_scale,
seed)
loss_mpu, grad_mpu = mpu_cross_entropy(batch_size, seq_length,
vocab_size, logits_scale,
seed)
error = loss_torch.sub_(loss_mpu).abs().max()
print(' max error in loss on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
error = grad_torch.sub_(grad_mpu).abs().max()
print(' max error in grad on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>> passed the test :-)')
if __name__ == '__main__':
initialize_distributed()
world_size = torch.distributed.get_world_size()
model_parallel_size = 1
while model_parallel_size <= world_size:
print_separator('test cross entropy')
test_cross_entropy(model_parallel_size)
model_parallel_size *= 2

View File

@@ -1,92 +0,0 @@
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import operator
import sys
sys.path.append("../..")
import torch
import mpu
from mpu import data as data_utils
from commons import initialize_distributed
from commons import print_separator
def test_boradcast_data(model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing boradcast_data with model parallel size {} ...'.
format(model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size)
torch.manual_seed(1234 + mpu.get_data_parallel_rank())
model_parallel_size = mpu.get_model_parallel_world_size()
key_size_t = {'key1': [7, 11],
'key2': [8, 2, 1],
'key3': [13],
'key4': [5, 1, 2],
'key5': [5, 12]}
keys = list(key_size_t.keys())
data = {}
data_t = {}
for key in key_size_t:
data[key] = torch.LongTensor(size=key_size_t[key]).random_(0, 1000)
data_t[key] = data[key].clone()
data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000)
data_t['keyX'] = data['keyX'].clone()
if mpu.get_model_parallel_rank() != 0:
data = None
data_utils._check_data_types(keys, data_t, torch.int64)
key_size, key_numel, \
total_numel = data_utils._build_key_size_numel_dictionaries(keys, data)
for key in keys:
assert key_size[key] == key_size_t[key]
total_numel_t = 0
for key in keys:
target_size = functools.reduce(operator.mul, key_size_t[key], 1)
assert key_numel[key] == target_size
total_numel_t += target_size
assert total_numel == total_numel_t
data_b = data_utils.broadcast_data(keys, data, torch.int64)
for key in keys:
tensor = data_t[key].cuda()
assert data_b[key].sub(tensor).abs().max() == 0
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>> passed the test :-)')
if __name__ == '__main__':
initialize_distributed()
world_size = torch.distributed.get_world_size()
model_parallel_size = 1
while model_parallel_size <= world_size:
print_separator('test test boradcast data')
test_boradcast_data(model_parallel_size)
model_parallel_size *= 2

View File

@@ -1,98 +0,0 @@
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../..")
import torch
import mpu
from commons import initialize_distributed
from commons import print_separator
def test_initialize_model_parallel(model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing initialize_model_parallel with size {} ...'.format(
model_parallel_size))
model_parallel_size_ = min(model_parallel_size,
torch.distributed.get_world_size())
assert not mpu.model_parallel_is_initialized()
mpu.initialize_model_parallel(model_parallel_size_)
assert mpu.model_parallel_is_initialized()
# Checks.
def check(group, world_size, rank):
assert world_size == torch.distributed.get_world_size(group=group)
assert rank == torch.distributed.get_rank(group=group)
# Model parallel.
world_size = model_parallel_size_
rank = torch.distributed.get_rank() % model_parallel_size_
assert world_size == mpu.get_model_parallel_world_size()
assert rank == mpu.get_model_parallel_rank()
check(mpu.get_model_parallel_group(), world_size, rank)
# Data parallel.
world_size = torch.distributed.get_world_size() // model_parallel_size_
rank = torch.distributed.get_rank() // model_parallel_size
assert world_size == mpu.get_data_parallel_world_size()
assert rank == mpu.get_data_parallel_rank()
check(mpu.get_data_parallel_group(), world_size, rank)
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>> passed the test :-)')
def test_get_model_parallel_src_rank(model_parallel_size_):
if torch.distributed.get_rank() == 0:
print('> testing get_model_parallel_src_rank with size {} ...'.format(
model_parallel_size_))
model_parallel_size = min(model_parallel_size_,
torch.distributed.get_world_size())
assert not mpu.model_parallel_is_initialized()
mpu.initialize_model_parallel(model_parallel_size)
assert mpu.model_parallel_is_initialized()
# Checks
src_rank = torch.distributed.get_rank() - mpu.get_model_parallel_rank()
assert mpu.get_model_parallel_src_rank() == src_rank
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>> passed the test :-)')
if __name__ == '__main__':
initialize_distributed()
world_size = torch.distributed.get_world_size()
model_parallel_size = 1
while model_parallel_size <= world_size:
print_separator('test initialize model parallel')
test_initialize_model_parallel(model_parallel_size)
print_separator('test model parallel source rank')
test_get_model_parallel_src_rank(model_parallel_size)
model_parallel_size *= 2

View File

@@ -1,529 +0,0 @@
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import sys
sys.path.append("../..")
import torch
import torch.nn.init as init
from torch.nn.parameter import Parameter
import mpu
from commons import initialize_distributed
from commons import print_separator
from commons import set_random_seed
from mpu import layers
def test_parallel_embedding(model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing parallel embedding with model parallel size {} ...'.
format(model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
batch_size = 17
seq_length = 23
vocab_size = 48
hidden_size = 16
seed = 1236
set_random_seed(123)
input_data = torch.LongTensor(
size=(batch_size,seq_length)).random_(0, vocab_size).cuda()
loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda()
set_random_seed(seed)
embedding_original = torch.nn.Embedding(vocab_size, hidden_size).cuda()
output = embedding_original(input_data)
loss_original = torch.mul(output, loss_weight).sum()
loss_original.backward()
set_random_seed(seed)
embedding_parallel = layers.ParallelEmbedding(
vocab_size, hidden_size, init_method=init.normal_).cuda()
output = embedding_parallel(input_data)
loss_parallel = torch.mul(output, loss_weight).sum()
loss_parallel.backward()
set_random_seed(seed)
embedding_vocab_parallel = layers.VocabParallelEmbedding(
vocab_size, hidden_size, init_method=init.normal_).cuda()
output = embedding_vocab_parallel(input_data)
loss_vocab_parallel = torch.mul(output, loss_weight).sum()
loss_vocab_parallel.backward()
torch.distributed.barrier()
error = loss_parallel.sub(loss_original).abs()
print(' error in loss (parallel) on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-12, 'error: {}'.format(error)
torch.distributed.barrier()
error = loss_vocab_parallel.sub(loss_original).abs()
print(' error in loss (vocab parallel) on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-12, 'error: {}'.format(error)
weight_grad_orig = torch.split(embedding_original.weight.grad,
hidden_size // model_parallel_size,
1)[mpu.get_model_parallel_rank()]
error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max()
print(' error in grad (parallel) on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-12, 'error: {}'.format(error)
weight_grad_orig = torch.split(embedding_original.weight.grad,
vocab_size // model_parallel_size,
0)[mpu.get_model_parallel_rank()]
error = embedding_vocab_parallel.weight.grad.sub(
weight_grad_orig).abs().max()
print(' error in grad (vocab parallel) on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-12, 'error: {}'.format(error)
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>> passed the test :-)')
def test_initialize_affine_weight(model_parallel_size):
mpu.initialize_model_parallel(model_parallel_size)
if torch.distributed.get_rank() == 0:
print('> testing initialize_affine_weight with model parallel '
'size: {}'.format(model_parallel_size))
model_parallel_size = mpu.get_model_parallel_world_size()
seed = 12345
input_size_coeff = 13
input_size = input_size_coeff * model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * model_parallel_size
# ---------------
# Column parallel
# ---------------
weight = torch.empty(output_size_coeff, input_size)
set_random_seed(seed)
layers._initialize_affine_weight(weight, output_size, input_size,
output_size_coeff, 0,
torch.nn.init.normal_)
# Target.
set_random_seed(seed)
master_weight = torch.empty(output_size, input_size)
torch.nn.init.normal_(master_weight)
rank = mpu.get_model_parallel_rank()
my_weight = torch.split(master_weight, output_size_coeff,
dim=0)[rank].contiguous().clone()
# Compare.
error = weight.sub(my_weight).abs().max()
torch.distributed.barrier()
print(' column parallel max error (should be zero) on global rank '
'{}: {}'.format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
# ------------
# Row parallel
# ------------
weight = torch.empty(output_size, input_size_coeff)
set_random_seed(seed)
mpu.layers._initialize_affine_weight(weight, output_size, input_size,
input_size_coeff, 1,
torch.nn.init.normal_)
# Target.
set_random_seed(seed)
master_weight = torch.empty(output_size, input_size)
torch.nn.init.normal_(master_weight)
rank = mpu.get_model_parallel_rank()
my_weight = torch.split(master_weight, input_size_coeff,
dim=1)[rank].contiguous().clone()
# Compare.
error = weight.sub(my_weight).abs().max()
torch.distributed.barrier()
print(' row parallel max error (should be zero) on global rank '
'{}: {}'.format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
class IdentityLayer2D(torch.nn.Module):
def __init__(self, m , n):
super(IdentityLayer2D, self).__init__()
self.weight = Parameter(torch.Tensor(m, n))
torch.nn.init.xavier_normal_(self.weight)
def forward(self):
return self.weight
def test_column_parallel_linear(model_parallel_size):
mpu.initialize_model_parallel(model_parallel_size)
if torch.distributed.get_rank() == 0:
print('> testing ColumnParallelLinear with model parallel '
'size: {}'.format(model_parallel_size))
model_parallel_size = mpu.get_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
input_size_coeff = 13
input_size = input_size_coeff * model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * model_parallel_size
batch_size = 7
# Network
identity_layer = IdentityLayer2D(batch_size, input_size).cuda()
linear_layer = mpu.ColumnParallelLinear(
input_size, output_size, keep_master_weight_for_test=True).cuda()
loss_weight = torch.randn([batch_size, output_size]).cuda()
# Forward
input_ = identity_layer()
output = linear_layer(input_)
loss = torch.mul(output, loss_weight).sum()
# Backward
loss.backward()
# Values.
dLdY = loss_weight
X = identity_layer.weight
A = linear_layer.master_weight.cuda()
dLdA = torch.matmul(dLdY.t(), X)
dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
dLdX = torch.matmul(dLdY, A)
rank = mpu.get_model_parallel_rank()
my_dLdA = torch.split(dLdA, output_size_coeff,
dim=0)[rank].contiguous().clone()
error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdA on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
my_dLdb = torch.split(dLdb, output_size_coeff,
dim=0)[rank].contiguous().clone()
error = my_dLdb.sub(linear_layer.bias.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdb on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
error = dLdX.sub(identity_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdX on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
def test_row_parallel_linear(model_parallel_size):
mpu.initialize_model_parallel(model_parallel_size)
if torch.distributed.get_rank() == 0:
print('> testing RowParallelLinear with model parallel '
'size: {}'.format(model_parallel_size))
model_parallel_size = mpu.get_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
input_size_coeff = 13
input_size = input_size_coeff * model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * model_parallel_size
batch_size = 7
# Network
identity_layer = IdentityLayer2D(batch_size, input_size).cuda()
linear_layer = mpu.RowParallelLinear(
input_size, output_size, keep_master_weight_for_test=True).cuda()
loss_weight = torch.randn([batch_size, output_size]).cuda()
# Forward
input_ = identity_layer()
output = linear_layer(input_)
loss = torch.mul(output, loss_weight).sum()
# Backward
loss.backward()
# Values.
dLdY = loss_weight
X = identity_layer.weight
A = linear_layer.master_weight.cuda()
dLdA = torch.matmul(dLdY.t(), X)
dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
dLdX = torch.matmul(dLdY, A)
rank = mpu.get_model_parallel_rank()
my_dLdA = torch.split(dLdA, input_size_coeff,
dim=1)[rank].contiguous().clone()
error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdA on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
error = dLdb.sub(linear_layer.bias.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdb on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
error = dLdX.sub(identity_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdX on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
class IdentityLayer3D(torch.nn.Module):
def __init__(self, m , n, k):
super(IdentityLayer3D, self).__init__()
self.weight = Parameter(torch.Tensor(m, n, k))
torch.nn.init.xavier_normal_(self.weight)
def forward(self):
return self.weight
def parallel_self_attention(model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size,
sequence_length):
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
num_att_heads = num_att_heads_per_partition * \
torch.distributed.get_world_size()
hidden_size = hidden_size_per_att_head * num_att_heads
# Network
identity_layer = IdentityLayer3D(batch_size, sequence_length,
hidden_size).cuda()
attention_layer = mpu.BertParallelSelfAttention(hidden_size, num_att_heads,
dropout_prob).cuda()
loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda()
attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda()
# Forward
input_ = identity_layer()
output = attention_layer(input_, attention_mask)
loss = torch.mul(output, loss_weight).sum()
# Backward
loss.backward()
rank = mpu.get_model_parallel_rank()
mpu.destroy_model_parallel()
return rank, hidden_size, model_parallel_size, loss, \
attention_layer, identity_layer
def test_parallel_self_attention(model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing ParallelSelfAttention with model parallel '
'size: {}'.format(model_parallel_size))
num_att_heads_per_partition = 3
hidden_size_per_att_head = 7
dropout_prob = 0.0 # has to be zero
batch_size = 5
sequence_length = 13
rank_1, hideen_size_1, model_parallel_size_1, loss_1, \
attention_layer_1, identity_layer_1 =parallel_self_attention(
1, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
rank, hidden_size, model_parallel_size, loss, \
attention_layer, identity_layer =parallel_self_attention(
model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
assert hideen_size_1 == hidden_size
error = loss_1.sub(loss).abs().max()
torch.distributed.barrier()
print(' loss error on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 5.0e-6
my_lin_grad_list = torch.split(
attention_layer_1.query_key_value.weight.grad,
hidden_size // model_parallel_size, 0)[rank::model_parallel_size]
my_lin_grad = torch.cat(my_lin_grad_list, dim=0)
error = my_lin_grad.sub(
attention_layer.query_key_value.weight.grad).abs().max()
torch.distributed.barrier()
print(' weight gradient error on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 5.0e-6
error = identity_layer_1.weight.grad.sub(
identity_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' input gradient error on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 5.0e-6
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
def parallel_transformer(model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length):
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
num_att_heads = num_att_heads_per_partition * \
torch.distributed.get_world_size()
hidden_size = hidden_size_per_att_head * num_att_heads
intermediate_size = 4 * hidden_size
# Network
identity_layer = IdentityLayer3D(batch_size, sequence_length,
hidden_size).cuda()
transformer_layer = mpu.BertParallelTransformerLayer(
hidden_size, intermediate_size, num_att_heads, 0.0, 0.0,
torch.nn.functional.relu, 1.0e-5).cuda()
loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda()
attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda()
# Forward
input_ = identity_layer()
output = transformer_layer(input_, attention_mask)
loss = torch.mul(output, loss_weight).sum()
# Backward
loss.backward()
rank = mpu.get_model_parallel_rank()
mpu.destroy_model_parallel()
return rank, hidden_size, model_parallel_size, loss, \
transformer_layer, identity_layer
def test_parallel_transformer_layer(model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing ParallelTransformerLayer with model parallel '
'size: {}'.format(model_parallel_size))
num_att_heads_per_partition = 3
hidden_size_per_att_head = 7
batch_size = 5
sequence_length = 13
rank_1, hidden_size_1, model_parallel_size_1, loss_1, \
transformer_layer_1, identity_layer_1 = parallel_transformer(
1, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length)
rank, hidden_size, model_parallel_size, loss, \
transformer_layer, identity_layer = parallel_transformer(
model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length)
error = loss_1.sub(loss).abs().max()
torch.distributed.barrier()
print(' loss error on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 5.0e-5, 'error: {}'.format(error)
error = identity_layer_1.weight.grad.sub(
identity_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' input gradient error on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 5.0e-5, 'error: {}'.format(error)
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
if __name__ == '__main__':
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
initialize_distributed()
world_size = torch.distributed.get_world_size()
print_separator('test initialize affine weight')
model_parallel_size = 1
while model_parallel_size <= world_size:
test_initialize_affine_weight(model_parallel_size)
model_parallel_size *= 2
model_parallel_size = 1
while model_parallel_size <= world_size:
print_separator('test parallel embedding')
test_parallel_embedding(model_parallel_size)
model_parallel_size *= 2
print_separator('test column-parallel linear')
model_parallel_size = 1
while model_parallel_size <= world_size:
test_column_parallel_linear(model_parallel_size)
model_parallel_size *= 2
print_separator('test row-parallel linear')
model_parallel_size = 1
while model_parallel_size <= world_size:
test_row_parallel_linear(model_parallel_size)
model_parallel_size *= 2
print_separator('test parallel self-attention')
model_parallel_size = 1
while model_parallel_size <= world_size:
test_parallel_self_attention(model_parallel_size)
model_parallel_size *= 2
print_separator('test parallel transformer')
model_parallel_size = 1
while model_parallel_size <= world_size:
test_parallel_transformer_layer(model_parallel_size)
model_parallel_size *= 2

View File

@@ -1,207 +0,0 @@
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../..")
import torch
import mpu
from commons import initialize_distributed
from commons import print_separator
def test_set_cuda_rng_state(model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing set_rng_state with size {} ...'.
format(model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
size = 123
seed = 1234
torch.cuda.manual_seed(1234)
tensor = torch.cuda.FloatTensor(size)
# Get the state
rng_state = torch.cuda.get_rng_state()
rng_state_copy = rng_state.clone()
# Do some stuff.
for _ in range(5):
torch.randn(size, out=tensor)
result_1 = tensor.clone()
assert rng_state.sub(rng_state_copy).max() == 0
assert torch.cuda.get_rng_state().sub(rng_state_copy).max() > 0
# State should be different.
new_rng_state = torch.cuda.get_rng_state()
max_diff = new_rng_state.sub(rng_state).max()
print(' max diff in rng state (should be non-zero) on global rank {}: {}'.
format(torch.distributed.get_rank(), max_diff))
assert max_diff > 0
# Reset the rng state and do the same stuff.
mpu.random._set_cuda_rng_state(rng_state)
for _ in range(5):
torch.randn(size, out=tensor)
mpu.random._set_cuda_rng_state(rng_state)
for _ in range(5):
torch.randn(size, out=tensor)
result_2 = tensor.clone()
# Results should be the same
error = result_2.sub(result_1).abs().max()
print(' max error in generated tensors (should be zero) on '
'global rank {}: {}'.format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Input state should have remained intact.
error = rng_state.sub(rng_state_copy).max()
print(' max error in rng state (should be zero) on global rank {}: {}'.
format(torch.distributed.get_rank(), error))
assert error == 0
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>> passed the test :-)')
def test_cuda_rng_tracker(model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing cuda rng tracker with size {} ...'.
format(model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
seed_1 = 1234
seed_2 = 4321
size = [12, 21]
tensor = torch.cuda.FloatTensor(size)
# Set to seed_1 and generate two tensors.
torch.cuda.manual_seed(seed_1)
torch.randn(size, out=tensor)
target_11 = tensor.clone()
torch.randn(size, out=tensor)
target_12 = tensor.clone()
# Set to seed_2 and generate two tensors.
torch.cuda.manual_seed(seed_2)
torch.randn(size, out=tensor)
target_21 = tensor.clone()
torch.randn(size, out=tensor)
target_22 = tensor.clone()
# Now if we interleave seed_1 and seed_2,
# we should still get the same tensors
torch.cuda.manual_seed(seed_1)
mpu.get_cuda_rng_tracker().add('test', seed_2)
torch.randn(size, out=tensor)
result_11 = tensor.clone()
with mpu.get_cuda_rng_tracker().fork('test'):
torch.randn(size, out=tensor)
result_21 = tensor.clone()
torch.randn(size, out=tensor)
result_12 = tensor.clone()
with mpu.get_cuda_rng_tracker().fork('test'):
torch.randn(size, out=tensor)
result_22 = tensor.clone()
diff = result_11.sub(result_21).abs().max()
diff = min(diff, result_12.sub(result_22).abs().max())
print(' max diff in generated tensors (should be non-zero) on '
'global rank {}: {}'.format(torch.distributed.get_rank(), diff))
assert diff > 1.0e-6
error = max(result_11.sub(target_11).abs().max(),
result_12.sub(target_12).abs().max())
error = max(error, result_21.sub(target_21).abs().max())
error = max(error, result_22.sub(target_22).abs().max())
print(' max error in generated tensors (should be zero) on '
'global rank {}: {}'.format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset the tracker
mpu.get_cuda_rng_tracker().reset()
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>> passed the test :-)')
def test_model_parallel_cuda_manual_seed(model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing model parallel cuda manual seed with size {} ...'.
format(model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
mpu.model_parallel_cuda_manual_seed(12345)
assert torch.cuda.initial_seed() == 12345
with mpu.get_cuda_rng_tracker().fork():
assert torch.cuda.initial_seed() == (12345 + 2718 +
mpu.get_model_parallel_rank())
# Reset the tracker
mpu.get_cuda_rng_tracker().reset()
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>> passed the test :-)')
if __name__ == '__main__':
initialize_distributed()
world_size = torch.distributed.get_world_size()
model_parallel_size = 1
while model_parallel_size <= world_size:
print_separator('test set rng state')
test_set_cuda_rng_state(model_parallel_size)
model_parallel_size *= 2
model_parallel_size = 1
while model_parallel_size <= world_size:
print_separator('test cuda rng tracker')
test_cuda_rng_tracker(model_parallel_size)
model_parallel_size *= 2
model_parallel_size = 1
while model_parallel_size <= world_size:
print_separator('test model parallel cuda manual seed')
test_model_parallel_cuda_manual_seed(model_parallel_size)
model_parallel_size *= 2

View File

@@ -1,21 +1,16 @@
import os
import random
import time
import numpy as np
from modelscope.utils.logger import get_logger
import torch
"""Utilities for logging and serialization"""
logger = get_logger(__name__)
def get_log_constant(user_log):
return '[user log]' if user_log else ''
def print_rank_0(message):
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(message, flush=True)
logger.info(message)
else:
print(message, flush=True)
logger.info(message)
def print_args(args):

View File

@@ -3,6 +3,8 @@
import functools
import os
import pickle
import random
import numpy as np
import socket
import subprocess
import tempfile
@@ -11,8 +13,7 @@ from typing import Callable, List, Optional, Tuple
import torch
import torch.multiprocessing as mp
from torch import distributed as dist
from torch._utils import (_flatten_dense_tensors, _take_tensors,
_unflatten_dense_tensors)
from modelscope.utils.nlp import mpu
def _find_free_port() -> str:
@@ -106,6 +107,25 @@ def _init_dist_slurm(backend: str, port: Optional[int] = None) -> None:
dist.init_process_group(backend=backend)
def initialize_distributed(rank):
"""Initialize torch.distributed."""
# Manually set the device ids.
#torch.multiprocessing.set_start_method("spawn")
device = rank % torch.cuda.device_count()
torch.cuda.set_device(device)
# Call the init process
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', '127.0.0.1')
master_port = os.getenv('MASTER_PORT', '12345')
init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(
backend="nccl",
world_size=8, rank=rank,
init_method=init_method)
# Set the model-parallel communicators.
mpu.initialize_model_parallel(8)
def get_dist_info() -> Tuple[int, int]:
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
@@ -191,3 +211,17 @@ def broadcast(inputs, src):
dist.broadcast(inputs_tensor, src)
return pickle.loads(inputs_tensor.cpu().numpy().tobytes())
def set_random_seed(seed):
if seed is not None and seed > 0:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
else:
raise ValueError(f'Random seed should be positive, current seed is {seed}')
def set_random_seed_mpu(seed):
set_random_seed(seed)
mpu.model_parallel_cuda_manual_seed(seed)

View File

@@ -8,3 +8,4 @@ seqeval
spacy>=2.3.5
tokenizers
transformers>=4.12.0
deepspeed