Spaces:
Sleeping
Sleeping
import sys | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from .RNNTagger import RNNTagger | |
### auxiliary functions ############################################ | |
def logsumexp(x, dim): | |
""" sums up log-scale values """ | |
offset, _ = torch.max(x, dim=dim) | |
offset_broadcasted = offset.unsqueeze(dim) | |
safe_log_sum_exp = torch.log(torch.exp(x-offset_broadcasted).sum(dim=dim)) | |
return safe_log_sum_exp + offset | |
def lookup(T, indices): | |
""" look up probabilities of tags in a vector, matrix, or 3D tensor """ | |
if T.dim() == 3: | |
return T.gather(2, indices.unsqueeze(2)).squeeze(2) | |
elif T.dim() == 2: | |
return T.gather(1, indices.unsqueeze(1)).squeeze(1) | |
elif T.dim() == 1: | |
return T[indices] | |
else: | |
raise Exception('unexpected tensor size in function "lookup"') | |
### tagger class ############################################### | |
class CRFTagger(nn.Module): | |
""" implements a CRF tagger """ | |
def __init__(self, num_chars, num_tags, char_emb_size, | |
char_rec_size, word_rec_size, word_rnn_depth, | |
dropout_rate, word_emb_size, beam_size): | |
super(CRFTagger, self).__init__() | |
# simple LSTMTagger which computes tag scores | |
self.base_tagger = RNNTagger(num_chars, num_tags, char_emb_size, | |
char_rec_size, word_rec_size, | |
word_rnn_depth, dropout_rate, word_emb_size) | |
self.beam_size = beam_size if 0 < beam_size < num_tags else num_tags | |
self.weights = nn.Parameter(torch.zeros(num_tags, num_tags)) | |
self.dropout = nn.Dropout(dropout_rate) | |
def forward(self, fwd_charIDs, bwd_charIDs, tags=None): | |
annotation_mode = (tags is None) | |
scores = self.base_tagger(fwd_charIDs, bwd_charIDs) | |
# extract the highest-scoring tags for each word and their scores | |
best_scores, best_tags = scores.topk(self.beam_size, dim=-1) | |
if self.training: # not done during dev evaluation | |
# check whether the goldstandard tags are among the best tags | |
gs_contained = (best_tags == tags.unsqueeze(1)).sum(dim=-1) | |
# replace the tag with the lowest score at each position | |
# by the gs tag if the gs tag is not in the list | |
last_column = gs_contained * best_tags[:,-1] + (1-gs_contained) * tags | |
s = lookup(scores, last_column) | |
best_tags = torch.cat((best_tags[:,:-1], last_column.unsqueeze(1)), dim=1) | |
best_scores = torch.cat((best_scores[:,:-1], s.unsqueeze(1)), dim=1) | |
best_previous = [] # stores the backpointers of the Viterbi algorithm | |
viterbi_scores = best_scores[0] | |
if not annotation_mode: | |
forward_scores = best_scores[0] | |
for i in range(1,scores.size(0)): # for all word positions except the first | |
# lookup of the tag-pair weights | |
w = self.weights[best_tags[i-1]][:,best_tags[i]] | |
# Viterbi algorithm | |
values = viterbi_scores.unsqueeze(1) + best_scores[i].unsqueeze(0) + w | |
viterbi_scores, best_prev = torch.max(values, dim=0) | |
best_previous.append(best_prev) | |
# Forward algorithm | |
if not annotation_mode: | |
values = forward_scores.unsqueeze(1) + best_scores[i].unsqueeze(0) + w | |
forward_scores = logsumexp(values, dim=0) | |
# Viterbi algorithm | |
_, index = torch.max(viterbi_scores, dim=0) | |
best_indices = [index] | |
for i in range(len(best_previous)-1, -1, -1): | |
index = best_previous[i][index] | |
best_indices.append(index) | |
# reverse the indices and map them to tag IDs | |
best_indices = torch.stack(best_indices[::-1]) | |
predicted_tags = lookup(best_tags, best_indices) | |
if annotation_mode: | |
return predicted_tags | |
else: | |
# loss computation | |
basetagger_scores = lookup(scores, tags).sum() | |
CRFweights = self.weights[tags[:-1], tags[1:]].sum() if tags.size(0)>1 else 0 | |
logZ = logsumexp(forward_scores, dim=0) # log partition function | |
logprob = basetagger_scores + CRFweights - logZ | |
return predicted_tags, -logprob | |