mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
[to #42322933] Fix saved checkpoint can't run with pipeline for gpt3
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11225426
This commit is contained in:
@@ -957,7 +957,7 @@ def split_state_dict(state_dict: Dict[str, torch.Tensor], model: GPT3Model,
|
||||
return state_dict
|
||||
|
||||
|
||||
def save_checkpoint(model: torch.nn.Module, filename: str) -> None:
|
||||
def save_checkpoint(model: torch.nn.Module, filename: str, **kwargs) -> None:
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
model = model.module
|
||||
|
||||
@@ -1147,14 +1147,21 @@ class DistributedGPT3(TorchModel):
|
||||
tokens = tokens[:, :(context_length + 1)]
|
||||
return TokenGeneratorOutput(sequences=tokens)
|
||||
|
||||
def state_dict(self):
|
||||
return self.dist_model.state_dict()
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
||||
return self.dist_model.state_dict(destination, prefix, keep_vars)
|
||||
|
||||
def save_pretrained(self,
|
||||
target_folder: Union[str, os.PathLike],
|
||||
save_checkpoint_names: Union[str, List[str]] = None,
|
||||
save_function: Callable = save_checkpoint,
|
||||
save_function: Callable = None,
|
||||
config: Optional[dict] = None,
|
||||
**kwargs):
|
||||
# DistributedPipeline type is different from task name
|
||||
config['pipeline']['type'] = 'gpt3-generation'
|
||||
# a temp fix for master_ip, master_port and rank
|
||||
# can be removed after refactoring megatron_util
|
||||
for unused_key in ('master_ip', 'master_port', 'rank'):
|
||||
config['model'].pop(unused_key, None)
|
||||
|
||||
return super().save_pretrained(target_folder, save_checkpoint_names,
|
||||
save_function, config, **kwargs)
|
||||
save_checkpoint, config, **kwargs)
|
||||
|
||||
@@ -70,3 +70,8 @@ class GPT3ForTextGeneration(TorchModel):
|
||||
gen_params['top_p'] = input.pop('top_p', None)
|
||||
sample_output = self.model.generate(**gen_params)
|
||||
return {'sequences': sample_output[0]}
|
||||
|
||||
def save_pretrained(self, *args, **kwargs):
|
||||
if not isinstance(self.model, GPT3Model):
|
||||
return self.model.save_pretrained(*args, **kwargs)
|
||||
return super().save_pretrained(*args, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user