HareemFatima's picture
Update app.py
d239c53 verified
raw
history blame
No virus
1.58 kB
import streamlit as st
import sounddevice as sd
import soundfile as sf
from transformers import pipeline
# Load the model pipeline
model = pipeline("audio-classification", model="HareemFatima/distilhubert-finetuned-stutterdetection")
# Define a function to map predicted labels to types of stuttering
def map_label_to_stutter_type(label):
if label == 0:
return "nonstutter"
elif label == 1:
return "prolongation"
elif label == 2:
return "repetition"
elif label == 3:
return "blocks"
else:
return "Unknown"
# Function to classify audio input and return the stutter type
def classify_audio(audio_input):
# Call your model pipeline to classify the audio
prediction = model(audio_input)
# Get the predicted label
predicted_label = prediction[0]["label"]
# Map the label to the corresponding stutter type
stutter_type = map_label_to_stutter_type(predicted_label)
return stutter_type
# Streamlit app
def main():
st.title("Stutter Classification App")
audio_input = st.audio("Capture Audio", format="audio/wav", start_recording=True, channels=1)
if st.button("Stop Recording"):
sd.stop()
with st.spinner("Classifying..."):
# Read the recorded audio file
recording_path = "recording.wav"
audio_data, sampling_rate = sf.read(recording_path)
# Classify the audio
stutter_type = classify_audio(audio_data)
st.write("Predicted Stutter Type:", stutter_type)
if __name__ == "__main__":
main()