from transformers import pipeline # Load model directly from transformers import AutoProcessor, AutoModelForTextToWaveform import gradio as gr processor = AutoProcessor.from_pretrained("suno/bark-small") # Load audio classification model audio_classifier = pipeline( "audio-classification", model="HareemFatima/distilhubert-finetuned-stutterdetection" ) # Load text-to-speech model tts_processor = AutoProcessor.from_pretrained("suno/bark-small") tts_model = AutoModelForTextToWaveform.from_pretrained("suno/bark-small") # Define therapy text for different stutter types (replace with your specific therapy content) therapy_text = { "Normal Speech": "Your speech sounds great! Keep practicing!", "Blocking": "Take a deep breath and try speaking slowly. You can do it!", "Prolongation": "Focus on relaxing your mouth muscles and speaking smoothly.", # Add more stutter types and therapy text here } def predict_and_synthesize(audio): """Predicts stutter type and synthesizes speech with therapy text. Args: audio (bytes): Audio data from the user. Returns: tuple: A tuple containing the predicted stutter type (string) and synthesized speech (bytes). """ # Classify stuttering type using audio classification model prediction = audio_classifier(audio) stutter_type = prediction[0]["label"] # Retrieve therapy text based on predicted stutter type therapy = therapy_text.get(stutter_type, "General therapy tip: Practice slow, relaxed speech.") # Generate synthesized speech with the therapy text synthesized_speech = tts_model.generate( tts_processor(therapy, return_tensors="pt").input_ids )[0].squeeze().cpu().numpy() return stutter_type, synthesized_speech # Create Gradio interface interface = gr.Interface( fn=predict_and_synthesize, inputs="microphone", outputs=["text", "audio"], title="Stuttering Therapy Assistant", description="This app helps you identify stuttering types and provides personalized therapy suggestions. Upload an audio clip, and it will analyze the speech and generate audio with relevant therapy tips.", ) interface.launch(debug=False)