spitzc32
Added initial structure of the model
24d0437
from functools import lru_cache
from itertools import chain
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
def enumerate_spans(n):
for i in range(n):
for j in range(i, n):
yield (i, j)
@lru_cache # type: ignore
def get_all_spans(n: int) -> torch.Tensor:
return torch.tensor(list(enumerate_spans(n)), dtype=torch.long)
class SpanClassifier(nn.Module):
num_additional_labels = 1
def __init__(self, encoder, scorer: "SpanScorer"):
super().__init__()
self.encoder = encoder
self.scorer = scorer
def forward(
self, *input_ids: Sequence[torch.Tensor]
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
hs, lengths = self.encoder(*input_ids)
spans = list(map(get_all_spans, lengths))
scores = self.scorer(hs, spans)
return spans, scores
@torch.no_grad()
def decode(
self,
spans: Sequence[torch.Tensor],
scores: Sequence[torch.Tensor],
) -> List[List[Tuple[int, int, int]]]:
spans_flatten = torch.cat(spans)
scores_flatten = torch.cat(scores)
assert len(spans_flatten) == len(scores_flatten)
labels_flatten = scores_flatten.argmax(dim=1).cpu()
mask = labels_flatten < self.scorer.num_labels - 1
mentions = torch.hstack((spans_flatten[mask], labels_flatten[mask, None]))
output = []
offset = 0
sizes = [m.sum() for m in torch.split(mask, [len(idxs) for idxs in spans])]
for size in sizes:
output.append([tuple(m) for m in mentions[offset : offset + size].tolist()])
offset += size
return output # type: ignore
def compute_metrics(
self,
spans: Sequence[torch.Tensor],
scores: Sequence[torch.Tensor],
true_mentions: Sequence[Sequence[Tuple[int, int, int]]],
decode=True,
) -> Dict[str, Any]:
assert len(spans) == len(scores) == len(true_mentions)
num_labels = self.scorer.num_labels
true_labels = []
for spans_i, scores_i, true_mentions_i in zip(spans, scores, true_mentions):
assert len(spans_i) == len(scores_i)
span2idx = {tuple(s): idx for idx, s in enumerate(spans_i.tolist())}
labels_i = torch.full((len(spans_i),), fill_value=num_labels - 1)
for (start, end, label) in true_mentions_i:
idx = span2idx.get((start, end))
if idx is not None:
labels_i[idx] = label
true_labels.append(labels_i)
scores_flatten = torch.cat(scores)
true_labels_flatten = torch.cat(true_labels).to(scores_flatten.device)
assert len(scores_flatten) == len(true_labels_flatten)
loss = F.cross_entropy(scores_flatten, true_labels_flatten)
accuracy = categorical_accuracy(scores_flatten, true_labels_flatten)
result = {"loss": loss, "accuracy": accuracy}
if decode:
pred_mentions = self.decode(spans, scores)
tp, fn, fp = 0, 0, 0
for pred_mentions_i, true_mentions_i in zip(pred_mentions, true_mentions):
pred, gold = set(pred_mentions_i), set(true_mentions_i)
tp += len(gold & pred)
fn += len(gold - pred)
fp += len(pred - gold)
result["precision"] = (tp, tp + fp)
result["recall"] = (tp, tp + fn)
result["mentions"] = pred_mentions
return result
@torch.no_grad()
def categorical_accuracy(
y: torch.Tensor, t: torch.Tensor, ignore_index: Optional[int] = None
) -> Tuple[int, int]:
pred = y.argmax(dim=1)
if ignore_index is not None:
mask = t == ignore_index
ignore_cnt = mask.sum()
pred.masked_fill_(mask, ignore_index)
count = ((pred == t).sum() - ignore_cnt).item()
total = (t.numel() - ignore_cnt).item()
else:
count = (pred == t).sum().item()
total = t.numel()
return count, total
class SpanScorer(torch.nn.Module):
def __init__(self, num_labels: int):
super().__init__()
self.num_labels = num_labels
def forward(
self, xs: torch.Tensor, spans: Sequence[torch.Tensor]
):
raise NotImplementedError
class BaselineSpanScorer(SpanScorer):
def __init__(
self,
input_size: int,
num_labels: int,
mlp_units: Union[int, Sequence[int]] = 150,
mlp_dropout: float = 0.0,
feature="concat",
):
super().__init__(num_labels)
input_size *= 2 if feature == "concat" else 1
self.mlp = MLP(input_size, num_labels, mlp_units, F.relu, mlp_dropout)
self.feature = feature
def forward(
self, xs: torch.Tensor, spans: Sequence[torch.Tensor]
):
max_length = xs.size(1)
xs_flatten = xs.reshape(-1, xs.size(-1))
spans_flatten = torch.cat([idxs + max_length * i for i, idxs in enumerate(spans)])
features = self._compute_feature(xs_flatten, spans_flatten)
scores = self.mlp(features)
return torch.split(scores, [len(idxs) for idxs in spans])
def _compute_feature(self, xs, spans):
if self.feature == "concat":
return xs[spans.ravel()].view(len(spans), -1)
elif self.feature == "minus":
begins, ends = spans.T
return xs[ends] - xs[begins]
else:
raise NotImplementedError
class MLP(nn.Sequential):
def __init__(
self,
in_features: int,
out_features: Optional[int],
units: Optional[Union[int, Sequence[int]]] = None,
activate: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
dropout: float = 0.0,
bias: bool = True,
):
units = [units] if isinstance(units, int) else units
if not units and out_features is None:
raise ValueError("'out_features' or 'units' must be specified")
layers = []
for u in units or []:
layers.append(MLP.Layer(in_features, u, activate, dropout, bias))
in_features = u
if out_features is not None:
layers.append(MLP.Layer(in_features, out_features, None, 0.0, bias))
super().__init__(*layers)
class Layer(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
activate: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
dropout: float = 0.0,
bias: bool = True,
):
super().__init__()
if activate is not None and not callable(activate):
raise TypeError("activate must be callable: type={}".format(type(activate)))
self.linear = nn.Linear(in_features, out_features, bias)
self.activate = activate
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.linear(x)
if self.activate is not None:
h = self.activate(h)
return self.dropout(h)
def extra_repr(self) -> str:
return "{}, activate={}, dropout={}".format(
self.linear.extra_repr(), self.activate, self.dropout.p
)
def __repr__(self):
return "{}.{}({})".format(MLP.__name__, self._get_name(), self.extra_repr())