Commit
•
af87020
1
Parent(s):
b117cb6
Upload ngme.py
Browse files
ngme.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Optional, List
|
3 |
+
from functools import lru_cache
|
4 |
+
import unittest
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
n_dists = {
|
10 |
+
0: [1],
|
11 |
+
1: [0.4, 0.6],
|
12 |
+
2: [0.2, 0.3, 0.5],
|
13 |
+
3: [0.1, 0.2, 0.3, 0.4],
|
14 |
+
4: [0.1, 0.15, 0.2, 0.25, 0.3],
|
15 |
+
}
|
16 |
+
|
17 |
+
strats = {"linear": lambda x: x, "log": lambda x: math.log(x + 1), "exp": lambda x: x**2}
|
18 |
+
|
19 |
+
|
20 |
+
@lru_cache(maxsize=5)
|
21 |
+
def soft_dist(n):
|
22 |
+
return [1 / n] * n
|
23 |
+
|
24 |
+
|
25 |
+
@lru_cache(maxsize=5)
|
26 |
+
def n_dist(n: int, strategy: str) -> list[float]:
|
27 |
+
"""dist of ngram weight is logarithmic"""
|
28 |
+
ns = list(range(1, n + 1))
|
29 |
+
xs = list(map(strats[strategy], ns))
|
30 |
+
result = list(map(lambda x: x / sum(xs), xs))
|
31 |
+
return result
|
32 |
+
|
33 |
+
def soft_n_hot(
|
34 |
+
input,
|
35 |
+
num_classes: int,
|
36 |
+
strategy: Optional[str],
|
37 |
+
):
|
38 |
+
|
39 |
+
shape = list(input.size())[1:]
|
40 |
+
|
41 |
+
shape.append(num_classes)
|
42 |
+
|
43 |
+
ret = torch.zeros(shape).to(input.device)
|
44 |
+
|
45 |
+
if strategy:
|
46 |
+
soft_labels = n_dist(input.size(0), strategy)
|
47 |
+
else:
|
48 |
+
soft_labels = [1] * input.size(0)
|
49 |
+
|
50 |
+
for i, t in enumerate(input):
|
51 |
+
ret.scatter_(-1, t.unsqueeze(-1), soft_labels[i])
|
52 |
+
|
53 |
+
return ret
|
54 |
+
|
55 |
+
|
56 |
+
def n_hot(t, num_clases, ngram_sequences: Optional[torch.Tensor] = None, unk_idx: Optional[int] = None):
|
57 |
+
|
58 |
+
shape = list(t.size())
|
59 |
+
|
60 |
+
if ngram_sequences is not None:
|
61 |
+
shape.append(num_clases)
|
62 |
+
ret = torch.zeros(shape).to(t.device)
|
63 |
+
ret.scatter_(-1, t.unsqueeze(-1), 1)
|
64 |
+
for seq in ngram_sequences:
|
65 |
+
if unk_idx is not None:
|
66 |
+
mask = torch.eq(seq, unk_idx)
|
67 |
+
seq[mask] = t[mask]
|
68 |
+
ret.scatter_(-1, seq.unsqueeze(-1), 1)
|
69 |
+
return ret
|
70 |
+
|
71 |
+
elif len(shape) == 2:
|
72 |
+
return F.one_hot(t, num_classes=num_clases).float()
|
73 |
+
else:
|
74 |
+
shape = shape[1:]
|
75 |
+
shape.append(num_clases)
|
76 |
+
ret = torch.zeros(shape).to(t.device)
|
77 |
+
# Expect that first dimension is for all n-grams
|
78 |
+
for seq in t:
|
79 |
+
ret.scatter_(-1, seq.unsqueeze(-1), 1)
|
80 |
+
|
81 |
+
return ret
|
82 |
+
|
83 |
+
|
84 |
+
class NGramsEmbedding(torch.nn.Embedding):
|
85 |
+
"""N-Hot encoder"""
|
86 |
+
|
87 |
+
def __init__(
|
88 |
+
self,
|
89 |
+
num_embeddings: int,
|
90 |
+
embedding_dim: int,
|
91 |
+
padding_idx: Optional[int] = None,
|
92 |
+
max_norm: Optional[float] = None,
|
93 |
+
norm_type: float = 2,
|
94 |
+
scale_grad_by_freq: bool = False,
|
95 |
+
sparse: bool = False,
|
96 |
+
_weight: Optional[torch.Tensor] = None,
|
97 |
+
device=None,
|
98 |
+
dtype=None,
|
99 |
+
unk_idx: Optional[int] = None
|
100 |
+
) -> None:
|
101 |
+
super().__init__(
|
102 |
+
num_embeddings,
|
103 |
+
embedding_dim,
|
104 |
+
padding_idx=padding_idx,
|
105 |
+
max_norm=max_norm,
|
106 |
+
norm_type=norm_type,
|
107 |
+
scale_grad_by_freq=scale_grad_by_freq,
|
108 |
+
sparse=sparse,
|
109 |
+
_weight=_weight,
|
110 |
+
device=device,
|
111 |
+
dtype=dtype,
|
112 |
+
)
|
113 |
+
|
114 |
+
self.num_classes = num_embeddings
|
115 |
+
self.unk_idx = unk_idx
|
116 |
+
|
117 |
+
def forward(self, input: torch.Tensor, ngram_sequences: Optional[torch.Tensor] = None):
|
118 |
+
return self._forward(
|
119 |
+
n_hot(input, self.num_classes, ngram_sequences, self.unk_idx)
|
120 |
+
)
|
121 |
+
|
122 |
+
def _forward(self, n_hot: torch.Tensor) -> torch.Tensor:
|
123 |
+
return F.linear(n_hot, self.weight.t())
|
124 |
+
|
125 |
+
|
126 |
+
def collect_n_gram_sequences(**kwargs) -> List[torch.Tensor]:
|
127 |
+
sequences = []
|
128 |
+
for n in range(2, len(kwargs)+2):
|
129 |
+
s = kwargs[f"gram_{n}_sequence"]
|
130 |
+
if s is not None:
|
131 |
+
sequences.append(s)
|
132 |
+
else:
|
133 |
+
break
|
134 |
+
|
135 |
+
return sequences
|
136 |
+
|
137 |
+
def shift_with_pad(target_tensor, n, from_tensor):
|
138 |
+
shifted = target_tensor[:, n:]
|
139 |
+
|
140 |
+
seq_size = target_tensor.size(1) - 1
|
141 |
+
|
142 |
+
missing_idxs = torch.arange(seq_size - (n-1), seq_size).to(target_tensor.device)
|
143 |
+
|
144 |
+
# Pad with missing idxs from unigram tensor
|
145 |
+
shifted = torch.concat(
|
146 |
+
(shifted, from_tensor.index_select(1, missing_idxs)), dim=1
|
147 |
+
)
|
148 |
+
|
149 |
+
return shifted
|
150 |
+
|
151 |
+
class TestNGME(unittest.TestCase):
|
152 |
+
|
153 |
+
def test_one_hot(self):
|
154 |
+
t = torch.tensor([[0, 1, 2]])
|
155 |
+
ret = n_hot(t, 3)
|
156 |
+
expected = torch.eye(3)
|
157 |
+
assert torch.all(torch.eq(ret, expected))
|
158 |
+
|
159 |
+
def test_multi_hot1(self):
|
160 |
+
t = torch.tensor([[0, 1, 2]])
|
161 |
+
# [batch, ngram, seq]
|
162 |
+
two_grams = torch.tensor([[[0, 1, 2]]])
|
163 |
+
ret = n_hot(t, 3, two_grams)
|
164 |
+
expected = torch.eye(3)
|
165 |
+
assert torch.all(torch.eq(ret, expected))
|
166 |
+
|
167 |
+
def test_multi_hot2(self):
|
168 |
+
t = torch.tensor([[0, 1, 2]])
|
169 |
+
two_three_grams = torch.tensor([[[1, 2, 0]], [[2, 0, 1]]])
|
170 |
+
ret = n_hot(t, 3, two_three_grams)
|
171 |
+
expected = torch.ones(3, 3)
|
172 |
+
assert torch.all(torch.eq(ret, expected))
|
173 |
+
|
174 |
+
class TestShifting(unittest.TestCase):
|
175 |
+
|
176 |
+
def test_two_gram(self):
|
177 |
+
two_gram_batch = torch.tensor([[0, 1, 2, 3, 4]])
|
178 |
+
from_tensor = torch.tensor([[-4, -3, -2, -1]])
|
179 |
+
shifted = _shift_with_pad(two_gram_batch, 2, from_tensor)
|
180 |
+
expected = torch.tensor([[2, 3, 4, -1]])
|
181 |
+
assert torch.all(torch.eq(shifted, expected))
|
182 |
+
|
183 |
+
def test_three_gram(self):
|
184 |
+
three_gram_batch = torch.tensor([[0, 1, 2, 3, 4, 5, 6]])
|
185 |
+
from_tensor = torch.tensor([[-6, -5, -4, -3, -2, -1]])
|
186 |
+
shifted = _shift_with_pad(three_gram_batch, 3, from_tensor)
|
187 |
+
expected = torch.tensor([[3, 4, 5, 6, -2, -1]])
|
188 |
+
assert torch.all(torch.eq(shifted, expected))
|
189 |
+
|
190 |
+
def test_three_gram_2(self):
|
191 |
+
three_gram_batch = torch.tensor([[0, 1, 2, 3, 4, 5, 6], [0, 1, 2, 3, 4, 5, 6]])
|
192 |
+
from_tensor = torch.tensor([[-6, -5, -4, -3, -2, -1], [-6, -5, -4, -3, -2, -1]])
|
193 |
+
shifted = _shift_with_pad(three_gram_batch, 3, from_tensor)
|
194 |
+
expected = torch.tensor([[3, 4, 5, 6, -2, -1], [3, 4, 5, 6, -2, -1]])
|
195 |
+
assert torch.all(torch.eq(shifted, expected))
|
196 |
+
|
197 |
+
|
198 |
+
class TestNGramEmbeddings(unittest.TestCase):
|
199 |
+
|
200 |
+
def test(self):
|
201 |
+
emb = NGramsEmbedding(10, 10)
|
202 |
+
emb.weight = torch.nn.Parameter(torch.eye(10))
|
203 |
+
|
204 |
+
emb1 = emb(torch.tensor([[1, 2, 3]]))
|
205 |
+
emb2 = emb(torch.tensor([[4, 5, 6]]))
|
206 |
+
emb3 = emb(torch.tensor([[1, 2, 3]]), [torch.tensor([[4, 5, 6]])])
|
207 |
+
|
208 |
+
assert torch.all(torch.eq(torch.add(emb1, emb2), emb3))
|
209 |
+
|
210 |
+
def test_2(self):
|
211 |
+
emb = NGramsEmbedding(10, 10)
|
212 |
+
emb.weight = torch.nn.Parameter(torch.eye(10))
|
213 |
+
|
214 |
+
emb1 = emb(torch.tensor([[1, 2, 3]]))
|
215 |
+
|
216 |
+
emb2 = emb(torch.tensor([[1, 2, 3]]))
|
217 |
+
|
218 |
+
emb3 = emb(torch.tensor([[1, 2, 3]]), [torch.tensor([[1, 2, 3]])])
|
219 |
+
|
220 |
+
assert torch.all(torch.eq(emb1, emb3))
|
221 |
+
assert torch.all(torch.eq(emb2, emb3))
|
222 |
+
|
223 |
+
def test_3_gram(self):
|
224 |
+
emb = NGramsEmbedding(10, 10)
|
225 |
+
emb.weight = torch.nn.Parameter(torch.eye(10))
|
226 |
+
|
227 |
+
emb1 = emb(torch.tensor([[1, 2, 3]]))
|
228 |
+
emb2 = emb(torch.tensor([[4, 5, 6]]))
|
229 |
+
emb3 = emb(torch.tensor([[7, 8, 9]]))
|
230 |
+
|
231 |
+
emb4 = emb(torch.tensor([[1, 2, 3]]), [torch.tensor([[4, 5, 6]]), torch.tensor([[7, 8, 9]])])
|
232 |
+
|
233 |
+
assert torch.all(torch.eq(torch.add(torch.add(emb1, emb2), emb3), emb4))
|
234 |
+
|
235 |
+
def test_ignore_indx(self):
|
236 |
+
|
237 |
+
emb = NGramsEmbedding(10, 10, unk_idx=0)
|
238 |
+
emb.weight = torch.nn.Parameter(torch.eye(10))
|
239 |
+
|
240 |
+
unigram = torch.tensor([[1, 2, 3]])
|
241 |
+
bigram = torch.tensor([[0, 0, 0]])
|
242 |
+
|
243 |
+
emb1 = emb(unigram, [bigram])
|
244 |
+
emb2 = emb(unigram)
|
245 |
+
|
246 |
+
assert torch.all(torch.eq(emb1, emb2))
|
247 |
+
|
248 |
+
def test_ignore_indx_2(self):
|
249 |
+
|
250 |
+
emb = NGramsEmbedding(10, 10, unk_idx=0)
|
251 |
+
emb.weight = torch.nn.Parameter(torch.eye(10))
|
252 |
+
|
253 |
+
unigram = torch.tensor([[0, 2, 3]])
|
254 |
+
bigram = torch.tensor([[0, 0, 0]])
|
255 |
+
|
256 |
+
emb1 = emb(unigram, [bigram])
|
257 |
+
emb2 = emb(unigram)
|
258 |
+
|
259 |
+
assert torch.all(torch.eq(emb1, emb2))
|
260 |
+
|
261 |
+
|
262 |
+
|
263 |
+
|
264 |
+
if __name__ == '__main__':
|
265 |
+
unittest.main()
|