GeneticDisorder / app.py
asadAbdullah's picture
Update app.py
e0c3387 verified
raw
history blame
5.17 kB
# Import required libraries
import os
import pandas as pd
import streamlit as st
# from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
from transformers import pipeline
from sentence_transformers import SentenceTransformer, util
import requests
import json
# Configure Hugging Face API token securely
api_key = os.getenv("HF_API_KEY")
# Load the CSV dataset
try:
data = pd.read_csv('genetic-Final.csv') # Ensure the dataset filename is correct
except FileNotFoundError:
st.error("Dataset file not found. Please upload it to this directory.")
# Load DistilBERT Tokenizer and Model
# tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
# model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')
# Load DistilBERT tokenizer and model (without classification layer)
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
model = DistilBertModel.from_pretrained("distilbert-base-uncased")
query = "What is fructose-1,6-bisphosphatase deficiency?"
# Tokenize input
inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True)
# Get model output (embeddings)
with torch.no_grad():
outputs = model(**inputs)
# Extract embeddings (last hidden state)
embeddings = outputs.last_hidden_state.mean(dim=1) # Averaging over token embeddings
# Use the embeddings for further processing or retrieval
print(embeddings)
# Preprocessing the dataset (if needed)
if 'combined_description' not in data.columns:
data['combined_description'] = (
data['Symptoms'].fillna('') + " " +
data['Severity Level'].fillna('') + " " +
data['Risk Assessment'].fillna('') + " " +
data['Treatment Options'].fillna('') + " " +
data['Suggested Medical Tests'].fillna('') + " " +
data['Minimum Values for Medical Tests'].fillna('') + " " +
data['Emergency Treatment'].fillna('')
)
# Initialize Sentence Transformer model for RAG-based retrieval
retriever_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
# Define a function to get embeddings using DistilBERT
def generate_embedding(description):
if description:
inputs = tokenizer(description, return_tensors='pt', truncation=True, padding=True, max_length=512)
outputs = model(**inputs)
return outputs.logits.detach().numpy().flatten()
else:
return []
# Generate embeddings for the combined description
if 'embeddings' not in data.columns:
data['embeddings'] = data['combined_description'].apply(generate_embedding)
# Function to retrieve relevant information based on user query
def get_relevant_info(query, top_k=3):
query_embedding = retriever_model.encode(query)
similarities = [util.cos_sim(query_embedding, doc_emb)[0][0].item() for doc_emb in data['embeddings']]
top_indices = sorted(range(len(similarities)), key=lambda i: similarities[i], reverse=True)[:top_k]
return data.iloc[top_indices]
# Function to generate response using DistilBERT (integrating with the model)
def generate_response(input_text, relevant_info):
# Concatenate the relevant information as context for the model
context = "\n".join(relevant_info['combined_description'].tolist())
input_with_context = f"Context: {context}\n\nUser Query: {input_text}"
# Simple logic for generating a response using DistilBERT-based model
inputs = tokenizer(input_with_context, return_tensors='pt', truncation=True, padding=True, max_length=512)
outputs = model(**inputs)
logits = outputs.logits.detach().numpy().flatten()
response = tokenizer.decode(logits.argmax(), skip_special_tokens=True)
return response
# Streamlit UI for the Chatbot
def main():
st.title("Medical Report and Analysis Chatbot")
st.sidebar.header("Upload Medical Report or Enter Query")
# Text input for user queries
user_query = st.sidebar.text_input("Type your question or query")
# File uploader for medical report
uploaded_file = st.sidebar.file_uploader("Upload a medical report (optional)", type=["txt", "pdf", "csv"])
# Process the query if provided
if user_query:
st.write("### Query Response:")
# Retrieve relevant information from dataset
relevant_info = get_relevant_info(user_query)
st.write("#### Relevant Medical Information:")
for i, row in relevant_info.iterrows():
st.write(f"- {row['combined_description']}") # Adjust to show meaningful info
# Generate a response from DistilBERT model
response = generate_response(user_query, relevant_info)
st.write("#### Model's Response:")
st.write(response)
# Process the uploaded file (if any)
if uploaded_file:
# Display analysis of the uploaded report file (process based on file type)
st.write("### Uploaded Report Analysis:")
report_text = "Extracted report content here" # Placeholder for file processing logic
st.write(report_text)
if __name__ == "__main__":
main()