[to #46619305] add kwargs in init method to allow additional kwargs

This commit is contained in:
zhangzhicheng.zzc
2022-12-07 18:42:29 +08:00
parent f59f9146de
commit 92c5abb076
5 changed files with 5 additions and 5 deletions

View File

@@ -57,7 +57,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
r'decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight',
]
def __init__(self, config: T5Config):
def __init__(self, config: T5Config, **kwargs):
super().__init__(config)
self.model_dim = config.d_model

View File

@@ -24,7 +24,7 @@ class BertForDocumentSegmentation(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r'pooler']
def __init__(self, config):
def __init__(self, config, **kwargs):
super().__init__(config)
self.num_labels = config.num_labels
self.sentence_pooler_type = None

View File

@@ -11,7 +11,7 @@ from .backbone import BertModel, BertPreTrainedModel
@MODELS.register_module(Tasks.sentence_embedding, module_name=Models.bert)
class BertForSentenceEmbedding(BertPreTrainedModel):
def __init__(self, config):
def __init__(self, config, **kwargs):
super().__init__(config)
self.config = config
setattr(self, self.base_model_prefix,

View File

@@ -66,7 +66,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
weights.
"""
def __init__(self, config):
def __init__(self, config, **kwargs):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config

View File

@@ -25,7 +25,7 @@ __all__ = ['PoNetForDocumentSegmentation']
class PoNetForDocumentSegmentation(PoNetPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r'pooler']
def __init__(self, config):
def __init__(self, config, **kwargs):
super().__init__(config)
self.num_labels = config.num_labels