LCM / utils /models /sequence_tagger.py
shivrajanand's picture
Upload folder using huggingface_hub
e8f4897 verified
from ..nn import Embedding
from ..nn import utils
import torch
import torch.nn as nn
import torch.nn.functional as F
class Sequence_Tagger(nn.Module):
def __init__(self, word_dim, num_words, char_dim, num_chars, use_pos, use_char, pos_dim, num_pos,
num_filters, kernel_size, rnn_mode, hidden_size, num_layers, tag_space, num_tags,
embedd_word=None, embedd_char=None, embedd_pos=None,
p_in=0.33, p_out=0.33, p_rnn=(0.33, 0.33),
initializer=None):
super(Sequence_Tagger, self).__init__()
self.rnn_encoder = BiRecurrentConv_Encoder(word_dim, num_words, char_dim, num_chars, use_pos, use_char,
pos_dim, num_pos, num_filters,
kernel_size, rnn_mode, hidden_size,
num_layers, embedd_word=embedd_word,
embedd_char=embedd_char, embedd_pos=embedd_pos,
p_in=p_in, p_out=p_out, p_rnn=p_rnn, initializer=initializer)
self.sequence_tagger_decoder = Tagger_Decoder(hidden_size, tag_space, num_tags, p_out, initializer)
def forward(self, input_word, input_char, input_pos, mask=None, length=None, hx=None):
encoder_output, hn, mask, length = self.rnn_encoder(input_word, input_char, input_pos, mask, length, hx)
out_counter = self.sequence_tagger_decoder(encoder_output, mask)
return out_counter, mask, length
def loss(self, input, target, mask=None, length=None):
loss_ = self.sequence_tagger_decoder.loss(input, target, mask, length)
return loss_
def decode(self, input, mask=None, length=None, leading_symbolic=0):
out_pred = self.sequence_tagger_decoder.decode(input, mask, leading_symbolic)
return out_pred
class Tagger_Decoder(nn.Module):
def __init__(self, hidden_size, tag_space, num_tags, p_out, initializer):
super(Tagger_Decoder, self).__init__()
self.criterion_obj = nn.CrossEntropyLoss()
self.tag_space = tag_space
self.num_tags = num_tags
self.p_out = p_out
self.initializer = initializer
self.dropout_out = nn.Dropout(p_out)
self.out_dim = 2 * hidden_size
self.num_tags = num_tags
self.fc_1 = nn.Linear(self.out_dim, tag_space)
self.fc_2 = nn.Linear(tag_space, tag_space//2)
self.fc_3 = nn.Linear(tag_space//2, num_tags)
self.reset_parameters()
def reset_parameters(self):
if self.initializer is None:
return
for name, parameter in self.named_parameters():
if parameter.dim() == 1:
parameter.data.zero_()
else:
self.initializer(parameter.data)
def forward(self, input, mask):
# input from rnn [batch_size, length, hidden_size]
# [batch_size, length, tag_space]
output = self.dropout_out(F.elu(self.fc_1(input)))
#output = self.fc_2(output)
output = self.dropout_out(F.elu(self.fc_2(output)))
output = self.fc_3(output)
return output
def loss(self, input, target, mask=None, length=None):
if length is not None:
max_len = length.max()
if target.size(1) != max_len:
target = target[:, :max_len]
input = input.view(-1, self.num_tags)
target = target.contiguous().view(-1)
loss_ = self.criterion_obj(input, target)
return loss_
def decode(self, input, mask=None, leading_symbolic=0):
if mask is not None:
input = input * mask.unsqueeze(2)
# remove the first #symbolic rows and columns.
# now the shape of the input is [n_time_steps, batch_size, t] where t = num_labels - #symbolic.
input = input[:, :, :-leading_symbolic]
preds = torch.argmax(input, -1)
return preds
class BiRecurrentConv_Encoder(nn.Module):
def __init__(self, word_dim, num_words, char_dim, num_chars, use_pos, use_char, pos_dim, num_pos, num_filters,
kernel_size, rnn_mode, hidden_size, num_layers, embedd_word=None, embedd_char=None, embedd_pos=None,
p_in=0.33, p_out=0.33, p_rnn=(0.33, 0.33), initializer=None):
super(BiRecurrentConv_Encoder, self).__init__()
self.word_embedd = Embedding(num_words, word_dim, init_embedding=embedd_word)
self.char_embedd = Embedding(num_chars, char_dim, init_embedding=embedd_char) if use_char else None
self.pos_embedd = Embedding(num_pos, pos_dim, init_embedding=embedd_pos) if use_pos else None
self.conv1d = nn.Conv1d(char_dim, num_filters, kernel_size, padding=kernel_size - 1) if use_char else None
# dropout word
self.dropout_in = nn.Dropout2d(p_in)
# standard dropout
self.dropout_out = nn.Dropout2d(p_out)
self.dropout_rnn_in = nn.Dropout(p_rnn[0])
self.use_pos = use_pos
self.use_char = use_char
self.rnn_mode = rnn_mode
self.dim_enc = word_dim
if use_pos:
self.dim_enc += pos_dim
if use_char:
self.dim_enc += num_filters
if rnn_mode == 'RNN':
RNN = nn.RNN
drop_p_rnn = p_rnn[1]
elif rnn_mode == 'LSTM':
RNN = nn.LSTM
drop_p_rnn = p_rnn[1]
elif rnn_mode == 'GRU':
RNN = nn.GRU
drop_p_rnn = p_rnn[1]
else:
raise ValueError('Unknown RNN mode: %s' % rnn_mode)
self.rnn = RNN(self.dim_enc, hidden_size, num_layers=num_layers, batch_first=True, bidirectional=True,
dropout=drop_p_rnn)
self.initializer = initializer
self.reset_parameters()
def reset_parameters(self):
if self.initializer is None:
return
for name, parameter in self.named_parameters():
if name.find('embedd') == -1:
if parameter.dim() == 1:
parameter.data.zero_()
else:
self.initializer(parameter.data)
def forward(self, input_word, input_char, input_pos, mask=None, length=None, hx=None):
# hack length from mask
# we do not hack mask from length for special reasons.
# Thus, always provide mask if it is necessary.
if length is None and mask is not None:
length = mask.data.sum(dim=1).long()
# [batch_size, length, word_dim]
word = self.word_embedd(input_word)
# apply dropout on input
word = self.dropout_in(word)
input = word
if self.use_char:
# [batch_size, length, char_length, char_dim]
char = self.char_embedd(input_char)
char_size = char.size()
# first transform to [batch *length, char_length, char_dim]
# then transpose to [batch * length, char_dim, char_length]
char = char.view(char_size[0] * char_size[1], char_size[2], char_size[3]).transpose(1, 2)
# put into cnn [batch*length, char_filters, char_length]
# then put into maxpooling [batch * length, char_filters]
char, _ = self.conv1d(char).max(dim=2)
# reshape to [batch_size, length, char_filters]
char = torch.tanh(char).view(char_size[0], char_size[1], -1)
# apply dropout on input
char = self.dropout_in(char)
# concatenate word and char [batch_size, length, word_dim+char_filter]
input = torch.cat([input, char], dim=2)
if self.use_pos:
# [batch_size, length, pos_dim]
pos = self.pos_embedd(input_pos)
# apply dropout on input
pos = self.dropout_in(pos)
input = torch.cat([input, pos], dim=2)
# apply dropout rnn input
input = self.dropout_rnn_in(input)
# prepare packed_sequence
if length is not None:
seq_input, hx, rev_order, mask = utils.prepare_rnn_seq(input, length, hx=hx, masks=mask, batch_first=True)
self.rnn.flatten_parameters()
seq_output, hn = self.rnn(seq_input, hx=hx)
output, hn = utils.recover_rnn_seq(seq_output, rev_order, hx=hn, batch_first=True)
else:
# output from rnn [batch_size, length, hidden_size]
self.rnn.flatten_parameters()
output, hn = self.rnn(input, hx=hx)
# apply dropout for the output of rnn
output = self.dropout_out(output)
return output, hn, mask, length