From 2c994ed7600abb1ffd32ab8bd681dd7909a6e0d8 Mon Sep 17 00:00:00 2001 From: "wenshen.xws" Date: Wed, 26 Oct 2022 16:18:27 +0800 Subject: [PATCH] [to #42322933]fix tokenizer for faq MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 多语言faq,Tokenizer新增类型判别 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10530690 --- .../nlp/faq_question_answering_preprocessor.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/modelscope/preprocessors/nlp/faq_question_answering_preprocessor.py b/modelscope/preprocessors/nlp/faq_question_answering_preprocessor.py index 72c8ed99..873a8448 100644 --- a/modelscope/preprocessors/nlp/faq_question_answering_preprocessor.py +++ b/modelscope/preprocessors/nlp/faq_question_answering_preprocessor.py @@ -18,11 +18,19 @@ class FaqQuestionAnsweringPreprocessor(NLPBasePreprocessor): def __init__(self, model_dir: str, *args, **kwargs): super(FaqQuestionAnsweringPreprocessor, self).__init__( model_dir, mode=ModeKeys.INFERENCE, **kwargs) + from transformers import BertTokenizer - self.tokenizer = BertTokenizer.from_pretrained(model_dir) + preprocessor_config = Config.from_file( os.path.join(model_dir, ModelFile.CONFIGURATION)).get( ConfigFields.preprocessor, {}) + if preprocessor_config.get('tokenizer', + 'BertTokenizer') == 'XLMRoberta': + from transformers import XLMRobertaTokenizer + self.tokenizer = XLMRobertaTokenizer.from_pretrained(model_dir) + else: + self.tokenizer = BertTokenizer.from_pretrained(model_dir) + self.MAX_LEN = preprocessor_config.get('max_seq_length', 50) self.label_dict = None