herilalaina commited on
Commit
b1beb2e
1 Parent(s): 8d12475
Files changed (46) hide show
  1. app.py +101 -0
  2. lcpfn/.ipynb_checkpoints/__init__-checkpoint.py +53 -0
  3. lcpfn/.ipynb_checkpoints/curves-checkpoint.py +277 -0
  4. lcpfn/.ipynb_checkpoints/domhan_prior-checkpoint.py +195 -0
  5. lcpfn/__init__.py +53 -0
  6. lcpfn/__pycache__/__init__.cpython-310.pyc +0 -0
  7. lcpfn/__pycache__/bar_distribution.cpython-310.pyc +0 -0
  8. lcpfn/__pycache__/curves.cpython-310.pyc +0 -0
  9. lcpfn/__pycache__/domhan_prior.cpython-310.pyc +0 -0
  10. lcpfn/__pycache__/encoders.cpython-310.pyc +0 -0
  11. lcpfn/__pycache__/layer.cpython-310.pyc +0 -0
  12. lcpfn/__pycache__/model.cpython-310.pyc +0 -0
  13. lcpfn/__pycache__/positional_encodings.cpython-310.pyc +0 -0
  14. lcpfn/__pycache__/train.cpython-310.pyc +0 -0
  15. lcpfn/__pycache__/train_lcpfn.cpython-310.pyc +0 -0
  16. lcpfn/__pycache__/transformer.cpython-310.pyc +0 -0
  17. lcpfn/__pycache__/utils.cpython-310.pyc +0 -0
  18. lcpfn/bar_distribution.py +269 -0
  19. lcpfn/curves.py +277 -0
  20. lcpfn/decoders.py +30 -0
  21. lcpfn/domhan_prior.py +195 -0
  22. lcpfn/encoders.py +161 -0
  23. lcpfn/initializers.py +9 -0
  24. lcpfn/layer.py +126 -0
  25. lcpfn/model.py +29 -0
  26. lcpfn/positional_encodings.py +70 -0
  27. lcpfn/priors/__init__.py +1 -0
  28. lcpfn/priors/__pycache__/__init__.cpython-310.pyc +0 -0
  29. lcpfn/priors/__pycache__/gp.cpython-310.pyc +0 -0
  30. lcpfn/priors/__pycache__/prior.cpython-310.pyc +0 -0
  31. lcpfn/priors/__pycache__/ridge.cpython-310.pyc +0 -0
  32. lcpfn/priors/__pycache__/utils.cpython-310.pyc +0 -0
  33. lcpfn/priors/binarized_regression.py +19 -0
  34. lcpfn/priors/fast_gp.py +143 -0
  35. lcpfn/priors/fast_gp_mix.py +394 -0
  36. lcpfn/priors/gp.py +69 -0
  37. lcpfn/priors/prior.py +25 -0
  38. lcpfn/priors/pyro.py +41 -0
  39. lcpfn/priors/ridge.py +37 -0
  40. lcpfn/priors/stroke.py +143 -0
  41. lcpfn/priors/utils.py +151 -0
  42. lcpfn/train.py +602 -0
  43. lcpfn/train_lcpfn.py +92 -0
  44. lcpfn/transformer.py +226 -0
  45. lcpfn/utils.py +258 -0
  46. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import lcpfn
6
+ import torch
7
+
8
+ model = lcpfn.LCPFN()
9
+
10
+ def line_plot_fn(data, cutoff, ci_form):
11
+ cutoff = int(cutoff)
12
+ ci = int(ci_form)
13
+
14
+ empty_values = list(data[data.y == ""].index)
15
+
16
+ if len(empty_values) > 0:
17
+ if (len(empty_values) == 1 and empty_values[0] != 49) or (len(empty_values) > 1 and not all(y-x==1 for x,y in zip(empty_values, empty_values[1:]))):
18
+ raise gr.Error("Please enter a valid learning curve.")
19
+ else:
20
+ data = data[data.y != ""]
21
+
22
+ if len(data) < cutoff:
23
+ raise gr.Error(f"Cutoff ({cutoff}) cannot be greater than the number of data points ({len(data)}).")
24
+
25
+ try:
26
+ data["y"] = data["y"].astype(float)
27
+ except:
28
+ raise gr.Error("Please enter a valid learning curve.")
29
+
30
+ x = torch.arange(1, 51).unsqueeze(1)
31
+ y = torch.from_numpy(data.y.values).float().unsqueeze(1)
32
+
33
+ rest_prob = (1 - (ci / 100)) / 2
34
+ predictions = model.predict_quantiles(x_train=x[:cutoff], y_train=y[:cutoff], x_test=x[(cutoff-1):], qs=[rest_prob, 0.5, 1-rest_prob])
35
+
36
+ fig, ax = plt.subplots()
37
+
38
+ ax.plot(x, data.y, "black", label="target")
39
+
40
+ # plot extrapolation
41
+ ax.plot(x[(cutoff-1):], predictions[:, 1], "blue", label="Extrapolation by PFN")
42
+ ax.fill_between(
43
+ x[(cutoff-1):].flatten(), predictions[:, 0], predictions[:, 2], color="blue", alpha=0.2, label="CI of 90%"
44
+ )
45
+
46
+ # plot cutoff
47
+ ax.vlines(cutoff, 0, 1, linewidth=0.5, color="k", label="cutoff", linestyles="dashed")
48
+ ax.set_ylim(0, 1)
49
+ ax.set_xlim(0, 50)
50
+ ax.legend(loc="lower right")
51
+ ax.set_xlabel("t")
52
+ ax.set_ylabel("y")
53
+
54
+ return fig
55
+
56
+ prior = lcpfn.sample_from_prior(np.random)
57
+ curve, _ = prior()
58
+
59
+ examples = []
60
+ for _ in range(10):
61
+ prior = lcpfn.sample_from_prior(np.random)
62
+ curve, _ = prior()
63
+ if np.random.rand() < 0.5:
64
+ curve = _
65
+ df = pd.DataFrame.from_records(curve[:50][..., np.newaxis], columns=["y"])
66
+ df["t"] = [i for i in range(1, 50 + 1)]
67
+ examples.append([df[["t", "y"]], 10])
68
+
69
+ with gr.Column() as components:
70
+ gr.Number(value=10)
71
+ gr.Number(value=10)
72
+
73
+ with gr.Blocks() as demo:
74
+ with gr.Row():
75
+ with gr.Column():
76
+ dataform = gr.Dataframe(
77
+ value=examples[0][0],
78
+ headers=["t", "y"],
79
+ datatype=["number", "number"],
80
+ row_count=(50, "fixed"),
81
+ col_count=(2, "fixed"),
82
+ type="pandas",
83
+ )
84
+ with gr.Row():
85
+ cutoffform = gr.Number(label="cutoff", value=10)
86
+ ci_form = gr.Dropdown(label="Confidence Interval", choices=[
87
+ ("90%", 90),
88
+ ("95%", 95),
89
+ ("99%", 99)
90
+ ], value=90)
91
+ btn = gr.Button("Run")
92
+ outputform = gr.Plot()
93
+ btn.click(fn=line_plot_fn, inputs=[dataform, cutoffform, ci_form], outputs=outputform)
94
+ gr.Examples(examples, inputs=[dataform], label="Examples of synthetic learning curves")
95
+
96
+
97
+
98
+
99
+ if __name__ == "__main__":
100
+ demo.launch()
101
+
lcpfn/.ipynb_checkpoints/__init__-checkpoint.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ sys.path.insert(0, os.path.dirname(__file__))
3
+
4
+
5
+ model_path = 'trained_models'
6
+
7
+ def prepare_models():
8
+ pfns4bo_dir = os.path.dirname(__file__)
9
+ model_names = ['pfn_EPOCH1000_EMSIZE512_NLAYERS12_NBUCKETS1000.pt',
10
+ 'pfn_EPOCH1000_EMSIZE512_NLAYERS6_NBUCKETS1000.pt']
11
+
12
+ for name in model_names:
13
+ weights_path = os.path.join(pfns4bo_dir, model_path, name)
14
+ compressed_weights_path = os.path.join(pfns4bo_dir, model_path, name + '.gz')
15
+ if not os.path.exists(weights_path):
16
+ if not os.path.exists(compressed_weights_path):
17
+ print("Downloading", os.path.abspath(compressed_weights_path))
18
+ import requests
19
+ url = f'https://github.com/automl/lcpfn/raw/main/lcpfn/trained_models/{name + ".gz"}'
20
+ r = requests.get(url, allow_redirects=True)
21
+ os.makedirs(os.path.dirname(compressed_weights_path), exist_ok=True)
22
+ with open(compressed_weights_path, 'wb') as f:
23
+ f.write(r.content)
24
+ if os.path.exists(compressed_weights_path):
25
+ print("Unzipping", name)
26
+ os.system(f"gzip -dk {compressed_weights_path}")
27
+ else:
28
+ print("Failed to find", compressed_weights_path)
29
+ print("Make sure you have an internet connection to download the model automatically..")
30
+ if os.path.exists(weights_path):
31
+ print("Successfully located model at", weights_path)
32
+
33
+
34
+ model_dict = {
35
+ 'EMSIZE512_NLAYERS12_NBUCKETS1000': os.path.join(os.path.dirname(__file__),model_path,
36
+ 'pfn_EPOCH1000_EMSIZE512_NLAYERS12_NBUCKETS1000.pt'),
37
+ 'EMSIZE512_NLAYERS6_NBUCKETS1000': os.path.join(os.path.dirname(__file__),model_path,
38
+ 'pfn_EPOCH1000_EMSIZE512_NLAYERS6_NBUCKETS1000.pt'),
39
+ }
40
+
41
+
42
+ def __getattr__(name):
43
+ if name in model_dict:
44
+ if not os.path.exists(model_dict[name]):
45
+ print("Can't find", os.path.abspath(model_dict[name]), "thus unzipping/downloading models now.")
46
+ print("This might take a while..")
47
+ prepare_models()
48
+ return model_dict[name]
49
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
50
+
51
+ from lcpfn.model import LCPFN
52
+ from lcpfn.train_lcpfn import train_lcpfn
53
+ from lcpfn.domhan_prior import sample_from_prior, create_get_batch_func
lcpfn/.ipynb_checkpoints/curves-checkpoint.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from collections import OrderedDict
3
+
4
+ prior = {
5
+ "pow3": {
6
+ "uniform": OrderedDict(
7
+ a={"type": "uniform", "param1": -1, "param2": 1},
8
+ c={"type": "uniform", "param1": 0, "param2": 1},
9
+ alpha={"type": "uniform", "param1": 0, "param2": 1},
10
+ ),
11
+ "peaked": OrderedDict(
12
+ a={"type": "uniform", "param1": -0.6, "param2": 0.6},
13
+ c={"type": "uniform", "param1": 0, "param2": 1.25},
14
+ alpha={"type": "log_normal", "param1": 0, "param2": 2},
15
+ ),
16
+ },
17
+ "ilog2": {
18
+ "uniform": OrderedDict(
19
+ c={"type": "uniform", "param1": 0, "param2": 1},
20
+ a={"type": "uniform", "param1": -1, "param2": 1},
21
+ ),
22
+ "peaked": OrderedDict(
23
+ c={"type": "uniform", "param1": 0, "param2": 1},
24
+ a={"type": "uniform", "param1": -0.5, "param2": 0.5},
25
+ ),
26
+ },
27
+ "janoschek": {
28
+ "uniform": OrderedDict(
29
+ a={"type": "uniform", "param1": 0, "param2": 1},
30
+ beta={"type": "uniform", "param1": 0, "param2": 2},
31
+ k={"type": "uniform", "param1": 0, "param2": 1},
32
+ delta={"type": "uniform", "param1": -5, "param2": 5},
33
+ ),
34
+ "peaked": OrderedDict(
35
+ a={"type": "uniform", "param1": 0, "param2": 1},
36
+ beta={"type": "uniform", "param1": 0, "param2": 2},
37
+ k={"type": "log_normal", "param1": -2, "param2": 1},
38
+ delta={"type": "log_normal", "param1": 0, "param2": 0.5},
39
+ ),
40
+ },
41
+ }
42
+
43
+
44
+ def prior_sampler(rng, type, param1, param2):
45
+ if type == "uniform":
46
+ return rng.uniform(param1, param2)
47
+ elif type == "log_normal":
48
+ return rng.lognormal(param1, param2)
49
+ raise Exception("Unknown prior type: {}".format(type))
50
+
51
+
52
+ def pow3(x, c, a, alpha):
53
+ return c - a * (x) ** (-alpha)
54
+
55
+
56
+ def prior_pow3(rng):
57
+ return {
58
+ p: prior_sampler(
59
+ rng,
60
+ prior["pow3"]["peaked"][p]["type"],
61
+ param1=prior["pow3"]["peaked"][p]["param1"],
62
+ param2=prior["pow3"]["peaked"][p]["param2"],
63
+ )
64
+ for p in ["a", "c", "alpha"]
65
+ }
66
+
67
+
68
+ def uniform_prior_pow3(rng):
69
+ return {
70
+ p: prior_sampler(
71
+ rng,
72
+ prior["pow3"]["uniform"][p]["type"],
73
+ param1=prior["pow3"]["uniform"][p]["param1"],
74
+ param2=prior["pow3"]["uniform"][p]["param2"],
75
+ )
76
+ for p in ["a", "c", "alpha"]
77
+ }
78
+
79
+
80
+ def ilog2(x, c, a):
81
+ return c - a / (np.log(x + 1))
82
+
83
+
84
+ def prior_ilog2(rng):
85
+ return {
86
+ p: prior_sampler(
87
+ rng,
88
+ prior["ilog2"]["peaked"][p]["type"],
89
+ param1=prior["ilog2"]["peaked"][p]["param1"],
90
+ param2=prior["ilog2"]["peaked"][p]["param2"],
91
+ )
92
+ for p in ["a", "c"]
93
+ }
94
+
95
+
96
+ def uniform_prior_ilog2(rng):
97
+ return {
98
+ p: prior_sampler(
99
+ rng,
100
+ prior["ilog2"]["uniform"][p]["type"],
101
+ param1=prior["ilog2"]["uniform"][p]["param1"],
102
+ param2=prior["ilog2"]["uniform"][p]["param2"],
103
+ )
104
+ for p in ["a", "c"]
105
+ }
106
+
107
+
108
+ def janoschek(x, a, beta, k, delta):
109
+ """
110
+ http://www.pisces-conservation.com/growthhelp/janoschek.htm
111
+ """
112
+ return a - (a - beta) * np.exp(-k * x**delta)
113
+
114
+
115
+ def prior_janoschek(rng):
116
+ return {
117
+ p: prior_sampler(
118
+ rng,
119
+ prior["janoschek"]["peaked"][p]["type"],
120
+ param1=prior["janoschek"]["peaked"][p]["param1"],
121
+ param2=prior["janoschek"]["peaked"][p]["param2"],
122
+ )
123
+ for p in ["a", "beta", "k", "delta"]
124
+ }
125
+
126
+
127
+ def uniform_prior_janoschek(rng):
128
+ return {
129
+ p: prior_sampler(
130
+ rng,
131
+ prior["janoschek"]["uniform"][p]["type"],
132
+ param1=prior["janoschek"]["uniform"][p]["param1"],
133
+ param2=prior["janoschek"]["uniform"][p]["param2"],
134
+ )
135
+ for p in ["a", "beta", "k", "delta"]
136
+ }
137
+
138
+
139
+ def log_power(x, a, b, c):
140
+ # a: upper bound
141
+ # c: growth rate
142
+ # initial = a/ (1 + (1/e^b)^c
143
+ return a / (1.0 + (x / np.exp(b)) ** c)
144
+
145
+
146
+ def prior_log_power(rng):
147
+ # a ~ N(0.8,0.1)
148
+ # b ~ N(1,1)
149
+ # c ~ U(-3,0)
150
+ a = rng.normal(0.8, 0.1)
151
+ b = rng.normal(1.0, 1.0)
152
+ c = rng.uniform(-3.0, 0.0)
153
+ return {"a": a, "b": b, "c": c}
154
+
155
+
156
+ def weibull(x, alpha, beta, kappa, delta):
157
+ """
158
+ Weibull modell
159
+ http://www.pisces-conservation.com/growthhelp/index.html?morgan_mercer_floden.htm
160
+ alpha: upper asymptote
161
+ beta: lower asymptote
162
+ k: growth rate
163
+ delta: controls the x-ordinate for the point of inflection
164
+ """
165
+ return alpha - (alpha - beta) * np.exp(-((kappa * x) ** delta))
166
+
167
+
168
+ def prior_weibull(rng):
169
+ alpha = rng.uniform(0.0, 1.5)
170
+ beta = rng.uniform(0.0, 1)
171
+ kappa = np.exp(rng.normal(-2.0, 1.0))
172
+ delta = np.exp(rng.normal(0, 0.5))
173
+ return {"alpha": alpha, "beta": beta, "kappa": kappa, "delta": delta}
174
+
175
+
176
+ def mmf(x, alpha, beta, kappa, delta):
177
+ """
178
+ Morgan-Mercer-Flodin
179
+ description:
180
+ Nonlinear Regression page 342
181
+ http://bit.ly/1jodG17
182
+ http://www.pisces-conservation.com/growthhelp/index.html?morgan_mercer_floden.htm
183
+ alpha: upper asymptote
184
+ kappa: growth rate
185
+ beta: initial value
186
+ delta: controls the point of inflection
187
+ """
188
+ return alpha - (alpha - beta) / (1.0 + (kappa * x) ** delta)
189
+
190
+
191
+ def prior_mmf(rng):
192
+ # alpha ~ N(0.8,0.1)
193
+ # beta ~ N(0.2,0.1)
194
+ # ln(kappa) ~ N(0,2)
195
+ # ln(delta) ~ N(0,1)
196
+ alpha = rng.normal(0.8, 0.1)
197
+ beta = rng.normal(0.2, 0.1)
198
+ kappa = np.exp(rng.normal(0, 2))
199
+ delta = np.exp(rng.normal(0, 1))
200
+ return {"alpha": alpha, "beta": beta, "kappa": kappa, "delta": delta}
201
+
202
+
203
+ def vap(x, a, b, c):
204
+ """Vapor pressure model"""
205
+ # no upper bound if c > 0
206
+ # a = ln(upper bound) for c=0
207
+ # a+b = ln(initial)
208
+ return np.exp(a + b / x + c * np.log(x))
209
+
210
+
211
+ def prior_vap(rng):
212
+ a = rng.uniform(-2.0, 0.0) # @heri: range check
213
+ b = rng.uniform(-4.0, 0.0) # @heri: range check
214
+ c = np.exp(rng.uniform(-8.0, 0.0)) # @heri: same as weights
215
+ return {"a": a, "b": b, "c": c}
216
+
217
+
218
+ def loglog_linear(x, a, b):
219
+ x = np.log(x)
220
+ return np.log(a * x + b)
221
+
222
+
223
+ def prior_loglog_linear(rng):
224
+ # ln(a) ~ N(-2, 1)
225
+ # ln(b) ~ U(0, 1)
226
+ a = np.exp(rng.normal(-2.0, 1.0))
227
+ b = np.exp(rng.uniform(0.0, 1.0))
228
+ return {"a": a, "b": b}
229
+
230
+
231
+ def exp4(x, c, a, b, alpha):
232
+ return c - np.exp(-a * (x**alpha) + b)
233
+
234
+
235
+ def prior_exp4(rng):
236
+ # c ~ N(0.8,0.1)
237
+ c = rng.normal(0.8, 0.1)
238
+ # ln(a) ~ N(-2,1)
239
+ a = np.exp(rng.normal(-2, 1))
240
+ # ln(alpha) ~ N(0,1)
241
+ alpha = np.exp(rng.normal(0, 1))
242
+ # ln(b) ~ N(0,0.5)
243
+ b = np.exp(rng.normal(0, 0.5))
244
+ return {"a": a, "b": b, "c": c, "alpha": alpha}
245
+
246
+
247
+ def pow4(x, c, a, b, alpha):
248
+ return c - (a * x + b) ** -alpha
249
+
250
+
251
+ def prior_pow4(rng):
252
+ # ln(1 - c) ~ U(-5, 0)
253
+ c = 1 - np.exp(rng.uniform(-5.0, 0))
254
+ # ln(a) ~ N(-3, 2)
255
+ a = np.exp(rng.normal(-3.0, 2))
256
+ # ln(alpha) ~ N(0,1)
257
+ alpha = np.exp(rng.normal(0, 1))
258
+ # ln(b) ~ U(0, 1)
259
+ b = np.exp(rng.uniform(0, 1))
260
+ return {"a": a, "b": b, "c": c, "alpha": alpha}
261
+
262
+
263
+ def dr_hill_zero_background(x, theta, eta, kappa):
264
+ # theta: upper bound
265
+ # eta: growth rate
266
+ # initial = theta/(kappa^eta + 1)
267
+ return (theta * x**eta) / (kappa**eta + x**eta)
268
+
269
+
270
+ def prior_dr_hill_zero_background(rng):
271
+ # theta ~ U(1,0) N(0.8,0.1)
272
+ # ln(eta) ~ N(1,1)
273
+ # ln(kappa) ~ N(1,2)
274
+ theta = rng.normal(0.8, 0.1)
275
+ eta = np.exp(rng.normal(1.0, 1.0))
276
+ kappa = np.exp(rng.normal(1.0, 2.0))
277
+ return {"theta": theta, "eta": eta, "kappa": kappa}
lcpfn/.ipynb_checkpoints/domhan_prior-checkpoint.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import torch
3
+ import numpy as np
4
+ from lcpfn.curves import (
5
+ pow3,
6
+ ilog2,
7
+ janoschek,
8
+ log_power,
9
+ prior_ilog2,
10
+ uniform_prior_pow3,
11
+ weibull,
12
+ mmf,
13
+ vap,
14
+ loglog_linear,
15
+ exp4,
16
+ pow4,
17
+ dr_hill_zero_background,
18
+ )
19
+ from lcpfn.curves import (
20
+ prior_pow3,
21
+ prior_janoschek,
22
+ prior_log_power,
23
+ prior_weibull,
24
+ prior_mmf,
25
+ prior_vap,
26
+ prior_loglog_linear,
27
+ prior_exp4,
28
+ prior_pow4,
29
+ prior_dr_hill_zero_background,
30
+ )
31
+ from lcpfn.curves import (
32
+ uniform_prior_pow3,
33
+ uniform_prior_ilog2,
34
+ uniform_prior_janoschek,
35
+ )
36
+
37
+
38
+ def prior_weights(
39
+ rng,
40
+ components=[
41
+ "pow3",
42
+ "ilog2",
43
+ "janoschek",
44
+ "log_power",
45
+ "weibull",
46
+ "mmf",
47
+ "vap",
48
+ "loglog_linear",
49
+ "exp4",
50
+ "pow4",
51
+ "dr_hill_zero_background",
52
+ ],
53
+ ):
54
+ K = len(components)
55
+ weights = rng.uniform(0.0, 1, size=(K,))
56
+ return {f: weights[i] for i, f in enumerate(components)}
57
+
58
+
59
+ def sample_from_prior(rng, seq_len=100):
60
+ return sample_prior_comb(
61
+ rng=rng, seq_len=seq_len, components=["pow3", "ilog2", "janoschek"], distribution="peaked"
62
+ )
63
+
64
+
65
+ def sample_prior_comb(
66
+ rng,
67
+ components,
68
+ distribution,
69
+ var_lnloc=-4,
70
+ var_lnscale=1,
71
+ range_constraint=True,
72
+ seq_len=100,
73
+ ):
74
+ f_components = {
75
+ "pow3": pow3,
76
+ "ilog2": ilog2,
77
+ "janoschek": janoschek,
78
+ "log_power": log_power,
79
+ "weibull": weibull,
80
+ "mmf": mmf,
81
+ "vap": vap,
82
+ "loglog_linear": loglog_linear,
83
+ "exp4": exp4,
84
+ "pow4": pow4,
85
+ "dr_hill_zero_background": dr_hill_zero_background,
86
+ }
87
+
88
+ if distribution == "peaked":
89
+ f_priors = {
90
+ "pow3": prior_pow3,
91
+ "ilog2": prior_ilog2,
92
+ "janoschek": prior_janoschek,
93
+ "log_power": prior_log_power,
94
+ "weibull": prior_weibull,
95
+ "mmf": prior_mmf,
96
+ "vap": prior_vap,
97
+ "loglog_linear": prior_loglog_linear,
98
+ "exp4": prior_exp4,
99
+ "pow4": prior_pow4,
100
+ "dr_hill_zero_background": prior_dr_hill_zero_background,
101
+ }
102
+ elif distribution == "uniform":
103
+ f_priors = {
104
+ "pow3": uniform_prior_pow3,
105
+ "ilog2": uniform_prior_ilog2,
106
+ "janoschek": uniform_prior_janoschek
107
+ }
108
+ else:
109
+ raise NotImplemented()
110
+
111
+ x = np.arange(1, seq_len + 1)
112
+
113
+ while True:
114
+ # sample the noiseless curve
115
+ weights = prior_weights(rng, components=components)
116
+ y = np.zeros(x.shape, dtype="float")
117
+ kwargs = 0
118
+ for f, w in weights.items():
119
+ kwargs = f_priors[f](rng)
120
+ # print(f_components[f](x, **kwargs))
121
+ y += w * f_components[f](x, **kwargs)
122
+ # add noise (can exceed [0,1], but afaik no way to implement this prior in Tobis work)
123
+ var = np.exp(
124
+ rng.normal(var_lnloc, var_lnscale)
125
+ ) # @heri: ln_prob =+ log(normal.pdf(log(var), loc=var_lnloc, scale=var_lnscale))
126
+
127
+ # reject any curves that are non-increasing, exceed the [0,1] range
128
+ if (
129
+ y[-1] <= y[0]
130
+ or (range_constraint and (np.any(y < 0) or np.any(y > 1)))
131
+ or np.isnan(y).any()
132
+ ):
133
+ continue
134
+ else:
135
+ break
136
+
137
+ def curve(): # generates a sample from the same model, but with independent noise
138
+ y_noisy = y + rng.normal(np.zeros_like(y), var)
139
+ return y, y_noisy
140
+
141
+ return curve
142
+
143
+
144
+ def generate_prior_dataset(n, prior=sample_prior_comb, seed=42):
145
+ """
146
+ Returns a fixed sample from the prior (with fixed seq_len) as an n x seq_len np.ndarray
147
+ """
148
+ rng = np.random.RandomState(seed)
149
+ prior_data = np.stack([prior(rng)()[1] for _ in range(n)])
150
+ return prior_data
151
+
152
+
153
+ def create_get_batch_func(prior):
154
+ return partial(get_batch_domhan, prior=prior)
155
+
156
+ # function producing batches for PFN training
157
+ def get_batch_domhan(
158
+ batch_size,
159
+ seq_len,
160
+ num_features,
161
+ prior,
162
+ device="cpu",
163
+ noisy_target=True,
164
+ **_,
165
+ ):
166
+ assert num_features == 1
167
+
168
+ x = np.arange(1, seq_len + 1)
169
+ y_target = np.empty((batch_size, seq_len), dtype=float)
170
+ y_noisy = np.empty((batch_size, seq_len), dtype=float)
171
+
172
+ for i in range(batch_size):
173
+ curve_func = prior(np.random, seq_len=seq_len) # uses numpy rng
174
+ if noisy_target:
175
+ _, y_noisy[i] = curve_func()
176
+ y_target[i] = y_noisy[i]
177
+ else:
178
+ y_target[i], y_noisy[i] = curve_func()
179
+
180
+ # turn numpy arrays into correctly shaped torch tensors & move them to device
181
+ x = (
182
+ torch.arange(1, seq_len + 1)
183
+ .repeat((num_features, batch_size, 1))
184
+ .transpose(2, 0)
185
+ .to(device)
186
+ )
187
+ y_target = torch.from_numpy(y_target).transpose(1, 0).to(device)
188
+ y_noisy = torch.from_numpy(y_noisy).transpose(1, 0).to(device)
189
+
190
+ # changes
191
+ x = x.float()
192
+ y_target = y_target.float()
193
+ y_noisy = y_noisy.float()
194
+
195
+ return x, y_noisy, y_target
lcpfn/__init__.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ sys.path.insert(0, os.path.dirname(__file__))
3
+
4
+
5
+ model_path = 'trained_models'
6
+
7
+ def prepare_models():
8
+ pfns4bo_dir = os.path.dirname(__file__)
9
+ model_names = ['pfn_EPOCH1000_EMSIZE512_NLAYERS12_NBUCKETS1000.pt',
10
+ 'pfn_EPOCH1000_EMSIZE512_NLAYERS6_NBUCKETS1000.pt']
11
+
12
+ for name in model_names:
13
+ weights_path = os.path.join(pfns4bo_dir, model_path, name)
14
+ compressed_weights_path = os.path.join(pfns4bo_dir, model_path, name + '.gz')
15
+ if not os.path.exists(weights_path):
16
+ if not os.path.exists(compressed_weights_path):
17
+ print("Downloading", os.path.abspath(compressed_weights_path))
18
+ import requests
19
+ url = f'https://github.com/automl/lcpfn/raw/main/lcpfn/trained_models/{name + ".gz"}'
20
+ r = requests.get(url, allow_redirects=True)
21
+ os.makedirs(os.path.dirname(compressed_weights_path), exist_ok=True)
22
+ with open(compressed_weights_path, 'wb') as f:
23
+ f.write(r.content)
24
+ if os.path.exists(compressed_weights_path):
25
+ print("Unzipping", name)
26
+ os.system(f"gzip -dk {compressed_weights_path}")
27
+ else:
28
+ print("Failed to find", compressed_weights_path)
29
+ print("Make sure you have an internet connection to download the model automatically..")
30
+ if os.path.exists(weights_path):
31
+ print("Successfully located model at", weights_path)
32
+
33
+
34
+ model_dict = {
35
+ 'EMSIZE512_NLAYERS12_NBUCKETS1000': os.path.join(os.path.dirname(__file__),model_path,
36
+ 'pfn_EPOCH1000_EMSIZE512_NLAYERS12_NBUCKETS1000.pt'),
37
+ 'EMSIZE512_NLAYERS6_NBUCKETS1000': os.path.join(os.path.dirname(__file__),model_path,
38
+ 'pfn_EPOCH1000_EMSIZE512_NLAYERS6_NBUCKETS1000.pt'),
39
+ }
40
+
41
+
42
+ def __getattr__(name):
43
+ if name in model_dict:
44
+ if not os.path.exists(model_dict[name]):
45
+ print("Can't find", os.path.abspath(model_dict[name]), "thus unzipping/downloading models now.")
46
+ print("This might take a while..")
47
+ prepare_models()
48
+ return model_dict[name]
49
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
50
+
51
+ from lcpfn.model import LCPFN
52
+ from lcpfn.train_lcpfn import train_lcpfn
53
+ from lcpfn.domhan_prior import sample_from_prior, create_get_batch_func
lcpfn/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (2.03 kB). View file
 
