[to #42322933] Add beam search and pair finetune for GPT-3

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11397726

* test finetune weather

* support ppl and generation metrics
This commit is contained in:
hemu.zp
2023-01-11 22:04:11 +08:00
committed by wenmeng.zwm
parent 0e54e80b23
commit a277b343af
5 changed files with 271 additions and 83 deletions

View File

@@ -798,8 +798,8 @@ class GPT3Model(PreTrainedModel):
if labels is not None:
# [b s] => [s b]
labels = labels.transpose(0, 1).contiguous()
losses = mpu.vocab_parallel_cross_entropy(logits_parallel.float(),
labels)
losses = mpu.vocab_parallel_cross_entropy(
logits_parallel.clone().float(), labels)
# [s b] => [b s]
losses = losses.transpose(0, 1).contiguous()
@@ -1011,7 +1011,8 @@ class DistributedGPT3(TorchModel):
attention_mask=None,
position_ids=None,
labels=None,
prompt_length=None):
prompt_length=None,
is_pair=(False, )):
logits, losses = self.dist_model(
tokens,
@@ -1026,6 +1027,9 @@ class DistributedGPT3(TorchModel):
else:
loss_mask = torch.ones(
tokens.size(), dtype=torch.float, device=tokens.device)
if is_pair[0]:
for i, length in enumerate(prompt_length):
loss_mask[i, :length] = 0
losses = losses.float()
loss_mask = loss_mask.view(-1).float()
@@ -1033,13 +1037,13 @@ class DistributedGPT3(TorchModel):
return TextGenerationModelOutput(logits=logits, loss=loss)
def generate(self,
tokens,
temperature=1.0,
use_eod_token_for_early_termination=True,
stop_on_double_eol=False,
stop_on_eol=False,
**kwargs):
def sample(self,
tokens,
temperature=1.0,
use_eod_token_for_early_termination=True,
stop_on_double_eol=False,
stop_on_eol=False,
**kwargs):
batch_size = tokens.size(0)
lengths = kwargs.pop(
'prompt_length',
@@ -1074,68 +1078,182 @@ class DistributedGPT3(TorchModel):
# Run infernece
# =============
with torch.no_grad():
attention_mask, position_ids = \
GPT3Model.build_attention_mask_and_position_ids(tokens)
prev_context_length = 0
for context_length in range(min_prompt_length,
max_sequence_length):
attention_mask, position_ids = \
GPT3Model.build_attention_mask_and_position_ids(tokens)
prev_context_length = 0
for context_length in range(min_prompt_length, max_sequence_length):
# Pick the slice that we need to pass through the network.
tokens2use = tokens[:, prev_context_length:context_length]
positions2use = position_ids[:, prev_context_length:
context_length]
attention_mask2use = attention_mask[
..., prev_context_length:context_length, :context_length]
# Pick the slice that we need to pass through the network.
tokens2use = tokens[:, prev_context_length:context_length]
positions2use = position_ids[:, prev_context_length:context_length]
attention_mask2use = attention_mask[
..., prev_context_length:context_length, :context_length]
# logits will be meanigful only in the last pipeline stage.
logits = self(tokens2use, attention_mask2use,
positions2use).logits
# logits will be meanigful only in the last pipeline stage.
logits = self(tokens2use, attention_mask2use, positions2use).logits
# Sample.
last_token_logits = logits[:, -1, :]
new_sample = sample(
last_token_logits,
top_k=self.config.top_k,
top_p=self.config.top_p,
temperature=temperature,
vocab_size=self.config.vocab_size)
# Sample.
last_token_logits = logits[:, -1, :]
new_sample = sample(
last_token_logits,
top_k=self.config.top_k,
top_p=self.config.top_p,
temperature=temperature,
vocab_size=self.config.vocab_size)
# If a prompt length is smaller or equal th current context
# length, it means we have started generating tokens
started = lengths <= context_length
# Update the tokens.
tokens[started, context_length] = new_sample[started]
# If a prompt length is smaller or equal th current context
# length, it means we have started generating tokens
started = lengths <= context_length
# Update the tokens.
tokens[started, context_length] = new_sample[started]
# Update the context length for the next token generation.
prev_context_length = context_length
# Update the context length for the next token generation.
prev_context_length = context_length
# instead tokenization should be in the inference loop so stop sequences can be used
if stop_on_double_eol:
hit_double_eol = (new_sample
== 628).byte() & started.byte()
hit_two_eols = (new_sample == 198).byte() & (
tokens[:, context_length - 1]
== 198).byte() & started.byte()
done_token = hit_double_eol | hit_two_eols
elif stop_on_eol:
hit_double_eol = (new_sample
== 628).byte() & started.byte()
hit_eol = (new_sample == 198).byte() & started.byte()
done_token = hit_double_eol | hit_eol
else:
done_token = (new_sample == termination_id).byte() & \
started.byte()
# instead tokenization should be in the inference loop so stop sequences can be used
if stop_on_double_eol:
hit_double_eol = (new_sample == 628).byte() & started.byte()
hit_two_eols = (new_sample == 198).byte() & (
tokens[:,
context_length - 1] == 198).byte() & started.byte()
done_token = hit_double_eol | hit_two_eols
elif stop_on_eol:
hit_double_eol = (new_sample == 628).byte() & started.byte()
hit_eol = (new_sample == 198).byte() & started.byte()
done_token = hit_double_eol | hit_eol
else:
done_token = (new_sample == termination_id).byte() & \
started.byte()
is_generation_done = is_generation_done | done_token
done = torch.all(is_generation_done)
is_generation_done = is_generation_done | done_token
done = torch.all(is_generation_done)
if use_eod_token_for_early_termination and done:
break
if use_eod_token_for_early_termination and done:
break
tokens = tokens[:, :(context_length + 1)]
return TokenGeneratorOutput(sequences=tokens)
def beam_search(self, tokens, beam_size=5, num_return_gen=1, **kwargs):
batch_size = tokens.size(0)
assert (batch_size == 1)
prompt_length = kwargs.pop(
'prompt_length',
torch.tensor([tokens.size(1)], device=tokens.device)).item()
stop_token = self.config.eod_id
pads = torch.ones(
1, self.config.tokens_to_generate,
device=tokens.device).long() * stop_token
tokens = torch.cat((tokens, pads), dim=-1)
final_sequence_length = tokens.size(1)
final_sequence_length = min(final_sequence_length,
self.config.max_position_embeddings)
# If the context is too big, this happens
if prompt_length >= final_sequence_length:
raise ValueError('context length + tokens_to_generate too large')
# Initialize inference parameters.
self.inference_params = InferenceParams(beam_size,
final_sequence_length)
beam_hyp = BeamHypotheses(beam_size)
done = False
scores = torch.zeros(
beam_size, dtype=torch.float32,
device=torch.cuda.current_device()).unsqueeze(1)
# =============
# Run infernece
# =============
tokens = tokens.repeat(beam_size, 1)
attention_mask, position_ids = \
GPT3Model.build_attention_mask_and_position_ids(tokens)
prev_context_length = 0
for context_length in range(prompt_length, final_sequence_length):
# Pick the slice that we need to pass through the network.
tokens2use = tokens[:, prev_context_length:context_length]
positions2use = position_ids[:, prev_context_length:context_length]
attention_mask2use = attention_mask[
..., prev_context_length:context_length, :context_length]
# logits will be meanigful only in the last pipeline stage.
logits = self(tokens2use, attention_mask2use, positions2use).logits
vocab_size = logits.size(2)
log_probs = F.log_softmax(logits, dim=2)
new_scores = log_probs[:, -1, :] + scores
if context_length == prompt_length: # if this is the first one
sorted_scores, indices = torch.sort(
new_scores[0, :], descending=True)
else:
sorted_scores, indices = torch.sort(
new_scores.view(-1), descending=True)
best_beam_ids = torch.div(indices[:2 * beam_size],
vocab_size).trunc().long()
best_words = indices[:2 * beam_size] % vocab_size
best_scores = sorted_scores[:2 * beam_size]
next_beams = []
for beam_token_rank, (token_id, beam_score, beam_id) in enumerate(
zip(best_words, best_scores, best_beam_ids)):
if token_id.item() == stop_token:
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= beam_size
if is_beam_token_worse_than_top_num_beams:
continue
beam_hyp.add(tokens[beam_id].clone(), beam_score,
context_length + 1 - prompt_length)
else:
# add next predicted token since it is not eos_token
next_beams.append((token_id, beam_score, beam_id))
if len(next_beams) == beam_size:
break
if beam_hyp.is_done(best_scores.max().item(),
context_length + 1 - prompt_length):
done = True
break
best_batches = tokens.new([item[2] for item in next_beams])
tokens = tokens[best_batches, :]
tokens[:, context_length] = tokens.new(
[item[0] for item in next_beams])
scores = scores.new([item[1] for item in next_beams]).unsqueeze(1)
# set inference key values to make it consistent with best beam index
self.inference_params.swap_key_value_dict(best_batches)
# Update the context length for the next token generation.
prev_context_length = context_length
# if cannot find stop token, add open beams to hyps
if not done:
for beam_id in range(beam_size):
beam_hyp.add(tokens[beam_id].clone(), scores[beam_id],
context_length + 1 - prompt_length)
# rank based on scores
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0], reverse=True)
num_return_gen = min(num_return_gen, len(sorted_hyps))
scores = [sorted_hyps[i][0] for i in range(num_return_gen)]
tokens = [sorted_hyps[i][1] for i in range(num_return_gen)]
scores = torch.stack(scores, dim=0)
tokens = torch.stack(tokens, dim=0)
return TokenGeneratorOutput(sequences=tokens, scores=scores)
@torch.no_grad()
def generate(self, tokens, do_sample=True, *args, **kwargs):
if do_sample:
return self.sample(tokens, *args, **kwargs)
else:
return self.beam_search(tokens, *args, **kwargs)
def state_dict(self, destination=None, prefix='', keep_vars=False):
return self.dist_model.state_dict(destination, prefix, keep_vars)
@@ -1154,3 +1272,59 @@ class DistributedGPT3(TorchModel):
return super().save_pretrained(target_folder, save_checkpoint_names,
save_checkpoint, config, **kwargs)
class BeamHypotheses:
def __init__(self,
num_beams: int,
length_penalty: float = 1.0,
early_stopping: bool = False):
"""
Initialize n-best list of hypotheses.
"""
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.num_beams = num_beams
self.beams = []
self.worst_score = 1e9
def __len__(self):
"""
Number of hypotheses in the list.
"""
return len(self.beams)
def add(self,
hyp: torch.LongTensor,
sum_logprobs: float,
beam_indices: Optional[torch.LongTensor] = None):
"""
Add a new hypothesis to the list.
"""
score = sum_logprobs / (hyp.shape[-1]**self.length_penalty)
if len(self) < self.num_beams or score > self.worst_score:
self.beams.append((score, hyp, beam_indices))
if len(self) > self.num_beams:
sorted_next_scores = sorted([
(s, idx) for idx, (s, _, _) in enumerate(self.beams)
])
del self.beams[sorted_next_scores[0][1]]
self.worst_score = sorted_next_scores[1][0]
else:
self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool:
"""
If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
one in the heap, then we are done with this sentence.
"""
if len(self) < self.num_beams:
return False
elif self.early_stopping:
return True
else:
cur_score = best_sum_logprobs / cur_len**self.length_penalty
ret = self.worst_score >= cur_score
return ret

