Spaces:
Runtime error
Runtime error
Upload mac_cell.py
Browse files- mac_cell.py +592 -0
mac_cell.py
ADDED
@@ -0,0 +1,592 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
import numpy as np
|
3 |
+
import tensorflow as tf
|
4 |
+
|
5 |
+
import ops
|
6 |
+
from config import config
|
7 |
+
|
8 |
+
MACCellTuple = collections.namedtuple("MACCellTuple", ("control", "memory"))
|
9 |
+
|
10 |
+
'''
|
11 |
+
The MAC cell.
|
12 |
+
|
13 |
+
Recurrent cell for multi-step reasoning. Presented in https://arxiv.org/abs/1803.03067.
|
14 |
+
The cell has recurrent control and memory states that interact with the question
|
15 |
+
and knowledge base (image) respectively.
|
16 |
+
|
17 |
+
The hidden state structure is MACCellTuple(control, memory)
|
18 |
+
|
19 |
+
At each step the cell performs by calling to three subunits: control, read and write.
|
20 |
+
|
21 |
+
1. The Control Unit computes the control state by computing attention over the question words.
|
22 |
+
The control state represents the current reasoning operation the cell performs.
|
23 |
+
|
24 |
+
2. The Read Unit retrieves information from the knowledge base, given the control and previous
|
25 |
+
memory values, by computing 2-stages attention over the knowledge base.
|
26 |
+
|
27 |
+
3. The Write Unit integrates the retrieved information to the previous hidden memory state,
|
28 |
+
given the value of the control state, to perform the current reasoning operation.
|
29 |
+
'''
|
30 |
+
class MACCell(tf.nn.rnn_cell.RNNCell):
|
31 |
+
|
32 |
+
'''Initialize the MAC cell.
|
33 |
+
(Note that in the current version the cell is stateful --
|
34 |
+
updating its own internals when being called)
|
35 |
+
|
36 |
+
Args:
|
37 |
+
vecQuestions: the vector representation of the questions.
|
38 |
+
[batchSize, ctrlDim]
|
39 |
+
|
40 |
+
questionWords: the question words embeddings.
|
41 |
+
[batchSize, questionLength, ctrlDim]
|
42 |
+
|
43 |
+
questionCntxWords: the encoder outputs -- the "contextual" question words.
|
44 |
+
[batchSize, questionLength, ctrlDim]
|
45 |
+
|
46 |
+
questionLengths: the length of each question.
|
47 |
+
[batchSize]
|
48 |
+
|
49 |
+
memoryDropout: dropout on the memory state (Tensor scalar).
|
50 |
+
readDropout: dropout inside the read unit (Tensor scalar).
|
51 |
+
writeDropout: dropout on the new information that gets into the write unit (Tensor scalar).
|
52 |
+
|
53 |
+
batchSize: batch size (Tensor scalar).
|
54 |
+
train: train or test mod (Tensor boolean).
|
55 |
+
reuse: reuse cell
|
56 |
+
|
57 |
+
knowledgeBase:
|
58 |
+
'''
|
59 |
+
def __init__(self, vecQuestions, questionWords, questionCntxWords, questionLengths,
|
60 |
+
knowledgeBase, memoryDropout, readDropout, writeDropout,
|
61 |
+
batchSize, train, reuse = None):
|
62 |
+
|
63 |
+
self.vecQuestions = vecQuestions
|
64 |
+
self.questionWords = questionWords
|
65 |
+
self.questionCntxWords = questionCntxWords
|
66 |
+
self.questionLengths = questionLengths
|
67 |
+
|
68 |
+
self.knowledgeBase = knowledgeBase
|
69 |
+
|
70 |
+
self.dropouts = {}
|
71 |
+
self.dropouts["memory"] = memoryDropout
|
72 |
+
self.dropouts["read"] = readDropout
|
73 |
+
self.dropouts["write"] = writeDropout
|
74 |
+
|
75 |
+
self.none = tf.zeros((batchSize, 1), dtype = tf.float32)
|
76 |
+
|
77 |
+
self.batchSize = batchSize
|
78 |
+
self.train = train
|
79 |
+
self.reuse = reuse
|
80 |
+
|
81 |
+
'''
|
82 |
+
Cell state size.
|
83 |
+
'''
|
84 |
+
@property
|
85 |
+
def state_size(self):
|
86 |
+
return MACCellTuple(config.ctrlDim, config.memDim)
|
87 |
+
|
88 |
+
'''
|
89 |
+
Cell output size. Currently it doesn't have any outputs.
|
90 |
+
'''
|
91 |
+
@property
|
92 |
+
def output_size(self):
|
93 |
+
return 1
|
94 |
+
|
95 |
+
# pass encoder hidden states to control?
|
96 |
+
'''
|
97 |
+
The Control Unit: computes the new control state -- the reasoning operation,
|
98 |
+
by summing up the word embeddings according to a computed attention distribution.
|
99 |
+
|
100 |
+
The unit is recurrent: it receives the whole question and the previous control state,
|
101 |
+
merge them together (resulting in the "continuous control"), and then uses that
|
102 |
+
to compute attentions over the question words. Finally, it combines the words
|
103 |
+
together according to the attention distribution to get the new control state.
|
104 |
+
|
105 |
+
Args:
|
106 |
+
controlInput: external inputs to control unit (the question vector).
|
107 |
+
[batchSize, ctrlDim]
|
108 |
+
|
109 |
+
inWords: the representation of the words used to compute the attention.
|
110 |
+
[batchSize, questionLength, ctrlDim]
|
111 |
+
|
112 |
+
outWords: the representation of the words that are summed up.
|
113 |
+
(by default inWords == outWords)
|
114 |
+
[batchSize, questionLength, ctrlDim]
|
115 |
+
|
116 |
+
questionLengths: the length of each question.
|
117 |
+
[batchSize]
|
118 |
+
|
119 |
+
control: the previous control hidden state value.
|
120 |
+
[batchSize, ctrlDim]
|
121 |
+
|
122 |
+
contControl: optional corresponding continuous control state
|
123 |
+
(before casting the attention over the words).
|
124 |
+
[batchSize, ctrlDim]
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
the new control state
|
128 |
+
[batchSize, ctrlDim]
|
129 |
+
|
130 |
+
the continuous (pre-attention) control
|
131 |
+
[batchSize, ctrlDim]
|
132 |
+
'''
|
133 |
+
def control(self, controlInput, inWords, outWords, questionLengths,
|
134 |
+
control, contControl = None, name = "", reuse = None):
|
135 |
+
|
136 |
+
with tf.variable_scope("control" + name, reuse = reuse):
|
137 |
+
dim = config.ctrlDim
|
138 |
+
|
139 |
+
## Step 1: compute "continuous" control state given previous control and question.
|
140 |
+
# control inputs: question and previous control
|
141 |
+
newContControl = controlInput
|
142 |
+
if config.controlFeedPrev:
|
143 |
+
newContControl = control if config.controlFeedPrevAtt else contControl
|
144 |
+
if config.controlFeedInputs:
|
145 |
+
newContControl = tf.concat([newContControl, controlInput], axis = -1)
|
146 |
+
dim += config.ctrlDim
|
147 |
+
|
148 |
+
# merge inputs together
|
149 |
+
newContControl = ops.linear(newContControl, dim, config.ctrlDim,
|
150 |
+
act = config.controlContAct, name = "contControl")
|
151 |
+
dim = config.ctrlDim
|
152 |
+
|
153 |
+
## Step 2: compute attention distribution over words and sum them up accordingly.
|
154 |
+
# compute interactions with question words
|
155 |
+
interactions = tf.expand_dims(newContControl, axis = 1) * inWords
|
156 |
+
|
157 |
+
# optionally concatenate words
|
158 |
+
if config.controlConcatWords:
|
159 |
+
interactions = tf.concat([interactions, inWords], axis = -1)
|
160 |
+
dim += config.ctrlDim
|
161 |
+
|
162 |
+
# optional projection
|
163 |
+
if config.controlProj:
|
164 |
+
interactions = ops.linear(interactions, dim, config.ctrlDim,
|
165 |
+
act = config.controlProjAct)
|
166 |
+
dim = config.ctrlDim
|
167 |
+
|
168 |
+
# compute attention distribution over words and summarize them accordingly
|
169 |
+
logits = ops.inter2logits(interactions, dim)
|
170 |
+
# self.interL = (interW, interb)
|
171 |
+
|
172 |
+
# if config.controlCoverage:
|
173 |
+
# logits += coverageBias * coverage
|
174 |
+
|
175 |
+
attention = tf.nn.softmax(ops.expMask(logits, questionLengths))
|
176 |
+
self.attentions["question"].append(attention)
|
177 |
+
|
178 |
+
# if config.controlCoverage:
|
179 |
+
# coverage += attention # Add logits instead?
|
180 |
+
|
181 |
+
newControl = ops.att2Smry(attention, outWords)
|
182 |
+
|
183 |
+
# ablation: use continuous control (pre-attention) instead
|
184 |
+
if config.controlContinuous:
|
185 |
+
newControl = newContControl
|
186 |
+
|
187 |
+
return newControl, newContControl
|
188 |
+
|
189 |
+
'''
|
190 |
+
The read unit extracts relevant information from the knowledge base given the
|
191 |
+
cell's memory and control states. It computes attention distribution over
|
192 |
+
the knowledge base by comparing it first to the memory and then to the control.
|
193 |
+
Finally, it uses the attention distribution to sum up the knowledge base accordingly,
|
194 |
+
resulting in an extraction of relevant information.
|
195 |
+
|
196 |
+
Args:
|
197 |
+
knowledge base: representation of the knowledge base (image).
|
198 |
+
[batchSize, kbSize (Height * Width), memDim]
|
199 |
+
|
200 |
+
memory: the cell's memory state
|
201 |
+
[batchSize, memDim]
|
202 |
+
|
203 |
+
control: the cell's control state
|
204 |
+
[batchSize, ctrlDim]
|
205 |
+
|
206 |
+
Returns the information extracted.
|
207 |
+
[batchSize, memDim]
|
208 |
+
'''
|
209 |
+
def read(self, knowledgeBase, memory, control, name = "", reuse = None):
|
210 |
+
with tf.variable_scope("read" + name, reuse = reuse):
|
211 |
+
dim = config.memDim
|
212 |
+
|
213 |
+
## memory dropout
|
214 |
+
if config.memoryVariationalDropout:
|
215 |
+
memory = ops.applyVarDpMask(memory, self.memDpMask, self.dropouts["memory"])
|
216 |
+
else:
|
217 |
+
memory = tf.nn.dropout(memory, self.dropouts["memory"])
|
218 |
+
|
219 |
+
## Step 1: knowledge base / memory interactions
|
220 |
+
# parameters for knowledge base and memory projection
|
221 |
+
proj = None
|
222 |
+
if config.readProjInputs:
|
223 |
+
proj = {"dim": config.attDim, "shared": config.readProjShared, "dropout": self.dropouts["read"] }
|
224 |
+
dim = config.attDim
|
225 |
+
|
226 |
+
# parameters for concatenating knowledge base elements
|
227 |
+
concat = {"x": config.readMemConcatKB, "proj": config.readMemConcatProj}
|
228 |
+
|
229 |
+
# compute interactions between knowledge base and memory
|
230 |
+
interactions, interDim = ops.mul(x = knowledgeBase, y = memory, dim = config.memDim,
|
231 |
+
proj = proj, concat = concat, interMod = config.readMemAttType, name = "memInter")
|
232 |
+
|
233 |
+
projectedKB = proj.get("x") if proj else None
|
234 |
+
|
235 |
+
# project memory interactions back to hidden dimension
|
236 |
+
if config.readMemProj:
|
237 |
+
interactions = ops.linear(interactions, interDim, dim, act = config.readMemAct,
|
238 |
+
name = "memKbProj")
|
239 |
+
else:
|
240 |
+
dim = interDim
|
241 |
+
|
242 |
+
## Step 2: compute interactions with control
|
243 |
+
if config.readCtrl:
|
244 |
+
# compute interactions with control
|
245 |
+
if config.ctrlDim != dim:
|
246 |
+
control = ops.linear(control, ctrlDim, dim, name = "ctrlProj")
|
247 |
+
|
248 |
+
interactions, interDim = ops.mul(interactions, control, dim,
|
249 |
+
interMod = config.readCtrlAttType, concat = {"x": config.readCtrlConcatInter},
|
250 |
+
name = "ctrlInter")
|
251 |
+
|
252 |
+
# optionally concatenate knowledge base elements
|
253 |
+
if config.readCtrlConcatKB:
|
254 |
+
if config.readCtrlConcatProj:
|
255 |
+
addedInp, addedDim = projectedKB, config.attDim
|
256 |
+
else:
|
257 |
+
addedInp, addedDim = knowledgeBase, config.memDim
|
258 |
+
interactions = tf.concat([interactions, addedInp], axis = -1)
|
259 |
+
dim += addedDim
|
260 |
+
|
261 |
+
# optional nonlinearity
|
262 |
+
interactions = ops.activations[config.readCtrlAct](interactions)
|
263 |
+
|
264 |
+
## Step 3: sum attentions up over the knowledge base
|
265 |
+
# transform vectors to attention distribution
|
266 |
+
attention = ops.inter2att(interactions, dim, dropout = self.dropouts["read"])
|
267 |
+
|
268 |
+
self.attentions["kb"].append(attention)
|
269 |
+
|
270 |
+
# optionally use projected knowledge base instead of original
|
271 |
+
if config.readSmryKBProj:
|
272 |
+
knowledgeBase = projectedKB
|
273 |
+
|
274 |
+
# sum up the knowledge base according to the distribution
|
275 |
+
information = ops.att2Smry(attention, knowledgeBase)
|
276 |
+
|
277 |
+
return information
|
278 |
+
|
279 |
+
'''
|
280 |
+
The write unit integrates newly retrieved information (from the read unit),
|
281 |
+
with the cell's previous memory hidden state, resulting in a new memory value.
|
282 |
+
The unit optionally supports:
|
283 |
+
1. Self-attention to previous control / memory states, in order to consider previous steps
|
284 |
+
in the reasoning process.
|
285 |
+
2. Gating between the new memory and previous memory states, to allow dynamic adjustment
|
286 |
+
of the reasoning process length.
|
287 |
+
|
288 |
+
Args:
|
289 |
+
memory: the cell's memory state
|
290 |
+
[batchSize, memDim]
|
291 |
+
|
292 |
+
info: the information to integrate with the memory
|
293 |
+
[batchSize, memDim]
|
294 |
+
|
295 |
+
control: the cell's control state
|
296 |
+
[batchSize, ctrlDim]
|
297 |
+
|
298 |
+
contControl: optional corresponding continuous control state
|
299 |
+
(before casting the attention over the words).
|
300 |
+
[batchSize, ctrlDim]
|
301 |
+
|
302 |
+
Return the new memory
|
303 |
+
[batchSize, memDim]
|
304 |
+
'''
|
305 |
+
def write(self, memory, info, control, contControl = None, name = "", reuse = None):
|
306 |
+
with tf.variable_scope("write" + name, reuse = reuse):
|
307 |
+
|
308 |
+
# optionally project info
|
309 |
+
if config.writeInfoProj:
|
310 |
+
info = ops.linear(info, config.memDim, config.memDim, name = "info")
|
311 |
+
|
312 |
+
# optional info nonlinearity
|
313 |
+
info = ops.activations[config.writeInfoAct](info)
|
314 |
+
|
315 |
+
# compute self-attention vector based on previous controls and memories
|
316 |
+
if config.writeSelfAtt:
|
317 |
+
selfControl = control
|
318 |
+
if config.writeSelfAttMod == "CONT":
|
319 |
+
selfControl = contControl
|
320 |
+
# elif config.writeSelfAttMod == "POST":
|
321 |
+
# selfControl = postControl
|
322 |
+
selfControl = ops.linear(selfControl, config.ctrlDim, config.ctrlDim, name = "ctrlProj")
|
323 |
+
|
324 |
+
interactions = self.controls * tf.expand_dims(selfControl, axis = 1)
|
325 |
+
|
326 |
+
# if config.selfAttShareInter:
|
327 |
+
# selfAttlogits = self.linearP(selfAttInter, config.encDim, 1, self.interL[0], self.interL[1], name = "modSelfAttInter")
|
328 |
+
attention = ops.inter2att(interactions, config.ctrlDim, name = "selfAttention")
|
329 |
+
self.attentions["self"].append(attention)
|
330 |
+
selfSmry = ops.att2Smry(attention, self.memories)
|
331 |
+
|
332 |
+
# get write unit inputs: previous memory, the new info, optionally self-attention / control
|
333 |
+
newMemory, dim = memory, config.memDim
|
334 |
+
if config.writeInputs == "INFO":
|
335 |
+
newMemory = info
|
336 |
+
elif config.writeInputs == "SUM":
|
337 |
+
newMemory += info
|
338 |
+
elif config.writeInputs == "BOTH":
|
339 |
+
newMemory, dim = ops.concat(newMemory, info, dim, mul = config.writeConcatMul)
|
340 |
+
# else: MEM
|
341 |
+
|
342 |
+
if config.writeSelfAtt:
|
343 |
+
newMemory = tf.concat([newMemory, selfSmry], axis = -1)
|
344 |
+
dim += config.memDim
|
345 |
+
|
346 |
+
if config.writeMergeCtrl:
|
347 |
+
newMemory = tf.concat([newMemory, control], axis = -1)
|
348 |
+
dim += config.memDim
|
349 |
+
|
350 |
+
# project memory back to memory dimension
|
351 |
+
if config.writeMemProj or (dim != config.memDim):
|
352 |
+
newMemory = ops.linear(newMemory, dim, config.memDim, name = "newMemory")
|
353 |
+
|
354 |
+
# optional memory nonlinearity
|
355 |
+
newMemory = ops.activations[config.writeMemAct](newMemory)
|
356 |
+
|
357 |
+
# write unit gate
|
358 |
+
if config.writeGate:
|
359 |
+
gateDim = config.memDim
|
360 |
+
if config.writeGateShared:
|
361 |
+
gateDim = 1
|
362 |
+
|
363 |
+
z = tf.sigmoid(ops.linear(control, config.ctrlDim, gateDim, name = "gate", bias = config.writeGateBias))
|
364 |
+
|
365 |
+
self.attentions["gate"].append(z)
|
366 |
+
|
367 |
+
newMemory = newMemory * z + memory * (1 - z)
|
368 |
+
|
369 |
+
# optional batch normalization
|
370 |
+
if config.memoryBN:
|
371 |
+
newMemory = tf.contrib.layers.batch_norm(newMemory, decay = config.bnDecay,
|
372 |
+
center = config.bnCenter, scale = config.bnScale,
|
373 |
+
is_training = self.train, updates_collections = None)
|
374 |
+
|
375 |
+
return newMemory
|
376 |
+
|
377 |
+
def memAutoEnc(newMemory, info, control, name = "", reuse = None):
|
378 |
+
with tf.variable_scope("memAutoEnc" + name, reuse = reuse):
|
379 |
+
# inputs to auto encoder
|
380 |
+
features = info if config.autoEncMemInputs == "INFO" else newMemory
|
381 |
+
features = ops.linear(features, config.memDim, config.ctrlDim,
|
382 |
+
act = config.autoEncMemAct, name = "aeMem")
|
383 |
+
|
384 |
+
# reconstruct control
|
385 |
+
if config.autoEncMemLoss == "CONT":
|
386 |
+
loss = tf.reduce_mean(tf.squared_difference(control, features))
|
387 |
+
else:
|
388 |
+
interactions, dim = ops.mul(self.questionCntxWords, features, config.ctrlDim,
|
389 |
+
concat = {"x": config.autoEncMemCnct}, mulBias = config.mulBias, name = "aeMem")
|
390 |
+
|
391 |
+
logits = ops.inter2logits(interactions, dim)
|
392 |
+
logits = self.expMask(logits, self.questionLengths)
|
393 |
+
|
394 |
+
# reconstruct word attentions
|
395 |
+
if config.autoEncMemLoss == "PROB":
|
396 |
+
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
|
397 |
+
labels = self.attentions["question"][-1], logits = logits))
|
398 |
+
|
399 |
+
# reconstruct control through words attentions
|
400 |
+
else:
|
401 |
+
attention = tf.nn.softmax(logits)
|
402 |
+
summary = ops.att2Smry(attention, self.questionCntxWords)
|
403 |
+
loss = tf.reduce_mean(tf.squared_difference(control, summary))
|
404 |
+
|
405 |
+
return loss
|
406 |
+
|
407 |
+
'''
|
408 |
+
Call the cell to get new control and memory states.
|
409 |
+
|
410 |
+
Args:
|
411 |
+
inputs: in the current implementation the cell don't get recurrent inputs
|
412 |
+
every iteration (argument for comparability with rnn interface).
|
413 |
+
|
414 |
+
state: the cell current state (control, memory)
|
415 |
+
MACCellTuple([batchSize, ctrlDim],[batchSize, memDim])
|
416 |
+
|
417 |
+
Returns the new state -- the new memory and control values.
|
418 |
+
MACCellTuple([batchSize, ctrlDim],[batchSize, memDim])
|
419 |
+
'''
|
420 |
+
def __call__(self, inputs, state, scope = None):
|
421 |
+
scope = scope or type(self).__name__
|
422 |
+
with tf.variable_scope(scope, reuse = self.reuse): # as tfscope
|
423 |
+
control = state.control
|
424 |
+
memory = state.memory
|
425 |
+
|
426 |
+
# cell sharing
|
427 |
+
inputName = "qInput"
|
428 |
+
inputNameU = "qInputU"
|
429 |
+
inputReuseU = inputReuse = (self.iteration > 0)
|
430 |
+
if config.controlInputUnshared:
|
431 |
+
inputNameU = "qInput%d" % self.iteration
|
432 |
+
inputReuseU = None
|
433 |
+
|
434 |
+
cellName = ""
|
435 |
+
cellReuse = (self.iteration > 0)
|
436 |
+
if config.unsharedCells:
|
437 |
+
cellName = str(self.iteration)
|
438 |
+
cellReuse = None
|
439 |
+
|
440 |
+
## control unit
|
441 |
+
# prepare question input to control
|
442 |
+
controlInput = ops.linear(self.vecQuestions, config.ctrlDim, config.ctrlDim,
|
443 |
+
name = inputName, reuse = inputReuse)
|
444 |
+
|
445 |
+
controlInput = ops.activations[config.controlInputAct](controlInput)
|
446 |
+
|
447 |
+
controlInput = ops.linear(controlInput, config.ctrlDim, config.ctrlDim,
|
448 |
+
name = inputNameU, reuse = inputReuseU)
|
449 |
+
|
450 |
+
newControl, self.contControl = self.control(controlInput, self.inWords, self.outWords,
|
451 |
+
self.questionLengths, control, self.contControl, name = cellName, reuse = cellReuse)
|
452 |
+
|
453 |
+
# read unit
|
454 |
+
# ablation: use whole question as control
|
455 |
+
if config.controlWholeQ:
|
456 |
+
newControl = self.vecQuestions
|
457 |
+
# ops.linear(self.vecQuestions, config.ctrlDim, projDim, name = "qMod")
|
458 |
+
|
459 |
+
info = self.read(self.knowledgeBase, memory, newControl, name = cellName, reuse = cellReuse)
|
460 |
+
|
461 |
+
if config.writeDropout < 1.0:
|
462 |
+
# write unit
|
463 |
+
info = tf.nn.dropout(info, self.dropouts["write"])
|
464 |
+
|
465 |
+
newMemory = self.write(memory, info, newControl, self.contControl, name = cellName, reuse = cellReuse)
|
466 |
+
|
467 |
+
# add auto encoder loss for memory
|
468 |
+
# if config.autoEncMem:
|
469 |
+
# self.autoEncLosses["memory"] += memAutoEnc(newMemory, info, newControl)
|
470 |
+
|
471 |
+
# append as standard list?
|
472 |
+
self.controls = tf.concat([self.controls, tf.expand_dims(newControl, axis = 1)], axis = 1)
|
473 |
+
self.memories = tf.concat([self.memories, tf.expand_dims(newMemory, axis = 1)], axis = 1)
|
474 |
+
self.infos = tf.concat([self.infos, tf.expand_dims(info, axis = 1)], axis = 1)
|
475 |
+
|
476 |
+
# self.contControls = tf.concat([self.contControls, tf.expand_dims(contControl, axis = 1)], axis = 1)
|
477 |
+
# self.postControls = tf.concat([self.controls, tf.expand_dims(postControls, axis = 1)], axis = 1)
|
478 |
+
|
479 |
+
newState = MACCellTuple(newControl, newMemory)
|
480 |
+
return self.none, newState
|
481 |
+
|
482 |
+
'''
|
483 |
+
Initializes the a hidden state to based on the value of the initType:
|
484 |
+
"PRM" for parametric initialization
|
485 |
+
"ZERO" for zero initialization
|
486 |
+
"Q" to initialize to question vectors.
|
487 |
+
|
488 |
+
Args:
|
489 |
+
name: the state variable name.
|
490 |
+
dim: the dimension of the state.
|
491 |
+
initType: the type of the initialization
|
492 |
+
batchSize: the batch size
|
493 |
+
|
494 |
+
Returns the initialized hidden state.
|
495 |
+
'''
|
496 |
+
def initState(self, name, dim, initType, batchSize):
|
497 |
+
if initType == "PRM":
|
498 |
+
prm = tf.get_variable(name, shape = (dim, ),
|
499 |
+
initializer = tf.random_normal_initializer())
|
500 |
+
initState = tf.tile(tf.expand_dims(prm, axis = 0), [batchSize, 1])
|
501 |
+
elif initType == "ZERO":
|
502 |
+
initState = tf.zeros((batchSize, dim), dtype = tf.float32)
|
503 |
+
else: # "Q"
|
504 |
+
initState = self.vecQuestions
|
505 |
+
return initState
|
506 |
+
|
507 |
+
'''
|
508 |
+
Add a parametric null word to the questions.
|
509 |
+
|
510 |
+
Args:
|
511 |
+
words: the words to add a null word to.
|
512 |
+
[batchSize, questionLentgth]
|
513 |
+
|
514 |
+
lengths: question lengths.
|
515 |
+
[batchSize]
|
516 |
+
|
517 |
+
Returns the updated word sequence and lengths.
|
518 |
+
'''
|
519 |
+
def addNullWord(words, lengths):
|
520 |
+
nullWord = tf.get_variable("zeroWord", shape = (1 , config.ctrlDim), initializer = tf.random_normal_initializer())
|
521 |
+
nullWord = tf.tile(tf.expand_dims(nullWord, axis = 0), [self.batchSize, 1, 1])
|
522 |
+
words = tf.concat([nullWord, words], axis = 1)
|
523 |
+
lengths += 1
|
524 |
+
return words, lengths
|
525 |
+
|
526 |
+
'''
|
527 |
+
Initializes the cell internal state (currently it's stateful). In particular,
|
528 |
+
1. Data-structures (lists of attention maps and accumulated losses).
|
529 |
+
2. The memory and control states.
|
530 |
+
3. The knowledge base (optionally merging it with the question vectors)
|
531 |
+
4. The question words used by the cell (either the original word embeddings, or the
|
532 |
+
encoder outputs, with optional projection).
|
533 |
+
|
534 |
+
Args:
|
535 |
+
batchSize: the batch size
|
536 |
+
|
537 |
+
Returns the initial cell state.
|
538 |
+
'''
|
539 |
+
def zero_state(self, batchSize, dtype = tf.float32):
|
540 |
+
## initialize data-structures
|
541 |
+
self.attentions = {"kb": [], "question": [], "self": [], "gate": []}
|
542 |
+
self.autoEncLosses = {"control": tf.constant(0.0), "memory": tf.constant(0.0)}
|
543 |
+
|
544 |
+
|
545 |
+
## initialize state
|
546 |
+
initialControl = self.initState("initCtrl", config.ctrlDim, config.initCtrl, batchSize)
|
547 |
+
initialMemory = self.initState("initMem", config.memDim, config.initMem, batchSize)
|
548 |
+
|
549 |
+
self.controls = tf.expand_dims(initialControl, axis = 1)
|
550 |
+
self.memories = tf.expand_dims(initialMemory, axis = 1)
|
551 |
+
self.infos = tf.expand_dims(initialMemory, axis = 1)
|
552 |
+
|
553 |
+
self.contControl = initialControl
|
554 |
+
# self.contControls = tf.expand_dims(initialControl, axis = 1)
|
555 |
+
# self.postControls = tf.expand_dims(initialControl, axis = 1)
|
556 |
+
|
557 |
+
|
558 |
+
## initialize knowledge base
|
559 |
+
# optionally merge question into knowledge base representation
|
560 |
+
if config.initKBwithQ != "NON":
|
561 |
+
iVecQuestions = ops.linear(self.vecQuestions, config.ctrlDim, config.memDim, name = "questions")
|
562 |
+
|
563 |
+
concatMul = (config.initKBwithQ == "MUL")
|
564 |
+
cnct, dim = ops.concat(self.knowledgeBase, iVecQuestions, config.memDim, mul = concatMul, expandY = True)
|
565 |
+
self.knowledgeBase = ops.linear(cnct, dim, config.memDim, name = "initKB")
|
566 |
+
|
567 |
+
|
568 |
+
## initialize question words
|
569 |
+
# choose question words to work with (original embeddings or encoder outputs)
|
570 |
+
words = self.questionCntxWords if config.controlContextual else self.questionWords
|
571 |
+
|
572 |
+
# optionally add parametric "null" word in the to all questions
|
573 |
+
if config.addNullWord:
|
574 |
+
words, questionLengths = self.addNullWord(words, questionLengths)
|
575 |
+
|
576 |
+
# project words
|
577 |
+
self.inWords = self.outWords = words
|
578 |
+
if config.controlInWordsProj or config.controlOutWordsProj:
|
579 |
+
pWords = ops.linear(words, config.ctrlDim, config.ctrlDim, name = "wordsProj")
|
580 |
+
self.inWords = pWords if config.controlInWordsProj else words
|
581 |
+
self.outWords = pWords if config.controlOutWordsProj else words
|
582 |
+
|
583 |
+
# if config.controlCoverage:
|
584 |
+
# self.coverage = tf.zeros((batchSize, tf.shape(words)[1]), dtype = tf.float32)
|
585 |
+
# self.coverageBias = tf.get_variable("coverageBias", shape = (),
|
586 |
+
# initializer = config.controlCoverageBias)
|
587 |
+
|
588 |
+
## initialize memory variational dropout mask
|
589 |
+
if config.memoryVariationalDropout:
|
590 |
+
self.memDpMask = ops.generateVarDpMask((batchSize, config.memDim), self.dropouts["memory"])
|
591 |
+
|
592 |
+
return MACCellTuple(initialControl, initialMemory)
|