mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 03:59:23 +01:00
all ut passed
This commit is contained in:
@@ -19,6 +19,12 @@ class MaskedLMModelBase(Model):
|
||||
def build_model(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def train(self):
|
||||
return self.model.train()
|
||||
|
||||
def eval(self):
|
||||
return self.model.eval()
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
if hasattr(self.model, "config"):
|
||||
|
||||
@@ -26,6 +26,12 @@ class PalmForTextGeneration(Model):
|
||||
self.tokenizer = model.tokenizer
|
||||
self.generator = Translator(model)
|
||||
|
||||
def train(self):
|
||||
return self.generator.train()
|
||||
|
||||
def eval(self):
|
||||
return self.generator.eval()
|
||||
|
||||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
"""return the result by the model
|
||||
|
||||
|
||||
@@ -29,6 +29,12 @@ class SbertForTokenClassification(Model):
|
||||
self.model_dir)
|
||||
self.config = sofa.SbertConfig.from_pretrained(self.model_dir)
|
||||
|
||||
def train(self):
|
||||
return self.model.train()
|
||||
|
||||
def eval(self):
|
||||
return self.model.eval()
|
||||
|
||||
def forward(self, input: Dict[str,
|
||||
Any]) -> Dict[str, Union[str, np.ndarray]]:
|
||||
"""return the result by the model
|
||||
|
||||
@@ -314,7 +314,7 @@ class TextGenerationPreprocessor(Preprocessor):
|
||||
|
||||
rst['input_ids'].append(feature['input_ids'])
|
||||
rst['attention_mask'].append(feature['attention_mask'])
|
||||
rst['token_type_ids'].append(feature['token_type_ids'])
|
||||
# rst['token_type_ids'].append(feature['token_type_ids'])
|
||||
return {k: torch.tensor(v) for k, v in rst.items()}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user