all ut passed

This commit is contained in:
雨泓
2022-06-23 10:54:57 +08:00
parent 96e25be7d2
commit 1476e08b82
4 changed files with 19 additions and 1 deletions

View File

@@ -19,6 +19,12 @@ class MaskedLMModelBase(Model):
def build_model(self):
raise NotImplementedError()
def train(self):
return self.model.train()
def eval(self):
return self.model.eval()
@property
def config(self):
if hasattr(self.model, "config"):

View File

@@ -26,6 +26,12 @@ class PalmForTextGeneration(Model):
self.tokenizer = model.tokenizer
self.generator = Translator(model)
def train(self):
return self.generator.train()
def eval(self):
return self.generator.eval()
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""return the result by the model

View File

@@ -29,6 +29,12 @@ class SbertForTokenClassification(Model):
self.model_dir)
self.config = sofa.SbertConfig.from_pretrained(self.model_dir)
def train(self):
return self.model.train()
def eval(self):
return self.model.eval()
def forward(self, input: Dict[str,
Any]) -> Dict[str, Union[str, np.ndarray]]:
"""return the result by the model

View File

@@ -314,7 +314,7 @@ class TextGenerationPreprocessor(Preprocessor):
rst['input_ids'].append(feature['input_ids'])
rst['attention_mask'].append(feature['attention_mask'])
rst['token_type_ids'].append(feature['token_type_ids'])
# rst['token_type_ids'].append(feature['token_type_ids'])
return {k: torch.tensor(v) for k, v in rst.items()}