import streamlit as st from transformers import pipeline,AutoProcessor, AutoTokenizer, AutoModelForTextToWaveform # Load the audio classification model processor = AutoProcessor.from_pretrained("HareemFatima/distilhubert-finetuned-stutterdetection") model = AutoModelForAudioClassification.from_pretrained("HareemFatima/distilhubert-finetuned-stutterdetection") # Load the TTS tokenizer and model tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng") tts_model = AutoModelForTextToWaveform.from_pretrained("facebook/mms-tts-eng") # Define a function to classify audio and generate speech def classify_and_speak(audio_input): # Classify the audio classification_result = audio_classification_model(audio_input) predicted_class = classification_result[0]["label"] # Map predicted class to corresponding speech text speech_text = map_class_to_speech(predicted_class) # Generate speech input_ids = tts_tokenizer(speech_text, return_tensors="pt").input_ids speech = tts_model.generate(input_ids) # Display classification result and play speech st.write("Predicted Stutter Type:", predicted_class) st.audio(speech, format="audio/wav") # Define a function to map predicted class to speech text def map_class_to_speech(predicted_class): # Define speech text for each class speech_texts = { "nonstutter": "You are speaking fluently without any stutter.", "prolongation": "You are experiencing prolongation stutter. Try to relax and speak slowly.", "repetition": "You are experiencing repetition stutter. Focus on your breathing and try to speak smoothly.", "blocks": "You are experiencing block stutter. Take a deep breath and try to speak slowly and calmly." } return speech_texts.get(predicted_class, "Unknown stutter type") # Streamlit app def main(): st.title("Stutter Classification and Therapy App") audio_input = st.audio("Capture Audio", format="audio/wav", start_recording=True, channels=1) if st.button("Stop Recording"): with st.spinner("Classifying and speaking..."): classify_and_speak(audio_input) if __name__ == "__main__": main()