mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 20:19:22 +01:00
[to #42322933] Add beam search and pair finetune for GPT-3
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11397726 * test finetune weather * support ppl and generation metrics
This commit is contained in:
@@ -798,8 +798,8 @@ class GPT3Model(PreTrainedModel):
|
||||
if labels is not None:
|
||||
# [b s] => [s b]
|
||||
labels = labels.transpose(0, 1).contiguous()
|
||||
losses = mpu.vocab_parallel_cross_entropy(logits_parallel.float(),
|
||||
labels)
|
||||
losses = mpu.vocab_parallel_cross_entropy(
|
||||
logits_parallel.clone().float(), labels)
|
||||
# [s b] => [b s]
|
||||
losses = losses.transpose(0, 1).contiguous()
|
||||
|
||||
@@ -1011,7 +1011,8 @@ class DistributedGPT3(TorchModel):
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
labels=None,
|
||||
prompt_length=None):
|
||||
prompt_length=None,
|
||||
is_pair=(False, )):
|
||||
|
||||
logits, losses = self.dist_model(
|
||||
tokens,
|
||||
@@ -1026,6 +1027,9 @@ class DistributedGPT3(TorchModel):
|
||||
else:
|
||||
loss_mask = torch.ones(
|
||||
tokens.size(), dtype=torch.float, device=tokens.device)
|
||||
if is_pair[0]:
|
||||
for i, length in enumerate(prompt_length):
|
||||
loss_mask[i, :length] = 0
|
||||
|
||||
losses = losses.float()
|
||||
loss_mask = loss_mask.view(-1).float()
|
||||
@@ -1033,13 +1037,13 @@ class DistributedGPT3(TorchModel):
|
||||
|
||||
return TextGenerationModelOutput(logits=logits, loss=loss)
|
||||
|
||||
def generate(self,
|
||||
tokens,
|
||||
temperature=1.0,
|
||||
use_eod_token_for_early_termination=True,
|
||||
stop_on_double_eol=False,
|
||||
stop_on_eol=False,
|
||||
**kwargs):
|
||||
def sample(self,
|
||||
tokens,
|
||||
temperature=1.0,
|
||||
use_eod_token_for_early_termination=True,
|
||||
stop_on_double_eol=False,
|
||||
stop_on_eol=False,
|
||||
**kwargs):
|
||||
batch_size = tokens.size(0)
|
||||
lengths = kwargs.pop(
|
||||
'prompt_length',
|
||||
@@ -1074,68 +1078,182 @@ class DistributedGPT3(TorchModel):
|
||||
# Run infernece
|
||||
# =============
|
||||
|
||||
with torch.no_grad():
|
||||
attention_mask, position_ids = \
|
||||
GPT3Model.build_attention_mask_and_position_ids(tokens)
|
||||
prev_context_length = 0
|
||||
for context_length in range(min_prompt_length,
|
||||
max_sequence_length):
|
||||
attention_mask, position_ids = \
|
||||
GPT3Model.build_attention_mask_and_position_ids(tokens)
|
||||
prev_context_length = 0
|
||||
for context_length in range(min_prompt_length, max_sequence_length):
|
||||
|
||||
# Pick the slice that we need to pass through the network.
|
||||
tokens2use = tokens[:, prev_context_length:context_length]
|
||||
positions2use = position_ids[:, prev_context_length:
|
||||
context_length]
|
||||
attention_mask2use = attention_mask[
|
||||
..., prev_context_length:context_length, :context_length]
|
||||
# Pick the slice that we need to pass through the network.
|
||||
tokens2use = tokens[:, prev_context_length:context_length]
|
||||
positions2use = position_ids[:, prev_context_length:context_length]
|
||||
attention_mask2use = attention_mask[
|
||||
..., prev_context_length:context_length, :context_length]
|
||||
|
||||
# logits will be meanigful only in the last pipeline stage.
|
||||
logits = self(tokens2use, attention_mask2use,
|
||||
positions2use).logits
|
||||
# logits will be meanigful only in the last pipeline stage.
|
||||
logits = self(tokens2use, attention_mask2use, positions2use).logits
|
||||
|
||||
# Sample.
|
||||
last_token_logits = logits[:, -1, :]
|
||||
new_sample = sample(
|
||||
last_token_logits,
|
||||
top_k=self.config.top_k,
|
||||
top_p=self.config.top_p,
|
||||
temperature=temperature,
|
||||
vocab_size=self.config.vocab_size)
|
||||
# Sample.
|
||||
last_token_logits = logits[:, -1, :]
|
||||
new_sample = sample(
|
||||
last_token_logits,
|
||||
top_k=self.config.top_k,
|
||||
top_p=self.config.top_p,
|
||||
temperature=temperature,
|
||||
vocab_size=self.config.vocab_size)
|
||||
|
||||
# If a prompt length is smaller or equal th current context
|
||||
# length, it means we have started generating tokens
|
||||
started = lengths <= context_length
|
||||
# Update the tokens.
|
||||
tokens[started, context_length] = new_sample[started]
|
||||
# If a prompt length is smaller or equal th current context
|
||||
# length, it means we have started generating tokens
|
||||
started = lengths <= context_length
|
||||
# Update the tokens.
|
||||
tokens[started, context_length] = new_sample[started]
|
||||
|
||||
# Update the context length for the next token generation.
|
||||
prev_context_length = context_length
|
||||
# Update the context length for the next token generation.
|
||||
prev_context_length = context_length
|
||||
|
||||
# instead tokenization should be in the inference loop so stop sequences can be used
|
||||
if stop_on_double_eol:
|
||||
hit_double_eol = (new_sample
|
||||
== 628).byte() & started.byte()
|
||||
hit_two_eols = (new_sample == 198).byte() & (
|
||||
tokens[:, context_length - 1]
|
||||
== 198).byte() & started.byte()
|
||||
done_token = hit_double_eol | hit_two_eols
|
||||
elif stop_on_eol:
|
||||
hit_double_eol = (new_sample
|
||||
== 628).byte() & started.byte()
|
||||
hit_eol = (new_sample == 198).byte() & started.byte()
|
||||
done_token = hit_double_eol | hit_eol
|
||||
else:
|
||||
done_token = (new_sample == termination_id).byte() & \
|
||||
started.byte()
|
||||
# instead tokenization should be in the inference loop so stop sequences can be used
|
||||
if stop_on_double_eol:
|
||||
hit_double_eol = (new_sample == 628).byte() & started.byte()
|
||||
hit_two_eols = (new_sample == 198).byte() & (
|
||||
tokens[:,
|
||||
context_length - 1] == 198).byte() & started.byte()
|
||||
done_token = hit_double_eol | hit_two_eols
|
||||
elif stop_on_eol:
|
||||
hit_double_eol = (new_sample == 628).byte() & started.byte()
|
||||
hit_eol = (new_sample == 198).byte() & started.byte()
|
||||
done_token = hit_double_eol | hit_eol
|
||||
else:
|
||||
done_token = (new_sample == termination_id).byte() & \
|
||||
started.byte()
|
||||
|
||||
is_generation_done = is_generation_done | done_token
|
||||
done = torch.all(is_generation_done)
|
||||
is_generation_done = is_generation_done | done_token
|
||||
done = torch.all(is_generation_done)
|
||||
|
||||
if use_eod_token_for_early_termination and done:
|
||||
break
|
||||
if use_eod_token_for_early_termination and done:
|
||||
break
|
||||
|
||||
tokens = tokens[:, :(context_length + 1)]
|
||||
return TokenGeneratorOutput(sequences=tokens)
|
||||
|
||||
def beam_search(self, tokens, beam_size=5, num_return_gen=1, **kwargs):
|
||||
batch_size = tokens.size(0)
|
||||
assert (batch_size == 1)
|
||||
prompt_length = kwargs.pop(
|
||||
'prompt_length',
|
||||
torch.tensor([tokens.size(1)], device=tokens.device)).item()
|
||||
stop_token = self.config.eod_id
|
||||
pads = torch.ones(
|
||||
1, self.config.tokens_to_generate,
|
||||
device=tokens.device).long() * stop_token
|
||||
tokens = torch.cat((tokens, pads), dim=-1)
|
||||
final_sequence_length = tokens.size(1)
|
||||
final_sequence_length = min(final_sequence_length,
|
||||
self.config.max_position_embeddings)
|
||||
|
||||
# If the context is too big, this happens
|
||||
if prompt_length >= final_sequence_length:
|
||||
raise ValueError('context length + tokens_to_generate too large')
|
||||
|
||||
# Initialize inference parameters.
|
||||
self.inference_params = InferenceParams(beam_size,
|
||||
final_sequence_length)
|
||||
|
||||
beam_hyp = BeamHypotheses(beam_size)
|
||||
done = False
|
||||
scores = torch.zeros(
|
||||
beam_size, dtype=torch.float32,
|
||||
device=torch.cuda.current_device()).unsqueeze(1)
|
||||
|
||||
# =============
|
||||
# Run infernece
|
||||
# =============
|
||||
tokens = tokens.repeat(beam_size, 1)
|
||||
attention_mask, position_ids = \
|
||||
GPT3Model.build_attention_mask_and_position_ids(tokens)
|
||||
prev_context_length = 0
|
||||
for context_length in range(prompt_length, final_sequence_length):
|
||||
|
||||
# Pick the slice that we need to pass through the network.
|
||||
tokens2use = tokens[:, prev_context_length:context_length]
|
||||
positions2use = position_ids[:, prev_context_length:context_length]
|
||||
attention_mask2use = attention_mask[
|
||||
..., prev_context_length:context_length, :context_length]
|
||||
|
||||
# logits will be meanigful only in the last pipeline stage.
|
||||
logits = self(tokens2use, attention_mask2use, positions2use).logits
|
||||
|
||||
vocab_size = logits.size(2)
|
||||
log_probs = F.log_softmax(logits, dim=2)
|
||||
new_scores = log_probs[:, -1, :] + scores
|
||||
|
||||
if context_length == prompt_length: # if this is the first one
|
||||
sorted_scores, indices = torch.sort(
|
||||
new_scores[0, :], descending=True)
|
||||
else:
|
||||
sorted_scores, indices = torch.sort(
|
||||
new_scores.view(-1), descending=True)
|
||||
|
||||
best_beam_ids = torch.div(indices[:2 * beam_size],
|
||||
vocab_size).trunc().long()
|
||||
best_words = indices[:2 * beam_size] % vocab_size
|
||||
best_scores = sorted_scores[:2 * beam_size]
|
||||
|
||||
next_beams = []
|
||||
for beam_token_rank, (token_id, beam_score, beam_id) in enumerate(
|
||||
zip(best_words, best_scores, best_beam_ids)):
|
||||
if token_id.item() == stop_token:
|
||||
# if beam_token does not belong to top num_beams tokens, it should not be added
|
||||
is_beam_token_worse_than_top_num_beams = beam_token_rank >= beam_size
|
||||
if is_beam_token_worse_than_top_num_beams:
|
||||
continue
|
||||
beam_hyp.add(tokens[beam_id].clone(), beam_score,
|
||||
context_length + 1 - prompt_length)
|
||||
else:
|
||||
# add next predicted token since it is not eos_token
|
||||
next_beams.append((token_id, beam_score, beam_id))
|
||||
|
||||
if len(next_beams) == beam_size:
|
||||
break
|
||||
|
||||
if beam_hyp.is_done(best_scores.max().item(),
|
||||
context_length + 1 - prompt_length):
|
||||
done = True
|
||||
break
|
||||
|
||||
best_batches = tokens.new([item[2] for item in next_beams])
|
||||
tokens = tokens[best_batches, :]
|
||||
tokens[:, context_length] = tokens.new(
|
||||
[item[0] for item in next_beams])
|
||||
scores = scores.new([item[1] for item in next_beams]).unsqueeze(1)
|
||||
|
||||
# set inference key values to make it consistent with best beam index
|
||||
self.inference_params.swap_key_value_dict(best_batches)
|
||||
|
||||
# Update the context length for the next token generation.
|
||||
prev_context_length = context_length
|
||||
|
||||
# if cannot find stop token, add open beams to hyps
|
||||
if not done:
|
||||
for beam_id in range(beam_size):
|
||||
beam_hyp.add(tokens[beam_id].clone(), scores[beam_id],
|
||||
context_length + 1 - prompt_length)
|
||||
|
||||
# rank based on scores
|
||||
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0], reverse=True)
|
||||
num_return_gen = min(num_return_gen, len(sorted_hyps))
|
||||
scores = [sorted_hyps[i][0] for i in range(num_return_gen)]
|
||||
tokens = [sorted_hyps[i][1] for i in range(num_return_gen)]
|
||||
scores = torch.stack(scores, dim=0)
|
||||
tokens = torch.stack(tokens, dim=0)
|
||||
|
||||
return TokenGeneratorOutput(sequences=tokens, scores=scores)
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, tokens, do_sample=True, *args, **kwargs):
|
||||
if do_sample:
|
||||
return self.sample(tokens, *args, **kwargs)
|
||||
else:
|
||||
return self.beam_search(tokens, *args, **kwargs)
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
||||
return self.dist_model.state_dict(destination, prefix, keep_vars)
|
||||
|
||||
@@ -1154,3 +1272,59 @@ class DistributedGPT3(TorchModel):
|
||||
|
||||
return super().save_pretrained(target_folder, save_checkpoint_names,
|
||||
save_checkpoint, config, **kwargs)
|
||||
|
||||
|
||||
class BeamHypotheses:
|
||||
|
||||
def __init__(self,
|
||||
num_beams: int,
|
||||
length_penalty: float = 1.0,
|
||||
early_stopping: bool = False):
|
||||
"""
|
||||
Initialize n-best list of hypotheses.
|
||||
"""
|
||||
self.length_penalty = length_penalty
|
||||
self.early_stopping = early_stopping
|
||||
self.num_beams = num_beams
|
||||
self.beams = []
|
||||
self.worst_score = 1e9
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Number of hypotheses in the list.
|
||||
"""
|
||||
return len(self.beams)
|
||||
|
||||
def add(self,
|
||||
hyp: torch.LongTensor,
|
||||
sum_logprobs: float,
|
||||
beam_indices: Optional[torch.LongTensor] = None):
|
||||
"""
|
||||
Add a new hypothesis to the list.
|
||||
"""
|
||||
score = sum_logprobs / (hyp.shape[-1]**self.length_penalty)
|
||||
if len(self) < self.num_beams or score > self.worst_score:
|
||||
self.beams.append((score, hyp, beam_indices))
|
||||
if len(self) > self.num_beams:
|
||||
sorted_next_scores = sorted([
|
||||
(s, idx) for idx, (s, _, _) in enumerate(self.beams)
|
||||
])
|
||||
del self.beams[sorted_next_scores[0][1]]
|
||||
self.worst_score = sorted_next_scores[1][0]
|
||||
else:
|
||||
self.worst_score = min(score, self.worst_score)
|
||||
|
||||
def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool:
|
||||
"""
|
||||
If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
|
||||
one in the heap, then we are done with this sentence.
|
||||
"""
|
||||
|
||||
if len(self) < self.num_beams:
|
||||
return False
|
||||
elif self.early_stopping:
|
||||
return True
|
||||
else:
|
||||
cur_score = best_sum_logprobs / cur_len**self.length_penalty
|
||||
ret = self.worst_score >= cur_score
|
||||
return ret
|
||||
|
||||
@@ -27,7 +27,7 @@ class GPT3ForTextGeneration(TorchModel):
|
||||
# Temporarily compatible with DistributedGPT3 and GPT3Model,
|
||||
# the base/large model based on GPT3Model will be replaced in the future,
|
||||
# and GPT3Model will be deprecated
|
||||
if 'model_parallel_size' in kwargs:
|
||||
if 'world_size' in kwargs:
|
||||
from modelscope.models.nlp import DistributedGPT3
|
||||
self.model = DistributedGPT3(model_dir, **kwargs)
|
||||
else:
|
||||
|
||||
@@ -194,16 +194,11 @@ class TextGenerationJiebaPreprocessor(TextGenerationPreprocessorBase):
|
||||
model_dir: str,
|
||||
mode: str = ModeKeys.INFERENCE,
|
||||
src_txt='src_txt',
|
||||
tgt_txt=None,
|
||||
tgt_txt='tgt_txt',
|
||||
sequence_length: int = 128,
|
||||
use_fast=None):
|
||||
from modelscope.models.nlp.gpt3 import JiebaBPETokenizer
|
||||
super().__init__(mode, src_txt, tgt_txt)
|
||||
if self.tgt_txt is not None:
|
||||
logger.warning(
|
||||
f'TextGenerationJiebaPreprocessor currently does not support training, '
|
||||
f'the {self.tgt_txt} of the tgt_txt field will be ignored.')
|
||||
self.src_txt = src_txt
|
||||
self.tokenizer = JiebaBPETokenizer(
|
||||
osp.join(model_dir, 'tokenizer.json'))
|
||||
self.max_length = sequence_length
|
||||
@@ -252,6 +247,7 @@ class TextGenerationJiebaPreprocessor(TextGenerationPreprocessorBase):
|
||||
'tokens': tokens[:-1],
|
||||
'labels': tokens[1:],
|
||||
'prompt_length': prompt_length,
|
||||
'is_pair': int(sequence2 is not None),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -2,27 +2,25 @@
|
||||
|
||||
import os
|
||||
from collections.abc import Mapping
|
||||
from typing import List
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
from megatron_util import mpu
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.models import TorchModel
|
||||
from modelscope.models.nlp import GPT3ForTextGeneration
|
||||
from modelscope.trainers.builder import TRAINERS
|
||||
from modelscope.trainers.nlp_trainer import NlpEpochBasedTrainer
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.file_utils import func_receive_dict_inputs
|
||||
|
||||
|
||||
@TRAINERS.register_module(module_name=Trainers.gpt3_trainer)
|
||||
class GPT3Trainer(NlpEpochBasedTrainer):
|
||||
|
||||
def rebuild_config(self, cfg: Config):
|
||||
super().rebuild_config(cfg)
|
||||
cfg.model.rank = int(os.environ.get('LOCAL_RANK', -1))
|
||||
cfg.model.master_ip = os.environ.get('MASTER_ADDR', '127.0.0.1')
|
||||
cfg.model.master_port = os.environ.get('MASTER_PORT', '29500')
|
||||
cfg = super().rebuild_config(cfg)
|
||||
cfg.model.rank = int(os.environ.get('RANK', 0))
|
||||
return cfg
|
||||
|
||||
def train_step(self, model: TorchModel, inputs: Mapping):
|
||||
@@ -39,13 +37,19 @@ class GPT3Trainer(NlpEpochBasedTrainer):
|
||||
model = self.model.module if self._dist else self.model
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
if isinstance(
|
||||
data,
|
||||
Mapping) and not func_receive_dict_inputs(model.generate):
|
||||
result = model.generate(**data)
|
||||
else:
|
||||
result = model.generate(data)
|
||||
if self._is_pair(data):
|
||||
return self._generate_eval(model, data)
|
||||
else:
|
||||
return self._forward_eval(model, data)
|
||||
|
||||
@staticmethod
|
||||
def _is_pair(data: Dict[str, Any]) -> bool:
|
||||
return 'is_pair' in data and bool(data['is_pair'][0])
|
||||
|
||||
def _generate_eval(self, model: GPT3ForTextGeneration,
|
||||
data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
data['do_sample'] = False
|
||||
result = model.generate(data)
|
||||
|
||||
prompt_length: List[int] = data['prompt_length']
|
||||
result['preds'] = [
|
||||
@@ -56,6 +60,8 @@ class GPT3Trainer(NlpEpochBasedTrainer):
|
||||
self._decode(seq[skip_len - 1:])
|
||||
for seq, skip_len in zip(data['labels'], prompt_length)
|
||||
]
|
||||
assert len(result['preds']) == len(data['tgts'])
|
||||
|
||||
return result
|
||||
|
||||
def _forward_eval(self, model: GPT3ForTextGeneration,
|
||||
data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return model.forward(data)
|
||||
|
||||
@@ -52,6 +52,16 @@ class TestFinetuneTextGeneration(unittest.TestCase):
|
||||
'batch_size_per_gpu': 16,
|
||||
'workers_per_gpu': 1
|
||||
}
|
||||
cfg.train.hooks.append({
|
||||
'type': 'EvaluationHook',
|
||||
'by_epoch': True,
|
||||
'interval': 1
|
||||
})
|
||||
cfg.evaluation.dataloader = {
|
||||
'batch_size_per_gpu': 8,
|
||||
'workers_per_gpu': 1
|
||||
}
|
||||
cfg.evaluation.metrics = 'ppl'
|
||||
return cfg
|
||||
|
||||
kwargs = dict(
|
||||
@@ -73,6 +83,7 @@ class TestFinetuneTextGeneration(unittest.TestCase):
|
||||
def test_finetune_dureader(self):
|
||||
# DuReader_robust-QG is an example data set,
|
||||
# users can also use their own data set for training
|
||||
|
||||
dataset_dict = MsDataset.load('DuReader_robust-QG')
|
||||
|
||||
train_dataset = dataset_dict['train'].remap_columns({'text1': 'src_txt', 'text2': 'tgt_txt'}) \
|
||||
@@ -81,6 +92,7 @@ class TestFinetuneTextGeneration(unittest.TestCase):
|
||||
.map(lambda example: {'src_txt': example['src_txt'].replace('[SEP]', '<sep>') + '\n'})
|
||||
|
||||
max_epochs = 10
|
||||
|
||||
tmp_dir = './gpt3_dureader'
|
||||
|
||||
num_warmup_steps = 200
|
||||
@@ -98,7 +110,7 @@ class TestFinetuneTextGeneration(unittest.TestCase):
|
||||
'by_epoch': False
|
||||
}
|
||||
}
|
||||
cfg.train.optimizer = {'type': 'AdamW', 'lr': 3e-4}
|
||||
cfg.train.optimizer = {'type': 'AdamW', 'lr': 1e-4}
|
||||
cfg.train.dataloader = {
|
||||
'batch_size_per_gpu': 16,
|
||||
'workers_per_gpu': 1
|
||||
|
||||
Reference in New Issue
Block a user