Spaces:
Running
Running
herilalaina
commited on
Commit
•
b62776c
1
Parent(s):
3866017
update lcpfn
Browse files- lcpfn/__init__.py +40 -13
- lcpfn/bar_distribution.py +143 -63
- lcpfn/decoders.py +21 -9
- lcpfn/domhan_prior.py +7 -3
- lcpfn/encoders.py +55 -26
- lcpfn/initializers.py +3 -1
- lcpfn/layer.py +80 -27
- lcpfn/model.py +35 -8
- lcpfn/positional_encodings.py +31 -23
- lcpfn/train.py +0 -266
- lcpfn/train_lcpfn.py +9 -5
- lcpfn/transformer.py +184 -62
- lcpfn/utils.py +206 -55
- lcpfn/version.py +1 -0
- pyproject.toml +42 -0
- requirements.txt +0 -4
lcpfn/__init__.py
CHANGED
@@ -1,53 +1,80 @@
|
|
1 |
import os, sys
|
|
|
2 |
sys.path.insert(0, os.path.dirname(__file__))
|
3 |
|
4 |
|
5 |
-
model_path =
|
|
|
6 |
|
7 |
def prepare_models():
|
8 |
pfns4bo_dir = os.path.dirname(__file__)
|
9 |
-
model_names = [
|
10 |
-
|
|
|
|
|
11 |
|
12 |
for name in model_names:
|
13 |
weights_path = os.path.join(pfns4bo_dir, model_path, name)
|
14 |
-
compressed_weights_path = os.path.join(pfns4bo_dir, model_path, name +
|
15 |
if not os.path.exists(weights_path):
|
16 |
if not os.path.exists(compressed_weights_path):
|
17 |
print("Downloading", os.path.abspath(compressed_weights_path))
|
18 |
import requests
|
19 |
-
|
|
|
20 |
r = requests.get(url, allow_redirects=True)
|
21 |
os.makedirs(os.path.dirname(compressed_weights_path), exist_ok=True)
|
22 |
-
with open(compressed_weights_path,
|
23 |
f.write(r.content)
|
24 |
if os.path.exists(compressed_weights_path):
|
25 |
print("Unzipping", name)
|
26 |
os.system(f"gzip -dk {compressed_weights_path}")
|
27 |
else:
|
28 |
print("Failed to find", compressed_weights_path)
|
29 |
-
print(
|
|
|
|
|
30 |
if os.path.exists(weights_path):
|
31 |
print("Successfully located model at", weights_path)
|
32 |
|
33 |
|
34 |
model_dict = {
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
}
|
40 |
|
41 |
|
42 |
def __getattr__(name):
|
43 |
if name in model_dict:
|
44 |
if not os.path.exists(model_dict[name]):
|
45 |
-
print(
|
|
|
|
|
|
|
|
|
46 |
print("This might take a while..")
|
47 |
prepare_models()
|
48 |
return model_dict[name]
|
49 |
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
50 |
|
|
|
|
|
51 |
from lcpfn.model import LCPFN
|
52 |
from lcpfn.train_lcpfn import train_lcpfn
|
53 |
-
from lcpfn.domhan_prior import sample_from_prior, create_get_batch_func
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os, sys
|
2 |
+
|
3 |
sys.path.insert(0, os.path.dirname(__file__))
|
4 |
|
5 |
|
6 |
+
model_path = "trained_models"
|
7 |
+
|
8 |
|
9 |
def prepare_models():
|
10 |
pfns4bo_dir = os.path.dirname(__file__)
|
11 |
+
model_names = [
|
12 |
+
"pfn_EPOCH1000_EMSIZE512_NLAYERS12_NBUCKETS1000.pt",
|
13 |
+
"pfn_EPOCH1000_EMSIZE512_NLAYERS6_NBUCKETS1000.pt",
|
14 |
+
]
|
15 |
|
16 |
for name in model_names:
|
17 |
weights_path = os.path.join(pfns4bo_dir, model_path, name)
|
18 |
+
compressed_weights_path = os.path.join(pfns4bo_dir, model_path, name + ".gz")
|
19 |
if not os.path.exists(weights_path):
|
20 |
if not os.path.exists(compressed_weights_path):
|
21 |
print("Downloading", os.path.abspath(compressed_weights_path))
|
22 |
import requests
|
23 |
+
|
24 |
+
url = f'https://ml.informatik.uni-freiburg.de/research-artifacts/lcpfn/{name + ".gz"}'
|
25 |
r = requests.get(url, allow_redirects=True)
|
26 |
os.makedirs(os.path.dirname(compressed_weights_path), exist_ok=True)
|
27 |
+
with open(compressed_weights_path, "wb") as f:
|
28 |
f.write(r.content)
|
29 |
if os.path.exists(compressed_weights_path):
|
30 |
print("Unzipping", name)
|
31 |
os.system(f"gzip -dk {compressed_weights_path}")
|
32 |
else:
|
33 |
print("Failed to find", compressed_weights_path)
|
34 |
+
print(
|
35 |
+
"Make sure you have an internet connection to download the model automatically.."
|
36 |
+
)
|
37 |
if os.path.exists(weights_path):
|
38 |
print("Successfully located model at", weights_path)
|
39 |
|
40 |
|
41 |
model_dict = {
|
42 |
+
"EMSIZE512_NLAYERS12_NBUCKETS1000": os.path.join(
|
43 |
+
os.path.dirname(__file__),
|
44 |
+
model_path,
|
45 |
+
"pfn_EPOCH1000_EMSIZE512_NLAYERS12_NBUCKETS1000.pt",
|
46 |
+
),
|
47 |
+
"EMSIZE512_NLAYERS6_NBUCKETS1000": os.path.join(
|
48 |
+
os.path.dirname(__file__),
|
49 |
+
model_path,
|
50 |
+
"pfn_EPOCH1000_EMSIZE512_NLAYERS6_NBUCKETS1000.pt",
|
51 |
+
),
|
52 |
}
|
53 |
|
54 |
|
55 |
def __getattr__(name):
|
56 |
if name in model_dict:
|
57 |
if not os.path.exists(model_dict[name]):
|
58 |
+
print(
|
59 |
+
"Can't find",
|
60 |
+
os.path.abspath(model_dict[name]),
|
61 |
+
"thus unzipping/downloading models now.",
|
62 |
+
)
|
63 |
print("This might take a while..")
|
64 |
prepare_models()
|
65 |
return model_dict[name]
|
66 |
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
67 |
|
68 |
+
|
69 |
+
from .version import __version__
|
70 |
from lcpfn.model import LCPFN
|
71 |
from lcpfn.train_lcpfn import train_lcpfn
|
72 |
+
from lcpfn.domhan_prior import sample_from_prior, create_get_batch_func
|
73 |
+
|
74 |
+
__all__ = [
|
75 |
+
"LCPFN",
|
76 |
+
"train_lcpfn",
|
77 |
+
"sample_from_prior",
|
78 |
+
"create_get_batch_func",
|
79 |
+
"__version__",
|
80 |
+
]
|
lcpfn/bar_distribution.py
CHANGED
@@ -3,19 +3,25 @@ from torch import nn
|
|
3 |
|
4 |
|
5 |
class BarDistribution(nn.Module):
|
6 |
-
def __init__(
|
|
|
|
|
7 |
# sorted list of borders
|
8 |
super().__init__()
|
9 |
assert len(borders.shape) == 1
|
10 |
-
#self.borders = borders
|
11 |
-
self.register_buffer(
|
12 |
-
self.register_buffer(
|
13 |
-
#self.bucket_widths = self.borders[1:] - self.borders[:-1]
|
14 |
-
self.register_buffer(
|
15 |
full_width = self.bucket_widths.sum()
|
16 |
border_order = torch.argsort(borders)
|
17 |
-
assert (
|
18 |
-
|
|
|
|
|
|
|
|
|
19 |
self.num_bars = len(borders) - 1
|
20 |
|
21 |
def map_to_bucket_idx(self, y):
|
@@ -24,28 +30,35 @@ class BarDistribution(nn.Module):
|
|
24 |
target_sample[y == self.borders[-1]] = self.num_bars - 1
|
25 |
return target_sample
|
26 |
|
27 |
-
def forward(
|
|
|
|
|
28 |
target_sample = self.map_to_bucket_idx(y)
|
29 |
-
assert (target_sample >= 0).all() and (
|
30 |
-
|
|
|
|
|
|
|
|
|
31 |
|
32 |
bucket_log_probs = torch.log_softmax(logits, -1)
|
33 |
scaled_bucket_log_probs = bucket_log_probs - torch.log(self.bucket_widths)
|
34 |
-
#print(bucket_log_probs, logits.shape)
|
35 |
|
36 |
-
nll_loss = -scaled_bucket_log_probs.gather(
|
|
|
|
|
37 |
|
38 |
smooth_loss = -scaled_bucket_log_probs.mean(dim=-1)
|
39 |
-
smoothing = self.smoothing if self.training else 0.
|
40 |
-
loss = (1. - smoothing) * nll_loss + smoothing * smooth_loss
|
41 |
return loss
|
42 |
|
43 |
def mean(self, logits):
|
44 |
-
bucket_means = self.borders[:-1] + self.bucket_widths/2
|
45 |
p = torch.softmax(logits, -1)
|
46 |
return p @ bucket_means
|
47 |
|
48 |
-
|
49 |
def icdf(self, logits, left_prob):
|
50 |
"""
|
51 |
Implementation of the quantile function
|
@@ -55,22 +68,32 @@ class BarDistribution(nn.Module):
|
|
55 |
"""
|
56 |
probs = logits.softmax(-1)
|
57 |
cumprobs = torch.cumsum(probs, -1)
|
58 |
-
idx =
|
59 |
-
.
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
cumprobs = torch.cat(
|
61 |
[torch.zeros(*cumprobs.shape[:-1], 1, device=logits.device), cumprobs], -1
|
62 |
)
|
63 |
|
64 |
rest_prob = left_prob - cumprobs.gather(-1, idx[..., None]).squeeze(-1)
|
65 |
left_border = self.borders[idx]
|
66 |
-
right_border = self.borders[idx+1]
|
67 |
-
return left_border + (right_border - left_border) * rest_prob / probs.gather(
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
|
|
72 |
|
73 |
-
def ucb(self, logits, best_f, rest_prob=(1
|
74 |
"""
|
75 |
UCB utility. Rest Prob is the amount of utility above (below) the confidence interval that is ignored.
|
76 |
Higher rest_prob is equivalent to lower beta in the standard GP-UCB formulation.
|
@@ -90,23 +113,41 @@ class BarDistribution(nn.Module):
|
|
90 |
|
91 |
def mode(self, logits):
|
92 |
mode_inds = logits.argmax(-1)
|
93 |
-
bucket_means = self.borders[:-1] + self.bucket_widths/2
|
94 |
return bucket_means[mode_inds]
|
95 |
|
96 |
-
def ei(
|
97 |
-
|
|
|
|
|
98 |
if maximize:
|
99 |
bucket_contributions = torch.tensor(
|
100 |
-
[
|
101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
else:
|
103 |
bucket_contributions = torch.tensor(
|
104 |
-
[
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
p = torch.softmax(logits, -1)
|
107 |
return p @ bucket_contributions
|
108 |
|
109 |
-
def pi(
|
|
|
|
|
110 |
"""
|
111 |
Acquisition Function: Probability of Improvement
|
112 |
:param logits: as returned by Transformer
|
@@ -117,10 +158,9 @@ class BarDistribution(nn.Module):
|
|
117 |
assert maximize is True
|
118 |
p = torch.softmax(logits, -1)
|
119 |
border_widths = self.borders[1:] - self.borders[:-1]
|
120 |
-
factor = 1. - ((best_f - self.borders[:-1]) / border_widths).clamp(0
|
121 |
return (p * factor).sum(-1)
|
122 |
|
123 |
-
|
124 |
def mean_of_square(self, logits):
|
125 |
"""
|
126 |
Computes E[x^2].
|
@@ -128,7 +168,11 @@ class BarDistribution(nn.Module):
|
|
128 |
"""
|
129 |
left_borders = self.borders[:-1]
|
130 |
right_borders = self.borders[1:]
|
131 |
-
bucket_mean_of_square = (
|
|
|
|
|
|
|
|
|
132 |
p = torch.softmax(logits, -1)
|
133 |
return p @ bucket_mean_of_square
|
134 |
|
@@ -138,54 +182,74 @@ class BarDistribution(nn.Module):
|
|
138 |
|
139 |
class FullSupportBarDistribution(BarDistribution):
|
140 |
@staticmethod
|
141 |
-
def halfnormal_with_p_weight_before(range_max,p
|
142 |
-
s = range_max / torch.distributions.HalfNormal(torch.tensor(1.)).icdf(
|
|
|
|
|
143 |
return torch.distributions.HalfNormal(s)
|
144 |
|
145 |
-
def forward(
|
|
|
|
|
146 |
assert self.num_bars > 1
|
147 |
target_sample = self.map_to_bucket_idx(y)
|
148 |
-
target_sample.clamp_(0,self.num_bars-1)
|
149 |
assert logits.shape[-1] == self.num_bars
|
150 |
|
151 |
bucket_log_probs = torch.log_softmax(logits, -1)
|
152 |
scaled_bucket_log_probs = bucket_log_probs - torch.log(self.bucket_widths)
|
153 |
-
#print(bucket_log_probs, logits.shape)
|
154 |
-
log_probs = scaled_bucket_log_probs.gather(
|
155 |
-
|
156 |
-
|
157 |
-
|
|
|
|
|
|
|
|
|
158 |
|
159 |
# TODO look over it again
|
160 |
-
log_probs[target_sample == 0] += side_normals[0].log_prob(
|
161 |
-
|
|
|
|
|
|
|
|
|
162 |
|
163 |
nll_loss = -log_probs
|
164 |
|
165 |
smooth_loss = -scaled_bucket_log_probs.mean(dim=-1)
|
166 |
-
smoothing = self.smoothing if self.training else 0.
|
167 |
-
loss = (1. - smoothing) * nll_loss + smoothing * smooth_loss
|
168 |
-
|
169 |
|
170 |
return loss
|
171 |
|
172 |
def mean(self, logits):
|
173 |
bucket_means = self.borders[:-1] + self.bucket_widths / 2
|
174 |
p = torch.softmax(logits, -1)
|
175 |
-
side_normals = (
|
176 |
-
|
|
|
|
|
177 |
bucket_means[0] = -side_normals[0].mean + self.borders[1]
|
178 |
bucket_means[-1] = side_normals[1].mean + self.borders[-2]
|
179 |
return p @ bucket_means
|
180 |
|
181 |
|
182 |
-
|
183 |
-
|
|
|
|
|
|
|
|
|
184 |
assert (ys is not None) or (full_range is not None)
|
185 |
if ys is not None:
|
186 |
ys = ys.flatten()
|
187 |
-
if len(ys) % num_outputs:
|
188 |
-
|
|
|
|
|
|
|
189 |
ys_per_bucket = len(ys) // num_outputs
|
190 |
if full_range is None:
|
191 |
full_range = (ys.min(), ys.max())
|
@@ -193,17 +257,34 @@ def get_bucket_limits_(num_outputs:int, full_range:tuple=None, ys:torch.Tensor=N
|
|
193 |
assert full_range[0] <= ys.min() and full_range[1] >= ys.max()
|
194 |
full_range = torch.tensor(full_range)
|
195 |
ys_sorted, ys_order = ys.sort(0)
|
196 |
-
bucket_limits = (
|
|
|
|
|
|
|
197 |
if verbose:
|
198 |
-
print(
|
|
|
|
|
199 |
print(full_range)
|
200 |
-
bucket_limits = torch.cat(
|
|
|
|
|
201 |
|
202 |
else:
|
203 |
class_width = (full_range[1] - full_range[0]) / num_outputs
|
204 |
-
bucket_limits = torch.cat(
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
|
206 |
-
assert
|
|
|
|
|
|
|
|
|
207 |
return bucket_limits
|
208 |
|
209 |
|
@@ -266,4 +347,3 @@ def get_bucket_limits(
|
|
266 |
), f"{full_range[-1]} != {bucket_limits[-1]}"
|
267 |
|
268 |
return bucket_limits
|
269 |
-
|
|
|
3 |
|
4 |
|
5 |
class BarDistribution(nn.Module):
|
6 |
+
def __init__(
|
7 |
+
self, borders: torch.Tensor, smoothing=0.0
|
8 |
+
): # here borders should start with min and end with max, where all values lie in (min,max) and are sorted
|
9 |
# sorted list of borders
|
10 |
super().__init__()
|
11 |
assert len(borders.shape) == 1
|
12 |
+
# self.borders = borders
|
13 |
+
self.register_buffer("borders", borders)
|
14 |
+
self.register_buffer("smoothing", torch.tensor(smoothing))
|
15 |
+
# self.bucket_widths = self.borders[1:] - self.borders[:-1]
|
16 |
+
self.register_buffer("bucket_widths", self.borders[1:] - self.borders[:-1])
|
17 |
full_width = self.bucket_widths.sum()
|
18 |
border_order = torch.argsort(borders)
|
19 |
+
assert (
|
20 |
+
full_width - (self.borders[-1] - self.borders[0])
|
21 |
+
).abs() < 1e-4, f"diff: {full_width - (self.borders[-1] - self.borders[0])}"
|
22 |
+
assert (
|
23 |
+
border_order == torch.arange(len(borders)).to(border_order.device)
|
24 |
+
).all(), "Please provide sorted borders!"
|
25 |
self.num_bars = len(borders) - 1
|
26 |
|
27 |
def map_to_bucket_idx(self, y):
|
|
|
30 |
target_sample[y == self.borders[-1]] = self.num_bars - 1
|
31 |
return target_sample
|
32 |
|
33 |
+
def forward(
|
34 |
+
self, logits, y
|
35 |
+
): # gives the negative log density (the _loss_), y: T x B, logits: T x B x self.num_bars
|
36 |
target_sample = self.map_to_bucket_idx(y)
|
37 |
+
assert (target_sample >= 0).all() and (
|
38 |
+
target_sample < self.num_bars
|
39 |
+
).all(), f"y {y} not in support set for borders (min_y, max_y) {self.borders}"
|
40 |
+
assert (
|
41 |
+
logits.shape[-1] == self.num_bars
|
42 |
+
), f"{logits.shape[-1]} vs {self.num_bars}"
|
43 |
|
44 |
bucket_log_probs = torch.log_softmax(logits, -1)
|
45 |
scaled_bucket_log_probs = bucket_log_probs - torch.log(self.bucket_widths)
|
46 |
+
# print(bucket_log_probs, logits.shape)
|
47 |
|
48 |
+
nll_loss = -scaled_bucket_log_probs.gather(
|
49 |
+
-1, target_sample.unsqueeze(-1)
|
50 |
+
).squeeze(-1)
|
51 |
|
52 |
smooth_loss = -scaled_bucket_log_probs.mean(dim=-1)
|
53 |
+
smoothing = self.smoothing if self.training else 0.0
|
54 |
+
loss = (1.0 - smoothing) * nll_loss + smoothing * smooth_loss
|
55 |
return loss
|
56 |
|
57 |
def mean(self, logits):
|
58 |
+
bucket_means = self.borders[:-1] + self.bucket_widths / 2
|
59 |
p = torch.softmax(logits, -1)
|
60 |
return p @ bucket_means
|
61 |
|
|
|
62 |
def icdf(self, logits, left_prob):
|
63 |
"""
|
64 |
Implementation of the quantile function
|
|
|
68 |
"""
|
69 |
probs = logits.softmax(-1)
|
70 |
cumprobs = torch.cumsum(probs, -1)
|
71 |
+
idx = (
|
72 |
+
torch.searchsorted(
|
73 |
+
cumprobs,
|
74 |
+
left_prob * torch.ones(*cumprobs.shape[:-1], 1, device=probs.device),
|
75 |
+
)
|
76 |
+
.squeeze(-1)
|
77 |
+
.clamp(0, cumprobs.shape[-1] - 1)
|
78 |
+
) # this might not do the right for outliers
|
79 |
cumprobs = torch.cat(
|
80 |
[torch.zeros(*cumprobs.shape[:-1], 1, device=logits.device), cumprobs], -1
|
81 |
)
|
82 |
|
83 |
rest_prob = left_prob - cumprobs.gather(-1, idx[..., None]).squeeze(-1)
|
84 |
left_border = self.borders[idx]
|
85 |
+
right_border = self.borders[idx + 1]
|
86 |
+
return left_border + (right_border - left_border) * rest_prob / probs.gather(
|
87 |
+
-1, idx[..., None]
|
88 |
+
).squeeze(-1)
|
89 |
+
|
90 |
+
def quantile(self, logits, center_prob=0.682):
|
91 |
+
side_probs = (1.0 - center_prob) / 2
|
92 |
+
return torch.stack(
|
93 |
+
(self.icdf(logits, side_probs), self.icdf(logits, 1.0 - side_probs)), -1
|
94 |
+
)
|
95 |
|
96 |
+
def ucb(self, logits, best_f, rest_prob=(1 - 0.682) / 2, maximize=True):
|
97 |
"""
|
98 |
UCB utility. Rest Prob is the amount of utility above (below) the confidence interval that is ignored.
|
99 |
Higher rest_prob is equivalent to lower beta in the standard GP-UCB formulation.
|
|
|
113 |
|
114 |
def mode(self, logits):
|
115 |
mode_inds = logits.argmax(-1)
|
116 |
+
bucket_means = self.borders[:-1] + self.bucket_widths / 2
|
117 |
return bucket_means[mode_inds]
|
118 |
|
119 |
+
def ei(
|
120 |
+
self, logits, best_f, maximize=True
|
121 |
+
): # logits: evaluation_points x batch x feature_dim
|
122 |
+
bucket_means = self.borders[:-1] + self.bucket_widths / 2
|
123 |
if maximize:
|
124 |
bucket_contributions = torch.tensor(
|
125 |
+
[
|
126 |
+
max((bucket_max + max(bucket_min, best_f)) / 2 - best_f, 0)
|
127 |
+
for bucket_min, bucket_max, bucket_mean in zip(
|
128 |
+
self.borders[:-1], self.borders[1:], bucket_means
|
129 |
+
)
|
130 |
+
],
|
131 |
+
dtype=logits.dtype,
|
132 |
+
device=logits.device,
|
133 |
+
)
|
134 |
else:
|
135 |
bucket_contributions = torch.tensor(
|
136 |
+
[
|
137 |
+
-min((min(bucket_max, best_f) + bucket_min) / 2 - best_f, 0)
|
138 |
+
for bucket_min, bucket_max, bucket_mean in zip( # min on max instead of max on min, and compare min < instead of max >
|
139 |
+
self.borders[:-1], self.borders[1:], bucket_means
|
140 |
+
)
|
141 |
+
],
|
142 |
+
dtype=logits.dtype,
|
143 |
+
device=logits.device,
|
144 |
+
)
|
145 |
p = torch.softmax(logits, -1)
|
146 |
return p @ bucket_contributions
|
147 |
|
148 |
+
def pi(
|
149 |
+
self, logits, best_f, maximize=True
|
150 |
+
): # logits: evaluation_points x batch x feature_dim
|
151 |
"""
|
152 |
Acquisition Function: Probability of Improvement
|
153 |
:param logits: as returned by Transformer
|
|
|
158 |
assert maximize is True
|
159 |
p = torch.softmax(logits, -1)
|
160 |
border_widths = self.borders[1:] - self.borders[:-1]
|
161 |
+
factor = 1.0 - ((best_f - self.borders[:-1]) / border_widths).clamp(0.0, 1.0)
|
162 |
return (p * factor).sum(-1)
|
163 |
|
|
|
164 |
def mean_of_square(self, logits):
|
165 |
"""
|
166 |
Computes E[x^2].
|
|
|
168 |
"""
|
169 |
left_borders = self.borders[:-1]
|
170 |
right_borders = self.borders[1:]
|
171 |
+
bucket_mean_of_square = (
|
172 |
+
left_borders.square()
|
173 |
+
+ right_borders.square()
|
174 |
+
+ left_borders * right_borders
|
175 |
+
) / 3.0
|
176 |
p = torch.softmax(logits, -1)
|
177 |
return p @ bucket_mean_of_square
|
178 |
|
|
|
182 |
|
183 |
class FullSupportBarDistribution(BarDistribution):
|
184 |
@staticmethod
|
185 |
+
def halfnormal_with_p_weight_before(range_max, p=0.5):
|
186 |
+
s = range_max / torch.distributions.HalfNormal(torch.tensor(1.0)).icdf(
|
187 |
+
torch.tensor(p)
|
188 |
+
)
|
189 |
return torch.distributions.HalfNormal(s)
|
190 |
|
191 |
+
def forward(
|
192 |
+
self, logits, y
|
193 |
+
): # gives the negative log density (the _loss_), y: T x B, logits: T x B x self.num_bars
|
194 |
assert self.num_bars > 1
|
195 |
target_sample = self.map_to_bucket_idx(y)
|
196 |
+
target_sample.clamp_(0, self.num_bars - 1)
|
197 |
assert logits.shape[-1] == self.num_bars
|
198 |
|
199 |
bucket_log_probs = torch.log_softmax(logits, -1)
|
200 |
scaled_bucket_log_probs = bucket_log_probs - torch.log(self.bucket_widths)
|
201 |
+
# print(bucket_log_probs, logits.shape)
|
202 |
+
log_probs = scaled_bucket_log_probs.gather(
|
203 |
+
-1, target_sample.unsqueeze(-1)
|
204 |
+
).squeeze(-1)
|
205 |
+
|
206 |
+
side_normals = (
|
207 |
+
self.halfnormal_with_p_weight_before(self.bucket_widths[0]),
|
208 |
+
self.halfnormal_with_p_weight_before(self.bucket_widths[-1]),
|
209 |
+
)
|
210 |
|
211 |
# TODO look over it again
|
212 |
+
log_probs[target_sample == 0] += side_normals[0].log_prob(
|
213 |
+
(self.borders[1] - y[target_sample == 0]).clamp(min=0.00000001)
|
214 |
+
) + torch.log(self.bucket_widths[0])
|
215 |
+
log_probs[target_sample == self.num_bars - 1] += side_normals[1].log_prob(
|
216 |
+
y[target_sample == self.num_bars - 1] - self.borders[-2]
|
217 |
+
) + torch.log(self.bucket_widths[-1])
|
218 |
|
219 |
nll_loss = -log_probs
|
220 |
|
221 |
smooth_loss = -scaled_bucket_log_probs.mean(dim=-1)
|
222 |
+
smoothing = self.smoothing if self.training else 0.0
|
223 |
+
loss = (1.0 - smoothing) * nll_loss + smoothing * smooth_loss
|
|
|
224 |
|
225 |
return loss
|
226 |
|
227 |
def mean(self, logits):
|
228 |
bucket_means = self.borders[:-1] + self.bucket_widths / 2
|
229 |
p = torch.softmax(logits, -1)
|
230 |
+
side_normals = (
|
231 |
+
self.halfnormal_with_p_weight_before(self.bucket_widths[0]),
|
232 |
+
self.halfnormal_with_p_weight_before(self.bucket_widths[-1]),
|
233 |
+
)
|
234 |
bucket_means[0] = -side_normals[0].mean + self.borders[1]
|
235 |
bucket_means[-1] = side_normals[1].mean + self.borders[-2]
|
236 |
return p @ bucket_means
|
237 |
|
238 |
|
239 |
+
def get_bucket_limits_(
|
240 |
+
num_outputs: int,
|
241 |
+
full_range: tuple = None,
|
242 |
+
ys: torch.Tensor = None,
|
243 |
+
verbose: bool = False,
|
244 |
+
):
|
245 |
assert (ys is not None) or (full_range is not None)
|
246 |
if ys is not None:
|
247 |
ys = ys.flatten()
|
248 |
+
if len(ys) % num_outputs:
|
249 |
+
ys = ys[: -(len(ys) % num_outputs)]
|
250 |
+
print(
|
251 |
+
f"Using {len(ys)} y evals to estimate {num_outputs} buckets. Cut off the last {len(ys) % num_outputs} ys."
|
252 |
+
)
|
253 |
ys_per_bucket = len(ys) // num_outputs
|
254 |
if full_range is None:
|
255 |
full_range = (ys.min(), ys.max())
|
|
|
257 |
assert full_range[0] <= ys.min() and full_range[1] >= ys.max()
|
258 |
full_range = torch.tensor(full_range)
|
259 |
ys_sorted, ys_order = ys.sort(0)
|
260 |
+
bucket_limits = (
|
261 |
+
ys_sorted[ys_per_bucket - 1 :: ys_per_bucket][:-1]
|
262 |
+
+ ys_sorted[ys_per_bucket::ys_per_bucket]
|
263 |
+
) / 2
|
264 |
if verbose:
|
265 |
+
print(
|
266 |
+
f"Using {len(ys)} y evals to estimate {num_outputs} buckets. Cut off the last {len(ys) % num_outputs} ys."
|
267 |
+
)
|
268 |
print(full_range)
|
269 |
+
bucket_limits = torch.cat(
|
270 |
+
[full_range[0].unsqueeze(0), bucket_limits, full_range[1].unsqueeze(0)], 0
|
271 |
+
)
|
272 |
|
273 |
else:
|
274 |
class_width = (full_range[1] - full_range[0]) / num_outputs
|
275 |
+
bucket_limits = torch.cat(
|
276 |
+
[
|
277 |
+
full_range[0] + torch.arange(num_outputs).float() * class_width,
|
278 |
+
torch.tensor(full_range[1]).unsqueeze(0),
|
279 |
+
],
|
280 |
+
0,
|
281 |
+
)
|
282 |
|
283 |
+
assert (
|
284 |
+
len(bucket_limits) - 1 == num_outputs
|
285 |
+
and full_range[0] == bucket_limits[0]
|
286 |
+
and full_range[-1] == bucket_limits[-1]
|
287 |
+
)
|
288 |
return bucket_limits
|
289 |
|
290 |
|
|
|
347 |
), f"{full_range[-1]} != {bucket_limits[-1]}"
|
348 |
|
349 |
return bucket_limits
|
|
lcpfn/decoders.py
CHANGED
@@ -2,6 +2,14 @@ import torch
|
|
2 |
from torch import nn
|
3 |
import random
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
class ScaledDecoder(nn.Module):
|
7 |
def __init__(self, ninp, nhid, nout):
|
@@ -11,20 +19,24 @@ class ScaledDecoder(nn.Module):
|
|
11 |
self.linear2 = nn.Linear(nhid, 10)
|
12 |
|
13 |
def forward(self, x):
|
14 |
-
#return torch.cat([self.linear1(x), self.linear2(x)], -1)
|
15 |
x = self.linear(x)
|
16 |
-
x =
|
17 |
-
temps = self.linear2(x).softmax(-1) @ torch.tensor(
|
18 |
-
|
19 |
-
|
|
|
|
|
20 |
return self.linear1(x) / temps.unsqueeze(-1)
|
21 |
|
|
|
22 |
class FixedScaledDecoder(nn.Module):
|
23 |
def __init__(self, ninp, nhid, nout):
|
24 |
super().__init__()
|
25 |
-
self.mapper = nn.Sequential(
|
26 |
-
|
|
|
|
|
27 |
|
28 |
def forward(self, x):
|
29 |
-
return self.mapper(x)/self.T.sum()
|
30 |
-
|
|
|
2 |
from torch import nn
|
3 |
import random
|
4 |
|
5 |
+
from torch import Tensor
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
class GELU(nn.Module):
|
10 |
+
def forward(self, input: Tensor) -> Tensor:
|
11 |
+
return F.gelu(input)
|
12 |
+
|
13 |
|
14 |
class ScaledDecoder(nn.Module):
|
15 |
def __init__(self, ninp, nhid, nout):
|
|
|
19 |
self.linear2 = nn.Linear(nhid, 10)
|
20 |
|
21 |
def forward(self, x):
|
22 |
+
# return torch.cat([self.linear1(x), self.linear2(x)], -1)
|
23 |
x = self.linear(x)
|
24 |
+
x = GELU()(x)
|
25 |
+
temps = self.linear2(x).softmax(-1) @ torch.tensor(
|
26 |
+
[1.0, 1.4, 1.7, 2.0, 5.0, 10.0, 20.0, 40.0, 80.0, 160.0], device=x.device
|
27 |
+
)
|
28 |
+
if random.random() > 0.99:
|
29 |
+
print(temps.shape, temps[:, :2])
|
30 |
return self.linear1(x) / temps.unsqueeze(-1)
|
31 |
|
32 |
+
|
33 |
class FixedScaledDecoder(nn.Module):
|
34 |
def __init__(self, ninp, nhid, nout):
|
35 |
super().__init__()
|
36 |
+
self.mapper = nn.Sequential(
|
37 |
+
nn.Linear(ninp, nhid), nn.GELU(), nn.Linear(nhid, nout)
|
38 |
+
)
|
39 |
+
self.T = nn.Parameter(torch.ones(10000) / 10000)
|
40 |
|
41 |
def forward(self, x):
|
42 |
+
return self.mapper(x) / self.T.sum()
|
|
lcpfn/domhan_prior.py
CHANGED
@@ -58,7 +58,10 @@ def prior_weights(
|
|
58 |
|
59 |
def sample_from_prior(rng, seq_len=100):
|
60 |
return sample_prior_comb(
|
61 |
-
rng=rng,
|
|
|
|
|
|
|
62 |
)
|
63 |
|
64 |
|
@@ -103,7 +106,7 @@ def sample_prior_comb(
|
|
103 |
f_priors = {
|
104 |
"pow3": uniform_prior_pow3,
|
105 |
"ilog2": uniform_prior_ilog2,
|
106 |
-
"janoschek": uniform_prior_janoschek
|
107 |
}
|
108 |
else:
|
109 |
raise NotImplemented()
|
@@ -153,6 +156,7 @@ def generate_prior_dataset(n, prior=sample_prior_comb, seed=42):
|
|
153 |
def create_get_batch_func(prior):
|
154 |
return partial(get_batch_domhan, prior=prior)
|
155 |
|
|
|
156 |
# function producing batches for PFN training
|
157 |
def get_batch_domhan(
|
158 |
batch_size,
|
@@ -192,4 +196,4 @@ def get_batch_domhan(
|
|
192 |
y_target = y_target.float()
|
193 |
y_noisy = y_noisy.float()
|
194 |
|
195 |
-
return x, y_noisy, y_target
|
|
|
58 |
|
59 |
def sample_from_prior(rng, seq_len=100):
|
60 |
return sample_prior_comb(
|
61 |
+
rng=rng,
|
62 |
+
seq_len=seq_len,
|
63 |
+
components=["pow3", "ilog2", "janoschek"],
|
64 |
+
distribution="peaked",
|
65 |
)
|
66 |
|
67 |
|
|
|
106 |
f_priors = {
|
107 |
"pow3": uniform_prior_pow3,
|
108 |
"ilog2": uniform_prior_ilog2,
|
109 |
+
"janoschek": uniform_prior_janoschek,
|
110 |
}
|
111 |
else:
|
112 |
raise NotImplemented()
|
|
|
156 |
def create_get_batch_func(prior):
|
157 |
return partial(get_batch_domhan, prior=prior)
|
158 |
|
159 |
+
|
160 |
# function producing batches for PFN training
|
161 |
def get_batch_domhan(
|
162 |
batch_size,
|
|
|
196 |
y_target = y_target.float()
|
197 |
y_noisy = y_noisy.float()
|
198 |
|
199 |
+
return x, y_noisy, y_target
|
lcpfn/encoders.py
CHANGED
@@ -18,34 +18,45 @@ class StyleEncoder(nn.Module):
|
|
18 |
|
19 |
|
20 |
class _PositionalEncoding(nn.Module):
|
21 |
-
def __init__(self, d_model, dropout=0.):
|
22 |
super().__init__()
|
23 |
self.dropout = nn.Dropout(p=dropout)
|
24 |
self.d_model = d_model
|
25 |
-
self.device_test_tensor = nn.Parameter(torch.tensor(1.))
|
26 |
|
27 |
-
def forward(self, x)
|
28 |
-
assert self.d_model % x.shape[-1]*2 == 0
|
29 |
d_per_feature = self.d_model // x.shape[-1]
|
30 |
pe = torch.zeros(*x.shape, d_per_feature, device=self.device_test_tensor.device)
|
31 |
-
#position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
32 |
interval_size = 10
|
33 |
-
div_term = (
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
pe[..., 0::2] = torch.sin(x.unsqueeze(-1) * div_term)
|
36 |
pe[..., 1::2] = torch.cos(x.unsqueeze(-1) * div_term)
|
37 |
-
return self.dropout(pe).view(x.shape[0],x.shape[1],self.d_model)
|
38 |
|
39 |
|
40 |
Positional = lambda _, emsize: _PositionalEncoding(d_model=emsize)
|
41 |
|
|
|
42 |
class EmbeddingEncoder(nn.Module):
|
43 |
def __init__(self, num_features, em_size, num_embs=100):
|
44 |
super().__init__()
|
45 |
self.num_embs = num_embs
|
46 |
self.embeddings = nn.Embedding(num_embs * num_features, em_size, max_norm=True)
|
47 |
-
self.init_weights(.1)
|
48 |
-
self.min_max = (-2
|
49 |
|
50 |
@property
|
51 |
def width(self):
|
@@ -60,7 +71,9 @@ class EmbeddingEncoder(nn.Module):
|
|
60 |
|
61 |
def forward(self, x): # T x B x num_features
|
62 |
x_idxs = self.discretize(x)
|
63 |
-
x_idxs +=
|
|
|
|
|
64 |
# print(x_idxs,self.embeddings.weight.shape)
|
65 |
return self.embeddings(x_idxs).mean(-2)
|
66 |
|
@@ -72,7 +85,7 @@ class Normalize(nn.Module):
|
|
72 |
self.std = std
|
73 |
|
74 |
def forward(self, x):
|
75 |
-
return (x-self.mean)/self.std
|
76 |
|
77 |
|
78 |
def get_normalized_uniform_encoder(encoder_creator):
|
@@ -83,13 +96,16 @@ def get_normalized_uniform_encoder(encoder_creator):
|
|
83 |
:param encoder:
|
84 |
:return:
|
85 |
"""
|
86 |
-
return lambda in_dim, out_dim: nn.Sequential(
|
|
|
|
|
87 |
|
88 |
|
89 |
Linear = nn.Linear
|
90 |
-
MLP = lambda num_features, emsize: nn.Sequential(
|
91 |
-
|
92 |
-
|
|
|
93 |
|
94 |
class NanHandlingEncoder(nn.Module):
|
95 |
def __init__(self, num_features, emsize, keep_nans=True):
|
@@ -101,10 +117,17 @@ class NanHandlingEncoder(nn.Module):
|
|
101 |
|
102 |
def forward(self, x):
|
103 |
if self.keep_nans:
|
104 |
-
x = torch.cat(
|
105 |
-
|
106 |
-
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
else:
|
109 |
x = torch.nan_to_num(x, nan=0.0)
|
110 |
return self.layer(x)
|
@@ -124,24 +147,28 @@ class Linear(nn.Linear):
|
|
124 |
class Conv(nn.Module):
|
125 |
def __init__(self, input_size, emsize):
|
126 |
super().__init__()
|
127 |
-
self.convs = torch.nn.ModuleList(
|
128 |
-
|
|
|
|
|
129 |
|
130 |
def forward(self, x):
|
131 |
size = math.isqrt(x.shape[-1])
|
132 |
-
assert size*size == x.shape[-1]
|
133 |
x = x.reshape(*x.shape[:-1], 1, size, size)
|
134 |
for conv in self.convs:
|
135 |
if x.shape[-1] < 4:
|
136 |
break
|
137 |
x = conv(x)
|
138 |
x.relu_()
|
139 |
-
x = nn.AdaptiveAvgPool2d((1,1))(x).squeeze(-1).squeeze(-1)
|
140 |
return self.linear(x)
|
141 |
|
142 |
|
143 |
class CanEmb(nn.Embedding):
|
144 |
-
def __init__(
|
|
|
|
|
145 |
assert embedding_dim % num_features == 0
|
146 |
embedding_dim = embedding_dim // num_features
|
147 |
super().__init__(num_embeddings, embedding_dim, *args, **kwargs)
|
@@ -158,4 +185,6 @@ def get_Canonical(num_classes):
|
|
158 |
|
159 |
|
160 |
def get_Embedding(num_embs_per_feature=100):
|
161 |
-
return lambda num_features, emsize: EmbeddingEncoder(
|
|
|
|
|
|
18 |
|
19 |
|
20 |
class _PositionalEncoding(nn.Module):
|
21 |
+
def __init__(self, d_model, dropout=0.0):
|
22 |
super().__init__()
|
23 |
self.dropout = nn.Dropout(p=dropout)
|
24 |
self.d_model = d_model
|
25 |
+
self.device_test_tensor = nn.Parameter(torch.tensor(1.0))
|
26 |
|
27 |
+
def forward(self, x): # T x B x num_features
|
28 |
+
assert self.d_model % x.shape[-1] * 2 == 0
|
29 |
d_per_feature = self.d_model // x.shape[-1]
|
30 |
pe = torch.zeros(*x.shape, d_per_feature, device=self.device_test_tensor.device)
|
31 |
+
# position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
32 |
interval_size = 10
|
33 |
+
div_term = (
|
34 |
+
(1.0 / interval_size)
|
35 |
+
* 2
|
36 |
+
* math.pi
|
37 |
+
* torch.exp(
|
38 |
+
torch.arange(
|
39 |
+
0, d_per_feature, 2, device=self.device_test_tensor.device
|
40 |
+
).float()
|
41 |
+
* math.log(math.sqrt(2))
|
42 |
+
)
|
43 |
+
)
|
44 |
+
# print(div_term/2/math.pi)
|
45 |
pe[..., 0::2] = torch.sin(x.unsqueeze(-1) * div_term)
|
46 |
pe[..., 1::2] = torch.cos(x.unsqueeze(-1) * div_term)
|
47 |
+
return self.dropout(pe).view(x.shape[0], x.shape[1], self.d_model)
|
48 |
|
49 |
|
50 |
Positional = lambda _, emsize: _PositionalEncoding(d_model=emsize)
|
51 |
|
52 |
+
|
53 |
class EmbeddingEncoder(nn.Module):
|
54 |
def __init__(self, num_features, em_size, num_embs=100):
|
55 |
super().__init__()
|
56 |
self.num_embs = num_embs
|
57 |
self.embeddings = nn.Embedding(num_embs * num_features, em_size, max_norm=True)
|
58 |
+
self.init_weights(0.1)
|
59 |
+
self.min_max = (-2, +2)
|
60 |
|
61 |
@property
|
62 |
def width(self):
|
|
|
71 |
|
72 |
def forward(self, x): # T x B x num_features
|
73 |
x_idxs = self.discretize(x)
|
74 |
+
x_idxs += (
|
75 |
+
torch.arange(x.shape[-1], device=x.device).view(1, 1, -1) * self.num_embs
|
76 |
+
)
|
77 |
# print(x_idxs,self.embeddings.weight.shape)
|
78 |
return self.embeddings(x_idxs).mean(-2)
|
79 |
|
|
|
85 |
self.std = std
|
86 |
|
87 |
def forward(self, x):
|
88 |
+
return (x - self.mean) / self.std
|
89 |
|
90 |
|
91 |
def get_normalized_uniform_encoder(encoder_creator):
|
|
|
96 |
:param encoder:
|
97 |
:return:
|
98 |
"""
|
99 |
+
return lambda in_dim, out_dim: nn.Sequential(
|
100 |
+
Normalize(0.5, math.sqrt(1 / 12)), encoder_creator(in_dim, out_dim)
|
101 |
+
)
|
102 |
|
103 |
|
104 |
Linear = nn.Linear
|
105 |
+
MLP = lambda num_features, emsize: nn.Sequential(
|
106 |
+
nn.Linear(num_features + 1, emsize * 2), nn.ReLU(), nn.Linear(emsize * 2, emsize)
|
107 |
+
)
|
108 |
+
|
109 |
|
110 |
class NanHandlingEncoder(nn.Module):
|
111 |
def __init__(self, num_features, emsize, keep_nans=True):
|
|
|
117 |
|
118 |
def forward(self, x):
|
119 |
if self.keep_nans:
|
120 |
+
x = torch.cat(
|
121 |
+
[
|
122 |
+
torch.nan_to_num(x, nan=0.0),
|
123 |
+
normalize_data(
|
124 |
+
torch.isnan(x) * -1
|
125 |
+
+ torch.logical_and(torch.isinf(x), torch.sign(x) == 1) * 1
|
126 |
+
+ torch.logical_and(torch.isinf(x), torch.sign(x) == -1) * 2
|
127 |
+
),
|
128 |
+
],
|
129 |
+
-1,
|
130 |
+
)
|
131 |
else:
|
132 |
x = torch.nan_to_num(x, nan=0.0)
|
133 |
return self.layer(x)
|
|
|
147 |
class Conv(nn.Module):
|
148 |
def __init__(self, input_size, emsize):
|
149 |
super().__init__()
|
150 |
+
self.convs = torch.nn.ModuleList(
|
151 |
+
[nn.Conv2d(64 if i else 1, 64, 3) for i in range(5)]
|
152 |
+
)
|
153 |
+
self.linear = nn.Linear(64, emsize)
|
154 |
|
155 |
def forward(self, x):
|
156 |
size = math.isqrt(x.shape[-1])
|
157 |
+
assert size * size == x.shape[-1]
|
158 |
x = x.reshape(*x.shape[:-1], 1, size, size)
|
159 |
for conv in self.convs:
|
160 |
if x.shape[-1] < 4:
|
161 |
break
|
162 |
x = conv(x)
|
163 |
x.relu_()
|
164 |
+
x = nn.AdaptiveAvgPool2d((1, 1))(x).squeeze(-1).squeeze(-1)
|
165 |
return self.linear(x)
|
166 |
|
167 |
|
168 |
class CanEmb(nn.Embedding):
|
169 |
+
def __init__(
|
170 |
+
self, num_features, num_embeddings: int, embedding_dim: int, *args, **kwargs
|
171 |
+
):
|
172 |
assert embedding_dim % num_features == 0
|
173 |
embedding_dim = embedding_dim // num_features
|
174 |
super().__init__(num_embeddings, embedding_dim, *args, **kwargs)
|
|
|
185 |
|
186 |
|
187 |
def get_Embedding(num_embs_per_feature=100):
|
188 |
+
return lambda num_features, emsize: EmbeddingEncoder(
|
189 |
+
num_features, emsize, num_embs=num_embs_per_feature
|
190 |
+
)
|
lcpfn/initializers.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1 |
import torch
|
2 |
from torch import nn
|
3 |
|
|
|
4 |
def get_NormalInitializer(std):
|
5 |
def initializer(m):
|
6 |
if isinstance(m, nn.Linear):
|
7 |
nn.init.normal_(m.weight, 0, std)
|
8 |
nn.init.normal_(m.bias, 0, std)
|
9 |
-
|
|
|
|
1 |
import torch
|
2 |
from torch import nn
|
3 |
|
4 |
+
|
5 |
def get_NormalInitializer(std):
|
6 |
def initializer(m):
|
7 |
if isinstance(m, nn.Linear):
|
8 |
nn.init.normal_(m.weight, 0, std)
|
9 |
nn.init.normal_(m.bias, 0, std)
|
10 |
+
|
11 |
+
return initializer
|
lcpfn/layer.py
CHANGED
@@ -36,15 +36,28 @@ class TransformerEncoderLayer(nn.Module):
|
|
36 |
>>> src = torch.rand(32, 10, 512)
|
37 |
>>> out = encoder_layer(src)
|
38 |
"""
|
39 |
-
__constants__ = ['batch_first']
|
40 |
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
super().__init__()
|
46 |
-
self.self_attn = MultiheadAttention(
|
47 |
-
|
|
|
48 |
# Implementation of Feedforward model
|
49 |
self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
|
50 |
self.dropout = Dropout(dropout)
|
@@ -60,11 +73,16 @@ class TransformerEncoderLayer(nn.Module):
|
|
60 |
self.activation = _get_activation_fn(activation)
|
61 |
|
62 |
def __setstate__(self, state):
|
63 |
-
if
|
64 |
-
state[
|
65 |
super().__setstate__(state)
|
66 |
|
67 |
-
def forward(
|
|
|
|
|
|
|
|
|
|
|
68 |
r"""Pass the input through the encoder layer.
|
69 |
|
70 |
Args:
|
@@ -90,26 +108,61 @@ class TransformerEncoderLayer(nn.Module):
|
|
90 |
num_train_tokens = trainset_src_mask.shape[0]
|
91 |
|
92 |
global_tokens_src = src_[:num_global_tokens]
|
93 |
-
train_tokens_src = src_[
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
else:
|
108 |
if self.recompute_attn:
|
109 |
-
src2 = checkpoint(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
else:
|
111 |
-
src2 = self.self_attn(
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
113 |
src = src + self.dropout1(src2)
|
114 |
if not self.pre_norm:
|
115 |
src = self.norm1(src)
|
@@ -123,4 +176,4 @@ class TransformerEncoderLayer(nn.Module):
|
|
123 |
|
124 |
if not self.pre_norm:
|
125 |
src = self.norm2(src)
|
126 |
-
return src
|
|
|
36 |
>>> src = torch.rand(32, 10, 512)
|
37 |
>>> out = encoder_layer(src)
|
38 |
"""
|
|
|
39 |
|
40 |
+
__constants__ = ["batch_first"]
|
41 |
+
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
d_model,
|
45 |
+
nhead,
|
46 |
+
dim_feedforward=2048,
|
47 |
+
dropout=0.1,
|
48 |
+
activation="relu",
|
49 |
+
layer_norm_eps=1e-5,
|
50 |
+
batch_first=False,
|
51 |
+
pre_norm=False,
|
52 |
+
device=None,
|
53 |
+
dtype=None,
|
54 |
+
recompute_attn=False,
|
55 |
+
) -> None:
|
56 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
57 |
super().__init__()
|
58 |
+
self.self_attn = MultiheadAttention(
|
59 |
+
d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs
|
60 |
+
)
|
61 |
# Implementation of Feedforward model
|
62 |
self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
|
63 |
self.dropout = Dropout(dropout)
|
|
|
73 |
self.activation = _get_activation_fn(activation)
|
74 |
|
75 |
def __setstate__(self, state):
|
76 |
+
if "activation" not in state:
|
77 |
+
state["activation"] = F.relu
|
78 |
super().__setstate__(state)
|
79 |
|
80 |
+
def forward(
|
81 |
+
self,
|
82 |
+
src: Tensor,
|
83 |
+
src_mask: Optional[Tensor] = None,
|
84 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
85 |
+
) -> Tensor:
|
86 |
r"""Pass the input through the encoder layer.
|
87 |
|
88 |
Args:
|
|
|
108 |
num_train_tokens = trainset_src_mask.shape[0]
|
109 |
|
110 |
global_tokens_src = src_[:num_global_tokens]
|
111 |
+
train_tokens_src = src_[
|
112 |
+
num_global_tokens : num_global_tokens + num_train_tokens
|
113 |
+
]
|
114 |
+
global_and_train_tokens_src = src_[: num_global_tokens + num_train_tokens]
|
115 |
+
eval_tokens_src = src_[num_global_tokens + num_train_tokens :]
|
116 |
+
|
117 |
+
attn = (
|
118 |
+
partial(checkpoint, self.self_attn)
|
119 |
+
if self.recompute_attn
|
120 |
+
else self.self_attn
|
121 |
+
)
|
122 |
+
|
123 |
+
global_tokens_src2 = attn(
|
124 |
+
global_tokens_src,
|
125 |
+
global_and_train_tokens_src,
|
126 |
+
global_and_train_tokens_src,
|
127 |
+
None,
|
128 |
+
True,
|
129 |
+
global_src_mask,
|
130 |
+
)[0]
|
131 |
+
train_tokens_src2 = attn(
|
132 |
+
train_tokens_src,
|
133 |
+
global_tokens_src,
|
134 |
+
global_tokens_src,
|
135 |
+
None,
|
136 |
+
True,
|
137 |
+
trainset_src_mask,
|
138 |
+
)[0]
|
139 |
+
eval_tokens_src2 = attn(
|
140 |
+
eval_tokens_src, src_, src_, None, True, valset_src_mask
|
141 |
+
)[0]
|
142 |
+
|
143 |
+
src2 = torch.cat(
|
144 |
+
[global_tokens_src2, train_tokens_src2, eval_tokens_src2], dim=0
|
145 |
+
)
|
146 |
|
147 |
else:
|
148 |
if self.recompute_attn:
|
149 |
+
src2 = checkpoint(
|
150 |
+
self.self_attn,
|
151 |
+
src_,
|
152 |
+
src_,
|
153 |
+
src_,
|
154 |
+
src_key_padding_mask,
|
155 |
+
True,
|
156 |
+
src_mask,
|
157 |
+
)[0]
|
158 |
else:
|
159 |
+
src2 = self.self_attn(
|
160 |
+
src_,
|
161 |
+
src_,
|
162 |
+
src_,
|
163 |
+
attn_mask=src_mask,
|
164 |
+
key_padding_mask=src_key_padding_mask,
|
165 |
+
)[0]
|
166 |
src = src + self.dropout1(src2)
|
167 |
if not self.pre_norm:
|
168 |
src = self.norm1(src)
|
|
|
176 |
|
177 |
if not self.pre_norm:
|
178 |
src = self.norm2(src)
|
179 |
+
return src
|
lcpfn/model.py
CHANGED
@@ -1,29 +1,56 @@
|
|
1 |
import torch
|
2 |
import lcpfn
|
|
|
|
|
|
|
3 |
|
4 |
class LCPFN(torch.nn.Module):
|
5 |
def __init__(self, model_name="EMSIZE512_NLAYERS12_NBUCKETS1000"):
|
6 |
super(LCPFN, self).__init__()
|
7 |
-
self.model = torch.load(
|
|
|
|
|
8 |
self.model.eval()
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
@torch.no_grad()
|
11 |
-
def predict_mean(
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
14 |
|
15 |
@torch.no_grad()
|
16 |
-
def predict_quantiles(
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
@torch.no_grad()
|
21 |
def nll_loss(self, x_train, y_train, x_test, y_test):
|
|
|
22 |
logits = self(x_train=x_train, y_train=y_train, x_test=x_test)
|
23 |
return self.model.criterion(logits, y_test)
|
24 |
|
25 |
def forward(self, x_train, y_train, x_test):
|
|
|
26 |
single_eval_pos = x_train.shape[0]
|
27 |
x = torch.cat([x_train, x_test], dim=0).unsqueeze(1)
|
28 |
y = y_train.unsqueeze(1)
|
29 |
-
return self.model((x, y), single_eval_pos=single_eval_pos)
|
|
|
1 |
import torch
|
2 |
import lcpfn
|
3 |
+
import warnings
|
4 |
+
from lcpfn import utils
|
5 |
+
|
6 |
|
7 |
class LCPFN(torch.nn.Module):
|
8 |
def __init__(self, model_name="EMSIZE512_NLAYERS12_NBUCKETS1000"):
|
9 |
super(LCPFN, self).__init__()
|
10 |
+
self.model = torch.load(
|
11 |
+
getattr(lcpfn, model_name) if model_name in lcpfn.model_dict else model_name
|
12 |
+
)
|
13 |
self.model.eval()
|
14 |
|
15 |
+
def check_input(self, x_train, x_test, y_train, y_test=None):
|
16 |
+
if torch.any(x_train < 0) or torch.any(x_test < 0):
|
17 |
+
# raise warning if input has negative values
|
18 |
+
raise Exception("x values should be non-negative")
|
19 |
+
if torch.any((0 > y_train) | (y_train > 1)) or (
|
20 |
+
y_test is not None and torch.any(0 < y_test < 1)
|
21 |
+
):
|
22 |
+
# raise warning if input has values outside [0,1]
|
23 |
+
raise Exception(
|
24 |
+
"y values should be in the range [0,1]. Please set normalizer_kwargs accordingly."
|
25 |
+
)
|
26 |
+
|
27 |
@torch.no_grad()
|
28 |
+
def predict_mean(
|
29 |
+
self, x_train, y_train, x_test, normalizer=utils.identity_normalizer()
|
30 |
+
):
|
31 |
+
y_train_norm = normalizer[0](y_train)
|
32 |
+
logits = self(x_train=x_train, y_train=y_train_norm, x_test=x_test)
|
33 |
+
return normalizer[1](self.model.criterion.mean(logits))
|
34 |
|
35 |
@torch.no_grad()
|
36 |
+
def predict_quantiles(
|
37 |
+
self, x_train, y_train, x_test, qs, normalizer=utils.identity_normalizer()
|
38 |
+
):
|
39 |
+
y_train_norm = normalizer[0](y_train)
|
40 |
+
logits = self(x_train=x_train, y_train=y_train_norm, x_test=x_test)
|
41 |
+
return normalizer[1](
|
42 |
+
torch.cat([self.model.criterion.icdf(logits, q) for q in qs], dim=1)
|
43 |
+
)
|
44 |
|
45 |
@torch.no_grad()
|
46 |
def nll_loss(self, x_train, y_train, x_test, y_test):
|
47 |
+
# TODO add normalizer_kwargs
|
48 |
logits = self(x_train=x_train, y_train=y_train, x_test=x_test)
|
49 |
return self.model.criterion(logits, y_test)
|
50 |
|
51 |
def forward(self, x_train, y_train, x_test):
|
52 |
+
self.check_input(x_train, x_test, y_train)
|
53 |
single_eval_pos = x_train.shape[0]
|
54 |
x = torch.cat([x_train, x_test], dim=0).unsqueeze(1)
|
55 |
y = y_train.unsqueeze(1)
|
56 |
+
return self.model((x, y), single_eval_pos=single_eval_pos)
|
lcpfn/positional_encodings.py
CHANGED
@@ -15,7 +15,7 @@ class NoPositionalEncoding(nn.Module):
|
|
15 |
pass
|
16 |
|
17 |
def forward(self, x):
|
18 |
-
return x
|
19 |
|
20 |
|
21 |
class PositionalEncoding(nn.Module):
|
@@ -23,14 +23,16 @@ class PositionalEncoding(nn.Module):
|
|
23 |
super(PositionalEncoding, self).__init__()
|
24 |
pe = torch.zeros(max_len, d_model)
|
25 |
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
26 |
-
div_term = torch.exp(
|
|
|
|
|
27 |
pe[:, 0::2] = torch.sin(position * div_term)
|
28 |
pe[:, 1::2] = torch.cos(position * div_term)
|
29 |
pe = pe.unsqueeze(0).transpose(0, 1)
|
30 |
-
self.register_buffer(
|
31 |
|
32 |
def forward(self, x):
|
33 |
-
x = self.pe[:x.size(0), :] + x
|
34 |
return x
|
35 |
|
36 |
|
@@ -38,33 +40,39 @@ class LearnedPositionalEncoding(nn.Module):
|
|
38 |
def __init__(self, d_model, max_len=5000):
|
39 |
super(LearnedPositionalEncoding, self).__init__()
|
40 |
self.max_seq_len = max_len
|
41 |
-
#self.positional_embeddings = nn.Embedding(max_len, d_model)
|
42 |
self.positional_embeddings = nn.Parameter(torch.empty(max_len, d_model))
|
43 |
-
nn.init.normal_(self.positional_embeddings, mean=0, std=d_model
|
44 |
|
45 |
def forward(self, x):
|
46 |
seq_len, bs, d_model = x.shape
|
47 |
-
assert seq_len <= len(
|
|
|
|
|
48 |
pos_emb = self.positional_embeddings[:seq_len]
|
49 |
-
return
|
|
|
|
|
50 |
|
51 |
|
52 |
class PairedScrambledPositionalEncodings(LearnedPositionalEncoding):
|
53 |
# TODO check whether it is a problem to use the same perm. for full batch
|
54 |
def forward(self, x):
|
55 |
seq_len, bs, d_model = x.shape
|
56 |
-
assert seq_len <= len(
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
|
|
|
|
|
15 |
pass
|
16 |
|
17 |
def forward(self, x):
|
18 |
+
return x # * math.sqrt(x.shape[-1])
|
19 |
|
20 |
|
21 |
class PositionalEncoding(nn.Module):
|
|
|
23 |
super(PositionalEncoding, self).__init__()
|
24 |
pe = torch.zeros(max_len, d_model)
|
25 |
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
26 |
+
div_term = torch.exp(
|
27 |
+
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
|
28 |
+
)
|
29 |
pe[:, 0::2] = torch.sin(position * div_term)
|
30 |
pe[:, 1::2] = torch.cos(position * div_term)
|
31 |
pe = pe.unsqueeze(0).transpose(0, 1)
|
32 |
+
self.register_buffer("pe", pe)
|
33 |
|
34 |
def forward(self, x):
|
35 |
+
x = self.pe[: x.size(0), :] + x # * math.sqrt(x.shape[-1])
|
36 |
return x
|
37 |
|
38 |
|
|
|
40 |
def __init__(self, d_model, max_len=5000):
|
41 |
super(LearnedPositionalEncoding, self).__init__()
|
42 |
self.max_seq_len = max_len
|
43 |
+
# self.positional_embeddings = nn.Embedding(max_len, d_model)
|
44 |
self.positional_embeddings = nn.Parameter(torch.empty(max_len, d_model))
|
45 |
+
nn.init.normal_(self.positional_embeddings, mean=0, std=d_model**-0.5)
|
46 |
|
47 |
def forward(self, x):
|
48 |
seq_len, bs, d_model = x.shape
|
49 |
+
assert seq_len <= len(
|
50 |
+
self.positional_embeddings
|
51 |
+
), "seq_len can be at most max_len."
|
52 |
pos_emb = self.positional_embeddings[:seq_len]
|
53 |
+
return (
|
54 |
+
pos_emb.unsqueeze(1).expand(seq_len, bs, d_model) + x
|
55 |
+
) # * math.sqrt(x.shape[-1])
|
56 |
|
57 |
|
58 |
class PairedScrambledPositionalEncodings(LearnedPositionalEncoding):
|
59 |
# TODO check whether it is a problem to use the same perm. for full batch
|
60 |
def forward(self, x):
|
61 |
seq_len, bs, d_model = x.shape
|
62 |
+
assert seq_len <= len(
|
63 |
+
self.positional_embeddings
|
64 |
+
), "seq_len can be at most max_len."
|
65 |
+
assert (
|
66 |
+
len(self.positional_embeddings) % 2 == 0
|
67 |
+
), "Please specify an even max_len."
|
68 |
+
|
69 |
+
paired_embs = self.positional_embeddings.view(
|
70 |
+
len(self.positional_embeddings), -1, 2
|
71 |
+
)
|
72 |
+
pos_emb = paired_embs[torch.randperm(len(paired_embs))].view(
|
73 |
+
*self.positional_embeddings.shape
|
74 |
+
)[:seq_len]
|
75 |
+
|
76 |
+
return (
|
77 |
+
pos_emb.unsqueeze(1).expand(seq_len, bs, d_model) + x
|
78 |
+
) # * math.sqrt(x.shape[-1])
|
lcpfn/train.py
CHANGED
@@ -1,12 +1,7 @@
|
|
1 |
-
import os
|
2 |
import itertools
|
3 |
-
import argparse
|
4 |
import time
|
5 |
-
import datetime
|
6 |
-
import yaml
|
7 |
from contextlib import nullcontext
|
8 |
|
9 |
-
import pickle
|
10 |
import torch
|
11 |
from torch import nn
|
12 |
|
@@ -14,18 +9,11 @@ from lcpfn import utils
|
|
14 |
from lcpfn.transformer import TransformerModel
|
15 |
from lcpfn.bar_distribution import (
|
16 |
BarDistribution,
|
17 |
-
FullSupportBarDistribution,
|
18 |
-
get_bucket_limits,
|
19 |
)
|
20 |
from lcpfn.utils import (
|
21 |
get_cosine_schedule_with_warmup,
|
22 |
get_openai_lr,
|
23 |
-
StoreDictKeyPair,
|
24 |
-
get_weighted_single_eval_pos_sampler,
|
25 |
-
get_uniform_single_eval_pos_sampler,
|
26 |
)
|
27 |
-
from lcpfn import priors
|
28 |
-
from lcpfn import encoders
|
29 |
from lcpfn import positional_encodings
|
30 |
from lcpfn.utils import init_dist
|
31 |
from torch.cuda.amp import autocast, GradScaler
|
@@ -294,7 +282,6 @@ def train(
|
|
294 |
list_losses = []
|
295 |
try:
|
296 |
for epoch in range(1, epochs + 1) if epochs is not None else itertools.count(1):
|
297 |
-
|
298 |
epoch_start_time = time.time()
|
299 |
(
|
300 |
total_loss,
|
@@ -347,256 +334,3 @@ def train(
|
|
347 |
torch.save(model.to("cpu"), output_path)
|
348 |
print("Checkpoint stored at ", output_path)
|
349 |
return total_loss, total_positional_losses, model.to("cpu"), dl
|
350 |
-
|
351 |
-
|
352 |
-
def _parse_args(config_parser, parser):
|
353 |
-
# Do we have a config file to parse?
|
354 |
-
args_config, remaining = config_parser.parse_known_args()
|
355 |
-
if args_config.config:
|
356 |
-
with open(args_config.config, "r") as f:
|
357 |
-
cfg = yaml.safe_load(f)
|
358 |
-
parser.set_defaults(**cfg)
|
359 |
-
|
360 |
-
# The main arg parser parses the rest of the args, the usual
|
361 |
-
# defaults will have been overridden if config file specified.
|
362 |
-
args = parser.parse_args(remaining)
|
363 |
-
|
364 |
-
# Cache the args as a text string to save them in the output dir later
|
365 |
-
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
|
366 |
-
return args, args_text
|
367 |
-
|
368 |
-
|
369 |
-
if __name__ == "__main__":
|
370 |
-
config_parser = argparse.ArgumentParser(
|
371 |
-
description="Only used as a first parser for the config file path."
|
372 |
-
)
|
373 |
-
config_parser.add_argument("--config")
|
374 |
-
parser = argparse.ArgumentParser()
|
375 |
-
parser.add_argument("prior")
|
376 |
-
parser.add_argument("--loss_function", default="barnll")
|
377 |
-
# Optional Arg's for `--loss_function barnll`
|
378 |
-
parser.add_argument(
|
379 |
-
"--min_y",
|
380 |
-
type=float,
|
381 |
-
help="barnll can only model y in strict ranges, this is the minimum y can take.",
|
382 |
-
)
|
383 |
-
parser.add_argument(
|
384 |
-
"--max_y",
|
385 |
-
type=float,
|
386 |
-
help="barnll can only model y in strict ranges, this is the maximum y can take.",
|
387 |
-
)
|
388 |
-
parser.add_argument("--num_buckets", default=100, type=int)
|
389 |
-
# parser.add_argument('--num_features', default=None, type=int, help='Specify depending on the prior.')
|
390 |
-
parser.add_argument(
|
391 |
-
"--extra_prior_kwargs_dict",
|
392 |
-
default={},
|
393 |
-
dest="extra_prior_kwargs_dict",
|
394 |
-
action=StoreDictKeyPair,
|
395 |
-
nargs="+",
|
396 |
-
metavar="KEY=VAL",
|
397 |
-
help="Specify depending on the prior.",
|
398 |
-
)
|
399 |
-
parser.add_argument(
|
400 |
-
"--encoder", default="linear", type=str, help="Specify depending on the prior."
|
401 |
-
)
|
402 |
-
parser.add_argument(
|
403 |
-
"--y_encoder",
|
404 |
-
default="linear",
|
405 |
-
type=str,
|
406 |
-
help="Specify depending on the prior. You should specify this if you do not fuse x and y.",
|
407 |
-
)
|
408 |
-
parser.add_argument(
|
409 |
-
"--pos_encoder",
|
410 |
-
default="none",
|
411 |
-
type=str,
|
412 |
-
help="Specify depending on the prior.",
|
413 |
-
)
|
414 |
-
parser.add_argument("--bptt", default=10, type=int)
|
415 |
-
parser.add_argument("--epochs", default=200, type=int)
|
416 |
-
parser.add_argument("--warmup_epochs", default=50, type=int)
|
417 |
-
parser.add_argument("--validation_period", default=10, type=int)
|
418 |
-
parser.add_argument(
|
419 |
-
"--permutation_invariant_max_eval_pos",
|
420 |
-
default=None,
|
421 |
-
type=int,
|
422 |
-
help="Set this to an int to ",
|
423 |
-
)
|
424 |
-
parser.add_argument(
|
425 |
-
"--permutation_invariant_sampling",
|
426 |
-
default="weighted",
|
427 |
-
help="Only relevant if --permutation_invariant_max_eval_pos is set.",
|
428 |
-
)
|
429 |
-
parser.add_argument("--train_mixed_precision", action="store_true")
|
430 |
-
|
431 |
-
# these can likely be mostly left at defaults
|
432 |
-
parser.add_argument(
|
433 |
-
"--emsize", default=512, type=int
|
434 |
-
) # sometimes even larger is better e.g. 1024
|
435 |
-
parser.add_argument("--nlayers", default=6, type=int)
|
436 |
-
parser.add_argument("--nhid", default=None, type=int) # 2*emsize is the default
|
437 |
-
parser.add_argument(
|
438 |
-
"--nhead", default=4, type=int
|
439 |
-
) # nhead = emsize / 64 in the original paper
|
440 |
-
parser.add_argument("--dropout", default=0.0, type=float)
|
441 |
-
parser.add_argument("--steps_per_epoch", default=10, type=int)
|
442 |
-
parser.add_argument("--batch_size", default=1000, type=int)
|
443 |
-
parser.add_argument(
|
444 |
-
"--lr", "--learning_rate", default=0.001, type=float
|
445 |
-
) # try also .0003, .0001, go lower with lower batch size
|
446 |
-
parser.add_argument("--gpu_device", default="cuda", type=str)
|
447 |
-
|
448 |
-
# for model checkpointing
|
449 |
-
parser.add_argument(
|
450 |
-
"--checkpoint_file",
|
451 |
-
help="absolute or relative-to-the-project-rootdir path to the file storing the state dicts.",
|
452 |
-
default=None,
|
453 |
-
type=str,
|
454 |
-
)
|
455 |
-
parser.add_argument("--saving_period", default=10, type=str)
|
456 |
-
|
457 |
-
args, _ = _parse_args(config_parser, parser)
|
458 |
-
|
459 |
-
if args.nhid is None:
|
460 |
-
args.nhid = 2 * args.emsize
|
461 |
-
|
462 |
-
prior = args.__dict__.pop("prior")
|
463 |
-
|
464 |
-
if prior == "gp":
|
465 |
-
prior = priors.fast_gp.DataLoader
|
466 |
-
elif prior == "ridge":
|
467 |
-
prior = priors.ridge.DataLoader
|
468 |
-
elif prior == "stroke":
|
469 |
-
prior = priors.stroke.DataLoader
|
470 |
-
elif prior == "mix_gp":
|
471 |
-
prior = priors.fast_gp_mix.DataLoader
|
472 |
-
else:
|
473 |
-
raise NotImplementedError(f"Prior == {prior}.")
|
474 |
-
|
475 |
-
loss_function = args.__dict__.pop("loss_function")
|
476 |
-
|
477 |
-
criterion = nn.GaussianNLLLoss(reduction="none", full=True)
|
478 |
-
classificiation_criterion = nn.CrossEntropyLoss(reduction="none")
|
479 |
-
num_buckets = args.__dict__.pop("num_buckets")
|
480 |
-
max_y = args.__dict__.pop("max_y")
|
481 |
-
min_y = args.__dict__.pop("min_y")
|
482 |
-
# criterion = nn.MSELoss(reduction='none')
|
483 |
-
|
484 |
-
device = args.gpu_device if torch.cuda.is_available() else "cpu:0"
|
485 |
-
|
486 |
-
def get_y_sample():
|
487 |
-
args.__dict__["extra_prior_kwargs_dict"]["eval_pos_seq_len_sampler"] = lambda: (
|
488 |
-
args.bptt,
|
489 |
-
args.bptt,
|
490 |
-
)
|
491 |
-
dl = prior(
|
492 |
-
num_steps=1,
|
493 |
-
batch_size=args.batch_size * args.steps_per_epoch,
|
494 |
-
seq_len=args.bptt,
|
495 |
-
device=device,
|
496 |
-
**args.extra_prior_kwargs_dict,
|
497 |
-
)
|
498 |
-
args.__dict__["extra_prior_kwargs_dict"].pop("eval_pos_seq_len_sampler")
|
499 |
-
|
500 |
-
y_sample = next(iter(dl))[-2]
|
501 |
-
print(
|
502 |
-
f"Creating Bar distribution with borders from y sample of size {y_sample.numel()}"
|
503 |
-
)
|
504 |
-
return y_sample
|
505 |
-
|
506 |
-
if loss_function == "ce":
|
507 |
-
criterion = nn.CrossEntropyLoss(reduction="none")
|
508 |
-
elif loss_function == "gaussnll":
|
509 |
-
criterion = nn.GaussianNLLLoss(reduction="none", full=True)
|
510 |
-
elif loss_function == "mse":
|
511 |
-
criterion = nn.MSELoss(reduction="none")
|
512 |
-
elif loss_function == "barnll":
|
513 |
-
criterion = BarDistribution(
|
514 |
-
borders=get_bucket_limits(num_buckets, full_range=(min_y, max_y))
|
515 |
-
)
|
516 |
-
elif loss_function == "adaptivebarnll":
|
517 |
-
borders = get_bucket_limits(
|
518 |
-
num_buckets, ys=get_y_sample(), full_range=(min_y, max_y)
|
519 |
-
)
|
520 |
-
criterion = BarDistribution(borders=borders)
|
521 |
-
elif loss_function == "adaptivefullsupportbarnll":
|
522 |
-
assert (
|
523 |
-
min_y is None and max_y is None
|
524 |
-
), "Please do not specify `min_y` and `max_y` with `unboundedadaptivebarnll`."
|
525 |
-
borders = get_bucket_limits(num_buckets, ys=get_y_sample())
|
526 |
-
criterion = FullSupportBarDistribution(borders=borders)
|
527 |
-
else:
|
528 |
-
raise NotImplementedError(f"loss_function == {loss_function}.")
|
529 |
-
|
530 |
-
encoder = args.__dict__.pop("encoder")
|
531 |
-
y_encoder = args.__dict__.pop("y_encoder")
|
532 |
-
|
533 |
-
def get_encoder_generator(encoder):
|
534 |
-
if encoder == "linear":
|
535 |
-
encoder_generator = encoders.Linear
|
536 |
-
elif encoder == "mlp":
|
537 |
-
encoder_generator = encoders.MLP
|
538 |
-
elif encoder == "positional":
|
539 |
-
encoder_generator = encoders.Positional
|
540 |
-
else:
|
541 |
-
raise NotImplementedError(f"A {encoder} encoder is not valid.")
|
542 |
-
return encoder_generator
|
543 |
-
|
544 |
-
encoder_generator = get_encoder_generator(encoder)
|
545 |
-
y_encoder_generator = get_encoder_generator(y_encoder)
|
546 |
-
|
547 |
-
pos_encoder = args.__dict__.pop("pos_encoder")
|
548 |
-
|
549 |
-
if pos_encoder == "none":
|
550 |
-
pos_encoder_generator = None
|
551 |
-
elif pos_encoder == "sinus":
|
552 |
-
pos_encoder_generator = positional_encodings.PositionalEncoding
|
553 |
-
elif pos_encoder == "learned":
|
554 |
-
pos_encoder_generator = positional_encodings.LearnedPositionalEncoding
|
555 |
-
elif pos_encoder == "paired_scrambled_learned":
|
556 |
-
pos_encoder_generator = positional_encodings.PairedScrambledPositionalEncodings
|
557 |
-
else:
|
558 |
-
raise NotImplementedError(f"pos_encoer == {pos_encoder} is not valid.")
|
559 |
-
|
560 |
-
permutation_invariant_max_eval_pos = args.__dict__.pop(
|
561 |
-
"permutation_invariant_max_eval_pos"
|
562 |
-
)
|
563 |
-
permutation_invariant_sampling = args.__dict__.pop("permutation_invariant_sampling")
|
564 |
-
if permutation_invariant_max_eval_pos is not None:
|
565 |
-
if permutation_invariant_sampling == "weighted":
|
566 |
-
get_sampler = get_weighted_single_eval_pos_sampler
|
567 |
-
elif permutation_invariant_sampling == "uniform":
|
568 |
-
get_sampler = get_uniform_single_eval_pos_sampler
|
569 |
-
else:
|
570 |
-
raise ValueError()
|
571 |
-
args.__dict__["single_eval_pos_gen"] = get_sampler(
|
572 |
-
permutation_invariant_max_eval_pos
|
573 |
-
)
|
574 |
-
|
575 |
-
print("ARGS for `train`:", args.__dict__)
|
576 |
-
|
577 |
-
if args.__dict__["checkpoint_file"] is not None:
|
578 |
-
rootdir = os.path.dirname(os.path.realpath(__file__))
|
579 |
-
args.__dict__["checkpoint_file"] = os.path.join(
|
580 |
-
rootdir, args.__dict__["checkpoint_file"]
|
581 |
-
)
|
582 |
-
|
583 |
-
if os.path.exists(args.__dict__["checkpoint_file"]):
|
584 |
-
state_dicts = torch.load(args.__dict__["checkpoint_file"])
|
585 |
-
args.__dict__["load_weights_from_this_state_dict"] = state_dicts[
|
586 |
-
"model_state_dict"
|
587 |
-
]
|
588 |
-
args.__dict__["load_optimizer_from_this_state_dict"] = state_dicts[
|
589 |
-
"optimizer_state_dict"
|
590 |
-
]
|
591 |
-
else:
|
592 |
-
args.__dict__["load_weights_from_this_state_dict"] = None
|
593 |
-
args.__dict__["load_optimizer_from_this_state_dict"] = None
|
594 |
-
|
595 |
-
train(
|
596 |
-
prior,
|
597 |
-
criterion,
|
598 |
-
encoder_generator,
|
599 |
-
y_encoder_generator=y_encoder_generator,
|
600 |
-
pos_encoder_generator=pos_encoder_generator,
|
601 |
-
**args.__dict__,
|
602 |
-
)
|
|
|
|
|
1 |
import itertools
|
|
|
2 |
import time
|
|
|
|
|
3 |
from contextlib import nullcontext
|
4 |
|
|
|
5 |
import torch
|
6 |
from torch import nn
|
7 |
|
|
|
9 |
from lcpfn.transformer import TransformerModel
|
10 |
from lcpfn.bar_distribution import (
|
11 |
BarDistribution,
|
|
|
|
|
12 |
)
|
13 |
from lcpfn.utils import (
|
14 |
get_cosine_schedule_with_warmup,
|
15 |
get_openai_lr,
|
|
|
|
|
|
|
16 |
)
|
|
|
|
|
17 |
from lcpfn import positional_encodings
|
18 |
from lcpfn.utils import init_dist
|
19 |
from torch.cuda.amp import autocast, GradScaler
|
|
|
282 |
list_losses = []
|
283 |
try:
|
284 |
for epoch in range(1, epochs + 1) if epochs is not None else itertools.count(1):
|
|
|
285 |
epoch_start_time = time.time()
|
286 |
(
|
287 |
total_loss,
|
|
|
334 |
torch.save(model.to("cpu"), output_path)
|
335 |
print("Checkpoint stored at ", output_path)
|
336 |
return total_loss, total_positional_losses, model.to("cpu"), dl
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lcpfn/train_lcpfn.py
CHANGED
@@ -2,9 +2,11 @@ import math
|
|
2 |
|
3 |
from torch import nn
|
4 |
|
5 |
-
from lcpfn import bar_distribution, encoders,
|
6 |
from lcpfn import utils
|
7 |
|
|
|
|
|
8 |
|
9 |
def train_lcpfn(
|
10 |
get_batch_func,
|
@@ -12,7 +14,7 @@ def train_lcpfn(
|
|
12 |
emsize: int = 512,
|
13 |
nlayers: int = 12,
|
14 |
num_borders: int = 1000,
|
15 |
-
lr: float = 0.
|
16 |
batch_size: int = 100,
|
17 |
epochs: int = 1000,
|
18 |
):
|
@@ -25,7 +27,7 @@ def train_lcpfn(
|
|
25 |
emsize (int, optional): The size of the embedding layer. Defaults to 512.
|
26 |
nlayers (int, optional): The number of layers in the model. Defaults to 12.
|
27 |
num_borders_choices (int, optional): The number of borders to use. Defaults to 1000.
|
28 |
-
lr (float, optional): The learning rate for the optimizer. Defaults to 0.
|
29 |
batch_size (int, optional): The batch size for training. Defaults to 100.
|
30 |
epochs (int, optional): The number of epochs to train for. Defaults to 1000.
|
31 |
|
@@ -36,7 +38,7 @@ def train_lcpfn(
|
|
36 |
hps = {}
|
37 |
|
38 |
# PFN training hyperparameters
|
39 |
-
dataloader =
|
40 |
|
41 |
num_features = 1
|
42 |
|
@@ -82,7 +84,9 @@ def train_lcpfn(
|
|
82 |
epochs=epochs,
|
83 |
lr=lr,
|
84 |
bptt=seq_len,
|
85 |
-
single_eval_pos_gen=utils.get_uniform_single_eval_pos_sampler(
|
|
|
|
|
86 |
aggregate_k_gradients=1,
|
87 |
nhid=(emsize * 2),
|
88 |
steps_per_epoch=100,
|
|
|
2 |
|
3 |
from torch import nn
|
4 |
|
5 |
+
from lcpfn import bar_distribution, encoders, train
|
6 |
from lcpfn import utils
|
7 |
|
8 |
+
from lcpfn.priors import utils as putils
|
9 |
+
|
10 |
|
11 |
def train_lcpfn(
|
12 |
get_batch_func,
|
|
|
14 |
emsize: int = 512,
|
15 |
nlayers: int = 12,
|
16 |
num_borders: int = 1000,
|
17 |
+
lr: float = 0.0001,
|
18 |
batch_size: int = 100,
|
19 |
epochs: int = 1000,
|
20 |
):
|
|
|
27 |
emsize (int, optional): The size of the embedding layer. Defaults to 512.
|
28 |
nlayers (int, optional): The number of layers in the model. Defaults to 12.
|
29 |
num_borders_choices (int, optional): The number of borders to use. Defaults to 1000.
|
30 |
+
lr (float, optional): The learning rate for the optimizer. Defaults to 0.0001.
|
31 |
batch_size (int, optional): The batch size for training. Defaults to 100.
|
32 |
epochs (int, optional): The number of epochs to train for. Defaults to 1000.
|
33 |
|
|
|
38 |
hps = {}
|
39 |
|
40 |
# PFN training hyperparameters
|
41 |
+
dataloader = putils.get_batch_to_dataloader(get_batch_func) # type: ignore
|
42 |
|
43 |
num_features = 1
|
44 |
|
|
|
84 |
epochs=epochs,
|
85 |
lr=lr,
|
86 |
bptt=seq_len,
|
87 |
+
single_eval_pos_gen=utils.get_uniform_single_eval_pos_sampler(
|
88 |
+
seq_len, min_len=1
|
89 |
+
),
|
90 |
aggregate_k_gradients=1,
|
91 |
nhid=(emsize * 2),
|
92 |
steps_per_epoch=100,
|
lcpfn/transformer.py
CHANGED
@@ -4,35 +4,74 @@ from typing import Optional
|
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
from torch import Tensor
|
|
|
7 |
from torch.nn import Module, TransformerEncoder
|
8 |
|
9 |
from lcpfn.layer import TransformerEncoderLayer, _get_activation_fn
|
10 |
from lcpfn.utils import SeqBN, bool_mask_to_att_mask
|
11 |
|
12 |
|
|
|
|
|
|
|
|
|
13 |
|
14 |
class TransformerModel(nn.Module):
|
15 |
-
def __init__(
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
super().__init__()
|
20 |
-
self.model_type =
|
21 |
-
encoder_layer_creator = lambda: TransformerEncoderLayer(
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
self.ninp = ninp
|
26 |
self.encoder = encoder
|
27 |
self.y_encoder = y_encoder
|
28 |
self.pos_encoder = pos_encoder
|
29 |
-
self.decoder =
|
|
|
|
|
|
|
|
|
30 |
self.input_ln = SeqBN(ninp) if input_normalization else None
|
31 |
self.style_encoder = style_encoder
|
32 |
self.init_method = init_method
|
33 |
if num_global_att_tokens is not None:
|
34 |
assert not full_attention
|
35 |
-
self.global_att_embeddings =
|
|
|
|
|
36 |
self.full_attention = full_attention
|
37 |
|
38 |
self.n_out = n_out
|
@@ -47,37 +86,49 @@ class TransformerModel(nn.Module):
|
|
47 |
|
48 |
@staticmethod
|
49 |
def generate_D_q_matrix(sz, query_size):
|
50 |
-
train_size = sz-query_size
|
51 |
-
mask = torch.zeros(sz,sz) == 0
|
52 |
-
mask[:,train_size:].zero_()
|
53 |
mask |= torch.eye(sz) == 1
|
54 |
return bool_mask_to_att_mask(mask)
|
55 |
|
56 |
@staticmethod
|
57 |
-
def generate_global_att_query_matrix(
|
|
|
|
|
58 |
train_size = seq_len + num_global_att_tokens - num_query_tokens
|
59 |
sz = seq_len + num_global_att_tokens
|
60 |
mask = torch.zeros(num_query_tokens, sz) == 0
|
61 |
-
mask[:,train_size:].zero_()
|
62 |
-
mask[:,train_size:] |= torch.eye(num_query_tokens) == 1
|
63 |
return bool_mask_to_att_mask(mask)
|
64 |
|
65 |
@staticmethod
|
66 |
-
def generate_global_att_trainset_matrix(
|
|
|
|
|
67 |
train_size = seq_len + num_global_att_tokens - num_query_tokens
|
68 |
trainset_size = seq_len - num_query_tokens
|
69 |
mask = torch.zeros(trainset_size, num_global_att_tokens) == 0
|
70 |
-
#mask[:,num_global_att_tokens:].zero_()
|
71 |
-
#mask[:,num_global_att_tokens:] |= torch.eye(trainset_size) == 1
|
72 |
return bool_mask_to_att_mask(mask)
|
73 |
|
74 |
@staticmethod
|
75 |
-
def generate_global_att_globaltokens_matrix(
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
return bool_mask_to_att_mask(mask)
|
78 |
|
79 |
def init_weights(self):
|
80 |
-
initrange = 1.
|
81 |
# if isinstance(self.encoder,EmbeddingEncoder):
|
82 |
# self.encoder.weight.data.uniform_(-initrange, initrange)
|
83 |
# self.decoder.bias.data.zero_()
|
@@ -87,41 +138,74 @@ class TransformerModel(nn.Module):
|
|
87 |
for layer in self.transformer_encoder.layers:
|
88 |
nn.init.zeros_(layer.linear2.weight)
|
89 |
nn.init.zeros_(layer.linear2.bias)
|
90 |
-
attns =
|
|
|
|
|
|
|
|
|
91 |
for attn in attns:
|
92 |
nn.init.zeros_(attn.out_proj.weight)
|
93 |
nn.init.zeros_(attn.out_proj.bias)
|
94 |
|
95 |
def forward(self, src, src_mask=None, single_eval_pos=None):
|
96 |
-
assert isinstance(
|
|
|
|
|
97 |
|
98 |
-
if len(src) == 2:
|
99 |
src = (None,) + src
|
100 |
|
101 |
style_src, style_src_size = (src[0], (0 if (src[0] is None) else 1))
|
102 |
-
if src_mask is not None:
|
|
|
103 |
if src_mask is None:
|
104 |
x_src = src[1]
|
105 |
if self.global_att_embeddings is None:
|
106 |
full_len = len(x_src) + style_src_size
|
107 |
if self.full_attention:
|
108 |
-
src_mask = bool_mask_to_att_mask(
|
|
|
|
|
109 |
else:
|
110 |
-
src_mask = self.generate_D_q_matrix(
|
|
|
|
|
|
|
111 |
else:
|
112 |
-
src_mask_args = (
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
style_src, x_src, y_src = src
|
120 |
x_src = self.encoder(x_src)
|
121 |
-
y_src = self.y_encoder(
|
122 |
-
|
123 |
-
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
train_x = x_src[:single_eval_pos] + y_src[:single_eval_pos]
|
126 |
src = torch.cat([global_src, style_src, train_x, x_src[single_eval_pos:]], 0)
|
127 |
|
@@ -134,16 +218,29 @@ class TransformerModel(nn.Module):
|
|
134 |
# If we have style input, drop its output
|
135 |
output = self.transformer_encoder(src, src_mask)[style_src_size:]
|
136 |
output = self.decoder(output)
|
137 |
-
return output[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
|
139 |
@torch.no_grad()
|
140 |
def init_from_small_model(self, small_model):
|
141 |
-
assert
|
142 |
-
|
|
|
|
|
|
|
143 |
|
144 |
def set_encoder_weights(my_encoder, small_model_encoder):
|
145 |
-
my_encoder_linear, small_encoder_linear = (
|
146 |
-
|
|
|
|
|
|
|
147 |
small_in_dim = small_encoder_linear.out_features
|
148 |
my_encoder_linear.weight.zero_()
|
149 |
my_encoder_linear.bias.zero_()
|
@@ -158,7 +255,9 @@ class TransformerModel(nn.Module):
|
|
158 |
self.decoder.weight[:, :small_in_dim] = small_model.decoder.weight
|
159 |
self.decoder.bias = small_model.decoder.bias
|
160 |
|
161 |
-
for my_layer, small_layer in zip(
|
|
|
|
|
162 |
small_hid_dim = small_layer.linear1.out_features
|
163 |
my_in_dim = my_layer.linear1.in_features
|
164 |
|
@@ -166,23 +265,36 @@ class TransformerModel(nn.Module):
|
|
166 |
my_in_proj_w = my_layer.self_attn.in_proj_weight
|
167 |
small_in_proj_w = small_layer.self_attn.in_proj_weight
|
168 |
|
169 |
-
my_in_proj_w.view(3, my_in_dim, my_in_dim)[
|
170 |
-
|
171 |
-
|
172 |
-
my_layer.self_attn.in_proj_bias.view(3, my_in_dim)[:,
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
my_layer.self_attn.out_proj.
|
177 |
-
|
178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
my_layer.linear1.bias[:small_hid_dim] = small_layer.linear1.bias
|
180 |
|
181 |
-
my_layer.linear2.weight[:small_in_dim, :small_hid_dim] =
|
|
|
|
|
182 |
my_layer.linear2.bias[:small_in_dim] = small_layer.linear2.bias
|
183 |
|
184 |
-
my_layer.norm1.weight[:small_in_dim] =
|
185 |
-
|
|
|
|
|
|
|
|
|
186 |
|
187 |
my_layer.norm1.bias[:small_in_dim] = small_layer.norm1.bias
|
188 |
my_layer.norm2.bias[:small_in_dim] = small_layer.norm2.bias
|
@@ -196,15 +308,23 @@ class TransformerEncoderDiffInit(Module):
|
|
196 |
num_layers: the number of sub-encoder-layers in the encoder (required).
|
197 |
norm: the layer normalization component (optional).
|
198 |
"""
|
199 |
-
|
|
|
200 |
|
201 |
def __init__(self, encoder_layer_creator, num_layers, norm=None):
|
202 |
super().__init__()
|
203 |
-
self.layers = nn.ModuleList(
|
|
|
|
|
204 |
self.num_layers = num_layers
|
205 |
self.norm = norm
|
206 |
|
207 |
-
def forward(
|
|
|
|
|
|
|
|
|
|
|
208 |
r"""Pass the input through the encoder layers in turn.
|
209 |
|
210 |
Args:
|
@@ -218,7 +338,9 @@ class TransformerEncoderDiffInit(Module):
|
|
218 |
output = src
|
219 |
|
220 |
for mod in self.layers:
|
221 |
-
output = mod(
|
|
|
|
|
222 |
|
223 |
if self.norm is not None:
|
224 |
output = self.norm(output)
|
|
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
from torch import Tensor
|
7 |
+
import torch.nn.functional as F
|
8 |
from torch.nn import Module, TransformerEncoder
|
9 |
|
10 |
from lcpfn.layer import TransformerEncoderLayer, _get_activation_fn
|
11 |
from lcpfn.utils import SeqBN, bool_mask_to_att_mask
|
12 |
|
13 |
|
14 |
+
class GELU(nn.Module):
|
15 |
+
def forward(self, input: Tensor) -> Tensor:
|
16 |
+
return F.gelu(input)
|
17 |
+
|
18 |
|
19 |
class TransformerModel(nn.Module):
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
encoder,
|
23 |
+
n_out,
|
24 |
+
ninp,
|
25 |
+
nhead,
|
26 |
+
nhid,
|
27 |
+
nlayers,
|
28 |
+
dropout=0.0,
|
29 |
+
style_encoder=None,
|
30 |
+
y_encoder=None,
|
31 |
+
pos_encoder=None,
|
32 |
+
decoder=None,
|
33 |
+
input_normalization=False,
|
34 |
+
init_method=None,
|
35 |
+
pre_norm=False,
|
36 |
+
activation="gelu",
|
37 |
+
recompute_attn=False,
|
38 |
+
num_global_att_tokens=0,
|
39 |
+
full_attention=False,
|
40 |
+
all_layers_same_init=True,
|
41 |
+
):
|
42 |
super().__init__()
|
43 |
+
self.model_type = "Transformer"
|
44 |
+
encoder_layer_creator = lambda: TransformerEncoderLayer(
|
45 |
+
ninp,
|
46 |
+
nhead,
|
47 |
+
nhid,
|
48 |
+
dropout,
|
49 |
+
activation=activation,
|
50 |
+
pre_norm=pre_norm,
|
51 |
+
recompute_attn=recompute_attn,
|
52 |
+
)
|
53 |
+
self.transformer_encoder = (
|
54 |
+
TransformerEncoder(encoder_layer_creator(), nlayers)
|
55 |
+
if all_layers_same_init
|
56 |
+
else TransformerEncoderDiffInit(encoder_layer_creator, nlayers)
|
57 |
+
)
|
58 |
self.ninp = ninp
|
59 |
self.encoder = encoder
|
60 |
self.y_encoder = y_encoder
|
61 |
self.pos_encoder = pos_encoder
|
62 |
+
self.decoder = (
|
63 |
+
decoder(ninp, nhid, n_out)
|
64 |
+
if decoder is not None
|
65 |
+
else nn.Sequential(nn.Linear(ninp, nhid), GELU(), nn.Linear(nhid, n_out))
|
66 |
+
)
|
67 |
self.input_ln = SeqBN(ninp) if input_normalization else None
|
68 |
self.style_encoder = style_encoder
|
69 |
self.init_method = init_method
|
70 |
if num_global_att_tokens is not None:
|
71 |
assert not full_attention
|
72 |
+
self.global_att_embeddings = (
|
73 |
+
nn.Embedding(num_global_att_tokens, ninp) if num_global_att_tokens else None
|
74 |
+
)
|
75 |
self.full_attention = full_attention
|
76 |
|
77 |
self.n_out = n_out
|
|
|
86 |
|
87 |
@staticmethod
|
88 |
def generate_D_q_matrix(sz, query_size):
|
89 |
+
train_size = sz - query_size
|
90 |
+
mask = torch.zeros(sz, sz) == 0
|
91 |
+
mask[:, train_size:].zero_()
|
92 |
mask |= torch.eye(sz) == 1
|
93 |
return bool_mask_to_att_mask(mask)
|
94 |
|
95 |
@staticmethod
|
96 |
+
def generate_global_att_query_matrix(
|
97 |
+
num_global_att_tokens, seq_len, num_query_tokens
|
98 |
+
):
|
99 |
train_size = seq_len + num_global_att_tokens - num_query_tokens
|
100 |
sz = seq_len + num_global_att_tokens
|
101 |
mask = torch.zeros(num_query_tokens, sz) == 0
|
102 |
+
mask[:, train_size:].zero_()
|
103 |
+
mask[:, train_size:] |= torch.eye(num_query_tokens) == 1
|
104 |
return bool_mask_to_att_mask(mask)
|
105 |
|
106 |
@staticmethod
|
107 |
+
def generate_global_att_trainset_matrix(
|
108 |
+
num_global_att_tokens, seq_len, num_query_tokens
|
109 |
+
):
|
110 |
train_size = seq_len + num_global_att_tokens - num_query_tokens
|
111 |
trainset_size = seq_len - num_query_tokens
|
112 |
mask = torch.zeros(trainset_size, num_global_att_tokens) == 0
|
113 |
+
# mask[:,num_global_att_tokens:].zero_()
|
114 |
+
# mask[:,num_global_att_tokens:] |= torch.eye(trainset_size) == 1
|
115 |
return bool_mask_to_att_mask(mask)
|
116 |
|
117 |
@staticmethod
|
118 |
+
def generate_global_att_globaltokens_matrix(
|
119 |
+
num_global_att_tokens, seq_len, num_query_tokens
|
120 |
+
):
|
121 |
+
mask = (
|
122 |
+
torch.zeros(
|
123 |
+
num_global_att_tokens,
|
124 |
+
num_global_att_tokens + seq_len - num_query_tokens,
|
125 |
+
)
|
126 |
+
== 0
|
127 |
+
)
|
128 |
return bool_mask_to_att_mask(mask)
|
129 |
|
130 |
def init_weights(self):
|
131 |
+
initrange = 1.0
|
132 |
# if isinstance(self.encoder,EmbeddingEncoder):
|
133 |
# self.encoder.weight.data.uniform_(-initrange, initrange)
|
134 |
# self.decoder.bias.data.zero_()
|
|
|
138 |
for layer in self.transformer_encoder.layers:
|
139 |
nn.init.zeros_(layer.linear2.weight)
|
140 |
nn.init.zeros_(layer.linear2.bias)
|
141 |
+
attns = (
|
142 |
+
layer.self_attn
|
143 |
+
if isinstance(layer.self_attn, nn.ModuleList)
|
144 |
+
else [layer.self_attn]
|
145 |
+
)
|
146 |
for attn in attns:
|
147 |
nn.init.zeros_(attn.out_proj.weight)
|
148 |
nn.init.zeros_(attn.out_proj.bias)
|
149 |
|
150 |
def forward(self, src, src_mask=None, single_eval_pos=None):
|
151 |
+
assert isinstance(
|
152 |
+
src, tuple
|
153 |
+
), "inputs (src) have to be given as (x,y) or (style,x,y) tuple"
|
154 |
|
155 |
+
if len(src) == 2: # (x,y) and no style
|
156 |
src = (None,) + src
|
157 |
|
158 |
style_src, style_src_size = (src[0], (0 if (src[0] is None) else 1))
|
159 |
+
if src_mask is not None:
|
160 |
+
assert self.global_att_embeddings is None or isinstance(src_mask, tuple)
|
161 |
if src_mask is None:
|
162 |
x_src = src[1]
|
163 |
if self.global_att_embeddings is None:
|
164 |
full_len = len(x_src) + style_src_size
|
165 |
if self.full_attention:
|
166 |
+
src_mask = bool_mask_to_att_mask(
|
167 |
+
torch.ones((full_len, full_len), dtype=torch.bool)
|
168 |
+
).to(x_src.device)
|
169 |
else:
|
170 |
+
src_mask = self.generate_D_q_matrix(
|
171 |
+
len(x_src) + style_src_size,
|
172 |
+
len(x_src) + style_src_size - single_eval_pos,
|
173 |
+
).to(x_src.device)
|
174 |
else:
|
175 |
+
src_mask_args = (
|
176 |
+
self.global_att_embeddings.num_embeddings,
|
177 |
+
len(x_src) + style_src_size,
|
178 |
+
len(x_src) + style_src_size - single_eval_pos,
|
179 |
+
)
|
180 |
+
src_mask = (
|
181 |
+
self.generate_global_att_globaltokens_matrix(*src_mask_args).to(
|
182 |
+
x_src.device
|
183 |
+
),
|
184 |
+
self.generate_global_att_trainset_matrix(*src_mask_args).to(
|
185 |
+
x_src.device
|
186 |
+
),
|
187 |
+
self.generate_global_att_query_matrix(*src_mask_args).to(
|
188 |
+
x_src.device
|
189 |
+
),
|
190 |
+
)
|
191 |
|
192 |
style_src, x_src, y_src = src
|
193 |
x_src = self.encoder(x_src)
|
194 |
+
y_src = self.y_encoder(
|
195 |
+
y_src.unsqueeze(-1) if len(y_src.shape) < len(x_src.shape) else y_src
|
196 |
+
)
|
197 |
+
style_src = (
|
198 |
+
self.style_encoder(style_src).unsqueeze(0)
|
199 |
+
if self.style_encoder
|
200 |
+
else torch.tensor([], device=x_src.device)
|
201 |
+
)
|
202 |
+
global_src = (
|
203 |
+
torch.tensor([], device=x_src.device)
|
204 |
+
if self.global_att_embeddings is None
|
205 |
+
else self.global_att_embeddings.weight.unsqueeze(1).repeat(
|
206 |
+
1, x_src.shape[1], 1
|
207 |
+
)
|
208 |
+
)
|
209 |
train_x = x_src[:single_eval_pos] + y_src[:single_eval_pos]
|
210 |
src = torch.cat([global_src, style_src, train_x, x_src[single_eval_pos:]], 0)
|
211 |
|
|
|
218 |
# If we have style input, drop its output
|
219 |
output = self.transformer_encoder(src, src_mask)[style_src_size:]
|
220 |
output = self.decoder(output)
|
221 |
+
return output[
|
222 |
+
single_eval_pos
|
223 |
+
+ (
|
224 |
+
self.global_att_embeddings.num_embeddings
|
225 |
+
if self.global_att_embeddings
|
226 |
+
else 0
|
227 |
+
) :
|
228 |
+
]
|
229 |
|
230 |
@torch.no_grad()
|
231 |
def init_from_small_model(self, small_model):
|
232 |
+
assert (
|
233 |
+
isinstance(self.decoder, nn.Linear)
|
234 |
+
and isinstance(self.encoder, (nn.Linear, nn.Sequential))
|
235 |
+
and isinstance(self.y_encoder, (nn.Linear, nn.Sequential))
|
236 |
+
)
|
237 |
|
238 |
def set_encoder_weights(my_encoder, small_model_encoder):
|
239 |
+
my_encoder_linear, small_encoder_linear = (
|
240 |
+
(my_encoder, small_model_encoder)
|
241 |
+
if isinstance(my_encoder, nn.Linear)
|
242 |
+
else (my_encoder[-1], small_model_encoder[-1])
|
243 |
+
)
|
244 |
small_in_dim = small_encoder_linear.out_features
|
245 |
my_encoder_linear.weight.zero_()
|
246 |
my_encoder_linear.bias.zero_()
|
|
|
255 |
self.decoder.weight[:, :small_in_dim] = small_model.decoder.weight
|
256 |
self.decoder.bias = small_model.decoder.bias
|
257 |
|
258 |
+
for my_layer, small_layer in zip(
|
259 |
+
self.transformer_encoder.layers, small_model.transformer_encoder.layers
|
260 |
+
):
|
261 |
small_hid_dim = small_layer.linear1.out_features
|
262 |
my_in_dim = my_layer.linear1.in_features
|
263 |
|
|
|
265 |
my_in_proj_w = my_layer.self_attn.in_proj_weight
|
266 |
small_in_proj_w = small_layer.self_attn.in_proj_weight
|
267 |
|
268 |
+
my_in_proj_w.view(3, my_in_dim, my_in_dim)[
|
269 |
+
:, :small_in_dim, :small_in_dim
|
270 |
+
] = small_in_proj_w.view(3, small_in_dim, small_in_dim)
|
271 |
+
my_layer.self_attn.in_proj_bias.view(3, my_in_dim)[:, :small_in_dim] = (
|
272 |
+
small_layer.self_attn.in_proj_bias.view(3, small_in_dim)
|
273 |
+
)
|
274 |
+
|
275 |
+
my_layer.self_attn.out_proj.weight[:small_in_dim, :small_in_dim] = (
|
276 |
+
small_layer.self_attn.out_proj.weight
|
277 |
+
)
|
278 |
+
my_layer.self_attn.out_proj.bias[:small_in_dim] = (
|
279 |
+
small_layer.self_attn.out_proj.bias
|
280 |
+
)
|
281 |
+
|
282 |
+
my_layer.linear1.weight[:small_hid_dim, :small_in_dim] = (
|
283 |
+
small_layer.linear1.weight
|
284 |
+
)
|
285 |
my_layer.linear1.bias[:small_hid_dim] = small_layer.linear1.bias
|
286 |
|
287 |
+
my_layer.linear2.weight[:small_in_dim, :small_hid_dim] = (
|
288 |
+
small_layer.linear2.weight
|
289 |
+
)
|
290 |
my_layer.linear2.bias[:small_in_dim] = small_layer.linear2.bias
|
291 |
|
292 |
+
my_layer.norm1.weight[:small_in_dim] = (
|
293 |
+
math.sqrt(small_in_dim / my_in_dim) * small_layer.norm1.weight
|
294 |
+
)
|
295 |
+
my_layer.norm2.weight[:small_in_dim] = (
|
296 |
+
math.sqrt(small_in_dim / my_in_dim) * small_layer.norm2.weight
|
297 |
+
)
|
298 |
|
299 |
my_layer.norm1.bias[:small_in_dim] = small_layer.norm1.bias
|
300 |
my_layer.norm2.bias[:small_in_dim] = small_layer.norm2.bias
|
|
|
308 |
num_layers: the number of sub-encoder-layers in the encoder (required).
|
309 |
norm: the layer normalization component (optional).
|
310 |
"""
|
311 |
+
|
312 |
+
__constants__ = ["norm"]
|
313 |
|
314 |
def __init__(self, encoder_layer_creator, num_layers, norm=None):
|
315 |
super().__init__()
|
316 |
+
self.layers = nn.ModuleList(
|
317 |
+
[encoder_layer_creator() for _ in range(num_layers)]
|
318 |
+
)
|
319 |
self.num_layers = num_layers
|
320 |
self.norm = norm
|
321 |
|
322 |
+
def forward(
|
323 |
+
self,
|
324 |
+
src: Tensor,
|
325 |
+
mask: Optional[Tensor] = None,
|
326 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
327 |
+
) -> Tensor:
|
328 |
r"""Pass the input through the encoder layers in turn.
|
329 |
|
330 |
Args:
|
|
|
338 |
output = src
|
339 |
|
340 |
for mod in self.layers:
|
341 |
+
output = mod(
|
342 |
+
output, src_mask=mask, src_key_padding_mask=src_key_padding_mask
|
343 |
+
)
|
344 |
|
345 |
if self.norm is not None:
|
346 |
output = self.norm(output)
|
lcpfn/utils.py
CHANGED
@@ -9,9 +9,12 @@ from torch import nn
|
|
9 |
from torch.optim.lr_scheduler import LambdaLR
|
10 |
import numpy as np
|
11 |
|
|
|
12 |
# copied from huggingface
|
13 |
-
def get_cosine_schedule_with_warmup(
|
14 |
-
|
|
|
|
|
15 |
values of the cosine function between 0 and `pi * cycles` after a warmup
|
16 |
period during which it increases linearly between 0 and 1.
|
17 |
"""
|
@@ -19,13 +22,20 @@ def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_st
|
|
19 |
def lr_lambda(current_step):
|
20 |
if current_step < num_warmup_steps:
|
21 |
return float(current_step) / float(max(1, num_warmup_steps))
|
22 |
-
progress = float(current_step - num_warmup_steps) / float(
|
23 |
-
|
|
|
|
|
|
|
|
|
24 |
|
25 |
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
26 |
|
|
|
27 |
# copied from huggingface
|
28 |
-
def get_linear_schedule_with_warmup(
|
|
|
|
|
29 |
"""
|
30 |
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
|
31 |
a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
|
@@ -48,7 +58,9 @@ def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_st
|
|
48 |
if current_step < num_warmup_steps:
|
49 |
return float(current_step) / float(max(1, num_warmup_steps))
|
50 |
return max(
|
51 |
-
0.0,
|
|
|
|
|
52 |
)
|
53 |
|
54 |
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
@@ -65,7 +77,9 @@ def get_weighted_single_eval_pos_sampler(max_len):
|
|
65 |
where p <= `max_len`. At most `max_len` - 1 examples are shown to the Transformer.
|
66 |
:return: Sampler that can be fed to `train()` as `single_eval_pos_gen`.
|
67 |
"""
|
68 |
-
return lambda: random.choices(
|
|
|
|
|
69 |
|
70 |
|
71 |
def get_uniform_single_eval_pos_sampler(max_len, min_len=0):
|
@@ -95,19 +109,22 @@ def set_locals_in_self(locals):
|
|
95 |
Especially useful right at the beginning of `__init__`.
|
96 |
:param locals: `locals()`
|
97 |
"""
|
98 |
-
self = locals[
|
99 |
for var_name, val in locals.items():
|
100 |
-
if var_name !=
|
|
|
101 |
|
102 |
|
103 |
-
default_device =
|
104 |
|
105 |
|
106 |
# Copied from StackOverflow, but we do an eval on the values additionally
|
107 |
class StoreDictKeyPair(argparse.Action):
|
108 |
def __init__(self, option_strings, dest, nargs=None, **kwargs):
|
109 |
self._nargs = nargs
|
110 |
-
super(StoreDictKeyPair, self).__init__(
|
|
|
|
|
111 |
|
112 |
def __call__(self, parser, namespace, values, option_string=None):
|
113 |
my_dict = {}
|
@@ -120,16 +137,20 @@ class StoreDictKeyPair(argparse.Action):
|
|
120 |
setattr(namespace, self.dest, my_dict)
|
121 |
print("dict values: {}".format(my_dict))
|
122 |
|
|
|
123 |
def get_nan_value(v, set_value_to_nan=0.0):
|
124 |
if random.random() < set_value_to_nan:
|
125 |
return v
|
126 |
else:
|
127 |
return random.choice([-999, 0, 1, 999])
|
128 |
|
|
|
129 |
def to_ranking(data):
|
130 |
-
x =
|
131 |
x = x.sum(0)
|
132 |
return x
|
|
|
|
|
133 |
# TODO: Is there a better way to do this?
|
134 |
# 1. Cmparing to unique elements: When all values are different we still get quadratic blowup
|
135 |
# 2. Argsort(Argsort()) returns ranking, but with duplicate values there is an ordering which is problematic
|
@@ -137,49 +158,64 @@ def to_ranking(data):
|
|
137 |
def to_ranking_low_mem(data):
|
138 |
x = torch.zeros_like(data)
|
139 |
for col in range(data.shape[-1]):
|
140 |
-
x_ =
|
141 |
x_ = x_.sum(0)
|
142 |
x[:, :, col] = x_
|
143 |
return x
|
144 |
|
|
|
145 |
def nan_handling_missing_for_unknown_reason_value(set_value_to_nan=0.0):
|
146 |
-
return get_nan_value(float(
|
|
|
147 |
|
148 |
def nan_handling_missing_for_no_reason_value(set_value_to_nan=0.0):
|
149 |
-
return get_nan_value(float(
|
|
|
150 |
|
151 |
def nan_handling_missing_for_a_reason_value(set_value_to_nan=0.0):
|
152 |
-
return get_nan_value(float(
|
|
|
153 |
|
154 |
def torch_nanmean(x, axis=0):
|
155 |
-
num = torch.where(torch.isnan(x), torch.full_like(x, 0), torch.full_like(x, 1)).sum(
|
|
|
|
|
156 |
value = torch.where(torch.isnan(x), torch.full_like(x, 0), x).sum(axis=axis)
|
157 |
return value / num
|
158 |
|
|
|
159 |
def torch_nanstd(x, axis=0):
|
160 |
-
num = torch.where(torch.isnan(x), torch.full_like(x, 0), torch.full_like(x, 1)).sum(
|
|
|
|
|
161 |
value = torch.where(torch.isnan(x), torch.full_like(x, 0), x).sum(axis=axis)
|
162 |
mean = value / num
|
163 |
-
mean_broadcast = torch.repeat_interleave(
|
164 |
-
|
|
|
|
|
|
|
|
|
|
|
165 |
|
166 |
def normalize_data(data, normalize_positions=-1):
|
167 |
if normalize_positions > 0:
|
168 |
mean = torch_nanmean(data[:normalize_positions], axis=0)
|
169 |
-
std = torch_nanstd(data[:normalize_positions], axis=0) + .000001
|
170 |
else:
|
171 |
mean = torch_nanmean(data, axis=0)
|
172 |
-
std = torch_nanstd(data, axis=0) + .000001
|
173 |
data = (data - mean) / std
|
174 |
data = torch.clip(data, min=-100, max=100)
|
175 |
|
176 |
return data
|
177 |
|
|
|
178 |
def remove_outliers(X, n_sigma=4):
|
179 |
# Expects T, B, H
|
180 |
assert len(X.shape) == 3, "X must be T,B,H"
|
181 |
-
#for b in range(X.shape[1]):
|
182 |
-
|
183 |
data = X
|
184 |
data_mean, data_std = torch_nanmean(data, axis=0), torch_nanstd(data, axis=0)
|
185 |
cut_off = data_std * n_sigma
|
@@ -187,17 +223,26 @@ def remove_outliers(X, n_sigma=4):
|
|
187 |
|
188 |
data_clean = X[:].clone()
|
189 |
data_clean[torch.logical_or(data > upper, data < lower)] = np.nan
|
190 |
-
data_mean, data_std =
|
|
|
|
|
|
|
191 |
cut_off = data_std * n_sigma
|
192 |
lower, upper = data_mean - cut_off, data_mean + cut_off
|
193 |
|
194 |
-
X = torch.maximum(-torch.log(1+torch.abs(X)) + lower, X)
|
195 |
-
X = torch.minimum(torch.log(1+torch.abs(X)) + upper, X)
|
196 |
-
|
197 |
return X
|
198 |
|
|
|
199 |
def bool_mask_to_att_mask(mask):
|
200 |
-
return
|
|
|
|
|
|
|
|
|
|
|
201 |
|
202 |
def print_on_master_only(is_master):
|
203 |
import builtins as __builtin__
|
@@ -213,46 +258,152 @@ def print_on_master_only(is_master):
|
|
213 |
|
214 |
|
215 |
def init_dist(device):
|
216 |
-
print(
|
217 |
-
if
|
218 |
# launched with torch.distributed.launch
|
219 |
rank = int(os.environ["LOCAL_RANK"])
|
220 |
-
print(
|
221 |
torch.cuda.set_device(rank)
|
222 |
-
os.environ[
|
223 |
-
torch.distributed.init_process_group(
|
224 |
-
|
|
|
|
|
|
|
|
|
|
|
225 |
torch.distributed.barrier()
|
226 |
print_on_master_only(rank == 0)
|
227 |
-
print(
|
228 |
-
|
229 |
-
|
230 |
-
|
|
|
|
|
231 |
# this is for multi gpu when starting with submitit
|
232 |
-
assert device !=
|
233 |
-
rank = int(os.environ[
|
234 |
-
os.environ[
|
235 |
-
os.environ[
|
236 |
torch.cuda.set_device(rank)
|
237 |
-
os.environ[
|
238 |
-
print(
|
239 |
-
torch.distributed.init_process_group(
|
240 |
-
|
|
|
|
|
|
|
|
|
|
|
241 |
torch.distributed.barrier()
|
242 |
print_on_master_only(rank == 0)
|
243 |
-
print(
|
244 |
-
|
|
|
|
|
245 |
|
246 |
-
return True, rank, f
|
247 |
else:
|
248 |
-
print(
|
249 |
# will not change any of the behavior of print, but allows putting the force=True in the print calls
|
250 |
print_on_master_only(True)
|
251 |
return False, 0, device
|
252 |
|
253 |
|
254 |
def check_compatibility(dl):
|
255 |
-
if hasattr(dl,
|
256 |
-
print(
|
257 |
-
|
258 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
from torch.optim.lr_scheduler import LambdaLR
|
10 |
import numpy as np
|
11 |
|
12 |
+
|
13 |
# copied from huggingface
|
14 |
+
def get_cosine_schedule_with_warmup(
|
15 |
+
optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1
|
16 |
+
):
|
17 |
+
"""Create a schedule with a learning rate that decreases following the
|
18 |
values of the cosine function between 0 and `pi * cycles` after a warmup
|
19 |
period during which it increases linearly between 0 and 1.
|
20 |
"""
|
|
|
22 |
def lr_lambda(current_step):
|
23 |
if current_step < num_warmup_steps:
|
24 |
return float(current_step) / float(max(1, num_warmup_steps))
|
25 |
+
progress = float(current_step - num_warmup_steps) / float(
|
26 |
+
max(1, num_training_steps - num_warmup_steps)
|
27 |
+
)
|
28 |
+
return max(
|
29 |
+
0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
|
30 |
+
)
|
31 |
|
32 |
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
33 |
|
34 |
+
|
35 |
# copied from huggingface
|
36 |
+
def get_linear_schedule_with_warmup(
|
37 |
+
optimizer, num_warmup_steps, num_training_steps, last_epoch=-1
|
38 |
+
):
|
39 |
"""
|
40 |
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
|
41 |
a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
|
|
|
58 |
if current_step < num_warmup_steps:
|
59 |
return float(current_step) / float(max(1, num_warmup_steps))
|
60 |
return max(
|
61 |
+
0.0,
|
62 |
+
float(num_training_steps - current_step)
|
63 |
+
/ float(max(1, num_training_steps - num_warmup_steps)),
|
64 |
)
|
65 |
|
66 |
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
|
|
77 |
where p <= `max_len`. At most `max_len` - 1 examples are shown to the Transformer.
|
78 |
:return: Sampler that can be fed to `train()` as `single_eval_pos_gen`.
|
79 |
"""
|
80 |
+
return lambda: random.choices(
|
81 |
+
range(max_len), [1 / (max_len - i) for i in range(max_len)]
|
82 |
+
)[0]
|
83 |
|
84 |
|
85 |
def get_uniform_single_eval_pos_sampler(max_len, min_len=0):
|
|
|
109 |
Especially useful right at the beginning of `__init__`.
|
110 |
:param locals: `locals()`
|
111 |
"""
|
112 |
+
self = locals["self"]
|
113 |
for var_name, val in locals.items():
|
114 |
+
if var_name != "self":
|
115 |
+
setattr(self, var_name, val)
|
116 |
|
117 |
|
118 |
+
default_device = "cuda:0" if torch.cuda.is_available() else "cpu:0"
|
119 |
|
120 |
|
121 |
# Copied from StackOverflow, but we do an eval on the values additionally
|
122 |
class StoreDictKeyPair(argparse.Action):
|
123 |
def __init__(self, option_strings, dest, nargs=None, **kwargs):
|
124 |
self._nargs = nargs
|
125 |
+
super(StoreDictKeyPair, self).__init__(
|
126 |
+
option_strings, dest, nargs=nargs, **kwargs
|
127 |
+
)
|
128 |
|
129 |
def __call__(self, parser, namespace, values, option_string=None):
|
130 |
my_dict = {}
|
|
|
137 |
setattr(namespace, self.dest, my_dict)
|
138 |
print("dict values: {}".format(my_dict))
|
139 |
|
140 |
+
|
141 |
def get_nan_value(v, set_value_to_nan=0.0):
|
142 |
if random.random() < set_value_to_nan:
|
143 |
return v
|
144 |
else:
|
145 |
return random.choice([-999, 0, 1, 999])
|
146 |
|
147 |
+
|
148 |
def to_ranking(data):
|
149 |
+
x = data >= data.unsqueeze(-3)
|
150 |
x = x.sum(0)
|
151 |
return x
|
152 |
+
|
153 |
+
|
154 |
# TODO: Is there a better way to do this?
|
155 |
# 1. Cmparing to unique elements: When all values are different we still get quadratic blowup
|
156 |
# 2. Argsort(Argsort()) returns ranking, but with duplicate values there is an ordering which is problematic
|
|
|
158 |
def to_ranking_low_mem(data):
|
159 |
x = torch.zeros_like(data)
|
160 |
for col in range(data.shape[-1]):
|
161 |
+
x_ = data[:, :, col] >= data[:, :, col].unsqueeze(-2)
|
162 |
x_ = x_.sum(0)
|
163 |
x[:, :, col] = x_
|
164 |
return x
|
165 |
|
166 |
+
|
167 |
def nan_handling_missing_for_unknown_reason_value(set_value_to_nan=0.0):
|
168 |
+
return get_nan_value(float("nan"), set_value_to_nan)
|
169 |
+
|
170 |
|
171 |
def nan_handling_missing_for_no_reason_value(set_value_to_nan=0.0):
|
172 |
+
return get_nan_value(float("-inf"), set_value_to_nan)
|
173 |
+
|
174 |
|
175 |
def nan_handling_missing_for_a_reason_value(set_value_to_nan=0.0):
|
176 |
+
return get_nan_value(float("inf"), set_value_to_nan)
|
177 |
+
|
178 |
|
179 |
def torch_nanmean(x, axis=0):
|
180 |
+
num = torch.where(torch.isnan(x), torch.full_like(x, 0), torch.full_like(x, 1)).sum(
|
181 |
+
axis=axis
|
182 |
+
)
|
183 |
value = torch.where(torch.isnan(x), torch.full_like(x, 0), x).sum(axis=axis)
|
184 |
return value / num
|
185 |
|
186 |
+
|
187 |
def torch_nanstd(x, axis=0):
|
188 |
+
num = torch.where(torch.isnan(x), torch.full_like(x, 0), torch.full_like(x, 1)).sum(
|
189 |
+
axis=axis
|
190 |
+
)
|
191 |
value = torch.where(torch.isnan(x), torch.full_like(x, 0), x).sum(axis=axis)
|
192 |
mean = value / num
|
193 |
+
mean_broadcast = torch.repeat_interleave(
|
194 |
+
mean.unsqueeze(axis), x.shape[axis], dim=axis
|
195 |
+
)
|
196 |
+
return torch.sqrt(
|
197 |
+
torch.nansum(torch.square(mean_broadcast - x), axis=axis) / (num - 1)
|
198 |
+
)
|
199 |
+
|
200 |
|
201 |
def normalize_data(data, normalize_positions=-1):
|
202 |
if normalize_positions > 0:
|
203 |
mean = torch_nanmean(data[:normalize_positions], axis=0)
|
204 |
+
std = torch_nanstd(data[:normalize_positions], axis=0) + 0.000001
|
205 |
else:
|
206 |
mean = torch_nanmean(data, axis=0)
|
207 |
+
std = torch_nanstd(data, axis=0) + 0.000001
|
208 |
data = (data - mean) / std
|
209 |
data = torch.clip(data, min=-100, max=100)
|
210 |
|
211 |
return data
|
212 |
|
213 |
+
|
214 |
def remove_outliers(X, n_sigma=4):
|
215 |
# Expects T, B, H
|
216 |
assert len(X.shape) == 3, "X must be T,B,H"
|
217 |
+
# for b in range(X.shape[1]):
|
218 |
+
# for col in range(X.shape[2]):
|
219 |
data = X
|
220 |
data_mean, data_std = torch_nanmean(data, axis=0), torch_nanstd(data, axis=0)
|
221 |
cut_off = data_std * n_sigma
|
|
|
223 |
|
224 |
data_clean = X[:].clone()
|
225 |
data_clean[torch.logical_or(data > upper, data < lower)] = np.nan
|
226 |
+
data_mean, data_std = (
|
227 |
+
torch_nanmean(data_clean, axis=0),
|
228 |
+
torch_nanstd(data_clean, axis=0),
|
229 |
+
)
|
230 |
cut_off = data_std * n_sigma
|
231 |
lower, upper = data_mean - cut_off, data_mean + cut_off
|
232 |
|
233 |
+
X = torch.maximum(-torch.log(1 + torch.abs(X)) + lower, X)
|
234 |
+
X = torch.minimum(torch.log(1 + torch.abs(X)) + upper, X)
|
235 |
+
# print(ds[1][data < lower, col], ds[1][data > upper, col], ds[1][~np.isnan(data), col].shape, data_mean, data_std)
|
236 |
return X
|
237 |
|
238 |
+
|
239 |
def bool_mask_to_att_mask(mask):
|
240 |
+
return (
|
241 |
+
mask.float()
|
242 |
+
.masked_fill(mask == 0, float("-inf"))
|
243 |
+
.masked_fill(mask == 1, float(0.0))
|
244 |
+
)
|
245 |
+
|
246 |
|
247 |
def print_on_master_only(is_master):
|
248 |
import builtins as __builtin__
|
|
|
258 |
|
259 |
|
260 |
def init_dist(device):
|
261 |
+
print("init dist")
|
262 |
+
if "LOCAL_RANK" in os.environ:
|
263 |
# launched with torch.distributed.launch
|
264 |
rank = int(os.environ["LOCAL_RANK"])
|
265 |
+
print("torch.distributed.launch and my rank is", rank)
|
266 |
torch.cuda.set_device(rank)
|
267 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
|
268 |
+
torch.distributed.init_process_group(
|
269 |
+
backend="nccl",
|
270 |
+
init_method="env://",
|
271 |
+
timeout=datetime.timedelta(seconds=20),
|
272 |
+
world_size=torch.cuda.device_count(),
|
273 |
+
rank=rank,
|
274 |
+
)
|
275 |
torch.distributed.barrier()
|
276 |
print_on_master_only(rank == 0)
|
277 |
+
print(
|
278 |
+
f"Distributed training on {torch.cuda.device_count()} GPUs, this is rank {rank}, "
|
279 |
+
"only I can print, but when using print(..., force=True) it will print on all ranks."
|
280 |
+
)
|
281 |
+
return True, rank, f"cuda:{rank}"
|
282 |
+
elif "SLURM_PROCID" in os.environ and torch.cuda.device_count() > 1:
|
283 |
# this is for multi gpu when starting with submitit
|
284 |
+
assert device != "cpu:0"
|
285 |
+
rank = int(os.environ["SLURM_PROCID"])
|
286 |
+
os.environ["MASTER_ADDR"] = "localhost"
|
287 |
+
os.environ["MASTER_PORT"] = "12355"
|
288 |
torch.cuda.set_device(rank)
|
289 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
|
290 |
+
print("distributed submitit launch and my rank is", rank)
|
291 |
+
torch.distributed.init_process_group(
|
292 |
+
backend="nccl",
|
293 |
+
init_method="env://",
|
294 |
+
timeout=datetime.timedelta(seconds=20),
|
295 |
+
world_size=torch.cuda.device_count(),
|
296 |
+
rank=rank,
|
297 |
+
)
|
298 |
torch.distributed.barrier()
|
299 |
print_on_master_only(rank == 0)
|
300 |
+
print(
|
301 |
+
f"Distributed training on {torch.cuda.device_count()} GPUs, this is rank {rank}, "
|
302 |
+
"only I can print, but when using print(..., force=True) it will print on all ranks."
|
303 |
+
)
|
304 |
|
305 |
+
return True, rank, f"cuda:{rank}"
|
306 |
else:
|
307 |
+
print("Not using distributed")
|
308 |
# will not change any of the behavior of print, but allows putting the force=True in the print calls
|
309 |
print_on_master_only(True)
|
310 |
return False, 0, device
|
311 |
|
312 |
|
313 |
def check_compatibility(dl):
|
314 |
+
if hasattr(dl, "num_outputs"):
|
315 |
+
print(
|
316 |
+
"`num_outputs` for the DataLoader is deprecated. It is assumed to be 1 from now on."
|
317 |
+
)
|
318 |
+
assert dl.num_outputs != 1, (
|
319 |
+
"We assume num_outputs to be 1. Instead of the num_ouputs change your loss."
|
320 |
+
"We specify the number of classes in the CE loss."
|
321 |
+
)
|
322 |
+
|
323 |
+
|
324 |
+
def pfn_normalize(
|
325 |
+
lb=torch.tensor(float("-inf")),
|
326 |
+
ub=torch.tensor(float("inf")),
|
327 |
+
soft_lb=0.0,
|
328 |
+
soft_ub=1.0,
|
329 |
+
minimize=False,
|
330 |
+
):
|
331 |
+
"""
|
332 |
+
LC-PFN curve prior assumes curves to be normalized within the range [0,1] and to be maximized.
|
333 |
+
This function allows to normalize and denormalize data to fit this assumption.
|
334 |
+
|
335 |
+
Parameters:
|
336 |
+
lb (torch.Tensor): Lower bound of the data.
|
337 |
+
ub (torch.Tensor): Upper bound of the data.
|
338 |
+
soft_lb (float): Soft lower bound for normalization. Default is 0.0.
|
339 |
+
soft_ub (float): Soft upper bound for normalization. Default is 1.0.
|
340 |
+
minimize (bool): If True, the original curve is a minization. Default is False.
|
341 |
+
|
342 |
+
Returns: Two functions for normalizing and denormalizing the data.
|
343 |
+
"""
|
344 |
+
assert lb <= soft_lb and soft_lb < soft_ub and soft_ub <= ub
|
345 |
+
# step 1: linearly transform [soft_lb,soft_ub] [-1,1] (where the sigmoid behaves approx linearly)
|
346 |
+
# 2.0/(soft_ub - soft_lb)*(x - soft_lb) - 1.0
|
347 |
+
# step 2: apply a vertically scaled/shifted the sigmoid such that [lb,ub] --> [0,1]
|
348 |
+
|
349 |
+
def cinv(x):
|
350 |
+
return 1 - x if minimize else x
|
351 |
+
|
352 |
+
def lin_soft(x):
|
353 |
+
return 2 / (soft_ub - soft_lb) * (x - soft_lb) - 1
|
354 |
+
|
355 |
+
def lin_soft_inv(y):
|
356 |
+
return (y + 1) / 2 * (soft_ub - soft_lb) + soft_lb
|
357 |
+
|
358 |
+
try:
|
359 |
+
if torch.exp(-lin_soft(lb)) > 1e300:
|
360 |
+
raise RuntimeError
|
361 |
+
# otherwise overflow causes issues, treat these cases as if the lower bound was -infinite
|
362 |
+
# print(f"WARNING: {lb} --> NINF to avoid overflows ({np.exp(-lin_soft(lb))})")
|
363 |
+
except RuntimeError:
|
364 |
+
lb = torch.tensor(float("-inf"))
|
365 |
+
if torch.isinf(lb) and torch.isinf(ub):
|
366 |
+
return lambda x: cinv(
|
367 |
+
1 / (1 + torch.exp(-lin_soft(x)))
|
368 |
+
), lambda y: lin_soft_inv(torch.log(cinv(y) / (1 - cinv(y))))
|
369 |
+
elif torch.isinf(lb):
|
370 |
+
a = 1 + torch.exp(-lin_soft(ub))
|
371 |
+
return lambda x: cinv(
|
372 |
+
a / (1 + torch.exp(-lin_soft(x)))
|
373 |
+
), lambda y: lin_soft_inv(torch.log((cinv(y) / a) / (1 - (cinv(y) / a))))
|
374 |
+
elif torch.isinf(ub):
|
375 |
+
a = 1 / (1 - 1 / (1 + torch.exp(-lin_soft(lb))))
|
376 |
+
b = 1 - a
|
377 |
+
return lambda x: cinv(
|
378 |
+
a / (1 + torch.exp(-lin_soft(x))) + b
|
379 |
+
), lambda y: lin_soft_inv(
|
380 |
+
torch.log(((cinv(y) - b) / a) / (1 - ((cinv(y) - b) / a)))
|
381 |
+
)
|
382 |
+
else:
|
383 |
+
a = (
|
384 |
+
1
|
385 |
+
+ torch.exp(-lin_soft(ub))
|
386 |
+
+ torch.exp(-lin_soft(lb))
|
387 |
+
+ torch.exp(-lin_soft(ub) - lin_soft(lb))
|
388 |
+
) / (torch.exp(-lin_soft(lb)) - torch.exp(-lin_soft(ub)))
|
389 |
+
b = -a / (1 + torch.exp(-lin_soft(lb)))
|
390 |
+
return lambda x: cinv(
|
391 |
+
a / (1 + torch.exp(-lin_soft(x))) + b
|
392 |
+
), lambda y: lin_soft_inv(
|
393 |
+
torch.log(((cinv(y) - b) / a) / (1 - ((cinv(y) - b) / a)))
|
394 |
+
)
|
395 |
+
|
396 |
+
|
397 |
+
def get_default_normalizer():
|
398 |
+
default_normalizer_kwargs = {
|
399 |
+
"lb": torch.tensor(0.0),
|
400 |
+
"ub": torch.tensor(1.0),
|
401 |
+
"soft_lb": 0.0,
|
402 |
+
"soft_ub": 1.0,
|
403 |
+
"minimize": False,
|
404 |
+
}
|
405 |
+
return pfn_normalize(**default_normalizer_kwargs)
|
406 |
+
|
407 |
+
|
408 |
+
def identity_normalizer():
|
409 |
+
return lambda x: x, lambda x: x
|
lcpfn/version.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__version__ = "0.1.3"
|
pyproject.toml
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[project]
|
2 |
+
name = "lcpfn"
|
3 |
+
description = "In-context Bayesian Learning Curve Extrapolation"
|
4 |
+
readme = {file = "readme.md", content-type = 'text/markdown'}
|
5 |
+
license = {file = "LICENSE"}
|
6 |
+
authors = [
|
7 |
+
{name = "Steven Adriaensen", email= "adriaens@cs.uni-freiburg.de"},
|
8 |
+
{name = "Herilalaina Rakotoarison", email = "rakotoah@cs.uni-freiburg.de"},
|
9 |
+
{name = "Samuel Müller", email = "muellesa@cs.uni-freiburg.de"},
|
10 |
+
{name = "Frank Hutter", email = "fh@cs.uni-freiburg.de"},
|
11 |
+
]
|
12 |
+
requires-python = ">=3.9,<3.12"
|
13 |
+
dependencies = [
|
14 |
+
"torch<=1.11.0",
|
15 |
+
"numpy>=1.21.2,<2",
|
16 |
+
"requests>=2.23.0"
|
17 |
+
]
|
18 |
+
dynamic = ["version"]
|
19 |
+
classifiers = [
|
20 |
+
'Intended Audience :: Science/Research',
|
21 |
+
'License :: OSI Approved :: MIT License',
|
22 |
+
'Programming Language :: Python',
|
23 |
+
'Topic :: Software Development',
|
24 |
+
'Topic :: Scientific/Engineering',
|
25 |
+
'Operating System :: Unix',
|
26 |
+
'Operating System :: MacOS',
|
27 |
+
'Programming Language :: Python :: 3',
|
28 |
+
'Programming Language :: Python :: 3.9',
|
29 |
+
'Programming Language :: Python :: 3.10',
|
30 |
+
'Programming Language :: Python :: 3.11',
|
31 |
+
]
|
32 |
+
|
33 |
+
[project.urls]
|
34 |
+
homepage = "https://github.com/automl/lcpfn"
|
35 |
+
repository = "https://github.com/automl/lcpfn"
|
36 |
+
bugtracker = "https://github.com/automl/lcpfn/issues"
|
37 |
+
|
38 |
+
[tool.setuptools.packages.find]
|
39 |
+
include = ["lcpfn*"]
|
40 |
+
|
41 |
+
[tool.setuptools.dynamic]
|
42 |
+
version = {attr = "lcpfn.version.__version__"}
|
requirements.txt
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
torch==1.11.0
|
2 |
-
numpy>=1.21.2
|
3 |
-
scikit-learn
|
4 |
-
# lcpfn @ git+https://github.com/automl/lcpfn.git
|
|
|
|
|
|
|
|
|
|