taks-2 / src /streamlit_app.py
jiekarl's picture
Update src/streamlit_app.py
28f0e46 verified
import os
import streamlit as st
import pandas as pd
import chromadb
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
import numpy as np
from tqdm import tqdm
import re
import nltk
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
# Create directories and set environment
os.makedirs(os.path.expanduser('~/.streamlit'), exist_ok=True)
os.makedirs(os.path.join(os.path.expanduser('~'), 'nltk_data'), exist_ok=True)
os.environ['NLTK_DATA'] = os.path.join(os.path.expanduser('~'), 'nltk_data')
os.environ['STREAMLIT_GLOBAL_METRICS'] = 'false'
# Download NLTK resources
nltk.download('stopwords', download_dir=os.environ['NLTK_DATA'])
nltk.download('wordnet', download_dir=os.environ['NLTK_DATA'])
# Initialize lemmatizer and stopwords
nltk.data.path.append(os.environ['NLTK_DATA'])
lemmatizer = WordNetLemmatizer()
stop_words = set(stopwords.words('english'))
# Text preprocessing function
def preprocess_text(text):
# Lowercase
text = text.lower()
# Remove special characters
text = re.sub(r'[^a-zA-Z0-9\s]', '', text)
# Remove stopwords and lemmatize
tokens = text.split()
tokens = [lemmatizer.lemmatize(token) for token in tokens if token not in stop_words]
return " ".join(tokens)
# Initialize ChromaDB
def init_chroma_db(df):
client = chromadb.PersistentClient(path="./chroma_store")
try:
collection = client.get_collection(name="quotes")
st.success("Loaded existing quotes collection")
except:
st.info("Creating new quotes collection...")
collection = client.create_collection(name="quotes")
# Preprocess texts
texts = df['quote'].tolist()
processed_texts = [preprocess_text(text) for text in tqdm(texts, desc="Preprocessing")]
# Generate embeddings
embed_model = SentenceTransformer('all-MiniLM-L6-v2')
embeddings = embed_model.encode(processed_texts, show_progress_bar=True)
# Prepare metadata
metadatas = []
for i, row in tqdm(df.iterrows(), total=len(df), desc="Preparing metadata"):
meta = {"author": row['author']}
# Store tags as separate fields
for j, tag in enumerate(row['tags']):
meta[f"tag_{j}"] = tag
metadatas.append(meta)
# Add to collection
ids = [str(i) for i in range(len(texts))]
collection.add(
embeddings=embeddings.tolist(),
documents=texts, # Store original text
metadatas=metadatas,
ids=ids
)
st.success(f"Created new collection with {len(texts)} quotes")
return collection
# Retrieve quotes function
def retrieve_quotes(query, collection, k=5):
embed_model = SentenceTransformer('all-MiniLM-L6-v2')
processed_query = preprocess_text(query)
query_embedding = embed_model.encode([processed_query])[0].tolist()
results = collection.query(
query_embeddings=[query_embedding],
n_results=k
)
return [
f"\"{doc}\" β€” {meta['author']}"
for doc, meta in zip(results['documents'][0], results['metadatas'][0])
]
# Main app
def main():
st.set_page_config(page_title="Quote Finder", layout="wide")
st.title("Intelligent Quote Finder")
st.markdown("Find relevant quotes using semantic search powered by ChromaDB")
# Load data
try:
df = pd.read_json("https://huggingface.co/datasets/Abirate/english_quotes/resolve/main/quotes.jsonl", lines=True)
st.session_state.df = df
except Exception as e:
st.error(f"Error loading data: {e}")
return
# Initialize ChromaDB
if 'collection' not in st.session_state:
with st.spinner("Initializing database..."):
st.session_state.collection = init_chroma_db(df)
# Sidebar controls
with st.sidebar:
st.header("Search Parameters")
k = st.slider("Number of results", 1, 20, 5)
st.divider()
st.header("Database Info")
st.write(f"Total quotes: {len(df)}")
st.write(f"Authors: {df['author'].nunique()}")
st.write(f"Tags: {sum(len(tags) for tags in df['tags'])}")
# Main search interface
query = st.text_input("Search for quotes", placeholder="Enter a topic, emotion, or phrase...")
if st.button("Find Quotes") or query:
if not query:
st.warning("Please enter a search query")
return
with st.spinner(f"Finding quotes related to '{query}'..."):
results = retrieve_quotes(query, st.session_state.collection, k)
if results:
st.subheader(f"Top {len(results)} matching quotes:")
for i, quote in enumerate(results, 1):
st.markdown(f"#### {i}. {quote}")
else:
st.warning("No matching quotes found")
if __name__ == "__main__":
main()