|
import os |
|
import streamlit as st |
|
import pandas as pd |
|
import chromadb |
|
from chromadb.config import Settings |
|
from sentence_transformers import SentenceTransformer |
|
import numpy as np |
|
from tqdm import tqdm |
|
import re |
|
import nltk |
|
from nltk.corpus import stopwords |
|
from nltk.stem import WordNetLemmatizer |
|
|
|
|
|
os.makedirs(os.path.expanduser('~/.streamlit'), exist_ok=True) |
|
os.makedirs(os.path.join(os.path.expanduser('~'), 'nltk_data'), exist_ok=True) |
|
os.environ['NLTK_DATA'] = os.path.join(os.path.expanduser('~'), 'nltk_data') |
|
os.environ['STREAMLIT_GLOBAL_METRICS'] = 'false' |
|
|
|
|
|
nltk.download('stopwords', download_dir=os.environ['NLTK_DATA']) |
|
nltk.download('wordnet', download_dir=os.environ['NLTK_DATA']) |
|
|
|
|
|
nltk.data.path.append(os.environ['NLTK_DATA']) |
|
lemmatizer = WordNetLemmatizer() |
|
stop_words = set(stopwords.words('english')) |
|
|
|
|
|
def preprocess_text(text): |
|
|
|
text = text.lower() |
|
|
|
text = re.sub(r'[^a-zA-Z0-9\s]', '', text) |
|
|
|
tokens = text.split() |
|
tokens = [lemmatizer.lemmatize(token) for token in tokens if token not in stop_words] |
|
return " ".join(tokens) |
|
|
|
|
|
def init_chroma_db(df): |
|
client = chromadb.PersistentClient(path="./chroma_store") |
|
|
|
try: |
|
collection = client.get_collection(name="quotes") |
|
st.success("Loaded existing quotes collection") |
|
except: |
|
st.info("Creating new quotes collection...") |
|
collection = client.create_collection(name="quotes") |
|
|
|
|
|
texts = df['quote'].tolist() |
|
processed_texts = [preprocess_text(text) for text in tqdm(texts, desc="Preprocessing")] |
|
|
|
|
|
embed_model = SentenceTransformer('all-MiniLM-L6-v2') |
|
embeddings = embed_model.encode(processed_texts, show_progress_bar=True) |
|
|
|
|
|
metadatas = [] |
|
for i, row in tqdm(df.iterrows(), total=len(df), desc="Preparing metadata"): |
|
meta = {"author": row['author']} |
|
|
|
for j, tag in enumerate(row['tags']): |
|
meta[f"tag_{j}"] = tag |
|
metadatas.append(meta) |
|
|
|
|
|
ids = [str(i) for i in range(len(texts))] |
|
collection.add( |
|
embeddings=embeddings.tolist(), |
|
documents=texts, |
|
metadatas=metadatas, |
|
ids=ids |
|
) |
|
st.success(f"Created new collection with {len(texts)} quotes") |
|
|
|
return collection |
|
|
|
|
|
def retrieve_quotes(query, collection, k=5): |
|
embed_model = SentenceTransformer('all-MiniLM-L6-v2') |
|
processed_query = preprocess_text(query) |
|
query_embedding = embed_model.encode([processed_query])[0].tolist() |
|
|
|
results = collection.query( |
|
query_embeddings=[query_embedding], |
|
n_results=k |
|
) |
|
|
|
return [ |
|
f"\"{doc}\" β {meta['author']}" |
|
for doc, meta in zip(results['documents'][0], results['metadatas'][0]) |
|
] |
|
|
|
|
|
def main(): |
|
st.set_page_config(page_title="Quote Finder", layout="wide") |
|
|
|
st.title("Intelligent Quote Finder") |
|
st.markdown("Find relevant quotes using semantic search powered by ChromaDB") |
|
|
|
|
|
try: |
|
df = pd.read_json("https://huggingface.co/datasets/Abirate/english_quotes/resolve/main/quotes.jsonl", lines=True) |
|
st.session_state.df = df |
|
except Exception as e: |
|
st.error(f"Error loading data: {e}") |
|
return |
|
|
|
|
|
if 'collection' not in st.session_state: |
|
with st.spinner("Initializing database..."): |
|
st.session_state.collection = init_chroma_db(df) |
|
|
|
|
|
with st.sidebar: |
|
st.header("Search Parameters") |
|
k = st.slider("Number of results", 1, 20, 5) |
|
st.divider() |
|
st.header("Database Info") |
|
st.write(f"Total quotes: {len(df)}") |
|
st.write(f"Authors: {df['author'].nunique()}") |
|
st.write(f"Tags: {sum(len(tags) for tags in df['tags'])}") |
|
|
|
|
|
query = st.text_input("Search for quotes", placeholder="Enter a topic, emotion, or phrase...") |
|
|
|
if st.button("Find Quotes") or query: |
|
if not query: |
|
st.warning("Please enter a search query") |
|
return |
|
|
|
with st.spinner(f"Finding quotes related to '{query}'..."): |
|
results = retrieve_quotes(query, st.session_state.collection, k) |
|
|
|
if results: |
|
st.subheader(f"Top {len(results)} matching quotes:") |
|
for i, quote in enumerate(results, 1): |
|
st.markdown(f"#### {i}. {quote}") |
|
else: |
|
st.warning("No matching quotes found") |
|
|
|
if __name__ == "__main__": |
|
main() |