import os import gradio as gr import chromadb from sentence_transformers import SentenceTransformer import spaces @spaces.GPU def get_embeddings(query, task): model = SentenceTransformer("Linq-AI-Research/Linq-Embed-Mistral", use_auth_token=os.getenv("HF_TOKEN")) task = "Given a question, retrieve passages that answer the question" prompt = f"Instruct: {task}\nQuery: {query}" query_embeddings = model.encode([prompt]) return query_embeddings # Initialize a persistent Chroma client and retrieve collection client = chromadb.PersistentClient(path="./chroma") collection = client.get_collection(name="phil_de") authors_list = ["Ludwig Wittgenstein", "Sigmund Freud", "Marcus Aurelius", "Friedrich Nietzsche", "Epiktet", "Ernst Jünger", "Georg Christoph Lichtenberg", "Balthasar Gracian", "Hannah Arendt", "Erich Fromm", "Albert Camus"] #authors_list = ["Friedrich Nietzsche", "Joscha Bach"] def query_chroma(embeddings, authors, num_results=10): try: where_filter = {"author": {"$in": authors}} if authors else {} results = collection.query( query_embeddings=[embeddings], n_results=num_results, where=where_filter, include=["documents", "metadatas", "distances"] ) ids = results.get('ids', [[]])[0] metadatas = results.get('metadatas', [[]])[0] documents = results.get('documents', [[]])[0] distances = results.get('distances', [[]])[0] formatted_results = [] for id_, metadata, document_text, distance in zip(ids, metadatas, documents, distances): result_dict = { "id": id_, "author": metadata.get('author', 'Unknown author'), "book": metadata.get('book', 'Unknown book'), "section": metadata.get('section', 'Unknown section'), "title": metadata.get('title', 'Untitled'), "text": document_text, "distance": distance } formatted_results.append(result_dict) return formatted_results except Exception as e: return {"error": str(e)} # Main function def perform_query(query, authors, num_results): task = "Given a question, retrieve passages that answer the question" embeddings = get_embeddings(query, task) results = query_chroma(embeddings, authors, num_results) if "error" in results: return [gr.update(visible=True, value=f"Error: {results['error']}") for _ in range(max_textboxes * 3)] updates = [] for res in results: markdown_content = f"**{res['author']}, {res['book']}, Distance: {res['distance']}**\n\n{res['text']}" updates.append(gr.update(visible=True, value=markdown_content)) updates.append(gr.update(visible=True, value="Flag", elem_id=f"flag-{len(updates)//2}")) updates.append(gr.update(visible=False, value=res['id'])) # Hide the ID textbox updates += [gr.update(visible=False)] * (3 * (max_textboxes - len(results))) return updates # Initialize the CSVLogger callback for flagging callback = gr.CSVLogger() def flag_output(query, output_text, output_id): callback.flag([query, output_text, output_id]) # Gradio interface max_textboxes = 30 with gr.Blocks(css=".custom-markdown { border: 1px solid #ccc; padding: 10px; border-radius: 5px; }") as demo: gr.Markdown("Enter your query, filter authors (default is all), click **Search** to search. Click **Flag** if a result is relevant to the query and interesting to you. Try reranking the results.") with gr.Row(): with gr.Column(): inp = gr.Textbox(label="query", placeholder="Enter thought...") author_inp = gr.Dropdown(label="authors", choices=authors_list, multiselect=True) num_results_inp = gr.Number(label="number of results", value=10, step=1, minimum=1, maximum=max_textboxes) btn = gr.Button("Search") components = [] textboxes = [] flag_buttons = [] ids = [] for _ in range(max_textboxes): with gr.Column() as col: text_out = gr.Markdown(visible=False, elem_classes="custom-markdown") flag_btn = gr.Button(value="Flag", visible=False) id_out = gr.Textbox(visible=False) components.extend([text_out, flag_btn, id_out]) textboxes.append(text_out) flag_buttons.append(flag_btn) ids.append(id_out) callback.setup([inp] + textboxes + ids, "flagged_data_points") btn.click( fn=perform_query, inputs=[inp, author_inp, num_results_inp], outputs=components ) for i in range(0, len(components), 3): flag_buttons[i//3].click( fn=flag_output, inputs=[inp, textboxes[i//3], ids[i//3]], outputs=[], preprocess=False ) demo.launch()