PatrickHaller commited on
Commit
af87020
1 Parent(s): b117cb6

Upload ngme.py

Browse files
Files changed (1) hide show
  1. ngme.py +265 -0
ngme.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, List
3
+ from functools import lru_cache
4
+ import unittest
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ n_dists = {
10
+ 0: [1],
11
+ 1: [0.4, 0.6],
12
+ 2: [0.2, 0.3, 0.5],
13
+ 3: [0.1, 0.2, 0.3, 0.4],
14
+ 4: [0.1, 0.15, 0.2, 0.25, 0.3],
15
+ }
16
+
17
+ strats = {"linear": lambda x: x, "log": lambda x: math.log(x + 1), "exp": lambda x: x**2}
18
+
19
+
20
+ @lru_cache(maxsize=5)
21
+ def soft_dist(n):
22
+ return [1 / n] * n
23
+
24
+
25
+ @lru_cache(maxsize=5)
26
+ def n_dist(n: int, strategy: str) -> list[float]:
27
+ """dist of ngram weight is logarithmic"""
28
+ ns = list(range(1, n + 1))
29
+ xs = list(map(strats[strategy], ns))
30
+ result = list(map(lambda x: x / sum(xs), xs))
31
+ return result
32
+
33
+ def soft_n_hot(
34
+ input,
35
+ num_classes: int,
36
+ strategy: Optional[str],
37
+ ):
38
+
39
+ shape = list(input.size())[1:]
40
+
41
+ shape.append(num_classes)
42
+
43
+ ret = torch.zeros(shape).to(input.device)
44
+
45
+ if strategy:
46
+ soft_labels = n_dist(input.size(0), strategy)
47
+ else:
48
+ soft_labels = [1] * input.size(0)
49
+
50
+ for i, t in enumerate(input):
51
+ ret.scatter_(-1, t.unsqueeze(-1), soft_labels[i])
52
+
53
+ return ret
54
+
55
+
56
+ def n_hot(t, num_clases, ngram_sequences: Optional[torch.Tensor] = None, unk_idx: Optional[int] = None):
57
+
58
+ shape = list(t.size())
59
+
60
+ if ngram_sequences is not None:
61
+ shape.append(num_clases)
62
+ ret = torch.zeros(shape).to(t.device)
63
+ ret.scatter_(-1, t.unsqueeze(-1), 1)
64
+ for seq in ngram_sequences:
65
+ if unk_idx is not None:
66
+ mask = torch.eq(seq, unk_idx)
67
+ seq[mask] = t[mask]
68
+ ret.scatter_(-1, seq.unsqueeze(-1), 1)
69
+ return ret
70
+
71
+ elif len(shape) == 2:
72
+ return F.one_hot(t, num_classes=num_clases).float()
73
+ else:
74
+ shape = shape[1:]
75
+ shape.append(num_clases)
76
+ ret = torch.zeros(shape).to(t.device)
77
+ # Expect that first dimension is for all n-grams
78
+ for seq in t:
79
+ ret.scatter_(-1, seq.unsqueeze(-1), 1)
80
+
81
+ return ret
82
+
83
+
84
+ class NGramsEmbedding(torch.nn.Embedding):
85
+ """N-Hot encoder"""
86
+
87
+ def __init__(
88
+ self,
89
+ num_embeddings: int,
90
+ embedding_dim: int,
91
+ padding_idx: Optional[int] = None,
92
+ max_norm: Optional[float] = None,
93
+ norm_type: float = 2,
94
+ scale_grad_by_freq: bool = False,
95
+ sparse: bool = False,
96
+ _weight: Optional[torch.Tensor] = None,
97
+ device=None,
98
+ dtype=None,
99
+ unk_idx: Optional[int] = None
100
+ ) -> None:
101
+ super().__init__(
102
+ num_embeddings,
103
+ embedding_dim,
104
+ padding_idx=padding_idx,
105
+ max_norm=max_norm,
106
+ norm_type=norm_type,
107
+ scale_grad_by_freq=scale_grad_by_freq,
108
+ sparse=sparse,
109
+ _weight=_weight,
110
+ device=device,
111
+ dtype=dtype,
112
+ )
113
+
114
+ self.num_classes = num_embeddings
115
+ self.unk_idx = unk_idx
116
+
117
+ def forward(self, input: torch.Tensor, ngram_sequences: Optional[torch.Tensor] = None):
118
+ return self._forward(
119
+ n_hot(input, self.num_classes, ngram_sequences, self.unk_idx)
120
+ )
121
+
122
+ def _forward(self, n_hot: torch.Tensor) -> torch.Tensor:
123
+ return F.linear(n_hot, self.weight.t())
124
+
125
+
126
+ def collect_n_gram_sequences(**kwargs) -> List[torch.Tensor]:
127
+ sequences = []
128
+ for n in range(2, len(kwargs)+2):
129
+ s = kwargs[f"gram_{n}_sequence"]
130
+ if s is not None:
131
+ sequences.append(s)
132
+ else:
133
+ break
134
+
135
+ return sequences
136
+
137
+ def shift_with_pad(target_tensor, n, from_tensor):
138
+ shifted = target_tensor[:, n:]
139
+
140
+ seq_size = target_tensor.size(1) - 1
141
+
142
+ missing_idxs = torch.arange(seq_size - (n-1), seq_size).to(target_tensor.device)
143
+
144
+ # Pad with missing idxs from unigram tensor
145
+ shifted = torch.concat(
146
+ (shifted, from_tensor.index_select(1, missing_idxs)), dim=1
147
+ )
148
+
149
+ return shifted
150
+
151
+ class TestNGME(unittest.TestCase):
152
+
153
+ def test_one_hot(self):
154
+ t = torch.tensor([[0, 1, 2]])
155
+ ret = n_hot(t, 3)
156
+ expected = torch.eye(3)
157
+ assert torch.all(torch.eq(ret, expected))
158
+
159
+ def test_multi_hot1(self):
160
+ t = torch.tensor([[0, 1, 2]])
161
+ # [batch, ngram, seq]
162
+ two_grams = torch.tensor([[[0, 1, 2]]])
163
+ ret = n_hot(t, 3, two_grams)
164
+ expected = torch.eye(3)
165
+ assert torch.all(torch.eq(ret, expected))
166
+
167
+ def test_multi_hot2(self):
168
+ t = torch.tensor([[0, 1, 2]])
169
+ two_three_grams = torch.tensor([[[1, 2, 0]], [[2, 0, 1]]])
170
+ ret = n_hot(t, 3, two_three_grams)
171
+ expected = torch.ones(3, 3)
172
+ assert torch.all(torch.eq(ret, expected))
173
+
174
+ class TestShifting(unittest.TestCase):
175
+
176
+ def test_two_gram(self):
177
+ two_gram_batch = torch.tensor([[0, 1, 2, 3, 4]])
178
+ from_tensor = torch.tensor([[-4, -3, -2, -1]])
179
+ shifted = _shift_with_pad(two_gram_batch, 2, from_tensor)
180
+ expected = torch.tensor([[2, 3, 4, -1]])
181
+ assert torch.all(torch.eq(shifted, expected))
182
+
183
+ def test_three_gram(self):
184
+ three_gram_batch = torch.tensor([[0, 1, 2, 3, 4, 5, 6]])
185
+ from_tensor = torch.tensor([[-6, -5, -4, -3, -2, -1]])
186
+ shifted = _shift_with_pad(three_gram_batch, 3, from_tensor)
187
+ expected = torch.tensor([[3, 4, 5, 6, -2, -1]])
188
+ assert torch.all(torch.eq(shifted, expected))
189
+
190
+ def test_three_gram_2(self):
191
+ three_gram_batch = torch.tensor([[0, 1, 2, 3, 4, 5, 6], [0, 1, 2, 3, 4, 5, 6]])
192
+ from_tensor = torch.tensor([[-6, -5, -4, -3, -2, -1], [-6, -5, -4, -3, -2, -1]])
193
+ shifted = _shift_with_pad(three_gram_batch, 3, from_tensor)
194
+ expected = torch.tensor([[3, 4, 5, 6, -2, -1], [3, 4, 5, 6, -2, -1]])
195
+ assert torch.all(torch.eq(shifted, expected))
196
+
197
+
198
+ class TestNGramEmbeddings(unittest.TestCase):
199
+
200
+ def test(self):
201
+ emb = NGramsEmbedding(10, 10)
202
+ emb.weight = torch.nn.Parameter(torch.eye(10))
203
+
204
+ emb1 = emb(torch.tensor([[1, 2, 3]]))
205
+ emb2 = emb(torch.tensor([[4, 5, 6]]))
206
+ emb3 = emb(torch.tensor([[1, 2, 3]]), [torch.tensor([[4, 5, 6]])])
207
+
208
+ assert torch.all(torch.eq(torch.add(emb1, emb2), emb3))
209
+
210
+ def test_2(self):
211
+ emb = NGramsEmbedding(10, 10)
212
+ emb.weight = torch.nn.Parameter(torch.eye(10))
213
+
214
+ emb1 = emb(torch.tensor([[1, 2, 3]]))
215
+
216
+ emb2 = emb(torch.tensor([[1, 2, 3]]))
217
+
218
+ emb3 = emb(torch.tensor([[1, 2, 3]]), [torch.tensor([[1, 2, 3]])])
219
+
220
+ assert torch.all(torch.eq(emb1, emb3))
221
+ assert torch.all(torch.eq(emb2, emb3))
222
+
223
+ def test_3_gram(self):
224
+ emb = NGramsEmbedding(10, 10)
225
+ emb.weight = torch.nn.Parameter(torch.eye(10))
226
+
227
+ emb1 = emb(torch.tensor([[1, 2, 3]]))
228
+ emb2 = emb(torch.tensor([[4, 5, 6]]))
229
+ emb3 = emb(torch.tensor([[7, 8, 9]]))
230
+
231
+ emb4 = emb(torch.tensor([[1, 2, 3]]), [torch.tensor([[4, 5, 6]]), torch.tensor([[7, 8, 9]])])
232
+
233
+ assert torch.all(torch.eq(torch.add(torch.add(emb1, emb2), emb3), emb4))
234
+
235
+ def test_ignore_indx(self):
236
+
237
+ emb = NGramsEmbedding(10, 10, unk_idx=0)
238
+ emb.weight = torch.nn.Parameter(torch.eye(10))
239
+
240
+ unigram = torch.tensor([[1, 2, 3]])
241
+ bigram = torch.tensor([[0, 0, 0]])
242
+
243
+ emb1 = emb(unigram, [bigram])
244
+ emb2 = emb(unigram)
245
+
246
+ assert torch.all(torch.eq(emb1, emb2))
247
+
248
+ def test_ignore_indx_2(self):
249
+
250
+ emb = NGramsEmbedding(10, 10, unk_idx=0)
251
+ emb.weight = torch.nn.Parameter(torch.eye(10))
252
+
253
+ unigram = torch.tensor([[0, 2, 3]])
254
+ bigram = torch.tensor([[0, 0, 0]])
255
+
256
+ emb1 = emb(unigram, [bigram])
257
+ emb2 = emb(unigram)
258
+
259
+ assert torch.all(torch.eq(emb1, emb2))
260
+
261
+
262
+
263
+
264
+ if __name__ == '__main__':
265
+ unittest.main()