HareemFatima commited on
Commit
18e0da2
1 Parent(s): ec29075

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -33
app.py CHANGED
@@ -1,47 +1,32 @@
1
  import streamlit as st
2
- import sounddevice as sd
3
- import soundfile as sf
4
  from transformers import pipeline
5
 
6
  # Load the model pipeline
7
  model = pipeline("audio-classification", model="HareemFatima/distilhubert-finetuned-stutterdetection")
8
 
9
- # Define a function to map predicted labels to types of stuttering
10
- def map_label_to_stutter_type(label):
11
- if label == 0:
12
- return "nonstutter"
13
- elif label == 1:
14
- return "prolongation"
15
- elif label == 2:
16
- return "repetition"
17
- elif label == 3:
18
- return "blocks"
19
- else:
20
- return "Unknown"
21
-
22
- # Function to classify audio input and return the stutter type
23
- def classify_audio(audio_input):
24
- # Call your model pipeline to classify the audio
25
- prediction = model(audio_input)
26
- # Get the predicted label
27
- predicted_label = prediction[0]["label"]
28
- # Map the label to the corresponding stutter type
29
- stutter_type = map_label_to_stutter_type(predicted_label)
30
- return stutter_type
31
-
32
  # Streamlit app
33
  def main():
34
  st.title("Stutter Classification App")
35
  audio_input = st.audio("Capture Audio", format="audio/wav", start_recording=True, channels=1)
36
  if st.button("Stop Recording"):
37
- sd.stop()
38
- with st.spinner("Classifying..."):
39
- # Read the recorded audio file
40
- recording_path = "recording.wav"
41
- audio_data, sampling_rate = sf.read(recording_path)
42
- # Classify the audio
43
- stutter_type = classify_audio(audio_data)
44
- st.write("Predicted Stutter Type:", stutter_type)
 
 
 
 
 
 
 
 
 
 
45
 
46
  if __name__ == "__main__":
47
  main()
 
1
  import streamlit as st
 
 
2
  from transformers import pipeline
3
 
4
  # Load the model pipeline
5
  model = pipeline("audio-classification", model="HareemFatima/distilhubert-finetuned-stutterdetection")
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  # Streamlit app
8
  def main():
9
  st.title("Stutter Classification App")
10
  audio_input = st.audio("Capture Audio", format="audio/wav", start_recording=True, channels=1)
11
  if st.button("Stop Recording"):
12
+ # Assuming the recording is saved as "recording.wav"
13
+ recording_path = "recording.wav"
14
+ # Call the model pipeline to classify the audio
15
+ prediction = model(recording_path)
16
+ # Get the predicted label
17
+ predicted_label = prediction[0]["label"]
18
+ # Map the label to the corresponding stutter type
19
+ if predicted_label == 0:
20
+ stutter_type = "nonstutter"
21
+ elif predicted_label == 1:
22
+ stutter_type = "prolongation"
23
+ elif predicted_label == 2:
24
+ stutter_type = "repetition"
25
+ elif predicted_label == 3:
26
+ stutter_type = "blocks"
27
+ else:
28
+ stutter_type = "Unknown"
29
+ st.write("Predicted Stutter Type:", stutter_type)
30
 
31
  if __name__ == "__main__":
32
  main()