import math from typing import Optional, List from functools import lru_cache import unittest import torch import torch.nn.functional as F n_dists = { 0: [1], 1: [0.4, 0.6], 2: [0.2, 0.3, 0.5], 3: [0.1, 0.2, 0.3, 0.4], 4: [0.1, 0.15, 0.2, 0.25, 0.3], } strats = {"linear": lambda x: x, "log": lambda x: math.log(x + 1), "exp": lambda x: x**2} @lru_cache(maxsize=5) def soft_dist(n): return [1 / n] * n @lru_cache(maxsize=5) def n_dist(n: int, strategy: str) -> list[float]: """dist of ngram weight is logarithmic""" ns = list(range(1, n + 1)) xs = list(map(strats[strategy], ns)) result = list(map(lambda x: x / sum(xs), xs)) return result def soft_n_hot( input, num_classes: int, strategy: Optional[str], ): shape = list(input.size())[1:] shape.append(num_classes) ret = torch.zeros(shape).to(input.device) if strategy: soft_labels = n_dist(input.size(0), strategy) else: soft_labels = [1] * input.size(0) for i, t in enumerate(input): ret.scatter_(-1, t.unsqueeze(-1), soft_labels[i]) return ret def n_hot(t, num_clases, ngram_sequences: Optional[torch.Tensor] = None, unk_idx: Optional[int] = None): shape = list(t.size()) if ngram_sequences is not None: shape.append(num_clases) ret = torch.zeros(shape).to(t.device) ret.scatter_(-1, t.unsqueeze(-1), 1) for seq in ngram_sequences: if unk_idx is not None: mask = torch.eq(seq, unk_idx) seq[mask] = t[mask] ret.scatter_(-1, seq.unsqueeze(-1), 1) return ret elif len(shape) == 2: return F.one_hot(t, num_classes=num_clases).float() else: shape = shape[1:] shape.append(num_clases) ret = torch.zeros(shape).to(t.device) # Expect that first dimension is for all n-grams for seq in t: ret.scatter_(-1, seq.unsqueeze(-1), 1) return ret class NGramsEmbedding(torch.nn.Embedding): """N-Hot encoder""" def __init__( self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, max_norm: Optional[float] = None, norm_type: float = 2, scale_grad_by_freq: bool = False, sparse: bool = False, _weight: Optional[torch.Tensor] = None, device=None, dtype=None, unk_idx: Optional[int] = None ) -> None: super().__init__( num_embeddings, embedding_dim, padding_idx=padding_idx, max_norm=max_norm, norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq, sparse=sparse, _weight=_weight, device=device, dtype=dtype, ) self.num_classes = num_embeddings self.unk_idx = unk_idx def forward(self, input: torch.Tensor, ngram_sequences: Optional[torch.Tensor] = None): return self._forward( n_hot(input, self.num_classes, ngram_sequences, self.unk_idx) ) def _forward(self, n_hot: torch.Tensor) -> torch.Tensor: return F.linear(n_hot, self.weight.t()) def collect_n_gram_sequences(**kwargs) -> List[torch.Tensor]: sequences = [] for n in range(2, len(kwargs)+2): s = kwargs[f"gram_{n}_sequence"] if s is not None: sequences.append(s) else: break return sequences def shift_with_pad(target_tensor, n, from_tensor): shifted = target_tensor[:, n:] seq_size = target_tensor.size(1) - 1 missing_idxs = torch.arange(seq_size - (n-1), seq_size).to(target_tensor.device) # Pad with missing idxs from unigram tensor shifted = torch.concat( (shifted, from_tensor.index_select(1, missing_idxs)), dim=1 ) return shifted class TestNGME(unittest.TestCase): def test_one_hot(self): t = torch.tensor([[0, 1, 2]]) ret = n_hot(t, 3) expected = torch.eye(3) assert torch.all(torch.eq(ret, expected)) def test_multi_hot1(self): t = torch.tensor([[0, 1, 2]]) # [batch, ngram, seq] two_grams = torch.tensor([[[0, 1, 2]]]) ret = n_hot(t, 3, two_grams) expected = torch.eye(3) assert torch.all(torch.eq(ret, expected)) def test_multi_hot2(self): t = torch.tensor([[0, 1, 2]]) two_three_grams = torch.tensor([[[1, 2, 0]], [[2, 0, 1]]]) ret = n_hot(t, 3, two_three_grams) expected = torch.ones(3, 3) assert torch.all(torch.eq(ret, expected)) class TestShifting(unittest.TestCase): def test_two_gram(self): two_gram_batch = torch.tensor([[0, 1, 2, 3, 4]]) from_tensor = torch.tensor([[-4, -3, -2, -1]]) shifted = _shift_with_pad(two_gram_batch, 2, from_tensor) expected = torch.tensor([[2, 3, 4, -1]]) assert torch.all(torch.eq(shifted, expected)) def test_three_gram(self): three_gram_batch = torch.tensor([[0, 1, 2, 3, 4, 5, 6]]) from_tensor = torch.tensor([[-6, -5, -4, -3, -2, -1]]) shifted = _shift_with_pad(three_gram_batch, 3, from_tensor) expected = torch.tensor([[3, 4, 5, 6, -2, -1]]) assert torch.all(torch.eq(shifted, expected)) def test_three_gram_2(self): three_gram_batch = torch.tensor([[0, 1, 2, 3, 4, 5, 6], [0, 1, 2, 3, 4, 5, 6]]) from_tensor = torch.tensor([[-6, -5, -4, -3, -2, -1], [-6, -5, -4, -3, -2, -1]]) shifted = _shift_with_pad(three_gram_batch, 3, from_tensor) expected = torch.tensor([[3, 4, 5, 6, -2, -1], [3, 4, 5, 6, -2, -1]]) assert torch.all(torch.eq(shifted, expected)) class TestNGramEmbeddings(unittest.TestCase): def test(self): emb = NGramsEmbedding(10, 10) emb.weight = torch.nn.Parameter(torch.eye(10)) emb1 = emb(torch.tensor([[1, 2, 3]])) emb2 = emb(torch.tensor([[4, 5, 6]])) emb3 = emb(torch.tensor([[1, 2, 3]]), [torch.tensor([[4, 5, 6]])]) assert torch.all(torch.eq(torch.add(emb1, emb2), emb3)) def test_2(self): emb = NGramsEmbedding(10, 10) emb.weight = torch.nn.Parameter(torch.eye(10)) emb1 = emb(torch.tensor([[1, 2, 3]])) emb2 = emb(torch.tensor([[1, 2, 3]])) emb3 = emb(torch.tensor([[1, 2, 3]]), [torch.tensor([[1, 2, 3]])]) assert torch.all(torch.eq(emb1, emb3)) assert torch.all(torch.eq(emb2, emb3)) def test_3_gram(self): emb = NGramsEmbedding(10, 10) emb.weight = torch.nn.Parameter(torch.eye(10)) emb1 = emb(torch.tensor([[1, 2, 3]])) emb2 = emb(torch.tensor([[4, 5, 6]])) emb3 = emb(torch.tensor([[7, 8, 9]])) emb4 = emb(torch.tensor([[1, 2, 3]]), [torch.tensor([[4, 5, 6]]), torch.tensor([[7, 8, 9]])]) assert torch.all(torch.eq(torch.add(torch.add(emb1, emb2), emb3), emb4)) def test_ignore_indx(self): emb = NGramsEmbedding(10, 10, unk_idx=0) emb.weight = torch.nn.Parameter(torch.eye(10)) unigram = torch.tensor([[1, 2, 3]]) bigram = torch.tensor([[0, 0, 0]]) emb1 = emb(unigram, [bigram]) emb2 = emb(unigram) assert torch.all(torch.eq(emb1, emb2)) def test_ignore_indx_2(self): emb = NGramsEmbedding(10, 10, unk_idx=0) emb.weight = torch.nn.Parameter(torch.eye(10)) unigram = torch.tensor([[0, 2, 3]]) bigram = torch.tensor([[0, 0, 0]]) emb1 = emb(unigram, [bigram]) emb2 = emb(unigram) assert torch.all(torch.eq(emb1, emb2)) if __name__ == '__main__': unittest.main()