diff --git a/modelscope/models/nlp/gpt3/distributed_gpt3.py b/modelscope/models/nlp/gpt3/distributed_gpt3.py index 424e43b4..232e1187 100644 --- a/modelscope/models/nlp/gpt3/distributed_gpt3.py +++ b/modelscope/models/nlp/gpt3/distributed_gpt3.py @@ -957,7 +957,7 @@ def split_state_dict(state_dict: Dict[str, torch.Tensor], model: GPT3Model, return state_dict -def save_checkpoint(model: torch.nn.Module, filename: str) -> None: +def save_checkpoint(model: torch.nn.Module, filename: str, **kwargs) -> None: if isinstance(model, torch.nn.parallel.DistributedDataParallel): model = model.module @@ -1147,14 +1147,21 @@ class DistributedGPT3(TorchModel): tokens = tokens[:, :(context_length + 1)] return TokenGeneratorOutput(sequences=tokens) - def state_dict(self): - return self.dist_model.state_dict() + def state_dict(self, destination=None, prefix='', keep_vars=False): + return self.dist_model.state_dict(destination, prefix, keep_vars) def save_pretrained(self, target_folder: Union[str, os.PathLike], save_checkpoint_names: Union[str, List[str]] = None, - save_function: Callable = save_checkpoint, + save_function: Callable = None, config: Optional[dict] = None, **kwargs): + # DistributedPipeline type is different from task name + config['pipeline']['type'] = 'gpt3-generation' + # a temp fix for master_ip, master_port and rank + # can be removed after refactoring megatron_util + for unused_key in ('master_ip', 'master_port', 'rank'): + config['model'].pop(unused_key, None) + return super().save_pretrained(target_folder, save_checkpoint_names, - save_function, config, **kwargs) + save_checkpoint, config, **kwargs) diff --git a/modelscope/models/nlp/gpt3/text_generation.py b/modelscope/models/nlp/gpt3/text_generation.py index 74335de6..0d6f33b5 100644 --- a/modelscope/models/nlp/gpt3/text_generation.py +++ b/modelscope/models/nlp/gpt3/text_generation.py @@ -70,3 +70,8 @@ class GPT3ForTextGeneration(TorchModel): gen_params['top_p'] = input.pop('top_p', None) sample_output = self.model.generate(**gen_params) return {'sequences': sample_output[0]} + + def save_pretrained(self, *args, **kwargs): + if not isinstance(self.model, GPT3Model): + return self.model.save_pretrained(*args, **kwargs) + return super().save_pretrained(*args, **kwargs)