import streamlit as st import torch from transformers import AutoTokenizer, AutoModelForCausalLM import re from peft import PeftModel from pydub import AudioSegment import speech_recognition as sr import io from audio_recorder_streamlit import audio_recorder # Add at the top with other imports # Load model and tokenizer from local fine-tuned directory # Define base and adapter model paths BASE_MODEL = "stanford-crfm/BioMedLM" # or the path you used originally ADAPTER_ID = "Tufan1/BioMedLM-Cardio-Fold10-CPU" # HF model ID tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) # tokenizer # Load base model with safe settings base_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, torch_dtype=torch.float32, low_cpu_mem_usage=True, # Reduces memory spikes device_map="cpu" # Force CPU loading ) # Load adapter directly from Hub model = PeftModel.from_pretrained( base_model, ADAPTER_ID, device_map="cpu", adapter_name="cardio_adapter" ) # Force CPU-safe model loading #base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, torch_dtype=torch.float32) #model = PeftModel.from_pretrained(base_model, ADAPTER_PATH, device_map=None).to("cpu") # Dictionaries to decode user inputs gender_map = {1: "Female", 2: "Male"} cholesterol_map = {1: "Normal", 2: "Elevated", 3: "Peak"} glucose_map = {1: "Normal", 2: "High", 3: "Extreme"} binary_map = {0: "No", 1: "Yes"} # Function to predict diagnosis using the LLM def get_prediction(age, gender, height, weight, ap_hi, ap_lo, cholesterol, glucose, smoke, alco, active): input_text = f"""Patient Record: - Age: {age} years - Gender: {gender_map[gender]} - Height: {height} cm - Weight: {weight} kg - Systolic BP: {ap_hi} mmHg - Diastolic BP: {ap_lo} mmHg - Cholesterol Level: {cholesterol_map[cholesterol]} - Glucose Level: {glucose_map[glucose]} - Smokes: {binary_map[smoke]} - Alcohol Intake: {binary_map[alco]} - Physically Active: {binary_map[active]} Diagnosis:""" inputs = tokenizer(input_text, return_tensors="pt").to("cpu") model.eval() with torch.no_grad(): outputs = model.generate(**inputs, max_new_tokens=4) print("Raw output:", outputs) # Add this line decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) diagnosis = decoded.split("Diagnosis:")[-1].strip() return diagnosis # Function to extract patient features from a phrase or transcribed audio def extract_details_from_text(text): age = int(re.search(r'(\d+)\s*year', text).group(1)) if re.search(r'(\d+)\s*year', text) else None gender = 2 if "man" in text.lower() else (1 if "female" in text.lower() else None) height = int(re.search(r'(\d+)\s*cm', text).group(1)) if re.search(r'(\d+)\s*cm', text) else None weight = int(re.search(r'(\d+)\s*kg', text).group(1)) if re.search(r'(\d+)\s*kg', text) else None bp_match = re.search(r'BP\s*(\d+)[/](\d+)', text) ap_hi, ap_lo = (int(bp_match.group(1)), int(bp_match.group(2))) if bp_match else (None, None) cholesterol = 3 if "peak" in text.lower() else 2 if "elevated" in text.lower() else 1 glucose = 3 if "extreme" in text.lower() else 2 if "high" in text.lower() else 1 smoke = 1 if "smoke" in text.lower() else 0 alco = 1 if "alcohol" in text.lower() else 0 active = 1 if "exercise" in text.lower() or "active" in text.lower() else 0 return age, gender, height, weight, ap_hi, ap_lo, cholesterol, glucose, smoke, alco, active # Streamlit UI st.set_page_config(page_title="Cardiovascular Disease Predictor", layout="centered") st.title("🫀 Cardiovascular Disease Predictor (LLM Powered)") st.markdown("This tool uses a fine-tuned BioMedLM model to predict cardiovascular conditions from structured, text, or voice input.") input_mode = st.radio("Choose input method:", ["Manual Input", "Text Phrase", "Audio Upload"]) if input_mode == "Manual Input": age = st.number_input("Age (years)", min_value=1, max_value=120) gender = st.selectbox("Gender", [("Female", 1), ("Male", 2)], format_func=lambda x: x[0])[1] height = st.number_input("Height (cm)", min_value=50, max_value=250) weight = st.number_input("Weight (kg)", min_value=10, max_value=200) ap_hi = st.number_input("Systolic BP", min_value=80, max_value=250) ap_lo = st.number_input("Diastolic BP", min_value=40, max_value=150) cholesterol = st.selectbox("Cholesterol", [("Normal", 1), ("Elevated", 2), ("Peak", 3)], format_func=lambda x: x[0])[1] glucose = st.selectbox("Glucose", [("Normal", 1), ("High", 2), ("Extreme", 3)], format_func=lambda x: x[0])[1] smoke = st.radio("Smoker?", [("No", 0), ("Yes", 1)], format_func=lambda x: x[0])[1] alco = st.radio("Alcohol Intake?", [("No", 0), ("Yes", 1)], format_func=lambda x: x[0])[1] active = st.radio("Physically Active?", [("No", 0), ("Yes", 1)], format_func=lambda x: x[0])[1] if st.button("Predict Diagnosis"): diagnosis = get_prediction(age, gender, height, weight, ap_hi, ap_lo, cholesterol, glucose, smoke, alco, active) st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}") elif input_mode == "Text Phrase": phrase = st.text_area("Enter patient details in natural language:", height=200) if st.button("Extract & Predict"): try: values = extract_details_from_text(phrase) if all(v is not None for v in values): diagnosis = get_prediction(*values) st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}") else: st.warning("Couldn't extract all fields from the text. Please revise.") except Exception as e: st.error(f"Error: {e}") elif input_mode == "Audio Upload": audio_input_mode = st.radio("Choose audio input type:", ["Upload Audio File", "Record Audio"]) if audio_input_mode == "Upload Audio File": uploaded_file = st.file_uploader("Upload audio file (WAV, MP3, M4A, MPEG)", type=["wav", "mp3", "m4a", "mpeg"]) if uploaded_file: st.audio(uploaded_file, format='audio/wav') audio = AudioSegment.from_file(uploaded_file) if audio is not None and len(audio) > 0: wav_io = io.BytesIO() audio.export(wav_io, format="wav") wav_io.seek(0) recognizer = sr.Recognizer() with sr.AudioFile(wav_io) as source: audio_data = recognizer.record(source) try: text = recognizer.recognize_google(audio_data) st.markdown(f"**Transcribed Text:** _{text}_") values = extract_details_from_text(text) if all(v is not None for v in values): diagnosis = get_prediction(*values) st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}") else: st.warning("Could not extract complete information from audio.") except Exception as e: st.error(f"Audio processing error: {e}") else: st.error("Uploaded audio file is empty or not valid.") elif audio_input_mode == "Record Audio": audio = audio_recorder("Click to record", "Recording...") if audio is not None and len(audio) > 0: # Check if audio is not None and has length st.audio(audio, format="audio/wav") # Directly use audio as it is already a bytes object wav_io = io.BytesIO(audio) recognizer = sr.Recognizer() with sr.AudioFile(wav_io) as source: audio_data = recognizer.record(source) try: text = recognizer.recognize_google(audio_data) st.markdown(f"**Transcribed Text:** _{text}_") values = extract_details_from_text(text) if all(v is not None for v in values): diagnosis = get_prediction(*values) st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}") else: st.warning("Could not extract complete information from recorded audio.") except Exception as e: st.error(f"Recording processing error: {e}") else: st.error("No audio recorded or audio is empty.")