diff --git a/modelscope/preprocessors/nlp/faq_question_answering_preprocessor.py b/modelscope/preprocessors/nlp/faq_question_answering_preprocessor.py index 72c8ed99..873a8448 100644 --- a/modelscope/preprocessors/nlp/faq_question_answering_preprocessor.py +++ b/modelscope/preprocessors/nlp/faq_question_answering_preprocessor.py @@ -18,11 +18,19 @@ class FaqQuestionAnsweringPreprocessor(NLPBasePreprocessor): def __init__(self, model_dir: str, *args, **kwargs): super(FaqQuestionAnsweringPreprocessor, self).__init__( model_dir, mode=ModeKeys.INFERENCE, **kwargs) + from transformers import BertTokenizer - self.tokenizer = BertTokenizer.from_pretrained(model_dir) + preprocessor_config = Config.from_file( os.path.join(model_dir, ModelFile.CONFIGURATION)).get( ConfigFields.preprocessor, {}) + if preprocessor_config.get('tokenizer', + 'BertTokenizer') == 'XLMRoberta': + from transformers import XLMRobertaTokenizer + self.tokenizer = XLMRobertaTokenizer.from_pretrained(model_dir) + else: + self.tokenizer = BertTokenizer.from_pretrained(model_dir) + self.MAX_LEN = preprocessor_config.get('max_seq_length', 50) self.label_dict = None