diff --git a/modelscope/models/nlp/masked_language_model.py b/modelscope/models/nlp/masked_language_model.py index 4138da94..928bda7b 100644 --- a/modelscope/models/nlp/masked_language_model.py +++ b/modelscope/models/nlp/masked_language_model.py @@ -19,6 +19,12 @@ class MaskedLMModelBase(Model): def build_model(self): raise NotImplementedError() + def train(self): + return self.model.train() + + def eval(self): + return self.model.eval() + @property def config(self): if hasattr(self.model, "config"): diff --git a/modelscope/models/nlp/palm_for_text_generation.py b/modelscope/models/nlp/palm_for_text_generation.py index c0f66bad..f6c15387 100644 --- a/modelscope/models/nlp/palm_for_text_generation.py +++ b/modelscope/models/nlp/palm_for_text_generation.py @@ -26,6 +26,12 @@ class PalmForTextGeneration(Model): self.tokenizer = model.tokenizer self.generator = Translator(model) + def train(self): + return self.generator.train() + + def eval(self): + return self.generator.eval() + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: """return the result by the model diff --git a/modelscope/models/nlp/sbert_for_token_classification.py b/modelscope/models/nlp/sbert_for_token_classification.py index 50c2195b..80d99283 100644 --- a/modelscope/models/nlp/sbert_for_token_classification.py +++ b/modelscope/models/nlp/sbert_for_token_classification.py @@ -29,6 +29,12 @@ class SbertForTokenClassification(Model): self.model_dir) self.config = sofa.SbertConfig.from_pretrained(self.model_dir) + def train(self): + return self.model.train() + + def eval(self): + return self.model.eval() + def forward(self, input: Dict[str, Any]) -> Dict[str, Union[str, np.ndarray]]: """return the result by the model diff --git a/modelscope/preprocessors/nlp.py b/modelscope/preprocessors/nlp.py index d19b4f20..397d25eb 100644 --- a/modelscope/preprocessors/nlp.py +++ b/modelscope/preprocessors/nlp.py @@ -314,7 +314,7 @@ class TextGenerationPreprocessor(Preprocessor): rst['input_ids'].append(feature['input_ids']) rst['attention_mask'].append(feature['attention_mask']) - rst['token_type_ids'].append(feature['token_type_ids']) + # rst['token_type_ids'].append(feature['token_type_ids']) return {k: torch.tensor(v) for k, v in rst.items()}