import math from typing import Optional, List from functools import lru_cache from itertools import chain, tee import torch import torch.nn.functional as F n_dists = { 0: [1], 1: [0.4, 0.6], 2: [0.2, 0.3, 0.5], 3: [0.1, 0.2, 0.3, 0.4], 4: [0.1, 0.15, 0.2, 0.25, 0.3], } strats = {"linear": lambda x: x, "log": lambda x: math.log(x + 1), "exp": lambda x: x**2} def pad_sequence( sequence, n, pad_left=False, pad_right=False, left_pad_symbol=None, right_pad_symbol=None, ): """Copied from NLTK""" sequence = iter(sequence) if pad_left: sequence = chain((left_pad_symbol,) * (n - 1), sequence) if pad_right: sequence = chain(sequence, (right_pad_symbol,) * (n - 1)) return sequence def ngrams(sequence, n, **kwargs): """Copied from NLTK""" sequence = pad_sequence(sequence, n, **kwargs) # Creates the sliding window, of n no. of items. # `iterables` is a tuple of iterables where each iterable is a window of n items. iterables = tee(sequence, n) for i, sub_iterable in enumerate(iterables): # For each window, for _ in range(i): # iterate through every order of ngrams next(sub_iterable, None) # generate the ngrams within the window. return zip(*iterables) # Unpack and flattens the iterables. @lru_cache(maxsize=5) def soft_dist(n): return [1 / n] * n @lru_cache(maxsize=5) def n_dist(n: int, strategy: str) -> list[float]: """dist of ngram weight is logarithmic""" ns = list(range(1, n + 1)) xs = list(map(strats[strategy], ns)) result = list(map(lambda x: x / sum(xs), xs)) return result def soft_n_hot( input, num_classes: int, strategy: Optional[str], ): shape = list(input.size())[1:] shape.append(num_classes) ret = torch.zeros(shape).to(input.device) if strategy: soft_labels = n_dist(input.size(0), strategy) else: soft_labels = [1] * input.size(0) for i, t in enumerate(input): ret.scatter_(-1, t.unsqueeze(-1), soft_labels[i]) return ret def n_hot(t, num_clases, ngram_sequences: Optional[torch.Tensor] = None, unk_idx: Optional[int] = None): shape = list(t.size()) if ngram_sequences is not None: shape.append(num_clases) ret = torch.zeros(shape).to(t.device) ret.scatter_(-1, t.unsqueeze(-1), 1) for seq in ngram_sequences: if unk_idx is not None: mask = torch.eq(seq, unk_idx) seq[mask] = t[mask] ret.scatter_(-1, seq.unsqueeze(-1), 1) return ret elif len(shape) == 2: return F.one_hot(t, num_classes=num_clases).float() else: shape = shape[1:] shape.append(num_clases) ret = torch.zeros(shape).to(t.device) # Expect that first dimension is for all n-grams for seq in t: ret.scatter_(-1, seq.unsqueeze(-1), 1) return ret class NGramsEmbedding(torch.nn.Embedding): """N-Hot encoder""" def __init__( self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, max_norm: Optional[float] = None, norm_type: float = 2, scale_grad_by_freq: bool = False, sparse: bool = False, _weight: Optional[torch.Tensor] = None, device=None, dtype=None, unk_idx: Optional[int] = None ) -> None: super().__init__( num_embeddings, embedding_dim, padding_idx=padding_idx, max_norm=max_norm, norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq, sparse=sparse, _weight=_weight, device=device, dtype=dtype, ) self.num_classes = num_embeddings self.unk_idx = unk_idx def forward(self, input: torch.Tensor, ngram_sequences: Optional[torch.Tensor] = None): return self._forward( n_hot(input, self.num_classes, ngram_sequences, self.unk_idx) ) def _forward(self, n_hot: torch.Tensor) -> torch.Tensor: return F.linear(n_hot, self.weight.t()) def collect_n_gram_sequences(**kwargs) -> List[torch.Tensor]: sequences = [] for n in range(2, len(kwargs)+2): s = kwargs[f"gram_{n}_sequence"] if s is not None: sequences.append(s) else: break return sequences def shift_with_pad(target_tensor, n, from_tensor): shifted = target_tensor[:, n:] seq_size = target_tensor.size(1) - 1 missing_idxs = torch.arange(seq_size - (n-1), seq_size).to(target_tensor.device) # Pad with missing idxs from unigram tensor shifted = torch.concat( (shifted, from_tensor.index_select(1, missing_idxs)), dim=1 ) return shifted