{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "d31b58d0-132d-4a98-b199-c3b1d2ed9eb5", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/klkehl/miniconda3/envs/vllm/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", "Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:03<00:00, 1.62s/it]\n" ] } ], "source": [ "import gradio as gr\n", "import pandas as pd\n", "import torch\n", "import torch.nn.functional as F\n", "from sentence_transformers import SentenceTransformer\n", "from safetensors import safe_open\n", "from transformers import pipeline, AutoTokenizer\n", "\n", "# Load trial spaces data\n", "trial_spaces = pd.read_csv('ctgov_all_trials_trial_space_lineitems_10-31-24.csv')\n", "\n", "# Load embedding model\n", "embedding_model = SentenceTransformer('reranker_round2.model', trust_remote_code=True, device='cuda')\n", "\n", "# Load precomputed trial space embeddings\n", "with safe_open(\"trial_space_embeddings.safetensors\", framework=\"pt\", device=0) as f:\n", " trial_space_embeddings = f.get_tensor(\"space_embeddings\")\n", "\n", "# Load checker pipeline\n", "tokenizer = AutoTokenizer.from_pretrained(\"roberta-large\")\n", "checker_pipe = pipeline('text-classification', './roberta-checker', tokenizer=tokenizer, \n", " truncation=True, padding='max_length', max_length=512, device='cuda')\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "36d48a31-8514-4b0d-84a9-5fccc7ec7227", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7860\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "