Spaces:
Runtime error
Runtime error
import logging | |
from typing import List, Optional, Tuple, Union | |
import torch | |
import torch.nn | |
import torch.nn.functional as F | |
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence | |
from tqdm import tqdm | |
import flair.nn | |
from part import * | |
from flair.data import Dictionary, Sentence | |
from flair.datasets import DataLoader, FlairDatapointDataset | |
from flair.embeddings import TokenEmbeddings | |
from flair.file_utils import cached_path | |
from flair.training_utils import store_embeddings | |
from model.layer.bioes import get_spans_from_bio | |
from model.layer.lstm import LSTM | |
from model.layer.crf import CRF | |
from model.layer.viterbi import ViterbiDecoder, ViterbiLoss | |
log = logging.getLogger("flair") | |
class Bi_LSTM_CRF(flair.nn.Classifier[Sentence]): | |
def __init__( | |
self, | |
embeddings: TokenEmbeddings, | |
tag_dictionary: Dictionary, | |
tag_type: str, | |
rnn: Optional[torch.nn.RNN] = None, | |
tag_format: str = "BIOES", | |
hidden_size: int = 256, | |
rnn_layers: int = 1, | |
bidirectional: bool = True, | |
use_crf: bool = True, | |
ave_embeddings: bool = True, | |
dropout: float = 0.0, | |
word_dropout: float = 0.05, | |
locked_dropout: float = 0.5, | |
loss_weights: Dict[str, float] = None, | |
init_from_state_dict: bool = False, | |
allow_unk_predictions: bool = False, | |
): | |
""" | |
BiLSTM Span CRF class for predicting labels for single tokens. Can be parameterized by several attributes. | |
Span prediction is utilized if there are nested entities such as Address and Organization. Since the researchers | |
observed that the token are have different length for a given dataset, we made the Span useful by incorporating it | |
only if the data needs it. | |
:param embeddings: Embeddings to use during training and prediction | |
:param tag_dictionary: Dictionary containing all tags from corpus which can be predicted | |
:param tag_type: type of tag which is going to be predicted in case a corpus has multiple annotations | |
:param rnn: (Optional) Takes a torch.nn.Module as parameter by which you can pass a shared RNN between | |
different tasks. | |
:param hidden_size: Hidden size of RNN layer | |
:param rnn_layers: number of RNN layers | |
:param bidirectional: If True, RNN becomes bidirectional | |
:param use_crf: If True, use a Conditional Random Field for prediction, else linear map to tag space. | |
:param ave_embeddings: If True, add a linear layer on top of embeddings, if you want to imitate | |
fine tune non-trainable embeddings. | |
:param dropout: If > 0, then use dropout. | |
:param word_dropout: If > 0, then use word dropout. | |
:param locked_dropout: If > 0, then use locked dropout. | |
:param loss_weights: Dictionary of weights for labels for the loss function | |
(if any label's weight is unspecified it will default to 1.0) | |
:param init_from_state_dict: Indicator whether we are loading a model from state dict | |
since we need to transform previous models' weights into CRF instance weights | |
""" | |
super(Bi_LSTM_CRF, self).__init__() | |
# ----- Create the internal tag dictionary ----- | |
self.tag_type = tag_type | |
self.tag_format = tag_format.upper() | |
if init_from_state_dict: | |
self.label_dictionary = tag_dictionary | |
else: | |
# span-labels need special encoding (BIO or BIOES) | |
if tag_dictionary.span_labels: | |
# the big question is whether the label dictionary should contain an UNK or not | |
# without UNK, we cannot evaluate on data that contains labels not seen in test | |
# with UNK, the model learns less well if there are no UNK examples | |
self.label_dictionary = Dictionary(add_unk=allow_unk_predictions) | |
assert self.tag_format in ["BIOES", "BIO"] | |
for label in tag_dictionary.get_items(): | |
if label == "<unk>": | |
continue | |
self.label_dictionary.add_item("O") | |
if self.tag_format == "BIOES": | |
self.label_dictionary.add_item("S-" + label) | |
self.label_dictionary.add_item("B-" + label) | |
self.label_dictionary.add_item("E-" + label) | |
self.label_dictionary.add_item("I-" + label) | |
if self.tag_format == "BIO": | |
self.label_dictionary.add_item("B-" + label) | |
self.label_dictionary.add_item("I-" + label) | |
else: | |
self.label_dictionary = tag_dictionary | |
# is this a span prediction problem? | |
self.predict_spans = self._determine_if_span_prediction_problem(self.label_dictionary) | |
self.tagset_size = len(self.label_dictionary) | |
log.info(f"SequenceTagger predicts: {self.label_dictionary}") | |
# ----- Embeddings ----- | |
# We set the first initial embeddings gathered from Flair | |
# Stacked and concatenated then ave. using Linear | |
self.embeddings = embeddings | |
embedding_dim: int = embeddings.embedding_length | |
# ----- Initial loss weights parameters ----- | |
# This is for reiteration process of training. | |
# Initially we don't have any loss weights, but as we proceed to training, | |
# we get loss computations from the evaluation stage. | |
self.weight_dict = loss_weights | |
self.loss_weights = self._init_loss_weights(loss_weights) if loss_weights else None | |
# ----- RNN specific parameters ----- | |
# These parameters are for setting up the self.RNN | |
self.hidden_size = hidden_size if not rnn else rnn.hidden_size | |
self.rnn_layers = rnn_layers if not rnn else rnn.num_layers | |
self.bidirectional = bidirectional if not rnn else rnn.bidirectional | |
# ----- Conditional Random Field parameters ----- | |
self.use_crf = use_crf | |
# Previously trained models have been trained without an explicit CRF, thus it is required to check | |
# whether we are loading a model from state dict in order to skip or add START and STOP token | |
if use_crf and not init_from_state_dict and not self.label_dictionary.start_stop_tags_are_set(): | |
self.label_dictionary.set_start_stop_tags() | |
self.tagset_size += 2 | |
# ----- Dropout parameters ----- | |
# dropouts | |
self.use_dropout: float = dropout | |
self.use_word_dropout: float = word_dropout | |
self.use_locked_dropout: float = locked_dropout | |
if dropout > 0.0: | |
self.dropout = torch.nn.Dropout(dropout) | |
if word_dropout > 0.0: | |
self.word_dropout = flair.nn.WordDropout(word_dropout) | |
if locked_dropout > 0.0: | |
self.locked_dropout = flair.nn.LockedDropout(locked_dropout) | |
# ----- Model layers ----- | |
# Initialize Embedding Linear Dim for the purpose of ave them | |
self.ave_embeddings = ave_embeddings | |
if self.ave_embeddings: | |
self.embedding2nn = torch.nn.Linear(embedding_dim, embedding_dim) | |
# ----- RNN layer ----- | |
# If shared RNN provided, else create one for model | |
self.rnn: torch.nn.RNN = ( | |
rnn | |
if rnn | |
else LSTM( | |
rnn_layers, | |
hidden_size, | |
bidirectional, | |
rnn_input_dim=embedding_dim, | |
) | |
) | |
num_directions = 2 if self.bidirectional else 1 | |
hidden_output_dim = self.rnn.hidden_size * num_directions | |
# final linear map to tag space | |
self.linear = torch.nn.Linear(hidden_output_dim, len(self.label_dictionary)) | |
# the loss function is Viterbi if using CRF, else regular Cross Entropy Loss | |
self.loss_function = ( | |
ViterbiLoss(self.label_dictionary) | |
) | |
# if using CRF, we also require a CRF and a Viterbi decoder | |
if use_crf: | |
self.crf = CRF(self.label_dictionary, self.tagset_size, init_from_state_dict) | |
self.viterbi_decoder = ViterbiDecoder(self.label_dictionary) | |
self.to(flair.device) | |
def label_type(self): | |
return self.tag_type | |
def _init_loss_weights(self, loss_weights: Dict[str, float]) -> torch.Tensor: | |
""" | |
Intializes the loss weights based on given dictionary: | |
:param loss_weights: dictionary - contains loss weights | |
""" | |
n_classes = len(self.label_dictionary) | |
weight_list = [1.0 for _ in range(n_classes)] | |
for i, tag in enumerate(self.label_dictionary.get_items()): | |
if tag in loss_weights.keys(): | |
weight_list[i] = loss_weights[tag] | |
return torch.tensor(weight_list).to(flair.device) | |
def forward_loss(self, sentences: Union[List[Sentence], Sentence]) -> Tuple[torch.Tensor, int]: | |
""" | |
Calculates the loss of the forward propagation of the model | |
:param sentences: either a listof sentence or just a sentence | |
""" | |
# if there are no sentences, there is no loss | |
if len(sentences) == 0: | |
return torch.tensor(0.0, dtype=torch.float, device=flair.device, requires_grad=True), 0 | |
# forward pass to get scores | |
scores, gold_labels = self.forward(sentences) # type: ignore | |
# calculate loss given scores and labels | |
return self._calculate_loss(scores, gold_labels) | |
def forward(self, sentences: Union[List[Sentence], Sentence]): | |
""" | |
Forward propagation through network. Returns gold labels of batch in addition. | |
:param sentences: Batch of current sentences | |
""" | |
if not isinstance(sentences, list): | |
sentences = [sentences] | |
self.embeddings.embed(sentences) | |
# make a zero-padded tensor for the whole sentence | |
lengths, sentence_tensor = self._make_padded_tensor_for_batch(sentences) | |
# sort tensor in decreasing order based on lengths of sentences in batch | |
sorted_lengths, length_indices = lengths.sort(dim=0, descending=True) | |
sentences = [sentences[i] for i in length_indices] | |
sentence_tensor = sentence_tensor[length_indices] | |
# ----- Forward Propagation ----- | |
# we get the dropout we initialize for th regularization | |
# of our inputs | |
if self.use_dropout: | |
sentence_tensor = self.dropout(sentence_tensor) | |
if self.use_word_dropout: | |
sentence_tensor = self.word_dropout(sentence_tensor) | |
if self.use_locked_dropout: | |
sentence_tensor = self.locked_dropout(sentence_tensor) | |
# Average the embeddings using Linear Transform | |
if self.ave_embeddings: | |
sentence_tensor = self.embedding2nn(sentence_tensor) | |
# This packs our Sentence tensor form, the process for weighting | |
# our LSTM model | |
sentence_tensor, output_lengths = self.rnn(sentence_tensor, sorted_lengths) | |
# Regularize our computed sentence tensor form the LSTM model | |
if self.use_dropout: | |
sentence_tensor = self.dropout(sentence_tensor) | |
if self.use_locked_dropout: | |
sentence_tensor = self.locked_dropout(sentence_tensor) | |
# linear map to tag space | |
features = self.linear(sentence_tensor) | |
# Depending on whether we are using CRF or a linear layer, scores is either: | |
# -- A tensor of shape (batch size, sequence length, tagset size, tagset size) for CRF | |
# -- A tensor of shape (aggregated sequence length for all sentences in batch, tagset size) for linear layer | |
if self.use_crf: | |
features = self.crf(features) | |
scores = (features, sorted_lengths, self.crf.transitions) | |
else: | |
scores = self._get_scores_from_features(features, sorted_lengths) | |
# get the gold labels | |
gold_labels = self._get_gold_labels(sentences) | |
return scores, gold_labels | |
def _calculate_loss(self, scores, labels) -> Tuple[torch.Tensor, int]: | |
if not any(labels): | |
return torch.tensor(0.0, requires_grad=True, device=flair.device), 1 | |
labels = torch.tensor( | |
[ | |
self.label_dictionary.get_idx_for_item(label[0]) | |
if len(label) > 0 | |
else self.label_dictionary.get_idx_for_item("O") | |
for label in labels | |
], | |
dtype=torch.long, | |
device=flair.device, | |
) | |
return self.loss_function(scores, labels), len(labels) | |
def _make_padded_tensor_for_batch(self, sentences: List[Sentence]) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
makes zero padded tensors in the shape of the max longest sentence and the embedding_length to match | |
the shape of the embedding in feeding to our LSTM model. | |
:param sentences: Batch of current sentences | |
""" | |
names = self.embeddings.get_names() | |
tok_lengths: List[int] = [len(sentence.tokens) for sentence in sentences] | |
longest_token_sequence_in_batch: int = max(tok_lengths) | |
zero_tensor = torch.zeros( | |
self.embeddings.embedding_length * longest_token_sequence_in_batch, | |
dtype=torch.float, | |
device=flair.device, | |
) | |
all_embs = list() | |
for sentence in sentences: | |
all_embs += [emb for token in sentence for emb in token.get_each_embedding(names)] | |
nb_padding_tokens = longest_token_sequence_in_batch - len(sentence) | |
if nb_padding_tokens > 0: | |
t = zero_tensor[: self.embeddings.embedding_length * nb_padding_tokens] | |
all_embs.append(t) | |
sentence_tensor = torch.cat(all_embs).view( | |
[ | |
len(sentences), | |
longest_token_sequence_in_batch, | |
self.embeddings.embedding_length, | |
] | |
) | |
return torch.tensor(tok_lengths, dtype=torch.long), sentence_tensor | |
def _get_scores_from_features(features: torch.Tensor, lengths: torch.Tensor): | |
""" | |
Trims current batch tensor in shape (batch size, sequence length, tagset size) in such a way that all | |
pads are going to be removed. | |
:param features: torch.tensor containing all features from forward propagation | |
:param lengths: length from each sentence in batch in order to trim padding tokens | |
""" | |
features_formatted = [] | |
for feat, lens in zip(features, lengths): | |
features_formatted.append(feat[:lens]) | |
scores = torch.cat(features_formatted) | |
return scores | |
def _get_gold_labels(self, sentences: Union[List[Sentence], Sentence]): | |
""" | |
Extracts gold labels from each sentence. | |
:param sentences: List of sentences in batch | |
""" | |
# spans need to be encoded as token-level predictions | |
if self.predict_spans: | |
all_sentence_labels = [] | |
for sentence in sentences: | |
sentence_labels = ["O"] * len(sentence) | |
for label in sentence.get_labels(self.label_type): | |
span: Span = label.data_point | |
if self.tag_format == "BIOES": | |
if len(span) == 1: | |
sentence_labels[span[0].idx - 1] = "S-" + label.value | |
else: | |
sentence_labels[span[0].idx - 1] = "B-" + label.value | |
sentence_labels[span[-1].idx - 1] = "E-" + label.value | |
for i in range(span[0].idx, span[-1].idx - 1): | |
sentence_labels[i] = "I-" + label.value | |
else: | |
sentence_labels[span[0].idx - 1] = "B-" + label.value | |
for i in range(span[0].idx, span[-1].idx): | |
sentence_labels[i] = "I-" + label.value | |
all_sentence_labels.extend(sentence_labels) | |
labels = [[label] for label in all_sentence_labels] | |
# all others are regular labels for each token | |
else: | |
labels = [[token.get_label(self.label_type, "O").value] for sentence in sentences for token in sentence] | |
return labels | |
def predict( | |
self, | |
sentences: Union[List[Sentence], Sentence], | |
mini_batch_size: int = 32, | |
return_probabilities_for_all_classes: bool = False, | |
verbose: bool = False, | |
label_name: Optional[str] = None, | |
return_loss=False, | |
embedding_storage_mode="none", | |
force_token_predictions: bool = False, | |
): # type: ignore | |
""" | |
Predicts labels for current batch with CRF. | |
:param sentences: List of sentences in batch | |
:param mini_batch_size: batch size for test data | |
:param return_probabilities_for_all_classes: Whether to return probabilites for all classes | |
:param verbose: whether to use progress bar | |
:param label_name: which label to predict | |
:param return_loss: whether to return loss value | |
:param embedding_storage_mode: determines where to store embeddings - can be "gpu", "cpu" or None. | |
""" | |
if label_name is None: | |
label_name = self.tag_type | |
with torch.no_grad(): | |
if not sentences: | |
return sentences | |
# make sure its a list | |
if not isinstance(sentences, list) and not isinstance(sentences, flair.data.Dataset): | |
sentences = [sentences] | |
# filter empty sentences | |
sentences = [sentence for sentence in sentences if len(sentence) > 0] | |
# reverse sort all sequences by their length | |
reordered_sentences = sorted(sentences, key=lambda s: len(s), reverse=True) | |
if len(reordered_sentences) == 0: | |
return sentences | |
dataloader = DataLoader( | |
dataset=FlairDatapointDataset(reordered_sentences), | |
batch_size=mini_batch_size, | |
) | |
# progress bar for verbosity | |
if verbose: | |
dataloader = tqdm(dataloader, desc="Batch inference") | |
overall_loss = torch.zeros(1, device=flair.device) | |
batch_no = 0 | |
label_count = 0 | |
for batch in dataloader: | |
batch_no += 1 | |
# stop if all sentences are empty | |
if not batch: | |
continue | |
# get features from forward propagation | |
features, gold_labels = self.forward(batch) | |
# remove previously predicted labels of this type | |
for sentence in batch: | |
sentence.remove_labels(label_name) | |
# if return_loss, get loss value | |
if return_loss: | |
loss = self._calculate_loss(features, gold_labels) | |
overall_loss += loss[0] | |
label_count += loss[1] | |
# Sort batch in same way as forward propagation | |
lengths = torch.LongTensor([len(sentence) for sentence in batch]) | |
_, sort_indices = lengths.sort(dim=0, descending=True) | |
batch = [batch[i] for i in sort_indices] | |
# make predictions | |
if self.use_crf: | |
predictions, all_tags = self.viterbi_decoder.decode( | |
features, return_probabilities_for_all_classes, batch | |
) | |
else: | |
predictions, all_tags = self._standard_inference( | |
features, batch, return_probabilities_for_all_classes | |
) | |
# add predictions to Sentence | |
for sentence, sentence_predictions in zip(batch, predictions): | |
# BIOES-labels need to be converted to spans | |
if self.predict_spans and not force_token_predictions: | |
sentence_tags = [label[0] for label in sentence_predictions] | |
sentence_scores = [label[1] for label in sentence_predictions] | |
predicted_spans = get_spans_from_bio(sentence_tags, sentence_scores) | |
for predicted_span in predicted_spans: | |
span: Span = sentence[predicted_span[0][0] : predicted_span[0][-1] + 1] | |
span.add_label(label_name, value=predicted_span[2], score=predicted_span[1]) | |
# token-labels can be added directly ("O" and legacy "_" predictions are skipped) | |
else: | |
for token, label in zip(sentence.tokens, sentence_predictions): | |
if label[0] in ["O", "_"]: | |
continue | |
token.add_label(typename=label_name, value=label[0], score=label[1]) | |
# all_tags will be empty if all_tag_prob is set to False, so the for loop will be avoided | |
for (sentence, sent_all_tags) in zip(batch, all_tags): | |
for (token, token_all_tags) in zip(sentence.tokens, sent_all_tags): | |
token.add_tags_proba_dist(label_name, token_all_tags) | |
store_embeddings(sentences, storage_mode=embedding_storage_mode) | |
if return_loss: | |
return overall_loss, label_count | |
def _standard_inference(self, features: torch.Tensor, batch: List[Sentence], probabilities_for_all_classes: bool): | |
""" | |
Softmax over emission scores from forward propagation. | |
:param features: sentence tensor from forward propagation | |
:param batch: list of sentence | |
:param probabilities_for_all_classes: whether to return score for each tag in tag dictionary | |
""" | |
softmax_batch = F.softmax(features, dim=1).cpu() | |
scores_batch, prediction_batch = torch.max(softmax_batch, dim=1) | |
predictions = [] | |
all_tags = [] | |
for sentence in batch: | |
scores = scores_batch[: len(sentence)] | |
predictions_for_sentence = prediction_batch[: len(sentence)] | |
predictions.append( | |
[ | |
(self.label_dictionary.get_item_for_index(prediction), score.item()) | |
for token, score, prediction in zip(sentence, scores, predictions_for_sentence) | |
] | |
) | |
scores_batch = scores_batch[len(sentence) :] | |
prediction_batch = prediction_batch[len(sentence) :] | |
if probabilities_for_all_classes: | |
lengths = [len(sentence) for sentence in batch] | |
all_tags = self._all_scores_for_token(batch, softmax_batch, lengths) | |
return predictions, all_tags | |
def _all_scores_for_token(self, sentences: List[Sentence], scores: torch.Tensor, lengths: List[int]): | |
""" | |
Returns all scores for each tag in tag dictionary. | |
:param scores: Scores for current sentence. | |
""" | |
scores = scores.numpy() | |
tokens = [token for sentence in sentences for token in sentence] | |
prob_all_tags = [ | |
[ | |
Label(token, self.label_dictionary.get_item_for_index(score_id), score) | |
for score_id, score in enumerate(score_dist) | |
] | |
for score_dist, token in zip(scores, tokens) | |
] | |
prob_tags_per_sentence = [] | |
previous = 0 | |
for length in lengths: | |
prob_tags_per_sentence.append(prob_all_tags[previous : previous + length]) | |
previous = length | |
return prob_tags_per_sentence | |
def _get_state_dict(self): | |
"""Returns the state dictionary for this model.""" | |
model_state = { | |
**super()._get_state_dict(), | |
"embeddings": self.embeddings, | |
"hidden_size": self.hidden_size, | |
"tag_dictionary": self.label_dictionary, | |
"tag_format": self.tag_format, | |
"tag_type": self.tag_type, | |
"use_crf": self.use_crf, | |
"rnn_layers": self.rnn_layers, | |
"use_dropout": self.use_dropout, | |
"use_word_dropout": self.use_word_dropout, | |
"use_locked_dropout": self.use_locked_dropout, | |
"ave_embeddings": self.ave_embeddings, | |
"weight_dict": self.weight_dict, | |
} | |
return model_state | |
def _init_model_with_state_dict(cls, state, **kwargs): | |
if state["use_crf"]: | |
if "transitions" in state["state_dict"]: | |
state["state_dict"]["crf.transitions"] = state["state_dict"]["transitions"] | |
del state["state_dict"]["transitions"] | |
return super()._init_model_with_state_dict( | |
state, | |
embeddings=state.get("embeddings"), | |
tag_dictionary=state.get("tag_dictionary"), | |
tag_format=state.get("tag_format", "BIOES"), | |
tag_type=state.get("tag_type"), | |
use_crf=state.get("use_crf"), | |
rnn_layers=state.get("rnn_layers"), | |
hidden_size=state.get("hidden_size"), | |
dropout=state.get("use_dropout", 0.0), | |
word_dropout=state.get("use_word_dropout", 0.0), | |
locked_dropout=state.get("use_locked_dropout", 0.0), | |
ave_embeddings=state.get("ave_embeddings", True), | |
loss_weights=state.get("weight_dict"), | |
init_from_state_dict=True, | |
**kwargs, | |
) | |
def _filter_empty_sentences(sentences: List[Sentence]) -> List[Sentence]: | |
filtered_sentences = [sentence for sentence in sentences if sentence.tokens] | |
if len(sentences) != len(filtered_sentences): | |
log.warning(f"Ignore {len(sentences) - len(filtered_sentences)} sentence(s) with no tokens.") | |
return filtered_sentences | |
def _determine_if_span_prediction_problem(self, dictionary: Dictionary) -> bool: | |
for item in dictionary.get_items(): | |
if item.startswith("B-") or item.startswith("S-") or item.startswith("I-"): | |
return True | |
return False | |
def _print_predictions(self, batch, gold_label_type): | |
lines = [] | |
if self.predict_spans: | |
for datapoint in batch: | |
# all labels default to "O" | |
for token in datapoint: | |
token.set_label("gold_bio", "O") | |
token.set_label("predicted_bio", "O") | |
# set gold token-level | |
for gold_label in datapoint.get_labels(gold_label_type): | |
gold_span: Span = gold_label.data_point | |
prefix = "B-" | |
for token in gold_span: | |
token.set_label("gold_bio", prefix + gold_label.value) | |
prefix = "I-" | |
# set predicted token-level | |
for predicted_label in datapoint.get_labels("predicted"): | |
predicted_span: Span = predicted_label.data_point | |
prefix = "B-" | |
for token in predicted_span: | |
token.set_label("predicted_bio", prefix + predicted_label.value) | |
prefix = "I-" | |
# now print labels in CoNLL format | |
for token in datapoint: | |
eval_line = ( | |
f"{token.text} " | |
f"{token.get_label('gold_bio').value} " | |
f"{token.get_label('predicted_bio').value}\n" | |
) | |
lines.append(eval_line) | |
lines.append("\n") | |
else: | |
for datapoint in batch: | |
# print labels in CoNLL format | |
for token in datapoint: | |
eval_line = ( | |
f"{token.text} " | |
f"{token.get_label(gold_label_type).value} " | |
f"{token.get_label('predicted').value}\n" | |
) | |
lines.append(eval_line) | |
lines.append("\n") | |
return lines | |