MilesCranmer commited on
Commit
c2f964b
·
unverified ·
2 Parent(s): 8abc426 7de1c20

Merge pull request #531 from MilesCranmer/update-colab

Browse files
Files changed (1) hide show
  1. examples/pysr_demo.ipynb +210 -116
examples/pysr_demo.ipynb CHANGED
@@ -19,7 +19,7 @@
19
  "## Instructions\n",
20
  "1. Work on a copy of this notebook: _File_ > _Save a copy in Drive_ (you will need a Google account).\n",
21
  "2. (Optional) If you would like to do the deep learning component of this tutorial, turn on the GPU with Edit->Notebook settings->Hardware accelerator->GPU\n",
22
- "3. Execute the following cell (click on it and press Ctrl+Enter) to install Julia, IJulia and other packages (if needed, update `JULIA_VERSION` and the other parameters). This takes a couple of minutes.\n",
23
  "4. Continue to the next section.\n",
24
  "\n",
25
  "_Notes_:\n",
@@ -40,33 +40,34 @@
40
  "cell_type": "code",
41
  "execution_count": null,
42
  "metadata": {
43
- "id": "GIeFXS0F0zww"
 
 
 
 
44
  },
45
  "outputs": [],
46
  "source": [
47
- "%%shell\n",
48
- "set -e\n",
49
- "\n",
50
- "#---------------------------------------------------#\n",
51
- "JULIA_VERSION=\"1.8.5\"\n",
52
- "export JULIA_PKG_PRECOMPILE_AUTO=0\n",
53
- "#---------------------------------------------------#\n",
54
- "\n",
55
- "if [ -z `which julia` ]; then\n",
56
- " # Install Julia\n",
57
- " JULIA_VER=`cut -d '.' -f -2 <<< \"$JULIA_VERSION\"`\n",
58
- " echo \"Installing Julia $JULIA_VERSION on the current Colab Runtime...\"\n",
59
- " BASE_URL=\"https://julialang-s3.julialang.org/bin/linux/x64\"\n",
60
- " URL=\"$BASE_URL/$JULIA_VER/julia-$JULIA_VERSION-linux-x86_64.tar.gz\"\n",
61
- " wget -nv $URL -O /tmp/julia.tar.gz # -nv means \"not verbose\"\n",
62
- " tar -x -f /tmp/julia.tar.gz -C /usr/local --strip-components 1\n",
63
- " rm /tmp/julia.tar.gz\n",
64
- "\n",
65
- " echo \"Installing PyCall.jl...\"\n",
66
- " julia -e 'using Pkg; Pkg.add(\"PyCall\"); Pkg.build(\"PyCall\")'\n",
67
- " julia -e 'println(\"Success\")'\n",
68
  "\n",
69
- "fi"
 
70
  ]
71
  },
72
  {
@@ -75,18 +76,22 @@
75
  "id": "ORv1c6xvbDgV"
76
  },
77
  "source": [
78
- "Install PySR and PyTorch-Lightning:"
79
  ]
80
  },
81
  {
82
  "cell_type": "code",
83
  "execution_count": null,
84
  "metadata": {
85
- "id": "EhMRSZEYFPLz"
 
 
 
 
86
  },
87
  "outputs": [],
88
  "source": [
89
- "%pip install -Uq pysr pytorch_lightning"
90
  ]
91
  },
92
  {
@@ -95,7 +100,7 @@
95
  "id": "etTMEV0wDqld"
96
  },
97
  "source": [
98
- "The following step is not normally required, but colab's printing is non-standard and we need to manually set it up PyJulia:\n"
99
  ]
100
  },
101
  {
@@ -106,38 +111,25 @@
106
  },
107
  "outputs": [],
108
  "source": [
109
- "from julia import Julia\n",
 
 
110
  "\n",
111
- "julia = Julia(compiled_modules=False, threads=\"auto\")\n",
112
- "from julia import Main\n",
113
- "from julia.tools import redirect_output_streams\n",
114
  "\n",
115
- "redirect_output_streams()"
 
116
  ]
117
  },
118
  {
119
  "cell_type": "markdown",
120
  "metadata": {
121
- "id": "6u2WhbVhht-G"
122
  },
123
  "source": [
124
- "Let's install the backend of PySR, and all required libraries.\n",
125
- "\n",
126
- "**(This may take some time)**"
127
- ]
128
- },
129
- {
130
- "cell_type": "code",
131
- "execution_count": null,
132
- "metadata": {
133
- "id": "J-0QbxyK1_51"
134
- },
135
- "outputs": [],
136
- "source": [
137
- "import pysr\n",
138
- "\n",
139
- "# We don't precompile in colab because compiled modules are incompatible static Python libraries:\n",
140
- "pysr.install(precompile=False)"
141
  ]
142
  },
