{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# materials.smi-TED - INFERENCE (Regression)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Install extra packages for notebook\n", "%pip install seaborn xgboost" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.append('../inference')" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# materials.smi-ted (smi-ted)\n", "from smi_ted_light.load import load_smi_ted\n", "\n", "# Data\n", "import torch\n", "import pandas as pd\n", "import numpy as np\n", "\n", "# Chemistry\n", "from rdkit import Chem\n", "from rdkit.Chem import PandasTools\n", "from rdkit.Chem import Descriptors\n", "PandasTools.RenderImagesInAllDataFrames(True)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# function to canonicalize SMILES\n", "def normalize_smiles(smi, canonical=True, isomeric=False):\n", " try:\n", " normalized = Chem.MolToSmiles(\n", " Chem.MolFromSmiles(smi), canonical=canonical, isomericSmiles=isomeric\n", " )\n", " except:\n", " normalized = None\n", " return normalized" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Import smi-ted" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Random Seed: 12345\n", "Using Rotation Embedding\n", "Using Rotation Embedding\n", "Using Rotation Embedding\n", "Using Rotation Embedding\n", "Using Rotation Embedding\n", "Using Rotation Embedding\n", "Using Rotation Embedding\n", "Using Rotation Embedding\n", "Using Rotation Embedding\n", "Using Rotation Embedding\n", "Using Rotation Embedding\n", "Using Rotation Embedding\n", "Vocab size: 2393\n", "[INFERENCE MODE - smi-ted-Light]\n" ] } ], "source": [ "model_smi_ted = load_smi_ted(\n", " folder='../inference/smi_ted_light',\n", " ckpt_filename='smi-ted-Light_40.pt'\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Lipophilicity Dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Experiments - Data Load" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "df_train = pd.read_csv(\"../finetune/moleculenet/lipophilicity/train.csv\")\n", "df_test = pd.read_csv(\"../finetune/moleculenet/lipophilicity/test.csv\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### SMILES canonization" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(3360, 3)\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
smilesynorm_smiles
0Nc1ncnc2c1c(COc3cccc(Cl)c3)nn2C4CCOCC40.814313Nc1ncnc2c1c(COc1cccc(Cl)c1)nn2C1CCOCC1
1COc1cc(cc2cnc(Nc3ccc(cc3)[C@@H](C)NC(=O)C)nc12...0.446346COc1cc(-c2ccncc2)cc2cnc(Nc3ccc(C(C)NC(C)=O)cc3...
2CC(=O)Nc1ccc2ccn(c3cc(Nc4ccn(C)n4)n5ncc(C#N)c5...1.148828CC(=O)Nc1ccc2ccn(-c3cc(Nc4ccn(C)n4)n4ncc(C#N)c...
3Oc1ccc(CCNCCS(=O)(=O)CCCOCCSc2ccccc2)c3sc(O)nc130.404532O=S(=O)(CCCOCCSc1ccccc1)CCNCCc1ccc(O)c2nc(O)sc12
4Clc1ccc2C(=O)C3=C(Nc2c1)C(=O)NN(Cc4cc5ccccc5s4...-0.164144O=c1[nH]n(Cc2cc3ccccc3s2)c(=O)c2c(=O)c3ccc(Cl)...
\n", "
" ], "text/plain": [ " smiles y \\\n", "0 Nc1ncnc2c1c(COc3cccc(Cl)c3)nn2C4CCOCC4 0.814313 \n", "1 COc1cc(cc2cnc(Nc3ccc(cc3)[C@@H](C)NC(=O)C)nc12... 0.446346 \n", "2 CC(=O)Nc1ccc2ccn(c3cc(Nc4ccn(C)n4)n5ncc(C#N)c5... 1.148828 \n", "3 Oc1ccc(CCNCCS(=O)(=O)CCCOCCSc2ccccc2)c3sc(O)nc13 0.404532 \n", "4 Clc1ccc2C(=O)C3=C(Nc2c1)C(=O)NN(Cc4cc5ccccc5s4... -0.164144 \n", "\n", " norm_smiles \n", "0 Nc1ncnc2c1c(COc1cccc(Cl)c1)nn2C1CCOCC1 \n", "1 COc1cc(-c2ccncc2)cc2cnc(Nc3ccc(C(C)NC(C)=O)cc3... \n", "2 CC(=O)Nc1ccc2ccn(-c3cc(Nc4ccn(C)n4)n4ncc(C#N)c... \n", "3 O=S(=O)(CCCOCCSc1ccccc1)CCNCCc1ccc(O)c2nc(O)sc12 \n", "4 O=c1[nH]n(Cc2cc3ccccc3s2)c(=O)c2c(=O)c3ccc(Cl)... " ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_train['norm_smiles'] = df_train['smiles'].apply(normalize_smiles)\n", "df_train_normalized = df_train.dropna()\n", "print(df_train_normalized.shape)\n", "df_train_normalized.head()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(420, 3)\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
smilesynorm_smiles
0N(c1ccccc1)c2ccnc3ccccc230.488161c1ccc(Nc2ccnc3ccccc23)cc1
1Clc1ccc2Oc3ccccc3N=C(N4CCNCC4)c2c10.070017Clc1ccc2c(c1)C(N1CCNCC1)=Nc1ccccc1O2
2NC1(CCC1)c2ccc(cc2)c3ncc4cccnc4c3c5ccccc5-0.415030NC1(c2ccc(-c3ncc4cccnc4c3-c3ccccc3)cc2)CCC1
3OC[C@H](O)CN1C(=O)[C@@H](Cc2ccccc12)NC(=O)c3cc...0.897942O=C(NC1Cc2ccccc2N(CC(O)CO)C1=O)c1cc2cc(Cl)sc2[...
4NS(=O)(=O)c1nc2ccccc2s1-0.707731NS(=O)(=O)c1nc2ccccc2s1
\n", "
" ], "text/plain": [ " smiles y \\\n", "0 N(c1ccccc1)c2ccnc3ccccc23 0.488161 \n", "1 Clc1ccc2Oc3ccccc3N=C(N4CCNCC4)c2c1 0.070017 \n", "2 NC1(CCC1)c2ccc(cc2)c3ncc4cccnc4c3c5ccccc5 -0.415030 \n", "3 OC[C@H](O)CN1C(=O)[C@@H](Cc2ccccc12)NC(=O)c3cc... 0.897942 \n", "4 NS(=O)(=O)c1nc2ccccc2s1 -0.707731 \n", "\n", " norm_smiles \n", "0 c1ccc(Nc2ccnc3ccccc23)cc1 \n", "1 Clc1ccc2c(c1)C(N1CCNCC1)=Nc1ccccc1O2 \n", "2 NC1(c2ccc(-c3ncc4cccnc4c3-c3ccccc3)cc2)CCC1 \n", "3 O=C(NC1Cc2ccccc2N(CC(O)CO)C1=O)c1cc2cc(Cl)sc2[... \n", "4 NS(=O)(=O)c1nc2ccccc2s1 " ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_test['norm_smiles'] = df_test['smiles'].apply(normalize_smiles)\n", "df_test_normalized = df_test.dropna()\n", "print(df_test_normalized.shape)\n", "df_test_normalized.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Embeddings extraction " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### smi-ted embeddings extraction" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 33/33 [00:38<00:00, 1.15s/it]\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
0123456789...758759760761762763764765766767
00.367646-0.5048890.0404850.3853140.564923-0.6844971.1603970.0712180.7994280.181323...-1.379994-0.1672210.1048860.239571-0.7443900.590423-0.8089460.7925840.550898-0.176831
10.455316-0.4855540.0622060.3879940.567590-0.7132851.144267-0.0570460.7530160.112180...-1.332142-0.0966620.2219440.327923-0.7393580.659803-0.7757230.7458370.566330-0.111946
20.442309-0.4847320.0849450.3847870.564752-0.7041301.1594910.0211680.8465390.118463...-1.324177-0.1104030.2078240.281665-0.7808180.693484-0.8326260.7630950.532460-0.196708
30.527961-0.5191510.0916350.3535180.421795-0.7242201.0937520.1485740.8040470.194627...-1.358414-0.1114830.1516920.186741-0.6018670.641591-0.7474220.7942390.640765-0.239649
40.464432-0.5110900.0387850.3462170.492919-0.6193871.0481570.0959100.7386040.119270...-1.223927-0.1098630.1512800.244834-0.6866100.759327-0.7563380.7664270.610454-0.197345
\n", "

5 rows × 768 columns

\n", "
" ], "text/plain": [ " 0 1 2 3 4 5 6 \\\n", "0 0.367646 -0.504889 0.040485 0.385314 0.564923 -0.684497 1.160397 \n", "1 0.455316 -0.485554 0.062206 0.387994 0.567590 -0.713285 1.144267 \n", "2 0.442309 -0.484732 0.084945 0.384787 0.564752 -0.704130 1.159491 \n", "3 0.527961 -0.519151 0.091635 0.353518 0.421795 -0.724220 1.093752 \n", "4 0.464432 -0.511090 0.038785 0.346217 0.492919 -0.619387 1.048157 \n", "\n", " 7 8 9 ... 758 759 760 761 \\\n", "0 0.071218 0.799428 0.181323 ... -1.379994 -0.167221 0.104886 0.239571 \n", "1 -0.057046 0.753016 0.112180 ... -1.332142 -0.096662 0.221944 0.327923 \n", "2 0.021168 0.846539 0.118463 ... -1.324177 -0.110403 0.207824 0.281665 \n", "3 0.148574 0.804047 0.194627 ... -1.358414 -0.111483 0.151692 0.186741 \n", "4 0.095910 0.738604 0.119270 ... -1.223927 -0.109863 0.151280 0.244834 \n", "\n", " 762 763 764 765 766 767 \n", "0 -0.744390 0.590423 -0.808946 0.792584 0.550898 -0.176831 \n", "1 -0.739358 0.659803 -0.775723 0.745837 0.566330 -0.111946 \n", "2 -0.780818 0.693484 -0.832626 0.763095 0.532460 -0.196708 \n", "3 -0.601867 0.641591 -0.747422 0.794239 0.640765 -0.239649 \n", "4 -0.686610 0.759327 -0.756338 0.766427 0.610454 -0.197345 \n", "\n", "[5 rows x 768 columns]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "with torch.no_grad():\n", " df_embeddings_train = model_smi_ted.encode(df_train_normalized['norm_smiles'])\n", "df_embeddings_train.head()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 4/4 [00:05<00:00, 1.46s/it]\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
0123456789...758759760761762763764765766767
00.392252-0.5048460.0567910.3562970.475918-0.6488991.157862-0.0229140.7032400.192023...-1.208714-0.0944410.1288450.403995-0.7827820.541907-0.7072720.9010410.629461-0.020630
10.387422-0.4811420.0496750.3530580.601170-0.6460991.1423920.0600920.7637990.110331...-1.248282-0.1397900.0755850.202242-0.7297940.705914-0.7717510.8431730.618850-0.213584
20.390975-0.5100560.0706560.3806950.601486-0.5958271.1821930.0110850.6880930.056453...-1.294595-0.1648460.1944350.240742-0.7734430.608631-0.7471810.7919110.611874-0.125455
30.423924-0.5573250.0838100.3287030.399589-0.6228181.0799450.0976110.7240300.135976...-1.412060-0.1065410.1533140.209962-0.6996900.648061-0.7162410.7579860.615963-0.258693
40.335576-0.5595910.1194370.3641410.375474-0.6398331.1447070.0775120.7917590.164201...-1.279041-0.1867330.1069630.254949-0.6516940.594167-0.6804260.8874820.651587-0.144996
\n", "

5 rows × 768 columns

\n", "
" ], "text/plain": [ " 0 1 2 3 4 5 6 \\\n", "0 0.392252 -0.504846 0.056791 0.356297 0.475918 -0.648899 1.157862 \n", "1 0.387422 -0.481142 0.049675 0.353058 0.601170 -0.646099 1.142392 \n", "2 0.390975 -0.510056 0.070656 0.380695 0.601486 -0.595827 1.182193 \n", "3 0.423924 -0.557325 0.083810 0.328703 0.399589 -0.622818 1.079945 \n", "4 0.335576 -0.559591 0.119437 0.364141 0.375474 -0.639833 1.144707 \n", "\n", " 7 8 9 ... 758 759 760 761 \\\n", "0 -0.022914 0.703240 0.192023 ... -1.208714 -0.094441 0.128845 0.403995 \n", "1 0.060092 0.763799 0.110331 ... -1.248282 -0.139790 0.075585 0.202242 \n", "2 0.011085 0.688093 0.056453 ... -1.294595 -0.164846 0.194435 0.240742 \n", "3 0.097611 0.724030 0.135976 ... -1.412060 -0.106541 0.153314 0.209962 \n", "4 0.077512 0.791759 0.164201 ... -1.279041 -0.186733 0.106963 0.254949 \n", "\n", " 762 763 764 765 766 767 \n", "0 -0.782782 0.541907 -0.707272 0.901041 0.629461 -0.020630 \n", "1 -0.729794 0.705914 -0.771751 0.843173 0.618850 -0.213584 \n", "2 -0.773443 0.608631 -0.747181 0.791911 0.611874 -0.125455 \n", "3 -0.699690 0.648061 -0.716241 0.757986 0.615963 -0.258693 \n", "4 -0.651694 0.594167 -0.680426 0.887482 0.651587 -0.144996 \n", "\n", "[5 rows x 768 columns]" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "with torch.no_grad():\n", " df_embeddings_test = model_smi_ted.encode(df_test_normalized['norm_smiles'])\n", "df_embeddings_test.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Experiments - Lipophilicity prediction using smi-ted latent spaces" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### XGBoost prediction using the whole Latent Space" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "from xgboost import XGBRegressor\n", "from sklearn.metrics import mean_squared_error" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
XGBRegressor(base_score=None, booster=None, callbacks=None,\n",
       "             colsample_bylevel=None, colsample_bynode=None,\n",
       "             colsample_bytree=None, device=None, early_stopping_rounds=None,\n",
       "             enable_categorical=False, eval_metric=None, feature_types=None,\n",
       "             gamma=None, grow_policy=None, importance_type=None,\n",
       "             interaction_constraints=None, learning_rate=0.05, max_bin=None,\n",
       "             max_cat_threshold=None, max_cat_to_onehot=None,\n",
       "             max_delta_step=None, max_depth=4, max_leaves=None,\n",
       "             min_child_weight=None, missing=nan, monotone_constraints=None,\n",
       "             multi_strategy=None, n_estimators=2000, n_jobs=None,\n",
       "             num_parallel_tree=None, random_state=None, ...)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "XGBRegressor(base_score=None, booster=None, callbacks=None,\n", " colsample_bylevel=None, colsample_bynode=None,\n", " colsample_bytree=None, device=None, early_stopping_rounds=None,\n", " enable_categorical=False, eval_metric=None, feature_types=None,\n", " gamma=None, grow_policy=None, importance_type=None,\n", " interaction_constraints=None, learning_rate=0.05, max_bin=None,\n", " max_cat_threshold=None, max_cat_to_onehot=None,\n", " max_delta_step=None, max_depth=4, max_leaves=None,\n", " min_child_weight=None, missing=nan, monotone_constraints=None,\n", " multi_strategy=None, n_estimators=2000, n_jobs=None,\n", " num_parallel_tree=None, random_state=None, ...)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xgb_predict = XGBRegressor(n_estimators=2000, learning_rate=0.05, max_depth=4)\n", "xgb_predict.fit(df_embeddings_train, df_train_normalized['y'])" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# get XGBoost predictions\n", "y_pred = xgb_predict.predict(df_embeddings_test)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "RMSE Score: 0.6485\n" ] } ], "source": [ "rmse = np.sqrt(mean_squared_error(df_test_normalized[\"y\"], y_pred))\n", "print(f\"RMSE Score: {rmse:.4f}\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }