|
import math |
|
from typing import Optional, List |
|
from functools import lru_cache |
|
from itertools import chain, tee |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
n_dists = { |
|
0: [1], |
|
1: [0.4, 0.6], |
|
2: [0.2, 0.3, 0.5], |
|
3: [0.1, 0.2, 0.3, 0.4], |
|
4: [0.1, 0.15, 0.2, 0.25, 0.3], |
|
} |
|
|
|
strats = {"linear": lambda x: x, "log": lambda x: math.log(x + 1), "exp": lambda x: x**2} |
|
|
|
def pad_sequence( |
|
sequence, |
|
n, |
|
pad_left=False, |
|
pad_right=False, |
|
left_pad_symbol=None, |
|
right_pad_symbol=None, |
|
): |
|
"""Copied from NLTK""" |
|
sequence = iter(sequence) |
|
if pad_left: |
|
sequence = chain((left_pad_symbol,) * (n - 1), sequence) |
|
if pad_right: |
|
sequence = chain(sequence, (right_pad_symbol,) * (n - 1)) |
|
return sequence |
|
|
|
def ngrams(sequence, n, **kwargs): |
|
"""Copied from NLTK""" |
|
sequence = pad_sequence(sequence, n, **kwargs) |
|
|
|
|
|
|
|
iterables = tee(sequence, n) |
|
|
|
for i, sub_iterable in enumerate(iterables): |
|
for _ in range(i): |
|
next(sub_iterable, None) |
|
return zip(*iterables) |
|
|
|
|
|
@lru_cache(maxsize=5) |
|
def soft_dist(n): |
|
return [1 / n] * n |
|
|
|
|
|
@lru_cache(maxsize=5) |
|
def n_dist(n: int, strategy: str) -> list[float]: |
|
"""dist of ngram weight is logarithmic""" |
|
ns = list(range(1, n + 1)) |
|
xs = list(map(strats[strategy], ns)) |
|
result = list(map(lambda x: x / sum(xs), xs)) |
|
return result |
|
|
|
def soft_n_hot( |
|
input, |
|
num_classes: int, |
|
strategy: Optional[str], |
|
): |
|
|
|
shape = list(input.size())[1:] |
|
|
|
shape.append(num_classes) |
|
|
|
ret = torch.zeros(shape).to(input.device) |
|
|
|
if strategy: |
|
soft_labels = n_dist(input.size(0), strategy) |
|
else: |
|
soft_labels = [1] * input.size(0) |
|
|
|
for i, t in enumerate(input): |
|
ret.scatter_(-1, t.unsqueeze(-1), soft_labels[i]) |
|
|
|
return ret |
|
|
|
|
|
def n_hot(t, num_clases, ngram_sequences: Optional[torch.Tensor] = None, unk_idx: Optional[int] = None): |
|
|
|
shape = list(t.size()) |
|
|
|
if ngram_sequences is not None: |
|
shape.append(num_clases) |
|
ret = torch.zeros(shape).to(t.device) |
|
ret.scatter_(-1, t.unsqueeze(-1), 1) |
|
for seq in ngram_sequences: |
|
if unk_idx is not None: |
|
mask = torch.eq(seq, unk_idx) |
|
seq[mask] = t[mask] |
|
ret.scatter_(-1, seq.unsqueeze(-1), 1) |
|
return ret |
|
|
|
elif len(shape) == 2: |
|
return F.one_hot(t, num_classes=num_clases).float() |
|
else: |
|
shape = shape[1:] |
|
shape.append(num_clases) |
|
ret = torch.zeros(shape).to(t.device) |
|
|
|
for seq in t: |
|
ret.scatter_(-1, seq.unsqueeze(-1), 1) |
|
|
|
return ret |
|
|
|
|
|
class NGramsEmbedding(torch.nn.Embedding): |
|
"""N-Hot encoder""" |
|
|
|
def __init__( |
|
self, |
|
num_embeddings: int, |
|
embedding_dim: int, |
|
padding_idx: Optional[int] = None, |
|
max_norm: Optional[float] = None, |
|
norm_type: float = 2, |
|
scale_grad_by_freq: bool = False, |
|
sparse: bool = False, |
|
_weight: Optional[torch.Tensor] = None, |
|
device=None, |
|
dtype=None, |
|
unk_idx: Optional[int] = None |
|
) -> None: |
|
super().__init__( |
|
num_embeddings, |
|
embedding_dim, |
|
padding_idx=padding_idx, |
|
max_norm=max_norm, |
|
norm_type=norm_type, |
|
scale_grad_by_freq=scale_grad_by_freq, |
|
sparse=sparse, |
|
_weight=_weight, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
|
|
self.num_classes = num_embeddings |
|
self.unk_idx = unk_idx |
|
|
|
def forward(self, input: torch.Tensor, ngram_sequences: Optional[torch.Tensor] = None): |
|
return self._forward( |
|
n_hot(input, self.num_classes, ngram_sequences, self.unk_idx) |
|
) |
|
|
|
def _forward(self, n_hot: torch.Tensor) -> torch.Tensor: |
|
return F.linear(n_hot, self.weight.t()) |
|
|
|
|
|
def collect_n_gram_sequences(**kwargs) -> List[torch.Tensor]: |
|
sequences = [] |
|
for n in range(2, len(kwargs)+2): |
|
s = kwargs[f"gram_{n}_sequence"] |
|
if s is not None: |
|
sequences.append(s) |
|
else: |
|
break |
|
|
|
return sequences |
|
|
|
def shift_with_pad(target_tensor, n, from_tensor): |
|
shifted = target_tensor[:, n:] |
|
|
|
seq_size = target_tensor.size(1) - 1 |
|
|
|
missing_idxs = torch.arange(seq_size - (n-1), seq_size).to(target_tensor.device) |
|
|
|
|
|
shifted = torch.concat( |
|
(shifted, from_tensor.index_select(1, missing_idxs)), dim=1 |
|
) |
|
|
|
return shifted |
|
|