ydin0771 commited on
Commit
be859e7
β€’
1 Parent(s): 94566d7

Upload mac_cell.py

Browse files
Files changed (1) hide show
  1. mac_cell.py +592 -0
mac_cell.py ADDED
@@ -0,0 +1,592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import numpy as np
3
+ import tensorflow as tf
4
+
5
+ import ops
6
+ from config import config
7
+
8
+ MACCellTuple = collections.namedtuple("MACCellTuple", ("control", "memory"))
9
+
10
+ '''
11
+ The MAC cell.
12
+
13
+ Recurrent cell for multi-step reasoning. Presented in https://arxiv.org/abs/1803.03067.
14
+ The cell has recurrent control and memory states that interact with the question
15
+ and knowledge base (image) respectively.
16
+
17
+ The hidden state structure is MACCellTuple(control, memory)
18
+
19
+ At each step the cell performs by calling to three subunits: control, read and write.
20
+
21
+ 1. The Control Unit computes the control state by computing attention over the question words.
22
+ The control state represents the current reasoning operation the cell performs.
23
+
24
+ 2. The Read Unit retrieves information from the knowledge base, given the control and previous
25
+ memory values, by computing 2-stages attention over the knowledge base.
26
+
27
+ 3. The Write Unit integrates the retrieved information to the previous hidden memory state,
28
+ given the value of the control state, to perform the current reasoning operation.
29
+ '''
30
+ class MACCell(tf.nn.rnn_cell.RNNCell):
31
+
32
+ '''Initialize the MAC cell.
33
+ (Note that in the current version the cell is stateful --
34
+ updating its own internals when being called)
35
+
36
+ Args:
37
+ vecQuestions: the vector representation of the questions.
38
+ [batchSize, ctrlDim]
39
+
40
+ questionWords: the question words embeddings.
41
+ [batchSize, questionLength, ctrlDim]
42
+
43
+ questionCntxWords: the encoder outputs -- the "contextual" question words.
44
+ [batchSize, questionLength, ctrlDim]
45
+
46
+ questionLengths: the length of each question.
47
+ [batchSize]
48
+
49
+ memoryDropout: dropout on the memory state (Tensor scalar).
50
+ readDropout: dropout inside the read unit (Tensor scalar).
51
+ writeDropout: dropout on the new information that gets into the write unit (Tensor scalar).
52
+
53
+ batchSize: batch size (Tensor scalar).
54
+ train: train or test mod (Tensor boolean).
55
+ reuse: reuse cell
56
+
57
+ knowledgeBase:
58
+ '''
59
+ def __init__(self, vecQuestions, questionWords, questionCntxWords, questionLengths,
60
+ knowledgeBase, memoryDropout, readDropout, writeDropout,
61
+ batchSize, train, reuse = None):
62
+
63
+ self.vecQuestions = vecQuestions
64
+ self.questionWords = questionWords
65
+ self.questionCntxWords = questionCntxWords
66
+ self.questionLengths = questionLengths
67
+
68
+ self.knowledgeBase = knowledgeBase
69
+
70
+ self.dropouts = {}
71
+ self.dropouts["memory"] = memoryDropout
72
+ self.dropouts["read"] = readDropout
73
+ self.dropouts["write"] = writeDropout
74
+
75
+ self.none = tf.zeros((batchSize, 1), dtype = tf.float32)
76
+
77
+ self.batchSize = batchSize
78
+ self.train = train
79
+ self.reuse = reuse
80
+
81
+ '''
82
+ Cell state size.
83
+ '''
84
+ @property
85
+ def state_size(self):
86
+ return MACCellTuple(config.ctrlDim, config.memDim)
87
+
88
+ '''
89
+ Cell output size. Currently it doesn't have any outputs.
90
+ '''
91
+ @property
92
+ def output_size(self):
93
+ return 1
94
+
95
+ # pass encoder hidden states to control?
96
+ '''
97
+ The Control Unit: computes the new control state -- the reasoning operation,
98
+ by summing up the word embeddings according to a computed attention distribution.
99
+
100
+ The unit is recurrent: it receives the whole question and the previous control state,
101
+ merge them together (resulting in the "continuous control"), and then uses that
102
+ to compute attentions over the question words. Finally, it combines the words
103
+ together according to the attention distribution to get the new control state.
104
+
105
+ Args:
106
+ controlInput: external inputs to control unit (the question vector).
107
+ [batchSize, ctrlDim]
108
+
109
+ inWords: the representation of the words used to compute the attention.
110
+ [batchSize, questionLength, ctrlDim]
111
+
112
+ outWords: the representation of the words that are summed up.
113
+ (by default inWords == outWords)
114
+ [batchSize, questionLength, ctrlDim]
115
+
116
+ questionLengths: the length of each question.
117
+ [batchSize]
118
+
119
+ control: the previous control hidden state value.
120
+ [batchSize, ctrlDim]
121
+
122
+ contControl: optional corresponding continuous control state
123
+ (before casting the attention over the words).
124
+ [batchSize, ctrlDim]
125
+
126
+ Returns:
127
+ the new control state
128
+ [batchSize, ctrlDim]
129
+
130
+ the continuous (pre-attention) control
131
+ [batchSize, ctrlDim]
132
+ '''
133
+ def control(self, controlInput, inWords, outWords, questionLengths,
134
+ control, contControl = None, name = "", reuse = None):
135
+
136
+ with tf.variable_scope("control" + name, reuse = reuse):
137
+ dim = config.ctrlDim
138
+
139
+ ## Step 1: compute "continuous" control state given previous control and question.
140
+ # control inputs: question and previous control
141
+ newContControl = controlInput
142
+ if config.controlFeedPrev:
143
+ newContControl = control if config.controlFeedPrevAtt else contControl
144
+ if config.controlFeedInputs:
145
+ newContControl = tf.concat([newContControl, controlInput], axis = -1)
146
+ dim += config.ctrlDim
147
+
148
+ # merge inputs together
149
+ newContControl = ops.linear(newContControl, dim, config.ctrlDim,
150
+ act = config.controlContAct, name = "contControl")
151
+ dim = config.ctrlDim
152
+
153
+ ## Step 2: compute attention distribution over words and sum them up accordingly.
154
+ # compute interactions with question words
155
+ interactions = tf.expand_dims(newContControl, axis = 1) * inWords
156
+
157
+ # optionally concatenate words
158
+ if config.controlConcatWords:
159
+ interactions = tf.concat([interactions, inWords], axis = -1)
160
+ dim += config.ctrlDim
161
+
162
+ # optional projection
163
+ if config.controlProj:
164
+ interactions = ops.linear(interactions, dim, config.ctrlDim,
165
+ act = config.controlProjAct)
166
+ dim = config.ctrlDim
167
+
168
+ # compute attention distribution over words and summarize them accordingly
169
+ logits = ops.inter2logits(interactions, dim)
170
+ # self.interL = (interW, interb)
171
+
172
+ # if config.controlCoverage:
173
+ # logits += coverageBias * coverage
174
+
175
+ attention = tf.nn.softmax(ops.expMask(logits, questionLengths))
176
+ self.attentions["question"].append(attention)
177
+
178
+ # if config.controlCoverage:
179
+ # coverage += attention # Add logits instead?
180
+
181
+ newControl = ops.att2Smry(attention, outWords)
182
+
183
+ # ablation: use continuous control (pre-attention) instead
184
+ if config.controlContinuous:
185
+ newControl = newContControl
186
+
187
+ return newControl, newContControl
188
+
189
+ '''
190
+ The read unit extracts relevant information from the knowledge base given the
191
+ cell's memory and control states. It computes attention distribution over
192
+ the knowledge base by comparing it first to the memory and then to the control.
193
+ Finally, it uses the attention distribution to sum up the knowledge base accordingly,
194
+ resulting in an extraction of relevant information.
195
+
196
+ Args:
197
+ knowledge base: representation of the knowledge base (image).
198
+ [batchSize, kbSize (Height * Width), memDim]
199
+
200
+ memory: the cell's memory state
201
+ [batchSize, memDim]
202
+
203
+ control: the cell's control state
204
+ [batchSize, ctrlDim]
205
+
206
+ Returns the information extracted.
207
+ [batchSize, memDim]
208
+ '''
209
+ def read(self, knowledgeBase, memory, control, name = "", reuse = None):
210
+ with tf.variable_scope("read" + name, reuse = reuse):
211
+ dim = config.memDim
212
+
213
+ ## memory dropout
214
+ if config.memoryVariationalDropout:
215
+ memory = ops.applyVarDpMask(memory, self.memDpMask, self.dropouts["memory"])
216
+ else:
217
+ memory = tf.nn.dropout(memory, self.dropouts["memory"])
218
+
219
+ ## Step 1: knowledge base / memory interactions
220
+ # parameters for knowledge base and memory projection
221
+ proj = None
222
+ if config.readProjInputs:
223
+ proj = {"dim": config.attDim, "shared": config.readProjShared, "dropout": self.dropouts["read"] }
224
+ dim = config.attDim
225
+
226
+ # parameters for concatenating knowledge base elements
227
+ concat = {"x": config.readMemConcatKB, "proj": config.readMemConcatProj}
228
+
229
+ # compute interactions between knowledge base and memory
230
+ interactions, interDim = ops.mul(x = knowledgeBase, y = memory, dim = config.memDim,
231
+ proj = proj, concat = concat, interMod = config.readMemAttType, name = "memInter")
232
+
233
+ projectedKB = proj.get("x") if proj else None
234
+
235
+ # project memory interactions back to hidden dimension
236
+ if config.readMemProj:
237
+ interactions = ops.linear(interactions, interDim, dim, act = config.readMemAct,
238
+ name = "memKbProj")
239
+ else:
240
+ dim = interDim
241
+
242
+ ## Step 2: compute interactions with control
243
+ if config.readCtrl:
244
+ # compute interactions with control
245
+ if config.ctrlDim != dim:
246
+ control = ops.linear(control, ctrlDim, dim, name = "ctrlProj")
247
+
248
+ interactions, interDim = ops.mul(interactions, control, dim,
249
+ interMod = config.readCtrlAttType, concat = {"x": config.readCtrlConcatInter},
250
+ name = "ctrlInter")
251
+
252
+ # optionally concatenate knowledge base elements
253
+ if config.readCtrlConcatKB:
254
+ if config.readCtrlConcatProj:
255
+ addedInp, addedDim = projectedKB, config.attDim
256
+ else:
257
+ addedInp, addedDim = knowledgeBase, config.memDim
258
+ interactions = tf.concat([interactions, addedInp], axis = -1)
259
+ dim += addedDim
260
+
261
+ # optional nonlinearity
262
+ interactions = ops.activations[config.readCtrlAct](interactions)
263
+
264
+ ## Step 3: sum attentions up over the knowledge base
265
+ # transform vectors to attention distribution
266
+ attention = ops.inter2att(interactions, dim, dropout = self.dropouts["read"])
267
+
268
+ self.attentions["kb"].append(attention)
269
+
270
+ # optionally use projected knowledge base instead of original
271
+ if config.readSmryKBProj:
272
+ knowledgeBase = projectedKB
273
+
274
+ # sum up the knowledge base according to the distribution
275
+ information = ops.att2Smry(attention, knowledgeBase)
276
+
277
+ return information
278
+
279
+ '''
280
+ The write unit integrates newly retrieved information (from the read unit),
281
+ with the cell's previous memory hidden state, resulting in a new memory value.
282
+ The unit optionally supports:
283
+ 1. Self-attention to previous control / memory states, in order to consider previous steps
284
+ in the reasoning process.
285
+ 2. Gating between the new memory and previous memory states, to allow dynamic adjustment
286
+ of the reasoning process length.
287
+
288
+ Args:
289
+ memory: the cell's memory state
290
+ [batchSize, memDim]
291
+
292
+ info: the information to integrate with the memory
293
+ [batchSize, memDim]
294
+
295
+ control: the cell's control state
296
+ [batchSize, ctrlDim]
297
+
298
+ contControl: optional corresponding continuous control state
299
+ (before casting the attention over the words).
300
+ [batchSize, ctrlDim]
301
+
302
+ Return the new memory
303
+ [batchSize, memDim]
304
+ '''
305
+ def write(self, memory, info, control, contControl = None, name = "", reuse = None):
306
+ with tf.variable_scope("write" + name, reuse = reuse):
307
+
308
+ # optionally project info
309
+ if config.writeInfoProj:
310
+ info = ops.linear(info, config.memDim, config.memDim, name = "info")
311
+
312
+ # optional info nonlinearity
313
+ info = ops.activations[config.writeInfoAct](info)
314
+
315
+ # compute self-attention vector based on previous controls and memories
316
+ if config.writeSelfAtt:
317
+ selfControl = control
318
+ if config.writeSelfAttMod == "CONT":
319
+ selfControl = contControl
320
+ # elif config.writeSelfAttMod == "POST":
321
+ # selfControl = postControl
322
+ selfControl = ops.linear(selfControl, config.ctrlDim, config.ctrlDim, name = "ctrlProj")
323
+
324
+ interactions = self.controls * tf.expand_dims(selfControl, axis = 1)
325
+
326
+ # if config.selfAttShareInter:
327
+ # selfAttlogits = self.linearP(selfAttInter, config.encDim, 1, self.interL[0], self.interL[1], name = "modSelfAttInter")
328
+ attention = ops.inter2att(interactions, config.ctrlDim, name = "selfAttention")
329
+ self.attentions["self"].append(attention)
330
+ selfSmry = ops.att2Smry(attention, self.memories)
331
+
332
+ # get write unit inputs: previous memory, the new info, optionally self-attention / control
333
+ newMemory, dim = memory, config.memDim
334
+ if config.writeInputs == "INFO":
335
+ newMemory = info
336
+ elif config.writeInputs == "SUM":
337
+ newMemory += info
338
+ elif config.writeInputs == "BOTH":
339
+ newMemory, dim = ops.concat(newMemory, info, dim, mul = config.writeConcatMul)
340
+ # else: MEM
341
+
342
+ if config.writeSelfAtt:
343
+ newMemory = tf.concat([newMemory, selfSmry], axis = -1)
344
+ dim += config.memDim
345
+
346
+ if config.writeMergeCtrl:
347
+ newMemory = tf.concat([newMemory, control], axis = -1)
348
+ dim += config.memDim
349
+
350
+ # project memory back to memory dimension
351
+ if config.writeMemProj or (dim != config.memDim):
352
+ newMemory = ops.linear(newMemory, dim, config.memDim, name = "newMemory")
353
+
354
+ # optional memory nonlinearity
355
+ newMemory = ops.activations[config.writeMemAct](newMemory)
356
+
357
+ # write unit gate
358
+ if config.writeGate:
359
+ gateDim = config.memDim
360
+ if config.writeGateShared:
361
+ gateDim = 1
362
+
363
+ z = tf.sigmoid(ops.linear(control, config.ctrlDim, gateDim, name = "gate", bias = config.writeGateBias))
364
+
365
+ self.attentions["gate"].append(z)
366
+
367
+ newMemory = newMemory * z + memory * (1 - z)
368
+
369
+ # optional batch normalization
370
+ if config.memoryBN:
371
+ newMemory = tf.contrib.layers.batch_norm(newMemory, decay = config.bnDecay,
372
+ center = config.bnCenter, scale = config.bnScale,
373
+ is_training = self.train, updates_collections = None)
374
+
375
+ return newMemory
376
+
377
+ def memAutoEnc(newMemory, info, control, name = "", reuse = None):
378
+ with tf.variable_scope("memAutoEnc" + name, reuse = reuse):
379
+ # inputs to auto encoder
380
+ features = info if config.autoEncMemInputs == "INFO" else newMemory
381
+ features = ops.linear(features, config.memDim, config.ctrlDim,
382
+ act = config.autoEncMemAct, name = "aeMem")
383
+
384
+ # reconstruct control
385
+ if config.autoEncMemLoss == "CONT":
386
+ loss = tf.reduce_mean(tf.squared_difference(control, features))
387
+ else:
388
+ interactions, dim = ops.mul(self.questionCntxWords, features, config.ctrlDim,
389
+ concat = {"x": config.autoEncMemCnct}, mulBias = config.mulBias, name = "aeMem")
390
+
391
+ logits = ops.inter2logits(interactions, dim)
392
+ logits = self.expMask(logits, self.questionLengths)
393
+
394
+ # reconstruct word attentions
395
+ if config.autoEncMemLoss == "PROB":
396
+ loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
397
+ labels = self.attentions["question"][-1], logits = logits))
398
+
399
+ # reconstruct control through words attentions
400
+ else:
401
+ attention = tf.nn.softmax(logits)
402
+ summary = ops.att2Smry(attention, self.questionCntxWords)
403
+ loss = tf.reduce_mean(tf.squared_difference(control, summary))
404
+
405
+ return loss
406
+
407
+ '''
408
+ Call the cell to get new control and memory states.
409
+
410
+ Args:
411
+ inputs: in the current implementation the cell don't get recurrent inputs
412
+ every iteration (argument for comparability with rnn interface).
413
+
414
+ state: the cell current state (control, memory)
415
+ MACCellTuple([batchSize, ctrlDim],[batchSize, memDim])
416
+
417
+ Returns the new state -- the new memory and control values.
418
+ MACCellTuple([batchSize, ctrlDim],[batchSize, memDim])
419
+ '''
420
+ def __call__(self, inputs, state, scope = None):
421
+ scope = scope or type(self).__name__
422
+ with tf.variable_scope(scope, reuse = self.reuse): # as tfscope
423
+ control = state.control
424
+ memory = state.memory
425
+
426
+ # cell sharing
427
+ inputName = "qInput"
428
+ inputNameU = "qInputU"
429
+ inputReuseU = inputReuse = (self.iteration > 0)
430
+ if config.controlInputUnshared:
431
+ inputNameU = "qInput%d" % self.iteration
432
+ inputReuseU = None
433
+
434
+ cellName = ""
435
+ cellReuse = (self.iteration > 0)
436
+ if config.unsharedCells:
437
+ cellName = str(self.iteration)
438
+ cellReuse = None
439
+
440
+ ## control unit
441
+ # prepare question input to control
442
+ controlInput = ops.linear(self.vecQuestions, config.ctrlDim, config.ctrlDim,
443
+ name = inputName, reuse = inputReuse)
444
+
445
+ controlInput = ops.activations[config.controlInputAct](controlInput)
446
+
447
+ controlInput = ops.linear(controlInput, config.ctrlDim, config.ctrlDim,
448
+ name = inputNameU, reuse = inputReuseU)
449
+
450
+ newControl, self.contControl = self.control(controlInput, self.inWords, self.outWords,
451
+ self.questionLengths, control, self.contControl, name = cellName, reuse = cellReuse)
452
+
453
+ # read unit
454
+ # ablation: use whole question as control
455
+ if config.controlWholeQ:
456
+ newControl = self.vecQuestions
457
+ # ops.linear(self.vecQuestions, config.ctrlDim, projDim, name = "qMod")
458
+
459
+ info = self.read(self.knowledgeBase, memory, newControl, name = cellName, reuse = cellReuse)
460
+
461
+ if config.writeDropout < 1.0:
462
+ # write unit
463
+ info = tf.nn.dropout(info, self.dropouts["write"])
464
+
465
+ newMemory = self.write(memory, info, newControl, self.contControl, name = cellName, reuse = cellReuse)
466
+
467
+ # add auto encoder loss for memory
468
+ # if config.autoEncMem:
469
+ # self.autoEncLosses["memory"] += memAutoEnc(newMemory, info, newControl)
470
+
471
+ # append as standard list?
472
+ self.controls = tf.concat([self.controls, tf.expand_dims(newControl, axis = 1)], axis = 1)
473
+ self.memories = tf.concat([self.memories, tf.expand_dims(newMemory, axis = 1)], axis = 1)
474
+ self.infos = tf.concat([self.infos, tf.expand_dims(info, axis = 1)], axis = 1)
475
+
476
+ # self.contControls = tf.concat([self.contControls, tf.expand_dims(contControl, axis = 1)], axis = 1)
477
+ # self.postControls = tf.concat([self.controls, tf.expand_dims(postControls, axis = 1)], axis = 1)
478
+
479
+ newState = MACCellTuple(newControl, newMemory)
480
+ return self.none, newState
481
+
482
+ '''
483
+ Initializes the a hidden state to based on the value of the initType:
484
+ "PRM" for parametric initialization
485
+ "ZERO" for zero initialization
486
+ "Q" to initialize to question vectors.
487
+
488
+ Args:
489
+ name: the state variable name.
490
+ dim: the dimension of the state.
491
+ initType: the type of the initialization
492
+ batchSize: the batch size
493
+
494
+ Returns the initialized hidden state.
495
+ '''
496
+ def initState(self, name, dim, initType, batchSize):
497
+ if initType == "PRM":
498
+ prm = tf.get_variable(name, shape = (dim, ),
499
+ initializer = tf.random_normal_initializer())
500
+ initState = tf.tile(tf.expand_dims(prm, axis = 0), [batchSize, 1])
501
+ elif initType == "ZERO":
502
+ initState = tf.zeros((batchSize, dim), dtype = tf.float32)
503
+ else: # "Q"
504
+ initState = self.vecQuestions
505
+ return initState
506
+
507
+ '''
508
+ Add a parametric null word to the questions.
509
+
510
+ Args:
511
+ words: the words to add a null word to.
512
+ [batchSize, questionLentgth]
513
+
514
+ lengths: question lengths.
515
+ [batchSize]
516
+
517
+ Returns the updated word sequence and lengths.
518
+ '''
519
+ def addNullWord(words, lengths):
520
+ nullWord = tf.get_variable("zeroWord", shape = (1 , config.ctrlDim), initializer = tf.random_normal_initializer())
521
+ nullWord = tf.tile(tf.expand_dims(nullWord, axis = 0), [self.batchSize, 1, 1])
522
+ words = tf.concat([nullWord, words], axis = 1)
523
+ lengths += 1
524
+ return words, lengths
525
+
526
+ '''
527
+ Initializes the cell internal state (currently it's stateful). In particular,
528
+ 1. Data-structures (lists of attention maps and accumulated losses).
529
+ 2. The memory and control states.
530
+ 3. The knowledge base (optionally merging it with the question vectors)
531
+ 4. The question words used by the cell (either the original word embeddings, or the
532
+ encoder outputs, with optional projection).
533
+
534
+ Args:
535
+ batchSize: the batch size
536
+
537
+ Returns the initial cell state.
538
+ '''
539
+ def zero_state(self, batchSize, dtype = tf.float32):
540
+ ## initialize data-structures
541
+ self.attentions = {"kb": [], "question": [], "self": [], "gate": []}
542
+ self.autoEncLosses = {"control": tf.constant(0.0), "memory": tf.constant(0.0)}
543
+
544
+
545
+ ## initialize state
546
+ initialControl = self.initState("initCtrl", config.ctrlDim, config.initCtrl, batchSize)
547
+ initialMemory = self.initState("initMem", config.memDim, config.initMem, batchSize)
548
+
549
+ self.controls = tf.expand_dims(initialControl, axis = 1)
550
+ self.memories = tf.expand_dims(initialMemory, axis = 1)
551
+ self.infos = tf.expand_dims(initialMemory, axis = 1)
552
+
553
+ self.contControl = initialControl
554
+ # self.contControls = tf.expand_dims(initialControl, axis = 1)
555
+ # self.postControls = tf.expand_dims(initialControl, axis = 1)
556
+
557
+
558
+ ## initialize knowledge base
559
+ # optionally merge question into knowledge base representation
560
+ if config.initKBwithQ != "NON":
561
+ iVecQuestions = ops.linear(self.vecQuestions, config.ctrlDim, config.memDim, name = "questions")
562
+
563
+ concatMul = (config.initKBwithQ == "MUL")
564
+ cnct, dim = ops.concat(self.knowledgeBase, iVecQuestions, config.memDim, mul = concatMul, expandY = True)
565
+ self.knowledgeBase = ops.linear(cnct, dim, config.memDim, name = "initKB")
566
+
567
+
568
+ ## initialize question words
569
+ # choose question words to work with (original embeddings or encoder outputs)
570
+ words = self.questionCntxWords if config.controlContextual else self.questionWords
571
+
572
+ # optionally add parametric "null" word in the to all questions
573
+ if config.addNullWord:
574
+ words, questionLengths = self.addNullWord(words, questionLengths)
575
+
576
+ # project words
577
+ self.inWords = self.outWords = words
578
+ if config.controlInWordsProj or config.controlOutWordsProj:
579
+ pWords = ops.linear(words, config.ctrlDim, config.ctrlDim, name = "wordsProj")
580
+ self.inWords = pWords if config.controlInWordsProj else words
581
+ self.outWords = pWords if config.controlOutWordsProj else words
582
+
583
+ # if config.controlCoverage:
584
+ # self.coverage = tf.zeros((batchSize, tf.shape(words)[1]), dtype = tf.float32)
585
+ # self.coverageBias = tf.get_variable("coverageBias", shape = (),
586
+ # initializer = config.controlCoverageBias)
587
+
588
+ ## initialize memory variational dropout mask
589
+ if config.memoryVariationalDropout:
590
+ self.memDpMask = ops.generateVarDpMask((batchSize, config.memDim), self.dropouts["memory"])
591
+
592
+ return MACCellTuple(initialControl, initialMemory)