View File

@@ -27,7 +27,7 @@ class GPT3ForTextGeneration(TorchModel):
# Temporarily compatible with DistributedGPT3 and GPT3Model,
# the base/large model based on GPT3Model will be replaced in the future,
# and GPT3Model will be deprecated
if 'model_parallel_size' in kwargs:
if 'world_size' in kwargs:
from modelscope.models.nlp import DistributedGPT3
self.model = DistributedGPT3(model_dir, **kwargs)
else:

View File

@@ -194,16 +194,11 @@ class TextGenerationJiebaPreprocessor(TextGenerationPreprocessorBase):
model_dir: str,
mode: str = ModeKeys.INFERENCE,
src_txt='src_txt',
tgt_txt=None,
tgt_txt='tgt_txt',
sequence_length: int = 128,
use_fast=None):
from modelscope.models.nlp.gpt3 import JiebaBPETokenizer
super().__init__(mode, src_txt, tgt_txt)
if self.tgt_txt is not None:
logger.warning(
f'TextGenerationJiebaPreprocessor currently does not support training, '
f'the {self.tgt_txt} of the tgt_txt field will be ignored.')
self.src_txt = src_txt
self.tokenizer = JiebaBPETokenizer(
osp.join(model_dir, 'tokenizer.json'))
self.max_length = sequence_length
@@ -252,6 +247,7 @@ class TextGenerationJiebaPreprocessor(TextGenerationPreprocessorBase):
'tokens': tokens[:-1],
'labels': tokens[1:],
'prompt_length': prompt_length,
'is_pair': int(sequence2 is not None),
}

