# Import required libraries import os import pandas as pd import streamlit as st # from transformers import DistilBertTokenizer, DistilBertForSequenceClassification from transformers import DistilBertTokenizerFast, DistilBertModel # from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification from transformers import pipeline from sentence_transformers import SentenceTransformer, util import requests import json # Configure Hugging Face API token securely api_key = os.getenv("HF_API_KEY") # Load the CSV dataset try: data = pd.read_csv('genetic-Final.csv') # Ensure the dataset filename is correct except FileNotFoundError: st.error("Dataset file not found. Please upload it to this directory.") # Load DistilBERT Tokenizer and Model # tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') # model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased') # Load DistilBERT tokenizer and model (without classification layer) tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased") model = DistilBertModel.from_pretrained("distilbert-base-uncased") query = "What is fructose-1,6-bisphosphatase deficiency?" # Tokenize input inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True) # Get model output (embeddings) with torch.no_grad(): outputs = model(**inputs) # Extract embeddings (last hidden state) embeddings = outputs.last_hidden_state.mean(dim=1) # Averaging over token embeddings # Use the embeddings for further processing or retrieval print(embeddings) # Preprocessing the dataset (if needed) if 'combined_description' not in data.columns: data['combined_description'] = ( data['Symptoms'].fillna('') + " " + data['Severity Level'].fillna('') + " " + data['Risk Assessment'].fillna('') + " " + data['Treatment Options'].fillna('') + " " + data['Suggested Medical Tests'].fillna('') + " " + data['Minimum Values for Medical Tests'].fillna('') + " " + data['Emergency Treatment'].fillna('') ) # Initialize Sentence Transformer model for RAG-based retrieval retriever_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') # Define a function to get embeddings using DistilBERT def generate_embedding(description): if description: inputs = tokenizer(description, return_tensors='pt', truncation=True, padding=True, max_length=512) outputs = model(**inputs) return outputs.logits.detach().numpy().flatten() else: return [] # Generate embeddings for the combined description if 'embeddings' not in data.columns: data['embeddings'] = data['combined_description'].apply(generate_embedding) # Function to retrieve relevant information based on user query def get_relevant_info(query, top_k=3): query_embedding = retriever_model.encode(query) similarities = [util.cos_sim(query_embedding, doc_emb)[0][0].item() for doc_emb in data['embeddings']] top_indices = sorted(range(len(similarities)), key=lambda i: similarities[i], reverse=True)[:top_k] return data.iloc[top_indices] # Function to generate response using DistilBERT (integrating with the model) def generate_response(input_text, relevant_info): # Concatenate the relevant information as context for the model context = "\n".join(relevant_info['combined_description'].tolist()) input_with_context = f"Context: {context}\n\nUser Query: {input_text}" # Simple logic for generating a response using DistilBERT-based model inputs = tokenizer(input_with_context, return_tensors='pt', truncation=True, padding=True, max_length=512) outputs = model(**inputs) logits = outputs.logits.detach().numpy().flatten() response = tokenizer.decode(logits.argmax(), skip_special_tokens=True) return response # Streamlit UI for the Chatbot def main(): st.title("Medical Report and Analysis Chatbot") st.sidebar.header("Upload Medical Report or Enter Query") # Text input for user queries user_query = st.sidebar.text_input("Type your question or query") # File uploader for medical report uploaded_file = st.sidebar.file_uploader("Upload a medical report (optional)", type=["txt", "pdf", "csv"]) # Process the query if provided if user_query: st.write("### Query Response:") # Retrieve relevant information from dataset relevant_info = get_relevant_info(user_query) st.write("#### Relevant Medical Information:") for i, row in relevant_info.iterrows(): st.write(f"- {row['combined_description']}") # Adjust to show meaningful info # Generate a response from DistilBERT model response = generate_response(user_query, relevant_info) st.write("#### Model's Response:") st.write(response) # Process the uploaded file (if any) if uploaded_file: # Display analysis of the uploaded report file (process based on file type) st.write("### Uploaded Report Analysis:") report_text = "Extracted report content here" # Placeholder for file processing logic st.write(report_text) if __name__ == "__main__": main()