shivrajanand's picture
Upload folder using huggingface_hub
e8f4897 verified
import math
import numpy as np
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from utils.nlinalg import logsumexp, logdet
from utils.tasks import parse
from .attention import BiAAttention
class ChainCRF(nn.Module):
def __init__(self, input_size, num_labels, bigram=True):
'''
Args:
input_size: int
the dimension of the input.
num_labels: int
the number of labels of the crf layer
bigram: bool
if apply bi-gram parameter.
'''
super(ChainCRF, self).__init__()
self.input_size = input_size
self.num_labels = num_labels + 1
self.pad_label_id = num_labels
self.bigram = bigram
# state weight tensor
self.state_nn = nn.Linear(input_size, self.num_labels)
if bigram:
# transition weight tensor
self.trans_nn = nn.Linear(input_size, self.num_labels * self.num_labels)
self.register_parameter('trans_matrix', None)
else:
self.trans_nn = None
self.trans_matrix = Parameter(torch.Tensor(self.num_labels, self.num_labels))
self.reset_parameters()
def reset_parameters(self):
nn.init.constant_(self.state_nn.bias, 0.)
if self.bigram:
nn.init.xavier_uniform_(self.trans_nn.weight)
nn.init.constant_(self.trans_nn.bias, 0.)
else:
nn.init.normal_(self.trans_matrix)
# if not self.bigram:
# nn.init.normal(self.trans_matrix)
def forward(self, input, mask=None):
'''
Args:
input: Tensor
the input tensor with shape = [batch_size, length, input_size]
mask: Tensor or None
the mask tensor with shape = [batch_size, length]
Returns: Tensor
the energy tensor with shape = [batch_size, length, num_label, num_label]
'''
batch_size, length, _ = input.size()
# compute out_s by tensor dot [batch_size, length, input_size] * [input_size, num_label]
# thus out_s should be [batch_size, length, num_label] --> [batch_size, length, num_label, 1]
out_s = self.state_nn(input).unsqueeze(-1)
if self.bigram:
# compute out_s by tensor dot: [batch_size, length, input_size] * [input_size, num_label * num_label]
out_t = self.trans_nn(input).view(batch_size, length, self.num_labels, self.num_labels)
else:
out_t = self.trans_matrix
# the output should be [batch_size, length, num_label, num_label]
#output = out_t + out_s
#if mask is not None:
# output = output * mask.unsqueeze(2).unsqueeze(3)
return (out_s, out_t)
def loss(self, energy, target, mask=None, length=None):
'''
Args:
energy: Tensor
the input tensor with shape = [batch_size, length, num_label, num_label]
target: Tensor
the tensor of target labels with shape [batch_size, length]
mask:Tensor or None
the mask tensor with shape = [batch_size, length]
Returns: Tensor
A 1D tensor for minus log likelihood loss
'''
if length is not None:
max_len = length.max()
if energy.size(1) != max_len:
target = target[:, :max_len]
mask_transpose = None
if mask is not None:
energy = energy * mask.unsqueeze(2).unsqueeze(3)
mask_transpose = mask.unsqueeze(2).transpose(0, 1)
batch_size, len, _, _ = energy.size()
# shape = [length, batch_size, num_label, num_label]
energy_transpose = energy.transpose(0, 1)
# shape = [length, batch_size]
target_transpose = target.transpose(0, 1)
# shape = [length, batch_size, 1]
# shape = [batch_size, num_label]
partition = None
if energy.is_cuda:
# shape = [batch_size]
batch_index = torch.arange(0, batch_size).long().cuda()
prev_label = torch.cuda.LongTensor(batch_size).fill_(self.num_labels - 1)
tgt_energy = torch.zeros(batch_size).cuda()
else:
# shape = [batch_size]
batch_index = torch.arange(0, batch_size).long()
prev_label = torch.LongTensor(batch_size).fill_(self.num_labels - 1)
tgt_energy = torch.zeros(batch_size)
for t in range(len):
# shape = [batch_size, num_label, num_label]
curr_energy = energy_transpose[t]
if t == 0:
partition = curr_energy[:, -1, :]
else:
# shape = [batch_size, num_label]
partition_new = logsumexp(curr_energy + partition.unsqueeze(2), dim=1)
if mask_transpose is None:
partition = partition_new
else:
mask_t = mask_transpose[t]
partition = partition + (partition_new - partition) * mask_t
tgt_energy += curr_energy[batch_index, prev_label, target_transpose[t].data]
prev_label = target_transpose[t].data
return logsumexp(partition, dim=1) - tgt_energy
def decode(self, energy, mask=None, leading_symbolic=0):
"""
Args:
energy: Tensor
the input tensor with shape = [length, batch_size, num_label, num_label]
leading_symbolic: nt
number of symbolic labels leading in type alphabets (set it to 0 if you are not sure)
Returns: Tensor
decoding results in shape [batch_size, length]
"""
if mask is not None:
energy = energy * mask.unsqueeze(2).unsqueeze(3)
# Input should be provided as (batch_size, n_time_steps, num_labels, num_labels)
# For convenience, we need to dimshuffle to (n_time_steps, batch_size, num_labels, num_labels)
energy_transpose = energy.transpose(0, 1)
# the last row and column is the tag for pad symbol. reduce these two dimensions by 1 to remove that.
# also remove the first #symbolic rows and columns.
# now the shape of energies_shuffled is [n_time_steps, batch_size, t, t] where t = num_labels - #symbolic - 1.
energy_transpose = energy_transpose[:, :, leading_symbolic:-1, leading_symbolic:-1]
length, batch_size, num_label, _ = energy_transpose.size()
if energy.is_cuda:
batch_index = torch.arange(0, batch_size).long().cuda()
pi = torch.zeros([length, batch_size, num_label, 1]).cuda()
pointer = torch.cuda.LongTensor(length, batch_size, num_label).zero_()
back_pointer = torch.cuda.LongTensor(length, batch_size, 1).zero_()
else:
batch_index = torch.arange(0, batch_size).long()
pi = torch.zeros([length, batch_size, num_label, 1])
pointer = torch.LongTensor(length, batch_size, num_label).zero_()
back_pointer = torch.LongTensor(length, batch_size, 1).zero_()
pi[0] = energy[:, 0, -1, leading_symbolic:-1].unsqueeze(-1)
pointer[0] = -1
for t in range(1, length):
pi_prev = pi[t - 1]
x,y = torch.max(energy_transpose[t] + pi_prev, dim=1)
pi[t] = x.unsqueeze(-1)
pointer[t] = y
_, back_pointer[-1] = torch.max(pi[-1], dim=1)
back_pointer = back_pointer.squeeze(-1)
for t in reversed(range(length - 1)):
pointer_last = pointer[t + 1]
back_pointer[t] = pointer_last[batch_index, back_pointer[t + 1]]
return back_pointer.transpose(0, 1) + leading_symbolic
class TreeCRF(nn.Module):
'''
Tree CRF layer.
'''
def __init__(self, input_size, num_labels, biaffine=True):
'''
Args:
input_size: int
the dimension of the input.
num_labels: int
the number of labels of the crf layer
biaffine: bool
if apply bi-affine parameter.
**kwargs:
'''
super(TreeCRF, self).__init__()
self.input_size = input_size
self.num_labels = num_labels
self.attention = BiAAttention(input_size, input_size, num_labels, biaffine=biaffine)
def forward(self, input_h, input_c, mask=None):
'''
Args:
input_h: Tensor
the head input tensor with shape = [batch_size, length, input_size]
input_c: Tensor
the child input tensor with shape = [batch_size, length, input_size]
mask: Tensor or None
the mask tensor with shape = [batch_size, length]
lengths: Tensor or None
the length tensor with shape = [batch_size]
Returns: Tensor
the energy tensor with shape = [batch_size, num_label, length, length]
'''
_, length, _ = input_h.size()
# [batch_size, num_labels, length, length]
output = self.attention(input_h, input_c, mask_d=mask, mask_e=mask)
# set diagonal elements to -inf
output = output + torch.diag(output.data.new(length).fill_(-np.inf))
return output
def loss(self, input_h, input_c, heads, arc_tags, mask=None, lengths=None):
'''
Args:
input_h: Tensor
the head input tensor with shape = [batch_size, length, input_size]
input_c: Tensor
the child input tensor with shape = [batch_size, length, input_size]
target: Tensor
the tensor of target labels with shape [batch_size, length]
mask:Tensor or None
the mask tensor with shape = [batch_size, length]
lengths: tensor or list of int
the length of each input shape = [batch_size]
Returns: Tensor
A 1D tensor for minus log likelihood loss
'''
batch_size, length, _ = input_h.size()
energy = self.forward(input_h, input_c, mask=mask)
# [batch_size, num_labels, length, length]
A = torch.exp(energy)
# mask out invalid positions
if mask is not None:
A = A * mask.unsqueeze(1).unsqueeze(3) * mask.unsqueeze(1).unsqueeze(2)
# sum along the label axis [batch_size, length, length]
A = A.sum(dim=1)
# get D [batch_size, 1, length]
D = A.sum(dim=1, keepdim=True)
# make sure L is positive-defined
rtol = 1e-4
atol = 1e-6
D += D * rtol + atol
# [batch_size, length, length]
D = A.data.new(A.size()).zero_() + D
# zeros out all elements except diagonal.
D = D * torch.eye(length).type_as(D)
# compute laplacian matrix
# [batch_size, length, length]
L = D - A
# compute lengths
if lengths is None:
if mask is None:
lengths = [length for _ in range(batch_size)]
else:
lengths = mask.data.sum(dim=1).long()
# compute partition Z(x) [batch_size]
z = energy.data.new(batch_size)
for b in range(batch_size):
Lx = L[b, 1:lengths[b], 1:lengths[b]]
# print(torch.log(torch.eig(Lx.data)[0]))
z[b] = logdet(Lx)
# first create index matrix [length, batch_size]
# index = torch.zeros(length, batch_size) + torch.arange(0, length).view(length, 1)
index = torch.arange(0, length).view(length, 1).expand(length, batch_size)
index = index.type_as(energy.data).long()
batch_index = torch.arange(0, batch_size).type_as(energy.data).long()
# compute target energy [length-1, batch_size]
tgt_energy = energy[batch_index, arc_tags.data.t(), heads.data.t(), index][1:]
# sum over dim=0 shape = [batch_size]
tgt_energy = tgt_energy.sum(dim=0)
return z - tgt_energy
class ChainCRF_with_LE(nn.Module):
def __init__(self, input_size, num_labels, bigram=True):
'''
Args:
input_size: int
the dimension of the input.
num_labels: int
the number of labels of the crf layer
bigram: bool
if apply bi-gram parameter.
'''
super(ChainCRF_with_LE, self).__init__()
self.input_size = input_size
self.num_labels = num_labels + 1
self.pad_label_id = num_labels
self.bigram = bigram
# state weight tensor
self.state_nn = nn.Linear(input_size, self.num_labels)
self.reset_parameters()
def reset_parameters(self):
nn.init.constant_(self.state_nn.bias, 0.)
def forward(self, input, LE, mask=None):
'''
Args:
input: Tensor
the input tensor with shape = [batch_size, length, input_size]
mask: Tensor or None
the mask tensor with shape = [batch_size, length]
Returns: Tensor
the energy tensor with shape = [batch_size, length, num_label, num_label]
'''
batch_size, length, _ = input.size()
# compute out_s by tensor dot [batch_size, length, input_size] * [input_size, num_label]
# thus out_s should be [batch_size, length, num_label] --> [batch_size, length, num_label, 1]
out_s = self.state_nn(input).unsqueeze(-1)
out_t = torch.matmul(LE, torch.t(LE))
# add column of zeros (for END token)
col_zeros = torch.zeros((out_t.shape[0],1))
row_zeros = torch.zeros((1, out_t.shape[1] + 1))
if out_t.is_cuda:
col_zeros = col_zeros.cuda()
row_zeros = row_zeros.cuda()
out_t = torch.cat((out_t, col_zeros), dim=1)
# add row of zeros (for END token)
out_t = torch.cat((out_t, row_zeros))
# the output should be [batch_size, length, num_label, num_label]
return (out_s, out_t)
def loss(self, energy, target, mask=None, length=None):
'''
Args:
energy: Tensor
the input tensor with shape = [batch_size, length, num_label, num_label]
target: Tensor
the tensor of target labels with shape [batch_size, length]
mask:Tensor or None
the mask tensor with shape = [batch_size, length]
Returns: Tensor
A 1D tensor for minus log likelihood loss
'''
if length is not None:
max_len = length.max()
if energy.size(1) != max_len:
target = target[:, :max_len]
mask_transpose = None
if mask is not None:
energy = energy * mask.unsqueeze(2).unsqueeze(3)
mask_transpose = mask.unsqueeze(2).transpose(0, 1)
batch_size, len, _, _ = energy.size()
# shape = [length, batch_size, num_label, num_label]
energy_transpose = energy.transpose(0, 1)
# shape = [length, batch_size]
target_transpose = target.transpose(0, 1)
# shape = [length, batch_size, 1]
# shape = [batch_size, num_label]
partition = None
if energy.is_cuda:
# shape = [batch_size]
batch_index = torch.arange(0, batch_size).long().cuda()
prev_label = torch.cuda.LongTensor(batch_size).fill_(self.num_labels - 1)
tgt_energy = torch.zeros(batch_size).cuda()
else:
# shape = [batch_size]
batch_index = torch.arange(0, batch_size).long()
prev_label = torch.LongTensor(batch_size).fill_(self.num_labels - 1)
tgt_energy = torch.zeros(batch_size)
for t in range(len):
# shape = [batch_size, num_label, num_label]
curr_energy = energy_transpose[t]
if t == 0:
partition = curr_energy[:, -1, :]
else:
# shape = [batch_size, num_label]
partition_new = logsumexp(curr_energy + partition.unsqueeze(2), dim=1)
if mask_transpose is None:
partition = partition_new
else:
mask_t = mask_transpose[t]
partition = partition + (partition_new - partition) * mask_t
tgt_energy += curr_energy[batch_index, prev_label, target_transpose[t].data]
prev_label = target_transpose[t].data
return logsumexp(partition, dim=1) - tgt_energy
def decode(self, energy, mask=None, leading_symbolic=0):
"""
Args:
energy: Tensor
the input tensor with shape = [length, batch_size, num_label, num_label]
leading_symbolic: nt
number of symbolic labels leading in type alphabets (set it to 0 if you are not sure)
Returns: Tensor
decoding results in shape [batch_size, length]
"""
if mask is not None:
energy = energy * mask.unsqueeze(2).unsqueeze(3)
# Input should be provided as (batch_size, n_time_steps, num_labels, num_labels)
# For convenience, we need to dimshuffle to (n_time_steps, batch_size, num_labels, num_labels)
energy_transpose = energy.transpose(0, 1)
# the last row and column is the tag for pad symbol. reduce these two dimensions by 1 to remove that.
# also remove the first #symbolic rows and columns.
# now the shape of energies_shuffled is [n_time_steps, batch_size, t, t] where t = num_labels - #symbolic - 1.
energy_transpose = energy_transpose[:, :, leading_symbolic:-1, leading_symbolic:-1]
length, batch_size, num_label, _ = energy_transpose.size()
if energy.is_cuda:
batch_index = torch.arange(0, batch_size).long().cuda()
pi = torch.zeros([length, batch_size, num_label, 1]).cuda()
pointer = torch.cuda.LongTensor(length, batch_size, num_label).zero_()
back_pointer = torch.cuda.LongTensor(length, batch_size, 1).zero_()
else:
batch_index = torch.arange(0, batch_size).long()
pi = torch.zeros([length, batch_size, num_label, 1])
pointer = torch.LongTensor(length, batch_size, num_label).zero_()
back_pointer = torch.LongTensor(length, batch_size, 1).zero_()
pi[0] = energy[:, 0, -1, leading_symbolic:-1].unsqueeze(-1)
pointer[0] = -1
for t in range(1, length):
pi_prev = pi[t - 1]
x,y = torch.max(energy_transpose[t] + pi_prev, dim=1)
pi[t] = x.unsqueeze(-1)
pointer[t] = y
_, back_pointer[-1] = torch.max(pi[-1], dim=1)
back_pointer = back_pointer.squeeze(-1)
for t in reversed(range(length - 1)):
pointer_last = pointer[t + 1]
back_pointer[t] = pointer_last[batch_index, back_pointer[t + 1]]
return back_pointer.transpose(0, 1) + leading_symbolic