Spaces:
Running
Running
herilalaina
commited on
Commit
•
328c052
1
Parent(s):
b62776c
update
Browse files- lcpfn/.ipynb_checkpoints/__init__-checkpoint.py +0 -53
- lcpfn/.ipynb_checkpoints/curves-checkpoint.py +0 -277
- lcpfn/.ipynb_checkpoints/domhan_prior-checkpoint.py +0 -195
- lcpfn/__init__.py +0 -80
- lcpfn/bar_distribution.py +0 -349
- lcpfn/curves.py +0 -277
- lcpfn/decoders.py +0 -42
- lcpfn/domhan_prior.py +0 -199
- lcpfn/encoders.py +0 -190
- lcpfn/initializers.py +0 -11
- lcpfn/layer.py +0 -179
- lcpfn/model.py +0 -56
- lcpfn/positional_encodings.py +0 -78
- lcpfn/priors/__init__.py +0 -1
- lcpfn/priors/binarized_regression.py +0 -19
- lcpfn/priors/fast_gp.py +0 -143
- lcpfn/priors/fast_gp_mix.py +0 -394
- lcpfn/priors/gp.py +0 -69
- lcpfn/priors/prior.py +0 -25
- lcpfn/priors/pyro.py +0 -41
- lcpfn/priors/ridge.py +0 -37
- lcpfn/priors/stroke.py +0 -143
- lcpfn/priors/utils.py +0 -151
- lcpfn/train.py +0 -336
- lcpfn/train_lcpfn.py +0 -96
- lcpfn/transformer.py +0 -348
- lcpfn/utils.py +0 -409
- lcpfn/version.py +0 -1
- pyproject.toml +0 -42
- requirements.txt +3 -0
lcpfn/.ipynb_checkpoints/__init__-checkpoint.py
DELETED
@@ -1,53 +0,0 @@
|
|
1 |
-
import os, sys
|
2 |
-
sys.path.insert(0, os.path.dirname(__file__))
|
3 |
-
|
4 |
-
|
5 |
-
model_path = 'trained_models'
|
6 |
-
|
7 |
-
def prepare_models():
|
8 |
-
pfns4bo_dir = os.path.dirname(__file__)
|
9 |
-
model_names = ['pfn_EPOCH1000_EMSIZE512_NLAYERS12_NBUCKETS1000.pt',
|
10 |
-
'pfn_EPOCH1000_EMSIZE512_NLAYERS6_NBUCKETS1000.pt']
|
11 |
-
|
12 |
-
for name in model_names:
|
13 |
-
weights_path = os.path.join(pfns4bo_dir, model_path, name)
|
14 |
-
compressed_weights_path = os.path.join(pfns4bo_dir, model_path, name + '.gz')
|
15 |
-
if not os.path.exists(weights_path):
|
16 |
-
if not os.path.exists(compressed_weights_path):
|
17 |
-
print("Downloading", os.path.abspath(compressed_weights_path))
|
18 |
-
import requests
|
19 |
-
url = f'https://github.com/automl/lcpfn/raw/main/lcpfn/trained_models/{name + ".gz"}'
|
20 |
-
r = requests.get(url, allow_redirects=True)
|
21 |
-
os.makedirs(os.path.dirname(compressed_weights_path), exist_ok=True)
|
22 |
-
with open(compressed_weights_path, 'wb') as f:
|
23 |
-
f.write(r.content)
|
24 |
-
if os.path.exists(compressed_weights_path):
|
25 |
-
print("Unzipping", name)
|
26 |
-
os.system(f"gzip -dk {compressed_weights_path}")
|
27 |
-
else:
|
28 |
-
print("Failed to find", compressed_weights_path)
|
29 |
-
print("Make sure you have an internet connection to download the model automatically..")
|
30 |
-
if os.path.exists(weights_path):
|
31 |
-
print("Successfully located model at", weights_path)
|
32 |
-
|
33 |
-
|
34 |
-
model_dict = {
|
35 |
-
'EMSIZE512_NLAYERS12_NBUCKETS1000': os.path.join(os.path.dirname(__file__),model_path,
|
36 |
-
'pfn_EPOCH1000_EMSIZE512_NLAYERS12_NBUCKETS1000.pt'),
|
37 |
-
'EMSIZE512_NLAYERS6_NBUCKETS1000': os.path.join(os.path.dirname(__file__),model_path,
|
38 |
-
'pfn_EPOCH1000_EMSIZE512_NLAYERS6_NBUCKETS1000.pt'),
|
39 |
-
}
|
40 |
-
|
41 |
-
|
42 |
-
def __getattr__(name):
|
43 |
-
if name in model_dict:
|
44 |
-
if not os.path.exists(model_dict[name]):
|
45 |
-
print("Can't find", os.path.abspath(model_dict[name]), "thus unzipping/downloading models now.")
|
46 |
-
print("This might take a while..")
|
47 |
-
prepare_models()
|
48 |
-
return model_dict[name]
|
49 |
-
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
50 |
-
|
51 |
-
from lcpfn.model import LCPFN
|
52 |
-
from lcpfn.train_lcpfn import train_lcpfn
|
53 |
-
from lcpfn.domhan_prior import sample_from_prior, create_get_batch_func
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lcpfn/.ipynb_checkpoints/curves-checkpoint.py
DELETED
@@ -1,277 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
from collections import OrderedDict
|
3 |
-
|
4 |
-
prior = {
|
5 |
-
"pow3": {
|
6 |
-
"uniform": OrderedDict(
|
7 |
-
a={"type": "uniform", "param1": -1, "param2": 1},
|
8 |
-
c={"type": "uniform", "param1": 0, "param2": 1},
|
9 |
-
alpha={"type": "uniform", "param1": 0, "param2": 1},
|
10 |
-
),
|
11 |
-
"peaked": OrderedDict(
|
12 |
-
a={"type": "uniform", "param1": -0.6, "param2": 0.6},
|
13 |
-
c={"type": "uniform", "param1": 0, "param2": 1.25},
|
14 |
-
alpha={"type": "log_normal", "param1": 0, "param2": 2},
|
15 |
-
),
|
16 |
-
},
|
17 |
-
"ilog2": {
|
18 |
-
"uniform": OrderedDict(
|
19 |
-
c={"type": "uniform", "param1": 0, "param2": 1},
|
20 |
-
a={"type": "uniform", "param1": -1, "param2": 1},
|
21 |
-
),
|
22 |
-
"peaked": OrderedDict(
|
23 |
-
c={"type": "uniform", "param1": 0, "param2": 1},
|
24 |
-
a={"type": "uniform", "param1": -0.5, "param2": 0.5},
|
25 |
-
),
|
26 |
-
},
|
27 |
-
"janoschek": {
|
28 |
-
"uniform": OrderedDict(
|
29 |
-
a={"type": "uniform", "param1": 0, "param2": 1},
|
30 |
-
beta={"type": "uniform", "param1": 0, "param2": 2},
|
31 |
-
k={"type": "uniform", "param1": 0, "param2": 1},
|
32 |
-
delta={"type": "uniform", "param1": -5, "param2": 5},
|
33 |
-
),
|
34 |
-
"peaked": OrderedDict(
|
35 |
-
a={"type": "uniform", "param1": 0, "param2": 1},
|
36 |
-
beta={"type": "uniform", "param1": 0, "param2": 2},
|
37 |
-
k={"type": "log_normal", "param1": -2, "param2": 1},
|
38 |
-
delta={"type": "log_normal", "param1": 0, "param2": 0.5},
|
39 |
-
),
|
40 |
-
},
|
41 |
-
}
|
42 |
-
|
43 |
-
|
44 |
-
def prior_sampler(rng, type, param1, param2):
|
45 |
-
if type == "uniform":
|
46 |
-
return rng.uniform(param1, param2)
|
47 |
-
elif type == "log_normal":
|
48 |
-
return rng.lognormal(param1, param2)
|
49 |
-
raise Exception("Unknown prior type: {}".format(type))
|
50 |
-
|
51 |
-
|
52 |
-
def pow3(x, c, a, alpha):
|
53 |
-
return c - a * (x) ** (-alpha)
|
54 |
-
|
55 |
-
|
56 |
-
def prior_pow3(rng):
|
57 |
-
return {
|
58 |
-
p: prior_sampler(
|
59 |
-
rng,
|
60 |
-
prior["pow3"]["peaked"][p]["type"],
|
61 |
-
param1=prior["pow3"]["peaked"][p]["param1"],
|
62 |
-
param2=prior["pow3"]["peaked"][p]["param2"],
|
63 |
-
)
|
64 |
-
for p in ["a", "c", "alpha"]
|
65 |
-
}
|
66 |
-
|
67 |
-
|
68 |
-
def uniform_prior_pow3(rng):
|
69 |
-
return {
|
70 |
-
p: prior_sampler(
|
71 |
-
rng,
|
72 |
-
prior["pow3"]["uniform"][p]["type"],
|
73 |
-
param1=prior["pow3"]["uniform"][p]["param1"],
|
74 |
-
param2=prior["pow3"]["uniform"][p]["param2"],
|
75 |
-
)
|
76 |
-
for p in ["a", "c", "alpha"]
|
77 |
-
}
|
78 |
-
|
79 |
-
|
80 |
-
def ilog2(x, c, a):
|
81 |
-
return c - a / (np.log(x + 1))
|
82 |
-
|
83 |
-
|
84 |
-
def prior_ilog2(rng):
|
85 |
-
return {
|
86 |
-
p: prior_sampler(
|
87 |
-
rng,
|
88 |
-
prior["ilog2"]["peaked"][p]["type"],
|
89 |
-
param1=prior["ilog2"]["peaked"][p]["param1"],
|
90 |
-
param2=prior["ilog2"]["peaked"][p]["param2"],
|
91 |
-
)
|
92 |
-
for p in ["a", "c"]
|
93 |
-
}
|
94 |
-
|
95 |
-
|
96 |
-
def uniform_prior_ilog2(rng):
|
97 |
-
return {
|
98 |
-
p: prior_sampler(
|
99 |
-
rng,
|
100 |
-
prior["ilog2"]["uniform"][p]["type"],
|
101 |
-
param1=prior["ilog2"]["uniform"][p]["param1"],
|
102 |
-
param2=prior["ilog2"]["uniform"][p]["param2"],
|
103 |
-
)
|
104 |
-
for p in ["a", "c"]
|
105 |
-
}
|
106 |
-
|
107 |
-
|
108 |
-
def janoschek(x, a, beta, k, delta):
|
109 |
-
"""
|
110 |
-
http://www.pisces-conservation.com/growthhelp/janoschek.htm
|
111 |
-
"""
|
112 |
-
return a - (a - beta) * np.exp(-k * x**delta)
|
113 |
-
|
114 |
-
|
115 |
-
def prior_janoschek(rng):
|
116 |
-
return {
|
117 |
-
p: prior_sampler(
|
118 |
-
rng,
|
119 |
-
prior["janoschek"]["peaked"][p]["type"],
|
120 |
-
param1=prior["janoschek"]["peaked"][p]["param1"],
|
121 |
-
param2=prior["janoschek"]["peaked"][p]["param2"],
|
122 |
-
)
|
123 |
-
for p in ["a", "beta", "k", "delta"]
|
124 |
-
}
|
125 |
-
|
126 |
-
|
127 |
-
def uniform_prior_janoschek(rng):
|
128 |
-
return {
|
129 |
-
p: prior_sampler(
|
130 |
-
rng,
|
131 |
-
prior["janoschek"]["uniform"][p]["type"],
|
132 |
-
param1=prior["janoschek"]["uniform"][p]["param1"],
|
133 |
-
param2=prior["janoschek"]["uniform"][p]["param2"],
|
134 |
-
)
|
135 |
-
for p in ["a", "beta", "k", "delta"]
|
136 |
-
}
|
137 |
-
|
138 |
-
|
139 |
-
def log_power(x, a, b, c):
|
140 |
-
# a: upper bound
|
141 |
-
# c: growth rate
|
142 |
-
# initial = a/ (1 + (1/e^b)^c
|
143 |
-
return a / (1.0 + (x / np.exp(b)) ** c)
|
144 |
-
|
145 |
-
|
146 |
-
def prior_log_power(rng):
|
147 |
-
# a ~ N(0.8,0.1)
|
148 |
-
# b ~ N(1,1)
|
149 |
-
# c ~ U(-3,0)
|
150 |
-
a = rng.normal(0.8, 0.1)
|
151 |
-
b = rng.normal(1.0, 1.0)
|
152 |
-
c = rng.uniform(-3.0, 0.0)
|
153 |
-
return {"a": a, "b": b, "c": c}
|
154 |
-
|
155 |
-
|
156 |
-
def weibull(x, alpha, beta, kappa, delta):
|
157 |
-
"""
|
158 |
-
Weibull modell
|
159 |
-
http://www.pisces-conservation.com/growthhelp/index.html?morgan_mercer_floden.htm
|
160 |
-
alpha: upper asymptote
|
161 |
-
beta: lower asymptote
|
162 |
-
k: growth rate
|
163 |
-
delta: controls the x-ordinate for the point of inflection
|
164 |
-
"""
|
165 |
-
return alpha - (alpha - beta) * np.exp(-((kappa * x) ** delta))
|
166 |
-
|
167 |
-
|
168 |
-
def prior_weibull(rng):
|
169 |
-
alpha = rng.uniform(0.0, 1.5)
|
170 |
-
beta = rng.uniform(0.0, 1)
|
171 |
-
kappa = np.exp(rng.normal(-2.0, 1.0))
|
172 |
-
delta = np.exp(rng.normal(0, 0.5))
|
173 |
-
return {"alpha": alpha, "beta": beta, "kappa": kappa, "delta": delta}
|
174 |
-
|
175 |
-
|
176 |
-
def mmf(x, alpha, beta, kappa, delta):
|
177 |
-
"""
|
178 |
-
Morgan-Mercer-Flodin
|
179 |
-
description:
|
180 |
-
Nonlinear Regression page 342
|
181 |
-
http://bit.ly/1jodG17
|
182 |
-
http://www.pisces-conservation.com/growthhelp/index.html?morgan_mercer_floden.htm
|
183 |
-
alpha: upper asymptote
|
184 |
-
kappa: growth rate
|
185 |
-
beta: initial value
|
186 |
-
delta: controls the point of inflection
|
187 |
-
"""
|
188 |
-
return alpha - (alpha - beta) / (1.0 + (kappa * x) ** delta)
|
189 |
-
|
190 |
-
|
191 |
-
def prior_mmf(rng):
|
192 |
-
# alpha ~ N(0.8,0.1)
|
193 |
-
# beta ~ N(0.2,0.1)
|
194 |
-
# ln(kappa) ~ N(0,2)
|
195 |
-
# ln(delta) ~ N(0,1)
|
196 |
-
alpha = rng.normal(0.8, 0.1)
|
197 |
-
beta = rng.normal(0.2, 0.1)
|
198 |
-
kappa = np.exp(rng.normal(0, 2))
|
199 |
-
delta = np.exp(rng.normal(0, 1))
|
200 |
-
return {"alpha": alpha, "beta": beta, "kappa": kappa, "delta": delta}
|
201 |
-
|
202 |
-
|
203 |
-
def vap(x, a, b, c):
|
204 |
-
"""Vapor pressure model"""
|
205 |
-
# no upper bound if c > 0
|
206 |
-
# a = ln(upper bound) for c=0
|
207 |
-
# a+b = ln(initial)
|
208 |
-
return np.exp(a + b / x + c * np.log(x))
|
209 |
-
|
210 |
-
|
211 |
-
def prior_vap(rng):
|
212 |
-
a = rng.uniform(-2.0, 0.0) # @heri: range check
|
213 |
-
b = rng.uniform(-4.0, 0.0) # @heri: range check
|
214 |
-
c = np.exp(rng.uniform(-8.0, 0.0)) # @heri: same as weights
|
215 |
-
return {"a": a, "b": b, "c": c}
|
216 |
-
|
217 |
-
|
218 |
-
def loglog_linear(x, a, b):
|
219 |
-
x = np.log(x)
|
220 |
-
return np.log(a * x + b)
|
221 |
-
|
222 |
-
|
223 |
-
def prior_loglog_linear(rng):
|
224 |
-
# ln(a) ~ N(-2, 1)
|
225 |
-
# ln(b) ~ U(0, 1)
|
226 |
-
a = np.exp(rng.normal(-2.0, 1.0))
|
227 |
-
b = np.exp(rng.uniform(0.0, 1.0))
|
228 |
-
return {"a": a, "b": b}
|
229 |
-
|
230 |
-
|
231 |
-
def exp4(x, c, a, b, alpha):
|
232 |
-
return c - np.exp(-a * (x**alpha) + b)
|
233 |
-
|
234 |
-
|
235 |
-
def prior_exp4(rng):
|
236 |
-
# c ~ N(0.8,0.1)
|
237 |
-
c = rng.normal(0.8, 0.1)
|
238 |
-
# ln(a) ~ N(-2,1)
|
239 |
-
a = np.exp(rng.normal(-2, 1))
|
240 |
-
# ln(alpha) ~ N(0,1)
|
241 |
-
alpha = np.exp(rng.normal(0, 1))
|
242 |
-
# ln(b) ~ N(0,0.5)
|
243 |
-
b = np.exp(rng.normal(0, 0.5))
|
244 |
-
return {"a": a, "b": b, "c": c, "alpha": alpha}
|
245 |
-
|
246 |
-
|
247 |
-
def pow4(x, c, a, b, alpha):
|
248 |
-
return c - (a * x + b) ** -alpha
|
249 |
-
|
250 |
-
|
251 |
-
def prior_pow4(rng):
|
252 |
-
# ln(1 - c) ~ U(-5, 0)
|
253 |
-
c = 1 - np.exp(rng.uniform(-5.0, 0))
|
254 |
-
# ln(a) ~ N(-3, 2)
|
255 |
-
a = np.exp(rng.normal(-3.0, 2))
|
256 |
-
# ln(alpha) ~ N(0,1)
|
257 |
-
alpha = np.exp(rng.normal(0, 1))
|
258 |
-
# ln(b) ~ U(0, 1)
|
259 |
-
b = np.exp(rng.uniform(0, 1))
|
260 |
-
return {"a": a, "b": b, "c": c, "alpha": alpha}
|
261 |
-
|
262 |
-
|
263 |
-
def dr_hill_zero_background(x, theta, eta, kappa):
|
264 |
-
# theta: upper bound
|
265 |
-
# eta: growth rate
|
266 |
-
# initial = theta/(kappa^eta + 1)
|
267 |
-
return (theta * x**eta) / (kappa**eta + x**eta)
|
268 |
-
|
269 |
-
|
270 |
-
def prior_dr_hill_zero_background(rng):
|
271 |
-
# theta ~ U(1,0) N(0.8,0.1)
|
272 |
-
# ln(eta) ~ N(1,1)
|
273 |
-
# ln(kappa) ~ N(1,2)
|
274 |
-
theta = rng.normal(0.8, 0.1)
|
275 |
-
eta = np.exp(rng.normal(1.0, 1.0))
|
276 |
-
kappa = np.exp(rng.normal(1.0, 2.0))
|
277 |
-
return {"theta": theta, "eta": eta, "kappa": kappa}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lcpfn/.ipynb_checkpoints/domhan_prior-checkpoint.py
DELETED
@@ -1,195 +0,0 @@
|
|
1 |
-
from functools import partial
|
2 |
-
import torch
|
3 |
-
import numpy as np
|
4 |
-
from lcpfn.curves import (
|
5 |
-
pow3,
|
6 |
-
ilog2,
|
7 |
-
janoschek,
|
8 |
-
log_power,
|
9 |
-
prior_ilog2,
|
10 |
-
uniform_prior_pow3,
|
11 |
-
weibull,
|
12 |
-
mmf,
|
13 |
-
vap,
|
14 |
-
loglog_linear,
|
15 |
-
exp4,
|
16 |
-
pow4,
|
17 |
-
dr_hill_zero_background,
|
18 |
-
)
|
19 |
-
from lcpfn.curves import (
|
20 |
-
prior_pow3,
|
21 |
-
prior_janoschek,
|
22 |
-
prior_log_power,
|
23 |
-
prior_weibull,
|
24 |
-
prior_mmf,
|
25 |
-
prior_vap,
|
26 |
-
prior_loglog_linear,
|
27 |
-
prior_exp4,
|
28 |
-
prior_pow4,
|
29 |
-
prior_dr_hill_zero_background,
|
30 |
-
)
|
31 |
-
from lcpfn.curves import (
|
32 |
-
uniform_prior_pow3,
|
33 |
-
uniform_prior_ilog2,
|
34 |
-
uniform_prior_janoschek,
|
35 |
-
)
|
36 |
-
|
37 |
-
|
38 |
-
def prior_weights(
|
39 |
-
rng,
|
40 |
-
components=[
|
41 |
-
"pow3",
|
42 |
-
"ilog2",
|
43 |
-
"janoschek",
|
44 |
-
"log_power",
|
45 |
-
"weibull",
|
46 |
-
"mmf",
|
47 |
-
"vap",
|
48 |
-
"loglog_linear",
|
49 |
-
"exp4",
|
50 |
-
"pow4",
|
51 |
-
"dr_hill_zero_background",
|
52 |
-
],
|
53 |
-
):
|
54 |
-
K = len(components)
|
55 |
-
weights = rng.uniform(0.0, 1, size=(K,))
|
56 |
-
return {f: weights[i] for i, f in enumerate(components)}
|
57 |
-
|
58 |
-
|
59 |
-
def sample_from_prior(rng, seq_len=100):
|
60 |
-
return sample_prior_comb(
|
61 |
-
rng=rng, seq_len=seq_len, components=["pow3", "ilog2", "janoschek"], distribution="peaked"
|
62 |
-
)
|
63 |
-
|
64 |
-
|
65 |
-
def sample_prior_comb(
|
66 |
-
rng,
|
67 |
-
components,
|
68 |
-
distribution,
|
69 |
-
var_lnloc=-4,
|
70 |
-
var_lnscale=1,
|
71 |
-
range_constraint=True,
|
72 |
-
seq_len=100,
|
73 |
-
):
|
74 |
-
f_components = {
|
75 |
-
"pow3": pow3,
|
76 |
-
"ilog2": ilog2,
|
77 |
-
"janoschek": janoschek,
|
78 |
-
"log_power": log_power,
|
79 |
-
"weibull": weibull,
|
80 |
-
"mmf": mmf,
|
81 |
-
"vap": vap,
|
82 |
-
"loglog_linear": loglog_linear,
|
83 |
-
"exp4": exp4,
|
84 |
-
"pow4": pow4,
|
85 |
-
"dr_hill_zero_background": dr_hill_zero_background,
|
86 |
-
}
|
87 |
-
|
88 |
-
if distribution == "peaked":
|
89 |
-
f_priors = {
|
90 |
-
"pow3": prior_pow3,
|
91 |
-
"ilog2": prior_ilog2,
|
92 |
-
"janoschek": prior_janoschek,
|
93 |
-
"log_power": prior_log_power,
|
94 |
-
"weibull": prior_weibull,
|
95 |
-
"mmf": prior_mmf,
|
96 |
-
"vap": prior_vap,
|
97 |
-
"loglog_linear": prior_loglog_linear,
|
98 |
-
"exp4": prior_exp4,
|
99 |
-
"pow4": prior_pow4,
|
100 |
-
"dr_hill_zero_background": prior_dr_hill_zero_background,
|
101 |
-
}
|
102 |
-
elif distribution == "uniform":
|
103 |
-
f_priors = {
|
104 |
-
"pow3": uniform_prior_pow3,
|
105 |
-
"ilog2": uniform_prior_ilog2,
|
106 |
-
"janoschek": uniform_prior_janoschek
|
107 |
-
}
|
108 |
-
else:
|
109 |
-
raise NotImplemented()
|
110 |
-
|
111 |
-
x = np.arange(1, seq_len + 1)
|
112 |
-
|
113 |
-
while True:
|
114 |
-
# sample the noiseless curve
|
115 |
-
weights = prior_weights(rng, components=components)
|
116 |
-
y = np.zeros(x.shape, dtype="float")
|
117 |
-
kwargs = 0
|
118 |
-
for f, w in weights.items():
|
119 |
-
kwargs = f_priors[f](rng)
|
120 |
-
# print(f_components[f](x, **kwargs))
|
121 |
-
y += w * f_components[f](x, **kwargs)
|
122 |
-
# add noise (can exceed [0,1], but afaik no way to implement this prior in Tobis work)
|
123 |
-
var = np.exp(
|
124 |
-
rng.normal(var_lnloc, var_lnscale)
|
125 |
-
) # @heri: ln_prob =+ log(normal.pdf(log(var), loc=var_lnloc, scale=var_lnscale))
|
126 |
-
|
127 |
-
# reject any curves that are non-increasing, exceed the [0,1] range
|
128 |
-
if (
|
129 |
-
y[-1] <= y[0]
|
130 |
-
or (range_constraint and (np.any(y < 0) or np.any(y > 1)))
|
131 |
-
or np.isnan(y).any()
|
132 |
-
):
|
133 |
-
continue
|
134 |
-
else:
|
135 |
-
break
|
136 |
-
|
137 |
-
def curve(): # generates a sample from the same model, but with independent noise
|
138 |
-
y_noisy = y + rng.normal(np.zeros_like(y), var)
|
139 |
-
return y, y_noisy
|
140 |
-
|
141 |
-
return curve
|
142 |
-
|
143 |
-
|
144 |
-
def generate_prior_dataset(n, prior=sample_prior_comb, seed=42):
|
145 |
-
"""
|
146 |
-
Returns a fixed sample from the prior (with fixed seq_len) as an n x seq_len np.ndarray
|
147 |
-
"""
|
148 |
-
rng = np.random.RandomState(seed)
|
149 |
-
prior_data = np.stack([prior(rng)()[1] for _ in range(n)])
|
150 |
-
return prior_data
|
151 |
-
|
152 |
-
|
153 |
-
def create_get_batch_func(prior):
|
154 |
-
return partial(get_batch_domhan, prior=prior)
|
155 |
-
|
156 |
-
# function producing batches for PFN training
|
157 |
-
def get_batch_domhan(
|
158 |
-
batch_size,
|
159 |
-
seq_len,
|
160 |
-
num_features,
|
161 |
-
prior,
|
162 |
-
device="cpu",
|
163 |
-
noisy_target=True,
|
164 |
-
**_,
|
165 |
-
):
|
166 |
-
assert num_features == 1
|
167 |
-
|
168 |
-
x = np.arange(1, seq_len + 1)
|
169 |
-
y_target = np.empty((batch_size, seq_len), dtype=float)
|
170 |
-
y_noisy = np.empty((batch_size, seq_len), dtype=float)
|
171 |
-
|
172 |
-
for i in range(batch_size):
|
173 |
-
curve_func = prior(np.random, seq_len=seq_len) # uses numpy rng
|
174 |
-
if noisy_target:
|
175 |
-
_, y_noisy[i] = curve_func()
|
176 |
-
y_target[i] = y_noisy[i]
|
177 |
-
else:
|
178 |
-
y_target[i], y_noisy[i] = curve_func()
|
179 |
-
|
180 |
-
# turn numpy arrays into correctly shaped torch tensors & move them to device
|
181 |
-
x = (
|
182 |
-
torch.arange(1, seq_len + 1)
|
183 |
-
.repeat((num_features, batch_size, 1))
|
184 |
-
.transpose(2, 0)
|
185 |
-
.to(device)
|
186 |
-
)
|
187 |
-
y_target = torch.from_numpy(y_target).transpose(1, 0).to(device)
|
188 |
-
y_noisy = torch.from_numpy(y_noisy).transpose(1, 0).to(device)
|
189 |
-
|
190 |
-
# changes
|
191 |
-
x = x.float()
|
192 |
-
y_target = y_target.float()
|
193 |
-
y_noisy = y_noisy.float()
|
194 |
-
|
195 |
-
return x, y_noisy, y_target
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lcpfn/__init__.py
DELETED
@@ -1,80 +0,0 @@
|
|
1 |
-
import os, sys
|
2 |
-
|
3 |
-
sys.path.insert(0, os.path.dirname(__file__))
|
4 |
-
|
5 |
-
|
6 |
-
model_path = "trained_models"
|
7 |
-
|
8 |
-
|
9 |
-
def prepare_models():
|
10 |
-
pfns4bo_dir = os.path.dirname(__file__)
|
11 |
-
model_names = [
|
12 |
-
"pfn_EPOCH1000_EMSIZE512_NLAYERS12_NBUCKETS1000.pt",
|
13 |
-
"pfn_EPOCH1000_EMSIZE512_NLAYERS6_NBUCKETS1000.pt",
|
14 |
-
]
|
15 |
-
|
16 |
-
for name in model_names:
|
17 |
-
weights_path = os.path.join(pfns4bo_dir, model_path, name)
|
18 |
-
compressed_weights_path = os.path.join(pfns4bo_dir, model_path, name + ".gz")
|
19 |
-
if not os.path.exists(weights_path):
|
20 |
-
if not os.path.exists(compressed_weights_path):
|
21 |
-
print("Downloading", os.path.abspath(compressed_weights_path))
|
22 |
-
import requests
|
23 |
-
|
24 |
-
url = f'https://ml.informatik.uni-freiburg.de/research-artifacts/lcpfn/{name + ".gz"}'
|
25 |
-
r = requests.get(url, allow_redirects=True)
|
26 |
-
os.makedirs(os.path.dirname(compressed_weights_path), exist_ok=True)
|
27 |
-
with open(compressed_weights_path, "wb") as f:
|
28 |
-
f.write(r.content)
|
29 |
-
if os.path.exists(compressed_weights_path):
|
30 |
-
print("Unzipping", name)
|
31 |
-
os.system(f"gzip -dk {compressed_weights_path}")
|
32 |
-
else:
|
33 |
-
print("Failed to find", compressed_weights_path)
|
34 |
-
print(
|
35 |
-
"Make sure you have an internet connection to download the model automatically.."
|
36 |
-
)
|
37 |
-
if os.path.exists(weights_path):
|
38 |
-
print("Successfully located model at", weights_path)
|
39 |
-
|
40 |
-
|
41 |
-
model_dict = {
|
42 |
-
"EMSIZE512_NLAYERS12_NBUCKETS1000": os.path.join(
|
43 |
-
os.path.dirname(__file__),
|
44 |
-
model_path,
|
45 |
-
"pfn_EPOCH1000_EMSIZE512_NLAYERS12_NBUCKETS1000.pt",
|
46 |
-
),
|
47 |
-
"EMSIZE512_NLAYERS6_NBUCKETS1000": os.path.join(
|
48 |
-
os.path.dirname(__file__),
|
49 |
-
model_path,
|
50 |
-
"pfn_EPOCH1000_EMSIZE512_NLAYERS6_NBUCKETS1000.pt",
|
51 |
-
),
|
52 |
-
}
|
53 |
-
|
54 |
-
|
55 |
-
def __getattr__(name):
|
56 |
-
if name in model_dict:
|
57 |
-
if not os.path.exists(model_dict[name]):
|
58 |
-
print(
|
59 |
-
"Can't find",
|
60 |
-
os.path.abspath(model_dict[name]),
|
61 |
-
"thus unzipping/downloading models now.",
|
62 |
-
)
|
63 |
-
print("This might take a while..")
|
64 |
-
prepare_models()
|
65 |
-
return model_dict[name]
|
66 |
-
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
67 |
-
|
68 |
-
|
69 |
-
from .version import __version__
|
70 |
-
from lcpfn.model import LCPFN
|
71 |
-
from lcpfn.train_lcpfn import train_lcpfn
|
72 |
-
from lcpfn.domhan_prior import sample_from_prior, create_get_batch_func
|
73 |
-
|
74 |
-
__all__ = [
|
75 |
-
"LCPFN",
|
76 |
-
"train_lcpfn",
|
77 |
-
"sample_from_prior",
|
78 |
-
"create_get_batch_func",
|
79 |
-
"__version__",
|
80 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lcpfn/bar_distribution.py
DELETED
@@ -1,349 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from torch import nn
|
3 |
-
|
4 |
-
|
5 |
-
class BarDistribution(nn.Module):
|
6 |
-
def __init__(
|
7 |
-
self, borders: torch.Tensor, smoothing=0.0
|
8 |
-
): # here borders should start with min and end with max, where all values lie in (min,max) and are sorted
|
9 |
-
# sorted list of borders
|
10 |
-
super().__init__()
|
11 |
-
assert len(borders.shape) == 1
|
12 |
-
# self.borders = borders
|
13 |
-
self.register_buffer("borders", borders)
|
14 |
-
self.register_buffer("smoothing", torch.tensor(smoothing))
|
15 |
-
# self.bucket_widths = self.borders[1:] - self.borders[:-1]
|
16 |
-
self.register_buffer("bucket_widths", self.borders[1:] - self.borders[:-1])
|
17 |
-
full_width = self.bucket_widths.sum()
|
18 |
-
border_order = torch.argsort(borders)
|
19 |
-
assert (
|
20 |
-
full_width - (self.borders[-1] - self.borders[0])
|
21 |
-
).abs() < 1e-4, f"diff: {full_width - (self.borders[-1] - self.borders[0])}"
|
22 |
-
assert (
|
23 |
-
border_order == torch.arange(len(borders)).to(border_order.device)
|
24 |
-
).all(), "Please provide sorted borders!"
|
25 |
-
self.num_bars = len(borders) - 1
|
26 |
-
|
27 |
-
def map_to_bucket_idx(self, y):
|
28 |
-
target_sample = torch.searchsorted(self.borders, y) - 1
|
29 |
-
target_sample[y == self.borders[0]] = 0
|
30 |
-
target_sample[y == self.borders[-1]] = self.num_bars - 1
|
31 |
-
return target_sample
|
32 |
-
|
33 |
-
def forward(
|
34 |
-
self, logits, y
|
35 |
-
): # gives the negative log density (the _loss_), y: T x B, logits: T x B x self.num_bars
|
36 |
-
target_sample = self.map_to_bucket_idx(y)
|
37 |
-
assert (target_sample >= 0).all() and (
|
38 |
-
target_sample < self.num_bars
|
39 |
-
).all(), f"y {y} not in support set for borders (min_y, max_y) {self.borders}"
|
40 |
-
assert (
|
41 |
-
logits.shape[-1] == self.num_bars
|
42 |
-
), f"{logits.shape[-1]} vs {self.num_bars}"
|
43 |
-
|
44 |
-
bucket_log_probs = torch.log_softmax(logits, -1)
|
45 |
-
scaled_bucket_log_probs = bucket_log_probs - torch.log(self.bucket_widths)
|
46 |
-
# print(bucket_log_probs, logits.shape)
|
47 |
-
|
48 |
-
nll_loss = -scaled_bucket_log_probs.gather(
|
49 |
-
-1, target_sample.unsqueeze(-1)
|
50 |
-
).squeeze(-1)
|
51 |
-
|
52 |
-
smooth_loss = -scaled_bucket_log_probs.mean(dim=-1)
|
53 |
-
smoothing = self.smoothing if self.training else 0.0
|
54 |
-
loss = (1.0 - smoothing) * nll_loss + smoothing * smooth_loss
|
55 |
-
return loss
|
56 |
-
|
57 |
-
def mean(self, logits):
|
58 |
-
bucket_means = self.borders[:-1] + self.bucket_widths / 2
|
59 |
-
p = torch.softmax(logits, -1)
|
60 |
-
return p @ bucket_means
|
61 |
-
|
62 |
-
def icdf(self, logits, left_prob):
|
63 |
-
"""
|
64 |
-
Implementation of the quantile function
|
65 |
-
:param logits: Tensor of any shape, with the last dimension being logits
|
66 |
-
:param left_prob: float: The probability mass to the left of the result.
|
67 |
-
:return: Position with `left_prob` probability weight to the left.
|
68 |
-
"""
|
69 |
-
probs = logits.softmax(-1)
|
70 |
-
cumprobs = torch.cumsum(probs, -1)
|
71 |
-
idx = (
|
72 |
-
torch.searchsorted(
|
73 |
-
cumprobs,
|
74 |
-
left_prob * torch.ones(*cumprobs.shape[:-1], 1, device=probs.device),
|
75 |
-
)
|
76 |
-
.squeeze(-1)
|
77 |
-
.clamp(0, cumprobs.shape[-1] - 1)
|
78 |
-
) # this might not do the right for outliers
|
79 |
-
cumprobs = torch.cat(
|
80 |
-
[torch.zeros(*cumprobs.shape[:-1], 1, device=logits.device), cumprobs], -1
|
81 |
-
)
|
82 |
-
|
83 |
-
rest_prob = left_prob - cumprobs.gather(-1, idx[..., None]).squeeze(-1)
|
84 |
-
left_border = self.borders[idx]
|
85 |
-
right_border = self.borders[idx + 1]
|
86 |
-
return left_border + (right_border - left_border) * rest_prob / probs.gather(
|
87 |
-
-1, idx[..., None]
|
88 |
-
).squeeze(-1)
|
89 |
-
|
90 |
-
def quantile(self, logits, center_prob=0.682):
|
91 |
-
side_probs = (1.0 - center_prob) / 2
|
92 |
-
return torch.stack(
|
93 |
-
(self.icdf(logits, side_probs), self.icdf(logits, 1.0 - side_probs)), -1
|
94 |
-
)
|
95 |
-
|
96 |
-
def ucb(self, logits, best_f, rest_prob=(1 - 0.682) / 2, maximize=True):
|
97 |
-
"""
|
98 |
-
UCB utility. Rest Prob is the amount of utility above (below) the confidence interval that is ignored.
|
99 |
-
Higher rest_prob is equivalent to lower beta in the standard GP-UCB formulation.
|
100 |
-
:param logits: Logits, as returned by the Transformer.
|
101 |
-
:param best_f: Only here, since the other utilities have it.
|
102 |
-
:param rest_prob: The amount of utility above (below) the confidence interval that is ignored.
|
103 |
-
The default is equivalent to using GP-UCB with `beta=1`.
|
104 |
-
To get the corresponding `beta`, where `beta` is from
|
105 |
-
the standard GP definition of UCB `ucb_utility = mean + beta * std`,
|
106 |
-
you can use this computation: `beta = math.sqrt(2)*torch.erfinv(torch.tensor(2*rest_prob-1))`.
|
107 |
-
:param maximize:
|
108 |
-
:return: utility
|
109 |
-
"""
|
110 |
-
if maximize:
|
111 |
-
rest_prob = 1 - rest_prob
|
112 |
-
return self.icdf(logits, rest_prob)
|
113 |
-
|
114 |
-
def mode(self, logits):
|
115 |
-
mode_inds = logits.argmax(-1)
|
116 |
-
bucket_means = self.borders[:-1] + self.bucket_widths / 2
|
117 |
-
return bucket_means[mode_inds]
|
118 |
-
|
119 |
-
def ei(
|
120 |
-
self, logits, best_f, maximize=True
|
121 |
-
): # logits: evaluation_points x batch x feature_dim
|
122 |
-
bucket_means = self.borders[:-1] + self.bucket_widths / 2
|
123 |
-
if maximize:
|
124 |
-
bucket_contributions = torch.tensor(
|
125 |
-
[
|
126 |
-
max((bucket_max + max(bucket_min, best_f)) / 2 - best_f, 0)
|
127 |
-
for bucket_min, bucket_max, bucket_mean in zip(
|
128 |
-
self.borders[:-1], self.borders[1:], bucket_means
|
129 |
-
)
|
130 |
-
],
|
131 |
-
dtype=logits.dtype,
|
132 |
-
device=logits.device,
|
133 |
-
)
|
134 |
-
else:
|
135 |
-
bucket_contributions = torch.tensor(
|
136 |
-
[
|
137 |
-
-min((min(bucket_max, best_f) + bucket_min) / 2 - best_f, 0)
|
138 |
-
for bucket_min, bucket_max, bucket_mean in zip( # min on max instead of max on min, and compare min < instead of max >
|
139 |
-
self.borders[:-1], self.borders[1:], bucket_means
|
140 |
-
)
|
141 |
-
],
|
142 |
-
dtype=logits.dtype,
|
143 |
-
device=logits.device,
|
144 |
-
)
|
145 |
-
p = torch.softmax(logits, -1)
|
146 |
-
return p @ bucket_contributions
|
147 |
-
|
148 |
-
def pi(
|
149 |
-
self, logits, best_f, maximize=True
|
150 |
-
): # logits: evaluation_points x batch x feature_dim
|
151 |
-
"""
|
152 |
-
Acquisition Function: Probability of Improvement
|
153 |
-
:param logits: as returned by Transformer
|
154 |
-
:param best_f: best evaluation so far (the incumbent)
|
155 |
-
:param maximize: whether to maximize
|
156 |
-
:return: utility
|
157 |
-
"""
|
158 |
-
assert maximize is True
|
159 |
-
p = torch.softmax(logits, -1)
|
160 |
-
border_widths = self.borders[1:] - self.borders[:-1]
|
161 |
-
factor = 1.0 - ((best_f - self.borders[:-1]) / border_widths).clamp(0.0, 1.0)
|
162 |
-
return (p * factor).sum(-1)
|
163 |
-
|
164 |
-
def mean_of_square(self, logits):
|
165 |
-
"""
|
166 |
-
Computes E[x^2].
|
167 |
-
:param logits: Output of the model.
|
168 |
-
"""
|
169 |
-
left_borders = self.borders[:-1]
|
170 |
-
right_borders = self.borders[1:]
|
171 |
-
bucket_mean_of_square = (
|
172 |
-
left_borders.square()
|
173 |
-
+ right_borders.square()
|
174 |
-
+ left_borders * right_borders
|
175 |
-
) / 3.0
|
176 |
-
p = torch.softmax(logits, -1)
|
177 |
-
return p @ bucket_mean_of_square
|
178 |
-
|
179 |
-
def variance(self, logits):
|
180 |
-
return self.mean_of_square(logits) - self.mean(logits).square()
|
181 |
-
|
182 |
-
|
183 |
-
class FullSupportBarDistribution(BarDistribution):
|
184 |
-
@staticmethod
|
185 |
-
def halfnormal_with_p_weight_before(range_max, p=0.5):
|
186 |
-
s = range_max / torch.distributions.HalfNormal(torch.tensor(1.0)).icdf(
|
187 |
-
torch.tensor(p)
|
188 |
-
)
|
189 |
-
return torch.distributions.HalfNormal(s)
|
190 |
-
|
191 |
-
def forward(
|
192 |
-
self, logits, y
|
193 |
-
): # gives the negative log density (the _loss_), y: T x B, logits: T x B x self.num_bars
|
194 |
-
assert self.num_bars > 1
|
195 |
-
target_sample = self.map_to_bucket_idx(y)
|
196 |
-
target_sample.clamp_(0, self.num_bars - 1)
|
197 |
-
assert logits.shape[-1] == self.num_bars
|
198 |
-
|
199 |
-
bucket_log_probs = torch.log_softmax(logits, -1)
|
200 |
-
scaled_bucket_log_probs = bucket_log_probs - torch.log(self.bucket_widths)
|
201 |
-
# print(bucket_log_probs, logits.shape)
|
202 |
-
log_probs = scaled_bucket_log_probs.gather(
|
203 |
-
-1, target_sample.unsqueeze(-1)
|
204 |
-
).squeeze(-1)
|
205 |
-
|
206 |
-
side_normals = (
|
207 |
-
self.halfnormal_with_p_weight_before(self.bucket_widths[0]),
|
208 |
-
self.halfnormal_with_p_weight_before(self.bucket_widths[-1]),
|
209 |
-
)
|
210 |
-
|
211 |
-
# TODO look over it again
|
212 |
-
log_probs[target_sample == 0] += side_normals[0].log_prob(
|
213 |
-
(self.borders[1] - y[target_sample == 0]).clamp(min=0.00000001)
|
214 |
-
) + torch.log(self.bucket_widths[0])
|
215 |
-
log_probs[target_sample == self.num_bars - 1] += side_normals[1].log_prob(
|
216 |
-
y[target_sample == self.num_bars - 1] - self.borders[-2]
|
217 |
-
) + torch.log(self.bucket_widths[-1])
|
218 |
-
|
219 |
-
nll_loss = -log_probs
|
220 |
-
|
221 |
-
smooth_loss = -scaled_bucket_log_probs.mean(dim=-1)
|
222 |
-
smoothing = self.smoothing if self.training else 0.0
|
223 |
-
loss = (1.0 - smoothing) * nll_loss + smoothing * smooth_loss
|
224 |
-
|
225 |
-
return loss
|
226 |
-
|
227 |
-
def mean(self, logits):
|
228 |
-
bucket_means = self.borders[:-1] + self.bucket_widths / 2
|
229 |
-
p = torch.softmax(logits, -1)
|
230 |
-
side_normals = (
|
231 |
-
self.halfnormal_with_p_weight_before(self.bucket_widths[0]),
|
232 |
-
self.halfnormal_with_p_weight_before(self.bucket_widths[-1]),
|
233 |
-
)
|
234 |
-
bucket_means[0] = -side_normals[0].mean + self.borders[1]
|
235 |
-
bucket_means[-1] = side_normals[1].mean + self.borders[-2]
|
236 |
-
return p @ bucket_means
|
237 |
-
|
238 |
-
|
239 |
-
def get_bucket_limits_(
|
240 |
-
num_outputs: int,
|
241 |
-
full_range: tuple = None,
|
242 |
-
ys: torch.Tensor = None,
|
243 |
-
verbose: bool = False,
|
244 |
-
):
|
245 |
-
assert (ys is not None) or (full_range is not None)
|
246 |
-
if ys is not None:
|
247 |
-
ys = ys.flatten()
|
248 |
-
if len(ys) % num_outputs:
|
249 |
-
ys = ys[: -(len(ys) % num_outputs)]
|
250 |
-
print(
|
251 |
-
f"Using {len(ys)} y evals to estimate {num_outputs} buckets. Cut off the last {len(ys) % num_outputs} ys."
|
252 |
-
)
|
253 |
-
ys_per_bucket = len(ys) // num_outputs
|
254 |
-
if full_range is None:
|
255 |
-
full_range = (ys.min(), ys.max())
|
256 |
-
else:
|
257 |
-
assert full_range[0] <= ys.min() and full_range[1] >= ys.max()
|
258 |
-
full_range = torch.tensor(full_range)
|
259 |
-
ys_sorted, ys_order = ys.sort(0)
|
260 |
-
bucket_limits = (
|
261 |
-
ys_sorted[ys_per_bucket - 1 :: ys_per_bucket][:-1]
|
262 |
-
+ ys_sorted[ys_per_bucket::ys_per_bucket]
|
263 |
-
) / 2
|
264 |
-
if verbose:
|
265 |
-
print(
|
266 |
-
f"Using {len(ys)} y evals to estimate {num_outputs} buckets. Cut off the last {len(ys) % num_outputs} ys."
|
267 |
-
)
|
268 |
-
print(full_range)
|
269 |
-
bucket_limits = torch.cat(
|
270 |
-
[full_range[0].unsqueeze(0), bucket_limits, full_range[1].unsqueeze(0)], 0
|
271 |
-
)
|
272 |
-
|
273 |
-
else:
|
274 |
-
class_width = (full_range[1] - full_range[0]) / num_outputs
|
275 |
-
bucket_limits = torch.cat(
|
276 |
-
[
|
277 |
-
full_range[0] + torch.arange(num_outputs).float() * class_width,
|
278 |
-
torch.tensor(full_range[1]).unsqueeze(0),
|
279 |
-
],
|
280 |
-
0,
|
281 |
-
)
|
282 |
-
|
283 |
-
assert (
|
284 |
-
len(bucket_limits) - 1 == num_outputs
|
285 |
-
and full_range[0] == bucket_limits[0]
|
286 |
-
and full_range[-1] == bucket_limits[-1]
|
287 |
-
)
|
288 |
-
return bucket_limits
|
289 |
-
|
290 |
-
|
291 |
-
def get_bucket_limits(
|
292 |
-
num_outputs: int,
|
293 |
-
full_range: tuple = None,
|
294 |
-
ys: torch.Tensor = None,
|
295 |
-
verbose: bool = False,
|
296 |
-
):
|
297 |
-
assert (ys is None) != (
|
298 |
-
full_range is None
|
299 |
-
), "Either full_range or ys must be passed."
|
300 |
-
|
301 |
-
if ys is not None:
|
302 |
-
ys = ys.flatten()
|
303 |
-
ys = ys[~torch.isnan(ys)]
|
304 |
-
if len(ys) % num_outputs:
|
305 |
-
ys = ys[: -(len(ys) % num_outputs)]
|
306 |
-
print(
|
307 |
-
f"Using {len(ys)} y evals to estimate {num_outputs} buckets. Cut off the last {len(ys) % num_outputs} ys."
|
308 |
-
)
|
309 |
-
ys_per_bucket = len(ys) // num_outputs
|
310 |
-
if full_range is None:
|
311 |
-
full_range = (ys.min(), ys.max())
|
312 |
-
else:
|
313 |
-
assert (
|
314 |
-
full_range[0] <= ys.min() and full_range[1] >= ys.max()
|
315 |
-
), f"full_range {full_range} not in range of ys {ys.min(), ys.max()}"
|
316 |
-
full_range = torch.tensor(full_range)
|
317 |
-
ys_sorted, ys_order = ys.sort(0)
|
318 |
-
bucket_limits = (
|
319 |
-
ys_sorted[ys_per_bucket - 1 :: ys_per_bucket][:-1]
|
320 |
-
+ ys_sorted[ys_per_bucket::ys_per_bucket]
|
321 |
-
) / 2
|
322 |
-
if verbose:
|
323 |
-
print(
|
324 |
-
f"Using {len(ys)} y evals to estimate {num_outputs} buckets. Cut off the last {len(ys) % num_outputs} ys."
|
325 |
-
)
|
326 |
-
print(full_range)
|
327 |
-
bucket_limits = torch.cat(
|
328 |
-
[full_range[0].unsqueeze(0), bucket_limits, full_range[1].unsqueeze(0)], 0
|
329 |
-
)
|
330 |
-
|
331 |
-
else:
|
332 |
-
class_width = (full_range[1] - full_range[0]) / num_outputs
|
333 |
-
bucket_limits = torch.cat(
|
334 |
-
[
|
335 |
-
full_range[0] + torch.arange(num_outputs).float() * class_width,
|
336 |
-
torch.tensor(full_range[1]).unsqueeze(0),
|
337 |
-
],
|
338 |
-
0,
|
339 |
-
)
|
340 |
-
|
341 |
-
assert (
|
342 |
-
len(bucket_limits) - 1 == num_outputs
|
343 |
-
), f"len(bucket_limits) - 1 == {len(bucket_limits) - 1} != {num_outputs} == num_outputs"
|
344 |
-
assert full_range[0] == bucket_limits[0], f"{full_range[0]} != {bucket_limits[0]}"
|
345 |
-
assert (
|
346 |
-
full_range[-1] == bucket_limits[-1]
|
347 |
-
), f"{full_range[-1]} != {bucket_limits[-1]}"
|
348 |
-
|
349 |
-
return bucket_limits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lcpfn/curves.py
DELETED
@@ -1,277 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
from collections import OrderedDict
|
3 |
-
|
4 |
-
prior = {
|
5 |
-
"pow3": {
|
6 |
-
"uniform": OrderedDict(
|
7 |
-
a={"type": "uniform", "param1": -1, "param2": 1},
|
8 |
-
c={"type": "uniform", "param1": 0, "param2": 1},
|
9 |
-
alpha={"type": "uniform", "param1": 0, "param2": 1},
|
10 |
-
),
|
11 |
-
"peaked": OrderedDict(
|
12 |
-
a={"type": "uniform", "param1": -0.6, "param2": 0.6},
|
13 |
-
c={"type": "uniform", "param1": 0, "param2": 1.25},
|
14 |
-
alpha={"type": "log_normal", "param1": 0, "param2": 2},
|
15 |
-
),
|
16 |
-
},
|
17 |
-
"ilog2": {
|
18 |
-
"uniform": OrderedDict(
|
19 |
-
c={"type": "uniform", "param1": 0, "param2": 1},
|
20 |
-
a={"type": "uniform", "param1": -1, "param2": 1},
|
21 |
-
),
|
22 |
-
"peaked": OrderedDict(
|
23 |
-
c={"type": "uniform", "param1": 0, "param2": 1},
|
24 |
-
a={"type": "uniform", "param1": -0.5, "param2": 0.5},
|
25 |
-
),
|
26 |
-
},
|
27 |
-
"janoschek": {
|
28 |
-
"uniform": OrderedDict(
|
29 |
-
a={"type": "uniform", "param1": 0, "param2": 1},
|
30 |
-
beta={"type": "uniform", "param1": 0, "param2": 2},
|
31 |
-
k={"type": "uniform", "param1": 0, "param2": 1},
|
32 |
-
delta={"type": "uniform", "param1": -5, "param2": 5},
|
33 |
-
),
|
34 |
-
"peaked": OrderedDict(
|
35 |
-
a={"type": "uniform", "param1": 0, "param2": 1},
|
36 |
-
beta={"type": "uniform", "param1": 0, "param2": 2},
|
37 |
-
k={"type": "log_normal", "param1": -2, "param2": 1},
|
38 |
-
delta={"type": "log_normal", "param1": 0, "param2": 0.5},
|
39 |
-
),
|
40 |
-
},
|
41 |
-
}
|
42 |
-
|
43 |
-
|
44 |
-
def prior_sampler(rng, type, param1, param2):
|
45 |
-
if type == "uniform":
|
46 |
-
return rng.uniform(param1, param2)
|
47 |
-
elif type == "log_normal":
|
48 |
-
return rng.lognormal(param1, param2)
|
49 |
-
raise Exception("Unknown prior type: {}".format(type))
|
50 |
-
|
51 |
-
|
52 |
-
def pow3(x, c, a, alpha):
|
53 |
-
return c - a * (x) ** (-alpha)
|
54 |
-
|
55 |
-
|
56 |
-
def prior_pow3(rng):
|
57 |
-
return {
|
58 |
-
p: prior_sampler(
|
59 |
-
rng,
|
60 |
-
prior["pow3"]["peaked"][p]["type"],
|
61 |
-
param1=prior["pow3"]["peaked"][p]["param1"],
|
62 |
-
param2=prior["pow3"]["peaked"][p]["param2"],
|
63 |
-
)
|
64 |
-
for p in ["a", "c", "alpha"]
|
65 |
-
}
|
66 |
-
|
67 |
-
|
68 |
-
def uniform_prior_pow3(rng):
|
69 |
-
return {
|
70 |
-
p: prior_sampler(
|
71 |
-
rng,
|
72 |
-
prior["pow3"]["uniform"][p]["type"],
|
73 |
-
param1=prior["pow3"]["uniform"][p]["param1"],
|
74 |
-
param2=prior["pow3"]["uniform"][p]["param2"],
|
75 |
-
)
|
76 |
-
for p in ["a", "c", "alpha"]
|
77 |
-
}
|
78 |
-
|
79 |
-
|
80 |
-
def ilog2(x, c, a):
|
81 |
-
return c - a / (np.log(x + 1))
|
82 |
-
|
83 |
-
|
84 |
-
def prior_ilog2(rng):
|
85 |
-
return {
|
86 |
-
p: prior_sampler(
|
87 |
-
rng,
|
88 |
-
prior["ilog2"]["peaked"][p]["type"],
|
89 |
-
param1=prior["ilog2"]["peaked"][p]["param1"],
|
90 |
-
param2=prior["ilog2"]["peaked"][p]["param2"],
|
91 |
-
)
|
92 |
-
for p in ["a", "c"]
|
93 |
-
}
|
94 |
-
|
95 |
-
|
96 |
-
def uniform_prior_ilog2(rng):
|
97 |
-
return {
|
98 |
-
p: prior_sampler(
|
99 |
-
rng,
|
100 |
-
prior["ilog2"]["uniform"][p]["type"],
|
101 |
-
param1=prior["ilog2"]["uniform"][p]["param1"],
|
102 |
-
param2=prior["ilog2"]["uniform"][p]["param2"],
|
103 |
-
)
|
104 |
-
for p in ["a", "c"]
|
105 |
-
}
|
106 |
-
|
107 |
-
|
108 |
-
def janoschek(x, a, beta, k, delta):
|
109 |
-
"""
|
110 |
-
http://www.pisces-conservation.com/growthhelp/janoschek.htm
|
111 |
-
"""
|
112 |
-
return a - (a - beta) * np.exp(-k * x**delta)
|
113 |
-
|
114 |
-
|
115 |
-
def prior_janoschek(rng):
|
116 |
-
return {
|
117 |
-
p: prior_sampler(
|
118 |
-
rng,
|
119 |
-
prior["janoschek"]["peaked"][p]["type"],
|
120 |
-
param1=prior["janoschek"]["peaked"][p]["param1"],
|
121 |
-
param2=prior["janoschek"]["peaked"][p]["param2"],
|
122 |
-
)
|
123 |
-
for p in ["a", "beta", "k", "delta"]
|
124 |
-
}
|
125 |
-
|
126 |
-
|
127 |
-
def uniform_prior_janoschek(rng):
|
128 |
-
return {
|
129 |
-
p: prior_sampler(
|
130 |
-
rng,
|
131 |
-
prior["janoschek"]["uniform"][p]["type"],
|
132 |
-
param1=prior["janoschek"]["uniform"][p]["param1"],
|
133 |
-
param2=prior["janoschek"]["uniform"][p]["param2"],
|
134 |
-
)
|
135 |
-
for p in ["a", "beta", "k", "delta"]
|
136 |
-
}
|
137 |
-
|
138 |
-
|
139 |
-
def log_power(x, a, b, c):
|
140 |
-
# a: upper bound
|
141 |
-
# c: growth rate
|
142 |
-
# initial = a/ (1 + (1/e^b)^c
|
143 |
-
return a / (1.0 + (x / np.exp(b)) ** c)
|
144 |
-
|
145 |
-
|
146 |
-
def prior_log_power(rng):
|
147 |
-
# a ~ N(0.8,0.1)
|
148 |
-
# b ~ N(1,1)
|
149 |
-
# c ~ U(-3,0)
|
150 |
-
a = rng.normal(0.8, 0.1)
|
151 |
-
b = rng.normal(1.0, 1.0)
|
152 |
-
c = rng.uniform(-3.0, 0.0)
|
153 |
-
return {"a": a, "b": b, "c": c}
|
154 |
-
|
155 |
-
|
156 |
-
def weibull(x, alpha, beta, kappa, delta):
|
157 |
-
"""
|
158 |
-
Weibull modell
|
159 |
-
http://www.pisces-conservation.com/growthhelp/index.html?morgan_mercer_floden.htm
|
160 |
-
alpha: upper asymptote
|
161 |
-
beta: lower asymptote
|
162 |
-
k: growth rate
|
163 |
-
delta: controls the x-ordinate for the point of inflection
|
164 |
-
"""
|
165 |
-
return alpha - (alpha - beta) * np.exp(-((kappa * x) ** delta))
|
166 |
-
|
167 |
-
|
168 |
-
def prior_weibull(rng):
|
169 |
-
alpha = rng.uniform(0.0, 1.5)
|
170 |
-
beta = rng.uniform(0.0, 1)
|
171 |
-
kappa = np.exp(rng.normal(-2.0, 1.0))
|
172 |
-
delta = np.exp(rng.normal(0, 0.5))
|
173 |
-
return {"alpha": alpha, "beta": beta, "kappa": kappa, "delta": delta}
|
174 |
-
|
175 |
-
|
176 |
-
def mmf(x, alpha, beta, kappa, delta):
|
177 |
-
"""
|
178 |
-
Morgan-Mercer-Flodin
|
179 |
-
description:
|
180 |
-
Nonlinear Regression page 342
|
181 |
-
http://bit.ly/1jodG17
|
182 |
-
http://www.pisces-conservation.com/growthhelp/index.html?morgan_mercer_floden.htm
|
183 |
-
alpha: upper asymptote
|
184 |
-
kappa: growth rate
|
185 |
-
beta: initial value
|
186 |
-
delta: controls the point of inflection
|
187 |
-
"""
|
188 |
-
return alpha - (alpha - beta) / (1.0 + (kappa * x) ** delta)
|
189 |
-
|
190 |
-
|
191 |
-
def prior_mmf(rng):
|
192 |
-
# alpha ~ N(0.8,0.1)
|
193 |
-
# beta ~ N(0.2,0.1)
|
194 |
-
# ln(kappa) ~ N(0,2)
|
195 |
-
# ln(delta) ~ N(0,1)
|
196 |
-
alpha = rng.normal(0.8, 0.1)
|
197 |
-
beta = rng.normal(0.2, 0.1)
|
198 |
-
kappa = np.exp(rng.normal(0, 2))
|
199 |
-
delta = np.exp(rng.normal(0, 1))
|
200 |
-
return {"alpha": alpha, "beta": beta, "kappa": kappa, "delta": delta}
|
201 |
-
|
202 |
-
|
203 |
-
def vap(x, a, b, c):
|
204 |
-
"""Vapor pressure model"""
|
205 |
-
# no upper bound if c > 0
|
206 |
-
# a = ln(upper bound) for c=0
|
207 |
-
# a+b = ln(initial)
|
208 |
-
return np.exp(a + b / x + c * np.log(x))
|
209 |
-
|
210 |
-
|
211 |
-
def prior_vap(rng):
|
212 |
-
a = rng.uniform(-2.0, 0.0) # @heri: range check
|
213 |
-
b = rng.uniform(-4.0, 0.0) # @heri: range check
|
214 |
-
c = np.exp(rng.uniform(-8.0, 0.0)) # @heri: same as weights
|
215 |
-
return {"a": a, "b": b, "c": c}
|
216 |
-
|
217 |
-
|
218 |
-
def loglog_linear(x, a, b):
|
219 |
-
x = np.log(x)
|
220 |
-
return np.log(a * x + b)
|
221 |
-
|
222 |
-
|
223 |
-
def prior_loglog_linear(rng):
|
224 |
-
# ln(a) ~ N(-2, 1)
|
225 |
-
# ln(b) ~ U(0, 1)
|
226 |
-
a = np.exp(rng.normal(-2.0, 1.0))
|
227 |
-
b = np.exp(rng.uniform(0.0, 1.0))
|
228 |
-
return {"a": a, "b": b}
|
229 |
-
|
230 |
-
|
231 |
-
def exp4(x, c, a, b, alpha):
|
232 |
-
return c - np.exp(-a * (x**alpha) + b)
|
233 |
-
|
234 |
-
|
235 |
-
def prior_exp4(rng):
|
236 |
-
# c ~ N(0.8,0.1)
|
237 |
-
c = rng.normal(0.8, 0.1)
|
238 |
-
# ln(a) ~ N(-2,1)
|
239 |
-
a = np.exp(rng.normal(-2, 1))
|
240 |
-
# ln(alpha) ~ N(0,1)
|
241 |
-
alpha = np.exp(rng.normal(0, 1))
|
242 |
-
# ln(b) ~ N(0,0.5)
|
243 |
-
b = np.exp(rng.normal(0, 0.5))
|
244 |
-
return {"a": a, "b": b, "c": c, "alpha": alpha}
|
245 |
-
|
246 |
-
|
247 |
-
def pow4(x, c, a, b, alpha):
|
248 |
-
return c - (a * x + b) ** -alpha
|
249 |
-
|
250 |
-
|
251 |
-
def prior_pow4(rng):
|
252 |
-
# ln(1 - c) ~ U(-5, 0)
|
253 |
-
c = 1 - np.exp(rng.uniform(-5.0, 0))
|
254 |
-
# ln(a) ~ N(-3, 2)
|
255 |
-
a = np.exp(rng.normal(-3.0, 2))
|
256 |
-
# ln(alpha) ~ N(0,1)
|
257 |
-
alpha = np.exp(rng.normal(0, 1))
|
258 |
-
# ln(b) ~ U(0, 1)
|
259 |
-
b = np.exp(rng.uniform(0, 1))
|
260 |
-
return {"a": a, "b": b, "c": c, "alpha": alpha}
|
261 |
-
|
262 |
-
|
263 |
-
def dr_hill_zero_background(x, theta, eta, kappa):
|
264 |
-
# theta: upper bound
|
265 |
-
# eta: growth rate
|
266 |
-
# initial = theta/(kappa^eta + 1)
|
267 |
-
return (theta * x**eta) / (kappa**eta + x**eta)
|
268 |
-
|
269 |
-
|
270 |
-
def prior_dr_hill_zero_background(rng):
|
271 |
-
# theta ~ U(1,0) N(0.8,0.1)
|
272 |
-
# ln(eta) ~ N(1,1)
|
273 |
-
# ln(kappa) ~ N(1,2)
|
274 |
-
theta = rng.normal(0.8, 0.1)
|
275 |
-
eta = np.exp(rng.normal(1.0, 1.0))
|
276 |
-
kappa = np.exp(rng.normal(1.0, 2.0))
|
277 |
-
return {"theta": theta, "eta": eta, "kappa": kappa}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lcpfn/decoders.py
DELETED
@@ -1,42 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from torch import nn
|
3 |
-
import random
|
4 |
-
|
5 |
-
from torch import Tensor
|
6 |
-
import torch.nn.functional as F
|
7 |
-
|
8 |
-
|
9 |
-
class GELU(nn.Module):
|
10 |
-
def forward(self, input: Tensor) -> Tensor:
|
11 |
-
return F.gelu(input)
|
12 |
-
|
13 |
-
|
14 |
-
class ScaledDecoder(nn.Module):
|
15 |
-
def __init__(self, ninp, nhid, nout):
|
16 |
-
super().__init__()
|
17 |
-
self.linear = nn.Linear(ninp, nhid)
|
18 |
-
self.linear1 = nn.Linear(nhid, nout)
|
19 |
-
self.linear2 = nn.Linear(nhid, 10)
|
20 |
-
|
21 |
-
def forward(self, x):
|
22 |
-
# return torch.cat([self.linear1(x), self.linear2(x)], -1)
|
23 |
-
x = self.linear(x)
|
24 |
-
x = GELU()(x)
|
25 |
-
temps = self.linear2(x).softmax(-1) @ torch.tensor(
|
26 |
-
[1.0, 1.4, 1.7, 2.0, 5.0, 10.0, 20.0, 40.0, 80.0, 160.0], device=x.device
|
27 |
-
)
|
28 |
-
if random.random() > 0.99:
|
29 |
-
print(temps.shape, temps[:, :2])
|
30 |
-
return self.linear1(x) / temps.unsqueeze(-1)
|
31 |
-
|
32 |
-
|
33 |
-
class FixedScaledDecoder(nn.Module):
|
34 |
-
def __init__(self, ninp, nhid, nout):
|
35 |
-
super().__init__()
|
36 |
-
self.mapper = nn.Sequential(
|
37 |
-
nn.Linear(ninp, nhid), nn.GELU(), nn.Linear(nhid, nout)
|
38 |
-
)
|
39 |
-
self.T = nn.Parameter(torch.ones(10000) / 10000)
|
40 |
-
|
41 |
-
def forward(self, x):
|
42 |
-
return self.mapper(x) / self.T.sum()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lcpfn/domhan_prior.py
DELETED
@@ -1,199 +0,0 @@
|
|
1 |
-
from functools import partial
|
2 |
-
import torch
|
3 |
-
import numpy as np
|
4 |
-
from lcpfn.curves import (
|
5 |
-
pow3,
|
6 |
-
ilog2,
|
7 |
-
janoschek,
|
8 |
-
log_power,
|
9 |
-
prior_ilog2,
|
10 |
-
uniform_prior_pow3,
|
11 |
-
weibull,
|
12 |
-
mmf,
|
13 |
-
vap,
|
14 |
-
loglog_linear,
|
15 |
-
exp4,
|
16 |
-
pow4,
|
17 |
-
dr_hill_zero_background,
|
18 |
-
)
|
19 |
-
from lcpfn.curves import (
|
20 |
-
prior_pow3,
|
21 |
-
prior_janoschek,
|
22 |
-
prior_log_power,
|
23 |
-
prior_weibull,
|
24 |
-
prior_mmf,
|
25 |
-
prior_vap,
|
26 |
-
prior_loglog_linear,
|
27 |
-
prior_exp4,
|
28 |
-
prior_pow4,
|
29 |
-
prior_dr_hill_zero_background,
|
30 |
-
)
|
31 |
-
from lcpfn.curves import (
|
32 |
-
uniform_prior_pow3,
|
33 |
-
uniform_prior_ilog2,
|
34 |
-
uniform_prior_janoschek,
|
35 |
-
)
|
36 |
-
|
37 |
-
|
38 |
-
def prior_weights(
|
39 |
-
rng,
|
40 |
-
components=[
|
41 |
-
"pow3",
|
42 |
-
"ilog2",
|
43 |
-
"janoschek",
|
44 |
-
"log_power",
|
45 |
-
"weibull",
|
46 |
-
"mmf",
|
47 |
-
"vap",
|
48 |
-
"loglog_linear",
|
49 |
-
"exp4",
|
50 |
-
"pow4",
|
51 |
-
"dr_hill_zero_background",
|
52 |
-
],
|
53 |
-
):
|
54 |
-
K = len(components)
|
55 |
-
weights = rng.uniform(0.0, 1, size=(K,))
|
56 |
-
return {f: weights[i] for i, f in enumerate(components)}
|
57 |
-
|
58 |
-
|
59 |
-
def sample_from_prior(rng, seq_len=100):
|
60 |
-
return sample_prior_comb(
|
61 |
-
rng=rng,
|
62 |
-
seq_len=seq_len,
|
63 |
-
components=["pow3", "ilog2", "janoschek"],
|
64 |
-
distribution="peaked",
|
65 |
-
)
|
66 |
-
|
67 |
-
|
68 |
-
def sample_prior_comb(
|
69 |
-
rng,
|
70 |
-
components,
|
71 |
-
distribution,
|
72 |
-
var_lnloc=-4,
|
73 |
-
var_lnscale=1,
|
74 |
-
range_constraint=True,
|
75 |
-
seq_len=100,
|
76 |
-
):
|
77 |
-
f_components = {
|
78 |
-
"pow3": pow3,
|
79 |
-
"ilog2": ilog2,
|
80 |
-
"janoschek": janoschek,
|
81 |
-
"log_power": log_power,
|
82 |
-
"weibull": weibull,
|
83 |
-
"mmf": mmf,
|
84 |
-
"vap": vap,
|
85 |
-
"loglog_linear": loglog_linear,
|
86 |
-
"exp4": exp4,
|
87 |
-
"pow4": pow4,
|
88 |
-
"dr_hill_zero_background": dr_hill_zero_background,
|
89 |
-
}
|
90 |
-
|
91 |
-
if distribution == "peaked":
|
92 |
-
f_priors = {
|
93 |
-
"pow3": prior_pow3,
|
94 |
-
"ilog2": prior_ilog2,
|
95 |
-
"janoschek": prior_janoschek,
|
96 |
-
"log_power": prior_log_power,
|
97 |
-
"weibull": prior_weibull,
|
98 |
-
"mmf": prior_mmf,
|
99 |
-
"vap": prior_vap,
|
100 |
-
"loglog_linear": prior_loglog_linear,
|
101 |
-
"exp4": prior_exp4,
|
102 |
-
"pow4": prior_pow4,
|
103 |
-
"dr_hill_zero_background": prior_dr_hill_zero_background,
|
104 |
-
}
|
105 |
-
elif distribution == "uniform":
|
106 |
-
f_priors = {
|
107 |
-
"pow3": uniform_prior_pow3,
|
108 |
-
"ilog2": uniform_prior_ilog2,
|
109 |
-
"janoschek": uniform_prior_janoschek,
|
110 |
-
}
|
111 |
-
else:
|
112 |
-
raise NotImplemented()
|
113 |
-
|
114 |
-
x = np.arange(1, seq_len + 1)
|
115 |
-
|
116 |
-
while True:
|
117 |
-
# sample the noiseless curve
|
118 |
-
weights = prior_weights(rng, components=components)
|
119 |
-
y = np.zeros(x.shape, dtype="float")
|
120 |
-
kwargs = 0
|
121 |
-
for f, w in weights.items():
|
122 |
-
kwargs = f_priors[f](rng)
|
123 |
-
# print(f_components[f](x, **kwargs))
|
124 |
-
y += w * f_components[f](x, **kwargs)
|
125 |
-
# add noise (can exceed [0,1], but afaik no way to implement this prior in Tobis work)
|
126 |
-
var = np.exp(
|
127 |
-
rng.normal(var_lnloc, var_lnscale)
|
128 |
-
) # @heri: ln_prob =+ log(normal.pdf(log(var), loc=var_lnloc, scale=var_lnscale))
|
129 |
-
|
130 |
-
# reject any curves that are non-increasing, exceed the [0,1] range
|
131 |
-
if (
|
132 |
-
y[-1] <= y[0]
|
133 |
-
or (range_constraint and (np.any(y < 0) or np.any(y > 1)))
|
134 |
-
or np.isnan(y).any()
|
135 |
-
):
|
136 |
-
continue
|
137 |
-
else:
|
138 |
-
break
|
139 |
-
|
140 |
-
def curve(): # generates a sample from the same model, but with independent noise
|
141 |
-
y_noisy = y + rng.normal(np.zeros_like(y), var)
|
142 |
-
return y, y_noisy
|
143 |
-
|
144 |
-
return curve
|
145 |
-
|
146 |
-
|
147 |
-
def generate_prior_dataset(n, prior=sample_prior_comb, seed=42):
|
148 |
-
"""
|
149 |
-
Returns a fixed sample from the prior (with fixed seq_len) as an n x seq_len np.ndarray
|
150 |
-
"""
|
151 |
-
rng = np.random.RandomState(seed)
|
152 |
-
prior_data = np.stack([prior(rng)()[1] for _ in range(n)])
|
153 |
-
return prior_data
|
154 |
-
|
155 |
-
|
156 |
-
def create_get_batch_func(prior):
|
157 |
-
return partial(get_batch_domhan, prior=prior)
|
158 |
-
|
159 |
-
|
160 |
-
# function producing batches for PFN training
|
161 |
-
def get_batch_domhan(
|
162 |
-
batch_size,
|
163 |
-
seq_len,
|
164 |
-
num_features,
|
165 |
-
prior,
|
166 |
-
device="cpu",
|
167 |
-
noisy_target=True,
|
168 |
-
**_,
|
169 |
-
):
|
170 |
-
assert num_features == 1
|
171 |
-
|
172 |
-
x = np.arange(1, seq_len + 1)
|
173 |
-
y_target = np.empty((batch_size, seq_len), dtype=float)
|
174 |
-
y_noisy = np.empty((batch_size, seq_len), dtype=float)
|
175 |
-
|
176 |
-
for i in range(batch_size):
|
177 |
-
curve_func = prior(np.random, seq_len=seq_len) # uses numpy rng
|
178 |
-
if noisy_target:
|
179 |
-
_, y_noisy[i] = curve_func()
|
180 |
-
y_target[i] = y_noisy[i]
|
181 |
-
else:
|
182 |
-
y_target[i], y_noisy[i] = curve_func()
|
183 |
-
|
184 |
-
# turn numpy arrays into correctly shaped torch tensors & move them to device
|
185 |
-
x = (
|
186 |
-
torch.arange(1, seq_len + 1)
|
187 |
-
.repeat((num_features, batch_size, 1))
|
188 |
-
.transpose(2, 0)
|
189 |
-
.to(device)
|
190 |
-
)
|
191 |
-
y_target = torch.from_numpy(y_target).transpose(1, 0).to(device)
|
192 |
-
y_noisy = torch.from_numpy(y_noisy).transpose(1, 0).to(device)
|
193 |
-
|
194 |
-
# changes
|
195 |
-
x = x.float()
|
196 |
-
y_target = y_target.float()
|
197 |
-
y_noisy = y_noisy.float()
|
198 |
-
|
199 |
-
return x, y_noisy, y_target
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lcpfn/encoders.py
DELETED
@@ -1,190 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
|
3 |
-
import torch
|
4 |
-
import torch.nn as nn
|
5 |
-
from lcpfn.utils import normalize_data
|
6 |
-
import torch.nn.functional as F
|
7 |
-
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
8 |
-
|
9 |
-
|
10 |
-
class StyleEncoder(nn.Module):
|
11 |
-
def __init__(self, em_size, hyperparameter_definitions):
|
12 |
-
super().__init__()
|
13 |
-
self.em_size = em_size
|
14 |
-
self.embedding = nn.Linear(hyperparameter_definitions.shape[0], self.em_size)
|
15 |
-
|
16 |
-
def forward(self, hyperparameters): # T x B x num_hps
|
17 |
-
return self.embedding(hyperparameters)
|
18 |
-
|
19 |
-
|
20 |
-
class _PositionalEncoding(nn.Module):
|
21 |
-
def __init__(self, d_model, dropout=0.0):
|
22 |
-
super().__init__()
|
23 |
-
self.dropout = nn.Dropout(p=dropout)
|
24 |
-
self.d_model = d_model
|
25 |
-
self.device_test_tensor = nn.Parameter(torch.tensor(1.0))
|
26 |
-
|
27 |
-
def forward(self, x): # T x B x num_features
|
28 |
-
assert self.d_model % x.shape[-1] * 2 == 0
|
29 |
-
d_per_feature = self.d_model // x.shape[-1]
|
30 |
-
pe = torch.zeros(*x.shape, d_per_feature, device=self.device_test_tensor.device)
|
31 |
-
# position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
32 |
-
interval_size = 10
|
33 |
-
div_term = (
|
34 |
-
(1.0 / interval_size)
|
35 |
-
* 2
|
36 |
-
* math.pi
|
37 |
-
* torch.exp(
|
38 |
-
torch.arange(
|
39 |
-
0, d_per_feature, 2, device=self.device_test_tensor.device
|
40 |
-
).float()
|
41 |
-
* math.log(math.sqrt(2))
|
42 |
-
)
|
43 |
-
)
|
44 |
-
# print(div_term/2/math.pi)
|
45 |
-
pe[..., 0::2] = torch.sin(x.unsqueeze(-1) * div_term)
|
46 |
-
pe[..., 1::2] = torch.cos(x.unsqueeze(-1) * div_term)
|
47 |
-
return self.dropout(pe).view(x.shape[0], x.shape[1], self.d_model)
|
48 |
-
|
49 |
-
|
50 |
-
Positional = lambda _, emsize: _PositionalEncoding(d_model=emsize)
|
51 |
-
|
52 |
-
|
53 |
-
class EmbeddingEncoder(nn.Module):
|
54 |
-
def __init__(self, num_features, em_size, num_embs=100):
|
55 |
-
super().__init__()
|
56 |
-
self.num_embs = num_embs
|
57 |
-
self.embeddings = nn.Embedding(num_embs * num_features, em_size, max_norm=True)
|
58 |
-
self.init_weights(0.1)
|
59 |
-
self.min_max = (-2, +2)
|
60 |
-
|
61 |
-
@property
|
62 |
-
def width(self):
|
63 |
-
return self.min_max[1] - self.min_max[0]
|
64 |
-
|
65 |
-
def init_weights(self, initrange):
|
66 |
-
self.embeddings.weight.data.uniform_(-initrange, initrange)
|
67 |
-
|
68 |
-
def discretize(self, x):
|
69 |
-
split_size = self.width / self.num_embs
|
70 |
-
return (x - self.min_max[0] // split_size).int().clamp(0, self.num_embs - 1)
|
71 |
-
|
72 |
-
def forward(self, x): # T x B x num_features
|
73 |
-
x_idxs = self.discretize(x)
|
74 |
-
x_idxs += (
|
75 |
-
torch.arange(x.shape[-1], device=x.device).view(1, 1, -1) * self.num_embs
|
76 |
-
)
|
77 |
-
# print(x_idxs,self.embeddings.weight.shape)
|
78 |
-
return self.embeddings(x_idxs).mean(-2)
|
79 |
-
|
80 |
-
|
81 |
-
class Normalize(nn.Module):
|
82 |
-
def __init__(self, mean, std):
|
83 |
-
super().__init__()
|
84 |
-
self.mean = mean
|
85 |
-
self.std = std
|
86 |
-
|
87 |
-
def forward(self, x):
|
88 |
-
return (x - self.mean) / self.std
|
89 |
-
|
90 |
-
|
91 |
-
def get_normalized_uniform_encoder(encoder_creator):
|
92 |
-
"""
|
93 |
-
This can be used to wrap an encoder that is fed uniform samples in [0,1] and normalizes these to 0 mean and 1 std.
|
94 |
-
For example, it can be used as `encoder_creator = get_normalized_uniform_encoder(encoders.Linear)`, now this can
|
95 |
-
be initialized with `encoder_creator(feature_dim, in_dim)`.
|
96 |
-
:param encoder:
|
97 |
-
:return:
|
98 |
-
"""
|
99 |
-
return lambda in_dim, out_dim: nn.Sequential(
|
100 |
-
Normalize(0.5, math.sqrt(1 / 12)), encoder_creator(in_dim, out_dim)
|
101 |
-
)
|
102 |
-
|
103 |
-
|
104 |
-
Linear = nn.Linear
|
105 |
-
MLP = lambda num_features, emsize: nn.Sequential(
|
106 |
-
nn.Linear(num_features + 1, emsize * 2), nn.ReLU(), nn.Linear(emsize * 2, emsize)
|
107 |
-
)
|
108 |
-
|
109 |
-
|
110 |
-
class NanHandlingEncoder(nn.Module):
|
111 |
-
def __init__(self, num_features, emsize, keep_nans=True):
|
112 |
-
super().__init__()
|
113 |
-
self.num_features = 2 * num_features if keep_nans else num_features
|
114 |
-
self.emsize = emsize
|
115 |
-
self.keep_nans = keep_nans
|
116 |
-
self.layer = nn.Linear(self.num_features, self.emsize)
|
117 |
-
|
118 |
-
def forward(self, x):
|
119 |
-
if self.keep_nans:
|
120 |
-
x = torch.cat(
|
121 |
-
[
|
122 |
-
torch.nan_to_num(x, nan=0.0),
|
123 |
-
normalize_data(
|
124 |
-
torch.isnan(x) * -1
|
125 |
-
+ torch.logical_and(torch.isinf(x), torch.sign(x) == 1) * 1
|
126 |
-
+ torch.logical_and(torch.isinf(x), torch.sign(x) == -1) * 2
|
127 |
-
),
|
128 |
-
],
|
129 |
-
-1,
|
130 |
-
)
|
131 |
-
else:
|
132 |
-
x = torch.nan_to_num(x, nan=0.0)
|
133 |
-
return self.layer(x)
|
134 |
-
|
135 |
-
|
136 |
-
class Linear(nn.Linear):
|
137 |
-
def __init__(self, num_features, emsize):
|
138 |
-
super().__init__(num_features, emsize)
|
139 |
-
self.num_features = num_features
|
140 |
-
self.emsize = emsize
|
141 |
-
|
142 |
-
def forward(self, x):
|
143 |
-
x = torch.nan_to_num(x, nan=0.0)
|
144 |
-
return super().forward(x)
|
145 |
-
|
146 |
-
|
147 |
-
class Conv(nn.Module):
|
148 |
-
def __init__(self, input_size, emsize):
|
149 |
-
super().__init__()
|
150 |
-
self.convs = torch.nn.ModuleList(
|
151 |
-
[nn.Conv2d(64 if i else 1, 64, 3) for i in range(5)]
|
152 |
-
)
|
153 |
-
self.linear = nn.Linear(64, emsize)
|
154 |
-
|
155 |
-
def forward(self, x):
|
156 |
-
size = math.isqrt(x.shape[-1])
|
157 |
-
assert size * size == x.shape[-1]
|
158 |
-
x = x.reshape(*x.shape[:-1], 1, size, size)
|
159 |
-
for conv in self.convs:
|
160 |
-
if x.shape[-1] < 4:
|
161 |
-
break
|
162 |
-
x = conv(x)
|
163 |
-
x.relu_()
|
164 |
-
x = nn.AdaptiveAvgPool2d((1, 1))(x).squeeze(-1).squeeze(-1)
|
165 |
-
return self.linear(x)
|
166 |
-
|
167 |
-
|
168 |
-
class CanEmb(nn.Embedding):
|
169 |
-
def __init__(
|
170 |
-
self, num_features, num_embeddings: int, embedding_dim: int, *args, **kwargs
|
171 |
-
):
|
172 |
-
assert embedding_dim % num_features == 0
|
173 |
-
embedding_dim = embedding_dim // num_features
|
174 |
-
super().__init__(num_embeddings, embedding_dim, *args, **kwargs)
|
175 |
-
|
176 |
-
def forward(self, x):
|
177 |
-
lx = x.long()
|
178 |
-
assert (lx == x).all(), "CanEmb only works with tensors of whole numbers"
|
179 |
-
x = super().forward(lx)
|
180 |
-
return x.view(*x.shape[:-2], -1)
|
181 |
-
|
182 |
-
|
183 |
-
def get_Canonical(num_classes):
|
184 |
-
return lambda num_features, emsize: CanEmb(num_features, num_classes, emsize)
|
185 |
-
|
186 |
-
|
187 |
-
def get_Embedding(num_embs_per_feature=100):
|
188 |
-
return lambda num_features, emsize: EmbeddingEncoder(
|
189 |
-
num_features, emsize, num_embs=num_embs_per_feature
|
190 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lcpfn/initializers.py
DELETED
@@ -1,11 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from torch import nn
|
3 |
-
|
4 |
-
|
5 |
-
def get_NormalInitializer(std):
|
6 |
-
def initializer(m):
|
7 |
-
if isinstance(m, nn.Linear):
|
8 |
-
nn.init.normal_(m.weight, 0, std)
|
9 |
-
nn.init.normal_(m.bias, 0, std)
|
10 |
-
|
11 |
-
return initializer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lcpfn/layer.py
DELETED
@@ -1,179 +0,0 @@
|
|
1 |
-
from functools import partial
|
2 |
-
from typing import Optional
|
3 |
-
from torch import Tensor
|
4 |
-
from torch import nn
|
5 |
-
from torch.nn.modules.transformer import *
|
6 |
-
from torch.nn.modules.transformer import _get_activation_fn
|
7 |
-
|
8 |
-
from torch.utils.checkpoint import checkpoint
|
9 |
-
|
10 |
-
|
11 |
-
class TransformerEncoderLayer(nn.Module):
|
12 |
-
r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
|
13 |
-
This standard encoder layer is based on the paper "Attention Is All You Need".
|
14 |
-
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
15 |
-
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
16 |
-
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
|
17 |
-
in a different way during application.
|
18 |
-
|
19 |
-
Args:
|
20 |
-
d_model: the number of expected features in the input (required).
|
21 |
-
nhead: the number of heads in the multiheadattention models (required).
|
22 |
-
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
23 |
-
dropout: the dropout value (default=0.1).
|
24 |
-
activation: the activation function of intermediate layer, relu or gelu (default=relu).
|
25 |
-
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
|
26 |
-
batch_first: If ``True``, then the input and output tensors are provided
|
27 |
-
as (batch, seq, feature). Default: ``False``.
|
28 |
-
|
29 |
-
Examples::
|
30 |
-
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
|
31 |
-
>>> src = torch.rand(10, 32, 512)
|
32 |
-
>>> out = encoder_layer(src)
|
33 |
-
|
34 |
-
Alternatively, when ``batch_first`` is ``True``:
|
35 |
-
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
|
36 |
-
>>> src = torch.rand(32, 10, 512)
|
37 |
-
>>> out = encoder_layer(src)
|
38 |
-
"""
|
39 |
-
|
40 |
-
__constants__ = ["batch_first"]
|
41 |
-
|
42 |
-
def __init__(
|
43 |
-
self,
|
44 |
-
d_model,
|
45 |
-
nhead,
|
46 |
-
dim_feedforward=2048,
|
47 |
-
dropout=0.1,
|
48 |
-
activation="relu",
|
49 |
-
layer_norm_eps=1e-5,
|
50 |
-
batch_first=False,
|
51 |
-
pre_norm=False,
|
52 |
-
device=None,
|
53 |
-
dtype=None,
|
54 |
-
recompute_attn=False,
|
55 |
-
) -> None:
|
56 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
57 |
-
super().__init__()
|
58 |
-
self.self_attn = MultiheadAttention(
|
59 |
-
d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs
|
60 |
-
)
|
61 |
-
# Implementation of Feedforward model
|
62 |
-
self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
|
63 |
-
self.dropout = Dropout(dropout)
|
64 |
-
self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)
|
65 |
-
|
66 |
-
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
67 |
-
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
68 |
-
self.dropout1 = Dropout(dropout)
|
69 |
-
self.dropout2 = Dropout(dropout)
|
70 |
-
self.pre_norm = pre_norm
|
71 |
-
self.recompute_attn = recompute_attn
|
72 |
-
|
73 |
-
self.activation = _get_activation_fn(activation)
|
74 |
-
|
75 |
-
def __setstate__(self, state):
|
76 |
-
if "activation" not in state:
|
77 |
-
state["activation"] = F.relu
|
78 |
-
super().__setstate__(state)
|
79 |
-
|
80 |
-
def forward(
|
81 |
-
self,
|
82 |
-
src: Tensor,
|
83 |
-
src_mask: Optional[Tensor] = None,
|
84 |
-
src_key_padding_mask: Optional[Tensor] = None,
|
85 |
-
) -> Tensor:
|
86 |
-
r"""Pass the input through the encoder layer.
|
87 |
-
|
88 |
-
Args:
|
89 |
-
src: the sequence to the encoder layer (required).
|
90 |
-
src_mask: the mask for the src sequence (optional).
|
91 |
-
src_key_padding_mask: the mask for the src keys per batch (optional).
|
92 |
-
|
93 |
-
Shape:
|
94 |
-
see the docs in Transformer class.
|
95 |
-
"""
|
96 |
-
if self.pre_norm:
|
97 |
-
src_ = self.norm1(src)
|
98 |
-
else:
|
99 |
-
src_ = src
|
100 |
-
if isinstance(src_mask, tuple):
|
101 |
-
# global attention setup
|
102 |
-
assert not self.self_attn.batch_first
|
103 |
-
assert src_key_padding_mask is None
|
104 |
-
|
105 |
-
global_src_mask, trainset_src_mask, valset_src_mask = src_mask
|
106 |
-
|
107 |
-
num_global_tokens = global_src_mask.shape[0]
|
108 |
-
num_train_tokens = trainset_src_mask.shape[0]
|
109 |
-
|
110 |
-
global_tokens_src = src_[:num_global_tokens]
|
111 |
-
train_tokens_src = src_[
|
112 |
-
num_global_tokens : num_global_tokens + num_train_tokens
|
113 |
-
]
|
114 |
-
global_and_train_tokens_src = src_[: num_global_tokens + num_train_tokens]
|
115 |
-
eval_tokens_src = src_[num_global_tokens + num_train_tokens :]
|
116 |
-
|
117 |
-
attn = (
|
118 |
-
partial(checkpoint, self.self_attn)
|
119 |
-
if self.recompute_attn
|
120 |
-
else self.self_attn
|
121 |
-
)
|
122 |
-
|
123 |
-
global_tokens_src2 = attn(
|
124 |
-
global_tokens_src,
|
125 |
-
global_and_train_tokens_src,
|
126 |
-
global_and_train_tokens_src,
|
127 |
-
None,
|
128 |
-
True,
|
129 |
-
global_src_mask,
|
130 |
-
)[0]
|
131 |
-
train_tokens_src2 = attn(
|
132 |
-
train_tokens_src,
|
133 |
-
global_tokens_src,
|
134 |
-
global_tokens_src,
|
135 |
-
None,
|
136 |
-
True,
|
137 |
-
trainset_src_mask,
|
138 |
-
)[0]
|
139 |
-
eval_tokens_src2 = attn(
|
140 |
-
eval_tokens_src, src_, src_, None, True, valset_src_mask
|
141 |
-
)[0]
|
142 |
-
|
143 |
-
src2 = torch.cat(
|
144 |
-
[global_tokens_src2, train_tokens_src2, eval_tokens_src2], dim=0
|
145 |
-
)
|
146 |
-
|
147 |
-
else:
|
148 |
-
if self.recompute_attn:
|
149 |
-
src2 = checkpoint(
|
150 |
-
self.self_attn,
|
151 |
-
src_,
|
152 |
-
src_,
|
153 |
-
src_,
|
154 |
-
src_key_padding_mask,
|
155 |
-
True,
|
156 |
-
src_mask,
|
157 |
-
)[0]
|
158 |
-
else:
|
159 |
-
src2 = self.self_attn(
|
160 |
-
src_,
|
161 |
-
src_,
|
162 |
-
src_,
|
163 |
-
attn_mask=src_mask,
|
164 |
-
key_padding_mask=src_key_padding_mask,
|
165 |
-
)[0]
|
166 |
-
src = src + self.dropout1(src2)
|
167 |
-
if not self.pre_norm:
|
168 |
-
src = self.norm1(src)
|
169 |
-
|
170 |
-
if self.pre_norm:
|
171 |
-
src_ = self.norm2(src)
|
172 |
-
else:
|
173 |
-
src_ = src
|
174 |
-
src2 = self.linear2(self.dropout(self.activation(self.linear1(src_))))
|
175 |
-
src = src + self.dropout2(src2)
|
176 |
-
|
177 |
-
if not self.pre_norm:
|
178 |
-
src = self.norm2(src)
|
179 |
-
return src
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lcpfn/model.py
DELETED
@@ -1,56 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import lcpfn
|
3 |
-
import warnings
|
4 |
-
from lcpfn import utils
|
5 |
-
|
6 |
-
|
7 |
-
class LCPFN(torch.nn.Module):
|
8 |
-
def __init__(self, model_name="EMSIZE512_NLAYERS12_NBUCKETS1000"):
|
9 |
-
super(LCPFN, self).__init__()
|
10 |
-
self.model = torch.load(
|
11 |
-
getattr(lcpfn, model_name) if model_name in lcpfn.model_dict else model_name
|
12 |
-
)
|
13 |
-
self.model.eval()
|
14 |
-
|
15 |
-
def check_input(self, x_train, x_test, y_train, y_test=None):
|
16 |
-
if torch.any(x_train < 0) or torch.any(x_test < 0):
|
17 |
-
# raise warning if input has negative values
|
18 |
-
raise Exception("x values should be non-negative")
|
19 |
-
if torch.any((0 > y_train) | (y_train > 1)) or (
|
20 |
-
y_test is not None and torch.any(0 < y_test < 1)
|
21 |
-
):
|
22 |
-
# raise warning if input has values outside [0,1]
|
23 |
-
raise Exception(
|
24 |
-
"y values should be in the range [0,1]. Please set normalizer_kwargs accordingly."
|
25 |
-
)
|
26 |
-
|
27 |
-
@torch.no_grad()
|
28 |
-
def predict_mean(
|
29 |
-
self, x_train, y_train, x_test, normalizer=utils.identity_normalizer()
|
30 |
-
):
|
31 |
-
y_train_norm = normalizer[0](y_train)
|
32 |
-
logits = self(x_train=x_train, y_train=y_train_norm, x_test=x_test)
|
33 |
-
return normalizer[1](self.model.criterion.mean(logits))
|
34 |
-
|
35 |
-
@torch.no_grad()
|
36 |
-
def predict_quantiles(
|
37 |
-
self, x_train, y_train, x_test, qs, normalizer=utils.identity_normalizer()
|
38 |
-
):
|
39 |
-
y_train_norm = normalizer[0](y_train)
|
40 |
-
logits = self(x_train=x_train, y_train=y_train_norm, x_test=x_test)
|
41 |
-
return normalizer[1](
|
42 |
-
torch.cat([self.model.criterion.icdf(logits, q) for q in qs], dim=1)
|
43 |
-
)
|
44 |
-
|
45 |
-
@torch.no_grad()
|
46 |
-
def nll_loss(self, x_train, y_train, x_test, y_test):
|
47 |
-
# TODO add normalizer_kwargs
|
48 |
-
logits = self(x_train=x_train, y_train=y_train, x_test=x_test)
|
49 |
-
return self.model.criterion(logits, y_test)
|
50 |
-
|
51 |
-
def forward(self, x_train, y_train, x_test):
|
52 |
-
self.check_input(x_train, x_test, y_train)
|
53 |
-
single_eval_pos = x_train.shape[0]
|
54 |
-
x = torch.cat([x_train, x_test], dim=0).unsqueeze(1)
|
55 |
-
y = y_train.unsqueeze(1)
|
56 |
-
return self.model((x, y), single_eval_pos=single_eval_pos)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lcpfn/positional_encodings.py
DELETED
@@ -1,78 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
|
3 |
-
import torch
|
4 |
-
from torch import nn
|
5 |
-
|
6 |
-
|
7 |
-
# Protocol for positonal encodings.
|
8 |
-
# __init__(d_model, max_len=..[, more optionals])
|
9 |
-
# forward(x: (seq_len, bs, d_model)) -> Tensor of shape (*x.shape[:2],d_model) containing pos. embeddings
|
10 |
-
|
11 |
-
|
12 |
-
class NoPositionalEncoding(nn.Module):
|
13 |
-
def __init__(self, d_model, max_len=None):
|
14 |
-
super(NoPositionalEncoding, self).__init__()
|
15 |
-
pass
|
16 |
-
|
17 |
-
def forward(self, x):
|
18 |
-
return x # * math.sqrt(x.shape[-1])
|
19 |
-
|
20 |
-
|
21 |
-
class PositionalEncoding(nn.Module):
|
22 |
-
def __init__(self, d_model, max_len=5000):
|
23 |
-
super(PositionalEncoding, self).__init__()
|
24 |
-
pe = torch.zeros(max_len, d_model)
|
25 |
-
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
26 |
-
div_term = torch.exp(
|
27 |
-
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
|
28 |
-
)
|
29 |
-
pe[:, 0::2] = torch.sin(position * div_term)
|
30 |
-
pe[:, 1::2] = torch.cos(position * div_term)
|
31 |
-
pe = pe.unsqueeze(0).transpose(0, 1)
|
32 |
-
self.register_buffer("pe", pe)
|
33 |
-
|
34 |
-
def forward(self, x):
|
35 |
-
x = self.pe[: x.size(0), :] + x # * math.sqrt(x.shape[-1])
|
36 |
-
return x
|
37 |
-
|
38 |
-
|
39 |
-
class LearnedPositionalEncoding(nn.Module):
|
40 |
-
def __init__(self, d_model, max_len=5000):
|
41 |
-
super(LearnedPositionalEncoding, self).__init__()
|
42 |
-
self.max_seq_len = max_len
|
43 |
-
# self.positional_embeddings = nn.Embedding(max_len, d_model)
|
44 |
-
self.positional_embeddings = nn.Parameter(torch.empty(max_len, d_model))
|
45 |
-
nn.init.normal_(self.positional_embeddings, mean=0, std=d_model**-0.5)
|
46 |
-
|
47 |
-
def forward(self, x):
|
48 |
-
seq_len, bs, d_model = x.shape
|
49 |
-
assert seq_len <= len(
|
50 |
-
self.positional_embeddings
|
51 |
-
), "seq_len can be at most max_len."
|
52 |
-
pos_emb = self.positional_embeddings[:seq_len]
|
53 |
-
return (
|
54 |
-
pos_emb.unsqueeze(1).expand(seq_len, bs, d_model) + x
|
55 |
-
) # * math.sqrt(x.shape[-1])
|
56 |
-
|
57 |
-
|
58 |
-
class PairedScrambledPositionalEncodings(LearnedPositionalEncoding):
|
59 |
-
# TODO check whether it is a problem to use the same perm. for full batch
|
60 |
-
def forward(self, x):
|
61 |
-
seq_len, bs, d_model = x.shape
|
62 |
-
assert seq_len <= len(
|
63 |
-
self.positional_embeddings
|
64 |
-
), "seq_len can be at most max_len."
|
65 |
-
assert (
|
66 |
-
len(self.positional_embeddings) % 2 == 0
|
67 |
-
), "Please specify an even max_len."
|
68 |
-
|
69 |
-
paired_embs = self.positional_embeddings.view(
|
70 |
-
len(self.positional_embeddings), -1, 2
|
71 |
-
)
|
72 |
-
pos_emb = paired_embs[torch.randperm(len(paired_embs))].view(
|
73 |
-
*self.positional_embeddings.shape
|
74 |
-
)[:seq_len]
|
75 |
-
|
76 |
-
return (
|
77 |
-
pos_emb.unsqueeze(1).expand(seq_len, bs, d_model) + x
|
78 |
-
) # * math.sqrt(x.shape[-1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lcpfn/priors/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from . import gp, ridge
|
|
|
|
lcpfn/priors/binarized_regression.py
DELETED
@@ -1,19 +0,0 @@
|
|
1 |
-
from . import fast_gp, fast_gp_mix
|
2 |
-
from .utils import get_batch_to_dataloader
|
3 |
-
|
4 |
-
def regression_prior_to_binary(get_batch_function):
|
5 |
-
|
6 |
-
def binarized_get_batch_function(*args, assert_on=False, **kwargs):
|
7 |
-
x, y, target_y = get_batch_function(*args, **kwargs)
|
8 |
-
if assert_on:
|
9 |
-
assert y is target_y, "y == target_y is assumed by this function"
|
10 |
-
y = y.sigmoid().bernoulli()
|
11 |
-
return x, y, y
|
12 |
-
|
13 |
-
return binarized_get_batch_function
|
14 |
-
|
15 |
-
|
16 |
-
Binarized_fast_gp_dataloader = get_batch_to_dataloader(regression_prior_to_binary(fast_gp.get_batch))
|
17 |
-
|
18 |
-
|
19 |
-
Binarized_fast_gp_mix_dataloader = get_batch_to_dataloader(regression_prior_to_binary(fast_gp_mix.get_batch))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lcpfn/priors/fast_gp.py
DELETED
@@ -1,143 +0,0 @@
|
|
1 |
-
import time
|
2 |
-
|
3 |
-
import torch
|
4 |
-
from torch import nn
|
5 |
-
import gpytorch
|
6 |
-
|
7 |
-
from .utils import get_batch_to_dataloader
|
8 |
-
from utils import default_device
|
9 |
-
|
10 |
-
|
11 |
-
# We will use the simplest form of GP model, exact inference
|
12 |
-
class ExactGPModel(gpytorch.models.ExactGP):
|
13 |
-
def __init__(self, train_x, train_y, likelihood):
|
14 |
-
super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
|
15 |
-
self.mean_module = gpytorch.means.ConstantMean()
|
16 |
-
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
|
17 |
-
|
18 |
-
def forward(self, x):
|
19 |
-
mean_x = self.mean_module(x)
|
20 |
-
covar_x = self.covar_module(x)
|
21 |
-
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
|
22 |
-
|
23 |
-
|
24 |
-
def get_model(x, y, hyperparameters):
|
25 |
-
likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.GreaterThan(1.e-9))
|
26 |
-
model = ExactGPModel(x, y, likelihood)
|
27 |
-
model.likelihood.noise = torch.ones_like(model.likelihood.noise) * hyperparameters["noise"]
|
28 |
-
model.covar_module.outputscale = torch.ones_like(model.covar_module.outputscale) * hyperparameters["outputscale"]
|
29 |
-
model.covar_module.base_kernel.lengthscale = torch.ones_like(model.covar_module.base_kernel.lengthscale) * \
|
30 |
-
hyperparameters["lengthscale"]
|
31 |
-
return model, likelihood
|
32 |
-
|
33 |
-
|
34 |
-
@torch.no_grad()
|
35 |
-
def get_batch(batch_size, seq_len, num_features, device=default_device, hyperparameters=None,
|
36 |
-
equidistant_x=False, fix_x=None, **kwargs):
|
37 |
-
if isinstance(hyperparameters, (tuple, list)):
|
38 |
-
hyperparameters = {"noise": hyperparameters[0]
|
39 |
-
, "outputscale": hyperparameters[1]
|
40 |
-
, "lengthscale": hyperparameters[2]
|
41 |
-
, "is_binary_classification": hyperparameters[3]
|
42 |
-
# , "num_features_used": hyperparameters[4]
|
43 |
-
, "normalize_by_used_features": hyperparameters[5]
|
44 |
-
, "order_y": hyperparameters[6]
|
45 |
-
, "sampling": hyperparameters[7]
|
46 |
-
}
|
47 |
-
elif hyperparameters is None:
|
48 |
-
hyperparameters = {"noise": .1, "outputscale": .1, "lengthscale": .1}
|
49 |
-
|
50 |
-
if 'verbose' in hyperparameters and hyperparameters['verbose']:
|
51 |
-
print({"noise": hyperparameters['noise'], "outputscale": hyperparameters['outputscale']
|
52 |
-
, "lengthscale": hyperparameters['lengthscale'], 'batch_size': batch_size, 'sampling': hyperparameters['sampling']})
|
53 |
-
|
54 |
-
# hyperparameters = {k: hyperparameters[k]() if callable(hyperparameters[k]) else hyperparameters[k] for k in
|
55 |
-
# hyperparameters.keys()}
|
56 |
-
assert not (equidistant_x and (fix_x is not None))
|
57 |
-
|
58 |
-
with gpytorch.settings.fast_computations(*hyperparameters.get('fast_computations', (True, True, True))):
|
59 |
-
if equidistant_x:
|
60 |
-
assert num_features == 1
|
61 |
-
x = torch.linspace(0, 1., seq_len).unsqueeze(0).repeat(batch_size, 1).unsqueeze(-1)
|
62 |
-
elif fix_x is not None:
|
63 |
-
assert fix_x.shape == (seq_len, num_features)
|
64 |
-
x = fix_x.unsqueeze(0).repeat(batch_size, 1, 1).to(device)
|
65 |
-
else:
|
66 |
-
if hyperparameters.get('sampling','uniform') == 'uniform':
|
67 |
-
x = torch.rand(batch_size, seq_len, num_features, device=device)
|
68 |
-
else:
|
69 |
-
x = torch.randn(batch_size, seq_len, num_features, device=device)
|
70 |
-
model, likelihood = get_model(x, torch.Tensor(), hyperparameters)
|
71 |
-
model.to(device)
|
72 |
-
# trained_model = ExactGPModel(train_x, train_y, likelihood).cuda()
|
73 |
-
# trained_model.eval()
|
74 |
-
successful_sample = False
|
75 |
-
while not successful_sample:
|
76 |
-
try:
|
77 |
-
with gpytorch.settings.prior_mode(True):
|
78 |
-
model, likelihood = get_model(x, torch.Tensor(), hyperparameters)
|
79 |
-
model.to(device)
|
80 |
-
|
81 |
-
d = model(x)
|
82 |
-
sample_wo_noise = d.sample().transpose(0, 1) # this will be the target for the loss
|
83 |
-
sample = likelihood(sample_wo_noise).sample() # this will be the input to the Transformer
|
84 |
-
successful_sample = True
|
85 |
-
except RuntimeError: # This can happen when torch.linalg.eigh fails. Restart with new init resolves this.
|
86 |
-
print('GP Sampling unsuccessful, retrying.. ')
|
87 |
-
print(x)
|
88 |
-
print(hyperparameters)
|
89 |
-
|
90 |
-
if bool(torch.any(torch.isnan(x)).detach().cpu().numpy()):
|
91 |
-
print({"noise": hyperparameters['noise'], "outputscale": hyperparameters['outputscale']
|
92 |
-
, "lengthscale": hyperparameters['lengthscale'], 'batch_size': batch_size})
|
93 |
-
|
94 |
-
# TODO: Multi output
|
95 |
-
return x.transpose(0, 1), sample, sample if hyperparameters.get("observation_noise", True) else sample_wo_noise
|
96 |
-
|
97 |
-
DataLoader = get_batch_to_dataloader(get_batch)
|
98 |
-
|
99 |
-
def get_model_on_device(x,y,hyperparameters,device):
|
100 |
-
model, likelihood = get_model(x, y, hyperparameters)
|
101 |
-
model.to(device)
|
102 |
-
return model, likelihood
|
103 |
-
|
104 |
-
|
105 |
-
@torch.no_grad()
|
106 |
-
def evaluate(x, y, y_non_noisy, use_mse=False, hyperparameters={}, get_model_on_device=get_model_on_device, device=default_device, step_size=1, start_pos=0):
|
107 |
-
start_time = time.time()
|
108 |
-
losses_after_t = [.0] if start_pos == 0 else []
|
109 |
-
all_losses_after_t = []
|
110 |
-
|
111 |
-
with gpytorch.settings.fast_computations(*hyperparameters.get('fast_computations',(True,True,True))), gpytorch.settings.fast_pred_var(False):
|
112 |
-
for t in range(max(start_pos, 1), len(x), step_size):
|
113 |
-
loss_sum = 0.
|
114 |
-
model, likelihood = get_model_on_device(x[:t].transpose(0, 1), y[:t].transpose(0, 1), hyperparameters, device)
|
115 |
-
|
116 |
-
|
117 |
-
model.eval()
|
118 |
-
# print([t.shape for t in model.train_inputs])
|
119 |
-
# print(x[:t].transpose(0,1).shape, x[t].unsqueeze(1).shape, y[:t].transpose(0,1).shape)
|
120 |
-
f = model(x[t].unsqueeze(1))
|
121 |
-
l = likelihood(f)
|
122 |
-
means = l.mean.squeeze()
|
123 |
-
varis = l.covariance_matrix.squeeze()
|
124 |
-
# print(l.variance.squeeze(), l.mean.squeeze(), y[t])
|
125 |
-
|
126 |
-
assert len(means.shape) == len(varis.shape) == 1
|
127 |
-
assert len(means) == len(varis) == x.shape[1]
|
128 |
-
|
129 |
-
if use_mse:
|
130 |
-
c = nn.MSELoss(reduction='none')
|
131 |
-
ls = c(means, y[t])
|
132 |
-
else:
|
133 |
-
ls = -l.log_prob(y[t].unsqueeze(1))
|
134 |
-
|
135 |
-
losses_after_t.append(ls.mean())
|
136 |
-
all_losses_after_t.append(ls.flatten())
|
137 |
-
return torch.stack(all_losses_after_t).to('cpu'), torch.tensor(losses_after_t).to('cpu'), time.time() - start_time
|
138 |
-
|
139 |
-
if __name__ == '__main__':
|
140 |
-
hps = (.1,.1,.1)
|
141 |
-
for redo_idx in range(1):
|
142 |
-
print(
|
143 |
-
evaluate(*get_batch(1000, 10, hyperparameters=hps, num_features=10), use_mse=False, hyperparameters=hps))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lcpfn/priors/fast_gp_mix.py
DELETED
@@ -1,394 +0,0 @@
|
|
1 |
-
import time
|
2 |
-
import functools
|
3 |
-
import random
|
4 |
-
import math
|
5 |
-
import traceback
|
6 |
-
|
7 |
-
import numpy as np
|
8 |
-
import torch
|
9 |
-
from torch import nn
|
10 |
-
import gpytorch
|
11 |
-
from botorch.models import SingleTaskGP
|
12 |
-
from botorch.models.gp_regression import MIN_INFERRED_NOISE_LEVEL
|
13 |
-
from botorch.fit import fit_gpytorch_model
|
14 |
-
from gpytorch.mlls import ExactMarginalLogLikelihood
|
15 |
-
from gpytorch.likelihoods import GaussianLikelihood
|
16 |
-
from gpytorch.priors.torch_priors import GammaPrior, UniformPrior
|
17 |
-
from gpytorch.constraints import GreaterThan
|
18 |
-
|
19 |
-
|
20 |
-
from bar_distribution import BarDistribution
|
21 |
-
from utils import default_device
|
22 |
-
from .utils import get_batch_to_dataloader
|
23 |
-
from . import fast_gp
|
24 |
-
|
25 |
-
def get_model(x, y, hyperparameters: dict, sample=True):
|
26 |
-
if hyperparameters.get('handmade', False):
|
27 |
-
# We will use the simplest form of GP model, exact inference
|
28 |
-
class ExactGPModel(gpytorch.models.ExactGP):
|
29 |
-
def __init__(self, train_x, train_y, likelihood):
|
30 |
-
super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
|
31 |
-
self.mean_module = gpytorch.means.ConstantMean()
|
32 |
-
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.MaternKernel())
|
33 |
-
self.mean_module.register_prior("mean_prior", UniformPrior(-1, 1), "constant")
|
34 |
-
self.covar_module.base_kernel.register_prior("lengthscale_prior", UniformPrior(0.01, 0.5),
|
35 |
-
"lengthscale")
|
36 |
-
# model.covar_module.base_kernel.register_prior("period_length_prior", UniformPrior(0.05, 2.5), "period_length")
|
37 |
-
self.covar_module.register_prior("outputscale_prior", UniformPrior(1, 2), "outputscale")
|
38 |
-
likelihood.register_prior("noise_prior", UniformPrior(0.001, 0.01), "noise")
|
39 |
-
self.to(x)
|
40 |
-
|
41 |
-
def forward(self, x):
|
42 |
-
mean_x = self.mean_module(x)
|
43 |
-
covar_x = self.covar_module(x)
|
44 |
-
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
|
45 |
-
|
46 |
-
likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.Positive())
|
47 |
-
model = ExactGPModel(x, y, likelihood)
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
else:
|
52 |
-
aug_batch_shape = SingleTaskGP(x,y.unsqueeze(-1))._aug_batch_shape
|
53 |
-
noise_prior = GammaPrior(hyperparameters.get('noise_concentration',1.1), hyperparameters.get('noise_rate',0.05))
|
54 |
-
noise_prior_mode = (noise_prior.concentration - 1) / noise_prior.rate
|
55 |
-
likelihood = GaussianLikelihood(
|
56 |
-
noise_prior=noise_prior,
|
57 |
-
batch_shape=aug_batch_shape,
|
58 |
-
noise_constraint=GreaterThan(
|
59 |
-
MIN_INFERRED_NOISE_LEVEL,
|
60 |
-
transform=None,
|
61 |
-
initial_value=noise_prior_mode,
|
62 |
-
),
|
63 |
-
)
|
64 |
-
model = SingleTaskGP(x, y.unsqueeze(-1),
|
65 |
-
covar_module=gpytorch.kernels.ScaleKernel(
|
66 |
-
gpytorch.kernels.MaternKernel(
|
67 |
-
nu=hyperparameters.get('nu',2.5),
|
68 |
-
ard_num_dims=x.shape[-1],
|
69 |
-
batch_shape=aug_batch_shape,
|
70 |
-
lengthscale_prior=gpytorch.priors.GammaPrior(hyperparameters.get('lengthscale_concentration',3.0), hyperparameters.get('lengthscale_rate',6.0)),
|
71 |
-
),
|
72 |
-
batch_shape=aug_batch_shape,
|
73 |
-
outputscale_prior=gpytorch.priors.GammaPrior(hyperparameters.get('outputscale_concentration',.5), hyperparameters.get('outputscale_rate',0.15)),
|
74 |
-
), likelihood=likelihood)
|
75 |
-
|
76 |
-
likelihood = model.likelihood
|
77 |
-
model.to(x.device)
|
78 |
-
if sample:
|
79 |
-
sampled_model = model.pyro_sample_from_prior()
|
80 |
-
return sampled_model, sampled_model.likelihood
|
81 |
-
else:
|
82 |
-
assert not(hyperparameters.get('sigmoid', False)) and not(hyperparameters.get('y_minmax_norm', False)), "Sigmoid and y_minmax_norm can only be used to sample models..."
|
83 |
-
return model, likelihood
|
84 |
-
|
85 |
-
|
86 |
-
@torch.no_grad()
|
87 |
-
def get_batch(batch_size, seq_len, num_features, device=default_device, hyperparameters=None,
|
88 |
-
batch_size_per_gp_sample=None,
|
89 |
-
fix_to_range=None, equidistant_x=False, **kwargs):
|
90 |
-
'''
|
91 |
-
This function is very similar to the equivalent in .fast_gp. The only difference is that this function operates over
|
92 |
-
a mixture of GP priors.
|
93 |
-
:param batch_size:
|
94 |
-
:param seq_len:
|
95 |
-
:param num_features:
|
96 |
-
:param device:
|
97 |
-
:param hyperparameters:
|
98 |
-
:param for_regression:
|
99 |
-
:return:
|
100 |
-
'''
|
101 |
-
hyperparameters = hyperparameters or {}
|
102 |
-
with gpytorch.settings.fast_computations(*hyperparameters.get('fast_computations',(True,True,True))):
|
103 |
-
batch_size_per_gp_sample = (batch_size_per_gp_sample or max(batch_size // 10,1))
|
104 |
-
assert batch_size % batch_size_per_gp_sample == 0
|
105 |
-
|
106 |
-
total_num_candidates = batch_size*(2**(fix_to_range is not None))
|
107 |
-
num_candidates = batch_size_per_gp_sample * (2**(fix_to_range is not None))
|
108 |
-
if equidistant_x:
|
109 |
-
assert num_features == 1
|
110 |
-
x = torch.linspace(0,1.,seq_len).unsqueeze(0).repeat(total_num_candidates,1).unsqueeze(-1)
|
111 |
-
else:
|
112 |
-
x = torch.rand(total_num_candidates, seq_len, num_features, device=device)
|
113 |
-
samples = []
|
114 |
-
samples_wo_noise = []
|
115 |
-
for i in range(0,total_num_candidates,num_candidates):
|
116 |
-
model, likelihood = get_model(x[i:i+num_candidates], torch.zeros(num_candidates,x.shape[1]).to(device), hyperparameters)
|
117 |
-
model.to(device)
|
118 |
-
likelihood.to(device)
|
119 |
-
if hyperparameters.get('handmade', False):
|
120 |
-
model.covar_module.base_kernel.lengthscale = model.covar_module.base_kernel.lengthscale.to(device)
|
121 |
-
model.covar_module.outputscale = model.covar_module.outputscale.to(device)
|
122 |
-
likelihood.noise = likelihood.noise.to(device)
|
123 |
-
model.mean_module.constant = model.mean_module.constant.to(device)
|
124 |
-
|
125 |
-
# trained_model = ExactGPModel(train_x, train_y, likelihood).cuda()
|
126 |
-
# trained_model.eval()
|
127 |
-
successful_sample = 0
|
128 |
-
throwaway_share = 0.
|
129 |
-
sampling_with_observation_noise = hyperparameters.get("observation_noise", True)
|
130 |
-
while successful_sample < 1:
|
131 |
-
with gpytorch.settings.prior_mode(True):
|
132 |
-
#print(x.device, device, f'{model.covar_module.base_kernel.lengthscale=}, {model.covar_module.base_kernel.lengthscale.device=}')
|
133 |
-
|
134 |
-
|
135 |
-
if sampling_with_observation_noise :
|
136 |
-
d = model(x[i:i+num_candidates])
|
137 |
-
d = likelihood(d)
|
138 |
-
sample = d.sample() # bs_per_gp_s x T
|
139 |
-
|
140 |
-
else:
|
141 |
-
d = model(x[i:i+num_candidates])
|
142 |
-
sample_wo_noise = d.sample()
|
143 |
-
sample = likelihood(sample_wo_noise).sample()
|
144 |
-
|
145 |
-
if hyperparameters.get('y_minmax_norm'):
|
146 |
-
sample = ((sample - sample.min(1)[0]) / (sample.max(1)[0] - sample.min(1)[0]))
|
147 |
-
if hyperparameters.get('sigmoid'):
|
148 |
-
sample = sample.sigmoid()
|
149 |
-
|
150 |
-
if not sampling_with_observation_noise:
|
151 |
-
if hyperparameters.get('y_minmax_norm'):
|
152 |
-
sample_wo_noise = ((sample_wo_noise - sample_wo_noise.min(1)[0]) / (sample_wo_noise.max(1)[0] - sample_wo_noise.min(1)[0]))
|
153 |
-
if hyperparameters.get('sigmoid'):
|
154 |
-
sample_wo_noise = sample_wo_noise.sigmoid()
|
155 |
-
|
156 |
-
if fix_to_range is None:
|
157 |
-
samples.append(sample.transpose(0, 1))
|
158 |
-
if not sampling_with_observation_noise: samples_wo_noise.append(sample_wo_noise.transpose(0,1))
|
159 |
-
successful_sample = True
|
160 |
-
continue
|
161 |
-
|
162 |
-
smaller_mask = sample < fix_to_range[0]
|
163 |
-
larger_mask = sample >= fix_to_range[1]
|
164 |
-
in_range_mask = ~ (smaller_mask | larger_mask).any(1)
|
165 |
-
throwaway_share += (~in_range_mask[:batch_size_per_gp_sample]).sum()/batch_size_per_gp_sample
|
166 |
-
if in_range_mask.sum() < batch_size_per_gp_sample:
|
167 |
-
successful_sample -= 1
|
168 |
-
if successful_sample < 100:
|
169 |
-
print("Please change hyper-parameters (e.g. decrease outputscale_mean) it"
|
170 |
-
"seems like the range is set to tight for your hyper-parameters.")
|
171 |
-
continue
|
172 |
-
|
173 |
-
x[i:i+batch_size_per_gp_sample] = x[i:i+num_candidates][in_range_mask][:batch_size_per_gp_sample]
|
174 |
-
sample = sample[in_range_mask][:batch_size_per_gp_sample]
|
175 |
-
samples.append(sample.transpose(0,1))
|
176 |
-
if not sampling_with_observation_noise: samples_wo_noise.append(sample_wo_noise.transpose(0,1))
|
177 |
-
successful_sample = True
|
178 |
-
|
179 |
-
if random.random() < .01:
|
180 |
-
print('throwaway share', throwaway_share/(batch_size//batch_size_per_gp_sample))
|
181 |
-
|
182 |
-
#print(f'took {time.time() - start}')
|
183 |
-
|
184 |
-
x = x.view(-1,batch_size,seq_len,num_features)[0]
|
185 |
-
# TODO think about enabling the line below
|
186 |
-
#sample = sample - sample[0, :].unsqueeze(0).expand(*sample.shape)
|
187 |
-
x = x.transpose(0,1)
|
188 |
-
sample = torch.cat(samples, 1)
|
189 |
-
|
190 |
-
if sampling_with_observation_noise:
|
191 |
-
target_sample = sample
|
192 |
-
else:
|
193 |
-
target_sample = torch.cat(samples_wo_noise, 1)
|
194 |
-
|
195 |
-
assert x.shape[:2] == sample.shape[:2]
|
196 |
-
|
197 |
-
return x, sample, target_sample # x.shape = (T,B,H)
|
198 |
-
|
199 |
-
|
200 |
-
class DataLoader(get_batch_to_dataloader(get_batch)):
|
201 |
-
@torch.no_grad()
|
202 |
-
def validate(self, model, step_size=1, start_pos=0):
|
203 |
-
if isinstance(model.criterion, BarDistribution):
|
204 |
-
(_, x,y), target_y, eval_pos = self.gbm(**self.get_batch_kwargs)
|
205 |
-
model.eval()
|
206 |
-
losses = []
|
207 |
-
for eval_pos in range(start_pos, len(x), step_size):
|
208 |
-
logits = model((x,y), single_eval_pos=eval_pos)
|
209 |
-
means = model.criterion.mean(logits) # num_evals x batch_size
|
210 |
-
mse = nn.MSELoss()
|
211 |
-
losses.append(mse(means[0], target_y[eval_pos]))
|
212 |
-
model.train()
|
213 |
-
return torch.stack(losses)
|
214 |
-
else:
|
215 |
-
return 123.
|
216 |
-
|
217 |
-
|
218 |
-
@torch.enable_grad()
|
219 |
-
def get_fitted_model(x, y, hyperparameters, device):
|
220 |
-
# fit the gaussian process
|
221 |
-
model, likelihood = get_model(x,y,hyperparameters,sample=False)
|
222 |
-
#print(model.covar_module.base_kernel.lengthscale)
|
223 |
-
model.to(device)
|
224 |
-
mll = ExactMarginalLogLikelihood(likelihood, model)
|
225 |
-
model.train()
|
226 |
-
fit_gpytorch_model(mll)
|
227 |
-
#print(model.covar_module.base_kernel.lengthscale)
|
228 |
-
return model, likelihood
|
229 |
-
|
230 |
-
|
231 |
-
evaluate = functools.partial(fast_gp.evaluate, get_model_on_device=get_fitted_model)
|
232 |
-
|
233 |
-
def get_mcmc_model(x, y, hyperparameters, device, num_samples, warmup_steps, obs=True):
|
234 |
-
from pyro.infer.mcmc import NUTS, MCMC, HMC
|
235 |
-
import pyro
|
236 |
-
x = x.to(device)
|
237 |
-
y = y.to(device)
|
238 |
-
model, likelihood = get_model(x, y, hyperparameters, sample=False)
|
239 |
-
model.to(device)
|
240 |
-
|
241 |
-
|
242 |
-
def pyro_model(x, y):
|
243 |
-
sampled_model = model.pyro_sample_from_prior()
|
244 |
-
output = sampled_model.likelihood(sampled_model(x))
|
245 |
-
if obs:
|
246 |
-
return pyro.sample("obs", output, obs=y)
|
247 |
-
|
248 |
-
nuts_kernel = NUTS(pyro_model)
|
249 |
-
mcmc_run = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup_steps, num_chains=1)
|
250 |
-
#print(x.shape)
|
251 |
-
mcmc_run.run(x, y)
|
252 |
-
#print(mcmc_run.get_samples())
|
253 |
-
model.pyro_load_from_samples(mcmc_run.get_samples()) # pyro.infer wie noah?
|
254 |
-
model.eval()
|
255 |
-
#print(mcmc_run.diagnostics())
|
256 |
-
# test_x = torch.linspace(0, 1, 101).unsqueeze(-1)
|
257 |
-
# test_y = torch.sin(test_x * (2 * math.pi))
|
258 |
-
# expanded_test_x = test_x.unsqueeze(0).repeat(num_samples, 1, 1)
|
259 |
-
# output = model(expanded_test_x)
|
260 |
-
#print(x.shape)
|
261 |
-
return model, likelihood
|
262 |
-
# output = model(x[-1].unsqueeze(1).repeat(1, num_samples 1))
|
263 |
-
# return output.mean
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
def get_mean_logdensity(dists, x: torch.Tensor, full_range=None):
|
269 |
-
means = torch.cat([d.mean.squeeze() for d in dists], 0)
|
270 |
-
vars = torch.cat([d.variance.squeeze() for d in dists], 0)
|
271 |
-
assert len(means.shape) == 1 and len(vars.shape) == 1
|
272 |
-
dist = torch.distributions.Normal(means, vars.sqrt())
|
273 |
-
#logprobs = torch.cat([d.log_prob(x) for d in dists], 0)
|
274 |
-
logprobs = dist.log_prob(x)
|
275 |
-
if full_range is not None:
|
276 |
-
used_weight = 1. - (dist.cdf(torch.tensor(full_range[0])) + (1.-dist.cdf(torch.tensor(full_range[1]))))
|
277 |
-
if torch.isinf(-torch.log(used_weight)).any() or torch.isinf(torch.log(used_weight)).any():
|
278 |
-
print('factor is inf', -torch.log(used_weight))
|
279 |
-
logprobs -= torch.log(used_weight)
|
280 |
-
assert len(logprobs.shape) == 1
|
281 |
-
#print(logprobs)
|
282 |
-
return torch.logsumexp(logprobs, 0) - math.log(len(logprobs))
|
283 |
-
|
284 |
-
|
285 |
-
def evaluate_(x, y, y_non_noisy, hyperparameters=None, device=default_device, num_samples=100, warmup_steps=300,
|
286 |
-
full_range=None, min_seq_len=0, use_likelihood=False, obs=True):
|
287 |
-
with gpytorch.settings.fast_computations(*hyperparameters.get('fast_computations',(True,True,True))), gpytorch.settings.fast_pred_var(False):
|
288 |
-
x = x.to(device).double()
|
289 |
-
y = y.to(device).double()
|
290 |
-
start_time = time.time()
|
291 |
-
losses_after_t = [.0] if min_seq_len == 0 else []
|
292 |
-
all_losses = []
|
293 |
-
|
294 |
-
for t in range(max(min_seq_len,1), len(x)):
|
295 |
-
#print('Timestep', t)
|
296 |
-
loss_sum = 0.
|
297 |
-
step_losses = []
|
298 |
-
start_step = time.time()
|
299 |
-
print(x.shape, y.shape)
|
300 |
-
for b_i in range(x.shape[1]):
|
301 |
-
x_train = x[:t,b_i]
|
302 |
-
y_train = y[:t,b_i]
|
303 |
-
from pyro.infer.mcmc import NUTS, MCMC, HMC
|
304 |
-
import pyro
|
305 |
-
x_train = x_train.to(device)
|
306 |
-
y_train = y_train.to(device)
|
307 |
-
print(x_train.shape, y_train.shape)
|
308 |
-
model, likelihood = get_model(x_train, y_train, hyperparameters, sample=False)
|
309 |
-
model.to(device)
|
310 |
-
|
311 |
-
def pyro_model(x, y):
|
312 |
-
sampled_model = model.pyro_sample_from_prior()
|
313 |
-
output = sampled_model.likelihood(sampled_model(x))
|
314 |
-
if obs:
|
315 |
-
return pyro.sample("obs", output, obs=y)
|
316 |
-
|
317 |
-
nuts_kernel = NUTS(pyro_model)
|
318 |
-
mcmc_run = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup_steps, num_chains=1, disable_progbar=True)
|
319 |
-
# print(x.shape)
|
320 |
-
mcmc_run.run(x_train, y_train)
|
321 |
-
# print(mcmc_run.get_samples())
|
322 |
-
model.pyro_load_from_samples(mcmc_run.get_samples())
|
323 |
-
model.eval()
|
324 |
-
|
325 |
-
with torch.no_grad():
|
326 |
-
dists = model(x[t, b_i, :].unsqueeze(
|
327 |
-
0).repeat(num_samples, 1, 1))
|
328 |
-
if use_likelihood:
|
329 |
-
dists = likelihood(dists)
|
330 |
-
l = -get_mean_logdensity([dists], y[t, b_i].repeat(num_samples), full_range)
|
331 |
-
print(l)
|
332 |
-
|
333 |
-
step_losses.append(l.item())
|
334 |
-
#print('loss',l.item())
|
335 |
-
print(f'current average loss at step {t} is {sum(step_losses)/len(step_losses)} with {(time.time()-start_step)/len(step_losses)} s per eval.')
|
336 |
-
loss_sum += l
|
337 |
-
|
338 |
-
loss_sum /= x.shape[1]
|
339 |
-
all_losses.append(step_losses)
|
340 |
-
print(f'loss after step {t} is {loss_sum}')
|
341 |
-
losses_after_t.append(loss_sum)
|
342 |
-
print(f'losses so far {torch.tensor(losses_after_t)}')
|
343 |
-
return torch.tensor(losses_after_t), time.time() - start_time, all_losses
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
if __name__ == '__main__':
|
350 |
-
import argparse
|
351 |
-
|
352 |
-
parser = argparse.ArgumentParser()
|
353 |
-
parser.add_argument('--batch_size', type=int)
|
354 |
-
parser.add_argument('--seq_len', type=int)
|
355 |
-
parser.add_argument('--min_seq_len', type=int, default=0)
|
356 |
-
parser.add_argument('--warmup_steps', type=int)
|
357 |
-
parser.add_argument('--num_samples', type=int)
|
358 |
-
parser.add_argument('--min_y', type=int)
|
359 |
-
parser.add_argument('--max_y', type=int)
|
360 |
-
parser.add_argument('--dim', type=int, default=1)
|
361 |
-
parser.add_argument('--use_likelihood', action='store_true')
|
362 |
-
parser.add_argument('--device', default='cpu')
|
363 |
-
parser.add_argument('--outputscale_concentraion', default=2., type=float)
|
364 |
-
parser.add_argument('--noise_concentration', default=1.1, type=float)
|
365 |
-
parser.add_argument('--noise_rate', default=.05, type=float)
|
366 |
-
parser.add_argument('--handmade', action='store_true')
|
367 |
-
parser.add_argument('--no_obs', action='store_true')
|
368 |
-
parser.add_argument('--seed', type=int, default=0)
|
369 |
-
|
370 |
-
args = parser.parse_args()
|
371 |
-
import pyro
|
372 |
-
import gpytorch
|
373 |
-
print(pyro.__version__)
|
374 |
-
print(gpytorch.__version__)
|
375 |
-
|
376 |
-
|
377 |
-
print('min_y:', args.min_y)
|
378 |
-
full_range = (None if args.min_y is None else (args.min_y,args.max_y))
|
379 |
-
|
380 |
-
hps = {'handmade': args.handmade, 'outputscale_concentration': args.outputscale_concentraion, 'noise_concentration': args.noise_concentration,
|
381 |
-
'noise_rate': args.noise_rate, 'fast_computations': (False,False,False)}
|
382 |
-
if args.seed:
|
383 |
-
torch.manual_seed(args.seed)
|
384 |
-
np.random.seed(args.seed)
|
385 |
-
random.seed(args.seed)
|
386 |
-
x, y, _ = get_batch(args.batch_size, args.seq_len, args.dim, fix_to_range=full_range, hyperparameters=hps)
|
387 |
-
#assert args.seq_len == 7 and args.min_seq_len == 6
|
388 |
-
#x = torch.cat([torch.linspace(0, 1, 6), torch.tensor([.33])]).unsqueeze(1).repeat(1,args.batch_size).unsqueeze(-1)
|
389 |
-
#y = torch.sin(x * (2 * math.pi)).squeeze(-1)
|
390 |
-
print('RESULT:', evaluate_(x, y, y, device=args.device, warmup_steps=args.warmup_steps,
|
391 |
-
num_samples=args.num_samples, full_range=full_range, min_seq_len=args.min_seq_len,
|
392 |
-
hyperparameters=hps, use_likelihood=args.use_likelihood, obs=not args.no_obs))
|
393 |
-
|
394 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lcpfn/priors/gp.py
DELETED
@@ -1,69 +0,0 @@
|
|
1 |
-
import time
|
2 |
-
import random
|
3 |
-
|
4 |
-
import numpy as np
|
5 |
-
import torch
|
6 |
-
from torch import nn
|
7 |
-
from sklearn.gaussian_process import GaussianProcessRegressor
|
8 |
-
from sklearn.gaussian_process.kernels import RBF, DotProduct, WhiteKernel
|
9 |
-
from .utils import get_batch_to_dataloader
|
10 |
-
|
11 |
-
|
12 |
-
length_scale_sampling_gp = .6
|
13 |
-
|
14 |
-
def get_gp(length_scale=None):
|
15 |
-
return GaussianProcessRegressor(
|
16 |
-
kernel=RBF(length_scale=length_scale or length_scale_sampling_gp, length_scale_bounds='fixed'),
|
17 |
-
random_state=0, optimizer=None)
|
18 |
-
|
19 |
-
|
20 |
-
def get_batch(batch_size, seq_len, num_features, noisy_std=None):
|
21 |
-
# m = torch.normal(0.,.1,size=(batch_size,num_features))
|
22 |
-
# m2 = torch.rand(batch_size,num_features)
|
23 |
-
# b = 0 # torch.rand(batch_size)
|
24 |
-
x_t = torch.rand(batch_size, seq_len, num_features)
|
25 |
-
# gp_b = TensorGP(kernel=TensorRBF(noisy_std))
|
26 |
-
# y_t = gp_b.sample_from_GP_prior(x_t).detach()
|
27 |
-
|
28 |
-
gpr = get_gp(noisy_std)
|
29 |
-
y_t = torch.zeros(batch_size, seq_len)
|
30 |
-
|
31 |
-
for i in range(len(y_t)):
|
32 |
-
y_t[i] += gpr.sample_y(x_t[i], random_state=random.randint(0, 2 ** 32)).squeeze()
|
33 |
-
x, y = x_t.transpose(0, 1), y_t.transpose(0, 1)
|
34 |
-
# x, _ = torch.sort(x,dim=0)
|
35 |
-
return x, y, y
|
36 |
-
|
37 |
-
|
38 |
-
DataLoader = get_batch_to_dataloader(get_batch)
|
39 |
-
|
40 |
-
def evaluate(x, y, y_non_noisy, use_mse=False, length_scale=length_scale_sampling_gp):
|
41 |
-
start_time = time.time()
|
42 |
-
losses_after_t = [.0]
|
43 |
-
for t in range(1, len(x)):
|
44 |
-
loss_sum = 0.
|
45 |
-
for b_i in range(x.shape[1]):
|
46 |
-
gpr = get_gp(length_scale).fit(x[:t, b_i], y[:t, b_i])
|
47 |
-
means, stds = gpr.predict(x[t, b_i].unsqueeze(0), return_std=True)
|
48 |
-
assert len(means) == 1 == len(stds)
|
49 |
-
if use_mse:
|
50 |
-
c = nn.MSELoss()
|
51 |
-
l = c(torch.tensor(means), y[t, b_i].unsqueeze(-1))
|
52 |
-
else:
|
53 |
-
c = nn.GaussianNLLLoss(full=True)
|
54 |
-
l = c(torch.tensor(means), y[t, b_i].unsqueeze(-1),
|
55 |
-
var=torch.tensor(stds) ** 2)
|
56 |
-
loss_sum += l
|
57 |
-
|
58 |
-
|
59 |
-
losses_after_t.append(loss_sum / x.shape[1])
|
60 |
-
|
61 |
-
return torch.tensor(losses_after_t), time.time()-start_time
|
62 |
-
|
63 |
-
if __name__ == '__main__':
|
64 |
-
ls = .1
|
65 |
-
for alpha in set([ls, ls * 1.1, ls * .9]):
|
66 |
-
print(alpha)
|
67 |
-
for redo_idx in range(1):
|
68 |
-
print(
|
69 |
-
evaluate(*get_batch(1000, 10, noisy_std=ls, num_features=10), use_mse=False, length_scale=alpha))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lcpfn/priors/prior.py
DELETED
@@ -1,25 +0,0 @@
|
|
1 |
-
from abc import ABCMeta, abstractmethod
|
2 |
-
from torch.utils.data import DataLoader
|
3 |
-
|
4 |
-
|
5 |
-
class PriorDataLoader(DataLoader, metaclass=ABCMeta):
|
6 |
-
@abstractmethod
|
7 |
-
def __init__(self, num_steps, batch_size, eval_pos_seq_len_sampler, seq_len_maximum, device, **kwargs):
|
8 |
-
"""
|
9 |
-
|
10 |
-
:param num_steps: int, first argument, the number of steps to take per epoch, i.e. iteration of the DataLoader
|
11 |
-
:param batch_size: int, number of datasets per batch
|
12 |
-
:param eval_pos_seq_len_sampler: callable, it takes no arguments and returns a tuple (single eval pos, bptt)
|
13 |
-
:param kwargs: for future compatibility it is good to have a final all catch, as new kwargs might be introduced
|
14 |
-
"""
|
15 |
-
pass
|
16 |
-
|
17 |
-
# A class or object variable `num_features`: int
|
18 |
-
# Optional: `validate` function that accepts a transformer model
|
19 |
-
|
20 |
-
# The DataLoader iter should return batches of the form ([style], x, y), target_y, single_eval_pos
|
21 |
-
# We follow sequence len (s) first, batch size (b) second. So x: (s,b,num_features), y,target_y: (s,b)
|
22 |
-
# and style: Optional[(b,num_style_params)], style can be omitted or set to None, if it is not intended to be used.
|
23 |
-
|
24 |
-
# For more references, see `priors/utils.py` for a pretty general implementation of a DataLoader
|
25 |
-
# and `train.py` for the only call of it.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lcpfn/priors/pyro.py
DELETED
@@ -1,41 +0,0 @@
|
|
1 |
-
import random
|
2 |
-
|
3 |
-
import torch
|
4 |
-
from torch import nn
|
5 |
-
|
6 |
-
from utils import default_device
|
7 |
-
from .utils import get_batch_to_dataloader
|
8 |
-
|
9 |
-
|
10 |
-
def get_batch(batch_size, seq_len, batch_size_per_gp_sample=None, **config):
|
11 |
-
batch_size_per_gp_sample = batch_size_per_gp_sample or batch_size // 16
|
12 |
-
assert batch_size % batch_size_per_gp_sample == 0, 'Please choose a batch_size divisible by batch_size_per_gp_sample.'
|
13 |
-
num_models = batch_size // batch_size_per_gp_sample
|
14 |
-
# standard kaiming uniform init currently...
|
15 |
-
|
16 |
-
models = [config['model']() for _ in range(num_models)]
|
17 |
-
|
18 |
-
sample = sum([[model(seq_len=seq_len) for _ in range(0,batch_size_per_gp_sample)] for model in models],[])
|
19 |
-
|
20 |
-
def normalize_data(data):
|
21 |
-
mean = data.mean(0)
|
22 |
-
std = data.std(0) + .000001
|
23 |
-
eval_xs = (data - mean) / std
|
24 |
-
|
25 |
-
return eval_xs
|
26 |
-
|
27 |
-
x, y = zip(*sample)
|
28 |
-
|
29 |
-
y = torch.stack(y, 1).squeeze(-1).detach()
|
30 |
-
x = torch.stack(x, 1).detach()
|
31 |
-
|
32 |
-
if 'normalize_y' in config and config['normalize_y']:
|
33 |
-
x, y = normalize_data(x), normalize_data(y)
|
34 |
-
elif 'normalize_y' in config and config['normalize']:
|
35 |
-
x, y = normalize_data(x), y
|
36 |
-
|
37 |
-
return x, y, y
|
38 |
-
|
39 |
-
|
40 |
-
DataLoader = get_batch_to_dataloader(get_batch)
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lcpfn/priors/ridge.py
DELETED
@@ -1,37 +0,0 @@
|
|
1 |
-
import random
|
2 |
-
import time
|
3 |
-
|
4 |
-
import numpy as np
|
5 |
-
import torch
|
6 |
-
from torch import nn
|
7 |
-
from sklearn.linear_model import Ridge
|
8 |
-
from .utils import get_batch_to_dataloader
|
9 |
-
|
10 |
-
def get_batch(batch_size, seq_len, num_features, noisy_std = .1):
|
11 |
-
m = torch.normal(0., .1, size=(batch_size,num_features))
|
12 |
-
b = 0 # torch.rand(batch_size)
|
13 |
-
x = torch.rand(seq_len, batch_size,num_features)
|
14 |
-
y_non_noisy = torch.einsum('bf,tbf->tb',m,x)
|
15 |
-
y = y_non_noisy + torch.normal(torch.zeros_like(y_non_noisy),noisy_std) # noisy_std is alpha
|
16 |
-
return x, y, y_non_noisy
|
17 |
-
|
18 |
-
DataLoader = get_batch_to_dataloader(get_batch)
|
19 |
-
|
20 |
-
|
21 |
-
def evaluate(x,y,y_non_noisy, alpha=0.):
|
22 |
-
start_time = time.time()
|
23 |
-
losses_after_t = [.0]
|
24 |
-
for t in range(1,len(x)):
|
25 |
-
loss_sum = 0.
|
26 |
-
for b_i in range(x.shape[1]):
|
27 |
-
clf = Ridge(alpha=alpha)
|
28 |
-
clf.fit(x[:t,b_i],y[:t,b_i])
|
29 |
-
y_ = clf.predict(x[t,b_i].unsqueeze(0))
|
30 |
-
l = nn.MSELoss()(y_non_noisy[t,b_i].unsqueeze(0),torch.tensor(y_))
|
31 |
-
loss_sum += l
|
32 |
-
losses_after_t.append(loss_sum/x.shape[1])
|
33 |
-
return torch.tensor(losses_after_t), time.time()-start_time
|
34 |
-
|
35 |
-
if __name__ == '__main__':
|
36 |
-
for alpha in [.001,.01,.5,1.]:
|
37 |
-
print(alpha, evaluate(*get_batch(1000,10,noisy_std=.01),alpha=alpha))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lcpfn/priors/stroke.py
DELETED
@@ -1,143 +0,0 @@
|
|
1 |
-
from PIL import Image, ImageDraw, ImageFilter
|
2 |
-
import random
|
3 |
-
import math
|
4 |
-
|
5 |
-
import torch
|
6 |
-
import numpy as np
|
7 |
-
from .utils import get_batch_to_dataloader
|
8 |
-
|
9 |
-
def mnist_prior(num_classes=2, size=28, min_max_strokes=(1,3), min_max_len=(5/28,20/28), min_max_start=(2/28,25/28),
|
10 |
-
min_max_width=(1/28,4/28), max_offset=4/28, max_target_offset=2/28):
|
11 |
-
classes = []
|
12 |
-
for i in range(num_classes):
|
13 |
-
num_strokes = random.randint(*min_max_strokes)
|
14 |
-
len_strokes = [random.randint(int(size * min_max_len[0]), int(size * min_max_len[1])) for i in range(num_strokes)]
|
15 |
-
stroke_start_points = [
|
16 |
-
(random.randint(int(size * min_max_start[0]), int(size * min_max_start[1])), random.randint(int(size * min_max_start[0]), int(size * min_max_start[1]))) for i in
|
17 |
-
range(num_strokes)]
|
18 |
-
stroke_directions = []
|
19 |
-
# i = Image.fromarray(np.zeros((28,28),dtype=np.uint8))
|
20 |
-
# draw = ImageDraw.Draw(i)
|
21 |
-
for i in range(num_strokes):
|
22 |
-
sp, length = stroke_start_points[i], len_strokes[i]
|
23 |
-
counter = 0
|
24 |
-
while True:
|
25 |
-
if counter % 3 == 0:
|
26 |
-
length = random.randint(int(size * min_max_len[0]), int(size * min_max_len[1]))
|
27 |
-
sp = (
|
28 |
-
random.randint(int(size * min_max_start[0]), int(size * min_max_start[1])), random.randint(int(size * min_max_start[0]), int(size * min_max_start[1])))
|
29 |
-
stroke_start_points[i], len_strokes[i] = sp, length
|
30 |
-
radians = random.random() * 2 * math.pi
|
31 |
-
x_vel = math.cos(radians) * length
|
32 |
-
y_vel = math.sin(radians) * length
|
33 |
-
new_p = (sp[0] + x_vel, sp[1] + y_vel)
|
34 |
-
# print(math.degrees(radians),sp,new_p)
|
35 |
-
if not any(n > size - 1 or n < 0 for n in new_p):
|
36 |
-
break
|
37 |
-
counter += 1
|
38 |
-
stroke_directions.append(radians)
|
39 |
-
# print([round(x) for x in sp+new_p])
|
40 |
-
# draw.line([round(x) for x in sp+new_p], fill=128, width=3)
|
41 |
-
classes.append((len_strokes, stroke_start_points, stroke_directions))
|
42 |
-
|
43 |
-
generator_functions = []
|
44 |
-
for c in classes:
|
45 |
-
def g(c=c):
|
46 |
-
len_strokes, stroke_start_points, stroke_directions = c
|
47 |
-
i = Image.fromarray(np.zeros((size, size), dtype=np.uint8))
|
48 |
-
draw = ImageDraw.Draw(i)
|
49 |
-
width = random.randint(int(size * min_max_width[0]), int(size * min_max_width[1]))
|
50 |
-
offset = random.randint(int(-size * max_offset), int(size * max_offset)), random.randint(int(- size * max_offset), int(size * max_offset))
|
51 |
-
for sp, length, radians in zip(stroke_start_points, len_strokes, stroke_directions):
|
52 |
-
sp = (sp[0] + offset[0], sp[1] + offset[1])
|
53 |
-
x_vel = math.cos(radians) * length + random.randint(int(-size * max_target_offset), int(size * max_target_offset))
|
54 |
-
y_vel = math.sin(radians) * length + random.randint(int(-size * max_target_offset), int(size * max_target_offset))
|
55 |
-
new_p = (sp[0] + x_vel, sp[1] + y_vel)
|
56 |
-
stroke_directions.append(radians)
|
57 |
-
draw.line([round(x) for x in sp + new_p], fill=128, width=width)
|
58 |
-
a_i = np.array(i)
|
59 |
-
a_i[a_i == 128] = np.random.randint(200, 255, size=a_i.shape)[a_i == 128]
|
60 |
-
return Image.fromarray(a_i).filter(ImageFilter.GaussianBlur(.2))
|
61 |
-
|
62 |
-
generator_functions.append(g)
|
63 |
-
return generator_functions
|
64 |
-
|
65 |
-
|
66 |
-
# g1,g2 = mnist_prior(2)
|
67 |
-
|
68 |
-
# for i in [g1() for _ in range(10)]:
|
69 |
-
# display(i.resize((200,200)))
|
70 |
-
|
71 |
-
from torchvision.transforms import ToTensor, ToPILImage
|
72 |
-
|
73 |
-
|
74 |
-
def normalize(x):
|
75 |
-
return (x-x.mean())/(x.std()+.000001)
|
76 |
-
|
77 |
-
from os import path, listdir
|
78 |
-
import random
|
79 |
-
|
80 |
-
def get_batch(batch_size, seq_len, num_features=None, noisy_std=None, only_train_for_last_idx=False, normalize_x=False, num_outputs=2, use_saved_from=None, **kwargs): # num_features = 28*28=784
|
81 |
-
if use_saved_from is not None:
|
82 |
-
directory = path.join(use_saved_from, f'len_{seq_len}_out_{num_outputs}_features_{num_features}_bs_{batch_size}')
|
83 |
-
filename = random.choice(listdir(directory))
|
84 |
-
return torch.load(path.join(directory,filename))
|
85 |
-
|
86 |
-
size = math.isqrt(num_features)
|
87 |
-
assert size * size == num_features, 'num_features needs to be the square of an integer.'
|
88 |
-
if only_train_for_last_idx:
|
89 |
-
assert (seq_len-1) % num_outputs == 0
|
90 |
-
|
91 |
-
# assert seq_len % 2 == 0, "assert seq_len % 2 == 0"
|
92 |
-
batch = []
|
93 |
-
y = []
|
94 |
-
target_y = []
|
95 |
-
for b_i in range(batch_size):
|
96 |
-
gs = mnist_prior(num_outputs, size, **kwargs)
|
97 |
-
if only_train_for_last_idx:
|
98 |
-
generators = [i for i in range(len(gs)) for _ in range((seq_len-1) // num_outputs)]
|
99 |
-
random.shuffle(generators)
|
100 |
-
generators += [random.randint(0, len(gs) - 1)]
|
101 |
-
target = [-100 for _ in generators]
|
102 |
-
target[-1] = generators[-1]
|
103 |
-
else:
|
104 |
-
generators = [random.randint(0, len(gs) - 1) for _ in range(seq_len)]
|
105 |
-
target = generators
|
106 |
-
normalize_or_not = lambda x: normalize(x) if normalize_x else x
|
107 |
-
s = torch.cat([normalize_or_not(ToTensor()(gs[f_i]())) for f_i in generators], 0)
|
108 |
-
batch.append(s)
|
109 |
-
y.append(torch.tensor(generators))
|
110 |
-
target_y.append(torch.tensor(target))
|
111 |
-
x = torch.stack(batch, 1).view(seq_len, batch_size, -1)
|
112 |
-
y = torch.stack(y, 1)
|
113 |
-
target_y = torch.stack(target_y, 1)
|
114 |
-
return x,y,target_y
|
115 |
-
|
116 |
-
DataLoader = get_batch_to_dataloader(get_batch)
|
117 |
-
DataLoader.num_outputs = 2
|
118 |
-
|
119 |
-
if __name__ == '__main__':
|
120 |
-
g1, g2 = mnist_prior(2, size=3)
|
121 |
-
|
122 |
-
# for i in range(10):
|
123 |
-
# print(PILToTensor()(g1()))
|
124 |
-
# display(ToPILImage()(PILToTensor()(g1())).resize((200,200)))
|
125 |
-
# display(g2().resize((200,200)))
|
126 |
-
|
127 |
-
size = 10
|
128 |
-
x, y = get_batch(1, 10, num_features=size * size)
|
129 |
-
|
130 |
-
x_ = x[..., :-1].squeeze(1)
|
131 |
-
last_y = x[..., -1].squeeze(1)
|
132 |
-
y = y.squeeze(1)
|
133 |
-
|
134 |
-
# print(y)
|
135 |
-
|
136 |
-
for i, y_, last_y_, x__ in zip(x_, y, last_y, x.squeeze(1)):
|
137 |
-
# print(y_)
|
138 |
-
# print(i.shape)
|
139 |
-
# print(x__)
|
140 |
-
img = ToPILImage()(i.view(size, size))
|
141 |
-
# display(img.resize((200,200)))
|
142 |
-
|
143 |
-
print(y, last_y)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lcpfn/priors/utils.py
DELETED
@@ -1,151 +0,0 @@
|
|
1 |
-
import random
|
2 |
-
|
3 |
-
import pandas as pd
|
4 |
-
import torch
|
5 |
-
|
6 |
-
from lcpfn.utils import set_locals_in_self
|
7 |
-
from itertools import repeat
|
8 |
-
from .prior import PriorDataLoader
|
9 |
-
from torch import nn
|
10 |
-
import numpy as np
|
11 |
-
import matplotlib.pyplot as plt
|
12 |
-
import matplotlib.gridspec as gridspec
|
13 |
-
import scipy.stats as stats
|
14 |
-
import math
|
15 |
-
|
16 |
-
def get_batch_to_dataloader(get_batch_method_):
|
17 |
-
class DL(PriorDataLoader):
|
18 |
-
get_batch_method = get_batch_method_
|
19 |
-
|
20 |
-
# Caution, you might need to set self.num_features manually if it is not part of the args.
|
21 |
-
def __init__(self, num_steps, **get_batch_kwargs):
|
22 |
-
set_locals_in_self(locals())
|
23 |
-
|
24 |
-
# The stuff outside the or is set as class attribute before instantiation.
|
25 |
-
self.num_features = get_batch_kwargs.get('num_features') or self.num_features
|
26 |
-
print('DataLoader.__dict__', self.__dict__)
|
27 |
-
|
28 |
-
@staticmethod
|
29 |
-
def gbm(*args, eval_pos_seq_len_sampler, **kwargs):
|
30 |
-
kwargs['single_eval_pos'], kwargs['seq_len'] = eval_pos_seq_len_sampler()
|
31 |
-
# Scales the batch size dynamically with the power of 'dynamic_batch_size'.
|
32 |
-
# A transformer with quadratic memory usage in the seq len would need a power of 2 to keep memory constant.
|
33 |
-
if 'dynamic_batch_size' in kwargs and kwargs['dynamic_batch_size'] > 0:
|
34 |
-
kwargs['batch_size'] = kwargs['batch_size'] * math.floor(math.pow(kwargs['seq_len_maximum'], kwargs['dynamic_batch_size']) / math.pow(kwargs['seq_len'], kwargs['dynamic_batch_size']))
|
35 |
-
batch = get_batch_method_(*args, **kwargs)
|
36 |
-
x, y, target_y, style = batch if len(batch) == 4 else (batch[0], batch[1], batch[2], None)
|
37 |
-
return (style, x, y), target_y, kwargs['single_eval_pos']
|
38 |
-
|
39 |
-
def __len__(self):
|
40 |
-
return self.num_steps
|
41 |
-
|
42 |
-
def __iter__(self):
|
43 |
-
return iter(self.gbm(**self.get_batch_kwargs) for _ in range(self.num_steps))
|
44 |
-
|
45 |
-
return DL
|
46 |
-
|
47 |
-
"""
|
48 |
-
import seaborn as sns
|
49 |
-
def plot_features(data, targets, fig=None):
|
50 |
-
if torch.is_tensor(data):
|
51 |
-
data = data.detach().cpu().numpy()
|
52 |
-
targets = targets.detach().cpu().numpy()
|
53 |
-
fig2 = plt.figure(figsize=(8, 8))
|
54 |
-
spec2 = gridspec.GridSpec(ncols=data.shape[1], nrows=data.shape[1], figure=fig2)
|
55 |
-
for d in range(0, data.shape[1]):
|
56 |
-
for d2 in range(0, data.shape[1]):
|
57 |
-
sub_ax = fig2.add_subplot(spec2[d, d2])
|
58 |
-
if d == d2:
|
59 |
-
sns.kdeplot(data[:, d],hue=targets[:],ax=sub_ax,legend=False, palette="deep")
|
60 |
-
sub_ax.set(ylabel=None)
|
61 |
-
else:
|
62 |
-
sns.scatterplot(data[:, d], data[:, d2],
|
63 |
-
hue=targets[:],legend=False, palette="deep")
|
64 |
-
#plt.scatter(data[:, d], data[:, d2],
|
65 |
-
# c=targets[:])
|
66 |
-
sub_ax.get_xaxis().set_ticks([])
|
67 |
-
sub_ax.get_yaxis().set_ticks([])
|
68 |
-
plt.subplots_adjust(wspace=0.05, hspace=0.05)
|
69 |
-
fig2.show()
|
70 |
-
|
71 |
-
|
72 |
-
def plot_prior(prior):
|
73 |
-
s = np.array([prior() for _ in range(0, 1000)])
|
74 |
-
count, bins, ignored = plt.hist(s, 50, density=True)
|
75 |
-
print(s.min())
|
76 |
-
plt.show()
|
77 |
-
"""
|
78 |
-
|
79 |
-
trunc_norm_sampler_f = lambda mu, sigma : lambda: stats.truncnorm((0 - mu) / sigma, (1000000 - mu) / sigma, loc=mu, scale=sigma).rvs(1)[0]
|
80 |
-
beta_sampler_f = lambda a, b : lambda : np.random.beta(a, b)
|
81 |
-
gamma_sampler_f = lambda a, b : lambda : np.random.gamma(a, b)
|
82 |
-
uniform_sampler_f = lambda a, b : lambda : np.random.uniform(a, b)
|
83 |
-
uniform_int_sampler_f = lambda a, b : lambda : round(np.random.uniform(a, b))
|
84 |
-
def zipf_sampler_f(a, b, c):
|
85 |
-
x = np.arange(b, c)
|
86 |
-
weights = x ** (-a)
|
87 |
-
weights /= weights.sum()
|
88 |
-
return lambda : stats.rv_discrete(name='bounded_zipf', values=(x, weights)).rvs(1)
|
89 |
-
scaled_beta_sampler_f = lambda a, b, scale, minimum : lambda : minimum + round(beta_sampler_f(a, b)() * (scale - minimum))
|
90 |
-
|
91 |
-
|
92 |
-
def normalize_by_used_features_f(x, num_features_used, num_features, normalize_with_sqrt=False):
|
93 |
-
if normalize_with_sqrt:
|
94 |
-
return x / (num_features_used / num_features)**(1 / 2)
|
95 |
-
return x / (num_features_used / num_features)
|
96 |
-
|
97 |
-
|
98 |
-
def order_by_y(x, y):
|
99 |
-
order = torch.argsort(y if random.randint(0, 1) else -y, dim=0)[:, 0, 0]
|
100 |
-
order = order.reshape(2, -1).transpose(0, 1).reshape(-1)#.reshape(seq_len)
|
101 |
-
x = x[order] # .reshape(2, -1).transpose(0, 1).reshape(-1).flip([0]).reshape(seq_len, 1, -1)
|
102 |
-
y = y[order] # .reshape(2, -1).transpose(0, 1).reshape(-1).reshape(seq_len, 1, -1)
|
103 |
-
|
104 |
-
return x, y
|
105 |
-
|
106 |
-
def randomize_classes(x, num_classes):
|
107 |
-
classes = torch.arange(0, num_classes, device=x.device)
|
108 |
-
random_classes = torch.randperm(num_classes, device=x.device).type(x.type())
|
109 |
-
x = ((x.unsqueeze(-1) == classes) * random_classes).sum(-1)
|
110 |
-
return x
|
111 |
-
|
112 |
-
|
113 |
-
class CategoricalActivation(nn.Module):
|
114 |
-
def __init__(self, categorical_p=0.1, ordered_p=0.7
|
115 |
-
, keep_activation_size=False
|
116 |
-
, num_classes_sampler=zipf_sampler_f(0.8, 1, 10)):
|
117 |
-
self.categorical_p = categorical_p
|
118 |
-
self.ordered_p = ordered_p
|
119 |
-
self.keep_activation_size = keep_activation_size
|
120 |
-
self.num_classes_sampler = num_classes_sampler
|
121 |
-
|
122 |
-
super().__init__()
|
123 |
-
|
124 |
-
def forward(self, x):
|
125 |
-
# x shape: T, B, H
|
126 |
-
|
127 |
-
x = nn.Softsign()(x)
|
128 |
-
|
129 |
-
num_classes = self.num_classes_sampler()
|
130 |
-
hid_strength = torch.abs(x).mean(0).unsqueeze(0) if self.keep_activation_size else None
|
131 |
-
|
132 |
-
categorical_classes = torch.rand((x.shape[1], x.shape[2])) < self.categorical_p
|
133 |
-
class_boundaries = torch.zeros((num_classes - 1, x.shape[1], x.shape[2]), device=x.device, dtype=x.dtype)
|
134 |
-
# Sample a different index for each hidden dimension, but shared for all batches
|
135 |
-
for b in range(x.shape[1]):
|
136 |
-
for h in range(x.shape[2]):
|
137 |
-
ind = torch.randint(0, x.shape[0], (num_classes - 1,))
|
138 |
-
class_boundaries[:, b, h] = x[ind, b, h]
|
139 |
-
|
140 |
-
for b in range(x.shape[1]):
|
141 |
-
x_rel = x[:, b, categorical_classes[b]]
|
142 |
-
boundaries_rel = class_boundaries[:, b, categorical_classes[b]].unsqueeze(1)
|
143 |
-
x[:, b, categorical_classes[b]] = (x_rel > boundaries_rel).sum(dim=0).float() - num_classes / 2
|
144 |
-
|
145 |
-
ordered_classes = torch.rand((x.shape[1],x.shape[2])) < self.ordered_p
|
146 |
-
ordered_classes = torch.logical_and(ordered_classes, categorical_classes)
|
147 |
-
x[:, ordered_classes] = randomize_classes(x[:, ordered_classes], num_classes)
|
148 |
-
|
149 |
-
x = x * hid_strength if self.keep_activation_size else x
|
150 |
-
|
151 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lcpfn/train.py
DELETED
@@ -1,336 +0,0 @@
|
|
1 |
-
import itertools
|
2 |
-
import time
|
3 |
-
from contextlib import nullcontext
|
4 |
-
|
5 |
-
import torch
|
6 |
-
from torch import nn
|
7 |
-
|
8 |
-
from lcpfn import utils
|
9 |
-
from lcpfn.transformer import TransformerModel
|
10 |
-
from lcpfn.bar_distribution import (
|
11 |
-
BarDistribution,
|
12 |
-
)
|
13 |
-
from lcpfn.utils import (
|
14 |
-
get_cosine_schedule_with_warmup,
|
15 |
-
get_openai_lr,
|
16 |
-
)
|
17 |
-
from lcpfn import positional_encodings
|
18 |
-
from lcpfn.utils import init_dist
|
19 |
-
from torch.cuda.amp import autocast, GradScaler
|
20 |
-
|
21 |
-
|
22 |
-
class Losses:
|
23 |
-
gaussian = nn.GaussianNLLLoss(full=True, reduction="none")
|
24 |
-
mse = nn.MSELoss(reduction="none")
|
25 |
-
ce = lambda num_classes: nn.CrossEntropyLoss(
|
26 |
-
reduction="none", weight=torch.ones(num_classes)
|
27 |
-
)
|
28 |
-
bce = nn.BCEWithLogitsLoss(reduction="none")
|
29 |
-
get_BarDistribution = BarDistribution
|
30 |
-
|
31 |
-
|
32 |
-
def train(
|
33 |
-
priordataloader_class,
|
34 |
-
criterion,
|
35 |
-
encoder_generator,
|
36 |
-
emsize=200,
|
37 |
-
nhid=200,
|
38 |
-
nlayers=6,
|
39 |
-
nhead=2,
|
40 |
-
dropout=0.2,
|
41 |
-
epochs=10,
|
42 |
-
steps_per_epoch=100,
|
43 |
-
batch_size=200,
|
44 |
-
bptt=10,
|
45 |
-
lr=None,
|
46 |
-
weight_decay=0.0,
|
47 |
-
warmup_epochs=10,
|
48 |
-
input_normalization=False,
|
49 |
-
y_encoder_generator=None,
|
50 |
-
pos_encoder_generator=None,
|
51 |
-
decoder=None,
|
52 |
-
extra_prior_kwargs_dict={},
|
53 |
-
scheduler=get_cosine_schedule_with_warmup,
|
54 |
-
load_weights_from_this_state_dict=None,
|
55 |
-
validation_period=10,
|
56 |
-
single_eval_pos_gen=None,
|
57 |
-
bptt_extra_samples=None,
|
58 |
-
gpu_device="cuda:0",
|
59 |
-
aggregate_k_gradients=1,
|
60 |
-
verbose=True,
|
61 |
-
style_encoder_generator=None,
|
62 |
-
epoch_callback=None,
|
63 |
-
initializer=None,
|
64 |
-
initialize_with_model=None,
|
65 |
-
train_mixed_precision=False,
|
66 |
-
saving_period=10,
|
67 |
-
checkpoint_file=None,
|
68 |
-
load_optimizer_from_this_state_dict=None,
|
69 |
-
output_path=None,
|
70 |
-
**model_extra_args,
|
71 |
-
):
|
72 |
-
device = gpu_device if torch.cuda.is_available() else "cpu:0"
|
73 |
-
print(f"Using {device} device")
|
74 |
-
using_dist, rank, device = init_dist(device)
|
75 |
-
single_eval_pos_gen = (
|
76 |
-
single_eval_pos_gen
|
77 |
-
if callable(single_eval_pos_gen)
|
78 |
-
else lambda: single_eval_pos_gen
|
79 |
-
)
|
80 |
-
|
81 |
-
def eval_pos_seq_len_sampler():
|
82 |
-
single_eval_pos = single_eval_pos_gen()
|
83 |
-
if bptt_extra_samples:
|
84 |
-
return single_eval_pos, single_eval_pos + bptt_extra_samples
|
85 |
-
else:
|
86 |
-
return single_eval_pos, bptt
|
87 |
-
|
88 |
-
dl = priordataloader_class(
|
89 |
-
num_steps=steps_per_epoch,
|
90 |
-
batch_size=batch_size,
|
91 |
-
eval_pos_seq_len_sampler=eval_pos_seq_len_sampler,
|
92 |
-
seq_len_maximum=bptt + (bptt_extra_samples if bptt_extra_samples else 0),
|
93 |
-
device=device,
|
94 |
-
**extra_prior_kwargs_dict,
|
95 |
-
)
|
96 |
-
|
97 |
-
encoder = encoder_generator(dl.num_features, emsize)
|
98 |
-
style_def = next(iter(dl))[0][
|
99 |
-
0
|
100 |
-
] # This is (style, x, y), target with x and y with batch size
|
101 |
-
print(f"Style definition: {style_def}")
|
102 |
-
style_encoder = (
|
103 |
-
style_encoder_generator(hyperparameter_definitions=style_def[0], em_size=emsize)
|
104 |
-
if (style_def is not None)
|
105 |
-
else None
|
106 |
-
)
|
107 |
-
if isinstance(criterion, nn.GaussianNLLLoss):
|
108 |
-
n_out = 2
|
109 |
-
elif (
|
110 |
-
isinstance(criterion, BarDistribution)
|
111 |
-
or "BarDistribution" in criterion.__class__.__name__
|
112 |
-
): # TODO remove this fix (only for dev)
|
113 |
-
n_out = criterion.num_bars
|
114 |
-
elif isinstance(criterion, nn.CrossEntropyLoss):
|
115 |
-
n_out = criterion.weight.shape[0]
|
116 |
-
else:
|
117 |
-
n_out = 1
|
118 |
-
model = TransformerModel(
|
119 |
-
encoder,
|
120 |
-
n_out,
|
121 |
-
emsize,
|
122 |
-
nhead,
|
123 |
-
nhid,
|
124 |
-
nlayers,
|
125 |
-
dropout,
|
126 |
-
style_encoder=style_encoder,
|
127 |
-
y_encoder=y_encoder_generator(1, emsize),
|
128 |
-
input_normalization=input_normalization,
|
129 |
-
pos_encoder=(
|
130 |
-
pos_encoder_generator or positional_encodings.NoPositionalEncoding
|
131 |
-
)(emsize, bptt * 2),
|
132 |
-
decoder=decoder,
|
133 |
-
init_method=initializer,
|
134 |
-
**model_extra_args,
|
135 |
-
)
|
136 |
-
model.criterion = criterion
|
137 |
-
if load_weights_from_this_state_dict is not None:
|
138 |
-
model.load_state_dict(load_weights_from_this_state_dict)
|
139 |
-
if initialize_with_model is not None:
|
140 |
-
model.init_from_small_model(initialize_with_model)
|
141 |
-
|
142 |
-
print(
|
143 |
-
f"Using a Transformer with {sum(p.numel() for p in model.parameters())/1000/1000:.{2}f} M parameters"
|
144 |
-
)
|
145 |
-
|
146 |
-
try:
|
147 |
-
for (k, v), (k2, v2) in zip(
|
148 |
-
model.state_dict().items(), initialize_with_model.state_dict().items()
|
149 |
-
):
|
150 |
-
print(k, ((v - v2) / v).abs().mean(), v.shape)
|
151 |
-
except Exception:
|
152 |
-
pass
|
153 |
-
|
154 |
-
model.to(device)
|
155 |
-
if using_dist:
|
156 |
-
print("Distributed training")
|
157 |
-
model = torch.nn.parallel.DistributedDataParallel(
|
158 |
-
model, device_ids=[rank], output_device=rank, broadcast_buffers=False
|
159 |
-
)
|
160 |
-
|
161 |
-
# learning rate
|
162 |
-
if lr is None:
|
163 |
-
lr = get_openai_lr(model)
|
164 |
-
print(f"Using OpenAI max lr of {lr}.")
|
165 |
-
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
166 |
-
scheduler = scheduler(
|
167 |
-
optimizer, warmup_epochs, epochs if epochs is not None else 100
|
168 |
-
) # when training for fixed time lr schedule takes 100 steps
|
169 |
-
|
170 |
-
if load_optimizer_from_this_state_dict is not None:
|
171 |
-
optimizer.load_state_dict(load_optimizer_from_this_state_dict)
|
172 |
-
scaler = GradScaler() if train_mixed_precision else None
|
173 |
-
|
174 |
-
# check that everything uses up-to-date APIs
|
175 |
-
utils.check_compatibility(dl)
|
176 |
-
|
177 |
-
def train_epoch():
|
178 |
-
model.train() # Turn on the train mode
|
179 |
-
total_loss = 0.0
|
180 |
-
total_positional_losses = 0.0
|
181 |
-
total_positional_losses_recorded = 0
|
182 |
-
before_get_batch = time.time()
|
183 |
-
assert (
|
184 |
-
len(dl) % aggregate_k_gradients == 0
|
185 |
-
), "Please set the number of steps per epoch s.t. `aggregate_k_gradients` divides it."
|
186 |
-
for batch, (data, targets, single_eval_pos) in enumerate(dl):
|
187 |
-
if using_dist and not (
|
188 |
-
batch % aggregate_k_gradients == aggregate_k_gradients - 1
|
189 |
-
):
|
190 |
-
cm = model.no_sync()
|
191 |
-
else:
|
192 |
-
cm = nullcontext()
|
193 |
-
with cm:
|
194 |
-
time_to_get_batch = time.time() - before_get_batch
|
195 |
-
before_forward = time.time()
|
196 |
-
|
197 |
-
with autocast(enabled=scaler is not None):
|
198 |
-
# If style is set to None, it should not be transferred to device
|
199 |
-
output = model(
|
200 |
-
tuple(e.to(device) if torch.is_tensor(e) else e for e in data)
|
201 |
-
if isinstance(data, tuple)
|
202 |
-
else data.to(device),
|
203 |
-
single_eval_pos=single_eval_pos,
|
204 |
-
)
|
205 |
-
|
206 |
-
forward_time = time.time() - before_forward
|
207 |
-
|
208 |
-
if single_eval_pos is not None:
|
209 |
-
targets = targets[single_eval_pos:]
|
210 |
-
if isinstance(criterion, nn.GaussianNLLLoss):
|
211 |
-
assert (
|
212 |
-
output.shape[-1] == 2
|
213 |
-
), "need to write a little bit of code to handle multiple regression targets at once"
|
214 |
-
|
215 |
-
mean_pred = output[..., 0]
|
216 |
-
var_pred = output[..., 1].abs()
|
217 |
-
losses = criterion(
|
218 |
-
mean_pred.flatten(),
|
219 |
-
targets.to(device).flatten(),
|
220 |
-
var=var_pred.flatten(),
|
221 |
-
)
|
222 |
-
elif isinstance(criterion, (nn.MSELoss, nn.BCEWithLogitsLoss)):
|
223 |
-
losses = criterion(
|
224 |
-
output.flatten(), targets.to(device).flatten()
|
225 |
-
)
|
226 |
-
elif isinstance(criterion, nn.CrossEntropyLoss):
|
227 |
-
losses = criterion(
|
228 |
-
output.reshape(-1, n_out),
|
229 |
-
targets.to(device).long().flatten(),
|
230 |
-
)
|
231 |
-
else:
|
232 |
-
losses = criterion(output, targets)
|
233 |
-
losses = losses.view(*output.shape[0:2])
|
234 |
-
loss = losses.mean() / aggregate_k_gradients
|
235 |
-
|
236 |
-
if scaler:
|
237 |
-
loss = scaler.scale(loss)
|
238 |
-
loss.backward()
|
239 |
-
|
240 |
-
if batch % aggregate_k_gradients == aggregate_k_gradients - 1:
|
241 |
-
if scaler:
|
242 |
-
scaler.unscale_(optimizer)
|
243 |
-
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
244 |
-
try:
|
245 |
-
if scaler:
|
246 |
-
scaler.step(optimizer)
|
247 |
-
scaler.update()
|
248 |
-
else:
|
249 |
-
optimizer.step()
|
250 |
-
except:
|
251 |
-
print("Invalid optimization step encountered")
|
252 |
-
optimizer.zero_grad()
|
253 |
-
|
254 |
-
step_time = time.time() - before_forward
|
255 |
-
|
256 |
-
if not torch.isnan(loss):
|
257 |
-
total_loss += losses.mean().cpu().detach()
|
258 |
-
total_positional_losses += (
|
259 |
-
losses.mean(1).cpu().detach()
|
260 |
-
if single_eval_pos is None
|
261 |
-
else nn.functional.one_hot(torch.tensor(single_eval_pos), bptt)
|
262 |
-
* losses[: bptt - single_eval_pos].mean().cpu().detach()
|
263 |
-
)
|
264 |
-
|
265 |
-
total_positional_losses_recorded += (
|
266 |
-
torch.ones(bptt)
|
267 |
-
if single_eval_pos is None
|
268 |
-
else nn.functional.one_hot(torch.tensor(single_eval_pos), bptt)
|
269 |
-
)
|
270 |
-
|
271 |
-
before_get_batch = time.time()
|
272 |
-
return (
|
273 |
-
total_loss / steps_per_epoch,
|
274 |
-
(total_positional_losses / total_positional_losses_recorded).tolist(),
|
275 |
-
time_to_get_batch,
|
276 |
-
forward_time,
|
277 |
-
step_time,
|
278 |
-
)
|
279 |
-
|
280 |
-
total_loss = float("inf")
|
281 |
-
total_positional_losses = float("inf")
|
282 |
-
list_losses = []
|
283 |
-
try:
|
284 |
-
for epoch in range(1, epochs + 1) if epochs is not None else itertools.count(1):
|
285 |
-
epoch_start_time = time.time()
|
286 |
-
(
|
287 |
-
total_loss,
|
288 |
-
total_positional_losses,
|
289 |
-
time_to_get_batch,
|
290 |
-
forward_time,
|
291 |
-
step_time,
|
292 |
-
) = train_epoch()
|
293 |
-
list_losses.append(total_loss.item())
|
294 |
-
if hasattr(dl, "validate") and epoch % validation_period == 0:
|
295 |
-
with torch.no_grad():
|
296 |
-
val_score = dl.validate(model)
|
297 |
-
|
298 |
-
else:
|
299 |
-
val_score = None
|
300 |
-
|
301 |
-
if epoch % saving_period == 0 and checkpoint_file is not None:
|
302 |
-
checkpoint = {
|
303 |
-
"model_state_dict": model.state_dict(),
|
304 |
-
"optimizer_state_dict": optimizer.state_dict(),
|
305 |
-
"epoch": epoch,
|
306 |
-
}
|
307 |
-
torch.save(checkpoint, checkpoint_file)
|
308 |
-
full_model_path = checkpoint_file.split(".")[0] + "_full_model.pt"
|
309 |
-
torch.save(model, full_model_path)
|
310 |
-
|
311 |
-
if verbose:
|
312 |
-
print("-" * 89)
|
313 |
-
print(
|
314 |
-
f"| end of epoch {epoch:3d} | time: {(time.time() - epoch_start_time):5.2f}s | mean loss {total_loss:5.2f} | "
|
315 |
-
f"pos losses {','.join([f'{l:5.2f}' for l in total_positional_losses])}, lr {scheduler.get_last_lr()[0]}"
|
316 |
-
f" data time {time_to_get_batch:5.2f} step time {step_time:5.2f}"
|
317 |
-
f" forward time {forward_time:5.2f}"
|
318 |
-
+ (f"val score {val_score}" if val_score is not None else "")
|
319 |
-
)
|
320 |
-
print("-" * 89)
|
321 |
-
|
322 |
-
# stepping with wallclock time based scheduler
|
323 |
-
if epoch_callback is not None and rank == 0:
|
324 |
-
epoch_callback(model, epoch / epochs)
|
325 |
-
scheduler.step()
|
326 |
-
except KeyboardInterrupt:
|
327 |
-
pass
|
328 |
-
|
329 |
-
if rank == 0: # trivially true for non-parallel training
|
330 |
-
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
331 |
-
model = model.module
|
332 |
-
dl = None
|
333 |
-
if output_path is not None:
|
334 |
-
torch.save(model.to("cpu"), output_path)
|
335 |
-
print("Checkpoint stored at ", output_path)
|
336 |
-
return total_loss, total_positional_losses, model.to("cpu"), dl
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lcpfn/train_lcpfn.py
DELETED
@@ -1,96 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
|
3 |
-
from torch import nn
|
4 |
-
|
5 |
-
from lcpfn import bar_distribution, encoders, train
|
6 |
-
from lcpfn import utils
|
7 |
-
|
8 |
-
from lcpfn.priors import utils as putils
|
9 |
-
|
10 |
-
|
11 |
-
def train_lcpfn(
|
12 |
-
get_batch_func,
|
13 |
-
seq_len: int = 100,
|
14 |
-
emsize: int = 512,
|
15 |
-
nlayers: int = 12,
|
16 |
-
num_borders: int = 1000,
|
17 |
-
lr: float = 0.0001,
|
18 |
-
batch_size: int = 100,
|
19 |
-
epochs: int = 1000,
|
20 |
-
):
|
21 |
-
"""
|
22 |
-
Train a LCPFN model using the specified hyperparameters.
|
23 |
-
|
24 |
-
Args:
|
25 |
-
get_batch_func (callable): A function that returns a batch of learning curves.
|
26 |
-
seq_len (int, optional): The length of the input sequence. Defaults to 100.
|
27 |
-
emsize (int, optional): The size of the embedding layer. Defaults to 512.
|
28 |
-
nlayers (int, optional): The number of layers in the model. Defaults to 12.
|
29 |
-
num_borders_choices (int, optional): The number of borders to use. Defaults to 1000.
|
30 |
-
lr (float, optional): The learning rate for the optimizer. Defaults to 0.0001.
|
31 |
-
batch_size (int, optional): The batch size for training. Defaults to 100.
|
32 |
-
epochs (int, optional): The number of epochs to train for. Defaults to 1000.
|
33 |
-
|
34 |
-
Returns:
|
35 |
-
torch.module: The trained model.
|
36 |
-
"""
|
37 |
-
|
38 |
-
hps = {}
|
39 |
-
|
40 |
-
# PFN training hyperparameters
|
41 |
-
dataloader = putils.get_batch_to_dataloader(get_batch_func) # type: ignore
|
42 |
-
|
43 |
-
num_features = 1
|
44 |
-
|
45 |
-
ys = get_batch_func(
|
46 |
-
10_000,
|
47 |
-
seq_len,
|
48 |
-
num_features,
|
49 |
-
hyperparameters=hps,
|
50 |
-
single_eval_pos=seq_len,
|
51 |
-
)
|
52 |
-
|
53 |
-
bucket_limits = bar_distribution.get_bucket_limits(num_borders, ys=ys[2])
|
54 |
-
|
55 |
-
# Discretization of the predictive distributions
|
56 |
-
criterions = {
|
57 |
-
num_features: {
|
58 |
-
num_borders: bar_distribution.FullSupportBarDistribution(bucket_limits)
|
59 |
-
}
|
60 |
-
}
|
61 |
-
|
62 |
-
config = dict(
|
63 |
-
nlayers=nlayers,
|
64 |
-
priordataloader_class=dataloader,
|
65 |
-
criterion=criterions[num_features][num_borders],
|
66 |
-
encoder_generator=lambda in_dim, out_dim: nn.Sequential(
|
67 |
-
encoders.Normalize(0.0, 101.0),
|
68 |
-
encoders.Normalize(0.5, math.sqrt(1 / 12)),
|
69 |
-
encoders.Linear(in_dim, out_dim),
|
70 |
-
),
|
71 |
-
emsize=emsize,
|
72 |
-
nhead=(emsize // 128),
|
73 |
-
warmup_epochs=(epochs // 4),
|
74 |
-
y_encoder_generator=encoders.get_normalized_uniform_encoder(encoders.Linear),
|
75 |
-
batch_size=batch_size,
|
76 |
-
scheduler=utils.get_cosine_schedule_with_warmup,
|
77 |
-
extra_prior_kwargs_dict={
|
78 |
-
# "num_workers": 10,
|
79 |
-
"num_features": num_features,
|
80 |
-
"hyperparameters": {
|
81 |
-
**hps,
|
82 |
-
},
|
83 |
-
},
|
84 |
-
epochs=epochs,
|
85 |
-
lr=lr,
|
86 |
-
bptt=seq_len,
|
87 |
-
single_eval_pos_gen=utils.get_uniform_single_eval_pos_sampler(
|
88 |
-
seq_len, min_len=1
|
89 |
-
),
|
90 |
-
aggregate_k_gradients=1,
|
91 |
-
nhid=(emsize * 2),
|
92 |
-
steps_per_epoch=100,
|
93 |
-
train_mixed_precision=False,
|
94 |
-
)
|
95 |
-
|
96 |
-
return train.train(**config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lcpfn/transformer.py
DELETED
@@ -1,348 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
from typing import Optional
|
3 |
-
|
4 |
-
import torch
|
5 |
-
import torch.nn as nn
|
6 |
-
from torch import Tensor
|
7 |
-
import torch.nn.functional as F
|
8 |
-
from torch.nn import Module, TransformerEncoder
|
9 |
-
|
10 |
-
from lcpfn.layer import TransformerEncoderLayer, _get_activation_fn
|
11 |
-
from lcpfn.utils import SeqBN, bool_mask_to_att_mask
|
12 |
-
|
13 |
-
|
14 |
-
class GELU(nn.Module):
|
15 |
-
def forward(self, input: Tensor) -> Tensor:
|
16 |
-
return F.gelu(input)
|
17 |
-
|
18 |
-
|
19 |
-
class TransformerModel(nn.Module):
|
20 |
-
def __init__(
|
21 |
-
self,
|
22 |
-
encoder,
|
23 |
-
n_out,
|
24 |
-
ninp,
|
25 |
-
nhead,
|
26 |
-
nhid,
|
27 |
-
nlayers,
|
28 |
-
dropout=0.0,
|
29 |
-
style_encoder=None,
|
30 |
-
y_encoder=None,
|
31 |
-
pos_encoder=None,
|
32 |
-
decoder=None,
|
33 |
-
input_normalization=False,
|
34 |
-
init_method=None,
|
35 |
-
pre_norm=False,
|
36 |
-
activation="gelu",
|
37 |
-
recompute_attn=False,
|
38 |
-
num_global_att_tokens=0,
|
39 |
-
full_attention=False,
|
40 |
-
all_layers_same_init=True,
|
41 |
-
):
|
42 |
-
super().__init__()
|
43 |
-
self.model_type = "Transformer"
|
44 |
-
encoder_layer_creator = lambda: TransformerEncoderLayer(
|
45 |
-
ninp,
|
46 |
-
nhead,
|
47 |
-
nhid,
|
48 |
-
dropout,
|
49 |
-
activation=activation,
|
50 |
-
pre_norm=pre_norm,
|
51 |
-
recompute_attn=recompute_attn,
|
52 |
-
)
|
53 |
-
self.transformer_encoder = (
|
54 |
-
TransformerEncoder(encoder_layer_creator(), nlayers)
|
55 |
-
if all_layers_same_init
|
56 |
-
else TransformerEncoderDiffInit(encoder_layer_creator, nlayers)
|
57 |
-
)
|
58 |
-
self.ninp = ninp
|
59 |
-
self.encoder = encoder
|
60 |
-
self.y_encoder = y_encoder
|
61 |
-
self.pos_encoder = pos_encoder
|
62 |
-
self.decoder = (
|
63 |
-
decoder(ninp, nhid, n_out)
|
64 |
-
if decoder is not None
|
65 |
-
else nn.Sequential(nn.Linear(ninp, nhid), GELU(), nn.Linear(nhid, n_out))
|
66 |
-
)
|
67 |
-
self.input_ln = SeqBN(ninp) if input_normalization else None
|
68 |
-
self.style_encoder = style_encoder
|
69 |
-
self.init_method = init_method
|
70 |
-
if num_global_att_tokens is not None:
|
71 |
-
assert not full_attention
|
72 |
-
self.global_att_embeddings = (
|
73 |
-
nn.Embedding(num_global_att_tokens, ninp) if num_global_att_tokens else None
|
74 |
-
)
|
75 |
-
self.full_attention = full_attention
|
76 |
-
|
77 |
-
self.n_out = n_out
|
78 |
-
self.nhid = nhid
|
79 |
-
|
80 |
-
self.init_weights()
|
81 |
-
|
82 |
-
@staticmethod
|
83 |
-
def generate_square_subsequent_mask(sz):
|
84 |
-
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
85 |
-
return bool_mask_to_att_mask(mask)
|
86 |
-
|
87 |
-
@staticmethod
|
88 |
-
def generate_D_q_matrix(sz, query_size):
|
89 |
-
train_size = sz - query_size
|
90 |
-
mask = torch.zeros(sz, sz) == 0
|
91 |
-
mask[:, train_size:].zero_()
|
92 |
-
mask |= torch.eye(sz) == 1
|
93 |
-
return bool_mask_to_att_mask(mask)
|
94 |
-
|
95 |
-
@staticmethod
|
96 |
-
def generate_global_att_query_matrix(
|
97 |
-
num_global_att_tokens, seq_len, num_query_tokens
|
98 |
-
):
|
99 |
-
train_size = seq_len + num_global_att_tokens - num_query_tokens
|
100 |
-
sz = seq_len + num_global_att_tokens
|
101 |
-
mask = torch.zeros(num_query_tokens, sz) == 0
|
102 |
-
mask[:, train_size:].zero_()
|
103 |
-
mask[:, train_size:] |= torch.eye(num_query_tokens) == 1
|
104 |
-
return bool_mask_to_att_mask(mask)
|
105 |
-
|
106 |
-
@staticmethod
|
107 |
-
def generate_global_att_trainset_matrix(
|
108 |
-
num_global_att_tokens, seq_len, num_query_tokens
|
109 |
-
):
|
110 |
-
train_size = seq_len + num_global_att_tokens - num_query_tokens
|
111 |
-
trainset_size = seq_len - num_query_tokens
|
112 |
-
mask = torch.zeros(trainset_size, num_global_att_tokens) == 0
|
113 |
-
# mask[:,num_global_att_tokens:].zero_()
|
114 |
-
# mask[:,num_global_att_tokens:] |= torch.eye(trainset_size) == 1
|
115 |
-
return bool_mask_to_att_mask(mask)
|
116 |
-
|
117 |
-
@staticmethod
|
118 |
-
def generate_global_att_globaltokens_matrix(
|
119 |
-
num_global_att_tokens, seq_len, num_query_tokens
|
120 |
-
):
|
121 |
-
mask = (
|
122 |
-
torch.zeros(
|
123 |
-
num_global_att_tokens,
|
124 |
-
num_global_att_tokens + seq_len - num_query_tokens,
|
125 |
-
)
|
126 |
-
== 0
|
127 |
-
)
|
128 |
-
return bool_mask_to_att_mask(mask)
|
129 |
-
|
130 |
-
def init_weights(self):
|
131 |
-
initrange = 1.0
|
132 |
-
# if isinstance(self.encoder,EmbeddingEncoder):
|
133 |
-
# self.encoder.weight.data.uniform_(-initrange, initrange)
|
134 |
-
# self.decoder.bias.data.zero_()
|
135 |
-
# self.decoder.weight.data.uniform_(-initrange, initrange)
|
136 |
-
if self.init_method is not None:
|
137 |
-
self.apply(self.init_method)
|
138 |
-
for layer in self.transformer_encoder.layers:
|
139 |
-
nn.init.zeros_(layer.linear2.weight)
|
140 |
-
nn.init.zeros_(layer.linear2.bias)
|
141 |
-
attns = (
|
142 |
-
layer.self_attn
|
143 |
-
if isinstance(layer.self_attn, nn.ModuleList)
|
144 |
-
else [layer.self_attn]
|
145 |
-
)
|
146 |
-
for attn in attns:
|
147 |
-
nn.init.zeros_(attn.out_proj.weight)
|
148 |
-
nn.init.zeros_(attn.out_proj.bias)
|
149 |
-
|
150 |
-
def forward(self, src, src_mask=None, single_eval_pos=None):
|
151 |
-
assert isinstance(
|
152 |
-
src, tuple
|
153 |
-
), "inputs (src) have to be given as (x,y) or (style,x,y) tuple"
|
154 |
-
|
155 |
-
if len(src) == 2: # (x,y) and no style
|
156 |
-
src = (None,) + src
|
157 |
-
|
158 |
-
style_src, style_src_size = (src[0], (0 if (src[0] is None) else 1))
|
159 |
-
if src_mask is not None:
|
160 |
-
assert self.global_att_embeddings is None or isinstance(src_mask, tuple)
|
161 |
-
if src_mask is None:
|
162 |
-
x_src = src[1]
|
163 |
-
if self.global_att_embeddings is None:
|
164 |
-
full_len = len(x_src) + style_src_size
|
165 |
-
if self.full_attention:
|
166 |
-
src_mask = bool_mask_to_att_mask(
|
167 |
-
torch.ones((full_len, full_len), dtype=torch.bool)
|
168 |
-
).to(x_src.device)
|
169 |
-
else:
|
170 |
-
src_mask = self.generate_D_q_matrix(
|
171 |
-
len(x_src) + style_src_size,
|
172 |
-
len(x_src) + style_src_size - single_eval_pos,
|
173 |
-
).to(x_src.device)
|
174 |
-
else:
|
175 |
-
src_mask_args = (
|
176 |
-
self.global_att_embeddings.num_embeddings,
|
177 |
-
len(x_src) + style_src_size,
|
178 |
-
len(x_src) + style_src_size - single_eval_pos,
|
179 |
-
)
|
180 |
-
src_mask = (
|
181 |
-
self.generate_global_att_globaltokens_matrix(*src_mask_args).to(
|
182 |
-
x_src.device
|
183 |
-
),
|
184 |
-
self.generate_global_att_trainset_matrix(*src_mask_args).to(
|
185 |
-
x_src.device
|
186 |
-
),
|
187 |
-
self.generate_global_att_query_matrix(*src_mask_args).to(
|
188 |
-
x_src.device
|
189 |
-
),
|
190 |
-
)
|
191 |
-
|
192 |
-
style_src, x_src, y_src = src
|
193 |
-
x_src = self.encoder(x_src)
|
194 |
-
y_src = self.y_encoder(
|
195 |
-
y_src.unsqueeze(-1) if len(y_src.shape) < len(x_src.shape) else y_src
|
196 |
-
)
|
197 |
-
style_src = (
|
198 |
-
self.style_encoder(style_src).unsqueeze(0)
|
199 |
-
if self.style_encoder
|
200 |
-
else torch.tensor([], device=x_src.device)
|
201 |
-
)
|
202 |
-
global_src = (
|
203 |
-
torch.tensor([], device=x_src.device)
|
204 |
-
if self.global_att_embeddings is None
|
205 |
-
else self.global_att_embeddings.weight.unsqueeze(1).repeat(
|
206 |
-
1, x_src.shape[1], 1
|
207 |
-
)
|
208 |
-
)
|
209 |
-
train_x = x_src[:single_eval_pos] + y_src[:single_eval_pos]
|
210 |
-
src = torch.cat([global_src, style_src, train_x, x_src[single_eval_pos:]], 0)
|
211 |
-
|
212 |
-
if self.input_ln is not None:
|
213 |
-
src = self.input_ln(src)
|
214 |
-
|
215 |
-
if self.pos_encoder is not None:
|
216 |
-
src = self.pos_encoder(src)
|
217 |
-
|
218 |
-
# If we have style input, drop its output
|
219 |
-
output = self.transformer_encoder(src, src_mask)[style_src_size:]
|
220 |
-
output = self.decoder(output)
|
221 |
-
return output[
|
222 |
-
single_eval_pos
|
223 |
-
+ (
|
224 |
-
self.global_att_embeddings.num_embeddings
|
225 |
-
if self.global_att_embeddings
|
226 |
-
else 0
|
227 |
-
) :
|
228 |
-
]
|
229 |
-
|
230 |
-
@torch.no_grad()
|
231 |
-
def init_from_small_model(self, small_model):
|
232 |
-
assert (
|
233 |
-
isinstance(self.decoder, nn.Linear)
|
234 |
-
and isinstance(self.encoder, (nn.Linear, nn.Sequential))
|
235 |
-
and isinstance(self.y_encoder, (nn.Linear, nn.Sequential))
|
236 |
-
)
|
237 |
-
|
238 |
-
def set_encoder_weights(my_encoder, small_model_encoder):
|
239 |
-
my_encoder_linear, small_encoder_linear = (
|
240 |
-
(my_encoder, small_model_encoder)
|
241 |
-
if isinstance(my_encoder, nn.Linear)
|
242 |
-
else (my_encoder[-1], small_model_encoder[-1])
|
243 |
-
)
|
244 |
-
small_in_dim = small_encoder_linear.out_features
|
245 |
-
my_encoder_linear.weight.zero_()
|
246 |
-
my_encoder_linear.bias.zero_()
|
247 |
-
my_encoder_linear.weight[:small_in_dim] = small_encoder_linear.weight
|
248 |
-
my_encoder_linear.bias[:small_in_dim] = small_encoder_linear.bias
|
249 |
-
|
250 |
-
set_encoder_weights(self.encoder, small_model.encoder)
|
251 |
-
set_encoder_weights(self.y_encoder, small_model.y_encoder)
|
252 |
-
|
253 |
-
small_in_dim = small_model.decoder.in_features
|
254 |
-
|
255 |
-
self.decoder.weight[:, :small_in_dim] = small_model.decoder.weight
|
256 |
-
self.decoder.bias = small_model.decoder.bias
|
257 |
-
|
258 |
-
for my_layer, small_layer in zip(
|
259 |
-
self.transformer_encoder.layers, small_model.transformer_encoder.layers
|
260 |
-
):
|
261 |
-
small_hid_dim = small_layer.linear1.out_features
|
262 |
-
my_in_dim = my_layer.linear1.in_features
|
263 |
-
|
264 |
-
# packed along q,k,v order in first dim
|
265 |
-
my_in_proj_w = my_layer.self_attn.in_proj_weight
|
266 |
-
small_in_proj_w = small_layer.self_attn.in_proj_weight
|
267 |
-
|
268 |
-
my_in_proj_w.view(3, my_in_dim, my_in_dim)[
|
269 |
-
:, :small_in_dim, :small_in_dim
|
270 |
-
] = small_in_proj_w.view(3, small_in_dim, small_in_dim)
|
271 |
-
my_layer.self_attn.in_proj_bias.view(3, my_in_dim)[:, :small_in_dim] = (
|
272 |
-
small_layer.self_attn.in_proj_bias.view(3, small_in_dim)
|
273 |
-
)
|
274 |
-
|
275 |
-
my_layer.self_attn.out_proj.weight[:small_in_dim, :small_in_dim] = (
|
276 |
-
small_layer.self_attn.out_proj.weight
|
277 |
-
)
|
278 |
-
my_layer.self_attn.out_proj.bias[:small_in_dim] = (
|
279 |
-
small_layer.self_attn.out_proj.bias
|
280 |
-
)
|
281 |
-
|
282 |
-
my_layer.linear1.weight[:small_hid_dim, :small_in_dim] = (
|
283 |
-
small_layer.linear1.weight
|
284 |
-
)
|
285 |
-
my_layer.linear1.bias[:small_hid_dim] = small_layer.linear1.bias
|
286 |
-
|
287 |
-
my_layer.linear2.weight[:small_in_dim, :small_hid_dim] = (
|
288 |
-
small_layer.linear2.weight
|
289 |
-
)
|
290 |
-
my_layer.linear2.bias[:small_in_dim] = small_layer.linear2.bias
|
291 |
-
|
292 |
-
my_layer.norm1.weight[:small_in_dim] = (
|
293 |
-
math.sqrt(small_in_dim / my_in_dim) * small_layer.norm1.weight
|
294 |
-
)
|
295 |
-
my_layer.norm2.weight[:small_in_dim] = (
|
296 |
-
math.sqrt(small_in_dim / my_in_dim) * small_layer.norm2.weight
|
297 |
-
)
|
298 |
-
|
299 |
-
my_layer.norm1.bias[:small_in_dim] = small_layer.norm1.bias
|
300 |
-
my_layer.norm2.bias[:small_in_dim] = small_layer.norm2.bias
|
301 |
-
|
302 |
-
|
303 |
-
class TransformerEncoderDiffInit(Module):
|
304 |
-
r"""TransformerEncoder is a stack of N encoder layers
|
305 |
-
|
306 |
-
Args:
|
307 |
-
encoder_layer_creator: a function generating objects of TransformerEncoderLayer class without args (required).
|
308 |
-
num_layers: the number of sub-encoder-layers in the encoder (required).
|
309 |
-
norm: the layer normalization component (optional).
|
310 |
-
"""
|
311 |
-
|
312 |
-
__constants__ = ["norm"]
|
313 |
-
|
314 |
-
def __init__(self, encoder_layer_creator, num_layers, norm=None):
|
315 |
-
super().__init__()
|
316 |
-
self.layers = nn.ModuleList(
|
317 |
-
[encoder_layer_creator() for _ in range(num_layers)]
|
318 |
-
)
|
319 |
-
self.num_layers = num_layers
|
320 |
-
self.norm = norm
|
321 |
-
|
322 |
-
def forward(
|
323 |
-
self,
|
324 |
-
src: Tensor,
|
325 |
-
mask: Optional[Tensor] = None,
|
326 |
-
src_key_padding_mask: Optional[Tensor] = None,
|
327 |
-
) -> Tensor:
|
328 |
-
r"""Pass the input through the encoder layers in turn.
|
329 |
-
|
330 |
-
Args:
|
331 |
-
src: the sequence to the encoder (required).
|
332 |
-
mask: the mask for the src sequence (optional).
|
333 |
-
src_key_padding_mask: the mask for the src keys per batch (optional).
|
334 |
-
|
335 |
-
Shape:
|
336 |
-
see the docs in Transformer class.
|
337 |
-
"""
|
338 |
-
output = src
|
339 |
-
|
340 |
-
for mod in self.layers:
|
341 |
-
output = mod(
|
342 |
-
output, src_mask=mask, src_key_padding_mask=src_key_padding_mask
|
343 |
-
)
|
344 |
-
|
345 |
-
if self.norm is not None:
|
346 |
-
output = self.norm(output)
|
347 |
-
|
348 |
-
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lcpfn/utils.py
DELETED
@@ -1,409 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import math
|
3 |
-
import argparse
|
4 |
-
import random
|
5 |
-
import datetime
|
6 |
-
|
7 |
-
import torch
|
8 |
-
from torch import nn
|
9 |
-
from torch.optim.lr_scheduler import LambdaLR
|
10 |
-
import numpy as np
|
11 |
-
|
12 |
-
|
13 |
-
# copied from huggingface
|
14 |
-
def get_cosine_schedule_with_warmup(
|
15 |
-
optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1
|
16 |
-
):
|
17 |
-
"""Create a schedule with a learning rate that decreases following the
|
18 |
-
values of the cosine function between 0 and `pi * cycles` after a warmup
|
19 |
-
period during which it increases linearly between 0 and 1.
|
20 |
-
"""
|
21 |
-
|
22 |
-
def lr_lambda(current_step):
|
23 |
-
if current_step < num_warmup_steps:
|
24 |
-
return float(current_step) / float(max(1, num_warmup_steps))
|
25 |
-
progress = float(current_step - num_warmup_steps) / float(
|
26 |
-
max(1, num_training_steps - num_warmup_steps)
|
27 |
-
)
|
28 |
-
return max(
|
29 |
-
0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
|
30 |
-
)
|
31 |
-
|
32 |
-
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
33 |
-
|
34 |
-
|
35 |
-
# copied from huggingface
|
36 |
-
def get_linear_schedule_with_warmup(
|
37 |
-
optimizer, num_warmup_steps, num_training_steps, last_epoch=-1
|
38 |
-
):
|
39 |
-
"""
|
40 |
-
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
|
41 |
-
a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
|
42 |
-
|
43 |
-
Args:
|
44 |
-
optimizer (:class:`~torch.optim.Optimizer`):
|
45 |
-
The optimizer for which to schedule the learning rate.
|
46 |
-
num_warmup_steps (:obj:`int`):
|
47 |
-
The number of steps for the warmup phase.
|
48 |
-
num_training_steps (:obj:`int`):
|
49 |
-
The total number of training steps.
|
50 |
-
last_epoch (:obj:`int`, `optional`, defaults to -1):
|
51 |
-
The index of the last epoch when resuming training.
|
52 |
-
|
53 |
-
Return:
|
54 |
-
:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
55 |
-
"""
|
56 |
-
|
57 |
-
def lr_lambda(current_step: int):
|
58 |
-
if current_step < num_warmup_steps:
|
59 |
-
return float(current_step) / float(max(1, num_warmup_steps))
|
60 |
-
return max(
|
61 |
-
0.0,
|
62 |
-
float(num_training_steps - current_step)
|
63 |
-
/ float(max(1, num_training_steps - num_warmup_steps)),
|
64 |
-
)
|
65 |
-
|
66 |
-
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
67 |
-
|
68 |
-
|
69 |
-
def get_openai_lr(transformer_model):
|
70 |
-
num_params = sum(p.numel() for p in transformer_model.parameters())
|
71 |
-
return 0.003239 - 0.0001395 * math.log(num_params)
|
72 |
-
|
73 |
-
|
74 |
-
def get_weighted_single_eval_pos_sampler(max_len):
|
75 |
-
"""
|
76 |
-
This gives a sampler that can be used for `single_eval_pos` which yields good performance for all positions p,
|
77 |
-
where p <= `max_len`. At most `max_len` - 1 examples are shown to the Transformer.
|
78 |
-
:return: Sampler that can be fed to `train()` as `single_eval_pos_gen`.
|
79 |
-
"""
|
80 |
-
return lambda: random.choices(
|
81 |
-
range(max_len), [1 / (max_len - i) for i in range(max_len)]
|
82 |
-
)[0]
|
83 |
-
|
84 |
-
|
85 |
-
def get_uniform_single_eval_pos_sampler(max_len, min_len=0):
|
86 |
-
"""
|
87 |
-
Just sample any evaluation position with the same weight
|
88 |
-
:return: Sampler that can be fed to `train()` as `single_eval_pos_gen`.
|
89 |
-
"""
|
90 |
-
return lambda: random.choices(range(min_len, max_len))[0]
|
91 |
-
|
92 |
-
|
93 |
-
class SeqBN(nn.Module):
|
94 |
-
def __init__(self, d_model):
|
95 |
-
super().__init__()
|
96 |
-
self.bn = nn.BatchNorm1d(d_model)
|
97 |
-
self.d_model = d_model
|
98 |
-
|
99 |
-
def forward(self, x):
|
100 |
-
assert self.d_model == x.shape[-1]
|
101 |
-
flat_x = x.view(-1, self.d_model)
|
102 |
-
flat_x = self.bn(flat_x)
|
103 |
-
return flat_x.view(*x.shape)
|
104 |
-
|
105 |
-
|
106 |
-
def set_locals_in_self(locals):
|
107 |
-
"""
|
108 |
-
Call this function like `set_locals_in_self(locals())` to set all local variables as object variables.
|
109 |
-
Especially useful right at the beginning of `__init__`.
|
110 |
-
:param locals: `locals()`
|
111 |
-
"""
|
112 |
-
self = locals["self"]
|
113 |
-
for var_name, val in locals.items():
|
114 |
-
if var_name != "self":
|
115 |
-
setattr(self, var_name, val)
|
116 |
-
|
117 |
-
|
118 |
-
default_device = "cuda:0" if torch.cuda.is_available() else "cpu:0"
|
119 |
-
|
120 |
-
|
121 |
-
# Copied from StackOverflow, but we do an eval on the values additionally
|
122 |
-
class StoreDictKeyPair(argparse.Action):
|
123 |
-
def __init__(self, option_strings, dest, nargs=None, **kwargs):
|
124 |
-
self._nargs = nargs
|
125 |
-
super(StoreDictKeyPair, self).__init__(
|
126 |
-
option_strings, dest, nargs=nargs, **kwargs
|
127 |
-
)
|
128 |
-
|
129 |
-
def __call__(self, parser, namespace, values, option_string=None):
|
130 |
-
my_dict = {}
|
131 |
-
for kv in values:
|
132 |
-
k, v = kv.split("=")
|
133 |
-
try:
|
134 |
-
my_dict[k] = eval(v)
|
135 |
-
except NameError:
|
136 |
-
my_dict[k] = v
|
137 |
-
setattr(namespace, self.dest, my_dict)
|
138 |
-
print("dict values: {}".format(my_dict))
|
139 |
-
|
140 |
-
|
141 |
-
def get_nan_value(v, set_value_to_nan=0.0):
|
142 |
-
if random.random() < set_value_to_nan:
|
143 |
-
return v
|
144 |
-
else:
|
145 |
-
return random.choice([-999, 0, 1, 999])
|
146 |
-
|
147 |
-
|
148 |
-
def to_ranking(data):
|
149 |
-
x = data >= data.unsqueeze(-3)
|
150 |
-
x = x.sum(0)
|
151 |
-
return x
|
152 |
-
|
153 |
-
|
154 |
-
# TODO: Is there a better way to do this?
|
155 |
-
# 1. Cmparing to unique elements: When all values are different we still get quadratic blowup
|
156 |
-
# 2. Argsort(Argsort()) returns ranking, but with duplicate values there is an ordering which is problematic
|
157 |
-
# 3. Argsort(Argsort(Unique))->Scatter seems a bit complicated, doesn't have quadratic blowup, but how fast?
|
158 |
-
def to_ranking_low_mem(data):
|
159 |
-
x = torch.zeros_like(data)
|
160 |
-
for col in range(data.shape[-1]):
|
161 |
-
x_ = data[:, :, col] >= data[:, :, col].unsqueeze(-2)
|
162 |
-
x_ = x_.sum(0)
|
163 |
-
x[:, :, col] = x_
|
164 |
-
return x
|
165 |
-
|
166 |
-
|
167 |
-
def nan_handling_missing_for_unknown_reason_value(set_value_to_nan=0.0):
|
168 |
-
return get_nan_value(float("nan"), set_value_to_nan)
|
169 |
-
|
170 |
-
|
171 |
-
def nan_handling_missing_for_no_reason_value(set_value_to_nan=0.0):
|
172 |
-
return get_nan_value(float("-inf"), set_value_to_nan)
|
173 |
-
|
174 |
-
|
175 |
-
def nan_handling_missing_for_a_reason_value(set_value_to_nan=0.0):
|
176 |
-
return get_nan_value(float("inf"), set_value_to_nan)
|
177 |
-
|
178 |
-
|
179 |
-
def torch_nanmean(x, axis=0):
|
180 |
-
num = torch.where(torch.isnan(x), torch.full_like(x, 0), torch.full_like(x, 1)).sum(
|
181 |
-
axis=axis
|
182 |
-
)
|
183 |
-
value = torch.where(torch.isnan(x), torch.full_like(x, 0), x).sum(axis=axis)
|
184 |
-
return value / num
|
185 |
-
|
186 |
-
|
187 |
-
def torch_nanstd(x, axis=0):
|
188 |
-
num = torch.where(torch.isnan(x), torch.full_like(x, 0), torch.full_like(x, 1)).sum(
|
189 |
-
axis=axis
|
190 |
-
)
|
191 |
-
value = torch.where(torch.isnan(x), torch.full_like(x, 0), x).sum(axis=axis)
|
192 |
-
mean = value / num
|
193 |
-
mean_broadcast = torch.repeat_interleave(
|
194 |
-
mean.unsqueeze(axis), x.shape[axis], dim=axis
|
195 |
-
)
|
196 |
-
return torch.sqrt(
|
197 |
-
torch.nansum(torch.square(mean_broadcast - x), axis=axis) / (num - 1)
|
198 |
-
)
|
199 |
-
|
200 |
-
|
201 |
-
def normalize_data(data, normalize_positions=-1):
|
202 |
-
if normalize_positions > 0:
|
203 |
-
mean = torch_nanmean(data[:normalize_positions], axis=0)
|
204 |
-
std = torch_nanstd(data[:normalize_positions], axis=0) + 0.000001
|
205 |
-
else:
|
206 |
-
mean = torch_nanmean(data, axis=0)
|
207 |
-
std = torch_nanstd(data, axis=0) + 0.000001
|
208 |
-
data = (data - mean) / std
|
209 |
-
data = torch.clip(data, min=-100, max=100)
|
210 |
-
|
211 |
-
return data
|
212 |
-
|
213 |
-
|
214 |
-
def remove_outliers(X, n_sigma=4):
|
215 |
-
# Expects T, B, H
|
216 |
-
assert len(X.shape) == 3, "X must be T,B,H"
|
217 |
-
# for b in range(X.shape[1]):
|
218 |
-
# for col in range(X.shape[2]):
|
219 |
-
data = X
|
220 |
-
data_mean, data_std = torch_nanmean(data, axis=0), torch_nanstd(data, axis=0)
|
221 |
-
cut_off = data_std * n_sigma
|
222 |
-
lower, upper = data_mean - cut_off, data_mean + cut_off
|
223 |
-
|
224 |
-
data_clean = X[:].clone()
|
225 |
-
data_clean[torch.logical_or(data > upper, data < lower)] = np.nan
|
226 |
-
data_mean, data_std = (
|
227 |
-
torch_nanmean(data_clean, axis=0),
|
228 |
-
torch_nanstd(data_clean, axis=0),
|
229 |
-
)
|
230 |
-
cut_off = data_std * n_sigma
|
231 |
-
lower, upper = data_mean - cut_off, data_mean + cut_off
|
232 |
-
|
233 |
-
X = torch.maximum(-torch.log(1 + torch.abs(X)) + lower, X)
|
234 |
-
X = torch.minimum(torch.log(1 + torch.abs(X)) + upper, X)
|
235 |
-
# print(ds[1][data < lower, col], ds[1][data > upper, col], ds[1][~np.isnan(data), col].shape, data_mean, data_std)
|
236 |
-
return X
|
237 |
-
|
238 |
-
|
239 |
-
def bool_mask_to_att_mask(mask):
|
240 |
-
return (
|
241 |
-
mask.float()
|
242 |
-
.masked_fill(mask == 0, float("-inf"))
|
243 |
-
.masked_fill(mask == 1, float(0.0))
|
244 |
-
)
|
245 |
-
|
246 |
-
|
247 |
-
def print_on_master_only(is_master):
|
248 |
-
import builtins as __builtin__
|
249 |
-
|
250 |
-
builtin_print = __builtin__.print
|
251 |
-
|
252 |
-
def print(*args, **kwargs):
|
253 |
-
force = kwargs.pop("force", False)
|
254 |
-
if is_master or force:
|
255 |
-
builtin_print(*args, **kwargs)
|
256 |
-
|
257 |
-
__builtin__.print = print
|
258 |
-
|
259 |
-
|
260 |
-
def init_dist(device):
|
261 |
-
print("init dist")
|
262 |
-
if "LOCAL_RANK" in os.environ:
|
263 |
-
# launched with torch.distributed.launch
|
264 |
-
rank = int(os.environ["LOCAL_RANK"])
|
265 |
-
print("torch.distributed.launch and my rank is", rank)
|
266 |
-
torch.cuda.set_device(rank)
|
267 |
-
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
|
268 |
-
torch.distributed.init_process_group(
|
269 |
-
backend="nccl",
|
270 |
-
init_method="env://",
|
271 |
-
timeout=datetime.timedelta(seconds=20),
|
272 |
-
world_size=torch.cuda.device_count(),
|
273 |
-
rank=rank,
|
274 |
-
)
|
275 |
-
torch.distributed.barrier()
|
276 |
-
print_on_master_only(rank == 0)
|
277 |
-
print(
|
278 |
-
f"Distributed training on {torch.cuda.device_count()} GPUs, this is rank {rank}, "
|
279 |
-
"only I can print, but when using print(..., force=True) it will print on all ranks."
|
280 |
-
)
|
281 |
-
return True, rank, f"cuda:{rank}"
|
282 |
-
elif "SLURM_PROCID" in os.environ and torch.cuda.device_count() > 1:
|
283 |
-
# this is for multi gpu when starting with submitit
|
284 |
-
assert device != "cpu:0"
|
285 |
-
rank = int(os.environ["SLURM_PROCID"])
|
286 |
-
os.environ["MASTER_ADDR"] = "localhost"
|
287 |
-
os.environ["MASTER_PORT"] = "12355"
|
288 |
-
torch.cuda.set_device(rank)
|
289 |
-
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
|
290 |
-
print("distributed submitit launch and my rank is", rank)
|
291 |
-
torch.distributed.init_process_group(
|
292 |
-
backend="nccl",
|
293 |
-
init_method="env://",
|
294 |
-
timeout=datetime.timedelta(seconds=20),
|
295 |
-
world_size=torch.cuda.device_count(),
|
296 |
-
rank=rank,
|
297 |
-
)
|
298 |
-
torch.distributed.barrier()
|
299 |
-
print_on_master_only(rank == 0)
|
300 |
-
print(
|
301 |
-
f"Distributed training on {torch.cuda.device_count()} GPUs, this is rank {rank}, "
|
302 |
-
"only I can print, but when using print(..., force=True) it will print on all ranks."
|
303 |
-
)
|
304 |
-
|
305 |
-
return True, rank, f"cuda:{rank}"
|
306 |
-
else:
|
307 |
-
print("Not using distributed")
|
308 |
-
# will not change any of the behavior of print, but allows putting the force=True in the print calls
|
309 |
-
print_on_master_only(True)
|
310 |
-
return False, 0, device
|
311 |
-
|
312 |
-
|
313 |
-
def check_compatibility(dl):
|
314 |
-
if hasattr(dl, "num_outputs"):
|
315 |
-
print(
|
316 |
-
"`num_outputs` for the DataLoader is deprecated. It is assumed to be 1 from now on."
|
317 |
-
)
|
318 |
-
assert dl.num_outputs != 1, (
|
319 |
-
"We assume num_outputs to be 1. Instead of the num_ouputs change your loss."
|
320 |
-
"We specify the number of classes in the CE loss."
|
321 |
-
)
|
322 |
-
|
323 |
-
|
324 |
-
def pfn_normalize(
|
325 |
-
lb=torch.tensor(float("-inf")),
|
326 |
-
ub=torch.tensor(float("inf")),
|
327 |
-
soft_lb=0.0,
|
328 |
-
soft_ub=1.0,
|
329 |
-
minimize=False,
|
330 |
-
):
|
331 |
-
"""
|
332 |
-
LC-PFN curve prior assumes curves to be normalized within the range [0,1] and to be maximized.
|
333 |
-
This function allows to normalize and denormalize data to fit this assumption.
|
334 |
-
|
335 |
-
Parameters:
|
336 |
-
lb (torch.Tensor): Lower bound of the data.
|
337 |
-
ub (torch.Tensor): Upper bound of the data.
|
338 |
-
soft_lb (float): Soft lower bound for normalization. Default is 0.0.
|
339 |
-
soft_ub (float): Soft upper bound for normalization. Default is 1.0.
|
340 |
-
minimize (bool): If True, the original curve is a minization. Default is False.
|
341 |
-
|
342 |
-
Returns: Two functions for normalizing and denormalizing the data.
|
343 |
-
"""
|
344 |
-
assert lb <= soft_lb and soft_lb < soft_ub and soft_ub <= ub
|
345 |
-
# step 1: linearly transform [soft_lb,soft_ub] [-1,1] (where the sigmoid behaves approx linearly)
|
346 |
-
# 2.0/(soft_ub - soft_lb)*(x - soft_lb) - 1.0
|
347 |
-
# step 2: apply a vertically scaled/shifted the sigmoid such that [lb,ub] --> [0,1]
|
348 |
-
|
349 |
-
def cinv(x):
|
350 |
-
return 1 - x if minimize else x
|
351 |
-
|
352 |
-
def lin_soft(x):
|
353 |
-
return 2 / (soft_ub - soft_lb) * (x - soft_lb) - 1
|
354 |
-
|
355 |
-
def lin_soft_inv(y):
|
356 |
-
return (y + 1) / 2 * (soft_ub - soft_lb) + soft_lb
|
357 |
-
|
358 |
-
try:
|
359 |
-
if torch.exp(-lin_soft(lb)) > 1e300:
|
360 |
-
raise RuntimeError
|
361 |
-
# otherwise overflow causes issues, treat these cases as if the lower bound was -infinite
|
362 |
-
# print(f"WARNING: {lb} --> NINF to avoid overflows ({np.exp(-lin_soft(lb))})")
|
363 |
-
except RuntimeError:
|
364 |
-
lb = torch.tensor(float("-inf"))
|
365 |
-
if torch.isinf(lb) and torch.isinf(ub):
|
366 |
-
return lambda x: cinv(
|
367 |
-
1 / (1 + torch.exp(-lin_soft(x)))
|
368 |
-
), lambda y: lin_soft_inv(torch.log(cinv(y) / (1 - cinv(y))))
|
369 |
-
elif torch.isinf(lb):
|
370 |
-
a = 1 + torch.exp(-lin_soft(ub))
|
371 |
-
return lambda x: cinv(
|
372 |
-
a / (1 + torch.exp(-lin_soft(x)))
|
373 |
-
), lambda y: lin_soft_inv(torch.log((cinv(y) / a) / (1 - (cinv(y) / a))))
|
374 |
-
elif torch.isinf(ub):
|
375 |
-
a = 1 / (1 - 1 / (1 + torch.exp(-lin_soft(lb))))
|
376 |
-
b = 1 - a
|
377 |
-
return lambda x: cinv(
|
378 |
-
a / (1 + torch.exp(-lin_soft(x))) + b
|
379 |
-
), lambda y: lin_soft_inv(
|
380 |
-
torch.log(((cinv(y) - b) / a) / (1 - ((cinv(y) - b) / a)))
|
381 |
-
)
|
382 |
-
else:
|
383 |
-
a = (
|
384 |
-
1
|
385 |
-
+ torch.exp(-lin_soft(ub))
|
386 |
-
+ torch.exp(-lin_soft(lb))
|
387 |
-
+ torch.exp(-lin_soft(ub) - lin_soft(lb))
|
388 |
-
) / (torch.exp(-lin_soft(lb)) - torch.exp(-lin_soft(ub)))
|
389 |
-
b = -a / (1 + torch.exp(-lin_soft(lb)))
|
390 |
-
return lambda x: cinv(
|
391 |
-
a / (1 + torch.exp(-lin_soft(x))) + b
|
392 |
-
), lambda y: lin_soft_inv(
|
393 |
-
torch.log(((cinv(y) - b) / a) / (1 - ((cinv(y) - b) / a)))
|
394 |
-
)
|
395 |
-
|
396 |
-
|
397 |
-
def get_default_normalizer():
|
398 |
-
default_normalizer_kwargs = {
|
399 |
-
"lb": torch.tensor(0.0),
|
400 |
-
"ub": torch.tensor(1.0),
|
401 |
-
"soft_lb": 0.0,
|
402 |
-
"soft_ub": 1.0,
|
403 |
-
"minimize": False,
|
404 |
-
}
|
405 |
-
return pfn_normalize(**default_normalizer_kwargs)
|
406 |
-
|
407 |
-
|
408 |
-
def identity_normalizer():
|
409 |
-
return lambda x: x, lambda x: x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lcpfn/version.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
__version__ = "0.1.3"
|
|
|
|
pyproject.toml
DELETED
@@ -1,42 +0,0 @@
|
|
1 |
-
[project]
|
2 |
-
name = "lcpfn"
|
3 |
-
description = "In-context Bayesian Learning Curve Extrapolation"
|
4 |
-
readme = {file = "readme.md", content-type = 'text/markdown'}
|
5 |
-
license = {file = "LICENSE"}
|
6 |
-
authors = [
|
7 |
-
{name = "Steven Adriaensen", email= "adriaens@cs.uni-freiburg.de"},
|
8 |
-
{name = "Herilalaina Rakotoarison", email = "rakotoah@cs.uni-freiburg.de"},
|
9 |
-
{name = "Samuel Müller", email = "muellesa@cs.uni-freiburg.de"},
|
10 |
-
{name = "Frank Hutter", email = "fh@cs.uni-freiburg.de"},
|
11 |
-
]
|
12 |
-
requires-python = ">=3.9,<3.12"
|
13 |
-
dependencies = [
|
14 |
-
"torch<=1.11.0",
|
15 |
-
"numpy>=1.21.2,<2",
|
16 |
-
"requests>=2.23.0"
|
17 |
-
]
|
18 |
-
dynamic = ["version"]
|
19 |
-
classifiers = [
|
20 |
-
'Intended Audience :: Science/Research',
|
21 |
-
'License :: OSI Approved :: MIT License',
|
22 |
-
'Programming Language :: Python',
|
23 |
-
'Topic :: Software Development',
|
24 |
-
'Topic :: Scientific/Engineering',
|
25 |
-
'Operating System :: Unix',
|
26 |
-
'Operating System :: MacOS',
|
27 |
-
'Programming Language :: Python :: 3',
|
28 |
-
'Programming Language :: Python :: 3.9',
|
29 |
-
'Programming Language :: Python :: 3.10',
|
30 |
-
'Programming Language :: Python :: 3.11',
|
31 |
-
]
|
32 |
-
|
33 |
-
[project.urls]
|
34 |
-
homepage = "https://github.com/automl/lcpfn"
|
35 |
-
repository = "https://github.com/automl/lcpfn"
|
36 |
-
bugtracker = "https://github.com/automl/lcpfn/issues"
|
37 |
-
|
38 |
-
[tool.setuptools.packages.find]
|
39 |
-
include = ["lcpfn*"]
|
40 |
-
|
41 |
-
[tool.setuptools.dynamic]
|
42 |
-
version = {attr = "lcpfn.version.__version__"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
torch<=1.11.0
|
2 |
+
numpy>=1.21.2,<2
|
3 |
+
lcpfn==0.1.3
|