DiffLinker / src /linker_size.py
igashov
DiffLinker code
95ba5bc
raw
history blame
No virus
2.94 kB
import numpy as np
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
from src.egnn import GCL
class DistributionNodes:
def __init__(self, histogram):
self.n_nodes = []
prob = []
self.keys = {}
for i, nodes in enumerate(histogram):
self.n_nodes.append(nodes)
self.keys[nodes] = i
prob.append(histogram[nodes])
self.n_nodes = torch.tensor(self.n_nodes)
prob = np.array(prob)
prob = prob/np.sum(prob)
self.prob = torch.from_numpy(prob).float()
entropy = torch.sum(self.prob * torch.log(self.prob + 1e-30))
print("Entropy of n_nodes: H[N]", entropy.item())
self.m = Categorical(torch.tensor(prob))
def sample(self, n_samples=1):
idx = self.m.sample((n_samples,))
return self.n_nodes[idx]
def log_prob(self, batch_n_nodes):
assert len(batch_n_nodes.size()) == 1
idcs = [self.keys[i.item()] for i in batch_n_nodes]
idcs = torch.tensor(idcs).to(batch_n_nodes.device)
log_p = torch.log(self.prob + 1e-30)
log_p = log_p.to(batch_n_nodes.device)
log_probs = log_p[idcs]
return log_probs
class SizeGNN(nn.Module):
def __init__(self, in_node_nf, hidden_nf, out_node_nf, n_layers, normalization, device='cpu'):
super(SizeGNN, self).__init__()
self.hidden_nf = hidden_nf
self.out_node_nf = out_node_nf
self.device = device
self.embedding_in = nn.Linear(in_node_nf, self.hidden_nf)
self.gcl1 = GCL(
input_nf=self.hidden_nf,
output_nf=self.hidden_nf,
hidden_nf=self.hidden_nf,
normalization_factor=1,
aggregation_method='sum',
edges_in_d=1,
activation=nn.ReLU(),
attention=False,
normalization=normalization
)
layers = []
for i in range(n_layers - 1):
layer = GCL(
input_nf=self.hidden_nf,
output_nf=self.hidden_nf,
hidden_nf=self.hidden_nf,
normalization_factor=1,
aggregation_method='sum',
edges_in_d=1,
activation=nn.ReLU(),
attention=False,
normalization=normalization
)
layers.append(layer)
self.gcl_layers = nn.ModuleList(layers)
self.embedding_out = nn.Linear(self.hidden_nf, self.out_node_nf)
self.to(self.device)
def forward(self, h, edges, distances, node_mask, edge_mask):
h = self.embedding_in(h)
h, _ = self.gcl1(h, edges, edge_attr=distances, node_mask=node_mask, edge_mask=edge_mask)
for gcl in self.gcl_layers:
h, _ = gcl(h, edges, edge_attr=distances, node_mask=node_mask, edge_mask=edge_mask)
h = self.embedding_out(h)
return h