lol040604lol commited on
Commit
40e49d9
·
verified ·
1 Parent(s): 5471b51

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -263
app.py CHANGED
@@ -8,14 +8,11 @@ import whisper
8
  import torch
9
  import requests
10
  from datetime import datetime
11
- from sklearn.linear_model import LinearRegression
12
  import numpy as np
13
  from dotenv import load_dotenv
14
  import os
15
- import librosa
16
- import tempfile
17
  import soundfile as sf
18
- from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration
19
 
20
  # Load .env file
21
  load_dotenv()
@@ -27,7 +24,7 @@ model_path = os.getenv('MODEL_PATH')
27
  # Load product and objection data
28
  @st.cache_resource
29
  def load_data():
30
- product_data = pd.read_csv("product_data.csv") # Use relative path
31
  objections_data = pd.read_csv("objections_data.csv")
32
  return product_data, objections_data
33
 
@@ -39,26 +36,12 @@ objections = objections_data['objection'].tolist()
39
  responses = objections_data['response'].tolist()
40
 
41
  # Initialize models
42
-
43
- RTC_CONFIG = {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
44
-
45
  def initialize_models():
46
- """
47
- Initializes the SentenceTransformer and Whisper models.
48
-
49
- Returns:
50
- model: SentenceTransformer instance for embeddings.
51
- whisper_model: Whisper model instance for audio transcription.
52
- """
53
  try:
54
- # Check for GPU availability
55
  device = "cuda" if torch.cuda.is_available() else "cpu"
56
  print(f"Using device: {device}")
57
 
58
- # Load SentenceTransformer model
59
  model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
60
-
61
- # Load Whisper model
62
  whisper_model = whisper.load_model("base", device=device)
63
 
64
  print("Models successfully initialized.")
@@ -67,7 +50,6 @@ def initialize_models():
67
  print(f"Error initializing models: {e}")
68
  raise
69
 
70
- # Initialize and store models
71
  model, whisper_model = initialize_models()
72
 
73
  # Create embeddings and FAISS indices
@@ -85,256 +67,47 @@ def create_indices():
85
  return product_index, objection_index
86
 
87
  product_index, objection_index = create_indices()
88
- # Configuration for WebRTC
89
- RTC_CONFIG = RTCConfiguration({"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]})
90
-
91
- from io import BytesIO
92
- import soundfile as sf
93
 
94
- def process_audio(audio_data, sample_rate=16000):
 
95
  try:
96
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio_file:
97
- sf.write(temp_audio_file.name, audio_data, samplerate=sample_rate)
 
98
 
99
- # Transcribe audio using Whisper
100
- transcription = whisper_model.transcribe(temp_audio_file.name)["text"]
101
-
102
- os.unlink(temp_audio_file.name) # Cleanup temporary file
103
  return transcription
104
  except Exception as e:
105
- st.error(f"Error processing audio: {e}")
106
  return None
107
 
108
-
109
- class AudioProcessor:
110
- def __init__(self, whisper_model):
111
- self.whisper_model = whisper_model
112
- self.transcription = ""
113
-
114
- def process_audio(self, audio_data):
115
- try:
116
- # Convert raw audio bytes to a WAV file
117
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio_file:
118
- sf.write(temp_audio_file.name, audio_data, samplerate=16000)
119
-
120
- # Transcribe audio using Whisper
121
- transcription = self.whisper_model.transcribe(temp_audio_file.name)["text"]
122
- os.unlink(temp_audio_file.name) # Clean up the temporary file
123
- return transcription
124
- except Exception as e:
125
- st.error(f"Error processing audio: {e}")
126
- return None
127
- audio_processor = AudioProcessor(whisper_model)
128
-
129
- import av
130
- from streamlit_webrtc import webrtc_streamer, WebRtcMode
131
-
132
- def audio_callback(frame: av.AudioFrame):
133
- audio = frame.to_ndarray()
134
- transcription = process_audio(audio)
135
- if transcription:
136
- st.session_state.transcription = transcription
137
-
138
- webrtc_streamer(
139
- key="speech-to-text",
140
- mode=WebRtcMode.SENDRECV,
141
- rtc_configuration=RTC_CONFIG,
142
- media_stream_constraints={"audio": True, "video": False},
143
- async_processing=True,
144
- audio_processor_factory=audio_callback
145
- )
146
-
147
-
148
- # Initialize audio stream
149
- # Hugging Face API for sentiment analysis
150
- API_URL = "https://api-inference.huggingface.co/models/distilbert-base-uncased-finetuned-sst-2-english"
151
- API_KEY = api_key
152
- headers = {"Authorization": f"Bearer {API_KEY}"}
153
-
154
- def analyze_sentiment(text):
155
- payload = {"inputs": text}
156
- response = requests.post(API_URL, headers=headers, json=payload)
157
- if response.status_code == 200:
158
- result = response.json()
159
- sentiments = result[0]
160
- if len(sentiments) > 0:
161
- best_sentiment = max(sentiments, key=lambda x: x['score'])
162
- return best_sentiment
163
- else:
164
- return {"label": "ERROR", "score": 0.0}
165
- return {"label": "ERROR", "score": 0.0}
166
-
167
-
168
- def recommend_products(query):
169
- query_embedding = model.encode([query])
170
- distances, indices = product_index.search(query_embedding, 3)
171
- return [(product_titles[i], product_descriptions[i]) for i in indices[0]]
172
-
173
- def handle_objection(query):
174
- query_embedding = model.encode([query])
175
- distances, indices = objection_index.search(query_embedding, 1)
176
- idx = indices[0][0]
177
- return objections[idx], responses[idx]
178
-
179
- # Function to save session data to a JSON file
180
- def save_session_data(session_data):
181
- with open("session_data.json", "w") as f:
182
- json.dump(session_data, f)
183
-
184
  # Streamlit UI
185
- st.title("Real-Time Product Recommendation & Sentiment Analysis")
186
-
187
- session_data = {
188
- "interactions": [],
189
- "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
190
- }
191
-
192
- if st.button("Stop Listening"):
193
- st.info("Processing session data...")
194
-
195
- # Load session data from the JSON file
196
- with open("session_data.json", "r") as f:
197
- session_data = json.load(f)
198
-
199
- # Import and analyze data using dashboard.py
200
- from dashboard import analyze_data # Ensure dashboard.py is in the same directory
201
- analysis_results = analyze_data(session_data)
202
-
203
- # Initialize variables for the summary
204
- summary_data = []
205
- all_recommendations = []
206
- objection_summary = []
207
-
208
- # Process each interaction in the session data
209
- for i, interaction in enumerate(session_data["interactions"]):
210
- # Extract relevant data
211
- customer_transcription = interaction['transcription']
212
- sentiment_label = interaction['sentiment']['label']
213
- product_recommendations = [rec[1] for rec in interaction['product_recommendations']]
214
- objection = interaction.get('objection_handling', None)
215
-
216
- # Build the narrative for each interaction
217
- if i == 0:
218
- # Start of the call
219
- summary_data.append(f"When the call started, the customer was {sentiment_label}. ")
220
- else:
221
- # Progression of the conversation
222
- summary_data.append(f"Then, the customer's tone shifted to {sentiment_label}. ")
223
-
224
- # Add product recommendations to the summary
225
- summary_data.append(f"We provided recommendations: {', '.join(product_recommendations)}. ")
226
-
227
- # Add objection handling if applicable
228
- #if objection:
229
- # summary_data.append(f"Objection: {objection['objection']}. Response: {objection['response']}. ")
230
-
231
- # Collect all recommendations and objections for analysis
232
- #all_recommendations.extend(product_recommendations)
233
- #if objection:
234
- # objection_summary.append(f"Objection: {objection['objection']}, Response: {objection['response']}")
235
- # Combine the summary data into one long narrative
236
- narrative_summary = " ".join(summary_data)
237
- overall_sentiment = (
238
- "Overall sentiment trends are depicted in the Call Summary Table and Sentiment Trends graph. "
239
- "Explore Sentiment Predictions below to anticipate the customer's future interests."
240
- )
241
-
242
- # Debug: Verify narrative summary
243
- print("Narrative Summary:", narrative_summary)
244
- print("All Recommendations:", all_recommendations)
245
- print("Objection Summary:", objection_summary)
246
 
247
- # Load BART model and tokenizer for summarization
248
- from transformers import BartForConditionalGeneration, BartTokenizer
249
 
250
- def load_bart_model():
251
- model_name = "facebook/bart-large-cnn" # Pre-trained BART model for summarization
252
- model = BartForConditionalGeneration.from_pretrained(model_name)
253
- tokenizer = BartTokenizer.from_pretrained(model_name)
254
- return model, tokenizer
255
-
256
- model, tokenizer = load_bart_model()
257
-
258
- # Function to generate summary using BART
259
- def generate_summary(text, model, tokenizer):
260
- inputs = tokenizer.encode("summarize: " + text, return_tensors="pt", max_length=1024, truncation=True)
261
- summary_ids = model.generate(inputs, max_length=200, min_length=50, length_penalty=2.0, num_beams=4, early_stopping=True)
262
- summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
263
- return summary
264
-
265
- # Generate the post-call summary
266
- call_summary = generate_summary(narrative_summary, model, tokenizer)
267
-
268
- # Display the post-call summary
269
- st.subheader("Post-Call Summary")
270
- st.write(f"Session Timestamp: {analysis_results['timestamp']}")
271
- st.write("Debug Narrative Summary:", narrative_summary)
272
- # Display the summary table
273
-
274
- st.subheader("Call Summary Table")
275
- st.dataframe(analysis_results["summary_table"])
276
-
277
- # Display the sentiment trends chart
278
- st.subheader("Sentiment Trends")
279
- st.pyplot(analysis_results["sentiment_chart"])
280
-
281
- # Display product recommendation trends
282
- st.subheader("Top Product Recommendations")
283
- st.pyplot(analysis_results["recommendation_chart"])
284
-
285
- # Display sentiment predictions (if available)
286
- st.subheader("Sentiment Predictions")
287
- sentiment_predictions = analysis_results["sentiment_predictions"]
288
- if isinstance(sentiment_predictions, str): # Handle insufficient data case
289
- st.write(sentiment_predictions)
290
- else:
291
- st.line_chart(sentiment_predictions)
292
-
293
- # Display the word cloud for transcription topics
294
- st.subheader("Transcription Word Cloud")
295
- st.pyplot(analysis_results["wordcloud"])
296
-
297
- # Display actionable recommendations
298
- st.subheader("Actionable Recommendations")
299
- for recommendation in analysis_results["actionable_recommendations"]:
300
- st.write(f"- {recommendation}")
301
-
302
- if "session_data" not in st.session_state:
303
- st.session_state["session_data"] = {"interactions": [], "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
304
-
305
- session_data = st.session_state["session_data"]
306
-
307
- import streamlit as st
308
-
309
- import tempfile
310
- import os
311
- import soundfile as sf
312
- import json
313
-
314
- import tempfile
315
- import os
316
- import soundfile as sf
317
- import json
318
-
319
- if "listening" not in st.session_state:
320
- st.session_state.listening = False
321
- if "transcription" not in st.session_state:
322
- st.session_state.transcription = ""
323
-
324
- if st.button("Start Listening"):
325
- st.session_state.listening = True
326
-
327
- if st.session_state.transcription:
328
- st.subheader("Transcribed Text:")
329
- st.write(st.session_state.transcription)
330
-
331
- if st.session_state.listening:
332
- st.info("Listening... Speak into the microphone.")
333
- webrtc_streamer(
334
- key="example",
335
- mode=WebRtcMode.SENDRECV,
336
- rtc_configuration=RTC_CONFIG,
337
- media_stream_constraints={"audio": True, "video": False},
338
- audio_frame_callback=audio_callback,
339
- async_processing=True,
340
- )
 
8
  import torch
9
  import requests
10
  from datetime import datetime
 
11
  import numpy as np
12
  from dotenv import load_dotenv
13
  import os
 
 
14
  import soundfile as sf
15
+ import tempfile
16
 
17
  # Load .env file
18
  load_dotenv()
 
24
  # Load product and objection data
25
  @st.cache_resource
26
  def load_data():
27
+ product_data = pd.read_csv("product_data.csv")
28
  objections_data = pd.read_csv("objections_data.csv")
29
  return product_data, objections_data
30
 
 
36
  responses = objections_data['response'].tolist()
37
 
38
  # Initialize models
 
 
 
39
  def initialize_models():
 
 
 
 
 
 
 
40
  try:
 
41
  device = "cuda" if torch.cuda.is_available() else "cpu"
42
  print(f"Using device: {device}")
43
 
 
44
  model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
 
 
45
  whisper_model = whisper.load_model("base", device=device)
46
 
47
  print("Models successfully initialized.")
 
50
  print(f"Error initializing models: {e}")
51
  raise
52
 
 
53
  model, whisper_model = initialize_models()
54
 
55
  # Create embeddings and FAISS indices
 
67
  return product_index, objection_index
68
 
69
  product_index, objection_index = create_indices()
 
 
 
 
 
70
 
71
+ # Process recorded audio file
72
+ def process_audio_file(uploaded_file):
73
  try:
74
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
75
+ temp_audio.write(uploaded_file.read())
76
+ temp_audio_path = temp_audio.name
77
 
78
+ transcription = whisper_model.transcribe(temp_audio_path)["text"]
79
+ os.unlink(temp_audio_path) # Cleanup temporary file
 
 
80
  return transcription
81
  except Exception as e:
82
+ st.error(f"Error processing audio file: {e}")
83
  return None
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  # Streamlit UI
86
+ st.title("Recorded Audio Product Recommendation & Sentiment Analysis")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
+ uploaded_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "m4a"])
 
89
 
90
+ if uploaded_file is not None:
91
+ st.audio(uploaded_file, format='audio/wav')
92
+ transcription = process_audio_file(uploaded_file)
93
+ if transcription:
94
+ st.subheader("Transcription:")
95
+ st.write(transcription)
96
+
97
+ sentiment_result = requests.post(
98
+ "https://api-inference.huggingface.co/models/distilbert-base-uncased-finetuned-sst-2-english",
99
+ headers={"Authorization": f"Bearer {api_key}"},
100
+ json={"inputs": transcription}
101
+ ).json()
102
+
103
+ if sentiment_result:
104
+ st.subheader("Sentiment Analysis:")
105
+ st.write(sentiment_result[0])
106
+
107
+ recommendations = model.encode([transcription])
108
+ distances, indices = product_index.search(recommendations, 3)
109
+ recommended_products = [(product_titles[i], product_descriptions[i]) for i in indices[0]]
110
+
111
+ st.subheader("Product Recommendations:")
112
+ for title, desc in recommended_products:
113
+ st.write(f"**{title}**: {desc}")