DiffLinker / src /linker_size.py
igashov's picture
updated code
88b37fb
raw
history blame
2.81 kB
import numpy as np
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
from src.egnn import GCL
class DistributionNodes:
def __init__(self, histogram):
self.n_nodes = []
prob = []
self.keys = {}
for i, nodes in enumerate(histogram):
self.n_nodes.append(nodes)
self.keys[nodes] = i
prob.append(histogram[nodes])
self.n_nodes = torch.tensor(self.n_nodes)
prob = np.array(prob)
prob = prob/np.sum(prob)
self.prob = torch.from_numpy(prob).float()
self.m = Categorical(torch.tensor(prob))
def sample(self, n_samples=1):
idx = self.m.sample((n_samples,))
return self.n_nodes[idx]
def log_prob(self, batch_n_nodes):
assert len(batch_n_nodes.size()) == 1
idcs = [self.keys[i.item()] for i in batch_n_nodes]
idcs = torch.tensor(idcs).to(batch_n_nodes.device)
log_p = torch.log(self.prob + 1e-30)
log_p = log_p.to(batch_n_nodes.device)
log_probs = log_p[idcs]
return log_probs
class SizeGNN(nn.Module):
def __init__(self, in_node_nf, hidden_nf, out_node_nf, n_layers, normalization, device='cpu'):
super(SizeGNN, self).__init__()
self.hidden_nf = hidden_nf
self.out_node_nf = out_node_nf
self.device = device
self.embedding_in = nn.Linear(in_node_nf, self.hidden_nf)
self.gcl1 = GCL(
input_nf=self.hidden_nf,
output_nf=self.hidden_nf,
hidden_nf=self.hidden_nf,
normalization_factor=1,
aggregation_method='sum',
edges_in_d=1,
activation=nn.ReLU(),
attention=False,
normalization=normalization
)
layers = []
for i in range(n_layers - 1):
layer = GCL(
input_nf=self.hidden_nf,
output_nf=self.hidden_nf,
hidden_nf=self.hidden_nf,
normalization_factor=1,
aggregation_method='sum',
edges_in_d=1,
activation=nn.ReLU(),
attention=False,
normalization=normalization
)
layers.append(layer)
self.gcl_layers = nn.ModuleList(layers)
self.embedding_out = nn.Linear(self.hidden_nf, self.out_node_nf)
self.to(self.device)
def forward(self, h, edges, distances, node_mask, edge_mask):
h = self.embedding_in(h)
h, _ = self.gcl1(h, edges, edge_attr=distances, node_mask=node_mask, edge_mask=edge_mask)
for gcl in self.gcl_layers:
h, _ = gcl(h, edges, edge_attr=distances, node_mask=node_mask, edge_mask=edge_mask)
h = self.embedding_out(h)
return h