Spaces:
Runtime error
Runtime error
import streamlit as st | |
import sounddevice as sd | |
import soundfile as sf | |
from transformers import pipeline | |
# Load the model pipeline | |
model = pipeline("audio-classification", model="HareemFatima/distilhubert-finetuned-stutterdetection") | |
# Define a function to map predicted labels to types of stuttering | |
def map_label_to_stutter_type(label): | |
if label == 0: | |
return "nonstutter" | |
elif label == 1: | |
return "prolongation" | |
elif label == 2: | |
return "repetition" | |
elif label == 3: | |
return "blocks" | |
else: | |
return "Unknown" | |
# Function to classify audio input and return the stutter type | |
def classify_audio(audio_input): | |
# Call your model pipeline to classify the audio | |
prediction = model(audio_input) | |
# Get the predicted label | |
predicted_label = prediction[0]["label"] | |
# Map the label to the corresponding stutter type | |
stutter_type = map_label_to_stutter_type(predicted_label) | |
return stutter_type | |
# Streamlit app | |
def main(): | |
st.title("Stutter Classification App") | |
audio_input = st.audio("Capture Audio", format="audio/wav", start_recording=True, channels=1) | |
if st.button("Stop Recording"): | |
sd.stop() | |
with st.spinner("Classifying..."): | |
# Read the recorded audio file | |
recording_path = "recording.wav" | |
audio_data, sampling_rate = sf.read(recording_path) | |
# Classify the audio | |
stutter_type = classify_audio(audio_data) | |
st.write("Predicted Stutter Type:", stutter_type) | |
if __name__ == "__main__": | |
main() | |