View File

@@ -2,27 +2,25 @@
import os
from collections.abc import Mapping
from typing import List
from typing import Any, Dict, List
import torch
from megatron_util import mpu
from modelscope.metainfo import Trainers
from modelscope.models import TorchModel
from modelscope.models.nlp import GPT3ForTextGeneration
from modelscope.trainers.builder import TRAINERS
from modelscope.trainers.nlp_trainer import NlpEpochBasedTrainer
from modelscope.utils.config import Config
from modelscope.utils.file_utils import func_receive_dict_inputs
@TRAINERS.register_module(module_name=Trainers.gpt3_trainer)
class GPT3Trainer(NlpEpochBasedTrainer):
def rebuild_config(self, cfg: Config):
super().rebuild_config(cfg)
cfg.model.rank = int(os.environ.get('LOCAL_RANK', -1))
cfg.model.master_ip = os.environ.get('MASTER_ADDR', '127.0.0.1')
cfg.model.master_port = os.environ.get('MASTER_PORT', '29500')
cfg = super().rebuild_config(cfg)
cfg.model.rank = int(os.environ.get('RANK', 0))
return cfg
def train_step(self, model: TorchModel, inputs: Mapping):
@@ -39,13 +37,19 @@ class GPT3Trainer(NlpEpochBasedTrainer):
model = self.model.module if self._dist else self.model
model.eval()
with torch.no_grad():
if isinstance(
data,
Mapping) and not func_receive_dict_inputs(model.generate):
result = model.generate(**data)
else:
result = model.generate(data)
if self._is_pair(data):
return self._generate_eval(model, data)
else:
return self._forward_eval(model, data)
@staticmethod
def _is_pair(data: Dict[str, Any]) -> bool:
return 'is_pair' in data and bool(data['is_pair'][0])
def _generate_eval(self, model: GPT3ForTextGeneration,
data: Dict[str, Any]) -> Dict[str, Any]:
data['do_sample'] = False
result = model.generate(data)
prompt_length: List[int] = data['prompt_length']
result['preds'] = [
@@ -56,6 +60,8 @@ class GPT3Trainer(NlpEpochBasedTrainer):
self._decode(seq[skip_len - 1:])
for seq, skip_len in zip(data['labels'], prompt_length)
]
assert len(result['preds']) == len(data['tgts'])
return result
def _forward_eval(self, model: GPT3ForTextGeneration,
data: Dict[str, Any]) -> Dict[str, Any]:
return model.forward(data)

