awacke1 commited on
Commit
a2236b2
Β·
verified Β·
1 Parent(s): 412ec8c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +473 -107
app.py CHANGED
@@ -3,34 +3,52 @@ import pandas as pd
3
  import numpy as np
4
  from sentence_transformers import SentenceTransformer
5
  from sklearn.metrics.pairwise import cosine_similarity
 
 
6
  import os
 
 
7
  from datetime import datetime
 
 
 
 
 
 
 
 
8
  from datasets import load_dataset
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  # Initialize session state
11
- if 'search_history' not in st.session_state:
12
- st.session_state['search_history'] = []
13
- if 'search_columns' not in st.session_state:
14
- st.session_state['search_columns'] = []
15
- if 'dataset_loaded' not in st.session_state:
16
- st.session_state['dataset_loaded'] = False
17
- if 'current_page' not in st.session_state:
18
- st.session_state['current_page'] = 0
19
- if 'data_cache' not in st.session_state:
20
- st.session_state['data_cache'] = None
21
- if 'dataset_info' not in st.session_state:
22
- st.session_state['dataset_info'] = None
23
-
24
- ROWS_PER_PAGE = 100 # Number of rows to load at a time
25
 
26
  @st.cache_resource
27
  def get_model():
28
- """Cache the model loading"""
29
  return SentenceTransformer('all-MiniLM-L6-v2')
30
 
31
  @st.cache_data
32
  def load_dataset_page(dataset_id, token, page, rows_per_page):
33
- """Load and cache a specific page of data"""
34
  try:
35
  start_idx = page * rows_per_page
36
  end_idx = start_idx + rows_per_page
@@ -47,113 +65,408 @@ def load_dataset_page(dataset_id, token, page, rows_per_page):
47
 
48
  @st.cache_data
49
  def get_dataset_info(dataset_id, token):
50
- """Load and cache dataset information"""
51
  try:
52
- dataset = load_dataset(
53
- dataset_id,
54
- token=token,
55
- streaming=True
56
- )
57
  return dataset['train'].info
58
  except Exception as e:
59
  st.error(f"Error loading dataset info: {str(e)}")
60
  return None
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  class FastDatasetSearcher:
63
  def __init__(self, dataset_id="tomg-group-umd/cinepile"):
64
  self.dataset_id = dataset_id
65
  self.text_model = get_model()
66
  self.token = os.environ.get('DATASET_KEY')
67
  if not self.token:
68
- st.error("Please set the DATASET_KEY environment variable with your Hugging Face token.")
69
  st.stop()
70
 
71
- # Initialize numpy for model inputs
72
- self.np = np
73
-
74
- # Load dataset info if not already loaded
75
  if st.session_state['dataset_info'] is None:
76
  st.session_state['dataset_info'] = get_dataset_info(self.dataset_id, self.token)
77
 
78
  def load_page(self, page=0):
79
- """Load a specific page of data using cached function"""
80
  return load_dataset_page(self.dataset_id, self.token, page, ROWS_PER_PAGE)
81
 
82
  def quick_search(self, query, df):
83
- """Fast search on current page"""
84
- if df.empty:
85
  return df
86
 
87
  try:
88
- # Get columns to search (excluding numpy array columns)
 
 
 
 
89
  searchable_cols = []
90
  for col in df.columns:
91
  sample_val = df[col].iloc[0]
92
  if not isinstance(sample_val, (np.ndarray, bytes)):
93
  searchable_cols.append(col)
94
 
95
- # Prepare query
96
  query_lower = query.lower()
 
97
  query_embedding = self.text_model.encode([query], show_progress_bar=False)[0]
 
98
  scores = []
 
99
 
100
- # Process each row
101
  for _, row in df.iterrows():
102
- # Combine text from searchable columns
103
  text_parts = []
 
 
 
104
  for col in searchable_cols:
105
  val = row[col]
106
  if val is not None:
