|
from typing import * |
|
|
|
import torch |
|
from torch.nn import CrossEntropyLoss, KLDivLoss, LogSoftmax |
|
|
|
from .span_typing import SpanTyping |
|
|
|
|
|
@SpanTyping.register('mlp') |
|
class MLPSpanTyping(SpanTyping): |
|
""" |
|
An MLP implementation for Span Typing. |
|
""" |
|
def __init__( |
|
self, |
|
input_dim: int, |
|
hidden_dims: List[int], |
|
label_emb: torch.nn.Embedding, |
|
n_category: int, |
|
label_to_ignore: Optional[List[int]] = None |
|
): |
|
""" |
|
:param input_dim: dim(parent_span) + dim(child_span) + dim(label_dim) |
|
:param hidden_dims: The dim of hidden layers of MLP. |
|
:param n_category: #labels |
|
:param label_emb: Embeds labels to vectors. |
|
""" |
|
super().__init__(label_emb.num_embeddings, label_to_ignore, ) |
|
self.MLPs: List[torch.nn.Linear] = list() |
|
for i_mlp, output_dim in enumerate(hidden_dims + [n_category]): |
|
mlp = torch.nn.Linear(input_dim, output_dim, bias=True) |
|
self.MLPs.append(mlp) |
|
self.add_module(f'MLP-{i_mlp}', mlp) |
|
input_dim = output_dim |
|
|
|
|
|
self.label_emb = label_emb |
|
|
|
def forward( |
|
self, |
|
span_vec: torch.Tensor, |
|
parent_at_span: torch.Tensor, |
|
span_labels: Optional[torch.Tensor], |
|
prediction_only: bool = False, |
|
) -> Dict[str, torch.Tensor]: |
|
""" |
|
Inputs: All features for typing a child span. |
|
Process: Update the metric. |
|
Output: The loss of typing and predictions. |
|
:return: |
|
loss: Loss for label prediction. |
|
prediction: Predicted labels. |
|
""" |
|
is_soft = span_labels.dtype != torch.int64 |
|
|
|
label_vec = span_labels @ self.label_emb.weight if is_soft else self.label_emb(span_labels) |
|
n_batch, n_span, _ = label_vec.shape |
|
n_label, _ = self.ontology.shape |
|
|
|
parent_label_features = label_vec.gather(1, parent_at_span.unsqueeze(2).expand_as(label_vec)) |
|
|
|
parent_span_features = span_vec.gather(1, parent_at_span.unsqueeze(2).expand_as(span_vec)) |
|
|
|
child_span_features = span_vec |
|
|
|
features = torch.cat([parent_label_features, parent_span_features, child_span_features], dim=2) |
|
|
|
for mlp in self.MLPs[:-1]: |
|
features = torch.relu(mlp(features)) |
|
logits = self.MLPs[-1](features) |
|
|
|
logits_for_prediction = logits.clone() |
|
|
|
if not is_soft: |
|
|
|
parent_labels = span_labels.gather(1, parent_at_span) |
|
onto_mask = self.ontology.unsqueeze(0).expand(n_batch, -1, -1).gather( |
|
1, parent_labels.unsqueeze(2).expand(-1, -1, n_label) |
|
) |
|
logits_for_prediction[~onto_mask] = float('-inf') |
|
|
|
label_dist = torch.softmax(logits_for_prediction, 2) |
|
label_confidence, predictions = label_dist.max(2) |
|
ret = {'prediction': predictions, 'label_confidence': label_confidence, 'distribution': label_dist} |
|
if prediction_only: |
|
return ret |
|
|
|
span_labels = span_labels.clone() |
|
|
|
if is_soft: |
|
self.acc_metric(logits_for_prediction, span_labels.max(2)[1], ~span_labels.sum(2).isclose(torch.tensor(0.))) |
|
ret['loss'] = KLDivLoss(reduction='sum')(LogSoftmax(dim=2)(logits), span_labels) |
|
else: |
|
for label_idx in self.label_to_ignore: |
|
span_labels[span_labels == label_idx] = -100 |
|
self.acc_metric(logits_for_prediction, span_labels, span_labels != -100) |
|
ret['loss'] = CrossEntropyLoss(reduction='sum')(logits.flatten(0, 1), span_labels.flatten()) |
|
|
|
return ret |
|
|