| from typing import Callable, NamedTuple |
| from typing import List, Tuple, Type, Union |
|
|
| import torch |
| from nltk import Tree |
| from torch import Tensor |
| from torch import nn |
| from torch.distributions.utils import lazy_property |
| from torchrua import C, segment_mean, L, Z |
| from transformers.models.roberta.modeling_roberta import PreTrainedModel, RobertaModel |
|
|
| from tmp.configuration_parserker import ParserkerConfig |
|
|
| Frames = Union[List[Tensor], Tuple[Tensor, ...]] |
|
|
|
|
| def diag(tensor: Tensor, offset: int) -> Tensor: |
| return tensor.diagonal(offset=offset, dim1=1, dim2=2) |
|
|
|
|
| def diag_scatter(chart: Tensor, score: Tensor, offset: int) -> None: |
| chart.diagonal(offset=offset, dim1=1, dim2=2)[::] = score |
|
|
|
|
| def left(chart: Tensor, offset: int) -> Tensor: |
| b, t, _, *size = chart.size() |
| c, n, m, *stride = chart.stride() |
| return chart.as_strided( |
| size=(b, t - offset, offset, *size), |
| stride=(c, n + m, m, *stride), |
| ) |
|
|
|
|
| def right(chart: Tensor, offset: int) -> Tensor: |
| b, t, _, *size = chart.size() |
| c, n, m, *stride = chart.stride() |
| return chart[:, 1:, offset:].as_strided( |
| size=(b, t - offset, offset, *size), |
| stride=(c, n + m, n, *stride), |
| ) |
|
|
|
|
| def to_hex(x: int, num_bits: int) -> str: |
| return f'{x:0{(num_bits + 3) // 4}X}' |
|
|
|
|
| def bits_to_long(tensor: Tensor) -> Tensor: |
| *_, num_bits = tensor.size() |
| index = torch.arange(num_bits, dtype=torch.long, device=tensor.device) |
| return (tensor << index).sum(dim=-1) |
|
|
|
|
| def long_to_bits(tensor: Tensor, num_bits: int) -> Tensor: |
| index = torch.arange(num_bits, dtype=torch.long, device=tensor.device) |
| return (tensor[..., None] >> index) & 1 |
|
|
|
|
| def max(tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor: |
| return torch.max(tensor, dim=dim, keepdim=keepdim).values |
|
|
|
|
| class Semiring(NamedTuple): |
| zero: float |
| one: float |
| add: Callable |
| mul: Callable |
| sum: Callable |
| prod: Callable |
|
|
|
|
| Log = Semiring( |
| zero=-float('inf'), |
| one=0., |
| add=torch.logaddexp, |
| mul=torch.add, |
| sum=torch.logsumexp, |
| prod=torch.sum, |
| ) |
|
|
| Max = Semiring( |
| zero=-float('inf'), |
| one=0., |
| add=torch.maximum, |
| mul=torch.add, |
| sum=max, |
| prod=torch.sum, |
| ) |
|
|
|
|
| def cumsum(tensor: Tensor) -> Tensor: |
| b, t1, t2, k = tensor.size() |
| assert t1 == t2, f'{t1} != {t2}' |
|
|
| p1 = tensor.permute(0, 3, 1, 2).triu() |
| c1 = p1.cumsum(dim=-1) |
| c2 = c1.flip(dims=[-2]).cumsum(dim=-2).flip(dims=[-2]) |
| p2 = c2.permute(0, 2, 3, 1) |
| return p2 |
|
|
|
|
| def cky_partitions(logits: Tensor, token_sizes: Tensor, semiring: Type[Semiring]): |
| logits = cumsum(logits) |
| logits = torch.stack([torch.zeros_like(logits), logits], dim=-1) |
| b, t, _, k, _ = logits.size() |
|
|
| chart = torch.full_like(logits[..., 0, 0], fill_value=semiring.zero, requires_grad=False) |
|
|
| z = diag(logits, offset=0)[..., None].permute([0, 3, 4, 1, 2]) |
|
|
| frames = [z] |
| z = semiring.sum(z, dim=-1) |
| z = semiring.prod(z, dim=-1) |
|
|
| diag_scatter(chart, z[..., 0], offset=0) |
| index = torch.arange(t, dtype=chart.dtype, device=chart.device) |
|
|
| for w in range(1, t): |
| z = diag(logits, offset=w)[..., None].permute([0, 3, 4, 1, 2]) |
| z = z - left(logits, offset=w) - right(logits, offset=w) |
| z = z / ((1 + index[:w]) * (w - index[:w]))[:, None, None] |
|
|
| frames.append(z) |
| z = semiring.sum(z, dim=-1) |
| z = semiring.prod(z, dim=-1) |
|
|
| xyz = semiring.mul(z, semiring.mul(left(chart, offset=w), right(chart, offset=w))) |
| score = semiring.sum(xyz, dim=-1) |
|
|
| diag_scatter(chart, score, offset=w) |
|
|
| index = torch.arange(b, dtype=torch.long, device=chart.device) |
| return chart[index, 0, token_sizes - 1], frames |
|
|
|
|
| class Distrubition(object): |
| def __init__(self, logits: Tensor, token_sizes: Tensor) -> None: |
| super(Distrubition, self).__init__() |
| self.logits = logits |
| self.token_sizes = token_sizes |
|
|
| @lazy_property |
| def log_partitions(self): |
| partitions, frames = cky_partitions( |
| logits=self.logits, |
| token_sizes=self.token_sizes, |
| semiring=Log, |
| ) |
|
|
| return partitions, frames |
|
|
| @lazy_property |
| def max(self): |
| partitions, frames = cky_partitions( |
| logits=self.logits, |
| token_sizes=self.token_sizes, |
| semiring=Max, |
| ) |
|
|
| return partitions, frames |
|
|
| @lazy_property |
| def marginals(self) -> Frames: |
| partitions, frames = self.log_partitions |
| return torch.autograd.grad( |
| partitions, frames, torch.ones_like(partitions), |
| create_graph=True, retain_graph=True, |
| only_inputs=True, allow_unused=True, |
| ) |
|
|
| @lazy_property |
| def grads(self) -> Frames: |
| partitions, frames = self.max |
| return torch.autograd.grad( |
| partitions, frames, torch.ones_like(partitions), |
| create_graph=False, retain_graph=False, |
| only_inputs=True, allow_unused=True, |
| ) |
|
|
| @staticmethod |
| def gather(marginals: Frames, grads: Frames, spans: Tensor): |
| b, _, _, k, _ = marginals[0].size() |
|
|
| xs, ys, zs = [], [], [] |
| for w, (x, grad) in enumerate(zip(marginals, grads)): |
| mask, y = grad.max(dim=-1, keepdim=True) |
| mask = mask.sum(dim=-2, keepdim=True) > 0 |
|
|
| z = diag(spans, offset=w)[..., None, None, None] |
|
|
| xs.append(torch.masked_select(x, mask)) |
| ys.append(torch.masked_select(y, mask)) |
| zs.append(torch.masked_select(z, mask)) |
|
|
| xs = torch.cat(xs, dim=0).view((-1, k, 2)) |
| ys = torch.cat(ys, dim=0).view((-1, k)) |
| zs = torch.cat(zs, dim=0) |
| return xs, ys, zs |
|
|
| @lazy_property |
| def argmax(self) -> C: |
| b, t, _, _, _ = self.grads[0].size() |
|
|
| b = torch.arange(b, dtype=torch.long, device=self.grads[0].device) |
| x = torch.arange(t, dtype=torch.long, device=self.grads[0].device) |
| y = torch.arange(t, dtype=torch.long, device=self.grads[0].device) |
| b, x, y = torch.broadcast_tensors(b[:, None, None], x[None, :, None], y[None, None, :]) |
|
|
| data = [] |
| for w, grad in enumerate(self.grads): |
| mask, z = grad.max(dim=-1, keepdim=False) |
| mask = mask.sum(dim=-1, keepdim=False) > 0 |
|
|
| data.append(torch.stack([ |
| torch.masked_select(diag(b, offset=w)[..., None], mask), |
| torch.masked_select(diag(x, offset=w)[..., None], mask), |
| torch.masked_select(diag(y, offset=w)[..., None], mask), |
| torch.masked_select(bits_to_long(z), mask), |
| ], dim=-1)) |
|
|
| data = torch.cat(data, dim=0) |
| b = torch.argsort(data[..., 0], dim=0, descending=False) |
| return C(data=data[b, 1:], token_sizes=self.token_sizes * 2 - 1) |
|
|
|
|
| class HashLayer(nn.Module): |
| def __init__(self, config: ParserkerConfig) -> None: |
| super(HashLayer, self).__init__() |
|
|
| self.num_bits = config.num_bits |
| self.bit_size = (config.hidden_size + config.num_bits - 1) // config.num_bits |
| self.scale = self.bit_size ** -0.5 |
|
|
| self.q_proj = nn.Linear(config.hidden_size, self.num_bits * self.bit_size, bias=True) |
| self.k_proj = nn.Linear(config.hidden_size, self.num_bits * self.bit_size, bias=True) |
|
|
| def forward(self, q: Tensor, k: Tensor): |
| q = self.q_proj(q).unflatten(dim=-1, sizes=(self.num_bits, 1, self.bit_size)) |
| k = self.k_proj(k).unflatten(dim=-1, sizes=(self.num_bits, self.bit_size, 1)) |
|
|
| return (q[:, :, None] @ k[:, None, :]).flatten(start_dim=-3).transpose(1, 2) * self.scale |
|
|
|
|
| class ParserkerModel(PreTrainedModel): |
| config_class = ParserkerConfig |
| base_model_prefix = "backbone" |
| _tied_weights_keys = {} |
|
|
| def __init__(self, config: ParserkerConfig, **kwargs): |
| super(ParserkerModel, self).__init__(config=config, **kwargs) |
|
|
| self.pad_token_id = config.pad_token_id |
| self.num_bits = config.num_bits |
|
|
| self.backbone = RobertaModel(config, add_pooling_layer=False) |
| self.hash_layer = HashLayer(config) |
|
|
| @property |
| def all_tied_weights_keys(self): |
| return getattr(self, "_tied_weights_keys", []) |
|
|
| def forward(self, input_ids: Z, duration: Z) -> Tensor: |
| out = self.backbone.forward( |
| input_ids=input_ids.left(self.pad_token_id).data, |
| attention_mask=input_ids.bmask(), |
| return_dict=True, |
| ) |
|
|
| tensor = L(data=out.last_hidden_state, token_sizes=input_ids.cat().token_sizes) |
| tensor, token_sizes = tensor.seg(duration, segment_mean).trunc((1, 1)) |
|
|
| logits = self.hash_layer(tensor, tensor) |
|
|
| return L(data=logits, token_sizes=token_sizes) |
|
|
| def parse(self, input_ids: Z, duration: C): |
| logits, token_sizes = self(input_ids, duration) |
| logits = logits.clone().requires_grad_(True) |
|
|
| dist = Distrubition(logits=logits, token_sizes=token_sizes) |
| return dist.argmax |
|
|
| def to_tree(self, words, spans) -> Tree: |
| stack = [] |
|
|
| for x, y, z in sorted(spans, key=lambda item: (item[0], -item[1]), reverse=True): |
| children = [] |
| while len(stack) > 0: |
| xx, yy, zz = stack.pop() |
| if x <= xx and yy <= y: |
| children.append(zz) |
| else: |
| stack.append((xx, yy, zz)) |
| break |
|
|
| if len(children) == 0: |
| children = ['__tok'] |
|
|
| stack.append((x, y, Tree(to_hex(z, self.num_bits), children))) |
|
|
| [(_, _, tree)] = stack |
|
|
| for index in range(len(tree.leaves())): |
| position = tree.leaf_treeposition(index) |
| tree[position] = words[index] |
|
|
| return tree |
|
|