spitzc32
Added initial structure of the model
24d0437
import torch
import flair
START_TAG: str = "<START>"
STOP_TAG: str = "<STOP>"
class CRF(torch.nn.Module):
"""
Conditional Random Field Implementation according to sgrvinod and modified to not
only look at the current word, but also on the previously seen annotation.
"""
def __init__(self, tag_dictionary, tagset_size: int, init_from_state_dict: bool):
"""
:param tag_dictionary: tag dictionary in order to find ID for start and stop tags
:param tagset_size: number of tag from tag dictionary
:param init_from_state_dict: whether we load pretrained model from state dict
"""
super(CRF, self).__init__()
self.tagset_size = tagset_size
# Transitions are used in the following way: transitions[to, from].
self.transitions = torch.nn.Parameter(torch.randn(tagset_size, tagset_size))
# If we are not using a pretrained model and train a fresh one, we need to set transitions from any tag
# to START-tag and from STOP-tag to any other tag to -10000.
if not init_from_state_dict:
self.transitions.detach()[tag_dictionary.get_idx_for_item(START_TAG), :] = -10000
self.transitions.detach()[:, tag_dictionary.get_idx_for_item(STOP_TAG)] = -10000
self.to(flair.device)
def forward(self, features: torch.Tensor) -> torch.Tensor:
"""
Forward propagation of Conditional Random Field.
:param features: output from LSTM Layer in shape (batch size, seq len, hidden size)
:return: CRF scores (emission scores for each token + transitions prob from previous state) in
shape (batch_size, seq len, tagset size, tagset size)
"""
batch_size, seq_len = features.size()[:2]
emission_scores = features
emission_scores = emission_scores.unsqueeze(-1).expand(batch_size, seq_len, self.tagset_size, self.tagset_size)
crf_scores = emission_scores + self.transitions.unsqueeze(0).unsqueeze(0)
return crf_scores