143
  {
@@ -227,14 +219,19 @@
227
  "cell_type": "code",
228
  "execution_count": null,
229
  "metadata": {
230
- "id": "p4PSrO-NK1Wa"
 
 
 
 
 
231
  },
232
  "outputs": [],
233
  "source": [
234
  "# Learn equations\n",
235
  "model = PySRRegressor(\n",
236
  " niterations=30,\n",
237
- " binary_operators=[\"plus\", \"mult\"],\n",
238
  " unary_operators=[\"cos\", \"exp\", \"sin\"],\n",
239
  " **default_pysr_params\n",
240
  ")\n",
@@ -255,7 +252,12 @@
255
  "cell_type": "code",
256
  "execution_count": null,
257
  "metadata": {
258
- "id": "4HR8gknlZz4W"
 
 
 
 
 
259
  },
260
  "outputs": [],
261
  "source": [
@@ -275,7 +277,12 @@
275
  "cell_type": "code",
276
  "execution_count": null,
277
  "metadata": {
278
- "id": "IQKOohdpztS7"
 
 
 
 
 
279
  },
280
  "outputs": [],
281
  "source": [
@@ -295,7 +302,12 @@
295
  "cell_type": "code",
296
  "execution_count": null,
297
  "metadata": {
298
- "id": "GRcxq-TTlpRX"
 
 
 
 
 
299
  },
300
  "outputs": [],
301
  "source": [
@@ -324,7 +336,12 @@
324
  "cell_type": "code",
325
  "execution_count": null,
326
  "metadata": {
327
- "id": "HFGaNL6tbDgi"
 
 
 
 
 
328
  },
329
  "outputs": [],
330
  "source": [
@@ -346,7 +363,11 @@
346
  "cell_type": "code",
347
  "execution_count": null,
348
  "metadata": {
349
- "id": "Vbz4IMsk2NYH"
 
 
 
 
350
  },
351
  "outputs": [],
352
  "source": [
@@ -406,14 +427,19 @@
406
  "cell_type": "code",
407
  "execution_count": null,
408
  "metadata": {
409
- "id": "PoEkpvYuGUdy"
 
 
 
 
 
410
  },
411
  "outputs": [],
412
  "source": [
413
  "model = PySRRegressor(\n",
414
  " niterations=5,\n",
415
  " populations=40,\n",
416
- " binary_operators=[\"plus\", \"mult\"],\n",
417
  " unary_operators=[\"cos\", \"exp\", \"sin\", \"quart(x) = x^4\"],\n",
418
  " extra_sympy_mappings={\"quart\": lambda x: x**4},\n",
419
  ")\n",
@@ -424,7 +450,12 @@
424
  "cell_type": "code",
425
  "execution_count": null,
426
  "metadata": {
427
- "id": "emn2IajKbDgy"
 
 
 
 
 
428
  },
429
  "outputs": [],
430
  "source": [
@@ -546,7 +577,12 @@
546
  "cell_type": "code",
547
  "execution_count": null,
548
  "metadata": {
549
- "id": "sqMqb4nJ5ZR5"
 
 
 
 
 
550
  },
551
  "outputs": [],
552
  "source": [
@@ -579,7 +615,11 @@
579
  "cell_type": "code",
580
  "execution_count": null,
581
  "metadata": {
582
- "id": "v8WBYtcZbDhC"
 
 
 
 
583
  },
584
  "outputs": [],
585
  "source": [
@@ -599,7 +639,11 @@
599
  "cell_type": "code",
600
  "execution_count": null,
601
  "metadata": {
602
- "id": "a07K3KUjOxcp"
 
 
 
 
603
  },
604
  "outputs": [],
605
  "source": [
@@ -607,7 +651,7 @@
607
  " loss=\"myloss(x, y, w) = w * abs(x - y)\", # Custom loss function with weights.\n",
608
  " niterations=20,\n",
609
  " populations=20, # Use more populations\n",
610
- " binary_operators=[\"plus\", \"mult\"],\n",
611
  " unary_operators=[\"cos\"],\n",
612
  ")\n",
613
  "model.fit(X, y, weights=weights)"
@@ -688,17 +732,19 @@
688
  ]
689
  },
690
  {
691
- "attachments": {},
692
  "cell_type": "markdown",
693
- "metadata": {},
 
 
694
  "source": [
695
  "# Multiple outputs"
696
  ]
697
  },
698
  {
699
- "attachments": {},
700
  "cell_type": "markdown",
701
- "metadata": {},
 
 
702
  "source": [
703
  "For multiple outputs, multiple equations are returned:"
704
  ]
@@ -706,7 +752,9 @@
706
  {
707
  "cell_type": "code",
708
  "execution_count": null,
709
- "metadata": {},
 
 
710
  "outputs": [],
711
  "source": [
712
  "X = 2 * np.random.randn(100, 5)\n",
@@ -716,7 +764,9 @@
716
  {
717
  "cell_type": "code",
718
  "execution_count": null,
719
- "metadata": {},
 
 
720
  "outputs": [],
721
  "source": [
722
  "model = PySRRegressor(\n",
@@ -730,24 +780,28 @@
730
  {
731
  "cell_type": "code",
732
  "execution_count": null,
733
- "metadata": {},
 
 
734
  "outputs": [],
735
  "source": [
736
  "model"
737
  ]
738
  },
739
  {
740
- "attachments": {},
741
  "cell_type": "markdown",
742
- "metadata": {},
 
 
743
  "source": [
744
  "# Julia packages and types"
745
  ]
746
  },
747
  {
748
- "attachments": {},
749
  "cell_type": "markdown",
750
- "metadata": {},
 
 
751
  "source": [
752
  "PySR uses [SymbolicRegression.jl](https://github.com/MilesCranmer/SymbolicRegression.jl)\n",
753
  "as its search backend. This is a pure Julia package, and so can interface easily with any other\n",
@@ -771,20 +825,23 @@
771
  {
772
  "cell_type": "code",
773
  "execution_count": null,
774
- "metadata": {},
 
 
775
  "outputs": [],
776
  "source": [
777
  "import pysr\n",
778
  "\n",
779
  "jl = pysr.julia_helpers.init_julia(\n",
780
- " julia_kwargs={\"threads\": \"auto\", \"optimize\": 2, \"compiled_modules\": False}\n",
781
  ")"
782
  ]
783
  },
784
  {
785
- "attachments": {},
786
  "cell_type": "markdown",
787
- "metadata": {},
 
 
788
  "source": [
789
  "\n",
790
  "\n",
@@ -797,7 +854,9 @@
797
  {
798
  "cell_type": "code",
799
  "execution_count": null,
800
- "metadata": {},
 
 
801
  "outputs": [],
802
  "source": [
803
  "jl.eval(\n",
@@ -809,9 +868,10 @@
809
  ]
810
  },
811
  {
812
- "attachments": {},
813
  "cell_type": "markdown",
814
- "metadata": {},
 
 
815
  "source": [
816
  "This imports the Julia package manager, and uses it to install\n",
817
  "`Primes.jl`. Now let's import `Primes.jl`:"
@@ -820,16 +880,19 @@
820
  {
821
  "cell_type": "code",
822
  "execution_count": null,
823
- "metadata": {},
 
 
824
  "outputs": [],
825
  "source": [
826
  "jl.eval(\"import Primes\")"
827
  ]
828
  },
829
  {
830
- "attachments": {},
831
  "cell_type": "markdown",
832
- "metadata": {},
 
 
833
  "source": [
834
  "\n",
835
  "Now, we define a custom operator:\n"
@@ -838,7 +901,9 @@
838
  {
839
  "cell_type": "code",
840
  "execution_count": null,
841
- "metadata": {},
 
 
842
  "outputs": [],
843
  "source": [
844
  "jl.eval(\n",
@@ -855,9 +920,10 @@
855
  ]
856
  },
857
  {
858
- "attachments": {},
859
  "cell_type": "markdown",
860
- "metadata": {},
 
 
861
  "source": [
862
  "\n",
863
  "We have created a function `p`, which takes a number `i` of type `T` (e.g., `T=Float64`).\n",
@@ -887,16 +953,19 @@
887
  {
888
  "cell_type": "code",
889
  "execution_count": null,
890
- "metadata": {},
 
 
891
  "outputs": [],
892
  "source": [
893
  "primes = {i: jl.p(i * 1.0) for i in range(1, 999)}"
894
  ]
895
  },
896
  {
897
- "attachments": {},
898
  "cell_type": "markdown",
899
- "metadata": {},
 
 
900
  "source": [
901
  "Next, let's use this list of primes to create a dataset of $x, y$ pairs:"
902
  ]
@@ -904,7 +973,9 @@
904
  {
905
  "cell_type": "code",
906
  "execution_count": null,
907
- "metadata": {},
 
 
908
  "outputs": [],
909
  "source": [
910
  "import numpy as np\n",
@@ -914,9 +985,10 @@
914
  ]
915
  },
916
  {
917
- "attachments": {},
918
  "cell_type": "markdown",
919
- "metadata": {},
 
 
920
  "source": [
921
  "Note that we have also added a tiny bit of noise to the dataset.\n",
922
  "\n",
@@ -926,7 +998,9 @@
926
  {
927
  "cell_type": "code",
928
  "execution_count": null,
929
- "metadata": {},
 
 
930
  "outputs": [],
931
  "source": [
932
  "from pysr import PySRRegressor\n",
@@ -947,8 +1021,9 @@
947
  },
948
  {
949
  "cell_type": "markdown",
950
- "id": "ee30bd41",
951
- "metadata": {},
 
952
  "source": [
953
  "We are all set to go! Let's see if we can find the true relation:"
954
  ]
@@ -956,16 +1031,19 @@
956
  {
957
  "cell_type": "code",
958
  "execution_count": null,
959
- "metadata": {},
 
 
960
  "outputs": [],
961
  "source": [
962
  "model.fit(X, y)"
963
  ]
964
  },
965
  {
966
- "attachments": {},
967
  "cell_type": "markdown",
968
- "metadata": {},
 
 
969
  "source": [
970
  "if all works out, you should be able to see the true relation (note that the constant offset might not be exactly 1, since it is allowed to round to the nearest integer).\n",
971
  "\n",
@@ -975,7 +1053,9 @@
975
  {
976
  "cell_type": "code",
977
  "execution_count": null,
978
- "metadata": {},
 
 
979
  "outputs": [],
980
  "source": [
981
  "model.sympy()"
@@ -991,7 +1071,6 @@
991
  ]
992
  },
993
  {
994
- "attachments": {},
995
  "cell_type": "markdown",
996
  "metadata": {
997
  "id": "3hS2kTAbbDhL"
@@ -1068,6 +1147,17 @@
1068
  "> We import torch *after* already starting PyJulia. This is required due to interference between their C bindings. If you use torch, and then run PyJulia, you will likely hit a segfault. So keep this in mind for mixed deep learning + PyJulia/PySR workflows."
1069
  ]
1070
  },
 
 
 
 
 
 
 
 
 
 
 
1071
  {
1072
  "cell_type": "code",
1073
  "execution_count": null,
@@ -1083,7 +1173,7 @@
1083
  "import pytorch_lightning as pl\n",
1084
  "\n",
1085
  "hidden = 128\n",
1086
- "total_steps = 30_000\n",
1087
  "\n",
1088
  "\n",
1089
  "def mlp(size_in, size_out, act=nn.ReLU):\n",
@@ -1298,7 +1388,9 @@
1298
  {
1299
  "cell_type": "code",
1300
  "execution_count": null,
1301
- "metadata": {},
 
 
1302
  "outputs": [],
1303
  "source": [
1304
  "nnet_recordings = {\n",
@@ -1316,9 +1408,10 @@
1316
  ]
1317
  },
1318
  {
1319
- "attachments": {},
1320
  "cell_type": "markdown",
1321
- "metadata": {},
 
 
1322
  "source": [
1323
  "We can now load the data, including after a crash (be sure to re-run the import cells at the top of this notebook, including the one that starts PyJulia)."
1324
  ]
@@ -1326,7 +1419,9 @@
1326
  {
1327
  "cell_type": "code",
1328
  "execution_count": null,
1329
- "metadata": {},
 
 
1330
  "outputs": [],
1331
  "source": [
1332
  "import pickle as pkl\n",
@@ -1339,9 +1434,10 @@
1339
  ]
1340
  },
1341
  {
1342
- "attachments": {},
1343
  "cell_type": "markdown",
1344
- "metadata": {},
 
 
1345
  "source": [
1346
  "And now fit using a subsample of the data (symbolic regression only needs a small sample to find the best equation):"
1347
  ]
@@ -1358,9 +1454,9 @@
1358
  "f_sample_idx = rstate.choice(f_input.shape[0], size=500, replace=False)\n",
1359
  "\n",
1360
  "model = PySRRegressor(\n",
1361
- " niterations=20,\n",
1362
- " binary_operators=[\"plus\", \"sub\", \"mult\"],\n",
1363
- " unary_operators=[\"cos\", \"square\", \"neg\"],\n",
1364
  ")\n",
1365
  "model.fit(g_input[f_sample_idx], g_output[f_sample_idx])"
1366
  ]
@@ -1384,14 +1480,13 @@
1384
  ]
1385
  },
1386
  {
1387
- "attachments": {},
1388
  "cell_type": "markdown",
1389
  "metadata": {
1390
  "id": "6WuaeqyqbDhe"
1391
  },
1392
  "source": [
1393
  "Recall we are searching for $f$ and $g$ such that:\n",
1394
- "$$z=f(\\sum g(x_i))$$ \n",
1395
  "which approximates the true relation:\n",
1396
  "$$ z = y^2,\\quad y = \\frac{1}{10} \\sum(y_i),\\quad y_i = x_{i0}^2 + 6 \\cos(2 x_{i2})$$\n",
1397
  "\n",
@@ -1459,7 +1554,6 @@
1459
  "metadata": {
1460
  "accelerator": "GPU",
1461
  "colab": {
1462
- "name": "pysr_demo.ipynb",
1463
  "provenance": []
1464
  },
1465
  "gpuClass": "standard",
 
19
  "## Instructions\n",
20
  "1. Work on a copy of this notebook: _File_ > _Save a copy in Drive_ (you will need a Google account).\n",
21
  "2. (Optional) If you would like to do the deep learning component of this tutorial, turn on the GPU with Edit->Notebook settings->Hardware accelerator->GPU\n",
22
+ "3. Execute the following cell (click on it and press Ctrl+Enter) to install Julia. This may take a minute or so.\n",
23
  "4. Continue to the next section.\n",
24
  "\n",
25
  "_Notes_:\n",
 
40
  "cell_type": "code",
41
  "execution_count": null,
42
  "metadata": {
43
+ "colab": {
44
+ "base_uri": "https://localhost:8080/"
45
+ },
46
+ "id": "GIeFXS0F0zww",
47
+ "outputId": "5399ed75-f77f-47c5-e53b-4b2f231f2839"
48
  },
49
  "outputs": [],
50
  "source": [
51
+ "!curl -fsSL https://install.julialang.org | sh -s -- -y --default-channel 1.10"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": null,
57
+ "metadata": {
58
+ "colab": {
59
+ "base_uri": "https://localhost:8080/"
60
+ },
61
+ "id": "Iu9X-Y-YNmwM",
62
+ "outputId": "ee14af65-043a-4ad6-efa0-3cdcc48a4eb8"
63
+ },
64
+ "outputs": [],
65
+ "source": [
66
+ "# Make julia available on PATH:\n",
67
+ "!ln -s $HOME/.juliaup/bin/julia /usr/local/bin/julia\n",
 
 
 
 
68
  "\n",
69
+ "# Test it works:\n",
70
+ "!julia --version"
71
  ]
72
  },
73
  {
 
76
  "id": "ORv1c6xvbDgV"
77
  },
78
  "source": [
79
+ "Install PySR"
80
  ]
81
  },
82
  {
83
  "cell_type": "code",
84
  "execution_count": null,
85
  "metadata": {
86
+ "colab": {
87
+ "base_uri": "https://localhost:8080/"
88
+ },
89
+ "id": "EhMRSZEYFPLz",
90
+ "outputId": "e3aad3cb-d921-473e-b77b-8fa6a3a9e2e8"
91
  },
92
  "outputs": [],
93
  "source": [
94
+ "!pip install pysr && python -m pysr install"
95
  ]
96
  },
97
  {
 
100
  "id": "etTMEV0wDqld"
101
  },
102
  "source": [
103
+ "Colab's printing is non-standard, so we need to manually initialize Julia and redirect its printing. Normally, however, this is not required, and PySR will automatically start Julia during the first call to `.fit`:"
104
  ]
105
  },
106
  {
 
111
  },
112
  "outputs": [],
113
  "source": [
114
+ "def init_colab_printing():\n",
115
+ " from pysr.julia_helpers import init_julia\n",
116
+ " from julia.tools import redirect_output_streams\n",
117
  "\n",
118
+ " julia_kwargs = dict(optimize=3, threads=\"auto\", compiled_modules=False)\n",
119
+ " init_julia(julia_kwargs=julia_kwargs)\n",
120
+ " redirect_output_streams()\n",
121
  "\n",
122
+ "\n",
123
+ "init_colab_printing()"
124
  ]
125
  },
126
  {
127
  "cell_type": "markdown",
128
  "metadata": {
129
+ "id": "qeCPKd9wldEK"
130
  },
131
  "source": [
132
+ "Now, let's import all of our libraries:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  ]
134
  },
135
  {
 
219
  "cell_type": "code",
220
  "execution_count": null,
221
  "metadata": {
222
+ "colab": {
223
+ "base_uri": "https://localhost:8080/",
224
+ "height": 1000
225
+ },
226
+ "id": "p4PSrO-NK1Wa",
227
+ "outputId": "55910ab3-895d-400b-e9ce-c75aef639c68"
228
  },
229
  "outputs": [],
230
  "source": [
231
  "# Learn equations\n",
232
  "model = PySRRegressor(\n",
233
  " niterations=30,\n",
234
+ " binary_operators=[\"+\", \"*\"],\n",
235
  " unary_operators=[\"cos\", \"exp\", \"sin\"],\n",
236
  " **default_pysr_params\n",
237
  ")\n",
 
252
  "cell_type": "code",
253
  "execution_count": null,
254
  "metadata": {
255
+ "colab": {
256
+ "base_uri": "https://localhost:8080/",
257
+ "height": 252
258
+ },
259
+ "id": "4HR8gknlZz4W",
260
+ "outputId": "496283bd-a743-4cc6-a2f9-9619ba91d870"
261
  },
262
  "outputs": [],
263
  "source": [
 
277
  "cell_type": "code",
278
  "execution_count": null,
279
  "metadata": {
280
+ "colab": {
281
+ "base_uri": "https://localhost:8080/",
282
+ "height": 38
283
+ },
284
+ "id": "IQKOohdpztS7",
285
+ "outputId": "0e7d058a-cce1-45ae-db94-6625f7e53a06"
286
  },
287
  "outputs": [],
288
  "source": [
 
302
  "cell_type": "code",
303
  "execution_count": null,
304
  "metadata": {
305
+ "colab": {
306
+ "base_uri": "https://localhost:8080/",
307
+ "height": 39
308
+ },
309
+ "id": "GRcxq-TTlpRX",
310
+ "outputId": "50bda367-1ed1-4860-8fcf-c940f2e4d935"
311
  },
312
  "outputs": [],
313
  "source": [
 
336
  "cell_type": "code",
337
  "execution_count": null,
338
  "metadata": {
339
+ "colab": {
340
+ "base_uri": "https://localhost:8080/",
341
+ "height": 35
342
+ },
343
+ "id": "HFGaNL6tbDgi",
344
+ "outputId": "0f364da5-e18d-4e31-cadf-087d641a3aed"
345
  },
346
  "outputs": [],
347
  "source": [
 
363
  "cell_type": "code",
364
  "execution_count": null,
365
  "metadata": {
366
+ "colab": {
367
+ "base_uri": "https://localhost:8080/"
368
+ },
369
+ "id": "Vbz4IMsk2NYH",
370
+ "outputId": "361d4b6e-ac23-479d-b511-5001af05ca43"
371
  },
372
  "outputs": [],
373
  "source": [
 
427
  "cell_type": "code",
428
  "execution_count": null,
429
  "metadata": {
430
+ "colab": {
431
+ "base_uri": "https://localhost:8080/",
432
+ "height": 339
433
+ },
434
+ "id": "PoEkpvYuGUdy",
435
+ "outputId": "02834373-a054-400b-8247-2bf33a5c5beb"
436
  },
437
  "outputs": [],
438
  "source": [
439
  "model = PySRRegressor(\n",
440
  " niterations=5,\n",
441
  " populations=40,\n",
442
+ " binary_operators=[\"+\", \"*\"],\n",
443
  " unary_operators=[\"cos\", \"exp\", \"sin\", \"quart(x) = x^4\"],\n",
444
  " extra_sympy_mappings={\"quart\": lambda x: x**4},\n",
445
  ")\n",
 
450
  "cell_type": "code",
451
  "execution_count": null,
452
  "metadata": {
453
+ "colab": {
454
+ "base_uri": "https://localhost:8080/",
455
+ "height": 38
456
+ },
457
+ "id": "emn2IajKbDgy",
458
+ "outputId": "11d5d3cf-de43-4f2b-f653-30016e09bdd0"
459
  },
460
  "outputs": [],
461
  "source": [
 
577
  "cell_type": "code",
578
  "execution_count": null,
579
  "metadata": {
580
+ "colab": {
581
+ "base_uri": "https://localhost:8080/",
582
+ "height": 467
583
+ },
584
+ "id": "sqMqb4nJ5ZR5",
585
+ "outputId": "aa24922b-2395-4e00-dce3-268fc8e603dc"
586
  },
587
  "outputs": [],
588
  "source": [
 
615
  "cell_type": "code",
616
  "execution_count": null,
617
  "metadata": {
618
+ "colab": {
619
+ "base_uri": "https://localhost:8080/"
620
+ },
621
+ "id": "v8WBYtcZbDhC",
622
+ "outputId": "37d4002f-e9d6-40c0-9a24-c671d9c384e6"
623
  },
624
  "outputs": [],
625
  "source": [
 
639
  "cell_type": "code",
640
  "execution_count": null,
641
  "metadata": {
642
+ "colab": {
643
+ "base_uri": "https://localhost:8080/"
644
+ },
645
+ "id": "a07K3KUjOxcp",
646
+ "outputId": "41d11915-78b7-4446-c153-b92a5e2abd4c"
647
  },
648
  "outputs": [],
649
  "source": [
 
651
  " loss=\"myloss(x, y, w) = w * abs(x - y)\", # Custom loss function with weights.\n",
652
  " niterations=20,\n",
653
  " populations=20, # Use more populations\n",
654
+ " binary_operators=[\"+\", \"*\"],\n",
655
  " unary_operators=[\"cos\"],\n",
656
  ")\n",
657
  "model.fit(X, y, weights=weights)"
 
732
  ]
733
  },
734
  {
 
735
  "cell_type": "markdown",
736
+ "metadata": {
737
+ "id": "2x-8M8W4G-KM"
738
+ },
739
  "source": [
740
  "# Multiple outputs"
741
  ]
742
  },
743
  {
 
744
  "cell_type": "markdown",
745
+ "metadata": {
746
+ "id": "LIJcWqBQG-KM"
747
+ },
748
  "source": [
749
  "For multiple outputs, multiple equations are returned:"
750
  ]
 
752
  {
753
  "cell_type": "code",
754
  "execution_count": null,
755
+ "metadata": {
756
+ "id": "_Aar1ZJwG-KM"
757
+ },
758
  "outputs": [],
759
  "source": [
760
  "X = 2 * np.random.randn(100, 5)\n",
 
764
  {
765
  "cell_type": "code",
766
  "execution_count": null,
767
+ "metadata": {
768
+ "id": "9Znwq40PG-KM"
769
+ },
770
  "outputs": [],
771
  "source": [
772
  "model = PySRRegressor(\n",
 
780
  {
781
  "cell_type": "code",
782
  "execution_count": null,
783
+ "metadata": {
784
+ "id": "0Y_vy0sqG-KM"
785
+ },
786
  "outputs": [],
787
  "source": [
788
  "model"
789
  ]
790
  },
791
  {
 
792
  "cell_type": "markdown",
793
+ "metadata": {
794
+ "id": "-UP49CsGG-KN"
795
+ },
796
  "source": [
797
  "# Julia packages and types"
798
  ]
799
  },
800
  {
 
801
  "cell_type": "markdown",
802
+ "metadata": {
803
+ "id": "tOdNHheUG-KN"
804
+ },
805
  "source": [
806
  "PySR uses [SymbolicRegression.jl](https://github.com/MilesCranmer/SymbolicRegression.jl)\n",
807
  "as its search backend. This is a pure Julia package, and so can interface easily with any other\n",
 
825
  {
826
  "cell_type": "code",
827
  "execution_count": null,
828
+ "metadata": {
829
+ "id": "yUC4BMuHG-KN"
830
+ },
831
  "outputs": [],
832
  "source": [
833
  "import pysr\n",
834
  "\n",
835
  "jl = pysr.julia_helpers.init_julia(\n",
836
+ " julia_kwargs=dict(optimize=3, threads=\"auto\", compiled_modules=False)\n",
837
  ")"
838
  ]
839
  },
840
  {
 
841
  "cell_type": "markdown",
842
+ "metadata": {
843
+ "id": "af07m4uBG-KN"
844
+ },
845
  "source": [
846
  "\n",
847
  "\n",
 
854
  {
855
  "cell_type": "code",
856
  "execution_count": null,
857
+ "metadata": {
858
+ "id": "xBlMY-s4G-KN"
859
+ },
860
  "outputs": [],
861
  "source": [
862
  "jl.eval(\n",
 
868
  ]
869
  },
870
  {
 
871
  "cell_type": "markdown",
872
+ "metadata": {
873
+ "id": "1rJFukD6G-KN"
874
+ },
875
  "source": [
876
  "This imports the Julia package manager, and uses it to install\n",
877
  "`Primes.jl`. Now let's import `Primes.jl`:"
 
880
  {
881
  "cell_type": "code",
882
  "execution_count": null,
883
+ "metadata": {
884
+ "id": "1PQl1rIaG-KN"
885
+ },
886
  "outputs": [],
887
  "source": [
888
  "jl.eval(\"import Primes\")"
889
  ]
890
  },
891
  {
 
892
  "cell_type": "markdown",
893
+ "metadata": {
894
+ "id": "edGdMxKnG-KN"
895
+ },
896
  "source": [
897
  "\n",
898
  "Now, we define a custom operator:\n"
 
901
  {
902
  "cell_type": "code",
903
  "execution_count": null,
904
+ "metadata": {
905
+ "id": "9Ut3HcW3G-KN"
906
+ },
907
  "outputs": [],
908
  "source": [
909
  "jl.eval(\n",
 
920
  ]
921
  },
922
  {
 
923
  "cell_type": "markdown",
924
+ "metadata": {
925
+ "id": "_wcV8889G-KN"
926
+ },
927
  "source": [
928
  "\n",
929
  "We have created a function `p`, which takes a number `i` of type `T` (e.g., `T=Float64`).\n",
 
953
  {
954
  "cell_type": "code",
955
  "execution_count": null,
956
+ "metadata": {
957
+ "id": "giqwisEPG-KN"
958
+ },
959
  "outputs": [],
960
  "source": [
961
  "primes = {i: jl.p(i * 1.0) for i in range(1, 999)}"
962
  ]
963
  },
964
  {
 
965
  "cell_type": "markdown",
966
+ "metadata": {
967
+ "id": "MPAqARj6G-KO"
968
+ },
969
  "source": [
970
  "Next, let's use this list of primes to create a dataset of $x, y$ pairs:"
971
  ]
 
973
  {
974
  "cell_type": "code",
975
  "execution_count": null,
976
+ "metadata": {
977
+ "id": "jab4tRRRG-KO"
978
+ },
979
  "outputs": [],
980
  "source": [
981
  "import numpy as np\n",
 
985
  ]
986
  },
987
  {
 
988
  "cell_type": "markdown",
989
+ "metadata": {
990
+ "id": "3eFgWrjcG-KO"
991
+ },
992
  "source": [
993
  "Note that we have also added a tiny bit of noise to the dataset.\n",
994
  "\n",
 
998
  {
999
  "cell_type": "code",
1000
  "execution_count": null,
1001
+ "metadata": {
1002
+ "id": "pEYskM2_G-KO"
1003
+ },
1004
  "outputs": [],
1005
  "source": [
1006
  "from pysr import PySRRegressor\n",
 
1021
  },
1022
  {
1023
  "cell_type": "markdown",
1024
+ "metadata": {
1025
+ "id": "ee30bd41"
1026
+ },
1027
  "source": [
1028
  "We are all set to go! Let's see if we can find the true relation:"
1029
  ]
 
1031
  {
1032
  "cell_type": "code",
1033
  "execution_count": null,
1034
+ "metadata": {
1035
+ "id": "li-TB19iG-KO"
1036
+ },
1037
  "outputs": [],
1038
  "source": [
1039
  "model.fit(X, y)"
1040
  ]
1041
  },
1042
  {
 
1043
  "cell_type": "markdown",
1044
+ "metadata": {
1045
+ "id": "jwhTWZryG-KO"
1046
+ },
1047
  "source": [
1048
  "if all works out, you should be able to see the true relation (note that the constant offset might not be exactly 1, since it is allowed to round to the nearest integer).\n",
1049
  "\n",
 
1053
  {
1054
  "cell_type": "code",
1055
  "execution_count": null,
1056
+ "metadata": {
1057
+ "id": "bSlpX9xAG-KO"
1058
+ },
1059
  "outputs": [],
1060
  "source": [
1061
  "model.sympy()"
 
1071
  ]
1072
  },
1073
  {
 
1074
  "cell_type": "markdown",
1075
  "metadata": {
1076
  "id": "3hS2kTAbbDhL"
 
1147
  "> We import torch *after* already starting PyJulia. This is required due to interference between their C bindings. If you use torch, and then run PyJulia, you will likely hit a segfault. So keep this in mind for mixed deep learning + PyJulia/PySR workflows."
1148
  ]
1149
  },
1150
+ {
1151
+ "cell_type": "code",
1152
+ "execution_count": null,
1153
+ "metadata": {
1154
+ "id": "k-Od8b9DlkHK"
1155
+ },
1156
+ "outputs": [],
1157
+ "source": [
1158
+ "!pip install pytorch_lightning"
1159
+ ]
1160
+ },
1161
  {
1162
  "cell_type": "code",
1163
  "execution_count": null,
 
1173
  "import pytorch_lightning as pl\n",
1174
  "\n",
1175
  "hidden = 128\n",
1176
+ "total_steps = 50_000\n",
1177
  "\n",
1178
  "\n",
1179
  "def mlp(size_in, size_out, act=nn.ReLU):\n",
 
1388
  {
1389
  "cell_type": "code",
1390
  "execution_count": null,
1391
+ "metadata": {
1392
+ "id": "UX7Am6mZG-KT"
1393
+ },
1394
  "outputs": [],
1395
  "source": [
1396
  "nnet_recordings = {\n",
 
1408
  ]
1409
  },
1410
  {
 
1411
  "cell_type": "markdown",
1412
+ "metadata": {
1413
+ "id": "krhaNlwFG-KT"
1414
+ },
1415
  "source": [
1416
  "We can now load the data, including after a crash (be sure to re-run the import cells at the top of this notebook, including the one that starts PyJulia)."
1417
  ]
 
1419
  {
1420
  "cell_type": "code",
1421
  "execution_count": null,
1422
+ "metadata": {
1423
+ "id": "NF9aSFXHG-KT"
1424
+ },
1425
  "outputs": [],
1426
  "source": [
1427
  "import pickle as pkl\n",
 
1434
  ]
1435
  },
1436
  {
 
1437
  "cell_type": "markdown",
1438
+ "metadata": {
1439
+ "id": "_hTYHhDGG-KT"
1440
+ },
1441
  "source": [
1442
  "And now fit using a subsample of the data (symbolic regression only needs a small sample to find the best equation):"
1443
  ]
 
1454
  "f_sample_idx = rstate.choice(f_input.shape[0], size=500, replace=False)\n",
1455
  "\n",
1456
  "model = PySRRegressor(\n",
1457
+ " niterations=50,\n",
1458
+ " binary_operators=[\"+\", \"-\", \"*\"],\n",
1459
+ " unary_operators=[\"cos\", \"square\"],\n",
1460
  ")\n",
1461
  "model.fit(g_input[f_sample_idx], g_output[f_sample_idx])"
1462
  ]
 
1480
  ]
1481
  },
1482
  {
 
1483
  "cell_type": "markdown",
1484
  "metadata": {
1485
  "id": "6WuaeqyqbDhe"
1486
  },
1487
  "source": [
1488
  "Recall we are searching for $f$ and $g$ such that:\n",
1489
+ "$$z=f(\\sum g(x_i))$$\n",
1490
  "which approximates the true relation:\n",
1491
  "$$ z = y^2,\\quad y = \\frac{1}{10} \\sum(y_i),\\quad y_i = x_{i0}^2 + 6 \\cos(2 x_{i2})$$\n",
1492
  "\n",
 
1554
  "metadata": {
1555
  "accelerator": "GPU",
1556
  "colab": {
 
1557
  "provenance": []
1558
  },
1559
  "gpuClass": "standard",