107
- if isinstance(val, (list, dict)):
108
- text_parts.append(str(val))
109
- else:
110
- text_parts.append(str(val))
111
 
112
  text = ' '.join(text_parts)
113
 
114
- # Calculate scores
115
  if text.strip():
116
- # Keyword matching
117
- keyword_score = text.lower().count(query_lower) / max(len(text.split()), 1)
 
 
118
 
119
- # Semantic matching
120
  text_embedding = self.text_model.encode([text], show_progress_bar=False)[0]
121
  semantic_score = float(cosine_similarity([query_embedding], [text_embedding])[0][0])
122
 
123
- # Combine scores
124
- combined_score = 0.5 * semantic_score + 0.5 * keyword_score
 
 
 
 
125
  else:
126
  combined_score = 0.0
 
127
 
128
  scores.append(combined_score)
 
129
 
130
- # Get top results
131
  results_df = df.copy()
132
  results_df['score'] = scores
133
- return results_df.sort_values('score', ascending=False)
 
 
 
 
 
 
 
 
134
 
135
  except Exception as e:
136
  st.error(f"Search error: {str(e)}")
137
  return df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
- # Get top results
140
- results_df = df.copy()
141
- results_df['score'] = scores
142
- return results_df.sort_values('score', ascending=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
  def render_result(result):
145
- """Render a single search result"""
146
- # Get score from the Series
147
- score = result.get('score', 0) if 'score' in result else 0
148
- result_filtered = result.drop('score') if 'score' in result else result
149
 
150
- # Display video if available
151
  if 'youtube_id' in result:
152
- st.video(
153
- f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}"
154
- )
155
 
156
- # Display other fields
157
  cols = st.columns([2, 1])
158
  with cols[0]:
159
  for key, value in result_filtered.items():
@@ -164,61 +477,114 @@ def render_result(result):
164
  st.metric("Relevance Score", f"{score:.2%}")
165
 
166
  def main():
167
- st.title("πŸŽ₯ Fast Video Dataset Search")
168
 
169
- # Initialize search class
170
- searcher = FastDatasetSearcher()
171
 
