Agarwal commited on
Commit
1a6266a
1 Parent(s): bf9ef4a

updated for paper reproducability

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .ipynb_checkpoints/README-checkpoint.md +11 -3
  2. .ipynb_checkpoints/calculate_profiles-checkpoint.py +3 -0
  3. .ipynb_checkpoints/utils-checkpoint.py +6 -3
  4. Paper/Dataset.pdf +0 -0
  5. Paper/Snapshot_SS_colorbar.svg +503 -0
  6. Paper/figures.ipynb +0 -0
  7. Paper/los.pkl +3 -0
  8. Paper/overall_stats.pkl +3 -0
  9. Paper/simulations.pkl +3 -0
  10. Paper/x_p.pkl +3 -0
  11. Paper/y_p.pkl +3 -0
  12. Paper/y_pred_krr.pkl +3 -0
  13. Paper/y_pred_linear.pkl +3 -0
  14. Paper/y_pred_nn_interpolation.pkl +3 -0
  15. Paper/y_pred_nn_nearestneighbor.pkl +3 -0
  16. Paper/y_pred_nn_pointwise.pkl +3 -0
  17. Paper/y_prof.pkl +3 -0
  18. README.md +11 -3
  19. __pycache__/utils.cpython-310.pyc +0 -0
  20. calculate_profiles.py +3 -0
  21. data/scaling_laws_data.pkl +3 -0
  22. data/scaling_laws_data_hot_start.pkl +3 -0
  23. data/sims.pt +3 -0
  24. data/snapshots.pkl +3 -0
  25. data/x_Tprev.pkl +3 -0
  26. data/x_modes.pkl +3 -0
  27. data/x_p.pkl +3 -0
  28. data/x_pointwise_orgres.pkl +3 -0
  29. data/y_Tprev.pkl +3 -0
  30. data/y_modes.pkl +3 -0
  31. data/y_p.pkl +3 -0
  32. data/y_pointwise_orgres.pkl +3 -0
  33. data/y_prof.pt +3 -0
  34. evaluate/load_profiles.ipynb +0 -0
  35. evaluate/scaling_laws.ipynb +0 -0
  36. my_simulation_parameters.txt +10 -0
  37. outputs/profile_raq_ra10.0_fkt10000000000.0_fkv1.0.png +0 -0
  38. outputs/profile_raq_ra10.0_fkt10000000000.0_fkv1.0.txt +128 -0
  39. outputs/profile_raq_ra5.0_fkt100000000.0_fkv50.0.png +0 -0
  40. outputs/profile_raq_ra5.0_fkt100000000.0_fkv50.0.txt +128 -0
  41. outputs/profile_raq_ra7.5_fkt1000000000.0_fkv25.0.png +0 -0
  42. outputs/profile_raq_ra7.5_fkt1000000000.0_fkv25.0.txt +128 -0
  43. preprocess/preprocess_profiles.ipynb +470 -0
  44. stats/MLP_stats.txt +23 -23
  45. {data → train}/mlp.py +13 -1
  46. {data → train}/train_profiles_mlp.py +9 -7
  47. train/trained_networks/mlp_profile_[128]_2_selu/mlp.pt +3 -0
  48. train/trained_networks/mlp_profile_[128]_2_selu/mlp.txt +0 -0
  49. train/trained_networks/mlp_profile_[128]_3_selu/mlp.pt +3 -0
  50. train/trained_networks/mlp_profile_[128]_3_selu/mlp.txt +0 -0
.ipynb_checkpoints/README-checkpoint.md CHANGED
@@ -12,7 +12,15 @@ license: mit
12
 
13
  This respository contains trained neural networks that can be used to predict the steady-state temperature profile.
14
 
15
- Step 1: Define the simulation parameters
 
16
  Step 2: The output is as follows
17
- - the depth profile (first column) and the temperature profile (second column)
18
- - corresonding plot of the temperature profile
 
 
 
 
 
 
 
 
12
 
13
  This respository contains trained neural networks that can be used to predict the steady-state temperature profile.
14
 
15
+ Here is how to predict several profiles at once:
16
+ Step 1: Define the simulation parameters in my_simulation_parameters.txt
17
  Step 2: The output is as follows
