sociolome / sftp /modules /span_typing /mlp_span_typing.py
Gosse Minnema
Initial commit
05922fb
from typing import *
import torch
from torch.nn import CrossEntropyLoss, KLDivLoss, LogSoftmax
from .span_typing import SpanTyping
@SpanTyping.register('mlp')
class MLPSpanTyping(SpanTyping):
"""
An MLP implementation for Span Typing.
"""
def __init__(
self,
input_dim: int,
hidden_dims: List[int],
label_emb: torch.nn.Embedding,
n_category: int,
label_to_ignore: Optional[List[int]] = None
):
"""
:param input_dim: dim(parent_span) + dim(child_span) + dim(label_dim)
:param hidden_dims: The dim of hidden layers of MLP.
:param n_category: #labels
:param label_emb: Embeds labels to vectors.
"""
super().__init__(label_emb.num_embeddings, label_to_ignore, )
self.MLPs: List[torch.nn.Linear] = list()
for i_mlp, output_dim in enumerate(hidden_dims + [n_category]):
mlp = torch.nn.Linear(input_dim, output_dim, bias=True)
self.MLPs.append(mlp)
self.add_module(f'MLP-{i_mlp}', mlp)
input_dim = output_dim
# Embeds labels as features.
self.label_emb = label_emb
def forward(
self,
span_vec: torch.Tensor,
parent_at_span: torch.Tensor,
span_labels: Optional[torch.Tensor],
prediction_only: bool = False,
) -> Dict[str, torch.Tensor]:
"""
Inputs: All features for typing a child span.
Process: Update the metric.
Output: The loss of typing and predictions.
:return:
loss: Loss for label prediction.
prediction: Predicted labels.
"""
is_soft = span_labels.dtype != torch.int64
# Shape [batch, span, label_dim]
label_vec = span_labels @ self.label_emb.weight if is_soft else self.label_emb(span_labels)
n_batch, n_span, _ = label_vec.shape
n_label, _ = self.ontology.shape
# Shape [batch, span, label_dim]
parent_label_features = label_vec.gather(1, parent_at_span.unsqueeze(2).expand_as(label_vec))
# Shape [batch, span, token_dim]
parent_span_features = span_vec.gather(1, parent_at_span.unsqueeze(2).expand_as(span_vec))
# Shape [batch, span, token_dim]
child_span_features = span_vec
features = torch.cat([parent_label_features, parent_span_features, child_span_features], dim=2)
# Shape [batch, span, label]
for mlp in self.MLPs[:-1]:
features = torch.relu(mlp(features))
logits = self.MLPs[-1](features)
logits_for_prediction = logits.clone()
if not is_soft:
# Shape [batch, span]
parent_labels = span_labels.gather(1, parent_at_span)
onto_mask = self.ontology.unsqueeze(0).expand(n_batch, -1, -1).gather(
1, parent_labels.unsqueeze(2).expand(-1, -1, n_label)
)
logits_for_prediction[~onto_mask] = float('-inf')
label_dist = torch.softmax(logits_for_prediction, 2)
label_confidence, predictions = label_dist.max(2)
ret = {'prediction': predictions, 'label_confidence': label_confidence, 'distribution': label_dist}
if prediction_only:
return ret
span_labels = span_labels.clone()
if is_soft:
self.acc_metric(logits_for_prediction, span_labels.max(2)[1], ~span_labels.sum(2).isclose(torch.tensor(0.)))
ret['loss'] = KLDivLoss(reduction='sum')(LogSoftmax(dim=2)(logits), span_labels)
else:
for label_idx in self.label_to_ignore:
span_labels[span_labels == label_idx] = -100
self.acc_metric(logits_for_prediction, span_labels, span_labels != -100)
ret['loss'] = CrossEntropyLoss(reduction='sum')(logits.flatten(0, 1), span_labels.flatten())
return ret