View File

@@ -52,6 +52,16 @@ class TestFinetuneTextGeneration(unittest.TestCase):
'batch_size_per_gpu': 16,
'workers_per_gpu': 1
}
cfg.train.hooks.append({
'type': 'EvaluationHook',
'by_epoch': True,
'interval': 1
})
cfg.evaluation.dataloader = {
'batch_size_per_gpu': 8,
'workers_per_gpu': 1
}
cfg.evaluation.metrics = 'ppl'
return cfg
kwargs = dict(
@@ -73,6 +83,7 @@ class TestFinetuneTextGeneration(unittest.TestCase):
def test_finetune_dureader(self):
# DuReader_robust-QG is an example data set,
# users can also use their own data set for training
dataset_dict = MsDataset.load('DuReader_robust-QG')
train_dataset = dataset_dict['train'].remap_columns({'text1': 'src_txt', 'text2': 'tgt_txt'}) \
@@ -81,6 +92,7 @@ class TestFinetuneTextGeneration(unittest.TestCase):
.map(lambda example: {'src_txt': example['src_txt'].replace('[SEP]', '<sep>') + '\n'})
max_epochs = 10
tmp_dir = './gpt3_dureader'
num_warmup_steps = 200
@@ -98,7 +110,7 @@ class TestFinetuneTextGeneration(unittest.TestCase):
'by_epoch': False
}
}
cfg.train.optimizer = {'type': 'AdamW', 'lr': 3e-4}
cfg.train.optimizer = {'type': 'AdamW', 'lr': 1e-4}
cfg.train.dataloader = {
'batch_size_per_gpu': 16,
'workers_per_gpu': 1