Spaces:
Runtime error
Runtime error
from functools import lru_cache | |
from itertools import chain | |
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
def enumerate_spans(n): | |
for i in range(n): | |
for j in range(i, n): | |
yield (i, j) | |
# type: ignore | |
def get_all_spans(n: int) -> torch.Tensor: | |
return torch.tensor(list(enumerate_spans(n)), dtype=torch.long) | |
class SpanClassifier(nn.Module): | |
num_additional_labels = 1 | |
def __init__(self, encoder, scorer: "SpanScorer"): | |
super().__init__() | |
self.encoder = encoder | |
self.scorer = scorer | |
def forward( | |
self, *input_ids: Sequence[torch.Tensor] | |
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: | |
hs, lengths = self.encoder(*input_ids) | |
spans = list(map(get_all_spans, lengths)) | |
scores = self.scorer(hs, spans) | |
return spans, scores | |
def decode( | |
self, | |
spans: Sequence[torch.Tensor], | |
scores: Sequence[torch.Tensor], | |
) -> List[List[Tuple[int, int, int]]]: | |
spans_flatten = torch.cat(spans) | |
scores_flatten = torch.cat(scores) | |
assert len(spans_flatten) == len(scores_flatten) | |
labels_flatten = scores_flatten.argmax(dim=1).cpu() | |
mask = labels_flatten < self.scorer.num_labels - 1 | |
mentions = torch.hstack((spans_flatten[mask], labels_flatten[mask, None])) | |
output = [] | |
offset = 0 | |
sizes = [m.sum() for m in torch.split(mask, [len(idxs) for idxs in spans])] | |
for size in sizes: | |
output.append([tuple(m) for m in mentions[offset : offset + size].tolist()]) | |
offset += size | |
return output # type: ignore | |
def compute_metrics( | |
self, | |
spans: Sequence[torch.Tensor], | |
scores: Sequence[torch.Tensor], | |
true_mentions: Sequence[Sequence[Tuple[int, int, int]]], | |
decode=True, | |
) -> Dict[str, Any]: | |
assert len(spans) == len(scores) == len(true_mentions) | |
num_labels = self.scorer.num_labels | |
true_labels = [] | |
for spans_i, scores_i, true_mentions_i in zip(spans, scores, true_mentions): | |
assert len(spans_i) == len(scores_i) | |
span2idx = {tuple(s): idx for idx, s in enumerate(spans_i.tolist())} | |
labels_i = torch.full((len(spans_i),), fill_value=num_labels - 1) | |
for (start, end, label) in true_mentions_i: | |
idx = span2idx.get((start, end)) | |
if idx is not None: | |
labels_i[idx] = label | |
true_labels.append(labels_i) | |
scores_flatten = torch.cat(scores) | |
true_labels_flatten = torch.cat(true_labels).to(scores_flatten.device) | |
assert len(scores_flatten) == len(true_labels_flatten) | |
loss = F.cross_entropy(scores_flatten, true_labels_flatten) | |
accuracy = categorical_accuracy(scores_flatten, true_labels_flatten) | |
result = {"loss": loss, "accuracy": accuracy} | |
if decode: | |
pred_mentions = self.decode(spans, scores) | |
tp, fn, fp = 0, 0, 0 | |
for pred_mentions_i, true_mentions_i in zip(pred_mentions, true_mentions): | |
pred, gold = set(pred_mentions_i), set(true_mentions_i) | |
tp += len(gold & pred) | |
fn += len(gold - pred) | |
fp += len(pred - gold) | |
result["precision"] = (tp, tp + fp) | |
result["recall"] = (tp, tp + fn) | |
result["mentions"] = pred_mentions | |
return result | |
def categorical_accuracy( | |
y: torch.Tensor, t: torch.Tensor, ignore_index: Optional[int] = None | |
) -> Tuple[int, int]: | |
pred = y.argmax(dim=1) | |
if ignore_index is not None: | |
mask = t == ignore_index | |
ignore_cnt = mask.sum() | |
pred.masked_fill_(mask, ignore_index) | |
count = ((pred == t).sum() - ignore_cnt).item() | |
total = (t.numel() - ignore_cnt).item() | |
else: | |
count = (pred == t).sum().item() | |
total = t.numel() | |
return count, total | |
class SpanScorer(torch.nn.Module): | |
def __init__(self, num_labels: int): | |
super().__init__() | |
self.num_labels = num_labels | |
def forward( | |
self, xs: torch.Tensor, spans: Sequence[torch.Tensor] | |
): | |
raise NotImplementedError | |
class BaselineSpanScorer(SpanScorer): | |
def __init__( | |
self, | |
input_size: int, | |
num_labels: int, | |
mlp_units: Union[int, Sequence[int]] = 150, | |
mlp_dropout: float = 0.0, | |
feature="concat", | |
): | |
super().__init__(num_labels) | |
input_size *= 2 if feature == "concat" else 1 | |
self.mlp = MLP(input_size, num_labels, mlp_units, F.relu, mlp_dropout) | |
self.feature = feature | |
def forward( | |
self, xs: torch.Tensor, spans: Sequence[torch.Tensor] | |
): | |
max_length = xs.size(1) | |
xs_flatten = xs.reshape(-1, xs.size(-1)) | |
spans_flatten = torch.cat([idxs + max_length * i for i, idxs in enumerate(spans)]) | |
features = self._compute_feature(xs_flatten, spans_flatten) | |
scores = self.mlp(features) | |
return torch.split(scores, [len(idxs) for idxs in spans]) | |
def _compute_feature(self, xs, spans): | |
if self.feature == "concat": | |
return xs[spans.ravel()].view(len(spans), -1) | |
elif self.feature == "minus": | |
begins, ends = spans.T | |
return xs[ends] - xs[begins] | |
else: | |
raise NotImplementedError | |
class MLP(nn.Sequential): | |
def __init__( | |
self, | |
in_features: int, | |
out_features: Optional[int], | |
units: Optional[Union[int, Sequence[int]]] = None, | |
activate: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, | |
dropout: float = 0.0, | |
bias: bool = True, | |
): | |
units = [units] if isinstance(units, int) else units | |
if not units and out_features is None: | |
raise ValueError("'out_features' or 'units' must be specified") | |
layers = [] | |
for u in units or []: | |
layers.append(MLP.Layer(in_features, u, activate, dropout, bias)) | |
in_features = u | |
if out_features is not None: | |
layers.append(MLP.Layer(in_features, out_features, None, 0.0, bias)) | |
super().__init__(*layers) | |
class Layer(nn.Module): | |
def __init__( | |
self, | |
in_features: int, | |
out_features: int, | |
activate: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, | |
dropout: float = 0.0, | |
bias: bool = True, | |
): | |
super().__init__() | |
if activate is not None and not callable(activate): | |
raise TypeError("activate must be callable: type={}".format(type(activate))) | |
self.linear = nn.Linear(in_features, out_features, bias) | |
self.activate = activate | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
h = self.linear(x) | |
if self.activate is not None: | |
h = self.activate(h) | |
return self.dropout(h) | |
def extra_repr(self) -> str: | |
return "{}, activate={}, dropout={}".format( | |
self.linear.extra_repr(), self.activate, self.dropout.p | |
) | |
def __repr__(self): | |
return "{}.{}({})".format(MLP.__name__, self._get_name(), self.extra_repr()) | |