MonkeyDLLLLLLuffy commited on
Commit
a26b24a
·
verified ·
1 Parent(s): a33bb2e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -46
app.py CHANGED
@@ -1,22 +1,25 @@
1
  import streamlit as st
2
  import torch
3
- from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
4
  import torchaudio
5
  import os
6
  import re
7
- from difflib import SequenceMatcher
8
  import numpy as np
9
 
 
 
 
 
10
  # Device setup
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
- # Load Whisper model with adjusted parameters for better memory handling
14
  MODEL_NAME = "alvanlii/whisper-small-cantonese"
15
  language = "zh"
16
- pipe = pipeline(
17
  task="automatic-speech-recognition",
18
  model=MODEL_NAME,
19
- chunk_length_s=30, # Reduce chunk size for better memory handling
20
  device=device,
21
  generate_kwargs={
22
  "no_repeat_ngram_size": 3,
@@ -24,14 +27,20 @@ pipe = pipeline(
24
  "temperature": 0.7,
25
  "top_p": 0.97,
26
  "top_k": 40,
27
- "max_new_tokens": 400, # Reduced from 500 to avoid exceeding 448
28
- "do_sample": True # Required for `top_p` and `top_k` to take effect
29
  }
30
  )
31
- pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language=language, task="transcribe")
 
 
32
 
33
- # Similarity check to remove repeated phrases
34
  def remove_repeated_phrases(text):
 
 
 
 
35
  sentences = re.split(r'(?<=[。!?])', text)
36
  cleaned_sentences = []
37
  for sentence in sentences:
@@ -39,53 +48,68 @@ def remove_repeated_phrases(text):
39
  cleaned_sentences.append(sentence.strip())
40
  return " ".join(cleaned_sentences)
41
 
 
42
  def remove_punctuation(text):
43
  return re.sub(r'[^\w\s]', '', text)
44
 
 
45
  def transcribe_audio(audio_path):
46
  waveform, sample_rate = torchaudio.load(audio_path)
47
 
48
- # Convert stereo to mono (if needed)
49
- if waveform.shape[0] > 1: # More than 1 channel
50
- waveform = torch.mean(waveform, dim=0, keepdim=True) # Average the channels
51
-
52
- waveform = waveform.squeeze(0).numpy() # Convert to NumPy (1D array)
53
 
 
54
  duration = waveform.shape[0] / sample_rate
 
 
55
  if duration > 60:
56
- chunk_size = sample_rate * 55 # 55 seconds
57
- step_size = sample_rate * 50 # 50 seconds overlap
58
  results = []
59
-
60
  for start in range(0, waveform.shape[0], step_size):
61
  chunk = waveform[start:start + chunk_size]
62
  if chunk.shape[0] == 0:
63
  break
64
- transcript = pipe({"sampling_rate": sample_rate, "raw": chunk})["text"]
65
  results.append(remove_punctuation(transcript))
66
-
67
  return remove_punctuation(remove_repeated_phrases(" ".join(results)))
 
 
 
 
 
 
 
 
 
 
68
 
69
- return remove_punctuation(remove_repeated_phrases(pipe({"sampling_rate": sample_rate, "raw": waveform})["text"]))
70
-
71
- # Sentiment analysis model
72
- sentiment_pipe = pipeline("text-classification", model="MonkeyDLLLLLLuffy/CustomModel-multilingual-sentiment-analysis-enhanced", device=device)
73
-
74
- # Rate sentiment with batch processing
75
  def rate_quality(text):
76
  chunks = [text[i:i+512] for i in range(0, len(text), 512)]
77
  results = sentiment_pipe(chunks, batch_size=4)
78
 
79
- label_map = {"Very Negative": "Very Poor", "Negative": "Poor", "Neutral": "Neutral", "Positive": "Good", "Very Positive": "Very Good"}
 
 
 
 
 
 
80
  processed_results = [label_map.get(res["label"], "Unknown") for res in results]
81
 
 
82
  return max(set(processed_results), key=processed_results.count)
83
 
84
- # Streamlit main interface
 
 
85
  def main():
86
  st.set_page_config(page_title="Customer Service Analyzer", page_icon="🎙️")
87
 
