diff --git "a/examples/faiss.ipynb" "b/examples/faiss.ipynb" new file mode 100644--- /dev/null +++ "b/examples/faiss.ipynb" @@ -0,0 +1,267 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "95d2a9e6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded 24148 SMILES strings.\n", + "After deduplication: 24148 SMILES.\n", + "Encoding SMILES... (this may take a while)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c2d1275efb4f44f9869321a7e568e6a8", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Batches: 0%| | 0/755 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from rdkit import Chem\n", + "from rdkit.Chem import Draw\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "# === 6. Query + Visualization ===\n", + "query_smiles = [\"O=C1/C=C\\\\C=C2/N1C[C@@H]3CNC[C@H]2C3\"] # Your query\n", + "\n", + "# Encode & normalize query\n", + "query_emb = model.encode(query_smiles, convert_to_numpy=True)\n", + "faiss.normalize_L2(query_emb)\n", + "\n", + "k = 10\n", + "distances, indices = index.search(query_emb, k)\n", + "\n", + "# Collect molecules and labels\n", + "mols = []\n", + "labels = []\n", + "\n", + "# Add query molecule (with label \"Query\")\n", + "query_mol = Chem.MolFromSmiles(query_smiles[0])\n", + "if query_mol is None:\n", + " print(\"⚠️ Invalid query SMILES!\")\n", + "else:\n", + " mols.append(query_mol)\n", + " labels.append(\"Query\")\n", + "\n", + "# Add top-k results\n", + "for i in range(k):\n", + " idx = indices[0][i]\n", + " sim_score = distances[0][i]\n", + " smi = smiles_list[idx]\n", + " \n", + " mol = Chem.MolFromSmiles(smi)\n", + " if mol is not None:\n", + " mols.append(mol)\n", + " labels.append(f\"Sim: {sim_score:.3f}\")\n", + " else:\n", + " print(f\"⚠️ Invalid SMILES in result #{i+1}: {smi}\")\n", + "\n", + "# Plot\n", + "n_mols = len(mols)\n", + "if n_mols == 0:\n", + " print(\"No valid molecules to display.\")\n", + "else:\n", + " # Create grid: 1 row, n_mols columns\n", + " fig, axes = plt.subplots(1, n_mols, figsize=(4 * n_mols, 4))\n", + " if n_mols == 1:\n", + " axes = [axes] # make iterable\n", + "\n", + " # Render each molecule\n", + " for i, (mol, label) in enumerate(zip(mols, labels)):\n", + " img = Draw.MolToImage(mol, size=(300, 300))\n", + " axes[i].imshow(img)\n", + " axes[i].set_title(label, fontsize=12)\n", + " axes[i].axis('off')\n", + "\n", + " plt.tight_layout()\n", + " \n", + " plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}