|
import streamlit as st |
|
import pandas as pd |
|
import numpy as np |
|
from sentence_transformers import SentenceTransformer |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
import torch |
|
import json |
|
import os |
|
import glob |
|
from pathlib import Path |
|
from datetime import datetime |
|
import edge_tts |
|
import asyncio |
|
import base64 |
|
import requests |
|
from collections import defaultdict |
|
from audio_recorder_streamlit import audio_recorder |
|
import streamlit.components.v1 as components |
|
from urllib.parse import quote |
|
from xml.etree import ElementTree as ET |
|
|
|
|
|
if 'search_history' not in st.session_state: |
|
st.session_state['search_history'] = [] |
|
if 'last_voice_input' not in st.session_state: |
|
st.session_state['last_voice_input'] = "" |
|
if 'transcript_history' not in st.session_state: |
|
st.session_state['transcript_history'] = [] |
|
if 'should_rerun' not in st.session_state: |
|
st.session_state['should_rerun'] = False |
|
if 'search_columns' not in st.session_state: |
|
st.session_state['search_columns'] = [] |
|
if 'initial_search_done' not in st.session_state: |
|
st.session_state['initial_search_done'] = False |
|
if 'tts_voice' not in st.session_state: |
|
st.session_state['tts_voice'] = "en-US-AriaNeural" |
|
if 'arxiv_last_query' not in st.session_state: |
|
st.session_state['arxiv_last_query'] = "" |
|
|
|
def fetch_dataset_info(dataset_id): |
|
"""Fetch dataset information including all available configs and splits""" |
|
info_url = f"https://huggingface.co/api/datasets/{dataset_id}" |
|
try: |
|
response = requests.get(info_url, timeout=30) |
|
if response.status_code == 200: |
|
return response.json() |
|
except Exception as e: |
|
st.warning(f"Error fetching dataset info: {e}") |
|
return None |
|
|
|
def fetch_dataset_rows(dataset_id, config="default", split="train", max_rows=100): |
|
"""Fetch rows from a specific config and split of a dataset""" |
|
url = f"https://datasets-server.huggingface.co/first-rows?dataset={dataset_id}&config={config}&split={split}" |
|
try: |
|
response = requests.get(url, timeout=30) |
|
if response.status_code == 200: |
|
data = response.json() |
|
if 'rows' in data: |
|
processed_rows = [] |
|
for row_data in data['rows']: |
|
row = row_data.get('row', row_data) |
|
|
|
for key in row: |
|
if any(term in key.lower() for term in ['embed', 'vector', 'encoding']): |
|
if isinstance(row[key], str): |
|
try: |
|
row[key] = [float(x.strip()) for x in row[key].strip('[]').split(',') if x.strip()] |
|
except: |
|
continue |
|
row['_config'] = config |
|
row['_split'] = split |
|
processed_rows.append(row) |
|
return processed_rows |
|
except Exception as e: |
|
st.warning(f"Error fetching rows for {config}/{split}: {e}") |
|
return [] |
|
|
|
def search_dataset(dataset_id, search_text, include_configs=None, include_splits=None): |
|
""" |
|
Search across all configurations and splits of a dataset |
|
|
|
Args: |
|
dataset_id (str): The Hugging Face dataset ID |
|
search_text (str): Text to search for in descriptions and queries |
|
include_configs (list): List of specific configs to search, or None for all |
|
include_splits (list): List of specific splits to search, or None for all |
|
|
|
Returns: |
|
tuple: (DataFrame of results, list of available configs, list of available splits) |
|
""" |
|
|
|
dataset_info = fetch_dataset_info(dataset_id) |
|
if not dataset_info: |
|
return pd.DataFrame(), [], [] |
|
|
|
|
|
configs = include_configs if include_configs else dataset_info.get('config_names', ['default']) |
|
all_rows = [] |
|
available_splits = set() |
|
|
|
|
|
for config in configs: |
|
try: |
|
|
|
splits_url = f"https://datasets-server.huggingface.co/splits?dataset={dataset_id}&config={config}" |
|
splits_response = requests.get(splits_url, timeout=30) |
|
if splits_response.status_code == 200: |
|
splits_data = splits_response.json() |
|
splits = [split['split'] for split in splits_data.get('splits', [])] |
|
if not splits: |
|
splits = ['train'] |
|
|
|
|
|
if include_splits: |
|
splits = [s for s in splits if s in include_splits] |
|
|
|
available_splits.update(splits) |
|
|
|
|
|
for split in splits: |
|
rows = fetch_dataset_rows(dataset_id, config, split) |
|
for row in rows: |
|
|
|
text_content = ' '.join(str(v) for v in row.values() if isinstance(v, (str, int, float))) |
|
if search_text.lower() in text_content.lower(): |
|
row['_matched_text'] = text_content |
|
row['_relevance_score'] = text_content.lower().count(search_text.lower()) |
|
all_rows.append(row) |
|
|
|
except Exception as e: |
|
st.warning(f"Error processing config {config}: {e}") |
|
continue |
|
|
|
|
|
if all_rows: |
|
df = pd.DataFrame(all_rows) |
|
df = df.sort_values('_relevance_score', ascending=False) |
|
return df, configs, list(available_splits) |
|
|
|
return pd.DataFrame(), configs, list(available_splits) |
|
|
|
class VideoSearch: |
|
def __init__(self): |
|
self.text_model = SentenceTransformer('all-MiniLM-L6-v2') |
|
self.dataset_id = "omegalabsinc/omega-multimodal" |
|
self.load_dataset() |
|
|
|
def fetch_dataset_rows(self): |
|
"""Fetch dataset with enhanced search capabilities""" |
|
try: |
|
|
|
df, configs, splits = search_dataset( |
|
self.dataset_id, |
|
"", |
|
include_configs=None, |
|
include_splits=None |
|
) |
|
|
|
if not df.empty: |
|
st.session_state['search_columns'] = [col for col in df.columns |
|
if col not in ['video_embed', 'description_embed', 'audio_embed'] |
|
and not col.startswith('_')] |
|
return df |
|
|
|
return self.load_example_data() |
|
|
|
except Exception as e: |
|
st.warning(f"Error loading dataset: {e}") |
|
return self.load_example_data() |
|
|
|
def load_example_data(self): |
|
"""Load example data as fallback""" |
|
example_data = [ |
|
{ |
|
"video_id": "cd21da96-fcca-4c94-a60f-0b1e4e1e29fc", |
|
"youtube_id": "IO-vwtyicn4", |
|
"description": "This video shows a close-up of an ancient text carved into a surface.", |
|
"views": 45489, |
|
"start_time": 1452, |
|
"end_time": 1458, |
|
"video_embed": [0.014160037972033024, -0.003111184574663639, -0.016604168340563774], |
|
"description_embed": [-0.05835828185081482, 0.02589797042310238, 0.11952091753482819] |
|
} |
|
] |
|
return pd.DataFrame(example_data) |
|
|
|
def prepare_features(self): |
|
"""Prepare embeddings with adaptive field detection""" |
|
try: |
|
embed_cols = [col for col in self.dataset.columns |
|
if any(term in col.lower() for term in ['embed', 'vector', 'encoding'])] |
|
|
|
embeddings = {} |
|
for col in embed_cols: |
|
try: |
|
data = [] |
|
for row in self.dataset[col]: |
|
if isinstance(row, str): |
|
values = [float(x.strip()) for x in row.strip('[]').split(',') if x.strip()] |
|
elif isinstance(row, list): |
|
values = row |
|
else: |
|
continue |
|
data.append(values) |
|
|
|
if data: |
|
embeddings[col] = np.array(data) |
|
except: |
|
continue |
|
|
|
|
|
if 'video_embed' in embeddings: |
|
self.video_embeds = embeddings['video_embed'] |
|
else: |
|
self.video_embeds = next(iter(embeddings.values())) |
|
|
|
if 'description_embed' in embeddings: |
|
self.text_embeds = embeddings['description_embed'] |
|
else: |
|
self.text_embeds = self.video_embeds |
|
|
|
except: |
|
|
|
num_rows = len(self.dataset) |
|
self.video_embeds = np.random.randn(num_rows, 384) |
|
self.text_embeds = np.random.randn(num_rows, 384) |
|
|
|
def load_dataset(self): |
|
self.dataset = self.fetch_dataset_rows() |
|
self.prepare_features() |
|
|
|
def search(self, query, column=None, top_k=20): |
|
query_embedding = self.text_model.encode([query])[0] |
|
video_sims = cosine_similarity([query_embedding], self.video_embeds)[0] |
|
text_sims = cosine_similarity([query_embedding], self.text_embeds)[0] |
|
combined_sims = 0.5 * video_sims + 0.5 * text_sims |
|
|
|
|
|
if column and column in self.dataset.columns and column != "All Fields": |
|
mask = self.dataset[column].astype(str).str.contains(query, case=False) |
|
combined_sims[~mask] *= 0.5 |
|
|
|
top_k = min(top_k, 100) |
|
top_indices = np.argsort(combined_sims)[-top_k:][::-1] |
|
|
|
results = [] |
|
for idx in top_indices: |
|
result = {'relevance_score': float(combined_sims[idx])} |
|
for col in self.dataset.columns: |
|
if col not in ['video_embed', 'description_embed', 'audio_embed']: |
|
result[col] = self.dataset.iloc[idx][col] |
|
results.append(result) |
|
|
|
return results |
|
|
|
@st.cache_resource |
|
def get_speech_model(): |
|
return edge_tts.Communicate |
|
|
|
async def generate_speech(text, voice=None): |
|
if not text.strip(): |
|
return None |
|
if not voice: |
|
voice = st.session_state['tts_voice'] |
|
try: |
|
communicate = get_speech_model()(text, voice) |
|
audio_file = f"speech_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp3" |
|
await communicate.save(audio_file) |
|
return audio_file |
|
except Exception as e: |
|
st.error(f"Error generating speech: {e}") |
|
return None |
|
|
|
def transcribe_audio(audio_path): |
|
"""Placeholder for ASR transcription""" |
|
return "ASR not implemented. Integrate a local model or another service here." |
|
|
|
def show_file_manager(): |
|
"""Display file manager interface""" |
|
st.subheader("π File Manager") |
|
col1, col2 = st.columns(2) |
|
with col1: |
|
uploaded_file = st.file_uploader("Upload File", type=['txt', 'md', 'mp3']) |
|
if uploaded_file: |
|
with open(uploaded_file.name, "wb") as f: |
|
f.write(uploaded_file.getvalue()) |
|
st.success(f"Uploaded: {uploaded_file.name}") |
|
st.experimental_rerun() |
|
|
|
with col2: |
|
if st.button("π Clear All Files"): |
|
for f in glob.glob("*.txt") + glob.glob("*.md") + glob.glob("*.mp3"): |
|
os.remove(f) |
|
st.success("All files cleared!") |
|
st.experimental_rerun() |
|
|
|
files = glob.glob("*.txt") + glob.glob("*.md") + glob.glob("*.mp3") |
|
if files: |
|
st.write("### Existing Files") |
|
for f in files: |
|
with st.expander(f"π {os.path.basename(f)}"): |
|
if f.endswith('.mp3'): |
|
st.audio(f) |
|
else: |
|
with open(f, 'r', encoding='utf-8') as file: |
|
st.text_area("Content", file.read(), height=100) |
|
if st.button(f"Delete {os.path.basename(f)}", key=f"del_{f}"): |
|
os.remove(f) |
|
st.experimental_rerun() |
|
|
|
def arxiv_search(query, max_results=5): |
|
"""Perform a simple Arxiv search using their API and return top results.""" |
|
base_url = "http://export.arxiv.org/api/query?" |
|
search_url = base_url + f"search_query={quote(query)}&start=0&max_results={max_results}" |
|
r = requests.get(search_url) |
|
if r.status_code == 200: |
|
root = ET.fromstring(r.text) |
|
ns = {'atom': 'http://www.w3.org/2005/Atom'} |
|
entries = root.findall('atom:entry', ns) |
|
results = [] |
|
for entry in entries: |
|
title = entry.find('atom:title', ns).text.strip() |
|
summary = entry.find('atom:summary', ns).text.strip() |
|
link = None |
|
for l in entry.findall('atom:link', ns): |
|
if l.get('type') == 'text/html': |
|
link = l.get('href') |
|
break |
|
results.append((title, summary, link)) |
|
return results |
|
return [] |
|
|
|
def perform_arxiv_lookup(q, vocal_summary=True, titles_summary=True, full_audio=False): |
|
results = arxiv_search(q, max_results=5) |
|
if not results: |
|
st.write("No Arxiv results found.") |
|
return |
|
st.markdown(f"**Arxiv Search Results for '{q}':**") |
|
for i, (title, summary, link) in enumerate(results, start=1): |
|
st.markdown(f"**{i}. {title}**") |
|
st.write(summary) |
|
if link: |
|
st.markdown(f"[View Paper]({link})") |
|
|
|
if vocal_summary: |
|
spoken_text = f"Here are some Arxiv results for {q}. " |
|
if titles_summary: |
|
spoken_text += " Titles: " + ", ".join([res[0] for res in results]) |
|
else: |
|
|
|
spoken_text += " " + results[0][1][:200] |
|
|
|
audio_file = asyncio.run(generate_speech(spoken_text)) |
|
if audio_file: |
|
st.audio(audio_file) |
|
|
|
if full_audio: |
|
|
|
full_text = "" |
|
for i,(title, summary, _) in enumerate(results, start=1): |
|
full_text += f"Result {i}: {title}. {summary} " |
|
audio_file_full = asyncio.run(generate_speech(full_text)) |
|
if audio_file_full: |
|
st.write("### Full Audio") |
|
st.audio(audio_file_full) |
|
|
|
def main(): |
|
st.title("π₯ Video & Arxiv Search with Voice (No OpenAI/Anthropic)") |
|
|
|
|
|
search = VideoSearch() |
|
|
|
|
|
tab1, tab2, tab3, tab4, tab5 = st.tabs(["π Search", "ποΈ Voice Input", "π Arxiv", "π Files", "π Advanced Search"]) |
|
|
|
|
|
with tab1: |
|
st.subheader("Search Videos") |
|
col1, col2 = st.columns([3, 1]) |
|
with col1: |
|
query = st.text_input("Enter your search query:", |
|
value="ancient" if not st.session_state['initial_search_done'] else "") |
|
with col2: |
|
search_column = st.selectbox("Search in field:", |
|
["All Fields"] + st.session_state['search_columns']) |
|
|
|
col3, col4 = st.columns(2) |
|
with col3: |
|
num_results = st.slider("Number of results:", 1, 100, 20) |
|
with col4: |
|
search_button = st.button("π Search") |
|
|
|
if (search_button or not st.session_state['initial_search_done']) and query: |
|
st.session_state['initial_search_done'] = True |
|
selected_column = None if search_column == "All Fields" else search_column |
|
with st.spinner("Searching..."): |
|
results = search.search(query, selected_column, num_results) |
|
|
|
st.session_state['search_history'].append({ |
|
'query': query, |
|
'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
|
'results': results[:5] |
|
}) |
|
|
|
for i, result in enumerate(results, 1): |
|
with st.expander(f"Result {i}: {result['description'][:100]}...", expanded=(i==1)): |
|
cols = st.columns([2, 1]) |
|
with cols[0]: |
|
st.markdown("**Description:**") |
|
st.write(result['description']) |
|
st.markdown(f"**Time Range:** {result['start_time']}s - {result['end_time']}s") |
|
st.markdown(f"**Views:** {result['views']:,}") |
|
|
|
with cols[1]: |
|
st.markdown(f"**Relevance Score:** {result['relevance_score']:.2%}") |
|
if result.get('youtube_id'): |
|
st.video(f"https://youtube.com/watch?v={result['youtube_id']}&t={result['start_time']}") |
|
|
|
if st.button(f"π Audio Summary", key=f"audio_{i}"): |
|
summary = f"Video summary: {result['description'][:200]}" |
|
audio_file = asyncio.run(generate_speech(summary)) |
|
if audio_file: |
|
st.audio(audio_file) |
|
|
|
|
|
with tab2: |
|
st.subheader("Voice Input") |
|
st.write("ποΈ Record your voice:") |
|
audio_bytes = audio_recorder() |
|
if audio_bytes: |
|
audio_path = f"temp_audio_{datetime.now().strftime('%Y%m%d_%H%M%S')}.wav" |
|
with open(audio_path, "wb") as f: |
|
f.write(audio_bytes) |
|
st.success("Audio recorded successfully!") |
|
|
|
voice_query = transcribe_audio(audio_path) |
|
st.markdown("**Transcribed Text:**") |
|
st.write(voice_query) |
|
st.session_state['last_voice_input'] = voice_query |
|
|
|
if st.button("π Search from Voice"): |
|
results = search.search(voice_query, None, 20) |
|
for i, result in enumerate(results, 1): |
|
with st.expander(f"Result {i}", expanded=(i==1)): |
|
st.write(result['description']) |
|
if result.get('youtube_id'): |
|
st.video(f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}") |
|
|
|
if os.path.exists(audio_path): |
|
os.remove(audio_path) |
|
|
|
|
|
with tab3: |
|
st.subheader("Arxiv Search") |
|
q = st.text_input("Enter your Arxiv search query:", value=st.session_state['arxiv_last_query']) |
|
vocal_summary = st.checkbox("π Short Audio Summary", value=True) |
|
titles_summary = st.checkbox("π Titles Only", value=True) |
|
full_audio = st.checkbox("π Full Audio Results", value=False) |
|
|
|
if st.button("π Arxiv Search"): |
|
st.session_state['arxiv_last_query'] = q |
|
perform_arxiv_lookup(q, vocal_summary=vocal_summary, titles_summary=titles_summary, full_audio=full_audio) |
|
|
|
|
|
with tab4: |
|
show_file_manager() |
|
|
|
|
|
with tab5: |
|
st.subheader("Advanced Dataset Search") |
|
|
|
|
|
dataset_id = st.text_input("Dataset ID:", value="omegalabsinc/omega-multimodal") |
|
|
|
|
|
col1, col2 = st.columns([2, 1]) |
|
with col1: |
|
search_text = st.text_input("Search text:", |
|
placeholder="Enter text to search across all fields") |
|
|
|
|
|
if dataset_id: |
|
dataset_info = fetch_dataset_info(dataset_id) |
|
if dataset_info: |
|
configs = dataset_info.get('config_names', ['default']) |
|
with col2: |
|
selected_configs = st.multiselect( |
|
"Configurations:", |
|
options=configs, |
|
default=['default'] if 'default' in configs else None |
|
) |
|
|
|
|
|
if selected_configs: |
|
all_splits = set() |
|
for config in selected_configs: |
|
splits_url = f"https://datasets-server.huggingface.co/splits?dataset={dataset_id}&config={config}" |
|
try: |
|
response = requests.get(splits_url, timeout=30) |
|
if response.status_code == 200: |
|
splits_data = response.json() |
|
splits = [split['split'] for split in splits_data.get('splits', [])] |
|
all_splits.update(splits) |
|
except Exception as e: |
|
st.warning(f"Error fetching splits for {config}: {e}") |
|
|
|
selected_splits = st.multiselect( |
|
"Splits:", |
|
options=list(all_splits), |
|
default=['train'] if 'train' in all_splits else None |
|
) |
|
|
|
if st.button("π Search Dataset"): |
|
with st.spinner("Searching dataset..."): |
|
results_df, _, _ = search_dataset( |
|
dataset_id, |
|
search_text, |
|
include_configs=selected_configs, |
|
include_splits=selected_splits |
|
) |
|
|
|
if not results_df.empty: |
|
st.write(f"Found {len(results_df)} results") |
|
|
|
|
|
for idx, row in results_df.iterrows(): |
|
with st.expander( |
|
f"Result {idx+1} (Config: {row['_config']}, Split: {row['_split']}, Score: {row['_relevance_score']})" |
|
): |
|
|
|
for col in row.index: |
|
if not col.startswith('_') and not any( |
|
term in col.lower() |
|
for term in ['embed', 'vector', 'encoding'] |
|
): |
|
st.write(f"**{col}:** {row[col]}") |
|
|
|
|
|
if 'youtube_id' in row: |
|
st.video( |
|
f"https://youtube.com/watch?v={row['youtube_id']}&t={row.get('start_time', 0)}" |
|
) |
|
else: |
|
st.warning("No results found.") |
|
else: |
|
st.error("Unable to fetch dataset information.") |
|
|
|
|
|
with st.sidebar: |
|
st.subheader("βοΈ Settings & History") |
|
if st.button("ποΈ Clear History"): |
|
st.session_state['search_history'] = [] |
|
st.experimental_rerun() |
|
|
|
st.markdown("### Recent Searches") |
|
for entry in reversed(st.session_state['search_history'][-5:]): |
|
with st.expander(f"{entry['timestamp']}: {entry['query']}"): |
|
for i, result in enumerate(entry['results'], 1): |
|
st.write(f"{i}. {result['description'][:100]}...") |
|
|
|
st.markdown("### Voice Settings") |
|
st.selectbox("TTS Voice:", |
|
["en-US-AriaNeural", "en-US-GuyNeural", "en-GB-SoniaNeural"], |
|
key="tts_voice") |
|
|
|
if __name__ == "__main__": |
|
main() |