# !pip install streamlit # !pip install transformers # !pip install torch # !pip install scikit-learn import streamlit as st import pandas as pd import torch from transformers import AutoTokenizer, AutoModel from sklearn.metrics.pairwise import cosine_similarity import numpy as np import os from PIL import Image # Load the BioBERT model and tokenizer @st.cache_resource def load_model_and_tokenizer(): model_name = "dmis-lab/biobert-base-cased-v1.1" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModel.from_pretrained(model_name).to(device) return tokenizer, model # Function to generate embeddings for a single input text def generate_single_embedding(text, tokenizer, model): model.eval() with torch.no_grad(): encoding = tokenizer( text, max_length=512, padding="max_length", truncation=True, return_tensors="pt", ) encoding = {key: val.squeeze(0).to(device) for key, val in encoding.items()} output = model(**encoding) return output.last_hidden_state[:, 0, :].cpu().numpy() # Load the dataset and embeddings @st.cache_data def load_data_and_embeddings(): file_name = "./filtered_combined.xlsx" model_file = "./biobert_embeddings.pt" df = pd.read_excel(file_name) df["Combined_Text"] = df["Combined Column"].fillna("") embeddings = torch.load(model_file) return df, embeddings # Function to get top N similar trials def get_similar_trials(query_embedding, embeddings, top_n=10): query_embedding_cpu = query_embedding.cpu().detach().numpy() embeddings_cpu = embeddings.cpu().detach().numpy() similarities = cosine_similarity(query_embedding_cpu, embeddings_cpu) similar_indices = similarities.argsort(axis=1)[:, -top_n-1:-1][:, ::-1] return similar_indices, similarities # Load resources device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer, model = load_model_and_tokenizer() df, embeddings = load_data_and_embeddings() def main(): tokenizer, model = load_model_and_tokenizer() st.write("Model and Tokenizer Loaded Successfully!") # Add your Streamlit app code here # Streamlit GUI st.title("Clinical Trials Similarity Finder") st.write("Find the most similar clinical trials using BioBERT embeddings.") dropdown_data = [ ("NCT00385736", "Efficacy and Safety of Adalimumab in Subjects With Moderately to Severely Acute Ulcerative Colitis"), ("NCT00386607", "A Safety and Tolerability Study of the Combination of Aliskiren/Valsartan in Patients With High Blood Pressure, Followed by Long-term Safety and Tolerability of Aliskiren, Valsartan and Hydrochlorothiazide."), ("NCT03518073", "A Study of LY3303560 in Participants With Early Symptomatic Alzheimer's Disease"), ] st.write("Use the following NCT_IDs for testing the project.") # Create a DataFrame for better presentation df1 = pd.DataFrame(dropdown_data, columns=["NCT ID", "Study Title"]) # Show the table in the Streamlit app st.dataframe(df1) # Input method # option = st.radio( # "Search by:", # ("NCT ID", "Outcome or Criteria"), # index=0, # help="Choose how you want to search for similar trials." # ) # if option == "NCT ID": # nct_id = st.text_input("Enter NCT ID:", placeholder="e.g., NCT00385736") # else: # criteria_text = st.text_area( # "Enter Outcome or Criteria:", # placeholder="e.g., A study evaluating the effects of drug X on Y patients..." # ) nct_id = st.text_input("Enter NCT ID:", placeholder="e.g., NCT00385736") top_n = st.slider("Number of similar trials to retrieve:", min_value=1, max_value=20, value=10) if st.button("Find Similar Trials"): # if option == "NCT ID" and nct_id: # # Search by NCT ID # nct_id_to_index = {nct_id: idx for idx, nct_id in enumerate(df["nct_id"])} # if nct_id in nct_id_to_index: # query_idx = nct_id_to_index[nct_id] # query_embedding = embeddings[query_idx].unsqueeze(0).to(device) # else: # st.error(f"NCT ID {nct_id} not found in the dataset.") # st.stop() # elif option == "Outcome or Criteria" and criteria_text: # # Search by text # query_embedding = torch.tensor(generate_single_embedding(criteria_text, tokenizer, model)).to(device) # else: # st.error("Please provide a valid input.") # st.stop() if nct_id: # Search by NCT ID nct_id_to_index = {nct_id: idx for idx, nct_id in enumerate(df["nct_id"])} if nct_id in nct_id_to_index: query_idx = nct_id_to_index[nct_id] query_embedding = embeddings[query_idx].unsqueeze(0).to(device) else: st.error(f"NCT ID {nct_id} not found in the dataset.") st.stop() # Get similar trials similar_indices, similarities = get_similar_trials(query_embedding, embeddings, top_n=top_n) similar_trials = df.iloc[similar_indices[0]].copy() similar_trials["Similarity Score"] = [ similarities[0, idx] for idx in similar_indices[0] ] # Display results st.write("### Top Similar Clinical Trials:") st.dataframe(similar_trials[["nct_id", "Study Title", "Similarity Score"]]) # Download as Excel output_file = "similar_trials_results.xlsx" similar_trials=similar_trials.drop(columns=['Combined_Text', 'Combined Column']) similar_trials.to_excel(output_file, index=False) with open(output_file, "rb") as f: st.download_button("Download Results as Excel", f, file_name="similar_trials_results.xlsx") st.title("Visualizations") # t-SNE Plot Section st.subheader("t-SNE Plot") st.write( "t-SNE (t-Distributed Stochastic Neighbor Embedding) is used to visualize high-dimensional embeddings in a lower-dimensional space, helping to identify clusters or patterns in the data." ) tsne_image_path = "model/tsne_visualization.png" # Replace with the actual path to your t-SNE plot image tsne_image = Image.open(tsne_image_path) st.image(tsne_image, caption="t-SNE Plot") st.markdown("---") # Cosine Similarity Matrix Section st.subheader("Cosine Similarity Matrix") st.write( "The cosine similarity matrix shows the similarity scores between different clinical trial embeddings, where higher scores indicate more similar trials." ) cosine_matrix_image_path = "model/cosine_similarity.png" # Replace with the actual path to your cosine similarity matrix image cosine_matrix_image = Image.open(cosine_matrix_image_path) st.image(cosine_matrix_image, caption="Cosine Similarity Matrix") st.markdown( "### Reference\n" "For more information, visit the [code files uploaded on Hugging Face](https://huggingface.co/spaces/yashgupta1512/nest/tree/main/model)." ) if __name__ == "__main__": main()