diff --git a/modelscope/models/nlp/gpt3/distributed_gpt3.py b/modelscope/models/nlp/gpt3/distributed_gpt3.py index ceb8c218..e469f866 100644 --- a/modelscope/models/nlp/gpt3/distributed_gpt3.py +++ b/modelscope/models/nlp/gpt3/distributed_gpt3.py @@ -798,8 +798,8 @@ class GPT3Model(PreTrainedModel): if labels is not None: # [b s] => [s b] labels = labels.transpose(0, 1).contiguous() - losses = mpu.vocab_parallel_cross_entropy(logits_parallel.float(), - labels) + losses = mpu.vocab_parallel_cross_entropy( + logits_parallel.clone().float(), labels) # [s b] => [b s] losses = losses.transpose(0, 1).contiguous() @@ -1011,7 +1011,8 @@ class DistributedGPT3(TorchModel): attention_mask=None, position_ids=None, labels=None, - prompt_length=None): + prompt_length=None, + is_pair=(False, )): logits, losses = self.dist_model( tokens, @@ -1026,6 +1027,9 @@ class DistributedGPT3(TorchModel): else: loss_mask = torch.ones( tokens.size(), dtype=torch.float, device=tokens.device) + if is_pair[0]: + for i, length in enumerate(prompt_length): + loss_mask[i, :length] = 0 losses = losses.float() loss_mask = loss_mask.view(-1).float() @@ -1033,13 +1037,13 @@ class DistributedGPT3(TorchModel): return TextGenerationModelOutput(logits=logits, loss=loss) - def generate(self, - tokens, - temperature=1.0, - use_eod_token_for_early_termination=True, - stop_on_double_eol=False, - stop_on_eol=False, - **kwargs): + def sample(self, + tokens, + temperature=1.0, + use_eod_token_for_early_termination=True, + stop_on_double_eol=False, + stop_on_eol=False, + **kwargs): batch_size = tokens.size(0) lengths = kwargs.pop( 'prompt_length', @@ -1074,68 +1078,182 @@ class DistributedGPT3(TorchModel): # Run infernece # ============= - with torch.no_grad(): - attention_mask, position_ids = \ - GPT3Model.build_attention_mask_and_position_ids(tokens) - prev_context_length = 0 - for context_length in range(min_prompt_length, - max_sequence_length): + attention_mask, position_ids = \ + GPT3Model.build_attention_mask_and_position_ids(tokens) + prev_context_length = 0 + for context_length in range(min_prompt_length, max_sequence_length): - # Pick the slice that we need to pass through the network. - tokens2use = tokens[:, prev_context_length:context_length] - positions2use = position_ids[:, prev_context_length: - context_length] - attention_mask2use = attention_mask[ - ..., prev_context_length:context_length, :context_length] + # Pick the slice that we need to pass through the network. + tokens2use = tokens[:, prev_context_length:context_length] + positions2use = position_ids[:, prev_context_length:context_length] + attention_mask2use = attention_mask[ + ..., prev_context_length:context_length, :context_length] - # logits will be meanigful only in the last pipeline stage. - logits = self(tokens2use, attention_mask2use, - positions2use).logits + # logits will be meanigful only in the last pipeline stage. + logits = self(tokens2use, attention_mask2use, positions2use).logits - # Sample. - last_token_logits = logits[:, -1, :] - new_sample = sample( - last_token_logits, - top_k=self.config.top_k, - top_p=self.config.top_p, - temperature=temperature, - vocab_size=self.config.vocab_size) + # Sample. + last_token_logits = logits[:, -1, :] + new_sample = sample( + last_token_logits, + top_k=self.config.top_k, + top_p=self.config.top_p, + temperature=temperature, + vocab_size=self.config.vocab_size) - # If a prompt length is smaller or equal th current context - # length, it means we have started generating tokens - started = lengths <= context_length - # Update the tokens. - tokens[started, context_length] = new_sample[started] + # If a prompt length is smaller or equal th current context + # length, it means we have started generating tokens + started = lengths <= context_length + # Update the tokens. + tokens[started, context_length] = new_sample[started] - # Update the context length for the next token generation. - prev_context_length = context_length + # Update the context length for the next token generation. + prev_context_length = context_length - # instead tokenization should be in the inference loop so stop sequences can be used - if stop_on_double_eol: - hit_double_eol = (new_sample - == 628).byte() & started.byte() - hit_two_eols = (new_sample == 198).byte() & ( - tokens[:, context_length - 1] - == 198).byte() & started.byte() - done_token = hit_double_eol | hit_two_eols - elif stop_on_eol: - hit_double_eol = (new_sample - == 628).byte() & started.byte() - hit_eol = (new_sample == 198).byte() & started.byte() - done_token = hit_double_eol | hit_eol - else: - done_token = (new_sample == termination_id).byte() & \ - started.byte() + # instead tokenization should be in the inference loop so stop sequences can be used + if stop_on_double_eol: + hit_double_eol = (new_sample == 628).byte() & started.byte() + hit_two_eols = (new_sample == 198).byte() & ( + tokens[:, + context_length - 1] == 198).byte() & started.byte() + done_token = hit_double_eol | hit_two_eols + elif stop_on_eol: + hit_double_eol = (new_sample == 628).byte() & started.byte() + hit_eol = (new_sample == 198).byte() & started.byte() + done_token = hit_double_eol | hit_eol + else: + done_token = (new_sample == termination_id).byte() & \ + started.byte() - is_generation_done = is_generation_done | done_token - done = torch.all(is_generation_done) + is_generation_done = is_generation_done | done_token + done = torch.all(is_generation_done) - if use_eod_token_for_early_termination and done: - break + if use_eod_token_for_early_termination and done: + break tokens = tokens[:, :(context_length + 1)] return TokenGeneratorOutput(sequences=tokens) + def beam_search(self, tokens, beam_size=5, num_return_gen=1, **kwargs): + batch_size = tokens.size(0) + assert (batch_size == 1) + prompt_length = kwargs.pop( + 'prompt_length', + torch.tensor([tokens.size(1)], device=tokens.device)).item() + stop_token = self.config.eod_id + pads = torch.ones( + 1, self.config.tokens_to_generate, + device=tokens.device).long() * stop_token + tokens = torch.cat((tokens, pads), dim=-1) + final_sequence_length = tokens.size(1) + final_sequence_length = min(final_sequence_length, + self.config.max_position_embeddings) + + # If the context is too big, this happens + if prompt_length >= final_sequence_length: + raise ValueError('context length + tokens_to_generate too large') + + # Initialize inference parameters. + self.inference_params = InferenceParams(beam_size, + final_sequence_length) + + beam_hyp = BeamHypotheses(beam_size) + done = False + scores = torch.zeros( + beam_size, dtype=torch.float32, + device=torch.cuda.current_device()).unsqueeze(1) + + # ============= + # Run infernece + # ============= + tokens = tokens.repeat(beam_size, 1) + attention_mask, position_ids = \ + GPT3Model.build_attention_mask_and_position_ids(tokens) + prev_context_length = 0 + for context_length in range(prompt_length, final_sequence_length): + + # Pick the slice that we need to pass through the network. + tokens2use = tokens[:, prev_context_length:context_length] + positions2use = position_ids[:, prev_context_length:context_length] + attention_mask2use = attention_mask[ + ..., prev_context_length:context_length, :context_length] + + # logits will be meanigful only in the last pipeline stage. + logits = self(tokens2use, attention_mask2use, positions2use).logits + + vocab_size = logits.size(2) + log_probs = F.log_softmax(logits, dim=2) + new_scores = log_probs[:, -1, :] + scores + + if context_length == prompt_length: # if this is the first one + sorted_scores, indices = torch.sort( + new_scores[0, :], descending=True) + else: + sorted_scores, indices = torch.sort( + new_scores.view(-1), descending=True) + + best_beam_ids = torch.div(indices[:2 * beam_size], + vocab_size).trunc().long() + best_words = indices[:2 * beam_size] % vocab_size + best_scores = sorted_scores[:2 * beam_size] + + next_beams = [] + for beam_token_rank, (token_id, beam_score, beam_id) in enumerate( + zip(best_words, best_scores, best_beam_ids)): + if token_id.item() == stop_token: + # if beam_token does not belong to top num_beams tokens, it should not be added + is_beam_token_worse_than_top_num_beams = beam_token_rank >= beam_size + if is_beam_token_worse_than_top_num_beams: + continue + beam_hyp.add(tokens[beam_id].clone(), beam_score, + context_length + 1 - prompt_length) + else: + # add next predicted token since it is not eos_token + next_beams.append((token_id, beam_score, beam_id)) + + if len(next_beams) == beam_size: + break + + if beam_hyp.is_done(best_scores.max().item(), + context_length + 1 - prompt_length): + done = True + break + + best_batches = tokens.new([item[2] for item in next_beams]) + tokens = tokens[best_batches, :] + tokens[:, context_length] = tokens.new( + [item[0] for item in next_beams]) + scores = scores.new([item[1] for item in next_beams]).unsqueeze(1) + + # set inference key values to make it consistent with best beam index + self.inference_params.swap_key_value_dict(best_batches) + + # Update the context length for the next token generation. + prev_context_length = context_length + + # if cannot find stop token, add open beams to hyps + if not done: + for beam_id in range(beam_size): + beam_hyp.add(tokens[beam_id].clone(), scores[beam_id], + context_length + 1 - prompt_length) + + # rank based on scores + sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0], reverse=True) + num_return_gen = min(num_return_gen, len(sorted_hyps)) + scores = [sorted_hyps[i][0] for i in range(num_return_gen)] + tokens = [sorted_hyps[i][1] for i in range(num_return_gen)] + scores = torch.stack(scores, dim=0) + tokens = torch.stack(tokens, dim=0) + + return TokenGeneratorOutput(sequences=tokens, scores=scores) + + @torch.no_grad() + def generate(self, tokens, do_sample=True, *args, **kwargs): + if do_sample: + return self.sample(tokens, *args, **kwargs) + else: + return self.beam_search(tokens, *args, **kwargs) + def state_dict(self, destination=None, prefix='', keep_vars=False): return self.dist_model.state_dict(destination, prefix, keep_vars) @@ -1154,3 +1272,59 @@ class DistributedGPT3(TorchModel): return super().save_pretrained(target_folder, save_checkpoint_names, save_checkpoint, config, **kwargs) + + +class BeamHypotheses: + + def __init__(self, + num_beams: int, + length_penalty: float = 1.0, + early_stopping: bool = False): + """ + Initialize n-best list of hypotheses. + """ + self.length_penalty = length_penalty + self.early_stopping = early_stopping + self.num_beams = num_beams + self.beams = [] + self.worst_score = 1e9 + + def __len__(self): + """ + Number of hypotheses in the list. + """ + return len(self.beams) + + def add(self, + hyp: torch.LongTensor, + sum_logprobs: float, + beam_indices: Optional[torch.LongTensor] = None): + """ + Add a new hypothesis to the list. + """ + score = sum_logprobs / (hyp.shape[-1]**self.length_penalty) + if len(self) < self.num_beams or score > self.worst_score: + self.beams.append((score, hyp, beam_indices)) + if len(self) > self.num_beams: + sorted_next_scores = sorted([ + (s, idx) for idx, (s, _, _) in enumerate(self.beams) + ]) + del self.beams[sorted_next_scores[0][1]] + self.worst_score = sorted_next_scores[1][0] + else: + self.worst_score = min(score, self.worst_score) + + def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool: + """ + If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst + one in the heap, then we are done with this sentence. + """ + + if len(self) < self.num_beams: + return False + elif self.early_stopping: + return True + else: + cur_score = best_sum_logprobs / cur_len**self.length_penalty + ret = self.worst_score >= cur_score + return ret diff --git a/modelscope/models/nlp/gpt3/text_generation.py b/modelscope/models/nlp/gpt3/text_generation.py index 9361f0a2..27ce09d6 100644 --- a/modelscope/models/nlp/gpt3/text_generation.py +++ b/modelscope/models/nlp/gpt3/text_generation.py @@ -27,7 +27,7 @@ class GPT3ForTextGeneration(TorchModel): # Temporarily compatible with DistributedGPT3 and GPT3Model, # the base/large model based on GPT3Model will be replaced in the future, # and GPT3Model will be deprecated - if 'model_parallel_size' in kwargs: + if 'world_size' in kwargs: from modelscope.models.nlp import DistributedGPT3 self.model = DistributedGPT3(model_dir, **kwargs) else: diff --git a/modelscope/preprocessors/nlp/text_generation_preprocessor.py b/modelscope/preprocessors/nlp/text_generation_preprocessor.py index 65697b72..5f30b70a 100644 --- a/modelscope/preprocessors/nlp/text_generation_preprocessor.py +++ b/modelscope/preprocessors/nlp/text_generation_preprocessor.py @@ -194,16 +194,11 @@ class TextGenerationJiebaPreprocessor(TextGenerationPreprocessorBase): model_dir: str, mode: str = ModeKeys.INFERENCE, src_txt='src_txt', - tgt_txt=None, + tgt_txt='tgt_txt', sequence_length: int = 128, use_fast=None): from modelscope.models.nlp.gpt3 import JiebaBPETokenizer super().__init__(mode, src_txt, tgt_txt) - if self.tgt_txt is not None: - logger.warning( - f'TextGenerationJiebaPreprocessor currently does not support training, ' - f'the {self.tgt_txt} of the tgt_txt field will be ignored.') - self.src_txt = src_txt self.tokenizer = JiebaBPETokenizer( osp.join(model_dir, 'tokenizer.json')) self.max_length = sequence_length @@ -252,6 +247,7 @@ class TextGenerationJiebaPreprocessor(TextGenerationPreprocessorBase): 'tokens': tokens[:-1], 'labels': tokens[1:], 'prompt_length': prompt_length, + 'is_pair': int(sequence2 is not None), } diff --git a/modelscope/trainers/nlp/gpt3_trainer.py b/modelscope/trainers/nlp/gpt3_trainer.py index d4db0b3d..afda6424 100644 --- a/modelscope/trainers/nlp/gpt3_trainer.py +++ b/modelscope/trainers/nlp/gpt3_trainer.py @@ -2,27 +2,25 @@ import os from collections.abc import Mapping -from typing import List +from typing import Any, Dict, List import torch from megatron_util import mpu from modelscope.metainfo import Trainers from modelscope.models import TorchModel +from modelscope.models.nlp import GPT3ForTextGeneration from modelscope.trainers.builder import TRAINERS from modelscope.trainers.nlp_trainer import NlpEpochBasedTrainer from modelscope.utils.config import Config -from modelscope.utils.file_utils import func_receive_dict_inputs @TRAINERS.register_module(module_name=Trainers.gpt3_trainer) class GPT3Trainer(NlpEpochBasedTrainer): def rebuild_config(self, cfg: Config): - super().rebuild_config(cfg) - cfg.model.rank = int(os.environ.get('LOCAL_RANK', -1)) - cfg.model.master_ip = os.environ.get('MASTER_ADDR', '127.0.0.1') - cfg.model.master_port = os.environ.get('MASTER_PORT', '29500') + cfg = super().rebuild_config(cfg) + cfg.model.rank = int(os.environ.get('RANK', 0)) return cfg def train_step(self, model: TorchModel, inputs: Mapping): @@ -39,13 +37,19 @@ class GPT3Trainer(NlpEpochBasedTrainer): model = self.model.module if self._dist else self.model model.eval() - with torch.no_grad(): - if isinstance( - data, - Mapping) and not func_receive_dict_inputs(model.generate): - result = model.generate(**data) - else: - result = model.generate(data) + if self._is_pair(data): + return self._generate_eval(model, data) + else: + return self._forward_eval(model, data) + + @staticmethod + def _is_pair(data: Dict[str, Any]) -> bool: + return 'is_pair' in data and bool(data['is_pair'][0]) + + def _generate_eval(self, model: GPT3ForTextGeneration, + data: Dict[str, Any]) -> Dict[str, Any]: + data['do_sample'] = False + result = model.generate(data) prompt_length: List[int] = data['prompt_length'] result['preds'] = [ @@ -56,6 +60,8 @@ class GPT3Trainer(NlpEpochBasedTrainer): self._decode(seq[skip_len - 1:]) for seq, skip_len in zip(data['labels'], prompt_length) ] - assert len(result['preds']) == len(data['tgts']) - return result + + def _forward_eval(self, model: GPT3ForTextGeneration, + data: Dict[str, Any]) -> Dict[str, Any]: + return model.forward(data) diff --git a/tests/trainers/test_finetune_gpt3.py b/tests/trainers/test_finetune_gpt3.py index 7a9e03d0..563d271c 100644 --- a/tests/trainers/test_finetune_gpt3.py +++ b/tests/trainers/test_finetune_gpt3.py @@ -52,6 +52,16 @@ class TestFinetuneTextGeneration(unittest.TestCase): 'batch_size_per_gpu': 16, 'workers_per_gpu': 1 } + cfg.train.hooks.append({ + 'type': 'EvaluationHook', + 'by_epoch': True, + 'interval': 1 + }) + cfg.evaluation.dataloader = { + 'batch_size_per_gpu': 8, + 'workers_per_gpu': 1 + } + cfg.evaluation.metrics = 'ppl' return cfg kwargs = dict( @@ -73,6 +83,7 @@ class TestFinetuneTextGeneration(unittest.TestCase): def test_finetune_dureader(self): # DuReader_robust-QG is an example data set, # users can also use their own data set for training + dataset_dict = MsDataset.load('DuReader_robust-QG') train_dataset = dataset_dict['train'].remap_columns({'text1': 'src_txt', 'text2': 'tgt_txt'}) \ @@ -81,6 +92,7 @@ class TestFinetuneTextGeneration(unittest.TestCase): .map(lambda example: {'src_txt': example['src_txt'].replace('[SEP]', '') + '\n'}) max_epochs = 10 + tmp_dir = './gpt3_dureader' num_warmup_steps = 200 @@ -98,7 +110,7 @@ class TestFinetuneTextGeneration(unittest.TestCase): 'by_epoch': False } } - cfg.train.optimizer = {'type': 'AdamW', 'lr': 3e-4} + cfg.train.optimizer = {'type': 'AdamW', 'lr': 1e-4} cfg.train.dataloader = { 'batch_size_per_gpu': 16, 'workers_per_gpu': 1