From 2ebb6db2604813f0863dee4b5730d43dbd34d48e Mon Sep 17 00:00:00 2001 From: "zhangyanzhao.zyz" Date: Tue, 21 Feb 2023 22:42:13 +0800 Subject: [PATCH] fix bug of modelscope.trainers.nlp.sentence_embedding_trainer.get_data_collator MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修正sentence_embedding_trainer 初始化bug Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11725227 --- modelscope/trainers/nlp/sentence_embedding_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modelscope/trainers/nlp/sentence_embedding_trainer.py b/modelscope/trainers/nlp/sentence_embedding_trainer.py index b2116443..ebf8b50a 100644 --- a/modelscope/trainers/nlp/sentence_embedding_trainer.py +++ b/modelscope/trainers/nlp/sentence_embedding_trainer.py @@ -86,7 +86,7 @@ class SentenceEmbeddingTrainer(NlpEpochBasedTrainer): model_revision=model_revision, **kwargs) - def get_data_collator(self, data_collator): + def get_data_collator(self, data_collator, **kwargs): """Get the data collator for both training and evaluating. Args: @@ -99,7 +99,7 @@ class SentenceEmbeddingTrainer(NlpEpochBasedTrainer): data_collator = SentenceEmbeddingCollator( tokenizer=self.train_preprocessor.nlp_tokenizer, max_length=self.train_preprocessor.max_length) - return super().get_data_collator(data_collator) + return super().get_data_collator(data_collator, **kwargs) def evauate(self): return {}