TB / app.py
Zemedkun's picture
Upload 4 files
7963704 verified
raw
history blame
No virus
6.26 kB
import streamlit as st
import numpy as np
import onnxruntime as rt
import soundfile as sf
import sounddevice as sd
from scipy.io.wavfile import write
# Load the ONNX model
model_path = 'model.onnx' # Replace with the actual path to your model
model_inference = rt.InferenceSession(model_path)
# Function to preprocess audio data
def preprocess_audio(audio_data):
mu = np.nanmean(audio_data)
std = np.nanstd(audio_data)
audio_data = (audio_data - mu) / std
audio_data = np.pad(audio_data, (0, 22050 - len(audio_data)), 'constant').reshape(1, -1, 1).astype(np.float32)
return audio_data
# Function to preprocess clinical data
def preprocess_clinical_data(age, height, weight, reported_cough_dur, heart_rate, temperature, sex, tb_prior, tb_prior_Pul, tb_prior_Extrapul, tb_prior_Unknown, hemoptysis, weight_loss, smoke_lweek, fever, night_sweats):
sex_Female = 1 if sex == 'Female' else 0
sex_Male = 1 if sex == 'Male' else 0
tb_prior_No = 1 if tb_prior == 'No' else 0
tb_prior_Not_sure = 1 if tb_prior == 'Not sure' else 0
tb_prior_Yes = 1 if tb_prior == 'Yes' else 0
tb_prior_Pul_No = 1 if tb_prior_Pul == 'No' else 0
tb_prior_Pul_Yes = 1 if tb_prior_Pul == 'Yes' else 0
tb_prior_Extrapul_No = 1 if tb_prior_Extrapul == 'No' else 0
tb_prior_Extrapul_Yes = 1 if tb_prior_Extrapul == 'Yes' else 0
tb_prior_Unknown_No = 1 if tb_prior_Unknown == 'No' else 0
tb_prior_Unknown_Yes = 1 if tb_prior_Unknown == 'Yes' else 0
hemoptysis_No = 1 if hemoptysis == 'No' else 0
hemoptysis_Yes = 1 if hemoptysis == 'Yes' else 0
weight_loss_No = 1 if weight_loss == 'No' else 0
weight_loss_Yes = 1 if weight_loss == 'Yes' else 0
smoke_lweek_No = 1 if smoke_lweek == 'No' else 0
smoke_lweek_Yes = 1 if smoke_lweek == 'Yes' else 0
fever_No = 1 if fever == 'No' else 0
fever_Yes = 1 if fever == 'Yes' else 0
night_sweats_No = 1 if night_sweats == 'No' else 0
night_sweats_Yes = 1 if night_sweats == 'Yes' else 0
clinical_data = [age, height, weight, reported_cough_dur, heart_rate, temperature,
sex_Female, sex_Male, tb_prior_No, tb_prior_Not_sure, tb_prior_Yes,
tb_prior_Pul_No, tb_prior_Pul_Yes, tb_prior_Extrapul_No, tb_prior_Extrapul_Yes,
tb_prior_Unknown_No, tb_prior_Unknown_Yes, hemoptysis_No, hemoptysis_Yes,
weight_loss_No, weight_loss_Yes, smoke_lweek_No, smoke_lweek_Yes,
fever_No, fever_Yes, night_sweats_No, night_sweats_Yes]
clinical_data = np.array(clinical_data).reshape(1, -1).astype(np.float32)
return clinical_data
# Main function to run the app
def main():
st.title('TB Cough Sound Analysis')
# Create tabs
tabs = ["Record Cough Sound", "Upload Cough Sound"]
choice = st.sidebar.selectbox("Choose Option", tabs)
if choice == "Record Cough Sound":
st.write("**Record Cough Sound**")
st.write("Press the button below to start recording:")
start_recording = st.button("Start Recording")
if start_recording:
duration = 5 # Set the default recording duration to 5 seconds
st.write("Recording started...")
sample_rate = 22050
audio_data = sd.rec(int(duration * sample_rate), samplerate=sample_rate, channels=1)
sd.wait()
write("audio_file.wav", sample_rate, audio_data)
st.success("Recording saved as 'audio_file.wav'.")
st.info("Please proceed to the next step.")
elif choice == "Upload Cough Sound":
st.write("**Upload Cough Sound**")
uploaded_file = st.file_uploader("Upload Cough Sound (WAV file)", type=["wav"])
if uploaded_file is not None:
with open("audio_file.wav", "wb") as f:
f.write(uploaded_file.getvalue())
st.success("File uploaded successfully.")
st.info("Please proceed to the next step.")
st.write('**Step 2: Enter Clinical Information**')
age = st.slider('Age', 1, 100, 30)
height = st.slider('Height (cm)', 100, 300, 170)
weight = st.slider('Weight (kg)', 20, 200, 70)
reported_cough_dur = st.slider('Reported Cough Duration (days)', 1, 100, 10)
heart_rate = st.slider('Heart Rate (bpm)', 50, 200, 80)
temperature = st.slider('Body Temperature (°C)', 35.0, 40.0, 37.0)
sex = st.radio('Sex', ('Male', 'Female'))
tb_prior = st.radio('TB Prior', ('No', 'Not sure', 'Yes'))
tb_prior_Pul = st.radio('TB Prior Pul', ('No', 'Yes'))
tb_prior_Extrapul = st.radio('TB Prior Extrapul', ('No', 'Yes'))
tb_prior_Unknown = st.radio('TB Prior Unknown', ('No', 'Yes'))
hemoptysis = st.radio('Hemoptysis', ('No', 'Yes'))
weight_loss = st.radio('Weight Loss', ('No', 'Yes'))
smoke_lweek = st.radio('Smoke Lweek', ('No', 'Yes'))
fever = st.radio('Fever', ('No', 'Yes'))
night_sweats = st.radio('Night Sweats', ('No', 'Yes'))
if st.button('Predict'):
if choice == "Record Cough Sound":
audio_file_path = "audio_file.wav"
elif choice == "Upload Cough Sound":
audio_file_path = "audio_file.wav"
raw_values, rate = sf.read(audio_file_path)
audio_data = preprocess_audio(raw_values)
clinical_data = preprocess_clinical_data(age, height, weight, reported_cough_dur, heart_rate, temperature,
sex, tb_prior, tb_prior_Pul, tb_prior_Extrapul, tb_prior_Unknown,
hemoptysis, weight_loss, smoke_lweek, fever, night_sweats)
input_name = model_inference.get_inputs()[0].name
input_name2 = model_inference.get_inputs()[1].name
label_name = model_inference.get_outputs()[0].name
onnx_pred = model_inference.run([label_name], {input_name: audio_data, input_name2: clinical_data})
result = onnx_pred[0]
st.write(f"**Prediction:** {result[0][0]}")
if result[0][0] >= 0.5:
st.write("Tuberculosis (TB) is found based on the audio data and clinical information.")
else:
st.write("No Tuberculosis (TB) is found based on the audio data and clinical information.")
if __name__ == "__main__":
main()