import sys import random from datetime import datetime import torch import numpy as np class Logger(object): def __init__(self, logpath, syspart=sys.stdout): self.terminal = syspart self.log = open(logpath, "a") def write(self, message): self.terminal.write(message) self.log.write(message) self.log.flush() def flush(self): # this flush method is needed for python 3 compatibility. # this handles the flush command by doing nothing. # you might want to specify some extra behavior here. pass def log(*args): print(f'[{datetime.now()}]', *args) class EMA: def __init__(self, beta): super().__init__() self.beta = beta def update_model_average(self, ma_model, current_model): for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): old_weight, up_weight = ma_params.data, current_params.data ma_params.data = self.update_average(old_weight, up_weight) def update_average(self, old, new): if old is None: return new return old * self.beta + (1 - self.beta) * new def sum_except_batch(x): return x.reshape(x.size(0), -1).sum(dim=-1) def remove_mean(x): mean = torch.mean(x, dim=1, keepdim=True) x = x - mean return x def remove_mean_with_mask(x, node_mask): masked_max_abs_value = (x * (1 - node_mask)).abs().sum().item() assert masked_max_abs_value < 1e-5, f'Error {masked_max_abs_value} too high' N = node_mask.sum(1, keepdims=True) mean = torch.sum(x, dim=1, keepdim=True) / N x = x - mean * node_mask return x def remove_partial_mean_with_mask(x, node_mask, center_of_mass_mask): """ Subtract center of mass of fragments from coordinates of all atoms """ x_masked = x * center_of_mass_mask N = center_of_mass_mask.sum(1, keepdims=True) mean = torch.sum(x_masked, dim=1, keepdim=True) / N x = x - mean * node_mask return x def assert_mean_zero(x): mean = torch.mean(x, dim=1, keepdim=True) assert mean.abs().max().item() < 1e-4 def assert_mean_zero_with_mask(x, node_mask, eps=1e-10): assert_correctly_masked(x, node_mask) largest_value = x.abs().max().item() error = torch.sum(x, dim=1, keepdim=True).abs().max().item() rel_error = error / (largest_value + eps) assert rel_error < 1e-2, f'Mean is not zero, relative_error {rel_error}' def assert_partial_mean_zero_with_mask(x, node_mask, center_of_mass_mask, eps=1e-10): assert_correctly_masked(x, node_mask) x_masked = x * center_of_mass_mask largest_value = x_masked.abs().max().item() error = torch.sum(x_masked, dim=1, keepdim=True).abs().max().item() rel_error = error / (largest_value + eps) assert rel_error < 1e-2, f'Partial mean is not zero, relative_error {rel_error}' def assert_correctly_masked(variable, node_mask): assert (variable * (1 - node_mask)).abs().max().item() < 1e-4, \ 'Variables not masked properly.' def check_mask_correct(variables, node_mask): for i, variable in enumerate(variables): if len(variable) > 0: assert_correctly_masked(variable, node_mask) def center_gravity_zero_gaussian_log_likelihood(x): assert len(x.size()) == 3 B, N, D = x.size() assert_mean_zero(x) # r is invariant to a basis change in the relevant hyperplane. r2 = sum_except_batch(x.pow(2)) # The relevant hyperplane is (N-1) * D dimensional. degrees_of_freedom = (N-1) * D # Normalizing constant and logpx are computed: log_normalizing_constant = -0.5 * degrees_of_freedom * np.log(2*np.pi) log_px = -0.5 * r2 + log_normalizing_constant return log_px def sample_center_gravity_zero_gaussian(size, device): assert len(size) == 3 x = torch.randn(size, device=device) # This projection only works because Gaussian is rotation invariant around # zero and samples are independent! x_projected = remove_mean(x) return x_projected def center_gravity_zero_gaussian_log_likelihood_with_mask(x, node_mask): assert len(x.size()) == 3 B, N_embedded, D = x.size() assert_mean_zero_with_mask(x, node_mask) # r is invariant to a basis change in the relevant hyperplane, the masked # out values will have zero contribution. r2 = sum_except_batch(x.pow(2)) # The relevant hyperplane is (N-1) * D dimensional. N = node_mask.squeeze(2).sum(1) # N has shape [B] degrees_of_freedom = (N-1) * D # Normalizing constant and logpx are computed: log_normalizing_constant = -0.5 * degrees_of_freedom * np.log(2*np.pi) log_px = -0.5 * r2 + log_normalizing_constant return log_px def sample_center_gravity_zero_gaussian_with_mask(size, device, node_mask): assert len(size) == 3 x = torch.randn(size, device=device) x_masked = x * node_mask # This projection only works because Gaussian is rotation invariant around # zero and samples are independent! # TODO: check it x_projected = remove_mean_with_mask(x_masked, node_mask) return x_projected def standard_gaussian_log_likelihood(x): # Normalizing constant and logpx are computed: log_px = sum_except_batch(-0.5 * x * x - 0.5 * np.log(2*np.pi)) return log_px def sample_gaussian(size, device): x = torch.randn(size, device=device) return x def standard_gaussian_log_likelihood_with_mask(x, node_mask): # Normalizing constant and logpx are computed: log_px_elementwise = -0.5 * x * x - 0.5 * np.log(2*np.pi) log_px = sum_except_batch(log_px_elementwise * node_mask) return log_px def sample_gaussian_with_mask(size, device, node_mask): x = torch.randn(size, device=device) x_masked = x * node_mask return x_masked def concatenate_features(x, h): xh = torch.cat([x, h['categorical']], dim=2) if 'integer' in h: xh = torch.cat([xh, h['integer']], dim=2) return xh def split_features(z, n_dims, num_classes, include_charges): assert z.size(2) == n_dims + num_classes + include_charges x = z[:, :, 0:n_dims] h = {'categorical': z[:, :, n_dims:n_dims+num_classes]} if include_charges: h['integer'] = z[:, :, n_dims+num_classes:n_dims+num_classes+1] return x, h # For gradient clipping class Queue: def __init__(self, max_len=50): self.items = [] self.max_len = max_len def __len__(self): return len(self.items) def add(self, item): self.items.insert(0, item) if len(self) > self.max_len: self.items.pop() def mean(self): return np.mean(self.items) def std(self): return np.std(self.items) def gradient_clipping(flow, gradnorm_queue): # Allow gradient norm to be 150% + 2 * stdev of the recent history. max_grad_norm = 1.5 * gradnorm_queue.mean() + 2 * gradnorm_queue.std() # Clips gradient and returns the norm grad_norm = torch.nn.utils.clip_grad_norm_( flow.parameters(), max_norm=max_grad_norm, norm_type=2.0) if float(grad_norm) > max_grad_norm: gradnorm_queue.add(float(max_grad_norm)) else: gradnorm_queue.add(float(grad_norm)) if float(grad_norm) > max_grad_norm: print(f'Clipped gradient with value {grad_norm:.1f} while allowed {max_grad_norm:.1f}') return grad_norm def disable_rdkit_logging(): """ Disables RDKit whiny logging. """ import rdkit.rdBase as rkrb import rdkit.RDLogger as rkl logger = rkl.logger() logger.setLevel(rkl.ERROR) rkrb.DisableLog('rdApp.error') def set_deterministic(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False class FoundNaNException(Exception): def __init__(self, x, h): x_nan_idx = self.find_nan_idx(x) h_nan_idx = self.find_nan_idx(h) self.x_h_nan_idx = x_nan_idx & h_nan_idx self.only_x_nan_idx = x_nan_idx.difference(h_nan_idx) self.only_h_nan_idx = h_nan_idx.difference(x_nan_idx) @staticmethod def find_nan_idx(z): idx = set() for i in range(z.shape[0]): if torch.any(torch.isnan(z[i])): idx.add(i) return idx def get_batch_idx_for_animation(batch_size, batch_idx): batch_indices = [] mol_indices = [] for idx in [0, 110, 360]: if idx // batch_size == batch_idx: batch_indices.append(idx % batch_size) mol_indices.append(idx) return batch_indices, mol_indices # Rotation data augmntation def random_rotation(x): bs, n_nodes, n_dims = x.size() device = x.device angle_range = np.pi * 2 if n_dims == 2: theta = torch.rand(bs, 1, 1).to(device) * angle_range - np.pi cos_theta = torch.cos(theta) sin_theta = torch.sin(theta) R_row0 = torch.cat([cos_theta, -sin_theta], dim=2) R_row1 = torch.cat([sin_theta, cos_theta], dim=2) R = torch.cat([R_row0, R_row1], dim=1) x = x.transpose(1, 2) x = torch.matmul(R, x) x = x.transpose(1, 2) elif n_dims == 3: # Build Rx Rx = torch.eye(3).unsqueeze(0).repeat(bs, 1, 1).to(device) theta = torch.rand(bs, 1, 1).to(device) * angle_range - np.pi cos = torch.cos(theta) sin = torch.sin(theta) Rx[:, 1:2, 1:2] = cos Rx[:, 1:2, 2:3] = sin Rx[:, 2:3, 1:2] = - sin Rx[:, 2:3, 2:3] = cos # Build Ry Ry = torch.eye(3).unsqueeze(0).repeat(bs, 1, 1).to(device) theta = torch.rand(bs, 1, 1).to(device) * angle_range - np.pi cos = torch.cos(theta) sin = torch.sin(theta) Ry[:, 0:1, 0:1] = cos Ry[:, 0:1, 2:3] = -sin Ry[:, 2:3, 0:1] = sin Ry[:, 2:3, 2:3] = cos # Build Rz Rz = torch.eye(3).unsqueeze(0).repeat(bs, 1, 1).to(device) theta = torch.rand(bs, 1, 1).to(device) * angle_range - np.pi cos = torch.cos(theta) sin = torch.sin(theta) Rz[:, 0:1, 0:1] = cos Rz[:, 0:1, 1:2] = sin Rz[:, 1:2, 0:1] = -sin Rz[:, 1:2, 1:2] = cos x = x.transpose(1, 2) x = torch.matmul(Rx, x) #x = torch.matmul(Rx.transpose(1, 2), x) x = torch.matmul(Ry, x) #x = torch.matmul(Ry.transpose(1, 2), x) x = torch.matmul(Rz, x) #x = torch.matmul(Rz.transpose(1, 2), x) x = x.transpose(1, 2) else: raise Exception("Not implemented Error") return x.contiguous()