|
import streamlit as st |
|
import pandas as pd |
|
import numpy as np |
|
import json |
|
import os |
|
import glob |
|
import random |
|
from pathlib import Path |
|
from datetime import datetime |
|
import edge_tts |
|
import asyncio |
|
import requests |
|
import streamlit.components.v1 as components |
|
import base64 |
|
import re |
|
from xml.etree import ElementTree as ET |
|
from datasets import load_dataset |
|
|
|
|
|
USER_NAMES = [ |
|
"Aria", "Guy", "Sonia", "Tony", "Jenny", "Davis", "Libby", "Clara", "Liam", "Natasha", "William" |
|
] |
|
|
|
ENGLISH_VOICES = [ |
|
"en-US-AriaNeural", "en-US-GuyNeural", "en-GB-SoniaNeural", "en-GB-TonyNeural", |
|
"en-US-JennyNeural", "en-US-DavisNeural", "en-GB-LibbyNeural", "en-CA-ClaraNeural", |
|
"en-CA-LiamNeural", "en-AU-NatashaNeural", "en-AU-WilliamNeural" |
|
] |
|
|
|
|
|
USER_VOICES = dict(zip(USER_NAMES, ENGLISH_VOICES)) |
|
|
|
ROWS_PER_PAGE = 100 |
|
SAVED_INPUTS_DIR = "saved_inputs" |
|
os.makedirs(SAVED_INPUTS_DIR, exist_ok=True) |
|
|
|
SESSION_VARS = { |
|
'search_history': [], |
|
'last_voice_input': "", |
|
'transcript_history': [], |
|
'should_rerun': False, |
|
'search_columns': [], |
|
'initial_search_done': False, |
|
'arxiv_last_query': "", |
|
'dataset_loaded': False, |
|
'current_page': 0, |
|
'data_cache': None, |
|
'dataset_info': None, |
|
'nps_submitted': False, |
|
'nps_last_shown': None, |
|
'old_val': None, |
|
'voice_text': None, |
|
'user_name': random.choice(USER_NAMES), |
|
'max_items': 100, |
|
'global_voice': "en-US-AriaNeural", |
|
'last_arxiv_input': None |
|
} |
|
|
|
for var, default in SESSION_VARS.items(): |
|
if var not in st.session_state: |
|
st.session_state[var] = default |
|
|
|
def create_voice_component(): |
|
mycomponent = components.declare_component( |
|
"mycomponent", |
|
path="mycomponent" |
|
) |
|
return mycomponent |
|
|
|
def clean_for_speech(text: str) -> str: |
|
text = text.replace("\n", " ") |
|
text = text.replace("</s>", " ") |
|
text = text.replace("#", "") |
|
text = re.sub(r"\(https?:\/\/[^\)]+\)", "", text) |
|
text = re.sub(r"\s+", " ", text).strip() |
|
return text |
|
|
|
async def edge_tts_generate_audio(text, voice="en-US-AriaNeural"): |
|
text = clean_for_speech(text) |
|
if not text.strip(): |
|
return None |
|
communicate = edge_tts.Communicate(text, voice) |
|
out_fn = f"speech_{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}.mp3" |
|
await communicate.save(out_fn) |
|
return out_fn |
|
|
|
def speak_with_edge_tts(text, voice="en-US-AriaNeural"): |
|
return asyncio.run(edge_tts_generate_audio(text, voice)) |
|
|
|
def play_and_download_audio(file_path): |
|
if file_path and os.path.exists(file_path): |
|
st.audio(file_path) |
|
dl_link = f'<a href="data:audio/mpeg;base64,{base64.b64encode(open(file_path,"rb").read()).decode()}" download="{os.path.basename(file_path)}">Download {os.path.basename(file_path)}</a>' |
|
st.markdown(dl_link, unsafe_allow_html=True) |
|
|
|
def generate_filename(prefix, text): |
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
safe_text = re.sub(r'[^\w\s-]', '', text[:50]).strip().lower() |
|
safe_text = re.sub(r'[-\s]+', '-', safe_text) |
|
return f"{prefix}_{timestamp}_{safe_text}.md" |
|
|
|
def save_input_as_md(user_name, text, prefix="input"): |
|
if not text.strip(): |
|
return |
|
fn = generate_filename(prefix, text) |
|
full_path = os.path.join(SAVED_INPUTS_DIR, fn) |
|
with open(full_path, 'w', encoding='utf-8') as f: |
|
f.write(f"# User: {user_name}\n") |
|
f.write(f"**Timestamp:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n") |
|
f.write(text) |
|
return full_path |
|
|
|
def save_response_as_md(user_name, text, prefix="response"): |
|
if not text.strip(): |
|
return |
|
fn = generate_filename(prefix, text) |
|
full_path = os.path.join(SAVED_INPUTS_DIR, fn) |
|
with open(full_path, 'w', encoding='utf-8') as f: |
|
f.write(f"# User: {user_name}\n") |
|
f.write(f"**Timestamp:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n") |
|
f.write(text) |
|
return full_path |
|
|
|
def list_saved_inputs(): |
|
files = sorted(glob.glob(os.path.join(SAVED_INPUTS_DIR, "*.md"))) |
|
return files |
|
|
|
def parse_md_file(fpath): |
|
user_line = "" |
|
ts_line = "" |
|
content_lines = [] |
|
with open(fpath, 'r', encoding='utf-8') as f: |
|
lines = f.readlines() |
|
for line in lines: |
|
if line.startswith("# User:"): |
|
user_line = line.replace("# User:", "").strip() |
|
elif line.startswith("**Timestamp:**"): |
|
ts_line = line.replace("**Timestamp:**", "").strip() |
|
else: |
|
content_lines.append(line.strip()) |
|
content = "\n".join(content_lines).strip() |
|
return user_line, ts_line, content |
|
|
|
def arxiv_search(query, max_results=3): |
|
base_url = "http://export.arxiv.org/api/query" |
|
params = { |
|
'search_query': query.replace(' ', '+'), |
|
'start': 0, |
|
'max_results': max_results |
|
} |
|
response = requests.get(base_url, params=params, timeout=30) |
|
if response.status_code == 200: |
|
root = ET.fromstring(response.text) |
|
ns = {"a": "http://www.w3.org/2005/Atom"} |
|
entries = root.findall('a:entry', ns) |
|
results = [] |
|
for entry in entries: |
|
title = entry.find('a:title', ns).text.strip() |
|
summary = entry.find('a:summary', ns).text.strip() |
|
summary_short = summary[:300] + "..." |
|
results.append((title, summary_short)) |
|
return results |
|
return [] |
|
|
|
def summarize_arxiv_results(results): |
|
lines = [] |
|
for i, (title, summary) in enumerate(results, 1): |
|
lines.append(f"Result {i}: {title}\n{summary}\n") |
|
return "\n\n".join(lines) |
|
|
|
def simple_dataset_search(query, df): |
|
if df.empty or not query.strip(): |
|
return pd.DataFrame() |
|
query_terms = query.lower().split() |
|
matches = [] |
|
for idx, row in df.iterrows(): |
|
text_parts = [] |
|
for col in df.columns: |
|
val = row[col] |
|
if isinstance(val, str): |
|
text_parts.append(val.lower()) |
|
elif isinstance(val, (int, float)): |
|
text_parts.append(str(val)) |
|
full_text = " ".join(text_parts) |
|
if any(qt in full_text for qt in query_terms): |
|
matches.append(row) |
|
if matches: |
|
return pd.DataFrame(matches) |
|
return pd.DataFrame() |
|
|
|
from datasets import load_dataset |
|
|
|
@st.cache_data |
|
def load_dataset_page(dataset_id, token, page, rows_per_page): |
|
try: |
|
start_idx = page * rows_per_page |
|
end_idx = start_idx + rows_per_page |
|
dataset = load_dataset( |
|
dataset_id, |
|
token=token, |
|
streaming=False, |
|
split=f'train[{start_idx}:{end_idx}]' |
|
) |
|
return pd.DataFrame(dataset) |
|
except: |
|
return pd.DataFrame() |
|
|
|
class SimpleDatasetSearcher: |
|
def __init__(self, dataset_id="tomg-group-umd/cinepile"): |
|
self.dataset_id = dataset_id |
|
self.token = os.environ.get('DATASET_KEY') |
|
def load_page(self, page=0): |
|
return load_dataset_page(self.dataset_id, self.token, page, ROWS_PER_PAGE) |
|
|
|
def concatenate_mp3(files, output_file): |
|
|
|
with open(output_file, 'wb') as outfile: |
|
for f in files: |
|
with open(f, 'rb') as infile: |
|
outfile.write(infile.read()) |
|
|
|
def main(): |
|
st.title("ποΈ Voice Chat & Search") |
|
|
|
|
|
with st.sidebar: |
|
|
|
st.session_state['user_name'] = st.selectbox("Current User:", USER_NAMES, index=0) |
|
|
|
st.session_state['max_items'] = st.number_input("Max Items per search iteration:", min_value=1, max_value=1000, value=st.session_state['max_items']) |
|
|
|
st.subheader("π Saved Inputs & Responses") |
|
saved_files = list_saved_inputs() |
|
for fpath in saved_files: |
|
user, ts, content = parse_md_file(fpath) |
|
fname = os.path.basename(fpath) |
|
st.write(f"- {fname} (User: {user})") |
|
|
|
|
|
voice_component = create_voice_component() |
|
voice_val = voice_component(my_input_value="Start speaking...") |
|
|
|
|
|
tab1, tab2, tab3, tab4 = st.tabs(["π£οΈ Voice Chat History", "π ArXiv Search", "π Dataset Search", "βοΈ Settings"]) |
|
|
|
|
|
with tab1: |
|
st.subheader("Voice Chat History") |
|
files = list_saved_inputs() |
|
conversation = [] |
|
for fpath in files: |
|
user, ts, content = parse_md_file(fpath) |
|
conversation.append((user, ts, content, fpath)) |
|
|
|
|
|
for i, (user, ts, content, fpath) in enumerate(reversed(conversation), start=1): |
|
with st.expander(f"{ts} - {user}", expanded=False): |
|
st.write(content) |
|
|
|
if st.button(f"π Read Aloud {ts}-{user}", key=f"read_{i}_{fpath}"): |
|
voice = USER_VOICES.get(user, "en-US-AriaNeural") |
|
audio_file = speak_with_edge_tts(content, voice=voice) |
|
if audio_file: |
|
play_and_download_audio(audio_file) |
|
|
|
|
|
if st.button("π Read Conversation", key="read_conversation_all"): |
|
|
|
conversation_chrono = list(reversed(conversation)) |
|
mp3_files = [] |
|
for user, ts, content, fpath in conversation_chrono: |
|
voice = USER_VOICES.get(user, "en-US-AriaNeural") |
|
audio_file = speak_with_edge_tts(content, voice=voice) |
|
if audio_file: |
|
mp3_files.append(audio_file) |
|
st.write(f"**{user} ({ts}):**") |
|
play_and_download_audio(audio_file) |
|
|
|
if mp3_files: |
|
combined_file = f"full_conversation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp3" |
|
concatenate_mp3(mp3_files, combined_file) |
|
st.write("**Full Conversation Audio:**") |
|
play_and_download_audio(combined_file) |
|
|
|
|
|
with tab2: |
|
st.subheader("ArXiv Search") |
|
edited_input = st.text_area("Enter or Edit Search Query:", value=(voice_val.strip() if voice_val else ""), height=100) |
|
autorun = st.checkbox("β‘ Auto-Run", value=True) |
|
run_arxiv = st.button("π ArXiv Search", key="run_arxiv_button") |
|
|
|
input_changed = (edited_input != st.session_state.get('old_val')) |
|
should_run_arxiv = False |
|
if autorun and input_changed and edited_input.strip(): |
|
should_run_arxiv = True |
|
if run_arxiv and edited_input.strip(): |
|
should_run_arxiv = True |
|
|
|
if should_run_arxiv and st.session_state['last_arxiv_input'] != edited_input: |
|
st.session_state['old_val'] = edited_input |
|
st.session_state['last_arxiv_input'] = edited_input |
|
save_input_as_md(st.session_state['user_name'], edited_input, prefix="input") |
|
with st.spinner("Searching ArXiv..."): |
|
results = arxiv_search(edited_input) |
|
if results: |
|
summary = summarize_arxiv_results(results) |
|
save_response_as_md(st.session_state['user_name'], summary, prefix="response") |
|
st.write(summary) |
|
|
|
voice = USER_VOICES.get(st.session_state['user_name'], "en-US-AriaNeural") |
|
audio_file = speak_with_edge_tts(summary, voice=voice) |
|
if audio_file: |
|
play_and_download_audio(audio_file) |
|
else: |
|
st.warning("No results found on ArXiv.") |
|
|
|
|
|
with tab3: |
|
st.subheader("Dataset Search") |
|
ds_searcher = SimpleDatasetSearcher() |
|
query = st.text_input("Enter dataset search query:") |
|
run_ds_search = st.button("Search Dataset", key="ds_search_button") |
|
num_results = st.slider("Max results:", 1, 100, 20, key="ds_max_results") |
|
|
|
if run_ds_search and query.strip(): |
|
with st.spinner("Searching dataset..."): |
|
df = ds_searcher.load_page(0) |
|
results = simple_dataset_search(query, df) |
|
if not results.empty: |
|
st.write(f"Found {len(results)} results:") |
|
shown = 0 |
|
for i, (_, row) in enumerate(results.iterrows(), 1): |
|
if shown >= num_results: |
|
break |
|
with st.expander(f"Result {i}", expanded=(i==1)): |
|
for k, v in row.items(): |
|
st.write(f"**{k}:** {v}") |
|
shown += 1 |
|
else: |
|
st.warning("No matching results found.") |
|
|
|
|
|
with tab4: |
|
st.subheader("Settings") |
|
if st.button("ποΈ Clear Search History", key="clear_history"): |
|
|
|
for fpath in list_saved_inputs(): |
|
os.remove(fpath) |
|
st.session_state['search_history'] = [] |
|
st.success("Search history cleared for everyone!") |
|
st.rerun() |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|