testGAT / app.py
QQ2S3R's picture
Update app.py
30807b6 verified
import os
import zipfile
import logging
import torch
import torch.nn.functional as F
import numpy as np
from io import BytesIO
from PIL import Image
import gradio as gr
from torch_geometric.data import Data as PyGData
import matplotlib
matplotlib.use('Agg')
from rdkit import Chem
from rdkit.Chem import Draw, AllChem, MolFromSmiles
# ----------------------------
# Logging & GPU Configuration
# ----------------------------
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
logger.info("Set GPU memory optimization: PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128")
# ----------------------------
# Unzip Model Files if Needed
# ----------------------------
if not os.path.exists("best_model-B-6000-185.pth"):
logger.info("Unzipping model archive...")
try:
with zipfile.ZipFile("models.zip", 'r') as z:
z.extractall(".")
logger.info("Model archive unzipped successfully.")
except Exception as e:
logger.error(f"Failed to unzip models.zip: {e}")
raise
# ----------------------------
# Import Model Utilities
# ----------------------------
try:
from model_utils import EnhancedGAT, smiles_to_graph, visualize_single_molecule
logger.info("Imported model_utils successfully.")
except ImportError as e:
logger.error(f"Failed to import model_utils: {e}")
raise
# ----------------------------
# Device Setup
# ----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
if torch.cuda.is_available():
logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
# ----------------------------
# Model Loading
# ----------------------------
def load_models():
from torch.serialization import add_safe_globals
import numpy.core.multiarray
# allow safe numpy objects if needed
add_safe_globals([numpy.core.multiarray.scalar])
specs = {
"Elastic": ("models/best_model-E-500-68.pth", 2),
"Plastic": ("models/best_model-P-5000-180.pth", 2),
"Brittle": ("models/best_model-B-6000-185.pth", 2),
}
models = {}
for name, (path, out_dim) in specs.items():
if not os.path.exists(path):
if os.path.exists("models.zip"):
logger.info("Extracting models.zip...")
with zipfile.ZipFile("models.zip", 'r') as z:
z.extractall(".")
else:
raise FileNotFoundError(f"Missing model file: {path}")
model = EnhancedGAT(input_dim=12, hidden_dim=512, output_dim=out_dim, num_heads=8)
try:
state = torch.load(path, map_location=device, weights_only=False)
except TypeError:
state = torch.load(path, map_location=device)
state_dict = state.get("model_state_dict", state)
model.load_state_dict(state_dict)
model.eval().to(device)
models[name] = model
logger.info(f"{name} model loaded successfully.")
return models
models = load_models()
# ----------------------------
# Prediction Function
# ----------------------------
def predict_all(smiles: str):
"""
Run predictions for Elastic, Plastic, Brittle.
Use threshold 0.5 for Elastic/Brittle, 0.3 for Plastic.
Return (text, PIL image) for each.
"""
atom_feats, (rows, cols, edge_attr), _ = smiles_to_graph(smiles)
x = torch.tensor(atom_feats, dtype=torch.float)
edge_index = torch.tensor(np.vstack((rows, cols)), dtype=torch.long)
edge_attr = torch.tensor(edge_attr, dtype=torch.float).unsqueeze(1)
data = PyGData(x=x, edge_index=edge_index, edge_attr=edge_attr,
smiles=[smiles], batch=torch.zeros(x.size(0), dtype=torch.long))
outputs = []
thresholds = {"Elastic": 0.5, "Plastic": 0.3, "Brittle": 0.5}
for name in ["Elastic", "Plastic", "Brittle"]:
model = models[name]
with torch.no_grad():
logits = model(data)
# assume binary classification: two outputs
if logits.dim() == 1 or logits.size(1) == 1:
prob = torch.sigmoid(logits).item()
else:
prob = F.softmax(logits, dim=1)[0, 1].item()
label = int(prob >= thresholds[name])
# get visualization buffer
buf, _ = visualize_single_molecule(model, data, device, name)
img = Image.open(buf) if buf else None
outputs.append((f"{name}: {label}", img))
# flatten to 6 outputs
return (*outputs[0], *outputs[1], *outputs[2])
# ----------------------------
# Molecule Builder Utilities
# ----------------------------
ATOM_TYPES = ["C", "N", "O", "S", "P", "F", "Cl", "Br", "I", "H"]
BOND_TYPES = ["Single", "Double", "Triple"]
def init_molecule():
return {"atoms": [], "bonds": []}
def add_atom(mol, atom_type):
mol["atoms"].append({"id": len(mol["atoms"]), "type": atom_type})
return mol
def add_bond(mol, a1_sel, a2_sel, b_type):
if not a1_sel or not a2_sel:
return mol
i1, i2 = int(a1_sel.split(":")[0]), int(a2_sel.split(":")[0])
if {i1, i2} in [{b["atom1"], b["atom2"]} for b in mol["bonds"]]:
return mol
mol["bonds"].append({"atom1": i1, "atom2": i2, "type": b_type})
return mol
def generate_smiles(mol):
try:
rw = Chem.RWMol()
id_map = {}
for atom in mol["atoms"]:
idx = rw.AddAtom(Chem.Atom(atom["type"]))
id_map[atom["id"]] = idx
for b in mol["bonds"]:
bond_map = {"Single": Chem.BondType.SINGLE,
"Double": Chem.BondType.DOUBLE,
"Triple": Chem.BondType.TRIPLE}
rw.AddBond(id_map[b["atom1"]], id_map[b["atom2"]], bond_map[b["type"]])
rw.UpdatePropertyCache()
Chem.SanitizeMol(rw)
return Chem.MolToSmiles(rw)
except Exception as e:
logger.error(f"SMILES generation failed: {e}")
return ""
def visualize_molecule(mol):
"""Return a PIL image or None."""
smiles = generate_smiles(mol)
if not smiles:
return None
m = MolFromSmiles(smiles)
if m is None:
return None
AllChem.Compute2DCoords(m)
return Draw.MolToImage(m, size=(300, 300))
def update_atom_dropdowns(mol):
choices = [f"{a['id']}: {a['type']}" for a in mol["atoms"]]
return gr.update(choices=choices, value=None), gr.update(choices=choices, value=None)
def update_atoms_list(mol):
return [[a["id"], a["type"]] for a in mol["atoms"]]
def update_bonds_list(mol):
out = []
for b in mol["bonds"]:
t1 = next(a["type"] for a in mol["atoms"] if a["id"] == b["atom1"])
t2 = next(a["type"] for a in mol["atoms"] if a["id"] == b["atom2"])
out.append([f"{b['atom1']}: {t1}", f"{b['atom2']}: {t2}", b["type"]])
return out
# ----------------------------
# Gradio Interface
# ----------------------------
with gr.Blocks(title="CrystalGAT", css="""
.gradio-container {max-width:800px; margin:auto}
.gr-button {margin:0.2em}
""") as demo:
gr.Markdown("## CrystalGAT \nEnter a SMILES string or build a molecule to predict Elastic, Plastic, and Brittle classes with attention visualization.")
with gr.Tab("SMILES Input"):
smi_in = gr.Textbox(label="SMILES", placeholder="e.g. CCO")
predict1 = gr.Button("Predict")
with gr.Tab("Manual Molecule Construction"):
state = gr.State(init_molecule())
status = gr.Textbox(label="Status", interactive=False, value="Start by adding atoms")
with gr.Row():
with gr.Column():
atom_type = gr.Dropdown(label="Atom Type", choices=ATOM_TYPES, value="C")
add_a = gr.Button("Add Atom")
atom_tbl = gr.Dataframe(headers=["ID","Type"], datatype=["number","str"], interactive=False)
with gr.Column():
a1 = gr.Dropdown(label="Atom 1", choices=[], value=None)
a2 = gr.Dropdown(label="Atom 2", choices=[], value=None)
bond_type = gr.Dropdown(label="Bond Type", choices=BOND_TYPES, value="Single")
add_b = gr.Button("Add Bond")
bond_tbl = gr.Dataframe(headers=["Atom1","Atom2","Type"], datatype=["str","str","str"], interactive=False)
with gr.Row():
clear = gr.Button("Clear All")
make = gr.Button("Generate SMILES")
smi_out = gr.Textbox(label="SMILES Output", interactive=False)
mol_img = gr.Image(type="pil", label="Molecule Preview")
predict2 = gr.Button("Predict on Built Molecule")
# Outputs
with gr.Row():
e_txt = gr.Text(label="Elastic")
e_img = gr.Image(type="pil", label="Elastic Attention")
with gr.Row():
p_txt = gr.Text(label="Plastic")
p_img = gr.Image(type="pil", label="Plastic Attention")
with gr.Row():
b_txt = gr.Text(label="Brittle")
b_img = gr.Image(type="pil", label="Brittle Attention")
# Event bindings
predict1.click(fn=predict_all, inputs=smi_in,
outputs=[e_txt, e_img, p_txt, p_img, b_txt, b_img])
add_a.click(fn=add_atom, inputs=[state, atom_type], outputs=state)\
.then(fn=update_atoms_list, inputs=state, outputs=atom_tbl)\
.then(fn=update_atom_dropdowns, inputs=state, outputs=[a1, a2])\
.then(fn=lambda: "Atom added.", outputs=status)
add_b.click(fn=add_bond, inputs=[state, a1, a2, bond_type], outputs=state)\
.then(fn=update_bonds_list, inputs=state, outputs=bond_tbl)\
.then(fn=lambda: "Bond added/updated.", outputs=status)
clear.click(fn=init_molecule, outputs=state)\
.then(fn=lambda: ([], []), outputs=[atom_tbl, bond_tbl])\
.then(fn=lambda: (gr.update(choices=[], value=None), gr.update(choices=[], value=None)),
outputs=[a1, a2])\
.then(fn=lambda: "Cleared all.", outputs=status)
make.click(fn=generate_smiles, inputs=state, outputs=smi_out)\
.then(fn=visualize_molecule, inputs=state, outputs=mol_img)\
.then(fn=lambda: "Molecule generated.", outputs=status)
predict2.click(fn=lambda s: predict_all(s) if s else ("Enter SMILES", None, "", None, "", None),
inputs=smi_out,
outputs=[e_txt, e_img, p_txt, p_img, b_txt, b_img])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)