lcpfn/__pycache__/bar_distribution.cpython-310.pyc ADDED
Binary file (9.96 kB). View file
 
lcpfn/__pycache__/curves.cpython-310.pyc ADDED
Binary file (6.81 kB). View file
 
lcpfn/__pycache__/domhan_prior.cpython-310.pyc ADDED
Binary file (3.92 kB). View file
 
lcpfn/__pycache__/encoders.cpython-310.pyc ADDED
Binary file (8.02 kB). View file
 
lcpfn/__pycache__/layer.cpython-310.pyc ADDED
Binary file (4.64 kB). View file
 
lcpfn/__pycache__/model.cpython-310.pyc ADDED
Binary file (1.8 kB). View file
 
lcpfn/__pycache__/positional_encodings.cpython-310.pyc ADDED
Binary file (2.86 kB). View file
 
lcpfn/__pycache__/train.cpython-310.pyc ADDED
Binary file (13.5 kB). View file
 
lcpfn/__pycache__/train_lcpfn.cpython-310.pyc ADDED
Binary file (2.82 kB). View file
 
lcpfn/__pycache__/transformer.cpython-310.pyc ADDED
Binary file (8.04 kB). View file
 
lcpfn/__pycache__/utils.cpython-310.pyc ADDED
Binary file (10.7 kB). View file
 
