from __future__ import print_function
import h5py
import numpy as np
import os
from utils import write_datasets
import matplotlib
import matplotlib.pyplot as plt
import scipy.signal
def generate_rnn(rng, N, g, tau, dt, max_firing_rate):
"""Create a (vanilla) RNN with a bunch of hyper parameters for generating
chaotic data.
rng: numpy random number generator
N: number of hidden units
g: scaling of recurrent weight matrix in g W, with W ~ N(0,1/N)
tau: time scale of individual unit dynamics
dt: time step for equation updates
max_firing_rate: how to resecale the -1,1 firing rates
the dictionary of these parameters, plus some others.
rnn = {}
rnn['N'] = N
rnn['W'] = rng.randn(N,N)/np.sqrt(N)
rnn['Bin'] = rng.randn(N)/np.sqrt(1.0)
rnn['Bin2'] = rng.randn(N)/np.sqrt(1.0)
rnn['b'] = np.zeros(N)
rnn['g'] = g
rnn['tau'] = tau
rnn['dt'] = dt
rnn['max_firing_rate'] = max_firing_rate
mfr = rnn['max_firing_rate'] # spikes / sec
nbins_per_sec = 1.0/rnn['dt'] # bins / sec
# Used for plotting in LFADS
rnn['conversion_factor'] = mfr / nbins_per_sec # spikes / bin
return rnn
def generate_data(rnn, T, E, x0s=None, P_sxn=None, input_magnitude=0.0,
""" Generates data from an randomly initialized RNN.
rnn: the rnn
T: Time in seconds to run (divided by rnn['dt'] to get steps, rounded down.
E: total number of examples
S: number of samples (subsampling N)
A list of length E of NxT tensors of the network being run.
N = rnn['N']
def run_rnn(rnn, x0, ntime_steps, input_time=None):
rs = np.zeros([N,ntime_steps])
x_tm1 = x0
r_tm1 = np.tanh(x0)
tau = rnn['tau']
dt = rnn['dt']
alpha = (1.0-dt/tau)
W = dt/tau*rnn['W']*rnn['g']
Bin = dt/tau*rnn['Bin']
Bin2 = dt/tau*rnn['Bin2']
b = dt/tau*rnn['b']
us = np.zeros([1, ntime_steps])
for t in range(ntime_steps):
x_t = alpha*x_tm1 +,r_tm1) + b
if input_time is not None and t == input_time:
us[0,t] = input_magnitude
x_t += Bin * us[0,t] # DCS is this what was used?
r_t = np.tanh(x_t)
x_tm1 = x_t
r_tm1 = r_t
rs[:,t] = r_t
return rs, us
if P_sxn is None:
P_sxn = np.eye(N)
ntime_steps = int(T / rnn['dt'])
data_e = []
inputs_e = []
for e in range(E):
input_time = input_times[e] if input_times is not None else None
r_nxt, u_uxt = run_rnn(rnn, x0s[:,e], ntime_steps, input_time)
r_sxt =, r_nxt)
S = P_sxn.shape[0]
data_e = normalize_rates(data_e, E, S)
return data_e, x0s, inputs_e
def normalize_rates(data_e, E, S):
# Normalization, made more complex because of the P matrices.
# Normalize by min and max in each channel. This normalization will
# cause offset differences between identical rnn runs, but different
# t hits.
for e in range(E):
r_sxt = data_e[e]
for i in range(S):
rmin = np.min(r_sxt[i,:])
rmax = np.max(r_sxt[i,:])
assert rmax - rmin != 0, 'Something wrong'
r_sxt[i,:] = (r_sxt[i,:] - rmin)/(rmax-rmin)
data_e[e] = r_sxt
return data_e
def spikify_data(data_e, rng, dt=1.0, max_firing_rate=100):
""" Apply spikes to a continuous dataset whose values are between 0.0 and 1.0
data_e: nexamples length list of NxT trials
dt: how often the data are sampled
max_firing_rate: the firing rate that is associated with a value of 1.0
spikified_e: a list of length b of the data represented as spikes,
sampled from the underlying poisson process.
E = len(data_e)
spikes_e = []
for e in range(E):
data = data_e[e]
N,T = data.shape
data_s = np.zeros([N,T]).astype(
for n in range(N):
f = data[n,:]
s = rng.poisson(f*max_firing_rate*dt, size=T)
data_s[n,:] = s
return spikes_e
def gaussify_data(data_e, rng, dt=1.0, max_firing_rate=100):
""" Apply gaussian noise to a continuous dataset whose values are between
0.0 and 1.0
data_e: nexamples length list of NxT trials
dt: how often the data are sampled
max_firing_rate: the firing rate that is associated with a value of 1.0
gauss_e: a list of length b of the data with noise.
E = len(data_e)
mfr = max_firing_rate
gauss_e = []
for e in range(E):
data = data_e[e]
N,T = data.shape
noisy_data = data * mfr + np.random.randn(N,T) * (5.0*mfr) * np.sqrt(dt)
return gauss_e
def get_train_n_valid_inds(num_trials, train_fraction, nreplications):
"""Split the numbers between 0 and num_trials-1 into two portions for
training and validation, based on the train fraction.
num_trials: the number of trials
train_fraction: (e.g. .80)
nreplications: the number of spiking trials per initial condition
a 2-tuple of two lists: the training indices and validation indices
train_inds = []
valid_inds = []
for i in range(num_trials):
# This line divides up the trials so that within one initial condition,
# the randomness of spikifying the condition is shared among both
# training and validation data splits.
if (i % nreplications)+1 > train_fraction * nreplications:
return train_inds, valid_inds
def split_list_by_inds(data, inds1, inds2):
"""Take the data, a list, and split it up based on the indices in inds1 and
data: the list of data to split
inds1, the first list of indices
inds2, the second list of indices
Returns: a 2-tuple of two lists.
if data is None or len(data) == 0:
return [], []
dout1 = [data[i] for i in inds1]
dout2 = [data[i] for i in inds2]
return dout1, dout2
def nparray_and_transpose(data_a_b_c):
"""Convert the list of items in data to a numpy array, and transpose it
data: data_asbsc: a nested, nested list of length a, with sublist length
b, with sublist length c.
a numpy 3-tensor with dimensions a x c x b
data_axbxc = np.array([datum_b_c for datum_b_c in data_a_b_c])
data_axcxb = np.transpose(data_axbxc, axes=[0,2,1])
return data_axcxb
def add_alignment_projections(datasets, npcs, ntime=None, nsamples=None):
"""Create a matrix that aligns the datasets a bit, under
the assumption that each dataset is observing the same underlying dynamical
datasets: The dictionary of dataset structures.
npcs: The number of pcs for each, basically like lfads factors.
nsamples (optional): Number of samples to take for each dataset.
ntime (optional): Number of time steps to take in each sample.
The dataset structures, with the field alignment_matrix_cxf added.
This is # channels x npcs dimension
nchannels_all = 0
channel_idxs = {}
conditions_all = {}
nconditions_all = 0
for name, dataset in datasets.items():
cidxs = np.where(dataset['P_sxn'])[1] # non-zero entries in columns
channel_idxs[name] = [cidxs[0], cidxs[-1]+1]
nchannels_all += cidxs[-1]+1 - cidxs[0]
conditions_all[name] = np.unique(dataset['condition_labels_train'])
all_conditions_list = \
nconditions_all = all_conditions_list.shape[0]
if ntime is None:
ntime = dataset['train_data'].shape[1]
if nsamples is None:
nsamples = dataset['train_data'].shape[0]
# In the data workup in the paper, Chethan did intra condition
# averaging, so let's do that here.
avg_data_all = {}
for name, conditions in conditions_all.items():
dataset = datasets[name]
avg_data_all[name] = {}
for cname in conditions:
td_idxs = np.argwhere(np.array(dataset['condition_labels_train'])==cname)
data = np.squeeze(dataset['train_data'][td_idxs,:,:], axis=1)
avg_data = np.mean(data, axis=0)
avg_data_all[name][cname] = avg_data
# Visualize this in the morning.
all_data_nxtc = np.zeros([nchannels_all, ntime * nconditions_all])
for name, dataset in datasets.items():
cidx_s = channel_idxs[name][0]
cidx_f = channel_idxs[name][1]
for cname in conditions_all[name]:
cidxs = np.argwhere(all_conditions_list == cname)
if cidxs.shape[0] > 0:
cidx = cidxs[0][0]
all_tidxs = np.arange(0, ntime+1) + cidx*ntime
all_data_nxtc[cidx_s:cidx_f, all_tidxs[0]:all_tidxs[-1]] = \
# A bit of filtering. We don't care about spectral properties, or
# filtering artifacts, simply correlate time steps a bit.
filt_len = 6
bc_filt = np.ones([filt_len])/float(filt_len)
for c in range(nchannels_all):
all_data_nxtc[c,:] = scipy.signal.filtfilt(bc_filt, [1.0], all_data_nxtc[c,:])
# Compute the PCs.
all_data_mean_nx1 = np.mean(all_data_nxtc, axis=1, keepdims=True)
all_data_zm_nxtc = all_data_nxtc - all_data_mean_nx1
corr_mat_nxn =, all_data_zm_nxtc.T)
evals_n, evecs_nxn = np.linalg.eigh(corr_mat_nxn)
sidxs = np.flipud(np.argsort(evals_n)) # sort such that 0th is highest
evals_n = evals_n[sidxs]
evecs_nxn = evecs_nxn[:,sidxs]
# Project all the channels data onto the low-D PCA basis, where
# low-d is the npcs parameter.
all_data_pca_pxtc =[:, 0:npcs].T, all_data_zm_nxtc)
# Now for each dataset, we regress the channel data onto the top
# pcs, and this will be our alignment matrix for that dataset.
# |B - A*W|^2
for name, dataset in datasets.items():
cidx_s = channel_idxs[name][0]
cidx_f = channel_idxs[name][1]
all_data_zm_chxtc = all_data_zm_nxtc[cidx_s:cidx_f,:] # ch for channel
W_chxp, _, _, _ = \
np.linalg.lstsq(all_data_zm_chxtc.T, all_data_pca_pxtc.T)
dataset['alignment_matrix_cxf'] = W_chxp
alignment_bias_cx1 = all_data_mean_nx1[cidx_s:cidx_f]
dataset['alignment_bias_c'] = np.squeeze(alignment_bias_cx1, axis=1)
do_debug_plot = False
if do_debug_plot:
pc_vecs = evecs_nxn[:,0:npcs]
ntoplot = 400
plt.plot(np.log10(evals_n), '-x')
plt.imshow(, all_data_zm_chxtc))
plt.imshow(, W_chxp).T - all_data_pca_pxtc)
import pdb
return datasets