HareemFatima's picture
Update app.py
1bb5dbb verified
import streamlit as st
from transformers import pipeline,AutoProcessor, AutoTokenizer, AutoModelForTextToWaveform
# Load the audio classification model
processor = AutoProcessor.from_pretrained("distilhubert-finetuned-stutterdetection")
model = AutoModelForAudioClassification.from_pretrained("distilhubert-finetuned-stutterdetection")
# Load the TTS tokenizer and model
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
tts_model = AutoModelForTextToWaveform.from_pretrained("facebook/mms-tts-eng")
# Define a function to classify audio and generate speech
def classify_and_speak(audio_input):
# Classify the audio
classification_result = audio_classification_model(audio_input)
predicted_class = classification_result[0]["label"]
# Map predicted class to corresponding speech text
speech_text = map_class_to_speech(predicted_class)
# Generate speech
input_ids = tts_tokenizer(speech_text, return_tensors="pt").input_ids
speech = tts_model.generate(input_ids)
# Display classification result and play speech
st.write("Predicted Stutter Type:", predicted_class)
st.audio(speech, format="audio/wav")
# Define a function to map predicted class to speech text
def map_class_to_speech(predicted_class):
# Define speech text for each class
speech_texts = {
"nonstutter": "You are speaking fluently without any stutter.",
"prolongation": "You are experiencing prolongation stutter. Try to relax and speak slowly.",
"repetition": "You are experiencing repetition stutter. Focus on your breathing and try to speak smoothly.",
"blocks": "You are experiencing block stutter. Take a deep breath and try to speak slowly and calmly."
}
return speech_texts.get(predicted_class, "Unknown stutter type")
# Streamlit app
def main():
st.title("Stutter Classification and Therapy App")
audio_input = st.audio("Capture Audio", format="audio/wav", start_recording=True, channels=1)
if st.button("Stop Recording"):
with st.spinner("Classifying and speaking..."):
classify_and_speak(audio_input)
if __name__ == "__main__":
main()