lcpfn/bar_distribution.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class BarDistribution(nn.Module):
6
+ def __init__(self, borders: torch.Tensor, smoothing=.0): # here borders should start with min and end with max, where all values lie in (min,max) and are sorted
7
+ # sorted list of borders
8
+ super().__init__()
9
+ assert len(borders.shape) == 1
10
+ #self.borders = borders
11
+ self.register_buffer('borders', borders)
12
+ self.register_buffer('smoothing', torch.tensor(smoothing))
13
+ #self.bucket_widths = self.borders[1:] - self.borders[:-1]
14
+ self.register_buffer('bucket_widths', self.borders[1:] - self.borders[:-1])
15
+ full_width = self.bucket_widths.sum()
16
+ border_order = torch.argsort(borders)
17
+ assert (full_width - (self.borders[-1] - self.borders[0])).abs() < 1e-4, f'diff: {full_width - (self.borders[-1] - self.borders[0])}'
18
+ assert (border_order == torch.arange(len(borders)).to(border_order.device)).all(), "Please provide sorted borders!"
19
+ self.num_bars = len(borders) - 1
20
+
21
+ def map_to_bucket_idx(self, y):
22
+ target_sample = torch.searchsorted(self.borders, y) - 1
23
+ target_sample[y == self.borders[0]] = 0
24
+ target_sample[y == self.borders[-1]] = self.num_bars - 1
25
+ return target_sample
26
+
27
+ def forward(self, logits, y): # gives the negative log density (the _loss_), y: T x B, logits: T x B x self.num_bars
28
+ target_sample = self.map_to_bucket_idx(y)
29
+ assert (target_sample >= 0).all() and (target_sample < self.num_bars).all(), f'y {y} not in support set for borders (min_y, max_y) {self.borders}'
30
+ assert logits.shape[-1] == self.num_bars, f'{logits.shape[-1]} vs {self.num_bars}'
31
+
32
+ bucket_log_probs = torch.log_softmax(logits, -1)
33
+ scaled_bucket_log_probs = bucket_log_probs - torch.log(self.bucket_widths)
34
+ #print(bucket_log_probs, logits.shape)
35
+
36
+ nll_loss = -scaled_bucket_log_probs.gather(-1,target_sample.unsqueeze(-1)).squeeze(-1)
37
+
38
+ smooth_loss = -scaled_bucket_log_probs.mean(dim=-1)
39
+ smoothing = self.smoothing if self.training else 0.
40
+ loss = (1. - smoothing) * nll_loss + smoothing * smooth_loss
41
+ return loss
42
+
43
+ def mean(self, logits):
44
+ bucket_means = self.borders[:-1] + self.bucket_widths/2
45
+ p = torch.softmax(logits, -1)
46
+ return p @ bucket_means
47
+
48
+
49
+ def icdf(self, logits, left_prob):
50
+ """
51
+ Implementation of the quantile function
52
+ :param logits: Tensor of any shape, with the last dimension being logits
53
+ :param left_prob: float: The probability mass to the left of the result.
54
+ :return: Position with `left_prob` probability weight to the left.
55
+ """
56
+ probs = logits.softmax(-1)
57
+ cumprobs = torch.cumsum(probs, -1)
58
+ idx = torch.searchsorted(cumprobs, left_prob * torch.ones(*cumprobs.shape[:-1], 1, device = probs.device))\
59
+ .squeeze(-1).clamp(0, cumprobs.shape[-1] - 1) # this might not do the right for outliers
60
+ cumprobs = torch.cat(
61
+ [torch.zeros(*cumprobs.shape[:-1], 1, device=logits.device), cumprobs], -1
62
+ )
63
+
64
+ rest_prob = left_prob - cumprobs.gather(-1, idx[..., None]).squeeze(-1)
65
+ left_border = self.borders[idx]
66
+ right_border = self.borders[idx+1]
67
+ return left_border + (right_border - left_border) * rest_prob / probs.gather(-1, idx[..., None]).squeeze(-1)
68
+
69
+ def quantile(self, logits, center_prob=.682):
70
+ side_probs = (1.-center_prob)/2
71
+ return torch.stack((self.icdf(logits, side_probs), self.icdf(logits, 1.-side_probs)),-1)
72
+
73
+ def ucb(self, logits, best_f, rest_prob=(1-.682)/2, maximize=True):
74
+ """
75
+ UCB utility. Rest Prob is the amount of utility above (below) the confidence interval that is ignored.
76
+ Higher rest_prob is equivalent to lower beta in the standard GP-UCB formulation.
77
+ :param logits: Logits, as returned by the Transformer.
78
+ :param best_f: Only here, since the other utilities have it.
79
+ :param rest_prob: The amount of utility above (below) the confidence interval that is ignored.
80
+ The default is equivalent to using GP-UCB with `beta=1`.
81
+ To get the corresponding `beta`, where `beta` is from
82
+ the standard GP definition of UCB `ucb_utility = mean + beta * std`,
83
+ you can use this computation: `beta = math.sqrt(2)*torch.erfinv(torch.tensor(2*rest_prob-1))`.
84
+ :param maximize:
85
+ :return: utility
86
+ """
87
+ if maximize:
88
+ rest_prob = 1 - rest_prob
89
+ return self.icdf(logits, rest_prob)
90
+
91
+ def mode(self, logits):
92
+ mode_inds = logits.argmax(-1)
93
+ bucket_means = self.borders[:-1] + self.bucket_widths/2
94
+ return bucket_means[mode_inds]
95
+
96
+ def ei(self, logits, best_f, maximize=True): # logits: evaluation_points x batch x feature_dim
97
+ bucket_means = self.borders[:-1] + self.bucket_widths/2
98
+ if maximize:
99
+ bucket_contributions = torch.tensor(
100
+ [max((bucket_max + max(bucket_min, best_f)) / 2 - best_f,0) for
101
+ bucket_min, bucket_max, bucket_mean in zip(self.borders[:-1], self.borders[1:], bucket_means)], dtype=logits.dtype, device=logits.device)
102
+ else:
103
+ bucket_contributions = torch.tensor(
104
+ [-min((min(bucket_max,best_f) + bucket_min) / 2 - best_f,0) for # min on max instead of max on min, and compare min < instead of max >
105
+ bucket_min, bucket_max, bucket_mean in zip(self.borders[:-1], self.borders[1:], bucket_means)], dtype=logits.dtype, device=logits.device)
106
+ p = torch.softmax(logits, -1)
107
+ return p @ bucket_contributions
108
+
109
+ def pi(self, logits, best_f, maximize=True):# logits: evaluation_points x batch x feature_dim
110
+ """
111
+ Acquisition Function: Probability of Improvement
112
+ :param logits: as returned by Transformer
113
+ :param best_f: best evaluation so far (the incumbent)
114
+ :param maximize: whether to maximize
115
+ :return: utility
116
+ """
117
+ assert maximize is True
118
+ p = torch.softmax(logits, -1)
119
+ border_widths = self.borders[1:] - self.borders[:-1]
120
+ factor = 1. - ((best_f - self.borders[:-1]) / border_widths).clamp(0., 1.)
121
+ return (p * factor).sum(-1)
122
+
123
+
124
+ def mean_of_square(self, logits):
125
+ """
126
+ Computes E[x^2].
127
+ :param logits: Output of the model.
128
+ """
129
+ left_borders = self.borders[:-1]
130
+ right_borders = self.borders[1:]
131
+ bucket_mean_of_square = (left_borders.square() + right_borders.square() + left_borders*right_borders)/3.
132
+ p = torch.softmax(logits, -1)
133
+ return p @ bucket_mean_of_square
134
+
135
+ def variance(self, logits):
136
+ return self.mean_of_square(logits) - self.mean(logits).square()
137
+
138
+
139
+ class FullSupportBarDistribution(BarDistribution):
140
+ @staticmethod
141
+ def halfnormal_with_p_weight_before(range_max,p=.5):
142
+ s = range_max / torch.distributions.HalfNormal(torch.tensor(1.)).icdf(torch.tensor(p))
143
+ return torch.distributions.HalfNormal(s)
144
+
145
+ def forward(self, logits, y): # gives the negative log density (the _loss_), y: T x B, logits: T x B x self.num_bars
146
+ assert self.num_bars > 1
147
+ target_sample = self.map_to_bucket_idx(y)
148
+ target_sample.clamp_(0,self.num_bars-1)
149
+ assert logits.shape[-1] == self.num_bars
150
+
151
+ bucket_log_probs = torch.log_softmax(logits, -1)
152
+ scaled_bucket_log_probs = bucket_log_probs - torch.log(self.bucket_widths)
153
+ #print(bucket_log_probs, logits.shape)
154
+ log_probs = scaled_bucket_log_probs.gather(-1,target_sample.unsqueeze(-1)).squeeze(-1)
155
+
156
+ side_normals = (self.halfnormal_with_p_weight_before(self.bucket_widths[0]), self.halfnormal_with_p_weight_before(self.bucket_widths[-1]))
157
+
158
+
159
+ # TODO look over it again
160
+ log_probs[target_sample == 0] += side_normals[0].log_prob((self.borders[1]-y[target_sample == 0]).clamp(min=.00000001)) + torch.log(self.bucket_widths[0])
161
+ log_probs[target_sample == self.num_bars-1] += side_normals[1].log_prob(y[target_sample == self.num_bars-1]-self.borders[-2]) + torch.log(self.bucket_widths[-1])
162
+
163
+ nll_loss = -log_probs
164
+
165
+ smooth_loss = -scaled_bucket_log_probs.mean(dim=-1)
166
+ smoothing = self.smoothing if self.training else 0.
167
+ loss = (1. - smoothing) * nll_loss + smoothing * smooth_loss
168
+
169
+
170
+ return loss
171
+
172
+ def mean(self, logits):
173
+ bucket_means = self.borders[:-1] + self.bucket_widths / 2
174
+ p = torch.softmax(logits, -1)
175
+ side_normals = (self.halfnormal_with_p_weight_before(self.bucket_widths[0]),
176
+ self.halfnormal_with_p_weight_before(self.bucket_widths[-1]))
177
+ bucket_means[0] = -side_normals[0].mean + self.borders[1]
178
+ bucket_means[-1] = side_normals[1].mean + self.borders[-2]
179
+ return p @ bucket_means
180
+
181
+
182
+
183
+ def get_bucket_limits_(num_outputs:int, full_range:tuple=None, ys:torch.Tensor=None, verbose:bool=False):
184
+ assert (ys is not None) or (full_range is not None)
185
+ if ys is not None:
186
+ ys = ys.flatten()
187
+ if len(ys) % num_outputs: ys = ys[:-(len(ys) % num_outputs)]
188
+ print(f'Using {len(ys)} y evals to estimate {num_outputs} buckets. Cut off the last {len(ys) % num_outputs} ys.')
189
+ ys_per_bucket = len(ys) // num_outputs
190
+ if full_range is None:
191
+ full_range = (ys.min(), ys.max())
192
+ else:
193
+ assert full_range[0] <= ys.min() and full_range[1] >= ys.max()
194
+ full_range = torch.tensor(full_range)
195
+ ys_sorted, ys_order = ys.sort(0)
196
+ bucket_limits = (ys_sorted[ys_per_bucket-1::ys_per_bucket][:-1]+ys_sorted[ys_per_bucket::ys_per_bucket])/2
197
+ if verbose:
198
+ print(f'Using {len(ys)} y evals to estimate {num_outputs} buckets. Cut off the last {len(ys) % num_outputs} ys.')
199
+ print(full_range)
200
+ bucket_limits = torch.cat([full_range[0].unsqueeze(0), bucket_limits, full_range[1].unsqueeze(0)],0)
201
+
202
+ else:
203
+ class_width = (full_range[1] - full_range[0]) / num_outputs
204
+ bucket_limits = torch.cat([full_range[0] + torch.arange(num_outputs).float()*class_width, torch.tensor(full_range[1]).unsqueeze(0)], 0)
205
+
206
+ assert len(bucket_limits) - 1 == num_outputs and full_range[0] == bucket_limits[0] and full_range[-1] == bucket_limits[-1]
207
+ return bucket_limits
208
+
209
+
210
+ def get_bucket_limits(
211
+ num_outputs: int,
212
+ full_range: tuple = None,
213
+ ys: torch.Tensor = None,
214
+ verbose: bool = False,
215
+ ):
216
+ assert (ys is None) != (
217
+ full_range is None
218
+ ), "Either full_range or ys must be passed."
219
+
220
+ if ys is not None:
221
+ ys = ys.flatten()
222
+ ys = ys[~torch.isnan(ys)]
223
+ if len(ys) % num_outputs:
224
+ ys = ys[: -(len(ys) % num_outputs)]
225
+ print(
226
+ f"Using {len(ys)} y evals to estimate {num_outputs} buckets. Cut off the last {len(ys) % num_outputs} ys."
227
+ )
228
+ ys_per_bucket = len(ys) // num_outputs
229
+ if full_range is None:
230
+ full_range = (ys.min(), ys.max())
231
+ else:
232
+ assert (
233
+ full_range[0] <= ys.min() and full_range[1] >= ys.max()
234
+ ), f"full_range {full_range} not in range of ys {ys.min(), ys.max()}"
235
+ full_range = torch.tensor(full_range)
236
+ ys_sorted, ys_order = ys.sort(0)
237
+ bucket_limits = (
238
+ ys_sorted[ys_per_bucket - 1 :: ys_per_bucket][:-1]
239
+ + ys_sorted[ys_per_bucket::ys_per_bucket]
240
+ ) / 2
241
+ if verbose:
242
+ print(
243
+ f"Using {len(ys)} y evals to estimate {num_outputs} buckets. Cut off the last {len(ys) % num_outputs} ys."
244
+ )
245
+ print(full_range)
246
+ bucket_limits = torch.cat(
247
+ [full_range[0].unsqueeze(0), bucket_limits, full_range[1].unsqueeze(0)], 0
248
+ )
249
+
250
+ else:
251
+ class_width = (full_range[1] - full_range[0]) / num_outputs
252
+ bucket_limits = torch.cat(
253
+ [
254
+ full_range[0] + torch.arange(num_outputs).float() * class_width,
255
+ torch.tensor(full_range[1]).unsqueeze(0),
256
+ ],
257
+ 0,
258
+ )
259
+
260
+ assert (
261
+ len(bucket_limits) - 1 == num_outputs
262
+ ), f"len(bucket_limits) - 1 == {len(bucket_limits) - 1} != {num_outputs} == num_outputs"
263
+ assert full_range[0] == bucket_limits[0], f"{full_range[0]} != {bucket_limits[0]}"
264
+ assert (
265
+ full_range[-1] == bucket_limits[-1]
266
+ ), f"{full_range[-1]} != {bucket_limits[-1]}"
267
+
268
+ return bucket_limits
269
+
lcpfn/curves.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from collections import OrderedDict
3
+
4
+ prior = {
5
+ "pow3": {
6
+ "uniform": OrderedDict(
7
+ a={"type": "uniform", "param1": -1, "param2": 1},
8
+ c={"type": "uniform", "param1": 0, "param2": 1},
9
+ alpha={"type": "uniform", "param1": 0, "param2": 1},
10
+ ),
11
+ "peaked": OrderedDict(
12
+ a={"type": "uniform", "param1": -0.6, "param2": 0.6},
13
+ c={"type": "uniform", "param1": 0, "param2": 1.25},
14
+ alpha={"type": "log_normal", "param1": 0, "param2": 2},
15
+ ),
16
+ },
17
+ "ilog2": {
18
+ "uniform": OrderedDict(
19
+ c={"type": "uniform", "param1": 0, "param2": 1},
20
+ a={"type": "uniform", "param1": -1, "param2": 1},
21
+ ),
22
+ "peaked": OrderedDict(
23
+ c={"type": "uniform", "param1": 0, "param2": 1},
24
+ a={"type": "uniform", "param1": -0.5, "param2": 0.5},
25
+ ),
26
+ },
27
+ "janoschek": {
28
+ "uniform": OrderedDict(
29
+ a={"type": "uniform", "param1": 0, "param2": 1},
30
+ beta={"type": "uniform", "param1": 0, "param2": 2},
31
+ k={"type": "uniform", "param1": 0, "param2": 1},
32
+ delta={"type": "uniform", "param1": -5, "param2": 5},
33
+ ),
34
+ "peaked": OrderedDict(
35
+ a={"type": "uniform", "param1": 0, "param2": 1},
36
+ beta={"type": "uniform", "param1": 0, "param2": 2},
37
+ k={"type": "log_normal", "param1": -2, "param2": 1},
38
+ delta={"type": "log_normal", "param1": 0, "param2": 0.5},
39
+ ),
40
+ },
41
+ }
42
+
43
+
44
+ def prior_sampler(rng, type, param1, param2):
45
+ if type == "uniform":
46
+ return rng.uniform(param1, param2)
47
+ elif type == "log_normal":
48
+ return rng.lognormal(param1, param2)
49
+ raise Exception("Unknown prior type: {}".format(type))
50
+
51
+
52
+ def pow3(x, c, a, alpha):
53
+ return c - a * (x) ** (-alpha)
54
+
55
+
56
+ def prior_pow3(rng):
57
+ return {
58
+ p: prior_sampler(
59
+ rng,
60
+ prior["pow3"]["peaked"][p]["type"],
61
+ param1=prior["pow3"]["peaked"][p]["param1"],
62
+ param2=prior["pow3"]["peaked"][p]["param2"],
63
+ )
64
+ for p in ["a", "c", "alpha"]
65
+ }
66
+
67
+
68
+ def uniform_prior_pow3(rng):
69
+ return {
70
+ p: prior_sampler(
71
+ rng,
72
+ prior["pow3"]["uniform"][p]["type"],
73
+ param1=prior["pow3"]["uniform"][p]["param1"],
74
+ param2=prior["pow3"]["uniform"][p]["param2"],
75
+ )
76
+ for p in ["a", "c", "alpha"]
77
+ }
78
+
79
+
80
+ def ilog2(x, c, a):
81
+ return c - a / (np.log(x + 1))
82
+
83
+
84
+ def prior_ilog2(rng):
85
+ return {
86
+ p: prior_sampler(
87
+ rng,
88
+ prior["ilog2"]["peaked"][p]["type"],
89
+ param1=prior["ilog2"]["peaked"][p]["param1"],
90
+ param2=prior["ilog2"]["peaked"][p]["param2"],
91
+ )
92
+ for p in ["a", "c"]
93
+ }
94
+
95
+
96
+ def uniform_prior_ilog2(rng):
97
+ return {
98
+ p: prior_sampler(
99
+ rng,
100
+ prior["ilog2"]["uniform"][p]["type"],
101
+ param1=prior["ilog2"]["uniform"][p]["param1"],
102
+ param2=prior["ilog2"]["uniform"][p]["param2"],
103
+ )
104
+ for p in ["a", "c"]
105
+ }
106
+
107
+
108
+ def janoschek(x, a, beta, k, delta):
109
+ """
110
+ http://www.pisces-conservation.com/growthhelp/janoschek.htm
111
+ """
112
+ return a - (a - beta) * np.exp(-k * x**delta)
113
+
114
+
115
+ def prior_janoschek(rng):
116
+ return {
117
+ p: prior_sampler(
118
+ rng,
119
+ prior["janoschek"]["peaked"][p]["type"],
120
+ param1=prior["janoschek"]["peaked"][p]["param1"],
121
+ param2=prior["janoschek"]["peaked"][p]["param2"],
122
+ )
123
+ for p in ["a", "beta", "k", "delta"]
124
+ }
125
+
126
+
127
+ def uniform_prior_janoschek(rng):
128
+ return {
129
+ p: prior_sampler(
130
+ rng,
131
+ prior["janoschek"]["uniform"][p]["type"],
132
+ param1=prior["janoschek"]["uniform"][p]["param1"],
133
+ param2=prior["janoschek"]["uniform"][p]["param2"],
134
+ )
135
+ for p in ["a", "beta", "k", "delta"]
136
+ }
137
+
138
+
139
+ def log_power(x, a, b, c):
140
+ # a: upper bound
141
+ # c: growth rate
142
+ # initial = a/ (1 + (1/e^b)^c
143
+ return a / (1.0 + (x / np.exp(b)) ** c)
144
+
145
+
146
+ def prior_log_power(rng):
147
+ # a ~ N(0.8,0.1)
148
+ # b ~ N(1,1)
149
+ # c ~ U(-3,0)
150
+ a = rng.normal(0.8, 0.1)
151
+ b = rng.normal(1.0, 1.0)
152
+ c = rng.uniform(-3.0, 0.0)
153
+ return {"a": a, "b": b, "c": c}
154
+
155
+
156
+ def weibull(x, alpha, beta, kappa, delta):
157
+ """
158
+ Weibull modell
159
+ http://www.pisces-conservation.com/growthhelp/index.html?morgan_mercer_floden.htm
160
+ alpha: upper asymptote
161
+ beta: lower asymptote
162
+ k: growth rate
163
+ delta: controls the x-ordinate for the point of inflection
164
+ """
165
+ return alpha - (alpha - beta) * np.exp(-((kappa * x) ** delta))
166
+
167
+
168
+ def prior_weibull(rng):
169
+ alpha = rng.uniform(0.0, 1.5)
170
+ beta = rng.uniform(0.0, 1)
171
+ kappa = np.exp(rng.normal(-2.0, 1.0))
172
+ delta = np.exp(rng.normal(0, 0.5))
173
+ return {"alpha": alpha, "beta": beta, "kappa": kappa, "delta": delta}
174
+
175
+
176
+ def mmf(x, alpha, beta, kappa, delta):
177
+ """
178
+ Morgan-Mercer-Flodin
179
+ description:
180
+ Nonlinear Regression page 342
181
+ http://bit.ly/1jodG17
182
+ http://www.pisces-conservation.com/growthhelp/index.html?morgan_mercer_floden.htm
183
+ alpha: upper asymptote
184
+ kappa: growth rate
185
+ beta: initial value
186
+ delta: controls the point of inflection
187
+ """
188
+ return alpha - (alpha - beta) / (1.0 + (kappa * x) ** delta)
189
+
190
+
191
+ def prior_mmf(rng):
192
+ # alpha ~ N(0.8,0.1)
193
+ # beta ~ N(0.2,0.1)
194
+ # ln(kappa) ~ N(0,2)
195
+ # ln(delta) ~ N(0,1)
196
+ alpha = rng.normal(0.8, 0.1)
197
+ beta = rng.normal(0.2, 0.1)
198
+ kappa = np.exp(rng.normal(0, 2))
199
+ delta = np.exp(rng.normal(0, 1))
200
+ return {"alpha": alpha, "beta": beta, "kappa": kappa, "delta": delta}
201
+
202
+
203
+ def vap(x, a, b, c):
204
+ """Vapor pressure model"""
205
+ # no upper bound if c > 0
206
+ # a = ln(upper bound) for c=0
207
+ # a+b = ln(initial)
208
+ return np.exp(a + b / x + c * np.log(x))
209
+
210
+
211
+ def prior_vap(rng):
212
+ a = rng.uniform(-2.0, 0.0) # @heri: range check
213
+ b = rng.uniform(-4.0, 0.0) # @heri: range check
214
+ c = np.exp(rng.uniform(-8.0, 0.0)) # @heri: same as weights
215
+ return {"a": a, "b": b, "c": c}
216
+
217
+
218
+ def loglog_linear(x, a, b):
219
+ x = np.log(x)
220
+ return np.log(a * x + b)
221
+
222
+
223
+ def prior_loglog_linear(rng):
224
+ # ln(a) ~ N(-2, 1)
225
+ # ln(b) ~ U(0, 1)
226
+ a = np.exp(rng.normal(-2.0, 1.0))
227
+ b = np.exp(rng.uniform(0.0, 1.0))
228
+ return {"a": a, "b": b}
229
+
230
+
231
+ def exp4(x, c, a, b, alpha):
232
+ return c - np.exp(-a * (x**alpha) + b)
233
+
234
+
235
+ def prior_exp4(rng):
236
+ # c ~ N(0.8,0.1)
237
+ c = rng.normal(0.8, 0.1)
238
+ # ln(a) ~ N(-2,1)
239
+ a = np.exp(rng.normal(-2, 1))
240
+ # ln(alpha) ~ N(0,1)
241
+ alpha = np.exp(rng.normal(0, 1))
242
+ # ln(b) ~ N(0,0.5)
243
+ b = np.exp(rng.normal(0, 0.5))
244
+ return {"a": a, "b": b, "c": c, "alpha": alpha}
245
+
246
+
247
+ def pow4(x, c, a, b, alpha):
248
+ return c - (a * x + b) ** -alpha
249
+
250
+
251
+ def prior_pow4(rng):
252
+ # ln(1 - c) ~ U(-5, 0)
253
+ c = 1 - np.exp(rng.uniform(-5.0, 0))
254
+ # ln(a) ~ N(-3, 2)
255
+ a = np.exp(rng.normal(-3.0, 2))
256
+ # ln(alpha) ~ N(0,1)
257
+ alpha = np.exp(rng.normal(0, 1))
258
+ # ln(b) ~ U(0, 1)
259
+ b = np.exp(rng.uniform(0, 1))
260
+ return {"a": a, "b": b, "c": c, "alpha": alpha}
261
+
262
+
263
+ def dr_hill_zero_background(x, theta, eta, kappa):
264
+ # theta: upper bound
265
+ # eta: growth rate
266
+ # initial = theta/(kappa^eta + 1)
267
+ return (theta * x**eta) / (kappa**eta + x**eta)
268
+
269
+
270
+ def prior_dr_hill_zero_background(rng):
271
+ # theta ~ U(1,0) N(0.8,0.1)
272
+ # ln(eta) ~ N(1,1)
273
+ # ln(kappa) ~ N(1,2)
274
+ theta = rng.normal(0.8, 0.1)
275
+ eta = np.exp(rng.normal(1.0, 1.0))
276
+ kappa = np.exp(rng.normal(1.0, 2.0))
277
+ return {"theta": theta, "eta": eta, "kappa": kappa}
lcpfn/decoders.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import random
4
+
5
+
6
+ class ScaledDecoder(nn.Module):
7
+ def __init__(self, ninp, nhid, nout):
8
+ super().__init__()
9
+ self.linear = nn.Linear(ninp, nhid)
10
+ self.linear1 = nn.Linear(nhid, nout)
11
+ self.linear2 = nn.Linear(nhid, 10)
12
+
13
+ def forward(self, x):
14
+ #return torch.cat([self.linear1(x), self.linear2(x)], -1)
15
+ x = self.linear(x)
16
+ x = nn.GELU()(x)
17
+ temps = self.linear2(x).softmax(-1) @ torch.tensor([1.,1.4,1.7,2.,5.,10.,20.,40.,80.,160.], device=x.device)
18
+ if random.random() > .99:
19
+ print(temps.shape,temps[:,:2])
20
+ return self.linear1(x) / temps.unsqueeze(-1)
21
+
22
+ class FixedScaledDecoder(nn.Module):
23
+ def __init__(self, ninp, nhid, nout):
24
+ super().__init__()
25
+ self.mapper = nn.Sequential(nn.Linear(ninp, nhid), nn.GELU(), nn.Linear(nhid, nout))
26
+ self.T = nn.Parameter(torch.ones(10000)/10000)
27
+
28
+ def forward(self, x):
29
+ return self.mapper(x)/self.T.sum()
30
+
lcpfn/domhan_prior.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import torch
3
+ import numpy as np
4
+ from lcpfn.curves import (
5
+ pow3,
6
+ ilog2,
7
+ janoschek,
8
+ log_power,
9
+ prior_ilog2,
10
+ uniform_prior_pow3,
11
+ weibull,
12
+ mmf,
13
+ vap,
14
+ loglog_linear,
15
+ exp4,
16
+ pow4,
17
+ dr_hill_zero_background,
18
+ )
19
+ from lcpfn.curves import (
20
+ prior_pow3,
21
+ prior_janoschek,
22
+ prior_log_power,
23
+ prior_weibull,
24
+ prior_mmf,
25
+ prior_vap,
26
+ prior_loglog_linear,
27
+ prior_exp4,
28
+ prior_pow4,
29
+ prior_dr_hill_zero_background,
30
+ )
31
+ from lcpfn.curves import (
32
+ uniform_prior_pow3,
33
+ uniform_prior_ilog2,
34
+ uniform_prior_janoschek,
35
+ )
36
+
37
+
38
+ def prior_weights(
39
+ rng,
40
+ components=[
41
+ "pow3",
42
+ "ilog2",
43
+ "janoschek",
44
+ "log_power",
45
+ "weibull",
46
+ "mmf",
47
+ "vap",
48
+ "loglog_linear",
49
+ "exp4",
50
+ "pow4",
51
+ "dr_hill_zero_background",
52
+ ],
53
+ ):
54
+ K = len(components)
55
+ weights = rng.uniform(0.0, 1, size=(K,))
56
+ return {f: weights[i] for i, f in enumerate(components)}
57
+
58
+
59
+ def sample_from_prior(rng, seq_len=100):
60
+ return sample_prior_comb(
61
+ rng=rng, seq_len=seq_len, components=["pow3", "ilog2", "janoschek"], distribution="peaked"
62
+ )
63
+
64
+
65
+ def sample_prior_comb(
66
+ rng,
67
+ components,
68
+ distribution,
69
+ var_lnloc=-4,
70
+ var_lnscale=1,
71
+ range_constraint=True,
72
+ seq_len=100,
73
+ ):
74
+ f_components = {
75
+ "pow3": pow3,
76
+ "ilog2": ilog2,
77
+ "janoschek": janoschek,
78
+ "log_power": log_power,
79
+ "weibull": weibull,
80
+ "mmf": mmf,
81
+ "vap": vap,
82
+ "loglog_linear": loglog_linear,
83
+ "exp4": exp4,
84
+ "pow4": pow4,
85
+ "dr_hill_zero_background": dr_hill_zero_background,
86
+ }
87
+
88
+ if distribution == "peaked":
89
+ f_priors = {
90
+ "pow3": prior_pow3,
91
+ "ilog2": prior_ilog2,
92
+ "janoschek": prior_janoschek,
93
+ "log_power": prior_log_power,
94
+ "weibull": prior_weibull,
95
+ "mmf": prior_mmf,
96
+ "vap": prior_vap,
97
+ "loglog_linear": prior_loglog_linear,
98
+ "exp4": prior_exp4,
99
+ "pow4": prior_pow4,
100
+ "dr_hill_zero_background": prior_dr_hill_zero_background,
101
+ }
102
+ elif distribution == "uniform":
103
+ f_priors = {
104
+ "pow3": uniform_prior_pow3,
105
+ "ilog2": uniform_prior_ilog2,
106
+ "janoschek": uniform_prior_janoschek
107
+ }
108
+ else:
109
+ raise NotImplemented()
110
+
111
+ x = np.arange(1, seq_len + 1)
112
+
113
+ while True:
114
+ # sample the noiseless curve
115
+ weights = prior_weights(rng, components=components)
116
+ y = np.zeros(x.shape, dtype="float")
117
+ kwargs = 0
118
+ for f, w in weights.items():
119
+ kwargs = f_priors[f](rng)
120
+ # print(f_components[f](x, **kwargs))
121
+ y += w * f_components[f](x, **kwargs)
122
+ # add noise (can exceed [0,1], but afaik no way to implement this prior in Tobis work)
123
+ var = np.exp(
124
+ rng.normal(var_lnloc, var_lnscale)
125
+ ) # @heri: ln_prob =+ log(normal.pdf(log(var), loc=var_lnloc, scale=var_lnscale))
126
+
127
+ # reject any curves that are non-increasing, exceed the [0,1] range
128
+ if (
129
+ y[-1] <= y[0]
130
+ or (range_constraint and (np.any(y < 0) or np.any(y > 1)))
131
+ or np.isnan(y).any()
132
+ ):
133
+ continue
134
+ else:
135
+ break
136
+
137
+ def curve(): # generates a sample from the same model, but with independent noise
138
+ y_noisy = y + rng.normal(np.zeros_like(y), var)
139
+ return y, y_noisy
140
+
141
+ return curve
142
+
143
+
144
+ def generate_prior_dataset(n, prior=sample_prior_comb, seed=42):
145
+ """
146
+ Returns a fixed sample from the prior (with fixed seq_len) as an n x seq_len np.ndarray
147
+ """
148
+ rng = np.random.RandomState(seed)
149
+ prior_data = np.stack([prior(rng)()[1] for _ in range(n)])
150
+ return prior_data
151
+
152
+
153
+ def create_get_batch_func(prior):
154
+ return partial(get_batch_domhan, prior=prior)
155
+
156
+ # function producing batches for PFN training
157
+ def get_batch_domhan(
158
+ batch_size,
159
+ seq_len,
160
+ num_features,
161
+ prior,
162
+ device="cpu",
163
+ noisy_target=True,
164
+ **_,
165
+ ):
166
+ assert num_features == 1
167
+
168
+ x = np.arange(1, seq_len + 1)
169
+ y_target = np.empty((batch_size, seq_len), dtype=float)
170
+ y_noisy = np.empty((batch_size, seq_len), dtype=float)
171
+
172
+ for i in range(batch_size):
173
+ curve_func = prior(np.random, seq_len=seq_len) # uses numpy rng
174
+ if noisy_target:
175
+ _, y_noisy[i] = curve_func()
176
+ y_target[i] = y_noisy[i]
177
+ else:
178
+ y_target[i], y_noisy[i] = curve_func()
179
+
180
+ # turn numpy arrays into correctly shaped torch tensors & move them to device
181
+ x = (
182
+ torch.arange(1, seq_len + 1)
183
+ .repeat((num_features, batch_size, 1))
184
+ .transpose(2, 0)
185
+ .to(device)
186
+ )
187
+ y_target = torch.from_numpy(y_target).transpose(1, 0).to(device)
188
+ y_noisy = torch.from_numpy(y_noisy).transpose(1, 0).to(device)
189
+
190
+ # changes
191
+ x = x.float()
192
+ y_target = y_target.float()
193
+ y_noisy = y_noisy.float()
194
+
195
+ return x, y_noisy, y_target
lcpfn/encoders.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from lcpfn.utils import normalize_data
6
+ import torch.nn.functional as F
7
+ from torch.nn import TransformerEncoder, TransformerEncoderLayer
8
+
9
+
10
+ class StyleEncoder(nn.Module):
11
+ def __init__(self, em_size, hyperparameter_definitions):
12
+ super().__init__()
13
+ self.em_size = em_size
14
+ self.embedding = nn.Linear(hyperparameter_definitions.shape[0], self.em_size)
15
+
16
+ def forward(self, hyperparameters): # T x B x num_hps
17
+ return self.embedding(hyperparameters)
18
+
19
+
20
+ class _PositionalEncoding(nn.Module):
21
+ def __init__(self, d_model, dropout=0.):
22
+ super().__init__()
23
+ self.dropout = nn.Dropout(p=dropout)
24
+ self.d_model = d_model
25
+ self.device_test_tensor = nn.Parameter(torch.tensor(1.))
26
+
27
+ def forward(self, x):# T x B x num_features
28
+ assert self.d_model % x.shape[-1]*2 == 0
29
+ d_per_feature = self.d_model // x.shape[-1]
30
+ pe = torch.zeros(*x.shape, d_per_feature, device=self.device_test_tensor.device)
31
+ #position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
32
+ interval_size = 10
33
+ div_term = (1./interval_size) * 2*math.pi*torch.exp(torch.arange(0, d_per_feature, 2, device=self.device_test_tensor.device).float()*math.log(math.sqrt(2)))
34
+ #print(div_term/2/math.pi)
35
+ pe[..., 0::2] = torch.sin(x.unsqueeze(-1) * div_term)
36
+ pe[..., 1::2] = torch.cos(x.unsqueeze(-1) * div_term)
37
+ return self.dropout(pe).view(x.shape[0],x.shape[1],self.d_model)
38
+
39
+
40
+ Positional = lambda _, emsize: _PositionalEncoding(d_model=emsize)
41
+
42
+ class EmbeddingEncoder(nn.Module):
43
+ def __init__(self, num_features, em_size, num_embs=100):
44
+ super().__init__()
45
+ self.num_embs = num_embs
46
+ self.embeddings = nn.Embedding(num_embs * num_features, em_size, max_norm=True)
47
+ self.init_weights(.1)
48
+ self.min_max = (-2,+2)
49
+
50
+ @property
51
+ def width(self):
52
+ return self.min_max[1] - self.min_max[0]
53
+
54
+ def init_weights(self, initrange):
55
+ self.embeddings.weight.data.uniform_(-initrange, initrange)
56
+
57
+ def discretize(self, x):
58
+ split_size = self.width / self.num_embs
59
+ return (x - self.min_max[0] // split_size).int().clamp(0, self.num_embs - 1)
60
+
61
+ def forward(self, x): # T x B x num_features
62
+ x_idxs = self.discretize(x)
63
+ x_idxs += torch.arange(x.shape[-1], device=x.device).view(1, 1, -1) * self.num_embs
64
+ # print(x_idxs,self.embeddings.weight.shape)
65
+ return self.embeddings(x_idxs).mean(-2)
66
+
67
+
68
+ class Normalize(nn.Module):
69
+ def __init__(self, mean, std):
70
+ super().__init__()
71
+ self.mean = mean
72
+ self.std = std
73
+
74
+ def forward(self, x):
75
+ return (x-self.mean)/self.std
76
+
77
+
78
+ def get_normalized_uniform_encoder(encoder_creator):
79
+ """
80
+ This can be used to wrap an encoder that is fed uniform samples in [0,1] and normalizes these to 0 mean and 1 std.
81
+ For example, it can be used as `encoder_creator = get_normalized_uniform_encoder(encoders.Linear)`, now this can
82
+ be initialized with `encoder_creator(feature_dim, in_dim)`.
83
+ :param encoder:
84
+ :return:
85
+ """
86
+ return lambda in_dim, out_dim: nn.Sequential(Normalize(.5, math.sqrt(1/12)), encoder_creator(in_dim, out_dim))
87
+
88
+
89
+ Linear = nn.Linear
90
+ MLP = lambda num_features, emsize: nn.Sequential(nn.Linear(num_features+1,emsize*2),
91
+ nn.ReLU(),
92
+ nn.Linear(emsize*2,emsize))
93
+
94
+ class NanHandlingEncoder(nn.Module):
95
+ def __init__(self, num_features, emsize, keep_nans=True):
96
+ super().__init__()
97
+ self.num_features = 2 * num_features if keep_nans else num_features
98
+ self.emsize = emsize
99
+ self.keep_nans = keep_nans
100
+ self.layer = nn.Linear(self.num_features, self.emsize)
101
+
102
+ def forward(self, x):
103
+ if self.keep_nans:
104
+ x = torch.cat([torch.nan_to_num(x, nan=0.0), normalize_data(torch.isnan(x) * -1
105
+ + torch.logical_and(torch.isinf(x), torch.sign(x) == 1) * 1
106
+ + torch.logical_and(torch.isinf(x), torch.sign(x) == -1) * 2
107
+ )], -1)
108
+ else:
109
+ x = torch.nan_to_num(x, nan=0.0)
110
+ return self.layer(x)
111
+
112
+
113
+ class Linear(nn.Linear):
114
+ def __init__(self, num_features, emsize):
115
+ super().__init__(num_features, emsize)
116
+ self.num_features = num_features
117
+ self.emsize = emsize
118
+
119
+ def forward(self, x):
120
+ x = torch.nan_to_num(x, nan=0.0)
121
+ return super().forward(x)
122
+
123
+
124
+ class Conv(nn.Module):
125
+ def __init__(self, input_size, emsize):
126
+ super().__init__()
127
+ self.convs = torch.nn.ModuleList([nn.Conv2d(64 if i else 1, 64, 3) for i in range(5)])
128
+ self.linear = nn.Linear(64,emsize)
129
+
130
+ def forward(self, x):
131
+ size = math.isqrt(x.shape[-1])
132
+ assert size*size == x.shape[-1]
133
+ x = x.reshape(*x.shape[:-1], 1, size, size)
134
+ for conv in self.convs:
135
+ if x.shape[-1] < 4:
136
+ break
137
+ x = conv(x)
138
+ x.relu_()
139
+ x = nn.AdaptiveAvgPool2d((1,1))(x).squeeze(-1).squeeze(-1)
140
+ return self.linear(x)
141
+
142
+
143
+ class CanEmb(nn.Embedding):
144
+ def __init__(self, num_features, num_embeddings: int, embedding_dim: int, *args, **kwargs):
145
+ assert embedding_dim % num_features == 0
146
+ embedding_dim = embedding_dim // num_features
147
+ super().__init__(num_embeddings, embedding_dim, *args, **kwargs)
148
+
149
+ def forward(self, x):
150
+ lx = x.long()
151
+ assert (lx == x).all(), "CanEmb only works with tensors of whole numbers"
152
+ x = super().forward(lx)
153
+ return x.view(*x.shape[:-2], -1)
154
+
155
+
156
+ def get_Canonical(num_classes):
157
+ return lambda num_features, emsize: CanEmb(num_features, num_classes, emsize)
158
+
159
+
160
+ def get_Embedding(num_embs_per_feature=100):
161
+ return lambda num_features, emsize: EmbeddingEncoder(num_features, emsize, num_embs=num_embs_per_feature)
lcpfn/initializers.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ def get_NormalInitializer(std):
5
+ def initializer(m):
6
+ if isinstance(m, nn.Linear):
7
+ nn.init.normal_(m.weight, 0, std)
8
+ nn.init.normal_(m.bias, 0, std)
9
+ return initializer
lcpfn/layer.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Optional
3
+ from torch import Tensor
4
+ from torch import nn
5
+ from torch.nn.modules.transformer import *
6
+ from torch.nn.modules.transformer import _get_activation_fn
7
+
8
+ from torch.utils.checkpoint import checkpoint
9
+
10
+
11
+ class TransformerEncoderLayer(nn.Module):
12
+ r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
13
+ This standard encoder layer is based on the paper "Attention Is All You Need".
14
+ Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
15
+ Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
16
+ Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
17
+ in a different way during application.
18
+
19
+ Args:
20
+ d_model: the number of expected features in the input (required).
21
+ nhead: the number of heads in the multiheadattention models (required).
22
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
23
+ dropout: the dropout value (default=0.1).
24
+ activation: the activation function of intermediate layer, relu or gelu (default=relu).
25
+ layer_norm_eps: the eps value in layer normalization components (default=1e-5).
26
+ batch_first: If ``True``, then the input and output tensors are provided
27
+ as (batch, seq, feature). Default: ``False``.
28
+
29
+ Examples::
30
+ >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
31
+ >>> src = torch.rand(10, 32, 512)
32
+ >>> out = encoder_layer(src)
33
+
34
+ Alternatively, when ``batch_first`` is ``True``:
35
+ >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
36
+ >>> src = torch.rand(32, 10, 512)
37
+ >>> out = encoder_layer(src)
38
+ """
39
+ __constants__ = ['batch_first']
40
+
41
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu",
42
+ layer_norm_eps=1e-5, batch_first=False, pre_norm=False,
43
+ device=None, dtype=None, recompute_attn=False) -> None:
44
+ factory_kwargs = {'device': device, 'dtype': dtype}
45
+ super().__init__()
46
+ self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
47
+ **factory_kwargs)
48
+ # Implementation of Feedforward model
49
+ self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
50
+ self.dropout = Dropout(dropout)
51
+ self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)
52
+
53
+ self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
54
+ self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
55
+ self.dropout1 = Dropout(dropout)
56
+ self.dropout2 = Dropout(dropout)
57
+ self.pre_norm = pre_norm
58
+ self.recompute_attn = recompute_attn
59
+
60
+ self.activation = _get_activation_fn(activation)
61
+
62
+ def __setstate__(self, state):
63
+ if 'activation' not in state:
64
+ state['activation'] = F.relu
65
+ super().__setstate__(state)
66
+
67
+ def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
68
+ r"""Pass the input through the encoder layer.
69
+
70
+ Args:
71
+ src: the sequence to the encoder layer (required).
72
+ src_mask: the mask for the src sequence (optional).
73
+ src_key_padding_mask: the mask for the src keys per batch (optional).
74
+
75
+ Shape:
76
+ see the docs in Transformer class.
77
+ """
78
+ if self.pre_norm:
79
+ src_ = self.norm1(src)
80
+ else:
81
+ src_ = src
82
+ if isinstance(src_mask, tuple):
83
+ # global attention setup
84
+ assert not self.self_attn.batch_first
85
+ assert src_key_padding_mask is None
86
+
87
+ global_src_mask, trainset_src_mask, valset_src_mask = src_mask
88
+
89
+ num_global_tokens = global_src_mask.shape[0]
90
+ num_train_tokens = trainset_src_mask.shape[0]
91
+
92
+ global_tokens_src = src_[:num_global_tokens]
93
+ train_tokens_src = src_[num_global_tokens:num_global_tokens+num_train_tokens]
94
+ global_and_train_tokens_src = src_[:num_global_tokens+num_train_tokens]
95
+ eval_tokens_src = src_[num_global_tokens+num_train_tokens:]
96
+
97
+
98
+ attn = partial(checkpoint, self.self_attn) if self.recompute_attn else self.self_attn
99
+
100
+ global_tokens_src2 = attn(global_tokens_src, global_and_train_tokens_src, global_and_train_tokens_src, None, True, global_src_mask)[0]
101
+ train_tokens_src2 = attn(train_tokens_src, global_tokens_src, global_tokens_src, None, True, trainset_src_mask)[0]
102
+ eval_tokens_src2 = attn(eval_tokens_src, src_, src_,
103
+ None, True, valset_src_mask)[0]
104
+
105
+ src2 = torch.cat([global_tokens_src2, train_tokens_src2, eval_tokens_src2], dim=0)
106
+
107
+ else:
108
+ if self.recompute_attn:
109
+ src2 = checkpoint(self.self_attn, src_, src_, src_, src_key_padding_mask, True, src_mask)[0]
110
+ else:
111
+ src2 = self.self_attn(src_, src_, src_, attn_mask=src_mask,
112
+ key_padding_mask=src_key_padding_mask)[0]
113
+ src = src + self.dropout1(src2)
114
+ if not self.pre_norm:
115
+ src = self.norm1(src)
116
+
117
+ if self.pre_norm:
118
+ src_ = self.norm2(src)
119
+ else:
120
+ src_ = src
121
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src_))))
122
+ src = src + self.dropout2(src2)
123
+
124
+ if not self.pre_norm:
125
+ src = self.norm2(src)
126
+ return src
lcpfn/model.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import lcpfn
3
+
4
+ class LCPFN(torch.nn.Module):
5
+ def __init__(self, model_name="EMSIZE512_NLAYERS12_NBUCKETS1000"):
6
+ super(LCPFN, self).__init__()
7
+ self.model = torch.load(getattr(lcpfn, model_name) if model_name in lcpfn.model_dict else model_name)
8
+ self.model.eval()
9
+
10
+ @torch.no_grad()
11
+ def predict_mean(self, x_train, y_train, x_test):
12
+ logits = self(x_train=x_train, y_train=y_train, x_test=x_test)
13
+ return self.model.criterion.mean(logits)
14
+
15
+ @torch.no_grad()
16
+ def predict_quantiles(self, x_train, y_train, x_test, qs):
17
+ logits = self(x_train=x_train, y_train=y_train, x_test=x_test)
18
+ return torch.cat([self.model.criterion.icdf(logits, q) for q in qs], dim=1)
19
+
20
+ @torch.no_grad()
21
+ def nll_loss(self, x_train, y_train, x_test, y_test):
22
+ logits = self(x_train=x_train, y_train=y_train, x_test=x_test)
23
+ return self.model.criterion(logits, y_test)
24
+
25
+ def forward(self, x_train, y_train, x_test):
26
+ single_eval_pos = x_train.shape[0]
27
+ x = torch.cat([x_train, x_test], dim=0).unsqueeze(1)
28
+ y = y_train.unsqueeze(1)
29
+ return self.model((x, y), single_eval_pos=single_eval_pos)
lcpfn/positional_encodings.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ # Protocol for positonal encodings.
8
+ # __init__(d_model, max_len=..[, more optionals])
9
+ # forward(x: (seq_len, bs, d_model)) -> Tensor of shape (*x.shape[:2],d_model) containing pos. embeddings
10
+
11
+
12
+ class NoPositionalEncoding(nn.Module):
13
+ def __init__(self, d_model, max_len=None):
14
+ super(NoPositionalEncoding, self).__init__()
15
+ pass
16
+
17
+ def forward(self, x):
18
+ return x #* math.sqrt(x.shape[-1])
19
+
20
+
21
+ class PositionalEncoding(nn.Module):
22
+ def __init__(self, d_model, max_len=5000):
23
+ super(PositionalEncoding, self).__init__()
24
+ pe = torch.zeros(max_len, d_model)
25
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
26
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
27
+ pe[:, 0::2] = torch.sin(position * div_term)
28
+ pe[:, 1::2] = torch.cos(position * div_term)
29
+ pe = pe.unsqueeze(0).transpose(0, 1)
30
+ self.register_buffer('pe', pe)
31
+
32
+ def forward(self, x):
33
+ x = self.pe[:x.size(0), :] + x # * math.sqrt(x.shape[-1])
34
+ return x
35
+
36
+
37
+ class LearnedPositionalEncoding(nn.Module):
38
+ def __init__(self, d_model, max_len=5000):
39
+ super(LearnedPositionalEncoding, self).__init__()
40
+ self.max_seq_len = max_len
41
+ #self.positional_embeddings = nn.Embedding(max_len, d_model)
42
+ self.positional_embeddings = nn.Parameter(torch.empty(max_len, d_model))
43
+ nn.init.normal_(self.positional_embeddings, mean=0, std=d_model ** -0.5)
44
+
45
+ def forward(self, x):
46
+ seq_len, bs, d_model = x.shape
47
+ assert seq_len <= len(self.positional_embeddings), 'seq_len can be at most max_len.'
48
+ pos_emb = self.positional_embeddings[:seq_len]
49
+ return pos_emb.unsqueeze(1).expand(seq_len, bs, d_model) + x #* math.sqrt(x.shape[-1])
50
+
51
+
52
+ class PairedScrambledPositionalEncodings(LearnedPositionalEncoding):
53
+ # TODO check whether it is a problem to use the same perm. for full batch
54
+ def forward(self, x):
55
+ seq_len, bs, d_model = x.shape
56
+ assert seq_len <= len(self.positional_embeddings), 'seq_len can be at most max_len.'
57
+ assert len(self.positional_embeddings) % 2 == 0, 'Please specify an even max_len.'
58
+
59
+ paired_embs = self.positional_embeddings.view(len(self.positional_embeddings), -1, 2)
60
+ pos_emb = paired_embs[torch.randperm(len(paired_embs))].view(*self.positional_embeddings.shape)[:seq_len]
61
+
62
+ return pos_emb.unsqueeze(1).expand(seq_len, bs, d_model) + x #* math.sqrt(x.shape[-1])
63
+
64
+
65
+
66
+
67
+
68
+
69
+
70
+
lcpfn/priors/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import gp, ridge
lcpfn/priors/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (216 Bytes). View file
 
lcpfn/priors/__pycache__/gp.cpython-310.pyc ADDED
Binary file (2.17 kB). View file
 
lcpfn/priors/__pycache__/prior.cpython-310.pyc ADDED
Binary file (1.11 kB). View file
 
lcpfn/priors/__pycache__/ridge.cpython-310.pyc ADDED
Binary file (1.44 kB). View file
 
lcpfn/priors/__pycache__/utils.cpython-310.pyc ADDED
Binary file (6.26 kB). View file
 
lcpfn/priors/binarized_regression.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import fast_gp, fast_gp_mix
2
+ from .utils import get_batch_to_dataloader
3
+
4
+ def regression_prior_to_binary(get_batch_function):
5
+
6
+ def binarized_get_batch_function(*args, assert_on=False, **kwargs):
7
+ x, y, target_y = get_batch_function(*args, **kwargs)
8
+ if assert_on:
9
+ assert y is target_y, "y == target_y is assumed by this function"
10
+ y = y.sigmoid().bernoulli()
11
+ return x, y, y
12
+
13
+ return binarized_get_batch_function
14
+
15
+
16
+ Binarized_fast_gp_dataloader = get_batch_to_dataloader(regression_prior_to_binary(fast_gp.get_batch))
17
+
18
+
19
+ Binarized_fast_gp_mix_dataloader = get_batch_to_dataloader(regression_prior_to_binary(fast_gp_mix.get_batch))
lcpfn/priors/fast_gp.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import torch
4
+ from torch import nn
5
+ import gpytorch
6
+
7
+ from .utils import get_batch_to_dataloader
8
+ from utils import default_device
9
+
10
+
11
+ # We will use the simplest form of GP model, exact inference
12
+ class ExactGPModel(gpytorch.models.ExactGP):
13
+ def __init__(self, train_x, train_y, likelihood):
14
+ super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
15
+ self.mean_module = gpytorch.means.ConstantMean()
16
+ self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
17
+
18
+ def forward(self, x):
19
+ mean_x = self.mean_module(x)
20
+ covar_x = self.covar_module(x)
21
+ return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
22
+
23
+
24
+ def get_model(x, y, hyperparameters):
25
+ likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.GreaterThan(1.e-9))
26
+ model = ExactGPModel(x, y, likelihood)
27
+ model.likelihood.noise = torch.ones_like(model.likelihood.noise) * hyperparameters["noise"]
28
+ model.covar_module.outputscale = torch.ones_like(model.covar_module.outputscale) * hyperparameters["outputscale"]
29
+ model.covar_module.base_kernel.lengthscale = torch.ones_like(model.covar_module.base_kernel.lengthscale) * \
30
+ hyperparameters["lengthscale"]
31
+ return model, likelihood
32
+
33
+
34
+ @torch.no_grad()
35
+ def get_batch(batch_size, seq_len, num_features, device=default_device, hyperparameters=None,
36
+ equidistant_x=False, fix_x=None, **kwargs):
37
+ if isinstance(hyperparameters, (tuple, list)):
38
+ hyperparameters = {"noise": hyperparameters[0]
39
+ , "outputscale": hyperparameters[1]
40
+ , "lengthscale": hyperparameters[2]
41
+ , "is_binary_classification": hyperparameters[3]
42
+ # , "num_features_used": hyperparameters[4]
43
+ , "normalize_by_used_features": hyperparameters[5]
44
+ , "order_y": hyperparameters[6]
45
+ , "sampling": hyperparameters[7]
46
+ }
47
+ elif hyperparameters is None:
48
+ hyperparameters = {"noise": .1, "outputscale": .1, "lengthscale": .1}
49
+
50
+ if 'verbose' in hyperparameters and hyperparameters['verbose']:
51
+ print({"noise": hyperparameters['noise'], "outputscale": hyperparameters['outputscale']
52
+ , "lengthscale": hyperparameters['lengthscale'], 'batch_size': batch_size, 'sampling': hyperparameters['sampling']})
53
+
54
+ # hyperparameters = {k: hyperparameters[k]() if callable(hyperparameters[k]) else hyperparameters[k] for k in
55
+ # hyperparameters.keys()}
56
+ assert not (equidistant_x and (fix_x is not None))
57
+
58
+ with gpytorch.settings.fast_computations(*hyperparameters.get('fast_computations', (True, True, True))):
59
+ if equidistant_x:
60
+ assert num_features == 1
61
+ x = torch.linspace(0, 1., seq_len).unsqueeze(0).repeat(batch_size, 1).unsqueeze(-1)
62
+ elif fix_x is not None:
63
+ assert fix_x.shape == (seq_len, num_features)
64
+ x = fix_x.unsqueeze(0).repeat(batch_size, 1, 1).to(device)
65
+ else:
66
+ if hyperparameters.get('sampling','uniform') == 'uniform':
67
+ x = torch.rand(batch_size, seq_len, num_features, device=device)
68
+ else:
69
+ x = torch.randn(batch_size, seq_len, num_features, device=device)
70
+ model, likelihood = get_model(x, torch.Tensor(), hyperparameters)
71
+ model.to(device)
72
+ # trained_model = ExactGPModel(train_x, train_y, likelihood).cuda()
73
+ # trained_model.eval()
74
+ successful_sample = False
75
+ while not successful_sample:
76
+ try:
77
+ with gpytorch.settings.prior_mode(True):
78
+ model, likelihood = get_model(x, torch.Tensor(), hyperparameters)
79
+ model.to(device)
80
+
81
+ d = model(x)
82
+ sample_wo_noise = d.sample().transpose(0, 1) # this will be the target for the loss
83
+ sample = likelihood(sample_wo_noise).sample() # this will be the input to the Transformer
84
+ successful_sample = True
85
+ except RuntimeError: # This can happen when torch.linalg.eigh fails. Restart with new init resolves this.
86
+ print('GP Sampling unsuccessful, retrying.. ')
87
+ print(x)
88
+ print(hyperparameters)
89
+
90
+ if bool(torch.any(torch.isnan(x)).detach().cpu().numpy()):
91
+ print({"noise": hyperparameters['noise'], "outputscale": hyperparameters['outputscale']
92
+ , "lengthscale": hyperparameters['lengthscale'], 'batch_size': batch_size})
93
+
94
+ # TODO: Multi output
95
+ return x.transpose(0, 1), sample, sample if hyperparameters.get("observation_noise", True) else sample_wo_noise
96
+
97
+ DataLoader = get_batch_to_dataloader(get_batch)
98
+
99
+ def get_model_on_device(x,y,hyperparameters,device):
100
+ model, likelihood = get_model(x, y, hyperparameters)
101
+ model.to(device)
102
+ return model, likelihood
103
+
104
+
105
+ @torch.no_grad()
106
+ def evaluate(x, y, y_non_noisy, use_mse=False, hyperparameters={}, get_model_on_device=get_model_on_device, device=default_device, step_size=1, start_pos=0):
107
+ start_time = time.time()
108
+ losses_after_t = [.0] if start_pos == 0 else []
109
+ all_losses_after_t = []
110
+
111
+ with gpytorch.settings.fast_computations(*hyperparameters.get('fast_computations',(True,True,True))), gpytorch.settings.fast_pred_var(False):
112
+ for t in range(max(start_pos, 1), len(x), step_size):
113
+ loss_sum = 0.
114
+ model, likelihood = get_model_on_device(x[:t].transpose(0, 1), y[:t].transpose(0, 1), hyperparameters, device)
115
+
116
+
117
+ model.eval()
118
+ # print([t.shape for t in model.train_inputs])
119
+ # print(x[:t].transpose(0,1).shape, x[t].unsqueeze(1).shape, y[:t].transpose(0,1).shape)
120
+ f = model(x[t].unsqueeze(1))
121
+ l = likelihood(f)
122
+ means = l.mean.squeeze()
123
+ varis = l.covariance_matrix.squeeze()
124
+ # print(l.variance.squeeze(), l.mean.squeeze(), y[t])
125
+
126
+ assert len(means.shape) == len(varis.shape) == 1
127
+ assert len(means) == len(varis) == x.shape[1]
128
+
129
+ if use_mse:
130
+ c = nn.MSELoss(reduction='none')
131
+ ls = c(means, y[t])
132
+ else:
133
+ ls = -l.log_prob(y[t].unsqueeze(1))
134
+
135
+ losses_after_t.append(ls.mean())
136
+ all_losses_after_t.append(ls.flatten())
137
+ return torch.stack(all_losses_after_t).to('cpu'), torch.tensor(losses_after_t).to('cpu'), time.time() - start_time
138
+
139
+ if __name__ == '__main__':
140
+ hps = (.1,.1,.1)
141
+ for redo_idx in range(1):
142
+ print(
143
+ evaluate(*get_batch(1000, 10, hyperparameters=hps, num_features=10), use_mse=False, hyperparameters=hps))
lcpfn/priors/fast_gp_mix.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import functools
3
+ import random
4
+ import math
5
+ import traceback
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch import nn
10
+ import gpytorch
11
+ from botorch.models import SingleTaskGP
12
+ from botorch.models.gp_regression import MIN_INFERRED_NOISE_LEVEL
13
+ from botorch.fit import fit_gpytorch_model
14
+ from gpytorch.mlls import ExactMarginalLogLikelihood
15
+ from gpytorch.likelihoods import GaussianLikelihood
16
+ from gpytorch.priors.torch_priors import GammaPrior, UniformPrior
17
+ from gpytorch.constraints import GreaterThan
18
+
19
+
20
+ from bar_distribution import BarDistribution
21
+ from utils import default_device
22
+ from .utils import get_batch_to_dataloader
23
+ from . import fast_gp
24
+
25
+ def get_model(x, y, hyperparameters: dict, sample=True):
26
+ if hyperparameters.get('handmade', False):
27
+ # We will use the simplest form of GP model, exact inference
28
+ class ExactGPModel(gpytorch.models.ExactGP):
29
+ def __init__(self, train_x, train_y, likelihood):
30
+ super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
31
+ self.mean_module = gpytorch.means.ConstantMean()
32
+ self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.MaternKernel())
33
+ self.mean_module.register_prior("mean_prior", UniformPrior(-1, 1), "constant")
34
+ self.covar_module.base_kernel.register_prior("lengthscale_prior", UniformPrior(0.01, 0.5),
35
+ "lengthscale")
36
+ # model.covar_module.base_kernel.register_prior("period_length_prior", UniformPrior(0.05, 2.5), "period_length")
37
+ self.covar_module.register_prior("outputscale_prior", UniformPrior(1, 2), "outputscale")
38
+ likelihood.register_prior("noise_prior", UniformPrior(0.001, 0.01), "noise")
39
+ self.to(x)
40
+
41
+ def forward(self, x):
42
+ mean_x = self.mean_module(x)
43
+ covar_x = self.covar_module(x)
44
+ return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
45
+
46
+ likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.Positive())
47
+ model = ExactGPModel(x, y, likelihood)
48
+
49
+
50
+
51
+ else:
52
+ aug_batch_shape = SingleTaskGP(x,y.unsqueeze(-1))._aug_batch_shape
53
+ noise_prior = GammaPrior(hyperparameters.get('noise_concentration',1.1), hyperparameters.get('noise_rate',0.05))
54
+ noise_prior_mode = (noise_prior.concentration - 1) / noise_prior.rate
55
+ likelihood = GaussianLikelihood(
56
+ noise_prior=noise_prior,
57
+ batch_shape=aug_batch_shape,
58
+ noise_constraint=GreaterThan(
59
+ MIN_INFERRED_NOISE_LEVEL,
60
+ transform=None,
61
+ initial_value=noise_prior_mode,
62
+ ),
63
+ )
64
+ model = SingleTaskGP(x, y.unsqueeze(-1),
65
+ covar_module=gpytorch.kernels.ScaleKernel(
66
+ gpytorch.kernels.MaternKernel(
67
+ nu=hyperparameters.get('nu',2.5),
68
+ ard_num_dims=x.shape[-1],
69
+ batch_shape=aug_batch_shape,
70
+ lengthscale_prior=gpytorch.priors.GammaPrior(hyperparameters.get('lengthscale_concentration',3.0), hyperparameters.get('lengthscale_rate',6.0)),
71
+ ),
72
+ batch_shape=aug_batch_shape,
73
+ outputscale_prior=gpytorch.priors.GammaPrior(hyperparameters.get('outputscale_concentration',.5), hyperparameters.get('outputscale_rate',0.15)),
74
+ ), likelihood=likelihood)
75
+
76
+ likelihood = model.likelihood
77
+ model.to(x.device)
78
+ if sample:
79
+ sampled_model = model.pyro_sample_from_prior()
80
+ return sampled_model, sampled_model.likelihood
81
+ else:
82
+ assert not(hyperparameters.get('sigmoid', False)) and not(hyperparameters.get('y_minmax_norm', False)), "Sigmoid and y_minmax_norm can only be used to sample models..."
83
+ return model, likelihood
84
+
85
+
86
+ @torch.no_grad()
87
+ def get_batch(batch_size, seq_len, num_features, device=default_device, hyperparameters=None,
88
+ batch_size_per_gp_sample=None,
89
+ fix_to_range=None, equidistant_x=False, **kwargs):
90
+ '''
91
+ This function is very similar to the equivalent in .fast_gp. The only difference is that this function operates over
92
+ a mixture of GP priors.
93
+ :param batch_size:
94
+ :param seq_len:
95
+ :param num_features:
96
+ :param device:
97
+ :param hyperparameters:
98
+ :param for_regression:
99
+ :return:
100
+ '''
101
+ hyperparameters = hyperparameters or {}
102
+ with gpytorch.settings.fast_computations(*hyperparameters.get('fast_computations',(True,True,True))):
103
+ batch_size_per_gp_sample = (batch_size_per_gp_sample or max(batch_size // 10,1))
104
+ assert batch_size % batch_size_per_gp_sample == 0
105
+
106
+ total_num_candidates = batch_size*(2**(fix_to_range is not None))
107
+ num_candidates = batch_size_per_gp_sample * (2**(fix_to_range is not None))
108
+ if equidistant_x:
109
+ assert num_features == 1
110
+ x = torch.linspace(0,1.,seq_len).unsqueeze(0).repeat(total_num_candidates,1).unsqueeze(-1)
111
+ else:
112
+ x = torch.rand(total_num_candidates, seq_len, num_features, device=device)
113
+ samples = []
114
+ samples_wo_noise = []
115
+ for i in range(0,total_num_candidates,num_candidates):
116
+ model, likelihood = get_model(x[i:i+num_candidates], torch.zeros(num_candidates,x.shape[1]).to(device), hyperparameters)
117
+ model.to(device)
118
+ likelihood.to(device)
119
+ if hyperparameters.get('handmade', False):
120
+ model.covar_module.base_kernel.lengthscale = model.covar_module.base_kernel.lengthscale.to(device)
121
+ model.covar_module.outputscale = model.covar_module.outputscale.to(device)
122
+ likelihood.noise = likelihood.noise.to(device)
123
+ model.mean_module.constant = model.mean_module.constant.to(device)
124
+
125
+ # trained_model = ExactGPModel(train_x, train_y, likelihood).cuda()
126
+ # trained_model.eval()
127
+ successful_sample = 0
128
+ throwaway_share = 0.
129
+ sampling_with_observation_noise = hyperparameters.get("observation_noise", True)
130
+ while successful_sample < 1:
131
+ with gpytorch.settings.prior_mode(True):
132
+ #print(x.device, device, f'{model.covar_module.base_kernel.lengthscale=}, {model.covar_module.base_kernel.lengthscale.device=}')
133
+
134
+
135
+ if sampling_with_observation_noise :
136
+ d = model(x[i:i+num_candidates])
137
+ d = likelihood(d)
138
+ sample = d.sample() # bs_per_gp_s x T
139
+
140
+ else:
141
+ d = model(x[i:i+num_candidates])
142
+ sample_wo_noise = d.sample()
143
+ sample = likelihood(sample_wo_noise).sample()
144
+
145
+ if hyperparameters.get('y_minmax_norm'):
146
+ sample = ((sample - sample.min(1)[0]) / (sample.max(1)[0] - sample.min(1)[0]))
147
+ if hyperparameters.get('sigmoid'):
148
+ sample = sample.sigmoid()
149
+
150
+ if not sampling_with_observation_noise:
151
+ if hyperparameters.get('y_minmax_norm'):
152
+ sample_wo_noise = ((sample_wo_noise - sample_wo_noise.min(1)[0]) / (sample_wo_noise.max(1)[0] - sample_wo_noise.min(1)[0]))
153
+ if hyperparameters.get('sigmoid'):
154
+ sample_wo_noise = sample_wo_noise.sigmoid()
155
+
156
+ if fix_to_range is None:
157
+ samples.append(sample.transpose(0, 1))
158
+ if not sampling_with_observation_noise: samples_wo_noise.append(sample_wo_noise.transpose(0,1))
159
+ successful_sample = True
160
+ continue
161
+
162
+ smaller_mask = sample < fix_to_range[0]
163
+ larger_mask = sample >= fix_to_range[1]
164
+ in_range_mask = ~ (smaller_mask | larger_mask).any(1)
165
+ throwaway_share += (~in_range_mask[:batch_size_per_gp_sample]).sum()/batch_size_per_gp_sample
166
+ if in_range_mask.sum() < batch_size_per_gp_sample:
167
+ successful_sample -= 1
168
+ if successful_sample < 100:
169
+ print("Please change hyper-parameters (e.g. decrease outputscale_mean) it"
170
+ "seems like the range is set to tight for your hyper-parameters.")
171
+ continue
172
+
173
+ x[i:i+batch_size_per_gp_sample] = x[i:i+num_candidates][in_range_mask][:batch_size_per_gp_sample]
174
+ sample = sample[in_range_mask][:batch_size_per_gp_sample]
175
+ samples.append(sample.transpose(0,1))
176
+ if not sampling_with_observation_noise: samples_wo_noise.append(sample_wo_noise.transpose(0,1))
177
+ successful_sample = True
178
+
179
+ if random.random() < .01:
180
+ print('throwaway share', throwaway_share/(batch_size//batch_size_per_gp_sample))
181
+
182
+ #print(f'took {time.time() - start}')
183
+
184
+ x = x.view(-1,batch_size,seq_len,num_features)[0]
185
+ # TODO think about enabling the line below
186
+ #sample = sample - sample[0, :].unsqueeze(0).expand(*sample.shape)
187
+ x = x.transpose(0,1)
188
+ sample = torch.cat(samples, 1)
189
+
190
+ if sampling_with_observation_noise:
191
+ target_sample = sample
192
+ else:
193
+ target_sample = torch.cat(samples_wo_noise, 1)
194
+
195
+ assert x.shape[:2] == sample.shape[:2]
196
+
197
+ return x, sample, target_sample # x.shape = (T,B,H)
198
+
199
+
200
+ class DataLoader(get_batch_to_dataloader(get_batch)):
201
+ @torch.no_grad()
202
+ def validate(self, model, step_size=1, start_pos=0):
203
+ if isinstance(model.criterion, BarDistribution):
204
+ (_, x,y), target_y, eval_pos = self.gbm(**self.get_batch_kwargs)
205
+ model.eval()
206
+ losses = []
207
+ for eval_pos in range(start_pos, len(x), step_size):
208
+ logits = model((x,y), single_eval_pos=eval_pos)
209
+ means = model.criterion.mean(logits) # num_evals x batch_size
210
+ mse = nn.MSELoss()
211
+ losses.append(mse(means[0], target_y[eval_pos]))
212
+ model.train()
213
+ return torch.stack(losses)
214
+ else:
215
+ return 123.
216
+
217
+
218
+ @torch.enable_grad()
219
+ def get_fitted_model(x, y, hyperparameters, device):
220
+ # fit the gaussian process
221
+ model, likelihood = get_model(x,y,hyperparameters,sample=False)
222
+ #print(model.covar_module.base_kernel.lengthscale)
223
+ model.to(device)
224
+ mll = ExactMarginalLogLikelihood(likelihood, model)
225
+ model.train()
226
+ fit_gpytorch_model(mll)
227
+ #print(model.covar_module.base_kernel.lengthscale)
228
+ return model, likelihood
229
+
230
+
231
+ evaluate = functools.partial(fast_gp.evaluate, get_model_on_device=get_fitted_model)
232
+
233
+ def get_mcmc_model(x, y, hyperparameters, device, num_samples, warmup_steps, obs=True):
234
+ from pyro.infer.mcmc import NUTS, MCMC, HMC
235
+ import pyro
236
+ x = x.to(device)
237
+ y = y.to(device)
238
+ model, likelihood = get_model(x, y, hyperparameters, sample=False)
239
+ model.to(device)
240
+
241
+
242
+ def pyro_model(x, y):
243
+ sampled_model = model.pyro_sample_from_prior()
244
+ output = sampled_model.likelihood(sampled_model(x))
245
+ if obs:
246
+ return pyro.sample("obs", output, obs=y)
247
+
248
+ nuts_kernel = NUTS(pyro_model)
249
+ mcmc_run = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup_steps, num_chains=1)
250
+ #print(x.shape)
251
+ mcmc_run.run(x, y)
252
+ #print(mcmc_run.get_samples())
253
+ model.pyro_load_from_samples(mcmc_run.get_samples()) # pyro.infer wie noah?
254
+ model.eval()
255
+ #print(mcmc_run.diagnostics())
256
+ # test_x = torch.linspace(0, 1, 101).unsqueeze(-1)
257
+ # test_y = torch.sin(test_x * (2 * math.pi))
258
+ # expanded_test_x = test_x.unsqueeze(0).repeat(num_samples, 1, 1)
259
+ # output = model(expanded_test_x)
260
+ #print(x.shape)
261
+ return model, likelihood
262
+ # output = model(x[-1].unsqueeze(1).repeat(1, num_samples 1))
263
+ # return output.mean
264
+
265
+
266
+
267
+
268
+ def get_mean_logdensity(dists, x: torch.Tensor, full_range=None):
269
+ means = torch.cat([d.mean.squeeze() for d in dists], 0)
270
+ vars = torch.cat([d.variance.squeeze() for d in dists], 0)
271
+ assert len(means.shape) == 1 and len(vars.shape) == 1
272
+ dist = torch.distributions.Normal(means, vars.sqrt())
273
+ #logprobs = torch.cat([d.log_prob(x) for d in dists], 0)
274
+ logprobs = dist.log_prob(x)
275
+ if full_range is not None:
276
+ used_weight = 1. - (dist.cdf(torch.tensor(full_range[0])) + (1.-dist.cdf(torch.tensor(full_range[1]))))
277
+ if torch.isinf(-torch.log(used_weight)).any() or torch.isinf(torch.log(used_weight)).any():
278
+ print('factor is inf', -torch.log(used_weight))
279
+ logprobs -= torch.log(used_weight)
280
+ assert len(logprobs.shape) == 1
281
+ #print(logprobs)
282
+ return torch.logsumexp(logprobs, 0) - math.log(len(logprobs))
283
+
284
+
285
+ def evaluate_(x, y, y_non_noisy, hyperparameters=None, device=default_device, num_samples=100, warmup_steps=300,
286
+ full_range=None, min_seq_len=0, use_likelihood=False, obs=True):
287
+ with gpytorch.settings.fast_computations(*hyperparameters.get('fast_computations',(True,True,True))), gpytorch.settings.fast_pred_var(False):
288
+ x = x.to(device).double()
289
+ y = y.to(device).double()
290
+ start_time = time.time()
291
+ losses_after_t = [.0] if min_seq_len == 0 else []
292
+ all_losses = []
293
+
294
+ for t in range(max(min_seq_len,1), len(x)):
295
+ #print('Timestep', t)
296
+ loss_sum = 0.
297
+ step_losses = []
298
+ start_step = time.time()
299
+ print(x.shape, y.shape)
300
+ for b_i in range(x.shape[1]):
301
+ x_train = x[:t,b_i]
302
+ y_train = y[:t,b_i]
303
+ from pyro.infer.mcmc import NUTS, MCMC, HMC
304
+ import pyro
305
+ x_train = x_train.to(device)
306
+ y_train = y_train.to(device)
307
+ print(x_train.shape, y_train.shape)
308
+ model, likelihood = get_model(x_train, y_train, hyperparameters, sample=False)
309
+ model.to(device)
310
+
311
+ def pyro_model(x, y):
312
+ sampled_model = model.pyro_sample_from_prior()
313
+ output = sampled_model.likelihood(sampled_model(x))
314
+ if obs:
315
+ return pyro.sample("obs", output, obs=y)
316
+
317
+ nuts_kernel = NUTS(pyro_model)
318
+ mcmc_run = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup_steps, num_chains=1, disable_progbar=True)
319
+ # print(x.shape)
320
+ mcmc_run.run(x_train, y_train)
321
+ # print(mcmc_run.get_samples())
322
+ model.pyro_load_from_samples(mcmc_run.get_samples())
323
+ model.eval()
324
+
325
+ with torch.no_grad():
326
+ dists = model(x[t, b_i, :].unsqueeze(
327
+ 0).repeat(num_samples, 1, 1))
328
+ if use_likelihood:
329
+ dists = likelihood(dists)
330
+ l = -get_mean_logdensity([dists], y[t, b_i].repeat(num_samples), full_range)
331
+ print(l)
332
+
333
+ step_losses.append(l.item())
334
+ #print('loss',l.item())
335
+ print(f'current average loss at step {t} is {sum(step_losses)/len(step_losses)} with {(time.time()-start_step)/len(step_losses)} s per eval.')
336
+ loss_sum += l
337
+
338
+ loss_sum /= x.shape[1]
339
+ all_losses.append(step_losses)
340
+ print(f'loss after step {t} is {loss_sum}')
341
+ losses_after_t.append(loss_sum)
342
+ print(f'losses so far {torch.tensor(losses_after_t)}')
343
+ return torch.tensor(losses_after_t), time.time() - start_time, all_losses
344
+
345
+
346
+
347
+
348
+
349
+ if __name__ == '__main__':
350
+ import argparse
351
+
352
+ parser = argparse.ArgumentParser()
353
+ parser.add_argument('--batch_size', type=int)
354
+ parser.add_argument('--seq_len', type=int)
355
+ parser.add_argument('--min_seq_len', type=int, default=0)
356
+ parser.add_argument('--warmup_steps', type=int)
357
+ parser.add_argument('--num_samples', type=int)
358
+ parser.add_argument('--min_y', type=int)
359
+ parser.add_argument('--max_y', type=int)
360
+ parser.add_argument('--dim', type=int, default=1)
361
+ parser.add_argument('--use_likelihood', action='store_true')
362
+ parser.add_argument('--device', default='cpu')
363
+ parser.add_argument('--outputscale_concentraion', default=2., type=float)
364
+ parser.add_argument('--noise_concentration', default=1.1, type=float)
365
+ parser.add_argument('--noise_rate', default=.05, type=float)
366
+ parser.add_argument('--handmade', action='store_true')
367
+ parser.add_argument('--no_obs', action='store_true')
368
+ parser.add_argument('--seed', type=int, default=0)
369
+
370
+ args = parser.parse_args()
371
+ import pyro
372
+ import gpytorch
373
+ print(pyro.__version__)
374
+ print(gpytorch.__version__)
375
+
376
+
377
+ print('min_y:', args.min_y)
378
+ full_range = (None if args.min_y is None else (args.min_y,args.max_y))
379
+
380
+ hps = {'handmade': args.handmade, 'outputscale_concentration': args.outputscale_concentraion, 'noise_concentration': args.noise_concentration,
381
+ 'noise_rate': args.noise_rate, 'fast_computations': (False,False,False)}
382
+ if args.seed:
383
+ torch.manual_seed(args.seed)
384
+ np.random.seed(args.seed)
385
+ random.seed(args.seed)
386
+ x, y, _ = get_batch(args.batch_size, args.seq_len, args.dim, fix_to_range=full_range, hyperparameters=hps)
387
+ #assert args.seq_len == 7 and args.min_seq_len == 6
388
+ #x = torch.cat([torch.linspace(0, 1, 6), torch.tensor([.33])]).unsqueeze(1).repeat(1,args.batch_size).unsqueeze(-1)
389
+ #y = torch.sin(x * (2 * math.pi)).squeeze(-1)
390
+ print('RESULT:', evaluate_(x, y, y, device=args.device, warmup_steps=args.warmup_steps,
391
+ num_samples=args.num_samples, full_range=full_range, min_seq_len=args.min_seq_len,
392
+ hyperparameters=hps, use_likelihood=args.use_likelihood, obs=not args.no_obs))
393
+
394
+
lcpfn/priors/gp.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import random
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torch import nn
7
+ from sklearn.gaussian_process import GaussianProcessRegressor
8
+ from sklearn.gaussian_process.kernels import RBF, DotProduct, WhiteKernel
9
+ from .utils import get_batch_to_dataloader
10
+
11
+
12
+ length_scale_sampling_gp = .6
13
+
14
+ def get_gp(length_scale=None):
15
+ return GaussianProcessRegressor(
16
+ kernel=RBF(length_scale=length_scale or length_scale_sampling_gp, length_scale_bounds='fixed'),
17
+ random_state=0, optimizer=None)
18
+
19
+
20
+ def get_batch(batch_size, seq_len, num_features, noisy_std=None):
21
+ # m = torch.normal(0.,.1,size=(batch_size,num_features))
22
+ # m2 = torch.rand(batch_size,num_features)
23
+ # b = 0 # torch.rand(batch_size)
24
+ x_t = torch.rand(batch_size, seq_len, num_features)
25
+ # gp_b = TensorGP(kernel=TensorRBF(noisy_std))
26
+ # y_t = gp_b.sample_from_GP_prior(x_t).detach()
27
+
28
+ gpr = get_gp(noisy_std)
29
+ y_t = torch.zeros(batch_size, seq_len)
30
+
31
+ for i in range(len(y_t)):
32
+ y_t[i] += gpr.sample_y(x_t[i], random_state=random.randint(0, 2 ** 32)).squeeze()
33
+ x, y = x_t.transpose(0, 1), y_t.transpose(0, 1)
34
+ # x, _ = torch.sort(x,dim=0)
35
+ return x, y, y
36
+
37
+
38
+ DataLoader = get_batch_to_dataloader(get_batch)
39
+
40
+ def evaluate(x, y, y_non_noisy, use_mse=False, length_scale=length_scale_sampling_gp):
41
+ start_time = time.time()
42
+ losses_after_t = [.0]
43
+ for t in range(1, len(x)):
44
+ loss_sum = 0.
45
+ for b_i in range(x.shape[1]):
46
+ gpr = get_gp(length_scale).fit(x[:t, b_i], y[:t, b_i])
47
+ means, stds = gpr.predict(x[t, b_i].unsqueeze(0), return_std=True)
48
+ assert len(means) == 1 == len(stds)
49
+ if use_mse:
50
+ c = nn.MSELoss()
51
+ l = c(torch.tensor(means), y[t, b_i].unsqueeze(-1))
52
+ else:
53
+ c = nn.GaussianNLLLoss(full=True)
54
+ l = c(torch.tensor(means), y[t, b_i].unsqueeze(-1),
55
+ var=torch.tensor(stds) ** 2)
56
+ loss_sum += l
57
+
58
+
59
+ losses_after_t.append(loss_sum / x.shape[1])
60
+
61
+ return torch.tensor(losses_after_t), time.time()-start_time
62
+
63
+ if __name__ == '__main__':
64
+ ls = .1
65
+ for alpha in set([ls, ls * 1.1, ls * .9]):
66
+ print(alpha)
67
+ for redo_idx in range(1):
68
+ print(
69
+ evaluate(*get_batch(1000, 10, noisy_std=ls, num_features=10), use_mse=False, length_scale=alpha))
lcpfn/priors/prior.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABCMeta, abstractmethod
2
+ from torch.utils.data import DataLoader
3
+
4
+
5
+ class PriorDataLoader(DataLoader, metaclass=ABCMeta):
6
+ @abstractmethod
7
+ def __init__(self, num_steps, batch_size, eval_pos_seq_len_sampler, seq_len_maximum, device, **kwargs):
8
+ """
9
+
10
+ :param num_steps: int, first argument, the number of steps to take per epoch, i.e. iteration of the DataLoader
11
+ :param batch_size: int, number of datasets per batch
12
+ :param eval_pos_seq_len_sampler: callable, it takes no arguments and returns a tuple (single eval pos, bptt)
13
+ :param kwargs: for future compatibility it is good to have a final all catch, as new kwargs might be introduced
14
+ """
15
+ pass
16
+
17
+ # A class or object variable `num_features`: int
18
+ # Optional: `validate` function that accepts a transformer model
19
+
20
+ # The DataLoader iter should return batches of the form ([style], x, y), target_y, single_eval_pos
21
+ # We follow sequence len (s) first, batch size (b) second. So x: (s,b,num_features), y,target_y: (s,b)
22
+ # and style: Optional[(b,num_style_params)], style can be omitted or set to None, if it is not intended to be used.
23
+
24
+ # For more references, see `priors/utils.py` for a pretty general implementation of a DataLoader
25
+ # and `train.py` for the only call of it.
lcpfn/priors/pyro.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from utils import default_device
7
+ from .utils import get_batch_to_dataloader
8
+
9
+
10
+ def get_batch(batch_size, seq_len, batch_size_per_gp_sample=None, **config):
11
+ batch_size_per_gp_sample = batch_size_per_gp_sample or batch_size // 16
12
+ assert batch_size % batch_size_per_gp_sample == 0, 'Please choose a batch_size divisible by batch_size_per_gp_sample.'
13
+ num_models = batch_size // batch_size_per_gp_sample
14
+ # standard kaiming uniform init currently...
15
+
16
+ models = [config['model']() for _ in range(num_models)]
17
+
18
+ sample = sum([[model(seq_len=seq_len) for _ in range(0,batch_size_per_gp_sample)] for model in models],[])
19
+
20
+ def normalize_data(data):
21
+ mean = data.mean(0)
22
+ std = data.std(0) + .000001
23
+ eval_xs = (data - mean) / std
24
+
25
+ return eval_xs
26
+
27
+ x, y = zip(*sample)
28
+
29
+ y = torch.stack(y, 1).squeeze(-1).detach()
30
+ x = torch.stack(x, 1).detach()
31
+
32
+ if 'normalize_y' in config and config['normalize_y']:
33
+ x, y = normalize_data(x), normalize_data(y)
34
+ elif 'normalize_y' in config and config['normalize']:
35
+ x, y = normalize_data(x), y
36
+
37
+ return x, y, y
38
+
39
+
40
+ DataLoader = get_batch_to_dataloader(get_batch)
41
+
lcpfn/priors/ridge.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import time
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torch import nn
7
+ from sklearn.linear_model import Ridge
8
+ from .utils import get_batch_to_dataloader
9
+
10
+ def get_batch(batch_size, seq_len, num_features, noisy_std = .1):
11
+ m = torch.normal(0., .1, size=(batch_size,num_features))
12
+ b = 0 # torch.rand(batch_size)
13
+ x = torch.rand(seq_len, batch_size,num_features)
14
+ y_non_noisy = torch.einsum('bf,tbf->tb',m,x)
15
+ y = y_non_noisy + torch.normal(torch.zeros_like(y_non_noisy),noisy_std) # noisy_std is alpha
16
+ return x, y, y_non_noisy
17
+
18
+ DataLoader = get_batch_to_dataloader(get_batch)
19
+
20
+
21
+ def evaluate(x,y,y_non_noisy, alpha=0.):
22
+ start_time = time.time()
23
+ losses_after_t = [.0]
24
+ for t in range(1,len(x)):
25
+ loss_sum = 0.
26
+ for b_i in range(x.shape[1]):
27
+ clf = Ridge(alpha=alpha)
28
+ clf.fit(x[:t,b_i],y[:t,b_i])
29
+ y_ = clf.predict(x[t,b_i].unsqueeze(0))
30
+ l = nn.MSELoss()(y_non_noisy[t,b_i].unsqueeze(0),torch.tensor(y_))
31
+ loss_sum += l
32
+ losses_after_t.append(loss_sum/x.shape[1])
33
+ return torch.tensor(losses_after_t), time.time()-start_time
34
+
35
+ if __name__ == '__main__':
36
+ for alpha in [.001,.01,.5,1.]:
37
+ print(alpha, evaluate(*get_batch(1000,10,noisy_std=.01),alpha=alpha))
lcpfn/priors/stroke.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageDraw, ImageFilter
2
+ import random
3
+ import math
4
+
5
+ import torch
6
+ import numpy as np
7
+ from .utils import get_batch_to_dataloader
8
+
9
+ def mnist_prior(num_classes=2, size=28, min_max_strokes=(1,3), min_max_len=(5/28,20/28), min_max_start=(2/28,25/28),
10
+ min_max_width=(1/28,4/28), max_offset=4/28, max_target_offset=2/28):
11
+ classes = []
12
+ for i in range(num_classes):
13
+ num_strokes = random.randint(*min_max_strokes)
14
+ len_strokes = [random.randint(int(size * min_max_len[0]), int(size * min_max_len[1])) for i in range(num_strokes)]
15
+ stroke_start_points = [
16
+ (random.randint(int(size * min_max_start[0]), int(size * min_max_start[1])), random.randint(int(size * min_max_start[0]), int(size * min_max_start[1]))) for i in
17
+ range(num_strokes)]
18
+ stroke_directions = []
19
+ # i = Image.fromarray(np.zeros((28,28),dtype=np.uint8))
20
+ # draw = ImageDraw.Draw(i)
21
+ for i in range(num_strokes):
22
+ sp, length = stroke_start_points[i], len_strokes[i]
23
+ counter = 0
24
+ while True:
25
+ if counter % 3 == 0:
26
+ length = random.randint(int(size * min_max_len[0]), int(size * min_max_len[1]))
27
+ sp = (
28
+ random.randint(int(size * min_max_start[0]), int(size * min_max_start[1])), random.randint(int(size * min_max_start[0]), int(size * min_max_start[1])))
29
+ stroke_start_points[i], len_strokes[i] = sp, length
30
+ radians = random.random() * 2 * math.pi
31
+ x_vel = math.cos(radians) * length
32
+ y_vel = math.sin(radians) * length
33
+ new_p = (sp[0] + x_vel, sp[1] + y_vel)
34
+ # print(math.degrees(radians),sp,new_p)
35
+ if not any(n > size - 1 or n < 0 for n in new_p):
36
+ break
37
+ counter += 1
38
+ stroke_directions.append(radians)
39
+ # print([round(x) for x in sp+new_p])
40
+ # draw.line([round(x) for x in sp+new_p], fill=128, width=3)
41
+ classes.append((len_strokes, stroke_start_points, stroke_directions))
42
+
43
+ generator_functions = []
44
+ for c in classes:
45
+ def g(c=c):
46
+ len_strokes, stroke_start_points, stroke_directions = c
47
+ i = Image.fromarray(np.zeros((size, size), dtype=np.uint8))
48
+ draw = ImageDraw.Draw(i)
49
+ width = random.randint(int(size * min_max_width[0]), int(size * min_max_width[1]))
50
+ offset = random.randint(int(-size * max_offset), int(size * max_offset)), random.randint(int(- size * max_offset), int(size * max_offset))
51
+ for sp, length, radians in zip(stroke_start_points, len_strokes, stroke_directions):
52
+ sp = (sp[0] + offset[0], sp[1] + offset[1])
53
+ x_vel = math.cos(radians) * length + random.randint(int(-size * max_target_offset), int(size * max_target_offset))
54
+ y_vel = math.sin(radians) * length + random.randint(int(-size * max_target_offset), int(size * max_target_offset))
55
+ new_p = (sp[0] + x_vel, sp[1] + y_vel)
56
+ stroke_directions.append(radians)
57
+ draw.line([round(x) for x in sp + new_p], fill=128, width=width)
58
+ a_i = np.array(i)
59
+ a_i[a_i == 128] = np.random.randint(200, 255, size=a_i.shape)[a_i == 128]
60
+ return Image.fromarray(a_i).filter(ImageFilter.GaussianBlur(.2))
61
+
62
+ generator_functions.append(g)
63
+ return generator_functions
64
+
65
+
66
+ # g1,g2 = mnist_prior(2)
67
+
68
+ # for i in [g1() for _ in range(10)]:
69
+ # display(i.resize((200,200)))
70
+
71
+ from torchvision.transforms import ToTensor, ToPILImage
72
+
73
+
74
+ def normalize(x):
75
+ return (x-x.mean())/(x.std()+.000001)
76
+
77
+ from os import path, listdir
78
+ import random
79
+
80
+ def get_batch(batch_size, seq_len, num_features=None, noisy_std=None, only_train_for_last_idx=False, normalize_x=False, num_outputs=2, use_saved_from=None, **kwargs): # num_features = 28*28=784
81
+ if use_saved_from is not None:
82
+ directory = path.join(use_saved_from, f'len_{seq_len}_out_{num_outputs}_features_{num_features}_bs_{batch_size}')
83
+ filename = random.choice(listdir(directory))
84
+ return torch.load(path.join(directory,filename))
85
+
86
+ size = math.isqrt(num_features)
87
+ assert size * size == num_features, 'num_features needs to be the square of an integer.'
88
+ if only_train_for_last_idx:
89
+ assert (seq_len-1) % num_outputs == 0
90
+
91
+ # assert seq_len % 2 == 0, "assert seq_len % 2 == 0"
92
+ batch = []
93
+ y = []
94
+ target_y = []
95
+ for b_i in range(batch_size):
96
+ gs = mnist_prior(num_outputs, size, **kwargs)
97
+ if only_train_for_last_idx:
98
+ generators = [i for i in range(len(gs)) for _ in range((seq_len-1) // num_outputs)]
99
+ random.shuffle(generators)
100
+ generators += [random.randint(0, len(gs) - 1)]
101
+ target = [-100 for _ in generators]
102
+ target[-1] = generators[-1]
103
+ else:
104
+ generators = [random.randint(0, len(gs) - 1) for _ in range(seq_len)]
105
+ target = generators
106
+ normalize_or_not = lambda x: normalize(x) if normalize_x else x
107
+ s = torch.cat([normalize_or_not(ToTensor()(gs[f_i]())) for f_i in generators], 0)
108
+ batch.append(s)
109
+ y.append(torch.tensor(generators))
110
+ target_y.append(torch.tensor(target))
111
+ x = torch.stack(batch, 1).view(seq_len, batch_size, -1)
112
+ y = torch.stack(y, 1)
113
+ target_y = torch.stack(target_y, 1)
114
+ return x,y,target_y
115
+
116
+ DataLoader = get_batch_to_dataloader(get_batch)
117
+ DataLoader.num_outputs = 2
118
+
119
+ if __name__ == '__main__':
120
+ g1, g2 = mnist_prior(2, size=3)
121
+
122
+ # for i in range(10):
123
+ # print(PILToTensor()(g1()))
124
+ # display(ToPILImage()(PILToTensor()(g1())).resize((200,200)))
125
+ # display(g2().resize((200,200)))
126
+
127
+ size = 10
128
+ x, y = get_batch(1, 10, num_features=size * size)
129
+
130
+ x_ = x[..., :-1].squeeze(1)
131
+ last_y = x[..., -1].squeeze(1)
132
+ y = y.squeeze(1)
133
+
134
+ # print(y)
135
+
136
+ for i, y_, last_y_, x__ in zip(x_, y, last_y, x.squeeze(1)):
137
+ # print(y_)
138
+ # print(i.shape)
139
+ # print(x__)
140
+ img = ToPILImage()(i.view(size, size))
141
+ # display(img.resize((200,200)))
142
+
143
+ print(y, last_y)
lcpfn/priors/utils.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import pandas as pd
4
+ import torch
5
+
6
+ from lcpfn.utils import set_locals_in_self
7
+ from itertools import repeat
8
+ from .prior import PriorDataLoader
9
+ from torch import nn
10
+ import numpy as np
11
+ import matplotlib.pyplot as plt
12
+ import matplotlib.gridspec as gridspec
13
+ import scipy.stats as stats
14
+ import math
15
+
16
+ def get_batch_to_dataloader(get_batch_method_):
17
+ class DL(PriorDataLoader):
18
+ get_batch_method = get_batch_method_
19
+
20
+ # Caution, you might need to set self.num_features manually if it is not part of the args.
21
+ def __init__(self, num_steps, **get_batch_kwargs):
22
+ set_locals_in_self(locals())
23
+
24
+ # The stuff outside the or is set as class attribute before instantiation.
25
+ self.num_features = get_batch_kwargs.get('num_features') or self.num_features
26
+ print('DataLoader.__dict__', self.__dict__)
27
+
28
+ @staticmethod
29
+ def gbm(*args, eval_pos_seq_len_sampler, **kwargs):
30
+ kwargs['single_eval_pos'], kwargs['seq_len'] = eval_pos_seq_len_sampler()
31
+ # Scales the batch size dynamically with the power of 'dynamic_batch_size'.
32
+ # A transformer with quadratic memory usage in the seq len would need a power of 2 to keep memory constant.
33
+ if 'dynamic_batch_size' in kwargs and kwargs['dynamic_batch_size'] > 0:
34
+ kwargs['batch_size'] = kwargs['batch_size'] * math.floor(math.pow(kwargs['seq_len_maximum'], kwargs['dynamic_batch_size']) / math.pow(kwargs['seq_len'], kwargs['dynamic_batch_size']))
35
+ batch = get_batch_method_(*args, **kwargs)
36
+ x, y, target_y, style = batch if len(batch) == 4 else (batch[0], batch[1], batch[2], None)
37
+ return (style, x, y), target_y, kwargs['single_eval_pos']
38
+
39
+ def __len__(self):
40
+ return self.num_steps
41
+
42
+ def __iter__(self):
43
+ return iter(self.gbm(**self.get_batch_kwargs) for _ in range(self.num_steps))
44
+
45
+ return DL
46
+
47
+ """
48
+ import seaborn as sns
49
+ def plot_features(data, targets, fig=None):
50
+ if torch.is_tensor(data):
51
+ data = data.detach().cpu().numpy()
52
+ targets = targets.detach().cpu().numpy()
53
+ fig2 = plt.figure(figsize=(8, 8))
54
+ spec2 = gridspec.GridSpec(ncols=data.shape[1], nrows=data.shape[1], figure=fig2)
55
+ for d in range(0, data.shape[1]):
56
+ for d2 in range(0, data.shape[1]):
57
+ sub_ax = fig2.add_subplot(spec2[d, d2])
58
+ if d == d2:
59
+ sns.kdeplot(data[:, d],hue=targets[:],ax=sub_ax,legend=False, palette="deep")
60
+ sub_ax.set(ylabel=None)
61
+ else:
62
+ sns.scatterplot(data[:, d], data[:, d2],
63
+ hue=targets[:],legend=False, palette="deep")
64
+ #plt.scatter(data[:, d], data[:, d2],
65
+ # c=targets[:])
66
+ sub_ax.get_xaxis().set_ticks([])
67
+ sub_ax.get_yaxis().set_ticks([])
68
+ plt.subplots_adjust(wspace=0.05, hspace=0.05)
69
+ fig2.show()
70
+
71
+
72
+ def plot_prior(prior):
73
+ s = np.array([prior() for _ in range(0, 1000)])
74
+ count, bins, ignored = plt.hist(s, 50, density=True)
75
+ print(s.min())
76
+ plt.show()
77
+ """
78
+
79
+ trunc_norm_sampler_f = lambda mu, sigma : lambda: stats.truncnorm((0 - mu) / sigma, (1000000 - mu) / sigma, loc=mu, scale=sigma).rvs(1)[0]
80
+ beta_sampler_f = lambda a, b : lambda : np.random.beta(a, b)
81
+ gamma_sampler_f = lambda a, b : lambda : np.random.gamma(a, b)
82
+ uniform_sampler_f = lambda a, b : lambda : np.random.uniform(a, b)
83
+ uniform_int_sampler_f = lambda a, b : lambda : round(np.random.uniform(a, b))
84
+ def zipf_sampler_f(a, b, c):
85
+ x = np.arange(b, c)
86
+ weights = x ** (-a)
87
+ weights /= weights.sum()
88
+ return lambda : stats.rv_discrete(name='bounded_zipf', values=(x, weights)).rvs(1)
89
+ scaled_beta_sampler_f = lambda a, b, scale, minimum : lambda : minimum + round(beta_sampler_f(a, b)() * (scale - minimum))
90
+
91
+
92
+ def normalize_by_used_features_f(x, num_features_used, num_features, normalize_with_sqrt=False):
93
+ if normalize_with_sqrt:
94
+ return x / (num_features_used / num_features)**(1 / 2)
95
+ return x / (num_features_used / num_features)
96
+
97
+
98
+ def order_by_y(x, y):
99
+ order = torch.argsort(y if random.randint(0, 1) else -y, dim=0)[:, 0, 0]
100
+ order = order.reshape(2, -1).transpose(0, 1).reshape(-1)#.reshape(seq_len)
101
+ x = x[order] # .reshape(2, -1).transpose(0, 1).reshape(-1).flip([0]).reshape(seq_len, 1, -1)
102
+ y = y[order] # .reshape(2, -1).transpose(0, 1).reshape(-1).reshape(seq_len, 1, -1)
103
+
104
+ return x, y
105
+
106
+ def randomize_classes(x, num_classes):
107
+ classes = torch.arange(0, num_classes, device=x.device)
108
+ random_classes = torch.randperm(num_classes, device=x.device).type(x.type())
109
+ x = ((x.unsqueeze(-1) == classes) * random_classes).sum(-1)
110
+ return x
111
+
112
+
113
+ class CategoricalActivation(nn.Module):
114
+ def __init__(self, categorical_p=0.1, ordered_p=0.7
115
+ , keep_activation_size=False
116
+ , num_classes_sampler=zipf_sampler_f(0.8, 1, 10)):
117
+ self.categorical_p = categorical_p
118
+ self.ordered_p = ordered_p
119
+ self.keep_activation_size = keep_activation_size
120
+ self.num_classes_sampler = num_classes_sampler
121
+
122
+ super().__init__()
123
+
124
+ def forward(self, x):
125
+ # x shape: T, B, H
126
+
127
+ x = nn.Softsign()(x)
128
+
129
+ num_classes = self.num_classes_sampler()
130
+ hid_strength = torch.abs(x).mean(0).unsqueeze(0) if self.keep_activation_size else None
131
+
132
+ categorical_classes = torch.rand((x.shape[1], x.shape[2])) < self.categorical_p
133
+ class_boundaries = torch.zeros((num_classes - 1, x.shape[1], x.shape[2]), device=x.device, dtype=x.dtype)
134
+ # Sample a different index for each hidden dimension, but shared for all batches
135
+ for b in range(x.shape[1]):
136
+ for h in range(x.shape[2]):
137
+ ind = torch.randint(0, x.shape[0], (num_classes - 1,))
138
+ class_boundaries[:, b, h] = x[ind, b, h]
139
+
140
+ for b in range(x.shape[1]):
141
+ x_rel = x[:, b, categorical_classes[b]]
142
+ boundaries_rel = class_boundaries[:, b, categorical_classes[b]].unsqueeze(1)
143
+ x[:, b, categorical_classes[b]] = (x_rel > boundaries_rel).sum(dim=0).float() - num_classes / 2
144
+
145
+ ordered_classes = torch.rand((x.shape[1],x.shape[2])) < self.ordered_p
146
+ ordered_classes = torch.logical_and(ordered_classes, categorical_classes)
147
+ x[:, ordered_classes] = randomize_classes(x[:, ordered_classes], num_classes)
148
+
149
+ x = x * hid_strength if self.keep_activation_size else x
150
+
151
+ return x
lcpfn/train.py ADDED
@@ -0,0 +1,602 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import itertools
3
+ import argparse
4
+ import time
5
+ import datetime
6
+ import yaml
7
+ from contextlib import nullcontext
8
+
9
+ import pickle
10
+ import torch
11
+ from torch import nn
12
+
13
+ from lcpfn import utils
14
+ from lcpfn.transformer import TransformerModel
15
+ from lcpfn.bar_distribution import (
16
+ BarDistribution,
17
+ FullSupportBarDistribution,
18
+ get_bucket_limits,
19
+ )
20
+ from lcpfn.utils import (
21
+ get_cosine_schedule_with_warmup,
22
+ get_openai_lr,
23
+ StoreDictKeyPair,
24
+ get_weighted_single_eval_pos_sampler,
25
+ get_uniform_single_eval_pos_sampler,
26
+ )
27
+ from lcpfn import priors
28
+ from lcpfn import encoders
29
+ from lcpfn import positional_encodings
30
+ from lcpfn.utils import init_dist
31
+ from torch.cuda.amp import autocast, GradScaler
32
+
33
+
34
+ class Losses:
35
+ gaussian = nn.GaussianNLLLoss(full=True, reduction="none")
36
+ mse = nn.MSELoss(reduction="none")
37
+ ce = lambda num_classes: nn.CrossEntropyLoss(
38
+ reduction="none", weight=torch.ones(num_classes)
39
+ )
40
+ bce = nn.BCEWithLogitsLoss(reduction="none")
41
+ get_BarDistribution = BarDistribution
42
+
43
+
44
+ def train(
45
+ priordataloader_class,
46
+ criterion,
47
+ encoder_generator,
48
+ emsize=200,
49
+ nhid=200,
50
+ nlayers=6,
51
+ nhead=2,
52
+ dropout=0.2,
53
+ epochs=10,
54
+ steps_per_epoch=100,
55
+ batch_size=200,
56
+ bptt=10,
57
+ lr=None,
58
+ weight_decay=0.0,
59
+ warmup_epochs=10,
60
+ input_normalization=False,
61
+ y_encoder_generator=None,
62
+ pos_encoder_generator=None,
63
+ decoder=None,
64
+ extra_prior_kwargs_dict={},
65
+ scheduler=get_cosine_schedule_with_warmup,
66
+ load_weights_from_this_state_dict=None,
67
+ validation_period=10,
68
+ single_eval_pos_gen=None,
69
+ bptt_extra_samples=None,
70
+ gpu_device="cuda:0",
71
+ aggregate_k_gradients=1,
72
+ verbose=True,
73
+ style_encoder_generator=None,
74
+ epoch_callback=None,
75
+ initializer=None,
76
+ initialize_with_model=None,
77
+ train_mixed_precision=False,
78
+ saving_period=10,
79
+ checkpoint_file=None,
80
+ load_optimizer_from_this_state_dict=None,
81
+ output_path=None,
82
+ **model_extra_args,
83
+ ):
84
+ device = gpu_device if torch.cuda.is_available() else "cpu:0"
85
+ print(f"Using {device} device")
86
+ using_dist, rank, device = init_dist(device)
87
+ single_eval_pos_gen = (
88
+ single_eval_pos_gen
89
+ if callable(single_eval_pos_gen)
90
+ else lambda: single_eval_pos_gen
91
+ )
92
+
93
+ def eval_pos_seq_len_sampler():
94
+ single_eval_pos = single_eval_pos_gen()
95
+ if bptt_extra_samples:
96
+ return single_eval_pos, single_eval_pos + bptt_extra_samples
97
+ else:
98
+ return single_eval_pos, bptt
99
+
100
+ dl = priordataloader_class(
101
+ num_steps=steps_per_epoch,
102
+ batch_size=batch_size,
103
+ eval_pos_seq_len_sampler=eval_pos_seq_len_sampler,
104
+ seq_len_maximum=bptt + (bptt_extra_samples if bptt_extra_samples else 0),
105
+ device=device,
106
+ **extra_prior_kwargs_dict,
107
+ )
108
+
109
+ encoder = encoder_generator(dl.num_features, emsize)
110
+ style_def = next(iter(dl))[0][
111
+ 0
112
+ ] # This is (style, x, y), target with x and y with batch size
113
+ print(f"Style definition: {style_def}")
114
+ style_encoder = (
115
+ style_encoder_generator(hyperparameter_definitions=style_def[0], em_size=emsize)
116
+ if (style_def is not None)
117
+ else None
118
+ )
119
+ if isinstance(criterion, nn.GaussianNLLLoss):
120
+ n_out = 2
121
+ elif (
122
+ isinstance(criterion, BarDistribution)
123
+ or "BarDistribution" in criterion.__class__.__name__
124
+ ): # TODO remove this fix (only for dev)
125
+ n_out = criterion.num_bars
126
+ elif isinstance(criterion, nn.CrossEntropyLoss):
127
+ n_out = criterion.weight.shape[0]
128
+ else:
129
+ n_out = 1
130
+ model = TransformerModel(
131
+ encoder,
132
+ n_out,
133
+ emsize,
134
+ nhead,
135
+ nhid,
136
+ nlayers,
137
+ dropout,
138
+ style_encoder=style_encoder,
139
+ y_encoder=y_encoder_generator(1, emsize),
140
+ input_normalization=input_normalization,
141
+ pos_encoder=(
142
+ pos_encoder_generator or positional_encodings.NoPositionalEncoding
143
+ )(emsize, bptt * 2),
144
+ decoder=decoder,
145
+ init_method=initializer,
146
+ **model_extra_args,
147
+ )
148
+ model.criterion = criterion
149
+ if load_weights_from_this_state_dict is not None:
150
+ model.load_state_dict(load_weights_from_this_state_dict)
151
+ if initialize_with_model is not None:
152
+ model.init_from_small_model(initialize_with_model)
153
+
154
+ print(
155
+ f"Using a Transformer with {sum(p.numel() for p in model.parameters())/1000/1000:.{2}f} M parameters"
156
+ )
157
+
158
+ try:
159
+ for (k, v), (k2, v2) in zip(
160
+ model.state_dict().items(), initialize_with_model.state_dict().items()
161
+ ):
162
+ print(k, ((v - v2) / v).abs().mean(), v.shape)
163
+ except Exception:
164
+ pass
165
+
166
+ model.to(device)
167
+ if using_dist:
168
+ print("Distributed training")
169
+ model = torch.nn.parallel.DistributedDataParallel(
170
+ model, device_ids=[rank], output_device=rank, broadcast_buffers=False
171
+ )
172
+
173
+ # learning rate
174
+ if lr is None:
175
+ lr = get_openai_lr(model)
176
+ print(f"Using OpenAI max lr of {lr}.")
177
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
178
+ scheduler = scheduler(
179
+ optimizer, warmup_epochs, epochs if epochs is not None else 100
180
+ ) # when training for fixed time lr schedule takes 100 steps
181
+
182
+ if load_optimizer_from_this_state_dict is not None:
183
+ optimizer.load_state_dict(load_optimizer_from_this_state_dict)
184
+ scaler = GradScaler() if train_mixed_precision else None
185
+
186
+ # check that everything uses up-to-date APIs
187
+ utils.check_compatibility(dl)
188
+
189
+ def train_epoch():
190
+ model.train() # Turn on the train mode
191
+ total_loss = 0.0
192
+ total_positional_losses = 0.0
193
+ total_positional_losses_recorded = 0
194
+ before_get_batch = time.time()
195
+ assert (
196
+ len(dl) % aggregate_k_gradients == 0
197
+ ), "Please set the number of steps per epoch s.t. `aggregate_k_gradients` divides it."
198
+ for batch, (data, targets, single_eval_pos) in enumerate(dl):
199
+ if using_dist and not (
200
+ batch % aggregate_k_gradients == aggregate_k_gradients - 1
201
+ ):
202
+ cm = model.no_sync()
203
+ else:
204
+ cm = nullcontext()
205
+ with cm:
206
+ time_to_get_batch = time.time() - before_get_batch
207
+ before_forward = time.time()
208
+
209
+ with autocast(enabled=scaler is not None):
210
+ # If style is set to None, it should not be transferred to device
211
+ output = model(
212
+ tuple(e.to(device) if torch.is_tensor(e) else e for e in data)
213
+ if isinstance(data, tuple)
214
+ else data.to(device),
215
+ single_eval_pos=single_eval_pos,
216
+ )
217
+
218
+ forward_time = time.time() - before_forward
219
+
220
+ if single_eval_pos is not None:
221
+ targets = targets[single_eval_pos:]
222
+ if isinstance(criterion, nn.GaussianNLLLoss):
223
+ assert (
224
+ output.shape[-1] == 2
225
+ ), "need to write a little bit of code to handle multiple regression targets at once"
226
+
227
+ mean_pred = output[..., 0]
228
+ var_pred = output[..., 1].abs()
229
+ losses = criterion(
230
+ mean_pred.flatten(),
231
+ targets.to(device).flatten(),
232
+ var=var_pred.flatten(),
233
+ )
234
+ elif isinstance(criterion, (nn.MSELoss, nn.BCEWithLogitsLoss)):
235
+ losses = criterion(
236
+ output.flatten(), targets.to(device).flatten()
237
+ )
238
+ elif isinstance(criterion, nn.CrossEntropyLoss):
239
+ losses = criterion(
240
+ output.reshape(-1, n_out),
241
+ targets.to(device).long().flatten(),
242
+ )
243
+ else:
244
+ losses = criterion(output, targets)
245
+ losses = losses.view(*output.shape[0:2])
246
+ loss = losses.mean() / aggregate_k_gradients
247
+
248
+ if scaler:
249
+ loss = scaler.scale(loss)
250
+ loss.backward()
251
+
252
+ if batch % aggregate_k_gradients == aggregate_k_gradients - 1:
253
+ if scaler:
254
+ scaler.unscale_(optimizer)
255
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
256
+ try:
257
+ if scaler:
258
+ scaler.step(optimizer)
259
+ scaler.update()
260
+ else:
261
+ optimizer.step()
262
+ except:
263
+ print("Invalid optimization step encountered")
264
+ optimizer.zero_grad()
265
+
266
+ step_time = time.time() - before_forward
267
+
268
+ if not torch.isnan(loss):
269
+ total_loss += losses.mean().cpu().detach()
270
+ total_positional_losses += (
271
+ losses.mean(1).cpu().detach()
272
+ if single_eval_pos is None
273
+ else nn.functional.one_hot(torch.tensor(single_eval_pos), bptt)
274
+ * losses[: bptt - single_eval_pos].mean().cpu().detach()
275
+ )
276
+
277
+ total_positional_losses_recorded += (
278
+ torch.ones(bptt)
279
+ if single_eval_pos is None
280
+ else nn.functional.one_hot(torch.tensor(single_eval_pos), bptt)
281
+ )
282
+
283
+ before_get_batch = time.time()
284
+ return (
285
+ total_loss / steps_per_epoch,
286
+ (total_positional_losses / total_positional_losses_recorded).tolist(),
287
+ time_to_get_batch,
288
+ forward_time,
289
+ step_time,
290
+ )
291
+
292
+ total_loss = float("inf")
293
+ total_positional_losses = float("inf")
294
+ list_losses = []
295
+ try:
296
+ for epoch in range(1, epochs + 1) if epochs is not None else itertools.count(1):
297
+
298
+ epoch_start_time = time.time()
299
+ (
300
+ total_loss,
301
+ total_positional_losses,
302
+ time_to_get_batch,
303
+ forward_time,
304
+ step_time,
305
+ ) = train_epoch()
306
+ list_losses.append(total_loss.item())
307
+ if hasattr(dl, "validate") and epoch % validation_period == 0:
308
+ with torch.no_grad():
309
+ val_score = dl.validate(model)
310
+
311
+ else:
312
+ val_score = None
313
+
314
+ if epoch % saving_period == 0 and checkpoint_file is not None:
315
+ checkpoint = {
316
+ "model_state_dict": model.state_dict(),
317
+ "optimizer_state_dict": optimizer.state_dict(),
318
+ "epoch": epoch,
319
+ }
320
+ torch.save(checkpoint, checkpoint_file)
321
+ full_model_path = checkpoint_file.split(".")[0] + "_full_model.pt"
322
+ torch.save(model, full_model_path)
323
+
324
+ if verbose:
325
+ print("-" * 89)
326
+ print(
327
+ f"| end of epoch {epoch:3d} | time: {(time.time() - epoch_start_time):5.2f}s | mean loss {total_loss:5.2f} | "
328
+ f"pos losses {','.join([f'{l:5.2f}' for l in total_positional_losses])}, lr {scheduler.get_last_lr()[0]}"
329
+ f" data time {time_to_get_batch:5.2f} step time {step_time:5.2f}"
330
+ f" forward time {forward_time:5.2f}"
331
+ + (f"val score {val_score}" if val_score is not None else "")
332
+ )
333
+ print("-" * 89)
334
+
335
+ # stepping with wallclock time based scheduler
336
+ if epoch_callback is not None and rank == 0:
337
+ epoch_callback(model, epoch / epochs)
338
+ scheduler.step()
339
+ except KeyboardInterrupt:
340
+ pass
341
+
342
+ if rank == 0: # trivially true for non-parallel training
343
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel):
344
+ model = model.module
345
+ dl = None
346
+ if output_path is not None:
347
+ torch.save(model.to("cpu"), output_path)
348
+ print("Checkpoint stored at ", output_path)
349
+ return total_loss, total_positional_losses, model.to("cpu"), dl
350
+
351
+
352
+ def _parse_args(config_parser, parser):
353
+ # Do we have a config file to parse?
354
+ args_config, remaining = config_parser.parse_known_args()
355
+ if args_config.config:
356
+ with open(args_config.config, "r") as f:
357
+ cfg = yaml.safe_load(f)
358
+ parser.set_defaults(**cfg)
359
+
360
+ # The main arg parser parses the rest of the args, the usual
361
+ # defaults will have been overridden if config file specified.
362
+ args = parser.parse_args(remaining)
363
+
364
+ # Cache the args as a text string to save them in the output dir later
365
+ args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
366
+ return args, args_text
367
+
368
+
369
+ if __name__ == "__main__":
370
+ config_parser = argparse.ArgumentParser(
371
+ description="Only used as a first parser for the config file path."
372
+ )
373
+ config_parser.add_argument("--config")
374
+ parser = argparse.ArgumentParser()
375
+ parser.add_argument("prior")
376
+ parser.add_argument("--loss_function", default="barnll")
377
+ # Optional Arg's for `--loss_function barnll`
378
+ parser.add_argument(
379
+ "--min_y",
380
+ type=float,
381
+ help="barnll can only model y in strict ranges, this is the minimum y can take.",
382
+ )
383
+ parser.add_argument(
384
+ "--max_y",
385
+ type=float,
386
+ help="barnll can only model y in strict ranges, this is the maximum y can take.",
387
+ )
388
+ parser.add_argument("--num_buckets", default=100, type=int)
389
+ # parser.add_argument('--num_features', default=None, type=int, help='Specify depending on the prior.')
390
+ parser.add_argument(
391
+ "--extra_prior_kwargs_dict",
392
+ default={},
393
+ dest="extra_prior_kwargs_dict",
394
+ action=StoreDictKeyPair,
395
+ nargs="+",
396
+ metavar="KEY=VAL",
397
+ help="Specify depending on the prior.",
398
+ )
399
+ parser.add_argument(
400
+ "--encoder", default="linear", type=str, help="Specify depending on the prior."
401
+ )
402
+ parser.add_argument(
403
+ "--y_encoder",
404
+ default="linear",
405
+ type=str,
406
+ help="Specify depending on the prior. You should specify this if you do not fuse x and y.",
407
+ )
408
+ parser.add_argument(
409
+ "--pos_encoder",
410
+ default="none",
411
+ type=str,
412
+ help="Specify depending on the prior.",
413
+ )
414
+ parser.add_argument("--bptt", default=10, type=int)
415
+ parser.add_argument("--epochs", default=200, type=int)
416
+ parser.add_argument("--warmup_epochs", default=50, type=int)
417
+ parser.add_argument("--validation_period", default=10, type=int)
418
+ parser.add_argument(
419
+ "--permutation_invariant_max_eval_pos",
420
+ default=None,
421
+ type=int,
422
+ help="Set this to an int to ",
423
+ )
424
+ parser.add_argument(
425
+ "--permutation_invariant_sampling",
426
+ default="weighted",
427
+ help="Only relevant if --permutation_invariant_max_eval_pos is set.",
428
+ )
429
+ parser.add_argument("--train_mixed_precision", action="store_true")
430
+
431
+ # these can likely be mostly left at defaults
432
+ parser.add_argument(
433
+ "--emsize", default=512, type=int
434
+ ) # sometimes even larger is better e.g. 1024
435
+ parser.add_argument("--nlayers", default=6, type=int)
436
+ parser.add_argument("--nhid", default=None, type=int) # 2*emsize is the default
437
+ parser.add_argument(
438
+ "--nhead", default=4, type=int
439
+ ) # nhead = emsize / 64 in the original paper
440
+ parser.add_argument("--dropout", default=0.0, type=float)
441
+ parser.add_argument("--steps_per_epoch", default=10, type=int)
442
+ parser.add_argument("--batch_size", default=1000, type=int)
443
+ parser.add_argument(
444
+ "--lr", "--learning_rate", default=0.001, type=float
445
+ ) # try also .0003, .0001, go lower with lower batch size
446
+ parser.add_argument("--gpu_device", default="cuda", type=str)
447
+
448
+ # for model checkpointing
449
+ parser.add_argument(
450
+ "--checkpoint_file",
451
+ help="absolute or relative-to-the-project-rootdir path to the file storing the state dicts.",
452
+ default=None,
453
+ type=str,
454
+ )
455
+ parser.add_argument("--saving_period", default=10, type=str)
456
+
457
+ args, _ = _parse_args(config_parser, parser)
458
+
459
+ if args.nhid is None:
460
+ args.nhid = 2 * args.emsize
461
+
462
+ prior = args.__dict__.pop("prior")
463
+
464
+ if prior == "gp":
465
+ prior = priors.fast_gp.DataLoader
466
+ elif prior == "ridge":
467
+ prior = priors.ridge.DataLoader
468
+ elif prior == "stroke":
469
+ prior = priors.stroke.DataLoader
470
+ elif prior == "mix_gp":
471
+ prior = priors.fast_gp_mix.DataLoader
472
+ else:
473
+ raise NotImplementedError(f"Prior == {prior}.")
474
+
475
+ loss_function = args.__dict__.pop("loss_function")
476
+
477
+ criterion = nn.GaussianNLLLoss(reduction="none", full=True)
478
+ classificiation_criterion = nn.CrossEntropyLoss(reduction="none")
479
+ num_buckets = args.__dict__.pop("num_buckets")
480
+ max_y = args.__dict__.pop("max_y")
481
+ min_y = args.__dict__.pop("min_y")
482
+ # criterion = nn.MSELoss(reduction='none')
483
+
484
+ device = args.gpu_device if torch.cuda.is_available() else "cpu:0"
485
+
486
+ def get_y_sample():
487
+ args.__dict__["extra_prior_kwargs_dict"]["eval_pos_seq_len_sampler"] = lambda: (
488
+ args.bptt,
489
+ args.bptt,
490
+ )
491
+ dl = prior(
492
+ num_steps=1,
493
+ batch_size=args.batch_size * args.steps_per_epoch,
494
+ seq_len=args.bptt,
495
+ device=device,
496
+ **args.extra_prior_kwargs_dict,
497
+ )
498
+ args.__dict__["extra_prior_kwargs_dict"].pop("eval_pos_seq_len_sampler")
499
+
500
+ y_sample = next(iter(dl))[-2]
501
+ print(
502
+ f"Creating Bar distribution with borders from y sample of size {y_sample.numel()}"
503
+ )
504
+ return y_sample
505
+
506
+ if loss_function == "ce":
507
+ criterion = nn.CrossEntropyLoss(reduction="none")
508
+ elif loss_function == "gaussnll":
509
+ criterion = nn.GaussianNLLLoss(reduction="none", full=True)
510
+ elif loss_function == "mse":
511
+ criterion = nn.MSELoss(reduction="none")
512
+ elif loss_function == "barnll":
513
+ criterion = BarDistribution(
514
+ borders=get_bucket_limits(num_buckets, full_range=(min_y, max_y))
515
+ )
516
+ elif loss_function == "adaptivebarnll":
517
+ borders = get_bucket_limits(
518
+ num_buckets, ys=get_y_sample(), full_range=(min_y, max_y)
519
+ )
520
+ criterion = BarDistribution(borders=borders)
521
+ elif loss_function == "adaptivefullsupportbarnll":
522
+ assert (
523
+ min_y is None and max_y is None
524
+ ), "Please do not specify `min_y` and `max_y` with `unboundedadaptivebarnll`."
525
+ borders = get_bucket_limits(num_buckets, ys=get_y_sample())
526
+ criterion = FullSupportBarDistribution(borders=borders)
527
+ else:
528
+ raise NotImplementedError(f"loss_function == {loss_function}.")
529
+
530
+ encoder = args.__dict__.pop("encoder")
531
+ y_encoder = args.__dict__.pop("y_encoder")
532
+
533
+ def get_encoder_generator(encoder):
534
+ if encoder == "linear":
535
+ encoder_generator = encoders.Linear
536
+ elif encoder == "mlp":
537
+ encoder_generator = encoders.MLP
538
+ elif encoder == "positional":
539
+ encoder_generator = encoders.Positional
540
+ else:
541
+ raise NotImplementedError(f"A {encoder} encoder is not valid.")
542
+ return encoder_generator
543
+
544
+ encoder_generator = get_encoder_generator(encoder)
545
+ y_encoder_generator = get_encoder_generator(y_encoder)
546
+
547
+ pos_encoder = args.__dict__.pop("pos_encoder")
548
+
549
+ if pos_encoder == "none":
550
+ pos_encoder_generator = None
551
+ elif pos_encoder == "sinus":
552
+ pos_encoder_generator = positional_encodings.PositionalEncoding
553
+ elif pos_encoder == "learned":
554
+ pos_encoder_generator = positional_encodings.LearnedPositionalEncoding
555
+ elif pos_encoder == "paired_scrambled_learned":
556
+ pos_encoder_generator = positional_encodings.PairedScrambledPositionalEncodings
557
+ else:
558
+ raise NotImplementedError(f"pos_encoer == {pos_encoder} is not valid.")
559
+
560
+ permutation_invariant_max_eval_pos = args.__dict__.pop(
561
+ "permutation_invariant_max_eval_pos"
562
+ )
563
+ permutation_invariant_sampling = args.__dict__.pop("permutation_invariant_sampling")
564
+ if permutation_invariant_max_eval_pos is not None:
565
+ if permutation_invariant_sampling == "weighted":
566
+ get_sampler = get_weighted_single_eval_pos_sampler
567
+ elif permutation_invariant_sampling == "uniform":
568
+ get_sampler = get_uniform_single_eval_pos_sampler
569
+ else:
570
+ raise ValueError()
571
+ args.__dict__["single_eval_pos_gen"] = get_sampler(
572
+ permutation_invariant_max_eval_pos
573
+ )
574
+
575
+ print("ARGS for `train`:", args.__dict__)
576
+
577
+ if args.__dict__["checkpoint_file"] is not None:
578
+ rootdir = os.path.dirname(os.path.realpath(__file__))
579
+ args.__dict__["checkpoint_file"] = os.path.join(
580
+ rootdir, args.__dict__["checkpoint_file"]
581
+ )
582
+
583
+ if os.path.exists(args.__dict__["checkpoint_file"]):
584
+ state_dicts = torch.load(args.__dict__["checkpoint_file"])
585
+ args.__dict__["load_weights_from_this_state_dict"] = state_dicts[
586
+ "model_state_dict"
587
+ ]
588
+ args.__dict__["load_optimizer_from_this_state_dict"] = state_dicts[
589
+ "optimizer_state_dict"
590
+ ]
591
+ else:
592
+ args.__dict__["load_weights_from_this_state_dict"] = None
593
+ args.__dict__["load_optimizer_from_this_state_dict"] = None
594
+
595
+ train(
596
+ prior,
597
+ criterion,
598
+ encoder_generator,
599
+ y_encoder_generator=y_encoder_generator,
600
+ pos_encoder_generator=pos_encoder_generator,
601
+ **args.__dict__,
602
+ )
lcpfn/train_lcpfn.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ from torch import nn
4
+
5
+ from lcpfn import bar_distribution, encoders, priors, train
6
+ from lcpfn import utils
7
+
8
+
9
+ def train_lcpfn(
10
+ get_batch_func,
11
+ seq_len: int = 100,
12
+ emsize: int = 512,
13
+ nlayers: int = 12,
14
+ num_borders: int = 1000,
15
+ lr: float = 0.001,
16
+ batch_size: int = 100,
17
+ epochs: int = 1000,
18
+ ):
19
+ """
20
+ Train a LCPFN model using the specified hyperparameters.
21
+
22
+ Args:
23
+ get_batch_func (callable): A function that returns a batch of learning curves.
24
+ seq_len (int, optional): The length of the input sequence. Defaults to 100.
25
+ emsize (int, optional): The size of the embedding layer. Defaults to 512.
26
+ nlayers (int, optional): The number of layers in the model. Defaults to 12.
27
+ num_borders_choices (int, optional): The number of borders to use. Defaults to 1000.
28
+ lr (float, optional): The learning rate for the optimizer. Defaults to 0.001.
29
+ batch_size (int, optional): The batch size for training. Defaults to 100.
30
+ epochs (int, optional): The number of epochs to train for. Defaults to 1000.
31
+
32
+ Returns:
33
+ torch.module: The trained model.
34
+ """
35
+
36
+ hps = {}
37
+
38
+ # PFN training hyperparameters
39
+ dataloader = priors.utils.get_batch_to_dataloader(get_batch_func) # type: ignore
40
+
41
+ num_features = 1
42
+
43
+ ys = get_batch_func(
44
+ 10_000,
45
+ seq_len,
46
+ num_features,
47
+ hyperparameters=hps,
48
+ single_eval_pos=seq_len,
49
+ )
50
+
51
+ bucket_limits = bar_distribution.get_bucket_limits(num_borders, ys=ys[2])
52
+
53
+ # Discretization of the predictive distributions
54
+ criterions = {
55
+ num_features: {
56
+ num_borders: bar_distribution.FullSupportBarDistribution(bucket_limits)
57
+ }
58
+ }
59
+
60
+ config = dict(
61
+ nlayers=nlayers,
62
+ priordataloader_class=dataloader,
63
+ criterion=criterions[num_features][num_borders],
64
+ encoder_generator=lambda in_dim, out_dim: nn.Sequential(
65
+ encoders.Normalize(0.0, 101.0),
66
+ encoders.Normalize(0.5, math.sqrt(1 / 12)),
67
+ encoders.Linear(in_dim, out_dim),
68
+ ),
69
+ emsize=emsize,
70
+ nhead=(emsize // 128),
71
+ warmup_epochs=(epochs // 4),
72
+ y_encoder_generator=encoders.get_normalized_uniform_encoder(encoders.Linear),
73
+ batch_size=batch_size,
74
+ scheduler=utils.get_cosine_schedule_with_warmup,
75
+ extra_prior_kwargs_dict={
76
+ # "num_workers": 10,
77
+ "num_features": num_features,
78
+ "hyperparameters": {
79
+ **hps,
80
+ },
81
+ },
82
+ epochs=epochs,
83
+ lr=lr,
84
+ bptt=seq_len,
85
+ single_eval_pos_gen=utils.get_uniform_single_eval_pos_sampler(seq_len, min_len=1),
86
+ aggregate_k_gradients=1,
87
+ nhid=(emsize * 2),
88
+ steps_per_epoch=100,
89
+ train_mixed_precision=False,
90
+ )
91
+
92
+ return train.train(**config)
lcpfn/transformer.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch import Tensor
7
+ from torch.nn import Module, TransformerEncoder
8
+
9
+ from lcpfn.layer import TransformerEncoderLayer, _get_activation_fn
10
+ from lcpfn.utils import SeqBN, bool_mask_to_att_mask
11
+
12
+
13
+
14
+ class TransformerModel(nn.Module):
15
+ def __init__(self, encoder, n_out, ninp, nhead, nhid, nlayers, dropout=0.0, style_encoder=None, y_encoder=None,
16
+ pos_encoder=None, decoder=None, input_normalization=False, init_method=None, pre_norm=False,
17
+ activation='gelu', recompute_attn=False, num_global_att_tokens=0, full_attention=False,
18
+ all_layers_same_init=True):
19
+ super().__init__()
20
+ self.model_type = 'Transformer'
21
+ encoder_layer_creator = lambda: TransformerEncoderLayer(ninp, nhead, nhid, dropout, activation=activation,
22
+ pre_norm=pre_norm, recompute_attn=recompute_attn)
23
+ self.transformer_encoder = TransformerEncoder(encoder_layer_creator(), nlayers)\
24
+ if all_layers_same_init else TransformerEncoderDiffInit(encoder_layer_creator, nlayers)
25
+ self.ninp = ninp
26
+ self.encoder = encoder
27
+ self.y_encoder = y_encoder
28
+ self.pos_encoder = pos_encoder
29
+ self.decoder = decoder(ninp, nhid, n_out) if decoder is not None else nn.Sequential(nn.Linear(ninp, nhid), nn.GELU(), nn.Linear(nhid, n_out))
30
+ self.input_ln = SeqBN(ninp) if input_normalization else None
31
+ self.style_encoder = style_encoder
32
+ self.init_method = init_method
33
+ if num_global_att_tokens is not None:
34
+ assert not full_attention
35
+ self.global_att_embeddings = nn.Embedding(num_global_att_tokens, ninp) if num_global_att_tokens else None
36
+ self.full_attention = full_attention
37
+
38
+ self.n_out = n_out
39
+ self.nhid = nhid
40
+
41
+ self.init_weights()
42
+
43
+ @staticmethod
44
+ def generate_square_subsequent_mask(sz):
45
+ mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
46
+ return bool_mask_to_att_mask(mask)
47
+
48
+ @staticmethod
49
+ def generate_D_q_matrix(sz, query_size):
50
+ train_size = sz-query_size
51
+ mask = torch.zeros(sz,sz) == 0
52
+ mask[:,train_size:].zero_()
53
+ mask |= torch.eye(sz) == 1
54
+ return bool_mask_to_att_mask(mask)
55
+
56
+ @staticmethod
57
+ def generate_global_att_query_matrix(num_global_att_tokens, seq_len, num_query_tokens):
58
+ train_size = seq_len + num_global_att_tokens - num_query_tokens
59
+ sz = seq_len + num_global_att_tokens
60
+ mask = torch.zeros(num_query_tokens, sz) == 0
61
+ mask[:,train_size:].zero_()
62
+ mask[:,train_size:] |= torch.eye(num_query_tokens) == 1
63
+ return bool_mask_to_att_mask(mask)
64
+
65
+ @staticmethod
66
+ def generate_global_att_trainset_matrix(num_global_att_tokens, seq_len, num_query_tokens):
67
+ train_size = seq_len + num_global_att_tokens - num_query_tokens
68
+ trainset_size = seq_len - num_query_tokens
69
+ mask = torch.zeros(trainset_size, num_global_att_tokens) == 0
70
+ #mask[:,num_global_att_tokens:].zero_()
71
+ #mask[:,num_global_att_tokens:] |= torch.eye(trainset_size) == 1
72
+ return bool_mask_to_att_mask(mask)
73
+
74
+ @staticmethod
75
+ def generate_global_att_globaltokens_matrix(num_global_att_tokens, seq_len, num_query_tokens):
76
+ mask = torch.zeros(num_global_att_tokens, num_global_att_tokens+seq_len-num_query_tokens) == 0
77
+ return bool_mask_to_att_mask(mask)
78
+
79
+ def init_weights(self):
80
+ initrange = 1.
81
+ # if isinstance(self.encoder,EmbeddingEncoder):
82
+ # self.encoder.weight.data.uniform_(-initrange, initrange)
83
+ # self.decoder.bias.data.zero_()
84
+ # self.decoder.weight.data.uniform_(-initrange, initrange)
85
+ if self.init_method is not None:
86
+ self.apply(self.init_method)
87
+ for layer in self.transformer_encoder.layers:
88
+ nn.init.zeros_(layer.linear2.weight)
89
+ nn.init.zeros_(layer.linear2.bias)
90
+ attns = layer.self_attn if isinstance(layer.self_attn, nn.ModuleList) else [layer.self_attn]
91
+ for attn in attns:
92
+ nn.init.zeros_(attn.out_proj.weight)
93
+ nn.init.zeros_(attn.out_proj.bias)
94
+
95
+ def forward(self, src, src_mask=None, single_eval_pos=None):
96
+ assert isinstance(src, tuple), 'inputs (src) have to be given as (x,y) or (style,x,y) tuple'
97
+
98
+ if len(src) == 2: # (x,y) and no style
99
+ src = (None,) + src
100
+
101
+ style_src, style_src_size = (src[0], (0 if (src[0] is None) else 1))
102
+ if src_mask is not None: assert self.global_att_embeddings is None or isinstance(src_mask, tuple)
103
+ if src_mask is None:
104
+ x_src = src[1]
105
+ if self.global_att_embeddings is None:
106
+ full_len = len(x_src) + style_src_size
107
+ if self.full_attention:
108
+ src_mask = bool_mask_to_att_mask(torch.ones((full_len, full_len), dtype=torch.bool)).to(x_src.device)
109
+ else:
110
+ src_mask = self.generate_D_q_matrix(len(x_src) + style_src_size, len(x_src) + style_src_size -single_eval_pos).to(x_src.device)
111
+ else:
112
+ src_mask_args = (self.global_att_embeddings.num_embeddings,
113
+ len(x_src) + style_src_size,
114
+ len(x_src) + style_src_size - single_eval_pos)
115
+ src_mask = (self.generate_global_att_globaltokens_matrix(*src_mask_args).to(x_src.device),
116
+ self.generate_global_att_trainset_matrix(*src_mask_args).to(x_src.device),
117
+ self.generate_global_att_query_matrix(*src_mask_args).to(x_src.device))
118
+
119
+ style_src, x_src, y_src = src
120
+ x_src = self.encoder(x_src)
121
+ y_src = self.y_encoder(y_src.unsqueeze(-1) if len(y_src.shape) < len(x_src.shape) else y_src)
122
+ style_src = self.style_encoder(style_src).unsqueeze(0) if self.style_encoder else torch.tensor([], device=x_src.device)
123
+ global_src = torch.tensor([], device=x_src.device) if self.global_att_embeddings is None else \
124
+ self.global_att_embeddings.weight.unsqueeze(1).repeat(1, x_src.shape[1], 1)
125
+ train_x = x_src[:single_eval_pos] + y_src[:single_eval_pos]
126
+ src = torch.cat([global_src, style_src, train_x, x_src[single_eval_pos:]], 0)
127
+
128
+ if self.input_ln is not None:
129
+ src = self.input_ln(src)
130
+
131
+ if self.pos_encoder is not None:
132
+ src = self.pos_encoder(src)
133
+
134
+ # If we have style input, drop its output
135
+ output = self.transformer_encoder(src, src_mask)[style_src_size:]
136
+ output = self.decoder(output)
137
+ return output[single_eval_pos+(self.global_att_embeddings.num_embeddings if self.global_att_embeddings else 0):]
138
+
139
+ @torch.no_grad()
140
+ def init_from_small_model(self, small_model):
141
+ assert isinstance(self.decoder, nn.Linear) and isinstance(self.encoder, (nn.Linear, nn.Sequential)) \
142
+ and isinstance(self.y_encoder, (nn.Linear, nn.Sequential))
143
+
144
+ def set_encoder_weights(my_encoder, small_model_encoder):
145
+ my_encoder_linear, small_encoder_linear = (my_encoder, small_model_encoder) \
146
+ if isinstance(my_encoder, nn.Linear) else (my_encoder[-1], small_model_encoder[-1])
147
+ small_in_dim = small_encoder_linear.out_features
148
+ my_encoder_linear.weight.zero_()
149
+ my_encoder_linear.bias.zero_()
150
+ my_encoder_linear.weight[:small_in_dim] = small_encoder_linear.weight
151
+ my_encoder_linear.bias[:small_in_dim] = small_encoder_linear.bias
152
+
153
+ set_encoder_weights(self.encoder, small_model.encoder)
154
+ set_encoder_weights(self.y_encoder, small_model.y_encoder)
155
+
156
+ small_in_dim = small_model.decoder.in_features
157
+
158
+ self.decoder.weight[:, :small_in_dim] = small_model.decoder.weight
159
+ self.decoder.bias = small_model.decoder.bias
160
+
161
+ for my_layer, small_layer in zip(self.transformer_encoder.layers, small_model.transformer_encoder.layers):
162
+ small_hid_dim = small_layer.linear1.out_features
163
+ my_in_dim = my_layer.linear1.in_features
164
+
165
+ # packed along q,k,v order in first dim
166
+ my_in_proj_w = my_layer.self_attn.in_proj_weight
167
+ small_in_proj_w = small_layer.self_attn.in_proj_weight
168
+
169
+ my_in_proj_w.view(3, my_in_dim, my_in_dim)[:, :small_in_dim, :small_in_dim] = small_in_proj_w.view(3,
170
+ small_in_dim,
171
+ small_in_dim)
172
+ my_layer.self_attn.in_proj_bias.view(3, my_in_dim)[:,
173
+ :small_in_dim] = small_layer.self_attn.in_proj_bias.view(3, small_in_dim)
174
+
175
+ my_layer.self_attn.out_proj.weight[:small_in_dim, :small_in_dim] = small_layer.self_attn.out_proj.weight
176
+ my_layer.self_attn.out_proj.bias[:small_in_dim] = small_layer.self_attn.out_proj.bias
177
+
178
+ my_layer.linear1.weight[:small_hid_dim, :small_in_dim] = small_layer.linear1.weight
179
+ my_layer.linear1.bias[:small_hid_dim] = small_layer.linear1.bias
180
+
181
+ my_layer.linear2.weight[:small_in_dim, :small_hid_dim] = small_layer.linear2.weight
182
+ my_layer.linear2.bias[:small_in_dim] = small_layer.linear2.bias
183
+
184
+ my_layer.norm1.weight[:small_in_dim] = math.sqrt(small_in_dim / my_in_dim) * small_layer.norm1.weight
185
+ my_layer.norm2.weight[:small_in_dim] = math.sqrt(small_in_dim / my_in_dim) * small_layer.norm2.weight
186
+
187
+ my_layer.norm1.bias[:small_in_dim] = small_layer.norm1.bias
188
+ my_layer.norm2.bias[:small_in_dim] = small_layer.norm2.bias
189
+
190
+
191
+ class TransformerEncoderDiffInit(Module):
192
+ r"""TransformerEncoder is a stack of N encoder layers
193
+
194
+ Args:
195
+ encoder_layer_creator: a function generating objects of TransformerEncoderLayer class without args (required).
196
+ num_layers: the number of sub-encoder-layers in the encoder (required).
197
+ norm: the layer normalization component (optional).
198
+ """
199
+ __constants__ = ['norm']
200
+
201
+ def __init__(self, encoder_layer_creator, num_layers, norm=None):
202
+ super().__init__()
203
+ self.layers = nn.ModuleList([encoder_layer_creator() for _ in range(num_layers)])
204
+ self.num_layers = num_layers
205
+ self.norm = norm
206
+
207
+ def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
208
+ r"""Pass the input through the encoder layers in turn.
209
+
210
+ Args:
211
+ src: the sequence to the encoder (required).
212
+ mask: the mask for the src sequence (optional).
213
+ src_key_padding_mask: the mask for the src keys per batch (optional).
214
+
215
+ Shape:
216
+ see the docs in Transformer class.
217
+ """
218
+ output = src
219
+
220
+ for mod in self.layers:
221
+ output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
222
+
223
+ if self.norm is not None:
224
+ output = self.norm(output)
225
+
226
+ return output
lcpfn/utils.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import argparse
4
+ import random
5
+ import datetime
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.optim.lr_scheduler import LambdaLR
10
+ import numpy as np
11
+
12
+ # copied from huggingface
13
+ def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1):
14
+ """ Create a schedule with a learning rate that decreases following the
15
+ values of the cosine function between 0 and `pi * cycles` after a warmup
16
+ period during which it increases linearly between 0 and 1.
17
+ """
18
+
19
+ def lr_lambda(current_step):
20
+ if current_step < num_warmup_steps:
21
+ return float(current_step) / float(max(1, num_warmup_steps))
22
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
23
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
24
+
25
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
26
+
27
+ # copied from huggingface
28
+ def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
29
+ """
30
+ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
31
+ a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
32
+
33
+ Args:
34
+ optimizer (:class:`~torch.optim.Optimizer`):
35
+ The optimizer for which to schedule the learning rate.
36
+ num_warmup_steps (:obj:`int`):
37
+ The number of steps for the warmup phase.
38
+ num_training_steps (:obj:`int`):
39
+ The total number of training steps.
40
+ last_epoch (:obj:`int`, `optional`, defaults to -1):
41
+ The index of the last epoch when resuming training.
42
+
43
+ Return:
44
+ :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
45
+ """
46
+
47
+ def lr_lambda(current_step: int):
48
+ if current_step < num_warmup_steps:
49
+ return float(current_step) / float(max(1, num_warmup_steps))
50
+ return max(
51
+ 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
52
+ )
53
+
54
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
55
+
56
+
57
+ def get_openai_lr(transformer_model):
58
+ num_params = sum(p.numel() for p in transformer_model.parameters())
59
+ return 0.003239 - 0.0001395 * math.log(num_params)
60
+
61
+
62
+ def get_weighted_single_eval_pos_sampler(max_len):
63
+ """
64
+ This gives a sampler that can be used for `single_eval_pos` which yields good performance for all positions p,
65
+ where p <= `max_len`. At most `max_len` - 1 examples are shown to the Transformer.
66
+ :return: Sampler that can be fed to `train()` as `single_eval_pos_gen`.
67
+ """
68
+ return lambda: random.choices(range(max_len), [1 / (max_len - i) for i in range(max_len)])[0]
69
+
70
+
71
+ def get_uniform_single_eval_pos_sampler(max_len, min_len=0):
72
+ """
73
+ Just sample any evaluation position with the same weight
74
+ :return: Sampler that can be fed to `train()` as `single_eval_pos_gen`.
75
+ """
76
+ return lambda: random.choices(range(min_len, max_len))[0]
77
+
78
+
79
+ class SeqBN(nn.Module):
80
+ def __init__(self, d_model):
81
+ super().__init__()
82
+ self.bn = nn.BatchNorm1d(d_model)
83
+ self.d_model = d_model
84
+
85
+ def forward(self, x):
86
+ assert self.d_model == x.shape[-1]
87
+ flat_x = x.view(-1, self.d_model)
88
+ flat_x = self.bn(flat_x)
89
+ return flat_x.view(*x.shape)
90
+
91
+
92
+ def set_locals_in_self(locals):
93
+ """
94
+ Call this function like `set_locals_in_self(locals())` to set all local variables as object variables.
95
+ Especially useful right at the beginning of `__init__`.
96
+ :param locals: `locals()`
97
+ """
98
+ self = locals['self']
99
+ for var_name, val in locals.items():
100
+ if var_name != 'self': setattr(self, var_name, val)
101
+
102
+
103
+ default_device = 'cuda:0' if torch.cuda.is_available() else 'cpu:0'
104
+
105
+
106
+ # Copied from StackOverflow, but we do an eval on the values additionally
107
+ class StoreDictKeyPair(argparse.Action):
108
+ def __init__(self, option_strings, dest, nargs=None, **kwargs):
109
+ self._nargs = nargs
110
+ super(StoreDictKeyPair, self).__init__(option_strings, dest, nargs=nargs, **kwargs)
111
+
112
+ def __call__(self, parser, namespace, values, option_string=None):
113
+ my_dict = {}
114
+ for kv in values:
115
+ k, v = kv.split("=")
116
+ try:
117
+ my_dict[k] = eval(v)
118
+ except NameError:
119
+ my_dict[k] = v
120
+ setattr(namespace, self.dest, my_dict)
121
+ print("dict values: {}".format(my_dict))
122
+
123
+ def get_nan_value(v, set_value_to_nan=0.0):
124
+ if random.random() < set_value_to_nan:
125
+ return v
126
+ else:
127
+ return random.choice([-999, 0, 1, 999])
128
+
129
+ def to_ranking(data):
130
+ x = (data >= data.unsqueeze(-3))
131
+ x = x.sum(0)
132
+ return x
133
+ # TODO: Is there a better way to do this?
134
+ # 1. Cmparing to unique elements: When all values are different we still get quadratic blowup
135
+ # 2. Argsort(Argsort()) returns ranking, but with duplicate values there is an ordering which is problematic
136
+ # 3. Argsort(Argsort(Unique))->Scatter seems a bit complicated, doesn't have quadratic blowup, but how fast?
137
+ def to_ranking_low_mem(data):
138
+ x = torch.zeros_like(data)
139
+ for col in range(data.shape[-1]):
140
+ x_ = (data[:, :, col] >= data[:, :, col].unsqueeze(-2))
141
+ x_ = x_.sum(0)
142
+ x[:, :, col] = x_
143
+ return x
144
+
145
+ def nan_handling_missing_for_unknown_reason_value(set_value_to_nan=0.0):
146
+ return get_nan_value(float('nan'), set_value_to_nan)
147
+
148
+ def nan_handling_missing_for_no_reason_value(set_value_to_nan=0.0):
149
+ return get_nan_value(float('-inf'), set_value_to_nan)
150
+
151
+ def nan_handling_missing_for_a_reason_value(set_value_to_nan=0.0):
152
+ return get_nan_value(float('inf'), set_value_to_nan)
153
+
154
+ def torch_nanmean(x, axis=0):
155
+ num = torch.where(torch.isnan(x), torch.full_like(x, 0), torch.full_like(x, 1)).sum(axis=axis)
156
+ value = torch.where(torch.isnan(x), torch.full_like(x, 0), x).sum(axis=axis)
157
+ return value / num
158
+
159
+ def torch_nanstd(x, axis=0):
160
+ num = torch.where(torch.isnan(x), torch.full_like(x, 0), torch.full_like(x, 1)).sum(axis=axis)
161
+ value = torch.where(torch.isnan(x), torch.full_like(x, 0), x).sum(axis=axis)
162
+ mean = value / num
163
+ mean_broadcast = torch.repeat_interleave(mean.unsqueeze(axis), x.shape[axis], dim=axis)
164
+ return torch.sqrt(torch.nansum(torch.square(mean_broadcast - x), axis=axis) / (num - 1))
165
+
166
+ def normalize_data(data, normalize_positions=-1):
167
+ if normalize_positions > 0:
168
+ mean = torch_nanmean(data[:normalize_positions], axis=0)
169
+ std = torch_nanstd(data[:normalize_positions], axis=0) + .000001
170
+ else:
171
+ mean = torch_nanmean(data, axis=0)
172
+ std = torch_nanstd(data, axis=0) + .000001
173
+ data = (data - mean) / std
174
+ data = torch.clip(data, min=-100, max=100)
175
+
176
+ return data
177
+
178
+ def remove_outliers(X, n_sigma=4):
179
+ # Expects T, B, H
180
+ assert len(X.shape) == 3, "X must be T,B,H"
181
+ #for b in range(X.shape[1]):
182
+ #for col in range(X.shape[2]):
183
+ data = X
184
+ data_mean, data_std = torch_nanmean(data, axis=0), torch_nanstd(data, axis=0)
185
+ cut_off = data_std * n_sigma
186
+ lower, upper = data_mean - cut_off, data_mean + cut_off
187
+
188
+ data_clean = X[:].clone()
189
+ data_clean[torch.logical_or(data > upper, data < lower)] = np.nan
190
+ data_mean, data_std = torch_nanmean(data_clean, axis=0), torch_nanstd(data_clean, axis=0)
191
+ cut_off = data_std * n_sigma
192
+ lower, upper = data_mean - cut_off, data_mean + cut_off
193
+
194
+ X = torch.maximum(-torch.log(1+torch.abs(X)) + lower, X)
195
+ X = torch.minimum(torch.log(1+torch.abs(X)) + upper, X)
196
+ # print(ds[1][data < lower, col], ds[1][data > upper, col], ds[1][~np.isnan(data), col].shape, data_mean, data_std)
197
+ return X
198
+
199
+ def bool_mask_to_att_mask(mask):
200
+ return mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
201
+
202
+ def print_on_master_only(is_master):
203
+ import builtins as __builtin__
204
+
205
+ builtin_print = __builtin__.print
206
+
207
+ def print(*args, **kwargs):
208
+ force = kwargs.pop("force", False)
209
+ if is_master or force:
210
+ builtin_print(*args, **kwargs)
211
+
212
+ __builtin__.print = print
213
+
214
+
215
+ def init_dist(device):
216
+ print('init dist')
217
+ if 'LOCAL_RANK' in os.environ:
218
+ # launched with torch.distributed.launch
219
+ rank = int(os.environ["LOCAL_RANK"])
220
+ print('torch.distributed.launch and my rank is', rank)
221
+ torch.cuda.set_device(rank)
222
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(rank)
223
+ torch.distributed.init_process_group(backend="nccl", init_method="env://", timeout=datetime.timedelta(seconds=20),
224
+ world_size=torch.cuda.device_count(), rank=rank)
225
+ torch.distributed.barrier()
226
+ print_on_master_only(rank == 0)
227
+ print(f"Distributed training on {torch.cuda.device_count()} GPUs, this is rank {rank}, "
228
+ "only I can print, but when using print(..., force=True) it will print on all ranks.")
229
+ return True, rank, f'cuda:{rank}'
230
+ elif 'SLURM_PROCID' in os.environ and torch.cuda.device_count() > 1:
231
+ # this is for multi gpu when starting with submitit
232
+ assert device != 'cpu:0'
233
+ rank = int(os.environ['SLURM_PROCID'])
234
+ os.environ['MASTER_ADDR'] = 'localhost'
235
+ os.environ['MASTER_PORT'] = '12355'
236
+ torch.cuda.set_device(rank)
237
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(rank)
238
+ print('distributed submitit launch and my rank is', rank)
239
+ torch.distributed.init_process_group(backend="nccl", init_method="env://", timeout=datetime.timedelta(seconds=20),
240
+ world_size=torch.cuda.device_count(), rank=rank)
241
+ torch.distributed.barrier()
242
+ print_on_master_only(rank == 0)
243
+ print(f"Distributed training on {torch.cuda.device_count()} GPUs, this is rank {rank}, "
244
+ "only I can print, but when using print(..., force=True) it will print on all ranks.")
245
+
246
+ return True, rank, f'cuda:{rank}'
247
+ else:
248
+ print('Not using distributed')
249
+ # will not change any of the behavior of print, but allows putting the force=True in the print calls
250
+ print_on_master_only(True)
251
+ return False, 0, device
252
+
253
+
254
+ def check_compatibility(dl):
255
+ if hasattr(dl, 'num_outputs'):
256
+ print('`num_outputs` for the DataLoader is deprecated. It is assumed to be 1 from now on.')
257
+ assert dl.num_outputs != 1, "We assume num_outputs to be 1. Instead of the num_ouputs change your loss." \
258
+ "We specify the number of classes in the CE loss."
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch==1.11.0
2
+ numpy>=1.21.2
3
+ # lcpfn @ git+https://github.com/automl/lcpfn.git