transformer_base_final_2 / structformer_as_hf_no_parser.py
omarmomen's picture
add model
0fbf7fe
raw
history blame
22.7 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from transformers import PreTrainedModel
from transformers import PretrainedConfig
from transformers.modeling_outputs import MaskedLMOutput
from typing import List
from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
MaskedLMOutput,
SequenceClassifierOutput
)
##########################################
# HuggingFace Config
##########################################
class StructformerConfig(PretrainedConfig):
model_type = "structformer"
def __init__(
self,
hidden_size=768,
n_context_layers=2,
nlayers=6,
ntokens=32000,
nhead=8,
dropout=0.1,
dropatt=0.1,
relative_bias=False,
pos_emb=False,
pad=0,
n_parser_layers=4,
conv_size=9,
relations=('head', 'child'),
weight_act='softmax',
**kwargs,
):
self.hidden_size = hidden_size
self.n_context_layers = n_context_layers
self.nlayers = nlayers
self.ntokens = ntokens
self.nhead = nhead
self.dropout = dropout
self.dropatt = dropatt
self.relative_bias = relative_bias
self.pos_emb = pos_emb
self.pad = pad
self.n_parser_layers = n_parser_layers
self.conv_size = conv_size
self.relations = relations
self.weight_act = weight_act
super().__init__(**kwargs)
##########################################
# Custom Layers
##########################################
def _get_activation_fn(activation):
"""Get specified activation function."""
if activation == "relu":
return nn.ReLU()
elif activation == "gelu":
return nn.GELU()
elif activation == "leakyrelu":
return nn.LeakyReLU()
raise RuntimeError(
"activation should be relu/gelu, not {}".format(activation))
class Conv1d(nn.Module):
"""1D convolution layer."""
def __init__(self, hidden_size, kernel_size, dilation=1):
"""Initialization.
Args:
hidden_size: dimension of input embeddings
kernel_size: convolution kernel size
dilation: the spacing between the kernel points
"""
super(Conv1d, self).__init__()
if kernel_size % 2 == 0:
padding = (kernel_size // 2) * dilation
self.shift = True
else:
padding = ((kernel_size - 1) // 2) * dilation
self.shift = False
self.conv = nn.Conv1d(
hidden_size,
hidden_size,
kernel_size,
padding=padding,
dilation=dilation)
def forward(self, x):
"""Compute convolution.
Args:
x: input embeddings
Returns:
conv_output: convolution results
"""
if self.shift:
return self.conv(x.transpose(1, 2)).transpose(1, 2)[:, 1:]
else:
return self.conv(x.transpose(1, 2)).transpose(1, 2)
class MultiheadAttention(nn.Module):
"""Multi-head self-attention layer."""
def __init__(self,
embed_dim,
num_heads,
dropout=0.,
bias=True,
v_proj=True,
out_proj=True,
relative_bias=True):
"""Initialization.
Args:
embed_dim: dimension of input embeddings
num_heads: number of self-attention heads
dropout: dropout rate
bias: bool, indicate whether include bias for linear transformations
v_proj: bool, indicate whether project inputs to new values
out_proj: bool, indicate whether project outputs to new values
relative_bias: bool, indicate whether use a relative position based
attention bias
"""
super(MultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.drop = nn.Dropout(dropout)
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, ("embed_dim must be "
"divisible by "
"num_heads")
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
if v_proj:
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
else:
self.v_proj = nn.Identity()
if out_proj:
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
else:
self.out_proj = nn.Identity()
if relative_bias:
self.relative_bias = nn.Parameter(torch.zeros((self.num_heads, 512)))
else:
self.relative_bias = None
self._reset_parameters()
def _reset_parameters(self):
"""Initialize attention parameters."""
init.xavier_uniform_(self.q_proj.weight)
init.constant_(self.q_proj.bias, 0.)
init.xavier_uniform_(self.k_proj.weight)
init.constant_(self.k_proj.bias, 0.)
if isinstance(self.v_proj, nn.Linear):
init.xavier_uniform_(self.v_proj.weight)
init.constant_(self.v_proj.bias, 0.)
if isinstance(self.out_proj, nn.Linear):
init.xavier_uniform_(self.out_proj.weight)
init.constant_(self.out_proj.bias, 0.)
def forward(self, query, key_padding_mask=None, attn_mask=None):
"""Compute multi-head self-attention.
Args:
query: input embeddings
key_padding_mask: 3D mask that prevents attention to certain positions
attn_mask: 3D mask that rescale the attention weight at each position
Returns:
attn_output: self-attention output
"""
length, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
head_dim = embed_dim // self.num_heads
assert head_dim * self.num_heads == embed_dim, ("embed_dim must be "
"divisible by num_heads")
scaling = float(head_dim)**-0.5
q = self.q_proj(query)
k = self.k_proj(query)
v = self.v_proj(query)
q = q * scaling
if attn_mask is not None:
assert list(attn_mask.size()) == [bsz * self.num_heads,
query.size(0), query.size(0)]
q = q.contiguous().view(length, bsz * self.num_heads,
head_dim).transpose(0, 1)
k = k.contiguous().view(length, bsz * self.num_heads,
head_dim).transpose(0, 1)
v = v.contiguous().view(length, bsz * self.num_heads,
head_dim).transpose(0, 1)
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
assert list(
attn_output_weights.size()) == [bsz * self.num_heads, length, length]
if self.relative_bias is not None:
pos = torch.arange(length, device=query.device)
relative_pos = torch.abs(pos[:, None] - pos[None, :]) + 256
relative_pos = relative_pos[None, :, :].expand(bsz * self.num_heads, -1,
-1)
relative_bias = self.relative_bias.repeat_interleave(bsz, dim=0)
relative_bias = relative_bias[:, None, :].expand(-1, length, -1)
relative_bias = torch.gather(relative_bias, 2, relative_pos)
attn_output_weights = attn_output_weights + relative_bias
if key_padding_mask is not None:
attn_output_weights = attn_output_weights + key_padding_mask
if attn_mask is None:
attn_output_weights = torch.softmax(attn_output_weights, dim=-1)
else:
attn_output_weights = torch.sigmoid(attn_output_weights) * attn_mask
attn_output_weights = self.drop(attn_output_weights)
attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * self.num_heads, length, head_dim]
attn_output = attn_output.transpose(0, 1).contiguous().view(
length, bsz, embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output
class TransformerLayer(nn.Module):
"""TransformerEncoderLayer is made up of self-attn and feedforward network."""
def __init__(self,
d_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
dropatt=0.1,
activation="leakyrelu",
relative_bias=True):
"""Initialization.
Args:
d_model: dimension of inputs
nhead: number of self-attention heads
dim_feedforward: dimension of hidden layer in feedforward layer
dropout: dropout rate
dropatt: drop attention rate
activation: activation function
relative_bias: bool, indicate whether use a relative position based
attention bias
"""
super(TransformerLayer, self).__init__()
self.self_attn = MultiheadAttention(
d_model, nhead, dropout=dropatt, relative_bias=relative_bias)
# Implementation of Feedforward model
self.feedforward = nn.Sequential(
nn.LayerNorm(d_model), nn.Linear(d_model, dim_feedforward),
_get_activation_fn(activation), nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model))
self.norm = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.nhead = nhead
def forward(self, src, attn_mask=None, key_padding_mask=None):
"""Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
attn_mask: the mask for the src sequence (optional).
key_padding_mask: the mask for the src keys per batch (optional).
Returns:
src3: the output of transformer layer, share the same shape as src.
"""
src2 = self.self_attn(
self.norm(src), attn_mask=attn_mask, key_padding_mask=key_padding_mask)
src2 = src + self.dropout1(src2)
src3 = self.feedforward(src2)
src3 = src2 + self.dropout2(src3)
return src3
##########################################
# Custom Models
##########################################
def cumprod(x, reverse=False, exclusive=False):
"""cumulative product."""
if reverse:
x = x.flip([-1])
if exclusive:
x = F.pad(x[:, :, :-1], (1, 0), value=1)
cx = x.cumprod(-1)
if reverse:
cx = cx.flip([-1])
return cx
def cumsum(x, reverse=False, exclusive=False):
"""cumulative sum."""
bsz, _, length = x.size()
device = x.device
if reverse:
if exclusive:
w = torch.ones([bsz, length, length], device=device).tril(-1)
else:
w = torch.ones([bsz, length, length], device=device).tril(0)
cx = torch.bmm(x, w)
else:
if exclusive:
w = torch.ones([bsz, length, length], device=device).triu(1)
else:
w = torch.ones([bsz, length, length], device=device).triu(0)
cx = torch.bmm(x, w)
return cx
def cummin(x, reverse=False, exclusive=False, max_value=1e9):
"""cumulative min."""
if reverse:
if exclusive:
x = F.pad(x[:, :, 1:], (0, 1), value=max_value)
x = x.flip([-1]).cummin(-1)[0].flip([-1])
else:
if exclusive:
x = F.pad(x[:, :, :-1], (1, 0), value=max_value)
x = x.cummin(-1)[0]
return x
class Transformer(nn.Module):
"""Transformer model."""
def __init__(self,
hidden_size,
nlayers,
ntokens,
nhead=8,
dropout=0.1,
dropatt=0.1,
relative_bias=True,
pos_emb=False,
pad=0):
"""Initialization.
Args:
hidden_size: dimension of inputs and hidden states
nlayers: number of layers
ntokens: number of output categories
nhead: number of self-attention heads
dropout: dropout rate
dropatt: drop attention rate
relative_bias: bool, indicate whether use a relative position based
attention bias
pos_emb: bool, indicate whether use a learnable positional embedding
pad: pad token index
"""
super(Transformer, self).__init__()
self.drop = nn.Dropout(dropout)
self.emb = nn.Embedding(ntokens, hidden_size)
if pos_emb:
self.pos_emb = nn.Embedding(500, hidden_size)
self.layers = nn.ModuleList([
TransformerLayer(hidden_size, nhead, hidden_size * 4, dropout,
dropatt=dropatt, relative_bias=relative_bias)
for _ in range(nlayers)])
self.norm = nn.LayerNorm(hidden_size)
self.output_layer = nn.Linear(hidden_size, ntokens)
self.output_layer.weight = self.emb.weight
self.init_weights()
self.nlayers = nlayers
self.nhead = nhead
self.ntokens = ntokens
self.hidden_size = hidden_size
self.pad = pad
def init_weights(self):
"""Initialize token embedding and output bias."""
initrange = 0.1
self.emb.weight.data.uniform_(-initrange, initrange)
if hasattr(self, 'pos_emb'):
self.pos_emb.weight.data.uniform_(-initrange, initrange)
self.output_layer.bias.data.fill_(0)
def visibility(self, x, device):
"""Mask pad tokens."""
visibility = (x != self.pad).float()
visibility = visibility[:, None, :].expand(-1, x.size(1), -1)
visibility = torch.repeat_interleave(visibility, self.nhead, dim=0)
return visibility.log()
def encode(self, x, pos):
"""Standard transformer encode process."""
h = self.emb(x)
if hasattr(self, 'pos_emb'):
h = h + self.pos_emb(pos)
h_list = []
visibility = self.visibility(x, x.device)
for i in range(self.nlayers):
h_list.append(h)
h = self.layers[i](
h.transpose(0, 1), key_padding_mask=visibility).transpose(0, 1)
output = h
h_array = torch.stack(h_list, dim=2)
return output, h_array
def forward(self, x, pos):
"""Pass the input through the encoder layer.
Args:
x: input tokens (required).
pos: position for each token (optional).
Returns:
output: probability distributions for missing tokens.
state_dict: parsing results and raw output
"""
batch_size, length = x.size()
raw_output, _ = self.encode(x, pos)
raw_output = self.norm(raw_output)
raw_output = self.drop(raw_output)
output = self.output_layer(raw_output)
return output.view(batch_size * length, -1), {'raw_output': raw_output,}
class StructFormer(Transformer):
"""StructFormer model."""
def __init__(self,
hidden_size,
n_context_layers,
nlayers,
ntokens,
nhead=8,
dropout=0.1,
dropatt=0.1,
relative_bias=False,
pos_emb=False,
pad=0,
n_parser_layers=4,
conv_size=9,
relations=('head', 'child'),
weight_act='softmax'):
"""Initialization.
Args:
hidden_size: dimension of inputs and hidden states
nlayers: number of layers
ntokens: number of output categories
nhead: number of self-attention heads
dropout: dropout rate
dropatt: drop attention rate
relative_bias: bool, indicate whether use a relative position based
attention bias
pos_emb: bool, indicate whether use a learnable positional embedding
pad: pad token index
n_parser_layers: number of parsing layers
conv_size: convolution kernel size for parser
relations: relations that are used to compute self attention
weight_act: relations distribution activation function
"""
super(StructFormer, self).__init__(
hidden_size,
nlayers,
ntokens,
nhead=nhead,
dropout=dropout,
dropatt=dropatt,
relative_bias=relative_bias,
pos_emb=pos_emb,
pad=pad)
def encode(self, x, pos):
h = self.emb(x)
if hasattr(self, 'pos_emb'):
h = h + self.pos_emb(pos)
h_list = []
visibility = self.visibility(x, x.device)
for i in range(self.nlayers):
h_list.append(h)
h = self.layers[i](
h.transpose(0, 1), key_padding_mask=visibility).transpose(0, 1)
output = h
h_array = torch.stack(h_list, dim=2)
return output
def forward(self, input_ids, labels=None, position_ids=None, **kwargs):
x = input_ids
batch_size, length = x.size()
if position_ids is None:
pos = torch.arange(length, device=x.device).expand(batch_size, length)
raw_output = self.encode(x, pos)
raw_output = self.norm(raw_output)
raw_output = self.drop(raw_output)
output = self.output_layer(raw_output)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(output.view(batch_size * length, -1), labels.reshape(-1))
return MaskedLMOutput(
loss=loss, # shape: 1
logits=output, # shape: (batch_size * length, ntokens)
hidden_states=None,
attentions=None,
)
##########################################
# HuggingFace Model
##########################################
class StructformerModel(PreTrainedModel):
config_class = StructformerConfig
def __init__(self, config):
super().__init__(config)
self.model = StructFormer(
hidden_size=config.hidden_size,
n_context_layers=config.n_context_layers,
nlayers=config.nlayers,
ntokens=config.ntokens,
nhead=config.nhead,
dropout=config.dropout,
dropatt=config.dropatt,
relative_bias=config.relative_bias,
pos_emb=config.pos_emb,
pad=config.pad,
n_parser_layers=config.n_parser_layers,
conv_size=config.conv_size,
relations=config.relations,
weight_act=config.weight_act
)
def forward(self, input_ids, labels=None, **kwargs):
return self.model(input_ids, labels=labels, **kwargs)
class StructFormerClassification(Transformer):
"""StructFormer model."""
def __init__(self,
hidden_size,
n_context_layers,
nlayers,
ntokens,
nhead=8,
dropout=0.1,
dropatt=0.1,
relative_bias=False,
pos_emb=False,
pad=0,
n_parser_layers=4,
conv_size=9,
relations=('head', 'child'),
weight_act='softmax',
config=None,
):
super(StructFormerClassification, self).__init__(
hidden_size,
nlayers,
ntokens,
nhead=nhead,
dropout=dropout,
dropatt=dropatt,
relative_bias=relative_bias,
pos_emb=pos_emb,
pad=pad)
self.num_labels = config.num_labels
self.config = config
self.classifier = RobertaClassificationHead(config)
def encode(self, x, pos):
h = self.emb(x)
if hasattr(self, 'pos_emb'):
h = h + self.pos_emb(pos)
h_list = []
visibility = self.visibility(x, x.device)
for i in range(self.nlayers):
h_list.append(h)
h = self.layers[i](
h.transpose(0, 1), key_padding_mask=visibility).transpose(0, 1)
output = h
h_array = torch.stack(h_list, dim=2)
return output
def forward(self, input_ids, labels=None, position_ids=None, **kwargs):
x = input_ids
batch_size, length = x.size()
if position_ids is None:
pos = torch.arange(length, device=x.device).expand(batch_size, length)
raw_output = self.encode(x, pos)
raw_output = self.norm(raw_output)
raw_output = self.drop(raw_output)
#output = self.output_layer(raw_output)
logits = self.classifier(raw_output)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=None,
attentions=None,
)
class StructformerModelForSequenceClassification(PreTrainedModel):
config_class = StructformerConfig
def __init__(self, config):
super().__init__(config)
self.model = StructFormerClassification(
hidden_size=config.hidden_size,
n_context_layers=config.n_context_layers,
nlayers=config.nlayers,
ntokens=config.ntokens,
nhead=config.nhead,
dropout=config.dropout,
dropatt=config.dropatt,
relative_bias=config.relative_bias,
pos_emb=config.pos_emb,
pad=config.pad,
n_parser_layers=config.n_parser_layers,
conv_size=config.conv_size,
relations=config.relations,
weight_act=config.weight_act,
config=config)
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
if module.bias is not None:
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, input_ids, labels=None, **kwargs):
return self.model(input_ids, labels=labels, **kwargs)