mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 12:10:09 +01:00
updated
This commit is contained in:
@@ -1,115 +0,0 @@
|
||||
# Copyright (c) 2022 Zhipu.AI
|
||||
import csv
|
||||
import traceback
|
||||
from io import StringIO
|
||||
from urllib import parse
|
||||
|
||||
from flask import Response, jsonify, request, send_file
|
||||
|
||||
|
||||
class APIException(Exception):
|
||||
|
||||
def __init__(self, message):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class IllegalParamException(APIException):
|
||||
|
||||
def __init__(self, error):
|
||||
self.error = error
|
||||
super(IllegalParamException, self).__init__(error)
|
||||
|
||||
|
||||
class InputTooLongException(APIException):
|
||||
|
||||
def __init__(self, message, payload=None):
|
||||
self.payload = payload
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class CanNotReturnException(APIException):
|
||||
|
||||
def __init__(self, message, payload=None):
|
||||
self.payload = payload
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class MongoDBException(APIException):
|
||||
|
||||
def __init__(self, error):
|
||||
self.error = error
|
||||
super(MongoDBException, self).__init__(error)
|
||||
|
||||
|
||||
class MissParameterException(APIException):
|
||||
|
||||
def __init__(self, error):
|
||||
self.error = error
|
||||
super(MissParameterException, self).__init__(error)
|
||||
|
||||
|
||||
class HttpUtil:
|
||||
|
||||
@staticmethod
|
||||
def http_response(status=0, message='success', data=None, total=False):
|
||||
# if status and not isinstance(data, APIException):
|
||||
# sm.send_content(request.url_rule, traceback.format_exc(), request.data)
|
||||
if isinstance(data, Exception):
|
||||
data = str(data)
|
||||
r = {'status': status, 'message': message, 'result': data or []}
|
||||
if total and type(data) == list:
|
||||
if type(total) == int:
|
||||
r['total'] = total
|
||||
else:
|
||||
r['total'] = len(data)
|
||||
return jsonify(r)
|
||||
|
||||
@staticmethod
|
||||
def check_param(
|
||||
name,
|
||||
request, # noqa
|
||||
method=0,
|
||||
param_type=None,
|
||||
default=None,
|
||||
required=True):
|
||||
if method == 0:
|
||||
param = request.args.get(name)
|
||||
else:
|
||||
try:
|
||||
param = request.json.get(name)
|
||||
except Exception as e: # noqa
|
||||
raise IllegalParamException('data format json')
|
||||
|
||||
if param is None:
|
||||
if not required:
|
||||
return default
|
||||
raise IllegalParamException('param {} is required'.format(name))
|
||||
else:
|
||||
if param_type and type(param) != param_type:
|
||||
try:
|
||||
return param_type(param)
|
||||
except ValueError:
|
||||
raise IllegalParamException(
|
||||
'param {}: type wrong, not {}'.format(
|
||||
name, param_type))
|
||||
else:
|
||||
return param
|
||||
|
||||
@staticmethod
|
||||
def csv_file_response(data, filename):
|
||||
response = Response(HttpUtil.get_csv_stream(data), mimetype='text/csv')
|
||||
response.headers[
|
||||
'Content-Disposition'] = f'attachment; filename={parse.quote(filename)}.csv'
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def get_csv_stream(data):
|
||||
line = StringIO()
|
||||
csv_writer = csv.writer(line)
|
||||
csv_writer.writerow(['name', 'org', 'position', 'email', 'phone'])
|
||||
for p in data:
|
||||
csv_writer.writerow(
|
||||
[p['name'], p['aff'], p['position'], p['email'], p['phone']])
|
||||
res = line.getvalue()
|
||||
line.close()
|
||||
return res
|
||||
@@ -16,9 +16,7 @@ import torch.nn.functional as F
|
||||
from pypinyin import FINALS, FINALS_TONE, TONE3, pinyin
|
||||
|
||||
from .arguments import get_args
|
||||
from .com_utils.http_utils import (CanNotReturnException,
|
||||
InputTooLongException,
|
||||
MissParameterException)
|
||||
|
||||
from .gpt2 import mpu
|
||||
from .gpt2.configure_data import configure_data
|
||||
from .gpt2.data_utils import make_tokenizer
|
||||
@@ -30,6 +28,23 @@ from .gpt2.utils import (Timers, get_checkpoint_iteration, load_checkpoint,
|
||||
|
||||
open_old_pronounce = 1
|
||||
|
||||
class APIException(Exception):
|
||||
|
||||
def __init__(self, message):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class CanNotReturnException(APIException):
|
||||
|
||||
def __init__(self, message, payload=None):
|
||||
self.payload = payload
|
||||
super().__init__(message)
|
||||
|
||||
class InputTooLongException(APIException):
|
||||
|
||||
def __init__(self, message, payload=None):
|
||||
self.payload = payload
|
||||
super().__init__(message)
|
||||
|
||||
def get_model(args):
|
||||
"""Build the model."""
|
||||
|
||||
0
modelscope/models/nlp/txl_poem/gpt2/data_utils/__init__.py
Executable file → Normal file
0
modelscope/models/nlp/txl_poem/gpt2/data_utils/__init__.py
Executable file → Normal file
0
modelscope/models/nlp/txl_poem/gpt2/data_utils/corpora.py
Executable file → Normal file
0
modelscope/models/nlp/txl_poem/gpt2/data_utils/corpora.py
Executable file → Normal file
0
modelscope/models/nlp/txl_poem/gpt2/mpu/cross_entropy.py
Executable file → Normal file
0
modelscope/models/nlp/txl_poem/gpt2/mpu/cross_entropy.py
Executable file → Normal file
@@ -1,86 +0,0 @@
|
||||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
|
||||
import gpt2.mpu as mpu
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
|
||||
class IdentityLayer(torch.nn.Module):
|
||||
|
||||
def __init__(self, size, scale=1.0):
|
||||
super(IdentityLayer, self).__init__()
|
||||
self.weight = torch.nn.Parameter(scale * torch.randn(size))
|
||||
|
||||
def forward(self):
|
||||
return self.weight
|
||||
|
||||
|
||||
def set_random_seed(seed):
|
||||
"""Set random seed for reproducability."""
|
||||
random.seed(seed)
|
||||
numpy.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
mpu.model_parallel_cuda_manual_seed(seed)
|
||||
|
||||
|
||||
def initialize_distributed(backend='nccl'):
|
||||
"""Initialize torch.distributed."""
|
||||
# Get local rank in case it is provided.
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--local_rank',
|
||||
type=int,
|
||||
default=None,
|
||||
help='local rank passed from distributed launcher')
|
||||
args = parser.parse_args()
|
||||
local_rank = args.local_rank
|
||||
|
||||
# Get rank and world size.
|
||||
rank = int(os.getenv('RANK', '0'))
|
||||
world_size = int(os.getenv('WORLD_SIZE', '1'))
|
||||
|
||||
print('> initializing torch.distributed with local rank: {}, '
|
||||
'rank: {}, world size: {}'.format(local_rank, rank, world_size))
|
||||
|
||||
# Set the device id.
|
||||
device = rank % torch.cuda.device_count()
|
||||
if local_rank is not None:
|
||||
device = local_rank
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
# Call the init process.
|
||||
init_method = 'tcp://'
|
||||
master_ip = os.getenv('MASTER_ADDR', 'localhost')
|
||||
master_port = os.getenv('MASTER_PORT', '6000')
|
||||
init_method += master_ip + ':' + master_port
|
||||
torch.distributed.init_process_group(
|
||||
backend=backend,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
init_method=init_method)
|
||||
|
||||
|
||||
def print_separator(message):
|
||||
torch.distributed.barrier()
|
||||
filler_len = (78 - len(message)) // 2
|
||||
filler = '-' * filler_len
|
||||
string = '\n' + filler + ' {} '.format(message) + filler
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print(string, flush=True)
|
||||
torch.distributed.barrier()
|
||||
@@ -1,106 +0,0 @@
|
||||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import sys
|
||||
|
||||
import gpt2.mpu as mpu
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from commons import (IdentityLayer, initialize_distributed, print_separator,
|
||||
set_random_seed)
|
||||
from mpu.cross_entropy import vocab_parallel_cross_entropy
|
||||
|
||||
sys.path.append('../..')
|
||||
|
||||
|
||||
def torch_cross_entropy(batch_size, seq_length, vocab_size, logits_scale,
|
||||
seed):
|
||||
set_random_seed(seed)
|
||||
identity = IdentityLayer((batch_size, seq_length, vocab_size),
|
||||
scale=logits_scale).cuda()
|
||||
logits = identity()
|
||||
target = torch.cuda.LongTensor(size=(batch_size,
|
||||
seq_length)).random_(0, vocab_size)
|
||||
loss = F.cross_entropy(
|
||||
logits.view(-1,
|
||||
logits.size()[-1]), target.view(-1),
|
||||
reduction='none').view_as(target).mean()
|
||||
loss.backward()
|
||||
return loss, identity.weight.grad
|
||||
|
||||
|
||||
def mpu_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed):
|
||||
set_random_seed(seed)
|
||||
identity = IdentityLayer((batch_size, seq_length, vocab_size),
|
||||
scale=logits_scale).cuda()
|
||||
logits = identity()
|
||||
logits_parallel = mpu.scatter_to_model_parallel_region(logits)
|
||||
target = torch.cuda.LongTensor(size=(batch_size,
|
||||
seq_length)).random_(0, vocab_size)
|
||||
loss = vocab_parallel_cross_entropy(logits_parallel, target).mean()
|
||||
loss.backward()
|
||||
return loss, identity.weight.grad
|
||||
|
||||
|
||||
def test_cross_entropy(model_parallel_size):
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing cross entropy with model parallel size {} ...'.format(
|
||||
model_parallel_size))
|
||||
|
||||
mpu.initialize_model_parallel(model_parallel_size)
|
||||
model_parallel_size = mpu.get_model_parallel_world_size()
|
||||
|
||||
batch_size = 13
|
||||
seq_length = 17
|
||||
vocab_size_per_partition = 11
|
||||
logits_scale = 1000.0
|
||||
vocab_size = vocab_size_per_partition * model_parallel_size
|
||||
seed = 1234
|
||||
|
||||
loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length,
|
||||
vocab_size, logits_scale,
|
||||
seed)
|
||||
loss_mpu, grad_mpu = mpu_cross_entropy(batch_size, seq_length, vocab_size,
|
||||
logits_scale, seed)
|
||||
|
||||
error = loss_torch.sub_(loss_mpu).abs().max()
|
||||
print(' max error in loss on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-6
|
||||
|
||||
error = grad_torch.sub_(grad_mpu).abs().max()
|
||||
print(' max error in grad on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-6
|
||||
|
||||
# Reset groups
|
||||
mpu.destroy_model_parallel()
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('>> passed the test :-)')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
initialize_distributed()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
|
||||
model_parallel_size = 1
|
||||
while model_parallel_size <= world_size:
|
||||
print_separator('test cross entropy')
|
||||
test_cross_entropy(model_parallel_size)
|
||||
model_parallel_size *= 2
|
||||
@@ -1,91 +0,0 @@
|
||||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import operator
|
||||
import sys
|
||||
|
||||
import gpt2.mpu as mpu
|
||||
import torch
|
||||
from commons import initialize_distributed, print_separator
|
||||
from mpu import data as data_utils
|
||||
|
||||
sys.path.append('../..')
|
||||
|
||||
|
||||
def test_boradcast_data(model_parallel_size):
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print(
|
||||
'> testing boradcast_data with model parallel size {} ...'.format(
|
||||
model_parallel_size))
|
||||
|
||||
mpu.initialize_model_parallel(model_parallel_size)
|
||||
torch.manual_seed(1234 + mpu.get_data_parallel_rank())
|
||||
model_parallel_size = mpu.get_model_parallel_world_size()
|
||||
|
||||
key_size_t = {
|
||||
'key1': [7, 11],
|
||||
'key2': [8, 2, 1],
|
||||
'key3': [13],
|
||||
'key4': [5, 1, 2],
|
||||
'key5': [5, 12]
|
||||
}
|
||||
keys = list(key_size_t.keys())
|
||||
|
||||
data = {}
|
||||
data_t = {}
|
||||
for key in key_size_t:
|
||||
data[key] = torch.LongTensor(size=key_size_t[key]).random_(0, 1000)
|
||||
data_t[key] = data[key].clone()
|
||||
data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000)
|
||||
data_t['keyX'] = data['keyX'].clone()
|
||||
if mpu.get_model_parallel_rank() != 0:
|
||||
data = None
|
||||
|
||||
data_utils._check_data_types(keys, data_t, torch.int64)
|
||||
key_size, key_numel, \
|
||||
total_numel = data_utils._build_key_size_numel_dictionaries(keys, data)
|
||||
for key in keys:
|
||||
assert key_size[key] == key_size_t[key]
|
||||
total_numel_t = 0
|
||||
for key in keys:
|
||||
target_size = functools.reduce(operator.mul, key_size_t[key], 1)
|
||||
assert key_numel[key] == target_size
|
||||
total_numel_t += target_size
|
||||
assert total_numel == total_numel_t
|
||||
|
||||
data_b = data_utils.broadcast_data(keys, data, torch.int64)
|
||||
for key in keys:
|
||||
tensor = data_t[key].cuda()
|
||||
assert data_b[key].sub(tensor).abs().max() == 0
|
||||
|
||||
# Reset groups
|
||||
mpu.destroy_model_parallel()
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('>> passed the test :-)')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
initialize_distributed()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
|
||||
model_parallel_size = 1
|
||||
while model_parallel_size <= world_size:
|
||||
print_separator('test test boradcast data')
|
||||
test_boradcast_data(model_parallel_size)
|
||||
model_parallel_size *= 2
|
||||
@@ -1,95 +0,0 @@
|
||||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
|
||||
import gpt2.mpu as mpu
|
||||
import torch
|
||||
from commons import initialize_distributed, print_separator
|
||||
|
||||
sys.path.append('../..')
|
||||
|
||||
|
||||
def test_initialize_model_parallel(model_parallel_size):
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing initialize_model_parallel with size {} ...'.format(
|
||||
model_parallel_size))
|
||||
model_parallel_size_ = min(model_parallel_size,
|
||||
torch.distributed.get_world_size())
|
||||
assert not mpu.model_parallel_is_initialized()
|
||||
mpu.initialize_model_parallel(model_parallel_size_)
|
||||
assert mpu.model_parallel_is_initialized()
|
||||
|
||||
# Checks.
|
||||
def check(group, world_size, rank):
|
||||
assert world_size == torch.distributed.get_world_size(group=group)
|
||||
assert rank == torch.distributed.get_rank(group=group)
|
||||
|
||||
# Model parallel.
|
||||
world_size = model_parallel_size_
|
||||
rank = torch.distributed.get_rank() % model_parallel_size_
|
||||
assert world_size == mpu.get_model_parallel_world_size()
|
||||
assert rank == mpu.get_model_parallel_rank()
|
||||
check(mpu.get_model_parallel_group(), world_size, rank)
|
||||
|
||||
# Data parallel.
|
||||
world_size = torch.distributed.get_world_size() // model_parallel_size_
|
||||
rank = torch.distributed.get_rank() // model_parallel_size
|
||||
assert world_size == mpu.get_data_parallel_world_size()
|
||||
assert rank == mpu.get_data_parallel_rank()
|
||||
check(mpu.get_data_parallel_group(), world_size, rank)
|
||||
|
||||
# Reset groups
|
||||
mpu.destroy_model_parallel()
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('>> passed the test :-)')
|
||||
|
||||
|
||||
def test_get_model_parallel_src_rank(model_parallel_size_):
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing get_model_parallel_src_rank with size {} ...'.format(
|
||||
model_parallel_size_))
|
||||
model_parallel_size = min(model_parallel_size_,
|
||||
torch.distributed.get_world_size())
|
||||
assert not mpu.model_parallel_is_initialized()
|
||||
mpu.initialize_model_parallel(model_parallel_size)
|
||||
assert mpu.model_parallel_is_initialized()
|
||||
|
||||
# Checks
|
||||
src_rank = torch.distributed.get_rank() - mpu.get_model_parallel_rank()
|
||||
assert mpu.get_model_parallel_src_rank() == src_rank
|
||||
|
||||
# Reset groups
|
||||
mpu.destroy_model_parallel()
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('>> passed the test :-)')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
initialize_distributed()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
model_parallel_size = 1
|
||||
while model_parallel_size <= world_size:
|
||||
print_separator('test initialize model parallel')
|
||||
test_initialize_model_parallel(model_parallel_size)
|
||||
print_separator('test model parallel source rank')
|
||||
test_get_model_parallel_src_rank(model_parallel_size)
|
||||
model_parallel_size *= 2
|
||||
@@ -1,533 +0,0 @@
|
||||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import sys
|
||||
|
||||
import gpt2.mpu as mpu
|
||||
import torch
|
||||
import torch.nn.init as init
|
||||
from commons import initialize_distributed, print_separator, set_random_seed
|
||||
from mpu import layers
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
sys.path.append('../..')
|
||||
|
||||
|
||||
def test_parallel_embedding(model_parallel_size):
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing parallel embedding with model parallel size {} ...'.
|
||||
format(model_parallel_size))
|
||||
|
||||
mpu.initialize_model_parallel(model_parallel_size)
|
||||
model_parallel_size = mpu.get_model_parallel_world_size()
|
||||
|
||||
batch_size = 17
|
||||
seq_length = 23
|
||||
vocab_size = 48
|
||||
hidden_size = 16
|
||||
seed = 1236
|
||||
|
||||
set_random_seed(123)
|
||||
input_data = torch.LongTensor(size=(batch_size, seq_length)).random_(
|
||||
0, vocab_size).cuda()
|
||||
loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda()
|
||||
|
||||
set_random_seed(seed)
|
||||
embedding_original = torch.nn.Embedding(vocab_size, hidden_size).cuda()
|
||||
|
||||
output = embedding_original(input_data)
|
||||
loss_original = torch.mul(output, loss_weight).sum()
|
||||
loss_original.backward()
|
||||
|
||||
set_random_seed(seed)
|
||||
embedding_parallel = layers.ParallelEmbedding(
|
||||
vocab_size, hidden_size, init_method=init.normal_).cuda()
|
||||
output = embedding_parallel(input_data)
|
||||
loss_parallel = torch.mul(output, loss_weight).sum()
|
||||
loss_parallel.backward()
|
||||
|
||||
set_random_seed(seed)
|
||||
embedding_vocab_parallel = layers.VocabParallelEmbedding(
|
||||
vocab_size, hidden_size, init_method=init.normal_).cuda()
|
||||
output = embedding_vocab_parallel(input_data)
|
||||
loss_vocab_parallel = torch.mul(output, loss_weight).sum()
|
||||
loss_vocab_parallel.backward()
|
||||
|
||||
torch.distributed.barrier()
|
||||
error = loss_parallel.sub(loss_original).abs()
|
||||
print(' error in loss (parallel) on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-12, 'error: {}'.format(error)
|
||||
|
||||
torch.distributed.barrier()
|
||||
error = loss_vocab_parallel.sub(loss_original).abs()
|
||||
print(' error in loss (vocab parallel) on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-12, 'error: {}'.format(error)
|
||||
|
||||
weight_grad_orig = torch.split(embedding_original.weight.grad,
|
||||
hidden_size // model_parallel_size,
|
||||
1)[mpu.get_model_parallel_rank()]
|
||||
error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max()
|
||||
print(' error in grad (parallel) on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-12, 'error: {}'.format(error)
|
||||
|
||||
weight_grad_orig = torch.split(embedding_original.weight.grad,
|
||||
vocab_size // model_parallel_size,
|
||||
0)[mpu.get_model_parallel_rank()]
|
||||
error = embedding_vocab_parallel.weight.grad.sub(
|
||||
weight_grad_orig).abs().max()
|
||||
print(' error in grad (vocab parallel) on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-12, 'error: {}'.format(error)
|
||||
|
||||
# Reset groups
|
||||
mpu.destroy_model_parallel()
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('>> passed the test :-)')
|
||||
|
||||
|
||||
def test_initialize_affine_weight(model_parallel_size):
|
||||
|
||||
mpu.initialize_model_parallel(model_parallel_size)
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing initialize_affine_weight with model parallel '
|
||||
'size: {}'.format(model_parallel_size))
|
||||
model_parallel_size = mpu.get_model_parallel_world_size()
|
||||
|
||||
seed = 12345
|
||||
input_size_coeff = 13
|
||||
input_size = input_size_coeff * model_parallel_size
|
||||
output_size_coeff = 17
|
||||
output_size = output_size_coeff * model_parallel_size
|
||||
|
||||
# ---------------
|
||||
# Column parallel
|
||||
# ---------------
|
||||
weight = torch.empty(output_size_coeff, input_size)
|
||||
set_random_seed(seed)
|
||||
layers._initialize_affine_weight(weight, output_size, input_size,
|
||||
output_size_coeff, 0,
|
||||
torch.nn.init.normal_)
|
||||
# Target.
|
||||
set_random_seed(seed)
|
||||
master_weight = torch.empty(output_size, input_size)
|
||||
torch.nn.init.normal_(master_weight)
|
||||
rank = mpu.get_model_parallel_rank()
|
||||
my_weight = torch.split(
|
||||
master_weight, output_size_coeff, dim=0)[rank].contiguous().clone()
|
||||
|
||||
# Compare.
|
||||
error = weight.sub(my_weight).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' column parallel max error (should be zero) on global rank '
|
||||
'{}: {}'.format(torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-6
|
||||
|
||||
# ------------
|
||||
# Row parallel
|
||||
# ------------
|
||||
weight = torch.empty(output_size, input_size_coeff)
|
||||
set_random_seed(seed)
|
||||
mpu.layers._initialize_affine_weight(weight, output_size, input_size,
|
||||
input_size_coeff, 1,
|
||||
torch.nn.init.normal_)
|
||||
# Target.
|
||||
set_random_seed(seed)
|
||||
master_weight = torch.empty(output_size, input_size)
|
||||
torch.nn.init.normal_(master_weight)
|
||||
rank = mpu.get_model_parallel_rank()
|
||||
my_weight = torch.split(
|
||||
master_weight, input_size_coeff, dim=1)[rank].contiguous().clone()
|
||||
|
||||
# Compare.
|
||||
error = weight.sub(my_weight).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' row parallel max error (should be zero) on global rank '
|
||||
'{}: {}'.format(torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-6
|
||||
|
||||
# Reset groups
|
||||
mpu.destroy_model_parallel()
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print(' >> passed the test :-)')
|
||||
|
||||
|
||||
class IdentityLayer2D(torch.nn.Module):
|
||||
|
||||
def __init__(self, m, n):
|
||||
super(IdentityLayer2D, self).__init__()
|
||||
self.weight = Parameter(torch.Tensor(m, n))
|
||||
torch.nn.init.xavier_normal_(self.weight)
|
||||
|
||||
def forward(self):
|
||||
return self.weight
|
||||
|
||||
|
||||
def test_column_parallel_linear(model_parallel_size):
|
||||
|
||||
mpu.initialize_model_parallel(model_parallel_size)
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing ColumnParallelLinear with model parallel '
|
||||
'size: {}'.format(model_parallel_size))
|
||||
model_parallel_size = mpu.get_model_parallel_world_size()
|
||||
|
||||
seed = 12345
|
||||
set_random_seed(seed)
|
||||
input_size_coeff = 13
|
||||
input_size = input_size_coeff * model_parallel_size
|
||||
output_size_coeff = 17
|
||||
output_size = output_size_coeff * model_parallel_size
|
||||
batch_size = 7
|
||||
|
||||
# Network
|
||||
identity_layer = IdentityLayer2D(batch_size, input_size).cuda()
|
||||
linear_layer = mpu.ColumnParallelLinear(
|
||||
input_size, output_size, keep_master_weight_for_test=True).cuda()
|
||||
loss_weight = torch.randn([batch_size, output_size]).cuda()
|
||||
# Forward
|
||||
input_ = identity_layer()
|
||||
output = linear_layer(input_)
|
||||
loss = torch.mul(output, loss_weight).sum()
|
||||
# Backward
|
||||
loss.backward()
|
||||
|
||||
# Values.
|
||||
dLdY = loss_weight
|
||||
X = identity_layer.weight
|
||||
A = linear_layer.master_weight.cuda()
|
||||
dLdA = torch.matmul(dLdY.t(), X)
|
||||
dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
|
||||
dLdX = torch.matmul(dLdY, A)
|
||||
|
||||
rank = mpu.get_model_parallel_rank()
|
||||
my_dLdA = torch.split(
|
||||
dLdA, output_size_coeff, dim=0)[rank].contiguous().clone()
|
||||
error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' error in dLdA on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-6
|
||||
|
||||
my_dLdb = torch.split(
|
||||
dLdb, output_size_coeff, dim=0)[rank].contiguous().clone()
|
||||
error = my_dLdb.sub(linear_layer.bias.grad).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' error in dLdb on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-6
|
||||
|
||||
error = dLdX.sub(identity_layer.weight.grad).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' error in dLdX on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-6
|
||||
|
||||
# Reset groups
|
||||
mpu.destroy_model_parallel()
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print(' >> passed the test :-)')
|
||||
|
||||
|
||||
def test_row_parallel_linear(model_parallel_size):
|
||||
|
||||
mpu.initialize_model_parallel(model_parallel_size)
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing RowParallelLinear with model parallel '
|
||||
'size: {}'.format(model_parallel_size))
|
||||
model_parallel_size = mpu.get_model_parallel_world_size()
|
||||
|
||||
seed = 12345
|
||||
set_random_seed(seed)
|
||||
input_size_coeff = 13
|
||||
input_size = input_size_coeff * model_parallel_size
|
||||
output_size_coeff = 17
|
||||
output_size = output_size_coeff * model_parallel_size
|
||||
batch_size = 7
|
||||
|
||||
# Network
|
||||
identity_layer = IdentityLayer2D(batch_size, input_size).cuda()
|
||||
linear_layer = mpu.RowParallelLinear(
|
||||
input_size, output_size, keep_master_weight_for_test=True).cuda()
|
||||
loss_weight = torch.randn([batch_size, output_size]).cuda()
|
||||
# Forward
|
||||
input_ = identity_layer()
|
||||
output = linear_layer(input_)
|
||||
loss = torch.mul(output, loss_weight).sum()
|
||||
# Backward
|
||||
loss.backward()
|
||||
|
||||
# Values.
|
||||
dLdY = loss_weight
|
||||
X = identity_layer.weight
|
||||
A = linear_layer.master_weight.cuda()
|
||||
dLdA = torch.matmul(dLdY.t(), X)
|
||||
dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
|
||||
dLdX = torch.matmul(dLdY, A)
|
||||
|
||||
rank = mpu.get_model_parallel_rank()
|
||||
my_dLdA = torch.split(
|
||||
dLdA, input_size_coeff, dim=1)[rank].contiguous().clone()
|
||||
error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' error in dLdA on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-6
|
||||
|
||||
error = dLdb.sub(linear_layer.bias.grad).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' error in dLdb on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-6
|
||||
|
||||
error = dLdX.sub(identity_layer.weight.grad).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' error in dLdX on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-6
|
||||
|
||||
# Reset groups
|
||||
mpu.destroy_model_parallel()
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print(' >> passed the test :-)')
|
||||
|
||||
|
||||
class IdentityLayer3D(torch.nn.Module):
|
||||
|
||||
def __init__(self, m, n, k):
|
||||
super(IdentityLayer3D, self).__init__()
|
||||
self.weight = Parameter(torch.Tensor(m, n, k))
|
||||
torch.nn.init.xavier_normal_(self.weight)
|
||||
|
||||
def forward(self):
|
||||
return self.weight
|
||||
|
||||
|
||||
def parallel_self_attention(model_parallel_size, num_att_heads_per_partition,
|
||||
hidden_size_per_att_head, dropout_prob, batch_size,
|
||||
sequence_length):
|
||||
mpu.initialize_model_parallel(model_parallel_size)
|
||||
model_parallel_size = mpu.get_model_parallel_world_size()
|
||||
|
||||
seed = 12345
|
||||
set_random_seed(seed)
|
||||
|
||||
num_att_heads = num_att_heads_per_partition * \
|
||||
torch.distributed.get_world_size() # noqa
|
||||
hidden_size = hidden_size_per_att_head * num_att_heads
|
||||
|
||||
# Network
|
||||
identity_layer = IdentityLayer3D(batch_size, sequence_length,
|
||||
hidden_size).cuda()
|
||||
attention_layer = mpu.BertParallelSelfAttention(hidden_size, num_att_heads,
|
||||
dropout_prob).cuda()
|
||||
loss_weight = torch.randn([batch_size, sequence_length,
|
||||
hidden_size]).cuda()
|
||||
attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda()
|
||||
# Forward
|
||||
input_ = identity_layer()
|
||||
output = attention_layer(input_, attention_mask)
|
||||
loss = torch.mul(output, loss_weight).sum()
|
||||
# Backward
|
||||
loss.backward()
|
||||
|
||||
rank = mpu.get_model_parallel_rank()
|
||||
mpu.destroy_model_parallel()
|
||||
return rank, hidden_size, model_parallel_size, loss, \
|
||||
attention_layer, identity_layer
|
||||
|
||||
|
||||
def test_parallel_self_attention(model_parallel_size):
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing ParallelSelfAttention with model parallel '
|
||||
'size: {}'.format(model_parallel_size))
|
||||
|
||||
num_att_heads_per_partition = 3
|
||||
hidden_size_per_att_head = 7
|
||||
dropout_prob = 0.0 # has to be zero
|
||||
batch_size = 5
|
||||
sequence_length = 13
|
||||
|
||||
rank_1, hideen_size_1, model_parallel_size_1, loss_1, \
|
||||
attention_layer_1, identity_layer_1 = parallel_self_attention(
|
||||
1, num_att_heads_per_partition,
|
||||
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length) # noqa
|
||||
|
||||
rank, hidden_size, model_parallel_size, loss, \
|
||||
attention_layer, identity_layer = parallel_self_attention(
|
||||
model_parallel_size, num_att_heads_per_partition,
|
||||
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length) # noqa
|
||||
assert hideen_size_1 == hidden_size
|
||||
|
||||
error = loss_1.sub(loss).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' loss error on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 5.0e-6
|
||||
|
||||
my_lin_grad_list = torch.split(
|
||||
attention_layer_1.query_key_value.weight.grad,
|
||||
hidden_size // model_parallel_size, 0)[rank::model_parallel_size]
|
||||
my_lin_grad = torch.cat(my_lin_grad_list, dim=0)
|
||||
error = my_lin_grad.sub(
|
||||
attention_layer.query_key_value.weight.grad).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' weight gradient error on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 5.0e-6
|
||||
|
||||
error = identity_layer_1.weight.grad.sub(
|
||||
identity_layer.weight.grad).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' input gradient error on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 5.0e-6
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print(' >> passed the test :-)')
|
||||
|
||||
|
||||
def parallel_transformer(model_parallel_size, num_att_heads_per_partition,
|
||||
hidden_size_per_att_head, batch_size,
|
||||
sequence_length):
|
||||
|
||||
mpu.initialize_model_parallel(model_parallel_size)
|
||||
model_parallel_size = mpu.get_model_parallel_world_size()
|
||||
|
||||
seed = 12345
|
||||
set_random_seed(seed)
|
||||
|
||||
num_att_heads = num_att_heads_per_partition * \
|
||||
torch.distributed.get_world_size() # noqa
|
||||
hidden_size = hidden_size_per_att_head * num_att_heads
|
||||
intermediate_size = 4 * hidden_size
|
||||
|
||||
# Network
|
||||
identity_layer = IdentityLayer3D(batch_size, sequence_length,
|
||||
hidden_size).cuda()
|
||||
transformer_layer = mpu.BertParallelTransformerLayer(
|
||||
hidden_size, intermediate_size, num_att_heads, 0.0, 0.0,
|
||||
torch.nn.functional.relu, 1.0e-5).cuda()
|
||||
|
||||
loss_weight = torch.randn([batch_size, sequence_length,
|
||||
hidden_size]).cuda()
|
||||
attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda()
|
||||
# Forward
|
||||
input_ = identity_layer()
|
||||
output = transformer_layer(input_, attention_mask)
|
||||
loss = torch.mul(output, loss_weight).sum()
|
||||
# Backward
|
||||
loss.backward()
|
||||
|
||||
rank = mpu.get_model_parallel_rank()
|
||||
mpu.destroy_model_parallel()
|
||||
return rank, hidden_size, model_parallel_size, loss, \
|
||||
transformer_layer, identity_layer
|
||||
|
||||
|
||||
def test_parallel_transformer_layer(model_parallel_size):
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing ParallelTransformerLayer with model parallel '
|
||||
'size: {}'.format(model_parallel_size))
|
||||
|
||||
num_att_heads_per_partition = 3
|
||||
hidden_size_per_att_head = 7
|
||||
batch_size = 5
|
||||
sequence_length = 13
|
||||
|
||||
rank_1, hidden_size_1, model_parallel_size_1, loss_1, \
|
||||
transformer_layer_1, identity_layer_1 = parallel_transformer(
|
||||
1, num_att_heads_per_partition,
|
||||
hidden_size_per_att_head, batch_size, sequence_length)
|
||||
|
||||
rank, hidden_size, model_parallel_size, loss, \
|
||||
transformer_layer, identity_layer = parallel_transformer(
|
||||
model_parallel_size, num_att_heads_per_partition,
|
||||
hidden_size_per_att_head, batch_size, sequence_length)
|
||||
|
||||
error = loss_1.sub(loss).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' loss error on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 5.0e-5, 'error: {}'.format(error)
|
||||
|
||||
error = identity_layer_1.weight.grad.sub(
|
||||
identity_layer.weight.grad).abs().max()
|
||||
torch.distributed.barrier()
|
||||
print(' input gradient error on global rank {}: {}'.format(
|
||||
torch.distributed.get_rank(), error))
|
||||
assert error < 5.0e-5, 'error: {}'.format(error)
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print(' >> passed the test :-)')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
initialize_distributed()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
|
||||
print_separator('test initialize affine weight')
|
||||
model_parallel_size = 1
|
||||
while model_parallel_size <= world_size:
|
||||
test_initialize_affine_weight(model_parallel_size)
|
||||
model_parallel_size *= 2
|
||||
|
||||
model_parallel_size = 1
|
||||
while model_parallel_size <= world_size:
|
||||
print_separator('test parallel embedding')
|
||||
test_parallel_embedding(model_parallel_size)
|
||||
model_parallel_size *= 2
|
||||
|
||||
print_separator('test column-parallel linear')
|
||||
model_parallel_size = 1
|
||||
while model_parallel_size <= world_size:
|
||||
test_column_parallel_linear(model_parallel_size)
|
||||
model_parallel_size *= 2
|
||||
|
||||
print_separator('test row-parallel linear')
|
||||
model_parallel_size = 1
|
||||
while model_parallel_size <= world_size:
|
||||
test_row_parallel_linear(model_parallel_size)
|
||||
model_parallel_size *= 2
|
||||
|
||||
print_separator('test parallel self-attention')
|
||||
model_parallel_size = 1
|
||||
while model_parallel_size <= world_size:
|
||||
test_parallel_self_attention(model_parallel_size)
|
||||
model_parallel_size *= 2
|
||||
|
||||
print_separator('test parallel transformer')
|
||||
model_parallel_size = 1
|
||||
while model_parallel_size <= world_size:
|
||||
test_parallel_transformer_layer(model_parallel_size)
|
||||
model_parallel_size *= 2
|
||||
@@ -1,206 +0,0 @@
|
||||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
|
||||
import gpt2.mpu as mpu
|
||||
import torch
|
||||
from commons import initialize_distributed, print_separator
|
||||
|
||||
sys.path.append('../..')
|
||||
|
||||
|
||||
def test_set_cuda_rng_state(model_parallel_size):
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing set_rng_state with size {} ...'.format(
|
||||
model_parallel_size))
|
||||
|
||||
mpu.initialize_model_parallel(model_parallel_size)
|
||||
model_parallel_size = mpu.get_model_parallel_world_size()
|
||||
|
||||
size = 123
|
||||
seed = 1234 # noqa
|
||||
torch.cuda.manual_seed(1234)
|
||||
tensor = torch.cuda.FloatTensor(size)
|
||||
|
||||
# Get the state
|
||||
rng_state = torch.cuda.get_rng_state()
|
||||
rng_state_copy = rng_state.clone()
|
||||
|
||||
# Do some stuff.
|
||||
for _ in range(5):
|
||||
torch.randn(size, out=tensor)
|
||||
result_1 = tensor.clone()
|
||||
|
||||
assert rng_state.sub(rng_state_copy).max() == 0
|
||||
assert torch.cuda.get_rng_state().sub(rng_state_copy).max() > 0
|
||||
|
||||
# State should be different.
|
||||
new_rng_state = torch.cuda.get_rng_state()
|
||||
max_diff = new_rng_state.sub(rng_state).max()
|
||||
print(
|
||||
' max diff in rng state (should be non-zero) on global rank {}: {}'.
|
||||
format(torch.distributed.get_rank(), max_diff))
|
||||
assert max_diff > 0
|
||||
|
||||
# Reset the rng state and do the same stuff.
|
||||
mpu.random._set_cuda_rng_state(rng_state)
|
||||
for _ in range(5):
|
||||
torch.randn(size, out=tensor)
|
||||
mpu.random._set_cuda_rng_state(rng_state)
|
||||
for _ in range(5):
|
||||
torch.randn(size, out=tensor)
|
||||
result_2 = tensor.clone()
|
||||
|
||||
# Results should be the same
|
||||
error = result_2.sub(result_1).abs().max()
|
||||
print(' max error in generated tensors (should be zero) on '
|
||||
'global rank {}: {}'.format(torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-6
|
||||
|
||||
# Input state should have remained intact.
|
||||
error = rng_state.sub(rng_state_copy).max()
|
||||
print(' max error in rng state (should be zero) on global rank {}: {}'.
|
||||
format(torch.distributed.get_rank(), error))
|
||||
assert error == 0
|
||||
|
||||
# Reset groups
|
||||
mpu.destroy_model_parallel()
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('>> passed the test :-)')
|
||||
|
||||
|
||||
def test_cuda_rng_tracker(model_parallel_size):
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing cuda rng tracker with size {} ...'.format(
|
||||
model_parallel_size))
|
||||
|
||||
mpu.initialize_model_parallel(model_parallel_size)
|
||||
model_parallel_size = mpu.get_model_parallel_world_size()
|
||||
|
||||
seed_1 = 1234
|
||||
seed_2 = 4321
|
||||
size = [12, 21]
|
||||
tensor = torch.cuda.FloatTensor(size)
|
||||
|
||||
# Set to seed_1 and generate two tensors.
|
||||
torch.cuda.manual_seed(seed_1)
|
||||
torch.randn(size, out=tensor)
|
||||
target_11 = tensor.clone()
|
||||
torch.randn(size, out=tensor)
|
||||
target_12 = tensor.clone()
|
||||
|
||||
# Set to seed_2 and generate two tensors.
|
||||
torch.cuda.manual_seed(seed_2)
|
||||
torch.randn(size, out=tensor)
|
||||
target_21 = tensor.clone()
|
||||
torch.randn(size, out=tensor)
|
||||
target_22 = tensor.clone()
|
||||
|
||||
# Now if we interleave seed_1 and seed_2,
|
||||
# we should still get the same tensors
|
||||
torch.cuda.manual_seed(seed_1)
|
||||
mpu.get_cuda_rng_tracker().add('test', seed_2)
|
||||
|
||||
torch.randn(size, out=tensor)
|
||||
result_11 = tensor.clone()
|
||||
|
||||
with mpu.get_cuda_rng_tracker().fork('test'):
|
||||
torch.randn(size, out=tensor)
|
||||
result_21 = tensor.clone()
|
||||
|
||||
torch.randn(size, out=tensor)
|
||||
result_12 = tensor.clone()
|
||||
|
||||
with mpu.get_cuda_rng_tracker().fork('test'):
|
||||
torch.randn(size, out=tensor)
|
||||
result_22 = tensor.clone()
|
||||
|
||||
diff = result_11.sub(result_21).abs().max()
|
||||
diff = min(diff, result_12.sub(result_22).abs().max())
|
||||
print(' max diff in generated tensors (should be non-zero) on '
|
||||
'global rank {}: {}'.format(torch.distributed.get_rank(), diff))
|
||||
assert diff > 1.0e-6
|
||||
error = max(
|
||||
result_11.sub(target_11).abs().max(),
|
||||
result_12.sub(target_12).abs().max())
|
||||
error = max(error, result_21.sub(target_21).abs().max())
|
||||
error = max(error, result_22.sub(target_22).abs().max())
|
||||
print(' max error in generated tensors (should be zero) on '
|
||||
'global rank {}: {}'.format(torch.distributed.get_rank(), error))
|
||||
assert error < 1.0e-6
|
||||
|
||||
# Reset the tracker
|
||||
mpu.get_cuda_rng_tracker().reset()
|
||||
|
||||
# Reset groups
|
||||
mpu.destroy_model_parallel()
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('>> passed the test :-)')
|
||||
|
||||
|
||||
def test_model_parallel_cuda_manual_seed(model_parallel_size):
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('> testing model parallel cuda manual seed with size {} ...'.
|
||||
format(model_parallel_size))
|
||||
|
||||
mpu.initialize_model_parallel(model_parallel_size)
|
||||
model_parallel_size = mpu.get_model_parallel_world_size()
|
||||
|
||||
mpu.model_parallel_cuda_manual_seed(12345)
|
||||
assert torch.cuda.initial_seed() == 12345
|
||||
with mpu.get_cuda_rng_tracker().fork():
|
||||
assert torch.cuda.initial_seed() == (12345 + 2718
|
||||
+ mpu.get_model_parallel_rank())
|
||||
|
||||
# Reset the tracker
|
||||
mpu.get_cuda_rng_tracker().reset()
|
||||
|
||||
# Reset groups
|
||||
mpu.destroy_model_parallel()
|
||||
|
||||
torch.distributed.barrier()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('>> passed the test :-)')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
initialize_distributed()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
|
||||
model_parallel_size = 1
|
||||
while model_parallel_size <= world_size:
|
||||
print_separator('test set rng state')
|
||||
test_set_cuda_rng_state(model_parallel_size)
|
||||
model_parallel_size *= 2
|
||||
|
||||
model_parallel_size = 1
|
||||
while model_parallel_size <= world_size:
|
||||
print_separator('test cuda rng tracker')
|
||||
test_cuda_rng_tracker(model_parallel_size)
|
||||
model_parallel_size *= 2
|
||||
|
||||
model_parallel_size = 1
|
||||
while model_parallel_size <= world_size:
|
||||
print_separator('test model parallel cuda manual seed')
|
||||
test_model_parallel_cuda_manual_seed(model_parallel_size)
|
||||
model_parallel_size *= 2
|
||||
Reference in New Issue
Block a user