File size: 8,337 Bytes
0b8359d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
from __future__ import print_function

import h5py
import numpy as np
import os
from six.moves import xrange
import tensorflow as tf

from utils import write_datasets
from synthetic_data_utils import normalize_rates
from synthetic_data_utils import get_train_n_valid_inds, nparray_and_transpose
from synthetic_data_utils import spikify_data, split_list_by_inds

DATA_DIR = "rnn_synth_data_v1.0"

flags = tf.app.flags
flags.DEFINE_string("save_dir", "/tmp/" + DATA_DIR + "/",
                    "Directory for saving data.")
flags.DEFINE_string("datafile_name", "itb_rnn",
                    "Name of data file for input case.")
flags.DEFINE_integer("synth_data_seed", 5, "Random seed for RNN generation.")
flags.DEFINE_float("T", 1.0, "Time in seconds to generate.")
flags.DEFINE_integer("C", 800, "Number of conditions")
flags.DEFINE_integer("N", 50, "Number of units for the RNN")
flags.DEFINE_float("train_percentage", 4.0/5.0,
                   "Percentage of train vs validation trials")
flags.DEFINE_integer("nreplications", 5,
                     "Number of spikifications of the same underlying rates.")
flags.DEFINE_float("tau", 0.025, "Time constant of RNN")
flags.DEFINE_float("dt", 0.010, "Time bin")
flags.DEFINE_float("max_firing_rate", 30.0,
                   "Map 1.0 of RNN to a spikes per second")
flags.DEFINE_float("u_std", 0.25,
                   "Std dev of input to integration to bound model")
flags.DEFINE_string("checkpoint_path", "SAMPLE_CHECKPOINT",
                    """Path to directory with checkpoints of model
                    trained on integration to bound task. Currently this
                    is a placeholder which tells the code to grab the
                    checkpoint that is provided with the code
                    (in /trained_itb/..). If you have your own checkpoint
                    you would like to restore, you would point it to
                    that path.""")
FLAGS = flags.FLAGS


class IntegrationToBoundModel:
  def __init__(self, N):
    scale = 0.8 / float(N**0.5)
    self.N = N
    self.Wh_nxn = tf.Variable(tf.random_normal([N, N], stddev=scale))
    self.b_1xn = tf.Variable(tf.zeros([1, N]))
    self.Bu_1xn = tf.Variable(tf.zeros([1, N]))
    self.Wro_nxo = tf.Variable(tf.random_normal([N, 1], stddev=scale))
    self.bro_o = tf.Variable(tf.zeros([1]))

  def call(self, h_tm1_bxn, u_bx1):
    act_t_bxn = tf.matmul(h_tm1_bxn, self.Wh_nxn) + self.b_1xn + u_bx1 * self.Bu_1xn
    h_t_bxn = tf.nn.tanh(act_t_bxn)
    z_t = tf.nn.xw_plus_b(h_t_bxn, self.Wro_nxo, self.bro_o)
    return z_t, h_t_bxn

def get_data_batch(batch_size, T, rng, u_std):
  u_bxt = rng.randn(batch_size, T) * u_std
  running_sum_b = np.zeros([batch_size])
  labels_bxt = np.zeros([batch_size, T])
  for t in xrange(T):
    running_sum_b += u_bxt[:, t]
    labels_bxt[:, t] += running_sum_b
  labels_bxt = np.clip(labels_bxt, -1, 1)
  return u_bxt, labels_bxt


rng = np.random.RandomState(seed=FLAGS.synth_data_seed)
u_rng = np.random.RandomState(seed=FLAGS.synth_data_seed+1)
T = FLAGS.T
C = FLAGS.C
N = FLAGS.N  # must be same N as in trained model (provided example is N = 50)
nreplications = FLAGS.nreplications
E = nreplications * C  # total number of trials
train_percentage = FLAGS.train_percentage
ntimesteps = int(T / FLAGS.dt)
batch_size = 1  # gives one example per ntrial

model = IntegrationToBoundModel(N)
inputs_ph_t = [tf.placeholder(tf.float32,
                              shape=[None, 1]) for _ in range(ntimesteps)]
state = tf.zeros([batch_size, N])
saver = tf.train.Saver()

