Spaces:
Build error
Build error
import random | |
import math | |
import torch | |
from torch import nn | |
import numpy as np | |
from utils import default_device | |
from .utils import get_batch_to_dataloader | |
class GaussianNoise(nn.Module): | |
def __init__(self, std, device): | |
super().__init__() | |
self.std = std | |
self.device=device | |
def forward(self, x): | |
return x + torch.normal(torch.zeros_like(x), self.std) | |
def causes_sampler_f(num_causes): | |
means = np.random.normal(0, 1, (num_causes)) | |
std = np.abs(np.random.normal(0, 1, (num_causes)) * means) | |
return means, std | |
def get_batch(batch_size, seq_len, num_features, hyperparameters, device=default_device, num_outputs=1, sampling='normal', **kwargs): | |
if ('mix_activations' in hyperparameters) and hyperparameters['mix_activations']: | |
s = hyperparameters['prior_mlp_activations']() | |
hyperparameters['prior_mlp_activations'] = lambda : s | |
class MLP(torch.nn.Module): | |
def __init__(self, hyperparameters): | |
super(MLP, self).__init__() | |
with torch.no_grad(): | |
for key in hyperparameters: | |
setattr(self, key, hyperparameters[key]) | |
assert (self.num_layers >= 2) | |
if 'verbose' in hyperparameters and self.verbose: | |
print({k : hyperparameters[k] for k in ['is_causal', 'num_causes', 'prior_mlp_hidden_dim' | |
, 'num_layers', 'noise_std', 'y_is_effect', 'pre_sample_weights', 'prior_mlp_dropout_prob' | |
, 'pre_sample_causes']}) | |
if self.is_causal: | |
self.prior_mlp_hidden_dim = max(self.prior_mlp_hidden_dim, num_outputs + 2 * num_features) | |
else: | |
self.num_causes = num_features | |
# This means that the mean and standard deviation of each cause is determined in advance | |
if self.pre_sample_causes: | |
self.causes_mean, self.causes_std = causes_sampler_f(self.num_causes) | |
self.causes_mean = torch.tensor(self.causes_mean, device=device).unsqueeze(0).unsqueeze(0).tile( | |
(seq_len, 1, 1)) | |
self.causes_std = torch.tensor(self.causes_std, device=device).unsqueeze(0).unsqueeze(0).tile( | |
(seq_len, 1, 1)) | |
def generate_module(layer_idx, out_dim): | |
# Determine std of each noise term in initialization, so that is shared in runs | |
# torch.abs(torch.normal(torch.zeros((out_dim)), self.noise_std)) - Change std for each dimension? | |
noise = (GaussianNoise(torch.abs(torch.normal(torch.zeros(size=(1, out_dim), device=device), float(self.noise_std))), device=device) | |
if self.pre_sample_weights else GaussianNoise(float(self.noise_std), device=device)) | |
return [ | |
nn.Sequential(*[self.prior_mlp_activations() | |
, nn.Linear(self.prior_mlp_hidden_dim, out_dim) | |
, noise]) | |
] | |
self.layers = [nn.Linear(self.num_causes, self.prior_mlp_hidden_dim, device=device)] | |
self.layers += [module for layer_idx in range(self.num_layers-1) for module in generate_module(layer_idx, self.prior_mlp_hidden_dim)] | |
if not self.is_causal: | |
self.layers += generate_module(-1, num_outputs) | |
self.layers = nn.Sequential(*self.layers) | |
# Initialize Model parameters | |
for i, (n, p) in enumerate(self.layers.named_parameters()): | |
if self.block_wise_dropout: | |
if len(p.shape) == 2: # Only apply to weight matrices and not bias | |
nn.init.zeros_(p) | |
# TODO: N blocks should be a setting | |
n_blocks = random.randint(1, math.ceil(math.sqrt(min(p.shape[0], p.shape[1])))) | |
w, h = p.shape[0] // n_blocks, p.shape[1] // n_blocks | |
keep_prob = (n_blocks*w*h) / p.numel() | |
for block in range(0, n_blocks): | |
nn.init.normal_(p[w * block: w * (block+1), h * block: h * (block+1)], std=self.init_std / keep_prob**(1/2)) | |
else: | |
if len(p.shape) == 2: # Only apply to weight matrices and not bias | |
dropout_prob = self.prior_mlp_dropout_prob if i > 0 else 0.0 # Don't apply dropout in first layer | |
dropout_prob = min(dropout_prob, 0.99) | |
nn.init.normal_(p, std=self.init_std / (1. - dropout_prob)**(1/2)) | |
p *= torch.bernoulli(torch.zeros_like(p) + 1. - dropout_prob) | |
def forward(self): | |
def sample_normal(): | |
if self.pre_sample_causes: | |
causes = torch.normal(self.causes_mean, self.causes_std.abs()).float() | |
else: | |
causes = torch.normal(0., 1., (seq_len, 1, self.num_causes), device=device).float() | |
return causes | |
if self.sampling == 'normal': | |
causes = sample_normal() | |
elif self.sampling == 'mixed': | |
zipf_p, multi_p, normal_p = random.random() * 0.66, random.random() * 0.66, random.random() * 0.66 | |
def sample_cause(n): | |
if random.random() > normal_p: | |
if self.pre_sample_causes: | |
return torch.normal(self.causes_mean[:, :, n], self.causes_std[:, :, n].abs()).float() | |
else: | |
return torch.normal(0., 1., (seq_len, 1), device=device).float() | |
elif random.random() > multi_p: | |
x = torch.multinomial(torch.rand((random.randint(2, 10))), seq_len, replacement=True).to(device).unsqueeze(-1).float() | |
x = (x - torch.mean(x)) / torch.std(x) | |
return x | |
else: | |
x = torch.minimum(torch.tensor(np.random.zipf(2.0 + random.random() * 2, size=(seq_len)), | |
device=device).unsqueeze(-1).float(), torch.tensor(10.0, device=device)) | |
return x - torch.mean(x) | |
causes = torch.cat([sample_cause(n).unsqueeze(-1) for n in range(self.num_causes)], -1) | |
elif self.sampling == 'uniform': | |
causes = torch.rand((seq_len, 1, self.num_causes), device=device) | |
else: | |
raise ValueError(f'Sampling is set to invalid setting: {sampling}.') | |
outputs = [causes] | |
for layer in self.layers: | |
outputs.append(layer(outputs[-1])) | |
outputs = outputs[2:] | |
if self.is_causal: | |
## Sample nodes from graph if model is causal | |
outputs_flat = torch.cat(outputs, -1) | |
if self.in_clique: | |
random_perm = random.randint(0, outputs_flat.shape[-1] - num_outputs - num_features) + torch.randperm(num_outputs + num_features, device=device) | |
else: | |
random_perm = torch.randperm(outputs_flat.shape[-1]-1, device=device) | |
random_idx_y = list(range(-num_outputs, -0)) if self.y_is_effect else random_perm[0:num_outputs] | |
random_idx = random_perm[num_outputs:num_outputs + num_features] | |
if self.sort_features: | |
random_idx, _ = torch.sort(random_idx) | |
y = outputs_flat[:, :, random_idx_y] | |
x = outputs_flat[:, :, random_idx] | |
else: | |
y = outputs[-1][:, :, :] | |
x = causes | |
if bool(torch.any(torch.isnan(x)).detach().cpu().numpy()) or bool(torch.any(torch.isnan(y)).detach().cpu().numpy()): | |
x[:] = 0.0 | |
y[:] = 1.0 | |
return x, y | |
model = MLP(hyperparameters).to(device) | |
sample = sum([[model()] for _ in range(0, batch_size)], []) | |
x, y = zip(*sample) | |
y = torch.cat(y, 1).detach().squeeze(2) | |
x = torch.cat(x, 1).detach() | |
x = x[..., torch.randperm(x.shape[-1])] | |
return x, y, y | |
DataLoader = get_batch_to_dataloader(get_batch) | |
DataLoader.num_outputs = 1 | |