Spaces:
Runtime error
Runtime error
Upload model.py
Browse files
model.py
ADDED
@@ -0,0 +1,802 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import tensorflow as tf
|
5 |
+
|
6 |
+
import ops
|
7 |
+
from config import config
|
8 |
+
from mac_cell import MACCell
|
9 |
+
'''
|
10 |
+
The MAC network model. It performs reasoning processes to answer a question over
|
11 |
+
knowledge base (the image) by decomposing it into attention-based computational steps,
|
12 |
+
each perform by a recurrent MAC cell.
|
13 |
+
|
14 |
+
The network has three main components.
|
15 |
+
Input unit: processes the network inputs: raw question strings and image into
|
16 |
+
distributional representations.
|
17 |
+
|
18 |
+
The MAC network: calls the MACcells (mac_cell.py) config.netLength number of times,
|
19 |
+
to perform the reasoning process over the question and image.
|
20 |
+
|
21 |
+
The output unit: a classifier that receives the question and final state of the MAC
|
22 |
+
network and uses them to compute log-likelihood over the possible one-word answers.
|
23 |
+
'''
|
24 |
+
class MACnet(object):
|
25 |
+
|
26 |
+
'''Initialize the class.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
embeddingsInit: initialization for word embeddings (random / glove).
|
30 |
+
answerDict: answers dictionary (mapping between integer id and symbol).
|
31 |
+
'''
|
32 |
+
def __init__(self, embeddingsInit, answerDict):
|
33 |
+
self.embeddingsInit = embeddingsInit
|
34 |
+
self.answerDict = answerDict
|
35 |
+
self.build()
|
36 |
+
|
37 |
+
'''
|
38 |
+
Initializes placeholders.
|
39 |
+
questionsIndicesAll: integer ids of question words.
|
40 |
+
[batchSize, questionLength]
|
41 |
+
|
42 |
+
questionLengthsAll: length of each question.
|
43 |
+
[batchSize]
|
44 |
+
|
45 |
+
imagesPlaceholder: image features.
|
46 |
+
[batchSize, channels, height, width]
|
47 |
+
(converted internally to [batchSize, height, width, channels])
|
48 |
+
|
49 |
+
answersIndicesAll: integer ids of answer words.
|
50 |
+
[batchSize]
|
51 |
+
|
52 |
+
lr: learning rate (tensor scalar)
|
53 |
+
train: train / evaluation (tensor boolean)
|
54 |
+
|
55 |
+
dropout values dictionary (tensor scalars)
|
56 |
+
'''
|
57 |
+
# change to H x W x C?
|
58 |
+
def addPlaceholders(self):
|
59 |
+
with tf.variable_scope("Placeholders"):
|
60 |
+
## data
|
61 |
+
# questions
|
62 |
+
self.questionsIndicesAll = tf.placeholder(tf.int32, shape = (None, None))
|
63 |
+
self.questionLengthsAll = tf.placeholder(tf.int32, shape = (None, ))
|
64 |
+
|
65 |
+
# images
|
66 |
+
# put image known dimension as last dim?
|
67 |
+
self.imagesPlaceholder = tf.placeholder(tf.float32, shape = (None, None, None, None))
|
68 |
+
self.imagesAll = tf.transpose(self.imagesPlaceholder, (0, 2, 3, 1))
|
69 |
+
# self.imageH = tf.shape(self.imagesAll)[1]
|
70 |
+
# self.imageW = tf.shape(self.imagesAll)[2]
|
71 |
+
|
72 |
+
# answers
|
73 |
+
self.answersIndicesAll = tf.placeholder(tf.int32, shape = (None, ))
|
74 |
+
|
75 |
+
## optimization
|
76 |
+
self.lr = tf.placeholder(tf.float32, shape = ())
|
77 |
+
self.train = tf.placeholder(tf.bool, shape = ())
|
78 |
+
self.batchSizeAll = tf.shape(self.questionsIndicesAll)[0]
|
79 |
+
|
80 |
+
## dropouts
|
81 |
+
# TODO: change dropouts to be 1 - current
|
82 |
+
self.dropouts = {
|
83 |
+
"encInput": tf.placeholder(tf.float32, shape = ()),
|
84 |
+
"encState": tf.placeholder(tf.float32, shape = ()),
|
85 |
+
"stem": tf.placeholder(tf.float32, shape = ()),
|
86 |
+
"question": tf.placeholder(tf.float32, shape = ()),
|
87 |
+
# self.dropouts["question"]Out = tf.placeholder(tf.float32, shape = ())
|
88 |
+
# self.dropouts["question"]MAC = tf.placeholder(tf.float32, shape = ())
|
89 |
+
"read": tf.placeholder(tf.float32, shape = ()),
|
90 |
+
"write": tf.placeholder(tf.float32, shape = ()),
|
91 |
+
"memory": tf.placeholder(tf.float32, shape = ()),
|
92 |
+
"output": tf.placeholder(tf.float32, shape = ())
|
93 |
+
}
|
94 |
+
|
95 |
+
# batch norm params
|
96 |
+
self.batchNorm = {"decay": config.bnDecay, "train": self.train}
|
97 |
+
|
98 |
+
# if config.parametricDropout:
|
99 |
+
# self.dropouts["question"] = parametricDropout("qDropout", self.train)
|
100 |
+
# self.dropouts["read"] = parametricDropout("readDropout", self.train)
|
101 |
+
# else:
|
102 |
+
# self.dropouts["question"] = self.dropouts["_q"]
|
103 |
+
# self.dropouts["read"] = self.dropouts["_read"]
|
104 |
+
|
105 |
+
# if config.tempDynamic:
|
106 |
+
# self.tempAnnealRate = tf.placeholder(tf.float32, shape = ())
|
107 |
+
|
108 |
+
self.H, self.W, self.imageInDim = config.imageDims
|
109 |
+
|
110 |
+
# Feeds data into placeholders. See addPlaceholders method for further details.
|
111 |
+
def createFeedDict(self, data, images, train):
|
112 |
+
feedDict = {
|
113 |
+
self.questionsIndicesAll: np.array(data["question"]),
|
114 |
+
self.questionLengthsAll: np.array(data["questionLength"]),
|
115 |
+
self.imagesPlaceholder: images,
|
116 |
+
# self.answersIndicesAll: [0],
|
117 |
+
|
118 |
+
self.dropouts["encInput"]: config.encInputDropout if train else 1.0,
|
119 |
+
self.dropouts["encState"]: config.encStateDropout if train else 1.0,
|
120 |
+
self.dropouts["stem"]: config.stemDropout if train else 1.0,
|
121 |
+
self.dropouts["question"]: config.qDropout if train else 1.0, #_
|
122 |
+
self.dropouts["memory"]: config.memoryDropout if train else 1.0,
|
123 |
+
self.dropouts["read"]: config.readDropout if train else 1.0, #_
|
124 |
+
self.dropouts["write"]: config.writeDropout if train else 1.0,
|
125 |
+
self.dropouts["output"]: config.outputDropout if train else 1.0,
|
126 |
+
# self.dropouts["question"]Out: config.qDropoutOut if train else 1.0,
|
127 |
+
# self.dropouts["question"]MAC: config.qDropoutMAC if train else 1.0,
|
128 |
+
|
129 |
+
self.lr: config.lr,
|
130 |
+
self.train: train
|
131 |
+
}
|
132 |
+
|
133 |
+
# if config.tempDynamic:
|
134 |
+
# feedDict[self.tempAnnealRate] = tempAnnealRate
|
135 |
+
|
136 |
+
return feedDict
|
137 |
+
|
138 |
+
# Splits data to a specific GPU (tower) for parallelization
|
139 |
+
def initTowerBatch(self, towerI, towersNum, dataSize):
|
140 |
+
towerBatchSize = tf.floordiv(dataSize, towersNum)
|
141 |
+
start = towerI * towerBatchSize
|
142 |
+
end = (towerI + 1) * towerBatchSize if towerI < towersNum - 1 else dataSize
|
143 |
+
|
144 |
+
self.questionsIndices = self.questionsIndicesAll[start:end]
|
145 |
+
self.questionLengths = self.questionLengthsAll[start:end]
|
146 |
+
self.images = self.imagesAll[start:end]
|
147 |
+
self.answersIndices = self.answersIndicesAll[start:end]
|
148 |
+
|
149 |
+
self.batchSize = end - start
|
150 |
+
|
151 |
+
'''
|
152 |
+
The Image Input Unit (stem). Passes the image features through a CNN-network
|
153 |
+
Optionally adds position encoding (doesn't in the default behavior).
|
154 |
+
Flatten the image into Height * Width "Knowledge base" array.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
images: image input. [batchSize, height, width, inDim]
|
158 |
+
inDim: input image dimension
|
159 |
+
outDim: image out dimension
|
160 |
+
addLoc: if not None, adds positional encoding to the image
|
161 |
+
|
162 |
+
Returns preprocessed images.
|
163 |
+
[batchSize, height * width, outDim]
|
164 |
+
'''
|
165 |
+
def stem(self, images, inDim, outDim, addLoc = None):
|
166 |
+
|
167 |
+
with tf.variable_scope("stem"):
|
168 |
+
if addLoc is None:
|
169 |
+
addLoc = config.locationAware
|
170 |
+
|
171 |
+
if config.stemLinear:
|
172 |
+
features = ops.linear(images, inDim, outDim)
|
173 |
+
else:
|
174 |
+
dims = [inDim] + ([config.stemDim] * (config.stemNumLayers - 1)) + [outDim]
|
175 |
+
|
176 |
+
if addLoc:
|
177 |
+
images, inDim = ops.addLocation(images, inDim, config.locationDim,
|
178 |
+
h = self.H, w = self.W, locType = config.locationType)
|
179 |
+
dims[0] = inDim
|
180 |
+
|
181 |
+
# if config.locationType == "PE":
|
182 |
+
# dims[-1] /= 4
|
183 |
+
# dims[-1] *= 3
|
184 |
+
# else:
|
185 |
+
# dims[-1] -= 2
|
186 |
+
features = ops.CNNLayer(images, dims,
|
187 |
+
batchNorm = self.batchNorm if config.stemBN else None,
|
188 |
+
dropout = self.dropouts["stem"],
|
189 |
+
kernelSizes = config.stemKernelSizes,
|
190 |
+
strides = config.stemStrideSizes)
|
191 |
+
|
192 |
+
# if addLoc:
|
193 |
+
# lDim = outDim / 4
|
194 |
+
# lDim /= 4
|
195 |
+
# features, _ = addLocation(features, dims[-1], lDim, h = H, w = W,
|
196 |
+
# locType = config.locationType)
|
197 |
+
|
198 |
+
if config.stemGridRnn:
|
199 |
+
features = ops.multigridRNNLayer(features, H, W, outDim)
|
200 |
+
|
201 |
+
# flatten the 2d images into a 1d KB
|
202 |
+
features = tf.reshape(features, (self.batchSize, -1, outDim))
|
203 |
+
|
204 |
+
return features
|
205 |
+
|
206 |
+
# Embed question using parametrized word embeddings.
|
207 |
+
# The embedding are initialized to the values supported to the class initialization
|
208 |
+
def qEmbeddingsOp(self, qIndices, embInit):
|
209 |
+
with tf.variable_scope("qEmbeddings"):
|
210 |
+
# if config.useCPU:
|
211 |
+
# with tf.device('/cpu:0'):
|
212 |
+
# embeddingsVar = tf.Variable(self.embeddingsInit, name = "embeddings", dtype = tf.float32)
|
213 |
+
# else:
|
214 |
+
# embeddingsVar = tf.Variable(self.embeddingsInit, name = "embeddings", dtype = tf.float32)
|
215 |
+
embeddingsVar = tf.get_variable("emb", initializer = tf.to_float(embInit),
|
216 |
+
dtype = tf.float32, trainable = (not config.wrdEmbFixed))
|
217 |
+
embeddings = tf.concat([tf.zeros((1, config.wrdEmbDim)), embeddingsVar], axis = 0)
|
218 |
+
questions = tf.nn.embedding_lookup(embeddings, qIndices)
|
219 |
+
|
220 |
+
return questions, embeddings
|
221 |
+
|
222 |
+
# Embed answer words
|
223 |
+
def aEmbeddingsOp(self, embInit):
|
224 |
+
with tf.variable_scope("aEmbeddings"):
|
225 |
+
if embInit is None:
|
226 |
+
return None
|
227 |
+
answerEmbeddings = tf.get_variable("emb", initializer = tf.to_float(embInit),
|
228 |
+
dtype = tf.float32)
|
229 |
+
return answerEmbeddings
|
230 |
+
|
231 |
+
# Embed question and answer words with tied embeddings
|
232 |
+
def qaEmbeddingsOp(self, qIndices, embInit):
|
233 |
+
questions, qaEmbeddings = self.qEmbeddingsOp(qIndices, embInit["qa"])
|
234 |
+
aEmbeddings = tf.nn.embedding_lookup(qaEmbeddings, embInit["ansMap"])
|
235 |
+
|
236 |
+
return questions, qaEmbeddings, aEmbeddings
|
237 |
+
|
238 |
+
'''
|
239 |
+
Embed question (and optionally answer) using parametrized word embeddings.
|
240 |
+
The embedding are initialized to the values supported to the class initialization
|
241 |
+
'''
|
242 |
+
def embeddingsOp(self, qIndices, embInit):
|
243 |
+
if config.ansEmbMod == "SHARED":
|
244 |
+
questions, qEmb, aEmb = self.qaEmbeddingsOp(qIndices, embInit)
|
245 |
+
else:
|
246 |
+
questions, qEmb = self.qEmbeddingsOp(qIndices, embInit["q"])
|
247 |
+
aEmb = self.aEmbeddingsOp(embInit["a"])
|
248 |
+
|
249 |
+
return questions, qEmb, aEmb
|
250 |
+
|
251 |
+
'''
|
252 |
+
The Question Input Unit embeds the questions to randomly-initialized word vectors,
|
253 |
+
and runs a recurrent bidirectional encoder (RNN/LSTM etc.) that gives back
|
254 |
+
vector representations for each question (the RNN final hidden state), and
|
255 |
+
representations for each of the question words (the RNN outputs for each word).
|
256 |
+
|
257 |
+
The method uses bidirectional LSTM, by default.
|
258 |
+
Optionally projects the outputs of the LSTM (with linear projection /
|
259 |
+
optionally with some activation).
|
260 |
+
|
261 |
+
Args:
|
262 |
+
questions: question word embeddings
|
263 |
+
[batchSize, questionLength, wordEmbDim]
|
264 |
+
|
265 |
+
questionLengths: the question lengths.
|
266 |
+
[batchSize]
|
267 |
+
|
268 |
+
projWords: True to apply projection on RNN outputs.
|
269 |
+
projQuestion: True to apply projection on final RNN state.
|
270 |
+
projDim: projection dimension in case projection is applied.
|
271 |
+
|
272 |
+
Returns:
|
273 |
+
Contextual Words: RNN outputs for the words.
|
274 |
+
[batchSize, questionLength, ctrlDim]
|
275 |
+
|
276 |
+
Vectorized Question: Final hidden state representing the whole question.
|
277 |
+
[batchSize, ctrlDim]
|
278 |
+
'''
|
279 |
+
def encoder(self, questions, questionLengths, projWords = False,
|
280 |
+
projQuestion = False, projDim = None):
|
281 |
+
|
282 |
+
with tf.variable_scope("encoder"):
|
283 |
+
# variational dropout option
|
284 |
+
varDp = None
|
285 |
+
if config.encVariationalDropout:
|
286 |
+
varDp = {"stateDp": self.dropouts["stateInput"],
|
287 |
+
"inputDp": self.dropouts["encInput"],
|
288 |
+
"inputSize": config.wrdEmbDim}
|
289 |
+
|
290 |
+
# rnns
|
291 |
+
for i in range(config.encNumLayers):
|
292 |
+
questionCntxWords, vecQuestions = ops.RNNLayer(questions, questionLengths,
|
293 |
+
config.encDim, bi = config.encBi, cellType = config.encType,
|
294 |
+
dropout = self.dropouts["encInput"], varDp = varDp, name = "rnn%d" % i)
|
295 |
+
|
296 |
+
# dropout for the question vector
|
297 |
+
vecQuestions = tf.nn.dropout(vecQuestions, self.dropouts["question"])
|
298 |
+
|
299 |
+
# projection of encoder outputs
|
300 |
+
if projWords:
|
301 |
+
questionCntxWords = ops.linear(questionCntxWords, config.encDim, projDim,
|
302 |
+
name = "projCW")
|
303 |
+
if projQuestion:
|
304 |
+
vecQuestions = ops.linear(vecQuestions, config.encDim, projDim,
|
305 |
+
act = config.encProjQAct, name = "projQ")
|
306 |
+
|
307 |
+
return questionCntxWords, vecQuestions
|
308 |
+
|
309 |
+
'''
|
310 |
+
Stacked Attention Layer for baseline. Computes interaction between images
|
311 |
+
and the previous memory, and casts it back to compute attention over the
|
312 |
+
image, which in turn is summed up with the previous memory to result in the
|
313 |
+
new one.
|
314 |
+
|
315 |
+
Args:
|
316 |
+
images: input image.
|
317 |
+
[batchSize, H * W, inDim]
|
318 |
+
|
319 |
+
memory: previous memory value
|
320 |
+
[batchSize, inDim]
|
321 |
+
|
322 |
+
inDim: inputs dimension
|
323 |
+
hDim: hidden dimension to compute interactions between image and memory
|
324 |
+
|
325 |
+
Returns the new memory value.
|
326 |
+
'''
|
327 |
+
def baselineAttLayer(self, images, memory, inDim, hDim, name = "", reuse = None):
|
328 |
+
with tf.variable_scope("attLayer" + name, reuse = reuse):
|
329 |
+
# projImages = ops.linear(images, inDim, hDim, name = "projImage")
|
330 |
+
# projMemory = tf.expand_dims(ops.linear(memory, inDim, hDim, name = "projMemory"), axis = -2)
|
331 |
+
# if config.saMultiplicative:
|
332 |
+
# interactions = projImages * projMemory
|
333 |
+
# else:
|
334 |
+
# interactions = tf.tanh(projImages + projMemory)
|
335 |
+
interactions, _ = ops.mul(images, memory, inDim, proj = {"dim": hDim, "shared": False},
|
336 |
+
interMod = config.baselineAttType)
|
337 |
+
|
338 |
+
attention = ops.inter2att(interactions, hDim)
|
339 |
+
summary = ops.att2Smry(attention, images)
|
340 |
+
newMemory = memory + summary
|
341 |
+
|
342 |
+
return newMemory
|
343 |
+
|
344 |
+
'''
|
345 |
+
Baseline approach:
|
346 |
+
If baselineAtt is True, applies several layers (baselineAttNumLayers)
|
347 |
+
of stacked attention to image and memory, when memory is initialized
|
348 |
+
to the vector questions. See baselineAttLayer for further details.
|
349 |
+
|
350 |
+
Otherwise, computes result output features based on image representation
|
351 |
+
(baselineCNN), or question (baselineLSTM) or both.
|
352 |
+
|
353 |
+
Args:
|
354 |
+
vecQuestions: question vector representation
|
355 |
+
[batchSize, questionDim]
|
356 |
+
|
357 |
+
questionDim: dimension of question vectors
|
358 |
+
|
359 |
+
images: (flattened) image representation
|
360 |
+
[batchSize, imageDim]
|
361 |
+
|
362 |
+
imageDim: dimension of image representations.
|
363 |
+
|
364 |
+
hDim: hidden dimension to compute interactions between image and memory
|
365 |
+
(for attention-based baseline).
|
366 |
+
|
367 |
+
Returns final features to use in later classifier.
|
368 |
+
[batchSize, outDim] (out dimension depends on baseline method)
|
369 |
+
'''
|
370 |
+
def baseline(self, vecQuestions, questionDim, images, imageDim, hDim):
|
371 |
+
with tf.variable_scope("baseline"):
|
372 |
+
if config.baselineAtt:
|
373 |
+
memory = self.linear(vecQuestions, questionDim, hDim, name = "qProj")
|
374 |
+
images = self.linear(images, imageDim, hDim, name = "iProj")
|
375 |
+
|
376 |
+
for i in range(config.baselineAttNumLayers):
|
377 |
+
memory = self.baselineAttLayer(images, memory, hDim, hDim,
|
378 |
+
name = "baseline%d" % i)
|
379 |
+
memDim = hDim
|
380 |
+
else:
|
381 |
+
images, imagesDim = ops.linearizeFeatures(images, self.H, self.W,
|
382 |
+
imageDim, projDim = config.baselineProjDim)
|
383 |
+
if config.baselineLSTM and config.baselineCNN:
|
384 |
+
memory = tf.concat([vecQuestions, images], axis = -1)
|
385 |
+
memDim = questionDim + imageDim
|
386 |
+
elif config.baselineLSTM:
|
387 |
+
memory = vecQuestions
|
388 |
+
memDim = questionDim
|
389 |
+
else: # config.baselineCNN
|
390 |
+
memory = images
|
391 |
+
memDim = imageDim
|
392 |
+
|
393 |
+
return memory, memDim
|
394 |
+
|
395 |
+
'''
|
396 |
+
Runs the MAC recurrent network to perform the reasoning process.
|
397 |
+
Initializes a MAC cell and runs netLength iterations.
|
398 |
+
|
399 |
+
Currently it passes the question and knowledge base to the cell during
|
400 |
+
its creating, such that it doesn't need to interact with it through
|
401 |
+
inputs / outputs while running. The recurrent computation happens
|
402 |
+
by working iteratively over the hidden (control, memory) states.
|
403 |
+
|
404 |
+
Args:
|
405 |
+
images: flattened image features. Used as the "Knowledge Base".
|
406 |
+
(Received by default model behavior from the Image Input Units).
|
407 |
+
[batchSize, H * W, memDim]
|
408 |
+
|
409 |
+
vecQuestions: vector questions representations.
|
410 |
+
(Received by default model behavior from the Question Input Units
|
411 |
+
as the final RNN state).
|
412 |
+
[batchSize, ctrlDim]
|
413 |
+
|
414 |
+
questionWords: question word embeddings.
|
415 |
+
[batchSize, questionLength, ctrlDim]
|
416 |
+
|
417 |
+
questionCntxWords: question contextual words.
|
418 |
+
(Received by default model behavior from the Question Input Units
|
419 |
+
as the series of RNN output states).
|
420 |
+
[batchSize, questionLength, ctrlDim]
|
421 |
+
|
422 |
+
questionLengths: question lengths.
|
423 |
+
[batchSize]
|
424 |
+
|
425 |
+
Returns the final control state and memory state resulted from the network.
|
426 |
+
([batchSize, ctrlDim], [bathSize, memDim])
|
427 |
+
'''
|
428 |
+
def MACnetwork(self, images, vecQuestions, questionWords, questionCntxWords,
|
429 |
+
questionLengths, name = "", reuse = None):
|
430 |
+
|
431 |
+
with tf.variable_scope("MACnetwork" + name, reuse = reuse):
|
432 |
+
|
433 |
+
self.macCell = MACCell(
|
434 |
+
vecQuestions = vecQuestions,
|
435 |
+
questionWords = questionWords,
|
436 |
+
questionCntxWords = questionCntxWords,
|
437 |
+
questionLengths = questionLengths,
|
438 |
+
knowledgeBase = images,
|
439 |
+
memoryDropout = self.dropouts["memory"],
|
440 |
+
readDropout = self.dropouts["read"],
|
441 |
+
writeDropout = self.dropouts["write"],
|
442 |
+
# qDropoutMAC = self.qDropoutMAC,
|
443 |
+
batchSize = self.batchSize,
|
444 |
+
train = self.train,
|
445 |
+
reuse = reuse)
|
446 |
+
|
447 |
+
state = self.macCell.zero_state(self.batchSize, tf.float32)
|
448 |
+
|
449 |
+
# inSeq = tf.unstack(inSeq, axis = 1)
|
450 |
+
none = tf.zeros((self.batchSize, 1), dtype = tf.float32)
|
451 |
+
|
452 |
+
# for i, inp in enumerate(inSeq):
|
453 |
+
for i in range(config.netLength):
|
454 |
+
self.macCell.iteration = i
|
455 |
+
# if config.unsharedCells:
|
456 |
+
# with tf.variable_scope("iteration%d" % i):
|
457 |
+
# macCell.myNameScope = "iteration%d" % i
|
458 |
+
_, state = self.macCell(none, state)
|
459 |
+
# else:
|
460 |
+
# _, state = macCell(none, state)
|
461 |
+
# macCell.reuse = True
|
462 |
+
|
463 |
+
# self.autoEncMMLoss = macCell.autoEncMMLossI
|
464 |
+
# inputSeqL = None
|
465 |
+
# _, lastOutputs = tf.nn.dynamic_rnn(macCell, inputSeq, # / static
|
466 |
+
# sequence_length = inputSeqL,
|
467 |
+
# initial_state = initialState,
|
468 |
+
# swap_memory = True)
|
469 |
+
|
470 |
+
# self.postModules = None
|
471 |
+
# if (config.controlPostRNN or config.selfAttentionMod == "POST"): # may not work well with dlogits
|
472 |
+
# self.postModules, _ = self.RNNLayer(cLogits, None, config.encDim, bi = False,
|
473 |
+
# name = "decPostRNN", cellType = config.controlPostRNNmod)
|
474 |
+
# if config.controlPostRNN:
|
475 |
+
# logits = self.postModules
|
476 |
+
# self.postModules = tf.unstack(self.postModules, axis = 1)
|
477 |
+
|
478 |
+
# self.autoEncCtrlLoss = tf.constant(0.0)
|
479 |
+
# if config.autoEncCtrl:
|
480 |
+
# autoEncCtrlCellType = ("GRU" if config.autoEncCtrlGRU else "RNN")
|
481 |
+
# autoEncCtrlinp = logits
|
482 |
+
# _, autoEncHid = self.RNNLayer(autoEncCtrlinp, None, config.encDim,
|
483 |
+
# bi = True, name = "autoEncCtrl", cellType = autoEncCtrlCellType)
|
484 |
+
# self.autoEncCtrlLoss = (tf.nn.l2_loss(vecQuestions - autoEncHid)) / tf.to_float(self.batchSize)
|
485 |
+
|
486 |
+
finalControl = state.control
|
487 |
+
finalMemory = state.memory
|
488 |
+
|
489 |
+
return finalControl, finalMemory
|
490 |
+
|
491 |
+
'''
|
492 |
+
Output Unit (step 1): chooses the inputs to the output classifier.
|
493 |
+
|
494 |
+
By default the classifier input will be the the final memory state of the MAC network.
|
495 |
+
If outQuestion is True, concatenate the question representation to that.
|
496 |
+
If outImage is True, concatenate the image flattened representation.
|
497 |
+
|
498 |
+
Args:
|
499 |
+
memory: (final) memory state of the MAC network.
|
500 |
+
[batchSize, memDim]
|
501 |
+
|
502 |
+
vecQuestions: question vector representation.
|
503 |
+
[batchSize, ctrlDim]
|
504 |
+
|
505 |
+
images: image features.
|
506 |
+
[batchSize, H, W, imageInDim]
|
507 |
+
|
508 |
+
imageInDim: images dimension.
|
509 |
+
|
510 |
+
Returns the resulted features and their dimension.
|
511 |
+
'''
|
512 |
+
def outputOp(self, memory, vecQuestions, images, imageInDim):
|
513 |
+
with tf.variable_scope("outputUnit"):
|
514 |
+
features = memory
|
515 |
+
dim = config.memDim
|
516 |
+
|
517 |
+
if config.outQuestion:
|
518 |
+
eVecQuestions = ops.linear(vecQuestions, config.ctrlDim, config.memDim, name = "outQuestion")
|
519 |
+
features, dim = ops.concat(features, eVecQuestions, config.memDim, mul = config.outQuestionMul)
|
520 |
+
|
521 |
+
if config.outImage:
|
522 |
+
images, imagesDim = ops.linearizeFeatures(images, self.H, self.W, self.imageInDim,
|
523 |
+
outputDim = config.outImageDim)
|
524 |
+
images = ops.linear(images, config.memDim, config.outImageDim, name = "outImage")
|
525 |
+
features = tf.concat([features, images], axis = -1)
|
526 |
+
dim += config.outImageDim
|
527 |
+
|
528 |
+
return features, dim
|
529 |
+
|
530 |
+
'''
|
531 |
+
Output Unit (step 2): Computes the logits for the answers. Passes the features
|
532 |
+
through fully-connected network to get the logits over the possible answers.
|
533 |
+
Optionally uses answer word embeddings in computing the logits (by default, it doesn't).
|
534 |
+
|
535 |
+
Args:
|
536 |
+
features: features used to compute logits
|
537 |
+
[batchSize, inDim]
|
538 |
+
|
539 |
+
inDim: features dimension
|
540 |
+
|
541 |
+
aEmbedding: supported word embeddings for answer words in case answerMod is not NON.
|
542 |
+
Optionally computes logits by computing dot-product with answer embeddings.
|
543 |
+
|
544 |
+
Returns: the computed logits.
|
545 |
+
[batchSize, answerWordsNum]
|
546 |
+
'''
|
547 |
+
def classifier(self, features, inDim, aEmbeddings = None):
|
548 |
+
with tf.variable_scope("classifier"):
|
549 |
+
outDim = config.answerWordsNum
|
550 |
+
dims = [inDim] + config.outClassifierDims + [outDim]
|
551 |
+
if config.answerMod != "NON":
|
552 |
+
dims[-1] = config.wrdEmbDim
|
553 |
+
|
554 |
+
|
555 |
+
logits = ops.FCLayer(features, dims,
|
556 |
+
batchNorm = self.batchNorm if config.outputBN else None,
|
557 |
+
dropout = self.dropouts["output"])
|
558 |
+
|
559 |
+
if config.answerMod != "NON":
|
560 |
+
logits = tf.nn.dropout(logits, self.dropouts["output"])
|
561 |
+
interactions = ops.mul(aEmbeddings, logits, dims[-1], interMod = config.answerMod)
|
562 |
+
logits = ops.inter2logits(interactions, dims[-1], sumMod = "SUM")
|
563 |
+
logits += ops.getBias((outputDim, ), "ans")
|
564 |
+
|
565 |
+
# answersWeights = tf.transpose(aEmbeddings)
|
566 |
+
|
567 |
+
# if config.answerMod == "BL":
|
568 |
+
# Wans = ops.getWeight((dims[-1], config.wrdEmbDim), "ans")
|
569 |
+
# logits = tf.matmul(logits, Wans)
|
570 |
+
# elif config.answerMod == "DIAG":
|
571 |
+
# Wans = ops.getWeight((config.wrdEmbDim, ), "ans")
|
572 |
+
# logits = logits * Wans
|
573 |
+
|
574 |
+
# logits = tf.matmul(logits, answersWeights)
|
575 |
+
|
576 |
+
return logits
|
577 |
+
|
578 |
+
# def getTemp():
|
579 |
+
# with tf.variable_scope("temperature"):
|
580 |
+
# if config.tempParametric:
|
581 |
+
# self.temperatureVar = tf.get_variable("temperature", shape = (),
|
582 |
+
# initializer = tf.constant_initializer(5), dtype = tf.float32)
|
583 |
+
# temperature = tf.sigmoid(self.temperatureVar)
|
584 |
+
# else:
|
585 |
+
# temperature = config.temperature
|
586 |
+
|
587 |
+
# if config.tempDynamic:
|
588 |
+
# temperature *= self.tempAnnealRate
|
589 |
+
|
590 |
+
# return temperature
|
591 |
+
|
592 |
+
# Computes mean cross entropy loss between logits and answers.
|
593 |
+
def addAnswerLossOp(self, logits, answers):
|
594 |
+
with tf.variable_scope("answerLoss"):
|
595 |
+
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(labels = answers, logits = logits)
|
596 |
+
loss = tf.reduce_mean(losses)
|
597 |
+
self.answerLossList.append(loss)
|
598 |
+
|
599 |
+
return loss, losses
|
600 |
+
|
601 |
+
# Computes predictions (by finding maximal logit value, corresponding to highest probability)
|
602 |
+
# and mean accuracy between predictions and answers.
|
603 |
+
def addPredOp(self, logits, answers):
|
604 |
+
with tf.variable_scope("pred"):
|
605 |
+
preds = tf.to_int32(tf.argmax(logits, axis = -1)) # tf.nn.softmax(
|
606 |
+
corrects = tf.equal(preds, answers)
|
607 |
+
correctNum = tf.reduce_sum(tf.to_int32(corrects))
|
608 |
+
acc = tf.reduce_mean(tf.to_float(corrects))
|
609 |
+
self.correctNumList.append(correctNum)
|
610 |
+
self.answerAccList.append(acc)
|
611 |
+
|
612 |
+
return preds, corrects, correctNum
|
613 |
+
|
614 |
+
# Creates optimizer (adam)
|
615 |
+
def addOptimizerOp(self):
|
616 |
+
with tf.variable_scope("trainAddOptimizer"):
|
617 |
+
self.globalStep = tf.Variable(0, dtype = tf.int32, trainable = False, name = "globalStep") # init to 0 every run?
|
618 |
+
optimizer = tf.train.AdamOptimizer(learning_rate = self.lr)
|
619 |
+
|
620 |
+
return optimizer
|
621 |
+
|
622 |
+
'''
|
623 |
+
Computes gradients for all variables or subset of them, based on provided loss,
|
624 |
+
using optimizer.
|
625 |
+
'''
|
626 |
+
def computeGradients(self, optimizer, loss, trainableVars = None): # tf.trainable_variables()
|
627 |
+
with tf.variable_scope("computeGradients"):
|
628 |
+
if config.trainSubset:
|
629 |
+
trainableVars = []
|
630 |
+
allVars = tf.trainable_variables()
|
631 |
+
for var in allVars:
|
632 |
+
if any((s in var.name) for s in config.varSubset):
|
633 |
+
trainableVars.append(var)
|
634 |
+
|
635 |
+
gradients_vars = optimizer.compute_gradients(loss, trainableVars)
|
636 |
+
return gradients_vars
|
637 |
+
|
638 |
+
'''
|
639 |
+
Apply gradients. Optionally clip them, and update exponential moving averages
|
640 |
+
for parameters.
|
641 |
+
'''
|
642 |
+
def addTrainingOp(self, optimizer, gradients_vars):
|
643 |
+
with tf.variable_scope("train"):
|
644 |
+
gradients, variables = zip(*gradients_vars)
|
645 |
+
norm = tf.global_norm(gradients)
|
646 |
+
|
647 |
+
# gradient clipping
|
648 |
+
if config.clipGradients:
|
649 |
+
clippedGradients, _ = tf.clip_by_global_norm(gradients, config.gradMaxNorm, use_norm = norm)
|
650 |
+
gradients_vars = zip(clippedGradients, variables)
|
651 |
+
|
652 |
+
# updates ops (for batch norm) and train op
|
653 |
+
updateOps = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
|
654 |
+
with tf.control_dependencies(updateOps):
|
655 |
+
train = optimizer.apply_gradients(gradients_vars, global_step = self.globalStep)
|
656 |
+
|
657 |
+
# exponential moving average
|
658 |
+
if config.useEMA:
|
659 |
+
ema = tf.train.ExponentialMovingAverage(decay = config.emaDecayRate)
|
660 |
+
maintainAveragesOp = ema.apply(tf.trainable_variables())
|
661 |
+
|
662 |
+
with tf.control_dependencies([train]):
|
663 |
+
trainAndUpdateOp = tf.group(maintainAveragesOp)
|
664 |
+
|
665 |
+
train = trainAndUpdateOp
|
666 |
+
|
667 |
+
self.emaDict = ema.variables_to_restore()
|
668 |
+
|
669 |
+
return train, norm
|
670 |
+
|
671 |
+
# TODO (add back support for multi-gpu..)
|
672 |
+
def averageAcrossTowers(self, gpusNum):
|
673 |
+
self.lossAll = self.lossList[0]
|
674 |
+
|
675 |
+
self.answerLossAll = self.answerLossList[0]
|
676 |
+
self.correctNumAll = self.correctNumList[0]
|
677 |
+
self.answerAccAll = self.answerAccList[0]
|
678 |
+
self.predsAll = self.predsList[0]
|
679 |
+
self.gradientVarsAll = self.gradientVarsList[0]
|
680 |
+
|
681 |
+
def trim2DVectors(self, vectors, vectorsLengths):
|
682 |
+
maxLength = np.max(vectorsLengths)
|
683 |
+
return vectors[:,:maxLength]
|
684 |
+
|
685 |
+
def trimData(self, data):
|
686 |
+
data["question"] = self.trim2DVectors(data["question"], data["questionLength"])
|
687 |
+
return data
|
688 |
+
|
689 |
+
'''
|
690 |
+
Builds predictions JSON, by adding the model's predictions and attention maps
|
691 |
+
back to the original data JSON.
|
692 |
+
'''
|
693 |
+
def buildPredsList(self, prediction):
|
694 |
+
|
695 |
+
return self.answerDict.decodeId(prediction)
|
696 |
+
|
697 |
+
'''
|
698 |
+
Processes a batch of data with the model.
|
699 |
+
|
700 |
+
Args:
|
701 |
+
sess: TF session
|
702 |
+
|
703 |
+
data: Data batch. Dictionary that contains numpy array for:
|
704 |
+
questions, questionLengths, answers.
|
705 |
+
See preprocess.py for further information of the batch structure.
|
706 |
+
|
707 |
+
images: batch of image features, as numpy array. images["images"] contains
|
708 |
+
[batchSize, channels, h, w]
|
709 |
+
|
710 |
+
train: True to run batch for training.
|
711 |
+
|
712 |
+
getAtt: True to return attention maps for question and image (and optionally
|
713 |
+
self-attention and gate values).
|
714 |
+
|
715 |
+
Returns results: e.g. loss, accuracy, running time.
|
716 |
+
'''
|
717 |
+
def runBatch(self, sess, data, images, train, getAtt = False):
|
718 |
+
data = self.trimData(data)
|
719 |
+
|
720 |
+
predsOp = self.predsAll
|
721 |
+
|
722 |
+
time0 = time.time()
|
723 |
+
feed = self.createFeedDict(data, images, train)
|
724 |
+
|
725 |
+
time1 = time.time()
|
726 |
+
predsInfo = sess.run(
|
727 |
+
predsOp,
|
728 |
+
feed_dict = feed)
|
729 |
+
time2 = time.time()
|
730 |
+
|
731 |
+
predsList = self.buildPredsList(predsInfo[0])
|
732 |
+
|
733 |
+
return predsList
|
734 |
+
|
735 |
+
def build(self):
|
736 |
+
self.addPlaceholders()
|
737 |
+
self.optimizer = self.addOptimizerOp()
|
738 |
+
|
739 |
+
self.gradientVarsList = []
|
740 |
+
self.lossList = []
|
741 |
+
|
742 |
+
self.answerLossList = []
|
743 |
+
self.correctNumList = []
|
744 |
+
self.answerAccList = []
|
745 |
+
self.predsList = []
|
746 |
+
|
747 |
+
with tf.variable_scope("macModel"):
|
748 |
+
for i in range(config.gpusNum):
|
749 |
+
with tf.device("/gpu:{}".format(i)):
|
750 |
+
with tf.name_scope("tower{}".format(i)) as scope:
|
751 |
+
self.initTowerBatch(i, config.gpusNum, self.batchSizeAll)
|
752 |
+
|
753 |
+
self.loss = tf.constant(0.0)
|
754 |
+
|
755 |
+
# embed questions words (and optionally answer words)
|
756 |
+
questionWords, qEmbeddings, aEmbeddings = \
|
757 |
+
self.embeddingsOp(self.questionsIndices, self.embeddingsInit)
|
758 |
+
|
759 |
+
projWords = projQuestion = ((config.encDim != config.ctrlDim) or config.encProj)
|
760 |
+
questionCntxWords, vecQuestions = self.encoder(questionWords,
|
761 |
+
self.questionLengths, projWords, projQuestion, config.ctrlDim)
|
762 |
+
|
763 |
+
# Image Input Unit (stem)
|
764 |
+
imageFeatures = self.stem(self.images, self.imageInDim, config.memDim)
|
765 |
+
|
766 |
+
# baseline model
|
767 |
+
if config.useBaseline:
|
768 |
+
output, dim = self.baseline(vecQuestions, config.ctrlDim,
|
769 |
+
self.images, self.imageInDim, config.attDim)
|
770 |
+
# MAC model
|
771 |
+
else:
|
772 |
+
# self.temperature = self.getTemp()
|
773 |
+
|
774 |
+
finalControl, finalMemory = self.MACnetwork(imageFeatures, vecQuestions,
|
775 |
+
questionWords, questionCntxWords, self.questionLengths)
|
776 |
+
|
777 |
+
# Output Unit - step 1 (preparing classifier inputs)
|
778 |
+
output, dim = self.outputOp(finalMemory, vecQuestions,
|
779 |
+
self.images, self.imageInDim)
|
780 |
+
|
781 |
+
# Output Unit - step 2 (classifier)
|
782 |
+
logits = self.classifier(output, dim, aEmbeddings)
|
783 |
+
|
784 |
+
# compute loss, predictions, accuracy
|
785 |
+
answerLoss, self.losses = self.addAnswerLossOp(logits, self.answersIndices)
|
786 |
+
self.preds, self.corrects, self.correctNum = self.addPredOp(logits, self.answersIndices)
|
787 |
+
self.loss += answerLoss
|
788 |
+
self.predsList.append(self.preds)
|
789 |
+
|
790 |
+
self.lossList.append(self.loss)
|
791 |
+
|
792 |
+
# compute gradients
|
793 |
+
gradient_vars = self.computeGradients(self.optimizer, self.loss, trainableVars = None)
|
794 |
+
self.gradientVarsList.append(gradient_vars)
|
795 |
+
|
796 |
+
# reuse variables in next towers
|
797 |
+
tf.get_variable_scope().reuse_variables()
|
798 |
+
|
799 |
+
self.averageAcrossTowers(config.gpusNum)
|
800 |
+
|
801 |
+
self.trainOp, self.gradNorm = self.addTrainingOp(self.optimizer, self.gradientVarsAll)
|
802 |
+
self.noOp = tf.no_op()
|