Spaces:
Running
Running
File size: 94,569 Bytes
0b8359d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100 2101 2102 2103 2104 2105 2106 2107 2108 2109 2110 2111 2112 2113 2114 2115 2116 2117 2118 2119 2120 2121 2122 2123 2124 2125 2126 2127 2128 2129 2130 2131 2132 2133 2134 2135 2136 2137 2138 2139 2140 2141 2142 2143 2144 2145 2146 2147 2148 2149 2150 2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162 2163 2164 2165 2166 2167 2168 2169 2170 2171 |
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
"""
LFADS - Latent Factor Analysis via Dynamical Systems.
LFADS is an unsupervised method to decompose time series data into
various factors, such as an initial condition, a generative
dynamical system, control inputs to that generator, and a low
dimensional description of the observed data, called the factors.
Additionally, the observations have a noise model (in this case
Poisson), so a denoised version of the observations is also created
(e.g. underlying rates of a Poisson distribution given the observed
event counts).
The main data structure being passed around is a dataset. This is a dictionary
of data dictionaries.
DATASET: The top level dictionary is simply name (string -> dictionary).
The nested dictionary is the DATA DICTIONARY, which has the following keys:
'train_data' and 'valid_data', whose values are the corresponding training
and validation data with shape
ExTxD, E - # examples, T - # time steps, D - # dimensions in data.
The data dictionary also has a few more keys:
'train_ext_input' and 'valid_ext_input', if there are know external inputs
to the system being modeled, these take on dimensions:
ExTxI, E - # examples, T - # time steps, I = # dimensions in input.
'alignment_matrix_cxf' - If you are using multiple days data, it's possible
that one can align the channels (see manuscript). If so each dataset will
contain this matrix, which will be used for both the input adapter and the
output adapter for each dataset. These matrices, if provided, must be of
size [data_dim x factors] where data_dim is the number of neurons recorded
on that day, and factors is chosen and set through the '--factors' flag.
'alignment_bias_c' - See alignment_matrix_cxf. This bias will used to
the offset for the alignment transformation. It will *subtract* off the
bias from the data, so pca style inits can align factors across sessions.
If one runs LFADS on data where the true rates are known for some trials,
(say simulated, testing data, as in the example shipped with the paper), then
one can add three more fields for plotting purposes. These are 'train_truth'
and 'valid_truth', and 'conversion_factor'. These have the same dimensions as
'train_data', and 'valid_data' but represent the underlying rates of the
observations. Finally, if one needs to convert scale for plotting the true
underlying firing rates, there is the 'conversion_factor' key.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import os
import tensorflow as tf
from distributions import LearnableDiagonalGaussian, DiagonalGaussianFromInput
from distributions import diag_gaussian_log_likelihood
from distributions import KLCost_GaussianGaussian, Poisson
from distributions import LearnableAutoRegressive1Prior
from distributions import KLCost_GaussianGaussianProcessSampled
from utils import init_linear, linear, list_t_bxn_to_tensor_bxtxn, write_data
from utils import log_sum_exp, flatten
from plot_lfads import plot_lfads
class GRU(object):
"""Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).
"""
def __init__(self, num_units, forget_bias=1.0, weight_scale=1.0,
clip_value=np.inf, collections=None):
"""Create a GRU object.
Args:
num_units: Number of units in the GRU
forget_bias (optional): Hack to help learning.
weight_scale (optional): weights are scaled by ws/sqrt(#inputs), with
ws being the weight scale.
clip_value (optional): if the recurrent values grow above this value,
clip them.
collections (optional): List of additonal collections variables should
belong to.
"""
self._num_units = num_units
self._forget_bias = forget_bias
self._weight_scale = weight_scale
self._clip_value = clip_value
self._collections = collections
@property
def state_size(self):
return self._num_units
@property
def output_size(self):
return self._num_units
@property
def state_multiplier(self):
return 1
def output_from_state(self, state):
"""Return the output portion of the state."""
return state
def __call__(self, inputs, state, scope=None):
"""Gated recurrent unit (GRU) function.
Args:
inputs: A 2D batch x input_dim tensor of inputs.
state: The previous state from the last time step.
scope (optional): TF variable scope for defined GRU variables.
Returns:
A tuple (state, state), where state is the newly computed state at time t.
It is returned twice to respect an interface that works for LSTMs.
"""
x = inputs
h = state
if inputs is not None:
xh = tf.concat(axis=1, values=[x, h])
else:
xh = h
with tf.variable_scope(scope or type(self).__name__): # "GRU"
with tf.variable_scope("Gates"): # Reset gate and update gate.
# We start with bias of 1.0 to not reset and not update.
r, u = tf.split(axis=1, num_or_size_splits=2, value=linear(xh,
2 * self._num_units,
alpha=self._weight_scale,
name="xh_2_ru",
collections=self._collections))
r, u = tf.sigmoid(r), tf.sigmoid(u + self._forget_bias)
with tf.variable_scope("Candidate"):
xrh = tf.concat(axis=1, values=[x, r * h])
c = tf.tanh(linear(xrh, self._num_units, name="xrh_2_c",
collections=self._collections))
new_h = u * h + (1 - u) * c
new_h = tf.clip_by_value(new_h, -self._clip_value, self._clip_value)
return new_h, new_h
class GenGRU(object):
"""Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).
This version is specialized for the generator, but isn't as fast, so
we have two. Note this allows for l2 regularization on the recurrent
weights, but also implicitly rescales the inputs via the 1/sqrt(input)
scaling in the linear helper routine to be large magnitude, if there are
fewer inputs than recurrent state.
"""
def __init__(self, num_units, forget_bias=1.0,
input_weight_scale=1.0, rec_weight_scale=1.0, clip_value=np.inf,
input_collections=None, recurrent_collections=None):
"""Create a GRU object.
Args:
num_units: Number of units in the GRU
forget_bias (optional): Hack to help learning.
input_weight_scale (optional): weights are scaled ws/sqrt(#inputs), with
ws being the weight scale.
rec_weight_scale (optional): weights are scaled ws/sqrt(#inputs),
with ws being the weight scale.
clip_value (optional): if the recurrent values grow above this value,
clip them.
input_collections (optional): List of additonal collections variables
that input->rec weights should belong to.
recurrent_collections (optional): List of additonal collections variables
that rec->rec weights should belong to.
"""
self._num_units = num_units
self._forget_bias = forget_bias
self._input_weight_scale = input_weight_scale
self._rec_weight_scale = rec_weight_scale
self._clip_value = clip_value
self._input_collections = input_collections
self._rec_collections = recurrent_collections
@property
def state_size(self):
return self._num_units
@property
def output_size(self):
return self._num_units
@property
def state_multiplier(self):
return 1
def output_from_state(self, state):
"""Return the output portion of the state."""
return state
def __call__(self, inputs, state, scope=None):
"""Gated recurrent unit (GRU) function.
Args:
inputs: A 2D batch x input_dim tensor of inputs.
state: The previous state from the last time step.
scope (optional): TF variable scope for defined GRU variables.
Returns:
A tuple (state, state), where state is the newly computed state at time t.
It is returned twice to respect an interface that works for LSTMs.
"""
x = inputs
h = state
with tf.variable_scope(scope or type(self).__name__): # "GRU"
with tf.variable_scope("Gates"): # Reset gate and update gate.
# We start with bias of 1.0 to not reset and not update.
r_x = u_x = 0.0
if x is not None:
r_x, u_x = tf.split(axis=1, num_or_size_splits=2, value=linear(x,
2 * self._num_units,
alpha=self._input_weight_scale,
do_bias=False,
name="x_2_ru",
normalized=False,
collections=self._input_collections))
r_h, u_h = tf.split(axis=1, num_or_size_splits=2, value=linear(h,
2 * self._num_units,
do_bias=True,
alpha=self._rec_weight_scale,
name="h_2_ru",
collections=self._rec_collections))
r = r_x + r_h
u = u_x + u_h
r, u = tf.sigmoid(r), tf.sigmoid(u + self._forget_bias)
with tf.variable_scope("Candidate"):
c_x = 0.0
if x is not None:
c_x = linear(x, self._num_units, name="x_2_c", do_bias=False,
alpha=self._input_weight_scale,
normalized=False,
collections=self._input_collections)
c_rh = linear(r*h, self._num_units, name="rh_2_c", do_bias=True,
alpha=self._rec_weight_scale,
collections=self._rec_collections)
c = tf.tanh(c_x + c_rh)
new_h = u * h + (1 - u) * c
new_h = tf.clip_by_value(new_h, -self._clip_value, self._clip_value)
return new_h, new_h
class LFADS(object):
"""LFADS - Latent Factor Analysis via Dynamical Systems.
LFADS is an unsupervised method to decompose time series data into
various factors, such as an initial condition, a generative
dynamical system, inferred inputs to that generator, and a low
dimensional description of the observed data, called the factors.
Additoinally, the observations have a noise model (in this case
Poisson), so a denoised version of the observations is also created
(e.g. underlying rates of a Poisson distribution given the observed
event counts).
"""
def __init__(self, hps, kind="train", datasets=None):
"""Create an LFADS model.
train - a model for training, sampling of posteriors is used
posterior_sample_and_average - sample from the posterior, this is used
for evaluating the expected value of the outputs of LFADS, given a
specific input, by averaging over multiple samples from the approx
posterior. Also used for the lower bound on the negative
log-likelihood using IWAE error (Importance Weighed Auto-encoder).
This is the denoising operation.
prior_sample - a model for generation - sampling from priors is used
Args:
hps: The dictionary of hyper parameters.
kind: the type of model to build (see above).
datasets: a dictionary of named data_dictionaries, see top of lfads.py
"""
print("Building graph...")
all_kinds = ['train', 'posterior_sample_and_average', 'posterior_push_mean',
'prior_sample']
assert kind in all_kinds, 'Wrong kind'
if hps.feedback_factors_or_rates == "rates":
assert len(hps.dataset_names) == 1, \
"Multiple datasets not supported for rate feedback."
num_steps = hps.num_steps
ic_dim = hps.ic_dim
co_dim = hps.co_dim
ext_input_dim = hps.ext_input_dim
cell_class = GRU
gen_cell_class = GenGRU
def makelambda(v): # Used with tf.case
return lambda: v
# Define the data placeholder, and deal with all parts of the graph
# that are dataset dependent.
self.dataName = tf.placeholder(tf.string, shape=())
# The batch_size to be inferred from data, as normal.
# Additionally, the data_dim will be inferred as well, allowing for a
# single placeholder for all datasets, regardless of data dimension.
if hps.output_dist == 'poisson':
# Enforce correct dtype
assert np.issubdtype(
datasets[hps.dataset_names[0]]['train_data'].dtype, int), \
"Data dtype must be int for poisson output distribution"
data_dtype = tf.int32
elif hps.output_dist == 'gaussian':
assert np.issubdtype(
datasets[hps.dataset_names[0]]['train_data'].dtype, float), \
"Data dtype must be float for gaussian output dsitribution"
data_dtype = tf.float32
else:
assert False, "NIY"
self.dataset_ph = dataset_ph = tf.placeholder(data_dtype,
[None, num_steps, None],
name="data")
self.train_step = tf.get_variable("global_step", [], tf.int64,
tf.zeros_initializer(),
trainable=False)
self.hps = hps
ndatasets = hps.ndatasets
factors_dim = hps.factors_dim
self.preds = preds = [None] * ndatasets
self.fns_in_fac_Ws = fns_in_fac_Ws = [None] * ndatasets
self.fns_in_fatcor_bs = fns_in_fac_bs = [None] * ndatasets
self.fns_out_fac_Ws = fns_out_fac_Ws = [None] * ndatasets
self.fns_out_fac_bs = fns_out_fac_bs = [None] * ndatasets
self.datasetNames = dataset_names = hps.dataset_names
self.ext_inputs = ext_inputs = None
if len(dataset_names) == 1: # single session
if 'alignment_matrix_cxf' in datasets[dataset_names[0]].keys():
used_in_factors_dim = factors_dim
in_identity_if_poss = False
else:
used_in_factors_dim = hps.dataset_dims[dataset_names[0]]
in_identity_if_poss = True
else: # multisession
used_in_factors_dim = factors_dim
in_identity_if_poss = False
for d, name in enumerate(dataset_names):
data_dim = hps.dataset_dims[name]
in_mat_cxf = None
in_bias_1xf = None
align_bias_1xc = None
if datasets and 'alignment_matrix_cxf' in datasets[name].keys():
dataset = datasets[name]
if hps.do_train_readin:
print("Initializing trainable readin matrix with alignment matrix" \
" provided for dataset:", name)
else:
print("Setting non-trainable readin matrix to alignment matrix" \
" provided for dataset:", name)
in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32)
if in_mat_cxf.shape != (data_dim, factors_dim):
raise ValueError("""Alignment matrix must have dimensions %d x %d
(data_dim x factors_dim), but currently has %d x %d."""%
(data_dim, factors_dim, in_mat_cxf.shape[0],
in_mat_cxf.shape[1]))
if datasets and 'alignment_bias_c' in datasets[name].keys():
dataset = datasets[name]
if hps.do_train_readin:
print("Initializing trainable readin bias with alignment bias " \
"provided for dataset:", name)
else:
print("Setting non-trainable readin bias to alignment bias " \
"provided for dataset:", name)
align_bias_c = dataset['alignment_bias_c'].astype(np.float32)
align_bias_1xc = np.expand_dims(align_bias_c, axis=0)
if align_bias_1xc.shape[1] != data_dim:
raise ValueError("""Alignment bias must have dimensions %d
(data_dim), but currently has %d."""%
(data_dim, in_mat_cxf.shape[0]))
if in_mat_cxf is not None and align_bias_1xc is not None:
# (data - alignment_bias) * W_in
# data * W_in - alignment_bias * W_in
# So b = -alignment_bias * W_in to accommodate PCA style offset.
in_bias_1xf = -np.dot(align_bias_1xc, in_mat_cxf)
if hps.do_train_readin:
# only add to IO transformations collection only if we want it to be
# learnable, because IO_transformations collection will be trained
# when do_train_io_only
collections_readin=['IO_transformations']
else:
collections_readin=None
in_fac_lin = init_linear(data_dim, used_in_factors_dim,
do_bias=True,
mat_init_value=in_mat_cxf,
bias_init_value=in_bias_1xf,
identity_if_possible=in_identity_if_poss,
normalized=False, name="x_2_infac_"+name,
collections=collections_readin,
trainable=hps.do_train_readin)
in_fac_W, in_fac_b = in_fac_lin
fns_in_fac_Ws[d] = makelambda(in_fac_W)
fns_in_fac_bs[d] = makelambda(in_fac_b)
with tf.variable_scope("glm"):
out_identity_if_poss = False
if len(dataset_names) == 1 and \
factors_dim == hps.dataset_dims[dataset_names[0]]:
out_identity_if_poss = True
for d, name in enumerate(dataset_names):
data_dim = hps.dataset_dims[name]
in_mat_cxf = None
if datasets and 'alignment_matrix_cxf' in datasets[name].keys():
dataset = datasets[name]
in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32)
if datasets and 'alignment_bias_c' in datasets[name].keys():
dataset = datasets[name]
align_bias_c = dataset['alignment_bias_c'].astype(np.float32)
align_bias_1xc = np.expand_dims(align_bias_c, axis=0)
out_mat_fxc = None
out_bias_1xc = None
if in_mat_cxf is not None:
out_mat_fxc = in_mat_cxf.T
if align_bias_1xc is not None:
out_bias_1xc = align_bias_1xc
if hps.output_dist == 'poisson':
out_fac_lin = init_linear(factors_dim, data_dim, do_bias=True,
mat_init_value=out_mat_fxc,
bias_init_value=out_bias_1xc,
identity_if_possible=out_identity_if_poss,
normalized=False,
name="fac_2_logrates_"+name,
collections=['IO_transformations'])
out_fac_W, out_fac_b = out_fac_lin
elif hps.output_dist == 'gaussian':
out_fac_lin_mean = \
init_linear(factors_dim, data_dim, do_bias=True,
mat_init_value=out_mat_fxc,
bias_init_value=out_bias_1xc,
normalized=False,
name="fac_2_means_"+name,
collections=['IO_transformations'])
out_fac_W_mean, out_fac_b_mean = out_fac_lin_mean
mat_init_value = np.zeros([factors_dim, data_dim]).astype(np.float32)
bias_init_value = np.ones([1, data_dim]).astype(np.float32)
out_fac_lin_logvar = \
init_linear(factors_dim, data_dim, do_bias=True,
mat_init_value=mat_init_value,
bias_init_value=bias_init_value,
normalized=False,
name="fac_2_logvars_"+name,
collections=['IO_transformations'])
out_fac_W_mean, out_fac_b_mean = out_fac_lin_mean
out_fac_W_logvar, out_fac_b_logvar = out_fac_lin_logvar
out_fac_W = tf.concat(
axis=1, values=[out_fac_W_mean, out_fac_W_logvar])
out_fac_b = tf.concat(
axis=1, values=[out_fac_b_mean, out_fac_b_logvar])
else:
assert False, "NIY"
preds[d] = tf.equal(tf.constant(name), self.dataName)
data_dim = hps.dataset_dims[name]
fns_out_fac_Ws[d] = makelambda(out_fac_W)
fns_out_fac_bs[d] = makelambda(out_fac_b)
pf_pairs_in_fac_Ws = zip(preds, fns_in_fac_Ws)
pf_pairs_in_fac_bs = zip(preds, fns_in_fac_bs)
pf_pairs_out_fac_Ws = zip(preds, fns_out_fac_Ws)
pf_pairs_out_fac_bs = zip(preds, fns_out_fac_bs)
this_in_fac_W = tf.case(pf_pairs_in_fac_Ws, exclusive=True)
this_in_fac_b = tf.case(pf_pairs_in_fac_bs, exclusive=True)
this_out_fac_W = tf.case(pf_pairs_out_fac_Ws, exclusive=True)
this_out_fac_b = tf.case(pf_pairs_out_fac_bs, exclusive=True)
# External inputs (not changing by dataset, by definition).
if hps.ext_input_dim > 0:
self.ext_input = tf.placeholder(tf.float32,
[None, num_steps, ext_input_dim],
name="ext_input")
else:
self.ext_input = None
ext_input_bxtxi = self.ext_input
self.keep_prob = keep_prob = tf.placeholder(tf.float32, [], "keep_prob")
self.batch_size = batch_size = int(hps.batch_size)
self.learning_rate = tf.Variable(float(hps.learning_rate_init),
trainable=False, name="learning_rate")
self.learning_rate_decay_op = self.learning_rate.assign(
self.learning_rate * hps.learning_rate_decay_factor)
# Dropout the data.
dataset_do_bxtxd = tf.nn.dropout(tf.to_float(dataset_ph), keep_prob)
if hps.ext_input_dim > 0:
ext_input_do_bxtxi = tf.nn.dropout(ext_input_bxtxi, keep_prob)
else:
ext_input_do_bxtxi = None
# ENCODERS
def encode_data(dataset_bxtxd, enc_cell, name, forward_or_reverse,
num_steps_to_encode):
"""Encode data for LFADS
Args:
dataset_bxtxd - the data to encode, as a 3 tensor, with dims
time x batch x data dims.
enc_cell: encoder cell
name: name of encoder
forward_or_reverse: string, encode in forward or reverse direction
num_steps_to_encode: number of steps to encode, 0:num_steps_to_encode
Returns:
encoded data as a list with num_steps_to_encode items, in order
"""
if forward_or_reverse == "forward":
dstr = "_fwd"
time_fwd_or_rev = range(num_steps_to_encode)
else:
dstr = "_rev"
time_fwd_or_rev = reversed(range(num_steps_to_encode))
with tf.variable_scope(name+"_enc"+dstr, reuse=False):
enc_state = tf.tile(
tf.Variable(tf.zeros([1, enc_cell.state_size]),
name=name+"_enc_t0"+dstr), tf.stack([batch_size, 1]))
enc_state.set_shape([None, enc_cell.state_size]) # tile loses shape
enc_outs = [None] * num_steps_to_encode
for i, t in enumerate(time_fwd_or_rev):
with tf.variable_scope(name+"_enc"+dstr, reuse=True if i > 0 else None):
dataset_t_bxd = dataset_bxtxd[:,t,:]
in_fac_t_bxf = tf.matmul(dataset_t_bxd, this_in_fac_W) + this_in_fac_b
in_fac_t_bxf.set_shape([None, used_in_factors_dim])
if ext_input_dim > 0 and not hps.inject_ext_input_to_gen:
ext_input_t_bxi = ext_input_do_bxtxi[:,t,:]
enc_input_t_bxfpe = tf.concat(
axis=1, values=[in_fac_t_bxf, ext_input_t_bxi])
else:
enc_input_t_bxfpe = in_fac_t_bxf
enc_out, enc_state = enc_cell(enc_input_t_bxfpe, enc_state)
enc_outs[t] = enc_out
return enc_outs
# Encode initial condition means and variances
# ([x_T, x_T-1, ... x_0] and [x_0, x_1, ... x_T] -> g0/c0)
self.ic_enc_fwd = [None] * num_steps
self.ic_enc_rev = [None] * num_steps
if ic_dim > 0:
enc_ic_cell = cell_class(hps.ic_enc_dim,
weight_scale=hps.cell_weight_scale,
clip_value=hps.cell_clip_value)
ic_enc_fwd = encode_data(dataset_do_bxtxd, enc_ic_cell,
"ic", "forward",
hps.num_steps_for_gen_ic)
ic_enc_rev = encode_data(dataset_do_bxtxd, enc_ic_cell,
"ic", "reverse",
hps.num_steps_for_gen_ic)
self.ic_enc_fwd = ic_enc_fwd
self.ic_enc_rev = ic_enc_rev
# Encoder control input means and variances, bi-directional encoding so:
# ([x_T, x_T-1, ..., x_0] and [x_0, x_1 ... x_T] -> u_t)
self.ci_enc_fwd = [None] * num_steps
self.ci_enc_rev = [None] * num_steps
if co_dim > 0:
enc_ci_cell = cell_class(hps.ci_enc_dim,
weight_scale=hps.cell_weight_scale,
clip_value=hps.cell_clip_value)
ci_enc_fwd = encode_data(dataset_do_bxtxd, enc_ci_cell,
"ci", "forward",
hps.num_steps)
if hps.do_causal_controller:
ci_enc_rev = None
else:
ci_enc_rev = encode_data(dataset_do_bxtxd, enc_ci_cell,
"ci", "reverse",
hps.num_steps)
self.ci_enc_fwd = ci_enc_fwd
self.ci_enc_rev = ci_enc_rev
# STOCHASTIC LATENT VARIABLES, priors and posteriors
# (initial conditions g0, and control inputs, u_t)
# Note that zs represent all the stochastic latent variables.
with tf.variable_scope("z", reuse=False):
self.prior_zs_g0 = None
self.posterior_zs_g0 = None
self.g0s_val = None
if ic_dim > 0:
self.prior_zs_g0 = \
LearnableDiagonalGaussian(batch_size, ic_dim, name="prior_g0",
mean_init=0.0,
var_min=hps.ic_prior_var_min,
var_init=hps.ic_prior_var_scale,
var_max=hps.ic_prior_var_max)
ic_enc = tf.concat(axis=1, values=[ic_enc_fwd[-1], ic_enc_rev[0]])
ic_enc = tf.nn.dropout(ic_enc, keep_prob)
self.posterior_zs_g0 = \
DiagonalGaussianFromInput(ic_enc, ic_dim, "ic_enc_2_post_g0",
var_min=hps.ic_post_var_min)
if kind in ["train", "posterior_sample_and_average",
"posterior_push_mean"]:
zs_g0 = self.posterior_zs_g0
else:
zs_g0 = self.prior_zs_g0
if kind in ["train", "posterior_sample_and_average", "prior_sample"]:
self.g0s_val = zs_g0.sample
else:
self.g0s_val = zs_g0.mean
# Priors for controller, 'co' for controller output
self.prior_zs_co = prior_zs_co = [None] * num_steps
self.posterior_zs_co = posterior_zs_co = [None] * num_steps
self.zs_co = zs_co = [None] * num_steps
self.prior_zs_ar_con = None
if co_dim > 0:
# Controller outputs
autocorrelation_taus = [hps.prior_ar_atau for x in range(hps.co_dim)]
noise_variances = [hps.prior_ar_nvar for x in range(hps.co_dim)]
self.prior_zs_ar_con = prior_zs_ar_con = \
LearnableAutoRegressive1Prior(batch_size, hps.co_dim,
autocorrelation_taus,
noise_variances,
hps.do_train_prior_ar_atau,
hps.do_train_prior_ar_nvar,
num_steps, "u_prior_ar1")
# CONTROLLER -> GENERATOR -> RATES
# (u(t) -> gen(t) -> factors(t) -> rates(t) -> p(x_t|z_t) )
self.controller_outputs = u_t = [None] * num_steps
self.con_ics = con_state = None
self.con_states = con_states = [None] * num_steps
self.con_outs = con_outs = [None] * num_steps
self.gen_inputs = gen_inputs = [None] * num_steps
if co_dim > 0:
# gen_cell_class here for l2 penalty recurrent weights
# didn't split the cell_weight scale here, because I doubt it matters
con_cell = gen_cell_class(hps.con_dim,
input_weight_scale=hps.cell_weight_scale,
rec_weight_scale=hps.cell_weight_scale,
clip_value=hps.cell_clip_value,
recurrent_collections=['l2_con_reg'])
with tf.variable_scope("con", reuse=False):
self.con_ics = tf.tile(
tf.Variable(tf.zeros([1, hps.con_dim*con_cell.state_multiplier]),
name="c0"),
tf.stack([batch_size, 1]))
self.con_ics.set_shape([None, con_cell.state_size]) # tile loses shape
con_states[-1] = self.con_ics
gen_cell = gen_cell_class(hps.gen_dim,
input_weight_scale=hps.gen_cell_input_weight_scale,
rec_weight_scale=hps.gen_cell_rec_weight_scale,
clip_value=hps.cell_clip_value,
recurrent_collections=['l2_gen_reg'])
with tf.variable_scope("gen", reuse=False):
if ic_dim == 0:
self.gen_ics = tf.tile(
tf.Variable(tf.zeros([1, gen_cell.state_size]), name="g0"),
tf.stack([batch_size, 1]))
else:
self.gen_ics = linear(self.g0s_val, gen_cell.state_size,
identity_if_possible=True,
name="g0_2_gen_ic")
self.gen_states = gen_states = [None] * num_steps
self.gen_outs = gen_outs = [None] * num_steps
gen_states[-1] = self.gen_ics
gen_outs[-1] = gen_cell.output_from_state(gen_states[-1])
self.factors = factors = [None] * num_steps
factors[-1] = linear(gen_outs[-1], factors_dim, do_bias=False,
normalized=True, name="gen_2_fac")
self.rates = rates = [None] * num_steps
# rates[-1] is collected to potentially feed back to controller
with tf.variable_scope("glm", reuse=False):
if hps.output_dist == 'poisson':
log_rates_t0 = tf.matmul(factors[-1], this_out_fac_W) + this_out_fac_b
log_rates_t0.set_shape([None, None])
rates[-1] = tf.exp(log_rates_t0) # rate
rates[-1].set_shape([None, hps.dataset_dims[hps.dataset_names[0]]])
elif hps.output_dist == 'gaussian':
mean_n_logvars = tf.matmul(factors[-1],this_out_fac_W) + this_out_fac_b
mean_n_logvars.set_shape([None, None])
means_t_bxd, logvars_t_bxd = tf.split(axis=1, num_or_size_splits=2,
value=mean_n_logvars)
rates[-1] = means_t_bxd
else:
assert False, "NIY"
# We support multiple output distributions, for example Poisson, and also
# Gaussian. In these two cases respectively, there are one and two
# parameters (rates vs. mean and variance). So the output_dist_params
# tensor will variable sizes via tf.concat and tf.split, along the 1st
# dimension. So in the case of gaussian, for example, it'll be
# batch x (D+D), where each D dims is the mean, and then variances,
# respectively. For a distribution with 3 parameters, it would be
# batch x (D+D+D).
self.output_dist_params = dist_params = [None] * num_steps
self.log_p_xgz_b = log_p_xgz_b = 0.0 # log P(x|z)
for t in range(num_steps):
# Controller
if co_dim > 0:
# Build inputs for controller
tlag = t - hps.controller_input_lag
if tlag < 0:
con_in_f_t = tf.zeros_like(ci_enc_fwd[0])
else:
con_in_f_t = ci_enc_fwd[tlag]
if hps.do_causal_controller:
# If controller is causal (wrt to data generation process), then it
# cannot see future data. Thus, excluding ci_enc_rev[t] is obvious.
# Less obvious is the need to exclude factors[t-1]. This arises
# because information flows from g0 through factors to the controller
# input. The g0 encoding is backwards, so we must necessarily exclude
# the factors in order to keep the controller input purely from a
# forward encoding (however unlikely it is that
# g0->factors->controller channel might actually be used in this way).
con_in_list_t = [con_in_f_t]
else:
tlag_rev = t + hps.controller_input_lag
if tlag_rev >= num_steps:
# better than zeros
con_in_r_t = tf.zeros_like(ci_enc_rev[0])
else:
con_in_r_t = ci_enc_rev[tlag_rev]
con_in_list_t = [con_in_f_t, con_in_r_t]
if hps.do_feed_factors_to_controller:
if hps.feedback_factors_or_rates == "factors":
con_in_list_t.append(factors[t-1])
elif hps.feedback_factors_or_rates == "rates":
con_in_list_t.append(rates[t-1])
else:
assert False, "NIY"
con_in_t = tf.concat(axis=1, values=con_in_list_t)
con_in_t = tf.nn.dropout(con_in_t, keep_prob)
with tf.variable_scope("con", reuse=True if t > 0 else None):
con_outs[t], con_states[t] = con_cell(con_in_t, con_states[t-1])
posterior_zs_co[t] = \
DiagonalGaussianFromInput(con_outs[t], co_dim,
name="con_to_post_co")
if kind == "train":
u_t[t] = posterior_zs_co[t].sample
elif kind == "posterior_sample_and_average":
u_t[t] = posterior_zs_co[t].sample
elif kind == "posterior_push_mean":
u_t[t] = posterior_zs_co[t].mean
else:
u_t[t] = prior_zs_ar_con.samples_t[t]
# Inputs to the generator (controller output + external input)
if ext_input_dim > 0 and hps.inject_ext_input_to_gen:
ext_input_t_bxi = ext_input_do_bxtxi[:,t,:]
if co_dim > 0:
gen_inputs[t] = tf.concat(axis=1, values=[u_t[t], ext_input_t_bxi])
else:
gen_inputs[t] = ext_input_t_bxi
else:
gen_inputs[t] = u_t[t]
# Generator
data_t_bxd = dataset_ph[:,t,:]
with tf.variable_scope("gen", reuse=True if t > 0 else None):
gen_outs[t], gen_states[t] = gen_cell(gen_inputs[t], gen_states[t-1])
gen_outs[t] = tf.nn.dropout(gen_outs[t], keep_prob)
with tf.variable_scope("gen", reuse=True): # ic defined it above
factors[t] = linear(gen_outs[t], factors_dim, do_bias=False,
normalized=True, name="gen_2_fac")
with tf.variable_scope("glm", reuse=True if t > 0 else None):
if hps.output_dist == 'poisson':
log_rates_t = tf.matmul(factors[t], this_out_fac_W) + this_out_fac_b
log_rates_t.set_shape([None, None])
rates[t] = dist_params[t] = tf.exp(tf.clip_by_value(log_rates_t, -hps._clip_value, hps._clip_value)) # rates feed back
rates[t].set_shape([None, hps.dataset_dims[hps.dataset_names[0]]])
loglikelihood_t = Poisson(log_rates_t).logp(data_t_bxd)
elif hps.output_dist == 'gaussian':
mean_n_logvars = tf.matmul(factors[t],this_out_fac_W) + this_out_fac_b
mean_n_logvars.set_shape([None, None])
means_t_bxd, logvars_t_bxd = tf.split(axis=1, num_or_size_splits=2,
value=mean_n_logvars)
rates[t] = means_t_bxd # rates feed back to controller
dist_params[t] = tf.concat(
axis=1, values=[means_t_bxd, tf.exp(tf.clip_by_value(logvars_t_bxd, -hps._clip_value, hps._clip_value))])
loglikelihood_t = \
diag_gaussian_log_likelihood(data_t_bxd,
means_t_bxd, logvars_t_bxd)
else:
assert False, "NIY"
log_p_xgz_b += tf.reduce_sum(loglikelihood_t, [1])
# Correlation of inferred inputs cost.
self.corr_cost = tf.constant(0.0)
if hps.co_mean_corr_scale > 0.0:
all_sum_corr = []
for i in range(hps.co_dim):
for j in range(i+1, hps.co_dim):
sum_corr_ij = tf.constant(0.0)
for t in range(num_steps):
u_mean_t = posterior_zs_co[t].mean
sum_corr_ij += u_mean_t[:,i]*u_mean_t[:,j]
all_sum_corr.append(0.5 * tf.square(sum_corr_ij))
self.corr_cost = tf.reduce_mean(all_sum_corr) # div by batch and by n*(n-1)/2 pairs
# Variational Lower Bound on posterior, p(z|x), plus reconstruction cost.
# KL and reconstruction costs are normalized only by batch size, not by
# dimension, or by time steps.
kl_cost_g0_b = tf.zeros_like(batch_size, dtype=tf.float32)
kl_cost_co_b = tf.zeros_like(batch_size, dtype=tf.float32)
self.kl_cost = tf.constant(0.0) # VAE KL cost
self.recon_cost = tf.constant(0.0) # VAE reconstruction cost
self.nll_bound_vae = tf.constant(0.0)
self.nll_bound_iwae = tf.constant(0.0) # for eval with IWAE cost.
if kind in ["train", "posterior_sample_and_average", "posterior_push_mean"]:
kl_cost_g0_b = 0.0
kl_cost_co_b = 0.0
if ic_dim > 0:
g0_priors = [self.prior_zs_g0]
g0_posts = [self.posterior_zs_g0]
kl_cost_g0_b = KLCost_GaussianGaussian(g0_posts, g0_priors).kl_cost_b
kl_cost_g0_b = hps.kl_ic_weight * kl_cost_g0_b
if co_dim > 0:
kl_cost_co_b = \
KLCost_GaussianGaussianProcessSampled(
posterior_zs_co, prior_zs_ar_con).kl_cost_b
kl_cost_co_b = hps.kl_co_weight * kl_cost_co_b
# L = -KL + log p(x|z), to maximize bound on likelihood
# -L = KL - log p(x|z), to minimize bound on NLL
# so 'reconstruction cost' is negative log likelihood
self.recon_cost = - tf.reduce_mean(log_p_xgz_b)
self.kl_cost = tf.reduce_mean(kl_cost_g0_b + kl_cost_co_b)
lb_on_ll_b = log_p_xgz_b - kl_cost_g0_b - kl_cost_co_b
# VAE error averages outside the log
self.nll_bound_vae = -tf.reduce_mean(lb_on_ll_b)
# IWAE error averages inside the log
k = tf.cast(tf.shape(log_p_xgz_b)[0], tf.float32)
iwae_lb_on_ll = -tf.log(k) + log_sum_exp(lb_on_ll_b)
self.nll_bound_iwae = -iwae_lb_on_ll
# L2 regularization on the generator, normalized by number of parameters.
self.l2_cost = tf.constant(0.0)
if self.hps.l2_gen_scale > 0.0 or self.hps.l2_con_scale > 0.0:
l2_costs = []
l2_numels = []
l2_reg_var_lists = [tf.get_collection('l2_gen_reg'),
tf.get_collection('l2_con_reg')]
l2_reg_scales = [self.hps.l2_gen_scale, self.hps.l2_con_scale]
for l2_reg_vars, l2_scale in zip(l2_reg_var_lists, l2_reg_scales):
for v in l2_reg_vars:
numel = tf.reduce_prod(tf.concat(axis=0, values=tf.shape(v)))
numel_f = tf.cast(numel, tf.float32)
l2_numels.append(numel_f)
v_l2 = tf.reduce_sum(v*v)
l2_costs.append(0.5 * l2_scale * v_l2)
self.l2_cost = tf.add_n(l2_costs) / tf.add_n(l2_numels)
# Compute the cost for training, part of the graph regardless.
# The KL cost can be problematic at the beginning of optimization,
# so we allow an exponential increase in weighting the KL from 0
# to 1.
self.kl_decay_step = tf.maximum(self.train_step - hps.kl_start_step, 0)
self.l2_decay_step = tf.maximum(self.train_step - hps.l2_start_step, 0)
kl_decay_step_f = tf.cast(self.kl_decay_step, tf.float32)
l2_decay_step_f = tf.cast(self.l2_decay_step, tf.float32)
kl_increase_steps_f = tf.cast(hps.kl_increase_steps, tf.float32)
l2_increase_steps_f = tf.cast(hps.l2_increase_steps, tf.float32)
self.kl_weight = kl_weight = \
tf.minimum(kl_decay_step_f / kl_increase_steps_f, 1.0)
self.l2_weight = l2_weight = \
tf.minimum(l2_decay_step_f / l2_increase_steps_f, 1.0)
self.timed_kl_cost = kl_weight * self.kl_cost
self.timed_l2_cost = l2_weight * self.l2_cost
self.weight_corr_cost = hps.co_mean_corr_scale * self.corr_cost
self.cost = self.recon_cost + self.timed_kl_cost + \
self.timed_l2_cost + self.weight_corr_cost
if kind != "train":
# save every so often
self.seso_saver = tf.train.Saver(tf.global_variables(),
max_to_keep=hps.max_ckpt_to_keep)
# lowest validation error
self.lve_saver = tf.train.Saver(tf.global_variables(),
max_to_keep=hps.max_ckpt_to_keep_lve)
return
# OPTIMIZATION
# train the io matrices only
if self.hps.do_train_io_only:
self.train_vars = tvars = \
tf.get_collection('IO_transformations',
scope=tf.get_variable_scope().name)
# train the encoder only
elif self.hps.do_train_encoder_only:
tvars1 = \
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
scope='LFADS/ic_enc_*')
tvars2 = \
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
scope='LFADS/z/ic_enc_*')
self.train_vars = tvars = tvars1 + tvars2
# train all variables
else:
self.train_vars = tvars = \
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
scope=tf.get_variable_scope().name)
print("done.")
print("Model Variables (to be optimized): ")
total_params = 0
for i in range(len(tvars)):
shape = tvars[i].get_shape().as_list()
print(" ", i, tvars[i].name, shape)
total_params += np.prod(shape)
print("Total model parameters: ", total_params)
grads = tf.gradients(self.cost, tvars)
grads, grad_global_norm = tf.clip_by_global_norm(grads, hps.max_grad_norm)
opt = tf.train.AdamOptimizer(self.learning_rate, beta1=0.9, beta2=0.999,
epsilon=1e-01)
self.grads = grads
self.grad_global_norm = grad_global_norm
self.train_op = opt.apply_gradients(
zip(grads, tvars), global_step=self.train_step)
self.seso_saver = tf.train.Saver(tf.global_variables(),
max_to_keep=hps.max_ckpt_to_keep)
# lowest validation error
self.lve_saver = tf.train.Saver(tf.global_variables(),
max_to_keep=hps.max_ckpt_to_keep)
# SUMMARIES, used only during training.
# example summary
self.example_image = tf.placeholder(tf.float32, shape=[1,None,None,3],
name='image_tensor')
self.example_summ = tf.summary.image("LFADS example", self.example_image,
collections=["example_summaries"])
# general training summaries
self.lr_summ = tf.summary.scalar("Learning rate", self.learning_rate)
self.kl_weight_summ = tf.summary.scalar("KL weight", self.kl_weight)
self.l2_weight_summ = tf.summary.scalar("L2 weight", self.l2_weight)
self.corr_cost_summ = tf.summary.scalar("Corr cost", self.weight_corr_cost)
self.grad_global_norm_summ = tf.summary.scalar("Gradient global norm",
self.grad_global_norm)
if hps.co_dim > 0:
self.atau_summ = [None] * hps.co_dim
self.pvar_summ = [None] * hps.co_dim
for c in range(hps.co_dim):
self.atau_summ[c] = \
tf.summary.scalar("AR Autocorrelation taus " + str(c),
tf.exp(self.prior_zs_ar_con.logataus_1xu[0,c]))
self.pvar_summ[c] = \
tf.summary.scalar("AR Variances " + str(c),
tf.exp(self.prior_zs_ar_con.logpvars_1xu[0,c]))
# cost summaries, separated into different collections for
# training vs validation. We make placeholders for these, because
# even though the graph computes these costs on a per-batch basis,
# we want to report the more reliable metric of per-epoch cost.
kl_cost_ph = tf.placeholder(tf.float32, shape=[], name='kl_cost_ph')
self.kl_t_cost_summ = tf.summary.scalar("KL cost (train)", kl_cost_ph,
collections=["train_summaries"])
self.kl_v_cost_summ = tf.summary.scalar("KL cost (valid)", kl_cost_ph,
collections=["valid_summaries"])
l2_cost_ph = tf.placeholder(tf.float32, shape=[], name='l2_cost_ph')
self.l2_cost_summ = tf.summary.scalar("L2 cost", l2_cost_ph,
collections=["train_summaries"])
recon_cost_ph = tf.placeholder(tf.float32, shape=[], name='recon_cost_ph')
self.recon_t_cost_summ = tf.summary.scalar("Reconstruction cost (train)",
recon_cost_ph,
collections=["train_summaries"])
self.recon_v_cost_summ = tf.summary.scalar("Reconstruction cost (valid)",
recon_cost_ph,
collections=["valid_summaries"])
total_cost_ph = tf.placeholder(tf.float32, shape=[], name='total_cost_ph')
self.cost_t_summ = tf.summary.scalar("Total cost (train)", total_cost_ph,
collections=["train_summaries"])
self.cost_v_summ = tf.summary.scalar("Total cost (valid)", total_cost_ph,
collections=["valid_summaries"])
self.kl_cost_ph = kl_cost_ph
self.l2_cost_ph = l2_cost_ph
self.recon_cost_ph = recon_cost_ph
self.total_cost_ph = total_cost_ph
# Merged summaries, for easy coding later.
self.merged_examples = tf.summary.merge_all(key="example_summaries")
self.merged_generic = tf.summary.merge_all() # default key is 'summaries'
self.merged_train = tf.summary.merge_all(key="train_summaries")
self.merged_valid = tf.summary.merge_all(key="valid_summaries")
session = tf.get_default_session()
self.logfile = os.path.join(hps.lfads_save_dir, "lfads_log")
self.writer = tf.summary.FileWriter(self.logfile)
def build_feed_dict(self, train_name, data_bxtxd, ext_input_bxtxi=None,
keep_prob=None):
"""Build the feed dictionary, handles cases where there is no value defined.
Args:
train_name: The key into the datasets, to set the tf.case statement for
the proper readin / readout matrices.
data_bxtxd: The data tensor
ext_input_bxtxi (optional): The external input tensor
keep_prob: The drop out keep probability.
Returns:
The feed dictionary with TF tensors as keys and data as values, for use
with tf.Session.run()
"""
feed_dict = {}
B, T, _ = data_bxtxd.shape
feed_dict[self.dataName] = train_name
feed_dict[self.dataset_ph] = data_bxtxd
if self.ext_input is not None and ext_input_bxtxi is not None:
feed_dict[self.ext_input] = ext_input_bxtxi
if keep_prob is None:
feed_dict[self.keep_prob] = self.hps.keep_prob
else:
feed_dict[self.keep_prob] = keep_prob
return feed_dict
@staticmethod
def get_batch(data_extxd, ext_input_extxi=None, batch_size=None,
example_idxs=None):
"""Get a batch of data, either randomly chosen, or specified directly.
Args:
data_extxd: The data to model, numpy tensors with shape:
# examples x # time steps x # dimensions
ext_input_extxi (optional): The external inputs, numpy tensor with shape:
# examples x # time steps x # external input dimensions
batch_size: The size of the batch to return
example_idxs (optional): The example indices used to select examples.
Returns:
A tuple with two parts:
1. Batched data numpy tensor with shape:
batch_size x # time steps x # dimensions
2. Batched external input numpy tensor with shape:
batch_size x # time steps x # external input dims
"""
assert batch_size is not None or example_idxs is not None, "Problems"
E, T, D = data_extxd.shape
if example_idxs is None:
example_idxs = np.random.choice(E, batch_size)
ext_input_bxtxi = None
if ext_input_extxi is not None:
ext_input_bxtxi = ext_input_extxi[example_idxs,:,:]
return data_extxd[example_idxs,:,:], ext_input_bxtxi
@staticmethod
def example_idxs_mod_batch_size(nexamples, batch_size):
"""Given a number of examples, E, and a batch_size, B, generate indices
[0, 1, 2, ... B-1;
[B, B+1, ... 2*B-1;
...
]
returning those indices as a 2-dim tensor shaped like E/B x B. Note that
shape is only correct if E % B == 0. If not, then an extra row is generated
so that the remainder of examples is included. The extra examples are
explicitly to to the zero index (see randomize_example_idxs_mod_batch_size)
for randomized behavior.
Args:
nexamples: The number of examples to batch up.
batch_size: The size of the batch.
Returns:
2-dim tensor as described above.
"""
bmrem = batch_size - (nexamples % batch_size)
bmrem_examples = []
if bmrem < batch_size:
#bmrem_examples = np.zeros(bmrem, dtype=np.int32)
ridxs = np.random.permutation(nexamples)[0:bmrem].astype(np.int32)
bmrem_examples = np.sort(ridxs)
example_idxs = range(nexamples) + list(bmrem_examples)
example_idxs_e_x_edivb = np.reshape(example_idxs, [-1, batch_size])
return example_idxs_e_x_edivb, bmrem
@staticmethod
def randomize_example_idxs_mod_batch_size(nexamples, batch_size):
"""Indices 1:nexamples, randomized, in 2D form of
shape = (nexamples / batch_size) x batch_size. The remainder
is managed by drawing randomly from 1:nexamples.
Args:
nexamples: number of examples to randomize
batch_size: number of elements in batch
Returns:
The randomized, properly shaped indicies.
"""
assert nexamples > batch_size, "Problems"
bmrem = batch_size - nexamples % batch_size
bmrem_examples = []
if bmrem < batch_size:
bmrem_examples = np.random.choice(range(nexamples),
size=bmrem, replace=False)
example_idxs = range(nexamples) + list(bmrem_examples)
mixed_example_idxs = np.random.permutation(example_idxs)
example_idxs_e_x_edivb = np.reshape(mixed_example_idxs, [-1, batch_size])
return example_idxs_e_x_edivb, bmrem
def shuffle_spikes_in_time(self, data_bxtxd):
"""Shuffle the spikes in the temporal dimension. This is useful to
help the LFADS system avoid overfitting to individual spikes or fast
oscillations found in the data that are irrelevant to behavior. A
pure 'tabula rasa' approach would avoid this, but LFADS is sensitive
enough to pick up dynamics that you may not want.
Args:
data_bxtxd: numpy array of spike count data to be shuffled.
Returns:
S_bxtxd, a numpy array with the same dimensions and contents as
data_bxtxd, but shuffled appropriately.
"""
B, T, N = data_bxtxd.shape
w = self.hps.temporal_spike_jitter_width
if w == 0:
return data_bxtxd
max_counts = np.max(data_bxtxd)
S_bxtxd = np.zeros([B,T,N])
# Intuitively, shuffle spike occurances, 0 or 1, but since we have counts,
# Do it over and over again up to the max count.
for mc in range(1,max_counts+1):
idxs = np.nonzero(data_bxtxd >= mc)
data_ones = np.zeros_like(data_bxtxd)
data_ones[data_bxtxd >= mc] = 1
nfound = len(idxs[0])
shuffles_incrs_in_time = np.random.randint(-w, w, size=nfound)
shuffle_tidxs = idxs[1].copy()
shuffle_tidxs += shuffles_incrs_in_time
# Reflect on the boundaries to not lose mass.
shuffle_tidxs[shuffle_tidxs < 0] = -shuffle_tidxs[shuffle_tidxs < 0]
shuffle_tidxs[shuffle_tidxs > T-1] = \
(T-1)-(shuffle_tidxs[shuffle_tidxs > T-1] -(T-1))
for iii in zip(idxs[0], shuffle_tidxs, idxs[2]):
S_bxtxd[iii] += 1
return S_bxtxd
def shuffle_and_flatten_datasets(self, datasets, kind='train'):
"""Since LFADS supports multiple datasets in the same dynamical model,
we have to be careful to use all the data in a single training epoch. But
since the datasets my have different data dimensionality, we cannot batch
examples from data dictionaries together. Instead, we generate random
batches within each data dictionary, and then randomize these batches
while holding onto the dataname, so that when it's time to feed
the graph, the correct in/out matrices can be selected, per batch.
Args:
datasets: A dict of data dicts. The dataset dict is simply a
name(string)-> data dictionary mapping (See top of lfads.py).
kind: 'train' or 'valid'
Returns:
A flat list, in which each element is a pair ('name', indices).
"""
batch_size = self.hps.batch_size
ndatasets = len(datasets)
random_example_idxs = {}
epoch_idxs = {}
all_name_example_idx_pairs = []
kind_data = kind + '_data'
for name, data_dict in datasets.items():
nexamples, ntime, data_dim = data_dict[kind_data].shape
epoch_idxs[name] = 0
random_example_idxs, _ = \
self.randomize_example_idxs_mod_batch_size(nexamples, batch_size)
epoch_size = random_example_idxs.shape[0]
names = [name] * epoch_size
all_name_example_idx_pairs += zip(names, random_example_idxs)
np.random.shuffle(all_name_example_idx_pairs) # shuffle in place
return all_name_example_idx_pairs
def train_epoch(self, datasets, batch_size=None, do_save_ckpt=True):
"""Train the model through the entire dataset once.
Args:
datasets: A dict of data dicts. The dataset dict is simply a
name(string)-> data dictionary mapping (See top of lfads.py).
batch_size (optional): The batch_size to use
do_save_ckpt (optional): Should the routine save a checkpoint on this
training epoch?
Returns:
A tuple with 6 float values:
(total cost of the epoch, epoch reconstruction cost,
epoch kl cost, KL weight used this training epoch,
total l2 cost on generator, and the corresponding weight).
"""
ops_to_eval = [self.cost, self.recon_cost,
self.kl_cost, self.kl_weight,
self.l2_cost, self.l2_weight,
self.train_op]
collected_op_values = self.run_epoch(datasets, ops_to_eval, kind="train")
total_cost = total_recon_cost = total_kl_cost = 0.0
# normalizing by batch done in distributions.py
epoch_size = len(collected_op_values)
for op_values in collected_op_values:
total_cost += op_values[0]
total_recon_cost += op_values[1]
total_kl_cost += op_values[2]
kl_weight = collected_op_values[-1][3]
l2_cost = collected_op_values[-1][4]
l2_weight = collected_op_values[-1][5]
epoch_total_cost = total_cost / epoch_size
epoch_recon_cost = total_recon_cost / epoch_size
epoch_kl_cost = total_kl_cost / epoch_size
if do_save_ckpt:
session = tf.get_default_session()
checkpoint_path = os.path.join(self.hps.lfads_save_dir,
self.hps.checkpoint_name + '.ckpt')
self.seso_saver.save(session, checkpoint_path,
global_step=self.train_step)
return epoch_total_cost, epoch_recon_cost, epoch_kl_cost, \
kl_weight, l2_cost, l2_weight
def run_epoch(self, datasets, ops_to_eval, kind="train", batch_size=None,
do_collect=True, keep_prob=None):
"""Run the model through the entire dataset once.
Args:
datasets: A dict of data dicts. The dataset dict is simply a
name(string)-> data dictionary mapping (See top of lfads.py).
ops_to_eval: A list of tensorflow operations that will be evaluated in
the tf.session.run() call.
batch_size (optional): The batch_size to use
do_collect (optional): Should the routine collect all session.run
output as a list, and return it?
keep_prob (optional): The dropout keep probability.
Returns:
A list of lists, the internal list is the return for the ops for each
session.run() call. The outer list collects over the epoch.
"""
hps = self.hps
all_name_example_idx_pairs = \
self.shuffle_and_flatten_datasets(datasets, kind)
kind_data = kind + '_data'
kind_ext_input = kind + '_ext_input'
total_cost = total_recon_cost = total_kl_cost = 0.0
session = tf.get_default_session()
epoch_size = len(all_name_example_idx_pairs)
evaled_ops_list = []
for name, example_idxs in all_name_example_idx_pairs:
data_dict = datasets[name]
data_extxd = data_dict[kind_data]
if hps.output_dist == 'poisson' and hps.temporal_spike_jitter_width > 0:
data_extxd = self.shuffle_spikes_in_time(data_extxd)
ext_input_extxi = data_dict[kind_ext_input]
data_bxtxd, ext_input_bxtxi = self.get_batch(data_extxd, ext_input_extxi,
example_idxs=example_idxs)
feed_dict = self.build_feed_dict(name, data_bxtxd, ext_input_bxtxi,
keep_prob=keep_prob)
evaled_ops_np = session.run(ops_to_eval, feed_dict=feed_dict)
if do_collect:
evaled_ops_list.append(evaled_ops_np)
return evaled_ops_list
def summarize_all(self, datasets, summary_values):
"""Plot and summarize stuff in tensorboard.
Note that everything done in the current function is otherwise done on
a single, randomly selected dataset (except for summary_values, which are
passed in.)
Args:
datasets, the dictionary of datasets used in the study.
summary_values: These summary values are created from the training loop,
and so summarize the entire set of datasets.
"""
hps = self.hps
tr_kl_cost = summary_values['tr_kl_cost']
tr_recon_cost = summary_values['tr_recon_cost']
tr_total_cost = summary_values['tr_total_cost']
kl_weight = summary_values['kl_weight']
l2_weight = summary_values['l2_weight']
l2_cost = summary_values['l2_cost']
has_any_valid_set = summary_values['has_any_valid_set']
i = summary_values['nepochs']
session = tf.get_default_session()
train_summ, train_step = session.run([self.merged_train,
self.train_step],
feed_dict={self.l2_cost_ph:l2_cost,
self.kl_cost_ph:tr_kl_cost,
self.recon_cost_ph:tr_recon_cost,
self.total_cost_ph:tr_total_cost})
self.writer.add_summary(train_summ, train_step)
if has_any_valid_set:
ev_kl_cost = summary_values['ev_kl_cost']
ev_recon_cost = summary_values['ev_recon_cost']
ev_total_cost = summary_values['ev_total_cost']
eval_summ = session.run(self.merged_valid,
feed_dict={self.kl_cost_ph:ev_kl_cost,
self.recon_cost_ph:ev_recon_cost,
self.total_cost_ph:ev_total_cost})
self.writer.add_summary(eval_summ, train_step)
print("Epoch:%d, step:%d (TRAIN, VALID): total: %.2f, %.2f\
recon: %.2f, %.2f, kl: %.2f, %.2f, l2: %.5f,\
kl weight: %.2f, l2 weight: %.2f" % \
(i, train_step, tr_total_cost, ev_total_cost,
tr_recon_cost, ev_recon_cost, tr_kl_cost, ev_kl_cost,
l2_cost, kl_weight, l2_weight))
csv_outstr = "epoch,%d, step,%d, total,%.2f,%.2f, \
recon,%.2f,%.2f, kl,%.2f,%.2f, l2,%.5f, \
klweight,%.2f, l2weight,%.2f\n"% \
(i, train_step, tr_total_cost, ev_total_cost,
tr_recon_cost, ev_recon_cost, tr_kl_cost, ev_kl_cost,
l2_cost, kl_weight, l2_weight)
else:
print("Epoch:%d, step:%d TRAIN: total: %.2f recon: %.2f, kl: %.2f,\
l2: %.5f, kl weight: %.2f, l2 weight: %.2f" % \
(i, train_step, tr_total_cost, tr_recon_cost, tr_kl_cost,
l2_cost, kl_weight, l2_weight))
csv_outstr = "epoch,%d, step,%d, total,%.2f, recon,%.2f, kl,%.2f, \
l2,%.5f, klweight,%.2f, l2weight,%.2f\n"% \
(i, train_step, tr_total_cost, tr_recon_cost,
tr_kl_cost, l2_cost, kl_weight, l2_weight)
if self.hps.csv_log:
csv_file = os.path.join(self.hps.lfads_save_dir, self.hps.csv_log+'.csv')
with open(csv_file, "a") as myfile:
myfile.write(csv_outstr)
def plot_single_example(self, datasets):
"""Plot an image relating to a randomly chosen, specific example. We use
posterior sample and average by taking one example, and filling a whole
batch with that example, sample from the posterior, and then average the
quantities.
"""
hps = self.hps
all_data_names = datasets.keys()
data_name = np.random.permutation(all_data_names)[0]
data_dict = datasets[data_name]
has_valid_set = True if data_dict['valid_data'] is not None else False
cf = 1.0 # plotting concern
# posterior sample and average here
E, _, _ = data_dict['train_data'].shape
eidx = np.random.choice(E)
example_idxs = eidx * np.ones(hps.batch_size, dtype=np.int32)
train_data_bxtxd, train_ext_input_bxtxi = \
self.get_batch(data_dict['train_data'], data_dict['train_ext_input'],
example_idxs=example_idxs)
truth_train_data_bxtxd = None
if 'train_truth' in data_dict and data_dict['train_truth'] is not None:
truth_train_data_bxtxd, _ = self.get_batch(data_dict['train_truth'],
example_idxs=example_idxs)
cf = data_dict['conversion_factor']
# plotter does averaging
train_model_values = self.eval_model_runs_batch(data_name,
train_data_bxtxd,
train_ext_input_bxtxi,
do_average_batch=False)
train_step = train_model_values['train_steps']
feed_dict = self.build_feed_dict(data_name, train_data_bxtxd,
train_ext_input_bxtxi, keep_prob=1.0)
session = tf.get_default_session()
generic_summ = session.run(self.merged_generic, feed_dict=feed_dict)
self.writer.add_summary(generic_summ, train_step)
valid_data_bxtxd = valid_model_values = valid_ext_input_bxtxi = None
truth_valid_data_bxtxd = None
if has_valid_set:
E, _, _ = data_dict['valid_data'].shape
eidx = np.random.choice(E)
example_idxs = eidx * np.ones(hps.batch_size, dtype=np.int32)
valid_data_bxtxd, valid_ext_input_bxtxi = \
self.get_batch(data_dict['valid_data'],
data_dict['valid_ext_input'],
example_idxs=example_idxs)
if 'valid_truth' in data_dict and data_dict['valid_truth'] is not None:
truth_valid_data_bxtxd, _ = self.get_batch(data_dict['valid_truth'],
example_idxs=example_idxs)
else:
truth_valid_data_bxtxd = None
# plotter does averaging
valid_model_values = self.eval_model_runs_batch(data_name,
valid_data_bxtxd,
valid_ext_input_bxtxi,
do_average_batch=False)
example_image = plot_lfads(train_bxtxd=train_data_bxtxd,
train_model_vals=train_model_values,
train_ext_input_bxtxi=train_ext_input_bxtxi,
train_truth_bxtxd=truth_train_data_bxtxd,
valid_bxtxd=valid_data_bxtxd,
valid_model_vals=valid_model_values,
valid_ext_input_bxtxi=valid_ext_input_bxtxi,
valid_truth_bxtxd=truth_valid_data_bxtxd,
bidx=None, cf=cf, output_dist=hps.output_dist)
example_image = np.expand_dims(example_image, axis=0)
example_summ = session.run(self.merged_examples,
feed_dict={self.example_image : example_image})
self.writer.add_summary(example_summ)
def train_model(self, datasets):
"""Train the model, print per-epoch information, and save checkpoints.
Loop over training epochs. The function that actually does the
training is train_epoch. This function iterates over the training
data, one epoch at a time. The learning rate schedule is such
that it will stay the same until the cost goes up in comparison to
the last few values, then it will drop.
Args:
datasets: A dict of data dicts. The dataset dict is simply a
name(string)-> data dictionary mapping (See top of lfads.py).
"""
hps = self.hps
has_any_valid_set = False
for data_dict in datasets.values():
if data_dict['valid_data'] is not None:
has_any_valid_set = True
break
session = tf.get_default_session()
lr = session.run(self.learning_rate)
lr_stop = hps.learning_rate_stop
i = -1
train_costs = []
valid_costs = []
ev_total_cost = ev_recon_cost = ev_kl_cost = 0.0
lowest_ev_cost = np.Inf
while True:
i += 1
do_save_ckpt = True if i % 10 ==0 else False
tr_total_cost, tr_recon_cost, tr_kl_cost, kl_weight, l2_cost, l2_weight = \
self.train_epoch(datasets, do_save_ckpt=do_save_ckpt)
# Evaluate the validation cost, and potentially save. Note that this
# routine will not save a validation checkpoint until the kl weight and
# l2 weights are equal to 1.0.
if has_any_valid_set:
ev_total_cost, ev_recon_cost, ev_kl_cost = \
self.eval_cost_epoch(datasets, kind='valid')
valid_costs.append(ev_total_cost)
# > 1 may give more consistent results, but not the actual lowest vae.
# == 1 gives the lowest vae seen so far.
n_lve = 1
run_avg_lve = np.mean(valid_costs[-n_lve:])
# conditions for saving checkpoints:
# KL weight must have finished stepping (>=1.0), AND
# L2 weight must have finished stepping OR L2 is not being used, AND
# the current run has a lower LVE than previous runs AND
# len(valid_costs > n_lve) (not sure what that does)
if kl_weight >= 1.0 and \
(l2_weight >= 1.0 or \
(self.hps.l2_gen_scale == 0.0 and self.hps.l2_con_scale == 0.0)) \
and (len(valid_costs) > n_lve and run_avg_lve < lowest_ev_cost):
lowest_ev_cost = run_avg_lve
checkpoint_path = os.path.join(self.hps.lfads_save_dir,
self.hps.checkpoint_name + '_lve.ckpt')
self.lve_saver.save(session, checkpoint_path,
global_step=self.train_step,
latest_filename='checkpoint_lve')
# Plot and summarize.
values = {'nepochs':i, 'has_any_valid_set': has_any_valid_set,
'tr_total_cost':tr_total_cost, 'ev_total_cost':ev_total_cost,
'tr_recon_cost':tr_recon_cost, 'ev_recon_cost':ev_recon_cost,
'tr_kl_cost':tr_kl_cost, 'ev_kl_cost':ev_kl_cost,
'l2_weight':l2_weight, 'kl_weight':kl_weight,
'l2_cost':l2_cost}
self.summarize_all(datasets, values)
self.plot_single_example(datasets)
# Manage learning rate.
train_res = tr_total_cost
n_lr = hps.learning_rate_n_to_compare
if len(train_costs) > n_lr and train_res > np.max(train_costs[-n_lr:]):
_ = session.run(self.learning_rate_decay_op)
lr = session.run(self.learning_rate)
print(" Decreasing learning rate to %f." % lr)
# Force the system to run n_lr times while at this lr.
train_costs.append(np.inf)
else:
train_costs.append(train_res)
if lr < lr_stop:
print("Stopping optimization based on learning rate criteria.")
break
def eval_cost_epoch(self, datasets, kind='train', ext_input_extxi=None,
batch_size=None):
"""Evaluate the cost of the epoch.
Args:
data_dict: The dictionary of data (training and validation) used for
training and evaluation of the model, respectively.
Returns:
a 3 tuple of costs:
(epoch total cost, epoch reconstruction cost, epoch KL cost)
"""
ops_to_eval = [self.cost, self.recon_cost, self.kl_cost]
collected_op_values = self.run_epoch(datasets, ops_to_eval, kind=kind,
keep_prob=1.0)
total_cost = total_recon_cost = total_kl_cost = 0.0
# normalizing by batch done in distributions.py
epoch_size = len(collected_op_values)
for op_values in collected_op_values:
total_cost += op_values[0]
total_recon_cost += op_values[1]
total_kl_cost += op_values[2]
epoch_total_cost = total_cost / epoch_size
epoch_recon_cost = total_recon_cost / epoch_size
epoch_kl_cost = total_kl_cost / epoch_size
return epoch_total_cost, epoch_recon_cost, epoch_kl_cost
def eval_model_runs_batch(self, data_name, data_bxtxd, ext_input_bxtxi=None,
do_eval_cost=False, do_average_batch=False):
"""Returns all the goodies for the entire model, per batch.
If data_bxtxd and ext_input_bxtxi can have fewer than batch_size along dim 1
in which case this handles the padding and truncating automatically
Args:
data_name: The name of the data dict, to select which in/out matrices
to use.
data_bxtxd: Numpy array training data with shape:
batch_size x # time steps x # dimensions
ext_input_bxtxi: Numpy array training external input with shape:
batch_size x # time steps x # external input dims
do_eval_cost (optional): If true, the IWAE (Importance Weighted
Autoencoder) log likeihood bound, instead of the VAE version.
do_average_batch (optional): average over the batch, useful for getting
good IWAE costs, and model outputs for a single data point.
Returns:
A dictionary with the outputs of the model decoder, namely:
prior g0 mean, prior g0 variance, approx. posterior mean, approx
posterior mean, the generator initial conditions, the control inputs (if
enabled), the state of the generator, the factors, and the rates.
"""
session = tf.get_default_session()
# if fewer than batch_size provided, pad to batch_size
hps = self.hps
batch_size = hps.batch_size
E, _, _ = data_bxtxd.shape
if E < hps.batch_size:
data_bxtxd = np.pad(data_bxtxd, ((0, hps.batch_size-E), (0, 0), (0, 0)),
mode='constant', constant_values=0)
if ext_input_bxtxi is not None:
ext_input_bxtxi = np.pad(ext_input_bxtxi,
((0, hps.batch_size-E), (0, 0), (0, 0)),
mode='constant', constant_values=0)
feed_dict = self.build_feed_dict(data_name, data_bxtxd,
ext_input_bxtxi, keep_prob=1.0)
# Non-temporal signals will be batch x dim.
# Temporal signals are list length T with elements batch x dim.
tf_vals = [self.gen_ics, self.gen_states, self.factors,
self.output_dist_params]
tf_vals.append(self.cost)
tf_vals.append(self.nll_bound_vae)
tf_vals.append(self.nll_bound_iwae)
tf_vals.append(self.train_step) # not train_op!
if self.hps.ic_dim > 0:
tf_vals += [self.prior_zs_g0.mean, self.prior_zs_g0.logvar,
self.posterior_zs_g0.mean, self.posterior_zs_g0.logvar]
if self.hps.co_dim > 0:
tf_vals.append(self.controller_outputs)
tf_vals_flat, fidxs = flatten(tf_vals)
np_vals_flat = session.run(tf_vals_flat, feed_dict=feed_dict)
ff = 0
gen_ics = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
gen_states = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
factors = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
out_dist_params = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
costs = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
nll_bound_vaes = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
nll_bound_iwaes = [np_vals_flat[f] for f in fidxs[ff]]; ff +=1
train_steps = [np_vals_flat[f] for f in fidxs[ff]]; ff +=1
if self.hps.ic_dim > 0:
prior_g0_mean = [np_vals_flat[f] for f in fidxs[ff]]; ff +=1
prior_g0_logvar = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
post_g0_mean = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
post_g0_logvar = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
if self.hps.co_dim > 0:
controller_outputs = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
# [0] are to take out the non-temporal items from lists
gen_ics = gen_ics[0]
costs = costs[0]
nll_bound_vaes = nll_bound_vaes[0]
nll_bound_iwaes = nll_bound_iwaes[0]
train_steps = train_steps[0]
# Convert to full tensors, not lists of tensors in time dim.
gen_states = list_t_bxn_to_tensor_bxtxn(gen_states)
factors = list_t_bxn_to_tensor_bxtxn(factors)
out_dist_params = list_t_bxn_to_tensor_bxtxn(out_dist_params)
if self.hps.ic_dim > 0:
# select first time point
prior_g0_mean = prior_g0_mean[0]
prior_g0_logvar = prior_g0_logvar[0]
post_g0_mean = post_g0_mean[0]
post_g0_logvar = post_g0_logvar[0]
if self.hps.co_dim > 0:
controller_outputs = list_t_bxn_to_tensor_bxtxn(controller_outputs)
# slice out the trials in case < batch_size provided
if E < hps.batch_size:
idx = np.arange(E)
gen_ics = gen_ics[idx, :]
gen_states = gen_states[idx, :]
factors = factors[idx, :, :]
out_dist_params = out_dist_params[idx, :, :]
if self.hps.ic_dim > 0:
prior_g0_mean = prior_g0_mean[idx, :]
prior_g0_logvar = prior_g0_logvar[idx, :]
post_g0_mean = post_g0_mean[idx, :]
post_g0_logvar = post_g0_logvar[idx, :]
if self.hps.co_dim > 0:
controller_outputs = controller_outputs[idx, :, :]
if do_average_batch:
gen_ics = np.mean(gen_ics, axis=0)
gen_states = np.mean(gen_states, axis=0)
factors = np.mean(factors, axis=0)
out_dist_params = np.mean(out_dist_params, axis=0)
if self.hps.ic_dim > 0:
prior_g0_mean = np.mean(prior_g0_mean, axis=0)
prior_g0_logvar = np.mean(prior_g0_logvar, axis=0)
post_g0_mean = np.mean(post_g0_mean, axis=0)
post_g0_logvar = np.mean(post_g0_logvar, axis=0)
if self.hps.co_dim > 0:
controller_outputs = np.mean(controller_outputs, axis=0)
model_vals = {}
model_vals['gen_ics'] = gen_ics
model_vals['gen_states'] = gen_states
model_vals['factors'] = factors
model_vals['output_dist_params'] = out_dist_params
model_vals['costs'] = costs
model_vals['nll_bound_vaes'] = nll_bound_vaes
model_vals['nll_bound_iwaes'] = nll_bound_iwaes
model_vals['train_steps'] = train_steps
if self.hps.ic_dim > 0:
model_vals['prior_g0_mean'] = prior_g0_mean
model_vals['prior_g0_logvar'] = prior_g0_logvar
model_vals['post_g0_mean'] = post_g0_mean
model_vals['post_g0_logvar'] = post_g0_logvar
if self.hps.co_dim > 0:
model_vals['controller_outputs'] = controller_outputs
return model_vals
def eval_model_runs_avg_epoch(self, data_name, data_extxd,
ext_input_extxi=None):
"""Returns all the expected value for goodies for the entire model.
The expected value is taken over hidden (z) variables, namely the initial
conditions and the control inputs. The expected value is approximate, and
accomplished via sampling (batch_size) samples for every examples.
Args:
data_name: The name of the data dict, to select which in/out matrices
to use.
data_extxd: Numpy array training data with shape:
# examples x # time steps x # dimensions
ext_input_extxi (optional): Numpy array training external input with
shape: # examples x # time steps x # external input dims
Returns:
A dictionary with the averaged outputs of the model decoder, namely:
prior g0 mean, prior g0 variance, approx. posterior mean, approx
posterior mean, the generator initial conditions, the control inputs (if
enabled), the state of the generator, the factors, and the output
distribution parameters, e.g. (rates or mean and variances).
"""
hps = self.hps
batch_size = hps.batch_size
E, T, D = data_extxd.shape
E_to_process = hps.ps_nexamples_to_process
if E_to_process > E:
E_to_process = E
if hps.ic_dim > 0:
prior_g0_mean = np.zeros([E_to_process, hps.ic_dim])
prior_g0_logvar = np.zeros([E_to_process, hps.ic_dim])
post_g0_mean = np.zeros([E_to_process, hps.ic_dim])
post_g0_logvar = np.zeros([E_to_process, hps.ic_dim])
if hps.co_dim > 0:
controller_outputs = np.zeros([E_to_process, T, hps.co_dim])
gen_ics = np.zeros([E_to_process, hps.gen_dim])
gen_states = np.zeros([E_to_process, T, hps.gen_dim])
factors = np.zeros([E_to_process, T, hps.factors_dim])
if hps.output_dist == 'poisson':
out_dist_params = np.zeros([E_to_process, T, D])
elif hps.output_dist == 'gaussian':
out_dist_params = np.zeros([E_to_process, T, D+D])
else:
assert False, "NIY"
costs = np.zeros(E_to_process)
nll_bound_vaes = np.zeros(E_to_process)
nll_bound_iwaes = np.zeros(E_to_process)
train_steps = np.zeros(E_to_process)
for es_idx in range(E_to_process):
print("Running %d of %d." % (es_idx+1, E_to_process))
example_idxs = es_idx * np.ones(batch_size, dtype=np.int32)
data_bxtxd, ext_input_bxtxi = self.get_batch(data_extxd,
ext_input_extxi,
batch_size=batch_size,
example_idxs=example_idxs)
model_values = self.eval_model_runs_batch(data_name, data_bxtxd,
ext_input_bxtxi,
do_eval_cost=True,
do_average_batch=True)
if self.hps.ic_dim > 0:
prior_g0_mean[es_idx,:] = model_values['prior_g0_mean']
prior_g0_logvar[es_idx,:] = model_values['prior_g0_logvar']
post_g0_mean[es_idx,:] = model_values['post_g0_mean']
post_g0_logvar[es_idx,:] = model_values['post_g0_logvar']
gen_ics[es_idx,:] = model_values['gen_ics']
if self.hps.co_dim > 0:
controller_outputs[es_idx,:,:] = model_values['controller_outputs']
gen_states[es_idx,:,:] = model_values['gen_states']
factors[es_idx,:,:] = model_values['factors']
out_dist_params[es_idx,:,:] = model_values['output_dist_params']
costs[es_idx] = model_values['costs']
nll_bound_vaes[es_idx] = model_values['nll_bound_vaes']
nll_bound_iwaes[es_idx] = model_values['nll_bound_iwaes']
train_steps[es_idx] = model_values['train_steps']
print('bound nll(vae): %.3f, bound nll(iwae): %.3f' \
% (nll_bound_vaes[es_idx], nll_bound_iwaes[es_idx]))
model_runs = {}
if self.hps.ic_dim > 0:
model_runs['prior_g0_mean'] = prior_g0_mean
model_runs['prior_g0_logvar'] = prior_g0_logvar
model_runs['post_g0_mean'] = post_g0_mean
model_runs['post_g0_logvar'] = post_g0_logvar
model_runs['gen_ics'] = gen_ics
if self.hps.co_dim > 0:
model_runs['controller_outputs'] = controller_outputs
model_runs['gen_states'] = gen_states
model_runs['factors'] = factors
model_runs['output_dist_params'] = out_dist_params
model_runs['costs'] = costs
model_runs['nll_bound_vaes'] = nll_bound_vaes
model_runs['nll_bound_iwaes'] = nll_bound_iwaes
model_runs['train_steps'] = train_steps
return model_runs
def eval_model_runs_push_mean(self, data_name, data_extxd,
ext_input_extxi=None):
"""Returns values of interest for the model by pushing the means through
The mean values for both initial conditions and the control inputs are
pushed through the model instead of sampling (as is done in
eval_model_runs_avg_epoch).
This is a quick and approximate version of estimating these values instead
of sampling from the posterior many times and then averaging those values of
interest.
Internally, a total of batch_size trials are run through the model at once.
Args:
data_name: The name of the data dict, to select which in/out matrices
to use.
data_extxd: Numpy array training data with shape:
# examples x # time steps x # dimensions
ext_input_extxi (optional): Numpy array training external input with
shape: # examples x # time steps x # external input dims
Returns:
A dictionary with the estimated outputs of the model decoder, namely:
prior g0 mean, prior g0 variance, approx. posterior mean, approx
posterior mean, the generator initial conditions, the control inputs (if
enabled), the state of the generator, the factors, and the output
distribution parameters, e.g. (rates or mean and variances).
"""
hps = self.hps
batch_size = hps.batch_size
E, T, D = data_extxd.shape
E_to_process = hps.ps_nexamples_to_process
if E_to_process > E:
print("Setting number of posterior samples to process to : ", E)
E_to_process = E
if hps.ic_dim > 0:
prior_g0_mean = np.zeros([E_to_process, hps.ic_dim])
prior_g0_logvar = np.zeros([E_to_process, hps.ic_dim])
post_g0_mean = np.zeros([E_to_process, hps.ic_dim])
post_g0_logvar = np.zeros([E_to_process, hps.ic_dim])
if hps.co_dim > 0:
controller_outputs = np.zeros([E_to_process, T, hps.co_dim])
gen_ics = np.zeros([E_to_process, hps.gen_dim])
gen_states = np.zeros([E_to_process, T, hps.gen_dim])
factors = np.zeros([E_to_process, T, hps.factors_dim])
if hps.output_dist == 'poisson':
out_dist_params = np.zeros([E_to_process, T, D])
elif hps.output_dist == 'gaussian':
out_dist_params = np.zeros([E_to_process, T, D+D])
else:
assert False, "NIY"
costs = np.zeros(E_to_process)
nll_bound_vaes = np.zeros(E_to_process)
nll_bound_iwaes = np.zeros(E_to_process)
train_steps = np.zeros(E_to_process)
# generator that will yield 0:N in groups of per items, e.g.
# (0:per-1), (per:2*per-1), ..., with the last group containing <= per items
# this will be used to feed per=batch_size trials into the model at a time
def trial_batches(N, per):
for i in range(0, N, per):
yield np.arange(i, min(i+per, N), dtype=np.int32)
for batch_idx, es_idx in enumerate(trial_batches(E_to_process,
hps.batch_size)):
print("Running trial batch %d with %d trials" % (batch_idx+1,
len(es_idx)))
data_bxtxd, ext_input_bxtxi = self.get_batch(data_extxd,
ext_input_extxi,
batch_size=batch_size,
example_idxs=es_idx)
model_values = self.eval_model_runs_batch(data_name, data_bxtxd,
ext_input_bxtxi,
do_eval_cost=True,
do_average_batch=False)
if self.hps.ic_dim > 0:
prior_g0_mean[es_idx,:] = model_values['prior_g0_mean']
prior_g0_logvar[es_idx,:] = model_values['prior_g0_logvar']
post_g0_mean[es_idx,:] = model_values['post_g0_mean']
post_g0_logvar[es_idx,:] = model_values['post_g0_logvar']
gen_ics[es_idx,:] = model_values['gen_ics']
if self.hps.co_dim > 0:
controller_outputs[es_idx,:,:] = model_values['controller_outputs']
gen_states[es_idx,:,:] = model_values['gen_states']
factors[es_idx,:,:] = model_values['factors']
out_dist_params[es_idx,:,:] = model_values['output_dist_params']
# TODO
# model_values['costs'] and other costs come out as scalars, summed over
# all the trials in the batch. what we want is the per-trial costs
costs[es_idx] = model_values['costs']
nll_bound_vaes[es_idx] = model_values['nll_bound_vaes']
nll_bound_iwaes[es_idx] = model_values['nll_bound_iwaes']
train_steps[es_idx] = model_values['train_steps']
model_runs = {}
if self.hps.ic_dim > 0:
model_runs['prior_g0_mean'] = prior_g0_mean
model_runs['prior_g0_logvar'] = prior_g0_logvar
model_runs['post_g0_mean'] = post_g0_mean
model_runs['post_g0_logvar'] = post_g0_logvar
model_runs['gen_ics'] = gen_ics
if self.hps.co_dim > 0:
model_runs['controller_outputs'] = controller_outputs
model_runs['gen_states'] = gen_states
model_runs['factors'] = factors
model_runs['output_dist_params'] = out_dist_params
# You probably do not want the LL associated values when pushing the mean
# instead of sampling.
model_runs['costs'] = costs
model_runs['nll_bound_vaes'] = nll_bound_vaes
model_runs['nll_bound_iwaes'] = nll_bound_iwaes
model_runs['train_steps'] = train_steps
return model_runs
def write_model_runs(self, datasets, output_fname=None, push_mean=False):
"""Run the model on the data in data_dict, and save the computed values.
LFADS generates a number of outputs for each examples, and these are all
saved. They are:
The mean and variance of the prior of g0.
The mean and variance of approximate posterior of g0.
The control inputs (if enabled)
The initial conditions, g0, for all examples.
The generator states for all time.
The factors for all time.
The output distribution parameters (e.g. rates) for all time.
Args:
datasets: a dictionary of named data_dictionaries, see top of lfads.py
output_fname: a file name stem for the output files.
push_mean: if False (default), generates batch_size samples for each trial
and averages the results. if True, runs each trial once without noise,
pushing the posterior mean initial conditions and control inputs through
the trained model. False is used for posterior_sample_and_average, True
is used for posterior_push_mean.
"""
hps = self.hps
kind = hps.kind
for data_name, data_dict in datasets.items():
data_tuple = [('train', data_dict['train_data'],
data_dict['train_ext_input']),
('valid', data_dict['valid_data'],
data_dict['valid_ext_input'])]
for data_kind, data_extxd, ext_input_extxi in data_tuple:
if not output_fname:
fname = "model_runs_" + data_name + '_' + data_kind + '_' + kind
else:
fname = output_fname + data_name + '_' + data_kind + '_' + kind
print("Writing data for %s data and kind %s." % (data_name, data_kind))
if push_mean:
model_runs = self.eval_model_runs_push_mean(data_name, data_extxd,
ext_input_extxi)
else:
model_runs = self.eval_model_runs_avg_epoch(data_name, data_extxd,
ext_input_extxi)
full_fname = os.path.join(hps.lfads_save_dir, fname)
write_data(full_fname, model_runs, compression='gzip')
print("Done.")
def write_model_samples(self, dataset_name, output_fname=None):
"""Use the prior distribution to generate batch_size number of samples
from the model.
LFADS generates a number of outputs for each sample, and these are all
saved. They are:
The mean and variance of the prior of g0.
The control inputs (if enabled)
The initial conditions, g0, for all examples.
The generator states for all time.
The factors for all time.
The output distribution parameters (e.g. rates) for all time.
Args:
dataset_name: The name of the dataset to grab the factors -> rates
alignment matrices from.
output_fname: The name of the file in which to save the generated
samples.
"""
hps = self.hps
batch_size = hps.batch_size
print("Generating %d samples" % (batch_size))
tf_vals = [self.factors, self.gen_states, self.gen_ics,
self.cost, self.output_dist_params]
if hps.ic_dim > 0:
tf_vals += [self.prior_zs_g0.mean, self.prior_zs_g0.logvar]
if hps.co_dim > 0:
tf_vals += [self.prior_zs_ar_con.samples_t]
tf_vals_flat, fidxs = flatten(tf_vals)
session = tf.get_default_session()
feed_dict = {}
feed_dict[self.dataName] = dataset_name
feed_dict[self.keep_prob] = 1.0
np_vals_flat = session.run(tf_vals_flat, feed_dict=feed_dict)
ff = 0
factors = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
gen_states = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
gen_ics = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
costs = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
output_dist_params = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
if hps.ic_dim > 0:
prior_g0_mean = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
prior_g0_logvar = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
if hps.co_dim > 0:
prior_zs_ar_con = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
# [0] are to take out the non-temporal items from lists
gen_ics = gen_ics[0]
costs = costs[0]
# Convert to full tensors, not lists of tensors in time dim.
gen_states = list_t_bxn_to_tensor_bxtxn(gen_states)
factors = list_t_bxn_to_tensor_bxtxn(factors)
output_dist_params = list_t_bxn_to_tensor_bxtxn(output_dist_params)
if hps.ic_dim > 0:
prior_g0_mean = prior_g0_mean[0]
prior_g0_logvar = prior_g0_logvar[0]
if hps.co_dim > 0:
prior_zs_ar_con = list_t_bxn_to_tensor_bxtxn(prior_zs_ar_con)
model_vals = {}
model_vals['gen_ics'] = gen_ics
model_vals['gen_states'] = gen_states
model_vals['factors'] = factors
model_vals['output_dist_params'] = output_dist_params
model_vals['costs'] = costs.reshape(1)
if hps.ic_dim > 0:
model_vals['prior_g0_mean'] = prior_g0_mean
model_vals['prior_g0_logvar'] = prior_g0_logvar
if hps.co_dim > 0:
model_vals['prior_zs_ar_con'] = prior_zs_ar_con
full_fname = os.path.join(hps.lfads_save_dir, output_fname)
write_data(full_fname, model_vals, compression='gzip')
print("Done.")
@staticmethod
def eval_model_parameters(use_nested=True, include_strs=None):
"""Evaluate and return all of the TF variables in the model.
Args:
use_nested (optional): For returning values, use a nested dictoinary, based
on variable scoping, or return all variables in a flat dictionary.
include_strs (optional): A list of strings to use as a filter, to reduce the
number of variables returned. A variable name must contain at least one
string in include_strs as a sub-string in order to be returned.
Returns:
The parameters of the model. This can be in a flat
dictionary, or a nested dictionary, where the nesting is by variable
scope.
"""
all_tf_vars = tf.global_variables()
session = tf.get_default_session()
all_tf_vars_eval = session.run(all_tf_vars)
vars_dict = {}
strs = ["LFADS"]
if include_strs:
strs += include_strs
for i, (var, var_eval) in enumerate(zip(all_tf_vars, all_tf_vars_eval)):
if any(s in include_strs for s in var.name):
if not isinstance(var_eval, np.ndarray): # for H5PY
print(var.name, """ is not numpy array, saving as numpy array
with value: """, var_eval, type(var_eval))
e = np.array(var_eval)
print(e, type(e))
else:
e = var_eval
vars_dict[var.name] = e
if not use_nested:
return vars_dict
var_names = vars_dict.keys()
nested_vars_dict = {}
current_dict = nested_vars_dict
for v, var_name in enumerate(var_names):
var_split_name_list = var_name.split('/')
split_name_list_len = len(var_split_name_list)
current_dict = nested_vars_dict
for p, part in enumerate(var_split_name_list):
if p < split_name_list_len - 1:
if part in current_dict:
current_dict = current_dict[part]
else:
current_dict[part] = {}
current_dict = current_dict[part]
else:
current_dict[part] = vars_dict[var_name]
return nested_vars_dict
@staticmethod
def spikify_rates(rates_bxtxd):
"""Randomly spikify underlying rates according a Poisson distribution
Args:
rates_bxtxd: a numpy tensor with shape:
Returns:
A numpy array with the same shape as rates_bxtxd, but with the event
counts.
"""
B,T,N = rates_bxtxd.shape
assert all([B > 0, N > 0]), "problems"
# Because the rates are changing, there is nesting
spikes_bxtxd = np.zeros([B,T,N], dtype=np.int32)
for b in range(B):
for t in range(T):
for n in range(N):
rate = rates_bxtxd[b,t,n]
count = np.random.poisson(rate)
spikes_bxtxd[b,t,n] = count
return spikes_bxtxd
|