import os |
import torch |
import torch.nn as nn |
import torch.nn.functional as F |
from torch_geometric.nn import global_add_pool, global_max_pool |
from torch_geometric.nn import MessagePassing |
import json |
class GraphCLIP(nn.Module): |
def __init__( |
self, |
graph_num_layer, |
graph_hidden_size, |
dropout, |
model_config, |
): |
super().__init__() |
self.model_config = model_config |
self.hidden_size = graph_hidden_size |
self.molecule_encoder = GNNEncoder(num_layer=graph_num_layer, hidden_size=graph_hidden_size, drop_ratio=dropout) |
self.molecule_projection = ProjectionHead(embedding_dim=graph_hidden_size, projection_dim=graph_hidden_size, dropout=dropout) |
def forward(self, x, edge_index, edge_attr, batch): |
molecule_features = self.molecule_encoder(x, edge_index, edge_attr, batch) |
molecule_embeddings = self.molecule_projection(molecule_features) |
molecule_embeddings = molecule_embeddings / molecule_embeddings.norm(dim=-1, keepdim=True) |
return molecule_embeddings |
def save_pretrained(self, output_dir): |
""" |
Save the molecule encoder, projection models, and model_config to the output directory. |
""" |
if not os.path.exists(output_dir): |
os.makedirs(output_dir) |
molecule_path = os.path.join(output_dir, 'model.pt') |
proj_path = molecule_path.replace('model', 'model_proj') |
config_path = os.path.join(output_dir, 'model_config.json') |
torch.save(self.molecule_encoder.state_dict(), molecule_path) |
torch.save(self.molecule_projection.state_dict(), proj_path) |
with open(config_path, 'w') as f: |
json.dump(self.model_config, f, indent=2) |
def disable_grads(self): |
""" |
Disable gradients for all parameters in the model. |
""" |
for param in self.parameters(): |
param.requires_grad = False |
def init_model(self, model_path, verbose=True): |
molecule_path = os.path.join(model_path, 'model.pt') |
proj_path = molecule_path.replace('model', 'model_proj') |
if os.path.exists(molecule_path): |
self.molecule_encoder.load_state_dict(torch.load(molecule_path, map_location='cpu', weights_only=False)) |
else: |
raise FileNotFoundError(f"Molecule encoder file not found: {molecule_path}") |
if os.path.exists(proj_path): |
self.molecule_projection.load_state_dict(torch.load(proj_path, map_location='cpu', weights_only=False)) |
else: |
raise FileNotFoundError(f"Molecule projection file not found: {proj_path}") |
if verbose: |
print('GraphCLIP Models initialized.') |
print('Molecule model:\n', self.molecule_encoder) |
print('Molecule projection:\n', self.molecule_projection) |
class GNNEncoder(nn.Module): |
def __init__(self, num_layer, hidden_size, drop_ratio): |
super(GNNEncoder, self).__init__() |
self.num_layer = num_layer |
self.drop_ratio = drop_ratio |
if self.num_layer < 2: |
raise ValueError("Number of GNN layers must be greater than 1.") |
self.atom_encoder = nn.Embedding(118, hidden_size) |
self.virtualnode_embedding = nn.Embedding(1, hidden_size) |
nn.init.constant_(self.virtualnode_embedding.weight.data, 0) |
self.convs = nn.ModuleList() |
self.norms = nn.ModuleList() |
self.mlp_virtualnode_list = nn.ModuleList() |
for layer in range(num_layer): |
self.convs.append(GINConv(hidden_size, drop_ratio)) |
self.norms.append(nn.LayerNorm(hidden_size, elementwise_affine=True)) |
if layer < num_layer - 1: |
self.mlp_virtualnode_list.append(nn.Sequential(nn.Linear(hidden_size, 4*hidden_size), nn.LayerNorm(4*hidden_size), nn.GELU(), nn.Dropout(drop_ratio), \ |
nn.Linear(4*hidden_size, hidden_size))) |
def initialize_weights(self): |
def _basic_init(module): |
if isinstance(module, nn.Linear): |
torch.nn.init.xavier_uniform_(module.weight) |
if module.bias is not None: |
nn.init.constant_(module.bias, 0) |
self.apply(_basic_init) |
def forward(self, x, edge_index, edge_attr, batch): |
virtualnode_embedding = self.virtualnode_embedding(torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device)) |
h_list = [self.atom_encoder(x)] |
for layer in range(self.num_layer): |
h_list[layer] = h_list[layer] + virtualnode_embedding[batch] |
h = self.convs[layer](h_list[layer], edge_index, edge_attr) |
h = self.norms[layer](h) |
if layer < self.num_layer - 1: |
h = F.gelu(h) |
h = F.dropout(h, self.drop_ratio, training = self.training) |
h = h + h_list[layer] |
h_list.append(h) |
if layer < self.num_layer - 1: |
virtual_pool = global_max_pool(h_list[layer], batch) |
virtualnode_embedding = virtualnode_embedding + F.dropout(self.mlp_virtualnode_list[layer](virtual_pool), self.drop_ratio, training = self.training) |
h_node = h_list[-1] |
h_graph = global_add_pool(h_node, batch) |
return h_graph |
class GINConv(MessagePassing): |
def __init__(self, hidden_size, drop_ratio): |
''' |
hidden_size (int) |
''' |
super(GINConv, self).__init__(aggr = "add") |
self.mlp = nn.Sequential(nn.Linear(hidden_size, 4*hidden_size), nn.LayerNorm(4*hidden_size), nn.GELU(), nn.Dropout(drop_ratio), nn.Linear(4*hidden_size, hidden_size)) |
self.eps = torch.nn.Parameter(torch.Tensor([0])) |
self.bond_encoder = nn.Embedding(5, hidden_size) |
def forward(self, x, edge_index, edge_attr): |
edge_embedding = self.bond_encoder(edge_attr) |
out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding)) |
return out |
def message(self, x_j, edge_attr): |
return F.gelu(x_j + edge_attr) |
def update(self, aggr_out): |
return aggr_out |
class ProjectionHead(nn.Module): |
def __init__( |
self, |
embedding_dim, |
projection_dim, |
dropout, |
act_layer=nn.GELU, |
hidden_features=None, |
bias=True |
): |
super().__init__() |
projection_dim = projection_dim or embedding_dim |
hidden_features = hidden_features or embedding_dim |
linear_layer = nn.Linear |
self.fc1 = linear_layer(embedding_dim, hidden_features, bias=bias) |
self.norm1 = nn.LayerNorm(hidden_features) |
self.act = act_layer() |
self.drop1 = nn.Dropout(dropout) |
self.fc2 = linear_layer(hidden_features, projection_dim, bias=bias) |
def forward(self, x): |
x = self.fc1(x) |
x = self.norm1(x) |
x = self.act(x) |
x = self.drop1(x) |
x = self.fc2(x) |
return x |