ntt123 commited on
Commit
994c4fd
1 Parent(s): 0f03376

Create script.js

Browse files
Files changed (1) hide show
  1. script.js +222 -0
script.js ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ var log = console.log;
2
+ var ctx = null;
3
+ var canvas = null;
4
+ var RNN_SIZE = 512;
5
+ var cur_run = 0;
6
+
7
+ var randn = function() {
8
+ // Standard Normal random variable using Box-Muller transform.
9
+ var u = Math.random() * 0.999 + 1e-5;
10
+ var v = Math.random() * 0.999 + 1e-5;
11
+ return Math.sqrt(-2.0 * Math.log(u)) * Math.cos(2.0 * Math.PI * v);
12
+ }
13
+
14
+ var rand_truncated_normal = function(low, high) {
15
+ while (true) {
16
+ r = randn();
17
+ if (r >= low && r <= high)
18
+ break;
19
+ // rejection sampling.
20
+ }
21
+ return r;
22
+ }
23
+
24
+
25
+
26
+ var char2idx = {'\x00': 0, ' ': 1, '!': 2, '"': 3, '#': 4, "'": 5, '(': 6, ')': 7, ',': 8, '-': 9, '.': 10, '0': 11, '1': 12, '2': 13, '3': 14, '4': 15, '5': 16, '6': 17, '7': 18, '8': 19, '9': 20, ':': 21, ';': 22, '?': 23, 'A': 24, 'B': 25, 'C': 26, 'D': 27, 'E': 28, 'F': 29, 'G': 30, 'H': 31, 'I': 32, 'J': 33, 'K': 34, 'L': 35, 'M': 36, 'N': 37, 'O': 38, 'P': 39, 'R': 40, 'S': 41, 'T': 42, 'U': 43, 'V': 44, 'W': 45, 'Y': 46, 'a': 47, 'b': 48, 'c': 49, 'd': 50, 'e': 51, 'f': 52, 'g': 53, 'h': 54, 'i': 55, 'j': 56, 'k': 57, 'l': 58, 'm': 59, 'n': 60, 'o': 61, 'p': 62, 'q': 63, 'r': 64, 's': 65, 't': 66, 'u': 67, 'v': 68, 'w': 69, 'x': 70, 'y': 71, 'z': 72};
27
+
28
+ var gru_core = function(input, weights, state, hidden_size) {
29
+ var [w_h,w_i,b] = weights;
30
+ var [w_h_z,w_h_a] = tf.split(w_h, [2 * hidden_size, hidden_size], 1);
31
+ var [b_z,b_a] = tf.split(b, [2 * hidden_size, hidden_size], 0);
32
+ gates_x = tf.matMul(input, w_i);
33
+ [zr_x,a_x] = tf.split(gates_x, [2 * hidden_size, hidden_size], 1);
34
+ zr_h = tf.matMul(state, w_h_z);
35
+ zr = tf.add(tf.add(zr_x, zr_h), b_z);
36
+ // fix this
37
+ [z,r] = tf.split(tf.sigmoid(zr), 2, 1);
38
+ a_h = tf.matMul(tf.mul(r, state), w_h_a);
39
+ a = tf.tanh(tf.add(tf.add(a_x, a_h), b_a));
40
+ next_state = tf.add(tf.mul(tf.sub(1., z), state), tf.mul(z, a));
41
+ return [next_state, next_state];
42
+ };
43
+
44
+
45
+ var generate = function() {
46
+ cur_run = cur_run + 1;
47
+ setTimeout(function() {
48
+ var counter = 2000;
49
+ tf.disposeVariables();
50
+
51
+ tf.engine().startScope();
52
+ ctx.clearRect(0, 0, canvas.width, canvas.height);
53
+ ctx.beginPath();
54
+ dojob(cur_run);
55
+ }, 200);
56
+
57
+ return false;
58
+ }
59
+
60
+ var dojob = function(run_id) {
61
+ var text = document.getElementById("user-input").value;
62
+ if (text.length == 0) {
63
+ text = "The quick brown fox jumps over the lazy dog";
64
+ }
65
+
66
+ var cur_x = 50.;
67
+ var cur_y = 300.;
68
+
69
+
70
+ log(text);
71
+ original_text = text;
72
+ text = '' + text + ' ' + text;
73
+
74
+ text = Array.from(text).map(function(e) {
75
+ return char2idx[e]
76
+ })
77
+ var text_embed = WEIGHTS['rnn/~/embed_1__embeddings'];
78
+ indices = tf.tensor1d(text, 'int32');
79
+ text = text_embed.gather(indices);
80
+
81
+ filter = WEIGHTS['rnn/~/conv1_d__w'];
82
+ embed = tf.conv1d(text, filter, 1, 'same');
83
+ bias = tf.expandDims(WEIGHTS['rnn/~/conv1_d__b'], 0);
84
+ embed = tf.add(embed, bias);
85
+
86
+ var writer_embed = WEIGHTS['rnn/~/embed__embeddings'];
87
+ var e = document.getElementById("writers");
88
+ var wid = parseInt(e.value);
89
+ // log(wid);
90
+
91
+ wid = tf.tensor1d([wid], 'int32');
92
+ wid = writer_embed.gather(wid);
93
+ embed = tf.add(wid, embed);
94
+
95
+ // initial state
96
+ var gru0_hx = tf.zeros([1, RNN_SIZE]);
97
+ var gru1_hx = tf.zeros([1, RNN_SIZE]);
98
+ // var gru2_hx = tf.zeros([1, RNN_SIZE]);
99
+
100
+ var att_location = tf.zeros([1, 1]);
101
+ var att_context = tf.zeros([1, 73]);
102
+
103
+ var input = tf.tensor([[0., 0., 1.]]);
104
+
105
+ gru0_w_h = WEIGHTS['rnn/~/lstm_attention_core/~/gru__w_h'];
106
+ gru0_w_i = WEIGHTS['rnn/~/lstm_attention_core/~/gru__w_i'];
107
+ gru0_bias = WEIGHTS['rnn/~/lstm_attention_core/~/gru__b'];
108
+
109
+ gru1_w_h = WEIGHTS['rnn/~/lstm_attention_core/~/gru_1__w_h'];
110
+ gru1_w_i = WEIGHTS['rnn/~/lstm_attention_core/~/gru_1__w_i'];
111
+ gru1_bias = WEIGHTS['rnn/~/lstm_attention_core/~/gru_1__b'];
112
+ att_w = WEIGHTS['rnn/~/lstm_attention_core/~/linear__w'];
113
+ att_b = WEIGHTS['rnn/~/lstm_attention_core/~/linear__b'];
114
+ gmm_w = WEIGHTS['rnn/~/linear__w'];
115
+ gmm_b = WEIGHTS['rnn/~/linear__b'];
116
+
117
+ ruler = tf.tensor([...Array(text.shape[0]).keys()]);
118
+ var bias = parseInt(document.getElementById("bias").value) / 100 * 3;
119
+
120
+ cur_x = 50.;
121
+ cur_y = 400.;
122
+ var path = [];
123
+ var dx = 0.;
124
+ var dy = 0;
125
+ var eos = 1.;
126
+ var counter = 0;
127
+
128
+
129
+ function loop(my_run_id) {
130
+ if (my_run_id < cur_run) {
131
+ tf.disposeVariables();
132
+ tf.engine().endScope();
133
+ return;
134
+ }
135
+
136
+ counter++;
137
+ if (counter < 2000) {
138
+ [att_location,att_context,gru0_hx,gru1_hx,input] = tf.tidy(function() {
139
+ // Attention
140
+ const inp_0 = tf.concat([att_context, input], 1);
141
+ gru0_hx_ = gru0_hx;
142
+ [out_0,gru0_hx] = gru_core(inp_0, [gru0_w_h, gru0_w_i, gru0_bias], gru0_hx, RNN_SIZE);
143
+ tf.dispose(gru0_hx_);
144
+ const att_inp = tf.concat([att_context, input, out_0], 1);
145
+ const att_params = tf.add(tf.matMul(att_inp, att_w), att_b);
146
+ [alpha,beta,kappa] = tf.split(tf.softplus(att_params), 3, 1);
147
+ att_location_ = att_location;
148
+ att_location = tf.add(att_location, tf.div(kappa, 25.));
149
+ tf.dispose(att_location_)
150
+
151
+ const phi = tf.mul(alpha, tf.exp(tf.div(tf.neg(tf.square(tf.sub(att_location, ruler))), beta)));
152
+ att_context_ = att_context;
153
+ att_context = tf.sum(tf.mul(tf.expandDims(phi, 2), tf.expandDims(embed, 0)), 1)
154
+ tf.dispose(att_context_);
155
+
156
+ const inp_1 = tf.concat([input, out_0, att_context], 1);
157
+ tf.dispose(input);
158
+ gru1_hx_ = gru1_hx;
159
+ [out_1,gru1_hx] = gru_core(inp_1, [gru1_w_h, gru1_w_i, gru1_bias], gru1_hx, RNN_SIZE);
160
+ tf.dispose(gru1_hx_);
161
+
162
+ // GMM
163
+ const gmm_params = tf.add(tf.matMul(out_1, gmm_w), gmm_b);
164
+ [x,y,logstdx,logstdy,angle,log_weight,eos_logit] = tf.split(gmm_params, [5, 5, 5, 5, 5, 5, 1], 1);
165
+ // log_weight = tf.softmax(log_weight, 1);
166
+ // log_weight = tf.log(log_weight);
167
+ // log_weight = tf.mul(log_weight, 1. + bias);
168
+ // const idx = tf.multinomial(log_weight, 1).dataSync()[0];
169
+ // log_weight = tf.softmax(log_weight, 1);
170
+ // log_weight = tf.log(log_weight);
171
+ // log_weight = tf.mul(log_weight, 1. + bias);
172
+ const idx = tf.argMax(log_weight, 1).dataSync()[0];
173
+
174
+ x = x.dataSync()[idx];
175
+ y = y.dataSync()[idx];
176
+ const stdx = tf.exp(tf.sub(logstdx, bias)).dataSync()[idx];
177
+ const stdy = tf.exp(tf.sub(logstdy, bias)).dataSync()[idx];
178
+ angle = angle.dataSync()[idx];
179
+ e = tf.sigmoid(tf.mul(eos_logit, (1. + 0.*bias))).dataSync()[0];
180
+ const rx = rand_truncated_normal(-5, 5) * stdx;
181
+ const ry = rand_truncated_normal(-5, 5) * stdy;
182
+ x = x + Math.cos(-angle) * rx - Math.sin(-angle) * ry;
183
+ y = y + Math.sin(-angle) * rx + Math.cos(-angle) * ry;
184
+ if (Math.random() < e) {
185
+ e = 1.;
186
+ } else {
187
+ e = 0.;
188
+ }
189
+ input = tf.tensor([[x, y, e]]);
190
+ return [att_location, att_context, gru0_hx, gru1_hx, input];
191
+ });
192
+
193
+ [dx,dy,eos_] = input.dataSync();
194
+ dy = -dy * 3;
195
+ dx = dx * 3;
196
+ if (eos == 0.) {
197
+ ctx.beginPath();
198
+ ctx.moveTo(cur_x, cur_y, 0, 0);
199
+ ctx.lineTo(cur_x + dx, cur_y + dy);
200
+ ctx.stroke();
201
+ }
202
+ eos = eos_;
203
+ cur_x = cur_x + dx;
204
+ cur_y = cur_y + dy;
205
+
206
+ if (att_location.dataSync()[0] < original_text.length + 2) {
207
+ setTimeout(function() {loop(my_run_id);}, 0);
208
+ }
209
+ }
210
+ }
211
+
212
+ loop(run_id);
213
+ }
214
+
215
+ window.onload = function(e) {
216
+ //Setting up canvas
217
+ canvas = document.getElementById("hw-canvas");
218
+ ctx = canvas.getContext("2d");
219
+ ctx.canvas.width = window.innerWidth;
220
+ ctx.canvas.height = window.innerHeight;
221
+
222
+ }