File size: 5,170 Bytes
728b290
 
 
 
e0c3387
 
 
728b290
 
 
 
 
6feb2e4
6d55797
728b290
6feb2e4
6d55797
7e5c1c8
6d55797
 
728b290
6feb2e4
e0c3387
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
728b290
6feb2e4
6d55797
 
 
 
 
 
 
 
 
 
f69047d
7e5c1c8
 
 
6feb2e4
 
 
 
 
 
 
 
7ed7e20
6feb2e4
7e5c1c8
6feb2e4
7e5c1c8
6feb2e4
7e5c1c8
 
 
 
 
 
6feb2e4
f69047d
6feb2e4
f69047d
 
7ed7e20
6feb2e4
 
 
 
 
 
 
7ed7e20
7e5c1c8
728b290
7e5c1c8
 
728b290
 
7e5c1c8
728b290
7e5c1c8
728b290
 
 
 
6feb2e4
7e5c1c8
6feb2e4
 
 
7e5c1c8
6feb2e4
728b290
6feb2e4
7e5c1c8
 
 
728b290
7e5c1c8
728b290
6feb2e4
728b290
6feb2e4
728b290
 
 
6d55797
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# Import required libraries
import os
import pandas as pd
import streamlit as st
# from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification

from transformers import pipeline
from sentence_transformers import SentenceTransformer, util
import requests
import json

# Configure Hugging Face API token securely
api_key = os.getenv("HF_API_KEY")

# Load the CSV dataset
try:
    data = pd.read_csv('genetic-Final.csv')  # Ensure the dataset filename is correct
except FileNotFoundError:
    st.error("Dataset file not found. Please upload it to this directory.")

# Load DistilBERT Tokenizer and Model
# tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
# model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')

# Load DistilBERT tokenizer and model (without classification layer)
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
model = DistilBertModel.from_pretrained("distilbert-base-uncased")

query = "What is fructose-1,6-bisphosphatase deficiency?"

# Tokenize input
inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True)

# Get model output (embeddings)
with torch.no_grad():
    outputs = model(**inputs)

# Extract embeddings (last hidden state)
embeddings = outputs.last_hidden_state.mean(dim=1)  # Averaging over token embeddings

# Use the embeddings for further processing or retrieval
print(embeddings)


# Preprocessing the dataset (if needed)
if 'combined_description' not in data.columns:
    data['combined_description'] = (
        data['Symptoms'].fillna('') + " " +
        data['Severity Level'].fillna('') + " " +
        data['Risk Assessment'].fillna('') + " " +
        data['Treatment Options'].fillna('') + " " +
        data['Suggested Medical Tests'].fillna('') + " " +
        data['Minimum Values for Medical Tests'].fillna('') + " " +
        data['Emergency Treatment'].fillna('')
    )

# Initialize Sentence Transformer model for RAG-based retrieval
retriever_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

# Define a function to get embeddings using DistilBERT
def generate_embedding(description):
    if description:
        inputs = tokenizer(description, return_tensors='pt', truncation=True, padding=True, max_length=512)
        outputs = model(**inputs)
        return outputs.logits.detach().numpy().flatten()
    else:
        return []

# Generate embeddings for the combined description
if 'embeddings' not in data.columns:
    data['embeddings'] = data['combined_description'].apply(generate_embedding)

# Function to retrieve relevant information based on user query
def get_relevant_info(query, top_k=3):
    query_embedding = retriever_model.encode(query)
    similarities = [util.cos_sim(query_embedding, doc_emb)[0][0].item() for doc_emb in data['embeddings']]
    top_indices = sorted(range(len(similarities)), key=lambda i: similarities[i], reverse=True)[:top_k]
    return data.iloc[top_indices]

# Function to generate response using DistilBERT (integrating with the model)
def generate_response(input_text, relevant_info):
    # Concatenate the relevant information as context for the model
    context = "\n".join(relevant_info['combined_description'].tolist())
    input_with_context = f"Context: {context}\n\nUser Query: {input_text}"

    # Simple logic for generating a response using DistilBERT-based model
    inputs = tokenizer(input_with_context, return_tensors='pt', truncation=True, padding=True, max_length=512)
    outputs = model(**inputs)
    logits = outputs.logits.detach().numpy().flatten()
    response = tokenizer.decode(logits.argmax(), skip_special_tokens=True)
    
    return response

# Streamlit UI for the Chatbot
def main():
    st.title("Medical Report and Analysis Chatbot")
    st.sidebar.header("Upload Medical Report or Enter Query")

    # Text input for user queries
    user_query = st.sidebar.text_input("Type your question or query")

    # File uploader for medical report
    uploaded_file = st.sidebar.file_uploader("Upload a medical report (optional)", type=["txt", "pdf", "csv"])

    # Process the query if provided
    if user_query:
        st.write("### Query Response:")

        # Retrieve relevant information from dataset
        relevant_info = get_relevant_info(user_query)
        st.write("#### Relevant Medical Information:")
        for i, row in relevant_info.iterrows():
            st.write(f"- {row['combined_description']}")  # Adjust to show meaningful info

        # Generate a response from DistilBERT model
        response = generate_response(user_query, relevant_info)
        st.write("#### Model's Response:")
        st.write(response)

    # Process the uploaded file (if any)
    if uploaded_file:
        # Display analysis of the uploaded report file (process based on file type)
        st.write("### Uploaded Report Analysis:")
        report_text = "Extracted report content here"  # Placeholder for file processing logic
        st.write(report_text)

if __name__ == "__main__":
    main()