ntt123 commited on
Commit
6c6b947
1 Parent(s): f4d35b6

Create script.js

Browse files
Files changed (1) hide show
  1. script.js +243 -0
script.js ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ var log = console.log;
2
+ var ctx = null;
3
+ var canvas = null;
4
+ var RNN_SIZE = 400;
5
+ var VOCAB_SIZE = 165;
6
+ var NUM_ATT_HEADS=10;
7
+ var NUM_GMM_HEADS=20;
8
+ var cur_run = 0;
9
+ var scale_factor = 1.;
10
+
11
+ var randn = function() {
12
+ // Standard Normal random variable using Box-Muller transform.
13
+ var u = Math.random() * 0.999 + 1e-5;
14
+ var v = Math.random() * 0.999 + 1e-5;
15
+ return Math.sqrt(-2.0 * Math.log(u)) * Math.cos(2.0 * Math.PI * v);
16
+ }
17
+
18
+ var rand_truncated_normal = function(low, high) {
19
+ while (true) {
20
+ r = randn();
21
+ if (r >= low && r <= high)
22
+ break;
23
+ // rejection sampling.
24
+ }
25
+ return r;
26
+ }
27
+
28
+
29
+
30
+ var char2idx = {'\x00': 0, ' ': 1, '!': 2, '"': 3, '#': 4, '%': 5, '&': 6, "'": 7, '(': 8, ')': 9, '*': 10, ',': 11, '-': 12, '.': 13, '/': 14, '0': 15, '1': 16, '2': 17, '3': 18, '4': 19, '5': 20, '6': 21, '7': 22, '8': 23, '9': 24, ':': 25, ';': 26, '?': 27, 'A': 28, 'B': 29, 'C': 30, 'D': 31, 'E': 32, 'F': 33, 'G': 34, 'H': 35, 'I': 36, 'J': 37, 'K': 38, 'L': 39, 'M': 40, 'N': 41, 'O': 42, 'P': 43, 'Q': 44, 'R': 45, 'S': 46, 'T': 47, 'U': 48, 'V': 49, 'W': 50, 'X': 51, 'Y': 52, 'a': 53, 'b': 54, 'c': 55, 'd': 56, 'e': 57, 'f': 58, 'g': 59, 'h': 60, 'i': 61, 'j': 62, 'k': 63, 'l': 64, 'm': 65, 'n': 66, 'o': 67, 'p': 68, 'q': 69, 'r': 70, 's': 71, 't': 72, 'u': 73, 'v': 74, 'w': 75, 'x': 76, 'y': 77, 'z': 78, 'À': 79, 'Á': 80, 'Â': 81, 'Ô': 82, 'Ú': 83, 'Ý': 84, 'à': 85, 'á': 86, 'â': 87, 'ã': 88, 'è': 89, 'é': 90, 'ê': 91, 'ì': 92, 'í': 93, 'ò': 94, 'ó': 95, 'ô': 96, 'õ': 97, 'ù': 98, 'ú': 99, 'ý': 100, 'Ă': 101, 'ă': 102, 'Đ': 103, 'đ': 104, 'ĩ': 105, 'ũ': 106, 'Ơ': 107, 'ơ': 108, 'Ư': 109, 'ư': 110, 'ạ': 111, 'Ả': 112, 'ả': 113, 'Ấ': 114, 'ấ': 115, 'Ầ': 116, 'ầ': 117, 'ẩ': 118, 'ẫ': 119, 'ậ': 120, 'ắ': 121, 'ằ': 122, 'ẳ': 123, 'ẵ': 124, 'ặ': 125, 'ẹ': 126, 'ẻ': 127, 'ẽ': 128, 'ế': 129, 'Ề': 130, 'ề': 131, 'Ể': 132, 'ể': 133, 'ễ': 134, 'Ệ': 135, 'ệ': 136, 'ỉ': 137, 'ị': 138, 'ọ': 139, 'ỏ': 140, 'Ố': 141, 'ố': 142, 'Ồ': 143, 'ồ': 144, 'ổ': 145, 'ỗ': 146, 'ộ': 147, 'ớ': 148, 'ờ': 149, 'Ở': 150, 'ở': 151, 'ỡ': 152, 'ợ': 153, 'ụ': 154, 'Ủ': 155, 'ủ': 156, 'ứ': 157, 'ừ': 158, 'ử': 159, 'ữ': 160, 'ự': 161, 'ỳ': 162, 'ỷ': 163, 'ỹ': 164};
31
+
32
+ var gru_core = function(input, weights, state, hidden_size) {
33
+ var [w_h,w_i,b] = weights;
34
+ var [w_h_z,w_h_a] = tf.split(w_h, [2 * hidden_size, hidden_size], 1);
35
+ var [b_z,b_a] = tf.split(b, [2 * hidden_size, hidden_size], 0);
36
+ gates_x = tf.matMul(input, w_i);
37
+ [zr_x,a_x] = tf.split(gates_x, [2 * hidden_size, hidden_size], 1);
38
+ zr_h = tf.matMul(state, w_h_z);
39
+ zr = tf.add(tf.add(zr_x, zr_h), b_z);
40
+ // fix this
41
+ [z,r] = tf.split(tf.sigmoid(zr), 2, 1);
42
+ a_h = tf.matMul(tf.mul(r, state), w_h_a);
43
+ a = tf.tanh(tf.add(tf.add(a_x, a_h), b_a));
44
+ next_state = tf.add(tf.mul(tf.sub(1., z), state), tf.mul(z, a));
45
+ return [next_state, next_state];
46
+ };
47
+
48
+
49
+ var generate = function() {
50
+ cur_run = cur_run + 1;
51
+ setTimeout(function() {
52
+ var counter = 2000;
53
+ tf.disposeVariables();
54
+
55
+ tf.engine().startScope();
56
+ ctx.clearRect(0, 0, canvas.width, canvas.height);
57
+ ctx.beginPath();
58
+ dojob(cur_run);
59
+ }, 200);
60
+
61
+ return false;
62
+ }
63
+
64
+ var dojob = function(run_id) {
65
+ var text = document.getElementById("user-input").value;
66
+ if (text.length == 0) {
67
+ text = "Tất cả mọi người đều sinh ra có quyền bình đẳng.";
68
+ }
69
+
70
+
71
+ log(text);
72
+ original_text = text;
73
+ text = '' + text + ' ';
74
+
75
+ text = Array.from(text).map(function(e) {
76
+ return char2idx[e]
77
+ })
78
+ var text_embed = WEIGHTS['rnn/~/embed_1__embeddings'];
79
+ indices = tf.tensor1d(text, 'int32');
80
+ text = text_embed.gather(indices);
81
+
82
+ var embed = text;
83
+
84
+ var writer_embed = WEIGHTS['rnn/~/embed__embeddings'];
85
+ var e = document.getElementById("writers");
86
+ var wid = parseInt(e.value);
87
+ log(wid);
88
+
89
+ wid = tf.tensor1d([wid], 'int32');
90
+ wid = writer_embed.gather(wid);
91
+ embed = tf.add(wid, embed);
92
+
93
+
94
+
95
+ filter = WEIGHTS['rnn/~/conv1_d__w'];
96
+ embed = tf.conv1d(embed, filter, 1, 'same');
97
+ bias = tf.expandDims(WEIGHTS['rnn/~/conv1_d__b'], 0);
98
+ embed = tf.add(embed, bias);
99
+
100
+
101
+ // initial state
102
+ var gru0_hx = tf.zeros([1, RNN_SIZE]);
103
+ var gru1_hx = tf.zeros([1, RNN_SIZE]);
104
+ var gru2_hx = tf.zeros([1, RNN_SIZE]);
105
+
106
+ var att_location = tf.zeros([1, NUM_ATT_HEADS]);
107
+ var att_context = tf.zeros([1, VOCAB_SIZE]);
108
+
109
+ var input = tf.tensor([[0., 0., 1.]]);
110
+
111
+ gru0_w_h = WEIGHTS['rnn/~/attention_core/~/gru__w_h'];
112
+ gru0_w_i = WEIGHTS['rnn/~/attention_core/~/gru__w_i'];
113
+ gru0_bias = WEIGHTS['rnn/~/attention_core/~/gru__b'];
114
+
115
+ gru1_w_h = WEIGHTS['rnn/~/attention_core/~/gru_1__w_h'];
116
+ gru1_w_i = WEIGHTS['rnn/~/attention_core/~/gru_1__w_i'];
117
+ gru1_bias = WEIGHTS['rnn/~/attention_core/~/gru_1__b'];
118
+
119
+ gru2_w_h = WEIGHTS['rnn/~/attention_core/~/gru_2__w_h'];
120
+ gru2_w_i = WEIGHTS['rnn/~/attention_core/~/gru_2__w_i'];
121
+ gru2_bias = WEIGHTS['rnn/~/attention_core/~/gru_2__b'];
122
+
123
+ att_w = WEIGHTS['rnn/~/attention_core/~/linear__w'];
124
+ att_b = WEIGHTS['rnn/~/attention_core/~/linear__b'];
125
+ gmm_w = WEIGHTS['rnn/~/linear__w'];
126
+ gmm_b = WEIGHTS['rnn/~/linear__b'];
127
+
128
+ var ruler = tf.tensor([...Array(text.shape[0]).keys()]);
129
+ ruler = tf.expandDims(ruler, 1);
130
+ var bias = parseInt(document.getElementById("bias").value) / 100 * 3;
131
+
132
+
133
+ var cur_x = 20;
134
+ var cur_y = innerHeight / 2 + 30;
135
+ var path = [];
136
+ var dx = 0.;
137
+ var dy = 0;
138
+ var eos = 1.;
139
+ var counter = 0;
140
+
141
+
142
+ function loop(my_run_id) {
143
+ if (my_run_id < cur_run) {
144
+ tf.disposeVariables();
145
+ tf.engine().endScope();
146
+ return;
147
+ }
148
+
149
+ counter++;
150
+ if (counter < 2000) {
151
+ [att_location,att_context,gru0_hx,gru1_hx, gru2_hx, input] = tf.tidy(function() {
152
+ // Attention
153
+ const inp_0 = tf.concat([att_context, input], 1);
154
+ gru0_hx_ = gru0_hx;
155
+ [out_0,gru0_hx] = gru_core(inp_0, [gru0_w_h, gru0_w_i, gru0_bias], gru0_hx, RNN_SIZE);
156
+ tf.dispose(gru0_hx_);
157
+ const att_inp = tf.concat([att_context, input, out_0], 1);
158
+ const att_params = tf.add(tf.matMul(att_inp, att_w), att_b);
159
+ [alpha,beta,kappa] = tf.split(tf.softplus(att_params), 3, 1);
160
+ att_location_ = att_location;
161
+ att_location = tf.add(att_location, tf.div(kappa, 25.));
162
+ tf.dispose(att_location_)
163
+
164
+ var phi = tf.sum(tf.mul(alpha, tf.exp(tf.div(tf.neg(tf.square(tf.sub(att_location, ruler))), beta))), 1);
165
+ phi = tf.expandDims(phi, 0);
166
+
167
+ att_context_ = att_context;
168
+ att_context = tf.sum(tf.mul(tf.expandDims(phi, 2), tf.expandDims(embed, 0)), 1)
169
+ tf.dispose(att_context_);
170
+
171
+ const inp_1 = tf.concat([input, out_0, att_context], 1);
172
+ // tf.dispose(input);
173
+ gru1_hx_ = gru1_hx;
174
+ [out_1,gru1_hx] = gru_core(inp_1, [gru1_w_h, gru1_w_i, gru1_bias], gru1_hx, RNN_SIZE);
175
+ tf.dispose(gru1_hx_);
176
+
177
+ const inp_2 = tf.concat([input, out_1, att_context], 1);
178
+ tf.dispose(input);
179
+ gru2_hx_ = gru2_hx;
180
+ [out_2, gru2_hx] = gru_core(inp_2, [gru2_w_h, gru2_w_i, gru2_bias], gru2_hx, RNN_SIZE);
181
+ tf.dispose(gru2_hx_);
182
+
183
+ // debugger;
184
+
185
+ // GMM
186
+ const gmm_params = tf.add(tf.matMul(out_2, gmm_w), gmm_b);
187
+ [x,y,logstdx,logstdy,angle,log_weight,eos_logit] = tf.split(gmm_params, [NUM_GMM_HEADS, NUM_GMM_HEADS, NUM_GMM_HEADS, NUM_GMM_HEADS, NUM_GMM_HEADS, NUM_GMM_HEADS, 1], 1);
188
+ // log_weight = tf.softmax(log_weight, 1);
189
+ // log_weight = tf.log(log_weight);
190
+ // log_weight = tf.mul(log_weight, 1. + bias);
191
+ const idx = tf.argMax(log_weight, 1).dataSync()[0];
192
+ // const idx = tf.multinomial(log_weight, 1).dataSync()[0];
193
+ x = x.dataSync()[idx];
194
+ y = y.dataSync()[idx];
195
+ const stdx = tf.exp(tf.sub(logstdx, bias)).dataSync()[idx];
196
+ const stdy = tf.exp(tf.sub(logstdy, bias)).dataSync()[idx];
197
+ angle = angle.dataSync()[idx];
198
+ e = tf.sigmoid(tf.mul(eos_logit, (1. + bias/5))).dataSync()[0];
199
+ const rx = rand_truncated_normal(-5, 5) * stdx;
200
+ const ry = rand_truncated_normal(-5, 5) * stdy;
201
+ x = x + Math.cos(-angle) * rx - Math.sin(-angle) * ry;
202
+ y = y + Math.sin(-angle) * rx + Math.cos(-angle) * ry;
203
+ if (Math.random() < e) {
204
+ e = 1.;
205
+ } else {
206
+ e = 0.;
207
+ }
208
+ input = tf.tensor([[x, y, e]]);
209
+ return [att_location, att_context, gru0_hx, gru1_hx, gru2_hx, input];
210
+ });
211
+
212
+ [dx,dy,eos_] = input.dataSync();
213
+ dy = -dy * 3. * scale_factor;
214
+ dx = dx * 3. * scale_factor;
215
+ if (eos == 0.) {
216
+ ctx.beginPath();
217
+ ctx.moveTo(cur_x, cur_y, 0, 0);
218
+ ctx.lineTo(cur_x + dx, cur_y + dy);
219
+ ctx.stroke();
220
+ }
221
+ eos = eos_;
222
+ cur_x = cur_x + dx;
223
+ cur_y = cur_y + dy;
224
+
225
+ if (att_location.dataSync()[0] < original_text.length + 1.5) {
226
+ setTimeout(function() {loop(my_run_id);}, 0);
227
+ }
228
+ }
229
+ }
230
+
231
+ loop(run_id);
232
+ }
233
+
234
+
235
+ window.onload = function(e) {
236
+ //Setting up canvas
237
+ canvas = document.getElementById("hw-canvas");
238
+ ctx = canvas.getContext("2d");
239
+ scale_factor = window.innerWidth / 1600;
240
+ ctx.canvas.width = window.innerWidth;
241
+ ctx.canvas.height = window.innerHeight;
242
+
243
+ }