Spaces:
Runtime error
Runtime error
from typing import Tuple | |
import numpy as np | |
import torch | |
import torch.nn | |
from torch.nn.functional import softmax | |
from torch.nn.utils.rnn import pack_padded_sequence | |
import flair | |
from flair.data import Dictionary, Label, List, Sentence | |
START_TAG: str = "<START>" | |
STOP_TAG: str = "<STOP>" | |
class ViterbiLoss(torch.nn.Module): | |
""" | |
Calculates the loss for each sequence up to its length t. | |
""" | |
def __init__(self, tag_dictionary: Dictionary): | |
""" | |
:param tag_dictionary: tag_dictionary of task | |
""" | |
super(ViterbiLoss, self).__init__() | |
self.tag_dictionary = tag_dictionary | |
self.tagset_size = len(tag_dictionary) | |
self.start_tag = tag_dictionary.get_idx_for_item(START_TAG) | |
self.stop_tag = tag_dictionary.get_idx_for_item(STOP_TAG) | |
def forward(self, features_tuple: tuple, targets: torch.Tensor) -> torch.Tensor: | |
""" | |
Forward propagation of Viterbi Loss | |
:param features_tuple: CRF scores from forward method in shape (batch size, seq len, tagset size, tagset size), | |
lengths of sentences in batch, transitions from CRF | |
:param targets: true tags for sentences which will be converted to matrix indices. | |
:return: average Viterbi Loss over batch size | |
""" | |
features, lengths, transitions = features_tuple | |
batch_size = features.size(0) | |
seq_len = features.size(1) | |
targets, targets_matrix_indices = self._format_targets(targets, lengths) | |
targets_matrix_indices = torch.tensor(targets_matrix_indices, dtype=torch.long).unsqueeze(2).to(flair.device) | |
# scores_at_targets[range(features.shape[0]), lengths.values -1] | |
# Squeeze crf scores matrices in 1-dim shape and gather scores at targets by matrix indices | |
scores_at_targets = torch.gather(features.view(batch_size, seq_len, -1), 2, targets_matrix_indices) | |
scores_at_targets = pack_padded_sequence(scores_at_targets, lengths, batch_first=True)[0] | |
transitions_to_stop = transitions[ | |
np.repeat(self.stop_tag, features.shape[0]), | |
[target[length - 1] for target, length in zip(targets, lengths)], | |
] | |
gold_score = scores_at_targets.sum() + transitions_to_stop.sum() | |
scores_upto_t = torch.zeros(batch_size, self.tagset_size, device=flair.device) | |
for t in range(max(lengths)): | |
batch_size_t = sum( | |
[length > t for length in lengths] | |
) # since batch is ordered, we can save computation time by reducing our effective batch_size | |
if t == 0: | |
# Initially, get scores from <start> tag to all other tags | |
scores_upto_t[:batch_size_t] = ( | |
scores_upto_t[:batch_size_t] + features[:batch_size_t, t, :, self.start_tag] | |
) | |
else: | |
# We add scores at current timestep to scores accumulated up to previous timestep, and log-sum-exp | |
# Remember, the cur_tag of the previous timestep is the prev_tag of this timestep | |
scores_upto_t[:batch_size_t] = self._log_sum_exp( | |
features[:batch_size_t, t, :, :] + scores_upto_t[:batch_size_t].unsqueeze(1), dim=2 | |
) | |
all_paths_scores = self._log_sum_exp(scores_upto_t + transitions[self.stop_tag].unsqueeze(0), dim=1).sum() | |
viterbi_loss = all_paths_scores - gold_score | |
return viterbi_loss | |
def _log_sum_exp(tensor, dim): | |
""" | |
Calculates the log-sum-exponent of a tensor's dimension in a numerically stable way. | |
:param tensor: tensor | |
:param dim: dimension to calculate log-sum-exp of | |
:return: log-sum-exp | |
""" | |
m, _ = torch.max(tensor, dim) | |
m_expanded = m.unsqueeze(dim).expand_as(tensor) | |
return m + torch.log(torch.sum(torch.exp(tensor - m_expanded), dim)) | |
def _format_targets(self, targets: torch.Tensor, lengths: torch.IntTensor): | |
""" | |
Formats targets into matrix indices. | |
CRF scores contain per sentence, per token a (tagset_size x tagset_size) matrix, containing emission score for | |
token j + transition prob from previous token i. Means, if we think of our rows as "to tag" and our columns | |
as "from tag", the matrix in cell [10,5] would contain the emission score for tag 10 + transition score | |
from previous tag 5 and could directly be addressed through the 1-dim indices (10 + tagset_size * 5) = 70, | |
if our tagset consists of 12 tags. | |
:param targets: targets as in tag dictionary | |
:param lengths: lengths of sentences in batch | |
""" | |
targets_per_sentence = [] | |
targets_list = targets.tolist() | |
for cut in lengths: | |
targets_per_sentence.append(targets_list[:cut]) | |
targets_list = targets_list[cut:] | |
for t in targets_per_sentence: | |
t += [self.tag_dictionary.get_idx_for_item(STOP_TAG)] * (int(lengths.max().item()) - len(t)) | |
matrix_indices = list( | |
map( | |
lambda s: [self.tag_dictionary.get_idx_for_item(START_TAG) + (s[0] * self.tagset_size)] | |
+ [s[i] + (s[i + 1] * self.tagset_size) for i in range(0, len(s) - 1)], | |
targets_per_sentence, | |
) | |
) | |
return targets_per_sentence, matrix_indices | |
class ViterbiDecoder: | |
""" | |
Decodes a given sequence using the Viterbi algorithm. | |
""" | |
def __init__(self, tag_dictionary: Dictionary): | |
""" | |
:param tag_dictionary: Dictionary of tags for sequence labeling task | |
""" | |
self.tag_dictionary = tag_dictionary | |
self.tagset_size = len(tag_dictionary) | |
self.start_tag = tag_dictionary.get_idx_for_item(START_TAG) | |
self.stop_tag = tag_dictionary.get_idx_for_item(STOP_TAG) | |
def decode( | |
self, features_tuple: tuple, probabilities_for_all_classes: bool, sentences: List[Sentence] | |
) -> Tuple[List, List]: | |
""" | |
Decoding function returning the most likely sequence of tags. | |
:param features_tuple: CRF scores from forward method in shape (batch size, seq len, tagset size, tagset size), | |
lengths of sentence in batch, transitions of CRF | |
:param probabilities_for_all_classes: whether to return probabilities for all tags | |
:return: decoded sequences | |
""" | |
features, lengths, transitions = features_tuple | |
all_tags = [] | |
batch_size = features.size(0) | |
seq_len = features.size(1) | |
# Create a tensor to hold accumulated sequence scores at each current tag | |
scores_upto_t = torch.zeros(batch_size, seq_len + 1, self.tagset_size).to(flair.device) | |
# Create a tensor to hold back-pointers | |
# i.e., indices of the previous_tag that corresponds to maximum accumulated score at current tag | |
# Let pads be the <end> tag index, since that was the last tag in the decoded sequence | |
backpointers = ( | |
torch.ones((batch_size, seq_len + 1, self.tagset_size), dtype=torch.long, device=flair.device) | |
* self.stop_tag | |
) | |
for t in range(seq_len): | |
batch_size_t = sum([length > t for length in lengths]) # effective batch size (sans pads) at this timestep | |
terminates = [i for i, length in enumerate(lengths) if length == t + 1] | |
if t == 0: | |
scores_upto_t[:batch_size_t, t] = features[:batch_size_t, t, :, self.start_tag] | |
backpointers[:batch_size_t, t, :] = ( | |
torch.ones((batch_size_t, self.tagset_size), dtype=torch.long) * self.start_tag | |
) | |
else: | |
# We add scores at current timestep to scores accumulated up to previous timestep, and | |
# choose the previous timestep that corresponds to the max. accumulated score for each current timestep | |
scores_upto_t[:batch_size_t, t], backpointers[:batch_size_t, t, :] = torch.max( | |
features[:batch_size_t, t, :, :] + scores_upto_t[:batch_size_t, t - 1].unsqueeze(1), dim=2 | |
) | |
# If sentence is over, add transition to STOP-tag | |
if terminates: | |
scores_upto_t[terminates, t + 1], backpointers[terminates, t + 1, :] = torch.max( | |
scores_upto_t[terminates, t].unsqueeze(1) + transitions[self.stop_tag].unsqueeze(0), dim=2 | |
) | |
# Decode/trace best path backwards | |
decoded = torch.zeros((batch_size, backpointers.size(1)), dtype=torch.long, device=flair.device) | |
pointer = torch.ones((batch_size, 1), dtype=torch.long, device=flair.device) * self.stop_tag | |
for t in list(reversed(range(backpointers.size(1)))): | |
decoded[:, t] = torch.gather(backpointers[:, t, :], 1, pointer).squeeze(1) | |
pointer = decoded[:, t].unsqueeze(1) | |
# Sanity check | |
assert torch.equal( | |
decoded[:, 0], torch.ones((batch_size), dtype=torch.long, device=flair.device) * self.start_tag | |
) | |
# remove start-tag and backscore to stop-tag | |
scores_upto_t = scores_upto_t[:, :-1, :] | |
decoded = decoded[:, 1:] | |
# Max + Softmax to get confidence score for predicted label and append label to each token | |
scores = softmax(scores_upto_t, dim=2) | |
confidences = torch.max(scores, dim=2) | |
tags = [] | |
for tag_seq, tag_seq_conf, length_seq in zip(decoded, confidences.values, lengths): | |
tags.append( | |
[ | |
(self.tag_dictionary.get_item_for_index(tag), conf.item()) | |
for tag, conf in list(zip(tag_seq, tag_seq_conf))[:length_seq] | |
] | |
) | |
if probabilities_for_all_classes: | |
all_tags = self._all_scores_for_token(scores.cpu(), lengths, sentences) | |
return tags, all_tags | |
def _all_scores_for_token(self, scores: torch.Tensor, lengths: torch.IntTensor, sentences: List[Sentence]): | |
""" | |
Returns all scores for each tag in tag dictionary. | |
:param scores: Scores for current sentence. | |
""" | |
scores = scores.numpy() | |
prob_tags_per_sentence = [] | |
for scores_sentence, length, sentence in zip(scores, lengths, sentences): | |
scores_sentence = scores_sentence[:length] | |
prob_tags_per_sentence.append( | |
[ | |
[ | |
Label(token, self.tag_dictionary.get_item_for_index(score_id), score) | |
for score_id, score in enumerate(score_dist) | |
] | |
for score_dist, token in zip(scores_sentence, sentence) | |
] | |
) | |
return prob_tags_per_sentence |