{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "d31b58d0-132d-4a98-b199-c3b1d2ed9eb5", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/klkehl/miniconda3/envs/vllm/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", "Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:03<00:00, 1.62s/it]\n" ] } ], "source": [ "import gradio as gr\n", "import pandas as pd\n", "import torch\n", "import torch.nn.functional as F\n", "from sentence_transformers import SentenceTransformer\n", "from safetensors import safe_open\n", "from transformers import pipeline, AutoTokenizer\n", "\n", "# Load trial spaces data\n", "trial_spaces = pd.read_csv('ctgov_all_trials_trial_space_lineitems_10-31-24.csv')\n", "\n", "# Load embedding model\n", "embedding_model = SentenceTransformer('reranker_round2.model', trust_remote_code=True, device='cuda')\n", "\n", "# Load precomputed trial space embeddings\n", "with safe_open(\"trial_space_embeddings.safetensors\", framework=\"pt\", device=0) as f:\n", " trial_space_embeddings = f.get_tensor(\"space_embeddings\")\n", "\n", "# Load checker pipeline\n", "tokenizer = AutoTokenizer.from_pretrained(\"roberta-large\")\n", "checker_pipe = pipeline('text-classification', './roberta-checker', tokenizer=tokenizer, \n", " truncation=True, padding='max_length', max_length=512, device='cuda')\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "36d48a31-8514-4b0d-84a9-5fccc7ec7227", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7860\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import gradio as gr\n", "import pandas as pd\n", "import torch\n", "import torch.nn.functional as F\n", "from sentence_transformers import SentenceTransformer\n", "from safetensors import safe_open\n", "from transformers import pipeline, AutoTokenizer\n", "\n", "# We assume the following objects have already been loaded:\n", "# trial_spaces (DataFrame), embedding_model (SentenceTransformer),\n", "# trial_space_embeddings (torch.tensor), checker_pipe (transformers pipeline)\n", "\n", "def match_clinical_trials(patient_summary: str):\n", " # Encode patient summary\n", " patient_embedding = embedding_model.encode([patient_summary], convert_to_tensor=True)\n", " \n", " # Compute similarities\n", " similarities = F.cosine_similarity(patient_embedding, trial_space_embeddings)\n", " \n", " # Pull top 10\n", " sorted_similarities, sorted_indices = torch.sort(similarities, descending=True)\n", " top_indices = sorted_indices[0:10].cpu().numpy()\n", " \n", " relevant_spaces = trial_spaces.iloc[top_indices].this_space\n", " relevant_nctid = trial_spaces.iloc[top_indices].nct_id\n", " relevant_title = trial_spaces.iloc[top_indices].title\n", " relevant_brief_summary = trial_spaces.iloc[top_indices].brief_summary\n", " relevant_eligibility_criteria = trial_spaces.iloc[top_indices].eligibility_criteria\n", "\n", " analysis = pd.DataFrame({\n", " 'patient_summary': patient_summary, \n", " 'this_space': relevant_spaces,\n", " 'nct_id': relevant_nctid, \n", " 'trial_title': relevant_title,\n", " 'trial_brief_summary': relevant_brief_summary, \n", " 'trial_eligibility_criteria': relevant_eligibility_criteria\n", " }).reset_index(drop=True)\n", " \n", " analysis['pt_trial_pair'] = analysis['this_space'] + \"\\nNow here is the patient summary:\" + analysis['patient_summary']\n", " \n", " # Run checker pipeline\n", " classifier_results = checker_pipe(analysis.pt_trial_pair.tolist())\n", " analysis['trial_checker_result'] = [x['label'] for x in classifier_results]\n", " analysis['trial_checker_score'] = [x['score'] for x in classifier_results]\n", " \n", " # Return a subset of columns that are most relevant\n", " return analysis[[\n", " 'nct_id', \n", " 'trial_title', \n", " 'trial_brief_summary', \n", " 'trial_eligibility_criteria', \n", " 'trial_checker_result', \n", " 'trial_checker_score'\n", " ]]\n", "\n", "custom_css = \"\"\"\n", "#input_box textarea {\n", " width: 600px !important;\n", " height: 250px !important;\n", "}\n", "\n", "#output_df table {\n", " width: 100% !important;\n", " table-layout: auto !important;\n", " border-collapse: collapse !important;\n", "}\n", "\n", "#output_df table td, #output_df table th {\n", " min-width: 100px;\n", " overflow: hidden;\n", " text-overflow: ellipsis;\n", " white-space: nowrap;\n", " border: 1px solid #ccc;\n", " padding: 4px;\n", "}\n", "\"\"\"\n", "\n", "# JavaScript for enabling colResizable\n", "js_script = \"\"\"\n", "\n", "\n", "\n", "\"\"\"\n", "\n", "with gr.Blocks(css=custom_css) as demo:\n", " gr.HTML(\"

Clinical Trial Matcher

\")\n", " patient_summary_input = gr.Textbox(label=\"Enter Patient Summary\", elem_id=\"input_box\")\n", " submit_btn = gr.Button(\"Find Matches\")\n", " output_df = gr.DataFrame(\n", " headers=[\n", " \"nct_id\", \n", " \"trial_title\", \n", " \"trial_brief_summary\", \n", " \"trial_eligibility_criteria\", \n", " \"trial_checker_result\", \n", " \"trial_checker_score\"\n", " ], \n", " elem_id=\"output_df\"\n", " )\n", "\n", " submit_btn.click(fn=match_clinical_trials, \n", " inputs=patient_summary_input, \n", " outputs=output_df)\n", " \n", " gr.HTML(js_script)\n", "\n", "demo.launch()\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "80ba3cd2-6a76-44d0-b5f4-6d3debd510ff", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Closing server running on port: 7860\n" ] } ], "source": [ "demo.close()" ] }, { "cell_type": "code", "execution_count": null, "id": "5e43df71-6f06-48d2-8dce-b1f27ab40d6c", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 5 }