herilalaina commited on
Commit
b62776c
1 Parent(s): 3866017

update lcpfn

Browse files
lcpfn/__init__.py CHANGED
@@ -1,53 +1,80 @@
1
  import os, sys
 
2
  sys.path.insert(0, os.path.dirname(__file__))
3
 
4
 
5
- model_path = 'trained_models'
 
6
 
7
  def prepare_models():
8
  pfns4bo_dir = os.path.dirname(__file__)
9
- model_names = ['pfn_EPOCH1000_EMSIZE512_NLAYERS12_NBUCKETS1000.pt',
10
- 'pfn_EPOCH1000_EMSIZE512_NLAYERS6_NBUCKETS1000.pt']
 
 
11
 
12
  for name in model_names:
13
  weights_path = os.path.join(pfns4bo_dir, model_path, name)
14
- compressed_weights_path = os.path.join(pfns4bo_dir, model_path, name + '.gz')
15
  if not os.path.exists(weights_path):
16
  if not os.path.exists(compressed_weights_path):
17
  print("Downloading", os.path.abspath(compressed_weights_path))
18
  import requests
19
- url = f'https://github.com/automl/lcpfn/raw/main/lcpfn/trained_models/{name + ".gz"}'
 
20
  r = requests.get(url, allow_redirects=True)
21
  os.makedirs(os.path.dirname(compressed_weights_path), exist_ok=True)
22
- with open(compressed_weights_path, 'wb') as f:
23
  f.write(r.content)
24
  if os.path.exists(compressed_weights_path):
25
  print("Unzipping", name)
26
  os.system(f"gzip -dk {compressed_weights_path}")
27
  else:
28
  print("Failed to find", compressed_weights_path)
29
- print("Make sure you have an internet connection to download the model automatically..")
 
 
30
  if os.path.exists(weights_path):
31
  print("Successfully located model at", weights_path)
32
 
33
 
34
  model_dict = {
35
- 'EMSIZE512_NLAYERS12_NBUCKETS1000': os.path.join(os.path.dirname(__file__),model_path,
36
- 'pfn_EPOCH1000_EMSIZE512_NLAYERS12_NBUCKETS1000.pt'),
37
- 'EMSIZE512_NLAYERS6_NBUCKETS1000': os.path.join(os.path.dirname(__file__),model_path,
38
- 'pfn_EPOCH1000_EMSIZE512_NLAYERS6_NBUCKETS1000.pt'),
 
 
 
 
 
 
39
  }
40
 
41
 
42
  def __getattr__(name):
43
  if name in model_dict:
44
  if not os.path.exists(model_dict[name]):
45
- print("Can't find", os.path.abspath(model_dict[name]), "thus unzipping/downloading models now.")
 
 
 
 
46
  print("This might take a while..")
47
  prepare_models()
48
  return model_dict[name]
49
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
50
 
 
 
51
  from lcpfn.model import LCPFN
52
  from lcpfn.train_lcpfn import train_lcpfn
53
- from lcpfn.domhan_prior import sample_from_prior, create_get_batch_func
 
 
 
 
 
 
 
 
 
1
  import os, sys
2
+
3
  sys.path.insert(0, os.path.dirname(__file__))
4
 
5
 
6
+ model_path = "trained_models"
7
+
8
 
9
  def prepare_models():
10
  pfns4bo_dir = os.path.dirname(__file__)
11
+ model_names = [
12
+ "pfn_EPOCH1000_EMSIZE512_NLAYERS12_NBUCKETS1000.pt",
13
+ "pfn_EPOCH1000_EMSIZE512_NLAYERS6_NBUCKETS1000.pt",
14
+ ]
15
 
16
  for name in model_names:
17
  weights_path = os.path.join(pfns4bo_dir, model_path, name)
18
+ compressed_weights_path = os.path.join(pfns4bo_dir, model_path, name + ".gz")
19
  if not os.path.exists(weights_path):
20
  if not os.path.exists(compressed_weights_path):
21
  print("Downloading", os.path.abspath(compressed_weights_path))
22
  import requests
23
+
24
+ url = f'https://ml.informatik.uni-freiburg.de/research-artifacts/lcpfn/{name + ".gz"}'
25
  r = requests.get(url, allow_redirects=True)
26
  os.makedirs(os.path.dirname(compressed_weights_path), exist_ok=True)
27
+ with open(compressed_weights_path, "wb") as f:
28
  f.write(r.content)
29
  if os.path.exists(compressed_weights_path):
30
  print("Unzipping", name)
31
  os.system(f"gzip -dk {compressed_weights_path}")
32
  else:
33
  print("Failed to find", compressed_weights_path)
34
+ print(
35
+ "Make sure you have an internet connection to download the model automatically.."
36
+ )
37
  if os.path.exists(weights_path):
38
  print("Successfully located model at", weights_path)
39
 
40
 
41
  model_dict = {
42
+ "EMSIZE512_NLAYERS12_NBUCKETS1000": os.path.join(
43
+ os.path.dirname(__file__),
44
+ model_path,
45
+ "pfn_EPOCH1000_EMSIZE512_NLAYERS12_NBUCKETS1000.pt",
46
+ ),
47
+ "EMSIZE512_NLAYERS6_NBUCKETS1000": os.path.join(
48
+ os.path.dirname(__file__),
49
+ model_path,
50
+ "pfn_EPOCH1000_EMSIZE512_NLAYERS6_NBUCKETS1000.pt",
51
+ ),
52
  }
53
 
54
 
55
  def __getattr__(name):
56
  if name in model_dict:
57
  if not os.path.exists(model_dict[name]):
58
+ print(
59
+ "Can't find",
60
+ os.path.abspath(model_dict[name]),
61
+ "thus unzipping/downloading models now.",
62
+ )
63
  print("This might take a while..")
64
  prepare_models()
65
  return model_dict[name]
66
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
67
 
68
+
69
+ from .version import __version__
70
  from lcpfn.model import LCPFN
71
  from lcpfn.train_lcpfn import train_lcpfn
72
+ from lcpfn.domhan_prior import sample_from_prior, create_get_batch_func
73
+
74
+ __all__ = [
75
+ "LCPFN",
76
+ "train_lcpfn",
77
+ "sample_from_prior",
78
+ "create_get_batch_func",
79
+ "__version__",
80
+ ]
lcpfn/bar_distribution.py CHANGED
@@ -3,19 +3,25 @@ from torch import nn
3
 
4
 
5
  class BarDistribution(nn.Module):
6
- def __init__(self, borders: torch.Tensor, smoothing=.0): # here borders should start with min and end with max, where all values lie in (min,max) and are sorted
 
 
7
  # sorted list of borders
8
  super().__init__()
9
  assert len(borders.shape) == 1
10
- #self.borders = borders
11
- self.register_buffer('borders', borders)
12
- self.register_buffer('smoothing', torch.tensor(smoothing))
13
- #self.bucket_widths = self.borders[1:] - self.borders[:-1]
14
- self.register_buffer('bucket_widths', self.borders[1:] - self.borders[:-1])
15
  full_width = self.bucket_widths.sum()
16
  border_order = torch.argsort(borders)
17
- assert (full_width - (self.borders[-1] - self.borders[0])).abs() < 1e-4, f'diff: {full_width - (self.borders[-1] - self.borders[0])}'
18
- assert (border_order == torch.arange(len(borders)).to(border_order.device)).all(), "Please provide sorted borders!"
 
 
 
 
19
  self.num_bars = len(borders) - 1
20
 
21
  def map_to_bucket_idx(self, y):
@@ -24,28 +30,35 @@ class BarDistribution(nn.Module):
24
  target_sample[y == self.borders[-1]] = self.num_bars - 1
25
  return target_sample
26
 
27
- def forward(self, logits, y): # gives the negative log density (the _loss_), y: T x B, logits: T x B x self.num_bars
 
 
28
  target_sample = self.map_to_bucket_idx(y)
29
- assert (target_sample >= 0).all() and (target_sample < self.num_bars).all(), f'y {y} not in support set for borders (min_y, max_y) {self.borders}'
30
- assert logits.shape[-1] == self.num_bars, f'{logits.shape[-1]} vs {self.num_bars}'
 
 
 
 
31
 
32
  bucket_log_probs = torch.log_softmax(logits, -1)
33
  scaled_bucket_log_probs = bucket_log_probs - torch.log(self.bucket_widths)
34
- #print(bucket_log_probs, logits.shape)
35
 
36
- nll_loss = -scaled_bucket_log_probs.gather(-1,target_sample.unsqueeze(-1)).squeeze(-1)
 
 
37
 
38
  smooth_loss = -scaled_bucket_log_probs.mean(dim=-1)
39
- smoothing = self.smoothing if self.training else 0.
40
- loss = (1. - smoothing) * nll_loss + smoothing * smooth_loss
41
  return loss
42
 
43
  def mean(self, logits):
44
- bucket_means = self.borders[:-1] + self.bucket_widths/2
45
  p = torch.softmax(logits, -1)
46
  return p @ bucket_means
47
 
48
-
49
  def icdf(self, logits, left_prob):
50
  """
51
  Implementation of the quantile function
@@ -55,22 +68,32 @@ class BarDistribution(nn.Module):
55
  """
56
  probs = logits.softmax(-1)
57
  cumprobs = torch.cumsum(probs, -1)
58
- idx = torch.searchsorted(cumprobs, left_prob * torch.ones(*cumprobs.shape[:-1], 1, device = probs.device))\
59
- .squeeze(-1).clamp(0, cumprobs.shape[-1] - 1) # this might not do the right for outliers
 
 
 
 
 
 
60
  cumprobs = torch.cat(
61
  [torch.zeros(*cumprobs.shape[:-1], 1, device=logits.device), cumprobs], -1
62
  )
63
 
64
  rest_prob = left_prob - cumprobs.gather(-1, idx[..., None]).squeeze(-1)
65
  left_border = self.borders[idx]
66
- right_border = self.borders[idx+1]
67
- return left_border + (right_border - left_border) * rest_prob / probs.gather(-1, idx[..., None]).squeeze(-1)
68
-
69
- def quantile(self, logits, center_prob=.682):
70
- side_probs = (1.-center_prob)/2
71
- return torch.stack((self.icdf(logits, side_probs), self.icdf(logits, 1.-side_probs)),-1)
 
 
 
 
72
 
73
- def ucb(self, logits, best_f, rest_prob=(1-.682)/2, maximize=True):
74
  """
75
  UCB utility. Rest Prob is the amount of utility above (below) the confidence interval that is ignored.
76
  Higher rest_prob is equivalent to lower beta in the standard GP-UCB formulation.
@@ -90,23 +113,41 @@ class BarDistribution(nn.Module):
90
 
91
  def mode(self, logits):
92
  mode_inds = logits.argmax(-1)
93
- bucket_means = self.borders[:-1] + self.bucket_widths/2
94
  return bucket_means[mode_inds]
95
 
96
- def ei(self, logits, best_f, maximize=True): # logits: evaluation_points x batch x feature_dim
97
- bucket_means = self.borders[:-1] + self.bucket_widths/2
 
 
98
  if maximize:
99
  bucket_contributions = torch.tensor(
100
- [max((bucket_max + max(bucket_min, best_f)) / 2 - best_f,0) for
101
- bucket_min, bucket_max, bucket_mean in zip(self.borders[:-1], self.borders[1:], bucket_means)], dtype=logits.dtype, device=logits.device)
 
 
 
 
 
 
 
102
  else:
103
  bucket_contributions = torch.tensor(
104
- [-min((min(bucket_max,best_f) + bucket_min) / 2 - best_f,0) for # min on max instead of max on min, and compare min < instead of max >
105
- bucket_min, bucket_max, bucket_mean in zip(self.borders[:-1], self.borders[1:], bucket_means)], dtype=logits.dtype, device=logits.device)
 
 
 
 
 
 
 
106
  p = torch.softmax(logits, -1)
107
  return p @ bucket_contributions
108
 
109
- def pi(self, logits, best_f, maximize=True):# logits: evaluation_points x batch x feature_dim
 
 
110
  """
111
  Acquisition Function: Probability of Improvement
112
  :param logits: as returned by Transformer
@@ -117,10 +158,9 @@ class BarDistribution(nn.Module):
117
  assert maximize is True
118
  p = torch.softmax(logits, -1)
119
  border_widths = self.borders[1:] - self.borders[:-1]
120
- factor = 1. - ((best_f - self.borders[:-1]) / border_widths).clamp(0., 1.)
121
  return (p * factor).sum(-1)
122
 
123
-
124
  def mean_of_square(self, logits):
125
  """
126
  Computes E[x^2].
@@ -128,7 +168,11 @@ class BarDistribution(nn.Module):
128
  """
129
  left_borders = self.borders[:-1]
130
  right_borders = self.borders[1:]
131
- bucket_mean_of_square = (left_borders.square() + right_borders.square() + left_borders*right_borders)/3.
 
 
 
 
132
  p = torch.softmax(logits, -1)
133
  return p @ bucket_mean_of_square
134
 
@@ -138,54 +182,74 @@ class BarDistribution(nn.Module):
138
 
139
  class FullSupportBarDistribution(BarDistribution):
140
  @staticmethod
141
- def halfnormal_with_p_weight_before(range_max,p=.5):
142
- s = range_max / torch.distributions.HalfNormal(torch.tensor(1.)).icdf(torch.tensor(p))
 
 
143
  return torch.distributions.HalfNormal(s)
144
 
145
- def forward(self, logits, y): # gives the negative log density (the _loss_), y: T x B, logits: T x B x self.num_bars
 
 
146
  assert self.num_bars > 1
147
  target_sample = self.map_to_bucket_idx(y)
148
- target_sample.clamp_(0,self.num_bars-1)
149
  assert logits.shape[-1] == self.num_bars
150
 
151
  bucket_log_probs = torch.log_softmax(logits, -1)
152
  scaled_bucket_log_probs = bucket_log_probs - torch.log(self.bucket_widths)
153
- #print(bucket_log_probs, logits.shape)
154
- log_probs = scaled_bucket_log_probs.gather(-1,target_sample.unsqueeze(-1)).squeeze(-1)
155
-
156
- side_normals = (self.halfnormal_with_p_weight_before(self.bucket_widths[0]), self.halfnormal_with_p_weight_before(self.bucket_widths[-1]))
157
-
 
 
 
 
158
 
159
  # TODO look over it again
160
- log_probs[target_sample == 0] += side_normals[0].log_prob((self.borders[1]-y[target_sample == 0]).clamp(min=.00000001)) + torch.log(self.bucket_widths[0])
161
- log_probs[target_sample == self.num_bars-1] += side_normals[1].log_prob(y[target_sample == self.num_bars-1]-self.borders[-2]) + torch.log(self.bucket_widths[-1])
 
 
 
 
162
 
163
  nll_loss = -log_probs
164
 
165
  smooth_loss = -scaled_bucket_log_probs.mean(dim=-1)
166
- smoothing = self.smoothing if self.training else 0.
167
- loss = (1. - smoothing) * nll_loss + smoothing * smooth_loss
168
-
169
 
170
  return loss
171
 
172
  def mean(self, logits):
173
  bucket_means = self.borders[:-1] + self.bucket_widths / 2
174
  p = torch.softmax(logits, -1)
175
- side_normals = (self.halfnormal_with_p_weight_before(self.bucket_widths[0]),
176
- self.halfnormal_with_p_weight_before(self.bucket_widths[-1]))
 
 
177
  bucket_means[0] = -side_normals[0].mean + self.borders[1]
178
  bucket_means[-1] = side_normals[1].mean + self.borders[-2]
179
  return p @ bucket_means
180
 
181
 
182
-
183
- def get_bucket_limits_(num_outputs:int, full_range:tuple=None, ys:torch.Tensor=None, verbose:bool=False):
 
 
 
 
184
  assert (ys is not None) or (full_range is not None)
185
  if ys is not None:
186
  ys = ys.flatten()
187
- if len(ys) % num_outputs: ys = ys[:-(len(ys) % num_outputs)]
188
- print(f'Using {len(ys)} y evals to estimate {num_outputs} buckets. Cut off the last {len(ys) % num_outputs} ys.')
 
 
 
