import torch from torch import nn import math from .utils import get_batch_to_dataloader from utils import default_device from .utils import order_by_y, normalize_by_used_features_f from .utils import trunc_norm_sampler_f, beta_sampler_f, gamma_sampler_f, uniform_sampler_f, zipf_sampler_f, scaled_beta_sampler_f, uniform_int_sampler_f def unpack_dict_of_tuples(d): # Returns list of dicts where each dict i contains values of tuple position i # {'a': (1,2), 'b': (3,4)} -> [{'a': 1, 'b': 3}, {'a': 2, 'b': 4}] return [dict(zip(d.keys(), v)) for v in list(zip(*list(d.values())))] class DifferentiableHyperparameter(nn.Module): ## We can sample this and get a hyperparameter value and a normalized hyperparameter indicator def __init__(self, distribution, embedding_dim, device, **args): super(DifferentiableHyperparameter, self).__init__() self.distribution = distribution self.embedding_dim = embedding_dim self.device=device for key in args: setattr(self, key, args[key]) def get_sampler(): #if self.distribution == "beta": # return beta_sampler_f(self.a, self.b), 0, 1 #elif self.distribution == "gamma": # return gamma_sampler_f(self.a, self.b), 0, 1 #elif self.distribution == "beta_int": # return scaled_beta_sampler_f(self.a, self.b, self.scale, self.min), self.scale + self.min, self.min, self.a / (self.a + self.b) if self.distribution == "uniform": if not hasattr(self, 'sample'): return uniform_sampler_f(self.min, self.max), self.min, self.max, (self.max+self.min) / 2, math.sqrt(1/12*(self.max-self.min)*(self.max-self.min)) else: return lambda: self.sample, self.min, self.max, None, None elif self.distribution == "uniform_int": return uniform_int_sampler_f(self.min, self.max), self.min, self.max, (self.max+self.min) / 2, math.sqrt(1/12*(self.max-self.min)*(self.max-self.min)) if self.distribution.startswith("meta"): self.hparams = {} def sample_meta(f): indicators, passed = unpack_dict_of_tuples({hp: self.hparams[hp]() for hp in self.hparams}) # sampled_embeddings = list(itertools.chain.from_iterable([sampled_embeddings[k] for k in sampled_embeddings])) meta_passed = f(**passed) return indicators, meta_passed args_passed = {'device': device, 'embedding_dim': embedding_dim} if self.distribution == "meta_beta": ## Truncated normal where std and mean are drawn randomly logarithmically scaled if hasattr(self, 'b') and hasattr(self, 'k'): self.hparams = {'b': lambda: (None, self.b), 'k': lambda: (None, self.k)} else: self.hparams = {"b": DifferentiableHyperparameter(distribution="uniform", min=self.min , max=self.max, **args_passed) , "k": DifferentiableHyperparameter(distribution="uniform", min=self.min , max=self.max, **args_passed)} def make_beta(b, k): return lambda b=b, k=k: self.scale * beta_sampler_f(b, k)() self.sampler = lambda make_beta=make_beta : sample_meta(make_beta) elif self.distribution == "meta_trunc_norm_log_scaled": # these choices are copied down below, don't change these without changing `replace_differentiable_distributions` self.min_std = self.min_std if hasattr(self, 'min_std') else 0.001 self.max_std = self.max_std if hasattr(self, 'max_std') else self.max_mean ## Truncated normal where std and mean are drawn randomly logarithmically scaled if not hasattr(self, 'log_mean'): self.hparams = {"log_mean": DifferentiableHyperparameter(distribution="uniform", min=math.log(self.min_mean) , max=math.log(self.max_mean), **args_passed) , "log_std": DifferentiableHyperparameter(distribution="uniform", min=math.log(self.min_std) , max=math.log(self.max_std), **args_passed)} else: self.hparams = {'log_mean': lambda: (None, self.log_mean), 'log_std': lambda: (None, self.log_std)} def make_trunc_norm(log_mean, log_std): return ((lambda : self.lower_bound + round(trunc_norm_sampler_f(math.exp(log_mean), math.exp(log_std))())) if self.round else (lambda: self.lower_bound + trunc_norm_sampler_f(math.exp(log_mean), math.exp(log_std))())) self.sampler = lambda make_trunc_norm=make_trunc_norm: sample_meta(make_trunc_norm) elif self.distribution == "meta_trunc_norm": self.min_std = self.min_std if hasattr(self, 'min_std') else 0 self.max_std = self.max_std if hasattr(self, 'max_std') else self.max_mean self.hparams = {"mean": DifferentiableHyperparameter(distribution="uniform", min=self.min_mean , max=self.max_mean, **args_passed) , "std": DifferentiableHyperparameter(distribution="uniform", min=self.min_std , max=self.max_std, **args_passed)} def make_trunc_norm(mean, std): return ((lambda: self.lower_bound + round( trunc_norm_sampler_f(math.exp(mean), math.exp(std))())) if self.round else ( lambda make_trunc_norm=make_trunc_norm: self.lower_bound + trunc_norm_sampler_f(math.exp(mean), math.exp(std))())) self.sampler = lambda : sample_meta(make_trunc_norm) elif self.distribution == "meta_choice": if hasattr(self, 'choice_1_weight'): self.hparams = {f'choice_{i}_weight': lambda: (None, getattr(self, f'choice_{i}_weight')) for i in range(1, len(self.choice_values))} else: self.hparams = {f"choice_{i}_weight": DifferentiableHyperparameter(distribution="uniform", min=-5.0 , max=6.0, **args_passed) for i in range(1, len(self.choice_values))} def make_choice(**choices): weights = torch.softmax(torch.tensor([1.0] + [choices[i] for i in choices], dtype=torch.float), 0) # create a tensor of weights sample = torch.multinomial(weights, 1, replacement=True).numpy()[0] return self.choice_values[sample] self.sampler = lambda make_choice=make_choice: sample_meta(make_choice) elif self.distribution == "meta_choice_mixed": if hasattr(self, 'choice_1_weight'): self.hparams = {f'choice_{i}_weight': lambda: (None, getattr(self, f'choice_{i}_weight')) for i in range(1, len(self.choice_values))} else: self.hparams = {f"choice_{i}_weight": DifferentiableHyperparameter(distribution="uniform", min=-5.0 , max=6.0, **args_passed) for i in range(1, len(self.choice_values))} def make_choice(**choices): weights = torch.softmax(torch.tensor([1.0] + [choices[i] for i in choices], dtype=torch.float), 0) # create a tensor of weights def sample(): s = torch.multinomial(weights, 1, replacement=True).numpy()[0] return self.choice_values[s]() return lambda: sample self.sampler = lambda make_choice=make_choice: sample_meta(make_choice) else: def return_two(x, min, max, mean, std): # Returns (a hyperparameter value, and an indicator value passed to the model) if mean is not None: ind = (x-mean)/std#(2 * (x-min) / (max-min) - 1) else: ind = None return ind, x # normalize indicator to [-1, 1] # def sample_standard(sampler_f, embedding): # s = torch.tensor([sampler_f()], device = self.device) # return s, embedding(s) self.sampler_f, self.sampler_min, self.sampler_max, self.sampler_mean, self.sampler_std = get_sampler() self.sampler = lambda : return_two(self.sampler_f(), min=self.sampler_min, max=self.sampler_max , mean=self.sampler_mean, std=self.sampler_std) # self.embedding_layer = nn.Linear(1, self.embedding_dim, device=self.device) # self.embed = lambda x : self.embedding_layer( # (x - self.sampler_min) / (self.sampler_max - self.sampler_min)) #self.sampler = lambda : sample_standard(self.sampler_f, self.embedding) def forward(self): s, s_passed = self.sampler() return s, s_passed class DifferentiableHyperparameterList(nn.Module): def __init__(self, hyperparameters, embedding_dim, device): super().__init__() self.device = device hyperparameters = {k: v for (k, v) in hyperparameters.items() if v} self.hyperparameters = nn.ModuleDict({hp: DifferentiableHyperparameter(embedding_dim = embedding_dim , name = hp , device = device, **hyperparameters[hp]) for hp in hyperparameters}) def get_hyperparameter_info(self): sampled_hyperparameters_f, sampled_hyperparameters_keys = [], [] def append_hp(hp_key, hp_val): sampled_hyperparameters_keys.append(hp_key) # Function remaps hyperparameters from [-1, 1] range to true value s_min, s_max, s_mean, s_std = hp_val.sampler_min, hp_val.sampler_max, hp_val.sampler_mean, hp_val.sampler_std sampled_hyperparameters_f.append((lambda x: (x-s_mean)/s_std, lambda y : (y * s_std)+s_mean)) #sampled_hyperparameters_f.append(((lambda x: ((x - s_min) / (s_max - s_min) * (2) - 1) # , (lambda y: ((y + 1) * (1 / 2) * (s_max - s_min) + s_min)))) for hp in self.hyperparameters: hp_val = self.hyperparameters[hp] if hasattr(hp_val, 'hparams'): for hp_ in hp_val.hparams: append_hp(f'{hp}_{hp_}', hp_val.hparams[hp_]) else: append_hp(hp, hp_val) return sampled_hyperparameters_keys, sampled_hyperparameters_f def sample_parameter_object(self): sampled_hyperparameters, s_passed = {}, {} for hp in self.hyperparameters: sampled_hyperparameters_, s_passed_ = self.hyperparameters[hp]() s_passed[hp] = s_passed_ if isinstance(sampled_hyperparameters_, dict): sampled_hyperparameters_ = {hp + '_' + str(key): val for key, val in sampled_hyperparameters_.items()} sampled_hyperparameters.update(sampled_hyperparameters_) else: sampled_hyperparameters[hp] = sampled_hyperparameters_ # s_passed contains the values passed to the get_batch function # sampled_hyperparameters contains the indicator of the sampled value, i.e. only number that describe the sampled object return s_passed, sampled_hyperparameters#self.pack_parameter_object(sampled_embeddings) class DifferentiablePrior(torch.nn.Module): def __init__(self, get_batch, hyperparameters, differentiable_hyperparameters, args): super(DifferentiablePrior, self).__init__() self.h = hyperparameters self.args = args self.get_batch = get_batch self.differentiable_hyperparameters = DifferentiableHyperparameterList(differentiable_hyperparameters , embedding_dim=self.h['emsize'] , device=self.args['device']) def forward(self): # Sample hyperparameters sampled_hyperparameters_passed, sampled_hyperparameters_indicators = self.differentiable_hyperparameters.sample_parameter_object() hyperparameters = {**self.h, **sampled_hyperparameters_passed} x, y, y_ = self.get_batch(hyperparameters=hyperparameters, **self.args) return x, y, y_, sampled_hyperparameters_indicators # TODO: Make this a class that keeps objects @torch.no_grad() def get_batch(batch_size, seq_len, num_features, get_batch , device=default_device, differentiable_hyperparameters={} , hyperparameters=None, batch_size_per_gp_sample=None, **kwargs): batch_size_per_gp_sample = batch_size_per_gp_sample or (min(64, batch_size)) num_models = batch_size // batch_size_per_gp_sample assert num_models * batch_size_per_gp_sample == batch_size, f'Batch size ({batch_size}) not divisible by batch_size_per_gp_sample ({batch_size_per_gp_sample})' args = {'device': device, 'seq_len': seq_len, 'num_features': num_features, 'batch_size': batch_size_per_gp_sample} models = [DifferentiablePrior(get_batch, hyperparameters, differentiable_hyperparameters, args) for _ in range(num_models)] sample = sum([[model()] for model in models], []) x, y, y_, hyperparameter_dict = zip(*sample) if 'verbose' in hyperparameters and hyperparameters['verbose']: print('Hparams', hyperparameter_dict[0].keys()) hyperparameter_matrix = [] for batch in hyperparameter_dict: hyperparameter_matrix.append([batch[hp] for hp in batch]) transposed_hyperparameter_matrix = list(zip(*hyperparameter_matrix)) assert all([all([hp is None for hp in hp_]) or all([hp is not None for hp in hp_]) for hp_ in transposed_hyperparameter_matrix]), 'it should always be the case that when a hyper-parameter is None, once it is always None' # we remove columns that are only None (i.e. not sampled) hyperparameter_matrix = [[hp for hp in hp_ if hp is not None] for hp_ in hyperparameter_matrix] if len(hyperparameter_matrix[0]) > 0: packed_hyperparameters = torch.tensor(hyperparameter_matrix) packed_hyperparameters = torch.repeat_interleave(packed_hyperparameters, repeats=batch_size_per_gp_sample, dim=0).detach() else: packed_hyperparameters = None x, y, y_, packed_hyperparameters = (torch.cat(x, 1).detach() , torch.cat(y, 1).detach() , torch.cat(y_, 1).detach() , packed_hyperparameters)#list(itertools.chain.from_iterable(itertools.repeat(x, batch_size_per_gp_sample) for x in packed_hyperparameters)))#torch.repeat_interleave(torch.stack(packed_hyperparameters, 0).detach(), repeats=batch_size_per_gp_sample, dim=0)) return x, y, y_, packed_hyperparameters DataLoader = get_batch_to_dataloader(get_batch) DataLoader.num_outputs = 1 #DataLoader.validate = lambda : 0 def draw_random_style(dl, device): (hp_embedding, data, targets_), targets = next(iter(dl)) return hp_embedding.to(device)[0:1, :] def merge_style_with_info(diff_hparams_keys, diff_hparams_f, style, transform=True): params = dict(zip(diff_hparams_keys, zip(diff_hparams_f, style.detach().cpu().numpy().tolist()[0]))) def t(v): if transform: return v[0][1](v[1]) else: return v[1] return {k : t(v) for k, v in params.items()} import ConfigSpace.hyperparameters as CSH def replace_differentiable_distributions(config): diff_config = config['differentiable_hyperparameters'] for name, diff_hp_dict in diff_config.items(): distribution = diff_hp_dict['distribution'] if distribution == 'uniform': diff_hp_dict['sample'] = CSH.UniformFloatHyperparameter(name, diff_hp_dict['min'], diff_hp_dict['max']) elif distribution == 'meta_beta': diff_hp_dict['k'] = CSH.UniformFloatHyperparameter(name+'_k', diff_hp_dict['min'], diff_hp_dict['max']) diff_hp_dict['b'] = CSH.UniformFloatHyperparameter(name+'_b', diff_hp_dict['min'], diff_hp_dict['max']) elif distribution == 'meta_choice': for i in range(1, len(diff_hp_dict['choice_values'])): diff_hp_dict[f'choice_{i}_weight'] = CSH.UniformFloatHyperparameter(name+f'choice_{i}_weight', -5.0, 6.0) elif distribution == 'meta_choice_mixed': for i in range(1, len(diff_hp_dict['choice_values'])): diff_hp_dict[f'choice_{i}_weight'] = CSH.UniformFloatHyperparameter(name+f'choice_{i}_weight', -5.0, 6.0) elif distribution == 'meta_trunc_norm_log_scaled': diff_hp_dict['log_mean'] = CSH.UniformFloatHyperparameter(name+'_log_mean', math.log(diff_hp_dict['min_mean']), math.log(diff_hp_dict['max_mean'])) min_std = diff_hp_dict['min_std'] if 'min_std' in diff_hp_dict else 0.001 max_std = diff_hp_dict['max_std'] if 'max_std' in diff_hp_dict else diff_hp_dict['max_mean'] diff_hp_dict['log_std'] = CSH.UniformFloatHyperparameter(name+'_log_std', math.log(min_std), math.log(max_std)) else: raise ValueError(f'Unknown distribution {distribution}')