abdullahmubeen10 commited on
Commit
7978b65
Β·
verified Β·
1 Parent(s): e6f3c3d

Update Demo.py

Browse files
Files changed (1) hide show
  1. Demo.py +132 -132
Demo.py CHANGED
@@ -1,133 +1,133 @@
1
- import streamlit as st
2
- import sparknlp
3
- import os
4
- import pandas as pd
5
- import librosa
6
-
7
- from sparknlp.base import *
8
- from sparknlp.common import *
9
- from sparknlp.annotator import *
10
- from pyspark.ml import Pipeline
11
- from sparknlp.pretrained import PretrainedPipeline
12
- from pyspark.sql.types import *
13
- import pyspark.sql.functions as F
14
-
15
- # Page configuration
16
- st.set_page_config(
17
- layout="wide",
18
- initial_sidebar_state="auto"
19
- )
20
-
21
- # Custom CSS for styling
22
- st.markdown("""
23
- <style>
24
- .main-title {
25
- font-size: 36px;
26
- color: #4A90E2;
27
- font-weight: bold;
28
- text-align: center;
29
- }
30
- .section {
31
- background-color: #f9f9f9;
32
- padding: 10px;
33
- border-radius: 10px;
34
- margin-top: 10px;
35
- }
36
- .section p, .section ul {
37
- color: #666666;
38
- }
39
- </style>
40
- """, unsafe_allow_html=True)
41
-
42
- @st.cache_resource
43
- def init_spark():
44
- """Initialize Spark NLP."""
45
- return sparknlp.start()
46
-
47
- @st.cache_resource
48
- def create_pipeline(model):
49
- """Create a Spark NLP pipeline for audio processing."""
50
- audio_assembler = AudioAssembler() \
51
- .setInputCol("audio_content") \
52
- .setOutputCol("audio_assembler")
53
-
54
- speech_to_text = WhisperForCTC \
55
- .pretrained(model)\
56
- .setInputCols("audio_assembler") \
57
- .setOutputCol("text")
58
-
59
- pipeline = Pipeline(stages=[
60
- audio_assembler,
61
- speech_to_text
62
- ])
63
- return pipeline
64
-
65
- def fit_data(pipeline, fed_data):
66
- """Fit the data into the pipeline and return the transcription."""
67
- data, sampling_rate = librosa.load(fed_data, sr=16000)
68
- data = data.tolist()
69
- spark_df = spark.createDataFrame([[data]], ["audio_content"])
70
-
71
- model = pipeline.fit(spark_df)
72
- lp = LightPipeline(model)
73
- lp_result = lp.fullAnnotate(data)[0]
74
- return lp_result
75
-
76
- def save_uploadedfile(uploadedfile, path):
77
- """Save the uploaded file to the specified path."""
78
- filepath = os.path.join(path, uploadedfile.name)
79
- with open(filepath, "wb") as f:
80
- if hasattr(uploadedfile, 'getbuffer'):
81
- f.write(uploadedfile.getbuffer())
82
- else:
83
- f.write(uploadedfile.read())
84
-
85
- # Sidebar content
86
- model_list = ["asr_whisper_small_english"]
87
- model = st.sidebar.selectbox(
88
- "Choose the pretrained model",
89
- model_list,
90
- help="For more info about the models visit: https://sparknlp.org/models"
91
- )
92
-
93
- # Main content
94
- st.markdown('<div class="main-title">Speech Recognition With WhisperForCTC</div>', unsafe_allow_html=True)
95
- st.markdown('<div class="section"><p>This demo transcribes audio files into texts using the <code>WhisperForCTC</code> Annotator and advanced speech recognition models.</p></div>', unsafe_allow_html=True)
96
-
97
- # Reference notebook link in sidebar
98
- st.sidebar.markdown('Reference notebook:')
99
- st.sidebar.markdown("""
100
- <a href="https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/annotation/audio/whisper/Automatic_Speech_Recognition_Whisper_(WhisperForCTC).ipynb">
101
- <img src="https://colab.research.google.com/assets/colab-badge.svg" style="zoom: 1.3" alt="Open In Colab"/>
102
- </a>
103
- """, unsafe_allow_html=True)
104
-
105
- # Load examples
106
- AUDIO_FILE_PATH = "inputs"
107
- audio_files = sorted(os.listdir(AUDIO_FILE_PATH))
108
-
109
- selected_audio = st.selectbox("Select an audio", audio_files)
110
-
111
- # Creating a simplified Python list of audio file types
112
- audio_file_types = ["mp3", "flac", "wav", "aac", "ogg", "aiff", "wma", "m4a", "ape", "dsf", "dff", "midi", "mid", "opus", "amr"]
113
- uploadedfile = st.file_uploader("Try it for yourself!", type=audio_file_types)
114
-
115
- if uploadedfile:
116
- selected_audio = f"{AUDIO_FILE_PATH}/{uploadedfile.name}"
117
- save_uploadedfile(uploadedfile, AUDIO_FILE_PATH)
118
- elif selected_audio:
119
- selected_audio = f"{AUDIO_FILE_PATH}/{selected_audio}"
120
-
121
- # Audio playback and transcription
122
- st.subheader("Play Audio")
123
-
124
- with open(selected_audio, 'rb') as audio_file:
125
- audio_bytes = audio_file.read()
126
- st.audio(audio_bytes)
127
-
128
- spark = init_spark()
129
- pipeline = create_pipeline(model)
130
- output = fit_data(pipeline, selected_audio)
131
-
132
- st.subheader(f"Transcription:")
133
  st.markdown(f"{(output['text'][0].result).title()}")
 
