Spaces:
Build error
Build error
import torch | |
from torch import nn | |
import math | |
from .utils import get_batch_to_dataloader | |
from utils import default_device | |
from .utils import order_by_y, normalize_by_used_features_f | |
from .utils import trunc_norm_sampler_f, beta_sampler_f, gamma_sampler_f, uniform_sampler_f, zipf_sampler_f, scaled_beta_sampler_f, uniform_int_sampler_f | |
def unpack_dict_of_tuples(d): | |
# Returns list of dicts where each dict i contains values of tuple position i | |
# {'a': (1,2), 'b': (3,4)} -> [{'a': 1, 'b': 3}, {'a': 2, 'b': 4}] | |
return [dict(zip(d.keys(), v)) for v in list(zip(*list(d.values())))] | |
class DifferentiableHyperparameter(nn.Module): | |
## We can sample this and get a hyperparameter value and a normalized hyperparameter indicator | |
def __init__(self, distribution, embedding_dim, device, **args): | |
super(DifferentiableHyperparameter, self).__init__() | |
self.distribution = distribution | |
self.embedding_dim = embedding_dim | |
self.device=device | |
for key in args: | |
setattr(self, key, args[key]) | |
def get_sampler(): | |
#if self.distribution == "beta": | |
# return beta_sampler_f(self.a, self.b), 0, 1 | |
#elif self.distribution == "gamma": | |
# return gamma_sampler_f(self.a, self.b), 0, 1 | |
#elif self.distribution == "beta_int": | |
# return scaled_beta_sampler_f(self.a, self.b, self.scale, self.min), self.scale + self.min, self.min, self.a / (self.a + self.b) | |
if self.distribution == "uniform": | |
if not hasattr(self, 'sample'): | |
return uniform_sampler_f(self.min, self.max), self.min, self.max, (self.max+self.min) / 2, math.sqrt(1/12*(self.max-self.min)*(self.max-self.min)) | |
else: | |
return lambda: self.sample, self.min, self.max, None, None | |
elif self.distribution == "uniform_int": | |
return uniform_int_sampler_f(self.min, self.max), self.min, self.max, (self.max+self.min) / 2, math.sqrt(1/12*(self.max-self.min)*(self.max-self.min)) | |
if self.distribution.startswith("meta"): | |
self.hparams = {} | |
def sample_meta(f): | |
indicators, passed = unpack_dict_of_tuples({hp: self.hparams[hp]() for hp in self.hparams}) | |
# sampled_embeddings = list(itertools.chain.from_iterable([sampled_embeddings[k] for k in sampled_embeddings])) | |
meta_passed = f(**passed) | |
return indicators, meta_passed | |
args_passed = {'device': device, 'embedding_dim': embedding_dim} | |
if self.distribution == "meta_beta": | |
## Truncated normal where std and mean are drawn randomly logarithmically scaled | |
if hasattr(self, 'b') and hasattr(self, 'k'): | |
self.hparams = {'b': lambda: (None, self.b), 'k': lambda: (None, self.k)} | |
else: | |
self.hparams = {"b": DifferentiableHyperparameter(distribution="uniform", min=self.min | |
, max=self.max, **args_passed) | |
, "k": DifferentiableHyperparameter(distribution="uniform", min=self.min | |
, max=self.max, **args_passed)} | |
def make_beta(b, k): | |
return lambda b=b, k=k: self.scale * beta_sampler_f(b, k)() | |
self.sampler = lambda make_beta=make_beta : sample_meta(make_beta) | |
elif self.distribution == "meta_trunc_norm_log_scaled": | |
# these choices are copied down below, don't change these without changing `replace_differentiable_distributions` | |
self.min_std = self.min_std if hasattr(self, 'min_std') else 0.001 | |
self.max_std = self.max_std if hasattr(self, 'max_std') else self.max_mean | |
## Truncated normal where std and mean are drawn randomly logarithmically scaled | |
if not hasattr(self, 'log_mean'): | |
self.hparams = {"log_mean": DifferentiableHyperparameter(distribution="uniform", min=math.log(self.min_mean) | |
, max=math.log(self.max_mean), **args_passed) | |
, "log_std": DifferentiableHyperparameter(distribution="uniform", min=math.log(self.min_std) | |
, max=math.log(self.max_std), **args_passed)} | |
else: | |
self.hparams = {'log_mean': lambda: (None, self.log_mean), 'log_std': lambda: (None, self.log_std)} | |
def make_trunc_norm(log_mean, log_std): | |
return ((lambda : self.lower_bound + round(trunc_norm_sampler_f(math.exp(log_mean), math.exp(log_std))())) if self.round | |
else (lambda: self.lower_bound + trunc_norm_sampler_f(math.exp(log_mean), math.exp(log_std))())) | |
self.sampler = lambda make_trunc_norm=make_trunc_norm: sample_meta(make_trunc_norm) | |
elif self.distribution == "meta_trunc_norm": | |
self.min_std = self.min_std if hasattr(self, 'min_std') else 0 | |
self.max_std = self.max_std if hasattr(self, 'max_std') else self.max_mean | |
self.hparams = {"mean": DifferentiableHyperparameter(distribution="uniform", min=self.min_mean | |
, max=self.max_mean, **args_passed) | |
, "std": DifferentiableHyperparameter(distribution="uniform", min=self.min_std | |
, max=self.max_std, **args_passed)} | |
def make_trunc_norm(mean, std): | |
return ((lambda: self.lower_bound + round( | |
trunc_norm_sampler_f(math.exp(mean), math.exp(std))())) if self.round | |
else ( | |
lambda make_trunc_norm=make_trunc_norm: self.lower_bound + trunc_norm_sampler_f(math.exp(mean), math.exp(std))())) | |
self.sampler = lambda : sample_meta(make_trunc_norm) | |
elif self.distribution == "meta_choice": | |
if hasattr(self, 'choice_1_weight'): | |
self.hparams = {f'choice_{i}_weight': lambda: (None, getattr(self, f'choice_{i}_weight')) for i in range(1, len(self.choice_values))} | |
else: | |
self.hparams = {f"choice_{i}_weight": DifferentiableHyperparameter(distribution="uniform", min=-5.0 | |
, max=6.0, **args_passed) for i in range(1, len(self.choice_values))} | |
def make_choice(**choices): | |
weights = torch.softmax(torch.tensor([1.0] + [choices[i] for i in choices], dtype=torch.float), 0) # create a tensor of weights | |
sample = torch.multinomial(weights, 1, replacement=True).numpy()[0] | |
return self.choice_values[sample] | |
self.sampler = lambda make_choice=make_choice: sample_meta(make_choice) | |
elif self.distribution == "meta_choice_mixed": | |
if hasattr(self, 'choice_1_weight'): | |
self.hparams = {f'choice_{i}_weight': lambda: (None, getattr(self, f'choice_{i}_weight')) for i in range(1, len(self.choice_values))} | |
else: | |
self.hparams = {f"choice_{i}_weight": DifferentiableHyperparameter(distribution="uniform", min=-5.0 | |
, max=6.0, **args_passed) for i in range(1, len(self.choice_values))} | |
def make_choice(**choices): | |
weights = torch.softmax(torch.tensor([1.0] + [choices[i] for i in choices], dtype=torch.float), 0) # create a tensor of weights | |
def sample(): | |
s = torch.multinomial(weights, 1, replacement=True).numpy()[0] | |
return self.choice_values[s]() | |
return lambda: sample | |
self.sampler = lambda make_choice=make_choice: sample_meta(make_choice) | |
else: | |
def return_two(x, min, max, mean, std): | |
# Returns (a hyperparameter value, and an indicator value passed to the model) | |
if mean is not None: | |
ind = (x-mean)/std#(2 * (x-min) / (max-min) - 1) | |
else: | |
ind = None | |
return ind, x # normalize indicator to [-1, 1] | |
# def sample_standard(sampler_f, embedding): | |
# s = torch.tensor([sampler_f()], device = self.device) | |
# return s, embedding(s) | |
self.sampler_f, self.sampler_min, self.sampler_max, self.sampler_mean, self.sampler_std = get_sampler() | |
self.sampler = lambda : return_two(self.sampler_f(), min=self.sampler_min, max=self.sampler_max | |
, mean=self.sampler_mean, std=self.sampler_std) | |
# self.embedding_layer = nn.Linear(1, self.embedding_dim, device=self.device) | |
# self.embed = lambda x : self.embedding_layer( | |
# (x - self.sampler_min) / (self.sampler_max - self.sampler_min)) | |
#self.sampler = lambda : sample_standard(self.sampler_f, self.embedding) | |
def forward(self): | |
s, s_passed = self.sampler() | |
return s, s_passed | |
class DifferentiableHyperparameterList(nn.Module): | |
def __init__(self, hyperparameters, embedding_dim, device): | |
super().__init__() | |
self.device = device | |
hyperparameters = {k: v for (k, v) in hyperparameters.items() if v} | |
self.hyperparameters = nn.ModuleDict({hp: DifferentiableHyperparameter(embedding_dim = embedding_dim | |
, name = hp | |
, device = device, **hyperparameters[hp]) for hp in hyperparameters}) | |
def get_hyperparameter_info(self): | |
sampled_hyperparameters_f, sampled_hyperparameters_keys = [], [] | |
def append_hp(hp_key, hp_val): | |
sampled_hyperparameters_keys.append(hp_key) | |
# Function remaps hyperparameters from [-1, 1] range to true value | |
s_min, s_max, s_mean, s_std = hp_val.sampler_min, hp_val.sampler_max, hp_val.sampler_mean, hp_val.sampler_std | |
sampled_hyperparameters_f.append((lambda x: (x-s_mean)/s_std, lambda y : (y * s_std)+s_mean)) | |
#sampled_hyperparameters_f.append(((lambda x: ((x - s_min) / (s_max - s_min) * (2) - 1) | |
# , (lambda y: ((y + 1) * (1 / 2) * (s_max - s_min) + s_min)))) | |
for hp in self.hyperparameters: | |
hp_val = self.hyperparameters[hp] | |
if hasattr(hp_val, 'hparams'): | |
for hp_ in hp_val.hparams: | |
append_hp(f'{hp}_{hp_}', hp_val.hparams[hp_]) | |
else: | |
append_hp(hp, hp_val) | |
return sampled_hyperparameters_keys, sampled_hyperparameters_f | |
def sample_parameter_object(self): | |
sampled_hyperparameters, s_passed = {}, {} | |
for hp in self.hyperparameters: | |
sampled_hyperparameters_, s_passed_ = self.hyperparameters[hp]() | |
s_passed[hp] = s_passed_ | |
if isinstance(sampled_hyperparameters_, dict): | |
sampled_hyperparameters_ = {hp + '_' + str(key): val for key, val in sampled_hyperparameters_.items()} | |
sampled_hyperparameters.update(sampled_hyperparameters_) | |
else: | |
sampled_hyperparameters[hp] = sampled_hyperparameters_ | |
# s_passed contains the values passed to the get_batch function | |
# sampled_hyperparameters contains the indicator of the sampled value, i.e. only number that describe the sampled object | |
return s_passed, sampled_hyperparameters#self.pack_parameter_object(sampled_embeddings) | |
class DifferentiablePrior(torch.nn.Module): | |
def __init__(self, get_batch, hyperparameters, differentiable_hyperparameters, args): | |
super(DifferentiablePrior, self).__init__() | |
self.h = hyperparameters | |
self.args = args | |
self.get_batch = get_batch | |
self.differentiable_hyperparameters = DifferentiableHyperparameterList(differentiable_hyperparameters | |
, embedding_dim=self.h['emsize'] | |
, device=self.args['device']) | |
def forward(self): | |
# Sample hyperparameters | |
sampled_hyperparameters_passed, sampled_hyperparameters_indicators = self.differentiable_hyperparameters.sample_parameter_object() | |
hyperparameters = {**self.h, **sampled_hyperparameters_passed} | |
x, y, y_ = self.get_batch(hyperparameters=hyperparameters, **self.args) | |
return x, y, y_, sampled_hyperparameters_indicators | |
# TODO: Make this a class that keeps objects | |
def get_batch(batch_size, seq_len, num_features, get_batch | |
, device=default_device, differentiable_hyperparameters={} | |
, hyperparameters=None, batch_size_per_gp_sample=None, **kwargs): | |
batch_size_per_gp_sample = batch_size_per_gp_sample or (min(64, batch_size)) | |
num_models = batch_size // batch_size_per_gp_sample | |
assert num_models * batch_size_per_gp_sample == batch_size, f'Batch size ({batch_size}) not divisible by batch_size_per_gp_sample ({batch_size_per_gp_sample})' | |
args = {'device': device, 'seq_len': seq_len, 'num_features': num_features, 'batch_size': batch_size_per_gp_sample} | |
models = [DifferentiablePrior(get_batch, hyperparameters, differentiable_hyperparameters, args) for _ in range(num_models)] | |
sample = sum([[model()] for model in models], []) | |
x, y, y_, hyperparameter_dict = zip(*sample) | |
if 'verbose' in hyperparameters and hyperparameters['verbose']: | |
print('Hparams', hyperparameter_dict[0].keys()) | |
hyperparameter_matrix = [] | |
for batch in hyperparameter_dict: | |
hyperparameter_matrix.append([batch[hp] for hp in batch]) | |
transposed_hyperparameter_matrix = list(zip(*hyperparameter_matrix)) | |
assert all([all([hp is None for hp in hp_]) or all([hp is not None for hp in hp_]) for hp_ in transposed_hyperparameter_matrix]), 'it should always be the case that when a hyper-parameter is None, once it is always None' | |
# we remove columns that are only None (i.e. not sampled) | |
hyperparameter_matrix = [[hp for hp in hp_ if hp is not None] for hp_ in hyperparameter_matrix] | |
if len(hyperparameter_matrix[0]) > 0: | |
packed_hyperparameters = torch.tensor(hyperparameter_matrix) | |
packed_hyperparameters = torch.repeat_interleave(packed_hyperparameters, repeats=batch_size_per_gp_sample, dim=0).detach() | |
else: | |
packed_hyperparameters = None | |
x, y, y_, packed_hyperparameters = (torch.cat(x, 1).detach() | |
, torch.cat(y, 1).detach() | |
, torch.cat(y_, 1).detach() | |
, packed_hyperparameters)#list(itertools.chain.from_iterable(itertools.repeat(x, batch_size_per_gp_sample) for x in packed_hyperparameters)))#torch.repeat_interleave(torch.stack(packed_hyperparameters, 0).detach(), repeats=batch_size_per_gp_sample, dim=0)) | |
return x, y, y_, packed_hyperparameters | |
DataLoader = get_batch_to_dataloader(get_batch) | |
DataLoader.num_outputs = 1 | |
#DataLoader.validate = lambda : 0 | |
def draw_random_style(dl, device): | |
(hp_embedding, data, targets_), targets = next(iter(dl)) | |
return hp_embedding.to(device)[0:1, :] | |
def merge_style_with_info(diff_hparams_keys, diff_hparams_f, style, transform=True): | |
params = dict(zip(diff_hparams_keys, zip(diff_hparams_f, style.detach().cpu().numpy().tolist()[0]))) | |
def t(v): | |
if transform: | |
return v[0][1](v[1]) | |
else: | |
return v[1] | |
return {k : t(v) for k, v in params.items()} | |
import ConfigSpace.hyperparameters as CSH | |
def replace_differentiable_distributions(config): | |
diff_config = config['differentiable_hyperparameters'] | |
for name, diff_hp_dict in diff_config.items(): | |
distribution = diff_hp_dict['distribution'] | |
if distribution == 'uniform': | |
diff_hp_dict['sample'] = CSH.UniformFloatHyperparameter(name, diff_hp_dict['min'], diff_hp_dict['max']) | |
elif distribution == 'meta_beta': | |
diff_hp_dict['k'] = CSH.UniformFloatHyperparameter(name+'_k', diff_hp_dict['min'], diff_hp_dict['max']) | |
diff_hp_dict['b'] = CSH.UniformFloatHyperparameter(name+'_b', diff_hp_dict['min'], diff_hp_dict['max']) | |
elif distribution == 'meta_choice': | |
for i in range(1, len(diff_hp_dict['choice_values'])): | |
diff_hp_dict[f'choice_{i}_weight'] = CSH.UniformFloatHyperparameter(name+f'choice_{i}_weight', -5.0, 6.0) | |
elif distribution == 'meta_choice_mixed': | |
for i in range(1, len(diff_hp_dict['choice_values'])): | |
diff_hp_dict[f'choice_{i}_weight'] = CSH.UniformFloatHyperparameter(name+f'choice_{i}_weight', -5.0, 6.0) | |
elif distribution == 'meta_trunc_norm_log_scaled': | |
diff_hp_dict['log_mean'] = CSH.UniformFloatHyperparameter(name+'_log_mean', math.log(diff_hp_dict['min_mean']), math.log(diff_hp_dict['max_mean'])) | |
min_std = diff_hp_dict['min_std'] if 'min_std' in diff_hp_dict else 0.001 | |
max_std = diff_hp_dict['max_std'] if 'max_std' in diff_hp_dict else diff_hp_dict['max_mean'] | |
diff_hp_dict['log_std'] = CSH.UniformFloatHyperparameter(name+'_log_std', math.log(min_std), math.log(max_std)) | |
else: | |
raise ValueError(f'Unknown distribution {distribution}') | |