Spaces:
Sleeping
Sleeping
# Import required libraries | |
import os | |
import pandas as pd | |
import streamlit as st | |
# from transformers import DistilBertTokenizer, DistilBertForSequenceClassification | |
from transformers import DistilBertTokenizerFast, DistilBertModel | |
# from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification | |
from transformers import pipeline | |
from sentence_transformers import SentenceTransformer, util | |
import requests | |
import json | |
# Configure Hugging Face API token securely | |
api_key = os.getenv("HF_API_KEY") | |
# Load the CSV dataset | |
try: | |
data = pd.read_csv('genetic-Final.csv') # Ensure the dataset filename is correct | |
except FileNotFoundError: | |
st.error("Dataset file not found. Please upload it to this directory.") | |
# Load DistilBERT Tokenizer and Model | |
# tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') | |
# model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased') | |
# Load DistilBERT tokenizer and model (without classification layer) | |
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased") | |
model = DistilBertModel.from_pretrained("distilbert-base-uncased") | |
query = "What is fructose-1,6-bisphosphatase deficiency?" | |
# Tokenize input | |
inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True) | |
# Get model output (embeddings) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Extract embeddings (last hidden state) | |
embeddings = outputs.last_hidden_state.mean(dim=1) # Averaging over token embeddings | |
# Use the embeddings for further processing or retrieval | |
print(embeddings) | |
# Preprocessing the dataset (if needed) | |
if 'combined_description' not in data.columns: | |
data['combined_description'] = ( | |
data['Symptoms'].fillna('') + " " + | |
data['Severity Level'].fillna('') + " " + | |
data['Risk Assessment'].fillna('') + " " + | |
data['Treatment Options'].fillna('') + " " + | |
data['Suggested Medical Tests'].fillna('') + " " + | |
data['Minimum Values for Medical Tests'].fillna('') + " " + | |
data['Emergency Treatment'].fillna('') | |
) | |
# Initialize Sentence Transformer model for RAG-based retrieval | |
retriever_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
# Define a function to get embeddings using DistilBERT | |
def generate_embedding(description): | |
if description: | |
inputs = tokenizer(description, return_tensors='pt', truncation=True, padding=True, max_length=512) | |
outputs = model(**inputs) | |
return outputs.logits.detach().numpy().flatten() | |
else: | |
return [] | |
# Generate embeddings for the combined description | |
if 'embeddings' not in data.columns: | |
data['embeddings'] = data['combined_description'].apply(generate_embedding) | |
# Function to retrieve relevant information based on user query | |
def get_relevant_info(query, top_k=3): | |
query_embedding = retriever_model.encode(query) | |
similarities = [util.cos_sim(query_embedding, doc_emb)[0][0].item() for doc_emb in data['embeddings']] | |
top_indices = sorted(range(len(similarities)), key=lambda i: similarities[i], reverse=True)[:top_k] | |
return data.iloc[top_indices] | |
# Function to generate response using DistilBERT (integrating with the model) | |
def generate_response(input_text, relevant_info): | |
# Concatenate the relevant information as context for the model | |
context = "\n".join(relevant_info['combined_description'].tolist()) | |
input_with_context = f"Context: {context}\n\nUser Query: {input_text}" | |
# Simple logic for generating a response using DistilBERT-based model | |
inputs = tokenizer(input_with_context, return_tensors='pt', truncation=True, padding=True, max_length=512) | |
outputs = model(**inputs) | |
logits = outputs.logits.detach().numpy().flatten() | |
response = tokenizer.decode(logits.argmax(), skip_special_tokens=True) | |
return response | |
# Streamlit UI for the Chatbot | |
def main(): | |
st.title("Medical Report and Analysis Chatbot") | |
st.sidebar.header("Upload Medical Report or Enter Query") | |
# Text input for user queries | |
user_query = st.sidebar.text_input("Type your question or query") | |
# File uploader for medical report | |
uploaded_file = st.sidebar.file_uploader("Upload a medical report (optional)", type=["txt", "pdf", "csv"]) | |
# Process the query if provided | |
if user_query: | |
st.write("### Query Response:") | |
# Retrieve relevant information from dataset | |
relevant_info = get_relevant_info(user_query) | |
st.write("#### Relevant Medical Information:") | |
for i, row in relevant_info.iterrows(): | |
st.write(f"- {row['combined_description']}") # Adjust to show meaningful info | |
# Generate a response from DistilBERT model | |
response = generate_response(user_query, relevant_info) | |
st.write("#### Model's Response:") | |
st.write(response) | |
# Process the uploaded file (if any) | |
if uploaded_file: | |
# Display analysis of the uploaded report file (process based on file type) | |
st.write("### Uploaded Report Analysis:") | |
report_text = "Extracted report content here" # Placeholder for file processing logic | |
st.write(report_text) | |
if __name__ == "__main__": | |
main() | |