|
import streamlit as st |
|
import pandas as pd |
|
import numpy as np |
|
from sentence_transformers import SentenceTransformer |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
import torch |
|
import json |
|
import os |
|
import glob |
|
from pathlib import Path |
|
from datetime import datetime, timedelta |
|
import edge_tts |
|
import asyncio |
|
import requests |
|
from collections import defaultdict |
|
import streamlit.components.v1 as components |
|
from urllib.parse import quote |
|
from xml.etree import ElementTree as ET |
|
from datasets import load_dataset |
|
import base64 |
|
import re |
|
|
|
|
|
SESSION_VARS = { |
|
'search_history': [], |
|
'last_voice_input': "", |
|
'transcript_history': [], |
|
'should_rerun': False, |
|
'search_columns': [], |
|
'initial_search_done': False, |
|
'tts_voice': "en-US-AriaNeural", |
|
'arxiv_last_query': "", |
|
'dataset_loaded': False, |
|
'current_page': 0, |
|
'data_cache': None, |
|
'dataset_info': None, |
|
'nps_submitted': False, |
|
'nps_last_shown': None, |
|
'old_val': None, |
|
'voice_text': None |
|
} |
|
|
|
|
|
ROWS_PER_PAGE = 100 |
|
MIN_SEARCH_SCORE = 0.3 |
|
EXACT_MATCH_BOOST = 2.0 |
|
|
|
|
|
for var, default in SESSION_VARS.items(): |
|
if var not in st.session_state: |
|
st.session_state[var] = default |
|
|
|
|
|
def create_voice_component(): |
|
"""Create the voice input component""" |
|
mycomponent = components.declare_component( |
|
"mycomponent", |
|
path="mycomponent" |
|
) |
|
return mycomponent |
|
|
|
|
|
def clean_for_speech(text: str) -> str: |
|
"""Clean text for speech synthesis""" |
|
text = text.replace("\n", " ") |
|
text = text.replace("</s>", " ") |
|
text = text.replace("#", "") |
|
text = re.sub(r"\(https?:\/\/[^\)]+\)", "", text) |
|
text = re.sub(r"\s+", " ", text).strip() |
|
return text |
|
|
|
async def edge_tts_generate_audio(text, voice="en-US-AriaNeural", rate=0, pitch=0): |
|
"""Generate audio using Edge TTS""" |
|
text = clean_for_speech(text) |
|
if not text.strip(): |
|
return None |
|
rate_str = f"{rate:+d}%" |
|
pitch_str = f"{pitch:+d}Hz" |
|
communicate = edge_tts.Communicate(text, voice, rate=rate_str, pitch=pitch_str) |
|
out_fn = f"speech_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp3" |
|
await communicate.save(out_fn) |
|
return out_fn |
|
|
|
def speak_with_edge_tts(text, voice="en-US-AriaNeural", rate=0, pitch=0): |
|
"""Wrapper for edge TTS generation""" |
|
return asyncio.run(edge_tts_generate_audio(text, voice, rate, pitch)) |
|
|
|
def play_and_download_audio(file_path): |
|
"""Play and provide download link for audio""" |
|
if file_path and os.path.exists(file_path): |
|
st.audio(file_path) |
|
dl_link = f'<a href="data:audio/mpeg;base64,{base64.b64encode(open(file_path,"rb").read()).decode()}" download="{os.path.basename(file_path)}">Download {os.path.basename(file_path)}</a>' |
|
st.markdown(dl_link, unsafe_allow_html=True) |
|
|
|
@st.cache_resource |
|
def get_model(): |
|
"""Get sentence transformer model""" |
|
return SentenceTransformer('all-MiniLM-L6-v2') |
|
|
|
@st.cache_data |
|
def load_dataset_page(dataset_id, token, page, rows_per_page): |
|
"""Load dataset page with caching""" |
|
try: |
|
start_idx = page * rows_per_page |
|
end_idx = start_idx + rows_per_page |
|
dataset = load_dataset( |
|
dataset_id, |
|
token=token, |
|
streaming=False, |
|
split=f'train[{start_idx}:{end_idx}]' |
|
) |
|
return pd.DataFrame(dataset) |
|
except Exception as e: |
|
st.error(f"Error loading page {page}: {str(e)}") |
|
return pd.DataFrame() |
|
|
|
@st.cache_data |
|
def get_dataset_info(dataset_id, token): |
|
"""Get dataset info with caching""" |
|
try: |
|
dataset = load_dataset(dataset_id, token=token, streaming=True) |
|
return dataset['train'].info |
|
except Exception as e: |
|
st.error(f"Error loading dataset info: {str(e)}") |
|
return None |
|
|
|
def fetch_dataset_info(dataset_id): |
|
"""Fetch dataset information""" |
|
info_url = f"https://huggingface.co/api/datasets/{dataset_id}" |
|
try: |
|
response = requests.get(info_url, timeout=30) |
|
if response.status_code == 200: |
|
return response.json() |
|
except Exception as e: |
|
st.warning(f"Error fetching dataset info: {e}") |
|
return None |
|
|
|
def generate_filename(text): |
|
"""Generate unique filename from text""" |
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
safe_text = re.sub(r'[^\w\s-]', '', text[:50]).strip().lower() |
|
safe_text = re.sub(r'[-\s]+', '-', safe_text) |
|
return f"{timestamp}_{safe_text}" |
|
|
|
def render_result(result): |
|
"""Render a single search result""" |
|
score = result.get('relevance_score', 0) |
|
result_filtered = {k: v for k, v in result.items() |
|
if k not in ['relevance_score', 'video_embed', 'description_embed', 'audio_embed']} |
|
|
|
if 'youtube_id' in result: |
|
st.video(f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}") |
|
|
|
cols = st.columns([2, 1]) |
|
with cols[0]: |
|
text_content = [] |
|
for key, value in result_filtered.items(): |
|
if isinstance(value, (str, int, float)): |
|
st.write(f"**{key}:** {value}") |
|
if isinstance(value, str) and len(value.strip()) > 0: |
|
text_content.append(f"{key}: {value}") |
|
|
|
with cols[1]: |
|
st.metric("Relevance", f"{score:.2%}") |
|
|
|
voices = { |
|
"Aria (US Female)": "en-US-AriaNeural", |
|
"Guy (US Male)": "en-US-GuyNeural", |
|
"Sonia (UK Female)": "en-GB-SoniaNeural", |
|
"Tony (UK Male)": "en-GB-TonyNeural" |
|
} |
|
|
|
selected_voice = st.selectbox( |
|
"Voice:", |
|
list(voices.keys()), |
|
key=f"voice_{result.get('video_id', '')}" |
|
) |
|
|
|
if st.button("π Read", key=f"read_{result.get('video_id', '')}"): |
|
text_to_read = ". ".join(text_content) |
|
audio_file = speak_with_edge_tts(text_to_read, voices[selected_voice]) |
|
if audio_file: |
|
play_and_download_audio(audio_file) |
|
|
|
class FastDatasetSearcher: |
|
"""Fast dataset search with semantic and token matching""" |
|
|
|
def __init__(self, dataset_id="tomg-group-umd/cinepile"): |
|
self.dataset_id = dataset_id |
|
self.text_model = get_model() |
|
self.token = os.environ.get('DATASET_KEY') |
|
if not self.token: |
|
st.error("Please set the DATASET_KEY environment variable") |
|
st.stop() |
|
|
|
if st.session_state['dataset_info'] is None: |
|
st.session_state['dataset_info'] = get_dataset_info(self.dataset_id, self.token) |
|
|
|
def load_page(self, page=0): |
|
"""Load a specific page of data""" |
|
return load_dataset_page(self.dataset_id, self.token, page, ROWS_PER_PAGE) |
|
|
|
def quick_search(self, query, df): |
|
"""Perform quick search with semantic similarity""" |
|
if df.empty or not query.strip(): |
|
return df |
|
|
|
try: |
|
searchable_cols = [] |
|
for col in df.columns: |
|
sample_val = df[col].iloc[0] |
|
if not isinstance(sample_val, (np.ndarray, bytes)): |
|
searchable_cols.append(col) |
|
|
|
query_lower = query.lower() |
|
query_terms = set(query_lower.split()) |
|
query_embedding = self.text_model.encode([query], show_progress_bar=False)[0] |
|
|
|
scores = [] |
|
matched_any = [] |
|
|
|
for _, row in df.iterrows(): |
|
text_parts = [] |
|
row_matched = False |
|
exact_match = False |
|
|
|
priority_fields = ['description', 'matched_text'] |
|
other_fields = [col for col in searchable_cols if col not in priority_fields] |
|
|
|
for col in priority_fields: |
|
if col in row: |
|
val = row[col] |
|
if val is not None: |
|
val_str = str(val).lower() |
|
if query_lower in val_str.split(): |
|
exact_match = True |
|
if any(term in val_str.split() for term in query_terms): |
|
row_matched = True |
|
text_parts.append(str(val)) |
|
|
|
for col in other_fields: |
|
val = row[col] |
|
if val is not None: |
|
val_str = str(val).lower() |
|
if query_lower in val_str.split(): |
|
exact_match = True |
|
if any(term in val_str.split() for term in query_terms): |
|
row_matched = True |
|
text_parts.append(str(val)) |
|
|
|
text = ' '.join(text_parts) |
|
|
|
if text.strip(): |
|
text_tokens = set(text.lower().split()) |
|
matching_terms = query_terms.intersection(text_tokens) |
|
keyword_score = len(matching_terms) / len(query_terms) |
|
|
|
text_embedding = self.text_model.encode([text], show_progress_bar=False)[0] |
|
semantic_score = float(cosine_similarity([query_embedding], [text_embedding])[0][0]) |
|
|
|
combined_score = 0.7 * keyword_score + 0.3 * semantic_score |
|
|
|
if exact_match: |
|
combined_score *= EXACT_MATCH_BOOST |
|
elif row_matched: |
|
combined_score *= 1.2 |
|
else: |
|
combined_score = 0.0 |
|
row_matched = False |
|
|
|
scores.append(combined_score) |
|
matched_any.append(row_matched) |
|
|
|
results_df = df.copy() |
|
results_df['score'] = scores |
|
results_df['matched'] = matched_any |
|
|
|
filtered_df = results_df[ |
|
(results_df['matched']) | |
|
(results_df['score'] > MIN_SEARCH_SCORE) |
|
] |
|
|
|
return filtered_df.sort_values('score', ascending=False) |
|
|
|
except Exception as e: |
|
st.error(f"Search error: {str(e)}") |
|
return df |
|
|
|
def main(): |
|
st.title("π₯ Smart Video & Voice Search") |
|
|
|
|
|
voice_component = create_voice_component() |
|
search = FastDatasetSearcher() |
|
|
|
|
|
voice_val = voice_component(my_input_value="Start speaking...") |
|
|
|
|
|
if voice_val: |
|
voice_text = str(voice_val).strip() |
|
edited_input = st.text_area("βοΈ Edit Voice Input:", value=voice_text, height=100) |
|
|
|
run_option = st.selectbox("Select Search Type:", |
|
["Quick Search", "Deep Search", "Voice Summary"]) |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
autorun = st.checkbox("β‘ Auto-Run", value=False) |
|
with col2: |
|
full_audio = st.checkbox("π Full Audio", value=False) |
|
|
|
input_changed = (voice_text != st.session_state.get('old_val')) |
|
|
|
if autorun and input_changed: |
|
st.session_state['old_val'] = voice_text |
|
with st.spinner("Processing voice input..."): |
|
if run_option == "Quick Search": |
|
results = search.quick_search(edited_input, search.load_page()) |
|
for i, result in enumerate(results.iterrows(), 1): |
|
with st.expander(f"Result {i}", expanded=(i==1)): |
|
render_result(result[1]) |
|
|
|
elif run_option == "Deep Search": |
|
with st.spinner("Performing deep search..."): |
|
results = [] |
|
for page in range(3): |
|
df = search.load_page(page) |
|
results.extend(search.quick_search(edited_input, df).iterrows()) |
|
|
|
for i, result in enumerate(results, 1): |
|
with st.expander(f"Result {i}", expanded=(i==1)): |
|
render_result(result[1]) |
|
|
|
elif run_option == "Voice Summary": |
|
audio_file = speak_with_edge_tts(edited_input) |
|
if audio_file: |
|
play_and_download_audio(audio_file) |
|
|
|
elif st.button("π Search"): |
|
st.session_state['old_val'] = voice_text |
|
with st.spinner("Processing..."): |
|
results = search.quick_search(edited_input, search.load_page()) |
|
for i, result in enumerate(results.iterrows(), 1): |
|
with st.expander(f"Result {i}", expanded=(i==1)): |
|
render_result(result[1]) |
|
|
|
|
|
tab1, tab2, tab3, tab4 = st.tabs([ |
|
"π Search", "ποΈ Voice", "πΎ History", "βοΈ Settings" |
|
]) |
|
|
|
with tab1: |
|
st.subheader("π Search") |
|
col1, col2 = st.columns([3, 1]) |
|
with col1: |
|
query = st.text_input("Enter search query:", |
|
value="" if st.session_state['initial_search_done'] else "") |
|
with col2: |
|
search_column = st.selectbox("Search in:", |
|
["All Fields"] + st.session_state['search_columns']) |
|
|
|
col3, col4 = st.columns(2) |
|
with col3: |
|
num_results = st.slider("Max results:", 1, 100, 20) |
|
with col4: |
|
search_button = st.button("π Search") |
|
|
|
if (search_button or not st.session_state['initial_search_done']) and query: |
|
st.session_state['initial_search_done'] = True |
|
selected_column = None if search_column == "All Fields" else search_column |
|
|
|
with st.spinner("Searching..."): |
|
df = search.load_page() |
|
results = search.quick_search(query, df) |
|
|
|
if len(results) > 0: |
|
st.session_state['search_history'].append({ |
|
'query': query, |
|
'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
|
'results': results[:5] |
|
}) |
|
|
|
st.write(f"Found {len(results)} results:") |
|
for i, (_, result) in enumerate(results.iterrows(), 1): |
|
if i > num_results: |
|
break |
|
with st.expander(f"Result {i}", expanded=(i==1)): |
|
render_result(result) |
|
else: |
|
st.warning("No matching results found.") |
|
|
|
with tab2: |
|
st.subheader("ποΈ Voice Input") |
|
st.write("Use the voice input above to start speaking, or record a new message:") |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
if st.button("ποΈ Start New Recording"): |
|
st.session_state['recording'] = True |
|
st.experimental_rerun() |
|
with col2: |
|
if st.button("π Stop Recording"): |
|
st.session_state['recording'] = False |
|
st.experimental_rerun() |
|
|
|
if st.session_state.get('recording', False): |
|
voice_component = create_voice_component() |
|
new_val = voice_component(my_input_value="Recording...") |
|
if new_val: |
|
st.text_area("Recorded Text:", value=new_val, height=100) |
|
if st.button("π Search with Recording"): |
|
with st.spinner("Processing recording..."): |
|
df = search.load_page() |
|
results = search.quick_search(new_val, df) |
|
for i, (_, result) in enumerate(results.iterrows(), 1): |
|
with st.expander(f"Result {i}", expanded=(i==1)): |
|
render_result(result) |
|
|
|
with tab3: |
|
st.subheader("πΎ Search History") |
|
if not st.session_state['search_history']: |
|
st.info("No search history yet. Try searching for something!") |
|
else: |
|
for entry in reversed(st.session_state['search_history']): |
|
with st.expander(f"π {entry['timestamp']} - {entry['query']}", expanded=False): |
|
for i, result in enumerate(entry['results'], 1): |
|
st.write(f"**Result {i}:**") |
|
if isinstance(result, pd.Series): |
|
render_result(result) |
|
else: |
|
st.write(result) |
|
|
|
with tab4: |
|
st.subheader("βοΈ Settings") |
|
st.write("Voice Settings:") |
|
default_voice = st.selectbox( |
|
"Default Voice:", |
|
[ |
|
"en-US-AriaNeural", |
|
"en-US-GuyNeural", |
|
"en-GB-SoniaNeural", |
|
"en-GB-TonyNeural" |
|
], |
|
index=0, |
|
key="default_voice" |
|
) |
|
|
|
st.write("Search Settings:") |
|
st.slider("Minimum Search Score:", 0.0, 1.0, MIN_SEARCH_SCORE, 0.1) |
|
st.slider("Exact Match Boost:", 1.0, 3.0, EXACT_MATCH_BOOST, 0.1) |
|
|
|
if st.button("ποΈ Clear Search History"): |
|
st.session_state['search_history'] = [] |
|
st.success("Search history cleared!") |
|
st.experimental_rerun() |
|
|
|
|
|
with st.sidebar: |
|
st.subheader("π Search Metrics") |
|
total_searches = len(st.session_state['search_history']) |
|
st.metric("Total Searches", total_searches) |
|
|
|
if total_searches > 0: |
|
recent_searches = st.session_state['search_history'][-5:] |
|
st.write("Recent Searches:") |
|
for entry in reversed(recent_searches): |
|
st.write(f"π {entry['query']}") |
|
|
|
if __name__ == "__main__": |
|
main() |