Spaces:
Runtime error
Runtime error
import collections | |
import numpy as np | |
import tensorflow as tf | |
import ops | |
from config import config | |
MACCellTuple = collections.namedtuple("MACCellTuple", ("control", "memory")) | |
''' | |
The MAC cell. | |
Recurrent cell for multi-step reasoning. Presented in https://arxiv.org/abs/1803.03067. | |
The cell has recurrent control and memory states that interact with the question | |
and knowledge base (image) respectively. | |
The hidden state structure is MACCellTuple(control, memory) | |
At each step the cell performs by calling to three subunits: control, read and write. | |
1. The Control Unit computes the control state by computing attention over the question words. | |
The control state represents the current reasoning operation the cell performs. | |
2. The Read Unit retrieves information from the knowledge base, given the control and previous | |
memory values, by computing 2-stages attention over the knowledge base. | |
3. The Write Unit integrates the retrieved information to the previous hidden memory state, | |
given the value of the control state, to perform the current reasoning operation. | |
''' | |
class MACCell(tf.nn.rnn_cell.RNNCell): | |
'''Initialize the MAC cell. | |
(Note that in the current version the cell is stateful -- | |
updating its own internals when being called) | |
Args: | |
vecQuestions: the vector representation of the questions. | |
[batchSize, ctrlDim] | |
questionWords: the question words embeddings. | |
[batchSize, questionLength, ctrlDim] | |
questionCntxWords: the encoder outputs -- the "contextual" question words. | |
[batchSize, questionLength, ctrlDim] | |
questionLengths: the length of each question. | |
[batchSize] | |
memoryDropout: dropout on the memory state (Tensor scalar). | |
readDropout: dropout inside the read unit (Tensor scalar). | |
writeDropout: dropout on the new information that gets into the write unit (Tensor scalar). | |
batchSize: batch size (Tensor scalar). | |
train: train or test mod (Tensor boolean). | |
reuse: reuse cell | |
knowledgeBase: | |
''' | |
def __init__(self, vecQuestions, questionWords, questionCntxWords, questionLengths, | |
knowledgeBase, memoryDropout, readDropout, writeDropout, | |
batchSize, train, reuse = None): | |
self.vecQuestions = vecQuestions | |
self.questionWords = questionWords | |
self.questionCntxWords = questionCntxWords | |
self.questionLengths = questionLengths | |
self.knowledgeBase = knowledgeBase | |
self.dropouts = {} | |
self.dropouts["memory"] = memoryDropout | |
self.dropouts["read"] = readDropout | |
self.dropouts["write"] = writeDropout | |
self.none = tf.zeros((batchSize, 1), dtype = tf.float32) | |
self.batchSize = batchSize | |
self.train = train | |
self.reuse = reuse | |
''' | |
Cell state size. | |
''' | |
def state_size(self): | |
return MACCellTuple(config.ctrlDim, config.memDim) | |
''' | |
Cell output size. Currently it doesn't have any outputs. | |
''' | |
def output_size(self): | |
return 1 | |
# pass encoder hidden states to control? | |
''' | |
The Control Unit: computes the new control state -- the reasoning operation, | |
by summing up the word embeddings according to a computed attention distribution. | |
The unit is recurrent: it receives the whole question and the previous control state, | |
merge them together (resulting in the "continuous control"), and then uses that | |
to compute attentions over the question words. Finally, it combines the words | |
together according to the attention distribution to get the new control state. | |
Args: | |
controlInput: external inputs to control unit (the question vector). | |
[batchSize, ctrlDim] | |
inWords: the representation of the words used to compute the attention. | |
[batchSize, questionLength, ctrlDim] | |
outWords: the representation of the words that are summed up. | |
(by default inWords == outWords) | |
[batchSize, questionLength, ctrlDim] | |
questionLengths: the length of each question. | |
[batchSize] | |
control: the previous control hidden state value. | |
[batchSize, ctrlDim] | |
contControl: optional corresponding continuous control state | |
(before casting the attention over the words). | |
[batchSize, ctrlDim] | |
Returns: | |
the new control state | |
[batchSize, ctrlDim] | |
the continuous (pre-attention) control | |
[batchSize, ctrlDim] | |
''' | |
def control(self, controlInput, inWords, outWords, questionLengths, | |
control, contControl = None, name = "", reuse = None): | |
with tf.variable_scope("control" + name, reuse = reuse): | |
dim = config.ctrlDim | |
## Step 1: compute "continuous" control state given previous control and question. | |
# control inputs: question and previous control | |
newContControl = controlInput | |
if config.controlFeedPrev: | |
newContControl = control if config.controlFeedPrevAtt else contControl | |
if config.controlFeedInputs: | |
newContControl = tf.concat([newContControl, controlInput], axis = -1) | |
dim += config.ctrlDim | |
# merge inputs together | |
newContControl = ops.linear(newContControl, dim, config.ctrlDim, | |
act = config.controlContAct, name = "contControl") | |
dim = config.ctrlDim | |
## Step 2: compute attention distribution over words and sum them up accordingly. | |
# compute interactions with question words | |
interactions = tf.expand_dims(newContControl, axis = 1) * inWords | |
# optionally concatenate words | |
if config.controlConcatWords: | |
interactions = tf.concat([interactions, inWords], axis = -1) | |
dim += config.ctrlDim | |
# optional projection | |
if config.controlProj: | |
interactions = ops.linear(interactions, dim, config.ctrlDim, | |
act = config.controlProjAct) | |
dim = config.ctrlDim | |
# compute attention distribution over words and summarize them accordingly | |
logits = ops.inter2logits(interactions, dim) | |
# self.interL = (interW, interb) | |
# if config.controlCoverage: | |
# logits += coverageBias * coverage | |
attention = tf.nn.softmax(ops.expMask(logits, questionLengths)) | |
self.attentions["question"].append(attention) | |
# if config.controlCoverage: | |
# coverage += attention # Add logits instead? | |
newControl = ops.att2Smry(attention, outWords) | |
# ablation: use continuous control (pre-attention) instead | |
if config.controlContinuous: | |
newControl = newContControl | |
return newControl, newContControl | |
''' | |
The read unit extracts relevant information from the knowledge base given the | |
cell's memory and control states. It computes attention distribution over | |
the knowledge base by comparing it first to the memory and then to the control. | |
Finally, it uses the attention distribution to sum up the knowledge base accordingly, | |
resulting in an extraction of relevant information. | |
Args: | |
knowledge base: representation of the knowledge base (image). | |
[batchSize, kbSize (Height * Width), memDim] | |
memory: the cell's memory state | |
[batchSize, memDim] | |
control: the cell's control state | |
[batchSize, ctrlDim] | |
Returns the information extracted. | |
[batchSize, memDim] | |
''' | |
def read(self, knowledgeBase, memory, control, name = "", reuse = None): | |
with tf.variable_scope("read" + name, reuse = reuse): | |
dim = config.memDim | |
## memory dropout | |
if config.memoryVariationalDropout: | |
memory = ops.applyVarDpMask(memory, self.memDpMask, self.dropouts["memory"]) | |
else: | |
memory = tf.nn.dropout(memory, self.dropouts["memory"]) | |
## Step 1: knowledge base / memory interactions | |
# parameters for knowledge base and memory projection | |
proj = None | |
if config.readProjInputs: | |
proj = {"dim": config.attDim, "shared": config.readProjShared, "dropout": self.dropouts["read"] } | |
dim = config.attDim | |
# parameters for concatenating knowledge base elements | |
concat = {"x": config.readMemConcatKB, "proj": config.readMemConcatProj} | |
# compute interactions between knowledge base and memory | |
interactions, interDim = ops.mul(x = knowledgeBase, y = memory, dim = config.memDim, | |
proj = proj, concat = concat, interMod = config.readMemAttType, name = "memInter") | |
projectedKB = proj.get("x") if proj else None | |
# project memory interactions back to hidden dimension | |
if config.readMemProj: | |
interactions = ops.linear(interactions, interDim, dim, act = config.readMemAct, | |
name = "memKbProj") | |
else: | |
dim = interDim | |
## Step 2: compute interactions with control | |
if config.readCtrl: | |
# compute interactions with control | |
if config.ctrlDim != dim: | |
control = ops.linear(control, ctrlDim, dim, name = "ctrlProj") | |
interactions, interDim = ops.mul(interactions, control, dim, | |
interMod = config.readCtrlAttType, concat = {"x": config.readCtrlConcatInter}, | |
name = "ctrlInter") | |
# optionally concatenate knowledge base elements | |
if config.readCtrlConcatKB: | |
if config.readCtrlConcatProj: | |
addedInp, addedDim = projectedKB, config.attDim | |
else: | |
addedInp, addedDim = knowledgeBase, config.memDim | |
interactions = tf.concat([interactions, addedInp], axis = -1) | |
dim += addedDim | |
# optional nonlinearity | |
interactions = ops.activations[config.readCtrlAct](interactions) | |
## Step 3: sum attentions up over the knowledge base | |
# transform vectors to attention distribution | |
attention = ops.inter2att(interactions, dim, dropout = self.dropouts["read"]) | |
self.attentions["kb"].append(attention) | |
# optionally use projected knowledge base instead of original | |
if config.readSmryKBProj: | |
knowledgeBase = projectedKB | |
# sum up the knowledge base according to the distribution | |
information = ops.att2Smry(attention, knowledgeBase) | |
return information | |
''' | |
The write unit integrates newly retrieved information (from the read unit), | |
with the cell's previous memory hidden state, resulting in a new memory value. | |
The unit optionally supports: | |
1. Self-attention to previous control / memory states, in order to consider previous steps | |
in the reasoning process. | |
2. Gating between the new memory and previous memory states, to allow dynamic adjustment | |
of the reasoning process length. | |
Args: | |
memory: the cell's memory state | |
[batchSize, memDim] | |
info: the information to integrate with the memory | |
[batchSize, memDim] | |
control: the cell's control state | |
[batchSize, ctrlDim] | |
contControl: optional corresponding continuous control state | |
(before casting the attention over the words). | |
[batchSize, ctrlDim] | |
Return the new memory | |
[batchSize, memDim] | |
''' | |
def write(self, memory, info, control, contControl = None, name = "", reuse = None): | |
with tf.variable_scope("write" + name, reuse = reuse): | |
# optionally project info | |
if config.writeInfoProj: | |
info = ops.linear(info, config.memDim, config.memDim, name = "info") | |
# optional info nonlinearity | |
info = ops.activations[config.writeInfoAct](info) | |
# compute self-attention vector based on previous controls and memories | |
if config.writeSelfAtt: | |
selfControl = control | |
if config.writeSelfAttMod == "CONT": | |
selfControl = contControl | |
# elif config.writeSelfAttMod == "POST": | |
# selfControl = postControl | |
selfControl = ops.linear(selfControl, config.ctrlDim, config.ctrlDim, name = "ctrlProj") | |
interactions = self.controls * tf.expand_dims(selfControl, axis = 1) | |
# if config.selfAttShareInter: | |
# selfAttlogits = self.linearP(selfAttInter, config.encDim, 1, self.interL[0], self.interL[1], name = "modSelfAttInter") | |
attention = ops.inter2att(interactions, config.ctrlDim, name = "selfAttention") | |
self.attentions["self"].append(attention) | |
selfSmry = ops.att2Smry(attention, self.memories) | |
# get write unit inputs: previous memory, the new info, optionally self-attention / control | |
newMemory, dim = memory, config.memDim | |
if config.writeInputs == "INFO": | |
newMemory = info | |
elif config.writeInputs == "SUM": | |
newMemory += info | |
elif config.writeInputs == "BOTH": | |
newMemory, dim = ops.concat(newMemory, info, dim, mul = config.writeConcatMul) | |
# else: MEM | |
if config.writeSelfAtt: | |
newMemory = tf.concat([newMemory, selfSmry], axis = -1) | |
dim += config.memDim | |
if config.writeMergeCtrl: | |
newMemory = tf.concat([newMemory, control], axis = -1) | |
dim += config.memDim | |
# project memory back to memory dimension | |
if config.writeMemProj or (dim != config.memDim): | |
newMemory = ops.linear(newMemory, dim, config.memDim, name = "newMemory") | |
# optional memory nonlinearity | |
newMemory = ops.activations[config.writeMemAct](newMemory) | |
# write unit gate | |
if config.writeGate: | |
gateDim = config.memDim | |
if config.writeGateShared: | |
gateDim = 1 | |
z = tf.sigmoid(ops.linear(control, config.ctrlDim, gateDim, name = "gate", bias = config.writeGateBias)) | |
self.attentions["gate"].append(z) | |
newMemory = newMemory * z + memory * (1 - z) | |
# optional batch normalization | |
if config.memoryBN: | |
newMemory = tf.contrib.layers.batch_norm(newMemory, decay = config.bnDecay, | |
center = config.bnCenter, scale = config.bnScale, | |
is_training = self.train, updates_collections = None) | |
return newMemory | |
def memAutoEnc(newMemory, info, control, name = "", reuse = None): | |
with tf.variable_scope("memAutoEnc" + name, reuse = reuse): | |
# inputs to auto encoder | |
features = info if config.autoEncMemInputs == "INFO" else newMemory | |
features = ops.linear(features, config.memDim, config.ctrlDim, | |
act = config.autoEncMemAct, name = "aeMem") | |
# reconstruct control | |
if config.autoEncMemLoss == "CONT": | |
loss = tf.reduce_mean(tf.squared_difference(control, features)) | |
else: | |
interactions, dim = ops.mul(self.questionCntxWords, features, config.ctrlDim, | |
concat = {"x": config.autoEncMemCnct}, mulBias = config.mulBias, name = "aeMem") | |
logits = ops.inter2logits(interactions, dim) | |
logits = self.expMask(logits, self.questionLengths) | |
# reconstruct word attentions | |
if config.autoEncMemLoss == "PROB": | |
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits( | |
labels = self.attentions["question"][-1], logits = logits)) | |
# reconstruct control through words attentions | |
else: | |
attention = tf.nn.softmax(logits) | |
summary = ops.att2Smry(attention, self.questionCntxWords) | |
loss = tf.reduce_mean(tf.squared_difference(control, summary)) | |
return loss | |
''' | |
Call the cell to get new control and memory states. | |
Args: | |
inputs: in the current implementation the cell don't get recurrent inputs | |
every iteration (argument for comparability with rnn interface). | |
state: the cell current state (control, memory) | |
MACCellTuple([batchSize, ctrlDim],[batchSize, memDim]) | |
Returns the new state -- the new memory and control values. | |
MACCellTuple([batchSize, ctrlDim],[batchSize, memDim]) | |
''' | |
def __call__(self, inputs, state, scope = None): | |
scope = scope or type(self).__name__ | |
with tf.variable_scope(scope, reuse = self.reuse): # as tfscope | |
control = state.control | |
memory = state.memory | |
# cell sharing | |
inputName = "qInput" | |
inputNameU = "qInputU" | |
inputReuseU = inputReuse = (self.iteration > 0) | |
if config.controlInputUnshared: | |
inputNameU = "qInput%d" % self.iteration | |
inputReuseU = None | |
cellName = "" | |
cellReuse = (self.iteration > 0) | |
if config.unsharedCells: | |
cellName = str(self.iteration) | |
cellReuse = None | |
## control unit | |
# prepare question input to control | |
controlInput = ops.linear(self.vecQuestions, config.ctrlDim, config.ctrlDim, | |
name = inputName, reuse = inputReuse) | |
controlInput = ops.activations[config.controlInputAct](controlInput) | |
controlInput = ops.linear(controlInput, config.ctrlDim, config.ctrlDim, | |
name = inputNameU, reuse = inputReuseU) | |
newControl, self.contControl = self.control(controlInput, self.inWords, self.outWords, | |
self.questionLengths, control, self.contControl, name = cellName, reuse = cellReuse) | |
# read unit | |
# ablation: use whole question as control | |
if config.controlWholeQ: | |
newControl = self.vecQuestions | |
# ops.linear(self.vecQuestions, config.ctrlDim, projDim, name = "qMod") | |
info = self.read(self.knowledgeBase, memory, newControl, name = cellName, reuse = cellReuse) | |
if config.writeDropout < 1.0: | |
# write unit | |
info = tf.nn.dropout(info, self.dropouts["write"]) | |
newMemory = self.write(memory, info, newControl, self.contControl, name = cellName, reuse = cellReuse) | |
# add auto encoder loss for memory | |
# if config.autoEncMem: | |
# self.autoEncLosses["memory"] += memAutoEnc(newMemory, info, newControl) | |
# append as standard list? | |
self.controls = tf.concat([self.controls, tf.expand_dims(newControl, axis = 1)], axis = 1) | |
self.memories = tf.concat([self.memories, tf.expand_dims(newMemory, axis = 1)], axis = 1) | |
self.infos = tf.concat([self.infos, tf.expand_dims(info, axis = 1)], axis = 1) | |
# self.contControls = tf.concat([self.contControls, tf.expand_dims(contControl, axis = 1)], axis = 1) | |
# self.postControls = tf.concat([self.controls, tf.expand_dims(postControls, axis = 1)], axis = 1) | |
newState = MACCellTuple(newControl, newMemory) | |
return self.none, newState | |
''' | |
Initializes the a hidden state to based on the value of the initType: | |
"PRM" for parametric initialization | |
"ZERO" for zero initialization | |
"Q" to initialize to question vectors. | |
Args: | |
name: the state variable name. | |
dim: the dimension of the state. | |
initType: the type of the initialization | |
batchSize: the batch size | |
Returns the initialized hidden state. | |
''' | |
def initState(self, name, dim, initType, batchSize): | |
if initType == "PRM": | |
prm = tf.get_variable(name, shape = (dim, ), | |
initializer = tf.random_normal_initializer()) | |
initState = tf.tile(tf.expand_dims(prm, axis = 0), [batchSize, 1]) | |
elif initType == "ZERO": | |
initState = tf.zeros((batchSize, dim), dtype = tf.float32) | |
else: # "Q" | |
initState = self.vecQuestions | |
return initState | |
''' | |
Add a parametric null word to the questions. | |
Args: | |
words: the words to add a null word to. | |
[batchSize, questionLentgth] | |
lengths: question lengths. | |
[batchSize] | |
Returns the updated word sequence and lengths. | |
''' | |
def addNullWord(words, lengths): | |
nullWord = tf.get_variable("zeroWord", shape = (1 , config.ctrlDim), initializer = tf.random_normal_initializer()) | |
nullWord = tf.tile(tf.expand_dims(nullWord, axis = 0), [self.batchSize, 1, 1]) | |
words = tf.concat([nullWord, words], axis = 1) | |
lengths += 1 | |
return words, lengths | |
''' | |
Initializes the cell internal state (currently it's stateful). In particular, | |
1. Data-structures (lists of attention maps and accumulated losses). | |
2. The memory and control states. | |
3. The knowledge base (optionally merging it with the question vectors) | |
4. The question words used by the cell (either the original word embeddings, or the | |
encoder outputs, with optional projection). | |
Args: | |
batchSize: the batch size | |
Returns the initial cell state. | |
''' | |
def zero_state(self, batchSize, dtype = tf.float32): | |
## initialize data-structures | |
self.attentions = {"kb": [], "question": [], "self": [], "gate": []} | |
self.autoEncLosses = {"control": tf.constant(0.0), "memory": tf.constant(0.0)} | |
## initialize state | |
initialControl = self.initState("initCtrl", config.ctrlDim, config.initCtrl, batchSize) | |
initialMemory = self.initState("initMem", config.memDim, config.initMem, batchSize) | |
self.controls = tf.expand_dims(initialControl, axis = 1) | |
self.memories = tf.expand_dims(initialMemory, axis = 1) | |
self.infos = tf.expand_dims(initialMemory, axis = 1) | |
self.contControl = initialControl | |
# self.contControls = tf.expand_dims(initialControl, axis = 1) | |
# self.postControls = tf.expand_dims(initialControl, axis = 1) | |
## initialize knowledge base | |
# optionally merge question into knowledge base representation | |
if config.initKBwithQ != "NON": | |
iVecQuestions = ops.linear(self.vecQuestions, config.ctrlDim, config.memDim, name = "questions") | |
concatMul = (config.initKBwithQ == "MUL") | |
cnct, dim = ops.concat(self.knowledgeBase, iVecQuestions, config.memDim, mul = concatMul, expandY = True) | |
self.knowledgeBase = ops.linear(cnct, dim, config.memDim, name = "initKB") | |
## initialize question words | |
# choose question words to work with (original embeddings or encoder outputs) | |
words = self.questionCntxWords if config.controlContextual else self.questionWords | |
# optionally add parametric "null" word in the to all questions | |
if config.addNullWord: | |
words, questionLengths = self.addNullWord(words, questionLengths) | |
# project words | |
self.inWords = self.outWords = words | |
if config.controlInWordsProj or config.controlOutWordsProj: | |
pWords = ops.linear(words, config.ctrlDim, config.ctrlDim, name = "wordsProj") | |
self.inWords = pWords if config.controlInWordsProj else words | |
self.outWords = pWords if config.controlOutWordsProj else words | |
# if config.controlCoverage: | |
# self.coverage = tf.zeros((batchSize, tf.shape(words)[1]), dtype = tf.float32) | |
# self.coverageBias = tf.get_variable("coverageBias", shape = (), | |
# initializer = config.controlCoverageBias) | |
## initialize memory variational dropout mask | |
if config.memoryVariationalDropout: | |
self.memDpMask = ops.generateVarDpMask((batchSize, config.memDim), self.dropouts["memory"]) | |
return MACCellTuple(initialControl, initialMemory) | |