QiDeBERTa-CSC / configuration.py
Morton-Li's picture
更新模型版本并修复问题。
b2f83ee
from transformers import PretrainedConfig
class QiDeBERTaConfig(PretrainedConfig):
model_type = "QiDeBERTa"
attribute_map = {
"hidden_size": "d_model",
"num_attention_heads": "num_heads",
"num_hidden_layers": "num_layers",
"intermediate_size": "d_ff",
}
def __init__(
self,
vocab_size=25500,
d_model=1024,
num_layers=24,
num_heads=16,
d_ff=4096,
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
initializer_range=0.02,
layer_norm_eps=1e-7,
relative_attention=True,
max_relative_positions=-1,
classifier_num_labels=-1,
unk_token_id=0,
bos_token_id=1,
eos_token_id=2,
pad_token_id=3,
mask_token_id=4,
position_biased_input=False,
position_buckets=256,
pos_att_type="p2c|c2p",
share_att_key=True,
**kwargs,
):
super().__init__(**kwargs)
self.d_model = d_model
self.num_layers = num_layers
self.num_heads = num_heads
self.d_ff = d_ff
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.relative_attention = relative_attention
self.max_relative_positions = max_relative_positions
self.classifier_num_labels = classifier_num_labels
self.unk_token_id = unk_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.mask_token_id = mask_token_id
self.position_biased_input = position_biased_input
self.share_att_key = share_att_key
self.position_buckets = position_buckets
# Backwards compatibility
if isinstance(pos_att_type, str):
pos_att_type = [x.strip() for x in pos_att_type.lower().split("|")]
self.pos_att_type = pos_att_type
self.vocab_size = vocab_size
self.layer_norm_eps = layer_norm_eps
self.pooler_hidden_size = kwargs.get("pooler_hidden_size", d_model)