189
  ys_per_bucket = len(ys) // num_outputs
190
  if full_range is None:
191
  full_range = (ys.min(), ys.max())
@@ -193,17 +257,34 @@ def get_bucket_limits_(num_outputs:int, full_range:tuple=None, ys:torch.Tensor=N
193
  assert full_range[0] <= ys.min() and full_range[1] >= ys.max()
194
  full_range = torch.tensor(full_range)
195
  ys_sorted, ys_order = ys.sort(0)
196
- bucket_limits = (ys_sorted[ys_per_bucket-1::ys_per_bucket][:-1]+ys_sorted[ys_per_bucket::ys_per_bucket])/2
 
 
 
197
  if verbose:
198
- print(f'Using {len(ys)} y evals to estimate {num_outputs} buckets. Cut off the last {len(ys) % num_outputs} ys.')
 
 
199
  print(full_range)
200
- bucket_limits = torch.cat([full_range[0].unsqueeze(0), bucket_limits, full_range[1].unsqueeze(0)],0)
 
 
201
 
202
  else:
203
  class_width = (full_range[1] - full_range[0]) / num_outputs
204
- bucket_limits = torch.cat([full_range[0] + torch.arange(num_outputs).float()*class_width, torch.tensor(full_range[1]).unsqueeze(0)], 0)
 
 
 
 
 
 
205
 
206
- assert len(bucket_limits) - 1 == num_outputs and full_range[0] == bucket_limits[0] and full_range[-1] == bucket_limits[-1]
 
 
 
 
207
  return bucket_limits
208
 
209
 
@@ -266,4 +347,3 @@ def get_bucket_limits(
266
  ), f"{full_range[-1]} != {bucket_limits[-1]}"
267
 
268
  return bucket_limits
269
-
 
3
 
4
 
5
  class BarDistribution(nn.Module):
6
+ def __init__(
7
+ self, borders: torch.Tensor, smoothing=0.0
8
+ ): # here borders should start with min and end with max, where all values lie in (min,max) and are sorted
9
  # sorted list of borders
10
  super().__init__()
11
  assert len(borders.shape) == 1
12
+ # self.borders = borders
13
+ self.register_buffer("borders", borders)
14
+ self.register_buffer("smoothing", torch.tensor(smoothing))
15
+ # self.bucket_widths = self.borders[1:] - self.borders[:-1]
16
+ self.register_buffer("bucket_widths", self.borders[1:] - self.borders[:-1])
17
  full_width = self.bucket_widths.sum()
18
  border_order = torch.argsort(borders)
19
+ assert (
20
+ full_width - (self.borders[-1] - self.borders[0])
21
+ ).abs() < 1e-4, f"diff: {full_width - (self.borders[-1] - self.borders[0])}"
22
+ assert (
23
+ border_order == torch.arange(len(borders)).to(border_order.device)
24
+ ).all(), "Please provide sorted borders!"
25
  self.num_bars = len(borders) - 1
26
 
27
  def map_to_bucket_idx(self, y):
 
30
  target_sample[y == self.borders[-1]] = self.num_bars - 1
31
  return target_sample
32
 
33
+ def forward(
34
+ self, logits, y
35
+ ): # gives the negative log density (the _loss_), y: T x B, logits: T x B x self.num_bars
36
  target_sample = self.map_to_bucket_idx(y)
37
+ assert (target_sample >= 0).all() and (
38
+ target_sample < self.num_bars
39
+ ).all(), f"y {y} not in support set for borders (min_y, max_y) {self.borders}"
40
+ assert (
41
+ logits.shape[-1] == self.num_bars
42
+ ), f"{logits.shape[-1]} vs {self.num_bars}"
43
 
44
  bucket_log_probs = torch.log_softmax(logits, -1)
45
  scaled_bucket_log_probs = bucket_log_probs - torch.log(self.bucket_widths)
46
+ # print(bucket_log_probs, logits.shape)
47
 
48
+ nll_loss = -scaled_bucket_log_probs.gather(
49
+ -1, target_sample.unsqueeze(-1)
50
+ ).squeeze(-1)
51
 
52
  smooth_loss = -scaled_bucket_log_probs.mean(dim=-1)
53
+ smoothing = self.smoothing if self.training else 0.0
54
+ loss = (1.0 - smoothing) * nll_loss + smoothing * smooth_loss
55
  return loss
56
 
57
  def mean(self, logits):
58
+ bucket_means = self.borders[:-1] + self.bucket_widths / 2
59
  p = torch.softmax(logits, -1)
60
  return p @ bucket_means
61
 
 
62
  def icdf(self, logits, left_prob):
63
  """
64
  Implementation of the quantile function
 
68
  """
69
  probs = logits.softmax(-1)
70
  cumprobs = torch.cumsum(probs, -1)
71
+ idx = (
72
+ torch.searchsorted(
73
+ cumprobs,
74
+ left_prob * torch.ones(*cumprobs.shape[:-1], 1, device=probs.device),
75
+ )
76
+ .squeeze(-1)
77
+ .clamp(0, cumprobs.shape[-1] - 1)
78
+ ) # this might not do the right for outliers
79
  cumprobs = torch.cat(
80
  [torch.zeros(*cumprobs.shape[:-1], 1, device=logits.device), cumprobs], -1
81
  )
82
 
83
  rest_prob = left_prob - cumprobs.gather(-1, idx[..., None]).squeeze(-1)
84
  left_border = self.borders[idx]
85
+ right_border = self.borders[idx + 1]
86
+ return left_border + (right_border - left_border) * rest_prob / probs.gather(
87
+ -1, idx[..., None]
88
+ ).squeeze(-1)
89
+
90
+ def quantile(self, logits, center_prob=0.682):
91
+ side_probs = (1.0 - center_prob) / 2
92
+ return torch.stack(
93
+ (self.icdf(logits, side_probs), self.icdf(logits, 1.0 - side_probs)), -1
94
+ )
95
 
96
+ def ucb(self, logits, best_f, rest_prob=(1 - 0.682) / 2, maximize=True):
97
  """
98
  UCB utility. Rest Prob is the amount of utility above (below) the confidence interval that is ignored.
99
  Higher rest_prob is equivalent to lower beta in the standard GP-UCB formulation.
 
113
 
114
  def mode(self, logits):
115
  mode_inds = logits.argmax(-1)
116
+ bucket_means = self.borders[:-1] + self.bucket_widths / 2
117
  return bucket_means[mode_inds]
118
 
119
+ def ei(
120
+ self, logits, best_f, maximize=True
121
+ ): # logits: evaluation_points x batch x feature_dim
122
+ bucket_means = self.borders[:-1] + self.bucket_widths / 2
123
  if maximize:
124
  bucket_contributions = torch.tensor(
125
+ [
126
+ max((bucket_max + max(bucket_min, best_f)) / 2 - best_f, 0)
127
+ for bucket_min, bucket_max, bucket_mean in zip(
128
+ self.borders[:-1], self.borders[1:], bucket_means
129
+ )
130
+ ],
131
+ dtype=logits.dtype,
132
+ device=logits.device,
133
+ )
134
  else:
135
  bucket_contributions = torch.tensor(
136
+ [
137
+ -min((min(bucket_max, best_f) + bucket_min) / 2 - best_f, 0)
138
+ for bucket_min, bucket_max, bucket_mean in zip( # min on max instead of max on min, and compare min < instead of max >
139
+ self.borders[:-1], self.borders[1:], bucket_means
140
+ )
141
+ ],
142
+ dtype=logits.dtype,
143
+ device=logits.device,
144
+ )
145
  p = torch.softmax(logits, -1)
146
  return p @ bucket_contributions
147
 
148
+ def pi(
149
+ self, logits, best_f, maximize=True
150
+ ): # logits: evaluation_points x batch x feature_dim
151
  """
152
  Acquisition Function: Probability of Improvement
153
  :param logits: as returned by Transformer
 
158
  assert maximize is True
159
  p = torch.softmax(logits, -1)
160
  border_widths = self.borders[1:] - self.borders[:-1]
161
+ factor = 1.0 - ((best_f - self.borders[:-1]) / border_widths).clamp(0.0, 1.0)
162
  return (p * factor).sum(-1)
163
 
 
164
  def mean_of_square(self, logits):
165
  """
166
  Computes E[x^2].
 
168
  """
169
  left_borders = self.borders[:-1]
170
  right_borders = self.borders[1:]
171
+ bucket_mean_of_square = (
172
+ left_borders.square()
173
+ + right_borders.square()
174
+ + left_borders * right_borders
175
+ ) / 3.0
176
  p = torch.softmax(logits, -1)
177
  return p @ bucket_mean_of_square
178
 
 
182
 
183
  class FullSupportBarDistribution(BarDistribution):
184
  @staticmethod
185
+ def halfnormal_with_p_weight_before(range_max, p=0.5):
186
+ s = range_max / torch.distributions.HalfNormal(torch.tensor(1.0)).icdf(
187
+ torch.tensor(p)
188
+ )
189
  return torch.distributions.HalfNormal(s)
190
 
191
+ def forward(
192
+ self, logits, y
193
+ ): # gives the negative log density (the _loss_), y: T x B, logits: T x B x self.num_bars
194
  assert self.num_bars > 1
195
  target_sample = self.map_to_bucket_idx(y)
196
+ target_sample.clamp_(0, self.num_bars - 1)
197
  assert logits.shape[-1] == self.num_bars
198
 
199
  bucket_log_probs = torch.log_softmax(logits, -1)
200
  scaled_bucket_log_probs = bucket_log_probs - torch.log(self.bucket_widths)
201
+ # print(bucket_log_probs, logits.shape)
202
+ log_probs = scaled_bucket_log_probs.gather(
203
+ -1, target_sample.unsqueeze(-1)
204
+ ).squeeze(-1)
205
+
206
+ side_normals = (
207
+ self.halfnormal_with_p_weight_before(self.bucket_widths[0]),
208
+ self.halfnormal_with_p_weight_before(self.bucket_widths[-1]),
209
+ )
210
 
211
  # TODO look over it again
212
+ log_probs[target_sample == 0] += side_normals[0].log_prob(
213
+ (self.borders[1] - y[target_sample == 0]).clamp(min=0.00000001)
214
+ ) + torch.log(self.bucket_widths[0])
215
+ log_probs[target_sample == self.num_bars - 1] += side_normals[1].log_prob(
216
+ y[target_sample == self.num_bars - 1] - self.borders[-2]
217
+ ) + torch.log(self.bucket_widths[-1])
218
 
219
  nll_loss = -log_probs
220
 
221
  smooth_loss = -scaled_bucket_log_probs.mean(dim=-1)
222
+ smoothing = self.smoothing if self.training else 0.0
223
+ loss = (1.0 - smoothing) * nll_loss + smoothing * smooth_loss
 
224
 
225
  return loss
226
 
227
  def mean(self, logits):
228
  bucket_means = self.borders[:-1] + self.bucket_widths / 2
229
  p = torch.softmax(logits, -1)
230
+ side_normals = (
231
+ self.halfnormal_with_p_weight_before(self.bucket_widths[0]),
232
+ self.halfnormal_with_p_weight_before(self.bucket_widths[-1]),
233
+ )
234
  bucket_means[0] = -side_normals[0].mean + self.borders[1]
235
  bucket_means[-1] = side_normals[1].mean + self.borders[-2]
236
  return p @ bucket_means
237
 
238
 
239
+ def get_bucket_limits_(
240
+ num_outputs: int,
241
+ full_range: tuple = None,
242
+ ys: torch.Tensor = None,
243
+ verbose: bool = False,
244
+ ):
245
  assert (ys is not None) or (full_range is not None)
246
  if ys is not None:
247
  ys = ys.flatten()
248
+ if len(ys) % num_outputs:
249
+ ys = ys[: -(len(ys) % num_outputs)]
250
+ print(
251
+ f"Using {len(ys)} y evals to estimate {num_outputs} buckets. Cut off the last {len(ys) % num_outputs} ys."
252
+ )
253
  ys_per_bucket = len(ys) // num_outputs
254
  if full_range is None:
255
  full_range = (ys.min(), ys.max())
 
257
  assert full_range[0] <= ys.min() and full_range[1] >= ys.max()
258
  full_range = torch.tensor(full_range)
259
  ys_sorted, ys_order = ys.sort(0)
260
+ bucket_limits = (
261
+ ys_sorted[ys_per_bucket - 1 :: ys_per_bucket][:-1]
262
+ + ys_sorted[ys_per_bucket::ys_per_bucket]
263
+ ) / 2
264
  if verbose:
265
+ print(
266
+ f"Using {len(ys)} y evals to estimate {num_outputs} buckets. Cut off the last {len(ys) % num_outputs} ys."
267
+ )
268
  print(full_range)
269
+ bucket_limits = torch.cat(
270
+ [full_range[0].unsqueeze(0), bucket_limits, full_range[1].unsqueeze(0)], 0
271
+ )
272
 
273
  else:
274
  class_width = (full_range[1] - full_range[0]) / num_outputs
275
+ bucket_limits = torch.cat(
276
+ [
277
+ full_range[0] + torch.arange(num_outputs).float() * class_width,
278
+ torch.tensor(full_range[1]).unsqueeze(0),
279
+ ],
280
+ 0,
281
+ )
282
 
283
+ assert (
284
+ len(bucket_limits) - 1 == num_outputs
285
+ and full_range[0] == bucket_limits[0]
286
+ and full_range[-1] == bucket_limits[-1]
287
+ )
288
  return bucket_limits
289
 
290
 
 
347
  ), f"{full_range[-1]} != {bucket_limits[-1]}"
348
 
349
  return bucket_limits
 
lcpfn/decoders.py CHANGED
@@ -2,6 +2,14 @@ import torch
2
  from torch import nn
3
  import random
4
 
 
 
 
 
 
 
 
 
5
 
6
  class ScaledDecoder(nn.Module):
7
  def __init__(self, ninp, nhid, nout):
@@ -11,20 +19,24 @@ class ScaledDecoder(nn.Module):
11
  self.linear2 = nn.Linear(nhid, 10)
12
 
13
  def forward(self, x):
14
- #return torch.cat([self.linear1(x), self.linear2(x)], -1)
15
  x = self.linear(x)
16
- x = nn.GELU()(x)
17
- temps = self.linear2(x).softmax(-1) @ torch.tensor([1.,1.4,1.7,2.,5.,10.,20.,40.,80.,160.], device=x.device)
18
- if random.random() > .99:
19
- print(temps.shape,temps[:,:2])
 
 
20
  return self.linear1(x) / temps.unsqueeze(-1)
21
 
 
22
  class FixedScaledDecoder(nn.Module):
23
  def __init__(self, ninp, nhid, nout):
24
  super().__init__()
25
- self.mapper = nn.Sequential(nn.Linear(ninp, nhid), nn.GELU(), nn.Linear(nhid, nout))
26
- self.T = nn.Parameter(torch.ones(10000)/10000)
 
 
27
 
28
  def forward(self, x):
29
- return self.mapper(x)/self.T.sum()
30
-
 
2
  from torch import nn
3
  import random
4
 
5
+ from torch import Tensor
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class GELU(nn.Module):
10
+ def forward(self, input: Tensor) -> Tensor:
11
+ return F.gelu(input)
12
+
13
 
14
  class ScaledDecoder(nn.Module):
15
  def __init__(self, ninp, nhid, nout):
 
19
  self.linear2 = nn.Linear(nhid, 10)
20
 
21
  def forward(self, x):
22
+ # return torch.cat([self.linear1(x), self.linear2(x)], -1)
23
  x = self.linear(x)
24
+ x = GELU()(x)
25
+ temps = self.linear2(x).softmax(-1) @ torch.tensor(
26
+ [1.0, 1.4, 1.7, 2.0, 5.0, 10.0, 20.0, 40.0, 80.0, 160.0], device=x.device
27
+ )
28
+ if random.random() > 0.99:
29
+ print(temps.shape, temps[:, :2])
30
  return self.linear1(x) / temps.unsqueeze(-1)
31
 
32
+
33
  class FixedScaledDecoder(nn.Module):
34
  def __init__(self, ninp, nhid, nout):