1
+ import streamlit as st
2
+ import sparknlp
3
+ import os
4
+ import pandas as pd
5
+ import librosa
6
+
7
+ from sparknlp.base import *
8
+ from sparknlp.common import *
9
+ from sparknlp.annotator import *
10
+ from pyspark.ml import Pipeline
11
+ from sparknlp.pretrained import PretrainedPipeline
12
+ from pyspark.sql.types import *
13
+ import pyspark.sql.functions as F
14
+
15
+ # Page configuration
16
+ st.set_page_config(
17
+ layout="wide",
18
+ initial_sidebar_state="auto"
19
+ )
20
+
21
+ # Custom CSS for styling
22
+ st.markdown("""
23
+ <style>
24
+ .main-title {
25
+ font-size: 36px;
26
+ color: #4A90E2;
27
+ font-weight: bold;
28
+ text-align: center;
29
+ }
30
+ .section {
31
+ background-color: #f9f9f9;
32
+ padding: 10px;
33
+ border-radius: 10px;
34
+ margin-top: 10px;
35
+ }
36
+ .section p, .section ul {
37
+ color: #666666;
38
+ }
39
+ </style>
40
+ """, unsafe_allow_html=True)
41
+
42
+ @st.cache_resource
43
+ def init_spark():
44
+ """Initialize Spark NLP."""
45
+ return sparknlp.start()
46
+
47
+ @st.cache_resource
48
+ def create_pipeline(model):
49
+ """Create a Spark NLP pipeline for audio processing."""
50
+ audioAssembler = AudioAssembler() \
51
+ .setInputCol("audio_content") \
52
+ .setOutputCol("audio_assembler")
53
+
54
+
55
+ speechToText = WhisperForCTC.pretrained("asr_whisper_small_english","en") \
56
+ .setInputCols(["audio_assembler"]) \
57
+ .setOutputCol("text")
58
+
59
+ pipeline = Pipeline(stages=[
60
+ audioAssembler,
61
+ speechToText
62
+ ])
63
+ return pipeline
64
+
65
+ def fit_data(pipeline, fed_data):
66
+ """Fit the data into the pipeline and return the transcription."""
67
+ data, sampling_rate = librosa.load(fed_data, sr=16000)
68
+ data = data.tolist()
69
+ spark_df = spark.createDataFrame([[data]], ["audio_content"])
70
+
71
+ model = pipeline.fit(spark_df)
72
+ lp = LightPipeline(model)
73
+ lp_result = lp.fullAnnotate(data)[0]
74
+ return lp_result
75
+
76
+ def save_uploadedfile(uploadedfile, path):
77
+ """Save the uploaded file to the specified path."""
78
+ filepath = os.path.join(path, uploadedfile.name)
79
+ with open(filepath, "wb") as f:
80
+ if hasattr(uploadedfile, 'getbuffer'):
81
+ f.write(uploadedfile.getbuffer())
82
+ else:
83
+ f.write(uploadedfile.read())
84
+
85
+ # Sidebar content
86
+ model_list = ["asr_whisper_small_english"]
87
+ model = st.sidebar.selectbox(
88
+ "Choose the pretrained model",
89
+ model_list,
90
+ help="For more info about the models visit: https://sparknlp.org/models"
91
+ )
92
+
93
+ # Main content
94
+ st.markdown('<div class="main-title">Speech Recognition With WhisperForCTC</div>', unsafe_allow_html=True)
95
+ st.markdown('<div class="section"><p>This demo transcribes audio files into texts using the <code>WhisperForCTC</code> Annotator and advanced speech recognition models.</p></div>', unsafe_allow_html=True)
96
+
97
+ # Reference notebook link in sidebar
98
+ st.sidebar.markdown('Reference notebook:')
99
+ st.sidebar.markdown("""
100
+ <a href="https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/annotation/audio/whisper/Automatic_Speech_Recognition_Whisper_(WhisperForCTC).ipynb">
101
+ <img src="https://colab.research.google.com/assets/colab-badge.svg" style="zoom: 1.3" alt="Open In Colab"/>
102
+ </a>
103
+ """, unsafe_allow_html=True)
104
+
105
+ # Load examples
106
+ AUDIO_FILE_PATH = "inputs"
107
+ audio_files = sorted(os.listdir(AUDIO_FILE_PATH))
108
+
109
+ selected_audio = st.selectbox("Select an audio", audio_files)
110
+
111
+ # Creating a simplified Python list of audio file types
112
+ audio_file_types = ["mp3", "flac", "wav", "aac", "ogg", "aiff", "wma", "m4a", "ape", "dsf", "dff", "midi", "mid", "opus", "amr"]
113
+ uploadedfile = st.file_uploader("Try it for yourself!", type=audio_file_types)
114
+
115
+ if uploadedfile:
116
+ selected_audio = f"{AUDIO_FILE_PATH}/{uploadedfile.name}"
117
+ save_uploadedfile(uploadedfile, AUDIO_FILE_PATH)
118
+ elif selected_audio:
119
+ selected_audio = f"{AUDIO_FILE_PATH}/{selected_audio}"
120
+
121
+ # Audio playback and transcription
122
+ st.subheader("Play Audio")
123
+
124
+ with open(selected_audio, 'rb') as audio_file:
125
+ audio_bytes = audio_file.read()
126
+ st.audio(audio_bytes)
127
+
128
+ spark = init_spark()
129
+ pipeline = create_pipeline(model)
130
+ output = fit_data(pipeline, selected_audio)
131
+
132
+ st.subheader(f"Transcription:")
133
  st.markdown(f"{(output['text'][0].result).title()}")