mhg-parsing / MHGTagger /CRFTagger.py
nielklug's picture
init
6ed21b9
import sys
import torch
from torch import nn
from torch.nn import functional as F
from .RNNTagger import RNNTagger
### auxiliary functions ############################################
def logsumexp(x, dim):
""" sums up log-scale values """
offset, _ = torch.max(x, dim=dim)
offset_broadcasted = offset.unsqueeze(dim)
safe_log_sum_exp = torch.log(torch.exp(x-offset_broadcasted).sum(dim=dim))
return safe_log_sum_exp + offset
def lookup(T, indices):
""" look up probabilities of tags in a vector, matrix, or 3D tensor """
if T.dim() == 3:
return T.gather(2, indices.unsqueeze(2)).squeeze(2)
elif T.dim() == 2:
return T.gather(1, indices.unsqueeze(1)).squeeze(1)
elif T.dim() == 1:
return T[indices]
else:
raise Exception('unexpected tensor size in function "lookup"')
### tagger class ###############################################
class CRFTagger(nn.Module):
""" implements a CRF tagger """
def __init__(self, num_chars, num_tags, char_emb_size,
char_rec_size, word_rec_size, word_rnn_depth,
dropout_rate, word_emb_size, beam_size):
super(CRFTagger, self).__init__()
# simple LSTMTagger which computes tag scores
self.base_tagger = RNNTagger(num_chars, num_tags, char_emb_size,
char_rec_size, word_rec_size,
word_rnn_depth, dropout_rate, word_emb_size)
self.beam_size = beam_size if 0 < beam_size < num_tags else num_tags
self.weights = nn.Parameter(torch.zeros(num_tags, num_tags))
self.dropout = nn.Dropout(dropout_rate)
def forward(self, fwd_charIDs, bwd_charIDs, tags=None):
annotation_mode = (tags is None)
scores = self.base_tagger(fwd_charIDs, bwd_charIDs)
# extract the highest-scoring tags for each word and their scores
best_scores, best_tags = scores.topk(self.beam_size, dim=-1)
if self.training: # not done during dev evaluation
# check whether the goldstandard tags are among the best tags
gs_contained = (best_tags == tags.unsqueeze(1)).sum(dim=-1)
# replace the tag with the lowest score at each position
# by the gs tag if the gs tag is not in the list
last_column = gs_contained * best_tags[:,-1] + (1-gs_contained) * tags
s = lookup(scores, last_column)
best_tags = torch.cat((best_tags[:,:-1], last_column.unsqueeze(1)), dim=1)
best_scores = torch.cat((best_scores[:,:-1], s.unsqueeze(1)), dim=1)
best_previous = [] # stores the backpointers of the Viterbi algorithm
viterbi_scores = best_scores[0]
if not annotation_mode:
forward_scores = best_scores[0]
for i in range(1,scores.size(0)): # for all word positions except the first
# lookup of the tag-pair weights
w = self.weights[best_tags[i-1]][:,best_tags[i]]
# Viterbi algorithm
values = viterbi_scores.unsqueeze(1) + best_scores[i].unsqueeze(0) + w
viterbi_scores, best_prev = torch.max(values, dim=0)
best_previous.append(best_prev)
# Forward algorithm
if not annotation_mode:
values = forward_scores.unsqueeze(1) + best_scores[i].unsqueeze(0) + w
forward_scores = logsumexp(values, dim=0)
# Viterbi algorithm
_, index = torch.max(viterbi_scores, dim=0)
best_indices = [index]
for i in range(len(best_previous)-1, -1, -1):
index = best_previous[i][index]
best_indices.append(index)
# reverse the indices and map them to tag IDs
best_indices = torch.stack(best_indices[::-1])
predicted_tags = lookup(best_tags, best_indices)
if annotation_mode:
return predicted_tags
else:
# loss computation
basetagger_scores = lookup(scores, tags).sum()
CRFweights = self.weights[tags[:-1], tags[1:]].sum() if tags.size(0)>1 else 0
logZ = logsumexp(forward_scores, dim=0) # log partition function
logprob = basetagger_scores + CRFweights - logZ
return predicted_tags, -logprob