diff --git a/modelscope/models/nlp/palm_v2/text_generation.py b/modelscope/models/nlp/palm_v2/text_generation.py index 5bb446b5..a87b5cdd 100644 --- a/modelscope/models/nlp/palm_v2/text_generation.py +++ b/modelscope/models/nlp/palm_v2/text_generation.py @@ -638,10 +638,7 @@ class AbsSummarizer(PalmPreTrainedModel): # Model self.generator.dense.weight = self.decoder.embeddings.weight if checkpoint is not None: - if 'model' in checkpoint: - checkpoint = checkpoint['model'] - for key in list(checkpoint.keys()): - checkpoint[key.replace('model.palm.', '')] = checkpoint[key] + checkpoint = self._unwrap_checkpoint(checkpoint) self.load_state_dict(checkpoint, strict=False) else: for module in self.decoder.modules(): @@ -673,6 +670,17 @@ class AbsSummarizer(PalmPreTrainedModel): # Model self.decoder.embeddings = tgt_embeddings self.generator.dense.weight = self.decoder.embeddings.weight + @staticmethod + def _unwrap_checkpoint(checkpoint: Dict): + wrap_names = ('model', 'palm') + for name in wrap_names: + if name in checkpoint: + checkpoint = checkpoint[name] + for name in wrap_names: + checkpoint = {(k[len(name) + 1:] if k.startswith(name) else k): v + for k, v in checkpoint.items()} + return checkpoint + def forward(self, src, tgt, mask_src): top_vec, _ = self.bert(src, mask_src, return_dict=False) state = TransformerDecoderState(src)