Spaces:
Running
Running
# Copyright 2017 Google Inc. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
# ============================================================================== | |
from __future__ import print_function | |
import h5py | |
import numpy as np | |
import os | |
from six.moves import xrange | |
import tensorflow as tf | |
from utils import write_datasets | |
from synthetic_data_utils import normalize_rates | |
from synthetic_data_utils import get_train_n_valid_inds, nparray_and_transpose | |
from synthetic_data_utils import spikify_data, split_list_by_inds | |
DATA_DIR = "rnn_synth_data_v1.0" | |
flags = tf.app.flags | |
flags.DEFINE_string("save_dir", "/tmp/" + DATA_DIR + "/", | |
"Directory for saving data.") | |
flags.DEFINE_string("datafile_name", "itb_rnn", | |
"Name of data file for input case.") | |
flags.DEFINE_integer("synth_data_seed", 5, "Random seed for RNN generation.") | |
flags.DEFINE_float("T", 1.0, "Time in seconds to generate.") | |
flags.DEFINE_integer("C", 800, "Number of conditions") | |
flags.DEFINE_integer("N", 50, "Number of units for the RNN") | |
flags.DEFINE_float("train_percentage", 4.0/5.0, | |
"Percentage of train vs validation trials") | |
flags.DEFINE_integer("nreplications", 5, | |
"Number of spikifications of the same underlying rates.") | |
flags.DEFINE_float("tau", 0.025, "Time constant of RNN") | |
flags.DEFINE_float("dt", 0.010, "Time bin") | |
flags.DEFINE_float("max_firing_rate", 30.0, | |
"Map 1.0 of RNN to a spikes per second") | |
flags.DEFINE_float("u_std", 0.25, | |
"Std dev of input to integration to bound model") | |
flags.DEFINE_string("checkpoint_path", "SAMPLE_CHECKPOINT", | |
"""Path to directory with checkpoints of model | |
trained on integration to bound task. Currently this | |
is a placeholder which tells the code to grab the | |
checkpoint that is provided with the code | |
(in /trained_itb/..). If you have your own checkpoint | |
you would like to restore, you would point it to | |
that path.""") | |
FLAGS = flags.FLAGS | |
class IntegrationToBoundModel: | |
def __init__(self, N): | |
scale = 0.8 / float(N**0.5) | |
self.N = N | |
self.Wh_nxn = tf.Variable(tf.random_normal([N, N], stddev=scale)) | |
self.b_1xn = tf.Variable(tf.zeros([1, N])) | |
self.Bu_1xn = tf.Variable(tf.zeros([1, N])) | |
self.Wro_nxo = tf.Variable(tf.random_normal([N, 1], stddev=scale)) | |
self.bro_o = tf.Variable(tf.zeros([1])) | |
def call(self, h_tm1_bxn, u_bx1): | |
act_t_bxn = tf.matmul(h_tm1_bxn, self.Wh_nxn) + self.b_1xn + u_bx1 * self.Bu_1xn | |
h_t_bxn = tf.nn.tanh(act_t_bxn) | |
z_t = tf.nn.xw_plus_b(h_t_bxn, self.Wro_nxo, self.bro_o) | |
return z_t, h_t_bxn | |
def get_data_batch(batch_size, T, rng, u_std): | |
u_bxt = rng.randn(batch_size, T) * u_std | |
running_sum_b = np.zeros([batch_size]) | |
labels_bxt = np.zeros([batch_size, T]) | |
for t in xrange(T): | |
running_sum_b += u_bxt[:, t] | |
labels_bxt[:, t] += running_sum_b | |
labels_bxt = np.clip(labels_bxt, -1, 1) | |
return u_bxt, labels_bxt | |
rng = np.random.RandomState(seed=FLAGS.synth_data_seed) | |
u_rng = np.random.RandomState(seed=FLAGS.synth_data_seed+1) | |
T = FLAGS.T | |
C = FLAGS.C | |
N = FLAGS.N # must be same N as in trained model (provided example is N = 50) | |
nreplications = FLAGS.nreplications | |
E = nreplications * C # total number of trials | |
train_percentage = FLAGS.train_percentage | |
ntimesteps = int(T / FLAGS.dt) | |
batch_size = 1 # gives one example per ntrial | |
model = IntegrationToBoundModel(N) | |
inputs_ph_t = [tf.placeholder(tf.float32, | |
shape=[None, 1]) for _ in range(ntimesteps)] | |
state = tf.zeros([batch_size, N]) | |
saver = tf.train.Saver() | |
P_nxn = rng.randn(N,N) / np.sqrt(N) # random projections | |
# unroll RNN for T timesteps | |
outputs_t = [] | |
states_t = [] | |
for inp in inputs_ph_t: | |
output, state = model.call(state, inp) | |
outputs_t.append(output) | |
states_t.append(state) | |
with tf.Session() as sess: | |
# restore the latest model ckpt | |
if FLAGS.checkpoint_path == "SAMPLE_CHECKPOINT": | |
dir_path = os.path.dirname(os.path.realpath(__file__)) | |
model_checkpoint_path = os.path.join(dir_path, "trained_itb/model-65000") | |
else: | |
model_checkpoint_path = FLAGS.checkpoint_path | |
try: | |
saver.restore(sess, model_checkpoint_path) | |
print ('Model restored from', model_checkpoint_path) | |
except: | |
assert False, ("No checkpoints to restore from, is the path %s correct?" | |
%model_checkpoint_path) | |
# generate data for trials | |
data_e = [] | |
u_e = [] | |
outs_e = [] | |
for c in range(C): | |
u_1xt, outs_1xt = get_data_batch(batch_size, ntimesteps, u_rng, FLAGS.u_std) | |
feed_dict = {} | |
for t in xrange(ntimesteps): | |
feed_dict[inputs_ph_t[t]] = np.reshape(u_1xt[:,t], (batch_size,-1)) | |
states_t_bxn, outputs_t_bxn = sess.run([states_t, outputs_t], | |
feed_dict=feed_dict) | |
states_nxt = np.transpose(np.squeeze(np.asarray(states_t_bxn))) | |
outputs_t_bxn = np.squeeze(np.asarray(outputs_t_bxn)) | |
r_sxt = np.dot(P_nxn, states_nxt) | |
for s in xrange(nreplications): | |
data_e.append(r_sxt) | |
u_e.append(u_1xt) | |
outs_e.append(outputs_t_bxn) | |
truth_data_e = normalize_rates(data_e, E, N) | |
spiking_data_e = spikify_data(truth_data_e, rng, dt=FLAGS.dt, | |
max_firing_rate=FLAGS.max_firing_rate) | |
train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage, | |
nreplications) | |
data_train_truth, data_valid_truth = split_list_by_inds(truth_data_e, | |
train_inds, | |
valid_inds) | |
data_train_spiking, data_valid_spiking = split_list_by_inds(spiking_data_e, | |
train_inds, | |
valid_inds) | |
data_train_truth = nparray_and_transpose(data_train_truth) | |
data_valid_truth = nparray_and_transpose(data_valid_truth) | |
data_train_spiking = nparray_and_transpose(data_train_spiking) | |
data_valid_spiking = nparray_and_transpose(data_valid_spiking) | |
# save down the inputs used to generate this data | |
train_inputs_u, valid_inputs_u = split_list_by_inds(u_e, | |
train_inds, | |
valid_inds) | |
train_inputs_u = nparray_and_transpose(train_inputs_u) | |
valid_inputs_u = nparray_and_transpose(valid_inputs_u) | |
# save down the network outputs (may be useful later) | |
train_outputs_u, valid_outputs_u = split_list_by_inds(outs_e, | |
train_inds, | |
valid_inds) | |
train_outputs_u = np.array(train_outputs_u) | |
valid_outputs_u = np.array(valid_outputs_u) | |
data = { 'train_truth': data_train_truth, | |
'valid_truth': data_valid_truth, | |
'train_data' : data_train_spiking, | |
'valid_data' : data_valid_spiking, | |
'train_percentage' : train_percentage, | |
'nreplications' : nreplications, | |
'dt' : FLAGS.dt, | |
'u_std' : FLAGS.u_std, | |
'max_firing_rate': FLAGS.max_firing_rate, | |
'train_inputs_u': train_inputs_u, | |
'valid_inputs_u': valid_inputs_u, | |
'train_outputs_u': train_outputs_u, | |
'valid_outputs_u': valid_outputs_u, | |
'conversion_factor' : FLAGS.max_firing_rate/(1.0/FLAGS.dt) } | |
# just one dataset here | |
datasets = {} | |
dataset_name = 'dataset_N' + str(N) | |
datasets[dataset_name] = data | |
# write out the dataset | |
write_datasets(FLAGS.save_dir, FLAGS.datafile_name, datasets) | |
print ('Saved to ', os.path.join(FLAGS.save_dir, | |
FLAGS.datafile_name + '_' + dataset_name)) | |