88
- # Business-oriented CSS styling
89
  st.markdown("""
90
  <style>
91
  .header {
@@ -107,28 +131,67 @@ def main():
107
  </div>
108
  """, unsafe_allow_html=True)
109
 
110
- uploaded_file = st.file_uploader("📤 Please upload your Cantonese customer service audio file", type=["wav", "mp3", "flac"])
 
 
 
 
 
 
111
 
112
- if uploaded_file is not None:
113
- temp_audio_path = "uploaded_audio.wav"
114
- with open(temp_audio_path, "wb") as f:
115
- f.write(uploaded_file.getbuffer())
 
116
 
 
 
117
  st.audio(uploaded_file, format="audio/wav")
118
 
119
- with st.spinner('🔄 Processing your audio, please wait...'):
120
- transcript = transcribe_audio(temp_audio_path)
121
- quality_rating = rate_quality(transcript)
122
-
123
- st.write("**Transcript:**", transcript)
124
- st.write("**Sentiment Analysis Result:**", quality_rating)
125
-
126
- result_text = f"Transcript:\n{transcript}\n\nSentiment Analysis Result: {quality_rating}"
127
- st.download_button(label="📥 Download Analysis Report", data=result_text, file_name="analysis_report.txt")
128
-
129
- st.markdown("❓If you encounter any issues, please contact customer support: 📧 **example@hellotoby.com**")
130
-
131
- os.remove(temp_audio_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
  if __name__ == "__main__":
134
  main()
 
1
  import streamlit as st
2
  import torch
3
+ from transformers import pipeline
4
  import torchaudio
5
  import os
6
  import re
 
7
  import numpy as np
8
 
9
+ # -----------------------------
10
+ # 1) Model loading and utility functions
11
+ # -----------------------------
12
+
13
  # Device setup
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
+ # Load Whisper model for Cantonese ASR
17
  MODEL_NAME = "alvanlii/whisper-small-cantonese"
18
  language = "zh"
19
+ asr_pipe = pipeline(
20
  task="automatic-speech-recognition",
21
  model=MODEL_NAME,
22
+ chunk_length_s=30, # Adjust chunk size for memory handling
23
  device=device,
24
  generate_kwargs={
25
  "no_repeat_ngram_size": 3,
 
27
  "temperature": 0.7,
28
  "top_p": 0.97,
29
  "top_k": 40,
30
+ "max_new_tokens": 400,
31
+ "do_sample": True
32
  }
33
  )
34
+ asr_pipe.model.config.forced_decoder_ids = asr_pipe.tokenizer.get_decoder_prompt_ids(
35
+ language=language, task="transcribe"
36
+ )
37
 
38
+ # Remove repeated sentences that are highly similar
39
  def remove_repeated_phrases(text):
40
+ def is_similar(a, b):
41
+ from difflib import SequenceMatcher
42
+ return SequenceMatcher(None, a, b).ratio() > 0.9
43
+
44
  sentences = re.split(r'(?<=[。!?])', text)
45
  cleaned_sentences = []
46
  for sentence in sentences:
 
48
  cleaned_sentences.append(sentence.strip())
49
  return " ".join(cleaned_sentences)
50
 
51
+ # Remove punctuation from text
52
  def remove_punctuation(text):
53
  return re.sub(r'[^\w\s]', '', text)
54
 
55
+ # Transcribe the audio using Whisper
56
  def transcribe_audio(audio_path):
57
  waveform, sample_rate = torchaudio.load(audio_path)
58
 
59
+ # Convert multi-channel audio to mono if necessary
60
+ if waveform.shape[0] > 1:
61
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
 
 
62
 
63
+ waveform = waveform.squeeze(0).numpy()
64
  duration = waveform.shape[0] / sample_rate
65
+
66
+ # For audio longer than 60 seconds, process in overlapping chunks
67
  if duration > 60:
68
+ chunk_size = sample_rate * 55
69
+ step_size = sample_rate * 50
70
  results = []
 
71
  for start in range(0, waveform.shape[0], step_size):
72
  chunk = waveform[start:start + chunk_size]
73
  if chunk.shape[0] == 0:
74
  break
75
+ transcript = asr_pipe({"sampling_rate": sample_rate, "raw": chunk})["text"]
76
  results.append(remove_punctuation(transcript))
 
77
  return remove_punctuation(remove_repeated_phrases(" ".join(results)))
78
+ else:
79
+ transcript = asr_pipe({"sampling_rate": sample_rate, "raw": waveform})["text"]
80
+ return remove_punctuation(remove_repeated_phrases(transcript))
81
+
82
+ # Load sentiment analysis model
83
+ sentiment_pipe = pipeline(
84
+ "text-classification",
85
+ model="MonkeyDLLLLLLuffy/CustomModel-multilingual-sentiment-analysis-enhanced",
86
+ device=device
87
+ )
88
 
89
+ # Perform sentiment analysis in chunks (max 512 tokens each)
 
 
 
 
 
90
  def rate_quality(text):
91
  chunks = [text[i:i+512] for i in range(0, len(text), 512)]
92
  results = sentiment_pipe(chunks, batch_size=4)
93
 
94
+ label_map = {
95
+ "Very Negative": "Very Poor",
96
+ "Negative": "Poor",
97
+ "Neutral": "Neutral",
98
+ "Positive": "Good",
99
+ "Very Positive": "Very Good"
100
+ }
101
  processed_results = [label_map.get(res["label"], "Unknown") for res in results]
102
 
103
+ # Use majority voting to determine the final sentiment
104
  return max(set(processed_results), key=processed_results.count)
105
 
106
+ # -----------------------------
107
+ # 2) Main Streamlit application
108
+ # -----------------------------
109
  def main():
110
  st.set_page_config(page_title="Customer Service Analyzer", page_icon="🎙️")
111
 
112
+ # Custom CSS styling
113
  st.markdown("""
114
  <style>
115
  .header {
 
131
  </div>
132
  """, unsafe_allow_html=True)
133
 
134
+ # Initialize session state to store results
135
+ if "transcript" not in st.session_state:
136
+ st.session_state["transcript"] = ""
137
+ if "quality_rating" not in st.session_state:
138
+ st.session_state["quality_rating"] = ""
139
+ if "uploaded_filename" not in st.session_state:
140
+ st.session_state["uploaded_filename"] = ""
141
 
142
+ # File uploader
143
+ uploaded_file = st.file_uploader(
144
+ "📤 Please upload your Cantonese customer service audio file",
145
+ type=["wav", "mp3", "flac"]
146
+ )
147
 
148
+ if uploaded_file is not None:
149
+ # Display audio player
150
  st.audio(uploaded_file, format="audio/wav")
151
 
152
+ # Only run the model again if a new file is uploaded
153
+ if st.session_state["uploaded_filename"] != uploaded_file.name:
154
+ st.session_state["uploaded_filename"] = uploaded_file.name
155
+
156
+ # Save uploaded file to a temporary path
157
+ temp_audio_path = "uploaded_audio.wav"
158
+ with open(temp_audio_path, "wb") as f:
159
+ f.write(uploaded_file.getbuffer())
160
+
161
+ # Process the audio
162
+ with st.spinner('🔄 Processing your audio, please wait...'):
163
+ transcript = transcribe_audio(temp_audio_path)
164
+ quality_rating = rate_quality(transcript)
165
+
166
+ # Store results in session state
167
+ st.session_state["transcript"] = transcript
168
+ st.session_state["quality_rating"] = quality_rating
169
+
170
+ # Remove the temporary file
171
+ if os.path.exists(temp_audio_path):
172
+ os.remove(temp_audio_path)
173
+
174
+ # Display results if available
175
+ if st.session_state["transcript"]:
176
+ st.write("**Transcript:**", st.session_state["transcript"])
177
+ st.write("**Sentiment Analysis Result:**", st.session_state["quality_rating"])
178
+
179
+ # Prepare download content
180
+ result_text = (
181
+ f"Transcript:\n{st.session_state['transcript']}\n\n"
182
+ f"Sentiment Analysis Result: {st.session_state['quality_rating']}"
183
+ )
184
+ # Download button for the analysis report
185
+ st.download_button(
186
+ label="📥 Download Analysis Report",
187
+ data=result_text,
188
+ file_name="analysis_report.txt"
189
+ )
190
+
191
+ st.markdown(
192
+ "❓If you encounter any issues, please contact customer support: "
193
+ "📧 **example@hellotoby.com**"
194
+ )
195
 
196
  if __name__ == "__main__":
197
  main()