|
import math |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from torch.nn.parameter import Parameter |
|
from utils.nlinalg import logsumexp, logdet |
|
from utils.tasks import parse |
|
from .attention import BiAAttention |
|
|
|
|
|
class ChainCRF(nn.Module): |
|
def __init__(self, input_size, num_labels, bigram=True): |
|
''' |
|
|
|
Args: |
|
input_size: int |
|
the dimension of the input. |
|
num_labels: int |
|
the number of labels of the crf layer |
|
bigram: bool |
|
if apply bi-gram parameter. |
|
''' |
|
super(ChainCRF, self).__init__() |
|
self.input_size = input_size |
|
self.num_labels = num_labels + 1 |
|
self.pad_label_id = num_labels |
|
self.bigram = bigram |
|
|
|
|
|
|
|
self.state_nn = nn.Linear(input_size, self.num_labels) |
|
if bigram: |
|
|
|
self.trans_nn = nn.Linear(input_size, self.num_labels * self.num_labels) |
|
self.register_parameter('trans_matrix', None) |
|
else: |
|
self.trans_nn = None |
|
self.trans_matrix = Parameter(torch.Tensor(self.num_labels, self.num_labels)) |
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
nn.init.constant_(self.state_nn.bias, 0.) |
|
if self.bigram: |
|
nn.init.xavier_uniform_(self.trans_nn.weight) |
|
nn.init.constant_(self.trans_nn.bias, 0.) |
|
else: |
|
nn.init.normal_(self.trans_matrix) |
|
|
|
|
|
|
|
def forward(self, input, mask=None): |
|
''' |
|
|
|
Args: |
|
input: Tensor |
|
the input tensor with shape = [batch_size, length, input_size] |
|
mask: Tensor or None |
|
the mask tensor with shape = [batch_size, length] |
|
|
|
Returns: Tensor |
|
the energy tensor with shape = [batch_size, length, num_label, num_label] |
|
|
|
''' |
|
batch_size, length, _ = input.size() |
|
|
|
|
|
|
|
out_s = self.state_nn(input).unsqueeze(-1) |
|
|
|
if self.bigram: |
|
|
|
out_t = self.trans_nn(input).view(batch_size, length, self.num_labels, self.num_labels) |
|
else: |
|
out_t = self.trans_matrix |
|
|
|
|
|
|
|
|
|
return (out_s, out_t) |
|
|
|
def loss(self, energy, target, mask=None, length=None): |
|
''' |
|
|
|
Args: |
|
energy: Tensor |
|
the input tensor with shape = [batch_size, length, num_label, num_label] |
|
target: Tensor |
|
the tensor of target labels with shape [batch_size, length] |
|
mask:Tensor or None |
|
the mask tensor with shape = [batch_size, length] |
|
|
|
Returns: Tensor |
|
A 1D tensor for minus log likelihood loss |
|
''' |
|
if length is not None: |
|
max_len = length.max() |
|
if energy.size(1) != max_len: |
|
target = target[:, :max_len] |
|
mask_transpose = None |
|
if mask is not None: |
|
energy = energy * mask.unsqueeze(2).unsqueeze(3) |
|
mask_transpose = mask.unsqueeze(2).transpose(0, 1) |
|
|
|
batch_size, len, _, _ = energy.size() |
|
|
|
energy_transpose = energy.transpose(0, 1) |
|
|
|
target_transpose = target.transpose(0, 1) |
|
|
|
|
|
|
|
partition = None |
|
if energy.is_cuda: |
|
|
|
batch_index = torch.arange(0, batch_size).long().cuda() |
|
prev_label = torch.cuda.LongTensor(batch_size).fill_(self.num_labels - 1) |
|
tgt_energy = torch.zeros(batch_size).cuda() |
|
else: |
|
|
|
batch_index = torch.arange(0, batch_size).long() |
|
prev_label = torch.LongTensor(batch_size).fill_(self.num_labels - 1) |
|
tgt_energy = torch.zeros(batch_size) |
|
|
|
for t in range(len): |
|
|
|
curr_energy = energy_transpose[t] |
|
if t == 0: |
|
partition = curr_energy[:, -1, :] |
|
else: |
|
|
|
partition_new = logsumexp(curr_energy + partition.unsqueeze(2), dim=1) |
|
if mask_transpose is None: |
|
partition = partition_new |
|
else: |
|
mask_t = mask_transpose[t] |
|
partition = partition + (partition_new - partition) * mask_t |
|
tgt_energy += curr_energy[batch_index, prev_label, target_transpose[t].data] |
|
prev_label = target_transpose[t].data |
|
return logsumexp(partition, dim=1) - tgt_energy |
|
|
|
def decode(self, energy, mask=None, leading_symbolic=0): |
|
""" |
|
|
|
Args: |
|
energy: Tensor |
|
the input tensor with shape = [length, batch_size, num_label, num_label] |
|
leading_symbolic: nt |
|
number of symbolic labels leading in type alphabets (set it to 0 if you are not sure) |
|
|
|
Returns: Tensor |
|
decoding results in shape [batch_size, length] |
|
|
|
""" |
|
if mask is not None: |
|
energy = energy * mask.unsqueeze(2).unsqueeze(3) |
|
|
|
|
|
energy_transpose = energy.transpose(0, 1) |
|
|
|
|
|
|
|
|
|
energy_transpose = energy_transpose[:, :, leading_symbolic:-1, leading_symbolic:-1] |
|
|
|
length, batch_size, num_label, _ = energy_transpose.size() |
|
if energy.is_cuda: |
|
batch_index = torch.arange(0, batch_size).long().cuda() |
|
pi = torch.zeros([length, batch_size, num_label, 1]).cuda() |
|
pointer = torch.cuda.LongTensor(length, batch_size, num_label).zero_() |
|
back_pointer = torch.cuda.LongTensor(length, batch_size, 1).zero_() |
|
else: |
|
batch_index = torch.arange(0, batch_size).long() |
|
pi = torch.zeros([length, batch_size, num_label, 1]) |
|
pointer = torch.LongTensor(length, batch_size, num_label).zero_() |
|
back_pointer = torch.LongTensor(length, batch_size, 1).zero_() |
|
|
|
pi[0] = energy[:, 0, -1, leading_symbolic:-1].unsqueeze(-1) |
|
pointer[0] = -1 |
|
for t in range(1, length): |
|
pi_prev = pi[t - 1] |
|
x,y = torch.max(energy_transpose[t] + pi_prev, dim=1) |
|
pi[t] = x.unsqueeze(-1) |
|
pointer[t] = y |
|
_, back_pointer[-1] = torch.max(pi[-1], dim=1) |
|
back_pointer = back_pointer.squeeze(-1) |
|
for t in reversed(range(length - 1)): |
|
pointer_last = pointer[t + 1] |
|
back_pointer[t] = pointer_last[batch_index, back_pointer[t + 1]] |
|
return back_pointer.transpose(0, 1) + leading_symbolic |
|
|
|
class TreeCRF(nn.Module): |
|
''' |
|
Tree CRF layer. |
|
''' |
|
def __init__(self, input_size, num_labels, biaffine=True): |
|
''' |
|
|
|
Args: |
|
input_size: int |
|
the dimension of the input. |
|
num_labels: int |
|
the number of labels of the crf layer |
|
biaffine: bool |
|
if apply bi-affine parameter. |
|
**kwargs: |
|
''' |
|
super(TreeCRF, self).__init__() |
|
self.input_size = input_size |
|
self.num_labels = num_labels |
|
self.attention = BiAAttention(input_size, input_size, num_labels, biaffine=biaffine) |
|
|
|
def forward(self, input_h, input_c, mask=None): |
|
''' |
|
|
|
Args: |
|
input_h: Tensor |
|
the head input tensor with shape = [batch_size, length, input_size] |
|
input_c: Tensor |
|
the child input tensor with shape = [batch_size, length, input_size] |
|
mask: Tensor or None |
|
the mask tensor with shape = [batch_size, length] |
|
lengths: Tensor or None |
|
the length tensor with shape = [batch_size] |
|
|
|
Returns: Tensor |
|
the energy tensor with shape = [batch_size, num_label, length, length] |
|
|
|
''' |
|
_, length, _ = input_h.size() |
|
|
|
output = self.attention(input_h, input_c, mask_d=mask, mask_e=mask) |
|
|
|
output = output + torch.diag(output.data.new(length).fill_(-np.inf)) |
|
return output |
|
|
|
def loss(self, input_h, input_c, heads, arc_tags, mask=None, lengths=None): |
|
''' |
|
|
|
Args: |
|
input_h: Tensor |
|
the head input tensor with shape = [batch_size, length, input_size] |
|
input_c: Tensor |
|
the child input tensor with shape = [batch_size, length, input_size] |
|
target: Tensor |
|
the tensor of target labels with shape [batch_size, length] |
|
mask:Tensor or None |
|
the mask tensor with shape = [batch_size, length] |
|
lengths: tensor or list of int |
|
the length of each input shape = [batch_size] |
|
|
|
Returns: Tensor |
|
A 1D tensor for minus log likelihood loss |
|
''' |
|
batch_size, length, _ = input_h.size() |
|
energy = self.forward(input_h, input_c, mask=mask) |
|
|
|
A = torch.exp(energy) |
|
|
|
if mask is not None: |
|
A = A * mask.unsqueeze(1).unsqueeze(3) * mask.unsqueeze(1).unsqueeze(2) |
|
|
|
|
|
A = A.sum(dim=1) |
|
|
|
D = A.sum(dim=1, keepdim=True) |
|
|
|
|
|
rtol = 1e-4 |
|
atol = 1e-6 |
|
D += D * rtol + atol |
|
|
|
|
|
D = A.data.new(A.size()).zero_() + D |
|
|
|
D = D * torch.eye(length).type_as(D) |
|
|
|
|
|
|
|
L = D - A |
|
|
|
|
|
if lengths is None: |
|
if mask is None: |
|
lengths = [length for _ in range(batch_size)] |
|
else: |
|
lengths = mask.data.sum(dim=1).long() |
|
|
|
|
|
z = energy.data.new(batch_size) |
|
for b in range(batch_size): |
|
Lx = L[b, 1:lengths[b], 1:lengths[b]] |
|
|
|
z[b] = logdet(Lx) |
|
|
|
|
|
|
|
index = torch.arange(0, length).view(length, 1).expand(length, batch_size) |
|
index = index.type_as(energy.data).long() |
|
batch_index = torch.arange(0, batch_size).type_as(energy.data).long() |
|
|
|
tgt_energy = energy[batch_index, arc_tags.data.t(), heads.data.t(), index][1:] |
|
|
|
tgt_energy = tgt_energy.sum(dim=0) |
|
|
|
return z - tgt_energy |
|
|
|
class ChainCRF_with_LE(nn.Module): |
|
def __init__(self, input_size, num_labels, bigram=True): |
|
''' |
|
|
|
Args: |
|
input_size: int |
|
the dimension of the input. |
|
num_labels: int |
|
the number of labels of the crf layer |
|
bigram: bool |
|
if apply bi-gram parameter. |
|
''' |
|
super(ChainCRF_with_LE, self).__init__() |
|
self.input_size = input_size |
|
self.num_labels = num_labels + 1 |
|
self.pad_label_id = num_labels |
|
self.bigram = bigram |
|
|
|
|
|
self.state_nn = nn.Linear(input_size, self.num_labels) |
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
nn.init.constant_(self.state_nn.bias, 0.) |
|
|
|
def forward(self, input, LE, mask=None): |
|
''' |
|
|
|
Args: |
|
input: Tensor |
|
the input tensor with shape = [batch_size, length, input_size] |
|
mask: Tensor or None |
|
the mask tensor with shape = [batch_size, length] |
|
|
|
Returns: Tensor |
|
the energy tensor with shape = [batch_size, length, num_label, num_label] |
|
|
|
''' |
|
batch_size, length, _ = input.size() |
|
|
|
|
|
|
|
out_s = self.state_nn(input).unsqueeze(-1) |
|
out_t = torch.matmul(LE, torch.t(LE)) |
|
|
|
col_zeros = torch.zeros((out_t.shape[0],1)) |
|
row_zeros = torch.zeros((1, out_t.shape[1] + 1)) |
|
if out_t.is_cuda: |
|
col_zeros = col_zeros.cuda() |
|
row_zeros = row_zeros.cuda() |
|
out_t = torch.cat((out_t, col_zeros), dim=1) |
|
|
|
out_t = torch.cat((out_t, row_zeros)) |
|
|
|
return (out_s, out_t) |
|
|
|
def loss(self, energy, target, mask=None, length=None): |
|
''' |
|
|
|
Args: |
|
energy: Tensor |
|
the input tensor with shape = [batch_size, length, num_label, num_label] |
|
target: Tensor |
|
the tensor of target labels with shape [batch_size, length] |
|
mask:Tensor or None |
|
the mask tensor with shape = [batch_size, length] |
|
|
|
Returns: Tensor |
|
A 1D tensor for minus log likelihood loss |
|
''' |
|
if length is not None: |
|
max_len = length.max() |
|
if energy.size(1) != max_len: |
|
target = target[:, :max_len] |
|
mask_transpose = None |
|
if mask is not None: |
|
energy = energy * mask.unsqueeze(2).unsqueeze(3) |
|
mask_transpose = mask.unsqueeze(2).transpose(0, 1) |
|
|
|
batch_size, len, _, _ = energy.size() |
|
|
|
energy_transpose = energy.transpose(0, 1) |
|
|
|
target_transpose = target.transpose(0, 1) |
|
|
|
|
|
|
|
partition = None |
|
if energy.is_cuda: |
|
|
|
batch_index = torch.arange(0, batch_size).long().cuda() |
|
prev_label = torch.cuda.LongTensor(batch_size).fill_(self.num_labels - 1) |
|
tgt_energy = torch.zeros(batch_size).cuda() |
|
else: |
|
|
|
batch_index = torch.arange(0, batch_size).long() |
|
prev_label = torch.LongTensor(batch_size).fill_(self.num_labels - 1) |
|
tgt_energy = torch.zeros(batch_size) |
|
|
|
for t in range(len): |
|
|
|
curr_energy = energy_transpose[t] |
|
if t == 0: |
|
partition = curr_energy[:, -1, :] |
|
else: |
|
|
|
partition_new = logsumexp(curr_energy + partition.unsqueeze(2), dim=1) |
|
if mask_transpose is None: |
|
partition = partition_new |
|
else: |
|
mask_t = mask_transpose[t] |
|
partition = partition + (partition_new - partition) * mask_t |
|
tgt_energy += curr_energy[batch_index, prev_label, target_transpose[t].data] |
|
prev_label = target_transpose[t].data |
|
return logsumexp(partition, dim=1) - tgt_energy |
|
|
|
def decode(self, energy, mask=None, leading_symbolic=0): |
|
""" |
|
|
|
Args: |
|
energy: Tensor |
|
the input tensor with shape = [length, batch_size, num_label, num_label] |
|
leading_symbolic: nt |
|
number of symbolic labels leading in type alphabets (set it to 0 if you are not sure) |
|
|
|
Returns: Tensor |
|
decoding results in shape [batch_size, length] |
|
|
|
""" |
|
if mask is not None: |
|
energy = energy * mask.unsqueeze(2).unsqueeze(3) |
|
|
|
|
|
energy_transpose = energy.transpose(0, 1) |
|
|
|
|
|
|
|
|
|
energy_transpose = energy_transpose[:, :, leading_symbolic:-1, leading_symbolic:-1] |
|
|
|
length, batch_size, num_label, _ = energy_transpose.size() |
|
if energy.is_cuda: |
|
batch_index = torch.arange(0, batch_size).long().cuda() |
|
pi = torch.zeros([length, batch_size, num_label, 1]).cuda() |
|
pointer = torch.cuda.LongTensor(length, batch_size, num_label).zero_() |
|
back_pointer = torch.cuda.LongTensor(length, batch_size, 1).zero_() |
|
else: |
|
batch_index = torch.arange(0, batch_size).long() |
|
pi = torch.zeros([length, batch_size, num_label, 1]) |
|
pointer = torch.LongTensor(length, batch_size, num_label).zero_() |
|
back_pointer = torch.LongTensor(length, batch_size, 1).zero_() |
|
|
|
pi[0] = energy[:, 0, -1, leading_symbolic:-1].unsqueeze(-1) |
|
pointer[0] = -1 |
|
for t in range(1, length): |
|
pi_prev = pi[t - 1] |
|
x,y = torch.max(energy_transpose[t] + pi_prev, dim=1) |
|
pi[t] = x.unsqueeze(-1) |
|
pointer[t] = y |
|
_, back_pointer[-1] = torch.max(pi[-1], dim=1) |
|
back_pointer = back_pointer.squeeze(-1) |
|
for t in reversed(range(length - 1)): |
|
pointer_last = pointer[t + 1] |
|
back_pointer[t] = pointer_last[batch_index, back_pointer[t + 1]] |
|
return back_pointer.transpose(0, 1) + leading_symbolic |
|
|