P_nxn = rng.randn(N,N) / np.sqrt(N)  # random projections

# unroll RNN for T timesteps
outputs_t = []
states_t = []

for inp in inputs_ph_t:
  output, state = model.call(state, inp)
  outputs_t.append(output)
  states_t.append(state)

with tf.Session() as sess:
  # restore the latest model ckpt
  if FLAGS.checkpoint_path == "SAMPLE_CHECKPOINT":
    dir_path = os.path.dirname(os.path.realpath(__file__))
    model_checkpoint_path = os.path.join(dir_path, "trained_itb/model-65000")
  else:
    model_checkpoint_path = FLAGS.checkpoint_path
  try:
    saver.restore(sess, model_checkpoint_path)
    print ('Model restored from', model_checkpoint_path)
  except:
    assert False, ("No checkpoints to restore from, is the path %s correct?"
                   %model_checkpoint_path)

  # generate data for trials
  data_e = []
  u_e = []
  outs_e = []
  for c in range(C):
    u_1xt, outs_1xt = get_data_batch(batch_size, ntimesteps, u_rng, FLAGS.u_std)

    feed_dict = {}
    for t in xrange(ntimesteps):
      feed_dict[inputs_ph_t[t]] = np.reshape(u_1xt[:,t], (batch_size,-1))

    states_t_bxn, outputs_t_bxn = sess.run([states_t, outputs_t],
                                           feed_dict=feed_dict)
    states_nxt = np.transpose(np.squeeze(np.asarray(states_t_bxn)))
    outputs_t_bxn = np.squeeze(np.asarray(outputs_t_bxn))
    r_sxt = np.dot(P_nxn, states_nxt)

    for s in xrange(nreplications):
      data_e.append(r_sxt)
      u_e.append(u_1xt)
      outs_e.append(outputs_t_bxn)

  truth_data_e = normalize_rates(data_e, E, N)

spiking_data_e = spikify_data(truth_data_e, rng, dt=FLAGS.dt,
                              max_firing_rate=FLAGS.max_firing_rate)
train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage,
                                                nreplications)

data_train_truth, data_valid_truth = split_list_by_inds(truth_data_e,
                                                        train_inds,
                                                        valid_inds)
data_train_spiking, data_valid_spiking = split_list_by_inds(spiking_data_e,
                                                            train_inds,
                                                            valid_inds)

data_train_truth = nparray_and_transpose(data_train_truth)
data_valid_truth = nparray_and_transpose(data_valid_truth)
data_train_spiking = nparray_and_transpose(data_train_spiking)
data_valid_spiking = nparray_and_transpose(data_valid_spiking)

# save down the inputs used to generate this data
train_inputs_u, valid_inputs_u = split_list_by_inds(u_e,
                                                    train_inds,
                                                    valid_inds)
train_inputs_u = nparray_and_transpose(train_inputs_u)
valid_inputs_u = nparray_and_transpose(valid_inputs_u)

# save down the network outputs (may be useful later)
train_outputs_u, valid_outputs_u = split_list_by_inds(outs_e,
                                                      train_inds,
                                                      valid_inds)
train_outputs_u = np.array(train_outputs_u)
valid_outputs_u = np.array(valid_outputs_u)


data = { 'train_truth': data_train_truth,
         'valid_truth': data_valid_truth,
         'train_data' : data_train_spiking,
         'valid_data' : data_valid_spiking,
         'train_percentage' : train_percentage,
         'nreplications' : nreplications,
         'dt' : FLAGS.dt,
         'u_std' : FLAGS.u_std,
         'max_firing_rate': FLAGS.max_firing_rate,
         'train_inputs_u': train_inputs_u,
         'valid_inputs_u': valid_inputs_u,
         'train_outputs_u': train_outputs_u,
         'valid_outputs_u': valid_outputs_u,
         'conversion_factor' : FLAGS.max_firing_rate/(1.0/FLAGS.dt) }

# just one dataset here
datasets = {}
dataset_name = 'dataset_N' + str(N)
datasets[dataset_name] = data

# write out the dataset
write_datasets(FLAGS.save_dir, FLAGS.datafile_name, datasets)
print ('Saved to ', os.path.join(FLAGS.save_dir,
                                 FLAGS.datafile_name + '_' + dataset_name))