herilalaina commited on
Commit
328c052
1 Parent(s): b62776c
lcpfn/.ipynb_checkpoints/__init__-checkpoint.py DELETED
@@ -1,53 +0,0 @@
1
- import os, sys
2
- sys.path.insert(0, os.path.dirname(__file__))
3
-
4
-
5
- model_path = 'trained_models'
6
-
7
- def prepare_models():
8
- pfns4bo_dir = os.path.dirname(__file__)
9
- model_names = ['pfn_EPOCH1000_EMSIZE512_NLAYERS12_NBUCKETS1000.pt',
10
- 'pfn_EPOCH1000_EMSIZE512_NLAYERS6_NBUCKETS1000.pt']
11
-
12
- for name in model_names:
13
- weights_path = os.path.join(pfns4bo_dir, model_path, name)
14
- compressed_weights_path = os.path.join(pfns4bo_dir, model_path, name + '.gz')
15
- if not os.path.exists(weights_path):
16
- if not os.path.exists(compressed_weights_path):
17
- print("Downloading", os.path.abspath(compressed_weights_path))
18
- import requests
19
- url = f'https://github.com/automl/lcpfn/raw/main/lcpfn/trained_models/{name + ".gz"}'
20
- r = requests.get(url, allow_redirects=True)
21
- os.makedirs(os.path.dirname(compressed_weights_path), exist_ok=True)
22
- with open(compressed_weights_path, 'wb') as f:
23
- f.write(r.content)
24
- if os.path.exists(compressed_weights_path):
25
- print("Unzipping", name)
26
- os.system(f"gzip -dk {compressed_weights_path}")
27
- else:
28
- print("Failed to find", compressed_weights_path)
29
- print("Make sure you have an internet connection to download the model automatically..")
30
- if os.path.exists(weights_path):
31
- print("Successfully located model at", weights_path)
32
-
33
-
34
- model_dict = {
35
- 'EMSIZE512_NLAYERS12_NBUCKETS1000': os.path.join(os.path.dirname(__file__),model_path,
36
- 'pfn_EPOCH1000_EMSIZE512_NLAYERS12_NBUCKETS1000.pt'),
37
- 'EMSIZE512_NLAYERS6_NBUCKETS1000': os.path.join(os.path.dirname(__file__),model_path,
38
- 'pfn_EPOCH1000_EMSIZE512_NLAYERS6_NBUCKETS1000.pt'),
39
- }
40
-
41
-
42
- def __getattr__(name):
43
- if name in model_dict:
44
- if not os.path.exists(model_dict[name]):
45
- print("Can't find", os.path.abspath(model_dict[name]), "thus unzipping/downloading models now.")
46
- print("This might take a while..")
47
- prepare_models()
48
- return model_dict[name]
49
- raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
50
-
51
- from lcpfn.model import LCPFN
52
- from lcpfn.train_lcpfn import train_lcpfn
53
- from lcpfn.domhan_prior import sample_from_prior, create_get_batch_func
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lcpfn/.ipynb_checkpoints/curves-checkpoint.py DELETED
@@ -1,277 +0,0 @@
1
- import numpy as np
2
- from collections import OrderedDict
3
-
4
- prior = {
5
- "pow3": {
6
- "uniform": OrderedDict(
7
- a={"type": "uniform", "param1": -1, "param2": 1},
8
- c={"type": "uniform", "param1": 0, "param2": 1},
9
- alpha={"type": "uniform", "param1": 0, "param2": 1},
10
- ),
11
- "peaked": OrderedDict(
12
- a={"type": "uniform", "param1": -0.6, "param2": 0.6},
13
- c={"type": "uniform", "param1": 0, "param2": 1.25},
14
- alpha={"type": "log_normal", "param1": 0, "param2": 2},
15
- ),
16
- },
17
- "ilog2": {
18
- "uniform": OrderedDict(
19
- c={"type": "uniform", "param1": 0, "param2": 1},
20
- a={"type": "uniform", "param1": -1, "param2": 1},
21
- ),
22
- "peaked": OrderedDict(
23
- c={"type": "uniform", "param1": 0, "param2": 1},
24
- a={"type": "uniform", "param1": -0.5, "param2": 0.5},
25
- ),
26
- },
27
- "janoschek": {
28
- "uniform": OrderedDict(
29
- a={"type": "uniform", "param1": 0, "param2": 1},
30
- beta={"type": "uniform", "param1": 0, "param2": 2},
31
- k={"type": "uniform", "param1": 0, "param2": 1},
32
- delta={"type": "uniform", "param1": -5, "param2": 5},
33
- ),
34
- "peaked": OrderedDict(
35
- a={"type": "uniform", "param1": 0, "param2": 1},
36
- beta={"type": "uniform", "param1": 0, "param2": 2},
37
- k={"type": "log_normal", "param1": -2, "param2": 1},
38
- delta={"type": "log_normal", "param1": 0, "param2": 0.5},
39
- ),
40
- },
41
- }
42
-
43
-
44
- def prior_sampler(rng, type, param1, param2):
45
- if type == "uniform":
46
- return rng.uniform(param1, param2)
47
- elif type == "log_normal":
48
- return rng.lognormal(param1, param2)
49
- raise Exception("Unknown prior type: {}".format(type))
50
-
51
-
52
- def pow3(x, c, a, alpha):
53
- return c - a * (x) ** (-alpha)
54
-
55
-
56
- def prior_pow3(rng):
57
- return {
58
- p: prior_sampler(
59
- rng,
60
- prior["pow3"]["peaked"][p]["type"],
61
- param1=prior["pow3"]["peaked"][p]["param1"],
62
- param2=prior["pow3"]["peaked"][p]["param2"],
63
- )
64
- for p in ["a", "c", "alpha"]
65
- }
66
-
67
-
68
- def uniform_prior_pow3(rng):
69
- return {
70
- p: prior_sampler(
71
- rng,
72
- prior["pow3"]["uniform"][p]["type"],
73
- param1=prior["pow3"]["uniform"][p]["param1"],
74
- param2=prior["pow3"]["uniform"][p]["param2"],
75
- )
76
- for p in ["a", "c", "alpha"]
77
- }
78
-
79
-
80
- def ilog2(x, c, a):
81
- return c - a / (np.log(x + 1))
82
-
83
-
84
- def prior_ilog2(rng):
85
- return {
86
- p: prior_sampler(
87
- rng,
88
- prior["ilog2"]["peaked"][p]["type"],
89
- param1=prior["ilog2"]["peaked"][p]["param1"],
90
- param2=prior["ilog2"]["peaked"][p]["param2"],
91
- )
92
- for p in ["a", "c"]
93
- }
94
-
95
-
96
- def uniform_prior_ilog2(rng):
97
- return {
98
- p: prior_sampler(
99
- rng,
100
- prior["ilog2"]["uniform"][p]["type"],
101
- param1=prior["ilog2"]["uniform"][p]["param1"],
102
- param2=prior["ilog2"]["uniform"][p]["param2"],
103
- )
104
- for p in ["a", "c"]
105
- }
106
-
107
-
108
- def janoschek(x, a, beta, k, delta):
109
- """
110
- http://www.pisces-conservation.com/growthhelp/janoschek.htm
111
- """
112
- return a - (a - beta) * np.exp(-k * x**delta)
113
-
114
-
115
- def prior_janoschek(rng):
116
- return {
117
- p: prior_sampler(
118
- rng,
119
- prior["janoschek"]["peaked"][p]["type"],
120
- param1=prior["janoschek"]["peaked"][p]["param1"],
121
- param2=prior["janoschek"]["peaked"][p]["param2"],
122
- )
123
- for p in ["a", "beta", "k", "delta"]
124
- }
125
-
126
-
127
- def uniform_prior_janoschek(rng):
128
- return {
129
- p: prior_sampler(
130
- rng,
131
- prior["janoschek"]["uniform"][p]["type"],
132
- param1=prior["janoschek"]["uniform"][p]["param1"],
133
- param2=prior["janoschek"]["uniform"][p]["param2"],
134
- )
135
- for p in ["a", "beta", "k", "delta"]
136
- }
137
-
138
-
139
- def log_power(x, a, b, c):
140
- # a: upper bound
141
- # c: growth rate
142
- # initial = a/ (1 + (1/e^b)^c
143
- return a / (1.0 + (x / np.exp(b)) ** c)
144
-
145
-
146
- def prior_log_power(rng):
147
- # a ~ N(0.8,0.1)
148
- # b ~ N(1,1)
149
- # c ~ U(-3,0)
150
- a = rng.normal(0.8, 0.1)
151
- b = rng.normal(1.0, 1.0)
152
- c = rng.uniform(-3.0, 0.0)
153
- return {"a": a, "b": b, "c": c}
154
-
155
-
156
- def weibull(x, alpha, beta, kappa, delta):
157
- """
158
- Weibull modell
159
- http://www.pisces-conservation.com/growthhelp/index.html?morgan_mercer_floden.htm
160
- alpha: upper asymptote
161
- beta: lower asymptote
162
- k: growth rate
163
- delta: controls the x-ordinate for the point of inflection
164
- """
165
- return alpha - (alpha - beta) * np.exp(-((kappa * x) ** delta))
166
-
167
-
168
- def prior_weibull(rng):
169
- alpha = rng.uniform(0.0, 1.5)
170
- beta = rng.uniform(0.0, 1)
171
- kappa = np.exp(rng.normal(-2.0, 1.0))
172
- delta = np.exp(rng.normal(0, 0.5))
173
- return {"alpha": alpha, "beta": beta, "kappa": kappa, "delta": delta}
174
-
175
-
176
- def mmf(x, alpha, beta, kappa, delta):
177
- """
178
- Morgan-Mercer-Flodin
179
- description:
180
- Nonlinear Regression page 342
181
- http://bit.ly/1jodG17
182
- http://www.pisces-conservation.com/growthhelp/index.html?morgan_mercer_floden.htm
183
- alpha: upper asymptote
184
- kappa: growth rate
185
- beta: initial value
186
- delta: controls the point of inflection
187
- """
188
- return alpha - (alpha - beta) / (1.0 + (kappa * x) ** delta)
189
-
190
-
191
- def prior_mmf(rng):
192
- # alpha ~ N(0.8,0.1)
193
- # beta ~ N(0.2,0.1)
194
- # ln(kappa) ~ N(0,2)
195
- # ln(delta) ~ N(0,1)
196
- alpha = rng.normal(0.8, 0.1)
197
- beta = rng.normal(0.2, 0.1)
198
- kappa = np.exp(rng.normal(0, 2))
199
- delta = np.exp(rng.normal(0, 1))
200
- return {"alpha": alpha, "beta": beta, "kappa": kappa, "delta": delta}
201
-
202
-
203
- def vap(x, a, b, c):
204
- """Vapor pressure model"""
205
- # no upper bound if c > 0
206
- # a = ln(upper bound) for c=0
207
- # a+b = ln(initial)
208
- return np.exp(a + b / x + c * np.log(x))
209
-
210
-
211
- def prior_vap(rng):
212
- a = rng.uniform(-2.0, 0.0) # @heri: range check
213
- b = rng.uniform(-4.0, 0.0) # @heri: range check
214
- c = np.exp(rng.uniform(-8.0, 0.0)) # @heri: same as weights
215
- return {"a": a, "b": b, "c": c}
216
-
217
-
218
- def loglog_linear(x, a, b):
219
- x = np.log(x)
220
- return np.log(a * x + b)
221
-
222
-
223
- def prior_loglog_linear(rng):
224
- # ln(a) ~ N(-2, 1)
225
- # ln(b) ~ U(0, 1)
226
- a = np.exp(rng.normal(-2.0, 1.0))
227
- b = np.exp(rng.uniform(0.0, 1.0))
228
- return {"a": a, "b": b}
229
-
230
-
231
- def exp4(x, c, a, b, alpha):
232
- return c - np.exp(-a * (x**alpha) + b)
233
-
234
-
235
- def prior_exp4(rng):
236
- # c ~ N(0.8,0.1)
237
- c = rng.normal(0.8, 0.1)
238
- # ln(a) ~ N(-2,1)
239
- a = np.exp(rng.normal(-2, 1))
240
- # ln(alpha) ~ N(0,1)
241
- alpha = np.exp(rng.normal(0, 1))
242
- # ln(b) ~ N(0,0.5)
243
- b = np.exp(rng.normal(0, 0.5))
244
- return {"a": a, "b": b, "c": c, "alpha": alpha}
245
-
246
-
247
- def pow4(x, c, a, b, alpha):
248
- return c - (a * x + b) ** -alpha
249
-
250
-
251
- def prior_pow4(rng):
252
- # ln(1 - c) ~ U(-5, 0)
253
- c = 1 - np.exp(rng.uniform(-5.0, 0))
254
- # ln(a) ~ N(-3, 2)
255
- a = np.exp(rng.normal(-3.0, 2))
256
- # ln(alpha) ~ N(0,1)
257
- alpha = np.exp(rng.normal(0, 1))
258
- # ln(b) ~ U(0, 1)
259
- b = np.exp(rng.uniform(0, 1))
260
- return {"a": a, "b": b, "c": c, "alpha": alpha}
261
-
262
-
263
- def dr_hill_zero_background(x, theta, eta, kappa):
264
- # theta: upper bound
265
- # eta: growth rate
266
- # initial = theta/(kappa^eta + 1)
267
- return (theta * x**eta) / (kappa**eta + x**eta)
268
-
269
-
270
- def prior_dr_hill_zero_background(rng):
271
- # theta ~ U(1,0) N(0.8,0.1)
272
- # ln(eta) ~ N(1,1)
273
- # ln(kappa) ~ N(1,2)
274
- theta = rng.normal(0.8, 0.1)
275
- eta = np.exp(rng.normal(1.0, 1.0))
276
- kappa = np.exp(rng.normal(1.0, 2.0))
277
- return {"theta": theta, "eta": eta, "kappa": kappa}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lcpfn/.ipynb_checkpoints/domhan_prior-checkpoint.py DELETED
@@ -1,195 +0,0 @@
1
- from functools import partial
2
- import torch
3
- import numpy as np
4
- from lcpfn.curves import (
5
- pow3,
6
- ilog2,
7
- janoschek,
8
- log_power,
9
- prior_ilog2,
10
- uniform_prior_pow3,
11
- weibull,
12
- mmf,
13
- vap,
14
- loglog_linear,
15
- exp4,
16
- pow4,
17
- dr_hill_zero_background,
18
- )
19
- from lcpfn.curves import (
20
- prior_pow3,
21
- prior_janoschek,
22
- prior_log_power,
23
- prior_weibull,
24
- prior_mmf,
25
- prior_vap,
26
- prior_loglog_linear,
27
- prior_exp4,
28
- prior_pow4,
29
- prior_dr_hill_zero_background,
30
- )
31
- from lcpfn.curves import (
32
- uniform_prior_pow3,
33
- uniform_prior_ilog2,
34
- uniform_prior_janoschek,
35
- )
36
-
37
-
38
- def prior_weights(
39
- rng,
40
- components=[
41
- "pow3",
42
- "ilog2",
43
- "janoschek",
44
- "log_power",
45
- "weibull",
46
- "mmf",
47
- "vap",
48
- "loglog_linear",
49
- "exp4",
50
- "pow4",
51
- "dr_hill_zero_background",
52
- ],
53
- ):
54
- K = len(components)
55
- weights = rng.uniform(0.0, 1, size=(K,))
56
- return {f: weights[i] for i, f in enumerate(components)}
57
-
58
-
59
- def sample_from_prior(rng, seq_len=100):
60
- return sample_prior_comb(
61
- rng=rng, seq_len=seq_len, components=["pow3", "ilog2", "janoschek"], distribution="peaked"
62
- )
63
-
64
-
65
- def sample_prior_comb(
66
- rng,
67
- components,
68
- distribution,
69
- var_lnloc=-4,
70
- var_lnscale=1,
71
- range_constraint=True,
72
- seq_len=100,
73
- ):
74
- f_components = {
75
- "pow3": pow3,
76
- "ilog2": ilog2,
77
- "janoschek": janoschek,
78
- "log_power": log_power,
79
- "weibull": weibull,
80
- "mmf": mmf,
81
- "vap": vap,
82
- "loglog_linear": loglog_linear,
83
- "exp4": exp4,
84
- "pow4": pow4,
85
- "dr_hill_zero_background": dr_hill_zero_background,
86
- }
87
-
88
- if distribution == "peaked":
89
- f_priors = {
90
- "pow3": prior_pow3,
91
- "ilog2": prior_ilog2,
92
- "janoschek": prior_janoschek,
93
- "log_power": prior_log_power,
94
- "weibull": prior_weibull,
95
- "mmf": prior_mmf,
96
- "vap": prior_vap,
97
- "loglog_linear": prior_loglog_linear,
98
- "exp4": prior_exp4,
99
- "pow4": prior_pow4,
100
- "dr_hill_zero_background": prior_dr_hill_zero_background,
101
- }
102
- elif distribution == "uniform":
103
- f_priors = {
104
- "pow3": uniform_prior_pow3,
105
- "ilog2": uniform_prior_ilog2,
106
- "janoschek": uniform_prior_janoschek
107
- }
108
- else:
109
- raise NotImplemented()
110
-
111
- x = np.arange(1, seq_len + 1)
112
-
113
- while True:
114
- # sample the noiseless curve
115
- weights = prior_weights(rng, components=components)
116
- y = np.zeros(x.shape, dtype="float")
117
- kwargs = 0
118
- for f, w in weights.items():
119
- kwargs = f_priors[f](rng)
120
- # print(f_components[f](x, **kwargs))
121
- y += w * f_components[f](x, **kwargs)
122
- # add noise (can exceed [0,1], but afaik no way to implement this prior in Tobis work)
123
- var = np.exp(
124
- rng.normal(var_lnloc, var_lnscale)
125
- ) # @heri: ln_prob =+ log(normal.pdf(log(var), loc=var_lnloc, scale=var_lnscale))
126
-
127
- # reject any curves that are non-increasing, exceed the [0,1] range
128
- if (
129
- y[-1] <= y[0]
130
- or (range_constraint and (np.any(y < 0) or np.any(y > 1)))
131
- or np.isnan(y).any()
132
- ):
133
- continue
134
- else:
135
- break
136
-
137
- def curve(): # generates a sample from the same model, but with independent noise
138
- y_noisy = y + rng.normal(np.zeros_like(y), var)
139
- return y, y_noisy
140
-
141
- return curve
142
-
143
-
144
- def generate_prior_dataset(n, prior=sample_prior_comb, seed=42):
145
- """
146
- Returns a fixed sample from the prior (with fixed seq_len) as an n x seq_len np.ndarray
147
- """
148
- rng = np.random.RandomState(seed)
149
- prior_data = np.stack([prior(rng)()[1] for _ in range(n)])
150
- return prior_data
151
-
152
-
153
- def create_get_batch_func(prior):
154
- return partial(get_batch_domhan, prior=prior)
155
-
156
- # function producing batches for PFN training
157
- def get_batch_domhan(
158
- batch_size,
159
- seq_len,
160
- num_features,
161
- prior,
162
- device="cpu",
163
- noisy_target=True,
164
- **_,
165
- ):
166
- assert num_features == 1
167
-
168
- x = np.arange(1, seq_len + 1)
169
- y_target = np.empty((batch_size, seq_len), dtype=float)
170
- y_noisy = np.empty((batch_size, seq_len), dtype=float)
171
-
172
- for i in range(batch_size):
173
- curve_func = prior(np.random, seq_len=seq_len) # uses numpy rng
174
- if noisy_target:
175
- _, y_noisy[i] = curve_func()
176
- y_target[i] = y_noisy[i]
177
- else:
178
- y_target[i], y_noisy[i] = curve_func()
179
-
180
- # turn numpy arrays into correctly shaped torch tensors & move them to device
181
- x = (
182
- torch.arange(1, seq_len + 1)
183
- .repeat((num_features, batch_size, 1))
184
- .transpose(2, 0)
185
- .to(device)
186
- )
187
- y_target = torch.from_numpy(y_target).transpose(1, 0).to(device)
188
- y_noisy = torch.from_numpy(y_noisy).transpose(1, 0).to(device)
189
-
190
- # changes
191
- x = x.float()
192
- y_target = y_target.float()
193
- y_noisy = y_noisy.float()
194
-
195
- return x, y_noisy, y_target
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lcpfn/__init__.py DELETED
@@ -1,80 +0,0 @@
1
- import os, sys
2
-
3
- sys.path.insert(0, os.path.dirname(__file__))
4
-
5
-
6
- model_path = "trained_models"
7
-
8
-
9
- def prepare_models():
10
- pfns4bo_dir = os.path.dirname(__file__)
11
- model_names = [
12
- "pfn_EPOCH1000_EMSIZE512_NLAYERS12_NBUCKETS1000.pt",
13
- "pfn_EPOCH1000_EMSIZE512_NLAYERS6_NBUCKETS1000.pt",
14
- ]
15
-
16
- for name in model_names:
17
- weights_path = os.path.join(pfns4bo_dir, model_path, name)
18
- compressed_weights_path = os.path.join(pfns4bo_dir, model_path, name + ".gz")
19
- if not os.path.exists(weights_path):
20
- if not os.path.exists(compressed_weights_path):
21
- print("Downloading", os.path.abspath(compressed_weights_path))
22
- import requests
23
-
24
- url = f'https://ml.informatik.uni-freiburg.de/research-artifacts/lcpfn/{name + ".gz"}'
25
- r = requests.get(url, allow_redirects=True)
26
- os.makedirs(os.path.dirname(compressed_weights_path), exist_ok=True)
27
- with open(compressed_weights_path, "wb") as f:
28
- f.write(r.content)
29
- if os.path.exists(compressed_weights_path):
30
- print("Unzipping", name)
31
- os.system(f"gzip -dk {compressed_weights_path}")
32
- else:
33
- print("Failed to find", compressed_weights_path)
34
- print(
35
- "Make sure you have an internet connection to download the model automatically.."
36
- )
37
- if os.path.exists(weights_path):
38
- print("Successfully located model at", weights_path)
39
-
40
-
41
- model_dict = {
42
- "EMSIZE512_NLAYERS12_NBUCKETS1000": os.path.join(
43
- os.path.dirname(__file__),
44
- model_path,
45
- "pfn_EPOCH1000_EMSIZE512_NLAYERS12_NBUCKETS1000.pt",
46
- ),
47
- "EMSIZE512_NLAYERS6_NBUCKETS1000": os.path.join(
48
- os.path.dirname(__file__),
49
- model_path,
50
- "pfn_EPOCH1000_EMSIZE512_NLAYERS6_NBUCKETS1000.pt",
51
- ),
52
- }
53
-
54
-
55
- def __getattr__(name):
56
- if name in model_dict:
57
- if not os.path.exists(model_dict[name]):
58
- print(
59
- "Can't find",
60
- os.path.abspath(model_dict[name]),
61
- "thus unzipping/downloading models now.",
62
- )
63
- print("This might take a while..")
64
- prepare_models()
65
- return model_dict[name]
66
- raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
67
-
68
-
69
- from .version import __version__
70
- from lcpfn.model import LCPFN
71
- from lcpfn.train_lcpfn import train_lcpfn
72
- from lcpfn.domhan_prior import sample_from_prior, create_get_batch_func
73
-
74
- __all__ = [
75
- "LCPFN",
76
- "train_lcpfn",
77
- "sample_from_prior",
78
- "create_get_batch_func",
79
- "__version__",
80
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lcpfn/bar_distribution.py DELETED
@@ -1,349 +0,0 @@
1
- import torch
2
- from torch import nn
3
-
4
-
5
- class BarDistribution(nn.Module):
6
- def __init__(
7
- self, borders: torch.Tensor, smoothing=0.0
8
- ): # here borders should start with min and end with max, where all values lie in (min,max) and are sorted
9
- # sorted list of borders
10
- super().__init__()
11
- assert len(borders.shape) == 1
12
- # self.borders = borders
13
- self.register_buffer("borders", borders)
14
- self.register_buffer("smoothing", torch.tensor(smoothing))
15
- # self.bucket_widths = self.borders[1:] - self.borders[:-1]
16
- self.register_buffer("bucket_widths", self.borders[1:] - self.borders[:-1])
17
- full_width = self.bucket_widths.sum()
18
- border_order = torch.argsort(borders)
19
- assert (
20
- full_width - (self.borders[-1] - self.borders[0])
21
- ).abs() < 1e-4, f"diff: {full_width - (self.borders[-1] - self.borders[0])}"
22
- assert (
23
- border_order == torch.arange(len(borders)).to(border_order.device)
24
- ).all(), "Please provide sorted borders!"
25
- self.num_bars = len(borders) - 1
26
-
27
- def map_to_bucket_idx(self, y):
28
- target_sample = torch.searchsorted(self.borders, y) - 1
29
- target_sample[y == self.borders[0]] = 0
30
- target_sample[y == self.borders[-1]] = self.num_bars - 1
31
- return target_sample
32
-
33
- def forward(
34
- self, logits, y
35
- ): # gives the negative log density (the _loss_), y: T x B, logits: T x B x self.num_bars
36
- target_sample = self.map_to_bucket_idx(y)
37
- assert (target_sample >= 0).all() and (
38
- target_sample < self.num_bars
39
- ).all(), f"y {y} not in support set for borders (min_y, max_y) {self.borders}"
40
- assert (
41
- logits.shape[-1] == self.num_bars
42
- ), f"{logits.shape[-1]} vs {self.num_bars}"
43
-
44
- bucket_log_probs = torch.log_softmax(logits, -1)
45
- scaled_bucket_log_probs = bucket_log_probs - torch.log(self.bucket_widths)
46
- # print(bucket_log_probs, logits.shape)
47
-
48
- nll_loss = -scaled_bucket_log_probs.gather(
49
- -1, target_sample.unsqueeze(-1)
50
- ).squeeze(-1)
51
-
52
- smooth_loss = -scaled_bucket_log_probs.mean(dim=-1)
53
- smoothing = self.smoothing if self.training else 0.0
54
- loss = (1.0 - smoothing) * nll_loss + smoothing * smooth_loss
55
- return loss
56
-
57
- def mean(self, logits):
58
- bucket_means = self.borders[:-1] + self.bucket_widths / 2
59
- p = torch.softmax(logits, -1)
60
- return p @ bucket_means
61
-
62
- def icdf(self, logits, left_prob):
63
- """
64
- Implementation of the quantile function
65
- :param logits: Tensor of any shape, with the last dimension being logits
66
- :param left_prob: float: The probability mass to the left of the result.
67
- :return: Position with `left_prob` probability weight to the left.
68
- """
69
- probs = logits.softmax(-1)
70
- cumprobs = torch.cumsum(probs, -1)
71
- idx = (
72
- torch.searchsorted(
73
- cumprobs,
74
- left_prob * torch.ones(*cumprobs.shape[:-1], 1, device=probs.device),
75
- )
76
- .squeeze(-1)
77
- .clamp(0, cumprobs.shape[-1] - 1)
78
- ) # this might not do the right for outliers
79
- cumprobs = torch.cat(
80
- [torch.zeros(*cumprobs.shape[:-1], 1, device=logits.device), cumprobs], -1
81
- )
82
-
83
- rest_prob = left_prob - cumprobs.gather(-1, idx[..., None]).squeeze(-1)
84
- left_border = self.borders[idx]
85
- right_border = self.borders[idx + 1]
86
- return left_border + (right_border - left_border) * rest_prob / probs.gather(
87
- -1, idx[..., None]
88
- ).squeeze(-1)
89
-
90
- def quantile(self, logits, center_prob=0.682):
91
- side_probs = (1.0 - center_prob) / 2
92
- return torch.stack(
93
- (self.icdf(logits, side_probs), self.icdf(logits, 1.0 - side_probs)), -1
94
- )
95
-
96
- def ucb(self, logits, best_f, rest_prob=(1 - 0.682) / 2, maximize=True):
97
- """
98
- UCB utility. Rest Prob is the amount of utility above (below) the confidence interval that is ignored.
99
- Higher rest_prob is equivalent to lower beta in the standard GP-UCB formulation.
100
- :param logits: Logits, as returned by the Transformer.
101
- :param best_f: Only here, since the other utilities have it.
102
- :param rest_prob: The amount of utility above (below) the confidence interval that is ignored.
103
- The default is equivalent to using GP-UCB with `beta=1`.
104
- To get the corresponding `beta`, where `beta` is from
105
- the standard GP definition of UCB `ucb_utility = mean + beta * std`,
106
- you can use this computation: `beta = math.sqrt(2)*torch.erfinv(torch.tensor(2*rest_prob-1))`.
107
- :param maximize:
108
- :return: utility
109
- """
110
- if maximize:
111
- rest_prob = 1 - rest_prob
112
- return self.icdf(logits, rest_prob)
113
-
114
- def mode(self, logits):
115
- mode_inds = logits.argmax(-1)
116
- bucket_means = self.borders[:-1] + self.bucket_widths / 2
117
- return bucket_means[mode_inds]
118
-
119
- def ei(
120
- self, logits, best_f, maximize=True
121
- ): # logits: evaluation_points x batch x feature_dim
122
- bucket_means = self.borders[:-1] + self.bucket_widths / 2
123
- if maximize:
124
- bucket_contributions = torch.tensor(
125
- [
126
- max((bucket_max + max(bucket_min, best_f)) / 2 - best_f, 0)
127
- for bucket_min, bucket_max, bucket_mean in zip(
128
- self.borders[:-1], self.borders[1:], bucket_means
129
- )
130
- ],
131
- dtype=logits.dtype,
132
- device=logits.device,
133
- )
134
- else:
135
- bucket_contributions = torch.tensor(
136
- [
137
- -min((min(bucket_max, best_f) + bucket_min) / 2 - best_f, 0)
138
- for bucket_min, bucket_max, bucket_mean in zip( # min on max instead of max on min, and compare min < instead of max >
139
- self.borders[:-1], self.borders[1:], bucket_means
140
- )
141
- ],
142
- dtype=logits.dtype,
143
- device=logits.device,
144
- )
145
- p = torch.softmax(logits, -1)
146
- return p @ bucket_contributions
147
-
148
- def pi(
149
- self, logits, best_f, maximize=True
150
- ): # logits: evaluation_points x batch x feature_dim
151
- """
152
- Acquisition Function: Probability of Improvement
153
- :param logits: as returned by Transformer
154
- :param best_f: best evaluation so far (the incumbent)
155
- :param maximize: whether to maximize
156
- :return: utility
157
- """
158
- assert maximize is True
159
- p = torch.softmax(logits, -1)
160
- border_widths = self.borders[1:] - self.borders[:-1]
161
- factor = 1.0 - ((best_f - self.borders[:-1]) / border_widths).clamp(0.0, 1.0)
162
- return (p * factor).sum(-1)
163
-
164
- def mean_of_square(self, logits):
165
- """
166
- Computes E[x^2].
167
- :param logits: Output of the model.
168
- """
169
- left_borders = self.borders[:-1]
170
- right_borders = self.borders[1:]
171
- bucket_mean_of_square = (
172
- left_borders.square()
173
- + right_borders.square()
174
- + left_borders * right_borders
175
- ) / 3.0
176
- p = torch.softmax(logits, -1)
177
- return p @ bucket_mean_of_square
178
-
179
- def variance(self, logits):
180
- return self.mean_of_square(logits) - self.mean(logits).square()
181
-
182
-
183
- class FullSupportBarDistribution(BarDistribution):
184
- @staticmethod
185
- def halfnormal_with_p_weight_before(range_max, p=0.5):
186
- s = range_max / torch.distributions.HalfNormal(torch.tensor(1.0)).icdf(
187
- torch.tensor(p)
188
- )
189
- return torch.distributions.HalfNormal(s)
190
-
191
- def forward(
192
- self, logits, y
193
- ): # gives the negative log density (the _loss_), y: T x B, logits: T x B x self.num_bars
194
- assert self.num_bars > 1
195
- target_sample = self.map_to_bucket_idx(y)
196
- target_sample.clamp_(0, self.num_bars - 1)
197
- assert logits.shape[-1] == self.num_bars
198
-
199
- bucket_log_probs = torch.log_softmax(logits, -1)
200
- scaled_bucket_log_probs = bucket_log_probs - torch.log(self.bucket_widths)
201
- # print(bucket_log_probs, logits.shape)
202
- log_probs = scaled_bucket_log_probs.gather(
203
- -1, target_sample.unsqueeze(-1)
204
- ).squeeze(-1)
205
-
206
- side_normals = (
207
- self.halfnormal_with_p_weight_before(self.bucket_widths[0]),
208
- self.halfnormal_with_p_weight_before(self.bucket_widths[-1]),
209
- )
210
-
211
- # TODO look over it again
212
- log_probs[target_sample == 0] += side_normals[0].log_prob(
213
- (self.borders[1] - y[target_sample == 0]).clamp(min=0.00000001)
214
- ) + torch.log(self.bucket_widths[0])
215
- log_probs[target_sample == self.num_bars - 1] += side_normals[1].log_prob(
216
- y[target_sample == self.num_bars - 1] - self.borders[-2]
217
- ) + torch.log(self.bucket_widths[-1])
218
-
219
- nll_loss = -log_probs
220
-
221
- smooth_loss = -scaled_bucket_log_probs.mean(dim=-1)
222
- smoothing = self.smoothing if self.training else 0.0
223
- loss = (1.0 - smoothing) * nll_loss + smoothing * smooth_loss
224
-
225
- return loss
226
-
227
- def mean(self, logits):
228
- bucket_means = self.borders[:-1] + self.bucket_widths / 2
229
- p = torch.softmax(logits, -1)
230
- side_normals = (
231
- self.halfnormal_with_p_weight_before(self.bucket_widths[0]),
232
- self.halfnormal_with_p_weight_before(self.bucket_widths[-1]),
233
- )
234
- bucket_means[0] = -side_normals[0].mean + self.borders[1]
235
- bucket_means[-1] = side_normals[1].mean + self.borders[-2]
236
- return p @ bucket_means
237
-
238
-
239
- def get_bucket_limits_(
240
- num_outputs: int,
241
- full_range: tuple = None,
242
- ys: torch.Tensor = None,
243
- verbose: bool = False,
244
- ):
245
- assert (ys is not None) or (full_range is not None)
246
- if ys is not None:
247
- ys = ys.flatten()
248
- if len(ys) % num_outputs:
249
- ys = ys[: -(len(ys) % num_outputs)]
250
- print(
251
- f"Using {len(ys)} y evals to estimate {num_outputs} buckets. Cut off the last {len(ys) % num_outputs} ys."
252
- )
253
- ys_per_bucket = len(ys) // num_outputs
254
- if full_range is None:
255
- full_range = (ys.min(), ys.max())
256
- else:
257
- assert full_range[0] <= ys.min() and full_range[1] >= ys.max()
258
- full_range = torch.tensor(full_range)
259
- ys_sorted, ys_order = ys.sort(0)
260
- bucket_limits = (
261
- ys_sorted[ys_per_bucket - 1 :: ys_per_bucket][:-1]
262
- + ys_sorted[ys_per_bucket::ys_per_bucket]
263
- ) / 2
264
- if verbose:
265
- print(
266
- f"Using {len(ys)} y evals to estimate {num_outputs} buckets. Cut off the last {len(ys) % num_outputs} ys."
267
- )
268
- print(full_range)
269
- bucket_limits = torch.cat(
270
- [full_range[0].unsqueeze(0), bucket_limits, full_range[1].unsqueeze(0)], 0
271
- )
272
-
273
- else:
274
- class_width = (full_range[1] - full_range[0]) / num_outputs
275
- bucket_limits = torch.cat(
276
- [
277
- full_range[0] + torch.arange(num_outputs).float() * class_width,
278
- torch.tensor(full_range[1]).unsqueeze(0),
279
- ],
280
- 0,
281
- )
282
-
283
- assert (
284
- len(bucket_limits) - 1 == num_outputs
285
- and full_range[0] == bucket_limits[0]
286
- and full_range[-1] == bucket_limits[-1]
287
- )
288
- return bucket_limits
289
-
290
-
291
- def get_bucket_limits(
292
- num_outputs: int,
293
- full_range: tuple = None,
294
- ys: torch.Tensor = None,
295
- verbose: bool = False,
296
- ):
297
- assert (ys is None) != (
298
- full_range is None
299
- ), "Either full_range or ys must be passed."
300
-
301
- if ys is not None:
302
- ys = ys.flatten()
303
- ys = ys[~torch.isnan(ys)]
304
- if len(ys) % num_outputs:
305
- ys = ys[: -(len(ys) % num_outputs)]
306
- print(
307
- f"Using {len(ys)} y evals to estimate {num_outputs} buckets. Cut off the last {len(ys) % num_outputs} ys."
308
- )
309
- ys_per_bucket = len(ys) // num_outputs
310
- if full_range is None:
311
- full_range = (ys.min(), ys.max())
312
- else:
313
- assert (
314
- full_range[0] <= ys.min() and full_range[1] >= ys.max()
315
- ), f"full_range {full_range} not in range of ys {ys.min(), ys.max()}"
316
- full_range = torch.tensor(full_range)
317
- ys_sorted, ys_order = ys.sort(0)
318
- bucket_limits = (
319
- ys_sorted[ys_per_bucket - 1 :: ys_per_bucket][:-1]
320
- + ys_sorted[ys_per_bucket::ys_per_bucket]
321
- ) / 2
322
- if verbose:
323
- print(
324
- f"Using {len(ys)} y evals to estimate {num_outputs} buckets. Cut off the last {len(ys) % num_outputs} ys."
325
- )
326
- print(full_range)
327
- bucket_limits = torch.cat(
328
- [full_range[0].unsqueeze(0), bucket_limits, full_range[1].unsqueeze(0)], 0
329
- )
330
-
331
- else:
332
- class_width = (full_range[1] - full_range[0]) / num_outputs
333
- bucket_limits = torch.cat(
334
- [
335
- full_range[0] + torch.arange(num_outputs).float() * class_width,
336
- torch.tensor(full_range[1]).unsqueeze(0),
337
- ],
338
- 0,
339
- )
340
-
341
- assert (
342
- len(bucket_limits) - 1 == num_outputs
343
- ), f"len(bucket_limits) - 1 == {len(bucket_limits) - 1} != {num_outputs} == num_outputs"
344
- assert full_range[0] == bucket_limits[0], f"{full_range[0]} != {bucket_limits[0]}"
345
- assert (
346
- full_range[-1] == bucket_limits[-1]
347
- ), f"{full_range[-1]} != {bucket_limits[-1]}"
348
-
349
- return bucket_limits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lcpfn/curves.py DELETED
@@ -1,277 +0,0 @@
1
- import numpy as np
2
- from collections import OrderedDict
3
-
4
- prior = {
5
- "pow3": {
6
- "uniform": OrderedDict(
7
- a={"type": "uniform", "param1": -1, "param2": 1},
8
- c={"type": "uniform", "param1": 0, "param2": 1},
9
- alpha={"type": "uniform", "param1": 0, "param2": 1},
10
- ),
11
- "peaked": OrderedDict(
12
- a={"type": "uniform", "param1": -0.6, "param2": 0.6},
13
- c={"type": "uniform", "param1": 0, "param2": 1.25},
14
- alpha={"type": "log_normal", "param1": 0, "param2": 2},
15
- ),
16
- },
17
- "ilog2": {
18
- "uniform": OrderedDict(
19
- c={"type": "uniform", "param1": 0, "param2": 1},
20
- a={"type": "uniform", "param1": -1, "param2": 1},
21
- ),
22
- "peaked": OrderedDict(
23
- c={"type": "uniform", "param1": 0, "param2": 1},
24
- a={"type": "uniform", "param1": -0.5, "param2": 0.5},
25
- ),
26
- },
27
- "janoschek": {
28
- "uniform": OrderedDict(
29
- a={"type": "uniform", "param1": 0, "param2": 1},
30
- beta={"type": "uniform", "param1": 0, "param2": 2},
31
- k={"type": "uniform", "param1": 0, "param2": 1},
32
- delta={"type": "uniform", "param1": -5, "param2": 5},
33
- ),
34
- "peaked": OrderedDict(
35
- a={"type": "uniform", "param1": 0, "param2": 1},
36
- beta={"type": "uniform", "param1": 0, "param2": 2},
37
- k={"type": "log_normal", "param1": -2, "param2": 1},
38
- delta={"type": "log_normal", "param1": 0, "param2": 0.5},
39
- ),
40
- },
41
- }
42
-
43
-
44
- def prior_sampler(rng, type, param1, param2):
45
- if type == "uniform":
46
- return rng.uniform(param1, param2)
47
- elif type == "log_normal":
48
- return rng.lognormal(param1, param2)
49
- raise Exception("Unknown prior type: {}".format(type))
50
-
51
-
52
- def pow3(x, c, a, alpha):
53
- return c - a * (x) ** (-alpha)
54
-
55
-
56
- def prior_pow3(rng):
57
- return {
58
- p: prior_sampler(
59
- rng,
60
- prior["pow3"]["peaked"][p]["type"],
61
- param1=prior["pow3"]["peaked"][p]["param1"],
62
- param2=prior["pow3"]["peaked"][p]["param2"],
63
- )
64
- for p in ["a", "c", "alpha"]
65
- }
66
-
67
-
68
- def uniform_prior_pow3(rng):
69
- return {
70
- p: prior_sampler(
71
- rng,
72
- prior["pow3"]["uniform"][p]["type"],
73
- param1=prior["pow3"]["uniform"][p]["param1"],
74
- param2=prior["pow3"]["uniform"][p]["param2"],
75
- )
76
- for p in ["a", "c", "alpha"]
77
- }
78
-
79
-
80
- def ilog2(x, c, a):
81
- return c - a / (np.log(x + 1))
82
-
83
-
84
- def prior_ilog2(rng):
85
- return {
86
- p: prior_sampler(
87
- rng,
88
- prior["ilog2"]["peaked"][p]["type"],
89
- param1=prior["ilog2"]["peaked"][p]["param1"],
90
- param2=prior["ilog2"]["peaked"][p]["param2"],
91
- )
92
- for p in ["a", "c"]
93
- }
94
-
95
-
96
- def uniform_prior_ilog2(rng):
97
- return {
98
- p: prior_sampler(
99
- rng,
100
- prior["ilog2"]["uniform"][p]["type"],
101
- param1=prior["ilog2"]["uniform"][p]["param1"],
102
- param2=prior["ilog2"]["uniform"][p]["param2"],
103
- )
104
- for p in ["a", "c"]
105
- }
106
-
107
-
108
- def janoschek(x, a, beta, k, delta):
109
- """
110
- http://www.pisces-conservation.com/growthhelp/janoschek.htm
111
- """
112
- return a - (a - beta) * np.exp(-k * x**delta)
113
-
114
-
115
- def prior_janoschek(rng):
116
- return {
117
- p: prior_sampler(
118
- rng,
119
- prior["janoschek"]["peaked"][p]["type"],
120
- param1=prior["janoschek"]["peaked"][p]["param1"],
121
- param2=prior["janoschek"]["peaked"][p]["param2"],
122
- )
123
- for p in ["a", "beta", "k", "delta"]
124
- }
125
-
126
-
127
- def uniform_prior_janoschek(rng):
128
- return {
129
- p: prior_sampler(
130
- rng,
131
- prior["janoschek"]["uniform"][p]["type"],
132
- param1=prior["janoschek"]["uniform"][p]["param1"],
133
- param2=prior["janoschek"]["uniform"][p]["param2"],
134
- )
135
- for p in ["a", "beta", "k", "delta"]
136
- }
137
-
138
-
139
- def log_power(x, a, b, c):
140
- # a: upper bound
141
- # c: growth rate
142
- # initial = a/ (1 + (1/e^b)^c
143
- return a / (1.0 + (x / np.exp(b)) ** c)
144
-
145
-
146
- def prior_log_power(rng):
147
- # a ~ N(0.8,0.1)
148
- # b ~ N(1,1)
149
- # c ~ U(-3,0)
150
- a = rng.normal(0.8, 0.1)
151
- b = rng.normal(1.0, 1.0)
152
- c = rng.uniform(-3.0, 0.0)
153
- return {"a": a, "b": b, "c": c}
154
-
155
-
156
- def weibull(x, alpha, beta, kappa, delta):
157
- """
158
- Weibull modell
159
- http://www.pisces-conservation.com/growthhelp/index.html?morgan_mercer_floden.htm
160
- alpha: upper asymptote
161
- beta: lower asymptote
162
- k: growth rate
163
- delta: controls the x-ordinate for the point of inflection
164
- """
165
- return alpha - (alpha - beta) * np.exp(-((kappa * x) ** delta))
166
-
167
-
168
- def prior_weibull(rng):
169
- alpha = rng.uniform(0.0, 1.5)
170
- beta = rng.uniform(0.0, 1)
171
- kappa = np.exp(rng.normal(-2.0, 1.0))
172
- delta = np.exp(rng.normal(0, 0.5))
173
- return {"alpha": alpha, "beta": beta, "kappa": kappa, "delta": delta}
174
-
175
-
176
- def mmf(x, alpha, beta, kappa, delta):
177
- """
178
- Morgan-Mercer-Flodin
179
- description:
180
- Nonlinear Regression page 342
181
- http://bit.ly/1jodG17
182
- http://www.pisces-conservation.com/growthhelp/index.html?morgan_mercer_floden.htm
183
- alpha: upper asymptote
184
- kappa: growth rate
185
- beta: initial value
186
- delta: controls the point of inflection
187
- """
188
- return alpha - (alpha - beta) / (1.0 + (kappa * x) ** delta)
189
-
190
-
191
- def prior_mmf(rng):
192
- # alpha ~ N(0.8,0.1)
193
- # beta ~ N(0.2,0.1)
194
- # ln(kappa) ~ N(0,2)
195
- # ln(delta) ~ N(0,1)
196
- alpha = rng.normal(0.8, 0.1)
197
- beta = rng.normal(0.2, 0.1)
198
- kappa = np.exp(rng.normal(0, 2))
199
- delta = np.exp(rng.normal(0, 1))
200
- return {"alpha": alpha, "beta": beta, "kappa": kappa, "delta": delta}
201
-
202
-
203
- def vap(x, a, b, c):
204
- """Vapor pressure model"""
205
- # no upper bound if c > 0
206
- # a = ln(upper bound) for c=0
207
- # a+b = ln(initial)
208
- return np.exp(a + b / x + c * np.log(x))
209
-
210
-
211
- def prior_vap(rng):
212
- a = rng.uniform(-2.0, 0.0) # @heri: range check
213
- b = rng.uniform(-4.0, 0.0) # @heri: range check
214
- c = np.exp(rng.uniform(-8.0, 0.0)) # @heri: same as weights
215
- return {"a": a, "b": b, "c": c}
216
-
217
-
218
- def loglog_linear(x, a, b):
219
- x = np.log(x)
220
- return np.log(a * x + b)
221
-
222
-
223
- def prior_loglog_linear(rng):
224
- # ln(a) ~ N(-2, 1)
225
- # ln(b) ~ U(0, 1)
226
- a = np.exp(rng.normal(-2.0, 1.0))
227
- b = np.exp(rng.uniform(0.0, 1.0))
228
- return {"a": a, "b": b}
229
-
230
-
231
- def exp4(x, c, a, b, alpha):
232
- return c - np.exp(-a * (x**alpha) + b)
233
-
234
-
235
- def prior_exp4(rng):
236
- # c ~ N(0.8,0.1)
237
- c = rng.normal(0.8, 0.1)
238
- # ln(a) ~ N(-2,1)
239
- a = np.exp(rng.normal(-2, 1))
240
- # ln(alpha) ~ N(0,1)
241
- alpha = np.exp(rng.normal(0, 1))
242
- # ln(b) ~ N(0,0.5)
243
- b = np.exp(rng.normal(0, 0.5))
244
- return {"a": a, "b": b, "c": c, "alpha": alpha}
245
-
246
-
247
- def pow4(x, c, a, b, alpha):
248
- return c - (a * x + b) ** -alpha
249
-
250
-
251
- def prior_pow4(rng):
252
- # ln(1 - c) ~ U(-5, 0)
253
- c = 1 - np.exp(rng.uniform(-5.0, 0))
254
- # ln(a) ~ N(-3, 2)
255
- a = np.exp(rng.normal(-3.0, 2))
256
- # ln(alpha) ~ N(0,1)
257
- alpha = np.exp(rng.normal(0, 1))
258
- # ln(b) ~ U(0, 1)
259
- b = np.exp(rng.uniform(0, 1))
260
- return {"a": a, "b": b, "c": c, "alpha": alpha}
261
-
262
-
263
- def dr_hill_zero_background(x, theta, eta, kappa):
264
- # theta: upper bound
265
- # eta: growth rate
266
- # initial = theta/(kappa^eta + 1)
267
- return (theta * x**eta) / (kappa**eta + x**eta)
268
-
269
-
270
- def prior_dr_hill_zero_background(rng):
271
- # theta ~ U(1,0) N(0.8,0.1)
272
- # ln(eta) ~ N(1,1)
273
- # ln(kappa) ~ N(1,2)
274
- theta = rng.normal(0.8, 0.1)
275
- eta = np.exp(rng.normal(1.0, 1.0))
276
- kappa = np.exp(rng.normal(1.0, 2.0))
277
- return {"theta": theta, "eta": eta, "kappa": kappa}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lcpfn/decoders.py DELETED
@@ -1,42 +0,0 @@
1
- import torch
2
- from torch import nn
3
- import random
4
-
5
- from torch import Tensor
6
- import torch.nn.functional as F
7
-
8
-
9
- class GELU(nn.Module):
10
- def forward(self, input: Tensor) -> Tensor:
11
- return F.gelu(input)
12
-
13
-
14
- class ScaledDecoder(nn.Module):
15
- def __init__(self, ninp, nhid, nout):
16
- super().__init__()
17
- self.linear = nn.Linear(ninp, nhid)
18
- self.linear1 = nn.Linear(nhid, nout)
19
- self.linear2 = nn.Linear(nhid, 10)
20
-
21
- def forward(self, x):
22
- # return torch.cat([self.linear1(x), self.linear2(x)], -1)
23
- x = self.linear(x)
24
- x = GELU()(x)
25
- temps = self.linear2(x).softmax(-1) @ torch.tensor(
26
- [1.0, 1.4, 1.7, 2.0, 5.0, 10.0, 20.0, 40.0, 80.0, 160.0], device=x.device
27
- )
28
- if random.random() > 0.99:
29
- print(temps.shape, temps[:, :2])
30
- return self.linear1(x) / temps.unsqueeze(-1)
31
-
32
-
33
- class FixedScaledDecoder(nn.Module):
34
- def __init__(self, ninp, nhid, nout):
35
- super().__init__()
36
- self.mapper = nn.Sequential(
37
- nn.Linear(ninp, nhid), nn.GELU(), nn.Linear(nhid, nout)
38
- )
39
- self.T = nn.Parameter(torch.ones(10000) / 10000)
40
-
41
- def forward(self, x):
42
- return self.mapper(x) / self.T.sum()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lcpfn/domhan_prior.py DELETED
@@ -1,199 +0,0 @@
1
- from functools import partial
2
- import torch
3
- import numpy as np
4
- from lcpfn.curves import (
5
- pow3,
6
- ilog2,
7
- janoschek,
8
- log_power,
9
- prior_ilog2,
10
- uniform_prior_pow3,
11
- weibull,
12
- mmf,
13
- vap,
14
- loglog_linear,
15
- exp4,
16
- pow4,
17
- dr_hill_zero_background,
18
- )
19
- from lcpfn.curves import (
20
- prior_pow3,
21
- prior_janoschek,
22
- prior_log_power,
23
- prior_weibull,
24
- prior_mmf,
25
- prior_vap,
26
- prior_loglog_linear,
27
- prior_exp4,
28
- prior_pow4,
29
- prior_dr_hill_zero_background,
30
- )
31
- from lcpfn.curves import (
32
- uniform_prior_pow3,
33
- uniform_prior_ilog2,
34
- uniform_prior_janoschek,
35
- )
36
-
37
-
38
- def prior_weights(
39
- rng,
40
- components=[
41
- "pow3",
42
- "ilog2",
43
- "janoschek",
44
- "log_power",
45
- "weibull",
46
- "mmf",
47
- "vap",
48
- "loglog_linear",
49
- "exp4",
50
- "pow4",
51
- "dr_hill_zero_background",
52
- ],
53
- ):
54
- K = len(components)
55
- weights = rng.uniform(0.0, 1, size=(K,))
56
- return {f: weights[i] for i, f in enumerate(components)}
57
-
58
-
59
- def sample_from_prior(rng, seq_len=100):
60
- return sample_prior_comb(
61
- rng=rng,
62
- seq_len=seq_len,
63
- components=["pow3", "ilog2", "janoschek"],
64
- distribution="peaked",
65
- )
66
-
67
-
68
- def sample_prior_comb(
69
- rng,
70
- components,
71
- distribution,
72
- var_lnloc=-4,
73
- var_lnscale=1,
74
- range_constraint=True,
75
- seq_len=100,
76
- ):
77
- f_components = {
78
- "pow3": pow3,
79
- "ilog2": ilog2,
80
- "janoschek": janoschek,
81
- "log_power": log_power,
82
- "weibull": weibull,
83
- "mmf": mmf,
84
- "vap": vap,
85
- "loglog_linear": loglog_linear,
86
- "exp4": exp4,
87
- "pow4": pow4,
88
- "dr_hill_zero_background": dr_hill_zero_background,
89
- }
90
-
91
- if distribution == "peaked":
92
- f_priors = {
93
- "pow3": prior_pow3,
94
- "ilog2": prior_ilog2,
95
- "janoschek": prior_janoschek,
96
- "log_power": prior_log_power,
97
- "weibull": prior_weibull,
98
- "mmf": prior_mmf,
99
- "vap": prior_vap,
100
- "loglog_linear": prior_loglog_linear,
101
- "exp4": prior_exp4,
102
- "pow4": prior_pow4,
103
- "dr_hill_zero_background": prior_dr_hill_zero_background,
104
- }
105
- elif distribution == "uniform":
106
- f_priors = {
107
- "pow3": uniform_prior_pow3,
108
- "ilog2": uniform_prior_ilog2,
109
- "janoschek": uniform_prior_janoschek,
110
- }
111
- else:
112
- raise NotImplemented()
113
-
114
- x = np.arange(1, seq_len + 1)
115
-
116
- while True:
117
- # sample the noiseless curve
118
- weights = prior_weights(rng, components=components)
119
- y = np.zeros(x.shape, dtype="float")
120
- kwargs = 0
121
- for f, w in weights.items():
122
- kwargs = f_priors[f](rng)
123
- # print(f_components[f](x, **kwargs))
124
- y += w * f_components[f](x, **kwargs)
125
- # add noise (can exceed [0,1], but afaik no way to implement this prior in Tobis work)
126
- var = np.exp(
127
- rng.normal(var_lnloc, var_lnscale)
128
- ) # @heri: ln_prob =+ log(normal.pdf(log(var), loc=var_lnloc, scale=var_lnscale))
129
-
130
- # reject any curves that are non-increasing, exceed the [0,1] range
131
- if (
132
- y[-1] <= y[0]
133
- or (range_constraint and (np.any(y < 0) or np.any(y > 1)))
134
- or np.isnan(y).any()
135
- ):
136
- continue
137
- else:
138
- break
139
-
140
- def curve(): # generates a sample from the same model, but with independent noise
141
- y_noisy = y + rng.normal(np.zeros_like(y), var)
142
- return y, y_noisy
143
-
144
- return curve
145
-
146
-
147
- def generate_prior_dataset(n, prior=sample_prior_comb, seed=42):
148
- """
149
- Returns a fixed sample from the prior (with fixed seq_len) as an n x seq_len np.ndarray
150
- """
151
- rng = np.random.RandomState(seed)
152
- prior_data = np.stack([prior(rng)()[1] for _ in range(n)])
153
- return prior_data
154
-
155
-
156
- def create_get_batch_func(prior):
157
- return partial(get_batch_domhan, prior=prior)
158
-
159
-
160
- # function producing batches for PFN training
161
- def get_batch_domhan(
162
- batch_size,
163
- seq_len,
164
- num_features,
165
- prior,
166
- device="cpu",
167
- noisy_target=True,
168
- **_,
169
- ):
170
- assert num_features == 1
171
-
172
- x = np.arange(1, seq_len + 1)
173
- y_target = np.empty((batch_size, seq_len), dtype=float)
174
- y_noisy = np.empty((batch_size, seq_len), dtype=float)
175
-
176
- for i in range(batch_size):
177
- curve_func = prior(np.random, seq_len=seq_len) # uses numpy rng
178
- if noisy_target:
179
- _, y_noisy[i] = curve_func()
180
- y_target[i] = y_noisy[i]
181
- else:
182
- y_target[i], y_noisy[i] = curve_func()
183
-
184
- # turn numpy arrays into correctly shaped torch tensors & move them to device
185
- x = (
186
- torch.arange(1, seq_len + 1)
187
- .repeat((num_features, batch_size, 1))
188
- .transpose(2, 0)
189
- .to(device)
190
- )
191
- y_target = torch.from_numpy(y_target).transpose(1, 0).to(device)
192
- y_noisy = torch.from_numpy(y_noisy).transpose(1, 0).to(device)
193
-
194
- # changes
195
- x = x.float()
196
- y_target = y_target.float()
197
- y_noisy = y_noisy.float()
198
-
199
- return x, y_noisy, y_target
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lcpfn/encoders.py DELETED
@@ -1,190 +0,0 @@
1
- import math
2
-
3
- import torch
4
- import torch.nn as nn
5
- from lcpfn.utils import normalize_data
6
- import torch.nn.functional as F
7
- from torch.nn import TransformerEncoder, TransformerEncoderLayer
8
-
9
-
10
- class StyleEncoder(nn.Module):
11
- def __init__(self, em_size, hyperparameter_definitions):
12
- super().__init__()
13
- self.em_size = em_size
14
- self.embedding = nn.Linear(hyperparameter_definitions.shape[0], self.em_size)
15
-
16
- def forward(self, hyperparameters): # T x B x num_hps
17
- return self.embedding(hyperparameters)
18
-
19
-
20
- class _PositionalEncoding(nn.Module):
21
- def __init__(self, d_model, dropout=0.0):
22
- super().__init__()
23
- self.dropout = nn.Dropout(p=dropout)
24
- self.d_model = d_model
25
- self.device_test_tensor = nn.Parameter(torch.tensor(1.0))
26
-
27
- def forward(self, x): # T x B x num_features
28
- assert self.d_model % x.shape[-1] * 2 == 0
29
- d_per_feature = self.d_model // x.shape[-1]
30
- pe = torch.zeros(*x.shape, d_per_feature, device=self.device_test_tensor.device)
31
- # position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
32
- interval_size = 10
33
- div_term = (
34
- (1.0 / interval_size)
35
- * 2
36
- * math.pi
37
- * torch.exp(
38
- torch.arange(
39
- 0, d_per_feature, 2, device=self.device_test_tensor.device
40
- ).float()
41
- * math.log(math.sqrt(2))
42
- )
43
- )
44
- # print(div_term/2/math.pi)
45
- pe[..., 0::2] = torch.sin(x.unsqueeze(-1) * div_term)
46
- pe[..., 1::2] = torch.cos(x.unsqueeze(-1) * div_term)
47
- return self.dropout(pe).view(x.shape[0], x.shape[1], self.d_model)
48
-
49
-
50
- Positional = lambda _, emsize: _PositionalEncoding(d_model=emsize)
51
-
52
-
53
- class EmbeddingEncoder(nn.Module):
54
- def __init__(self, num_features, em_size, num_embs=100):
55
- super().__init__()
56
- self.num_embs = num_embs
57
- self.embeddings = nn.Embedding(num_embs * num_features, em_size, max_norm=True)
58
- self.init_weights(0.1)
59
- self.min_max = (-2, +2)
60
-
61
- @property
62
- def width(self):
63
- return self.min_max[1] - self.min_max[0]
64
-
65
- def init_weights(self, initrange):
66
- self.embeddings.weight.data.uniform_(-initrange, initrange)
67
-
68
- def discretize(self, x):
69
- split_size = self.width / self.num_embs
70
- return (x - self.min_max[0] // split_size).int().clamp(0, self.num_embs - 1)
71
-
72
- def forward(self, x): # T x B x num_features
73
- x_idxs = self.discretize(x)
74
- x_idxs += (
75
- torch.arange(x.shape[-1], device=x.device).view(1, 1, -1) * self.num_embs
76
- )
77
- # print(x_idxs,self.embeddings.weight.shape)
78
- return self.embeddings(x_idxs).mean(-2)
79
-
80
-
81
- class Normalize(nn.Module):
82
- def __init__(self, mean, std):
83
- super().__init__()
84
- self.mean = mean
85
- self.std = std
86
-
87
- def forward(self, x):
88
- return (x - self.mean) / self.std
89
-
90
-
91
- def get_normalized_uniform_encoder(encoder_creator):
92
- """
93
- This can be used to wrap an encoder that is fed uniform samples in [0,1] and normalizes these to 0 mean and 1 std.
94
- For example, it can be used as `encoder_creator = get_normalized_uniform_encoder(encoders.Linear)`, now this can
95
- be initialized with `encoder_creator(feature_dim, in_dim)`.
96
- :param encoder:
97
- :return:
98
- """
99
- return lambda in_dim, out_dim: nn.Sequential(
100
- Normalize(0.5, math.sqrt(1 / 12)), encoder_creator(in_dim, out_dim)
101
- )
102
-
103
-
104
- Linear = nn.Linear
105
- MLP = lambda num_features, emsize: nn.Sequential(
106
- nn.Linear(num_features + 1, emsize * 2), nn.ReLU(), nn.Linear(emsize * 2, emsize)
107
- )
108
-
109
-
110
- class NanHandlingEncoder(nn.Module):
111
- def __init__(self, num_features, emsize, keep_nans=True):
112
- super().__init__()
113
- self.num_features = 2 * num_features if keep_nans else num_features
114
- self.emsize = emsize
115
- self.keep_nans = keep_nans
116
- self.layer = nn.Linear(self.num_features, self.emsize)
117
-
118
- def forward(self, x):
119
- if self.keep_nans:
120
- x = torch.cat(
121
- [
122
- torch.nan_to_num(x, nan=0.0),
123
- normalize_data(
124
- torch.isnan(x) * -1
125
- + torch.logical_and(torch.isinf(x), torch.sign(x) == 1) * 1
126
- + torch.logical_and(torch.isinf(x), torch.sign(x) == -1) * 2
127
- ),
128
- ],
129
- -1,
130
- )
131
- else:
132
- x = torch.nan_to_num(x, nan=0.0)
133
- return self.layer(x)
134
-
135
-
136
- class Linear(nn.Linear):
137
- def __init__(self, num_features, emsize):
138
- super().__init__(num_features, emsize)
139
- self.num_features = num_features
140
- self.emsize = emsize
141
-
142
- def forward(self, x):
143
- x = torch.nan_to_num(x, nan=0.0)
144
- return super().forward(x)
145
-
146
-
147
- class Conv(nn.Module):
148
- def __init__(self, input_size, emsize):
149
- super().__init__()
150
- self.convs = torch.nn.ModuleList(
151
- [nn.Conv2d(64 if i else 1, 64, 3) for i in range(5)]
152
- )
153
- self.linear = nn.Linear(64, emsize)
154
-
155
- def forward(self, x):
156
- size = math.isqrt(x.shape[-1])
157
- assert size * size == x.shape[-1]
158
- x = x.reshape(*x.shape[:-1], 1, size, size)
159
- for conv in self.convs:
160
- if x.shape[-1] < 4:
161
- break
162
- x = conv(x)
163
- x.relu_()
164
- x = nn.AdaptiveAvgPool2d((1, 1))(x).squeeze(-1).squeeze(-1)
165
- return self.linear(x)
166
-
167
-
168
- class CanEmb(nn.Embedding):
169
- def __init__(
170
- self, num_features, num_embeddings: int, embedding_dim: int, *args, **kwargs
171
- ):
172
- assert embedding_dim % num_features == 0
173
- embedding_dim = embedding_dim // num_features
174
- super().__init__(num_embeddings, embedding_dim, *args, **kwargs)
175
-
176
- def forward(self, x):
177
- lx = x.long()
178
- assert (lx == x).all(), "CanEmb only works with tensors of whole numbers"
179
- x = super().forward(lx)
180
- return x.view(*x.shape[:-2], -1)
181
-
182
-
183
- def get_Canonical(num_classes):
184
- return lambda num_features, emsize: CanEmb(num_features, num_classes, emsize)
185
-
186
-
187
- def get_Embedding(num_embs_per_feature=100):
188
- return lambda num_features, emsize: EmbeddingEncoder(
189
- num_features, emsize, num_embs=num_embs_per_feature
190
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lcpfn/initializers.py DELETED
@@ -1,11 +0,0 @@
1
- import torch
2
- from torch import nn
3
-
4
-
5
- def get_NormalInitializer(std):
6
- def initializer(m):
7
- if isinstance(m, nn.Linear):
8
- nn.init.normal_(m.weight, 0, std)
9
- nn.init.normal_(m.bias, 0, std)
10
-
11
- return initializer
 
 
 
 
 
 
 
 
 
 
 
 
lcpfn/layer.py DELETED
@@ -1,179 +0,0 @@
1
- from functools import partial
2
- from typing import Optional
3
- from torch import Tensor
4
- from torch import nn
5
- from torch.nn.modules.transformer import *
6
- from torch.nn.modules.transformer import _get_activation_fn
7
-
8
- from torch.utils.checkpoint import checkpoint
9
-
10
-
11
- class TransformerEncoderLayer(nn.Module):
12
- r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
13
- This standard encoder layer is based on the paper "Attention Is All You Need".
14
- Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
15
- Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
16
- Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
17
- in a different way during application.
18
-
19
- Args:
20
- d_model: the number of expected features in the input (required).
21
- nhead: the number of heads in the multiheadattention models (required).
22
- dim_feedforward: the dimension of the feedforward network model (default=2048).
23
- dropout: the dropout value (default=0.1).
24
- activation: the activation function of intermediate layer, relu or gelu (default=relu).
25
- layer_norm_eps: the eps value in layer normalization components (default=1e-5).
26
- batch_first: If ``True``, then the input and output tensors are provided
27
- as (batch, seq, feature). Default: ``False``.
28
-
29
- Examples::
30
- >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
31
- >>> src = torch.rand(10, 32, 512)
32
- >>> out = encoder_layer(src)
33
-
34
- Alternatively, when ``batch_first`` is ``True``:
35
- >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
36
- >>> src = torch.rand(32, 10, 512)
37
- >>> out = encoder_layer(src)
38
- """
39
-
40
- __constants__ = ["batch_first"]
41
-
42
- def __init__(
43
- self,
44
- d_model,
45
- nhead,
46
- dim_feedforward=2048,
47
- dropout=0.1,
48
- activation="relu",
49
- layer_norm_eps=1e-5,
50
- batch_first=False,
51
- pre_norm=False,
52
- device=None,
53
- dtype=None,
54
- recompute_attn=False,
55
- ) -> None:
56
- factory_kwargs = {"device": device, "dtype": dtype}
57
- super().__init__()
58
- self.self_attn = MultiheadAttention(
59
- d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs
60
- )
61
- # Implementation of Feedforward model
62
- self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
63
- self.dropout = Dropout(dropout)
64
- self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)
65
-
66
- self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
67
- self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
68
- self.dropout1 = Dropout(dropout)
69
- self.dropout2 = Dropout(dropout)
70
- self.pre_norm = pre_norm
71
- self.recompute_attn = recompute_attn
72
-
73
- self.activation = _get_activation_fn(activation)
74
-
75
- def __setstate__(self, state):
76
- if "activation" not in state:
77
- state["activation"] = F.relu
78
- super().__setstate__(state)
79
-
80
- def forward(
81
- self,
82
- src: Tensor,
83
- src_mask: Optional[Tensor] = None,
84
- src_key_padding_mask: Optional[Tensor] = None,
85
- ) -> Tensor:
86
- r"""Pass the input through the encoder layer.
87
-
88
- Args:
89
- src: the sequence to the encoder layer (required).
90
- src_mask: the mask for the src sequence (optional).
91
- src_key_padding_mask: the mask for the src keys per batch (optional).
92
-
93
- Shape:
94
- see the docs in Transformer class.
95
- """
96
- if self.pre_norm:
97
- src_ = self.norm1(src)
98
- else:
99
- src_ = src
100
- if isinstance(src_mask, tuple):
101
- # global attention setup
102
- assert not self.self_attn.batch_first
103
- assert src_key_padding_mask is None
104
-
105
- global_src_mask, trainset_src_mask, valset_src_mask = src_mask
106
-
107
- num_global_tokens = global_src_mask.shape[0]
108
- num_train_tokens = trainset_src_mask.shape[0]
109
-
110
- global_tokens_src = src_[:num_global_tokens]
111
- train_tokens_src = src_[
112
- num_global_tokens : num_global_tokens + num_train_tokens
113
- ]
114
- global_and_train_tokens_src = src_[: num_global_tokens + num_train_tokens]
115
- eval_tokens_src = src_[num_global_tokens + num_train_tokens :]
116
-
117
- attn = (
118
- partial(checkpoint, self.self_attn)
119
- if self.recompute_attn
120
- else self.self_attn
121
- )
122
-
123
- global_tokens_src2 = attn(
124
- global_tokens_src,
125
- global_and_train_tokens_src,
126
- global_and_train_tokens_src,
127
- None,
128
- True,
129
- global_src_mask,
130
- )[0]
131
- train_tokens_src2 = attn(
132
- train_tokens_src,
133
- global_tokens_src,
134
- global_tokens_src,
135
- None,
136
- True,
137
- trainset_src_mask,
138
- )[0]
139
- eval_tokens_src2 = attn(
140
- eval_tokens_src, src_, src_, None, True, valset_src_mask
141
- )[0]
142
-
143
- src2 = torch.cat(
144
- [global_tokens_src2, train_tokens_src2, eval_tokens_src2], dim=0
145
- )
146
-
147
- else:
148
- if self.recompute_attn:
149
- src2 = checkpoint(
150
- self.self_attn,
151
- src_,
152
- src_,
153
- src_,
154
- src_key_padding_mask,
155
- True,
156
- src_mask,
157
- )[0]
158
- else:
159
- src2 = self.self_attn(
160
- src_,
161
- src_,
162
- src_,
163
- attn_mask=src_mask,
164
- key_padding_mask=src_key_padding_mask,
165
- )[0]
166
- src = src + self.dropout1(src2)
167
- if not self.pre_norm:
168
- src = self.norm1(src)
169
-
170
- if self.pre_norm:
171
- src_ = self.norm2(src)
172
- else:
173
- src_ = src
174
- src2 = self.linear2(self.dropout(self.activation(self.linear1(src_))))
175
- src = src + self.dropout2(src2)
176
-
177
- if not self.pre_norm:
178
- src = self.norm2(src)
179
- return src
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lcpfn/model.py DELETED
@@ -1,56 +0,0 @@
1
- import torch
2
- import lcpfn
3
- import warnings
4
- from lcpfn import utils
5
-
6
-
7
- class LCPFN(torch.nn.Module):
8
- def __init__(self, model_name="EMSIZE512_NLAYERS12_NBUCKETS1000"):
9
- super(LCPFN, self).__init__()
10
- self.model = torch.load(
11
- getattr(lcpfn, model_name) if model_name in lcpfn.model_dict else model_name
12
- )
13
- self.model.eval()
14
-
15
- def check_input(self, x_train, x_test, y_train, y_test=None):
16
- if torch.any(x_train < 0) or torch.any(x_test < 0):
17
- # raise warning if input has negative values
18
- raise Exception("x values should be non-negative")
19
- if torch.any((0 > y_train) | (y_train > 1)) or (
20
- y_test is not None and torch.any(0 < y_test < 1)
21
- ):
22
- # raise warning if input has values outside [0,1]
23
- raise Exception(
24
- "y values should be in the range [0,1]. Please set normalizer_kwargs accordingly."
25
- )
26
-
27
- @torch.no_grad()
28
- def predict_mean(
29
- self, x_train, y_train, x_test, normalizer=utils.identity_normalizer()
30
- ):
31
- y_train_norm = normalizer[0](y_train)
32
- logits = self(x_train=x_train, y_train=y_train_norm, x_test=x_test)
33
- return normalizer[1](self.model.criterion.mean(logits))
34
-
35
- @torch.no_grad()
36
- def predict_quantiles(
37
- self, x_train, y_train, x_test, qs, normalizer=utils.identity_normalizer()
38
- ):
39
- y_train_norm = normalizer[0](y_train)
40
- logits = self(x_train=x_train, y_train=y_train_norm, x_test=x_test)
41
- return normalizer[1](
42
- torch.cat([self.model.criterion.icdf(logits, q) for q in qs], dim=1)
43
- )
44
-
45
- @torch.no_grad()
46
- def nll_loss(self, x_train, y_train, x_test, y_test):
47
- # TODO add normalizer_kwargs
48
- logits = self(x_train=x_train, y_train=y_train, x_test=x_test)
49
- return self.model.criterion(logits, y_test)
50
-
51
- def forward(self, x_train, y_train, x_test):
52
- self.check_input(x_train, x_test, y_train)
53
- single_eval_pos = x_train.shape[0]
54
- x = torch.cat([x_train, x_test], dim=0).unsqueeze(1)
55
- y = y_train.unsqueeze(1)
56
- return self.model((x, y), single_eval_pos=single_eval_pos)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lcpfn/positional_encodings.py DELETED
@@ -1,78 +0,0 @@
1
- import math
2
-
3
- import torch
4
- from torch import nn
5
-
6
-
7
- # Protocol for positonal encodings.
8
- # __init__(d_model, max_len=..[, more optionals])
9
- # forward(x: (seq_len, bs, d_model)) -> Tensor of shape (*x.shape[:2],d_model) containing pos. embeddings
10
-
11
-
12
- class NoPositionalEncoding(nn.Module):
13
- def __init__(self, d_model, max_len=None):
14
- super(NoPositionalEncoding, self).__init__()
15
- pass
16
-
17
- def forward(self, x):
18
- return x # * math.sqrt(x.shape[-1])
19
-
20
-
21
- class PositionalEncoding(nn.Module):
22
- def __init__(self, d_model, max_len=5000):
23
- super(PositionalEncoding, self).__init__()
24
- pe = torch.zeros(max_len, d_model)
25
- position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
26
- div_term = torch.exp(
27
- torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
28
- )
29
- pe[:, 0::2] = torch.sin(position * div_term)
30
- pe[:, 1::2] = torch.cos(position * div_term)
31
- pe = pe.unsqueeze(0).transpose(0, 1)
32
- self.register_buffer("pe", pe)
33
-
34
- def forward(self, x):
35
- x = self.pe[: x.size(0), :] + x # * math.sqrt(x.shape[-1])
36
- return x
37
-
38
-
39
- class LearnedPositionalEncoding(nn.Module):
40
- def __init__(self, d_model, max_len=5000):
41
- super(LearnedPositionalEncoding, self).__init__()
42
- self.max_seq_len = max_len
43
- # self.positional_embeddings = nn.Embedding(max_len, d_model)
44
- self.positional_embeddings = nn.Parameter(torch.empty(max_len, d_model))
45
- nn.init.normal_(self.positional_embeddings, mean=0, std=d_model**-0.5)
46
-
47
- def forward(self, x):
48
- seq_len, bs, d_model = x.shape
49
- assert seq_len <= len(
50
- self.positional_embeddings
51
- ), "seq_len can be at most max_len."
52
- pos_emb = self.positional_embeddings[:seq_len]
53
- return (
54
- pos_emb.unsqueeze(1).expand(seq_len, bs, d_model) + x
55
- ) # * math.sqrt(x.shape[-1])
56
-
57
-
58
- class PairedScrambledPositionalEncodings(LearnedPositionalEncoding):
59
- # TODO check whether it is a problem to use the same perm. for full batch
60
- def forward(self, x):
61
- seq_len, bs, d_model = x.shape
62
- assert seq_len <= len(
63
- self.positional_embeddings
64
- ), "seq_len can be at most max_len."
65
- assert (
66
- len(self.positional_embeddings) % 2 == 0
67
- ), "Please specify an even max_len."
68
-
69
- paired_embs = self.positional_embeddings.view(
70
- len(self.positional_embeddings), -1, 2
71
- )
72
- pos_emb = paired_embs[torch.randperm(len(paired_embs))].view(
73
- *self.positional_embeddings.shape
74
- )[:seq_len]
75
-
76
- return (
77
- pos_emb.unsqueeze(1).expand(seq_len, bs, d_model) + x
78
- ) # * math.sqrt(x.shape[-1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lcpfn/priors/__init__.py DELETED
@@ -1 +0,0 @@
1
- from . import gp, ridge
 
 
lcpfn/priors/binarized_regression.py DELETED
@@ -1,19 +0,0 @@
1
- from . import fast_gp, fast_gp_mix
2
- from .utils import get_batch_to_dataloader
3
-
4
- def regression_prior_to_binary(get_batch_function):
5
-
6
- def binarized_get_batch_function(*args, assert_on=False, **kwargs):
7
- x, y, target_y = get_batch_function(*args, **kwargs)
8
- if assert_on:
9
- assert y is target_y, "y == target_y is assumed by this function"
10
- y = y.sigmoid().bernoulli()
11
- return x, y, y
12
-
13
- return binarized_get_batch_function
14
-
15
-
16
- Binarized_fast_gp_dataloader = get_batch_to_dataloader(regression_prior_to_binary(fast_gp.get_batch))
17
-
18
-
19
- Binarized_fast_gp_mix_dataloader = get_batch_to_dataloader(regression_prior_to_binary(fast_gp_mix.get_batch))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lcpfn/priors/fast_gp.py DELETED
@@ -1,143 +0,0 @@
1
- import time
2
-
3
- import torch
4
- from torch import nn
5
- import gpytorch
6
-
7
- from .utils import get_batch_to_dataloader
8
- from utils import default_device
9
-
10
-
11
- # We will use the simplest form of GP model, exact inference
12
- class ExactGPModel(gpytorch.models.ExactGP):
13
- def __init__(self, train_x, train_y, likelihood):
14
- super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
15
- self.mean_module = gpytorch.means.ConstantMean()
16
- self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
17
-
18
- def forward(self, x):
19
- mean_x = self.mean_module(x)
20
- covar_x = self.covar_module(x)
21
- return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
22
-
23
-
24
- def get_model(x, y, hyperparameters):
25
- likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.GreaterThan(1.e-9))
26
- model = ExactGPModel(x, y, likelihood)
27
- model.likelihood.noise = torch.ones_like(model.likelihood.noise) * hyperparameters["noise"]
28
- model.covar_module.outputscale = torch.ones_like(model.covar_module.outputscale) * hyperparameters["outputscale"]
29
- model.covar_module.base_kernel.lengthscale = torch.ones_like(model.covar_module.base_kernel.lengthscale) * \
30
- hyperparameters["lengthscale"]
31
- return model, likelihood
32
-
33
-
34
- @torch.no_grad()
35
- def get_batch(batch_size, seq_len, num_features, device=default_device, hyperparameters=None,
36
- equidistant_x=False, fix_x=None, **kwargs):
37
- if isinstance(hyperparameters, (tuple, list)):
38
- hyperparameters = {"noise": hyperparameters[0]
39
- , "outputscale": hyperparameters[1]
40
- , "lengthscale": hyperparameters[2]
41
- , "is_binary_classification": hyperparameters[3]
42
- # , "num_features_used": hyperparameters[4]
43
- , "normalize_by_used_features": hyperparameters[5]
44
- , "order_y": hyperparameters[6]
45
- , "sampling": hyperparameters[7]
46
- }
47
- elif hyperparameters is None:
48
- hyperparameters = {"noise": .1, "outputscale": .1, "lengthscale": .1}
49
-
50
- if 'verbose' in hyperparameters and hyperparameters['verbose']:
51
- print({"noise": hyperparameters['noise'], "outputscale": hyperparameters['outputscale']
52
- , "lengthscale": hyperparameters['lengthscale'], 'batch_size': batch_size, 'sampling': hyperparameters['sampling']})
53
-
54
- # hyperparameters = {k: hyperparameters[k]() if callable(hyperparameters[k]) else hyperparameters[k] for k in
55
- # hyperparameters.keys()}
56
- assert not (equidistant_x and (fix_x is not None))
57
-
58
- with gpytorch.settings.fast_computations(*hyperparameters.get('fast_computations', (True, True, True))):
59
- if equidistant_x:
60
- assert num_features == 1
61
- x = torch.linspace(0, 1., seq_len).unsqueeze(0).repeat(batch_size, 1).unsqueeze(-1)
62
- elif fix_x is not None:
63
- assert fix_x.shape == (seq_len, num_features)
64
- x = fix_x.unsqueeze(0).repeat(batch_size, 1, 1).to(device)
65
- else:
66
- if hyperparameters.get('sampling','uniform') == 'uniform':
67
- x = torch.rand(batch_size, seq_len, num_features, device=device)
68
- else:
69
- x = torch.randn(batch_size, seq_len, num_features, device=device)
70
- model, likelihood = get_model(x, torch.Tensor(), hyperparameters)
71
- model.to(device)
72
- # trained_model = ExactGPModel(train_x, train_y, likelihood).cuda()
73
- # trained_model.eval()
74
- successful_sample = False
75
- while not successful_sample:
76
- try:
77
- with gpytorch.settings.prior_mode(True):
78
- model, likelihood = get_model(x, torch.Tensor(), hyperparameters)
79
- model.to(device)
80
-
81
- d = model(x)
82
- sample_wo_noise = d.sample().transpose(0, 1) # this will be the target for the loss
83
- sample = likelihood(sample_wo_noise).sample() # this will be the input to the Transformer
84
- successful_sample = True
85
- except RuntimeError: # This can happen when torch.linalg.eigh fails. Restart with new init resolves this.
86
- print('GP Sampling unsuccessful, retrying.. ')
87
- print(x)
88
- print(hyperparameters)
89
-
90
- if bool(torch.any(torch.isnan(x)).detach().cpu().numpy()):
91
- print({"noise": hyperparameters['noise'], "outputscale": hyperparameters['outputscale']
92
- , "lengthscale": hyperparameters['lengthscale'], 'batch_size': batch_size})
93
-
94
- # TODO: Multi output
95
- return x.transpose(0, 1), sample, sample if hyperparameters.get("observation_noise", True) else sample_wo_noise
96
-
97
- DataLoader = get_batch_to_dataloader(get_batch)
98
-
99
- def get_model_on_device(x,y,hyperparameters,device):
100
- model, likelihood = get_model(x, y, hyperparameters)
101
- model.to(device)
102
- return model, likelihood
103
-
104
-
105
- @torch.no_grad()
106
- def evaluate(x, y, y_non_noisy, use_mse=False, hyperparameters={}, get_model_on_device=get_model_on_device, device=default_device, step_size=1, start_pos=0):
107
- start_time = time.time()
108
- losses_after_t = [.0] if start_pos == 0 else []
109
- all_losses_after_t = []
110
-
111
- with gpytorch.settings.fast_computations(*hyperparameters.get('fast_computations',(True,True,True))), gpytorch.settings.fast_pred_var(False):
112
- for t in range(max(start_pos, 1), len(x), step_size):
113
- loss_sum = 0.
114
- model, likelihood = get_model_on_device(x[:t].transpose(0, 1), y[:t].transpose(0, 1), hyperparameters, device)
115
-
116
-
117
- model.eval()
118
- # print([t.shape for t in model.train_inputs])
119
- # print(x[:t].transpose(0,1).shape, x[t].unsqueeze(1).shape, y[:t].transpose(0,1).shape)
120
- f = model(x[t].unsqueeze(1))
121
- l = likelihood(f)
122
- means = l.mean.squeeze()
123
- varis = l.covariance_matrix.squeeze()
124
- # print(l.variance.squeeze(), l.mean.squeeze(), y[t])
125
-
126
- assert len(means.shape) == len(varis.shape) == 1
127
- assert len(means) == len(varis) == x.shape[1]
128
-
129
- if use_mse:
130
- c = nn.MSELoss(reduction='none')
131
- ls = c(means, y[t])
132
- else:
133
- ls = -l.log_prob(y[t].unsqueeze(1))
134
-
135
- losses_after_t.append(ls.mean())
136
- all_losses_after_t.append(ls.flatten())
137
- return torch.stack(all_losses_after_t).to('cpu'), torch.tensor(losses_after_t).to('cpu'), time.time() - start_time
138
-
139
- if __name__ == '__main__':
140
- hps = (.1,.1,.1)
141
- for redo_idx in range(1):
142
- print(
143
- evaluate(*get_batch(1000, 10, hyperparameters=hps, num_features=10), use_mse=False, hyperparameters=hps))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lcpfn/priors/fast_gp_mix.py DELETED
@@ -1,394 +0,0 @@
1
- import time
2
- import functools
3
- import random
4
- import math
5
- import traceback
6
-
7
- import numpy as np
8
- import torch
9
- from torch import nn
10
- import gpytorch
11
- from botorch.models import SingleTaskGP
12
- from botorch.models.gp_regression import MIN_INFERRED_NOISE_LEVEL
13
- from botorch.fit import fit_gpytorch_model
14
- from gpytorch.mlls import ExactMarginalLogLikelihood
15
- from gpytorch.likelihoods import GaussianLikelihood
16
- from gpytorch.priors.torch_priors import GammaPrior, UniformPrior
17
- from gpytorch.constraints import GreaterThan
18
-
19
-
20
- from bar_distribution import BarDistribution
21
- from utils import default_device
22
- from .utils import get_batch_to_dataloader
23
- from . import fast_gp
24
-
25
- def get_model(x, y, hyperparameters: dict, sample=True):
26
- if hyperparameters.get('handmade', False):
27
- # We will use the simplest form of GP model, exact inference
28
- class ExactGPModel(gpytorch.models.ExactGP):
29
- def __init__(self, train_x, train_y, likelihood):
30
- super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
31
- self.mean_module = gpytorch.means.ConstantMean()
32
- self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.MaternKernel())
33
- self.mean_module.register_prior("mean_prior", UniformPrior(-1, 1), "constant")
34
- self.covar_module.base_kernel.register_prior("lengthscale_prior", UniformPrior(0.01, 0.5),
35
- "lengthscale")
36
- # model.covar_module.base_kernel.register_prior("period_length_prior", UniformPrior(0.05, 2.5), "period_length")
37
- self.covar_module.register_prior("outputscale_prior", UniformPrior(1, 2), "outputscale")
38
- likelihood.register_prior("noise_prior", UniformPrior(0.001, 0.01), "noise")
39
- self.to(x)
40
-
41
- def forward(self, x):
42
- mean_x = self.mean_module(x)
43
- covar_x = self.covar_module(x)
44
- return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
45
-
46
- likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.Positive())
47
- model = ExactGPModel(x, y, likelihood)
48
-
49
-
50
-
51
- else:
52
- aug_batch_shape = SingleTaskGP(x,y.unsqueeze(-1))._aug_batch_shape
53
- noise_prior = GammaPrior(hyperparameters.get('noise_concentration',1.1), hyperparameters.get('noise_rate',0.05))
54
- noise_prior_mode = (noise_prior.concentration - 1) / noise_prior.rate
55
- likelihood = GaussianLikelihood(
56
- noise_prior=noise_prior,
57
- batch_shape=aug_batch_shape,
58
- noise_constraint=GreaterThan(
59
- MIN_INFERRED_NOISE_LEVEL,
60
- transform=None,
61
- initial_value=noise_prior_mode,
62
- ),
63
- )
64
- model = SingleTaskGP(x, y.unsqueeze(-1),
65
- covar_module=gpytorch.kernels.ScaleKernel(
66
- gpytorch.kernels.MaternKernel(
67
- nu=hyperparameters.get('nu',2.5),
68
- ard_num_dims=x.shape[-1],
69
- batch_shape=aug_batch_shape,
70
- lengthscale_prior=gpytorch.priors.GammaPrior(hyperparameters.get('lengthscale_concentration',3.0), hyperparameters.get('lengthscale_rate',6.0)),
71
- ),
72
- batch_shape=aug_batch_shape,
73
- outputscale_prior=gpytorch.priors.GammaPrior(hyperparameters.get('outputscale_concentration',.5), hyperparameters.get('outputscale_rate',0.15)),
74
- ), likelihood=likelihood)
75
-
76
- likelihood = model.likelihood
77
- model.to(x.device)
78
- if sample:
79
- sampled_model = model.pyro_sample_from_prior()
80
- return sampled_model, sampled_model.likelihood
81
- else:
82
- assert not(hyperparameters.get('sigmoid', False)) and not(hyperparameters.get('y_minmax_norm', False)), "Sigmoid and y_minmax_norm can only be used to sample models..."
83
- return model, likelihood
84
-
85
-
86
- @torch.no_grad()
87
- def get_batch(batch_size, seq_len, num_features, device=default_device, hyperparameters=None,
88
- batch_size_per_gp_sample=None,
89
- fix_to_range=None, equidistant_x=False, **kwargs):
90
- '''
91
- This function is very similar to the equivalent in .fast_gp. The only difference is that this function operates over
92
- a mixture of GP priors.
93
- :param batch_size:
94
- :param seq_len:
95
- :param num_features:
96
- :param device:
97
- :param hyperparameters:
98
- :param for_regression:
99
- :return:
100
- '''
101
- hyperparameters = hyperparameters or {}
102
- with gpytorch.settings.fast_computations(*hyperparameters.get('fast_computations',(True,True,True))):
103
- batch_size_per_gp_sample = (batch_size_per_gp_sample or max(batch_size // 10,1))
104
- assert batch_size % batch_size_per_gp_sample == 0
105
-
106
- total_num_candidates = batch_size*(2**(fix_to_range is not None))
107
- num_candidates = batch_size_per_gp_sample * (2**(fix_to_range is not None))
108
- if equidistant_x:
109
- assert num_features == 1
110
- x = torch.linspace(0,1.,seq_len).unsqueeze(0).repeat(total_num_candidates,1).unsqueeze(-1)
111
- else:
112
- x = torch.rand(total_num_candidates, seq_len, num_features, device=device)
113
- samples = []
114
- samples_wo_noise = []
115
- for i in range(0,total_num_candidates,num_candidates):
116
- model, likelihood = get_model(x[i:i+num_candidates], torch.zeros(num_candidates,x.shape[1]).to(device), hyperparameters)
117
- model.to(device)
118
- likelihood.to(device)
119
- if hyperparameters.get('handmade', False):
120
- model.covar_module.base_kernel.lengthscale = model.covar_module.base_kernel.lengthscale.to(device)
121
- model.covar_module.outputscale = model.covar_module.outputscale.to(device)
122
- likelihood.noise = likelihood.noise.to(device)
123
- model.mean_module.constant = model.mean_module.constant.to(device)
124
-
125
- # trained_model = ExactGPModel(train_x, train_y, likelihood).cuda()
126
- # trained_model.eval()
127
- successful_sample = 0
128
- throwaway_share = 0.
129
- sampling_with_observation_noise = hyperparameters.get("observation_noise", True)
130
- while successful_sample < 1:
131
- with gpytorch.settings.prior_mode(True):
132
- #print(x.device, device, f'{model.covar_module.base_kernel.lengthscale=}, {model.covar_module.base_kernel.lengthscale.device=}')
133
-
134
-
135
- if sampling_with_observation_noise :
136
- d = model(x[i:i+num_candidates])
137
- d = likelihood(d)
138
- sample = d.sample() # bs_per_gp_s x T
139
-
140
- else:
141
- d = model(x[i:i+num_candidates])
142
- sample_wo_noise = d.sample()
143
- sample = likelihood(sample_wo_noise).sample()
144
-
145
- if hyperparameters.get('y_minmax_norm'):
146
- sample = ((sample - sample.min(1)[0]) / (sample.max(1)[0] - sample.min(1)[0]))
147
- if hyperparameters.get('sigmoid'):
148
- sample = sample.sigmoid()
149
-
150
- if not sampling_with_observation_noise:
151
- if hyperparameters.get('y_minmax_norm'):
152
- sample_wo_noise = ((sample_wo_noise - sample_wo_noise.min(1)[0]) / (sample_wo_noise.max(1)[0] - sample_wo_noise.min(1)[0]))
153
- if hyperparameters.get('sigmoid'):
154
- sample_wo_noise = sample_wo_noise.sigmoid()
155
-
156
- if fix_to_range is None:
157
- samples.append(sample.transpose(0, 1))
158
- if not sampling_with_observation_noise: samples_wo_noise.append(sample_wo_noise.transpose(0,1))
159
- successful_sample = True
160
- continue
161
-
162
- smaller_mask = sample < fix_to_range[0]
163
- larger_mask = sample >= fix_to_range[1]
164
- in_range_mask = ~ (smaller_mask | larger_mask).any(1)
165
- throwaway_share += (~in_range_mask[:batch_size_per_gp_sample]).sum()/batch_size_per_gp_sample
166
- if in_range_mask.sum() < batch_size_per_gp_sample:
167
- successful_sample -= 1
168
- if successful_sample < 100:
169
- print("Please change hyper-parameters (e.g. decrease outputscale_mean) it"
170
- "seems like the range is set to tight for your hyper-parameters.")
171
- continue
172
-
173
- x[i:i+batch_size_per_gp_sample] = x[i:i+num_candidates][in_range_mask][:batch_size_per_gp_sample]
174
- sample = sample[in_range_mask][:batch_size_per_gp_sample]
175
- samples.append(sample.transpose(0,1))
176
- if not sampling_with_observation_noise: samples_wo_noise.append(sample_wo_noise.transpose(0,1))
177
- successful_sample = True
178
-
179
- if random.random() < .01:
180
- print('throwaway share', throwaway_share/(batch_size//batch_size_per_gp_sample))
181
-
182
- #print(f'took {time.time() - start}')
183
-
184
- x = x.view(-1,batch_size,seq_len,num_features)[0]
185
- # TODO think about enabling the line below
186
- #sample = sample - sample[0, :].unsqueeze(0).expand(*sample.shape)
187
- x = x.transpose(0,1)
188
- sample = torch.cat(samples, 1)
189
-
190
- if sampling_with_observation_noise:
191
- target_sample = sample
192
- else:
193
- target_sample = torch.cat(samples_wo_noise, 1)
194
-
195
- assert x.shape[:2] == sample.shape[:2]
196
-
197
- return x, sample, target_sample # x.shape = (T,B,H)
198
-
199
-
200
- class DataLoader(get_batch_to_dataloader(get_batch)):
201
- @torch.no_grad()
202
- def validate(self, model, step_size=1, start_pos=0):
203
- if isinstance(model.criterion, BarDistribution):
204
- (_, x,y), target_y, eval_pos = self.gbm(**self.get_batch_kwargs)
205
- model.eval()
206
- losses = []
207
- for eval_pos in range(start_pos, len(x), step_size):
208
- logits = model((x,y), single_eval_pos=eval_pos)
209
- means = model.criterion.mean(logits) # num_evals x batch_size
210
- mse = nn.MSELoss()
211
- losses.append(mse(means[0], target_y[eval_pos]))
212
- model.train()
213
- return torch.stack(losses)
214
- else:
215
- return 123.
216
-
217
-
218
- @torch.enable_grad()
219
- def get_fitted_model(x, y, hyperparameters, device):
220
- # fit the gaussian process
221
- model, likelihood = get_model(x,y,hyperparameters,sample=False)
222
- #print(model.covar_module.base_kernel.lengthscale)
223
- model.to(device)
224
- mll = ExactMarginalLogLikelihood(likelihood, model)
225
- model.train()
226
- fit_gpytorch_model(mll)
227
- #print(model.covar_module.base_kernel.lengthscale)
228
- return model, likelihood
229
-
230
-
231
- evaluate = functools.partial(fast_gp.evaluate, get_model_on_device=get_fitted_model)
232
-
233
- def get_mcmc_model(x, y, hyperparameters, device, num_samples, warmup_steps, obs=True):
234
- from pyro.infer.mcmc import NUTS, MCMC, HMC
235
- import pyro
236
- x = x.to(device)
237
- y = y.to(device)
238
- model, likelihood = get_model(x, y, hyperparameters, sample=False)
239
- model.to(device)
240
-
241
-
242
- def pyro_model(x, y):
243
- sampled_model = model.pyro_sample_from_prior()
244
- output = sampled_model.likelihood(sampled_model(x))
245
- if obs:
246
- return pyro.sample("obs", output, obs=y)
247
-
248
- nuts_kernel = NUTS(pyro_model)
249
- mcmc_run = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup_steps, num_chains=1)
250
- #print(x.shape)
251
- mcmc_run.run(x, y)
252
- #print(mcmc_run.get_samples())
253
- model.pyro_load_from_samples(mcmc_run.get_samples()) # pyro.infer wie noah?
254
- model.eval()
255
- #print(mcmc_run.diagnostics())
256
- # test_x = torch.linspace(0, 1, 101).unsqueeze(-1)
257
- # test_y = torch.sin(test_x * (2 * math.pi))
258
- # expanded_test_x = test_x.unsqueeze(0).repeat(num_samples, 1, 1)
259
- # output = model(expanded_test_x)
260
- #print(x.shape)
261
- return model, likelihood
262
- # output = model(x[-1].unsqueeze(1).repeat(1, num_samples 1))
263
- # return output.mean
264
-
265
-
266
-
267
-
268
- def get_mean_logdensity(dists, x: torch.Tensor, full_range=None):
269
- means = torch.cat([d.mean.squeeze() for d in dists], 0)
270
- vars = torch.cat([d.variance.squeeze() for d in dists], 0)
271
- assert len(means.shape) == 1 and len(vars.shape) == 1
272
- dist = torch.distributions.Normal(means, vars.sqrt())
273
- #logprobs = torch.cat([d.log_prob(x) for d in dists], 0)
274
- logprobs = dist.log_prob(x)
275
- if full_range is not None:
276
- used_weight = 1. - (dist.cdf(torch.tensor(full_range[0])) + (1.-dist.cdf(torch.tensor(full_range[1]))))
277
- if torch.isinf(-torch.log(used_weight)).any() or torch.isinf(torch.log(used_weight)).any():
278
- print('factor is inf', -torch.log(used_weight))
279
- logprobs -= torch.log(used_weight)
280
- assert len(logprobs.shape) == 1
281
- #print(logprobs)
282
- return torch.logsumexp(logprobs, 0) - math.log(len(logprobs))
283
-
284
-
285
- def evaluate_(x, y, y_non_noisy, hyperparameters=None, device=default_device, num_samples=100, warmup_steps=300,
286
- full_range=None, min_seq_len=0, use_likelihood=False, obs=True):
287
- with gpytorch.settings.fast_computations(*hyperparameters.get('fast_computations',(True,True,True))), gpytorch.settings.fast_pred_var(False):
288
- x = x.to(device).double()
289
- y = y.to(device).double()
290
- start_time = time.time()
291
- losses_after_t = [.0] if min_seq_len == 0 else []
292
- all_losses = []
293
-
294
- for t in range(max(min_seq_len,1), len(x)):
295
- #print('Timestep', t)
296
- loss_sum = 0.
297
- step_losses = []
298
- start_step = time.time()
299
- print(x.shape, y.shape)
300
- for b_i in range(x.shape[1]):
301
- x_train = x[:t,b_i]
302
- y_train = y[:t,b_i]
303
- from pyro.infer.mcmc import NUTS, MCMC, HMC
304
- import pyro
305
- x_train = x_train.to(device)
306
- y_train = y_train.to(device)
307
- print(x_train.shape, y_train.shape)
308
- model, likelihood = get_model(x_train, y_train, hyperparameters, sample=False)
309
- model.to(device)
310
-
311
- def pyro_model(x, y):
312
- sampled_model = model.pyro_sample_from_prior()
313
- output = sampled_model.likelihood(sampled_model(x))
314
- if obs:
315
- return pyro.sample("obs", output, obs=y)
316
-
317
- nuts_kernel = NUTS(pyro_model)
318
- mcmc_run = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup_steps, num_chains=1, disable_progbar=True)
319
- # print(x.shape)
320
- mcmc_run.run(x_train, y_train)
321
- # print(mcmc_run.get_samples())
322
- model.pyro_load_from_samples(mcmc_run.get_samples())
323
- model.eval()
324
-
325
- with torch.no_grad():
326
- dists = model(x[t, b_i, :].unsqueeze(
327
- 0).repeat(num_samples, 1, 1))
328
- if use_likelihood:
329
- dists = likelihood(dists)
330
- l = -get_mean_logdensity([dists], y[t, b_i].repeat(num_samples), full_range)
331
- print(l)
332
-
333
- step_losses.append(l.item())
334
- #print('loss',l.item())
335
- print(f'current average loss at step {t} is {sum(step_losses)/len(step_losses)} with {(time.time()-start_step)/len(step_losses)} s per eval.')
336
- loss_sum += l
337
-
338
- loss_sum /= x.shape[1]
339
- all_losses.append(step_losses)
340
- print(f'loss after step {t} is {loss_sum}')
341
- losses_after_t.append(loss_sum)
342
- print(f'losses so far {torch.tensor(losses_after_t)}')
343
- return torch.tensor(losses_after_t), time.time() - start_time, all_losses
344
-
345
-
346
-
347
-
348
-
349
- if __name__ == '__main__':
350
- import argparse
351
-
352
- parser = argparse.ArgumentParser()
353
- parser.add_argument('--batch_size', type=int)
354
- parser.add_argument('--seq_len', type=int)
355
- parser.add_argument('--min_seq_len', type=int, default=0)
356
- parser.add_argument('--warmup_steps', type=int)
357
- parser.add_argument('--num_samples', type=int)
358
- parser.add_argument('--min_y', type=int)
359
- parser.add_argument('--max_y', type=int)
360
- parser.add_argument('--dim', type=int, default=1)
361
- parser.add_argument('--use_likelihood', action='store_true')
362
- parser.add_argument('--device', default='cpu')
363
- parser.add_argument('--outputscale_concentraion', default=2., type=float)
364
- parser.add_argument('--noise_concentration', default=1.1, type=float)
365
- parser.add_argument('--noise_rate', default=.05, type=float)
366
- parser.add_argument('--handmade', action='store_true')
367
- parser.add_argument('--no_obs', action='store_true')
368
- parser.add_argument('--seed', type=int, default=0)
369
-
370
- args = parser.parse_args()
371
- import pyro
372
- import gpytorch
373
- print(pyro.__version__)
374
- print(gpytorch.__version__)
375
-
376
-
377
- print('min_y:', args.min_y)
378
- full_range = (None if args.min_y is None else (args.min_y,args.max_y))
379
-
380
- hps = {'handmade': args.handmade, 'outputscale_concentration': args.outputscale_concentraion, 'noise_concentration': args.noise_concentration,
381
- 'noise_rate': args.noise_rate, 'fast_computations': (False,False,False)}
382
- if args.seed:
383
- torch.manual_seed(args.seed)
384
- np.random.seed(args.seed)
385
- random.seed(args.seed)
386
- x, y, _ = get_batch(args.batch_size, args.seq_len, args.dim, fix_to_range=full_range, hyperparameters=hps)
387
- #assert args.seq_len == 7 and args.min_seq_len == 6
388
- #x = torch.cat([torch.linspace(0, 1, 6), torch.tensor([.33])]).unsqueeze(1).repeat(1,args.batch_size).unsqueeze(-1)
389
- #y = torch.sin(x * (2 * math.pi)).squeeze(-1)
390
- print('RESULT:', evaluate_(x, y, y, device=args.device, warmup_steps=args.warmup_steps,
391
- num_samples=args.num_samples, full_range=full_range, min_seq_len=args.min_seq_len,
392
- hyperparameters=hps, use_likelihood=args.use_likelihood, obs=not args.no_obs))
393
-
394
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lcpfn/priors/gp.py DELETED
@@ -1,69 +0,0 @@
1
- import time
2
- import random
3
-
4
- import numpy as np
5
- import torch
6
- from torch import nn
7
- from sklearn.gaussian_process import GaussianProcessRegressor
8
- from sklearn.gaussian_process.kernels import RBF, DotProduct, WhiteKernel
9
- from .utils import get_batch_to_dataloader
10
-
11
-
12
- length_scale_sampling_gp = .6
13
-
14
- def get_gp(length_scale=None):
15
- return GaussianProcessRegressor(
16
- kernel=RBF(length_scale=length_scale or length_scale_sampling_gp, length_scale_bounds='fixed'),
17
- random_state=0, optimizer=None)
18
-
19
-
20
- def get_batch(batch_size, seq_len, num_features, noisy_std=None):
21
- # m = torch.normal(0.,.1,size=(batch_size,num_features))
22
- # m2 = torch.rand(batch_size,num_features)
23
- # b = 0 # torch.rand(batch_size)
24
- x_t = torch.rand(batch_size, seq_len, num_features)
25
- # gp_b = TensorGP(kernel=TensorRBF(noisy_std))
26
- # y_t = gp_b.sample_from_GP_prior(x_t).detach()
27
-
28
- gpr = get_gp(noisy_std)
29
- y_t = torch.zeros(batch_size, seq_len)
30
-
31
- for i in range(len(y_t)):
32
- y_t[i] += gpr.sample_y(x_t[i], random_state=random.randint(0, 2 ** 32)).squeeze()
33
- x, y = x_t.transpose(0, 1), y_t.transpose(0, 1)
34
- # x, _ = torch.sort(x,dim=0)
35
- return x, y, y
36
-
37
-
38
- DataLoader = get_batch_to_dataloader(get_batch)
39
-
40
- def evaluate(x, y, y_non_noisy, use_mse=False, length_scale=length_scale_sampling_gp):
41
- start_time = time.time()
42
- losses_after_t = [.0]
43
- for t in range(1, len(x)):
44
- loss_sum = 0.
45
- for b_i in range(x.shape[1]):
46
- gpr = get_gp(length_scale).fit(x[:t, b_i], y[:t, b_i])
47
- means, stds = gpr.predict(x[t, b_i].unsqueeze(0), return_std=True)
48
- assert len(means) == 1 == len(stds)
49
- if use_mse:
50
- c = nn.MSELoss()
51
- l = c(torch.tensor(means), y[t, b_i].unsqueeze(-1))
52
- else:
53
- c = nn.GaussianNLLLoss(full=True)
54
- l = c(torch.tensor(means), y[t, b_i].unsqueeze(-1),
55
- var=torch.tensor(stds) ** 2)
56
- loss_sum += l
57
-
58
-
59
- losses_after_t.append(loss_sum / x.shape[1])
60
-
61
- return torch.tensor(losses_after_t), time.time()-start_time
62
-
63
- if __name__ == '__main__':
64
- ls = .1
65
- for alpha in set([ls, ls * 1.1, ls * .9]):
66
- print(alpha)
67
- for redo_idx in range(1):
68
- print(
69
- evaluate(*get_batch(1000, 10, noisy_std=ls, num_features=10), use_mse=False, length_scale=alpha))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lcpfn/priors/prior.py DELETED
@@ -1,25 +0,0 @@
1
- from abc import ABCMeta, abstractmethod
2
- from torch.utils.data import DataLoader
3
-
4
-
5
- class PriorDataLoader(DataLoader, metaclass=ABCMeta):
6
- @abstractmethod
7
- def __init__(self, num_steps, batch_size, eval_pos_seq_len_sampler, seq_len_maximum, device, **kwargs):
8
- """
9
-
10
- :param num_steps: int, first argument, the number of steps to take per epoch, i.e. iteration of the DataLoader
11
- :param batch_size: int, number of datasets per batch
12
- :param eval_pos_seq_len_sampler: callable, it takes no arguments and returns a tuple (single eval pos, bptt)
13
- :param kwargs: for future compatibility it is good to have a final all catch, as new kwargs might be introduced
14
- """
15
- pass
16
-
17
- # A class or object variable `num_features`: int
18
- # Optional: `validate` function that accepts a transformer model
19
-
20
- # The DataLoader iter should return batches of the form ([style], x, y), target_y, single_eval_pos
21
- # We follow sequence len (s) first, batch size (b) second. So x: (s,b,num_features), y,target_y: (s,b)
22
- # and style: Optional[(b,num_style_params)], style can be omitted or set to None, if it is not intended to be used.
23
-
24
- # For more references, see `priors/utils.py` for a pretty general implementation of a DataLoader
25
- # and `train.py` for the only call of it.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lcpfn/priors/pyro.py DELETED
@@ -1,41 +0,0 @@
1
- import random
2
-
3
- import torch
4
- from torch import nn
5
-
6
- from utils import default_device
7
- from .utils import get_batch_to_dataloader
8
-
9
-
10
- def get_batch(batch_size, seq_len, batch_size_per_gp_sample=None, **config):
11
- batch_size_per_gp_sample = batch_size_per_gp_sample or batch_size // 16
12
- assert batch_size % batch_size_per_gp_sample == 0, 'Please choose a batch_size divisible by batch_size_per_gp_sample.'
13
- num_models = batch_size // batch_size_per_gp_sample
14
- # standard kaiming uniform init currently...
15
-
16
- models = [config['model']() for _ in range(num_models)]
17
-
18
- sample = sum([[model(seq_len=seq_len) for _ in range(0,batch_size_per_gp_sample)] for model in models],[])
19
-
20
- def normalize_data(data):
21
- mean = data.mean(0)
22
- std = data.std(0) + .000001
23
- eval_xs = (data - mean) / std
24
-
25
- return eval_xs
26
-
27
- x, y = zip(*sample)
28
-
29
- y = torch.stack(y, 1).squeeze(-1).detach()
30
- x = torch.stack(x, 1).detach()
31
-
32
- if 'normalize_y' in config and config['normalize_y']:
33
- x, y = normalize_data(x), normalize_data(y)
34
- elif 'normalize_y' in config and config['normalize']:
35
- x, y = normalize_data(x), y
36
-
37
- return x, y, y
38
-
39
-
40
- DataLoader = get_batch_to_dataloader(get_batch)
41
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lcpfn/priors/ridge.py DELETED
@@ -1,37 +0,0 @@
1
- import random
2
- import time
3
-
4
- import numpy as np
5
- import torch
6
- from torch import nn
7
- from sklearn.linear_model import Ridge
8
- from .utils import get_batch_to_dataloader
9
-
10
- def get_batch(batch_size, seq_len, num_features, noisy_std = .1):
11
- m = torch.normal(0., .1, size=(batch_size,num_features))
12
- b = 0 # torch.rand(batch_size)
13
- x = torch.rand(seq_len, batch_size,num_features)
14
- y_non_noisy = torch.einsum('bf,tbf->tb',m,x)
15
- y = y_non_noisy + torch.normal(torch.zeros_like(y_non_noisy),noisy_std) # noisy_std is alpha
16
- return x, y, y_non_noisy
17
-
18
- DataLoader = get_batch_to_dataloader(get_batch)
19
-
20
-
21
- def evaluate(x,y,y_non_noisy, alpha=0.):
22
- start_time = time.time()
23
- losses_after_t = [.0]
24
- for t in range(1,len(x)):
25
- loss_sum = 0.
26
- for b_i in range(x.shape[1]):
27
- clf = Ridge(alpha=alpha)
28
- clf.fit(x[:t,b_i],y[:t,b_i])
29
- y_ = clf.predict(x[t,b_i].unsqueeze(0))
30
- l = nn.MSELoss()(y_non_noisy[t,b_i].unsqueeze(0),torch.tensor(y_))
31
- loss_sum += l
32
- losses_after_t.append(loss_sum/x.shape[1])
33
- return torch.tensor(losses_after_t), time.time()-start_time
34
-
35
- if __name__ == '__main__':
36
- for alpha in [.001,.01,.5,1.]:
37
- print(alpha, evaluate(*get_batch(1000,10,noisy_std=.01),alpha=alpha))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lcpfn/priors/stroke.py DELETED
@@ -1,143 +0,0 @@
1
- from PIL import Image, ImageDraw, ImageFilter
2
- import random
3
- import math
4
-
5
- import torch
6
- import numpy as np
7
- from .utils import get_batch_to_dataloader
8
-
9
- def mnist_prior(num_classes=2, size=28, min_max_strokes=(1,3), min_max_len=(5/28,20/28), min_max_start=(2/28,25/28),
10
- min_max_width=(1/28,4/28), max_offset=4/28, max_target_offset=2/28):
11
- classes = []
12
- for i in range(num_classes):
13
- num_strokes = random.randint(*min_max_strokes)
14
- len_strokes = [random.randint(int(size * min_max_len[0]), int(size * min_max_len[1])) for i in range(num_strokes)]
15
- stroke_start_points = [
16
- (random.randint(int(size * min_max_start[0]), int(size * min_max_start[1])), random.randint(int(size * min_max_start[0]), int(size * min_max_start[1]))) for i in
17
- range(num_strokes)]
18
- stroke_directions = []
19
- # i = Image.fromarray(np.zeros((28,28),dtype=np.uint8))
20
- # draw = ImageDraw.Draw(i)
21
- for i in range(num_strokes):
22
- sp, length = stroke_start_points[i], len_strokes[i]
23
- counter = 0
24
- while True:
25
- if counter % 3 == 0:
26
- length = random.randint(int(size * min_max_len[0]), int(size * min_max_len[1]))
27
- sp = (
28
- random.randint(int(size * min_max_start[0]), int(size * min_max_start[1])), random.randint(int(size * min_max_start[0]), int(size * min_max_start[1])))
29
- stroke_start_points[i], len_strokes[i] = sp, length
30
- radians = random.random() * 2 * math.pi
31
- x_vel = math.cos(radians) * length
32
- y_vel = math.sin(radians) * length
33
- new_p = (sp[0] + x_vel, sp[1] + y_vel)
34
- # print(math.degrees(radians),sp,new_p)
35
- if not any(n > size - 1 or n < 0 for n in new_p):
36
- break
37
- counter += 1
38
- stroke_directions.append(radians)
39
- # print([round(x) for x in sp+new_p])
40
- # draw.line([round(x) for x in sp+new_p], fill=128, width=3)
41
- classes.append((len_strokes, stroke_start_points, stroke_directions))
42
-
43
- generator_functions = []
44
- for c in classes:
45
- def g(c=c):
46
- len_strokes, stroke_start_points, stroke_directions = c
47
- i = Image.fromarray(np.zeros((size, size), dtype=np.uint8))
48
- draw = ImageDraw.Draw(i)
49
- width = random.randint(int(size * min_max_width[0]), int(size * min_max_width[1]))
50
- offset = random.randint(int(-size * max_offset), int(size * max_offset)), random.randint(int(- size * max_offset), int(size * max_offset))
51
- for sp, length, radians in zip(stroke_start_points, len_strokes, stroke_directions):
52
- sp = (sp[0] + offset[0], sp[1] + offset[1])
53
- x_vel = math.cos(radians) * length + random.randint(int(-size * max_target_offset), int(size * max_target_offset))
54
- y_vel = math.sin(radians) * length + random.randint(int(-size * max_target_offset), int(size * max_target_offset))
55
- new_p = (sp[0] + x_vel, sp[1] + y_vel)
56
- stroke_directions.append(radians)
57
- draw.line([round(x) for x in sp + new_p], fill=128, width=width)
58
- a_i = np.array(i)
59
- a_i[a_i == 128] = np.random.randint(200, 255, size=a_i.shape)[a_i == 128]
60
- return Image.fromarray(a_i).filter(ImageFilter.GaussianBlur(.2))
61
-
62
- generator_functions.append(g)
63
- return generator_functions
64
-
65
-
66
- # g1,g2 = mnist_prior(2)
67
-
68
- # for i in [g1() for _ in range(10)]:
69
- # display(i.resize((200,200)))
70
-
71
- from torchvision.transforms import ToTensor, ToPILImage
72
-
73
-
74
- def normalize(x):
75
- return (x-x.mean())/(x.std()+.000001)
76
-
77
- from os import path, listdir
78
- import random
79
-
80
- def get_batch(batch_size, seq_len, num_features=None, noisy_std=None, only_train_for_last_idx=False, normalize_x=False, num_outputs=2, use_saved_from=None, **kwargs): # num_features = 28*28=784
81
- if use_saved_from is not None:
82
- directory = path.join(use_saved_from, f'len_{seq_len}_out_{num_outputs}_features_{num_features}_bs_{batch_size}')
83
- filename = random.choice(listdir(directory))
84
- return torch.load(path.join(directory,filename))
85
-
86
- size = math.isqrt(num_features)
87
- assert size * size == num_features, 'num_features needs to be the square of an integer.'
88
- if only_train_for_last_idx:
89
- assert (seq_len-1) % num_outputs == 0
90
-
91
- # assert seq_len % 2 == 0, "assert seq_len % 2 == 0"
92
- batch = []
93
- y = []
94
- target_y = []
95
- for b_i in range(batch_size):
96
- gs = mnist_prior(num_outputs, size, **kwargs)
97
- if only_train_for_last_idx:
98
- generators = [i for i in range(len(gs)) for _ in range((seq_len-1) // num_outputs)]
99
- random.shuffle(generators)
100
- generators += [random.randint(0, len(gs) - 1)]
101
- target = [-100 for _ in generators]
102
- target[-1] = generators[-1]
103
- else:
104
- generators = [random.randint(0, len(gs) - 1) for _ in range(seq_len)]
105
- target = generators
106
- normalize_or_not = lambda x: normalize(x) if normalize_x else x
107
- s = torch.cat([normalize_or_not(ToTensor()(gs[f_i]())) for f_i in generators], 0)
108
- batch.append(s)
109
- y.append(torch.tensor(generators))
110
- target_y.append(torch.tensor(target))
111
- x = torch.stack(batch, 1).view(seq_len, batch_size, -1)
112
- y = torch.stack(y, 1)
113
- target_y = torch.stack(target_y, 1)
114
- return x,y,target_y
115
-
116
- DataLoader = get_batch_to_dataloader(get_batch)
117
- DataLoader.num_outputs = 2
118
-
119
- if __name__ == '__main__':
120
- g1, g2 = mnist_prior(2, size=3)
121
-
122
- # for i in range(10):
123
- # print(PILToTensor()(g1()))
124
- # display(ToPILImage()(PILToTensor()(g1())).resize((200,200)))
125
- # display(g2().resize((200,200)))
126
-
127
- size = 10
128
- x, y = get_batch(1, 10, num_features=size * size)
129
-
130
- x_ = x[..., :-1].squeeze(1)
131
- last_y = x[..., -1].squeeze(1)
132
- y = y.squeeze(1)
133
-
134
- # print(y)
135
-
136
- for i, y_, last_y_, x__ in zip(x_, y, last_y, x.squeeze(1)):
137
- # print(y_)
138
- # print(i.shape)
139
- # print(x__)
140
- img = ToPILImage()(i.view(size, size))
141
- # display(img.resize((200,200)))
142
-
143
- print(y, last_y)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lcpfn/priors/utils.py DELETED
@@ -1,151 +0,0 @@
1
- import random
2
-
3
- import pandas as pd
4
- import torch
5
-
6
- from lcpfn.utils import set_locals_in_self
7
- from itertools import repeat
8
- from .prior import PriorDataLoader
9
- from torch import nn
10
- import numpy as np
11
- import matplotlib.pyplot as plt
12
- import matplotlib.gridspec as gridspec
13
- import scipy.stats as stats
14
- import math
15
-
16
- def get_batch_to_dataloader(get_batch_method_):
17
- class DL(PriorDataLoader):
18
- get_batch_method = get_batch_method_
19
-
20
- # Caution, you might need to set self.num_features manually if it is not part of the args.
21
- def __init__(self, num_steps, **get_batch_kwargs):
22
- set_locals_in_self(locals())
23
-
24
- # The stuff outside the or is set as class attribute before instantiation.
25
- self.num_features = get_batch_kwargs.get('num_features') or self.num_features
26
- print('DataLoader.__dict__', self.__dict__)
27
-
28
- @staticmethod
29
- def gbm(*args, eval_pos_seq_len_sampler, **kwargs):
30
- kwargs['single_eval_pos'], kwargs['seq_len'] = eval_pos_seq_len_sampler()
31
- # Scales the batch size dynamically with the power of 'dynamic_batch_size'.
32
- # A transformer with quadratic memory usage in the seq len would need a power of 2 to keep memory constant.
33
- if 'dynamic_batch_size' in kwargs and kwargs['dynamic_batch_size'] > 0:
34
- kwargs['batch_size'] = kwargs['batch_size'] * math.floor(math.pow(kwargs['seq_len_maximum'], kwargs['dynamic_batch_size']) / math.pow(kwargs['seq_len'], kwargs['dynamic_batch_size']))
35
- batch = get_batch_method_(*args, **kwargs)
36
- x, y, target_y, style = batch if len(batch) == 4 else (batch[0], batch[1], batch[2], None)
37
- return (style, x, y), target_y, kwargs['single_eval_pos']
38
-
39
- def __len__(self):
40
- return self.num_steps
41
-
42
- def __iter__(self):
43
- return iter(self.gbm(**self.get_batch_kwargs) for _ in range(self.num_steps))
44
-
45
- return DL
46
-
47
- """
48
- import seaborn as sns
49
- def plot_features(data, targets, fig=None):
50
- if torch.is_tensor(data):
51
- data = data.detach().cpu().numpy()
52
- targets = targets.detach().cpu().numpy()
53
- fig2 = plt.figure(figsize=(8, 8))
54
- spec2 = gridspec.GridSpec(ncols=data.shape[1], nrows=data.shape[1], figure=fig2)
55
- for d in range(0, data.shape[1]):
56
- for d2 in range(0, data.shape[1]):
57
- sub_ax = fig2.add_subplot(spec2[d, d2])
58
- if d == d2:
59
- sns.kdeplot(data[:, d],hue=targets[:],ax=sub_ax,legend=False, palette="deep")
60
- sub_ax.set(ylabel=None)
61
- else:
62
- sns.scatterplot(data[:, d], data[:, d2],
63
- hue=targets[:],legend=False, palette="deep")
64
- #plt.scatter(data[:, d], data[:, d2],
65
- # c=targets[:])
66
- sub_ax.get_xaxis().set_ticks([])
67
- sub_ax.get_yaxis().set_ticks([])
68
- plt.subplots_adjust(wspace=0.05, hspace=0.05)
69
- fig2.show()
70
-
71
-
72
- def plot_prior(prior):
73
- s = np.array([prior() for _ in range(0, 1000)])
74
- count, bins, ignored = plt.hist(s, 50, density=True)
75
- print(s.min())
76
- plt.show()
77
- """
78
-
79
- trunc_norm_sampler_f = lambda mu, sigma : lambda: stats.truncnorm((0 - mu) / sigma, (1000000 - mu) / sigma, loc=mu, scale=sigma).rvs(1)[0]
80
- beta_sampler_f = lambda a, b : lambda : np.random.beta(a, b)
81
- gamma_sampler_f = lambda a, b : lambda : np.random.gamma(a, b)
82
- uniform_sampler_f = lambda a, b : lambda : np.random.uniform(a, b)
83
- uniform_int_sampler_f = lambda a, b : lambda : round(np.random.uniform(a, b))
84
- def zipf_sampler_f(a, b, c):
85
- x = np.arange(b, c)
86
- weights = x ** (-a)
87
- weights /= weights.sum()
88
- return lambda : stats.rv_discrete(name='bounded_zipf', values=(x, weights)).rvs(1)
89
- scaled_beta_sampler_f = lambda a, b, scale, minimum : lambda : minimum + round(beta_sampler_f(a, b)() * (scale - minimum))
90
-
91
-
92
- def normalize_by_used_features_f(x, num_features_used, num_features, normalize_with_sqrt=False):
93
- if normalize_with_sqrt:
94
- return x / (num_features_used / num_features)**(1 / 2)
95
- return x / (num_features_used / num_features)
96
-
97
-
98
- def order_by_y(x, y):
99
- order = torch.argsort(y if random.randint(0, 1) else -y, dim=0)[:, 0, 0]
100
- order = order.reshape(2, -1).transpose(0, 1).reshape(-1)#.reshape(seq_len)
101
- x = x[order] # .reshape(2, -1).transpose(0, 1).reshape(-1).flip([0]).reshape(seq_len, 1, -1)
102
- y = y[order] # .reshape(2, -1).transpose(0, 1).reshape(-1).reshape(seq_len, 1, -1)
103
-
104
- return x, y
105
-
106
- def randomize_classes(x, num_classes):
107
- classes = torch.arange(0, num_classes, device=x.device)
108
- random_classes = torch.randperm(num_classes, device=x.device).type(x.type())
109
- x = ((x.unsqueeze(-1) == classes) * random_classes).sum(-1)
110
- return x
111
-
112
-
113
- class CategoricalActivation(nn.Module):
114
- def __init__(self, categorical_p=0.1, ordered_p=0.7
115
- , keep_activation_size=False
116
- , num_classes_sampler=zipf_sampler_f(0.8, 1, 10)):
117
- self.categorical_p = categorical_p
118
- self.ordered_p = ordered_p
119
- self.keep_activation_size = keep_activation_size
120
- self.num_classes_sampler = num_classes_sampler
121
-
122
- super().__init__()
123
-
124
- def forward(self, x):
125
- # x shape: T, B, H
126
-
127
- x = nn.Softsign()(x)
128
-
129
- num_classes = self.num_classes_sampler()
130
- hid_strength = torch.abs(x).mean(0).unsqueeze(0) if self.keep_activation_size else None
131
-
132
- categorical_classes = torch.rand((x.shape[1], x.shape[2])) < self.categorical_p
133
- class_boundaries = torch.zeros((num_classes - 1, x.shape[1], x.shape[2]), device=x.device, dtype=x.dtype)
134
- # Sample a different index for each hidden dimension, but shared for all batches
135
- for b in range(x.shape[1]):
136
- for h in range(x.shape[2]):
137
- ind = torch.randint(0, x.shape[0], (num_classes - 1,))
138
- class_boundaries[:, b, h] = x[ind, b, h]
139
-
140
- for b in range(x.shape[1]):
141
- x_rel = x[:, b, categorical_classes[b]]
142
- boundaries_rel = class_boundaries[:, b, categorical_classes[b]].unsqueeze(1)
143
- x[:, b, categorical_classes[b]] = (x_rel > boundaries_rel).sum(dim=0).float() - num_classes / 2
144
-
145
- ordered_classes = torch.rand((x.shape[1],x.shape[2])) < self.ordered_p
146
- ordered_classes = torch.logical_and(ordered_classes, categorical_classes)
147
- x[:, ordered_classes] = randomize_classes(x[:, ordered_classes], num_classes)
148
-
149
- x = x * hid_strength if self.keep_activation_size else x
150
-
151
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lcpfn/train.py DELETED
@@ -1,336 +0,0 @@
1
- import itertools
2
- import time
3
- from contextlib import nullcontext
4
-
5
- import torch
6
- from torch import nn
7
-
8
- from lcpfn import utils
9
- from lcpfn.transformer import TransformerModel
10
- from lcpfn.bar_distribution import (
11
- BarDistribution,
12
- )
13
- from lcpfn.utils import (
14
- get_cosine_schedule_with_warmup,
15
- get_openai_lr,
16
- )
17
- from lcpfn import positional_encodings
18
- from lcpfn.utils import init_dist
19
- from torch.cuda.amp import autocast, GradScaler
20
-
21
-
22
- class Losses:
23
- gaussian = nn.GaussianNLLLoss(full=True, reduction="none")
24
- mse = nn.MSELoss(reduction="none")
25
- ce = lambda num_classes: nn.CrossEntropyLoss(
26
- reduction="none", weight=torch.ones(num_classes)
27
- )
28
- bce = nn.BCEWithLogitsLoss(reduction="none")
29
- get_BarDistribution = BarDistribution
30
-
31
-
32
- def train(
33
- priordataloader_class,
34
- criterion,
35
- encoder_generator,
36
- emsize=200,
37
- nhid=200,
38
- nlayers=6,
39
- nhead=2,
40
- dropout=0.2,
41
- epochs=10,
42
- steps_per_epoch=100,
43
- batch_size=200,
44
- bptt=10,
45
- lr=None,
46
- weight_decay=0.0,
47
- warmup_epochs=10,
48
- input_normalization=False,
49
- y_encoder_generator=None,
50
- pos_encoder_generator=None,
51
- decoder=None,
52
- extra_prior_kwargs_dict={},
53
- scheduler=get_cosine_schedule_with_warmup,
54
- load_weights_from_this_state_dict=None,
55
- validation_period=10,
56
- single_eval_pos_gen=None,
57
- bptt_extra_samples=None,
58
- gpu_device="cuda:0",
59
- aggregate_k_gradients=1,
60
- verbose=True,
61
- style_encoder_generator=None,
62
- epoch_callback=None,
63
- initializer=None,
64
- initialize_with_model=None,
65
- train_mixed_precision=False,
66
- saving_period=10,
67
- checkpoint_file=None,
68
- load_optimizer_from_this_state_dict=None,
69
- output_path=None,
70
- **model_extra_args,
71
- ):
72
- device = gpu_device if torch.cuda.is_available() else "cpu:0"
73
- print(f"Using {device} device")
74
- using_dist, rank, device = init_dist(device)
75
- single_eval_pos_gen = (
76
- single_eval_pos_gen
77
- if callable(single_eval_pos_gen)
78
- else lambda: single_eval_pos_gen
79
- )
80
-
81
- def eval_pos_seq_len_sampler():
82
- single_eval_pos = single_eval_pos_gen()
83
- if bptt_extra_samples:
84
- return single_eval_pos, single_eval_pos + bptt_extra_samples
85
- else:
86
- return single_eval_pos, bptt
87
-
88
- dl = priordataloader_class(
89
- num_steps=steps_per_epoch,
90
- batch_size=batch_size,
91
- eval_pos_seq_len_sampler=eval_pos_seq_len_sampler,
92
- seq_len_maximum=bptt + (bptt_extra_samples if bptt_extra_samples else 0),
93
- device=device,
94
- **extra_prior_kwargs_dict,
95
- )
96
-
97
- encoder = encoder_generator(dl.num_features, emsize)
98
- style_def = next(iter(dl))[0][
99
- 0
100
- ] # This is (style, x, y), target with x and y with batch size
101
- print(f"Style definition: {style_def}")
102
- style_encoder = (
103
- style_encoder_generator(hyperparameter_definitions=style_def[0], em_size=emsize)
104
- if (style_def is not None)
105
- else None
106
- )
107
- if isinstance(criterion, nn.GaussianNLLLoss):
108
- n_out = 2
109
- elif (
110
- isinstance(criterion, BarDistribution)
111
- or "BarDistribution" in criterion.__class__.__name__
112
- ): # TODO remove this fix (only for dev)
113
- n_out = criterion.num_bars
114
- elif isinstance(criterion, nn.CrossEntropyLoss):
115
- n_out = criterion.weight.shape[0]
116
- else:
117
- n_out = 1
118
- model = TransformerModel(
119
- encoder,
120
- n_out,
121
- emsize,
122
- nhead,
123
- nhid,
124
- nlayers,
125
- dropout,
126
- style_encoder=style_encoder,
127
- y_encoder=y_encoder_generator(1, emsize),
128
- input_normalization=input_normalization,
129
- pos_encoder=(
130
- pos_encoder_generator or positional_encodings.NoPositionalEncoding
131
- )(emsize, bptt * 2),
132
- decoder=decoder,
133
- init_method=initializer,
134
- **model_extra_args,
135
- )
136
- model.criterion = criterion
137
- if load_weights_from_this_state_dict is not None:
138
- model.load_state_dict(load_weights_from_this_state_dict)
139
- if initialize_with_model is not None:
140
- model.init_from_small_model(initialize_with_model)
141
-
142
- print(
143
- f"Using a Transformer with {sum(p.numel() for p in model.parameters())/1000/1000:.{2}f} M parameters"
144
- )
145
-
146
- try:
147
- for (k, v), (k2, v2) in zip(
148
- model.state_dict().items(), initialize_with_model.state_dict().items()
149
- ):
150
- print(k, ((v - v2) / v).abs().mean(), v.shape)
151
- except Exception:
152
- pass
153
-
154
- model.to(device)
155
- if using_dist:
156
- print("Distributed training")
157
- model = torch.nn.parallel.DistributedDataParallel(
158
- model, device_ids=[rank], output_device=rank, broadcast_buffers=False
159
- )
160
-
161
- # learning rate
162
- if lr is None:
163
- lr = get_openai_lr(model)
164
- print(f"Using OpenAI max lr of {lr}.")
165
- optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
166
- scheduler = scheduler(
167
- optimizer, warmup_epochs, epochs if epochs is not None else 100
168
- ) # when training for fixed time lr schedule takes 100 steps
169
-
170
- if load_optimizer_from_this_state_dict is not None:
171
- optimizer.load_state_dict(load_optimizer_from_this_state_dict)
172
- scaler = GradScaler() if train_mixed_precision else None
173
-
174
- # check that everything uses up-to-date APIs
175
- utils.check_compatibility(dl)
176
-
177
- def train_epoch():
178
- model.train() # Turn on the train mode
179
- total_loss = 0.0
180
- total_positional_losses = 0.0
181
- total_positional_losses_recorded = 0
182
- before_get_batch = time.time()
183
- assert (
184
- len(dl) % aggregate_k_gradients == 0
185
- ), "Please set the number of steps per epoch s.t. `aggregate_k_gradients` divides it."
186
- for batch, (data, targets, single_eval_pos) in enumerate(dl):
187
- if using_dist and not (
188
- batch % aggregate_k_gradients == aggregate_k_gradients - 1
189
- ):
190
- cm = model.no_sync()
191
- else:
192
- cm = nullcontext()
193
- with cm:
194
- time_to_get_batch = time.time() - before_get_batch
195
- before_forward = time.time()
196
-
197
- with autocast(enabled=scaler is not None):
198
- # If style is set to None, it should not be transferred to device
199
- output = model(
200
- tuple(e.to(device) if torch.is_tensor(e) else e for e in data)
201
- if isinstance(data, tuple)
202
- else data.to(device),
203
- single_eval_pos=single_eval_pos,
204
- )
205
-
206
- forward_time = time.time() - before_forward
207
-
208
- if single_eval_pos is not None:
209
- targets = targets[single_eval_pos:]
210
- if isinstance(criterion, nn.GaussianNLLLoss):
211
- assert (
212
- output.shape[-1] == 2
213
- ), "need to write a little bit of code to handle multiple regression targets at once"
214
-
215
- mean_pred = output[..., 0]
216
- var_pred = output[..., 1].abs()
217
- losses = criterion(
218
- mean_pred.flatten(),
219
- targets.to(device).flatten(),
220
- var=var_pred.flatten(),
221
- )
222
- elif isinstance(criterion, (nn.MSELoss, nn.BCEWithLogitsLoss)):
223
- losses = criterion(
224
- output.flatten(), targets.to(device).flatten()
225
- )
226
- elif isinstance(criterion, nn.CrossEntropyLoss):
227
- losses = criterion(
228
- output.reshape(-1, n_out),
229
- targets.to(device).long().flatten(),
230
- )
231
- else:
232
- losses = criterion(output, targets)
233
- losses = losses.view(*output.shape[0:2])
234
- loss = losses.mean() / aggregate_k_gradients
235
-
236
- if scaler:
237
- loss = scaler.scale(loss)
238
- loss.backward()
239
-
240
- if batch % aggregate_k_gradients == aggregate_k_gradients - 1:
241
- if scaler:
242
- scaler.unscale_(optimizer)
243
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
244
- try:
245
- if scaler:
246
- scaler.step(optimizer)
247
- scaler.update()
248
- else:
249
- optimizer.step()
250
- except:
251
- print("Invalid optimization step encountered")
252
- optimizer.zero_grad()
253
-
254
- step_time = time.time() - before_forward
255
-
256
- if not torch.isnan(loss):
257
- total_loss += losses.mean().cpu().detach()
258
- total_positional_losses += (
259
- losses.mean(1).cpu().detach()
260
- if single_eval_pos is None
261
- else nn.functional.one_hot(torch.tensor(single_eval_pos), bptt)
262
- * losses[: bptt - single_eval_pos].mean().cpu().detach()
263
- )
264
-
265
- total_positional_losses_recorded += (
266
- torch.ones(bptt)
267
- if single_eval_pos is None
268
- else nn.functional.one_hot(torch.tensor(single_eval_pos), bptt)
269
- )
270
-
271
- before_get_batch = time.time()
272
- return (
273
- total_loss / steps_per_epoch,
274
- (total_positional_losses / total_positional_losses_recorded).tolist(),
275
- time_to_get_batch,
276
- forward_time,
277
- step_time,
278
- )
279
-
280
- total_loss = float("inf")
281
- total_positional_losses = float("inf")
282
- list_losses = []
283
- try:
284
- for epoch in range(1, epochs + 1) if epochs is not None else itertools.count(1):
285
- epoch_start_time = time.time()
286
- (
287
- total_loss,
288
- total_positional_losses,
289
- time_to_get_batch,
290
- forward_time,
291
- step_time,
292
- ) = train_epoch()
293
- list_losses.append(total_loss.item())
294
- if hasattr(dl, "validate") and epoch % validation_period == 0:
295
- with torch.no_grad():
296
- val_score = dl.validate(model)
297
-
298
- else:
299
- val_score = None
300
-
301
- if epoch % saving_period == 0 and checkpoint_file is not None:
302
- checkpoint = {
303
- "model_state_dict": model.state_dict(),
304
- "optimizer_state_dict": optimizer.state_dict(),
305
- "epoch": epoch,
306
- }
307
- torch.save(checkpoint, checkpoint_file)
308
- full_model_path = checkpoint_file.split(".")[0] + "_full_model.pt"
309
- torch.save(model, full_model_path)
310
-
311
- if verbose:
312
- print("-" * 89)
313
- print(
314
- f"| end of epoch {epoch:3d} | time: {(time.time() - epoch_start_time):5.2f}s | mean loss {total_loss:5.2f} | "
315
- f"pos losses {','.join([f'{l:5.2f}' for l in total_positional_losses])}, lr {scheduler.get_last_lr()[0]}"
316
- f" data time {time_to_get_batch:5.2f} step time {step_time:5.2f}"
317
- f" forward time {forward_time:5.2f}"
318
- + (f"val score {val_score}" if val_score is not None else "")
319
- )
320
- print("-" * 89)
321
-
322
- # stepping with wallclock time based scheduler
323
- if epoch_callback is not None and rank == 0:
324
- epoch_callback(model, epoch / epochs)
325
- scheduler.step()
326
- except KeyboardInterrupt:
327
- pass
328
-
329
- if rank == 0: # trivially true for non-parallel training
330
- if isinstance(model, torch.nn.parallel.DistributedDataParallel):
331
- model = model.module
332
- dl = None
333
- if output_path is not None:
334
- torch.save(model.to("cpu"), output_path)
335
- print("Checkpoint stored at ", output_path)
336
- return total_loss, total_positional_losses, model.to("cpu"), dl
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lcpfn/train_lcpfn.py DELETED
@@ -1,96 +0,0 @@
1
- import math
2
-
3
- from torch import nn
4
-
5
- from lcpfn import bar_distribution, encoders, train
6
- from lcpfn import utils
7
-
8
- from lcpfn.priors import utils as putils
9
-
10
-
11
- def train_lcpfn(
12
- get_batch_func,
13
- seq_len: int = 100,
14
- emsize: int = 512,
15
- nlayers: int = 12,
16
- num_borders: int = 1000,
17
- lr: float = 0.0001,
18
- batch_size: int = 100,
19
- epochs: int = 1000,
20
- ):
21
- """
22
- Train a LCPFN model using the specified hyperparameters.
23
-
24
- Args:
25
- get_batch_func (callable): A function that returns a batch of learning curves.
26
- seq_len (int, optional): The length of the input sequence. Defaults to 100.
27
- emsize (int, optional): The size of the embedding layer. Defaults to 512.
28
- nlayers (int, optional): The number of layers in the model. Defaults to 12.
29
- num_borders_choices (int, optional): The number of borders to use. Defaults to 1000.
30
- lr (float, optional): The learning rate for the optimizer. Defaults to 0.0001.
31
- batch_size (int, optional): The batch size for training. Defaults to 100.
32
- epochs (int, optional): The number of epochs to train for. Defaults to 1000.
33
-
34
- Returns:
35
- torch.module: The trained model.
36
- """
37
-
38
- hps = {}
39
-
40
- # PFN training hyperparameters
41
- dataloader = putils.get_batch_to_dataloader(get_batch_func) # type: ignore
42
-
43
- num_features = 1
44
-
45
- ys = get_batch_func(
46
- 10_000,
47
- seq_len,
48
- num_features,
49
- hyperparameters=hps,
50
- single_eval_pos=seq_len,
51
- )
52
-
53
- bucket_limits = bar_distribution.get_bucket_limits(num_borders, ys=ys[2])
54
-
55
- # Discretization of the predictive distributions
56
- criterions = {
57
- num_features: {
58
- num_borders: bar_distribution.FullSupportBarDistribution(bucket_limits)
59
- }
60
- }
61
-
62
- config = dict(
63
- nlayers=nlayers,
64
- priordataloader_class=dataloader,
65
- criterion=criterions[num_features][num_borders],
66
- encoder_generator=lambda in_dim, out_dim: nn.Sequential(
67
- encoders.Normalize(0.0, 101.0),
68
- encoders.Normalize(0.5, math.sqrt(1 / 12)),
69
- encoders.Linear(in_dim, out_dim),
70
- ),
71
- emsize=emsize,
72
- nhead=(emsize // 128),
73
- warmup_epochs=(epochs // 4),
74
- y_encoder_generator=encoders.get_normalized_uniform_encoder(encoders.Linear),
75
- batch_size=batch_size,
76
- scheduler=utils.get_cosine_schedule_with_warmup,
77
- extra_prior_kwargs_dict={
78
- # "num_workers": 10,
79
- "num_features": num_features,
80
- "hyperparameters": {
81
- **hps,
82
- },
83
- },
84
- epochs=epochs,
85
- lr=lr,
86
- bptt=seq_len,
87
- single_eval_pos_gen=utils.get_uniform_single_eval_pos_sampler(
88
- seq_len, min_len=1
89
- ),
90
- aggregate_k_gradients=1,
91
- nhid=(emsize * 2),
92
- steps_per_epoch=100,
93
- train_mixed_precision=False,
94
- )
95
-
96
- return train.train(**config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lcpfn/transformer.py DELETED
@@ -1,348 +0,0 @@
1
- import math
2
- from typing import Optional
3
-
4
- import torch
5
- import torch.nn as nn
6
- from torch import Tensor
7
- import torch.nn.functional as F
8
- from torch.nn import Module, TransformerEncoder
9
-
10
- from lcpfn.layer import TransformerEncoderLayer, _get_activation_fn
11
- from lcpfn.utils import SeqBN, bool_mask_to_att_mask
12
-
13
-
14
- class GELU(nn.Module):
15
- def forward(self, input: Tensor) -> Tensor:
16
- return F.gelu(input)
17
-
18
-
19
- class TransformerModel(nn.Module):
20
- def __init__(
21
- self,
22
- encoder,
23
- n_out,
24
- ninp,
25
- nhead,
26
- nhid,
27
- nlayers,
28
- dropout=0.0,
29
- style_encoder=None,
30
- y_encoder=None,
31
- pos_encoder=None,
32
- decoder=None,
33
- input_normalization=False,
34
- init_method=None,
35
- pre_norm=False,
36
- activation="gelu",
37
- recompute_attn=False,
38
- num_global_att_tokens=0,
39
- full_attention=False,
40
- all_layers_same_init=True,
41
- ):
42
- super().__init__()
43
- self.model_type = "Transformer"
44
- encoder_layer_creator = lambda: TransformerEncoderLayer(
45
- ninp,
46
- nhead,
47
- nhid,
48
- dropout,
49
- activation=activation,
50
- pre_norm=pre_norm,
51
- recompute_attn=recompute_attn,
52
- )
53
- self.transformer_encoder = (
54
- TransformerEncoder(encoder_layer_creator(), nlayers)
55
- if all_layers_same_init
56
- else TransformerEncoderDiffInit(encoder_layer_creator, nlayers)
57
- )
58
- self.ninp = ninp
59
- self.encoder = encoder
60
- self.y_encoder = y_encoder
61
- self.pos_encoder = pos_encoder
62
- self.decoder = (
63
- decoder(ninp, nhid, n_out)
64
- if decoder is not None
65
- else nn.Sequential(nn.Linear(ninp, nhid), GELU(), nn.Linear(nhid, n_out))
66
- )
67
- self.input_ln = SeqBN(ninp) if input_normalization else None
68
- self.style_encoder = style_encoder
69
- self.init_method = init_method
70
- if num_global_att_tokens is not None:
71
- assert not full_attention
72
- self.global_att_embeddings = (
73
- nn.Embedding(num_global_att_tokens, ninp) if num_global_att_tokens else None
74
- )
75
- self.full_attention = full_attention
76
-
77
- self.n_out = n_out
78
- self.nhid = nhid
79
-
80
- self.init_weights()
81
-
82
- @staticmethod
83
- def generate_square_subsequent_mask(sz):
84
- mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
85
- return bool_mask_to_att_mask(mask)
86
-
87
- @staticmethod
88
- def generate_D_q_matrix(sz, query_size):
89
- train_size = sz - query_size
90
- mask = torch.zeros(sz, sz) == 0
91
- mask[:, train_size:].zero_()
92
- mask |= torch.eye(sz) == 1
93
- return bool_mask_to_att_mask(mask)
94
-
95
- @staticmethod
96
- def generate_global_att_query_matrix(
97
- num_global_att_tokens, seq_len, num_query_tokens
98
- ):
99
- train_size = seq_len + num_global_att_tokens - num_query_tokens
100
- sz = seq_len + num_global_att_tokens
101
- mask = torch.zeros(num_query_tokens, sz) == 0
102
- mask[:, train_size:].zero_()
103
- mask[:, train_size:] |= torch.eye(num_query_tokens) == 1
104
- return bool_mask_to_att_mask(mask)
105
-
106
- @staticmethod
107
- def generate_global_att_trainset_matrix(
108
- num_global_att_tokens, seq_len, num_query_tokens
109
- ):
110
- train_size = seq_len + num_global_att_tokens - num_query_tokens
111
- trainset_size = seq_len - num_query_tokens
112
- mask = torch.zeros(trainset_size, num_global_att_tokens) == 0
113
- # mask[:,num_global_att_tokens:].zero_()
114
- # mask[:,num_global_att_tokens:] |= torch.eye(trainset_size) == 1
115
- return bool_mask_to_att_mask(mask)
116
-
117
- @staticmethod
118
- def generate_global_att_globaltokens_matrix(
119
- num_global_att_tokens, seq_len, num_query_tokens
120
- ):
121
- mask = (
122
- torch.zeros(
123
- num_global_att_tokens,
124
- num_global_att_tokens + seq_len - num_query_tokens,
125
- )
126
- == 0
127
- )
128
- return bool_mask_to_att_mask(mask)
129
-
130
- def init_weights(self):
131
- initrange = 1.0
132
- # if isinstance(self.encoder,EmbeddingEncoder):
133
- # self.encoder.weight.data.uniform_(-initrange, initrange)
134
- # self.decoder.bias.data.zero_()
135
- # self.decoder.weight.data.uniform_(-initrange, initrange)
136
- if self.init_method is not None:
137
- self.apply(self.init_method)
138
- for layer in self.transformer_encoder.layers:
139
- nn.init.zeros_(layer.linear2.weight)
140
- nn.init.zeros_(layer.linear2.bias)
141
- attns = (
142
- layer.self_attn
143
- if isinstance(layer.self_attn, nn.ModuleList)
144
- else [layer.self_attn]
145
- )
146
- for attn in attns:
147
- nn.init.zeros_(attn.out_proj.weight)
148
- nn.init.zeros_(attn.out_proj.bias)
149
-
150
- def forward(self, src, src_mask=None, single_eval_pos=None):
151
- assert isinstance(
152
- src, tuple
153
- ), "inputs (src) have to be given as (x,y) or (style,x,y) tuple"
154
-
155
- if len(src) == 2: # (x,y) and no style
156
- src = (None,) + src
157
-
158
- style_src, style_src_size = (src[0], (0 if (src[0] is None) else 1))
159
- if src_mask is not None:
160
- assert self.global_att_embeddings is None or isinstance(src_mask, tuple)
161
- if src_mask is None:
162
- x_src = src[1]
163
- if self.global_att_embeddings is None:
164
- full_len = len(x_src) + style_src_size
165
- if self.full_attention:
166
- src_mask = bool_mask_to_att_mask(
167
- torch.ones((full_len, full_len), dtype=torch.bool)
168
- ).to(x_src.device)
169
- else:
170
- src_mask = self.generate_D_q_matrix(
171
- len(x_src) + style_src_size,
172
- len(x_src) + style_src_size - single_eval_pos,
173
- ).to(x_src.device)
174
- else:
175
- src_mask_args = (
176
- self.global_att_embeddings.num_embeddings,
177
- len(x_src) + style_src_size,
178
- len(x_src) + style_src_size - single_eval_pos,
179
- )
180
- src_mask = (
181
- self.generate_global_att_globaltokens_matrix(*src_mask_args).to(
182
- x_src.device
183
- ),
184
- self.generate_global_att_trainset_matrix(*src_mask_args).to(
185
- x_src.device
186
- ),
187
- self.generate_global_att_query_matrix(*src_mask_args).to(
188
- x_src.device
189
- ),
190
- )
191
-
192
- style_src, x_src, y_src = src
193
- x_src = self.encoder(x_src)
194
- y_src = self.y_encoder(
195
- y_src.unsqueeze(-1) if len(y_src.shape) < len(x_src.shape) else y_src
196
- )
197
- style_src = (
198
- self.style_encoder(style_src).unsqueeze(0)
199
- if self.style_encoder
200
- else torch.tensor([], device=x_src.device)
201
- )
202
- global_src = (
203
- torch.tensor([], device=x_src.device)
204
- if self.global_att_embeddings is None
205
- else self.global_att_embeddings.weight.unsqueeze(1).repeat(
206
- 1, x_src.shape[1], 1
207
- )
208
- )
209
- train_x = x_src[:single_eval_pos] + y_src[:single_eval_pos]
210
- src = torch.cat([global_src, style_src, train_x, x_src[single_eval_pos:]], 0)
211
-
212
- if self.input_ln is not None:
213
- src = self.input_ln(src)
214
-
215
- if self.pos_encoder is not None:
216
- src = self.pos_encoder(src)
217
-
218
- # If we have style input, drop its output
219
- output = self.transformer_encoder(src, src_mask)[style_src_size:]
220
- output = self.decoder(output)
221
- return output[
222
- single_eval_pos
223
- + (
224
- self.global_att_embeddings.num_embeddings
225
- if self.global_att_embeddings
226
- else 0
227
- ) :
228
- ]
229
-
230
- @torch.no_grad()
231
- def init_from_small_model(self, small_model):
232
- assert (
233
- isinstance(self.decoder, nn.Linear)
234
- and isinstance(self.encoder, (nn.Linear, nn.Sequential))
235
- and isinstance(self.y_encoder, (nn.Linear, nn.Sequential))
236
- )
237
-
238
- def set_encoder_weights(my_encoder, small_model_encoder):
239
- my_encoder_linear, small_encoder_linear = (
240
- (my_encoder, small_model_encoder)
241
- if isinstance(my_encoder, nn.Linear)
242
- else (my_encoder[-1], small_model_encoder[-1])
243
- )
244
- small_in_dim = small_encoder_linear.out_features
245
- my_encoder_linear.weight.zero_()
246
- my_encoder_linear.bias.zero_()
247
- my_encoder_linear.weight[:small_in_dim] = small_encoder_linear.weight
248
- my_encoder_linear.bias[:small_in_dim] = small_encoder_linear.bias
249
-
250
- set_encoder_weights(self.encoder, small_model.encoder)
251
- set_encoder_weights(self.y_encoder, small_model.y_encoder)
252
-
253
- small_in_dim = small_model.decoder.in_features
254
-
255
- self.decoder.weight[:, :small_in_dim] = small_model.decoder.weight
256
- self.decoder.bias = small_model.decoder.bias
257
-
258
- for my_layer, small_layer in zip(
259
- self.transformer_encoder.layers, small_model.transformer_encoder.layers
260
- ):
261
- small_hid_dim = small_layer.linear1.out_features
262
- my_in_dim = my_layer.linear1.in_features
263
-
264
- # packed along q,k,v order in first dim
265
- my_in_proj_w = my_layer.self_attn.in_proj_weight
266
- small_in_proj_w = small_layer.self_attn.in_proj_weight
267
-
268
- my_in_proj_w.view(3, my_in_dim, my_in_dim)[
269
- :, :small_in_dim, :small_in_dim
270
- ] = small_in_proj_w.view(3, small_in_dim, small_in_dim)
271
- my_layer.self_attn.in_proj_bias.view(3, my_in_dim)[:, :small_in_dim] = (
272
- small_layer.self_attn.in_proj_bias.view(3, small_in_dim)
273
- )
274
-
275
- my_layer.self_attn.out_proj.weight[:small_in_dim, :small_in_dim] = (
276
- small_layer.self_attn.out_proj.weight
277
- )
278
- my_layer.self_attn.out_proj.bias[:small_in_dim] = (
279
- small_layer.self_attn.out_proj.bias
280
- )
281
-
282
- my_layer.linear1.weight[:small_hid_dim, :small_in_dim] = (
283
- small_layer.linear1.weight
284
- )
285
- my_layer.linear1.bias[:small_hid_dim] = small_layer.linear1.bias
286
-
287
- my_layer.linear2.weight[:small_in_dim, :small_hid_dim] = (
288
- small_layer.linear2.weight
289
- )
290
- my_layer.linear2.bias[:small_in_dim] = small_layer.linear2.bias
291
-
292
- my_layer.norm1.weight[:small_in_dim] = (
293
- math.sqrt(small_in_dim / my_in_dim) * small_layer.norm1.weight
294
- )
295
- my_layer.norm2.weight[:small_in_dim] = (
296
- math.sqrt(small_in_dim / my_in_dim) * small_layer.norm2.weight
297
- )
298
-
299
- my_layer.norm1.bias[:small_in_dim] = small_layer.norm1.bias
300
- my_layer.norm2.bias[:small_in_dim] = small_layer.norm2.bias
301
-
302
-
303
- class TransformerEncoderDiffInit(Module):
304
- r"""TransformerEncoder is a stack of N encoder layers
305
-
306
- Args:
307
- encoder_layer_creator: a function generating objects of TransformerEncoderLayer class without args (required).
308
- num_layers: the number of sub-encoder-layers in the encoder (required).
309
- norm: the layer normalization component (optional).
310
- """
311
-
312
- __constants__ = ["norm"]
313
-
314
- def __init__(self, encoder_layer_creator, num_layers, norm=None):
315
- super().__init__()
316
- self.layers = nn.ModuleList(
317
- [encoder_layer_creator() for _ in range(num_layers)]
318
- )
319
- self.num_layers = num_layers
320
- self.norm = norm
321
-
322
- def forward(
323
- self,
324
- src: Tensor,
325
- mask: Optional[Tensor] = None,
326
- src_key_padding_mask: Optional[Tensor] = None,
327
- ) -> Tensor:
328
- r"""Pass the input through the encoder layers in turn.
329
-
330
- Args:
331
- src: the sequence to the encoder (required).
332
- mask: the mask for the src sequence (optional).
333
- src_key_padding_mask: the mask for the src keys per batch (optional).
334
-
335
- Shape:
336
- see the docs in Transformer class.
337
- """
338
- output = src
339
-
340
- for mod in self.layers:
341
- output = mod(
342
- output, src_mask=mask, src_key_padding_mask=src_key_padding_mask
343
- )
344
-
345
- if self.norm is not None:
346
- output = self.norm(output)
347
-
348
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lcpfn/utils.py DELETED
@@ -1,409 +0,0 @@
1
- import os
2
- import math
3
- import argparse
4
- import random
5
- import datetime
6
-
7
- import torch
8
- from torch import nn
9
- from torch.optim.lr_scheduler import LambdaLR
10
- import numpy as np
11
-
12
-
13
- # copied from huggingface
14
- def get_cosine_schedule_with_warmup(
15
- optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1
16
- ):
17
- """Create a schedule with a learning rate that decreases following the
18
- values of the cosine function between 0 and `pi * cycles` after a warmup
19
- period during which it increases linearly between 0 and 1.
20
- """
21
-
22
- def lr_lambda(current_step):
23
- if current_step < num_warmup_steps:
24
- return float(current_step) / float(max(1, num_warmup_steps))
25
- progress = float(current_step - num_warmup_steps) / float(
26
- max(1, num_training_steps - num_warmup_steps)
27
- )
28
- return max(
29
- 0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
30
- )
31
-
32
- return LambdaLR(optimizer, lr_lambda, last_epoch)
33
-
34
-
35
- # copied from huggingface
36
- def get_linear_schedule_with_warmup(
37
- optimizer, num_warmup_steps, num_training_steps, last_epoch=-1
38
- ):
39
- """
40
- Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
41
- a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
42
-
43
- Args:
44
- optimizer (:class:`~torch.optim.Optimizer`):
45
- The optimizer for which to schedule the learning rate.
46
- num_warmup_steps (:obj:`int`):
47
- The number of steps for the warmup phase.
48
- num_training_steps (:obj:`int`):
49
- The total number of training steps.
50
- last_epoch (:obj:`int`, `optional`, defaults to -1):
51
- The index of the last epoch when resuming training.
52
-
53
- Return:
54
- :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
55
- """
56
-
57
- def lr_lambda(current_step: int):
58
- if current_step < num_warmup_steps:
59
- return float(current_step) / float(max(1, num_warmup_steps))
60
- return max(
61
- 0.0,
62
- float(num_training_steps - current_step)
63
- / float(max(1, num_training_steps - num_warmup_steps)),
64
- )
65
-
66
- return LambdaLR(optimizer, lr_lambda, last_epoch)
67
-
68
-
69
- def get_openai_lr(transformer_model):
70
- num_params = sum(p.numel() for p in transformer_model.parameters())
71
- return 0.003239 - 0.0001395 * math.log(num_params)
72
-
73
-
74
- def get_weighted_single_eval_pos_sampler(max_len):
75
- """
76
- This gives a sampler that can be used for `single_eval_pos` which yields good performance for all positions p,
77
- where p <= `max_len`. At most `max_len` - 1 examples are shown to the Transformer.
78
- :return: Sampler that can be fed to `train()` as `single_eval_pos_gen`.
79
- """
80
- return lambda: random.choices(
81
- range(max_len), [1 / (max_len - i) for i in range(max_len)]
82
- )[0]
83
-
84
-
85
- def get_uniform_single_eval_pos_sampler(max_len, min_len=0):
86
- """
87
- Just sample any evaluation position with the same weight
88
- :return: Sampler that can be fed to `train()` as `single_eval_pos_gen`.
89
- """
90
- return lambda: random.choices(range(min_len, max_len))[0]
91
-
92
-
93
- class SeqBN(nn.Module):
94
- def __init__(self, d_model):
95
- super().__init__()
96
- self.bn = nn.BatchNorm1d(d_model)
97
- self.d_model = d_model
98
-
99
- def forward(self, x):
100
- assert self.d_model == x.shape[-1]
101
- flat_x = x.view(-1, self.d_model)
102
- flat_x = self.bn(flat_x)
103
- return flat_x.view(*x.shape)
104
-
105
-
106
- def set_locals_in_self(locals):
107
- """
108
- Call this function like `set_locals_in_self(locals())` to set all local variables as object variables.
109
- Especially useful right at the beginning of `__init__`.
110
- :param locals: `locals()`
111
- """
112
- self = locals["self"]
113
- for var_name, val in locals.items():
114
- if var_name != "self":
115
- setattr(self, var_name, val)
116
-
117
-
118
- default_device = "cuda:0" if torch.cuda.is_available() else "cpu:0"
119
-
120
-
121
- # Copied from StackOverflow, but we do an eval on the values additionally
122
- class StoreDictKeyPair(argparse.Action):
123
- def __init__(self, option_strings, dest, nargs=None, **kwargs):
124
- self._nargs = nargs
125
- super(StoreDictKeyPair, self).__init__(
126
- option_strings, dest, nargs=nargs, **kwargs
127
- )
128
-
129
- def __call__(self, parser, namespace, values, option_string=None):
130
- my_dict = {}
131
- for kv in values:
132
- k, v = kv.split("=")
133
- try:
134
- my_dict[k] = eval(v)
135
- except NameError:
136
- my_dict[k] = v
137
- setattr(namespace, self.dest, my_dict)
138
- print("dict values: {}".format(my_dict))
139
-
140
-
141
- def get_nan_value(v, set_value_to_nan=0.0):
142
- if random.random() < set_value_to_nan:
143
- return v
144
- else:
145
- return random.choice([-999, 0, 1, 999])
146
-
147
-
148
- def to_ranking(data):
149
- x = data >= data.unsqueeze(-3)
150
- x = x.sum(0)
151
- return x
152
-
153
-
154
- # TODO: Is there a better way to do this?
155
- # 1. Cmparing to unique elements: When all values are different we still get quadratic blowup
156
- # 2. Argsort(Argsort()) returns ranking, but with duplicate values there is an ordering which is problematic
157
- # 3. Argsort(Argsort(Unique))->Scatter seems a bit complicated, doesn't have quadratic blowup, but how fast?
158
- def to_ranking_low_mem(data):
159
- x = torch.zeros_like(data)
160
- for col in range(data.shape[-1]):
161
- x_ = data[:, :, col] >= data[:, :, col].unsqueeze(-2)
162
- x_ = x_.sum(0)
163
- x[:, :, col] = x_
164
- return x
165
-
166
-
167
- def nan_handling_missing_for_unknown_reason_value(set_value_to_nan=0.0):
168
- return get_nan_value(float("nan"), set_value_to_nan)
169
-
170
-
171
- def nan_handling_missing_for_no_reason_value(set_value_to_nan=0.0):
172
- return get_nan_value(float("-inf"), set_value_to_nan)
173
-
174
-
175
- def nan_handling_missing_for_a_reason_value(set_value_to_nan=0.0):
176
- return get_nan_value(float("inf"), set_value_to_nan)
177
-
178
-
179
- def torch_nanmean(x, axis=0):
180
- num = torch.where(torch.isnan(x), torch.full_like(x, 0), torch.full_like(x, 1)).sum(
181
- axis=axis
182
- )
183
- value = torch.where(torch.isnan(x), torch.full_like(x, 0), x).sum(axis=axis)
184
- return value / num
185
-
186
-
187
- def torch_nanstd(x, axis=0):
188
- num = torch.where(torch.isnan(x), torch.full_like(x, 0), torch.full_like(x, 1)).sum(
189
- axis=axis
190
- )
191
- value = torch.where(torch.isnan(x), torch.full_like(x, 0), x).sum(axis=axis)
192
- mean = value / num
193
- mean_broadcast = torch.repeat_interleave(
194
- mean.unsqueeze(axis), x.shape[axis], dim=axis
195
- )
196
- return torch.sqrt(
197
- torch.nansum(torch.square(mean_broadcast - x), axis=axis) / (num - 1)
198
- )
199
-
200
-
201
- def normalize_data(data, normalize_positions=-1):
202
- if normalize_positions > 0:
203
- mean = torch_nanmean(data[:normalize_positions], axis=0)
204
- std = torch_nanstd(data[:normalize_positions], axis=0) + 0.000001
205
- else:
206
- mean = torch_nanmean(data, axis=0)
207
- std = torch_nanstd(data, axis=0) + 0.000001
208
- data = (data - mean) / std
209
- data = torch.clip(data, min=-100, max=100)
210
-
211
- return data
212
-
213
-
214
- def remove_outliers(X, n_sigma=4):
215
- # Expects T, B, H
216
- assert len(X.shape) == 3, "X must be T,B,H"
217
- # for b in range(X.shape[1]):
218
- # for col in range(X.shape[2]):
219
- data = X
220
- data_mean, data_std = torch_nanmean(data, axis=0), torch_nanstd(data, axis=0)
221
- cut_off = data_std * n_sigma
222
- lower, upper = data_mean - cut_off, data_mean + cut_off
223
-
224
- data_clean = X[:].clone()
225
- data_clean[torch.logical_or(data > upper, data < lower)] = np.nan
226
- data_mean, data_std = (
227
- torch_nanmean(data_clean, axis=0),
228
- torch_nanstd(data_clean, axis=0),
229
- )
230
- cut_off = data_std * n_sigma
231
- lower, upper = data_mean - cut_off, data_mean + cut_off
232
-
233
- X = torch.maximum(-torch.log(1 + torch.abs(X)) + lower, X)
234
- X = torch.minimum(torch.log(1 + torch.abs(X)) + upper, X)
235
- # print(ds[1][data < lower, col], ds[1][data > upper, col], ds[1][~np.isnan(data), col].shape, data_mean, data_std)
236
- return X
237
-
238
-
239
- def bool_mask_to_att_mask(mask):
240
- return (
241
- mask.float()
242
- .masked_fill(mask == 0, float("-inf"))
243
- .masked_fill(mask == 1, float(0.0))
244
- )
245
-
246
-
247
- def print_on_master_only(is_master):
248
- import builtins as __builtin__
249
-
250
- builtin_print = __builtin__.print
251
-
252
- def print(*args, **kwargs):
253
- force = kwargs.pop("force", False)
254
- if is_master or force:
255
- builtin_print(*args, **kwargs)
256
-
257
- __builtin__.print = print
258
-
259
-
260
- def init_dist(device):
261
- print("init dist")
262
- if "LOCAL_RANK" in os.environ:
263
- # launched with torch.distributed.launch
264
- rank = int(os.environ["LOCAL_RANK"])
265
- print("torch.distributed.launch and my rank is", rank)
266
- torch.cuda.set_device(rank)
267
- os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
268
- torch.distributed.init_process_group(
269
- backend="nccl",
270
- init_method="env://",
271
- timeout=datetime.timedelta(seconds=20),
272
- world_size=torch.cuda.device_count(),
273
- rank=rank,
274
- )
275
- torch.distributed.barrier()
276
- print_on_master_only(rank == 0)
277
- print(
278
- f"Distributed training on {torch.cuda.device_count()} GPUs, this is rank {rank}, "
279
- "only I can print, but when using print(..., force=True) it will print on all ranks."
280
- )
281
- return True, rank, f"cuda:{rank}"
282
- elif "SLURM_PROCID" in os.environ and torch.cuda.device_count() > 1:
283
- # this is for multi gpu when starting with submitit
284
- assert device != "cpu:0"
285
- rank = int(os.environ["SLURM_PROCID"])
286
- os.environ["MASTER_ADDR"] = "localhost"
287
- os.environ["MASTER_PORT"] = "12355"
288
- torch.cuda.set_device(rank)
289
- os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
290
- print("distributed submitit launch and my rank is", rank)
291
- torch.distributed.init_process_group(
292
- backend="nccl",
293
- init_method="env://",
294
- timeout=datetime.timedelta(seconds=20),
295
- world_size=torch.cuda.device_count(),
296
- rank=rank,
297
- )
298
- torch.distributed.barrier()
299
- print_on_master_only(rank == 0)
300
- print(
301
- f"Distributed training on {torch.cuda.device_count()} GPUs, this is rank {rank}, "
302
- "only I can print, but when using print(..., force=True) it will print on all ranks."
303
- )
304
-
305
- return True, rank, f"cuda:{rank}"
306
- else:
307
- print("Not using distributed")
308
- # will not change any of the behavior of print, but allows putting the force=True in the print calls
309
- print_on_master_only(True)
310
- return False, 0, device
311
-
312
-
313
- def check_compatibility(dl):
314
- if hasattr(dl, "num_outputs"):
315
- print(
316
- "`num_outputs` for the DataLoader is deprecated. It is assumed to be 1 from now on."
317
- )
318
- assert dl.num_outputs != 1, (
319
- "We assume num_outputs to be 1. Instead of the num_ouputs change your loss."
320
- "We specify the number of classes in the CE loss."
321
- )
322
-
323
-
324
- def pfn_normalize(
325
- lb=torch.tensor(float("-inf")),
326
- ub=torch.tensor(float("inf")),
327
- soft_lb=0.0,
328
- soft_ub=1.0,
329
- minimize=False,
330
- ):
331
- """
332
- LC-PFN curve prior assumes curves to be normalized within the range [0,1] and to be maximized.
333
- This function allows to normalize and denormalize data to fit this assumption.
334
-
335
- Parameters:
336
- lb (torch.Tensor): Lower bound of the data.
337
- ub (torch.Tensor): Upper bound of the data.
338
- soft_lb (float): Soft lower bound for normalization. Default is 0.0.
339
- soft_ub (float): Soft upper bound for normalization. Default is 1.0.
340
- minimize (bool): If True, the original curve is a minization. Default is False.
341
-
342
- Returns: Two functions for normalizing and denormalizing the data.
343
- """
344
- assert lb <= soft_lb and soft_lb < soft_ub and soft_ub <= ub
345
- # step 1: linearly transform [soft_lb,soft_ub] [-1,1] (where the sigmoid behaves approx linearly)
346
- # 2.0/(soft_ub - soft_lb)*(x - soft_lb) - 1.0
347
- # step 2: apply a vertically scaled/shifted the sigmoid such that [lb,ub] --> [0,1]
348
-
349
- def cinv(x):
350
- return 1 - x if minimize else x
351
-
352
- def lin_soft(x):
353
- return 2 / (soft_ub - soft_lb) * (x - soft_lb) - 1
354
-
355
- def lin_soft_inv(y):
356
- return (y + 1) / 2 * (soft_ub - soft_lb) + soft_lb
357
-
358
- try:
359
- if torch.exp(-lin_soft(lb)) > 1e300:
360
- raise RuntimeError
361
- # otherwise overflow causes issues, treat these cases as if the lower bound was -infinite
362
- # print(f"WARNING: {lb} --> NINF to avoid overflows ({np.exp(-lin_soft(lb))})")
363
- except RuntimeError:
364
- lb = torch.tensor(float("-inf"))
365
- if torch.isinf(lb) and torch.isinf(ub):
366
- return lambda x: cinv(
367
- 1 / (1 + torch.exp(-lin_soft(x)))
368
- ), lambda y: lin_soft_inv(torch.log(cinv(y) / (1 - cinv(y))))
369
- elif torch.isinf(lb):
370
- a = 1 + torch.exp(-lin_soft(ub))
371
- return lambda x: cinv(
372
- a / (1 + torch.exp(-lin_soft(x)))
373
- ), lambda y: lin_soft_inv(torch.log((cinv(y) / a) / (1 - (cinv(y) / a))))
374
- elif torch.isinf(ub):
375
- a = 1 / (1 - 1 / (1 + torch.exp(-lin_soft(lb))))
376
- b = 1 - a
377
- return lambda x: cinv(
378
- a / (1 + torch.exp(-lin_soft(x))) + b
379
- ), lambda y: lin_soft_inv(
380
- torch.log(((cinv(y) - b) / a) / (1 - ((cinv(y) - b) / a)))
381
- )
382
- else:
383
- a = (
384
- 1
385
- + torch.exp(-lin_soft(ub))
386
- + torch.exp(-lin_soft(lb))
387
- + torch.exp(-lin_soft(ub) - lin_soft(lb))
388
- ) / (torch.exp(-lin_soft(lb)) - torch.exp(-lin_soft(ub)))
389
- b = -a / (1 + torch.exp(-lin_soft(lb)))
390
- return lambda x: cinv(
391
- a / (1 + torch.exp(-lin_soft(x))) + b
392
- ), lambda y: lin_soft_inv(
393
- torch.log(((cinv(y) - b) / a) / (1 - ((cinv(y) - b) / a)))
394
- )
395
-
396
-
397
- def get_default_normalizer():
398
- default_normalizer_kwargs = {
399
- "lb": torch.tensor(0.0),
400
- "ub": torch.tensor(1.0),
401
- "soft_lb": 0.0,
402
- "soft_ub": 1.0,
403
- "minimize": False,
404
- }
405
- return pfn_normalize(**default_normalizer_kwargs)
406
-
407
-
408
- def identity_normalizer():
409
- return lambda x: x, lambda x: x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lcpfn/version.py DELETED
@@ -1 +0,0 @@
1
- __version__ = "0.1.3"
 
 
pyproject.toml DELETED
@@ -1,42 +0,0 @@
1
- [project]
2
- name = "lcpfn"
3
- description = "In-context Bayesian Learning Curve Extrapolation"
4
- readme = {file = "readme.md", content-type = 'text/markdown'}
5
- license = {file = "LICENSE"}
6
- authors = [
7
- {name = "Steven Adriaensen", email= "adriaens@cs.uni-freiburg.de"},
8
- {name = "Herilalaina Rakotoarison", email = "rakotoah@cs.uni-freiburg.de"},
9
- {name = "Samuel Müller", email = "muellesa@cs.uni-freiburg.de"},
10
- {name = "Frank Hutter", email = "fh@cs.uni-freiburg.de"},
11
- ]
12
- requires-python = ">=3.9,<3.12"
13
- dependencies = [
14
- "torch<=1.11.0",
15
- "numpy>=1.21.2,<2",
16
- "requests>=2.23.0"
17
- ]
18
- dynamic = ["version"]
19
- classifiers = [
20
- 'Intended Audience :: Science/Research',
21
- 'License :: OSI Approved :: MIT License',
22
- 'Programming Language :: Python',
23
- 'Topic :: Software Development',
24
- 'Topic :: Scientific/Engineering',
25
- 'Operating System :: Unix',
26
- 'Operating System :: MacOS',
27
- 'Programming Language :: Python :: 3',
28
- 'Programming Language :: Python :: 3.9',
29
- 'Programming Language :: Python :: 3.10',
30
- 'Programming Language :: Python :: 3.11',
31
- ]
32
-
33
- [project.urls]
34
- homepage = "https://github.com/automl/lcpfn"
35
- repository = "https://github.com/automl/lcpfn"
36
- bugtracker = "https://github.com/automl/lcpfn/issues"
37
-
38
- [tool.setuptools.packages.find]
39
- include = ["lcpfn*"]
40
-
41
- [tool.setuptools.dynamic]
42
- version = {attr = "lcpfn.version.__version__"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch<=1.11.0
2
+ numpy>=1.21.2,<2
3
+ lcpfn==0.1.3