Spaces:
Build error
Build error
File size: 8,406 Bytes
e487255 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import random
import math
import torch
from torch import nn
import numpy as np
from utils import default_device
from .utils import get_batch_to_dataloader
class GaussianNoise(nn.Module):
def __init__(self, std, device):
super().__init__()
self.std = std
self.device=device
def forward(self, x):
return x + torch.normal(torch.zeros_like(x), self.std)
def causes_sampler_f(num_causes):
means = np.random.normal(0, 1, (num_causes))
std = np.abs(np.random.normal(0, 1, (num_causes)) * means)
return means, std
def get_batch(batch_size, seq_len, num_features, hyperparameters, device=default_device, num_outputs=1, sampling='normal', **kwargs):
if ('mix_activations' in hyperparameters) and hyperparameters['mix_activations']:
s = hyperparameters['prior_mlp_activations']()
hyperparameters['prior_mlp_activations'] = lambda : s
class MLP(torch.nn.Module):
def __init__(self, hyperparameters):
super(MLP, self).__init__()
with torch.no_grad():
for key in hyperparameters:
setattr(self, key, hyperparameters[key])
assert (self.num_layers >= 2)
if 'verbose' in hyperparameters and self.verbose:
print({k : hyperparameters[k] for k in ['is_causal', 'num_causes', 'prior_mlp_hidden_dim'
, 'num_layers', 'noise_std', 'y_is_effect', 'pre_sample_weights', 'prior_mlp_dropout_prob'
, 'pre_sample_causes']})
if self.is_causal:
self.prior_mlp_hidden_dim = max(self.prior_mlp_hidden_dim, num_outputs + 2 * num_features)
else:
self.num_causes = num_features
# This means that the mean and standard deviation of each cause is determined in advance
if self.pre_sample_causes:
self.causes_mean, self.causes_std = causes_sampler_f(self.num_causes)
self.causes_mean = torch.tensor(self.causes_mean, device=device).unsqueeze(0).unsqueeze(0).tile(
(seq_len, 1, 1))
self.causes_std = torch.tensor(self.causes_std, device=device).unsqueeze(0).unsqueeze(0).tile(
(seq_len, 1, 1))
def generate_module(layer_idx, out_dim):
# Determine std of each noise term in initialization, so that is shared in runs
# torch.abs(torch.normal(torch.zeros((out_dim)), self.noise_std)) - Change std for each dimension?
noise = (GaussianNoise(torch.abs(torch.normal(torch.zeros(size=(1, out_dim), device=device), float(self.noise_std))), device=device)
if self.pre_sample_weights else GaussianNoise(float(self.noise_std), device=device))
return [
nn.Sequential(*[self.prior_mlp_activations()
, nn.Linear(self.prior_mlp_hidden_dim, out_dim)
, noise])
]
self.layers = [nn.Linear(self.num_causes, self.prior_mlp_hidden_dim, device=device)]
self.layers += [module for layer_idx in range(self.num_layers-1) for module in generate_module(layer_idx, self.prior_mlp_hidden_dim)]
if not self.is_causal:
self.layers += generate_module(-1, num_outputs)
self.layers = nn.Sequential(*self.layers)
# Initialize Model parameters
for i, (n, p) in enumerate(self.layers.named_parameters()):
if self.block_wise_dropout:
if len(p.shape) == 2: # Only apply to weight matrices and not bias
nn.init.zeros_(p)
# TODO: N blocks should be a setting
n_blocks = random.randint(1, math.ceil(math.sqrt(min(p.shape[0], p.shape[1]))))
w, h = p.shape[0] // n_blocks, p.shape[1] // n_blocks
keep_prob = (n_blocks*w*h) / p.numel()
for block in range(0, n_blocks):
nn.init.normal_(p[w * block: w * (block+1), h * block: h * (block+1)], std=self.init_std / keep_prob**(1/2))
else:
if len(p.shape) == 2: # Only apply to weight matrices and not bias
dropout_prob = self.prior_mlp_dropout_prob if i > 0 else 0.0 # Don't apply dropout in first layer
dropout_prob = min(dropout_prob, 0.99)
nn.init.normal_(p, std=self.init_std / (1. - dropout_prob)**(1/2))
p *= torch.bernoulli(torch.zeros_like(p) + 1. - dropout_prob)
def forward(self):
def sample_normal():
if self.pre_sample_causes:
causes = torch.normal(self.causes_mean, self.causes_std.abs()).float()
else:
causes = torch.normal(0., 1., (seq_len, 1, self.num_causes), device=device).float()
return causes
if self.sampling == 'normal':
causes = sample_normal()
elif self.sampling == 'mixed':
zipf_p, multi_p, normal_p = random.random() * 0.66, random.random() * 0.66, random.random() * 0.66
def sample_cause(n):
if random.random() > normal_p:
if self.pre_sample_causes:
return torch.normal(self.causes_mean[:, :, n], self.causes_std[:, :, n].abs()).float()
else:
return torch.normal(0., 1., (seq_len, 1), device=device).float()
elif random.random() > multi_p:
x = torch.multinomial(torch.rand((random.randint(2, 10))), seq_len, replacement=True).to(device).unsqueeze(-1).float()
x = (x - torch.mean(x)) / torch.std(x)
return x
else:
x = torch.minimum(torch.tensor(np.random.zipf(2.0 + random.random() * 2, size=(seq_len)),
device=device).unsqueeze(-1).float(), torch.tensor(10.0, device=device))
return x - torch.mean(x)
causes = torch.cat([sample_cause(n).unsqueeze(-1) for n in range(self.num_causes)], -1)
elif self.sampling == 'uniform':
causes = torch.rand((seq_len, 1, self.num_causes), device=device)
else:
raise ValueError(f'Sampling is set to invalid setting: {sampling}.')
outputs = [causes]
for layer in self.layers:
outputs.append(layer(outputs[-1]))
outputs = outputs[2:]
if self.is_causal:
## Sample nodes from graph if model is causal
outputs_flat = torch.cat(outputs, -1)
if self.in_clique:
random_perm = random.randint(0, outputs_flat.shape[-1] - num_outputs - num_features) + torch.randperm(num_outputs + num_features, device=device)
else:
random_perm = torch.randperm(outputs_flat.shape[-1]-1, device=device)
random_idx_y = list(range(-num_outputs, -0)) if self.y_is_effect else random_perm[0:num_outputs]
random_idx = random_perm[num_outputs:num_outputs + num_features]
if self.sort_features:
random_idx, _ = torch.sort(random_idx)
y = outputs_flat[:, :, random_idx_y]
x = outputs_flat[:, :, random_idx]
else:
y = outputs[-1][:, :, :]
x = causes
if bool(torch.any(torch.isnan(x)).detach().cpu().numpy()) or bool(torch.any(torch.isnan(y)).detach().cpu().numpy()):
x[:] = 0.0
y[:] = 1.0
return x, y
model = MLP(hyperparameters).to(device)
sample = sum([[model()] for _ in range(0, batch_size)], [])
x, y = zip(*sample)
y = torch.cat(y, 1).detach().squeeze(2)
x = torch.cat(x, 1).detach()
x = x[..., torch.randperm(x.shape[-1])]
return x, y, y
DataLoader = get_batch_to_dataloader(get_batch)
DataLoader.num_outputs = 1
|