35
  super().__init__()
36
+ self.mapper = nn.Sequential(
37
+ nn.Linear(ninp, nhid), nn.GELU(), nn.Linear(nhid, nout)
38
+ )
39
+ self.T = nn.Parameter(torch.ones(10000) / 10000)
40
 
41
  def forward(self, x):
42
+ return self.mapper(x) / self.T.sum()
 
lcpfn/domhan_prior.py CHANGED
@@ -58,7 +58,10 @@ def prior_weights(
58
 
59
  def sample_from_prior(rng, seq_len=100):
60
  return sample_prior_comb(
61
- rng=rng, seq_len=seq_len, components=["pow3", "ilog2", "janoschek"], distribution="peaked"
 
 
 
62
  )
63
 
64
 
@@ -103,7 +106,7 @@ def sample_prior_comb(
103
  f_priors = {
104
  "pow3": uniform_prior_pow3,
105
  "ilog2": uniform_prior_ilog2,
106
- "janoschek": uniform_prior_janoschek
107
  }
108
  else:
109
  raise NotImplemented()
@@ -153,6 +156,7 @@ def generate_prior_dataset(n, prior=sample_prior_comb, seed=42):
153
  def create_get_batch_func(prior):
154
  return partial(get_batch_domhan, prior=prior)
155
 
 
156
  # function producing batches for PFN training
157
  def get_batch_domhan(
158
  batch_size,
@@ -192,4 +196,4 @@ def get_batch_domhan(
192
  y_target = y_target.float()
193
  y_noisy = y_noisy.float()
194
 
195
- return x, y_noisy, y_target
 
58
 
59
  def sample_from_prior(rng, seq_len=100):
60
  return sample_prior_comb(
61
+ rng=rng,
62
+ seq_len=seq_len,
63
+ components=["pow3", "ilog2", "janoschek"],
64
+ distribution="peaked",
65
  )
66
 
67
 
 
106
  f_priors = {
107
  "pow3": uniform_prior_pow3,
108
  "ilog2": uniform_prior_ilog2,
109
+ "janoschek": uniform_prior_janoschek,
110
  }
111
  else:
112
  raise NotImplemented()
 
156
  def create_get_batch_func(prior):
157
  return partial(get_batch_domhan, prior=prior)
158
 
159
+
160
  # function producing batches for PFN training
161
  def get_batch_domhan(
162
  batch_size,
 
196
  y_target = y_target.float()
197
  y_noisy = y_noisy.float()
198
 
199
+ return x, y_noisy, y_target
lcpfn/encoders.py CHANGED
@@ -18,34 +18,45 @@ class StyleEncoder(nn.Module):
18
 
19
 
20
  class _PositionalEncoding(nn.Module):
21
- def __init__(self, d_model, dropout=0.):
22
  super().__init__()
23
  self.dropout = nn.Dropout(p=dropout)
24
  self.d_model = d_model
25
- self.device_test_tensor = nn.Parameter(torch.tensor(1.))
26
 
27
- def forward(self, x):# T x B x num_features
28
- assert self.d_model % x.shape[-1]*2 == 0
29
  d_per_feature = self.d_model // x.shape[-1]
30
  pe = torch.zeros(*x.shape, d_per_feature, device=self.device_test_tensor.device)
31
- #position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
32
  interval_size = 10
33
- div_term = (1./interval_size) * 2*math.pi*torch.exp(torch.arange(0, d_per_feature, 2, device=self.device_test_tensor.device).float()*math.log(math.sqrt(2)))
34
- #print(div_term/2/math.pi)
 
 
 
 
 
 
 
 
 
 
35
  pe[..., 0::2] = torch.sin(x.unsqueeze(-1) * div_term)
36
  pe[..., 1::2] = torch.cos(x.unsqueeze(-1) * div_term)
37
- return self.dropout(pe).view(x.shape[0],x.shape[1],self.d_model)
38
 
39
 
40
  Positional = lambda _, emsize: _PositionalEncoding(d_model=emsize)
41
 
 
42
  class EmbeddingEncoder(nn.Module):
43
  def __init__(self, num_features, em_size, num_embs=100):
44
  super().__init__()
45
  self.num_embs = num_embs
46
  self.embeddings = nn.Embedding(num_embs * num_features, em_size, max_norm=True)
47
- self.init_weights(.1)
48
- self.min_max = (-2,+2)
49
 
50
  @property
51
  def width(self):
@@ -60,7 +71,9 @@ class EmbeddingEncoder(nn.Module):
60
 
61
  def forward(self, x): # T x B x num_features
62
  x_idxs = self.discretize(x)
63
- x_idxs += torch.arange(x.shape[-1], device=x.device).view(1, 1, -1) * self.num_embs
 
 
64
  # print(x_idxs,self.embeddings.weight.shape)
65
  return self.embeddings(x_idxs).mean(-2)
66
 
@@ -72,7 +85,7 @@ class Normalize(nn.Module):
72
  self.std = std
73
 
74
  def forward(self, x):
75
- return (x-self.mean)/self.std
76
 
77
 
78
  def get_normalized_uniform_encoder(encoder_creator):
@@ -83,13 +96,16 @@ def get_normalized_uniform_encoder(encoder_creator):
83
  :param encoder:
84
  :return:
85
  """
86
- return lambda in_dim, out_dim: nn.Sequential(Normalize(.5, math.sqrt(1/12)), encoder_creator(in_dim, out_dim))
 
 
87
 
88
 
89
  Linear = nn.Linear
90
- MLP = lambda num_features, emsize: nn.Sequential(nn.Linear(num_features+1,emsize*2),
91
- nn.ReLU(),
92
- nn.Linear(emsize*2,emsize))
 
93
 
94
  class NanHandlingEncoder(nn.Module):
95
  def __init__(self, num_features, emsize, keep_nans=True):
@@ -101,10 +117,17 @@ class NanHandlingEncoder(nn.Module):
101
 
102
  def forward(self, x):
103
  if self.keep_nans:
104
- x = torch.cat([torch.nan_to_num(x, nan=0.0), normalize_data(torch.isnan(x) * -1
105
- + torch.logical_and(torch.isinf(x), torch.sign(x) == 1) * 1
106
- + torch.logical_and(torch.isinf(x), torch.sign(x) == -1) * 2
107
- )], -1)
 
 
 
 
 
 
 
108
  else:
109
  x = torch.nan_to_num(x, nan=0.0)
110
  return self.layer(x)
@@ -124,24 +147,28 @@ class Linear(nn.Linear):
124
  class Conv(nn.Module):
125
  def __init__(self, input_size, emsize):
126
  super().__init__()
127
- self.convs = torch.nn.ModuleList([nn.Conv2d(64 if i else 1, 64, 3) for i in range(5)])
128
- self.linear = nn.Linear(64,emsize)
 
 
129
 
130
  def forward(self, x):
131
  size = math.isqrt(x.shape[-1])
132
- assert size*size == x.shape[-1]
133
  x = x.reshape(*x.shape[:-1], 1, size, size)
134
  for conv in self.convs:
135
  if x.shape[-1] < 4:
136
  break
137
  x = conv(x)
138
  x.relu_()
139
- x = nn.AdaptiveAvgPool2d((1,1))(x).squeeze(-1).squeeze(-1)
140
  return self.linear(x)
141
 
142
 
143
  class CanEmb(nn.Embedding):
144
- def __init__(self, num_features, num_embeddings: int, embedding_dim: int, *args, **kwargs):
 
 
145
  assert embedding_dim % num_features == 0
146
  embedding_dim = embedding_dim // num_features
147
  super().__init__(num_embeddings, embedding_dim, *args, **kwargs)
@@ -158,4 +185,6 @@ def get_Canonical(num_classes):
158
 
159
 
160
  def get_Embedding(num_embs_per_feature=100):
161
- return lambda num_features, emsize: EmbeddingEncoder(num_features, emsize, num_embs=num_embs_per_feature)
 
 
 
18
 
19
 
20
  class _PositionalEncoding(nn.Module):
21
+ def __init__(self, d_model, dropout=0.0):
22
  super().__init__()
23
  self.dropout = nn.Dropout(p=dropout)
24
  self.d_model = d_model
25
+ self.device_test_tensor = nn.Parameter(torch.tensor(1.0))
26
 
27
+ def forward(self, x): # T x B x num_features
28
+ assert self.d_model % x.shape[-1] * 2 == 0
29
  d_per_feature = self.d_model // x.shape[-1]
30
  pe = torch.zeros(*x.shape, d_per_feature, device=self.device_test_tensor.device)
31
+ # position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
32
  interval_size = 10
33
+ div_term = (
34
+ (1.0 / interval_size)
35
+ * 2
36
+ * math.pi
37
+ * torch.exp(
38
+ torch.arange(
39
+ 0, d_per_feature, 2, device=self.device_test_tensor.device
40
+ ).float()
41
+ * math.log(math.sqrt(2))
42
+ )
43
+ )
44
+ # print(div_term/2/math.pi)
45
  pe[..., 0::2] = torch.sin(x.unsqueeze(-1) * div_term)
46
  pe[..., 1::2] = torch.cos(x.unsqueeze(-1) * div_term)
47
+ return self.dropout(pe).view(x.shape[0], x.shape[1], self.d_model)
48
 
49
 
50
  Positional = lambda _, emsize: _PositionalEncoding(d_model=emsize)
51
 
52
+
53
  class EmbeddingEncoder(nn.Module):
54
  def __init__(self, num_features, em_size, num_embs=100):
55
  super().__init__()
56
  self.num_embs = num_embs
57
  self.embeddings = nn.Embedding(num_embs * num_features, em_size, max_norm=True)
58
+ self.init_weights(0.1)
59
+ self.min_max = (-2, +2)
60
 
61
  @property
62
  def width(self):
 
71
 
72
  def forward(self, x): # T x B x num_features
73
  x_idxs = self.discretize(x)
74
+ x_idxs += (
75
+ torch.arange(x.shape[-1], device=x.device).view(1, 1, -1) * self.num_embs
76
+ )
77
  # print(x_idxs,self.embeddings.weight.shape)
78
  return self.embeddings(x_idxs).mean(-2)
79
 
 
85
  self.std = std
86
 
87
  def forward(self, x):
88
+ return (x - self.mean) / self.std
89
 
90
 
91
  def get_normalized_uniform_encoder(encoder_creator):
 
96
  :param encoder:
97
  :return:
98
  """
99
+ return lambda in_dim, out_dim: nn.Sequential(
100
+ Normalize(0.5, math.sqrt(1 / 12)), encoder_creator(in_dim, out_dim)
101
+ )
102
 
103
 
104
  Linear = nn.Linear
105
+ MLP = lambda num_features, emsize: nn.Sequential(
106
+ nn.Linear(num_features + 1, emsize * 2), nn.ReLU(), nn.Linear(emsize * 2, emsize)
107
+ )
108
+
109
 
110
  class NanHandlingEncoder(nn.Module):
111
  def __init__(self, num_features, emsize, keep_nans=True):
 
117
 
118
  def forward(self, x):
119
  if self.keep_nans:
120
+ x = torch.cat(
121
+ [
122
+ torch.nan_to_num(x, nan=0.0),
123
+ normalize_data(
124
+ torch.isnan(x) * -1
125
+ + torch.logical_and(torch.isinf(x), torch.sign(x) == 1) * 1
126
+ + torch.logical_and(torch.isinf(x), torch.sign(x) == -1) * 2
127
+ ),
128
+ ],
129
+ -1,
130
+ )
131
  else:
132
  x = torch.nan_to_num(x, nan=0.0)
133
  return self.layer(x)
 
147
  class Conv(nn.Module):
148
  def __init__(self, input_size, emsize):
149
  super().__init__()
150
+ self.convs = torch.nn.ModuleList(
151
+ [nn.Conv2d(64 if i else 1, 64, 3) for i in range(5)]
152
+ )
153
+ self.linear = nn.Linear(64, emsize)
154
 
155
  def forward(self, x):
156
  size = math.isqrt(x.shape[-1])
157
+ assert size * size == x.shape[-1]
158
  x = x.reshape(*x.shape[:-1], 1, size, size)
159
  for conv in self.convs:
160
  if x.shape[-1] < 4:
161
  break
162
  x = conv(x)
163
  x.relu_()
164
+ x = nn.AdaptiveAvgPool2d((1, 1))(x).squeeze(-1).squeeze(-1)
165
  return self.linear(x)
166
 
167
 
168
  class CanEmb(nn.Embedding):
169
+ def __init__(
170
+ self, num_features, num_embeddings: int, embedding_dim: int, *args, **kwargs
171
+ ):
172
  assert embedding_dim % num_features == 0
173
  embedding_dim = embedding_dim // num_features
174
  super().__init__(num_embeddings, embedding_dim, *args, **kwargs)
 
185
 
186
 
187
  def get_Embedding(num_embs_per_feature=100):
188
+ return lambda num_features, emsize: EmbeddingEncoder(
189
+ num_features, emsize, num_embs=num_embs_per_feature
190
+ )
lcpfn/initializers.py CHANGED
@@ -1,9 +1,11 @@
1
  import torch
2
  from torch import nn
3
 
 
4
  def get_NormalInitializer(std):
5
  def initializer(m):
6
  if isinstance(m, nn.Linear):
7
  nn.init.normal_(m.weight, 0, std)
8
  nn.init.normal_(m.bias, 0, std)
9
- return initializer
 
 
1
  import torch
2
  from torch import nn
3
 
4
+
5
  def get_NormalInitializer(std):
6
  def initializer(m):
7
  if isinstance(m, nn.Linear):
8
  nn.init.normal_(m.weight, 0, std)
9
  nn.init.normal_(m.bias, 0, std)
10
+
11
+ return initializer
lcpfn/layer.py CHANGED
@@ -36,15 +36,28 @@ class TransformerEncoderLayer(nn.Module):
36
  >>> src = torch.rand(32, 10, 512)
37
  >>> out = encoder_layer(src)
38
  """
39
- __constants__ = ['batch_first']
40
 
41
- def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu",
42
- layer_norm_eps=1e-5, batch_first=False, pre_norm=False,
43
- device=None, dtype=None, recompute_attn=False) -> None:
44
- factory_kwargs = {'device': device, 'dtype': dtype}
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  super().__init__()
46
- self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
47
- **factory_kwargs)
 
48
  # Implementation of Feedforward model
49
  self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
50
  self.dropout = Dropout(dropout)
@@ -60,11 +73,16 @@ class TransformerEncoderLayer(nn.Module):
60
  self.activation = _get_activation_fn(activation)
61
 
62
  def __setstate__(self, state):
63
- if 'activation' not in state:
64
- state['activation'] = F.relu
65
  super().__setstate__(state)
66
 
67
- def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
 
 
 
 
 
68
  r"""Pass the input through the encoder layer.
69
 
70
  Args:
@@ -90,26 +108,61 @@ class TransformerEncoderLayer(nn.Module):
90
  num_train_tokens = trainset_src_mask.shape[0]
91
 
92
  global_tokens_src = src_[:num_global_tokens]
93
- train_tokens_src = src_[num_global_tokens:num_global_tokens+num_train_tokens]
94
- global_and_train_tokens_src = src_[:num_global_tokens+num_train_tokens]
95
- eval_tokens_src = src_[num_global_tokens+num_train_tokens:]
96
-
97
-
98
- attn = partial(checkpoint, self.self_attn) if self.recompute_attn else self.self_attn
99
-
100
- global_tokens_src2 = attn(global_tokens_src, global_and_train_tokens_src, global_and_train_tokens_src, None, True, global_src_mask)[0]
101
- train_tokens_src2 = attn(train_tokens_src, global_tokens_src, global_tokens_src, None, True, trainset_src_mask)[0]
102
- eval_tokens_src2 = attn(eval_tokens_src, src_, src_,
103
- None, True, valset_src_mask)[0]
104
-
105
- src2 = torch.cat([global_tokens_src2, train_tokens_src2, eval_tokens_src2], dim=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  else:
108
  if self.recompute_attn:
109
- src2 = checkpoint(self.self_attn, src_, src_, src_, src_key_padding_mask, True, src_mask)[0]
 
 
 
 
 
 
 
 
110
  else:
111
- src2 = self.self_attn(src_, src_, src_, attn_mask=src_mask,
112
- key_padding_mask=src_key_padding_mask)[0]
 
 
 
 
 
113
  src = src + self.dropout1(src2)
114
  if not self.pre_norm:
115
  src = self.norm1(src)
@@ -123,4 +176,4 @@ class TransformerEncoderLayer(nn.Module):
123
 
124
  if not self.pre_norm:
125
  src = self.norm2(src)
126
- return src
 
36
  >>> src = torch.rand(32, 10, 512)
37
  >>> out = encoder_layer(src)
38
  """
 
39
 
40
+ __constants__ = ["batch_first"]
41
+
42
+ def __init__(
43
+ self,
44
+ d_model,
45
+ nhead,
46
+ dim_feedforward=2048,
47
+ dropout=0.1,
48
+ activation="relu",
49
+ layer_norm_eps=1e-5,
50
+ batch_first=False,
51
+ pre_norm=False,
52
+ device=None,
53
+ dtype=None,
54
+ recompute_attn=False,
55
+ ) -> None:
56
+ factory_kwargs = {"device": device, "dtype": dtype}
57
  super().__init__()
58
+ self.self_attn = MultiheadAttention(
59
+ d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs
60
+ )
61
  # Implementation of Feedforward model
62
  self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
63
  self.dropout = Dropout(dropout)
 
73
  self.activation = _get_activation_fn(activation)
74
 
75
  def __setstate__(self, state):
76
+ if "activation" not in state:
77
+ state["activation"] = F.relu
78
  super().__setstate__(state)
79
 
80
+ def forward(
81
+ self,
82
+ src: Tensor,
83
+ src_mask: Optional[Tensor] = None,
84
+ src_key_padding_mask: Optional[Tensor] = None,
85
+ ) -> Tensor:
86
  r"""Pass the input through the encoder layer.
87
 
88
  Args:
 
108
  num_train_tokens = trainset_src_mask.shape[0]
109
 
110
  global_tokens_src = src_[:num_global_tokens]
111
+ train_tokens_src = src_[
112
+ num_global_tokens : num_global_tokens + num_train_tokens
113
+ ]
114
+ global_and_train_tokens_src = src_[: num_global_tokens + num_train_tokens]
115
+ eval_tokens_src = src_[num_global_tokens + num_train_tokens :]
116
+
117
+ attn = (
118
+ partial(checkpoint, self.self_attn)
119
+ if self.recompute_attn
120
+ else self.self_attn
121
+ )
122
+
123
+ global_tokens_src2 = attn(
124
+ global_tokens_src,
125
+ global_and_train_tokens_src,
126
+ global_and_train_tokens_src,
127
+ None,
128
+ True,
129
+ global_src_mask,
130
+ )[0]
131
+ train_tokens_src2 = attn(
132
+ train_tokens_src,
133
+ global_tokens_src,
134
+ global_tokens_src,
135
+ None,
136
+ True,
137
+ trainset_src_mask,
138
+ )[0]
139
+ eval_tokens_src2 = attn(
140
+ eval_tokens_src, src_, src_, None, True, valset_src_mask
141
+ )[0]
142
+
143
+ src2 = torch.cat(
144
+ [global_tokens_src2, train_tokens_src2, eval_tokens_src2], dim=0
145
+ )
146
 
147
  else:
148
  if self.recompute_attn:
149
+ src2 = checkpoint(
150
+ self.self_attn,
151
+ src_,
152
+ src_,
153
+ src_,
154
+ src_key_padding_mask,
155
+ True,
156
+ src_mask,
157
+ )[0]
158
  else:
159
+ src2 = self.self_attn(
160
+ src_,
161
+ src_,
162
+ src_,
163
+ attn_mask=src_mask,
164
+ key_padding_mask=src_key_padding_mask,
165
+ )[0]
166
  src = src + self.dropout1(src2)
167
  if not self.pre_norm:
168
  src = self.norm1(src)
 
176
 
177
  if not self.pre_norm:
178
  src = self.norm2(src)
179
+ return src
lcpfn/model.py CHANGED
@@ -1,29 +1,56 @@
1
  import torch
2
  import lcpfn
 
 
 
3
 
4
  class LCPFN(torch.nn.Module):
5
  def __init__(self, model_name="EMSIZE512_NLAYERS12_NBUCKETS1000"):
6
  super(LCPFN, self).__init__()
7
- self.model = torch.load(getattr(lcpfn, model_name) if model_name in lcpfn.model_dict else model_name)
 
 
8
  self.model.eval()
9
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  @torch.no_grad()
11
- def predict_mean(self, x_train, y_train, x_test):
12
- logits = self(x_train=x_train, y_train=y_train, x_test=x_test)
13
- return self.model.criterion.mean(logits)
 
 
 
14
 
15
  @torch.no_grad()
16
- def predict_quantiles(self, x_train, y_train, x_test, qs):
17
- logits = self(x_train=x_train, y_train=y_train, x_test=x_test)
18
- return torch.cat([self.model.criterion.icdf(logits, q) for q in qs], dim=1)
 
 
 
 
 
19
 
20
  @torch.no_grad()
21
  def nll_loss(self, x_train, y_train, x_test, y_test):
 
22
  logits = self(x_train=x_train, y_train=y_train, x_test=x_test)
23
  return self.model.criterion(logits, y_test)
24
 
25
  def forward(self, x_train, y_train, x_test):
 
26
  single_eval_pos = x_train.shape[0]
27
  x = torch.cat([x_train, x_test], dim=0).unsqueeze(1)
28
  y = y_train.unsqueeze(1)
29
- return self.model((x, y), single_eval_pos=single_eval_pos)
 
1
  import torch
2
  import lcpfn
3
+ import warnings
4
+ from lcpfn import utils
5
+
6
 
7
  class LCPFN(torch.nn.Module):
8
  def __init__(self, model_name="EMSIZE512_NLAYERS12_NBUCKETS1000"):
9
  super(LCPFN, self).__init__()
10
+ self.model = torch.load(
11
+ getattr(lcpfn, model_name) if model_name in lcpfn.model_dict else model_name
12
+ )
13
  self.model.eval()
14
 
15
+ def check_input(self, x_train, x_test, y_train, y_test=None):
16
+ if torch.any(x_train < 0) or torch.any(x_test < 0):
17
+ # raise warning if input has negative values
18
+ raise Exception("x values should be non-negative")
19
+ if torch.any((0 > y_train) | (y_train > 1)) or (
20
+ y_test is not None and torch.any(0 < y_test < 1)
21
+ ):
22
+ # raise warning if input has values outside [0,1]
23
+ raise Exception(
24
+ "y values should be in the range [0,1]. Please set normalizer_kwargs accordingly."
25
+ )
26
+
27
  @torch.no_grad()
28
+ def predict_mean(
29
+ self, x_train, y_train, x_test, normalizer=utils.identity_normalizer()
30
+ ):
31
+ y_train_norm = normalizer[0](y_train)
32
+ logits = self(x_train=x_train, y_train=y_train_norm, x_test=x_test)
33
+ return normalizer[1](self.model.criterion.mean(logits))
34
 
35
  @torch.no_grad()
36
+ def predict_quantiles(
37
+ self, x_train, y_train, x_test, qs, normalizer=utils.identity_normalizer()
38
+ ):
39
+ y_train_norm = normalizer[0](y_train)
40
+ logits = self(x_train=x_train, y_train=y_train_norm, x_test=x_test)
41
+ return normalizer[1](
42
+ torch.cat([self.model.criterion.icdf(logits, q) for q in qs], dim=1)
43
+ )
44
 
45
  @torch.no_grad()
46
  def nll_loss(self, x_train, y_train, x_test, y_test):
47
+ # TODO add normalizer_kwargs
48
  logits = self(x_train=x_train, y_train=y_train, x_test=x_test)
49
  return self.model.criterion(logits, y_test)
50
 
51
  def forward(self, x_train, y_train, x_test):
52
+ self.check_input(x_train, x_test, y_train)
53
  single_eval_pos = x_train.shape[0]
54
  x = torch.cat([x_train, x_test], dim=0).unsqueeze(1)
55
  y = y_train.unsqueeze(1)
56
+ return self.model((x, y), single_eval_pos=single_eval_pos)
lcpfn/positional_encodings.py CHANGED
@@ -15,7 +15,7 @@ class NoPositionalEncoding(nn.Module):
15
  pass
16
 
17
  def forward(self, x):
18
- return x #* math.sqrt(x.shape[-1])
19
 
20
 
21
  class PositionalEncoding(nn.Module):
@@ -23,14 +23,16 @@ class PositionalEncoding(nn.Module):
23
  super(PositionalEncoding, self).__init__()
24
  pe = torch.zeros(max_len, d_model)
25
  position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
26
- div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
 
 
27
  pe[:, 0::2] = torch.sin(position * div_term)
28
  pe[:, 1::2] = torch.cos(position * div_term)
29
  pe = pe.unsqueeze(0).transpose(0, 1)
30
- self.register_buffer('pe', pe)
31
 
32
  def forward(self, x):
33
- x = self.pe[:x.size(0), :] + x # * math.sqrt(x.shape[-1])
34
  return x
35
 
36
 
@@ -38,33 +40,39 @@ class LearnedPositionalEncoding(nn.Module):
38
  def __init__(self, d_model, max_len=5000):
39
  super(LearnedPositionalEncoding, self).__init__()
40
  self.max_seq_len = max_len
41
- #self.positional_embeddings = nn.Embedding(max_len, d_model)
42
  self.positional_embeddings = nn.Parameter(torch.empty(max_len, d_model))
43
- nn.init.normal_(self.positional_embeddings, mean=0, std=d_model ** -0.5)
44
 
45
  def forward(self, x):
46
  seq_len, bs, d_model = x.shape
47
- assert seq_len <= len(self.positional_embeddings), 'seq_len can be at most max_len.'
 
 
48
  pos_emb = self.positional_embeddings[:seq_len]
49
- return pos_emb.unsqueeze(1).expand(seq_len, bs, d_model) + x #* math.sqrt(x.shape[-1])
 
 
50
 
51
 
52
  class PairedScrambledPositionalEncodings(LearnedPositionalEncoding):
53
  # TODO check whether it is a problem to use the same perm. for full batch
54
  def forward(self, x):
55
  seq_len, bs, d_model = x.shape
56
- assert seq_len <= len(self.positional_embeddings), 'seq_len can be at most max_len.'
57
- assert len(self.positional_embeddings) % 2 == 0, 'Please specify an even max_len.'
58
-
59
- paired_embs = self.positional_embeddings.view(len(self.positional_embeddings), -1, 2)
60
- pos_emb = paired_embs[torch.randperm(len(paired_embs))].view(*self.positional_embeddings.shape)[:seq_len]
61
-
62
- return pos_emb.unsqueeze(1).expand(seq_len, bs, d_model) + x #* math.sqrt(x.shape[-1])
63
-
64
-
65
-
66
-
67
-
68
-
69
-
70
-
 
 
 
15
  pass
16
 
17
  def forward(self, x):
18
+ return x # * math.sqrt(x.shape[-1])
19
 
20
 
21
  class PositionalEncoding(nn.Module):
 
23
  super(PositionalEncoding, self).__init__()
24
  pe = torch.zeros(max_len, d_model)
25
  position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
26
+ div_term = torch.exp(
27
+ torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
28
+ )
29
  pe[:, 0::2] = torch.sin(position * div_term)
30
  pe[:, 1::2] = torch.cos(position * div_term)
31
  pe = pe.unsqueeze(0).transpose(0, 1)
32
+ self.register_buffer("pe", pe)
33
 
34
  def forward(self, x):
35
+ x = self.pe[: x.size(0), :] + x # * math.sqrt(x.shape[-1])
36
  return x
37
 
38
 
 
40
  def __init__(self, d_model, max_len=5000):
41
  super(LearnedPositionalEncoding, self).__init__()
42
  self.max_seq_len = max_len
43
+ # self.positional_embeddings = nn.Embedding(max_len, d_model)
44
  self.positional_embeddings = nn.Parameter(torch.empty(max_len, d_model))
45
+ nn.init.normal_(self.positional_embeddings, mean=0, std=d_model**-0.5)
46
 
47
  def forward(self, x):
48
  seq_len, bs, d_model = x.shape
49
+ assert seq_len <= len(
50
+ self.positional_embeddings
51
+ ), "seq_len can be at most max_len."
52
  pos_emb = self.positional_embeddings[:seq_len]
53
+ return (
54
+ pos_emb.unsqueeze(1).expand(seq_len, bs, d_model) + x
55
+ ) # * math.sqrt(x.shape[-1])
56
 
57
 
58
  class PairedScrambledPositionalEncodings(LearnedPositionalEncoding):
59
  # TODO check whether it is a problem to use the same perm. for full batch
60
  def forward(self, x):
61
  seq_len, bs, d_model = x.shape
62
+ assert seq_len <= len(
63
+ self.positional_embeddings
64
+ ), "seq_len can be at most max_len."
65
+ assert (
66
+ len(self.positional_embeddings) % 2 == 0
67
+ ), "Please specify an even max_len."
68
+
69
+ paired_embs = self.positional_embeddings.view(
70
+ len(self.positional_embeddings), -1, 2
71
+ )
72
+ pos_emb = paired_embs[torch.randperm(len(paired_embs))].view(
73
+ *self.positional_embeddings.shape
74
+ )[:seq_len]
75
+
76
+ return (
77
+ pos_emb.unsqueeze(1).expand(seq_len, bs, d_model) + x
78
+ ) # * math.sqrt(x.shape[-1])
lcpfn/train.py CHANGED
@@ -1,12 +1,7 @@
1
- import os
2
  import itertools
3
- import argparse
4
  import time
5
- import datetime
6
- import yaml
7
  from contextlib import nullcontext
8
 
9
- import pickle
10
  import torch
11
  from torch import nn
12
 
@@ -14,18 +9,11 @@ from lcpfn import utils
14
  from lcpfn.transformer import TransformerModel
15
  from lcpfn.bar_distribution import (
16
  BarDistribution,
17
- FullSupportBarDistribution,
18
- get_bucket_limits,
19
  )
20
  from lcpfn.utils import (
21
  get_cosine_schedule_with_warmup,
22
  get_openai_lr,
23
- StoreDictKeyPair,
24
- get_weighted_single_eval_pos_sampler,
25
- get_uniform_single_eval_pos_sampler,
26
  )
27
- from lcpfn import priors
28
- from lcpfn import encoders
29
  from lcpfn import positional_encodings
30
  from lcpfn.utils import init_dist
31
  from torch.cuda.amp import autocast, GradScaler
@@ -294,7 +282,6 @@ def train(
294
  list_losses = []
295
  try:
296
  for epoch in range(1, epochs + 1) if epochs is not None else itertools.count(1):
297
-
298
  epoch_start_time = time.time()
299
  (
300
  total_loss,
@@ -347,256 +334,3 @@ def train(
347
  torch.save(model.to("cpu"), output_path)
348
  print("Checkpoint stored at ", output_path)
349
  return total_loss, total_positional_losses, model.to("cpu"), dl
350
-
351
-
352
- def _parse_args(config_parser, parser):
353
- # Do we have a config file to parse?
354
- args_config, remaining = config_parser.parse_known_args()
355
- if args_config.config:
356
- with open(args_config.config, "r") as f:
357
- cfg = yaml.safe_load(f)
358
- parser.set_defaults(**cfg)
359
-
360
- # The main arg parser parses the rest of the args, the usual
361
- # defaults will have been overridden if config file specified.
362
- args = parser.parse_args(remaining)
363
-
364
- # Cache the args as a text string to save them in the output dir later
365
- args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
366
- return args, args_text
367
-
368
-
369
- if __name__ == "__main__":
370
- config_parser = argparse.ArgumentParser(
371
- description="Only used as a first parser for the config file path."
372
- )
373
- config_parser.add_argument("--config")
374
- parser = argparse.ArgumentParser()
375
- parser.add_argument("prior")
376
- parser.add_argument("--loss_function", default="barnll")
377
- # Optional Arg's for `--loss_function barnll`
378
- parser.add_argument(
379
- "--min_y",
380
- type=float,
381
- help="barnll can only model y in strict ranges, this is the minimum y can take.",
382
- )
383
- parser.add_argument(
384
- "--max_y",
385
- type=float,
386
- help="barnll can only model y in strict ranges, this is the maximum y can take.",
387
- )
388
- parser.add_argument("--num_buckets", default=100, type=int)
389
- # parser.add_argument('--num_features', default=None, type=int, help='Specify depending on the prior.')
390
- parser.add_argument(
391
- "--extra_prior_kwargs_dict",
392
- default={},
393
- dest="extra_prior_kwargs_dict",
394
- action=StoreDictKeyPair,
395
- nargs="+",
396
- metavar="KEY=VAL",
397
- help="Specify depending on the prior.",
398
- )
399
- parser.add_argument(
400
- "--encoder", default="linear", type=str, help="Specify depending on the prior."
401
- )
402
- parser.add_argument(
403
- "--y_encoder",
404
- default="linear",
405
- type=str,
406
- help="Specify depending on the prior. You should specify this if you do not fuse x and y.",
407
- )
408
- parser.add_argument(
409
- "--pos_encoder",
410
- default="none",
411
- type=str,
412
- help="Specify depending on the prior.",
413
- )
414
- parser.add_argument("--bptt", default=10, type=int)
415
- parser.add_argument("--epochs", default=200, type=int)
416
- parser.add_argument("--warmup_epochs", default=50, type=int)
417
- parser.add_argument("--validation_period", default=10, type=int)
418
- parser.add_argument(
419
- "--permutation_invariant_max_eval_pos",
420
- default=None,
421
- type=int,
422
- help="Set this to an int to ",
423
- )
424
- parser.add_argument(
425
- "--permutation_invariant_sampling",
426
- default="weighted",
427
- help="Only relevant if --permutation_invariant_max_eval_pos is set.",
428
- )
429
- parser.add_argument("--train_mixed_precision", action="store_true")
430
-
431
- # these can likely be mostly left at defaults
432
- parser.add_argument(
433
- "--emsize", default=512, type=int
434
- ) # sometimes even larger is better e.g. 1024
435
- parser.add_argument("--nlayers", default=6, type=int)
436
- parser.add_argument("--nhid", default=None, type=int) # 2*emsize is the default
437
- parser.add_argument(
438
- "--nhead", default=4, type=int
439
- ) # nhead = emsize / 64 in the original paper
440
- parser.add_argument("--dropout", default=0.0, type=float)
441
- parser.add_argument("--steps_per_epoch", default=10, type=int)
442
- parser.add_argument("--batch_size", default=1000, type=int)
443
- parser.add_argument(
444
- "--lr", "--learning_rate", default=0.001, type=float
445
- ) # try also .0003, .0001, go lower with lower batch size
446
- parser.add_argument("--gpu_device", default="cuda", type=str)
447
-
448
- # for model checkpointing
449
- parser.add_argument(
450
- "--checkpoint_file",
451
- help="absolute or relative-to-the-project-rootdir path to the file storing the state dicts.",
452
- default=None,
453
- type=str,
454
- )
455
- parser.add_argument("--saving_period", default=10, type=str)
456
-
457
- args, _ = _parse_args(config_parser, parser)
458
-
459
- if args.nhid is None:
460
- args.nhid = 2 * args.emsize
461
-
462
- prior = args.__dict__.pop("prior")
463
-
464
- if prior == "gp":
465
- prior = priors.fast_gp.DataLoader
466
- elif prior == "ridge":
467
- prior = priors.ridge.DataLoader
468
- elif prior == "stroke":
469
- prior = priors.stroke.DataLoader
470
- elif prior == "mix_gp":
471
- prior = priors.fast_gp_mix.DataLoader
472
- else:
473
- raise NotImplementedError(f"Prior == {prior}.")
474
-
475
- loss_function = args.__dict__.pop("loss_function")
476
-
477
- criterion = nn.GaussianNLLLoss(reduction="none", full=True)
478
- classificiation_criterion = nn.CrossEntropyLoss(reduction="none")
479
- num_buckets = args.__dict__.pop("num_buckets")
480
- max_y = args.__dict__.pop("max_y")
481
- min_y = args.__dict__.pop("min_y")
482
- # criterion = nn.MSELoss(reduction='none')
483
-
484
- device = args.gpu_device if torch.cuda.is_available() else "cpu:0"
485
-
486
- def get_y_sample():
487
- args.__dict__["extra_prior_kwargs_dict"]["eval_pos_seq_len_sampler"] = lambda: (
488
- args.bptt,
489
- args.bptt,
490
- )
491
- dl = prior(
492
- num_steps=1,
493
- batch_size=args.batch_size * args.steps_per_epoch,
494
- seq_len=args.bptt,
495
- device=device,
496
- **args.extra_prior_kwargs_dict,
497
- )
498
- args.__dict__["extra_prior_kwargs_dict"].pop("eval_pos_seq_len_sampler")
499
-
500
- y_sample = next(iter(dl))[-2]
501
- print(
502
- f"Creating Bar distribution with borders from y sample of size {y_sample.numel()}"
503
- )
504
- return y_sample
505
-
506
- if loss_function == "ce":
507
- criterion = nn.CrossEntropyLoss(reduction="none")
508
- elif loss_function == "gaussnll":
509
- criterion = nn.GaussianNLLLoss(reduction="none", full=True)
510
- elif loss_function == "mse":
511
- criterion = nn.MSELoss(reduction="none")
512
- elif loss_function == "barnll":
513
- criterion = BarDistribution(
514
- borders=get_bucket_limits(num_buckets, full_range=(min_y, max_y))
515
- )
516
- elif loss_function == "adaptivebarnll":
517
- borders = get_bucket_limits(
518
- num_buckets, ys=get_y_sample(), full_range=(min_y, max_y)
519
- )
520
- criterion = BarDistribution(borders=borders)
521
- elif loss_function == "adaptivefullsupportbarnll":
522
- assert (
523
- min_y is None and max_y is None
524
- ), "Please do not specify `min_y` and `max_y` with `unboundedadaptivebarnll`."
525
- borders = get_bucket_limits(num_buckets, ys=get_y_sample())
526
- criterion = FullSupportBarDistribution(borders=borders)
527
- else:
528
- raise NotImplementedError(f"loss_function == {loss_function}.")
529
-
530
- encoder = args.__dict__.pop("encoder")
531
- y_encoder = args.__dict__.pop("y_encoder")
532
-
533
- def get_encoder_generator(encoder):
534
- if encoder == "linear":
535
- encoder_generator = encoders.Linear
536
- elif encoder == "mlp":
537
- encoder_generator = encoders.MLP
538
- elif encoder == "positional":
539
- encoder_generator = encoders.Positional
540
- else:
541
- raise NotImplementedError(f"A {encoder} encoder is not valid.")
542
- return encoder_generator
543
-
544
- encoder_generator = get_encoder_generator(encoder)
545
- y_encoder_generator = get_encoder_generator(y_encoder)
546
-
547
- pos_encoder = args.__dict__.pop("pos_encoder")
548
-
549
- if pos_encoder == "none":
550
- pos_encoder_generator = None
551
- elif pos_encoder == "sinus":
552
- pos_encoder_generator = positional_encodings.PositionalEncoding
553
- elif pos_encoder == "learned":
554
- pos_encoder_generator = positional_encodings.LearnedPositionalEncoding
555
- elif pos_encoder == "paired_scrambled_learned":
556
- pos_encoder_generator = positional_encodings.PairedScrambledPositionalEncodings
557
- else:
558
- raise NotImplementedError(f"pos_encoer == {pos_encoder} is not valid.")
559
-
560
- permutation_invariant_max_eval_pos = args.__dict__.pop(
561
- "permutation_invariant_max_eval_pos"
562
- )
563
- permutation_invariant_sampling = args.__dict__.pop("permutation_invariant_sampling")
564
- if permutation_invariant_max_eval_pos is not None:
565
- if permutation_invariant_sampling == "weighted":
566
- get_sampler = get_weighted_single_eval_pos_sampler
567
- elif permutation_invariant_sampling == "uniform":
568
- get_sampler = get_uniform_single_eval_pos_sampler
569
- else:
570
- raise ValueError()
571
- args.__dict__["single_eval_pos_gen"] = get_sampler(
572
- permutation_invariant_max_eval_pos
573
- )
574
-
575
- print("ARGS for `train`:", args.__dict__)
576
-
577
- if args.__dict__["checkpoint_file"] is not None:
578
- rootdir = os.path.dirname(os.path.realpath(__file__))
579
- args.__dict__["checkpoint_file"] = os.path.join(
580
- rootdir, args.__dict__["checkpoint_file"]
581
- )
582
-
583
- if os.path.exists(args.__dict__["checkpoint_file"]):
584
- state_dicts = torch.load(args.__dict__["checkpoint_file"])
585
- args.__dict__["load_weights_from_this_state_dict"] = state_dicts[
586
- "model_state_dict"
587
- ]
588
- args.__dict__["load_optimizer_from_this_state_dict"] = state_dicts[
589
- "optimizer_state_dict"
590
- ]
591
- else:
592
- args.__dict__["load_weights_from_this_state_dict"] = None
593
- args.__dict__["load_optimizer_from_this_state_dict"] = None
594
-
595
- train(
596
- prior,
597
- criterion,
598
- encoder_generator,
599
- y_encoder_generator=y_encoder_generator,
600
- pos_encoder_generator=pos_encoder_generator,
601
- **args.__dict__,
602
- )
 
 
1
  import itertools
 
2
  import time
 
 
3
  from contextlib import nullcontext
4
 
 
5
  import torch
6
  from torch import nn
7
 
 
9
  from lcpfn.transformer import TransformerModel
10
  from lcpfn.bar_distribution import (
11
  BarDistribution,
 
 
12
  )
13
  from lcpfn.utils import (
14
  get_cosine_schedule_with_warmup,
15
  get_openai_lr,
 
 
 
16
  )
 
 
17
  from lcpfn import positional_encodings
18
  from lcpfn.utils import init_dist
19
  from torch.cuda.amp import autocast, GradScaler
 
282
  list_losses = []
283
  try:
284
  for epoch in range(1, epochs + 1) if epochs is not None else itertools.count(1):
 
285
  epoch_start_time = time.time()
286
  (
287
  total_loss,
 
334
  torch.save(model.to("cpu"), output_path)
335
  print("Checkpoint stored at ", output_path)
336
  return total_loss, total_positional_losses, model.to("cpu"), dl
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lcpfn/train_lcpfn.py CHANGED
@@ -2,9 +2,11 @@ import math
2
 
3
  from torch import nn
4
 
5
- from lcpfn import bar_distribution, encoders, priors, train
6
  from lcpfn import utils
7
 
 
 
8
 
9
  def train_lcpfn(
10
  get_batch_func,
@@ -12,7 +14,7 @@ def train_lcpfn(
12
  emsize: int = 512,
13
  nlayers: int = 12,
14
  num_borders: int = 1000,
15
- lr: float = 0.001,
16
  batch_size: int = 100,
17
  epochs: int = 1000,
18
  ):
@@ -25,7 +27,7 @@ def train_lcpfn(
25
  emsize (int, optional): The size of the embedding layer. Defaults to 512.
26
  nlayers (int, optional): The number of layers in the model. Defaults to 12.
27
  num_borders_choices (int, optional): The number of borders to use. Defaults to 1000.
28
- lr (float, optional): The learning rate for the optimizer. Defaults to 0.001.
29
  batch_size (int, optional): The batch size for training. Defaults to 100.
30
  epochs (int, optional): The number of epochs to train for. Defaults to 1000.
31
 
@@ -36,7 +38,7 @@ def train_lcpfn(
36
  hps = {}
37
 
38
  # PFN training hyperparameters
39
- dataloader = priors.utils.get_batch_to_dataloader(get_batch_func) # type: ignore
40
 
41
  num_features = 1
42
 
@@ -82,7 +84,9 @@ def train_lcpfn(
82
  epochs=epochs,
83
  lr=lr,
84
  bptt=seq_len,
85
- single_eval_pos_gen=utils.get_uniform_single_eval_pos_sampler(seq_len, min_len=1),
 
 
86
  aggregate_k_gradients=1,
87
  nhid=(emsize * 2),
88
  steps_per_epoch=100,
 
2
 
3
  from torch import nn
4
 
5
+ from lcpfn import bar_distribution, encoders, train
6
  from lcpfn import utils
7
 
8
+ from lcpfn.priors import utils as putils
9
+
10
 
11
  def train_lcpfn(
12
  get_batch_func,
 
14
  emsize: int = 512,
15
  nlayers: int = 12,
16
  num_borders: int = 1000,
17
+ lr: float = 0.0001,
18
  batch_size: int = 100,
19
  epochs: int = 1000,
20
  ):
 
27
  emsize (int, optional): The size of the embedding layer. Defaults to 512.
28
  nlayers (int, optional): The number of layers in the model. Defaults to 12.
29
  num_borders_choices (int, optional): The number of borders to use. Defaults to 1000.
30
+ lr (float, optional): The learning rate for the optimizer. Defaults to 0.0001.
31
  batch_size (int, optional): The batch size for training. Defaults to 100.
32
  epochs (int, optional): The number of epochs to train for. Defaults to 1000.
33
 
 
38
  hps = {}
39
 
40
  # PFN training hyperparameters
41
+ dataloader = putils.get_batch_to_dataloader(get_batch_func) # type: ignore
42
 
43
  num_features = 1
44
 
 
84
  epochs=epochs,
85
  lr=lr,
86
  bptt=seq_len,
87
+ single_eval_pos_gen=utils.get_uniform_single_eval_pos_sampler(
88
+ seq_len, min_len=1
89
+ ),
90
  aggregate_k_gradients=1,
91
  nhid=(emsize * 2),
92
  steps_per_epoch=100,
lcpfn/transformer.py CHANGED
@@ -4,35 +4,74 @@ from typing import Optional
4
  import torch
5
  import torch.nn as nn
6
  from torch import Tensor
 
7
  from torch.nn import Module, TransformerEncoder
8
 
9
  from lcpfn.layer import TransformerEncoderLayer, _get_activation_fn
10
  from lcpfn.utils import SeqBN, bool_mask_to_att_mask
11
 
12
 
 
 
 
 
13
 
14
  class TransformerModel(nn.Module):
15
- def __init__(self, encoder, n_out, ninp, nhead, nhid, nlayers, dropout=0.0, style_encoder=None, y_encoder=None,
16
- pos_encoder=None, decoder=None, input_normalization=False, init_method=None, pre_norm=False,
17
- activation='gelu', recompute_attn=False, num_global_att_tokens=0, full_attention=False,
18
- all_layers_same_init=True):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  super().__init__()
20
- self.model_type = 'Transformer'
21
- encoder_layer_creator = lambda: TransformerEncoderLayer(ninp, nhead, nhid, dropout, activation=activation,
22
- pre_norm=pre_norm, recompute_attn=recompute_attn)
23
- self.transformer_encoder = TransformerEncoder(encoder_layer_creator(), nlayers)\
24
- if all_layers_same_init else TransformerEncoderDiffInit(encoder_layer_creator, nlayers)
 
 
 
 
 
 
 
 
 
 
25
  self.ninp = ninp
26
  self.encoder = encoder
27
  self.y_encoder = y_encoder
28
  self.pos_encoder = pos_encoder
29
- self.decoder = decoder(ninp, nhid, n_out) if decoder is not None else nn.Sequential(nn.Linear(ninp, nhid), nn.GELU(), nn.Linear(nhid, n_out))
 
 
 
 
30
  self.input_ln = SeqBN(ninp) if input_normalization else None
31
  self.style_encoder = style_encoder
32
  self.init_method = init_method
33
  if num_global_att_tokens is not None:
34
  assert not full_attention
35
- self.global_att_embeddings = nn.Embedding(num_global_att_tokens, ninp) if num_global_att_tokens else None
 
 
36
  self.full_attention = full_attention
37
 
38
  self.n_out = n_out
@@ -47,37 +86,49 @@ class TransformerModel(nn.Module):
47
 
48
  @staticmethod
49
  def generate_D_q_matrix(sz, query_size):
50
- train_size = sz-query_size
51
- mask = torch.zeros(sz,sz) == 0
52
- mask[:,train_size:].zero_()
53
  mask |= torch.eye(sz) == 1
54
  return bool_mask_to_att_mask(mask)
55
 
56
  @staticmethod
57
- def generate_global_att_query_matrix(num_global_att_tokens, seq_len, num_query_tokens):
 
 
58
  train_size = seq_len + num_global_att_tokens - num_query_tokens
59
  sz = seq_len + num_global_att_tokens
60
  mask = torch.zeros(num_query_tokens, sz) == 0
61
- mask[:,train_size:].zero_()
62
- mask[:,train_size:] |= torch.eye(num_query_tokens) == 1
63
  return bool_mask_to_att_mask(mask)
64
 
65
  @staticmethod
66
- def generate_global_att_trainset_matrix(num_global_att_tokens, seq_len, num_query_tokens):
 
 
67
  train_size = seq_len + num_global_att_tokens - num_query_tokens
68
  trainset_size = seq_len - num_query_tokens
69
  mask = torch.zeros(trainset_size, num_global_att_tokens) == 0
70
- #mask[:,num_global_att_tokens:].zero_()
71
- #mask[:,num_global_att_tokens:] |= torch.eye(trainset_size) == 1
72
  return bool_mask_to_att_mask(mask)
73
 
74
  @staticmethod
75
- def generate_global_att_globaltokens_matrix(num_global_att_tokens, seq_len, num_query_tokens):
76
- mask = torch.zeros(num_global_att_tokens, num_global_att_tokens+seq_len-num_query_tokens) == 0
 
 
 
 
 
 
 
 
77
  return bool_mask_to_att_mask(mask)
78
 
79
  def init_weights(self):
80
- initrange = 1.
81
  # if isinstance(self.encoder,EmbeddingEncoder):
82
  # self.encoder.weight.data.uniform_(-initrange, initrange)
83
  # self.decoder.bias.data.zero_()
@@ -87,41 +138,74 @@ class TransformerModel(nn.Module):
87
  for layer in self.transformer_encoder.layers:
88
  nn.init.zeros_(layer.linear2.weight)
89
  nn.init.zeros_(layer.linear2.bias)
90
- attns = layer.self_attn if isinstance(layer.self_attn, nn.ModuleList) else [layer.self_attn]
 
 
 
 
91
  for attn in attns:
92
  nn.init.zeros_(attn.out_proj.weight)
93
  nn.init.zeros_(attn.out_proj.bias)
94
 
95
  def forward(self, src, src_mask=None, single_eval_pos=None):
96
- assert isinstance(src, tuple), 'inputs (src) have to be given as (x,y) or (style,x,y) tuple'
 
 
97
 
98
- if len(src) == 2: # (x,y) and no style
99
  src = (None,) + src
100
 
101
  style_src, style_src_size = (src[0], (0 if (src[0] is None) else 1))
102
- if src_mask is not None: assert self.global_att_embeddings is None or isinstance(src_mask, tuple)
 
103
  if src_mask is None:
104
  x_src = src[1]
105
  if self.global_att_embeddings is None:
106
  full_len = len(x_src) + style_src_size
107
  if self.full_attention:
108
- src_mask = bool_mask_to_att_mask(torch.ones((full_len, full_len), dtype=torch.bool)).to(x_src.device)
 
 
109
  else:
110
- src_mask = self.generate_D_q_matrix(len(x_src) + style_src_size, len(x_src) + style_src_size -single_eval_pos).to(x_src.device)
 
 
 
111
  else:
112
- src_mask_args = (self.global_att_embeddings.num_embeddings,
113
- len(x_src) + style_src_size,
114
- len(x_src) + style_src_size - single_eval_pos)
115
- src_mask = (self.generate_global_att_globaltokens_matrix(*src_mask_args).to(x_src.device),
116
- self.generate_global_att_trainset_matrix(*src_mask_args).to(x_src.device),
117
- self.generate_global_att_query_matrix(*src_mask_args).to(x_src.device))
 
 
 
 
 
 
 
 
 
 
118
 
119
  style_src, x_src, y_src = src
120
  x_src = self.encoder(x_src)
121
- y_src = self.y_encoder(y_src.unsqueeze(-1) if len(y_src.shape) < len(x_src.shape) else y_src)
122
- style_src = self.style_encoder(style_src).unsqueeze(0) if self.style_encoder else torch.tensor([], device=x_src.device)
123
- global_src = torch.tensor([], device=x_src.device) if self.global_att_embeddings is None else \
124
- self.global_att_embeddings.weight.unsqueeze(1).repeat(1, x_src.shape[1], 1)
 
 
 
 
 
 
 
 
 
 
 
125
  train_x = x_src[:single_eval_pos] + y_src[:single_eval_pos]
126
  src = torch.cat([global_src, style_src, train_x, x_src[single_eval_pos:]], 0)
127
 
@@ -134,16 +218,29 @@ class TransformerModel(nn.Module):
134
  # If we have style input, drop its output
135
  output = self.transformer_encoder(src, src_mask)[style_src_size:]
136
  output = self.decoder(output)
137
- return output[single_eval_pos+(self.global_att_embeddings.num_embeddings if self.global_att_embeddings else 0):]
 
 
 
 
 
 
 
138
 
139
  @torch.no_grad()
140
  def init_from_small_model(self, small_model):
141
- assert isinstance(self.decoder, nn.Linear) and isinstance(self.encoder, (nn.Linear, nn.Sequential)) \
142
- and isinstance(self.y_encoder, (nn.Linear, nn.Sequential))
 
 
 
143
 
144
  def set_encoder_weights(my_encoder, small_model_encoder):
145
- my_encoder_linear, small_encoder_linear = (my_encoder, small_model_encoder) \
146
- if isinstance(my_encoder, nn.Linear) else (my_encoder[-1], small_model_encoder[-1])
 
 
 
147
  small_in_dim = small_encoder_linear.out_features
148
  my_encoder_linear.weight.zero_()
149
  my_encoder_linear.bias.zero_()
@@ -158,7 +255,9 @@ class TransformerModel(nn.Module):
158
  self.decoder.weight[:, :small_in_dim] = small_model.decoder.weight
159
  self.decoder.bias = small_model.decoder.bias
160
 
161
- for my_layer, small_layer in zip(self.transformer_encoder.layers, small_model.transformer_encoder.layers):
 
 
162
  small_hid_dim = small_layer.linear1.out_features
163
  my_in_dim = my_layer.linear1.in_features
164
 
@@ -166,23 +265,36 @@ class TransformerModel(nn.Module):
166
  my_in_proj_w = my_layer.self_attn.in_proj_weight
167
  small_in_proj_w = small_layer.self_attn.in_proj_weight
168
 
169
- my_in_proj_w.view(3, my_in_dim, my_in_dim)[:, :small_in_dim, :small_in_dim] = small_in_proj_w.view(3,
170
- small_in_dim,
171
- small_in_dim)
172
- my_layer.self_attn.in_proj_bias.view(3, my_in_dim)[:,
173
- :small_in_dim] = small_layer.self_attn.in_proj_bias.view(3, small_in_dim)
174
-
175
- my_layer.self_attn.out_proj.weight[:small_in_dim, :small_in_dim] = small_layer.self_attn.out_proj.weight
176
- my_layer.self_attn.out_proj.bias[:small_in_dim] = small_layer.self_attn.out_proj.bias
177
-
178
- my_layer.linear1.weight[:small_hid_dim, :small_in_dim] = small_layer.linear1.weight
 
 
 
 
 
 
 
179
  my_layer.linear1.bias[:small_hid_dim] = small_layer.linear1.bias
180
 
181
- my_layer.linear2.weight[:small_in_dim, :small_hid_dim] = small_layer.linear2.weight
 
 
182
  my_layer.linear2.bias[:small_in_dim] = small_layer.linear2.bias
183
 
184
- my_layer.norm1.weight[:small_in_dim] = math.sqrt(small_in_dim / my_in_dim) * small_layer.norm1.weight
185
- my_layer.norm2.weight[:small_in_dim] = math.sqrt(small_in_dim / my_in_dim) * small_layer.norm2.weight
 
 
 
 
186
 
187
  my_layer.norm1.bias[:small_in_dim] = small_layer.norm1.bias
188
  my_layer.norm2.bias[:small_in_dim] = small_layer.norm2.bias
@@ -196,15 +308,23 @@ class TransformerEncoderDiffInit(Module):
196
  num_layers: the number of sub-encoder-layers in the encoder (required).
197
  norm: the layer normalization component (optional).
198
  """
199
- __constants__ = ['norm']
 
200
 
201
  def __init__(self, encoder_layer_creator, num_layers, norm=None):
202
  super().__init__()
203
- self.layers = nn.ModuleList([encoder_layer_creator() for _ in range(num_layers)])
 
 
204
  self.num_layers = num_layers
205
  self.norm = norm
206
 
207
- def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
 
 
 
 
 
208
  r"""Pass the input through the encoder layers in turn.
209
 
210
  Args:
@@ -218,7 +338,9 @@ class TransformerEncoderDiffInit(Module):
218
  output = src
219
 
220
  for mod in self.layers:
221
- output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
 
 
222
 
223
  if self.norm is not None:
224
  output = self.norm(output)
 
4
  import torch
5
  import torch.nn as nn
6
  from torch import Tensor
7
+ import torch.nn.functional as F
8
  from torch.nn import Module, TransformerEncoder
9
 
10
  from lcpfn.layer import TransformerEncoderLayer, _get_activation_fn
11
  from lcpfn.utils import SeqBN, bool_mask_to_att_mask
12
 
13
 
14
+ class GELU(nn.Module):
15
+ def forward(self, input: Tensor) -> Tensor:
16
+ return F.gelu(input)
17
+
18
 
19
  class TransformerModel(nn.Module):
20
+ def __init__(
21
+ self,
22
+ encoder,
23
+ n_out,
24
+ ninp,
25
+ nhead,
26
+ nhid,
27
+ nlayers,
28
+ dropout=0.0,
29
+ style_encoder=None,
30
+ y_encoder=None,
31
+ pos_encoder=None,
32
+ decoder=None,
33
+ input_normalization=False,
34
+ init_method=None,
35
+ pre_norm=False,
36
+ activation="gelu",
37
+ recompute_attn=False,
38
+ num_global_att_tokens=0,
39
+ full_attention=False,
40
+ all_layers_same_init=True,
41
+ ):
42
  super().__init__()
43
+ self.model_type = "Transformer"
44
+ encoder_layer_creator = lambda: TransformerEncoderLayer(
45
+ ninp,
46
+ nhead,
47
+ nhid,
48
+ dropout,
49
+ activation=activation,
50
+ pre_norm=pre_norm,
51
+ recompute_attn=recompute_attn,
52
+ )
53
+ self.transformer_encoder = (
54
+ TransformerEncoder(encoder_layer_creator(), nlayers)
55
+ if all_layers_same_init
56
+ else TransformerEncoderDiffInit(encoder_layer_creator, nlayers)
57
+ )
58
  self.ninp = ninp
59
  self.encoder = encoder
60
  self.y_encoder = y_encoder
61
  self.pos_encoder = pos_encoder
62
+ self.decoder = (
63
+ decoder(ninp, nhid, n_out)
64
+ if decoder is not None
65
+ else nn.Sequential(nn.Linear(ninp, nhid), GELU(), nn.Linear(nhid, n_out))
66
+ )
67
  self.input_ln = SeqBN(ninp) if input_normalization else None
68
  self.style_encoder = style_encoder
69
  self.init_method = init_method
70
  if num_global_att_tokens is not None:
71
  assert not full_attention
72
+ self.global_att_embeddings = (
73
+ nn.Embedding(num_global_att_tokens, ninp) if num_global_att_tokens else None
74
+ )
75
  self.full_attention = full_attention
76
 
77
  self.n_out = n_out
 
86
 
87
  @staticmethod
88
  def generate_D_q_matrix(sz, query_size):
89
+ train_size = sz - query_size
90
+ mask = torch.zeros(sz, sz) == 0
91
+ mask[:, train_size:].zero_()
92
  mask |= torch.eye(sz) == 1
93
  return bool_mask_to_att_mask(mask)
94
 
95
  @staticmethod
96
+ def generate_global_att_query_matrix(
97
+ num_global_att_tokens, seq_len, num_query_tokens
98
+ ):
99
  train_size = seq_len + num_global_att_tokens - num_query_tokens
100
  sz = seq_len + num_global_att_tokens
101
  mask = torch.zeros(num_query_tokens, sz) == 0
102
+ mask[:, train_size:].zero_()
103
+ mask[:, train_size:] |= torch.eye(num_query_tokens) == 1
104
  return bool_mask_to_att_mask(mask)
105
 
106
  @staticmethod
107
+ def generate_global_att_trainset_matrix(
108
+ num_global_att_tokens, seq_len, num_query_tokens
109
+ ):
110
  train_size = seq_len + num_global_att_tokens - num_query_tokens
111
  trainset_size = seq_len - num_query_tokens
112
  mask = torch.zeros(trainset_size, num_global_att_tokens) == 0
113
+ # mask[:,num_global_att_tokens:].zero_()
114
+ # mask[:,num_global_att_tokens:] |= torch.eye(trainset_size) == 1
115
  return bool_mask_to_att_mask(mask)
116
 
117
  @staticmethod
118
+ def generate_global_att_globaltokens_matrix(
119
+ num_global_att_tokens, seq_len, num_query_tokens
120
+ ):
121
+ mask = (
122
+ torch.zeros(
123
+ num_global_att_tokens,
124
+ num_global_att_tokens + seq_len - num_query_tokens,
125
+ )
126
+ == 0
127
+ )
128
  return bool_mask_to_att_mask(mask)
129
 
130
  def init_weights(self):
131
+ initrange = 1.0
132
  # if isinstance(self.encoder,EmbeddingEncoder):
133
  # self.encoder.weight.data.uniform_(-initrange, initrange)
134
  # self.decoder.bias.data.zero_()
 
138
  for layer in self.transformer_encoder.layers:
139
  nn.init.zeros_(layer.linear2.weight)
140
  nn.init.zeros_(layer.linear2.bias)
141
+ attns = (
142
+ layer.self_attn
143
+ if isinstance(layer.self_attn, nn.ModuleList)
144
+ else [layer.self_attn]
145
+ )
146
  for attn in attns:
147
  nn.init.zeros_(attn.out_proj.weight)
148
  nn.init.zeros_(attn.out_proj.bias)
149
 
150
  def forward(self, src, src_mask=None, single_eval_pos=None):
151
+ assert isinstance(
152
+ src, tuple
153
+ ), "inputs (src) have to be given as (x,y) or (style,x,y) tuple"
154
 
155
+ if len(src) == 2: # (x,y) and no style
156
  src = (None,) + src
157
 
158
  style_src, style_src_size = (src[0], (0 if (src[0] is None) else 1))
159
+ if src_mask is not None:
160
+ assert self.global_att_embeddings is None or isinstance(src_mask, tuple)
161
  if src_mask is None:
162
  x_src = src[1]
163
  if self.global_att_embeddings is None:
164
  full_len = len(x_src) + style_src_size
165
  if self.full_attention:
166
+ src_mask = bool_mask_to_att_mask(
167
+ torch.ones((full_len, full_len), dtype=torch.bool)
168
+ ).to(x_src.device)
169
  else:
170
+ src_mask = self.generate_D_q_matrix(
171
+ len(x_src) + style_src_size,
172
+ len(x_src) + style_src_size - single_eval_pos,
173
+ ).to(x_src.device)
174
  else:
175
+ src_mask_args = (
176
+ self.global_att_embeddings.num_embeddings,
177
+ len(x_src) + style_src_size,
178
+ len(x_src) + style_src_size - single_eval_pos,
179
+ )
180
+ src_mask = (
181
+ self.generate_global_att_globaltokens_matrix(*src_mask_args).to(
182
+ x_src.device
183
+ ),
184
+ self.generate_global_att_trainset_matrix(*src_mask_args).to(
185
+ x_src.device
186
+ ),
187
+ self.generate_global_att_query_matrix(*src_mask_args).to(
188
+ x_src.device
189
+ ),
190
+ )
191
 
192
  style_src, x_src, y_src = src
193
  x_src = self.encoder(x_src)
194
+ y_src = self.y_encoder(
195
+ y_src.unsqueeze(-1) if len(y_src.shape) < len(x_src.shape) else y_src
196
+ )
197
+ style_src = (
198
+ self.style_encoder(style_src).unsqueeze(0)
199
+ if self.style_encoder
200
+ else torch.tensor([], device=x_src.device)
201
+ )
202
+ global_src = (
203
+ torch.tensor([], device=x_src.device)
204
+ if self.global_att_embeddings is None
205
+ else self.global_att_embeddings.weight.unsqueeze(1).repeat(
206
+ 1, x_src.shape[1], 1
207
+ )
208
+ )
209
  train_x = x_src[:single_eval_pos] + y_src[:single_eval_pos]
210
  src = torch.cat([global_src, style_src, train_x, x_src[single_eval_pos:]], 0)
211
 
 
218
  # If we have style input, drop its output
219
  output = self.transformer_encoder(src, src_mask)[style_src_size:]
220
  output = self.decoder(output)
221
+ return output[
222
+ single_eval_pos
223
+ + (
224
+ self.global_att_embeddings.num_embeddings
225
+ if self.global_att_embeddings
226
+ else 0
227
+ ) :
228
+ ]
229
 
230
  @torch.no_grad()
231
  def init_from_small_model(self, small_model):
232
+ assert (
233
+ isinstance(self.decoder, nn.Linear)
234
+ and isinstance(self.encoder, (nn.Linear, nn.Sequential))
235
+ and isinstance(self.y_encoder, (nn.Linear, nn.Sequential))
236
+ )
237
 
238
  def set_encoder_weights(my_encoder, small_model_encoder):
239
+ my_encoder_linear, small_encoder_linear = (
240
+ (my_encoder, small_model_encoder)
241
+ if isinstance(my_encoder, nn.Linear)
242
+ else (my_encoder[-1], small_model_encoder[-1])
243
+ )
244
  small_in_dim = small_encoder_linear.out_features
245
  my_encoder_linear.weight.zero_()
246
  my_encoder_linear.bias.zero_()
 
255
  self.decoder.weight[:, :small_in_dim] = small_model.decoder.weight
256
  self.decoder.bias = small_model.decoder.bias
257
 
258
+ for my_layer, small_layer in zip(
259
+ self.transformer_encoder.layers, small_model.transformer_encoder.layers
260
+ ):
261
  small_hid_dim = small_layer.linear1.out_features
262
  my_in_dim = my_layer.linear1.in_features
263
 
 
265
  my_in_proj_w = my_layer.self_attn.in_proj_weight
266
  small_in_proj_w = small_layer.self_attn.in_proj_weight
267
 
268
+ my_in_proj_w.view(3, my_in_dim, my_in_dim)[
269
+ :, :small_in_dim, :small_in_dim
270
+ ] = small_in_proj_w.view(3, small_in_dim, small_in_dim)
271
+ my_layer.self_attn.in_proj_bias.view(3, my_in_dim)[:, :small_in_dim] = (
272
+ small_layer.self_attn.in_proj_bias.view(3, small_in_dim)
273
+ )
274
+
275
+ my_layer.self_attn.out_proj.weight[:small_in_dim, :small_in_dim] = (
276
+ small_layer.self_attn.out_proj.weight
277
+ )
278
+ my_layer.self_attn.out_proj.bias[:small_in_dim] = (
279
+ small_layer.self_attn.out_proj.bias
280
+ )
281
+
282
+ my_layer.linear1.weight[:small_hid_dim, :small_in_dim] = (
283
+ small_layer.linear1.weight
284
+ )
285
  my_layer.linear1.bias[:small_hid_dim] = small_layer.linear1.bias
286
 
287
+ my_layer.linear2.weight[:small_in_dim, :small_hid_dim] = (
288
+ small_layer.linear2.weight
289
+ )
290
  my_layer.linear2.bias[:small_in_dim] = small_layer.linear2.bias
291
 
292
+ my_layer.norm1.weight[:small_in_dim] = (
293
+ math.sqrt(small_in_dim / my_in_dim) * small_layer.norm1.weight
294
+ )
295
+ my_layer.norm2.weight[:small_in_dim] = (
296
+ math.sqrt(small_in_dim / my_in_dim) * small_layer.norm2.weight
297
+ )
298
 
299
  my_layer.norm1.bias[:small_in_dim] = small_layer.norm1.bias
300
  my_layer.norm2.bias[:small_in_dim] = small_layer.norm2.bias
 
308
  num_layers: the number of sub-encoder-layers in the encoder (required).
309
  norm: the layer normalization component (optional).
310
  """
311
+
312
+ __constants__ = ["norm"]
313
 
314
  def __init__(self, encoder_layer_creator, num_layers, norm=None):
315
  super().__init__()
316
+ self.layers = nn.ModuleList(
317
+ [encoder_layer_creator() for _ in range(num_layers)]
318
+ )
319
  self.num_layers = num_layers
320
  self.norm = norm
321
 
322
+ def forward(
323
+ self,
324
+ src: Tensor,
325
+ mask: Optional[Tensor] = None,
326
+ src_key_padding_mask: Optional[Tensor] = None,
327
+ ) -> Tensor:
328
  r"""Pass the input through the encoder layers in turn.
329
 
330
  Args:
 
338
  output = src
339
 
340
  for mod in self.layers:
341
+ output = mod(
342
+ output, src_mask=mask, src_key_padding_mask=src_key_padding_mask
343
+ )
344
 
345
  if self.norm is not None:
346
  output = self.norm(output)
lcpfn/utils.py CHANGED
@@ -9,9 +9,12 @@ from torch import nn
9
  from torch.optim.lr_scheduler import LambdaLR
10
  import numpy as np
11
 
 
12
  # copied from huggingface
13
- def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1):
14
- """ Create a schedule with a learning rate that decreases following the
 
 
15
  values of the cosine function between 0 and `pi * cycles` after a warmup
16
  period during which it increases linearly between 0 and 1.
17
  """
@@ -19,13 +22,20 @@ def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_st
19
  def lr_lambda(current_step):
20
  if current_step < num_warmup_steps:
21
  return float(current_step) / float(max(1, num_warmup_steps))
22
- progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
23
- return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
 
 
 
 
24
 
25
  return LambdaLR(optimizer, lr_lambda, last_epoch)
26
 
 
27
  # copied from huggingface
28
- def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
 
 
29
  """
30
  Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
31
  a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
@@ -48,7 +58,9 @@ def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_st
48
  if current_step < num_warmup_steps:
49
  return float(current_step) / float(max(1, num_warmup_steps))
50
  return max(
51
- 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
 
 
52
  )
53
 
54
  return LambdaLR(optimizer, lr_lambda, last_epoch)
@@ -65,7 +77,9 @@ def get_weighted_single_eval_pos_sampler(max_len):
65
  where p <= `max_len`. At most `max_len` - 1 examples are shown to the Transformer.
66
  :return: Sampler that can be fed to `train()` as `single_eval_pos_gen`.
67
  """
68
- return lambda: random.choices(range(max_len), [1 / (max_len - i) for i in range(max_len)])[0]
 
 
69
 
70
 
71
  def get_uniform_single_eval_pos_sampler(max_len, min_len=0):
@@ -95,19 +109,22 @@ def set_locals_in_self(locals):
95
  Especially useful right at the beginning of `__init__`.
96
  :param locals: `locals()`
97
  """
98
- self = locals['self']
99
  for var_name, val in locals.items():
100
- if var_name != 'self': setattr(self, var_name, val)
 
101
 
102
 
103
- default_device = 'cuda:0' if torch.cuda.is_available() else 'cpu:0'
104
 
105
 
106
  # Copied from StackOverflow, but we do an eval on the values additionally
107
  class StoreDictKeyPair(argparse.Action):
108
  def __init__(self, option_strings, dest, nargs=None, **kwargs):
109
  self._nargs = nargs
110
- super(StoreDictKeyPair, self).__init__(option_strings, dest, nargs=nargs, **kwargs)
 
 
111
 
112
  def __call__(self, parser, namespace, values, option_string=None):
113
  my_dict = {}
@@ -120,16 +137,20 @@ class StoreDictKeyPair(argparse.Action):
120
  setattr(namespace, self.dest, my_dict)
121
  print("dict values: {}".format(my_dict))
122
 
 
123
  def get_nan_value(v, set_value_to_nan=0.0):
124
  if random.random() < set_value_to_nan:
125
  return v
126
  else:
127
  return random.choice([-999, 0, 1, 999])
128
 
 
129
  def to_ranking(data):
130
- x = (data >= data.unsqueeze(-3))
131
  x = x.sum(0)
132
  return x
 
 
133
  # TODO: Is there a better way to do this?
134
  # 1. Cmparing to unique elements: When all values are different we still get quadratic blowup
135
  # 2. Argsort(Argsort()) returns ranking, but with duplicate values there is an ordering which is problematic
@@ -137,49 +158,64 @@ def to_ranking(data):
137
  def to_ranking_low_mem(data):
138
  x = torch.zeros_like(data)
139
  for col in range(data.shape[-1]):
140
- x_ = (data[:, :, col] >= data[:, :, col].unsqueeze(-2))
141
  x_ = x_.sum(0)
142
  x[:, :, col] = x_
143
  return x
144
 
 
145
  def nan_handling_missing_for_unknown_reason_value(set_value_to_nan=0.0):
146
- return get_nan_value(float('nan'), set_value_to_nan)
 
147
 
148
  def nan_handling_missing_for_no_reason_value(set_value_to_nan=0.0):
149
- return get_nan_value(float('-inf'), set_value_to_nan)
 
150
 
151
  def nan_handling_missing_for_a_reason_value(set_value_to_nan=0.0):
152
- return get_nan_value(float('inf'), set_value_to_nan)
 
153
 
154
  def torch_nanmean(x, axis=0):
155
- num = torch.where(torch.isnan(x), torch.full_like(x, 0), torch.full_like(x, 1)).sum(axis=axis)
 
 
156
  value = torch.where(torch.isnan(x), torch.full_like(x, 0), x).sum(axis=axis)
157
  return value / num
158
 
 
159
  def torch_nanstd(x, axis=0):
160
- num = torch.where(torch.isnan(x), torch.full_like(x, 0), torch.full_like(x, 1)).sum(axis=axis)
 
 
161
  value = torch.where(torch.isnan(x), torch.full_like(x, 0), x).sum(axis=axis)
162
  mean = value / num
163
- mean_broadcast = torch.repeat_interleave(mean.unsqueeze(axis), x.shape[axis], dim=axis)
164
- return torch.sqrt(torch.nansum(torch.square(mean_broadcast - x), axis=axis) / (num - 1))
 
 
 
 
 
165
 
166
  def normalize_data(data, normalize_positions=-1):
167
  if normalize_positions > 0:
168
  mean = torch_nanmean(data[:normalize_positions], axis=0)
169
- std = torch_nanstd(data[:normalize_positions], axis=0) + .000001
170
  else:
171
  mean = torch_nanmean(data, axis=0)
172
- std = torch_nanstd(data, axis=0) + .000001
173
  data = (data - mean) / std
174
  data = torch.clip(data, min=-100, max=100)
175
 
176
  return data
177
 
 
178
  def remove_outliers(X, n_sigma=4):
179
  # Expects T, B, H
180
  assert len(X.shape) == 3, "X must be T,B,H"
181
- #for b in range(X.shape[1]):
182
- #for col in range(X.shape[2]):
183
  data = X
184
  data_mean, data_std = torch_nanmean(data, axis=0), torch_nanstd(data, axis=0)
185
  cut_off = data_std * n_sigma
@@ -187,17 +223,26 @@ def remove_outliers(X, n_sigma=4):
187
 
188
  data_clean = X[:].clone()
189
  data_clean[torch.logical_or(data > upper, data < lower)] = np.nan
190
- data_mean, data_std = torch_nanmean(data_clean, axis=0), torch_nanstd(data_clean, axis=0)
 
 
 
191
  cut_off = data_std * n_sigma
192
  lower, upper = data_mean - cut_off, data_mean + cut_off
193
 
194
- X = torch.maximum(-torch.log(1+torch.abs(X)) + lower, X)
195
- X = torch.minimum(torch.log(1+torch.abs(X)) + upper, X)
196
- # print(ds[1][data < lower, col], ds[1][data > upper, col], ds[1][~np.isnan(data), col].shape, data_mean, data_std)
197
  return X
198
 
 
199
  def bool_mask_to_att_mask(mask):
200
- return mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
 
 
 
 
 
201
 
202
  def print_on_master_only(is_master):
203
  import builtins as __builtin__
@@ -213,46 +258,152 @@ def print_on_master_only(is_master):
213
 
214
 
215
  def init_dist(device):
216
- print('init dist')
217
- if 'LOCAL_RANK' in os.environ:
218
  # launched with torch.distributed.launch
219
  rank = int(os.environ["LOCAL_RANK"])
220
- print('torch.distributed.launch and my rank is', rank)
221
  torch.cuda.set_device(rank)
222
- os.environ['CUDA_VISIBLE_DEVICES'] = str(rank)
223
- torch.distributed.init_process_group(backend="nccl", init_method="env://", timeout=datetime.timedelta(seconds=20),
224
- world_size=torch.cuda.device_count(), rank=rank)
 
 
 
 
 
225
  torch.distributed.barrier()
226
  print_on_master_only(rank == 0)
227
- print(f"Distributed training on {torch.cuda.device_count()} GPUs, this is rank {rank}, "
228
- "only I can print, but when using print(..., force=True) it will print on all ranks.")
229
- return True, rank, f'cuda:{rank}'
230
- elif 'SLURM_PROCID' in os.environ and torch.cuda.device_count() > 1:
 
 
231
  # this is for multi gpu when starting with submitit
232
- assert device != 'cpu:0'
233
- rank = int(os.environ['SLURM_PROCID'])
234
- os.environ['MASTER_ADDR'] = 'localhost'
235
- os.environ['MASTER_PORT'] = '12355'
236
  torch.cuda.set_device(rank)
237
- os.environ['CUDA_VISIBLE_DEVICES'] = str(rank)
238
- print('distributed submitit launch and my rank is', rank)
239
- torch.distributed.init_process_group(backend="nccl", init_method="env://", timeout=datetime.timedelta(seconds=20),
240
- world_size=torch.cuda.device_count(), rank=rank)
 
 
 
 
 
241
  torch.distributed.barrier()
242
  print_on_master_only(rank == 0)
243
- print(f"Distributed training on {torch.cuda.device_count()} GPUs, this is rank {rank}, "
244
- "only I can print, but when using print(..., force=True) it will print on all ranks.")
 
 
245
 
246
- return True, rank, f'cuda:{rank}'
247
  else:
248
- print('Not using distributed')
249
  # will not change any of the behavior of print, but allows putting the force=True in the print calls
250
  print_on_master_only(True)
251
  return False, 0, device
252
 
253
 
254
  def check_compatibility(dl):
255
- if hasattr(dl, 'num_outputs'):
256
- print('`num_outputs` for the DataLoader is deprecated. It is assumed to be 1 from now on.')
257
- assert dl.num_outputs != 1, "We assume num_outputs to be 1. Instead of the num_ouputs change your loss." \
258
- "We specify the number of classes in the CE loss."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from torch.optim.lr_scheduler import LambdaLR
10
  import numpy as np
11
 
12
+
13
  # copied from huggingface
14
+ def get_cosine_schedule_with_warmup(
15
+ optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1
16
+ ):
17
+ """Create a schedule with a learning rate that decreases following the
18
  values of the cosine function between 0 and `pi * cycles` after a warmup
19
  period during which it increases linearly between 0 and 1.
20
  """
 
22
  def lr_lambda(current_step):
23
  if current_step < num_warmup_steps:
24
  return float(current_step) / float(max(1, num_warmup_steps))
25
+ progress = float(current_step - num_warmup_steps) / float(
26
+ max(1, num_training_steps - num_warmup_steps)
27
+ )
28
+ return max(
29
+ 0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
30
+ )
31
 
32
  return LambdaLR(optimizer, lr_lambda, last_epoch)
33
 
34
+
35
  # copied from huggingface
36
+ def get_linear_schedule_with_warmup(
37
+ optimizer, num_warmup_steps, num_training_steps, last_epoch=-1
38
+ ):
39
  """
40
  Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
41
  a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
 
58
  if current_step < num_warmup_steps:
59
  return float(current_step) / float(max(1, num_warmup_steps))
60
  return max(
61
+ 0.0,
62
+ float(num_training_steps - current_step)
63
+ / float(max(1, num_training_steps - num_warmup_steps)),
64
  )
65
 
66
  return LambdaLR(optimizer, lr_lambda, last_epoch)
 
77
  where p <= `max_len`. At most `max_len` - 1 examples are shown to the Transformer.
78
  :return: Sampler that can be fed to `train()` as `single_eval_pos_gen`.
79
  """
80
+ return lambda: random.choices(
81
+ range(max_len), [1 / (max_len - i) for i in range(max_len)]
82
+ )[0]
83
 
84
 
85
  def get_uniform_single_eval_pos_sampler(max_len, min_len=0):
 
109
  Especially useful right at the beginning of `__init__`.
110
  :param locals: `locals()`
111
  """
112
+ self = locals["self"]
113
  for var_name, val in locals.items():
114
+ if var_name != "self":
115
+ setattr(self, var_name, val)
116
 
117
 
118
+ default_device = "cuda:0" if torch.cuda.is_available() else "cpu:0"
119
 
120
 
121
  # Copied from StackOverflow, but we do an eval on the values additionally
122
  class StoreDictKeyPair(argparse.Action):
123
  def __init__(self, option_strings, dest, nargs=None, **kwargs):
124
  self._nargs = nargs
125
+ super(StoreDictKeyPair, self).__init__(
126
+ option_strings, dest, nargs=nargs, **kwargs
127
+ )
128
 
129
  def __call__(self, parser, namespace, values, option_string=None):
130
  my_dict = {}
 
137
  setattr(namespace, self.dest, my_dict)
138
  print("dict values: {}".format(my_dict))
139
 
140
+
141
  def get_nan_value(v, set_value_to_nan=0.0):
142
  if random.random() < set_value_to_nan:
143
  return v
144
  else:
145
  return random.choice([-999, 0, 1, 999])
146
 
147
+
148
  def to_ranking(data):
149
+ x = data >= data.unsqueeze(-3)
150
  x = x.sum(0)
151
  return x
152
+
153
+
154
  # TODO: Is there a better way to do this?
155
  # 1. Cmparing to unique elements: When all values are different we still get quadratic blowup
156
  # 2. Argsort(Argsort()) returns ranking, but with duplicate values there is an ordering which is problematic
 
158
  def to_ranking_low_mem(data):
159
  x = torch.zeros_like(data)
160
  for col in range(data.shape[-1]):
161
+ x_ = data[:, :, col] >= data[:, :, col].unsqueeze(-2)
162
  x_ = x_.sum(0)
163
  x[:, :, col] = x_
164
  return x
165
 
166
+
167
  def nan_handling_missing_for_unknown_reason_value(set_value_to_nan=0.0):
168
+ return get_nan_value(float("nan"), set_value_to_nan)
169
+
170
 
171
  def nan_handling_missing_for_no_reason_value(set_value_to_nan=0.0):
172
+ return get_nan_value(float("-inf"), set_value_to_nan)
173
+
174
 
175
  def nan_handling_missing_for_a_reason_value(set_value_to_nan=0.0):
176
+ return get_nan_value(float("inf"), set_value_to_nan)
177
+
178
 
179
  def torch_nanmean(x, axis=0):
180
+ num = torch.where(torch.isnan(x), torch.full_like(x, 0), torch.full_like(x, 1)).sum(
181
+ axis=axis
182
+ )
183
  value = torch.where(torch.isnan(x), torch.full_like(x, 0), x).sum(axis=axis)
184
  return value / num
185
 
186
+
187
  def torch_nanstd(x, axis=0):
188
+ num = torch.where(torch.isnan(x), torch.full_like(x, 0), torch.full_like(x, 1)).sum(
189
+ axis=axis
190
+ )
191
  value = torch.where(torch.isnan(x), torch.full_like(x, 0), x).sum(axis=axis)
192
  mean = value / num
193
+ mean_broadcast = torch.repeat_interleave(
194
+ mean.unsqueeze(axis), x.shape[axis], dim=axis
195
+ )
196
+ return torch.sqrt(
197
+ torch.nansum(torch.square(mean_broadcast - x), axis=axis) / (num - 1)
198
+ )
199
+
200
 
201
  def normalize_data(data, normalize_positions=-1):
202
  if normalize_positions > 0:
203
  mean = torch_nanmean(data[:normalize_positions], axis=0)
204
+ std = torch_nanstd(data[:normalize_positions], axis=0) + 0.000001
205
  else:
206
  mean = torch_nanmean(data, axis=0)
207
+ std = torch_nanstd(data, axis=0) + 0.000001
208
  data = (data - mean) / std
209
  data = torch.clip(data, min=-100, max=100)
210
 
211
  return data
212
 
213
+
214
  def remove_outliers(X, n_sigma=4):
215
  # Expects T, B, H
216
  assert len(X.shape) == 3, "X must be T,B,H"
217
+ # for b in range(X.shape[1]):
218
+ # for col in range(X.shape[2]):
219
  data = X
220
  data_mean, data_std = torch_nanmean(data, axis=0), torch_nanstd(data, axis=0)
221
  cut_off = data_std * n_sigma
 
223
 
224
  data_clean = X[:].clone()
225
  data_clean[torch.logical_or(data > upper, data < lower)] = np.nan
226
+ data_mean, data_std = (
227
+ torch_nanmean(data_clean, axis=0),
228
+ torch_nanstd(data_clean, axis=0),
229
+ )
230
  cut_off = data_std * n_sigma
231
  lower, upper = data_mean - cut_off, data_mean + cut_off
232
 
233
+ X = torch.maximum(-torch.log(1 + torch.abs(X)) + lower, X)
234
+ X = torch.minimum(torch.log(1 + torch.abs(X)) + upper, X)
235
+ # print(ds[1][data < lower, col], ds[1][data > upper, col], ds[1][~np.isnan(data), col].shape, data_mean, data_std)
236
  return X
237
 
238
+
239
  def bool_mask_to_att_mask(mask):
240
+ return (
241
+ mask.float()
242
+ .masked_fill(mask == 0, float("-inf"))
243
+ .masked_fill(mask == 1, float(0.0))
244
+ )
245
+
246
 
247
  def print_on_master_only(is_master):
248
  import builtins as __builtin__
 
258
 
259
 
260
  def init_dist(device):
261
+ print("init dist")
262
+ if "LOCAL_RANK" in os.environ:
263
  # launched with torch.distributed.launch
264
  rank = int(os.environ["LOCAL_RANK"])
265
+ print("torch.distributed.launch and my rank is", rank)
266
  torch.cuda.set_device(rank)
267
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
268
+ torch.distributed.init_process_group(
269
+ backend="nccl",
270
+ init_method="env://",
271
+ timeout=datetime.timedelta(seconds=20),
272
+ world_size=torch.cuda.device_count(),
273
+ rank=rank,
274
+ )
275
  torch.distributed.barrier()
276
  print_on_master_only(rank == 0)
277
+ print(
278
+ f"Distributed training on {torch.cuda.device_count()} GPUs, this is rank {rank}, "
279
+ "only I can print, but when using print(..., force=True) it will print on all ranks."
280
+ )
281
+ return True, rank, f"cuda:{rank}"
282
+ elif "SLURM_PROCID" in os.environ and torch.cuda.device_count() > 1:
283
  # this is for multi gpu when starting with submitit
284
+ assert device != "cpu:0"
285
+ rank = int(os.environ["SLURM_PROCID"])
286
+ os.environ["MASTER_ADDR"] = "localhost"
287
+ os.environ["MASTER_PORT"] = "12355"
288
  torch.cuda.set_device(rank)
289
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
290
+ print("distributed submitit launch and my rank is", rank)
291
+ torch.distributed.init_process_group(
292
+ backend="nccl",
293
+ init_method="env://",
294
+ timeout=datetime.timedelta(seconds=20),
295
+ world_size=torch.cuda.device_count(),
296
+ rank=rank,
297
+ )
298
  torch.distributed.barrier()
299
  print_on_master_only(rank == 0)
300
+ print(
301
+ f"Distributed training on {torch.cuda.device_count()} GPUs, this is rank {rank}, "
302
+ "only I can print, but when using print(..., force=True) it will print on all ranks."
303
+ )
304
 
305
+ return True, rank, f"cuda:{rank}"
306
  else:
307
+ print("Not using distributed")
308
  # will not change any of the behavior of print, but allows putting the force=True in the print calls
309
  print_on_master_only(True)
310
  return False, 0, device
311
 
312
 
313
  def check_compatibility(dl):
314
+ if hasattr(dl, "num_outputs"):
315
+ print(
316
+ "`num_outputs` for the DataLoader is deprecated. It is assumed to be 1 from now on."
317
+ )
318
+ assert dl.num_outputs != 1, (
319
+ "We assume num_outputs to be 1. Instead of the num_ouputs change your loss."
320
+ "We specify the number of classes in the CE loss."
321
+ )
322
+
323
+
324
+ def pfn_normalize(
325
+ lb=torch.tensor(float("-inf")),
326
+ ub=torch.tensor(float("inf")),
327
+ soft_lb=0.0,
328
+ soft_ub=1.0,
329
+ minimize=False,
330
+ ):
331
+ """
332
+ LC-PFN curve prior assumes curves to be normalized within the range [0,1] and to be maximized.
333
+ This function allows to normalize and denormalize data to fit this assumption.
334
+
335
+ Parameters:
336
+ lb (torch.Tensor): Lower bound of the data.
337
+ ub (torch.Tensor): Upper bound of the data.
338
+ soft_lb (float): Soft lower bound for normalization. Default is 0.0.
339
+ soft_ub (float): Soft upper bound for normalization. Default is 1.0.
340
+ minimize (bool): If True, the original curve is a minization. Default is False.
341
+
342
+ Returns: Two functions for normalizing and denormalizing the data.
343
+ """
344
+ assert lb <= soft_lb and soft_lb < soft_ub and soft_ub <= ub
345
+ # step 1: linearly transform [soft_lb,soft_ub] [-1,1] (where the sigmoid behaves approx linearly)
346
+ # 2.0/(soft_ub - soft_lb)*(x - soft_lb) - 1.0
347
+ # step 2: apply a vertically scaled/shifted the sigmoid such that [lb,ub] --> [0,1]
348
+
349
+ def cinv(x):
350
+ return 1 - x if minimize else x
351
+
352
+ def lin_soft(x):
353
+ return 2 / (soft_ub - soft_lb) * (x - soft_lb) - 1
354
+
355
+ def lin_soft_inv(y):
356
+ return (y + 1) / 2 * (soft_ub - soft_lb) + soft_lb
357
+
358
+ try:
359
+ if torch.exp(-lin_soft(lb)) > 1e300:
360
+ raise RuntimeError
361
+ # otherwise overflow causes issues, treat these cases as if the lower bound was -infinite
362
+ # print(f"WARNING: {lb} --> NINF to avoid overflows ({np.exp(-lin_soft(lb))})")
363
+ except RuntimeError:
364
+ lb = torch.tensor(float("-inf"))
365
+ if torch.isinf(lb) and torch.isinf(ub):
366
+ return lambda x: cinv(
367
+ 1 / (1 + torch.exp(-lin_soft(x)))
368
+ ), lambda y: lin_soft_inv(torch.log(cinv(y) / (1 - cinv(y))))
369
+ elif torch.isinf(lb):
370
+ a = 1 + torch.exp(-lin_soft(ub))
371
+ return lambda x: cinv(
372
+ a / (1 + torch.exp(-lin_soft(x)))
373
+ ), lambda y: lin_soft_inv(torch.log((cinv(y) / a) / (1 - (cinv(y) / a))))
374
+ elif torch.isinf(ub):
375
+ a = 1 / (1 - 1 / (1 + torch.exp(-lin_soft(lb))))
376
+ b = 1 - a
377
+ return lambda x: cinv(
378
+ a / (1 + torch.exp(-lin_soft(x))) + b
379
+ ), lambda y: lin_soft_inv(
380
+ torch.log(((cinv(y) - b) / a) / (1 - ((cinv(y) - b) / a)))
381
+ )
382
+ else:
383
+ a = (
384
+ 1
385
+ + torch.exp(-lin_soft(ub))
386
+ + torch.exp(-lin_soft(lb))
387
+ + torch.exp(-lin_soft(ub) - lin_soft(lb))
388
+ ) / (torch.exp(-lin_soft(lb)) - torch.exp(-lin_soft(ub)))
389
+ b = -a / (1 + torch.exp(-lin_soft(lb)))
390
+ return lambda x: cinv(
391
+ a / (1 + torch.exp(-lin_soft(x))) + b
392
+ ), lambda y: lin_soft_inv(
393
+ torch.log(((cinv(y) - b) / a) / (1 - ((cinv(y) - b) / a)))
394
+ )
395
+
396
+
397
+ def get_default_normalizer():
398
+ default_normalizer_kwargs = {
399
+ "lb": torch.tensor(0.0),
400
+ "ub": torch.tensor(1.0),
401
+ "soft_lb": 0.0,
402
+ "soft_ub": 1.0,
403
+ "minimize": False,
404
+ }
405
+ return pfn_normalize(**default_normalizer_kwargs)
406
+
407
+
408
+ def identity_normalizer():
409
+ return lambda x: x, lambda x: x
lcpfn/version.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "0.1.3"
pyproject.toml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "lcpfn"
3
+ description = "In-context Bayesian Learning Curve Extrapolation"
4
+ readme = {file = "readme.md", content-type = 'text/markdown'}
5
+ license = {file = "LICENSE"}
6
+ authors = [
7
+ {name = "Steven Adriaensen", email= "adriaens@cs.uni-freiburg.de"},
8
+ {name = "Herilalaina Rakotoarison", email = "rakotoah@cs.uni-freiburg.de"},
9
+ {name = "Samuel Müller", email = "muellesa@cs.uni-freiburg.de"},
10
+ {name = "Frank Hutter", email = "fh@cs.uni-freiburg.de"},
11
+ ]
12
+ requires-python = ">=3.9,<3.12"
13
+ dependencies = [
14
+ "torch<=1.11.0",
15
+ "numpy>=1.21.2,<2",
16
+ "requests>=2.23.0"
17
+ ]
18
+ dynamic = ["version"]
19
+ classifiers = [
20
+ 'Intended Audience :: Science/Research',
21
+ 'License :: OSI Approved :: MIT License',
22
+ 'Programming Language :: Python',
23
+ 'Topic :: Software Development',
24
+ 'Topic :: Scientific/Engineering',
25
+ 'Operating System :: Unix',
26
+ 'Operating System :: MacOS',
27
+ 'Programming Language :: Python :: 3',
28
+ 'Programming Language :: Python :: 3.9',
29
+ 'Programming Language :: Python :: 3.10',
30
+ 'Programming Language :: Python :: 3.11',
31
+ ]
32
+
33
+ [project.urls]
34
+ homepage = "https://github.com/automl/lcpfn"
35
+ repository = "https://github.com/automl/lcpfn"
36
+ bugtracker = "https://github.com/automl/lcpfn/issues"
37
+
38
+ [tool.setuptools.packages.find]
39
+ include = ["lcpfn*"]
40
+
41
+ [tool.setuptools.dynamic]
42
+ version = {attr = "lcpfn.version.__version__"}
requirements.txt DELETED
@@ -1,4 +0,0 @@
1
- torch==1.11.0
2
- numpy>=1.21.2
3
- scikit-learn
4
- # lcpfn @ git+https://github.com/automl/lcpfn.git