Spaces:
Sleeping
Sleeping
File size: 8,409 Bytes
c067bf9 e7bec26 4678109 17cfcdd 4678109 014aebe 4678109 e7bec26 17cfcdd c067bf9 55b99fb 4cbf465 dd07134 58a2772 55b99fb 58a2772 55b99fb 58a2772 55b99fb 58a2772 c067bf9 58a2772 8744263 4678109 d692a69 4678109 e7bec26 c6f1011 e7bec26 bf8ca3d e7bec26 4678109 129775b 4678109 e7bec26 4678109 e7bec26 4678109 0c3c269 0d99ea8 0c3c269 0d99ea8 0c3c269 9ccd224 0d99ea8 f1bba0e 0c3c269 0d99ea8 f1bba0e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import re
from peft import PeftModel
from pydub import AudioSegment
import speech_recognition as sr
import io
from audio_recorder_streamlit import audio_recorder # Add at the top with other imports
# Load model and tokenizer from local fine-tuned directory
# Define base and adapter model paths
BASE_MODEL = "stanford-crfm/BioMedLM" # or the path you used originally
ADAPTER_ID = "Tufan1/BioMedLM-Cardio-Fold10-CPU" # HF model ID
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) # tokenizer
# Load base model with safe settings
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float32,
low_cpu_mem_usage=True, # Reduces memory spikes
device_map="cpu" # Force CPU loading
)
# Load adapter directly from Hub
model = PeftModel.from_pretrained(
base_model,
ADAPTER_ID,
device_map="cpu",
adapter_name="cardio_adapter"
)
# Force CPU-safe model loading
#base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, torch_dtype=torch.float32)
#model = PeftModel.from_pretrained(base_model, ADAPTER_PATH, device_map=None).to("cpu")
# Dictionaries to decode user inputs
gender_map = {1: "Female", 2: "Male"}
cholesterol_map = {1: "Normal", 2: "Elevated", 3: "Peak"}
glucose_map = {1: "Normal", 2: "High", 3: "Extreme"}
binary_map = {0: "No", 1: "Yes"}
# Function to predict diagnosis using the LLM
def get_prediction(age, gender, height, weight, ap_hi, ap_lo,
cholesterol, glucose, smoke, alco, active):
input_text = f"""Patient Record:
- Age: {age} years
- Gender: {gender_map[gender]}
- Height: {height} cm
- Weight: {weight} kg
- Systolic BP: {ap_hi} mmHg
- Diastolic BP: {ap_lo} mmHg
- Cholesterol Level: {cholesterol_map[cholesterol]}
- Glucose Level: {glucose_map[glucose]}
- Smokes: {binary_map[smoke]}
- Alcohol Intake: {binary_map[alco]}
- Physically Active: {binary_map[active]}
Diagnosis:"""
inputs = tokenizer(input_text, return_tensors="pt").to("cpu")
model.eval()
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=4)
print("Raw output:", outputs) # Add this line
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
diagnosis = decoded.split("Diagnosis:")[-1].strip()
return diagnosis
# Function to extract patient features from a phrase or transcribed audio
def extract_details_from_text(text):
age = int(re.search(r'(\d+)\s*year', text).group(1)) if re.search(r'(\d+)\s*year', text) else None
gender = 2 if "man" in text.lower() else (1 if "female" in text.lower() else None)
height = int(re.search(r'(\d+)\s*cm', text).group(1)) if re.search(r'(\d+)\s*cm', text) else None
weight = int(re.search(r'(\d+)\s*kg', text).group(1)) if re.search(r'(\d+)\s*kg', text) else None
bp_match = re.search(r'BP\s*(\d+)[/](\d+)', text)
ap_hi, ap_lo = (int(bp_match.group(1)), int(bp_match.group(2))) if bp_match else (None, None)
cholesterol = 3 if "peak" in text.lower() else 2 if "elevated" in text.lower() else 1
glucose = 3 if "extreme" in text.lower() else 2 if "high" in text.lower() else 1
smoke = 1 if "smoke" in text.lower() else 0
alco = 1 if "alcohol" in text.lower() else 0
active = 1 if "exercise" in text.lower() or "active" in text.lower() else 0
return age, gender, height, weight, ap_hi, ap_lo, cholesterol, glucose, smoke, alco, active
# Streamlit UI
st.set_page_config(page_title="Cardiovascular Disease Predictor", layout="centered")
st.title("🫀 Cardiovascular Disease Predictor (LLM Powered)")
st.markdown("This tool uses a fine-tuned BioMedLM model to predict cardiovascular conditions from structured, text, or voice input.")
input_mode = st.radio("Choose input method:", ["Manual Input", "Text Phrase", "Audio Upload"])
if input_mode == "Manual Input":
age = st.number_input("Age (years)", min_value=1, max_value=120)
gender = st.selectbox("Gender", [("Female", 1), ("Male", 2)], format_func=lambda x: x[0])[1]
height = st.number_input("Height (cm)", min_value=50, max_value=250)
weight = st.number_input("Weight (kg)", min_value=10, max_value=200)
ap_hi = st.number_input("Systolic BP", min_value=80, max_value=250)
ap_lo = st.number_input("Diastolic BP", min_value=40, max_value=150)
cholesterol = st.selectbox("Cholesterol", [("Normal", 1), ("Elevated", 2), ("Peak", 3)], format_func=lambda x: x[0])[1]
glucose = st.selectbox("Glucose", [("Normal", 1), ("High", 2), ("Extreme", 3)], format_func=lambda x: x[0])[1]
smoke = st.radio("Smoker?", [("No", 0), ("Yes", 1)], format_func=lambda x: x[0])[1]
alco = st.radio("Alcohol Intake?", [("No", 0), ("Yes", 1)], format_func=lambda x: x[0])[1]
active = st.radio("Physically Active?", [("No", 0), ("Yes", 1)], format_func=lambda x: x[0])[1]
if st.button("Predict Diagnosis"):
diagnosis = get_prediction(age, gender, height, weight, ap_hi, ap_lo,
cholesterol, glucose, smoke, alco, active)
st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}")
elif input_mode == "Text Phrase":
phrase = st.text_area("Enter patient details in natural language:", height=200)
if st.button("Extract & Predict"):
try:
values = extract_details_from_text(phrase)
if all(v is not None for v in values):
diagnosis = get_prediction(*values)
st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}")
else:
st.warning("Couldn't extract all fields from the text. Please revise.")
except Exception as e:
st.error(f"Error: {e}")
elif input_mode == "Audio Upload":
audio_input_mode = st.radio("Choose audio input type:", ["Upload Audio File", "Record Audio"])
if audio_input_mode == "Upload Audio File":
uploaded_file = st.file_uploader("Upload audio file (WAV, MP3, M4A, MPEG)", type=["wav", "mp3", "m4a", "mpeg"])
if uploaded_file:
st.audio(uploaded_file, format='audio/wav')
audio = AudioSegment.from_file(uploaded_file)
if audio is not None and len(audio) > 0:
wav_io = io.BytesIO()
audio.export(wav_io, format="wav")
wav_io.seek(0)
recognizer = sr.Recognizer()
with sr.AudioFile(wav_io) as source:
audio_data = recognizer.record(source)
try:
text = recognizer.recognize_google(audio_data)
st.markdown(f"**Transcribed Text:** _{text}_")
values = extract_details_from_text(text)
if all(v is not None for v in values):
diagnosis = get_prediction(*values)
st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}")
else:
st.warning("Could not extract complete information from audio.")
except Exception as e:
st.error(f"Audio processing error: {e}")
else:
st.error("Uploaded audio file is empty or not valid.")
elif audio_input_mode == "Record Audio":
audio = audio_recorder("Click to record", "Recording...")
if audio is not None and len(audio) > 0: # Check if audio is not None and has length
st.audio(audio, format="audio/wav") # Directly use audio as it is already a bytes object
wav_io = io.BytesIO(audio)
recognizer = sr.Recognizer()
with sr.AudioFile(wav_io) as source:
audio_data = recognizer.record(source)
try:
text = recognizer.recognize_google(audio_data)
st.markdown(f"**Transcribed Text:** _{text}_")
values = extract_details_from_text(text)
if all(v is not None for v in values):
diagnosis = get_prediction(*values)
st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}")
else:
st.warning("Could not extract complete information from recorded audio.")
except Exception as e:
st.error(f"Recording processing error: {e}")
else:
st.error("No audio recorded or audio is empty.")
|