Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
from sklearn.metrics.pairwise import cosine_similarity | |
import torch | |
import json | |
import os | |
import glob | |
from pathlib import Path | |
from datetime import datetime, timedelta | |
import edge_tts | |
import asyncio | |
import requests | |
from collections import defaultdict | |
import streamlit.components.v1 as components | |
from urllib.parse import quote | |
from xml.etree import ElementTree as ET | |
from datasets import load_dataset | |
import base64 | |
import re | |
# π§ Initialize session state variables | |
SESSION_VARS = { | |
'search_history': [], # Track search history | |
'last_voice_input': "", # Last voice input | |
'transcript_history': [], # Conversation history | |
'should_rerun': False, # Trigger for UI updates | |
'search_columns': [], # Available search columns | |
'initial_search_done': False, # First search flag | |
'tts_voice': "en-US-AriaNeural", # Default voice | |
'arxiv_last_query': "", # Last ArXiv search | |
'dataset_loaded': False, # Dataset load status | |
'current_page': 0, # Current data page | |
'data_cache': None, # Data cache | |
'dataset_info': None, # Dataset metadata | |
'nps_submitted': False, # Track if user submitted NPS | |
'nps_last_shown': None, # When NPS was last shown | |
'old_val': None, # Previous voice input value | |
'voice_text': None # Processed voice text | |
} | |
# Constants | |
ROWS_PER_PAGE = 100 | |
MIN_SEARCH_SCORE = 0.3 | |
EXACT_MATCH_BOOST = 2.0 | |
# Initialize session state | |
for var, default in SESSION_VARS.items(): | |
if var not in st.session_state: | |
st.session_state[var] = default | |
# Voice Component Setup | |
def create_voice_component(): | |
"""Create the voice input component""" | |
mycomponent = components.declare_component( | |
"mycomponent", | |
path="mycomponent" | |
) | |
return mycomponent | |
# Utility Functions | |
def clean_for_speech(text: str) -> str: | |
"""Clean text for speech synthesis""" | |
text = text.replace("\n", " ") | |
text = text.replace("</s>", " ") | |
text = text.replace("#", "") | |
text = re.sub(r"\(https?:\/\/[^\)]+\)", "", text) | |
text = re.sub(r"\s+", " ", text).strip() | |
return text | |
async def edge_tts_generate_audio(text, voice="en-US-AriaNeural", rate=0, pitch=0): | |
"""Generate audio using Edge TTS""" | |
text = clean_for_speech(text) | |
if not text.strip(): | |
return None | |
rate_str = f"{rate:+d}%" | |
pitch_str = f"{pitch:+d}Hz" | |
communicate = edge_tts.Communicate(text, voice, rate=rate_str, pitch=pitch_str) | |
out_fn = f"speech_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp3" | |
await communicate.save(out_fn) | |
return out_fn | |
def speak_with_edge_tts(text, voice="en-US-AriaNeural", rate=0, pitch=0): | |
"""Wrapper for edge TTS generation""" | |
return asyncio.run(edge_tts_generate_audio(text, voice, rate, pitch)) | |
def play_and_download_audio(file_path): | |
"""Play and provide download link for audio""" | |
if file_path and os.path.exists(file_path): | |
st.audio(file_path) | |
dl_link = f'<a href="data:audio/mpeg;base64,{base64.b64encode(open(file_path,"rb").read()).decode()}" download="{os.path.basename(file_path)}">Download {os.path.basename(file_path)}</a>' | |
st.markdown(dl_link, unsafe_allow_html=True) | |
def get_model(): | |
"""Get sentence transformer model""" | |
return SentenceTransformer('all-MiniLM-L6-v2') | |
def load_dataset_page(dataset_id, token, page, rows_per_page): | |
"""Load dataset page with caching""" | |
try: | |
start_idx = page * rows_per_page | |
end_idx = start_idx + rows_per_page | |
dataset = load_dataset( | |
dataset_id, | |
token=token, | |
streaming=False, | |
split=f'train[{start_idx}:{end_idx}]' | |
) | |
return pd.DataFrame(dataset) | |
except Exception as e: | |
st.error(f"Error loading page {page}: {str(e)}") | |
return pd.DataFrame() | |
def get_dataset_info(dataset_id, token): | |
"""Get dataset info with caching""" | |
try: | |
dataset = load_dataset(dataset_id, token=token, streaming=True) | |
return dataset['train'].info | |
except Exception as e: | |
st.error(f"Error loading dataset info: {str(e)}") | |
return None | |
def fetch_dataset_info(dataset_id): | |
"""Fetch dataset information""" | |
info_url = f"https://huggingface.co/api/datasets/{dataset_id}" | |
try: | |
response = requests.get(info_url, timeout=30) | |
if response.status_code == 200: | |
return response.json() | |
except Exception as e: | |
st.warning(f"Error fetching dataset info: {e}") | |
return None | |
def generate_filename(text): | |
"""Generate unique filename from text""" | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
safe_text = re.sub(r'[^\w\s-]', '', text[:50]).strip().lower() | |
safe_text = re.sub(r'[-\s]+', '-', safe_text) | |
return f"{timestamp}_{safe_text}" | |
def render_result(result): | |
"""Render a single search result""" | |
score = result.get('relevance_score', 0) | |
result_filtered = {k: v for k, v in result.items() | |
if k not in ['relevance_score', 'video_embed', 'description_embed', 'audio_embed']} | |
if 'youtube_id' in result: | |
st.video(f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}") | |
cols = st.columns([2, 1]) | |
with cols[0]: | |
text_content = [] | |
for key, value in result_filtered.items(): | |
if isinstance(value, (str, int, float)): | |
st.write(f"**{key}:** {value}") | |
if isinstance(value, str) and len(value.strip()) > 0: | |
text_content.append(f"{key}: {value}") | |
with cols[1]: | |
st.metric("Relevance", f"{score:.2%}") | |
voices = { | |
"Aria (US Female)": "en-US-AriaNeural", | |
"Guy (US Male)": "en-US-GuyNeural", | |
"Sonia (UK Female)": "en-GB-SoniaNeural", | |
"Tony (UK Male)": "en-GB-TonyNeural" | |
} | |
selected_voice = st.selectbox( | |
"Voice:", | |
list(voices.keys()), | |
key=f"voice_{result.get('video_id', '')}" | |
) | |
if st.button("π Read", key=f"read_{result.get('video_id', '')}"): | |
text_to_read = ". ".join(text_content) | |
audio_file = speak_with_edge_tts(text_to_read, voices[selected_voice]) | |
if audio_file: | |
play_and_download_audio(audio_file) | |
class FastDatasetSearcher: | |
"""Fast dataset search with semantic and token matching""" | |
def __init__(self, dataset_id="tomg-group-umd/cinepile"): | |
self.dataset_id = dataset_id | |
self.text_model = get_model() | |
self.token = os.environ.get('DATASET_KEY') | |
if not self.token: | |
st.error("Please set the DATASET_KEY environment variable") | |
st.stop() | |
if st.session_state['dataset_info'] is None: | |
st.session_state['dataset_info'] = get_dataset_info(self.dataset_id, self.token) | |
def load_page(self, page=0): | |
"""Load a specific page of data""" | |
return load_dataset_page(self.dataset_id, self.token, page, ROWS_PER_PAGE) | |
def quick_search(self, query, df): | |
"""Perform quick search with semantic similarity""" | |
if df.empty or not query.strip(): | |
return df | |
try: | |
searchable_cols = [] | |
for col in df.columns: | |
sample_val = df[col].iloc[0] | |
if not isinstance(sample_val, (np.ndarray, bytes)): | |
searchable_cols.append(col) | |
query_lower = query.lower() | |
query_terms = set(query_lower.split()) | |
query_embedding = self.text_model.encode([query], show_progress_bar=False)[0] | |
scores = [] | |
matched_any = [] | |
for _, row in df.iterrows(): | |
text_parts = [] | |
row_matched = False | |
exact_match = False | |
priority_fields = ['description', 'matched_text'] | |
other_fields = [col for col in searchable_cols if col not in priority_fields] | |
for col in priority_fields: | |
if col in row: | |
val = row[col] | |
if val is not None: | |
val_str = str(val).lower() | |
if query_lower in val_str.split(): | |
exact_match = True | |
if any(term in val_str.split() for term in query_terms): | |
row_matched = True | |
text_parts.append(str(val)) | |
for col in other_fields: | |
val = row[col] | |
if val is not None: | |
val_str = str(val).lower() | |
if query_lower in val_str.split(): | |
exact_match = True | |
if any(term in val_str.split() for term in query_terms): | |
row_matched = True | |
text_parts.append(str(val)) | |
text = ' '.join(text_parts) | |
if text.strip(): | |
text_tokens = set(text.lower().split()) | |
matching_terms = query_terms.intersection(text_tokens) | |
keyword_score = len(matching_terms) / len(query_terms) | |
text_embedding = self.text_model.encode([text], show_progress_bar=False)[0] | |
semantic_score = float(cosine_similarity([query_embedding], [text_embedding])[0][0]) | |
combined_score = 0.7 * keyword_score + 0.3 * semantic_score | |
if exact_match: | |
combined_score *= EXACT_MATCH_BOOST | |
elif row_matched: | |
combined_score *= 1.2 | |
else: | |
combined_score = 0.0 | |
row_matched = False | |
scores.append(combined_score) | |
matched_any.append(row_matched) | |
results_df = df.copy() | |
results_df['score'] = scores | |
results_df['matched'] = matched_any | |
filtered_df = results_df[ | |
(results_df['matched']) | | |
(results_df['score'] > MIN_SEARCH_SCORE) | |
] | |
return filtered_df.sort_values('score', ascending=False) | |
except Exception as e: | |
st.error(f"Search error: {str(e)}") | |
return df | |
def main(): | |
st.title("π₯ Smart Video & Voice Search") | |
# Initialize components | |
voice_component = create_voice_component() | |
search = FastDatasetSearcher() | |
# Voice input at top level | |
voice_val = voice_component(my_input_value="Start speaking...") | |
# Show voice input if detected | |
if voice_val: | |
voice_text = str(voice_val).strip() | |
edited_input = st.text_area("βοΈ Edit Voice Input:", value=voice_text, height=100) | |
run_option = st.selectbox("Select Search Type:", | |
["Quick Search", "Deep Search", "Voice Summary"]) | |
col1, col2 = st.columns(2) | |
with col1: | |
autorun = st.checkbox("β‘ Auto-Run", value=False) | |
with col2: | |
full_audio = st.checkbox("π Full Audio", value=False) | |
input_changed = (voice_text != st.session_state.get('old_val')) | |
if autorun and input_changed: | |
st.session_state['old_val'] = voice_text | |
with st.spinner("Processing voice input..."): | |
if run_option == "Quick Search": | |
results = search.quick_search(edited_input, search.load_page()) | |
for i, result in enumerate(results.iterrows(), 1): | |
with st.expander(f"Result {i}", expanded=(i==1)): | |
render_result(result[1]) | |
elif run_option == "Deep Search": | |
with st.spinner("Performing deep search..."): | |
results = [] | |
for page in range(3): # Search first 3 pages | |
df = search.load_page(page) | |
results.extend(search.quick_search(edited_input, df).iterrows()) | |
for i, result in enumerate(results, 1): | |
with st.expander(f"Result {i}", expanded=(i==1)): | |
render_result(result[1]) | |
elif run_option == "Voice Summary": | |
audio_file = speak_with_edge_tts(edited_input) | |
if audio_file: | |
play_and_download_audio(audio_file) | |
elif st.button("π Search", key="voice_input_search"): | |
st.session_state['old_val'] = voice_text | |
with st.spinner("Processing..."): | |
results = search.quick_search(edited_input, search.load_page()) | |
for i, result in enumerate(results.iterrows(), 1): | |
with st.expander(f"Result {i}", expanded=(i==1)): | |
render_result(result[1]) | |
# Create main tabs | |
tab1, tab2, tab3, tab4 = st.tabs([ | |
"π Search", "ποΈ Voice", "πΎ History", "βοΈ Settings" | |
]) | |
with tab1: | |
st.subheader("π Search") | |
col1, col2 = st.columns([3, 1]) | |
with col1: | |
query = st.text_input("Enter search query:", | |
value="" if st.session_state['initial_search_done'] else "") | |
with col2: | |
search_column = st.selectbox("Search in:", | |
["All Fields"] + st.session_state['search_columns']) | |
col3, col4 = st.columns(2) | |
with col3: | |
num_results = st.slider("Max results:", 1, 100, 20) | |
with col4: | |
search_button = st.button("π Search", key="main_search_button") | |
if (search_button or not st.session_state['initial_search_done']) and query: | |
st.session_state['initial_search_done'] = True | |
selected_column = None if search_column == "All Fields" else search_column | |
with st.spinner("Searching..."): | |
df = search.load_page() | |
results = search.quick_search(query, df) | |
if len(results) > 0: | |
st.session_state['search_history'].append({ | |
'query': query, | |
'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
'results': results[:5] | |
}) | |
st.write(f"Found {len(results)} results:") | |
for i, (_, result) in enumerate(results.iterrows(), 1): | |
if i > num_results: | |
break | |
with st.expander(f"Result {i}", expanded=(i==1)): | |
render_result(result) | |
else: | |
st.warning("No matching results found.") | |
with tab2: | |
st.subheader("ποΈ Voice Input") | |
st.write("Use the voice input above to start speaking, or record a new message:") | |
col1, col2 = st.columns(2) | |
with col1: | |
if st.button("ποΈ Start New Recording", key="start_recording_button"): | |
st.session_state['recording'] = True | |
st.experimental_rerun() | |
with col2: | |
if st.button("π Stop Recording", key="stop_recording_button"): | |
st.session_state['recording'] = False | |
st.experimental_rerun() | |
if st.session_state.get('recording', False): | |
voice_component = create_voice_component() | |
new_val = voice_component(my_input_value="Recording...") | |
if new_val: | |
st.text_area("Recorded Text:", value=new_val, height=100) | |
if st.button("π Search with Recording", key="recording_search_button"): | |
with st.spinner("Processing recording..."): | |
df = search.load_page() | |
results = search.quick_search(new_val, df) | |
for i, (_, result) in enumerate(results.iterrows(), 1): | |
with st.expander(f"Result {i}", expanded=(i==1)): | |
render_result(result) | |
with tab3: | |
st.subheader("πΎ Search History") | |
if not st.session_state['search_history']: | |
st.info("No search history yet. Try searching for something!") | |
else: | |
for entry in reversed(st.session_state['search_history']): | |
with st.expander(f"π {entry['timestamp']} - {entry['query']}", expanded=False): | |
for i, result in enumerate(entry['results'], 1): | |
st.write(f"**Result {i}:**") | |
if isinstance(result, pd.Series): | |
render_result(result) | |
else: | |
st.write(result) | |
with tab4: | |
st.subheader("βοΈ Settings") | |
st.write("Voice Settings:") | |
default_voice = st.selectbox( | |
"Default Voice:", | |
[ | |
"en-US-AriaNeural", | |
"en-US-GuyNeural", | |
"en-GB-SoniaNeural", | |
"en-GB-TonyNeural" | |
], | |
index=0, | |
key="default_voice_setting" | |
) | |
st.write("Search Settings:") | |
st.slider("Minimum Search Score:", 0.0, 1.0, MIN_SEARCH_SCORE, 0.1, key="min_search_score") | |
st.slider("Exact Match Boost:", 1.0, 3.0, EXACT_MATCH_BOOST, 0.1, key="exact_match_boost") | |
if st.button("ποΈ Clear Search History", key="clear_history_button"): | |
st.session_state['search_history'] = [] | |
st.success("Search history cleared!") | |
st.experimental_rerun() | |
# Sidebar with metrics | |
with st.sidebar: | |
st.subheader("π Search Metrics") | |
total_searches = len(st.session_state['search_history']) | |
st.metric("Total Searches", total_searches) | |
if total_searches > 0: | |
recent_searches = st.session_state['search_history'][-5:] | |
st.write("Recent Searches:") | |
for entry in reversed(recent_searches): | |
st.write(f"π {entry['query']}") | |
if __name__ == "__main__": | |
main() |