Sumitkumar098's picture
Upload 2 files
7a8bc6d verified
import streamlit as st
import torch
import numpy as np
import pickle
import json
from transformers import AutoTokenizer, AutoModel
import torch.nn as nn
import os
# Set page config
st.set_page_config(
page_title="Drug Prediction and Polypharmacy System",
page_icon="๐Ÿ’Š",
layout="wide"
)
# Model class definition - must match the training model architecture
class EnhancedMedicationModel(nn.Module):
def __init__(self, model_name, num_medications, num_polypharmacy_classes, num_disease_classes, dropout_rate=0.3):
super().__init__()
self.bert = AutoModel.from_pretrained(model_name)
self.dropout = nn.Dropout(dropout_rate)
hidden_size = self.bert.config.hidden_size
# Common representation layer
self.common_dense = nn.Linear(hidden_size, hidden_size)
# Task-specific layers with increased complexity
# Medication prediction head (multi-label)
self.medication_classifier = nn.Sequential(
nn.Linear(hidden_size, hidden_size//2),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(hidden_size//2, num_medications)
)
# Polypharmacy risk head (multi-class)
self.polypharmacy_classifier = nn.Sequential(
nn.Linear(hidden_size, hidden_size//2),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(hidden_size//2, num_polypharmacy_classes)
)
# Disease prediction head (multi-class)
self.disease_classifier = nn.Sequential(
nn.Linear(hidden_size, hidden_size//2),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(hidden_size//2, num_disease_classes)
)
# Apply weight initialization
self._init_weights()
def _init_weights(self):
# Initialize weights for better convergence
for module in [self.medication_classifier, self.polypharmacy_classifier,
self.disease_classifier, self.common_dense]:
if isinstance(module, nn.Sequential):
for layer in module:
if isinstance(layer, nn.Linear):
nn.init.xavier_normal_(layer.weight)
nn.init.zeros_(layer.bias)
elif isinstance(module, nn.Linear):
nn.init.xavier_normal_(module.weight)
nn.init.zeros_(layer.bias)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.last_hidden_state[:, 0, :] # CLS token
pooled_output = self.dropout(pooled_output)
# Common representation
common_features = torch.relu(self.common_dense(pooled_output))
medication_logits = self.medication_classifier(common_features)
polypharmacy_logits = self.polypharmacy_classifier(common_features)
disease_logits = self.disease_classifier(common_features)
return medication_logits, polypharmacy_logits, disease_logits
@st.cache_resource
def load_model_and_resources():
"""Load model and necessary resources (cached for performance)"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model configuration - fixed file paths
with open('streamlit_model/model_config.json', 'r') as f:
model_config = json.load(f)
# Initialize model
model_name = model_config['model_name']
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Create model architecture
model = EnhancedMedicationModel(
model_name=model_name,
num_medications=model_config['num_medications'],
num_polypharmacy_classes=model_config['num_polypharmacy_classes'],
num_disease_classes=model_config['num_disease_classes'],
dropout_rate=0.3
)
# Load trained weights - fixed file path
model.load_state_dict(torch.load('streamlit_model/model_state_dict.pt', map_location=device))
model = model.to(device)
model.eval()
# Load encoders - fixed file path
with open('streamlit_model/label_encoders.pkl', 'rb') as f:
encoders = pickle.load(f)
# Load lookup data - fixed file path
with open('streamlit_model/lookup_data.pkl', 'rb') as f:
lookup_data = pickle.load(f)
return {
'model': model,
'tokenizer': tokenizer,
'mlb': encoders['mlb'],
'le_risk': encoders['le_risk'],
'le_disease': encoders['le_disease'],
'lookup_data': lookup_data,
'device': device
}
def predict_patient_health_profile(patient_data, resources):
"""
Predict health profile for a patient based on input data
"""
model = resources['model']
tokenizer = resources['tokenizer']
mlb = resources['mlb']
le_risk = resources['le_risk']
le_disease = resources['le_disease']
lookup_data = resources['lookup_data']
device = resources['device']
# Create text input
text_input = f"Patient age {patient_data['age']}, gender {patient_data['gender']}, blood group {patient_data['blood_group']}, weight {patient_data['weight']}kg. " + f"SYMPTOMS: {patient_data['symptoms']}. " + f"SEVERITY: {patient_data['severity']}."
# Tokenize
encoding = tokenizer(
text_input,
add_special_tokens=True,
max_length=256,
padding='max_length',
truncation=True,
return_tensors='pt'
)
# Move to device
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
# Get predictions
with torch.no_grad():
medication_logits, polypharmacy_logits, disease_logits = model(input_ids, attention_mask)
medication_preds = torch.sigmoid(medication_logits) > 0.5
polypharmacy_pred = torch.argmax(polypharmacy_logits, dim=1)
disease_pred = torch.argmax(disease_logits, dim=1)
# Convert predictions to human-readable format
predicted_medications = mlb.classes_[medication_preds[0].cpu().numpy()]
predicted_risk = le_risk.classes_[polypharmacy_pred.item()]
predicted_disease = le_disease.classes_[disease_pred.item()]
# Get medication probabilities for all medications
medication_probs = torch.sigmoid(medication_logits).cpu().numpy()[0]
med_prob_dict = {med: prob for med, prob in zip(mlb.classes_, medication_probs)}
# Sort medications by probability
sorted_meds = sorted(med_prob_dict.items(), key=lambda x: x[1], reverse=True)
top_meds = sorted_meds[:5] # Get top 5 medications
# Format medication results
med_results = []
for i, med in enumerate(predicted_medications[:3]):
med_details = {
'medication': med,
'dosage': 'Consult doctor',
'frequency': 'Consult doctor',
'instruction': 'Consult doctor',
'duration': 'As prescribed',
'confidence': float(med_prob_dict[med])
}
med_results.append(med_details)
# Get disease information
disease_causes = lookup_data['disease_causes_dict'].get(predicted_disease, "Unknown causes")
disease_prevention = lookup_data['disease_prevention_dict'].get(predicted_disease, "Consult healthcare provider")
# Get polypharmacy recommendation
polypharmacy_recommendation = lookup_data['polypharmacy_recommendation_dict'].get(
predicted_risk, "Consult healthcare provider"
)
# Get personalized health tip
age_decade = (patient_data['age'] // 10) * 10
health_tip_key = (predicted_disease, age_decade, patient_data['gender'])
personalized_health_tip = lookup_data['health_tips_dict'].get(
health_tip_key, "Maintain a balanced diet and regular exercise routine."
)
# Return comprehensive results
return {
'patient_name': patient_data['name'], # Include patient name in results
'predicted_disease': predicted_disease,
'disease_causes': disease_causes,
'disease_prevention': disease_prevention,
'medications': med_results,
'polypharmacy_risk': predicted_risk,
'polypharmacy_recommendation': polypharmacy_recommendation,
'personalized_health_tips': personalized_health_tip,
'medication_probabilities': {med: float(prob) for med, prob in top_meds}
}
def main():
# App title and description
st.title("๐Ÿฅ Drug Prediction and Polypharmacy System")
st.markdown("Enter patient information to receive medication recommendations, disease prediction, and polypharmacy risk assessment.")
try:
# Load model and resources
with st.spinner("Loading medical model and resources..."):
resources = load_model_and_resources()
# Create two columns for input form
col1, col2 = st.columns(2)
# Patient information inputs
with col1:
st.subheader("Patient Information")
# Add patient name input field
name = st.text_input("Patient Name", value="John Doe")
age = st.number_input("Age", min_value=1, max_value=120, value=45)
gender = st.selectbox("Gender", options=["Male", "Female", "Other"])
blood_group = st.selectbox("Blood Group", options=["A+", "A-", "B+", "B-", "AB+", "AB-", "O+", "O-"])
weight = st.number_input("Weight (kg)", min_value=1.0, max_value=300.0, value=70.0, step=0.1)
with col2:
st.subheader("Symptoms Information")
# Common symptoms options
common_symptoms = [
"Headache", "Fever", "Fatigue", "Nausea", "Cough",
"Sore throat", "Shortness of breath", "Chest pain",
"Dizziness", "Abdominal pain", "Vomiting", "Diarrhea",
"Muscle ache", "Joint pain", "Rash", "Loss of appetite"
]
# Use multiselect for symptoms selection
selected_symptoms = st.multiselect(
"Select Symptoms",
options=common_symptoms,
default=["Headache", "Fever", "Fatigue"]
)
# Custom symptom input
custom_symptom = st.text_input("Add other symptom (if not in list)")
if custom_symptom:
selected_symptoms.append(custom_symptom)
# Convert selected symptoms to string format as expected by the model
symptoms = "; ".join(selected_symptoms)
# More compact severity selection
st.subheader("Symptom Severity")
# Define severity levels
severity_levels = {
"Very Mild": 1,
"Mild": 2,
"Moderate": 3,
"Severe": 4,
"Very Severe": 5
}
severity_dict = {}
# Create a more compact layout with 2 columns for severity selection
if selected_symptoms:
cols = st.columns(2)
for i, symptom in enumerate(selected_symptoms):
# Alternate between columns
with cols[i % 2]:
severity_option = st.selectbox(
f"{symptom}",
options=list(severity_levels.keys()),
index=1 # Default to "Mild"
)
severity_dict[symptom] = severity_levels[severity_option]
# Convert severity dict to string format as expected by the model
severity = "; ".join([f"{symptom}:{score}" for symptom, score in severity_dict.items()])
# Submit button
if st.button("Generate Health Profile", type="primary"):
with st.spinner("Analyzing patient data and generating health profile..."):
# Prepare patient data
patient_data = {
'name': name, # Include name in patient data
'age': age,
'gender': gender,
'blood_group': blood_group,
'weight': weight,
'symptoms': symptoms,
'severity': severity
}
# Get prediction
prediction = predict_patient_health_profile(patient_data, resources)
# Display results in three columns
st.subheader(f"๐Ÿ” Health Profile Analysis Results for {prediction['patient_name']}")
col1, col2, col3 = st.columns([1, 1, 1])
# Column 1: Disease information
with col1:
st.markdown("### ๐Ÿฆ  Disease Prediction")
st.markdown(f"**Predicted Disease**: {prediction['predicted_disease']}")
with st.expander("Disease Causes"):
st.write(prediction['disease_causes'])
with st.expander("Prevention Methods"):
st.write(prediction['disease_prevention'])
# Column 2: Medication recommendations
with col2:
st.markdown("### ๐Ÿ’Š Medication Recommendations")
for i, med in enumerate(prediction['medications']):
st.markdown(f"**{i+1}. {med['medication']}** (Confidence: {med['confidence']:.2f})")
med_details = f"""
- **Dosage:** {med['dosage']}
- **Frequency:** {med['frequency']}
- **Instructions:** {med['instruction']}
- **Duration:** {med['duration']}
"""
st.markdown(med_details)
st.divider()
# Column 3: Risk assessment and health tips
with col3:
st.markdown("### โš ๏ธ Polypharmacy Assessment")
risk_color = "green" if prediction['polypharmacy_risk'] == "Low" else "orange" if prediction['polypharmacy_risk'] == "Medium" else "red"
st.markdown(f"**Risk Level**: <span style='color:{risk_color};font-weight:bold;'>{prediction['polypharmacy_risk']}</span>",
unsafe_allow_html=True)
st.markdown(f"**Recommendation**: {prediction['polypharmacy_recommendation']}")
st.markdown("### ๐ŸŒฟ Personalized Health Tips")
st.info(prediction['personalized_health_tips'])
# Display medication probabilities as text with progress bars
st.subheader("Medication Confidence Scores")
med_names = list(prediction['medication_probabilities'].keys())
med_probs = list(prediction['medication_probabilities'].values())
# Display each medication with its confidence score as text and progress bar
for med_name, med_prob in zip(med_names, med_probs):
st.text(f"{med_name}: {med_prob:.2f}")
st.progress(med_prob)
except Exception as e:
st.error(f"An error occurred: {str(e)}")
st.error("Please make sure all model files are correctly placed in the 'streamlit_model' directory")
if __name__ == "__main__":
main()