Fix loading checkpoint for palm

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11607362
This commit is contained in:
hemu.zp
2023-02-09 12:24:51 +00:00
parent 9619c5fc83
commit b56b2ecc28

View File

@@ -638,10 +638,7 @@ class AbsSummarizer(PalmPreTrainedModel): # Model
self.generator.dense.weight = self.decoder.embeddings.weight
if checkpoint is not None:
if 'model' in checkpoint:
checkpoint = checkpoint['model']
for key in list(checkpoint.keys()):
checkpoint[key.replace('model.palm.', '')] = checkpoint[key]
checkpoint = self._unwrap_checkpoint(checkpoint)
self.load_state_dict(checkpoint, strict=False)
else:
for module in self.decoder.modules():
@@ -673,6 +670,17 @@ class AbsSummarizer(PalmPreTrainedModel): # Model
self.decoder.embeddings = tgt_embeddings
self.generator.dense.weight = self.decoder.embeddings.weight
@staticmethod
def _unwrap_checkpoint(checkpoint: Dict):
wrap_names = ('model', 'palm')
for name in wrap_names:
if name in checkpoint:
checkpoint = checkpoint[name]
for name in wrap_names:
checkpoint = {(k[len(name) + 1:] if k.startswith(name) else k): v
for k, v in checkpoint.items()}
return checkpoint
def forward(self, src, tgt, mask_src):
top_vec, _ = self.bert(src, mask_src, return_dict=False)
state = TransformerDecoderState(src)