File size: 8,409 Bytes
c067bf9
e7bec26
 
4678109
17cfcdd
4678109
 
 
014aebe
4678109
e7bec26
17cfcdd
c067bf9
55b99fb
4cbf465
dd07134
58a2772
 
 
 
 
 
 
 
55b99fb
58a2772
 
55b99fb
58a2772
55b99fb
58a2772
 
 
c067bf9
58a2772
 
8744263
4678109
 
 
d692a69
4678109
 
 
e7bec26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6f1011
e7bec26
 
 
bf8ca3d
e7bec26
 
 
4678109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129775b
4678109
 
 
 
 
 
e7bec26
 
4678109
 
 
 
 
 
 
 
e7bec26
4678109
 
 
 
 
 
 
0c3c269
 
 
 
 
0d99ea8
0c3c269
 
 
0d99ea8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c3c269
 
9ccd224
0d99ea8
 
f1bba0e
 
0c3c269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d99ea8
 
 
f1bba0e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import re
from peft import PeftModel
from pydub import AudioSegment
import speech_recognition as sr
import io
from audio_recorder_streamlit import audio_recorder  # Add at the top with other imports

# Load model and tokenizer from local fine-tuned directory
# Define base and adapter model paths
BASE_MODEL = "stanford-crfm/BioMedLM"  # or the path you used originally
ADAPTER_ID = "Tufan1/BioMedLM-Cardio-Fold10-CPU"  # HF model ID
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) # tokenizer

# Load base model with safe settings
base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    torch_dtype=torch.float32,
    low_cpu_mem_usage=True,  # Reduces memory spikes
    device_map="cpu"  # Force CPU loading
)

# Load adapter directly from Hub
model = PeftModel.from_pretrained(
    base_model,
    ADAPTER_ID,
    device_map="cpu",
    adapter_name="cardio_adapter"
)


# Force CPU-safe model loading
#base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, torch_dtype=torch.float32)
#model = PeftModel.from_pretrained(base_model, ADAPTER_PATH, device_map=None).to("cpu")


# Dictionaries to decode user inputs
gender_map = {1: "Female", 2: "Male"}
cholesterol_map = {1: "Normal", 2: "Elevated", 3: "Peak"}
glucose_map = {1: "Normal", 2: "High", 3: "Extreme"}
binary_map = {0: "No", 1: "Yes"}

# Function to predict diagnosis using the LLM
def get_prediction(age, gender, height, weight, ap_hi, ap_lo,
                   cholesterol, glucose, smoke, alco, active):
    input_text = f"""Patient Record:
- Age: {age} years
- Gender: {gender_map[gender]}
- Height: {height} cm
- Weight: {weight} kg
- Systolic BP: {ap_hi} mmHg
- Diastolic BP: {ap_lo} mmHg
- Cholesterol Level: {cholesterol_map[cholesterol]}
- Glucose Level: {glucose_map[glucose]}
- Smokes: {binary_map[smoke]}
- Alcohol Intake: {binary_map[alco]}
- Physically Active: {binary_map[active]}

Diagnosis:"""

    inputs = tokenizer(input_text, return_tensors="pt").to("cpu")
    model.eval()
    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=4)
        print("Raw output:", outputs)  # Add this line
    decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
    diagnosis = decoded.split("Diagnosis:")[-1].strip()
    return diagnosis

# Function to extract patient features from a phrase or transcribed audio
def extract_details_from_text(text):
    age = int(re.search(r'(\d+)\s*year', text).group(1)) if re.search(r'(\d+)\s*year', text) else None
    gender = 2 if "man" in text.lower() else (1 if "female" in text.lower() else None)
    height = int(re.search(r'(\d+)\s*cm', text).group(1)) if re.search(r'(\d+)\s*cm', text) else None
    weight = int(re.search(r'(\d+)\s*kg', text).group(1)) if re.search(r'(\d+)\s*kg', text) else None
    bp_match = re.search(r'BP\s*(\d+)[/](\d+)', text)
    ap_hi, ap_lo = (int(bp_match.group(1)), int(bp_match.group(2))) if bp_match else (None, None)
    cholesterol = 3 if "peak" in text.lower() else 2 if "elevated" in text.lower() else 1
    glucose = 3 if "extreme" in text.lower() else 2 if "high" in text.lower() else 1
    smoke = 1 if "smoke" in text.lower() else 0
    alco = 1 if "alcohol" in text.lower() else 0
    active = 1 if "exercise" in text.lower() or "active" in text.lower() else 0
    return age, gender, height, weight, ap_hi, ap_lo, cholesterol, glucose, smoke, alco, active

