Spaces:
Sleeping
Sleeping
import sys | |
import random | |
from datetime import datetime | |
import torch | |
import numpy as np | |
class Logger(object): | |
def __init__(self, logpath, syspart=sys.stdout): | |
self.terminal = syspart | |
self.log = open(logpath, "a") | |
def write(self, message): | |
self.terminal.write(message) | |
self.log.write(message) | |
self.log.flush() | |
def flush(self): | |
# this flush method is needed for python 3 compatibility. | |
# this handles the flush command by doing nothing. | |
# you might want to specify some extra behavior here. | |
pass | |
def log(*args): | |
print(f'[{datetime.now()}]', *args) | |
class EMA: | |
def __init__(self, beta): | |
super().__init__() | |
self.beta = beta | |
def update_model_average(self, ma_model, current_model): | |
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): | |
old_weight, up_weight = ma_params.data, current_params.data | |
ma_params.data = self.update_average(old_weight, up_weight) | |
def update_average(self, old, new): | |
if old is None: | |
return new | |
return old * self.beta + (1 - self.beta) * new | |
def sum_except_batch(x): | |
return x.reshape(x.size(0), -1).sum(dim=-1) | |
def remove_mean(x): | |
mean = torch.mean(x, dim=1, keepdim=True) | |
x = x - mean | |
return x | |
def remove_mean_with_mask(x, node_mask): | |
masked_max_abs_value = (x * (1 - node_mask)).abs().sum().item() | |
assert masked_max_abs_value < 1e-5, f'Error {masked_max_abs_value} too high' | |
N = node_mask.sum(1, keepdims=True) | |
mean = torch.sum(x, dim=1, keepdim=True) / N | |
x = x - mean * node_mask | |
return x | |
def remove_partial_mean_with_mask(x, node_mask, center_of_mass_mask): | |
""" | |
Subtract center of mass of fragments from coordinates of all atoms | |
""" | |
x_masked = x * center_of_mass_mask | |
N = center_of_mass_mask.sum(1, keepdims=True) | |
mean = torch.sum(x_masked, dim=1, keepdim=True) / N | |
x = x - mean * node_mask | |
return x | |
def assert_mean_zero(x): | |
mean = torch.mean(x, dim=1, keepdim=True) | |
assert mean.abs().max().item() < 1e-4 | |
def assert_mean_zero_with_mask(x, node_mask, eps=1e-10): | |
assert_correctly_masked(x, node_mask) | |
largest_value = x.abs().max().item() | |
error = torch.sum(x, dim=1, keepdim=True).abs().max().item() | |
rel_error = error / (largest_value + eps) | |
assert rel_error < 1e-2, f'Mean is not zero, relative_error {rel_error}' | |
def assert_partial_mean_zero_with_mask(x, node_mask, center_of_mass_mask, eps=1e-10): | |
assert_correctly_masked(x, node_mask) | |
x_masked = x * center_of_mass_mask | |
largest_value = x_masked.abs().max().item() | |
error = torch.sum(x_masked, dim=1, keepdim=True).abs().max().item() | |
rel_error = error / (largest_value + eps) | |
assert rel_error < 1e-2, f'Partial mean is not zero, relative_error {rel_error}' | |
def assert_correctly_masked(variable, node_mask): | |
assert (variable * (1 - node_mask)).abs().max().item() < 1e-4, \ | |
'Variables not masked properly.' | |
def check_mask_correct(variables, node_mask): | |
for i, variable in enumerate(variables): | |
if len(variable) > 0: | |
assert_correctly_masked(variable, node_mask) | |
def center_gravity_zero_gaussian_log_likelihood(x): | |
assert len(x.size()) == 3 | |
B, N, D = x.size() | |
assert_mean_zero(x) | |
# r is invariant to a basis change in the relevant hyperplane. | |
r2 = sum_except_batch(x.pow(2)) | |
# The relevant hyperplane is (N-1) * D dimensional. | |
degrees_of_freedom = (N-1) * D | |
# Normalizing constant and logpx are computed: | |
log_normalizing_constant = -0.5 * degrees_of_freedom * np.log(2*np.pi) | |
log_px = -0.5 * r2 + log_normalizing_constant | |
return log_px | |
def sample_center_gravity_zero_gaussian(size, device): | |
assert len(size) == 3 | |
x = torch.randn(size, device=device) | |
# This projection only works because Gaussian is rotation invariant around | |
# zero and samples are independent! | |
x_projected = remove_mean(x) | |
return x_projected | |
def center_gravity_zero_gaussian_log_likelihood_with_mask(x, node_mask): | |
assert len(x.size()) == 3 | |
B, N_embedded, D = x.size() | |
assert_mean_zero_with_mask(x, node_mask) | |
# r is invariant to a basis change in the relevant hyperplane, the masked | |
# out values will have zero contribution. | |
r2 = sum_except_batch(x.pow(2)) | |
# The relevant hyperplane is (N-1) * D dimensional. | |
N = node_mask.squeeze(2).sum(1) # N has shape [B] | |
degrees_of_freedom = (N-1) * D | |
# Normalizing constant and logpx are computed: | |
log_normalizing_constant = -0.5 * degrees_of_freedom * np.log(2*np.pi) | |
log_px = -0.5 * r2 + log_normalizing_constant | |
return log_px | |
def sample_center_gravity_zero_gaussian_with_mask(size, device, node_mask): | |
assert len(size) == 3 | |
x = torch.randn(size, device=device) | |
x_masked = x * node_mask | |
# This projection only works because Gaussian is rotation invariant around | |
# zero and samples are independent! | |
# TODO: check it | |
x_projected = remove_mean_with_mask(x_masked, node_mask) | |
return x_projected | |
def standard_gaussian_log_likelihood(x): | |
# Normalizing constant and logpx are computed: | |
log_px = sum_except_batch(-0.5 * x * x - 0.5 * np.log(2*np.pi)) | |
return log_px | |
def sample_gaussian(size, device): | |
x = torch.randn(size, device=device) | |
return x | |
def standard_gaussian_log_likelihood_with_mask(x, node_mask): | |
# Normalizing constant and logpx are computed: | |
log_px_elementwise = -0.5 * x * x - 0.5 * np.log(2*np.pi) | |
log_px = sum_except_batch(log_px_elementwise * node_mask) | |
return log_px | |
def sample_gaussian_with_mask(size, device, node_mask): | |
x = torch.randn(size, device=device) | |
x_masked = x * node_mask | |
return x_masked | |
def concatenate_features(x, h): | |
xh = torch.cat([x, h['categorical']], dim=2) | |
if 'integer' in h: | |
xh = torch.cat([xh, h['integer']], dim=2) | |
return xh | |
def split_features(z, n_dims, num_classes, include_charges): | |
assert z.size(2) == n_dims + num_classes + include_charges | |
x = z[:, :, 0:n_dims] | |
h = {'categorical': z[:, :, n_dims:n_dims+num_classes]} | |
if include_charges: | |
h['integer'] = z[:, :, n_dims+num_classes:n_dims+num_classes+1] | |
return x, h | |
# For gradient clipping | |
class Queue: | |
def __init__(self, max_len=50): | |
self.items = [] | |
self.max_len = max_len | |
def __len__(self): | |
return len(self.items) | |
def add(self, item): | |
self.items.insert(0, item) | |
if len(self) > self.max_len: | |
self.items.pop() | |
def mean(self): | |
return np.mean(self.items) | |
def std(self): | |
return np.std(self.items) | |
def gradient_clipping(flow, gradnorm_queue): | |
# Allow gradient norm to be 150% + 2 * stdev of the recent history. | |
max_grad_norm = 1.5 * gradnorm_queue.mean() + 2 * gradnorm_queue.std() | |
# Clips gradient and returns the norm | |
grad_norm = torch.nn.utils.clip_grad_norm_( | |
flow.parameters(), max_norm=max_grad_norm, norm_type=2.0) | |
if float(grad_norm) > max_grad_norm: | |
gradnorm_queue.add(float(max_grad_norm)) | |
else: | |
gradnorm_queue.add(float(grad_norm)) | |
if float(grad_norm) > max_grad_norm: | |
print(f'Clipped gradient with value {grad_norm:.1f} while allowed {max_grad_norm:.1f}') | |
return grad_norm | |
def disable_rdkit_logging(): | |
""" | |
Disables RDKit whiny logging. | |
""" | |
import rdkit.rdBase as rkrb | |
import rdkit.RDLogger as rkl | |
logger = rkl.logger() | |
logger.setLevel(rkl.ERROR) | |
rkrb.DisableLog('rdApp.error') | |
def set_deterministic(seed): | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed_all(seed) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
class FoundNaNException(Exception): | |
def __init__(self, x, h): | |
x_nan_idx = self.find_nan_idx(x) | |
h_nan_idx = self.find_nan_idx(h) | |
self.x_h_nan_idx = x_nan_idx & h_nan_idx | |
self.only_x_nan_idx = x_nan_idx.difference(h_nan_idx) | |
self.only_h_nan_idx = h_nan_idx.difference(x_nan_idx) | |
def find_nan_idx(z): | |
idx = set() | |
for i in range(z.shape[0]): | |
if torch.any(torch.isnan(z[i])): | |
idx.add(i) | |
return idx | |
def get_batch_idx_for_animation(batch_size, batch_idx): | |
batch_indices = [] | |
mol_indices = [] | |
for idx in [0, 110, 360]: | |
if idx // batch_size == batch_idx: | |
batch_indices.append(idx % batch_size) | |
mol_indices.append(idx) | |
return batch_indices, mol_indices | |
# Rotation data augmntation | |
def random_rotation(x): | |
bs, n_nodes, n_dims = x.size() | |
device = x.device | |
angle_range = np.pi * 2 | |
if n_dims == 2: | |
theta = torch.rand(bs, 1, 1).to(device) * angle_range - np.pi | |
cos_theta = torch.cos(theta) | |
sin_theta = torch.sin(theta) | |
R_row0 = torch.cat([cos_theta, -sin_theta], dim=2) | |
R_row1 = torch.cat([sin_theta, cos_theta], dim=2) | |
R = torch.cat([R_row0, R_row1], dim=1) | |
x = x.transpose(1, 2) | |
x = torch.matmul(R, x) | |
x = x.transpose(1, 2) | |
elif n_dims == 3: | |
# Build Rx | |
Rx = torch.eye(3).unsqueeze(0).repeat(bs, 1, 1).to(device) | |
theta = torch.rand(bs, 1, 1).to(device) * angle_range - np.pi | |
cos = torch.cos(theta) | |
sin = torch.sin(theta) | |
Rx[:, 1:2, 1:2] = cos | |
Rx[:, 1:2, 2:3] = sin | |
Rx[:, 2:3, 1:2] = - sin | |
Rx[:, 2:3, 2:3] = cos | |
# Build Ry | |
Ry = torch.eye(3).unsqueeze(0).repeat(bs, 1, 1).to(device) | |
theta = torch.rand(bs, 1, 1).to(device) * angle_range - np.pi | |
cos = torch.cos(theta) | |
sin = torch.sin(theta) | |
Ry[:, 0:1, 0:1] = cos | |
Ry[:, 0:1, 2:3] = -sin | |
Ry[:, 2:3, 0:1] = sin | |
Ry[:, 2:3, 2:3] = cos | |
# Build Rz | |
Rz = torch.eye(3).unsqueeze(0).repeat(bs, 1, 1).to(device) | |
theta = torch.rand(bs, 1, 1).to(device) * angle_range - np.pi | |
cos = torch.cos(theta) | |
sin = torch.sin(theta) | |
Rz[:, 0:1, 0:1] = cos | |
Rz[:, 0:1, 1:2] = sin | |
Rz[:, 1:2, 0:1] = -sin | |
Rz[:, 1:2, 1:2] = cos | |
x = x.transpose(1, 2) | |
x = torch.matmul(Rx, x) | |
#x = torch.matmul(Rx.transpose(1, 2), x) | |
x = torch.matmul(Ry, x) | |
#x = torch.matmul(Ry.transpose(1, 2), x) | |
x = torch.matmul(Rz, x) | |
#x = torch.matmul(Rz.transpose(1, 2), x) | |
x = x.transpose(1, 2) | |
else: | |
raise Exception("Not implemented Error") | |
return x.contiguous() |