[to #42322933] Fix saved checkpoint can't run with pipeline for gpt3

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11225426
This commit is contained in:
hemu.zp
2022-12-28 06:37:54 +08:00
committed by yingda.chen
parent 59b7f411b8
commit addda1f613
2 changed files with 17 additions and 5 deletions

View File

@@ -957,7 +957,7 @@ def split_state_dict(state_dict: Dict[str, torch.Tensor], model: GPT3Model,
return state_dict
def save_checkpoint(model: torch.nn.Module, filename: str) -> None:
def save_checkpoint(model: torch.nn.Module, filename: str, **kwargs) -> None:
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
model = model.module
@@ -1147,14 +1147,21 @@ class DistributedGPT3(TorchModel):
tokens = tokens[:, :(context_length + 1)]
return TokenGeneratorOutput(sequences=tokens)
def state_dict(self):
return self.dist_model.state_dict()
def state_dict(self, destination=None, prefix='', keep_vars=False):
return self.dist_model.state_dict(destination, prefix, keep_vars)
def save_pretrained(self,
target_folder: Union[str, os.PathLike],
save_checkpoint_names: Union[str, List[str]] = None,
save_function: Callable = save_checkpoint,
save_function: Callable = None,
config: Optional[dict] = None,
**kwargs):
# DistributedPipeline type is different from task name
config['pipeline']['type'] = 'gpt3-generation'
# a temp fix for master_ip, master_port and rank
# can be removed after refactoring megatron_util
for unused_key in ('master_ip', 'master_port', 'rank'):
config['model'].pop(unused_key, None)
return super().save_pretrained(target_folder, save_checkpoint_names,
save_function, config, **kwargs)
save_checkpoint, config, **kwargs)

View File

@@ -70,3 +70,8 @@ class GPT3ForTextGeneration(TorchModel):
gen_params['top_p'] = input.pop('top_p', None)
sample_output = self.model.generate(**gen_params)
return {'sequences': sample_output[0]}
def save_pretrained(self, *args, **kwargs):
if not isinstance(self.model, GPT3Model):
return self.model.save_pretrained(*args, **kwargs)
return super().save_pretrained(*args, **kwargs)