CodeCompetitionClaudeVsGPT / backup6.app.py
awacke1's picture
Rename app.py to backup6.app.py
8bcaf1e verified
raw
history blame
25.3 kB
import streamlit as st
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import torch
import json
import os
import glob
from pathlib import Path
from datetime import datetime
import edge_tts
import asyncio
import requests
from collections import defaultdict
from audio_recorder_streamlit import audio_recorder
import streamlit.components.v1 as components
from urllib.parse import quote
from xml.etree import ElementTree as ET
from datasets import load_dataset
# 🧠 Initialize session state variables
SESSION_VARS = {
'search_history': [], # Track search history
'last_voice_input': "", # Last voice input
'transcript_history': [], # Conversation history
'should_rerun': False, # Trigger for UI updates
'search_columns': [], # Available search columns
'initial_search_done': False, # First search flag
'tts_voice': "en-US-AriaNeural", # Default voice
'arxiv_last_query': "", # Last ArXiv search
'dataset_loaded': False, # Dataset load status
'current_page': 0, # Current data page
'data_cache': None, # Data cache
'dataset_info': None # Dataset metadata
}
# Constants
ROWS_PER_PAGE = 100
# Initialize session state
for var, default in SESSION_VARS.items():
if var not in st.session_state:
st.session_state[var] = default
@st.cache_resource
def get_model():
return SentenceTransformer('all-MiniLM-L6-v2')
@st.cache_data
def load_dataset_page(dataset_id, token, page, rows_per_page):
try:
start_idx = page * rows_per_page
end_idx = start_idx + rows_per_page
dataset = load_dataset(
dataset_id,
token=token,
streaming=False,
split=f'train[{start_idx}:{end_idx}]'
)
return pd.DataFrame(dataset)
except Exception as e:
st.error(f"Error loading page {page}: {str(e)}")
return pd.DataFrame()
@st.cache_data
def get_dataset_info(dataset_id, token):
try:
dataset = load_dataset(dataset_id, token=token, streaming=True)
return dataset['train'].info
except Exception as e:
st.error(f"Error loading dataset info: {str(e)}")
return None
def fetch_dataset_info(dataset_id):
info_url = f"https://huggingface.co/api/datasets/{dataset_id}"
try:
response = requests.get(info_url, timeout=30)
if response.status_code == 200:
return response.json()
except Exception as e:
st.warning(f"Error fetching dataset info: {e}")
return None
def fetch_dataset_rows(dataset_id, config="default", split="train", max_rows=100):
url = f"https://datasets-server.huggingface.co/first-rows?dataset={dataset_id}&config={config}&split={split}"
try:
response = requests.get(url, timeout=30)
if response.status_code == 200:
data = response.json()
if 'rows' in data:
processed_rows = []
for row_data in data['rows']:
row = row_data.get('row', row_data)
# Process embeddings if present
for key in row:
if any(term in key.lower() for term in ['embed', 'vector', 'encoding']):
if isinstance(row[key], str):
try:
row[key] = [float(x.strip()) for x in row[key].strip('[]').split(',') if x.strip()]
except:
continue
row['_config'] = config
row['_split'] = split
processed_rows.append(row)
return processed_rows
except Exception as e:
st.warning(f"Error fetching rows: {e}")
return []
class FastDatasetSearcher:
def __init__(self, dataset_id="tomg-group-umd/cinepile"):
self.dataset_id = dataset_id
self.text_model = get_model()
self.token = os.environ.get('DATASET_KEY')
if not self.token:
st.error("Please set the DATASET_KEY environment variable")
st.stop()
if st.session_state['dataset_info'] is None:
st.session_state['dataset_info'] = get_dataset_info(self.dataset_id, self.token)
def load_page(self, page=0):
return load_dataset_page(self.dataset_id, self.token, page, ROWS_PER_PAGE)
def quick_search(self, query, df):
"""Enhanced search with strict token matching and semantic relevance"""
if df.empty or not query.strip():
return df
try:
# Define stricter thresholds
MIN_SEMANTIC_SCORE = 0.5 # Higher semantic threshold
EXACT_MATCH_BOOST = 2.0 # Boost for exact matches
# Get searchable columns
searchable_cols = []
for col in df.columns:
sample_val = df[col].iloc[0]
if not isinstance(sample_val, (np.ndarray, bytes)):
searchable_cols.append(col)
query_lower = query.lower()
query_terms = set(query_lower.split())
query_embedding = self.text_model.encode([query], show_progress_bar=False)[0]
scores = []
matched_any = []
for _, row in df.iterrows():
text_parts = []
row_matched = False
exact_match = False
# Prioritize description and matched_text fields
priority_fields = ['description', 'matched_text']
other_fields = [col for col in searchable_cols if col not in priority_fields]
# First check priority fields for exact matches
for col in priority_fields:
if col in row:
val = row[col]
if val is not None:
val_str = str(val).lower()
# Check for exact token matches
if query_lower in val_str.split():
exact_match = True
if any(term in val_str.split() for term in query_terms):
row_matched = True
text_parts.append(str(val))
# Then check other fields
for col in other_fields:
val = row[col]
if val is not None:
val_str = str(val).lower()
if query_lower in val_str.split():
exact_match = True
if any(term in val_str.split() for term in query_terms):
row_matched = True
text_parts.append(str(val))
text = ' '.join(text_parts)
if text.strip():
# Calculate exact token matches
text_tokens = set(text.lower().split())
matching_terms = query_terms.intersection(text_tokens)
keyword_score = len(matching_terms) / len(query_terms)
# Calculate semantic score
text_embedding = self.text_model.encode([text], show_progress_bar=False)[0]
semantic_score = float(cosine_similarity([query_embedding], [text_embedding])[0][0])
# Weighted scoring with priority for exact matches
combined_score = 0.8 * keyword_score + 0.2 * semantic_score
if exact_match:
combined_score *= EXACT_MATCH_BOOST
elif row_matched:
combined_score *= 1.2
else:
combined_score = 0.0
row_matched = False
scores.append(combined_score)
matched_any.append(row_matched)
results_df = df.copy()
results_df['score'] = scores
results_df['matched'] = matched_any
# Filter relevant results
filtered_df = results_df[
(results_df['matched']) | # Include direct matches
(results_df['score'] > MIN_KEYWORD_MATCHES) # Or high relevance
]
return filtered_df.sort_values('score', ascending=False)
except Exception as e:
st.error(f"Search error: {str(e)}")
return df
class VideoSearch:
def __init__(self):
self.text_model = SentenceTransformer('all-MiniLM-L6-v2')
self.dataset_id = "omegalabsinc/omega-multimodal"
self.load_dataset()
def fetch_dataset_rows(self):
try:
df, configs, splits = search_dataset(
self.dataset_id,
"",
include_configs=None,
include_splits=None
)
if not df.empty:
st.session_state['search_columns'] = [col for col in df.columns
if col not in ['video_embed', 'description_embed', 'audio_embed']
and not col.startswith('_')]
return df
return self.load_example_data()
except Exception as e:
st.warning(f"Error loading videos: {e}")
return self.load_example_data()
def load_example_data(self):
example_data = [{
"video_id": "sample-123",
"youtube_id": "dQw4w9WgXcQ",
"description": "An example video",
"views": 12345,
"start_time": 0,
"end_time": 60
}]
return pd.DataFrame(example_data)
def load_dataset(self):
self.dataset = self.fetch_dataset_rows()
self.prepare_features()
def prepare_features(self):
try:
embed_cols = [col for col in self.dataset.columns
if any(term in col.lower() for term in ['embed', 'vector', 'encoding'])]
embeddings = {}
for col in embed_cols:
try:
data = []
for row in self.dataset[col]:
if isinstance(row, str):
values = [float(x.strip()) for x in row.strip('[]').split(',') if x.strip()]
elif isinstance(row, list):
values = row
else:
continue
data.append(values)
if data:
embeddings[col] = np.array(data)
except:
continue
self.video_embeds = embeddings.get('video_embed', next(iter(embeddings.values())) if embeddings else None)
self.text_embeds = embeddings.get('description_embed', self.video_embeds)
except:
num_rows = len(self.dataset)
self.video_embeds = np.random.randn(num_rows, 384)
self.text_embeds = np.random.randn(num_rows, 384)
def search(self, query, column=None, top_k=20):
"""Enhanced search with better relevance scoring"""
MIN_RELEVANCE = 0.3 # Minimum relevance threshold
query_embedding = self.text_model.encode([query])[0]
video_sims = cosine_similarity([query_embedding], self.video_embeds)[0]
text_sims = cosine_similarity([query_embedding], self.text_embeds)[0]
combined_sims = 0.7 * text_sims + 0.3 * video_sims # Favor text matches
if column and column in self.dataset.columns and column != "All Fields":
# Direct matches in specified column
matches = self.dataset[column].astype(str).str.contains(query, case=False)
combined_sims[matches] *= 1.5 # Boost exact matches
# Filter by minimum relevance
relevant_indices = np.where(combined_sims >= MIN_RELEVANCE)[0]
if len(relevant_indices) == 0:
return []
top_k = min(top_k, len(relevant_indices))
top_indices = relevant_indices[np.argsort(combined_sims[relevant_indices])[-top_k:][::-1]]
results = []
for idx in top_indices:
result = {'relevance_score': float(combined_sims[idx])}
for col in self.dataset.columns:
if col not in ['video_embed', 'description_embed', 'audio_embed']:
result[col] = self.dataset.iloc[idx][col]
results.append(result)
return results
def search_dataset(dataset_id, search_text, include_configs=None, include_splits=None):
dataset_info = fetch_dataset_info(dataset_id)
if not dataset_info:
return pd.DataFrame(), [], []
configs = include_configs if include_configs else dataset_info.get('config_names', ['default'])
all_rows = []
available_splits = set()
for config in configs:
try:
splits_url = f"https://datasets-server.huggingface.co/splits?dataset={dataset_id}&config={config}"
splits_response = requests.get(splits_url, timeout=30)
if splits_response.status_code == 200:
splits_data = splits_response.json()
splits = [split['split'] for split in splits_data.get('splits', [])]
if not splits:
splits = ['train']
if include_splits:
splits = [s for s in splits if s in include_splits]
available_splits.update(splits)
for split in splits:
rows = fetch_dataset_rows(dataset_id, config, split)
for row in rows:
text_content = ' '.join(str(v) for v in row.values()
if isinstance(v, (str, int, float)))
if search_text.lower() in text_content.lower():
row['_matched_text'] = text_content
row['_relevance_score'] = text_content.lower().count(search_text.lower())
all_rows.append(row)
except Exception as e:
st.warning(f"Error processing config {config}: {e}")
continue
if all_rows:
df = pd.DataFrame(all_rows)
df = df.sort_values('_relevance_score', ascending=False)
return df, configs, list(available_splits)
return pd.DataFrame(), configs, list(available_splits)
@st.cache_resource
def get_speech_model():
return edge_tts.Communicate
async def generate_speech(text, voice=None):
if not text.strip():
return None
if not voice:
voice = st.session_state['tts_voice']
try:
communicate = get_speech_model()(text, voice)
audio_file = f"speech_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp3"
await communicate.save(audio_file)
return audio_file
except Exception as e:
st.error(f"Error generating speech: {e}")
return None
def transcribe_audio(audio_path):
"""Placeholder for ASR implementation"""
return "ASR not implemented. Add your preferred speech recognition here!"
def arxiv_search(query, max_results=5):
base_url = "http://export.arxiv.org/api/query?"
search_url = base_url + f"search_query={quote(query)}&start=0&max_results={max_results}"
try:
r = requests.get(search_url)
if r.status_code == 200:
root = ET.fromstring(r.text)
ns = {'atom': 'http://www.w3.org/2005/Atom'}
entries = root.findall('atom:entry', ns)
results = []
for entry in entries:
title = entry.find('atom:title', ns).text.strip()
summary = entry.find('atom:summary', ns).text.strip()
link = next((l.get('href') for l in entry.findall('atom:link', ns)
if l.get('type') == 'text/html'), None)
results.append((title, summary, link))
return results
except Exception as e:
st.error(f"ArXiv search error: {e}")
return []
def show_file_manager():
st.subheader("πŸ“‚ File Manager")
col1, col2 = st.columns(2)
with col1:
uploaded_file = st.file_uploader("Upload File", type=['txt', 'md', 'mp3'])
if uploaded_file:
with open(uploaded_file.name, "wb") as f:
f.write(uploaded_file.getvalue())
st.success(f"Uploaded: {uploaded_file.name}")
st.experimental_rerun()
with col2:
if st.button("πŸ—‘ Clear Files"):
for f in glob.glob("*.txt") + glob.glob("*.md") + glob.glob("*.mp3"):
os.remove(f)
st.success("All files cleared!")
st.experimental_rerun()
files = glob.glob("*.txt") + glob.glob("*.md") + glob.glob("*.mp3")
if files:
st.write("### Existing Files")
for f in files:
with st.expander(f"πŸ“„ {os.path.basename(f)}"):
if f.endswith('.mp3'):
st.audio(f)
else:
with open(f, 'r', encoding='utf-8') as file:
st.text_area("Content", file.read(), height=100)
if st.button(f"Delete {os.path.basename(f)}", key=f"del_{f}"):
os.remove(f)
st.experimental_rerun()
def perform_arxiv_lookup(query, vocal_summary=True, titles_summary=True, full_audio=False):
results = arxiv_search(query, max_results=5)
if not results:
st.write("No results found.")
return
st.markdown(f"**ArXiv Results for '{query}':**")
for i, (title, summary, link) in enumerate(results, start=1):
st.markdown(f"**{i}. {title}**")
st.write(summary)
if link:
st.markdown(f"[View Paper]({link})")
if vocal_summary:
spoken_text = f"Here are ArXiv results for {query}. "
if titles_summary:
spoken_text += " Titles: " + ", ".join([res[0] for res in results])
else:
spoken_text += " " + results[0][1][:200]
audio_file = asyncio.run(generate_speech(spoken_text))
if audio_file:
st.audio(audio_file)
if full_audio:
full_text = ""
for i, (title, summary, _) in enumerate(results, start=1):
full_text += f"Result {i}: {title}. {summary} "
audio_file_full = asyncio.run(generate_speech(full_text))
if audio_file_full:
st.write("### Full Audio Summary")
st.audio(audio_file_full)
def render_result(result):
"""Render a search result with voice selection and TTS options"""
score = result.get('relevance_score', 0)
result_filtered = {k: v for k, v in result.items()
if k not in ['relevance_score', 'video_embed', 'description_embed', 'audio_embed']}
if 'youtube_id' in result:
st.video(f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}")
cols = st.columns([2, 1])
with cols[0]:
text_content = [] # Collect text for TTS
for key, value in result_filtered.items():
if isinstance(value, (str, int, float)):
st.write(f"**{key}:** {value}")
if isinstance(value, str) and len(value.strip()) > 0:
text_content.append(f"{key}: {value}")
with cols[1]:
st.metric("Relevance Score", f"{score:.2%}")
# Voice selection for TTS
voices = {
"Aria (US Female)": "en-US-AriaNeural",
"Guy (US Male)": "en-US-GuyNeural",
"Sonia (UK Female)": "en-GB-SoniaNeural",
"Tony (UK Male)": "en-GB-TonyNeural",
"Jenny (US Female)": "en-US-JennyNeural"
}
selected_voice = st.selectbox(
"Select Voice",
list(voices.keys()),
key=f"voice_{result.get('video_id', '')}"
)
if st.button("πŸ”Š Read Description", key=f"read_{result.get('video_id', '')}"):
text_to_read = ". ".join(text_content)
audio_file = asyncio.run(generate_speech(text_to_read, voices[selected_voice]))
if audio_file:
st.audio(audio_file)
def main():
st.title("πŸŽ₯ Advanced Video & Dataset Search with Voice")
# Initialize search
search = VideoSearch()
# Create tabs
tab1, tab2, tab3, tab4 = st.tabs([
"πŸ” Search", "πŸŽ™οΈ Voice Input", "πŸ“š ArXiv", "πŸ“‚ Files"
])
# Search Tab
with tab1:
st.subheader("Search Videos")
col1, col2 = st.columns([3, 1])
with col1:
query = st.text_input("Enter search query:",
value="" if st.session_state['initial_search_done'] else "aliens")
with col2:
search_column = st.selectbox("Search in:",
["All Fields"] + st.session_state['search_columns'])
col3, col4 = st.columns(2)
with col3:
num_results = st.slider("Max results:", 1, 100, 20)
with col4:
search_button = st.button("πŸ” Search")
if (search_button or not st.session_state['initial_search_done']) and query:
st.session_state['initial_search_done'] = True
selected_column = None if search_column == "All Fields" else search_column
with st.spinner("Searching..."):
results = search.search(query, selected_column, num_results)
if results:
st.session_state['search_history'].append({
'query': query,
'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
'results': results[:5]
})
st.write(f"Found {len(results)} results:")
for i, result in enumerate(results, 1):
with st.expander(f"Result {i}", expanded=(i==1)):
render_result(result)
else:
st.warning("No matching results found.")
# Voice Input Tab
with tab2:
st.subheader("Voice Search")
st.write("πŸŽ™οΈ Record your query:")
audio_bytes = audio_recorder()
if audio_bytes:
with st.spinner("Processing audio..."):
audio_path = f"temp_audio_{datetime.now().strftime('%Y%m%d_%H%M%S')}.wav"
with open(audio_path, "wb") as f:
f.write(audio_bytes)
voice_query = transcribe_audio(audio_path)
st.markdown("**Transcribed Text:**")
st.write(voice_query)
st.session_state['last_voice_input'] = voice_query
if st.button("πŸ” Search from Voice"):
results = search.search(voice_query, None, 20)
for i, result in enumerate(results, 1):
with st.expander(f"Result {i}", expanded=(i==1)):
render_result(result)
if os.path.exists(audio_path):
os.remove(audio_path)
# ArXiv Tab
with tab3:
st.subheader("ArXiv Search")
arxiv_query = st.text_input("Search ArXiv:", value=st.session_state['arxiv_last_query'])
vocal_summary = st.checkbox("πŸŽ™ Quick Audio Summary", value=True)
titles_summary = st.checkbox("πŸ”– Titles Only", value=True)
full_audio = st.checkbox("πŸ“š Full Audio Summary", value=False)
if st.button("πŸ” Search ArXiv"):
st.session_state['arxiv_last_query'] = arxiv_query
perform_arxiv_lookup(arxiv_query, vocal_summary, titles_summary, full_audio)
# File Manager Tab
with tab4:
show_file_manager()
# Sidebar
with st.sidebar:
st.subheader("βš™οΈ Settings & History")
if st.button("πŸ—‘οΈ Clear History"):
st.session_state['search_history'] = []
st.experimental_rerun()
st.markdown("### Recent Searches")
for entry in reversed(st.session_state['search_history'][-5:]):
with st.expander(f"{entry['timestamp']}: {entry['query']}"):
for i, result in enumerate(entry['results'], 1):
st.write(f"{i}. {result.get('description', '')[:100]}...")
st.markdown("### Voice Settings")
st.selectbox("TTS Voice:", [
"en-US-AriaNeural",
"en-US-GuyNeural",
"en-GB-SoniaNeural"
], key="tts_voice")
if __name__ == "__main__":
main()