diff --git "a/notebooks/protac_degradation_predictor.ipynb" "b/notebooks/protac_degradation_predictor.ipynb" --- "a/notebooks/protac_degradation_predictor.ipynb" +++ "b/notebooks/protac_degradation_predictor.ipynb" @@ -539,7 +539,46 @@ "cell_type": "code", "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading embeddings from https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/embeddings/UP000005640_9606/per-protein.h5\n" + ] + }, + { + "ename": "URLError", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mSSLCertVerificationError\u001b[0m Traceback (most recent call last)", + "File \u001b[0;32m/opt/conda/lib/python3.10/urllib/request.py:1348\u001b[0m, in \u001b[0;36mAbstractHTTPHandler.do_open\u001b[0;34m(self, http_class, req, **http_conn_args)\u001b[0m\n\u001b[1;32m 1347\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1348\u001b[0m \u001b[43mh\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrequest\u001b[49m\u001b[43m(\u001b[49m\u001b[43mreq\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_method\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreq\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mselector\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreq\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1349\u001b[0m \u001b[43m \u001b[49m\u001b[43mencode_chunked\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreq\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mhas_header\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mTransfer-encoding\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1350\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mOSError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err: \u001b[38;5;66;03m# timeout error\u001b[39;00m\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/http/client.py:1282\u001b[0m, in \u001b[0;36mHTTPConnection.request\u001b[0;34m(self, method, url, body, headers, encode_chunked)\u001b[0m\n\u001b[1;32m 1281\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Send a complete request to the server.\"\"\"\u001b[39;00m\n\u001b[0;32m-> 1282\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_send_request\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbody\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mencode_chunked\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/http/client.py:1328\u001b[0m, in \u001b[0;36mHTTPConnection._send_request\u001b[0;34m(self, method, url, body, headers, encode_chunked)\u001b[0m\n\u001b[1;32m 1327\u001b[0m body \u001b[38;5;241m=\u001b[39m _encode(body, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbody\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m-> 1328\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mendheaders\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbody\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mencode_chunked\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencode_chunked\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/http/client.py:1277\u001b[0m, in \u001b[0;36mHTTPConnection.endheaders\u001b[0;34m(self, message_body, encode_chunked)\u001b[0m\n\u001b[1;32m 1276\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m CannotSendHeader()\n\u001b[0;32m-> 1277\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_send_output\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmessage_body\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mencode_chunked\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencode_chunked\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/http/client.py:1037\u001b[0m, in \u001b[0;36mHTTPConnection._send_output\u001b[0;34m(self, message_body, encode_chunked)\u001b[0m\n\u001b[1;32m 1036\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_buffer[:]\n\u001b[0;32m-> 1037\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msend\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmsg\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1039\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m message_body \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1040\u001b[0m \n\u001b[1;32m 1041\u001b[0m \u001b[38;5;66;03m# create a consistent interface to message_body\u001b[39;00m\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/http/client.py:975\u001b[0m, in \u001b[0;36mHTTPConnection.send\u001b[0;34m(self, data)\u001b[0m\n\u001b[1;32m 974\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mauto_open:\n\u001b[0;32m--> 975\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconnect\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 976\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/http/client.py:1454\u001b[0m, in \u001b[0;36mHTTPSConnection.connect\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1452\u001b[0m server_hostname \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhost\n\u001b[0;32m-> 1454\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msock \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_context\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwrap_socket\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msock\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1455\u001b[0m \u001b[43m \u001b[49m\u001b[43mserver_hostname\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mserver_hostname\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/ssl.py:513\u001b[0m, in \u001b[0;36mSSLContext.wrap_socket\u001b[0;34m(self, sock, server_side, do_handshake_on_connect, suppress_ragged_eofs, server_hostname, session)\u001b[0m\n\u001b[1;32m 507\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrap_socket\u001b[39m(\u001b[38;5;28mself\u001b[39m, sock, server_side\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 508\u001b[0m do_handshake_on_connect\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 509\u001b[0m suppress_ragged_eofs\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 510\u001b[0m server_hostname\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, session\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[1;32m 511\u001b[0m \u001b[38;5;66;03m# SSLSocket class handles server_hostname encoding before it calls\u001b[39;00m\n\u001b[1;32m 512\u001b[0m \u001b[38;5;66;03m# ctx._wrap_socket()\u001b[39;00m\n\u001b[0;32m--> 513\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msslsocket_class\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_create\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 514\u001b[0m \u001b[43m \u001b[49m\u001b[43msock\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msock\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 515\u001b[0m \u001b[43m \u001b[49m\u001b[43mserver_side\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mserver_side\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 516\u001b[0m \u001b[43m \u001b[49m\u001b[43mdo_handshake_on_connect\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdo_handshake_on_connect\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 517\u001b[0m \u001b[43m \u001b[49m\u001b[43msuppress_ragged_eofs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msuppress_ragged_eofs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 518\u001b[0m \u001b[43m \u001b[49m\u001b[43mserver_hostname\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mserver_hostname\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 519\u001b[0m \u001b[43m \u001b[49m\u001b[43mcontext\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 520\u001b[0m \u001b[43m \u001b[49m\u001b[43msession\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msession\u001b[49m\n\u001b[1;32m 521\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/ssl.py:1071\u001b[0m, in \u001b[0;36mSSLSocket._create\u001b[0;34m(cls, sock, server_side, do_handshake_on_connect, suppress_ragged_eofs, server_hostname, context, session)\u001b[0m\n\u001b[1;32m 1070\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdo_handshake_on_connect should not be specified for non-blocking sockets\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m-> 1071\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdo_handshake\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1072\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mOSError\u001b[39;00m, \u001b[38;5;167;01mValueError\u001b[39;00m):\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/ssl.py:1342\u001b[0m, in \u001b[0;36mSSLSocket.do_handshake\u001b[0;34m(self, block)\u001b[0m\n\u001b[1;32m 1341\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msettimeout(\u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[0;32m-> 1342\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_sslobj\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdo_handshake\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1343\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n", + "\u001b[0;31mSSLCertVerificationError\u001b[0m: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: self signed certificate in certificate chain (_ssl.c:997)", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[0;31mURLError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[9], line 9\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mexists(embeddings_path):\n\u001b[1;32m 7\u001b[0m \u001b[38;5;66;03m# Download the file\u001b[39;00m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mDownloading embeddings from \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdownload_link\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m----> 9\u001b[0m \u001b[43murllib\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrequest\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43murlretrieve\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdownload_link\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43membeddings_path\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/urllib/request.py:241\u001b[0m, in \u001b[0;36murlretrieve\u001b[0;34m(url, filename, reporthook, data)\u001b[0m\n\u001b[1;32m 224\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 225\u001b[0m \u001b[38;5;124;03mRetrieve a URL into a temporary location on disk.\u001b[39;00m\n\u001b[1;32m 226\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 237\u001b[0m \u001b[38;5;124;03mdata file as well as the resulting HTTPMessage object.\u001b[39;00m\n\u001b[1;32m 238\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 239\u001b[0m url_type, path \u001b[38;5;241m=\u001b[39m _splittype(url)\n\u001b[0;32m--> 241\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m contextlib\u001b[38;5;241m.\u001b[39mclosing(\u001b[43murlopen\u001b[49m\u001b[43m(\u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m) \u001b[38;5;28;01mas\u001b[39;00m fp:\n\u001b[1;32m 242\u001b[0m headers \u001b[38;5;241m=\u001b[39m fp\u001b[38;5;241m.\u001b[39minfo()\n\u001b[1;32m 244\u001b[0m \u001b[38;5;66;03m# Just return the local path and the \"headers\" for file://\u001b[39;00m\n\u001b[1;32m 245\u001b[0m \u001b[38;5;66;03m# URLs. No sense in performing a copy unless requested.\u001b[39;00m\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/urllib/request.py:216\u001b[0m, in \u001b[0;36murlopen\u001b[0;34m(url, data, timeout, cafile, capath, cadefault, context)\u001b[0m\n\u001b[1;32m 214\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 215\u001b[0m opener \u001b[38;5;241m=\u001b[39m _opener\n\u001b[0;32m--> 216\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mopener\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mopen\u001b[49m\u001b[43m(\u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/urllib/request.py:519\u001b[0m, in \u001b[0;36mOpenerDirector.open\u001b[0;34m(self, fullurl, data, timeout)\u001b[0m\n\u001b[1;32m 516\u001b[0m req \u001b[38;5;241m=\u001b[39m meth(req)\n\u001b[1;32m 518\u001b[0m sys\u001b[38;5;241m.\u001b[39maudit(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124murllib.Request\u001b[39m\u001b[38;5;124m'\u001b[39m, req\u001b[38;5;241m.\u001b[39mfull_url, req\u001b[38;5;241m.\u001b[39mdata, req\u001b[38;5;241m.\u001b[39mheaders, req\u001b[38;5;241m.\u001b[39mget_method())\n\u001b[0;32m--> 519\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_open\u001b[49m\u001b[43m(\u001b[49m\u001b[43mreq\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 521\u001b[0m \u001b[38;5;66;03m# post-process response\u001b[39;00m\n\u001b[1;32m 522\u001b[0m meth_name \u001b[38;5;241m=\u001b[39m protocol\u001b[38;5;241m+\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_response\u001b[39m\u001b[38;5;124m\"\u001b[39m\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/urllib/request.py:536\u001b[0m, in \u001b[0;36mOpenerDirector._open\u001b[0;34m(self, req, data)\u001b[0m\n\u001b[1;32m 533\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m result\n\u001b[1;32m 535\u001b[0m protocol \u001b[38;5;241m=\u001b[39m req\u001b[38;5;241m.\u001b[39mtype\n\u001b[0;32m--> 536\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_chain\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mhandle_open\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprotocol\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprotocol\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\n\u001b[1;32m 537\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43m_open\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreq\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m result:\n\u001b[1;32m 539\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m result\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/urllib/request.py:496\u001b[0m, in \u001b[0;36mOpenerDirector._call_chain\u001b[0;34m(self, chain, kind, meth_name, *args)\u001b[0m\n\u001b[1;32m 494\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m handler \u001b[38;5;129;01min\u001b[39;00m handlers:\n\u001b[1;32m 495\u001b[0m func \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mgetattr\u001b[39m(handler, meth_name)\n\u001b[0;32m--> 496\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 497\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m result \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 498\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m result\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/urllib/request.py:1391\u001b[0m, in \u001b[0;36mHTTPSHandler.https_open\u001b[0;34m(self, req)\u001b[0m\n\u001b[1;32m 1390\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mhttps_open\u001b[39m(\u001b[38;5;28mself\u001b[39m, req):\n\u001b[0;32m-> 1391\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdo_open\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhttp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclient\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mHTTPSConnection\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreq\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1392\u001b[0m \u001b[43m \u001b[49m\u001b[43mcontext\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_context\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcheck_hostname\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_check_hostname\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/urllib/request.py:1351\u001b[0m, in \u001b[0;36mAbstractHTTPHandler.do_open\u001b[0;34m(self, http_class, req, **http_conn_args)\u001b[0m\n\u001b[1;32m 1348\u001b[0m h\u001b[38;5;241m.\u001b[39mrequest(req\u001b[38;5;241m.\u001b[39mget_method(), req\u001b[38;5;241m.\u001b[39mselector, req\u001b[38;5;241m.\u001b[39mdata, headers,\n\u001b[1;32m 1349\u001b[0m encode_chunked\u001b[38;5;241m=\u001b[39mreq\u001b[38;5;241m.\u001b[39mhas_header(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mTransfer-encoding\u001b[39m\u001b[38;5;124m'\u001b[39m))\n\u001b[1;32m 1350\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mOSError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err: \u001b[38;5;66;03m# timeout error\u001b[39;00m\n\u001b[0;32m-> 1351\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m URLError(err)\n\u001b[1;32m 1352\u001b[0m r \u001b[38;5;241m=\u001b[39m h\u001b[38;5;241m.\u001b[39mgetresponse()\n\u001b[1;32m 1353\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m:\n", + "\u001b[0;31mURLError\u001b[0m: " + ] + } + ], "source": [ "import os\n", "import urllib.request\n", @@ -554,7 +593,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -567,7 +606,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "c400551d3f144a80b0afee76ea6df334", + "model_id": "caecdff147b64f319e5f295a8f92dddc", "version_major": 2, "version_minor": 0 }, @@ -632,7 +671,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -654,7 +693,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -674,13 +713,13 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "417807b6e51542c28782297037ef86c7", + "model_id": "321da9b6e6d746ee868684c5bae33426", "version_major": 2, "version_minor": 0 }, @@ -741,13 +780,13 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "00b924157b074714a11e7c951cb3c8cb", + "model_id": "e89c67067474404798507a1d6eb92295", "version_major": 2, "version_minor": 0 }, @@ -772,7 +811,7 @@ "Name: Avg Tanimoto, dtype: float64" ] }, - "execution_count": 14, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -799,36 +838,25 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 17, "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ - "# Plot the distribution of the average Tanimoto similarity\n", - "import seaborn as sns\n", - "import matplotlib.pyplot as plt\n", + "# # Plot the distribution of the average Tanimoto similarity\n", + "# import seaborn as sns\n", + "# import matplotlib.pyplot as plt\n", "\n", - "sns.histplot(protac_df['Avg Tanimoto'], bins=50)\n", - "plt.xlabel('Average Tanimoto similarity')\n", - "plt.ylabel('Count')\n", - "plt.title('Distribution of average Tanimoto similarity')\n", - "plt.grid(axis='y', alpha=0.5)\n", - "plt.show()" + "# sns.histplot(protac_df['Avg Tanimoto'], bins=50)\n", + "# plt.xlabel('Average Tanimoto similarity')\n", + "# plt.ylabel('Count')\n", + "# plt.title('Distribution of average Tanimoto similarity')\n", + "# plt.grid(axis='y', alpha=0.5)\n", + "# plt.show()" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ @@ -844,7 +872,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ @@ -861,7 +889,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ @@ -870,7 +898,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ @@ -939,7 +967,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 22, "metadata": {}, "outputs": [], "source": [ @@ -950,7 +978,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 23, "metadata": {}, "outputs": [], "source": [ @@ -978,7 +1006,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 24, "metadata": {}, "outputs": [], "source": [ @@ -1037,63 +1065,63 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 26, "metadata": {}, "outputs": [], "source": [ - "import warnings\n", - "import seaborn as sns\n", - "import matplotlib.pyplot as plt\n", - "from sklearn.decomposition import PCA\n", - "from sklearn.preprocessing import StandardScaler\n", - "\n", - "def plot_pca(protac_data, protac_labels):\n", - " \"\"\" Plot the PCA embeddings.\n", + "if RUN_DIMENSIONALITY_REDUCTION:\n", + " import warnings\n", + " import seaborn as sns\n", + " import matplotlib.pyplot as plt\n", + " from sklearn.decomposition import PCA\n", + " from sklearn.preprocessing import StandardScaler\n", "\n", - " Args:\n", - " protac_data (np.ndarray): The PROTAC data.\n", - " protac_labels (dict): The labels for the PROTAC data.\n", - " \"\"\"\n", - " pca = PCA(n_components=2, random_state=42)\n", - " scaler = StandardScaler()\n", - " pca_data = pca.fit_transform(scaler.fit_transform(protac_data))\n", + " def plot_pca(protac_data, protac_labels):\n", + " \"\"\" Plot the PCA embeddings.\n", "\n", - " # Plot PCA accordingly\n", - " labels_to_plot = [\n", - " active_col,\n", - " 'Active - OR',\n", - " 'Cell Line Identifier',\n", - " 'E3 Ligase',\n", - " 'Uniprot',\n", - " # 'Smiles',\n", - " 'Treatment Time (h)',\n", - " 'DC50 (nM)',\n", - " 'Dmax (%)',\n", - " ]\n", - " for label in labels_to_plot:\n", - " if label not in protac_labels:\n", - " continue\n", - " pca_embeddings = {\n", - " 'PCA 1': pca_data[:, 0],\n", - " 'PCA 2': pca_data[:, 1],\n", - " label: protac_labels[label],\n", - " }\n", - " pca_embeddings = pd.DataFrame(pca_embeddings).drop_duplicates()\n", - " # Scatter plot\n", - " with warnings.catch_warnings():\n", - " warnings.simplefilter(\"ignore\")\n", - " sns.scatterplot(data=pca_embeddings, x='PCA 1', y='PCA 2',\n", - " hue=label) #, palette=sns.color_palette('tab10'))\n", - " # Plot legend for active and e3 ligase only\n", - " if label in [active_col, 'E3 Ligase']:\n", - " plt.legend(title=f'{label}:')\n", - " else:\n", - " plt.legend().remove()\n", - " plt.title(f'PCA embedding coloring for \"{label}\"')\n", - " plt.grid(axis='both', alpha=0.5)\n", - " plt.show()\n", + " Args:\n", + " protac_data (np.ndarray): The PROTAC data.\n", + " protac_labels (dict): The labels for the PROTAC data.\n", + " \"\"\"\n", + " pca = PCA(n_components=2, random_state=42)\n", + " scaler = StandardScaler()\n", + " pca_data = pca.fit_transform(scaler.fit_transform(protac_data))\n", + "\n", + " # Plot PCA accordingly\n", + " labels_to_plot = [\n", + " active_col,\n", + " 'Active - OR',\n", + " 'Cell Line Identifier',\n", + " 'E3 Ligase',\n", + " 'Uniprot',\n", + " # 'Smiles',\n", + " 'Treatment Time (h)',\n", + " 'DC50 (nM)',\n", + " 'Dmax (%)',\n", + " ]\n", + " for label in labels_to_plot:\n", + " if label not in protac_labels:\n", + " continue\n", + " pca_embeddings = {\n", + " 'PCA 1': pca_data[:, 0],\n", + " 'PCA 2': pca_data[:, 1],\n", + " label: protac_labels[label],\n", + " }\n", + " pca_embeddings = pd.DataFrame(pca_embeddings).drop_duplicates()\n", + " # Scatter plot\n", + " with warnings.catch_warnings():\n", + " warnings.simplefilter(\"ignore\")\n", + " sns.scatterplot(data=pca_embeddings, x='PCA 1', y='PCA 2',\n", + " hue=label) #, palette=sns.color_palette('tab10'))\n", + " # Plot legend for active and e3 ligase only\n", + " if label in [active_col, 'E3 Ligase']:\n", + " plt.legend(title=f'{label}:')\n", + " else:\n", + " plt.legend().remove()\n", + " plt.title(f'PCA embedding coloring for \"{label}\"')\n", + " plt.grid(axis='both', alpha=0.5)\n", + " plt.show()\n", "\n", - "if RUN_DIMENSIONALITY_REDUCTION:\n", " plot_pca(protac_data, protac_labels)" ] }, @@ -1106,79 +1134,78 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 27, "metadata": {}, "outputs": [], "source": [ - "# Run PCA analysis on protac_data and protac_labels\n", - "import warnings\n", - "import seaborn as sns\n", - "import matplotlib.pyplot as plt\n", - "from mpl_toolkits.mplot3d import Axes3D\n", - "from sklearn.decomposition import PCA\n", - "\n", + "if RUN_DIMENSIONALITY_REDUCTION:\n", + " # Run PCA analysis on protac_data and protac_labels\n", + " import warnings\n", + " import seaborn as sns\n", + " import matplotlib.pyplot as plt\n", + " from mpl_toolkits.mplot3d import Axes3D\n", + " from sklearn.decomposition import PCA\n", "\n", - "def plot_pca_3d(protac_data, protac_labels):\n", - " \"\"\" Plot the PCA embeddings.\n", "\n", - " Args:\n", - " protac_data (np.ndarray): The PROTAC data.\n", - " protac_labels (dict): The labels for the PROTAC data.\n", - " \"\"\"\n", - " pca = PCA(n_components=3, random_state=42)\n", - " scaler = StandardScaler()\n", - " pca_data = pca.fit_transform(scaler.fit_transform(protac_data))\n", - "\n", - " # Plot PCA accordingly\n", - " labels_to_plot = [\n", - " active_col,\n", - " 'Cell Line Identifier',\n", - " 'E3 Ligase',\n", - " 'Uniprot',\n", - " # 'Smiles',\n", - " 'Treatment Time (h)',\n", - " ]\n", - " for label in labels_to_plot:\n", - " pca_embeddings = {\n", - " 'PCA 1': pca_data[:, 0],\n", - " 'PCA 2': pca_data[:, 1],\n", - " 'PCA 3': pca_data[:, 2],\n", - " label: protac_labels[label],\n", - " }\n", - " pca_embeddings = pd.DataFrame(pca_embeddings).drop_duplicates()\n", - " # Scatter plot\n", - " with warnings.catch_warnings():\n", - " warnings.simplefilter(\"ignore\")\n", - " # Plot 3D scatter\n", - " fig = plt.figure()\n", - " ax = fig.add_subplot(111, projection='3d')\n", - " for l in pd.Series(protac_labels[label]).unique():\n", - " ax.scatter(\n", - " pca_embeddings[pca_embeddings[label] == l]['PCA 1'],\n", - " pca_embeddings[pca_embeddings[label] == l]['PCA 2'],\n", - " pca_embeddings[pca_embeddings[label] == l]['PCA 3'],\n", - " label=l,\n", - " )\n", - "\n", - " # sns.scatterplot(data=pca_embeddings, x='PCA 1', y='PCA 2',\n", - " # hue=label) #, palette=sns.color_palette('tab10'))\n", - " # Plot legend for active and e3 ligase only\n", - " if label in [active_col, 'E3 Ligase']:\n", - " plt.legend(title=f'{label}:')\n", - " else:\n", - " plt.legend().remove()\n", - " plt.title(f'PCA embedding coloring for \"{label}\"')\n", - " plt.grid(axis='both', alpha=0.5)\n", - " plt.show()\n", + " def plot_pca_3d(protac_data, protac_labels):\n", + " \"\"\" Plot the PCA embeddings.\n", "\n", + " Args:\n", + " protac_data (np.ndarray): The PROTAC data.\n", + " protac_labels (dict): The labels for the PROTAC data.\n", + " \"\"\"\n", + " pca = PCA(n_components=3, random_state=42)\n", + " scaler = StandardScaler()\n", + " pca_data = pca.fit_transform(scaler.fit_transform(protac_data))\n", + "\n", + " # Plot PCA accordingly\n", + " labels_to_plot = [\n", + " active_col,\n", + " 'Cell Line Identifier',\n", + " 'E3 Ligase',\n", + " 'Uniprot',\n", + " # 'Smiles',\n", + " 'Treatment Time (h)',\n", + " ]\n", + " for label in labels_to_plot:\n", + " pca_embeddings = {\n", + " 'PCA 1': pca_data[:, 0],\n", + " 'PCA 2': pca_data[:, 1],\n", + " 'PCA 3': pca_data[:, 2],\n", + " label: protac_labels[label],\n", + " }\n", + " pca_embeddings = pd.DataFrame(pca_embeddings).drop_duplicates()\n", + " # Scatter plot\n", + " with warnings.catch_warnings():\n", + " warnings.simplefilter(\"ignore\")\n", + " # Plot 3D scatter\n", + " fig = plt.figure()\n", + " ax = fig.add_subplot(111, projection='3d')\n", + " for l in pd.Series(protac_labels[label]).unique():\n", + " ax.scatter(\n", + " pca_embeddings[pca_embeddings[label] == l]['PCA 1'],\n", + " pca_embeddings[pca_embeddings[label] == l]['PCA 2'],\n", + " pca_embeddings[pca_embeddings[label] == l]['PCA 3'],\n", + " label=l,\n", + " )\n", + "\n", + " # sns.scatterplot(data=pca_embeddings, x='PCA 1', y='PCA 2',\n", + " # hue=label) #, palette=sns.color_palette('tab10'))\n", + " # Plot legend for active and e3 ligase only\n", + " if label in [active_col, 'E3 Ligase']:\n", + " plt.legend(title=f'{label}:')\n", + " else:\n", + " plt.legend().remove()\n", + " plt.title(f'PCA embedding coloring for \"{label}\"')\n", + " plt.grid(axis='both', alpha=0.5)\n", + " plt.show()\n", "\n", - "if RUN_DIMENSIONALITY_REDUCTION:\n", " plot_pca_3d(protac_data, protac_labels)" ] }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 28, "metadata": {}, "outputs": [], "source": [ @@ -1212,7 +1239,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 29, "metadata": {}, "outputs": [], "source": [ @@ -1228,7 +1255,7 @@ }, { "cell_type": "code", - "execution_count": 338, + "execution_count": 34, "metadata": {}, "outputs": [ { @@ -1516,7 +1543,7 @@ }, { "cell_type": "code", - "execution_count": 340, + "execution_count": 35, "metadata": {}, "outputs": [ { @@ -1541,53 +1568,45 @@ }, { "cell_type": "code", - "execution_count": 344, + "execution_count": 59, "metadata": {}, "outputs": [], "source": [ "# Import train_val_split\n", "from sklearn.model_selection import train_test_split\n", "\n", - "protac_ds = PROTAC_Dataset(\n", + "train_df, test_df = train_test_split(\n", " protac_df[protac_df[active_col].notna()],\n", + " test_size=0.2,\n", + " random_state=42,\n", + ")\n", + "train_ds = PROTAC_Dataset(\n", + " train_df,\n", " protein_embeddings,\n", " cell2embedding,\n", " smiles2fp,\n", + " active_label=active_col,\n", " use_smote=False,\n", ")\n", - "scaler = protac_ds.fit_scaling(use_single_scaler=True)\n", - "protac_ds.apply_scaling(scaler, use_single_scaler=True)\n", - "\n", - "train_df, test_df = train_test_split(protac_ds, test_size=0.2, random_state=42)" - ] - }, - { - "cell_type": "code", - "execution_count": 85, - "metadata": {}, - "outputs": [], - "source": [ - "# Get the X and y as numpy arrays\n", - "X_train = np.hstack([\n", - " np.array(train_df['Smiles'].tolist()),\n", - " np.array(train_df['Uniprot'].tolist()),\n", - " np.array(train_df['E3 Ligase Uniprot'].tolist()),\n", - " np.array(train_df['Cell Line Identifier'].tolist()),\n", - "])\n", - "y_train = train_df[active_col].values\n", + "scaler = train_ds.fit_scaling(use_single_scaler=True)\n", + "train_ds.apply_scaling(scaler, use_single_scaler=True)\n", + "test_ds = PROTAC_Dataset(\n", + " test_df,\n", + " protein_embeddings,\n", + " cell2embedding,\n", + " smiles2fp,\n", + " active_label=active_col,\n", + " use_smote=False,\n", + ")\n", + "test_ds.apply_scaling(scaler, use_single_scaler=True)\n", "\n", - "X_test = np.hstack([\n", - " np.array(test_df['Smiles'].tolist()),\n", - " np.array(test_df['Uniprot'].tolist()),\n", - " np.array(test_df['E3 Ligase Uniprot'].tolist()),\n", - " np.array(test_df['Cell Line Identifier'].tolist()),\n", - "])\n", - "y_test = test_df[active_col].values" + "X_train, y_train = train_ds.get_numpy_arrays()\n", + "X_test, y_test = test_ds.get_numpy_arrays()" ] }, { "cell_type": "code", - "execution_count": 86, + "execution_count": 60, "metadata": {}, "outputs": [ { @@ -1595,7 +1614,7 @@ "output_type": "stream", "text": [ "Accuracy: 74.4%\n", - "ROC-AUC score: 0.772\n" + "ROC-AUC score: 0.768\n" ] } ], @@ -1623,10 +1642,207 @@ "print(f'ROC-AUC score: {roc_auc:.3f}')" ] }, + { + "cell_type": "code", + "execution_count": 64, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "test_acc: 0.744\n", + "test_f1_score: 0.758\n", + "test_hp_metric: 0.744\n", + "test_opt_score: 1.502\n", + "test_precision: 0.726\n", + "test_recall: 0.793\n", + "test_roc_auc: 0.768\n" + ] + } + ], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from torchmetrics import (\n", + " Accuracy,\n", + " AUROC,\n", + " Precision,\n", + " Recall,\n", + " F1Score,\n", + " MetricCollection,\n", + ")\n", + "\n", + "stages = ['train_metrics', 'val_metrics', 'test_metrics']\n", + "metrics = nn.ModuleDict({s: MetricCollection({\n", + " 'acc': Accuracy(task='binary'),\n", + " 'roc_auc': AUROC(task='binary'),\n", + " 'precision': Precision(task='binary'),\n", + " 'recall': Recall(task='binary'),\n", + " 'f1_score': F1Score(task='binary'),\n", + " 'opt_score': Accuracy(task='binary') + F1Score(task='binary'),\n", + " 'hp_metric': Accuracy(task='binary'),\n", + "}, prefix=s.replace('metrics', '')) for s in stages})\n", + "\n", + "y_pred = torch.tensor(logreg.predict_proba(X_test)[:, 1])\n", + "y_true = torch.tensor(y_test)\n", + "metrics['test_metrics'].update(y_pred, y_true)\n", + "# Print the metrics\n", + "for k, v in metrics['test_metrics'].compute().items():\n", + " print(f'{k}: {v:.3f}')" + ] + }, + { + "cell_type": "code", + "execution_count": 85, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train_acc: 1.000\n", + "train_f1_score: 1.000\n", + "train_hp_metric: 1.000\n", + "train_opt_score: 2.000\n", + "train_precision: 1.000\n", + "train_recall: 1.000\n", + "train_roc_auc: 1.000\n", + "val_acc: 0.744\n", + "val_f1_score: 0.758\n", + "val_hp_metric: 0.744\n", + "val_opt_score: 1.502\n", + "val_precision: 0.726\n", + "val_recall: 0.793\n", + "val_roc_auc: 0.768\n" + ] + } + ], + "source": [ + "from typing import Literal, List, Tuple, Optional\n", + "from sklearn.base import ClassifierMixin\n", + "\n", + "# Generic function to fit and evaluate a classifier model (given as argument),\n", + "# on train and val sets (and optionally a test set) given as dataframes\n", + "def train_sklearn_model(\n", + " clf: ClassifierMixin,\n", + " train_df: pd.DataFrame,\n", + " val_df: pd.DataFrame,\n", + " test_df: Optional[pd.DataFrame] = None,\n", + " active_label: str = 'Active',\n", + " use_single_scaler: bool = True,\n", + ") -> Tuple[ClassifierMixin, nn.ModuleDict]:\n", + " \"\"\" Train a classifier model on train and val sets and evaluate it on a test set.\n", + "\n", + " Args:\n", + " clf: The classifier model to train and evaluate.\n", + " train_df (pd.DataFrame): The training set.\n", + " val_df (pd.DataFrame): The validation set.\n", + " test_df (Optional[pd.DataFrame]): The test set.\n", + "\n", + " Returns:\n", + " Tuple[ClassifierMixin, nn.ModuleDict]: The trained model and the metrics.\n", + " \"\"\"\n", + " # Initialize the datasets\n", + " train_ds = PROTAC_Dataset(\n", + " train_df,\n", + " protein_embeddings,\n", + " cell2embedding,\n", + " smiles2fp,\n", + " active_label=active_label,\n", + " use_smote=False,\n", + " )\n", + " scaler = train_ds.fit_scaling(use_single_scaler=use_single_scaler)\n", + " train_ds.apply_scaling(scaler, use_single_scaler=use_single_scaler)\n", + " val_ds = PROTAC_Dataset(\n", + " val_df,\n", + " protein_embeddings,\n", + " cell2embedding,\n", + " smiles2fp,\n", + " active_label=active_label,\n", + " use_smote=False,\n", + " )\n", + " val_ds.apply_scaling(scaler, use_single_scaler=use_single_scaler)\n", + " if test_df is not None:\n", + " test_ds = PROTAC_Dataset(\n", + " test_df,\n", + " protein_embeddings,\n", + " cell2embedding,\n", + " smiles2fp,\n", + " active_label=active_label,\n", + " use_smote=False,\n", + " )\n", + " test_ds.apply_scaling(scaler, use_single_scaler=use_single_scaler)\n", + "\n", + " # Get the numpy arrays\n", + " X_train, y_train = train_ds.get_numpy_arrays()\n", + " X_val, y_val = val_ds.get_numpy_arrays()\n", + " if test_df is not None:\n", + " X_test, y_test = test_ds.get_numpy_arrays()\n", + "\n", + " # Train the model\n", + " clf.fit(X_train, y_train)\n", + " # Define the metrics as a module dict\n", + " stages = ['train_metrics', 'val_metrics', 'test_metrics']\n", + " metrics = nn.ModuleDict({s: MetricCollection({\n", + " 'acc': Accuracy(task='binary'),\n", + " 'roc_auc': AUROC(task='binary'),\n", + " 'precision': Precision(task='binary'),\n", + " 'recall': Recall(task='binary'),\n", + " 'f1_score': F1Score(task='binary'),\n", + " 'opt_score': Accuracy(task='binary') + F1Score(task='binary'),\n", + " 'hp_metric': Accuracy(task='binary'),\n", + " }, prefix=s.replace('metrics', '')) for s in stages})\n", + "\n", + " # Get the predictions\n", + " metrics_out = {}\n", + "\n", + " y_pred = torch.tensor(clf.predict_proba(X_train)[:, 1])\n", + " y_true = torch.tensor(y_train)\n", + " metrics['train_metrics'].update(y_pred, y_true)\n", + " metrics_out.update(metrics['train_metrics'].compute())\n", + "\n", + " y_pred = torch.tensor(clf.predict_proba(X_val)[:, 1])\n", + " y_true = torch.tensor(y_val)\n", + " metrics['val_metrics'].update(y_pred, y_true)\n", + " metrics_out.update(metrics['val_metrics'].compute())\n", + "\n", + " if test_df is not None:\n", + " y_pred = torch.tensor(clf.predict_proba(X_test)[:, 1])\n", + " y_true = torch.tensor(y_test)\n", + " metrics['test_metrics'].update(y_pred, y_true)\n", + " metrics_out.update(metrics['test_metrics'].compute())\n", + "\n", + " return clf, metrics_out\n", + "\n", + "# Train the logistic regression model\n", + "logreg = LogisticRegression(\n", + " penalty=None, #'l2',\n", + " max_iter=2000,\n", + " solver='lbfgs',\n", + " # dual=True, # True when n_features > n_samples\n", + " # C=0.05,\n", + " # tol=1e-5,\n", + " random_state=42,\n", + " # n_jobs=-1,\n", + ")\n", + "logreg, metrics = train_sklearn_model(logreg, train_df, test_df, active_label=active_col)\n", + "\n", + "for k, v in metrics.items():\n", + " print(f'{k}: {v:.3f}')" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -1662,7 +1878,7 @@ }, { "cell_type": "code", - "execution_count": 110, + "execution_count": 40, "metadata": {}, "outputs": [], "source": [ @@ -1675,7 +1891,7 @@ }, { "cell_type": "code", - "execution_count": 136, + "execution_count": 41, "metadata": {}, "outputs": [], "source": [ @@ -1685,7 +1901,7 @@ }, { "cell_type": "code", - "execution_count": 200, + "execution_count": 42, "metadata": {}, "outputs": [ { @@ -1705,7 +1921,7 @@ "Name: Dmax (%), Length: 812, dtype: float64" ] }, - "execution_count": 200, + "execution_count": 42, "metadata": {}, "output_type": "execute_result" } @@ -1716,7 +1932,7 @@ }, { "cell_type": "code", - "execution_count": 329, + "execution_count": 47, "metadata": {}, "outputs": [ { @@ -1730,14 +1946,15 @@ ] }, { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" + "ename": "NameError", + "evalue": "name 'sns' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[47], line 88\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mNumber of training points: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(y_train)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m (after resampling)\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 87\u001b[0m \u001b[38;5;66;03m# Plot y_train\u001b[39;00m\n\u001b[0;32m---> 88\u001b[0m \u001b[43msns\u001b[49m\u001b[38;5;241m.\u001b[39mhistplot(y_train, kde\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 89\u001b[0m plt\u001b[38;5;241m.\u001b[39mtitle(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mDmax (\u001b[39m\u001b[38;5;124m%\u001b[39m\u001b[38;5;124m) distribution\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 90\u001b[0m plt\u001b[38;5;241m.\u001b[39mxlabel(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mDmax (\u001b[39m\u001b[38;5;124m%\u001b[39m\u001b[38;5;124m)\u001b[39m\u001b[38;5;124m'\u001b[39m)\n", + "\u001b[0;31mNameError\u001b[0m: name 'sns' is not defined" + ] } ], "source": [ @@ -1838,25 +2055,15 @@ }, { "cell_type": "code", - "execution_count": 335, + "execution_count": 45, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "R^2 score: 0.948\n" + "R^2 score: 0.953\n" ] - }, - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" } ], "source": [ @@ -1898,7 +2105,7 @@ }, { "cell_type": "code", - "execution_count": 327, + "execution_count": 46, "metadata": {}, "outputs": [ { @@ -1909,14 +2116,15 @@ ] }, { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" + "ename": "NameError", + "evalue": "name 'plt' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[46], line 57\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[38;5;66;03m# Plot the loss curve\u001b[39;00m\n\u001b[1;32m 56\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m solver \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mlbfgs\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[0;32m---> 57\u001b[0m \u001b[43mplt\u001b[49m\u001b[38;5;241m.\u001b[39mplot(dc50_regr\u001b[38;5;241m.\u001b[39mloss_curve_, label\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mloss\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 58\u001b[0m plt\u001b[38;5;241m.\u001b[39mplot(dc50_regr\u001b[38;5;241m.\u001b[39mvalidation_scores_, label\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mR^2\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 59\u001b[0m plt\u001b[38;5;241m.\u001b[39mgrid(axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mboth\u001b[39m\u001b[38;5;124m'\u001b[39m, alpha\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.5\u001b[39m)\n", + "\u001b[0;31mNameError\u001b[0m: name 'plt' is not defined" + ] } ], "source": [ @@ -1986,7 +2194,7 @@ }, { "cell_type": "code", - "execution_count": 337, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -5121,7 +5329,7 @@ }, { "cell_type": "code", - "execution_count": 357, + "execution_count": 48, "metadata": {}, "outputs": [], "source": [ @@ -5139,7 +5347,7 @@ }, { "cell_type": "code", - "execution_count": 360, + "execution_count": 49, "metadata": {}, "outputs": [ { @@ -5184,7 +5392,7 @@ }, { "cell_type": "code", - "execution_count": 361, + "execution_count": 50, "metadata": {}, "outputs": [ { @@ -5231,7 +5439,7 @@ }, { "cell_type": "code", - "execution_count": 362, + "execution_count": 57, "metadata": {}, "outputs": [ { @@ -5239,13 +5447,13 @@ "output_type": "stream", "text": [ "Number of unique groups: 71\n", - "Number of entries in the test set: 170 (19.8%)\n", + "Number of entries in the test set: 171 (20.0%)\n", "Active/inactive PROTACs in the test set:\n", - "False 0.56\n", - "True 0.44\n", + "False 0.53\n", + "True 0.47\n", "Name: Active (Dmax 0.6, pDC50 6.0), dtype: float64\n", "Number of SMILES leaking in the train_val_df dataset: 0\n", - "Number of Uniprot leaking in the train_val_df dataset: 32\n" + "Number of Uniprot leaking in the train_val_df dataset: 25\n" ] } ], @@ -5266,13 +5474,17 @@ "# plt.grid(axis='y', alpha=0.5)\n", "# plt.show()\n", "\n", + "# Get the groups ordered by the average tanimoto distance (highest to lowest)\n", + "# NOTE: This will put the \"less similar\" PROTACs in the test set\n", + "tanimoto_groups = active_df.groupby('Tanimoto Group')['Avg Tanimoto'].mean().sort_values(ascending=False).index\n", + "\n", "test_df = []\n", "# For each group, get the number of active and inactive entries. Then, add those\n", "# entries to the test_df if: 1) the test_df lenght + the group entries is less\n", "# 20% of the active_df lenght, and 2) the percentage of True and False entries\n", "# in the active_col in test_df is roughly 50%.\n", "# Start the loop from the groups containing the smallest number of entries.\n", - "for group in reversed(active_df['Tanimoto Group'].value_counts().index):\n", + "for group in tanimoto_groups:\n", " group_df = active_df[active_df['Tanimoto Group'] == group]\n", " if test_df == []:\n", " test_df.append(group_df)\n", @@ -5292,8 +5504,8 @@ " if num_entries_test + num_entries < 0.1 * len(active_df):\n", " test_df.append(group_df)\n", " continue\n", - " if (num_active_group + num_active_test) / (num_entries_test + num_entries) < 0.6:\n", - " if (num_inactive_group + num_inactive_test) / (num_entries_test + num_entries) < 0.6:\n", + " if (num_active_group + num_active_test) / (num_entries_test + num_entries) < 0.55:\n", + " if (num_inactive_group + num_inactive_test) / (num_entries_test + num_entries) < 0.55:\n", " test_df.append(group_df)\n", "test_df = pd.concat(test_df)\n", "# Save to global dictionary of test indeces\n", @@ -5311,6 +5523,22 @@ "print(f'Number of Uniprot leaking in the train_val_df dataset: {len(uniprot_leak)}')" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```\n", + "Number of unique groups: 71\n", + "Number of entries in the test set: 170 (19.8%)\n", + "Active/inactive PROTACs in the test set:\n", + "False 0.56\n", + "True 0.44\n", + "Name: Active (Dmax 0.6, pDC50 6.0), dtype: float64\n", + "Number of SMILES leaking in the train_val_df dataset: 0\n", + "Number of Uniprot leaking in the train_val_df dataset: 32\n", + "```" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -12932,7 +13160,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -12946,7 +13174,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.0" + "version": "3.10.8" } }, "nbformat": 4,