18
+ - outputs/.txt files: the depth profile (first column) and the temperature profile (second column)
19
+ - outputs/.pdf files: corresonding plot of the temperature profile
20
+
21
+
22
+ Here's how to train a neural network:
23
+ Step 1: Ensure you have a working installation of PyTorch (https://pytorch.org/get-started/locally/)
24
+ Step 2: cd train/
25
+ Step 3 with GPU : python train_profiles_mlp.py -l 5 -f 128 -a "selu" -gpu 0
26
+ Step 3 without GPU: python train_profiles_mlp.py -l 5 -f 128 -a "selu"
.ipynb_checkpoints/calculate_profiles-checkpoint.py CHANGED
@@ -7,9 +7,11 @@ write_file = True
7
  plot_profile = True
8
  #### Define outputs ####
9
 
 
10
  with open('numpy_networks/mlp_[128, 128, 128, 128, 128].pkl', 'rb') as file:
11
  mlp = pickle.load(file)
12
 
 
13
  f_nn = "my_simulation_parameters.txt"
14
  with open(f_nn) as fw:
15
  lines = fw.readlines()
@@ -32,6 +34,7 @@ for line in lines:
32
  if not len(r_list) == len(v_list) and len(r_list) == len(t_list):
33
  raise Exception("Ensure equal number of values for all parameters in " + f_nn)
34
 
 
35
  for i in range(len(r_list)):
36
  if r_list[i]<0 or r_list[i]>9.5:
37
  warnings.warn('RaQ/Ra is outside the range of the training dataset')
 
7
  plot_profile = True
8
  #### Define outputs ####
9
 
10
+ # load the saved mlp
11
  with open('numpy_networks/mlp_[128, 128, 128, 128, 128].pkl', 'rb') as file:
12
  mlp = pickle.load(file)
13
 
14
+ # read the simulation parameters and ensure correct foramtting
15
  f_nn = "my_simulation_parameters.txt"
16
  with open(f_nn) as fw:
17
  lines = fw.readlines()
 
34
  if not len(r_list) == len(v_list) and len(r_list) == len(t_list):
35
  raise Exception("Ensure equal number of values for all parameters in " + f_nn)
36
 
37
+ # Check parameter ranges and print appropriate warnings
38
  for i in range(len(r_list)):
39
  if r_list[i]<0 or r_list[i]>9.5:
40
  warnings.warn('RaQ/Ra is outside the range of the training dataset')
.ipynb_checkpoints/utils-checkpoint.py CHANGED
@@ -25,7 +25,7 @@ def dimensionalize_fkv(x):
25
  return 10**(x*(1.9927988938926755-0.005251646002323797)+0.005251646002323797)
26
 
27
  def get_input(raq_ra, fkt, fkp, y_prof):
28
-
29
  x = np.zeros((len(raq_ra)*len(y_prof), 4))
30
 
31
  cntr = 0
@@ -40,7 +40,9 @@ def get_input(raq_ra, fkt, fkp, y_prof):
40
  return x
41
 
42
  def get_profile(inp, mlp, num_sims=1, correction=True, prof_points=128):
 
43
 
 
44
  num_layers = len(mlp)-1
45
  y_pred = inp
46
  res = []
@@ -59,11 +61,12 @@ def get_profile(inp, mlp, num_sims=1, correction=True, prof_points=128):
59
  res.append(y_pred)
60
 
61
  y_pred = y_pred.reshape(num_sims, -1)
62
-
 
63
  y_pred[:,0] = 1.
64
  y_pred[:,-1] = 0.
65
 
66
- if correction:
67
  inp = inp.reshape(num_sims, -1, inp.shape[-1])
68
 
69
  for sim_ind in range(num_sims):
 
25
  return 10**(x*(1.9927988938926755-0.005251646002323797)+0.005251646002323797)
26
 
27
  def get_input(raq_ra, fkt, fkp, y_prof):
28
+ # define input as (batch, [raq/ra, fkt, fkp and y])
29
  x = np.zeros((len(raq_ra)*len(y_prof), 4))
30
 
31
  cntr = 0
 
40
  return x
41
 
42
  def get_profile(inp, mlp, num_sims=1, correction=True, prof_points=128):
43
+ # get predicted profile based on input
44
 
45
+ # forward network pass in pure python using saved weights
46
  num_layers = len(mlp)-1
47
  y_pred = inp
48
  res = []
 
61
  res.append(y_pred)
62
 
63
  y_pred = y_pred.reshape(num_sims, -1)
64
+
65
+ # overwrite points at the boundary
66
  y_pred[:,0] = 1.
67
  y_pred[:,-1] = 0.
68
 
69
+ if correction: # boundary layer corection
70
  inp = inp.reshape(num_sims, -1, inp.shape[-1])
71
 
72
  for sim_ind in range(num_sims):
Paper/Dataset.pdf ADDED
Binary file (53.3 kB). View file
 
Paper/Snapshot_SS_colorbar.svg ADDED
Paper/figures.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Paper/los.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64f4894aef6463779d0de2527dce81c839cc5ad6e21358bc57910b807b89a64c
3
+ size 195952322
Paper/overall_stats.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5422a53fa78d2ff9c7ae8ae0a064d972f2fd3fdbc58ccc1ef5ad7eb9dc647c3b
3
+ size 577
Paper/simulations.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df59faeb9a989fa7751ebe13b513b0a3cf23bf6face190697c90f04567425cf9
3
+ size 4628
Paper/x_p.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d76ada3a1d22cafcd67ac04a5acbc0798fe7e6626ffd367aed535df895da5924
3
+ size 3537
Paper/y_p.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60058bfd5fcf2253384ca64602147641bdd50f4dc41505a132e255b5fbf73e0a
3
+ size 131327
Paper/y_pred_krr.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d5946c988e5d38c6d18f435eea9b349aabddc254439e670d7fa9dd0183a6b69
3
+ size 131327
Paper/y_pred_linear.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7413ff2d58d66c626c103560403025dca21b1f8615568b942eec596f712236cc
3
+ size 131327
Paper/y_pred_nn_interpolation.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1702a46ed99540ef523bb67495b63f2f4eaa530538923e680c57f65832d87a73
3
+ size 131327
Paper/y_pred_nn_nearestneighbor.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3dda5a9c7d01651b0f00d547069aa4191bb24e39d8a2174b0cd82fb4d1cdc08b
3
+ size 131327
Paper/y_pred_nn_pointwise.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54555229f98cb33f134b472259bdb5fc20cdd2ea99bdff056dbba2d3f3a2a039
3
+ size 131327
Paper/y_prof.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:74b064010ebed301b83d5a771864cddcf40972b92b7b93d9bad5be36b606fc6f
3
+ size 1174
README.md CHANGED
@@ -12,7 +12,15 @@ license: mit
12
 
13
  This respository contains trained neural networks that can be used to predict the steady-state temperature profile.
14
 
15
- Step 1: Define the simulation parameters
 
16
  Step 2: The output is as follows
17
- - the depth profile (first column) and the temperature profile (second column)
18
- - corresonding plot of the temperature profile
 
 
 
 
 
 
 
 
12
 
13
  This respository contains trained neural networks that can be used to predict the steady-state temperature profile.
14
 
15
+ Here is how to predict several profiles at once:
16
+ Step 1: Define the simulation parameters in my_simulation_parameters.txt
17
  Step 2: The output is as follows
18
+ - outputs/.txt files: the depth profile (first column) and the temperature profile (second column)
19
+ - outputs/.pdf files: corresonding plot of the temperature profile
20
+
21
+
22
+ Here's how to train a neural network:
23
+ Step 1: Ensure you have a working installation of PyTorch (https://pytorch.org/get-started/locally/)
24
+ Step 2: cd train/
25
+ Step 3 with GPU : python train_profiles_mlp.py -l 5 -f 128 -a "selu" -gpu 0
26
+ Step 3 without GPU: python train_profiles_mlp.py -l 5 -f 128 -a "selu"
__pycache__/utils.cpython-310.pyc CHANGED
Binary files a/__pycache__/utils.cpython-310.pyc and b/__pycache__/utils.cpython-310.pyc differ
 
calculate_profiles.py CHANGED
@@ -7,9 +7,11 @@ write_file = True
7
  plot_profile = True
8
  #### Define outputs ####
9
 
 
10
  with open('numpy_networks/mlp_[128, 128, 128, 128, 128].pkl', 'rb') as file:
11
  mlp = pickle.load(file)
12
 
 
13
  f_nn = "my_simulation_parameters.txt"
14
  with open(f_nn) as fw:
15
  lines = fw.readlines()
@@ -32,6 +34,7 @@ for line in lines:
32
  if not len(r_list) == len(v_list) and len(r_list) == len(t_list):
33
  raise Exception("Ensure equal number of values for all parameters in " + f_nn)
34
 
 
35
  for i in range(len(r_list)):
36
  if r_list[i]<0 or r_list[i]>9.5:
37
  warnings.warn('RaQ/Ra is outside the range of the training dataset')
 
7
  plot_profile = True
8
  #### Define outputs ####
9
 
10
+ # load the saved mlp
11
  with open('numpy_networks/mlp_[128, 128, 128, 128, 128].pkl', 'rb') as file:
12
  mlp = pickle.load(file)
13
 
14
+ # read the simulation parameters and ensure correct foramtting
15
  f_nn = "my_simulation_parameters.txt"
16
  with open(f_nn) as fw:
17
  lines = fw.readlines()
 
34
  if not len(r_list) == len(v_list) and len(r_list) == len(t_list):
35
  raise Exception("Ensure equal number of values for all parameters in " + f_nn)
36
 
37
+ # Check parameter ranges and print appropriate warnings
38
  for i in range(len(r_list)):
39
  if r_list[i]<0 or r_list[i]>9.5:
40
  warnings.warn('RaQ/Ra is outside the range of the training dataset')
data/scaling_laws_data.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7dcc41721a779493ba2d4f7517ef59794de61f79ac5c37a7944952b191cf3989
3
+ size 16783233
data/scaling_laws_data_hot_start.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31c2d899a964b4f641148841a67dea83d863d98860b707297ec30bdfbc5acd9b
3
+ size 22631031
data/sims.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f2c00c1fcc0e05799a37c8d657542d6cfcfd8b6f51ef192b9551e0b03dcbd97
3
+ size 22319
data/snapshots.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae8f8d8545c7ab78db38e3d9b6b98eae74d21d2f004811a2cb81acffb4ecdfdd
3
+ size 12437058
data/x_Tprev.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:351805bd60d4fd84c25112718702c24789545f3df5e18df65efe8b06e7bb77c0
3
+ size 6938
data/x_modes.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4afe1c0dbc408af7cb3900c7116851e02e84f560c95ed10d9627134551c28a7
3
+ size 3318
data/x_p.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4afe1c0dbc408af7cb3900c7116851e02e84f560c95ed10d9627134551c28a7
3
+ size 3318
data/x_pointwise_orgres.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24405442e95cae6608a507e044b0414ee2983da71e6e9cf61ed9343dafdb0471
3
+ size 524555
data/y_Tprev.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6eaeca58b64bb330ab8ad220f344f83a7773910f6dda2e98763e242bcccaa995
3
+ size 18584253
data/y_modes.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1acfa7a99ed3edda7ef750dbd30889d56c954be11e889173265a50321d72dbe5
3
+ size 33014
data/y_p.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60058bfd5fcf2253384ca64602147641bdd50f4dc41505a132e255b5fbf73e0a
3
+ size 131327
data/y_pointwise_orgres.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d19c505b85811df63cba87046342b1ae5930e259251757f50131c16357cc0866
3
+ size 131330
data/y_prof.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1be9d7e71bb22d9cddc3f6a4c2f363f4e7eb9879ed60cf2d8b8825f21f97f18
3
+ size 2135
evaluate/load_profiles.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
evaluate/scaling_laws.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
my_simulation_parameters.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #### Define parameters ####
2
+
3
+ # raq/ra betweel 0 and 10
4
+ r_list = 5,7.5,10,
5
+
6
+ # FKT between 1e+6 and 1e+10
7
+ t_list = 1e+8,1e+9,1e+10,
8
+
9
+ # FKV between 1e+0 and 1e+2
10
+ v_list = 50,25,1,
outputs/profile_raq_ra10.0_fkt10000000000.0_fkv1.0.png ADDED
outputs/profile_raq_ra10.0_fkt10000000000.0_fkv1.0.txt ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 1.0 1.0
2
+ 0.99609375 0.9902098942725738
3
+ 0.9881562500000001 0.9703163994344438
4
+ 0.98021875 0.9698090412062678
5
+ 0.97228125 0.9691821542327205
6
+ 0.96434375 0.9684515420121397
7
+ 0.95640625 0.9677356340791369
8
+ 0.94846875 0.9671697830904316
9
+ 0.94053125 0.9667202236581387
10
+ 0.93259375 0.966311449336034
11
+ 0.92465625 0.9659481243331839
12
+ 0.91671875 0.9656773805469817
13
+ 0.90878125 0.9654725654458026
14
+ 0.90084375 0.9653150700204262
15
+ 0.89290625 0.9651916980635759
16
+ 0.88496875 0.965092961746218
17
+ 0.87703125 0.9650111687302013
18
+ 0.86909375 0.9649397652692542
19
+ 0.8611562500000001 0.9648770260555964
20
+ 0.85321875 0.9648247711847248
21
+ 0.84528125 0.9647767597243578
22
+ 0.83734375 0.9647278248222245
23
+ 0.82940625 0.9646797526182401
24
+ 0.82146875 0.9646316205912606
25
+ 0.81353125 0.9645885560525155
26
+ 0.80559375 0.9645405864163007
27
+ 0.79765625 0.9644910702591126
28
+ 0.78971875 0.9644372136707845
29
+ 0.78178125 0.9643808755612203
30
+ 0.77384375 0.964321872132771
31
+ 0.76590625 0.9642631270679868
32
+ 0.75796875 0.9642026502634682
33
+ 0.75003125 0.9641388550418537
34
+ 0.74209375 0.96407165722111
35
+ 0.7341562500000001 0.9640034450644884
36
+ 0.72621875 0.9639535615753994
37
+ 0.71828125 0.9638993714622504
38
+ 0.71034375 0.9638450955265502
39
+ 0.70240625 0.9638059849372029
40
+ 0.69446875 0.9637577866497685
41
+ 0.68653125 0.9636980049798857
42
+ 0.67859375 0.9636324429465546
43
+ 0.67065625 0.9635611509667682
44
+ 0.66271875 0.9634841701965359
45
+ 0.65478125 0.9634025293866947
46
+ 0.64684375 0.963330646310048
47
+ 0.63890625 0.9632707432324515
48
+ 0.63096875 0.9632034377171873
49
+ 0.62303125 0.9631275938242613
50
+ 0.61509375 0.963044092611668
51
+ 0.6071562500000001 0.9629610203510949
52
+ 0.59921875 0.9628686842131967
53
+ 0.59128125 0.9627672336670928
54
+ 0.58334375 0.9626563500660463
55
+ 0.57540625 0.962536023518184
56
+ 0.56746875 0.962420491607416
57
+ 0.55953125 0.9622991897704051
58
+ 0.55159375 0.9621680519091717
59
+ 0.54365625 0.9620267304304034
60
+ 0.53571875 0.9618756339504292
61
+ 0.52778125 0.961713923059125
62
+ 0.51984375 0.9615400932428618
63
+ 0.51190625 0.9613493372438985
64
+ 0.50396875 0.9611404267958391
65
+ 0.49603125000000003 0.9609165495281505
66
+ 0.48809375 0.9606885428813832
67
+ 0.48015625 0.9604565399427599
68
+ 0.47221875 0.9602097371188835
69
+ 0.46428125 0.9599469049699622
70
+ 0.45634375 0.9596665062472561
71
+ 0.44840625 0.959366632496811
72
+ 0.44046875 0.9590449217547321
73
+ 0.43253125000000003 0.9586984513785973
74
+ 0.42459375 0.958323597797175
75
+ 0.41665625 0.9579158517374129
76
+ 0.40871875 0.957513439179302
77
+ 0.40078125 0.9571143302567431
78
+ 0.39284375 0.9566899501927818
79
+ 0.38490625 0.9562374521378395
80
+ 0.37696875 0.955753466698022
81
+ 0.36903125000000003 0.9552340223880293
82
+ 0.36109375 0.9546744202766487
83
+ 0.35315625 0.9540690636607255
84
+ 0.34521875 0.9534112769062083
85
+ 0.33728125 0.9526930564699306
86
+ 0.32934375 0.9519047613605006
87
+ 0.32140625 0.9510122522654944
88
+ 0.31346875 0.9499836306702947
89
+ 0.30553125000000003 0.9488385276983663
90
+ 0.29759375 0.947556941139438
91
+ 0.28965625 0.9461514819098414
92
+ 0.28171875 0.9446168768590487
93
+ 0.27378125 0.9428482229388179
94
+ 0.26584375 0.9407955463080119
95
+ 0.25790625 0.9383916980923374
96
+ 0.24996875000000002 0.9355501727743388
97
+ 0.24203125 0.9321597717436331
98
+ 0.23409375 0.9280749815436454
99
+ 0.22615625 0.92409832729004
100
+ 0.21821875000000002 0.9199603312644494
101
+ 0.21028125 0.91529397107669
102
+ 0.20234375 0.9099618854041043
103
+ 0.19440625 0.9038215180307873
104
+ 0.18646875000000002 0.8967189006591215
105
+ 0.17853125 0.8884624483245174
106
+ 0.17059375 0.8788109964569801
107
+ 0.16265625 0.8674551794353741
108
+ 0.15471875000000002 0.8539938886343238
109
+ 0.14678125 0.8379068574130993
110
+ 0.13884375 0.8184845893551922
111
+ 0.13090625 0.8002555708629069
112
+ 0.12296875 0.7830705389169234
113
+ 0.11503125 0.7637560194052709
114
+ 0.10709375 0.7404082311824085
115
+ 0.09915625 0.7105735141347064
116
+ 0.09121875 0.6794464330052945
117
+ 0.08328125 0.6430006959339014
118
+ 0.07534375 0.6061079752629455
119
+ 0.06740625 0.5659042020226758
120
+ 0.05946875 0.5277315266129229
121
+ 0.05153125 0.4859035789367027
122
+ 0.04359375 0.4445176269133653
123
+ 0.03565625 0.39944274497494076
124
+ 0.02771875 0.31052209885431414
125
+ 0.01978125 0.22160145273368756
126
+ 0.01184375 0.13268080661306095
127
+ 0.00390625 0.04376016049243435
128
+ 0.0 0.0
outputs/profile_raq_ra5.0_fkt100000000.0_fkv50.0.png ADDED
outputs/profile_raq_ra5.0_fkt100000000.0_fkv50.0.txt ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 1.0 1.0
2
+ 0.99609375 1.0004886307302538
3
+ 0.9881562500000001 1.0014815283741296
4
+ 0.98021875 1.000804723822155
5
+ 0.97228125 1.000350293388824
6
+ 0.96434375 1.0002550325461768
7
+ 0.95640625 1.0008654380922966
8
+ 0.94846875 1.0013652302040426
9
+ 0.94053125 1.001766314058726
10
+ 0.93259375 1.0020883822552678
11
+ 0.92465625 1.0023296302760354
12
+ 0.91671875 1.002461872518664
13
+ 0.90878125 1.0020987597468052
14
+ 0.90084375 1.0018112935862176
15
+ 0.89290625 1.0016318727877296
16
+ 0.88496875 1.0014263339050131
17
+ 0.87703125 1.0013376768427356
18
+ 0.86909375 1.0013275735204263
19
+ 0.8611562500000001 1.0013684925815804
20
+ 0.85321875 1.0014425840323204
21
+ 0.84528125 1.0016010617709805
22
+ 0.83734375 1.001813270476308
23
+ 0.82940625 1.0020246870888012
24
+ 0.82146875 1.0022308677315883
25
+ 0.81353125 1.0024308604725698
26
+ 0.80559375 1.0026193535512236
27
+ 0.79765625 1.0027957884153567
28
+ 0.78971875 1.0029601417827056
29
+ 0.78178125 1.0031127593519735
30
+ 0.77384375 1.0032561116829861
31
+ 0.76590625 1.0033889550368034
32
+ 0.75796875 1.0035116379965685
33
+ 0.75003125 1.0036248264745116
34
+ 0.74209375 1.0037291232564567
35
+ 0.7341562500000001 1.003820214014638
36
+ 0.72621875 1.0039010969261741
37
+ 0.71828125 1.0039746801239253
38
+ 0.71034375 1.0040414574335275
39
+ 0.70240625 1.0041018075512984
40
+ 0.69446875 1.0041560013675555
41
+ 0.68653125 1.0042044597671895
42
+ 0.67859375 1.0042474495055858
43
+ 0.67065625 1.0042852082893852
44
+ 0.66271875 1.0043164962792892
45
+ 0.65478125 1.0043387000734143
46
+ 0.64684375 1.0043610723879577
47
+ 0.63890625 1.004381003538
48
+ 0.63096875 1.004398923604277
49
+ 0.62303125 1.0044144376406412
50
+ 0.61509375 1.0044304460366351
51
+ 0.6071562500000001 1.0044636210329727
52
+ 0.59921875 1.0044923790380103
53
+ 0.59128125 1.004520233263561
54
+ 0.58334375 1.0045887493520003
55
+ 0.57540625 1.004679199743899
56
+ 0.56746875 1.004779877814393
57
+ 0.55953125 1.0048788693938613
58
+ 0.55159375 1.004977380138184
59
+ 0.54365625 1.0050756035254627
60
+ 0.53571875 1.0051737376253638
61
+ 0.52778125 1.005271987317915
62
+ 0.51984375 1.0053779926972155
63
+ 0.51190625 1.0054960240256412
64
+ 0.50396875 1.0055879712527709
65
+ 0.49603125000000003 1.0056687630323906
66
+ 0.48809375 1.0057530411581908
67
+ 0.48015625 1.0058562629680423
68
+ 0.47221875 1.005955859588927
69
+ 0.46428125 1.006049142126516
70
+ 0.45634375 1.0061392284693913
71
+ 0.44840625 1.0062227960462489
72
+ 0.44046875 1.0063014876598555
73
+ 0.43253125000000003 1.0063809537114619
74
+ 0.42459375 1.0064477385412693
75
+ 0.41665625 1.0065002233371656
76
+ 0.40871875 1.0065340278356054
77
+ 0.40078125 1.0065449910239257
78
+ 0.39284375 1.0065604716569034
79
+ 0.38490625 1.006545761987049
80
+ 0.37696875 1.006514220912114
81
+ 0.36903125000000003 1.006445797187465
82
+ 0.36109375 1.0063420660756848
83
+ 0.35315625 1.0061835811678699
84
+ 0.34521875 1.005955216668348
85
+ 0.33728125 1.0056493725988478
86
+ 0.32934375 1.0052531024438225
87
+ 0.32140625 1.0047476614201345
88
+ 0.31346875 1.0041090171355473
89
+ 0.30553125000000003 1.0033061411151785
90
+ 0.29759375 1.0022965449658188
91
+ 0.28965625 1.0010287716820465
92
+ 0.28171875 0.9994315147840812
93
+ 0.27378125 0.9971542852576364
94
+ 0.26584375 0.993085846984152
95
+ 0.25790625 0.9874956979334749
96
+ 0.24996875000000002 0.9803651503732914
97
+ 0.24203125 0.9712143155094474
98
+ 0.23409375 0.9593832352432593
99
+ 0.22615625 0.9448093523254151
100
+ 0.21821875000000002 0.9281858239988381
101
+ 0.21028125 0.9072652738070205
102
+ 0.20234375 0.8810591896616389
103
+ 0.19440625 0.8536958985978688
104
+ 0.18646875000000002 0.8242217024249135
105
+ 0.17853125 0.7970041482379427
106
+ 0.17059375 0.7680697876685222
107
+ 0.16265625 0.7341301061899966
108
+ 0.15471875000000002 0.7017719756017162
109
+ 0.14678125 0.6716856076715821
110
+ 0.13884375 0.6404048600148078
111
+ 0.13090625 0.6066679693481009
112
+ 0.12296875 0.573799169182174
113
+ 0.11503125 0.541222476564494
114
+ 0.10709375 0.5084939725792854
115
+ 0.09915625 0.472769769206234
116
+ 0.09121875 0.43567158985200666
117
+ 0.08328125 0.39883182158072766
118
+ 0.07534375 0.36379378245712773
119
+ 0.06740625 0.32625680593753936
120
+ 0.05946875 0.286219165783732
121
+ 0.05153125 0.2471830364120577
122
+ 0.04359375 0.20984231755040794
123
+ 0.03565625 0.1735799645976791
124
+ 0.02771875 0.13493902594052706
125
+ 0.01978125 0.096298087283375
126
+ 0.01184375 0.05765714862622295
127
+ 0.00390625 0.019016209969070892
128
+ 0.0 0.0
outputs/profile_raq_ra7.5_fkt1000000000.0_fkv25.0.png ADDED
outputs/profile_raq_ra7.5_fkt1000000000.0_fkv25.0.txt ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 1.0 1.0
2
+ 0.99609375 0.9990062610422076
3
+ 0.9881562500000001 0.9969869834799735
4
+ 0.98021875 0.9971232601553087
5
+ 0.97228125 0.9971806693281854
6
+ 0.96434375 0.9971845844770447
7
+ 0.95640625 0.9971461156385902
8
+ 0.94846875 0.9970699944828866
9
+ 0.94053125 0.9969707623390985
10
+ 0.93259375 0.9967311364452428
11
+ 0.92465625 0.9962571288038881
12
+ 0.91671875 0.9957798887799661
13
+ 0.90878125 0.9953726290532506
14
+ 0.90084375 0.9949482003998925
15
+ 0.89290625 0.9946372293678932
16
+ 0.88496875 0.994350193789516
17
+ 0.87703125 0.994082398444334
18
+ 0.86909375 0.9938673459886528
19
+ 0.8611562500000001 0.9936926261460569
20
+ 0.85321875 0.993549103310679
21
+ 0.84528125 0.9934299598039326
22
+ 0.83734375 0.9933408877236217
23
+ 0.82940625 0.9932748167728084
24
+ 0.82146875 0.9932240593252849
25
+ 0.81353125 0.9931845996072571
26
+ 0.80559375 0.9931535642345825
27
+ 0.79765625 0.9931284177384515
28
+ 0.78971875 0.9930763882075305
29
+ 0.78178125 0.993027270660629
30
+ 0.77384375 0.9929824997446136
31
+ 0.76590625 0.9929414748254256
32
+ 0.75796875 0.9929041445027907
33
+ 0.75003125 0.9928710699506768
34
+ 0.74209375 0.9928408956409854
35
+ 0.7341562500000001 0.9928062295169527
36
+ 0.72621875 0.9927731526457385
37
+ 0.71828125 0.9927418727153093
38
+ 0.71034375 0.9927124057458755
39
+ 0.70240625 0.9926843616027358
40
+ 0.69446875 0.9926577214129114
41
+ 0.68653125 0.9926324444557905
42
+ 0.67859375 0.992608483120658
43
+ 0.67065625 0.9925858197309926
44
+ 0.66271875 0.9925644421671426
45
+ 0.65478125 0.9925443623714286
46
+ 0.64684375 0.9925255878718212
47
+ 0.63890625 0.9925081328548363
48
+ 0.63096875 0.9924920175575421
49
+ 0.62303125 0.9924772674977019
50
+ 0.61509375 0.9924639128627663
51
+ 0.6071562500000001 0.9924519880277204
52
+ 0.59921875 0.9924414619655849
53
+ 0.59128125 0.9924248754213431
54
+ 0.58334375 0.9924091294843468
55
+ 0.57540625 0.9923949509101252
56
+ 0.56746875 0.9924055373747572
57
+ 0.55953125 0.9924226953546874
58
+ 0.55159375 0.9924412279591394
59
+ 0.54365625 0.9924609181073719
60
+ 0.53571875 0.9924822080796133
61
+ 0.52778125 0.992505191275782
62
+ 0.51984375 0.9925299638594424
63
+ 0.51190625 0.9925564376035475
64
+ 0.50396875 0.9925860465898036
65
+ 0.49603125000000003 0.9926185812893703
66
+ 0.48809375 0.9926533574477424
67
+ 0.48015625 0.9926880569201405
68
+ 0.47221875 0.9927248793493549
69
+ 0.46428125 0.9927643825455995
70
+ 0.45634375 0.9928067052937776
71
+ 0.44840625 0.992851761327506
72
+ 0.44046875 0.9928997556605149
73
+ 0.43253125000000003 0.9929506220835563
74
+ 0.42459375 0.9930046883440068
75
+ 0.41665625 0.9930610831698342
76
+ 0.40871875 0.9931189143921962
77
+ 0.40078125 0.9931812242068735
78
+ 0.39284375 0.9932488769681517
79
+ 0.38490625 0.9933222601182629
80
+ 0.37696875 0.9934017804362069
81
+ 0.36903125000000003 0.9935076371764416
82
+ 0.36109375 0.9936212772387668
83
+ 0.35315625 0.9937408126620867
84
+ 0.34521875 0.9938671203998567
85
+ 0.33728125 0.9940008199749039
86
+ 0.32934375 0.9941429866018373
87
+ 0.32140625 0.9942950177741342
88
+ 0.31346875 0.9944506020368925
89
+ 0.30553125000000003 0.994619583716713
90
+ 0.29759375 0.9947970240990528
91
+ 0.28965625 0.9949834978202933
92
+ 0.28171875 0.9951895224341948
93
+ 0.27378125 0.9954192703762218
94
+ 0.26584375 0.9956510966996391
95
+ 0.25790625 0.9958843935290972
96
+ 0.24996875000000002 0.9961171663976424
97
+ 0.24203125 0.9963453467837434
98
+ 0.23409375 0.9965568266194059
99
+ 0.22615625 0.99674747819695
100
+ 0.21821875000000002 0.9969482057070408
101
+ 0.21028125 0.9971208387999916
102
+ 0.20234375 0.9967090731311209
103
+ 0.19440625 0.9950787777951174
104
+ 0.18646875000000002 0.9930116113720479
105
+ 0.17853125 0.9902464116431152
106
+ 0.17059375 0.986576923533956
107
+ 0.16265625 0.9810600067528333
108
+ 0.15471875000000002 0.9733191107513938
109
+ 0.14678125 0.9633249834854362
110
+ 0.13884375 0.9503717984851009
111
+ 0.13090625 0.933303531509786
112
+ 0.12296875 0.9103604184926574
113
+ 0.11503125 0.8787642810144055
114
+ 0.10709375 0.8387919694348168
115
+ 0.09915625 0.7853416959477463
116
+ 0.09121875 0.715816483333058
117
+ 0.08328125 0.6491644646132442
118
+ 0.07534375 0.5822779840859292
119
+ 0.06740625 0.5214566188913672
120
+ 0.05946875 0.46451158897336337
121
+ 0.05153125 0.40186241330496747
122
+ 0.04359375 0.3343255941714209
123
+ 0.03565625 0.2648353775338124
124
+ 0.02771875 0.20587991224582963
125
+ 0.01978125 0.14692444695784684
126
+ 0.01184375 0.08796898166986407
127
+ 0.00390625 0.02901351638188129
128
+ 0.0 0.0
preprocess/preprocess_profiles.ipynb ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 19,
6
+ "id": "8b63d618-1f46-49e8-b388-c4185624e58c",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import os \n",
11
+ "import torch\n",
12
+ "import numpy as np\n",
13
+ "from tabulate import tabulate\n",
14
+ "import random\n",
15
+ "from matplotlib import pyplot as plt\n",
16
+ "import pickle\n",
17
+ "from scipy.signal import find_peaks"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": 20,
23
+ "id": "2129fb21-fcea-40b4-ba3c-473aa2e6f1e2",
24
+ "metadata": {},
25
+ "outputs": [],
26
+ "source": [
27
+ "pre = \"../data/\""
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": 29,
33
+ "id": "ae2009a3-a989-4756-849a-1180bb7ba087",
34
+ "metadata": {
35
+ "scrolled": true
36
+ },
37
+ "outputs": [],
38
+ "source": [
39
+ "load_sims = False\n",
40
+ "\n",
41
+ "x = {}\n",
42
+ "y = {}\n",
43
+ "var_vec = [\"Tprev\"]\n",
44
+ "for var in var_vec:\n",
45
+ " x[var] = {}\n",
46
+ " y[var] = {}\n",
47
+ " \n",
48
+ "for var in var_vec:\n",
49
+ " with open(pre + 'x_' + var + '.pkl', 'rb') as file: \n",
50
+ " x[var] = pickle.load(file) \n",
51
+ " with open(pre + 'y_' + var + '.pkl', 'rb') as file: \n",
52
+ " y[var] = pickle.load(file) \n",
53
+ " "
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "markdown",
58
+ "id": "bff0409c-0adb-42df-96e3-2ace88051f6e",
59
+ "metadata": {},
60
+ "source": [
61
+ "### Identify train/cv/test"
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "execution_count": 30,
67
+ "id": "22cbeaae-7421-4ae9-8042-079591939e07",
68
+ "metadata": {},
69
+ "outputs": [
70
+ {
71
+ "name": "stdout",
72
+ "output_type": "stream",
73
+ "text": [
74
+ "97\n",
75
+ "15\n",
76
+ "16\n"
77
+ ]
78
+ }
79
+ ],
80
+ "source": [
81
+ "sims = torch.load(pre + \"/sims.pt\")\n",
82
+ "\n",
83
+ "extrapolation_sims = []\n",
84
+ "interpolation_sims = []\n",
85
+ "for si, sim in enumerate(sims):\n",
86
+ " if si !=39 and si!=8: # and (si==0 or si==100 or si==120):\n",
87
+ " #print(tabulate([[\"num\", \"dataset\", \"raq\", \"fkt\", \"fkp\", \"gr\", \"ar\"],\n",
88
+ " # sim[:-1]\n",
89
+ " # ]))\n",
90
+ " ignr, ignr, raq, fkt, fkp, gr, ar, ignr = sim\n",
91
+ "\n",
92
+ " #if (fkt < 5e+5 or fkt > 5e+8) and (fkp < 15 or fkp > 85) and (raq < 1.5 or raq > 8.5):\n",
93
+ " if (fkt > 5e+9) or (fkp > 95) or (raq > 9.5):\n",
94
+ " extrapolation_sims.append(si)\n",
95
+ " else:\n",
96
+ " interpolation_sims.append(si)\n",
97
+ "\n",
98
+ "random.seed(1992)\n",
99
+ "inds = {}\n",
100
+ "inds[\"test\"] = extrapolation_sims #+ random.choices(interpolation_sims, k=9)\n",
101
+ "\n",
102
+ "remain_inds = []\n",
103
+ "for inp in interpolation_sims:\n",
104
+ " if inp not in inds[\"test\"]:\n",
105
+ " remain_inds.append(inp)\n",
106
+ " \n",
107
+ "inds[\"cv\"] = random.choices(remain_inds, k=16)\n",
108
+ "\n",
109
+ "inds[\"train\"] = []\n",
110
+ "for inp in remain_inds:\n",
111
+ " if inp not in inds[\"test\"] and inp not in inds[\"cv\"]:\n",
112
+ " inds[\"train\"].append(inp)\n",
113
+ "\n",
114
+ "inds[\"train\"] = np.unique(inds[\"train\"])\n",
115
+ "inds[\"cv\"] = np.unique(inds[\"cv\"])\n",
116
+ "inds[\"test\"] = np.unique(inds[\"test\"])\n",
117
+ "\n",
118
+ "print(len(inds[\"train\"]))#, sorted(inds[\"train\"]))\n",
119
+ "print(len(inds[\"cv\"]))#, sorted(inds[\"cv\"]))\n",
120
+ "print(len(inds[\"test\"]))#, sorted(inds[\"test\"]))"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "markdown",
125
+ "id": "7ade165c-b896-4192-a4e6-d4ac96a6dd07",
126
+ "metadata": {},
127
+ "source": [
128
+ "### Write simulation parameters "
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": 33,
134
+ "id": "f3ec1278-58cb-445a-a8e4-112d0032c784",
135
+ "metadata": {},
136
+ "outputs": [],
137
+ "source": [
138
+ "sims_table = [[\"Simulation\", \"Dataset\", \"RaQ/Ra\", \"FKT\", \"FKV\"]] \n",
139
+ "\n",
140
+ "for sim in sims:\n",
141
+ " if sim[0] in inds[\"train\"]:\n",
142
+ " an = \"train\"\n",
143
+ " elif sim[0] in inds[\"cv\"]:\n",
144
+ " an = \"cv\"\n",
145
+ " elif sim[0] in inds[\"test\"]:\n",
146
+ " an = \"test\"\n",
147
+ " sims_table.append([sim[0], an, sim[2], sim[3], sim[4]])\n",
148
+ "\n",
149
+ "with open('../inputs/simulations.txt', 'w') as f:\n",
150
+ " f.write(tabulate(sims_table))\n",
151
+ "\n",
152
+ "with open('../Paper/simulations.pkl', 'wb') as f:\n",
153
+ " pickle.dump(sims_table, f)"
154
+ ]
155
+ },
156
+ {
157
+ "cell_type": "markdown",
158
+ "id": "4f93a6bb-3b87-4822-a58a-a105ca049c0c",
159
+ "metadata": {},
160
+ "source": [
161
+ "### Pointwise input preparation"
162
+ ]
163
+ },
164
+ {
165
+ "cell_type": "code",
166
+ "execution_count": 34,
167
+ "id": "97e5f984-6dc9-4ec9-b770-3ca1a721f67c",
168
+ "metadata": {},
169
+ "outputs": [
170
+ {
171
+ "name": "stdout",
172
+ "output_type": "stream",
173
+ "text": [
174
+ "(26966, 4) (26966, 1)\n",
175
+ "(4170, 4) (4170, 1)\n",
176
+ "(4448, 4) (4448, 1)\n",
177
+ "(12416, 4) (12416, 1)\n",
178
+ "(1920, 4) (1920, 1)\n",
179
+ "(2048, 4) (2048, 1)\n"
180
+ ]
181
+ }
182
+ ],
183
+ "source": [
184
+ "x_pointwise = {}\n",
185
+ "y_pointwise = {}\n",
186
+ "\n",
187
+ "y_prof = torch.load(pre + \"/y_prof.pt\").flatten().numpy()\n",
188
+ "\n",
189
+ "y_prof = y_prof[::-1]\n",
190
+ "y_new = np.sort(np.concatenate((np.linspace(1,y_prof[15],100), \n",
191
+ " y_prof, np.linspace(y_prof[-10],y_prof[-1],50)), axis=0))[::-1]\n",
192
+ "\n",
193
+ "\n",
194
+ "for an in [\"train\", \"cv\", \"test\"]:\n",
195
+ " x_pointwise[an] = np.zeros((len(inds[an])*y_new.shape[0], 4))\n",
196
+ " y_pointwise[an] = np.zeros((len(inds[an])*y_new.shape[0], 1))\n",
197
+ " \n",
198
+ "\n",
199
+ " cntr = 0\n",
200
+ " \n",
201
+ " for i in inds[an]:\n",
202
+ " #print(an, i)\n",
203
+ " #u = y[\"uprev\"][i]*20\n",
204
+ " #v = y[\"vprev\"][i]*20\n",
205
+ " #vmag = np.sqrt(u[-50:,:]**2 + v[-50:,:]**2)\n",
206
+ " #vmag = np.mean(vmag, axis=0)\n",
207
+ " T = np.mean(y[\"Tprev\"][i], axis=0)\n",
208
+ " T_new = np.interp(y_new, y_prof[::-1], T[::-1])\n",
209
+ "\n",
210
+ " #plt.figure()\n",
211
+ " #plt.plot(T, y_prof)\n",
212
+ " #plt.plot(T_new, y_new, 'kx')\n",
213
+ " #plt.ylim([1,0])\n",
214
+ " #plt.show()\n",
215
+ "\n",
216
+ " for j in range(y_new.shape[0]):\n",
217
+ " x_pointwise[an][cntr,:3] = x[\"Tprev\"][i]\n",
218
+ " x_pointwise[an][cntr,3:4] = y_new[j]\n",
219
+ " \n",
220
+ " y_pointwise[an][cntr,0] = T_new[j]\n",
221
+ " cntr += 1 \n",
222
+ " print(x_pointwise[an].shape, y_pointwise[an].shape)\n",
223
+ "\n",
224
+ "\n",
225
+ "\n",
226
+ "with open(pre + 'x_pointwise.pkl', 'wb') as file: \n",
227
+ " pickle.dump(x_pointwise, file) \n",
228
+ "with open(pre + 'y_pointwise.pkl', 'wb') as file: \n",
229
+ " pickle.dump(y_pointwise, file) \n",
230
+ "\n",
231
+ "\n",
232
+ "for an in [\"train\", \"cv\", \"test\"]:\n",
233
+ " x_pointwise[an] = np.zeros((len(inds[an])*y_prof.shape[0], 4))\n",
234
+ " y_pointwise[an] = np.zeros((len(inds[an])*y_prof.shape[0], 1))\n",
235
+ " \n",
236
+ " cntr = 0\n",
237
+ " \n",
238
+ " for i in inds[an]:\n",
239
+ " T = np.mean(y[\"Tprev\"][i], axis=0)\n",
240
+ "\n",
241
+ " for j in range(y_prof.shape[0]):\n",
242
+ " x_pointwise[an][cntr,:3] = x[\"Tprev\"][i]\n",
243
+ " x_pointwise[an][cntr,3:4] = y_prof[j]\n",
244
+ " \n",
245
+ " y_pointwise[an][cntr,0] = T[j]\n",
246
+ " cntr += 1 \n",
247
+ " print(x_pointwise[an].shape, y_pointwise[an].shape)\n",
248
+ "\n",
249
+ "\n",
250
+ "\n",
251
+ "with open(pre + 'x_pointwise_orgres.pkl', 'wb') as file: \n",
252
+ " pickle.dump(x_pointwise, file) \n",
253
+ "with open(pre + 'y_pointwise_orgres.pkl', 'wb') as file: \n",
254
+ " pickle.dump(y_pointwise, file) "
255
+ ]
256
+ },
257
+ {
258
+ "cell_type": "code",
259
+ "execution_count": 35,
260
+ "id": "2875ca3d-079b-4bda-ad0d-29081aa82fee",
261
+ "metadata": {},
262
+ "outputs": [],
263
+ "source": [
264
+ "### Full profile input"
265
+ ]
266
+ },
267
+ {
268
+ "cell_type": "code",
269
+ "execution_count": 36,
270
+ "id": "40d86b1d-8db5-405c-a305-f7fb23c297dc",
271
+ "metadata": {
272
+ "scrolled": true
273
+ },
274
+ "outputs": [
275
+ {
276
+ "name": "stdout",
277
+ "output_type": "stream",
278
+ "text": [
279
+ "train 0\n",
280
+ "train 3\n",
281
+ "train 4\n",
282
+ "train 5\n",
283
+ "train 6\n",
284
+ "train 7\n",
285
+ "train 9\n",
286
+ "train 10\n",
287
+ "train 11\n",
288
+ "train 12\n",
289
+ "train 13\n",
290
+ "train 14\n",
291
+ "train 16\n",
292
+ "train 18\n",
293
+ "train 19\n",
294
+ "train 20\n",
295
+ "train 21\n",
296
+ "train 22\n",
297
+ "train 23\n",
298
+ "train 24\n",
299
+ "train 25\n",
300
+ "train 26\n",
301
+ "train 27\n",
302
+ "train 28\n",
303
+ "train 29\n",
304
+ "train 30\n",
305
+ "train 31\n",
306
+ "train 33\n",
307
+ "train 34\n",
308
+ "train 35\n",
309
+ "train 36\n",
310
+ "train 37\n",
311
+ "train 41\n",
312
+ "train 43\n",
313
+ "train 44\n",
314
+ "train 45\n",
315
+ "train 46\n",
316
+ "train 47\n",
317
+ "train 48\n",
318
+ "train 49\n",
319
+ "train 50\n",
320
+ "train 51\n",
321
+ "train 52\n",
322
+ "train 53\n",
323
+ "train 54\n",
324
+ "train 56\n",
325
+ "train 61\n",
326
+ "train 62\n",
327
+ "train 63\n",
328
+ "train 64\n",
329
+ "train 65\n",
330
+ "train 66\n",
331
+ "train 67\n",
332
+ "train 70\n",
333
+ "train 71\n",
334
+ "train 72\n",
335
+ "train 73\n",
336
+ "train 74\n",
337
+ "train 75\n",
338
+ "train 78\n",
339
+ "train 79\n",
340
+ "train 80\n",
341
+ "train 81\n",
342
+ "train 82\n",
343
+ "train 84\n",
344
+ "train 88\n",
345
+ "train 89\n",
346
+ "train 90\n",
347
+ "train 91\n",
348
+ "train 96\n",
349
+ "train 97\n",
350
+ "train 99\n",
351
+ "train 100\n",
352
+ "train 101\n",
353
+ "train 102\n",
354
+ "train 103\n",
355
+ "train 104\n",
356
+ "train 106\n",
357
+ "train 107\n",
358
+ "train 108\n",
359
+ "train 109\n",
360
+ "train 110\n",
361
+ "train 111\n",
362
+ "train 113\n",
363
+ "train 114\n",
364
+ "train 115\n",
365
+ "train 116\n",
366
+ "train 117\n",
367
+ "train 119\n",
368
+ "train 120\n",
369
+ "train 121\n",
370
+ "train 123\n",
371
+ "train 124\n",
372
+ "train 125\n",
373
+ "train 126\n",
374
+ "train 128\n",
375
+ "train 129\n",
376
+ "(97, 3) (97, 128)\n",
377
+ "cv 2\n",
378
+ "cv 17\n",
379
+ "cv 32\n",
380
+ "cv 38\n",
381
+ "cv 40\n",
382
+ "cv 57\n",
383
+ "cv 59\n",
384
+ "cv 60\n",
385
+ "cv 76\n",
386
+ "cv 83\n",
387
+ "cv 92\n",
388
+ "cv 95\n",
389
+ "cv 98\n",
390
+ "cv 105\n",
391
+ "cv 122\n",
392
+ "(15, 3) (15, 128)\n",
393
+ "test 1\n",
394
+ "test 15\n",
395
+ "test 42\n",
396
+ "test 55\n",
397
+ "test 58\n",
398
+ "test 68\n",
399
+ "test 69\n",
400
+ "test 77\n",
401
+ "test 85\n",
402
+ "test 86\n",
403
+ "test 87\n",
404
+ "test 93\n",
405
+ "test 94\n",
406
+ "test 112\n",
407
+ "test 118\n",
408
+ "test 127\n",
409
+ "(16, 3) (16, 128)\n"
410
+ ]
411
+ }
412
+ ],
413
+ "source": [
414
+ "x_p = {}\n",
415
+ "y_p = {}\n",
416
+ "\n",
417
+ "for an in [\"train\", \"cv\", \"test\"]:\n",
418
+ " x_p[an] = np.zeros((len(inds[an]),3))\n",
419
+ " y_p[an] = np.zeros((len(inds[an]),128))\n",
420
+ "\n",
421
+ " cntr = 0\n",
422
+ " for i in inds[an]:\n",
423
+ " print(an, i)\n",
424
+ " T = np.mean(y[\"Tprev\"][i], axis=0)\n",
425
+ "\n",
426
+ " x_p[an][cntr,:] = x[\"Tprev\"][i]\n",
427
+ " y_p[an][cntr,:] = T\n",
428
+ " cntr += 1 \n",
429
+ " \n",
430
+ " print(x_p[an].shape, y_p[an].shape)\n",
431
+ "\n",
432
+ "\n",
433
+ "\n",
434
+ "with open(pre + 'x_p.pkl', 'wb') as file: \n",
435
+ " pickle.dump(x_p, file) \n",
436
+ "with open(pre + 'y_p.pkl', 'wb') as file: \n",
437
+ " pickle.dump(y_p, file) "
438
+ ]
439
+ },
440
+ {
441
+ "cell_type": "code",
442
+ "execution_count": null,
443
+ "id": "960cc6f5-de5e-4788-aaaa-65e06aa33a42",
444
+ "metadata": {},
445
+ "outputs": [],
446
+ "source": []
447
+ }
448
+ ],
449
+ "metadata": {
450
+ "kernelspec": {
451
+ "display_name": "Python 3 (ipykernel)",
452
+ "language": "python",
453
+ "name": "python3"
454
+ },
455
+ "language_info": {
456
+ "codemirror_mode": {
457
+ "name": "ipython",
458
+ "version": 3
459
+ },
460
+ "file_extension": ".py",
461
+ "mimetype": "text/x-python",
462
+ "name": "python",
463
+ "nbconvert_exporter": "python",
464
+ "pygments_lexer": "ipython3",
465
+ "version": "3.10.9"
466
+ }
467
+ },
468
+ "nbformat": 4,
469
+ "nbformat_minor": 5
470
+ }
stats/MLP_stats.txt CHANGED
@@ -1,23 +1,23 @@
1
- ------------ --------- ------- ------- ---------
2
- architecture mae train mae cv diff diff + cv
3
- [32, 2] 0.01158 0.01249 0.00091 0.01294
4
- [64, 2] 0.01077 0.01379 0.00302 0.0153
5
- [128, 2] 0.00997 0.01177 0.00179 0.01266
6
- [256, 2] 0.01 0.01226 0.00226 0.01339
7
- [32, 3] 0.00922 0.01167 0.00245 0.01289
8
- [64, 3] 0.00878 0.01149 0.00271 0.01284
9
- [128, 3] 0.00834 0.01009 0.00175 0.01097
10
- [256, 3] 0.0082 0.00927 0.00108 0.00981
11
- [32, 4] 0.00797 0.00915 0.00118 0.00974
12
- [64, 4] 0.00652 0.00821 0.00169 0.00905
13
- [128, 4] 0.00592 0.0083 0.00237 0.00948
14
- [256, 4] 0.00681 0.00794 0.00113 0.0085
15
- [32, 5] 0.00877 0.01209 0.00332 0.01375
16
- [64, 5] 0.00584 0.0084 0.00256 0.00968
17
- [128, 5] 0.00608 0.00857 0.00249 0.00981
18
- [256, 5] 0.00682 0.00814 0.00133 0.00881
19
- [32, 6] 0.00799 0.01453 0.00654 0.0178
20
- [64, 6] 0.0118 0.01206 0.00026 0.01219
21
- [128, 6] 0.00607 0.00949 0.00342 0.0112
22
- [256, 6] 0.00738 0.01015 0.00277 0.01153
23
- ------------ --------- ------- ------- ---------
 
1
+ ------------ --------- ------- -------- ---------
2
+ architecture mae train mae cv diff diff + cv
3
+ [32, 2] 0.01167 0.01348 0.00181 0.01438
4
+ [64, 2] 0.01003 0.01141 0.00138 0.01211
5
+ [128, 2] 0.00979 0.0115 0.00171 0.01236
6
+ [256, 2] 0.00975 0.01108 0.00133 0.01174
7
+ [32, 3] 0.00849 0.00934 0.00085 0.00976
8
+ [64, 3] 0.00601 0.00769 0.00167 0.00852
9
+ [128, 3] 0.00572 0.00727 0.00155 0.00805
10
+ [256, 3] 0.00624 0.00729 0.00105 0.00782
11
+ [32, 4] 0.00642 0.00773 0.00131 0.00838
12
+ [64, 4] 0.00513 0.00599 0.00086 0.00642
13
+ [128, 4] 0.00439 0.00553 0.00114 0.0061
14
+ [256, 4] 0.0045 0.00584 0.00133 0.0065
15
+ [32, 5] 0.00709 0.00935 0.00226 0.01048
16
+ [64, 5] 0.00741 0.00838 0.00097 0.00886
17
+ [128, 5] 0.00442 0.00541 0.00098 0.0059
18
+ [256, 5] 0.00469 0.0071 0.00241 0.00831
19
+ [32, 6] 0.01028 0.01264 0.00236 0.01383
20
+ [64, 6] 0.01263 0.00999 -0.00264 0.00999
21
+ [128, 6] 0.00277 0.00861 0.00584 0.01154
22
+ [256, 6] 0.00694 0.00751 0.00057 0.00779
23
+ ------------ --------- ------- -------- ---------
{data → train}/mlp.py RENAMED
@@ -8,13 +8,22 @@ from matplotlib.lines import Line2D
8
  import math
9
 
10
  def get_lr(optimizer):
 
11
  for param_group in optimizer.param_groups:
12
  return param_group['lr']
13
 
14
  class MLP(nn.Module):
15
  def __init__(self, f_i: int, f_o: int, act_fn: object = nn.SELU, f=[], insert_in=[4], freq_encoding=False):
16
-
17
  super().__init__()
 
 
 
 
 
 
 
 
 
18
 
19
  self.insert_in = insert_in
20
  self.layers = nn.ModuleList()
@@ -63,6 +72,7 @@ class MLP(nn.Module):
63
  return x
64
 
65
  def one_epoch_mlp(mlp, epoch, loader, optimizer, device, is_train=False):
 
66
  running_loss = 0.
67
  counter = 1
68
  loss_fn = torch.nn.L1Loss() #reduction="none")
@@ -94,6 +104,7 @@ def one_epoch_mlp(mlp, epoch, loader, optimizer, device, is_train=False):
94
  return running_loss/counter
95
 
96
  def one_epoch_mlp_lbfgs(mlp, epoch, loader, optimizer, device, is_train=False):
 
97
  running_loss = 0.
98
  counter = 1
99
  loss_fn = torch.nn.L1Loss() #reduction="none")
@@ -181,6 +192,7 @@ class Siren(nn.Module):
181
  out = self.dropout(out)
182
  return out
183
 
 
184
  class SirenMLP(nn.Module):
185
 
186
  def __init__(self,
 
8
  import math
9
 
10
  def get_lr(optimizer):
11
+ '''Function to get learning rate'''
12
  for param_group in optimizer.param_groups:
13
  return param_group['lr']
14
 
15
  class MLP(nn.Module):
16
  def __init__(self, f_i: int, f_o: int, act_fn: object = nn.SELU, f=[], insert_in=[4], freq_encoding=False):
 
17
  super().__init__()
18
+ '''
19
+ Feedforward neural network model in Pytorch.
20
+ f_i: input filters
21
+ f_o: output dimension
22
+ act_fn: activate function
23
+ f: list of filters in hidden layers
24
+ insert_in: if insert inputs into layer numbers
25
+ freq_encoding: if use frequency encoding
26
+ '''
27
 
28
  self.insert_in = insert_in
29
  self.layers = nn.ModuleList()
 
72
  return x
73
 
74
  def one_epoch_mlp(mlp, epoch, loader, optimizer, device, is_train=False):
75
+ '''Function to run an epoch in Pytorch with a standard optimizer'''
76
  running_loss = 0.
77
  counter = 1
78
  loss_fn = torch.nn.L1Loss() #reduction="none")
 
104
  return running_loss/counter
105
 
106
  def one_epoch_mlp_lbfgs(mlp, epoch, loader, optimizer, device, is_train=False):
107
+ '''Function to run an epoch in Pytorch with a lbfgs'''
108
  running_loss = 0.
109
  counter = 1
110
  loss_fn = torch.nn.L1Loss() #reduction="none")
 
192
  out = self.dropout(out)
193
  return out
194
 
195
+ # Siren network
196
  class SirenMLP(nn.Module):
197
 
198
  def __init__(self,
{data → train}/train_profiles_mlp.py RENAMED
@@ -6,8 +6,8 @@ import matplotlib.pyplot as plt
6
  import torch
7
  from mlp import *
8
  import argparse
9
- from datasetio import *
10
- from torch.utils.data import TensorDataset
11
 
12
  import copy
13
  import pickle
@@ -16,8 +16,7 @@ import time
16
  # In[ ]:
17
 
18
 
19
- data_dir = "/plp_scr1/agar_sh/data/TPH/"
20
- nn_dir = "/plp_user/agar_sh/PBML/pytorch/TPH/MLP/trained_networks/"
21
 
22
 
23
  # In[ ]:
@@ -26,7 +25,7 @@ nn_dir = "/plp_user/agar_sh/PBML/pytorch/TPH/MLP/trained_networks/"
26
  run_cell = True
27
  if run_cell:
28
  parser = argparse.ArgumentParser(description='Train mlp')
29
- parser.add_argument("-gpu", "--gpu_number", type=int, help="specify gpu number")
30
  parser.add_argument("-a", "--act_fn", type=str, help ="activation function")
31
  parser.add_argument("-l", "--num_layers", type=int, help ="activation function")
32
  parser.add_argument("-f", "--f_h", type=int, help ="filters")
@@ -53,7 +52,10 @@ f_nn = "mlp_profile_pointwise_" + str(f_h) + "_" + str(num_layers) + "_" + act_f
53
  if not os.path.isdir(nn_dir + f_nn):
54
  os.mkdir(nn_dir + f_nn)
55
 
56
- device = torch.device("cuda:" + str(gpu_number)) if torch.cuda.is_available() else torch.device("cpu")
 
 
 
57
 
58
  epoch = 0
59
  start_lr = 1e-3
@@ -80,7 +82,7 @@ with open(nn_dir + "mlp.txt", 'w') as writer:
80
  dataset = {}
81
  loader = {}
82
  batches = {}
83
- pre = "/plp_user/agar_sh/PBML/pytorch/TPH/MLP/profiles/"
84
  with open(pre + 'x_pointwise.pkl', 'rb') as file:
85
  x_pointwise = pickle.load(file)
86
  with open(pre + 'y_pointwise.pkl', 'rb') as file:
 
6
  import torch
7
  from mlp import *
8
  import argparse
9
+ #from datasetio import *
10
+ from torch.utils.data import TensorDataset, DataLoader
11
 
12
  import copy
13
  import pickle
 
16
  # In[ ]:
17
 
18
 
19
+ nn_dir = "./"
 
20
 
21
 
22
  # In[ ]:
 
25
  run_cell = True
26
  if run_cell:
27
  parser = argparse.ArgumentParser(description='Train mlp')
28
+ parser.add_argument("-gpu", "--gpu_number", type=int, help="specify gpu number", default=-1)
29
  parser.add_argument("-a", "--act_fn", type=str, help ="activation function")
30
  parser.add_argument("-l", "--num_layers", type=int, help ="activation function")
31
  parser.add_argument("-f", "--f_h", type=int, help ="filters")
 
52
  if not os.path.isdir(nn_dir + f_nn):
53
  os.mkdir(nn_dir + f_nn)
54
 
55
+ if gpu_number >=0 and torch.cuda.is_available():
56
+ device = torch.device("cuda:" + str(gpu_number))
57
+ else:
58
+ device = torch.device("cpu")
59
 
60
  epoch = 0
61
  start_lr = 1e-3
 
82
  dataset = {}
83
  loader = {}
84
  batches = {}
85
+ pre = "../data/"
86
  with open(pre + 'x_pointwise.pkl', 'rb') as file:
87
  x_pointwise = pickle.load(file)
88
  with open(pre + 'y_pointwise.pkl', 'rb') as file:
train/trained_networks/mlp_profile_[128]_2_selu/mlp.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:47bf6d6c36d171d50ada7c24a10f740164f712c846a6028bf58f1ddd636dabae
3
+ size 267488
train/trained_networks/mlp_profile_[128]_2_selu/mlp.txt ADDED
The diff for this file is too large to render. See raw diff
 
train/trained_networks/mlp_profile_[128]_3_selu/mlp.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:35dbc140c6235cf3c1dcb2d820838d75eaa4cb67fde44069ec7fb9b22c299ce6
3
+ size 400016
train/trained_networks/mlp_profile_[128]_3_selu/mlp.txt ADDED
The diff for this file is too large to render. See raw diff