File size: 3,910 Bytes
9ee7153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#!/usr/bin/env python3
"""
Standalone inference script for Pyrosage TMP AttentiveFP Model
Usage: python inference.py "SMILES_STRING"
"""

import sys
import torch
from torch_geometric.nn import AttentiveFP
from rdkit import Chem
from torch_geometric.data import Data


def smiles_to_data(smiles):
    """Convert SMILES string to PyG Data object with enhanced features"""
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None

    # Enhanced atom features (10 dimensions)
    atom_features = []
    for atom in mol.GetAtoms():
        features = [
            atom.GetAtomicNum(),
            atom.GetTotalDegree(),
            atom.GetFormalCharge(),
            atom.GetTotalNumHs(),
            atom.GetNumRadicalElectrons(),
            int(atom.GetIsAromatic()),
            int(atom.IsInRing()),
            # Hybridization as one-hot (3 dimensions)
            int(atom.GetHybridization() == Chem.rdchem.HybridizationType.SP),
            int(atom.GetHybridization() == Chem.rdchem.HybridizationType.SP2),
            int(atom.GetHybridization() == Chem.rdchem.HybridizationType.SP3)
        ]
        atom_features.append(features)

    x = torch.tensor(atom_features, dtype=torch.float)

    # Enhanced bond features (6 dimensions)
    edges_list = []
    edge_features = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edges_list.extend([[i, j], [j, i]])

        features = [
            # Bond type as one-hot (4 dimensions)
            int(bond.GetBondType() == Chem.rdchem.BondType.SINGLE),
            int(bond.GetBondType() == Chem.rdchem.BondType.DOUBLE),
            int(bond.GetBondType() == Chem.rdchem.BondType.TRIPLE),
            int(bond.GetBondType() == Chem.rdchem.BondType.AROMATIC),
            # Additional features (2 dimensions)
            int(bond.GetIsConjugated()),
            int(bond.IsInRing())
        ]
        edge_features.extend([features, features])

    if not edges_list:
        return None

    edge_index = torch.tensor(edges_list, dtype=torch.long).t()
    edge_attr = torch.tensor(edge_features, dtype=torch.float)

    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)


def load_model():
    """Load the AttentiveFP model"""
    model_dict = torch.load('pytorch_model.pt', map_location='cpu')
    state_dict = model_dict['model_state_dict']
    hyperparams = model_dict['hyperparameters']

    model = AttentiveFP(
        in_channels=10,  # Enhanced atom features
        hidden_channels=hyperparams["hidden_channels"],
        out_channels=1,
        edge_dim=6,  # Enhanced bond features
        num_layers=hyperparams["num_layers"],
        num_timesteps=hyperparams["num_timesteps"],
        dropout=hyperparams["dropout"],
    )

    model.load_state_dict(state_dict)
    model.eval()
    return model


def predict(model, smiles):
    """Make prediction for a SMILES string"""
    data = smiles_to_data(smiles)
    if data is None:
        return None
    
    batch = torch.zeros(data.num_nodes, dtype=torch.long)
    with torch.no_grad():
        output = model(data.x, data.edge_index, data.edge_attr, batch)
        return output.item()


def main():
    if len(sys.argv) != 2:
        print("Usage: python inference.py 'SMILES_STRING'")
        print("Example: python inference.py 'CC(=O)OC1=CC=CC=C1C(=O)O'")
        sys.exit(1)
    
    smiles = sys.argv[1]
    print(f"Loading TMP AttentiveFP model...")
    
    try:
        model = load_model()
        print(f"Making prediction for: {smiles}")
        
        prediction = predict(model, smiles)
        if prediction is not None:
            print(f'Regression result: {prediction:.4f}')
        else:
            print("Error: Could not process SMILES string")
            
    except Exception as e:
        print(f"Error: {e}")
        sys.exit(1)


if __name__ == "__main__":
    main()