|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn import init |
|
from transformers import PreTrainedModel |
|
from transformers import PretrainedConfig |
|
from transformers.modeling_outputs import MaskedLMOutput |
|
from typing import List |
|
from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss |
|
from transformers.modeling_outputs import ( |
|
BaseModelOutputWithPastAndCrossAttentions, |
|
BaseModelOutputWithPoolingAndCrossAttentions, |
|
MaskedLMOutput, |
|
SequenceClassifierOutput |
|
) |
|
|
|
|
|
|
|
|
|
class StructformerConfig(PretrainedConfig): |
|
model_type = "structformer" |
|
|
|
def __init__( |
|
self, |
|
hidden_size=768, |
|
n_context_layers=2, |
|
nlayers=6, |
|
ntokens=32000, |
|
nhead=8, |
|
dropout=0.1, |
|
dropatt=0.1, |
|
relative_bias=False, |
|
pos_emb=False, |
|
pad=0, |
|
n_parser_layers=4, |
|
conv_size=9, |
|
relations=('head', 'child'), |
|
weight_act='softmax', |
|
**kwargs, |
|
): |
|
self.hidden_size = hidden_size |
|
self.n_context_layers = n_context_layers |
|
self.nlayers = nlayers |
|
self.ntokens = ntokens |
|
self.nhead = nhead |
|
self.dropout = dropout |
|
self.dropatt = dropatt |
|
self.relative_bias = relative_bias |
|
self.pos_emb = pos_emb |
|
self.pad = pad |
|
self.n_parser_layers = n_parser_layers |
|
self.conv_size = conv_size |
|
self.relations = relations |
|
self.weight_act = weight_act |
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
|
|
def _get_activation_fn(activation): |
|
"""Get specified activation function.""" |
|
if activation == "relu": |
|
return nn.ReLU() |
|
elif activation == "gelu": |
|
return nn.GELU() |
|
elif activation == "leakyrelu": |
|
return nn.LeakyReLU() |
|
|
|
raise RuntimeError( |
|
"activation should be relu/gelu, not {}".format(activation)) |
|
|
|
class Conv1d(nn.Module): |
|
"""1D convolution layer.""" |
|
|
|
def __init__(self, hidden_size, kernel_size, dilation=1): |
|
"""Initialization. |
|
|
|
Args: |
|
hidden_size: dimension of input embeddings |
|
kernel_size: convolution kernel size |
|
dilation: the spacing between the kernel points |
|
""" |
|
super(Conv1d, self).__init__() |
|
|
|
if kernel_size % 2 == 0: |
|
padding = (kernel_size // 2) * dilation |
|
self.shift = True |
|
else: |
|
padding = ((kernel_size - 1) // 2) * dilation |
|
self.shift = False |
|
self.conv = nn.Conv1d( |
|
hidden_size, |
|
hidden_size, |
|
kernel_size, |
|
padding=padding, |
|
dilation=dilation) |
|
|
|
def forward(self, x): |
|
"""Compute convolution. |
|
|
|
Args: |
|
x: input embeddings |
|
Returns: |
|
conv_output: convolution results |
|
""" |
|
|
|
if self.shift: |
|
return self.conv(x.transpose(1, 2)).transpose(1, 2)[:, 1:] |
|
else: |
|
return self.conv(x.transpose(1, 2)).transpose(1, 2) |
|
|
|
class MultiheadAttention(nn.Module): |
|
"""Multi-head self-attention layer.""" |
|
|
|
def __init__(self, |
|
embed_dim, |
|
num_heads, |
|
dropout=0., |
|
bias=True, |
|
v_proj=True, |
|
out_proj=True, |
|
relative_bias=True): |
|
"""Initialization. |
|
|
|
Args: |
|
embed_dim: dimension of input embeddings |
|
num_heads: number of self-attention heads |
|
dropout: dropout rate |
|
bias: bool, indicate whether include bias for linear transformations |
|
v_proj: bool, indicate whether project inputs to new values |
|
out_proj: bool, indicate whether project outputs to new values |
|
relative_bias: bool, indicate whether use a relative position based |
|
attention bias |
|
""" |
|
|
|
super(MultiheadAttention, self).__init__() |
|
self.embed_dim = embed_dim |
|
|
|
self.num_heads = num_heads |
|
self.drop = nn.Dropout(dropout) |
|
self.head_dim = embed_dim // num_heads |
|
assert self.head_dim * num_heads == self.embed_dim, ("embed_dim must be " |
|
"divisible by " |
|
"num_heads") |
|
|
|
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
if v_proj: |
|
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
else: |
|
self.v_proj = nn.Identity() |
|
|
|
if out_proj: |
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) |
|
else: |
|
self.out_proj = nn.Identity() |
|
|
|
if relative_bias: |
|
self.relative_bias = nn.Parameter(torch.zeros((self.num_heads, 512))) |
|
else: |
|
self.relative_bias = None |
|
|
|
self._reset_parameters() |
|
|
|
def _reset_parameters(self): |
|
"""Initialize attention parameters.""" |
|
|
|
init.xavier_uniform_(self.q_proj.weight) |
|
init.constant_(self.q_proj.bias, 0.) |
|
|
|
init.xavier_uniform_(self.k_proj.weight) |
|
init.constant_(self.k_proj.bias, 0.) |
|
|
|
if isinstance(self.v_proj, nn.Linear): |
|
init.xavier_uniform_(self.v_proj.weight) |
|
init.constant_(self.v_proj.bias, 0.) |
|
|
|
if isinstance(self.out_proj, nn.Linear): |
|
init.xavier_uniform_(self.out_proj.weight) |
|
init.constant_(self.out_proj.bias, 0.) |
|
|
|
def forward(self, query, key_padding_mask=None, attn_mask=None): |
|
"""Compute multi-head self-attention. |
|
|
|
Args: |
|
query: input embeddings |
|
key_padding_mask: 3D mask that prevents attention to certain positions |
|
attn_mask: 3D mask that rescale the attention weight at each position |
|
Returns: |
|
attn_output: self-attention output |
|
""" |
|
|
|
length, bsz, embed_dim = query.size() |
|
assert embed_dim == self.embed_dim |
|
|
|
head_dim = embed_dim // self.num_heads |
|
assert head_dim * self.num_heads == embed_dim, ("embed_dim must be " |
|
"divisible by num_heads") |
|
scaling = float(head_dim)**-0.5 |
|
|
|
q = self.q_proj(query) |
|
k = self.k_proj(query) |
|
v = self.v_proj(query) |
|
|
|
q = q * scaling |
|
|
|
if attn_mask is not None: |
|
assert list(attn_mask.size()) == [bsz * self.num_heads, |
|
query.size(0), query.size(0)] |
|
|
|
q = q.contiguous().view(length, bsz * self.num_heads, |
|
head_dim).transpose(0, 1) |
|
k = k.contiguous().view(length, bsz * self.num_heads, |
|
head_dim).transpose(0, 1) |
|
v = v.contiguous().view(length, bsz * self.num_heads, |
|
head_dim).transpose(0, 1) |
|
|
|
attn_output_weights = torch.bmm(q, k.transpose(1, 2)) |
|
assert list( |
|
attn_output_weights.size()) == [bsz * self.num_heads, length, length] |
|
|
|
if self.relative_bias is not None: |
|
pos = torch.arange(length, device=query.device) |
|
relative_pos = torch.abs(pos[:, None] - pos[None, :]) + 256 |
|
relative_pos = relative_pos[None, :, :].expand(bsz * self.num_heads, -1, |
|
-1) |
|
|
|
relative_bias = self.relative_bias.repeat_interleave(bsz, dim=0) |
|
relative_bias = relative_bias[:, None, :].expand(-1, length, -1) |
|
relative_bias = torch.gather(relative_bias, 2, relative_pos) |
|
attn_output_weights = attn_output_weights + relative_bias |
|
|
|
if key_padding_mask is not None: |
|
attn_output_weights = attn_output_weights + key_padding_mask |
|
|
|
if attn_mask is None: |
|
attn_output_weights = torch.softmax(attn_output_weights, dim=-1) |
|
else: |
|
attn_output_weights = torch.sigmoid(attn_output_weights) * attn_mask |
|
|
|
attn_output_weights = self.drop(attn_output_weights) |
|
|
|
attn_output = torch.bmm(attn_output_weights, v) |
|
|
|
assert list(attn_output.size()) == [bsz * self.num_heads, length, head_dim] |
|
attn_output = attn_output.transpose(0, 1).contiguous().view( |
|
length, bsz, embed_dim) |
|
attn_output = self.out_proj(attn_output) |
|
|
|
return attn_output |
|
|
|
class TransformerLayer(nn.Module): |
|
"""TransformerEncoderLayer is made up of self-attn and feedforward network.""" |
|
|
|
def __init__(self, |
|
d_model, |
|
nhead, |
|
dim_feedforward=2048, |
|
dropout=0.1, |
|
dropatt=0.1, |
|
activation="leakyrelu", |
|
relative_bias=True): |
|
"""Initialization. |
|
|
|
Args: |
|
d_model: dimension of inputs |
|
nhead: number of self-attention heads |
|
dim_feedforward: dimension of hidden layer in feedforward layer |
|
dropout: dropout rate |
|
dropatt: drop attention rate |
|
activation: activation function |
|
relative_bias: bool, indicate whether use a relative position based |
|
attention bias |
|
""" |
|
|
|
super(TransformerLayer, self).__init__() |
|
|
|
self.self_attn = MultiheadAttention( |
|
d_model, nhead, dropout=dropatt, relative_bias=relative_bias) |
|
|
|
|
|
self.feedforward = nn.Sequential( |
|
nn.LayerNorm(d_model), nn.Linear(d_model, dim_feedforward), |
|
_get_activation_fn(activation), nn.Dropout(dropout), |
|
nn.Linear(dim_feedforward, d_model)) |
|
|
|
self.norm = nn.LayerNorm(d_model) |
|
self.dropout1 = nn.Dropout(dropout) |
|
self.dropout2 = nn.Dropout(dropout) |
|
|
|
self.nhead = nhead |
|
|
|
def forward(self, src, attn_mask=None, key_padding_mask=None): |
|
"""Pass the input through the encoder layer. |
|
|
|
Args: |
|
src: the sequence to the encoder layer (required). |
|
attn_mask: the mask for the src sequence (optional). |
|
key_padding_mask: the mask for the src keys per batch (optional). |
|
Returns: |
|
src3: the output of transformer layer, share the same shape as src. |
|
""" |
|
src2 = self.self_attn( |
|
self.norm(src), attn_mask=attn_mask, key_padding_mask=key_padding_mask) |
|
src2 = src + self.dropout1(src2) |
|
src3 = self.feedforward(src2) |
|
src3 = src2 + self.dropout2(src3) |
|
|
|
return src3 |
|
|
|
|
|
|
|
|
|
def cumprod(x, reverse=False, exclusive=False): |
|
"""cumulative product.""" |
|
if reverse: |
|
x = x.flip([-1]) |
|
|
|
if exclusive: |
|
x = F.pad(x[:, :, :-1], (1, 0), value=1) |
|
|
|
cx = x.cumprod(-1) |
|
|
|
if reverse: |
|
cx = cx.flip([-1]) |
|
return cx |
|
|
|
def cumsum(x, reverse=False, exclusive=False): |
|
"""cumulative sum.""" |
|
bsz, _, length = x.size() |
|
device = x.device |
|
if reverse: |
|
if exclusive: |
|
w = torch.ones([bsz, length, length], device=device).tril(-1) |
|
else: |
|
w = torch.ones([bsz, length, length], device=device).tril(0) |
|
cx = torch.bmm(x, w) |
|
else: |
|
if exclusive: |
|
w = torch.ones([bsz, length, length], device=device).triu(1) |
|
else: |
|
w = torch.ones([bsz, length, length], device=device).triu(0) |
|
cx = torch.bmm(x, w) |
|
return cx |
|
|
|
def cummin(x, reverse=False, exclusive=False, max_value=1e9): |
|
"""cumulative min.""" |
|
if reverse: |
|
if exclusive: |
|
x = F.pad(x[:, :, 1:], (0, 1), value=max_value) |
|
x = x.flip([-1]).cummin(-1)[0].flip([-1]) |
|
else: |
|
if exclusive: |
|
x = F.pad(x[:, :, :-1], (1, 0), value=max_value) |
|
x = x.cummin(-1)[0] |
|
return x |
|
|
|
class Transformer(nn.Module): |
|
"""Transformer model.""" |
|
|
|
def __init__(self, |
|
hidden_size, |
|
nlayers, |
|
ntokens, |
|
nhead=8, |
|
dropout=0.1, |
|
dropatt=0.1, |
|
relative_bias=True, |
|
pos_emb=False, |
|
pad=0): |
|
"""Initialization. |
|
|
|
Args: |
|
hidden_size: dimension of inputs and hidden states |
|
nlayers: number of layers |
|
ntokens: number of output categories |
|
nhead: number of self-attention heads |
|
dropout: dropout rate |
|
dropatt: drop attention rate |
|
relative_bias: bool, indicate whether use a relative position based |
|
attention bias |
|
pos_emb: bool, indicate whether use a learnable positional embedding |
|
pad: pad token index |
|
""" |
|
|
|
super(Transformer, self).__init__() |
|
|
|
self.drop = nn.Dropout(dropout) |
|
|
|
self.emb = nn.Embedding(ntokens, hidden_size) |
|
if pos_emb: |
|
self.pos_emb = nn.Embedding(500, hidden_size) |
|
|
|
self.layers = nn.ModuleList([ |
|
TransformerLayer(hidden_size, nhead, hidden_size * 4, dropout, |
|
dropatt=dropatt, relative_bias=relative_bias) |
|
for _ in range(nlayers)]) |
|
|
|
self.norm = nn.LayerNorm(hidden_size) |
|
|
|
self.output_layer = nn.Linear(hidden_size, ntokens) |
|
self.output_layer.weight = self.emb.weight |
|
|
|
self.init_weights() |
|
|
|
self.nlayers = nlayers |
|
self.nhead = nhead |
|
self.ntokens = ntokens |
|
self.hidden_size = hidden_size |
|
self.pad = pad |
|
|
|
def init_weights(self): |
|
"""Initialize token embedding and output bias.""" |
|
initrange = 0.1 |
|
self.emb.weight.data.uniform_(-initrange, initrange) |
|
if hasattr(self, 'pos_emb'): |
|
self.pos_emb.weight.data.uniform_(-initrange, initrange) |
|
self.output_layer.bias.data.fill_(0) |
|
|
|
def visibility(self, x, device): |
|
"""Mask pad tokens.""" |
|
visibility = (x != self.pad).float() |
|
visibility = visibility[:, None, :].expand(-1, x.size(1), -1) |
|
visibility = torch.repeat_interleave(visibility, self.nhead, dim=0) |
|
return visibility.log() |
|
|
|
def encode(self, x, pos): |
|
"""Standard transformer encode process.""" |
|
h = self.emb(x) |
|
if hasattr(self, 'pos_emb'): |
|
h = h + self.pos_emb(pos) |
|
h_list = [] |
|
visibility = self.visibility(x, x.device) |
|
|
|
for i in range(self.nlayers): |
|
h_list.append(h) |
|
h = self.layers[i]( |
|
h.transpose(0, 1), key_padding_mask=visibility).transpose(0, 1) |
|
|
|
output = h |
|
h_array = torch.stack(h_list, dim=2) |
|
|
|
return output, h_array |
|
|
|
def forward(self, x, pos): |
|
"""Pass the input through the encoder layer. |
|
|
|
Args: |
|
x: input tokens (required). |
|
pos: position for each token (optional). |
|
Returns: |
|
output: probability distributions for missing tokens. |
|
state_dict: parsing results and raw output |
|
""" |
|
|
|
batch_size, length = x.size() |
|
|
|
raw_output, _ = self.encode(x, pos) |
|
raw_output = self.norm(raw_output) |
|
raw_output = self.drop(raw_output) |
|
|
|
output = self.output_layer(raw_output) |
|
return output.view(batch_size * length, -1), {'raw_output': raw_output,} |
|
|
|
class StructFormer(Transformer): |
|
"""StructFormer model.""" |
|
|
|
def __init__(self, |
|
hidden_size, |
|
n_context_layers, |
|
nlayers, |
|
ntokens, |
|
nhead=8, |
|
dropout=0.1, |
|
dropatt=0.1, |
|
relative_bias=False, |
|
pos_emb=False, |
|
pad=0, |
|
n_parser_layers=4, |
|
conv_size=9, |
|
relations=('head', 'child'), |
|
weight_act='softmax'): |
|
"""Initialization. |
|
|
|
Args: |
|
hidden_size: dimension of inputs and hidden states |
|
nlayers: number of layers |
|
ntokens: number of output categories |
|
nhead: number of self-attention heads |
|
dropout: dropout rate |
|
dropatt: drop attention rate |
|
relative_bias: bool, indicate whether use a relative position based |
|
attention bias |
|
pos_emb: bool, indicate whether use a learnable positional embedding |
|
pad: pad token index |
|
n_parser_layers: number of parsing layers |
|
conv_size: convolution kernel size for parser |
|
relations: relations that are used to compute self attention |
|
weight_act: relations distribution activation function |
|
""" |
|
|
|
super(StructFormer, self).__init__( |
|
hidden_size, |
|
nlayers, |
|
ntokens, |
|
nhead=nhead, |
|
dropout=dropout, |
|
dropatt=dropatt, |
|
relative_bias=relative_bias, |
|
pos_emb=pos_emb, |
|
pad=pad) |
|
|
|
|
|
def encode(self, x, pos): |
|
h = self.emb(x) |
|
if hasattr(self, 'pos_emb'): |
|
h = h + self.pos_emb(pos) |
|
h_list = [] |
|
visibility = self.visibility(x, x.device) |
|
|
|
for i in range(self.nlayers): |
|
h_list.append(h) |
|
h = self.layers[i]( |
|
h.transpose(0, 1), key_padding_mask=visibility).transpose(0, 1) |
|
|
|
output = h |
|
h_array = torch.stack(h_list, dim=2) |
|
|
|
return output |
|
|
|
|
|
def forward(self, input_ids, labels=None, position_ids=None, **kwargs): |
|
|
|
x = input_ids |
|
batch_size, length = x.size() |
|
|
|
if position_ids is None: |
|
pos = torch.arange(length, device=x.device).expand(batch_size, length) |
|
|
|
raw_output = self.encode(x, pos) |
|
raw_output = self.norm(raw_output) |
|
raw_output = self.drop(raw_output) |
|
|
|
output = self.output_layer(raw_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss_fct = nn.CrossEntropyLoss() |
|
loss = loss_fct(output.view(batch_size * length, -1), labels.reshape(-1)) |
|
|
|
return MaskedLMOutput( |
|
loss=loss, |
|
logits=output, |
|
hidden_states=None, |
|
attentions=None, |
|
) |
|
|
|
|
|
|
|
|
|
class StructformerModel(PreTrainedModel): |
|
config_class = StructformerConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = StructFormer( |
|
hidden_size=config.hidden_size, |
|
n_context_layers=config.n_context_layers, |
|
nlayers=config.nlayers, |
|
ntokens=config.ntokens, |
|
nhead=config.nhead, |
|
dropout=config.dropout, |
|
dropatt=config.dropatt, |
|
relative_bias=config.relative_bias, |
|
pos_emb=config.pos_emb, |
|
pad=config.pad, |
|
n_parser_layers=config.n_parser_layers, |
|
conv_size=config.conv_size, |
|
relations=config.relations, |
|
weight_act=config.weight_act |
|
) |
|
|
|
def forward(self, input_ids, labels=None, **kwargs): |
|
return self.model(input_ids, labels=labels, **kwargs) |
|
|
|
|
|
class StructFormerClassification(Transformer): |
|
"""StructFormer model.""" |
|
|
|
def __init__(self, |
|
hidden_size, |
|
n_context_layers, |
|
nlayers, |
|
ntokens, |
|
nhead=8, |
|
dropout=0.1, |
|
dropatt=0.1, |
|
relative_bias=False, |
|
pos_emb=False, |
|
pad=0, |
|
n_parser_layers=4, |
|
conv_size=9, |
|
relations=('head', 'child'), |
|
weight_act='softmax', |
|
config=None, |
|
): |
|
|
|
|
|
super(StructFormerClassification, self).__init__( |
|
hidden_size, |
|
nlayers, |
|
ntokens, |
|
nhead=nhead, |
|
dropout=dropout, |
|
dropatt=dropatt, |
|
relative_bias=relative_bias, |
|
pos_emb=pos_emb, |
|
pad=pad) |
|
|
|
self.num_labels = config.num_labels |
|
self.config = config |
|
|
|
self.classifier = RobertaClassificationHead(config) |
|
|
|
def encode(self, x, pos): |
|
h = self.emb(x) |
|
if hasattr(self, 'pos_emb'): |
|
h = h + self.pos_emb(pos) |
|
h_list = [] |
|
visibility = self.visibility(x, x.device) |
|
|
|
for i in range(self.nlayers): |
|
h_list.append(h) |
|
h = self.layers[i]( |
|
h.transpose(0, 1), key_padding_mask=visibility).transpose(0, 1) |
|
|
|
output = h |
|
h_array = torch.stack(h_list, dim=2) |
|
|
|
return output |
|
|
|
|
|
def forward(self, input_ids, labels=None, position_ids=None, **kwargs): |
|
|
|
x = input_ids |
|
batch_size, length = x.size() |
|
|
|
if position_ids is None: |
|
pos = torch.arange(length, device=x.device).expand(batch_size, length) |
|
|
|
raw_output = self.encode(x, pos) |
|
raw_output = self.norm(raw_output) |
|
raw_output = self.drop(raw_output) |
|
|
|
|
|
logits = self.classifier(raw_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
if self.config.problem_type is None: |
|
if self.num_labels == 1: |
|
self.config.problem_type = "regression" |
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): |
|
self.config.problem_type = "single_label_classification" |
|
else: |
|
self.config.problem_type = "multi_label_classification" |
|
|
|
if self.config.problem_type == "regression": |
|
loss_fct = MSELoss() |
|
if self.num_labels == 1: |
|
loss = loss_fct(logits.squeeze(), labels.squeeze()) |
|
else: |
|
loss = loss_fct(logits, labels) |
|
elif self.config.problem_type == "single_label_classification": |
|
loss_fct = CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
elif self.config.problem_type == "multi_label_classification": |
|
loss_fct = BCEWithLogitsLoss() |
|
loss = loss_fct(logits, labels) |
|
|
|
|
|
return SequenceClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=None, |
|
attentions=None, |
|
) |
|
|
|
|
|
class StructformerModelForSequenceClassification(PreTrainedModel): |
|
config_class = StructformerConfig |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = StructFormerClassification( |
|
hidden_size=config.hidden_size, |
|
n_context_layers=config.n_context_layers, |
|
nlayers=config.nlayers, |
|
ntokens=config.ntokens, |
|
nhead=config.nhead, |
|
dropout=config.dropout, |
|
dropatt=config.dropatt, |
|
relative_bias=config.relative_bias, |
|
pos_emb=config.pos_emb, |
|
pad=config.pad, |
|
n_parser_layers=config.n_parser_layers, |
|
conv_size=config.conv_size, |
|
relations=config.relations, |
|
weight_act=config.weight_act, |
|
config=config) |
|
|
|
def _init_weights(self, module): |
|
"""Initialize the weights""" |
|
if isinstance(module, nn.Linear): |
|
|
|
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
elif isinstance(module, nn.LayerNorm): |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
|
|
|
|
def forward(self, input_ids, labels=None, **kwargs): |
|
return self.model(input_ids, labels=labels, **kwargs) |