172
- # Show dataset info
173
- if st.session_state['dataset_info']:
174
- st.sidebar.write("### Dataset Info")
175
- st.sidebar.write(f"Total examples: {st.session_state['dataset_info'].splits['train'].num_examples:,}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
- total_pages = st.session_state['dataset_info'].splits['train'].num_examples // ROWS_PER_PAGE
178
- current_page = st.number_input("Page", min_value=0, max_value=total_pages, value=st.session_state['current_page'])
179
- else:
180
- current_page = st.number_input("Page", min_value=0, value=st.session_state['current_page'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
- # Load current page
183
- with st.spinner(f"Loading page {current_page}..."):
184
- df = searcher.load_page(current_page)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
- if df.empty:
187
- st.warning("No data available for this page.")
188
- return
 
 
 
 
 
 
 
 
189
 
190
- # Search interface
191
- col1, col2 = st.columns([3, 1])
192
- with col1:
193
- query = st.text_input("Search in current page:",
194
- help="Searches within currently loaded data")
195
- with col2:
196
- max_results = st.slider("Max results", 1, ROWS_PER_PAGE, 10)
197
 
198
- if query:
199
- with st.spinner("Searching..."):
200
- results = searcher.quick_search(query, df)
201
-
202
- # Display results
203
- st.write(f"Found {len(results)} results on this page:")
204
- for i, (_, result) in enumerate(results.head(max_results).iterrows(), 1):
205
- with st.expander(f"Result {i}", expanded=i==1):
206
- render_result(result)
207
-
208
- # Show raw data
209
- with st.expander("Show Raw Data"):
210
- st.dataframe(df)
211
-
212
- # Navigation buttons
213
- cols = st.columns(2)
214
- with cols[0]:
215
- if st.button("⬅️ Previous Page") and current_page > 0:
216
- st.session_state['current_page'] = current_page - 1
217
- st.rerun()
218
- with cols[1]:
219
- if st.button("Next Page ➑️"):
220
- st.session_state['current_page'] = current_page + 1
221
- st.rerun()
222
 
223
  if __name__ == "__main__":
224
  main()
 
3
  import numpy as np
4
  from sentence_transformers import SentenceTransformer
5
  from sklearn.metrics.pairwise import cosine_similarity
6
+ import torch
7
+ import json
8
  import os
9
+ import glob
10
+ from pathlib import Path
11
  from datetime import datetime
12
+ import edge_tts
13
+ import asyncio
14
+ import requests
15
+ from collections import defaultdict
16
+ from audio_recorder_streamlit import audio_recorder
17
+ import streamlit.components.v1 as components
18
+ from urllib.parse import quote
19
+ from xml.etree import ElementTree as ET
20
  from datasets import load_dataset
21
 
22
+ # 🧠 Initialize session state variables
23
+ SESSION_VARS = {
24
+ 'search_history': [], # Track search history
25
+ 'last_voice_input': "", # Last voice input
26
+ 'transcript_history': [], # Conversation history
27
+ 'should_rerun': False, # Trigger for UI updates
28
+ 'search_columns': [], # Available search columns
29
+ 'initial_search_done': False, # First search flag
30
+ 'tts_voice': "en-US-AriaNeural", # Default voice
31
+ 'arxiv_last_query': "", # Last ArXiv search
32
+ 'dataset_loaded': False, # Dataset load status
33
+ 'current_page': 0, # Current data page
34
+ 'data_cache': None, # Data cache
35
+ 'dataset_info': None # Dataset metadata
36
+ }
37
+
38
+ # Constants
39
+ ROWS_PER_PAGE = 100
40
+
41
  # Initialize session state
42
+ for var, default in SESSION_VARS.items():
43
+ if var not in st.session_state:
44
+ st.session_state[var] = default
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  @st.cache_resource
47
  def get_model():
 
48
  return SentenceTransformer('all-MiniLM-L6-v2')
49
 
50
  @st.cache_data
51
  def load_dataset_page(dataset_id, token, page, rows_per_page):
 
52
  try:
53
  start_idx = page * rows_per_page
54
  end_idx = start_idx + rows_per_page
 
65
 
66
  @st.cache_data
67
  def get_dataset_info(dataset_id, token):
 
68
  try:
69
+ dataset = load_dataset(dataset_id, token=token, streaming=True)
 
 
 
 
70
  return dataset['train'].info
71
  except Exception as e:
72
  st.error(f"Error loading dataset info: {str(e)}")
73
  return None
74
 
75
+ def fetch_dataset_info(dataset_id):
76
+ info_url = f"https://huggingface.co/api/datasets/{dataset_id}"
77
+ try:
78
+ response = requests.get(info_url, timeout=30)
79
+ if response.status_code == 200:
80
+ return response.json()
81
+ except Exception as e:
82
+ st.warning(f"Error fetching dataset info: {e}")
83
+ return None
84
+
85
+ def fetch_dataset_rows(dataset_id, config="default", split="train", max_rows=100):
86
+ url = f"https://datasets-server.huggingface.co/first-rows?dataset={dataset_id}&config={config}&split={split}"
87
+ try:
88
+ response = requests.get(url, timeout=30)
89
+ if response.status_code == 200:
90
+ data = response.json()
91
+ if 'rows' in data:
92
+ processed_rows = []
93
+ for row_data in data['rows']:
94
+ row = row_data.get('row', row_data)
95
+ # Process embeddings if present
96
+ for key in row:
97
+ if any(term in key.lower() for term in ['embed', 'vector', 'encoding']):
98
+ if isinstance(row[key], str):
99
+ try:
100
+ row[key] = [float(x.strip()) for x in row[key].strip('[]').split(',') if x.strip()]
101
+ except:
102
+ continue
103
+ row['_config'] = config
104
+ row['_split'] = split
105
+ processed_rows.append(row)
106
+ return processed_rows
107
+ except Exception as e:
108
+ st.warning(f"Error fetching rows: {e}")
109
+ return []
110
+
111
  class FastDatasetSearcher:
112
  def __init__(self, dataset_id="tomg-group-umd/cinepile"):
113
  self.dataset_id = dataset_id
114
  self.text_model = get_model()
115
  self.token = os.environ.get('DATASET_KEY')
116
  if not self.token:
117
+ st.error("Please set the DATASET_KEY environment variable")
118
  st.stop()
119
 
 
 
 
 
120
  if st.session_state['dataset_info'] is None:
121
  st.session_state['dataset_info'] = get_dataset_info(self.dataset_id, self.token)
122
 
123
  def load_page(self, page=0):
 
124
  return load_dataset_page(self.dataset_id, self.token, page, ROWS_PER_PAGE)
125
 
126
  def quick_search(self, query, df):
127
+ """Enhanced search with improved relevance filtering"""
128
+ if df.empty or not query.strip():
129
  return df
130
 
131
  try:
132
+ # Define relevance thresholds
133
+ MIN_KEYWORD_MATCHES = 0.1
134
+ MIN_SEMANTIC_SCORE = 0.3
135
+
136
+ # Get searchable columns
137
  searchable_cols = []
138
  for col in df.columns:
139
  sample_val = df[col].iloc[0]
140
  if not isinstance(sample_val, (np.ndarray, bytes)):
141
  searchable_cols.append(col)
142
 
 
143
  query_lower = query.lower()
144
+ query_terms = set(query_lower.split())
145
  query_embedding = self.text_model.encode([query], show_progress_bar=False)[0]
146
+
147
  scores = []
148
+ matched_any = []
149
 
 
150
  for _, row in df.iterrows():
 
151
  text_parts = []
152
+ row_matched = False
153
+
154
+ # Check for direct matches
155
  for col in searchable_cols:
156
  val = row[col]
157
  if val is not None:
158
+ val_str = str(val).lower()
159
+ if any(term in val_str for term in query_terms):
160
+ row_matched = True
161
+ text_parts.append(str(val))
162
 
163
  text = ' '.join(text_parts)
164
 
 
165
  if text.strip():
166
+ # Calculate term-based keyword score
167
+ text_terms = set(text.lower().split())
168
+ matching_terms = query_terms.intersection(text_terms)
169
+ keyword_score = len(matching_terms) / len(query_terms)
170
 
171
+ # Calculate semantic score
172
  text_embedding = self.text_model.encode([text], show_progress_bar=False)[0]
173
  semantic_score = float(cosine_similarity([query_embedding], [text_embedding])[0][0])
174
 
175
+ # Weighted combination
176
+ combined_score = 0.7 * keyword_score + 0.3 * semantic_score
177
+
178
+ # Boost exact matches
179
+ if row_matched:
180
+ combined_score *= 1.5
181
  else:
182
  combined_score = 0.0
183
+ row_matched = False
184
 
185
  scores.append(combined_score)
186
+ matched_any.append(row_matched)
187
 
 
188
  results_df = df.copy()
189
  results_df['score'] = scores
190
+ results_df['matched'] = matched_any
191
+
192
+ # Filter relevant results
193
+ filtered_df = results_df[
194
+ (results_df['matched']) | # Include direct matches
195
+ (results_df['score'] > MIN_KEYWORD_MATCHES) # Or high relevance
196
+ ]
197
+
198
+ return filtered_df.sort_values('score', ascending=False)
199
 
200
  except Exception as e:
201
  st.error(f"Search error: {str(e)}")
202
  return df
203
+
204
+ class VideoSearch:
205
+ def __init__(self):
206
+ self.text_model = SentenceTransformer('all-MiniLM-L6-v2')
207
+ self.dataset_id = "omegalabsinc/omega-multimodal"
208
+ self.load_dataset()
209
+
210
+ def fetch_dataset_rows(self):
211
+ try:
212
+ df, configs, splits = search_dataset(
213
+ self.dataset_id,
214
+ "",
215
+ include_configs=None,
216
+ include_splits=None
217
+ )
218
+
219
+ if not df.empty:
220
+ st.session_state['search_columns'] = [col for col in df.columns
221
+ if col not in ['video_embed', 'description_embed', 'audio_embed']
222
+ and not col.startswith('_')]
223
+ return df
224
+
225
+ return self.load_example_data()
226
+
227
+ except Exception as e:
228
+ st.warning(f"Error loading videos: {e}")
229
+ return self.load_example_data()
230
+
231
+ def load_example_data(self):
232
+ example_data = [{
233
+ "video_id": "sample-123",
234
+ "youtube_id": "dQw4w9WgXcQ",
235
+ "description": "An example video",
236
+ "views": 12345,
237
+ "start_time": 0,
238
+ "end_time": 60
239
+ }]
240
+ return pd.DataFrame(example_data)
241
+
242
+ def load_dataset(self):
243
+ self.dataset = self.fetch_dataset_rows()
244
+ self.prepare_features()
245
+
246
+ def prepare_features(self):
247
+ try:
248
+ embed_cols = [col for col in self.dataset.columns
249
+ if any(term in col.lower() for term in ['embed', 'vector', 'encoding'])]
250
+
251
+ embeddings = {}
252
+ for col in embed_cols:
253
+ try:
254
+ data = []
255
+ for row in self.dataset[col]:
256
+ if isinstance(row, str):
257
+ values = [float(x.strip()) for x in row.strip('[]').split(',') if x.strip()]
258
+ elif isinstance(row, list):
259
+ values = row
260
+ else:
261
+ continue
262
+ data.append(values)
263
+
264
+ if data:
265
+ embeddings[col] = np.array(data)
266
+ except:
267
+ continue
268
+
269
+ self.video_embeds = embeddings.get('video_embed', next(iter(embeddings.values())) if embeddings else None)
270
+ self.text_embeds = embeddings.get('description_embed', self.video_embeds)
271
+
272
+ except:
273
+ num_rows = len(self.dataset)
274
+ self.video_embeds = np.random.randn(num_rows, 384)
275
+ self.text_embeds = np.random.randn(num_rows, 384)
276
+
277
+ def search(self, query, column=None, top_k=20):
278
+ """Enhanced search with better relevance scoring"""
279
+ MIN_RELEVANCE = 0.3 # Minimum relevance threshold
280
+
281
+ query_embedding = self.text_model.encode([query])[0]
282
+ video_sims = cosine_similarity([query_embedding], self.video_embeds)[0]
283
+ text_sims = cosine_similarity([query_embedding], self.text_embeds)[0]
284
+ combined_sims = 0.7 * text_sims + 0.3 * video_sims # Favor text matches
285
+
286
+ if column and column in self.dataset.columns and column != "All Fields":
287
+ # Direct matches in specified column
288
+ matches = self.dataset[column].astype(str).str.contains(query, case=False)
289
+ combined_sims[matches] *= 1.5 # Boost exact matches
290
+
291
+ # Filter by minimum relevance
292
+ relevant_indices = np.where(combined_sims >= MIN_RELEVANCE)[0]
293
+ if len(relevant_indices) == 0:
294
+ return []
295
+
296
+ top_k = min(top_k, len(relevant_indices))
297
+ top_indices = relevant_indices[np.argsort(combined_sims[relevant_indices])[-top_k:][::-1]]
298
+
299
+ results = []
300
+ for idx in top_indices:
301
+ result = {'relevance_score': float(combined_sims[idx])}
302
+ for col in self.dataset.columns:
303
+ if col not in ['video_embed', 'description_embed', 'audio_embed']:
304
+ result[col] = self.dataset.iloc[idx][col]
305
+ results.append(result)
306
 
307
+ return results
308
+
309
+ def search_dataset(dataset_id, search_text, include_configs=None, include_splits=None):
310
+ dataset_info = fetch_dataset_info(dataset_id)
311
+ if not dataset_info:
312
+ return pd.DataFrame(), [], []
313
+
314
+ configs = include_configs if include_configs else dataset_info.get('config_names', ['default'])
315
+ all_rows = []
316
+ available_splits = set()
317
+
318
+ for config in configs:
319
+ try:
320
+ splits_url = f"https://datasets-server.huggingface.co/splits?dataset={dataset_id}&config={config}"
321
+ splits_response = requests.get(splits_url, timeout=30)
322
+ if splits_response.status_code == 200:
323
+ splits_data = splits_response.json()
324
+ splits = [split['split'] for split in splits_data.get('splits', [])]
325
+ if not splits:
326
+ splits = ['train']
327
+
328
+ if include_splits:
329
+ splits = [s for s in splits if s in include_splits]
330
+
331
+ available_splits.update(splits)
332
+
333
+ for split in splits:
334
+ rows = fetch_dataset_rows(dataset_id, config, split)
335
+ for row in rows:
336
+ text_content = ' '.join(str(v) for v in row.values()
337
+ if isinstance(v, (str, int, float)))
338
+ if search_text.lower() in text_content.lower():
339
+ row['_matched_text'] = text_content
340
+ row['_relevance_score'] = text_content.lower().count(search_text.lower())
341
+ all_rows.append(row)
342
+ except Exception as e:
343
+ st.warning(f"Error processing config {config}: {e}")
344
+ continue
345
+
346
+ if all_rows:
347
+ df = pd.DataFrame(all_rows)
348
+ df = df.sort_values('_relevance_score', ascending=False)
349
+ return df, configs, list(available_splits)
350
+
351
+ return pd.DataFrame(), configs, list(available_splits)
352
+
353
+ @st.cache_resource
354
+ def get_speech_model():
355
+ return edge_tts.Communicate
356
+
357
+ async def generate_speech(text, voice=None):
358
+ if not text.strip():
359
+ return None
360
+ if not voice:
361
+ voice = st.session_state['tts_voice']
362
+ try:
363
+ communicate = get_speech_model()(text, voice)
364
+ audio_file = f"speech_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp3"
365
+ await communicate.save(audio_file)
366
+ return audio_file
367
+ except Exception as e:
368
+ st.error(f"Error generating speech: {e}")
369
+ return None
370
+
371
+ def transcribe_audio(audio_path):
372
+ """Placeholder for ASR implementation"""
373
+ return "ASR not implemented. Add your preferred speech recognition here!"
374
+
375
+ def arxiv_search(query, max_results=5):
376
+ base_url = "http://export.arxiv.org/api/query?"
377
+ search_url = base_url + f"search_query={quote(query)}&start=0&max_results={max_results}"
378
+ try:
379
+ r = requests.get(search_url)
380
+ if r.status_code == 200:
381
+ root = ET.fromstring(r.text)
382
+ ns = {'atom': 'http://www.w3.org/2005/Atom'}
383
+ entries = root.findall('atom:entry', ns)
384
+ results = []
385
+ for entry in entries:
386
+ title = entry.find('atom:title', ns).text.strip()
387
+ summary = entry.find('atom:summary', ns).text.strip()
388
+ link = next((l.get('href') for l in entry.findall('atom:link', ns)
389
+ if l.get('type') == 'text/html'), None)
390
+ results.append((title, summary, link))
391
+ return results
392
+ except Exception as e:
393
+ st.error(f"ArXiv search error: {e}")
394
+ return []
395
+
396
+ def show_file_manager():
397
+ st.subheader("πŸ“‚ File Manager")
398
+ col1, col2 = st.columns(2)
399
+
400
+ with col1:
401
+ uploaded_file = st.file_uploader("Upload File", type=['txt', 'md', 'mp3'])
402
+ if uploaded_file:
403
+ with open(uploaded_file.name, "wb") as f:
404
+ f.write(uploaded_file.getvalue())
405
+ st.success(f"Uploaded: {uploaded_file.name}")
406
+ st.experimental_rerun()
407
+
408
+ with col2:
409
+ if st.button("πŸ—‘ Clear Files"):
410
+ for f in glob.glob("*.txt") + glob.glob("*.md") + glob.glob("*.mp3"):
411
+ os.remove(f)
412
+ st.success("All files cleared!")
413
+ st.experimental_rerun()
414
+
415
+ files = glob.glob("*.txt") + glob.glob("*.md") + glob.glob("*.mp3")
416
+ if files:
417
+ st.write("### Existing Files")
418
+ for f in files:
419
+ with st.expander(f"πŸ“„ {os.path.basename(f)}"):
420
+ if f.endswith('.mp3'):
421
+ st.audio(f)
422
+ else:
423
+ with open(f, 'r', encoding='utf-8') as file:
424
+ st.text_area("Content", file.read(), height=100)
425
+ if st.button(f"Delete {os.path.basename(f)}", key=f"del_{f}"):
426
+ os.remove(f)
427
+ st.experimental_rerun()
428
+
429
+ def perform_arxiv_lookup(query, vocal_summary=True, titles_summary=True, full_audio=False):
430
+ results = arxiv_search(query, max_results=5)
431
+ if not results:
432
+ st.write("No results found.")
433
+ return
434
+
435
+ st.markdown(f"**ArXiv Results for '{query}':**")
436
+ for i, (title, summary, link) in enumerate(results, start=1):
437
+ st.markdown(f"**{i}. {title}**")
438
+ st.write(summary)
439
+ if link:
440
+ st.markdown(f"[View Paper]({link})")
441
+
442
+ if vocal_summary:
443
+ spoken_text = f"Here are ArXiv results for {query}. "
444
+ if titles_summary:
445
+ spoken_text += " Titles: " + ", ".join([res[0] for res in results])
446
+ else:
447
+ spoken_text += " " + results[0][1][:200]
448
+
449
+ audio_file = asyncio.run(generate_speech(spoken_text))
450
+ if audio_file:
451
+ st.audio(audio_file)
452
+
453
+ if full_audio:
454
+ full_text = ""
455
+ for i, (title, summary, _) in enumerate(results, start=1):
456
+ full_text += f"Result {i}: {title}. {summary} "
457
+ audio_file_full = asyncio.run(generate_speech(full_text))
458
+ if audio_file_full:
459
+ st.write("### Full Audio Summary")
460
+ st.audio(audio_file_full)
461
 
462
  def render_result(result):
463
+ score = result.get('relevance_score', 0)
464
+ result_filtered = {k: v for k, v in result.items()
465
+ if k not in ['relevance_score', 'video_embed', 'description_embed', 'audio_embed']}
 
466
 
 
467
  if 'youtube_id' in result:
468
+ st.video(f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}")
 
 
469
 
 
470
  cols = st.columns([2, 1])
471
  with cols[0]:
472
  for key, value in result_filtered.items():
 
477
  st.metric("Relevance Score", f"{score:.2%}")
478
 
479
  def main():
480
+ st.title("πŸŽ₯ Advanced Video & Dataset Search with Voice")
481
 
482
+ # Initialize search
483
+ search = VideoSearch()
484
 
485
+ # Create tabs
486
+ tab1, tab2, tab3, tab4 = st.tabs([
487
+ "πŸ” Search", "πŸŽ™οΈ Voice Input", "πŸ“š ArXiv", "πŸ“‚ Files"
488
+ ])
489
+
490
+ # Search Tab
491
+ with tab1:
492
+ st.subheader("Search Videos")
493
+ col1, col2 = st.columns([3, 1])
494
+ with col1:
495
+ query = st.text_input("Enter search query:",
496
+ value="" if st.session_state['initial_search_done'] else "aliens")
497
+ with col2:
498
+ search_column = st.selectbox("Search in:",
499
+ ["All Fields"] + st.session_state['search_columns'])
500
+
501
+ col3, col4 = st.columns(2)
502
+ with col3:
503
+ num_results = st.slider("Max results:", 1, 100, 20)
504
+ with col4:
505
+ search_button = st.button("πŸ” Search")
506
 
507
+ if (search_button or not st.session_state['initial_search_done']) and query:
508
+ st.session_state['initial_search_done'] = True
509
+ selected_column = None if search_column == "All Fields" else search_column
510
+
511
+ with st.spinner("Searching..."):
512
+ results = search.search(query, selected_column, num_results)
513
+
514
+ if results:
515
+ st.session_state['search_history'].append({
516
+ 'query': query,
517
+ 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
518
+ 'results': results[:5]
519
+ })
520
+
521
+ st.write(f"Found {len(results)} results:")
522
+ for i, result in enumerate(results, 1):
523
+ with st.expander(f"Result {i}", expanded=(i==1)):
524
+ render_result(result)
525
+ else:
526
+ st.warning("No matching results found.")
527
 
528
+ # Voice Input Tab
529
+ with tab2:
530
+ st.subheader("Voice Search")
531
+ st.write("πŸŽ™οΈ Record your query:")
532
+ audio_bytes = audio_recorder()
533
+ if audio_bytes:
534
+ with st.spinner("Processing audio..."):
535
+ audio_path = f"temp_audio_{datetime.now().strftime('%Y%m%d_%H%M%S')}.wav"
536
+ with open(audio_path, "wb") as f:
537
+ f.write(audio_bytes)
538
+
539
+ voice_query = transcribe_audio(audio_path)
540
+ st.markdown("**Transcribed Text:**")
541
+ st.write(voice_query)
542
+ st.session_state['last_voice_input'] = voice_query
543
+
544
+ if st.button("πŸ” Search from Voice"):
545
+ results = search.search(voice_query, None, 20)
546
+ for i, result in enumerate(results, 1):
547
+ with st.expander(f"Result {i}", expanded=(i==1)):
548
+ render_result(result)
549
+
550
+ if os.path.exists(audio_path):
551
+ os.remove(audio_path)
552
 
553
+ # ArXiv Tab
554
+ with tab3:
555
+ st.subheader("ArXiv Search")
556
+ arxiv_query = st.text_input("Search ArXiv:", value=st.session_state['arxiv_last_query'])
557
+ vocal_summary = st.checkbox("πŸŽ™ Quick Audio Summary", value=True)
558
+ titles_summary = st.checkbox("πŸ”– Titles Only", value=True)
559
+ full_audio = st.checkbox("πŸ“š Full Audio Summary", value=False)
560
+
561
+ if st.button("πŸ” Search ArXiv"):
562
+ st.session_state['arxiv_last_query'] = arxiv_query
563
+ perform_arxiv_lookup(arxiv_query, vocal_summary, titles_summary, full_audio)
564
 
565
+ # File Manager Tab
566
+ with tab4:
567
+ show_file_manager()
 
 
 
 
568
 
569
+ # Sidebar
570
+ with st.sidebar:
571
+ st.subheader("βš™οΈ Settings & History")
572
+ if st.button("πŸ—‘οΈ Clear History"):
573
+ st.session_state['search_history'] = []
574
+ st.experimental_rerun()
575
+
576
+ st.markdown("### Recent Searches")
577
+ for entry in reversed(st.session_state['search_history'][-5:]):
578
+ with st.expander(f"{entry['timestamp']}: {entry['query']}"):
579
+ for i, result in enumerate(entry['results'], 1):
580
+ st.write(f"{i}. {result.get('description', '')[:100]}...")
581
+
582
+ st.markdown("### Voice Settings")
583
+ st.selectbox("TTS Voice:", [
584
+ "en-US-AriaNeural",
585
+ "en-US-GuyNeural",
586
+ "en-GB-SoniaNeural"
587
+ ], key="tts_voice")
 
 
 
 
 
588
 
589
  if __name__ == "__main__":
590
  main()