add default args

This commit is contained in:
雨泓
2022-06-22 21:54:41 +08:00
parent 3987a3bf7d
commit 96e25be7d2
2 changed files with 2 additions and 2 deletions

View File

@@ -18,6 +18,6 @@ class SbertForSentenceSimilarity(SbertForSequenceClassificationBase):
model_cls (Optional[Any], optional): model loader, if None, use the
default loader to load model weights, by default None.
"""
super().__init__(model_dir, *args, **kwargs)
super().__init__(model_dir, *args, model_args={"num_labels": 2}, **kwargs)
self.model_dir = model_dir
assert self.model.config.num_labels == 2

View File

@@ -19,5 +19,5 @@ class SbertForSentimentClassification(SbertForSequenceClassificationBase):
model_cls (Optional[Any], optional): model loader, if None, use the
default loader to load model weights, by default None.
"""
super().__init__(model_dir, *args, **kwargs)
super().__init__(model_dir, *args, model_args={"num_labels": 2}, **kwargs)
assert self.model.config.num_labels == 2