awacke1 commited on
Commit
2e6063c
β€’
1 Parent(s): 7ec5b58

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +471 -0
app.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ from sentence_transformers import SentenceTransformer
5
+ from sklearn.metrics.pairwise import cosine_similarity
6
+ import torch
7
+ import json
8
+ import os
9
+ import glob
10
+ from pathlib import Path
11
+ from datetime import datetime, timedelta
12
+ import edge_tts
13
+ import asyncio
14
+ import requests
15
+ from collections import defaultdict
16
+ import streamlit.components.v1 as components
17
+ from urllib.parse import quote
18
+ from xml.etree import ElementTree as ET
19
+ from datasets import load_dataset
20
+ import base64
21
+ import re
22
+
23
+ # 🧠 Initialize session state variables
24
+ SESSION_VARS = {
25
+ 'search_history': [], # Track search history
26
+ 'last_voice_input': "", # Last voice input
27
+ 'transcript_history': [], # Conversation history
28
+ 'should_rerun': False, # Trigger for UI updates
29
+ 'search_columns': [], # Available search columns
30
+ 'initial_search_done': False, # First search flag
31
+ 'tts_voice': "en-US-AriaNeural", # Default voice
32
+ 'arxiv_last_query': "", # Last ArXiv search
33
+ 'dataset_loaded': False, # Dataset load status
34
+ 'current_page': 0, # Current data page
35
+ 'data_cache': None, # Data cache
36
+ 'dataset_info': None, # Dataset metadata
37
+ 'nps_submitted': False, # Track if user submitted NPS
38
+ 'nps_last_shown': None, # When NPS was last shown
39
+ 'old_val': None, # Previous voice input value
40
+ 'voice_text': None # Processed voice text
41
+ }
42
+
43
+ # Constants
44
+ ROWS_PER_PAGE = 100
45
+ MIN_SEARCH_SCORE = 0.3
46
+ EXACT_MATCH_BOOST = 2.0
47
+
48
+ # Initialize session state
49
+ for var, default in SESSION_VARS.items():
50
+ if var not in st.session_state:
51
+ st.session_state[var] = default
52
+
53
+ # Voice Component Setup
54
+ def create_voice_component():
55
+ """Create the voice input component"""
56
+ mycomponent = components.declare_component(
57
+ "mycomponent",
58
+ path="mycomponent"
59
+ )
60
+ return mycomponent
61
+
62
+ # Utility Functions
63
+ def clean_for_speech(text: str) -> str:
64
+ """Clean text for speech synthesis"""
65
+ text = text.replace("\n", " ")
66
+ text = text.replace("</s>", " ")
67
+ text = text.replace("#", "")
68
+ text = re.sub(r"\(https?:\/\/[^\)]+\)", "", text)
69
+ text = re.sub(r"\s+", " ", text).strip()
70
+ return text
71
+
72
+ async def edge_tts_generate_audio(text, voice="en-US-AriaNeural", rate=0, pitch=0):
73
+ """Generate audio using Edge TTS"""
74
+ text = clean_for_speech(text)
75
+ if not text.strip():
76
+ return None
77
+ rate_str = f"{rate:+d}%"
78
+ pitch_str = f"{pitch:+d}Hz"
79
+ communicate = edge_tts.Communicate(text, voice, rate=rate_str, pitch=pitch_str)
80
+ out_fn = f"speech_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp3"
81
+ await communicate.save(out_fn)
82
+ return out_fn
83
+
84
+ def speak_with_edge_tts(text, voice="en-US-AriaNeural", rate=0, pitch=0):
85
+ """Wrapper for edge TTS generation"""
86
+ return asyncio.run(edge_tts_generate_audio(text, voice, rate, pitch))
87
+
88
+ def play_and_download_audio(file_path):
89
+ """Play and provide download link for audio"""
90
+ if file_path and os.path.exists(file_path):
91
+ st.audio(file_path)
92
+ dl_link = f'<a href="data:audio/mpeg;base64,{base64.b64encode(open(file_path,"rb").read()).decode()}" download="{os.path.basename(file_path)}">Download {os.path.basename(file_path)}</a>'
93
+ st.markdown(dl_link, unsafe_allow_html=True)
94
+
95
+ @st.cache_resource
96
+ def get_model():
97
+ """Get sentence transformer model"""
98
+ return SentenceTransformer('all-MiniLM-L6-v2')
99
+
100
+ @st.cache_data
101
+ def load_dataset_page(dataset_id, token, page, rows_per_page):
102
+ """Load dataset page with caching"""
103
+ try:
104
+ start_idx = page * rows_per_page
105
+ end_idx = start_idx + rows_per_page
106
+ dataset = load_dataset(
107
+ dataset_id,
108
+ token=token,
109
+ streaming=False,
110
+ split=f'train[{start_idx}:{end_idx}]'
111
+ )
112
+ return pd.DataFrame(dataset)
113
+ except Exception as e:
114
+ st.error(f"Error loading page {page}: {str(e)}")
115
+ return pd.DataFrame()
116
+
117
+ @st.cache_data
118
+ def get_dataset_info(dataset_id, token):
119
+ """Get dataset info with caching"""
120
+ try:
121
+ dataset = load_dataset(dataset_id, token=token, streaming=True)
122
+ return dataset['train'].info
123
+ except Exception as e:
124
+ st.error(f"Error loading dataset info: {str(e)}")
125
+ return None
126
+
127
+ def fetch_dataset_info(dataset_id):
128
+ """Fetch dataset information"""
129
+ info_url = f"https://huggingface.co/api/datasets/{dataset_id}"
130
+ try:
131
+ response = requests.get(info_url, timeout=30)
132
+ if response.status_code == 200:
133
+ return response.json()
134
+ except Exception as e:
135
+ st.warning(f"Error fetching dataset info: {e}")
136
+ return None
137
+
138
+ def generate_filename(text):
139
+ """Generate unique filename from text"""
140
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
141
+ safe_text = re.sub(r'[^\w\s-]', '', text[:50]).strip().lower()
142
+ safe_text = re.sub(r'[-\s]+', '-', safe_text)
143
+ return f"{timestamp}_{safe_text}"
144
+
145
+ def render_result(result):
146
+ """Render a single search result"""
147
+ score = result.get('relevance_score', 0)
148
+ result_filtered = {k: v for k, v in result.items()
149
+ if k not in ['relevance_score', 'video_embed', 'description_embed', 'audio_embed']}
150
+
151
+ if 'youtube_id' in result:
152
+ st.video(f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}")
153
+
154
+ cols = st.columns([2, 1])
155
+ with cols[0]:
156
+ text_content = []
157
+ for key, value in result_filtered.items():
158
+ if isinstance(value, (str, int, float)):
159
+ st.write(f"**{key}:** {value}")
160
+ if isinstance(value, str) and len(value.strip()) > 0:
161
+ text_content.append(f"{key}: {value}")
162
+
163
+ with cols[1]:
164
+ st.metric("Relevance", f"{score:.2%}")
165
+
166
+ voices = {
167
+ "Aria (US Female)": "en-US-AriaNeural",
168
+ "Guy (US Male)": "en-US-GuyNeural",
169
+ "Sonia (UK Female)": "en-GB-SoniaNeural",
170
+ "Tony (UK Male)": "en-GB-TonyNeural"
171
+ }
172
+
173
+ selected_voice = st.selectbox(
174
+ "Voice:",
175
+ list(voices.keys()),
176
+ key=f"voice_{result.get('video_id', '')}"
177
+ )
178
+
179
+ if st.button("πŸ”Š Read", key=f"read_{result.get('video_id', '')}"):
180
+ text_to_read = ". ".join(text_content)
181
+ audio_file = speak_with_edge_tts(text_to_read, voices[selected_voice])
182
+ if audio_file:
183
+ play_and_download_audio(audio_file)
184
+
185
+ class FastDatasetSearcher:
186
+ """Fast dataset search with semantic and token matching"""
187
+
188
+ def __init__(self, dataset_id="tomg-group-umd/cinepile"):
189
+ self.dataset_id = dataset_id
190
+ self.text_model = get_model()
191
+ self.token = os.environ.get('DATASET_KEY')
192
+ if not self.token:
193
+ st.error("Please set the DATASET_KEY environment variable")
194
+ st.stop()
195
+
196
+ if st.session_state['dataset_info'] is None:
197
+ st.session_state['dataset_info'] = get_dataset_info(self.dataset_id, self.token)
198
+
199
+ def load_page(self, page=0):
200
+ """Load a specific page of data"""
201
+ return load_dataset_page(self.dataset_id, self.token, page, ROWS_PER_PAGE)
202
+
203
+ def quick_search(self, query, df):
204
+ """Perform quick search with semantic similarity"""
205
+ if df.empty or not query.strip():
206
+ return df
207
+
208
+ try:
209
+ searchable_cols = []
210
+ for col in df.columns:
211
+ sample_val = df[col].iloc[0]
212
+ if not isinstance(sample_val, (np.ndarray, bytes)):
213
+ searchable_cols.append(col)
214
+
215
+ query_lower = query.lower()
216
+ query_terms = set(query_lower.split())
217
+ query_embedding = self.text_model.encode([query], show_progress_bar=False)[0]
218
+
219
+ scores = []
220
+ matched_any = []
221
+
222
+ for _, row in df.iterrows():
223
+ text_parts = []
224
+ row_matched = False
225
+ exact_match = False
226
+
227
+ priority_fields = ['description', 'matched_text']
228
+ other_fields = [col for col in searchable_cols if col not in priority_fields]
229
+
230
+ for col in priority_fields:
231
+ if col in row:
232
+ val = row[col]
233
+ if val is not None:
234
+ val_str = str(val).lower()
235
+ if query_lower in val_str.split():
236
+ exact_match = True
237
+ if any(term in val_str.split() for term in query_terms):
238
+ row_matched = True
239
+ text_parts.append(str(val))
240
+
241
+ for col in other_fields:
242
+ val = row[col]
243
+ if val is not None:
244
+ val_str = str(val).lower()
245
+ if query_lower in val_str.split():
246
+ exact_match = True
247
+ if any(term in val_str.split() for term in query_terms):
248
+ row_matched = True
249
+ text_parts.append(str(val))
250
+
251
+ text = ' '.join(text_parts)
252
+
253
+ if text.strip():
254
+ text_tokens = set(text.lower().split())
255
+ matching_terms = query_terms.intersection(text_tokens)
256
+ keyword_score = len(matching_terms) / len(query_terms)
257
+
258
+ text_embedding = self.text_model.encode([text], show_progress_bar=False)[0]
259
+ semantic_score = float(cosine_similarity([query_embedding], [text_embedding])[0][0])
260
+
261
+ combined_score = 0.7 * keyword_score + 0.3 * semantic_score
262
+
263
+ if exact_match:
264
+ combined_score *= EXACT_MATCH_BOOST
265
+ elif row_matched:
266
+ combined_score *= 1.2
267
+ else:
268
+ combined_score = 0.0
269
+ row_matched = False
270
+
271
+ scores.append(combined_score)
272
+ matched_any.append(row_matched)
273
+
274
+ results_df = df.copy()
275
+ results_df['score'] = scores
276
+ results_df['matched'] = matched_any
277
+
278
+ filtered_df = results_df[
279
+ (results_df['matched']) |
280
+ (results_df['score'] > MIN_SEARCH_SCORE)
281
+ ]
282
+
283
+ return filtered_df.sort_values('score', ascending=False)
284
+
285
+ except Exception as e:
286
+ st.error(f"Search error: {str(e)}")
287
+ return df
288
+
289
+ def main():
290
+ st.title("πŸŽ₯ Smart Video & Voice Search")
291
+
292
+ # Initialize components
293
+ voice_component = create_voice_component()
294
+ search = FastDatasetSearcher()
295
+
296
+ # Voice input at top level
297
+ voice_val = voice_component(my_input_value="Start speaking...")
298
+
299
+ # Show voice input if detected
300
+ if voice_val:
301
+ voice_text = str(voice_val).strip()
302
+ edited_input = st.text_area("✏️ Edit Voice Input:", value=voice_text, height=100)
303
+
304
+ run_option = st.selectbox("Select Search Type:",
305
+ ["Quick Search", "Deep Search", "Voice Summary"])
306
+
307
+ col1, col2 = st.columns(2)
308
+ with col1:
309
+ autorun = st.checkbox("⚑ Auto-Run", value=False)
310
+ with col2:
311
+ full_audio = st.checkbox("πŸ”Š Full Audio", value=False)
312
+
313
+ input_changed = (voice_text != st.session_state.get('old_val'))
314
+
315
+ if autorun and input_changed:
316
+ st.session_state['old_val'] = voice_text
317
+ with st.spinner("Processing voice input..."):
318
+ if run_option == "Quick Search":
319
+ results = search.quick_search(edited_input, search.load_page())
320
+ for i, result in enumerate(results.iterrows(), 1):
321
+ with st.expander(f"Result {i}", expanded=(i==1)):
322
+ render_result(result[1])
323
+
324
+ elif run_option == "Deep Search":
325
+ with st.spinner("Performing deep search..."):
326
+ results = []
327
+ for page in range(3): # Search first 3 pages
328
+ df = search.load_page(page)
329
+ results.extend(search.quick_search(edited_input, df).iterrows())
330
+
331
+ for i, result in enumerate(results, 1):
332
+ with st.expander(f"Result {i}", expanded=(i==1)):
333
+ render_result(result[1])
334
+
335
+ elif run_option == "Voice Summary":
336
+ audio_file = speak_with_edge_tts(edited_input)
337
+ if audio_file:
338
+ play_and_download_audio(audio_file)
339
+
340
+ elif st.button("πŸ” Search", key="voice_input_search"):
341
+ st.session_state['old_val'] = voice_text
342
+ with st.spinner("Processing..."):
343
+ results = search.quick_search(edited_input, search.load_page())
344
+ for i, result in enumerate(results.iterrows(), 1):
345
+ with st.expander(f"Result {i}", expanded=(i==1)):
346
+ render_result(result[1])
347
+
348
+ # Create main tabs
349
+ tab1, tab2, tab3, tab4 = st.tabs([
350
+ "πŸ” Search", "πŸŽ™οΈ Voice", "πŸ’Ύ History", "βš™οΈ Settings"
351
+ ])
352
+
353
+ with tab1:
354
+ st.subheader("πŸ” Search")
355
+ col1, col2 = st.columns([3, 1])
356
+ with col1:
357
+ query = st.text_input("Enter search query:",
358
+ value="" if st.session_state['initial_search_done'] else "")
359
+ with col2:
360
+ search_column = st.selectbox("Search in:",
361
+ ["All Fields"] + st.session_state['search_columns'])
362
+
363
+ col3, col4 = st.columns(2)
364
+ with col3:
365
+ num_results = st.slider("Max results:", 1, 100, 20)
366
+ with col4:
367
+ search_button = st.button("πŸ” Search", key="main_search_button")
368
+
369
+ if (search_button or not st.session_state['initial_search_done']) and query:
370
+ st.session_state['initial_search_done'] = True
371
+ selected_column = None if search_column == "All Fields" else search_column
372
+
373
+ with st.spinner("Searching..."):
374
+ df = search.load_page()
375
+ results = search.quick_search(query, df)
376
+
377
+ if len(results) > 0:
378
+ st.session_state['search_history'].append({
379
+ 'query': query,
380
+ 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
381
+ 'results': results[:5]
382
+ })
383
+
384
+ st.write(f"Found {len(results)} results:")
385
+ for i, (_, result) in enumerate(results.iterrows(), 1):
386
+ if i > num_results:
387
+ break
388
+ with st.expander(f"Result {i}", expanded=(i==1)):
389
+ render_result(result)
390
+ else:
391
+ st.warning("No matching results found.")
392
+
393
+ with tab2:
394
+ st.subheader("πŸŽ™οΈ Voice Input")
395
+ st.write("Use the voice input above to start speaking, or record a new message:")
396
+
397
+ col1, col2 = st.columns(2)
398
+ with col1:
399
+ if st.button("πŸŽ™οΈ Start New Recording", key="start_recording_button"):
400
+ st.session_state['recording'] = True
401
+ st.experimental_rerun()
402
+ with col2:
403
+ if st.button("πŸ›‘ Stop Recording", key="stop_recording_button"):
404
+ st.session_state['recording'] = False
405
+ st.experimental_rerun()
406
+
407
+ if st.session_state.get('recording', False):
408
+ voice_component = create_voice_component()
409
+ new_val = voice_component(my_input_value="Recording...")
410
+ if new_val:
411
+ st.text_area("Recorded Text:", value=new_val, height=100)
412
+ if st.button("πŸ” Search with Recording", key="recording_search_button"):
413
+ with st.spinner("Processing recording..."):
414
+ df = search.load_page()
415
+ results = search.quick_search(new_val, df)
416
+ for i, (_, result) in enumerate(results.iterrows(), 1):
417
+ with st.expander(f"Result {i}", expanded=(i==1)):
418
+ render_result(result)
419
+
420
+ with tab3:
421
+ st.subheader("πŸ’Ύ Search History")
422
+ if not st.session_state['search_history']:
423
+ st.info("No search history yet. Try searching for something!")
424
+ else:
425
+ for entry in reversed(st.session_state['search_history']):
426
+ with st.expander(f"πŸ•’ {entry['timestamp']} - {entry['query']}", expanded=False):
427
+ for i, result in enumerate(entry['results'], 1):
428
+ st.write(f"**Result {i}:**")
429
+ if isinstance(result, pd.Series):
430
+ render_result(result)
431
+ else:
432
+ st.write(result)
433
+
434
+ with tab4:
435
+ st.subheader("βš™οΈ Settings")
436
+ st.write("Voice Settings:")
437
+ default_voice = st.selectbox(
438
+ "Default Voice:",
439
+ [
440
+ "en-US-AriaNeural",
441
+ "en-US-GuyNeural",
442
+ "en-GB-SoniaNeural",
443
+ "en-GB-TonyNeural"
444
+ ],
445
+ index=0,
446
+ key="default_voice_setting"
447
+ )
448
+
449
+ st.write("Search Settings:")
450
+ st.slider("Minimum Search Score:", 0.0, 1.0, MIN_SEARCH_SCORE, 0.1, key="min_search_score")
451
+ st.slider("Exact Match Boost:", 1.0, 3.0, EXACT_MATCH_BOOST, 0.1, key="exact_match_boost")
452
+
453
+ if st.button("πŸ—‘οΈ Clear Search History", key="clear_history_button"):
454
+ st.session_state['search_history'] = []
455
+ st.success("Search history cleared!")
456
+ st.experimental_rerun()
457
+
458
+ # Sidebar with metrics
459
+ with st.sidebar:
460
+ st.subheader("πŸ“Š Search Metrics")
461
+ total_searches = len(st.session_state['search_history'])
462
+ st.metric("Total Searches", total_searches)
463
+
464
+ if total_searches > 0:
465
+ recent_searches = st.session_state['search_history'][-5:]
466
+ st.write("Recent Searches:")
467
+ for entry in reversed(recent_searches):
468
+ st.write(f"πŸ” {entry['query']}")
469
+
470
+ if __name__ == "__main__":
471
+ main()