danurahul commited on
Commit
ae7ea4a
1 Parent(s): 1a21884

Upload modules.py

Browse files
Files changed (1) hide show
  1. modules.py +233 -0
modules.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+ def embedding_lookup(lookup_table, x):
4
+ return tf.compat.v1.nn.embedding_lookup(lookup_table, x)
5
+
6
+
7
+ def normal_embedding_lookup(x, n_token, d_embed, d_proj, initializer,
8
+ proj_initializer, scope='normal_embed', **kwargs):
9
+ emb_scale = d_proj ** 0.5
10
+ with tf.compat.v1.variable_scope(scope):
11
+ lookup_table = tf.compat.v1.get_variable('lookup_table', [n_token, d_embed], initializer=initializer)
12
+ y = embedding_lookup(lookup_table, x)
13
+ if d_proj != d_embed:
14
+ proj_W = tf.compat.v1.get_variable('proj_W', [d_embed, d_proj], initializer=proj_initializer)
15
+ y = tf.einsum('ibe,ed->ibd', y, proj_W)
16
+ else:
17
+ proj_W = None
18
+ ret_params = [lookup_table, proj_W]
19
+ y *= emb_scale
20
+ return y, ret_params
21
+
22
+
23
+ def normal_softmax(hidden, target, n_token, params, scope='normal_softmax', **kwargs):
24
+ def _logit(x, W, b, proj):
25
+ y = x
26
+ if proj is not None:
27
+ y = tf.einsum('ibd,ed->ibe', y, proj)
28
+ return tf.einsum('ibd,nd->ibn', y, W) + b
29
+
30
+ params_W, params_projs = params[0], params[1]
31
+
32
+ with tf.compat.v1.variable_scope(scope):
33
+ softmax_b = tf.compat.v1.get_variable('bias', [n_token], initializer=tf.zeros_initializer())
34
+ output = _logit(hidden, params_W, softmax_b, params_projs)
35
+ nll = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target, logits=output)
36
+ return nll, output
37
+
38
+
39
+ def positional_embedding(pos_seq, inv_freq, bsz=None):
40
+ sinusoid_inp = tf.einsum('i,j->ij', pos_seq, inv_freq)
41
+ pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1)
42
+ if bsz is not None:
43
+ return tf.tile(pos_emb[:, None, :], [1, bsz, 1])
44
+ else:
45
+ return pos_emb[:, None, :]
46
+
47
+
48
+ def positionwise_FF(inp, d_model, d_inner, dropout, kernel_initializer,
49
+ scope='ff', is_training=True):
50
+ output = inp
51
+ with tf.compat.v1.variable_scope(scope):
52
+ output = tf.keras.layers.Dense(d_inner, activation=tf.nn.relu,
53
+ kernel_initializer=kernel_initializer, name='layer_1')(inp)
54
+ output = tf.keras.layers.Dropout(dropout, name='drop_1')(output, training=is_training)
55
+ output = tf.keras.layers.Dense(d_model, activation=tf.nn.relu,
56
+ kernel_initializer=kernel_initializer, name='layer_2')(output)
57
+ output = tf.keras.layers.Dropout(dropout, name='drop_2')(output, training=is_training)
58
+ output = tf.keras.layers.LayerNormalization(axis=-1)(output + inp)
59
+ return output
60
+
61
+
62
+ def _create_mask(qlen, mlen, same_length=False):
63
+ attn_mask = tf.ones([qlen, qlen])
64
+ mask_u = tf.linalg.band_part(attn_mask, 0, -1)
65
+ mask_dia = tf.linalg.band_part(attn_mask, 0, 0)
66
+ attn_mask_pad = tf.zeros([qlen, mlen])
67
+ ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1)
68
+ if same_length:
69
+ mask_l = tf.matrix_band_part(attn_mask, -1, 0)
70
+ ret = tf.concat([ret[:, :qlen] + mask_l - mask_dia, ret[:, qlen:]], 1)
71
+ return ret
72
+
73
+
74
+ def _cache_mem(curr_out, prev_mem, mem_len=None):
75
+ if mem_len is None or prev_mem is None:
76
+ new_mem = curr_out
77
+ elif mem_len == 0:
78
+ return prev_mem
79
+ else:
80
+ new_mem = tf.concat([prev_mem, curr_out], 0)[-mem_len:]
81
+ return tf.stop_gradient(new_mem)
82
+
83
+
84
+ def rel_shift(x):
85
+ x_size = tf.shape(x)
86
+ x = tf.pad(x, [[0, 0], [1, 0], [0, 0], [0, 0]])
87
+ x = tf.reshape(x, [x_size[1] + 1, x_size[0], x_size[2], x_size[3]])
88
+ x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1])
89
+ x = tf.reshape(x, x_size)
90
+ return x
91
+
92
+
93
+ def rel_multihead_attn(w, r, r_w_bias, r_r_bias, attn_mask, mems, d_model,
94
+ n_head, d_head, dropout, dropatt, is_training,
95
+ kernel_initializer, scope='rel_attn'):
96
+ scale = 1 / (d_head ** 0.5)
97
+ with tf.compat.v1.variable_scope(scope):
98
+ qlen = tf.shape(w)[0]
99
+ rlen = tf.shape(r)[0]
100
+ bsz = tf.shape(w)[1]
101
+
102
+ cat = tf.concat([mems, w], 0) if mems is not None and mems.shape.ndims > 1 else w
103
+
104
+ w_heads = tf.keras.layers.Dense(3 * n_head * d_head, use_bias=False,
105
+ kernel_initializer=kernel_initializer, name='qkv')(cat)
106
+ r_head_k = tf.keras.layers.Dense(n_head * d_head, use_bias=False,
107
+ kernel_initializer=kernel_initializer, name='r')(r)
108
+
109
+ w_head_q, w_head_k, w_head_v = tf.split(w_heads, 3, -1)
110
+ w_head_q = w_head_q[-qlen:]
111
+
112
+ klen = tf.shape(w_head_k)[0]
113
+
114
+ w_head_q = tf.reshape(w_head_q, [qlen, bsz, n_head, d_head])
115
+ w_head_k = tf.reshape(w_head_k, [klen, bsz, n_head, d_head])
116
+ w_head_v = tf.reshape(w_head_v, [klen, bsz, n_head, d_head])
117
+
118
+ r_head_k = tf.reshape(r_head_k, [rlen, n_head, d_head])
119
+
120
+ rw_head_q = w_head_q + r_w_bias
121
+ rr_head_q = w_head_q + r_r_bias
122
+
123
+ AC = tf.einsum('ibnd,jbnd->ijbn', rw_head_q, w_head_k)
124
+ BD = tf.einsum('ibnd,jnd->ijbn', rr_head_q, r_head_k)
125
+ BD = rel_shift(BD)
126
+
127
+ attn_score = (AC + BD) * scale
128
+ attn_mask_t = attn_mask[:, :, None, None]
129
+ attn_score = attn_score * (1 - attn_mask_t) - 1e30 * attn_mask_t
130
+
131
+ attn_prob = tf.nn.softmax(attn_score, 1)
132
+ attn_prob = tf.keras.layers.Dropout(dropatt)(attn_prob, training=is_training)
133
+
134
+ attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, w_head_v)
135
+ size_t = tf.shape(attn_vec)
136
+ attn_vec = tf.reshape(attn_vec, [size_t[0], size_t[1], n_head * d_head])
137
+
138
+ attn_out = tf.keras.layers.Dense(d_model, use_bias=False,
139
+ kernel_initializer=kernel_initializer, name='o')(attn_vec)
140
+ attn_out = tf.keras.layers.Dropout(dropout)(attn_out, training=is_training)
141
+ output = tf.keras.layers.LayerNormalization(axis=-1)(attn_out + w)
142
+ return output
143
+
144
+
145
+ def transformer(dec_inp, target, mems, n_token, n_layer, d_model, d_embed,
146
+ n_head, d_head, d_inner, dropout, dropatt,
147
+ initializer, is_training, proj_initializer=None,
148
+ mem_len=None, cutoffs=[], div_val=1, tie_projs=[],
149
+ same_length=False, clamp_len=-1,
150
+ input_perms=None, target_perms=None, head_target=None,
151
+ untie_r=False, proj_same_dim=True,
152
+ scope='transformer'):
153
+ """
154
+ cutoffs: a list of python int. Cutoffs for adaptive softmax.
155
+ tie_projs: a list of python bools. Whether to tie the projections.
156
+ perms: a list of tensors. Each tensor should of size [len, bsz, bin_size].
157
+ Only used in the adaptive setting.
158
+ """
159
+ new_mems = []
160
+ with tf.compat.v1.variable_scope(scope):
161
+ if untie_r:
162
+ r_w_bias = tf.compat.v1.get_variable('r_w_bias', [n_layer, n_head, d_head], initializer=initializer)
163
+ r_r_bias = tf.compat.v1.get_variable('r_r_bias', [n_layer, n_head, d_head], initializer=initializer)
164
+ else:
165
+ r_w_bias = tf.compat.v1.get_variable('r_w_bias', [n_head, d_head], initializer=initializer)
166
+ r_r_bias = tf.compat.v1.get_variable('r_r_bias', [n_head, d_head], initializer=initializer)
167
+
168
+ qlen = tf.shape(dec_inp)[0]
169
+ mlen = tf.shape(mems[0])[0] if mems is not None else 0
170
+ klen = qlen + mlen
171
+
172
+ if proj_initializer is None:
173
+ proj_initializer = initializer
174
+
175
+ embeddings, shared_params = normal_embedding_lookup(
176
+ x=dec_inp,
177
+ n_token=n_token,
178
+ d_embed=d_embed,
179
+ d_proj=d_model,
180
+ initializer=initializer,
181
+ proj_initializer=proj_initializer)
182
+
183
+ attn_mask = _create_mask(qlen, mlen, same_length)
184
+
185
+ pos_seq = tf.range(klen - 1, -1, -1.0)
186
+ if clamp_len > 0:
187
+ pos_seq = tf.minimum(pos_seq, clamp_len)
188
+ inv_freq = 1 / (10000 ** (tf.range(0, d_model, 2.0) / d_model))
189
+ pos_emb = positional_embedding(pos_seq, inv_freq)
190
+
191
+ output = tf.keras.layers.Dropout(rate=dropout)(embeddings, training=is_training)
192
+ pos_emb = tf.keras.layers.Dropout(rate=dropout)(pos_emb, training=is_training)
193
+
194
+ if mems is None:
195
+ mems = [None] * n_layer
196
+
197
+ for i in range(n_layer):
198
+ # cache new mems
199
+ new_mems.append(_cache_mem(output, mems[i], mem_len))
200
+
201
+ with tf.compat.v1.variable_scope('layer_{}'.format(i)):
202
+ output = rel_multihead_attn(
203
+ w=output,
204
+ r=pos_emb,
205
+ r_w_bias=r_w_bias if not untie_r else r_w_bias[i],
206
+ r_r_bias=r_r_bias if not untie_r else r_r_bias[i],
207
+ attn_mask=attn_mask,
208
+ mems=mems[i],
209
+ d_model=d_model,
210
+ n_head=n_head,
211
+ d_head=d_head,
212
+ dropout=dropout,
213
+ dropatt=dropatt,
214
+ is_training=is_training,
215
+ kernel_initializer=initializer)
216
+
217
+ output = positionwise_FF(
218
+ inp=output,
219
+ d_model=d_model,
220
+ d_inner=d_inner,
221
+ dropout=dropout,
222
+ kernel_initializer=initializer,
223
+ is_training=is_training)
224
+
225
+ output = tf.keras.layers.Dropout(dropout)(output, training=is_training)
226
+
227
+ loss, logits = normal_softmax(
228
+ hidden=output,
229
+ target=target,
230
+ n_token=n_token,
231
+ params=shared_params)
232
+
233
+ return loss, logits, new_mems