{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# materials.smi-TED - INFERENCE (Regression)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Install extra packages for notebook\n",
"%pip install seaborn xgboost"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append('../inference')"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# materials.smi-ted (smi-ted)\n",
"from smi_ted_light.load import load_smi_ted\n",
"\n",
"# Data\n",
"import torch\n",
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"# Chemistry\n",
"from rdkit import Chem\n",
"from rdkit.Chem import PandasTools\n",
"from rdkit.Chem import Descriptors\n",
"PandasTools.RenderImagesInAllDataFrames(True)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# function to canonicalize SMILES\n",
"def normalize_smiles(smi, canonical=True, isomeric=False):\n",
" try:\n",
" normalized = Chem.MolToSmiles(\n",
" Chem.MolFromSmiles(smi), canonical=canonical, isomericSmiles=isomeric\n",
" )\n",
" except:\n",
" normalized = None\n",
" return normalized"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Import smi-ted"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Random Seed: 12345\n",
"Using Rotation Embedding\n",
"Using Rotation Embedding\n",
"Using Rotation Embedding\n",
"Using Rotation Embedding\n",
"Using Rotation Embedding\n",
"Using Rotation Embedding\n",
"Using Rotation Embedding\n",
"Using Rotation Embedding\n",
"Using Rotation Embedding\n",
"Using Rotation Embedding\n",
"Using Rotation Embedding\n",
"Using Rotation Embedding\n",
"Vocab size: 2393\n",
"[INFERENCE MODE - smi-ted-Light]\n"
]
}
],
"source": [
"model_smi_ted = load_smi_ted(\n",
" folder='../inference/smi_ted_light',\n",
" ckpt_filename='smi-ted-Light_40.pt'\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Lipophilicity Dataset"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Experiments - Data Load"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"df_train = pd.read_csv(\"../finetune/moleculenet/lipophilicity/train.csv\")\n",
"df_test = pd.read_csv(\"../finetune/moleculenet/lipophilicity/test.csv\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### SMILES canonization"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(3360, 3)\n"
]
},
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" smiles \n",
" y \n",
" norm_smiles \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" Nc1ncnc2c1c(COc3cccc(Cl)c3)nn2C4CCOCC4 \n",
" 0.814313 \n",
" Nc1ncnc2c1c(COc1cccc(Cl)c1)nn2C1CCOCC1 \n",
" \n",
" \n",
" 1 \n",
" COc1cc(cc2cnc(Nc3ccc(cc3)[C@@H](C)NC(=O)C)nc12... \n",
" 0.446346 \n",
" COc1cc(-c2ccncc2)cc2cnc(Nc3ccc(C(C)NC(C)=O)cc3... \n",
" \n",
" \n",
" 2 \n",
" CC(=O)Nc1ccc2ccn(c3cc(Nc4ccn(C)n4)n5ncc(C#N)c5... \n",
" 1.148828 \n",
" CC(=O)Nc1ccc2ccn(-c3cc(Nc4ccn(C)n4)n4ncc(C#N)c... \n",
" \n",
" \n",
" 3 \n",
" Oc1ccc(CCNCCS(=O)(=O)CCCOCCSc2ccccc2)c3sc(O)nc13 \n",
" 0.404532 \n",
" O=S(=O)(CCCOCCSc1ccccc1)CCNCCc1ccc(O)c2nc(O)sc12 \n",
" \n",
" \n",
" 4 \n",
" Clc1ccc2C(=O)C3=C(Nc2c1)C(=O)NN(Cc4cc5ccccc5s4... \n",
" -0.164144 \n",
" O=c1[nH]n(Cc2cc3ccccc3s2)c(=O)c2c(=O)c3ccc(Cl)... \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" smiles y \\\n",
"0 Nc1ncnc2c1c(COc3cccc(Cl)c3)nn2C4CCOCC4 0.814313 \n",
"1 COc1cc(cc2cnc(Nc3ccc(cc3)[C@@H](C)NC(=O)C)nc12... 0.446346 \n",
"2 CC(=O)Nc1ccc2ccn(c3cc(Nc4ccn(C)n4)n5ncc(C#N)c5... 1.148828 \n",
"3 Oc1ccc(CCNCCS(=O)(=O)CCCOCCSc2ccccc2)c3sc(O)nc13 0.404532 \n",
"4 Clc1ccc2C(=O)C3=C(Nc2c1)C(=O)NN(Cc4cc5ccccc5s4... -0.164144 \n",
"\n",
" norm_smiles \n",
"0 Nc1ncnc2c1c(COc1cccc(Cl)c1)nn2C1CCOCC1 \n",
"1 COc1cc(-c2ccncc2)cc2cnc(Nc3ccc(C(C)NC(C)=O)cc3... \n",
"2 CC(=O)Nc1ccc2ccn(-c3cc(Nc4ccn(C)n4)n4ncc(C#N)c... \n",
"3 O=S(=O)(CCCOCCSc1ccccc1)CCNCCc1ccc(O)c2nc(O)sc12 \n",
"4 O=c1[nH]n(Cc2cc3ccccc3s2)c(=O)c2c(=O)c3ccc(Cl)... "
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_train['norm_smiles'] = df_train['smiles'].apply(normalize_smiles)\n",
"df_train_normalized = df_train.dropna()\n",
"print(df_train_normalized.shape)\n",
"df_train_normalized.head()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(420, 3)\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" smiles \n",
" y \n",
" norm_smiles \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" N(c1ccccc1)c2ccnc3ccccc23 \n",
" 0.488161 \n",
" c1ccc(Nc2ccnc3ccccc23)cc1 \n",
" \n",
" \n",
" 1 \n",
" Clc1ccc2Oc3ccccc3N=C(N4CCNCC4)c2c1 \n",
" 0.070017 \n",
" Clc1ccc2c(c1)C(N1CCNCC1)=Nc1ccccc1O2 \n",
" \n",
" \n",
" 2 \n",
" NC1(CCC1)c2ccc(cc2)c3ncc4cccnc4c3c5ccccc5 \n",
" -0.415030 \n",
" NC1(c2ccc(-c3ncc4cccnc4c3-c3ccccc3)cc2)CCC1 \n",
" \n",
" \n",
" 3 \n",
" OC[C@H](O)CN1C(=O)[C@@H](Cc2ccccc12)NC(=O)c3cc... \n",
" 0.897942 \n",
" O=C(NC1Cc2ccccc2N(CC(O)CO)C1=O)c1cc2cc(Cl)sc2[... \n",
" \n",
" \n",
" 4 \n",
" NS(=O)(=O)c1nc2ccccc2s1 \n",
" -0.707731 \n",
" NS(=O)(=O)c1nc2ccccc2s1 \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" smiles y \\\n",
"0 N(c1ccccc1)c2ccnc3ccccc23 0.488161 \n",
"1 Clc1ccc2Oc3ccccc3N=C(N4CCNCC4)c2c1 0.070017 \n",
"2 NC1(CCC1)c2ccc(cc2)c3ncc4cccnc4c3c5ccccc5 -0.415030 \n",
"3 OC[C@H](O)CN1C(=O)[C@@H](Cc2ccccc12)NC(=O)c3cc... 0.897942 \n",
"4 NS(=O)(=O)c1nc2ccccc2s1 -0.707731 \n",
"\n",
" norm_smiles \n",
"0 c1ccc(Nc2ccnc3ccccc23)cc1 \n",
"1 Clc1ccc2c(c1)C(N1CCNCC1)=Nc1ccccc1O2 \n",
"2 NC1(c2ccc(-c3ncc4cccnc4c3-c3ccccc3)cc2)CCC1 \n",
"3 O=C(NC1Cc2ccccc2N(CC(O)CO)C1=O)c1cc2cc(Cl)sc2[... \n",
"4 NS(=O)(=O)c1nc2ccccc2s1 "
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_test['norm_smiles'] = df_test['smiles'].apply(normalize_smiles)\n",
"df_test_normalized = df_test.dropna()\n",
"print(df_test_normalized.shape)\n",
"df_test_normalized.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Embeddings extraction "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### smi-ted embeddings extraction"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 33/33 [00:38<00:00, 1.15s/it]\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" 0 \n",
" 1 \n",
" 2 \n",
" 3 \n",
" 4 \n",
" 5 \n",
" 6 \n",
" 7 \n",
" 8 \n",
" 9 \n",
" ... \n",
" 758 \n",
" 759 \n",
" 760 \n",
" 761 \n",
" 762 \n",
" 763 \n",
" 764 \n",
" 765 \n",
" 766 \n",
" 767 \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" 0.367646 \n",
" -0.504889 \n",
" 0.040485 \n",
" 0.385314 \n",
" 0.564923 \n",
" -0.684497 \n",
" 1.160397 \n",
" 0.071218 \n",
" 0.799428 \n",
" 0.181323 \n",
" ... \n",
" -1.379994 \n",
" -0.167221 \n",
" 0.104886 \n",
" 0.239571 \n",
" -0.744390 \n",
" 0.590423 \n",
" -0.808946 \n",
" 0.792584 \n",
" 0.550898 \n",
" -0.176831 \n",
" \n",
" \n",
" 1 \n",
" 0.455316 \n",
" -0.485554 \n",
" 0.062206 \n",
" 0.387994 \n",
" 0.567590 \n",
" -0.713285 \n",
" 1.144267 \n",
" -0.057046 \n",
" 0.753016 \n",
" 0.112180 \n",
" ... \n",
" -1.332142 \n",
" -0.096662 \n",
" 0.221944 \n",
" 0.327923 \n",
" -0.739358 \n",
" 0.659803 \n",
" -0.775723 \n",
" 0.745837 \n",
" 0.566330 \n",
" -0.111946 \n",
" \n",
" \n",
" 2 \n",
" 0.442309 \n",
" -0.484732 \n",
" 0.084945 \n",
" 0.384787 \n",
" 0.564752 \n",
" -0.704130 \n",
" 1.159491 \n",
" 0.021168 \n",
" 0.846539 \n",
" 0.118463 \n",
" ... \n",
" -1.324177 \n",
" -0.110403 \n",
" 0.207824 \n",
" 0.281665 \n",
" -0.780818 \n",
" 0.693484 \n",
" -0.832626 \n",
" 0.763095 \n",
" 0.532460 \n",
" -0.196708 \n",
" \n",
" \n",
" 3 \n",
" 0.527961 \n",
" -0.519151 \n",
" 0.091635 \n",
" 0.353518 \n",
" 0.421795 \n",
" -0.724220 \n",
" 1.093752 \n",
" 0.148574 \n",
" 0.804047 \n",
" 0.194627 \n",
" ... \n",
" -1.358414 \n",
" -0.111483 \n",
" 0.151692 \n",
" 0.186741 \n",
" -0.601867 \n",
" 0.641591 \n",
" -0.747422 \n",
" 0.794239 \n",
" 0.640765 \n",
" -0.239649 \n",
" \n",
" \n",
" 4 \n",
" 0.464432 \n",
" -0.511090 \n",
" 0.038785 \n",
" 0.346217 \n",
" 0.492919 \n",
" -0.619387 \n",
" 1.048157 \n",
" 0.095910 \n",
" 0.738604 \n",
" 0.119270 \n",
" ... \n",
" -1.223927 \n",
" -0.109863 \n",
" 0.151280 \n",
" 0.244834 \n",
" -0.686610 \n",
" 0.759327 \n",
" -0.756338 \n",
" 0.766427 \n",
" 0.610454 \n",
" -0.197345 \n",
" \n",
" \n",
"
\n",
"
5 rows × 768 columns
\n",
"
"
],
"text/plain": [
" 0 1 2 3 4 5 6 \\\n",
"0 0.367646 -0.504889 0.040485 0.385314 0.564923 -0.684497 1.160397 \n",
"1 0.455316 -0.485554 0.062206 0.387994 0.567590 -0.713285 1.144267 \n",
"2 0.442309 -0.484732 0.084945 0.384787 0.564752 -0.704130 1.159491 \n",
"3 0.527961 -0.519151 0.091635 0.353518 0.421795 -0.724220 1.093752 \n",
"4 0.464432 -0.511090 0.038785 0.346217 0.492919 -0.619387 1.048157 \n",
"\n",
" 7 8 9 ... 758 759 760 761 \\\n",
"0 0.071218 0.799428 0.181323 ... -1.379994 -0.167221 0.104886 0.239571 \n",
"1 -0.057046 0.753016 0.112180 ... -1.332142 -0.096662 0.221944 0.327923 \n",
"2 0.021168 0.846539 0.118463 ... -1.324177 -0.110403 0.207824 0.281665 \n",
"3 0.148574 0.804047 0.194627 ... -1.358414 -0.111483 0.151692 0.186741 \n",
"4 0.095910 0.738604 0.119270 ... -1.223927 -0.109863 0.151280 0.244834 \n",
"\n",
" 762 763 764 765 766 767 \n",
"0 -0.744390 0.590423 -0.808946 0.792584 0.550898 -0.176831 \n",
"1 -0.739358 0.659803 -0.775723 0.745837 0.566330 -0.111946 \n",
"2 -0.780818 0.693484 -0.832626 0.763095 0.532460 -0.196708 \n",
"3 -0.601867 0.641591 -0.747422 0.794239 0.640765 -0.239649 \n",
"4 -0.686610 0.759327 -0.756338 0.766427 0.610454 -0.197345 \n",
"\n",
"[5 rows x 768 columns]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"with torch.no_grad():\n",
" df_embeddings_train = model_smi_ted.encode(df_train_normalized['norm_smiles'])\n",
"df_embeddings_train.head()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 4/4 [00:05<00:00, 1.46s/it]\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" 0 \n",
" 1 \n",
" 2 \n",
" 3 \n",
" 4 \n",
" 5 \n",
" 6 \n",
" 7 \n",
" 8 \n",
" 9 \n",
" ... \n",
" 758 \n",
" 759 \n",
" 760 \n",
" 761 \n",
" 762 \n",
" 763 \n",
" 764 \n",
" 765 \n",
" 766 \n",
" 767 \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" 0.392252 \n",
" -0.504846 \n",
" 0.056791 \n",
" 0.356297 \n",
" 0.475918 \n",
" -0.648899 \n",
" 1.157862 \n",
" -0.022914 \n",
" 0.703240 \n",
" 0.192023 \n",
" ... \n",
" -1.208714 \n",
" -0.094441 \n",
" 0.128845 \n",
" 0.403995 \n",
" -0.782782 \n",
" 0.541907 \n",
" -0.707272 \n",
" 0.901041 \n",
" 0.629461 \n",
" -0.020630 \n",
" \n",
" \n",
" 1 \n",
" 0.387422 \n",
" -0.481142 \n",
" 0.049675 \n",
" 0.353058 \n",
" 0.601170 \n",
" -0.646099 \n",
" 1.142392 \n",
" 0.060092 \n",
" 0.763799 \n",
" 0.110331 \n",
" ... \n",
" -1.248282 \n",
" -0.139790 \n",
" 0.075585 \n",
" 0.202242 \n",
" -0.729794 \n",
" 0.705914 \n",
" -0.771751 \n",
" 0.843173 \n",
" 0.618850 \n",
" -0.213584 \n",
" \n",
" \n",
" 2 \n",
" 0.390975 \n",
" -0.510056 \n",
" 0.070656 \n",
" 0.380695 \n",
" 0.601486 \n",
" -0.595827 \n",
" 1.182193 \n",
" 0.011085 \n",
" 0.688093 \n",
" 0.056453 \n",
" ... \n",
" -1.294595 \n",
" -0.164846 \n",
" 0.194435 \n",
" 0.240742 \n",
" -0.773443 \n",
" 0.608631 \n",
" -0.747181 \n",
" 0.791911 \n",
" 0.611874 \n",
" -0.125455 \n",
" \n",
" \n",
" 3 \n",
" 0.423924 \n",
" -0.557325 \n",
" 0.083810 \n",
" 0.328703 \n",
" 0.399589 \n",
" -0.622818 \n",
" 1.079945 \n",
" 0.097611 \n",
" 0.724030 \n",
" 0.135976 \n",
" ... \n",
" -1.412060 \n",
" -0.106541 \n",
" 0.153314 \n",
" 0.209962 \n",
" -0.699690 \n",
" 0.648061 \n",
" -0.716241 \n",
" 0.757986 \n",
" 0.615963 \n",
" -0.258693 \n",
" \n",
" \n",
" 4 \n",
" 0.335576 \n",
" -0.559591 \n",
" 0.119437 \n",
" 0.364141 \n",
" 0.375474 \n",
" -0.639833 \n",
" 1.144707 \n",
" 0.077512 \n",
" 0.791759 \n",
" 0.164201 \n",
" ... \n",
" -1.279041 \n",
" -0.186733 \n",
" 0.106963 \n",
" 0.254949 \n",
" -0.651694 \n",
" 0.594167 \n",
" -0.680426 \n",
" 0.887482 \n",
" 0.651587 \n",
" -0.144996 \n",
" \n",
" \n",
"
\n",
"
5 rows × 768 columns
\n",
"
"
],
"text/plain": [
" 0 1 2 3 4 5 6 \\\n",
"0 0.392252 -0.504846 0.056791 0.356297 0.475918 -0.648899 1.157862 \n",
"1 0.387422 -0.481142 0.049675 0.353058 0.601170 -0.646099 1.142392 \n",
"2 0.390975 -0.510056 0.070656 0.380695 0.601486 -0.595827 1.182193 \n",
"3 0.423924 -0.557325 0.083810 0.328703 0.399589 -0.622818 1.079945 \n",
"4 0.335576 -0.559591 0.119437 0.364141 0.375474 -0.639833 1.144707 \n",
"\n",
" 7 8 9 ... 758 759 760 761 \\\n",
"0 -0.022914 0.703240 0.192023 ... -1.208714 -0.094441 0.128845 0.403995 \n",
"1 0.060092 0.763799 0.110331 ... -1.248282 -0.139790 0.075585 0.202242 \n",
"2 0.011085 0.688093 0.056453 ... -1.294595 -0.164846 0.194435 0.240742 \n",
"3 0.097611 0.724030 0.135976 ... -1.412060 -0.106541 0.153314 0.209962 \n",
"4 0.077512 0.791759 0.164201 ... -1.279041 -0.186733 0.106963 0.254949 \n",
"\n",
" 762 763 764 765 766 767 \n",
"0 -0.782782 0.541907 -0.707272 0.901041 0.629461 -0.020630 \n",
"1 -0.729794 0.705914 -0.771751 0.843173 0.618850 -0.213584 \n",
"2 -0.773443 0.608631 -0.747181 0.791911 0.611874 -0.125455 \n",
"3 -0.699690 0.648061 -0.716241 0.757986 0.615963 -0.258693 \n",
"4 -0.651694 0.594167 -0.680426 0.887482 0.651587 -0.144996 \n",
"\n",
"[5 rows x 768 columns]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"with torch.no_grad():\n",
" df_embeddings_test = model_smi_ted.encode(df_test_normalized['norm_smiles'])\n",
"df_embeddings_test.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Experiments - Lipophilicity prediction using smi-ted latent spaces"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### XGBoost prediction using the whole Latent Space"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"from xgboost import XGBRegressor\n",
"from sklearn.metrics import mean_squared_error"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"XGBRegressor(base_score=None, booster=None, callbacks=None,\n",
" colsample_bylevel=None, colsample_bynode=None,\n",
" colsample_bytree=None, device=None, early_stopping_rounds=None,\n",
" enable_categorical=False, eval_metric=None, feature_types=None,\n",
" gamma=None, grow_policy=None, importance_type=None,\n",
" interaction_constraints=None, learning_rate=0.05, max_bin=None,\n",
" max_cat_threshold=None, max_cat_to_onehot=None,\n",
" max_delta_step=None, max_depth=4, max_leaves=None,\n",
" min_child_weight=None, missing=nan, monotone_constraints=None,\n",
" multi_strategy=None, n_estimators=2000, n_jobs=None,\n",
" num_parallel_tree=None, random_state=None, ...) In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org. XGBRegressoriFitted XGBRegressor(base_score=None, booster=None, callbacks=None,\n",
" colsample_bylevel=None, colsample_bynode=None,\n",
" colsample_bytree=None, device=None, early_stopping_rounds=None,\n",
" enable_categorical=False, eval_metric=None, feature_types=None,\n",
" gamma=None, grow_policy=None, importance_type=None,\n",
" interaction_constraints=None, learning_rate=0.05, max_bin=None,\n",
" max_cat_threshold=None, max_cat_to_onehot=None,\n",
" max_delta_step=None, max_depth=4, max_leaves=None,\n",
" min_child_weight=None, missing=nan, monotone_constraints=None,\n",
" multi_strategy=None, n_estimators=2000, n_jobs=None,\n",
" num_parallel_tree=None, random_state=None, ...) "
],
"text/plain": [
"XGBRegressor(base_score=None, booster=None, callbacks=None,\n",
" colsample_bylevel=None, colsample_bynode=None,\n",
" colsample_bytree=None, device=None, early_stopping_rounds=None,\n",
" enable_categorical=False, eval_metric=None, feature_types=None,\n",
" gamma=None, grow_policy=None, importance_type=None,\n",
" interaction_constraints=None, learning_rate=0.05, max_bin=None,\n",
" max_cat_threshold=None, max_cat_to_onehot=None,\n",
" max_delta_step=None, max_depth=4, max_leaves=None,\n",
" min_child_weight=None, missing=nan, monotone_constraints=None,\n",
" multi_strategy=None, n_estimators=2000, n_jobs=None,\n",
" num_parallel_tree=None, random_state=None, ...)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"xgb_predict = XGBRegressor(n_estimators=2000, learning_rate=0.05, max_depth=4)\n",
"xgb_predict.fit(df_embeddings_train, df_train_normalized['y'])"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# get XGBoost predictions\n",
"y_pred = xgb_predict.predict(df_embeddings_test)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"RMSE Score: 0.6485\n"
]
}
],
"source": [
"rmse = np.sqrt(mean_squared_error(df_test_normalized[\"y\"], y_pred))\n",
"print(f\"RMSE Score: {rmse:.4f}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}