# Streamlit UI
st.set_page_config(page_title="Cardiovascular Disease Predictor", layout="centered")
st.title("🫀 Cardiovascular Disease Predictor (LLM Powered)")
st.markdown("This tool uses a fine-tuned BioMedLM model to predict cardiovascular conditions from structured, text, or voice input.")

input_mode = st.radio("Choose input method:", ["Manual Input", "Text Phrase", "Audio Upload"])

if input_mode == "Manual Input":
    age = st.number_input("Age (years)", min_value=1, max_value=120)
    gender = st.selectbox("Gender", [("Female", 1), ("Male", 2)], format_func=lambda x: x[0])[1]
    height = st.number_input("Height (cm)", min_value=50, max_value=250)
    weight = st.number_input("Weight (kg)", min_value=10, max_value=200)
    ap_hi = st.number_input("Systolic BP", min_value=80, max_value=250)
    ap_lo = st.number_input("Diastolic BP", min_value=40, max_value=150)
    cholesterol = st.selectbox("Cholesterol", [("Normal", 1), ("Elevated", 2), ("Peak", 3)], format_func=lambda x: x[0])[1]
    glucose = st.selectbox("Glucose", [("Normal", 1), ("High", 2), ("Extreme", 3)], format_func=lambda x: x[0])[1]
    smoke = st.radio("Smoker?", [("No", 0), ("Yes", 1)], format_func=lambda x: x[0])[1]
    alco = st.radio("Alcohol Intake?", [("No", 0), ("Yes", 1)], format_func=lambda x: x[0])[1]
    active = st.radio("Physically Active?", [("No", 0), ("Yes", 1)], format_func=lambda x: x[0])[1]

    if st.button("Predict Diagnosis"):
        diagnosis = get_prediction(age, gender, height, weight, ap_hi, ap_lo,
                                   cholesterol, glucose, smoke, alco, active)
        st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}")

elif input_mode == "Text Phrase":
    phrase = st.text_area("Enter patient details in natural language:", height=200)
    if st.button("Extract & Predict"):
        try:
            values = extract_details_from_text(phrase)
            if all(v is not None for v in values):
                diagnosis = get_prediction(*values)
                st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}")
            else:
                st.warning("Couldn't extract all fields from the text. Please revise.")
        except Exception as e:
            st.error(f"Error: {e}")


elif input_mode == "Audio Upload":
    audio_input_mode = st.radio("Choose audio input type:", ["Upload Audio File", "Record Audio"])

    if audio_input_mode == "Upload Audio File":
        uploaded_file = st.file_uploader("Upload audio file (WAV, MP3, M4A, MPEG)", type=["wav", "mp3", "m4a", "mpeg"])
        
        if uploaded_file:
            st.audio(uploaded_file, format='audio/wav')
            audio = AudioSegment.from_file(uploaded_file)
            
            if audio is not None and len(audio) > 0:
                wav_io = io.BytesIO()
                audio.export(wav_io, format="wav")
                wav_io.seek(0)

                recognizer = sr.Recognizer()
                with sr.AudioFile(wav_io) as source:
                    audio_data = recognizer.record(source)

                try:
                    text = recognizer.recognize_google(audio_data)
                    st.markdown(f"**Transcribed Text:** _{text}_")
                    values = extract_details_from_text(text)
                    if all(v is not None for v in values):
                        diagnosis = get_prediction(*values)
                        st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}")
                    else:
                        st.warning("Could not extract complete information from audio.")
                except Exception as e:
                    st.error(f"Audio processing error: {e}")
            else:
                st.error("Uploaded audio file is empty or not valid.")

    elif audio_input_mode == "Record Audio":
        audio = audio_recorder("Click to record", "Recording...")

        if audio is not None and len(audio) > 0:  # Check if audio is not None and has length
            st.audio(audio, format="audio/wav")  # Directly use audio as it is already a bytes object
            wav_io = io.BytesIO(audio)

            recognizer = sr.Recognizer()
            with sr.AudioFile(wav_io) as source:
                audio_data = recognizer.record(source)

            try:
                text = recognizer.recognize_google(audio_data)
                st.markdown(f"**Transcribed Text:** _{text}_")
                values = extract_details_from_text(text)
                if all(v is not None for v in values):
                    diagnosis = get_prediction(*values)
                    st.success(f"🩺 **Predicted Diagnosis:** {diagnosis}")
                else:
                    st.warning("Could not extract complete information from recorded audio.")
            except Exception as e:
                st.error(f"Recording processing error: {e}")
        else:
            st.error("No audio recorded or audio is empty.")