mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 04:29:22 +01:00
Fix loading checkpoint for palm
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11607362
This commit is contained in:
@@ -638,10 +638,7 @@ class AbsSummarizer(PalmPreTrainedModel): # Model
|
||||
self.generator.dense.weight = self.decoder.embeddings.weight
|
||||
|
||||
if checkpoint is not None:
|
||||
if 'model' in checkpoint:
|
||||
checkpoint = checkpoint['model']
|
||||
for key in list(checkpoint.keys()):
|
||||
checkpoint[key.replace('model.palm.', '')] = checkpoint[key]
|
||||
checkpoint = self._unwrap_checkpoint(checkpoint)
|
||||
self.load_state_dict(checkpoint, strict=False)
|
||||
else:
|
||||
for module in self.decoder.modules():
|
||||
@@ -673,6 +670,17 @@ class AbsSummarizer(PalmPreTrainedModel): # Model
|
||||
self.decoder.embeddings = tgt_embeddings
|
||||
self.generator.dense.weight = self.decoder.embeddings.weight
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_checkpoint(checkpoint: Dict):
|
||||
wrap_names = ('model', 'palm')
|
||||
for name in wrap_names:
|
||||
if name in checkpoint:
|
||||
checkpoint = checkpoint[name]
|
||||
for name in wrap_names:
|
||||
checkpoint = {(k[len(name) + 1:] if k.startswith(name) else k): v
|
||||
for k, v in checkpoint.items()}
|
||||
return checkpoint
|
||||
|
||||
def forward(self, src, tgt, mask_src):
|
||||
top_vec, _ = self.bert(src, mask_src, return_dict=False)
|
||||
state = TransformerDecoderState(src)
|
||||
|
||||
Reference in New Issue
Block a user