Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import chromadb | |
from sentence_transformers import SentenceTransformer | |
import spaces | |
client = chromadb.PersistentClient(path="./chroma") | |
#collection_de = client.get_collection(name="phil_de") | |
collection_en = client.get_collection(name="phil_en") | |
#authors_list_de = ["Ludwig Wittgenstein", "Sigmund Freud", "Marcus Aurelius", "Friedrich Nietzsche", "Epiktet", "Ernst Jünger", "Georg Christoph Lichtenberg", "Balthasar Gracian", "Hannah Arendt", "Erich Fromm", "Albert Camus"] | |
authors_list_en = ["Friedrich Nietzsche", "Joscha Bach"] | |
def get_embeddings(queries, task): | |
model = SentenceTransformer("Linq-AI-Research/Linq-Embed-Mistral", use_auth_token=os.getenv("HF_TOKEN")) | |
prompts = [f"Instruct: {task}\nQuery: {query}" for query in queries] | |
query_embeddings = model.encode(prompts) | |
return query_embeddings | |
def query_chroma(collection, embedding, authors): | |
results = collection.query( | |
query_embeddings=[embedding.tolist()], | |
n_results=20, | |
where={"author": {"$in": authors}} if authors else {}, | |
include=["documents", "metadatas", "distances"] | |
) | |
ids = results.get('ids', [[]])[0] | |
metadatas = results.get('metadatas', [[]])[0] | |
documents = results.get('documents', [[]])[0] | |
distances = results.get('distances', [[]])[0] | |
formatted_results = [] | |
for id_, metadata, document_text, distance in zip(ids, metadatas, documents, distances): | |
result_dict = { | |
"id": id_, | |
"author": metadata.get('author', ''), | |
"book": metadata.get('book', ''), | |
"section": metadata.get('section', ''), | |
"title": metadata.get('title', ''), | |
"text": document_text, | |
"distance": distance | |
} | |
formatted_results.append(result_dict) | |
return formatted_results | |
theme = gr.themes.Soft( | |
primary_hue="indigo", | |
secondary_hue="slate", | |
neutral_hue="slate", | |
spacing_size="lg", | |
radius_size="lg", | |
text_size="lg", | |
font=["Helvetica", "sans-serif"], | |
font_mono=["Courier", "monospace"], | |
).set( | |
body_text_color="*neutral_800", | |
block_background_fill="*neutral_50", | |
block_border_width="0px", | |
button_primary_background_fill="*primary_600", | |
button_primary_background_fill_hover="*primary_700", | |
button_primary_text_color="white", | |
input_background_fill="white", | |
input_border_color="*neutral_200", | |
input_border_width="1px", | |
checkbox_background_color_selected="*primary_600", | |
checkbox_border_color_selected="*primary_600", | |
) | |
custom_css = """ | |
/* Remove outer padding, margins, and borders */ | |
gradio-app, | |
gradio-app > div, | |
gradio-app .gradio-container { | |
padding: 0 !important; | |
margin: 0 !important; | |
border: none !important; | |
} | |
/* Remove any potential outlines */ | |
gradio-app:focus, | |
gradio-app > div:focus, | |
gradio-app .gradio-container:focus { | |
outline: none !important; | |
} | |
/* Ensure full width */ | |
gradio-app { | |
width: 100% !important; | |
display: block !important; | |
} | |
.custom-markdown { | |
border: 1px solid var(--neutral-200); | |
padding: 10px; | |
border-radius: var(--radius-lg); | |
background-color: var(--color-background-primary); | |
margin-bottom: 15px; | |
} | |
.custom-markdown p { | |
margin-bottom: 10px; | |
line-height: 1.6; | |
} | |
@media (max-width: 768px) { | |
gradio-app, | |
gradio-app > div, | |
gradio-app .gradio-container { | |
padding-left: 1px !important; | |
padding-right: 1px !important; | |
} | |
.custom-markdown { | |
padding: 5px; | |
} | |
.accordion { | |
margin-left: -10px; | |
margin-right: -10px; | |
} | |
} | |
""" | |
with gr.Blocks(theme=theme, css=custom_css) as demo: | |
gr.Markdown("Enter one or more queries, divide them with semicola; filter authors (default is all), click **Search** to search.") | |
#database_inp = gr.Dropdown(label="Database", choices=["German", "English"], value="German") | |
author_inp = gr.Dropdown(label="Authors", choices=authors_list_en, multiselect=True) | |
inp = gr.Textbox(label="Query", lines=3, placeholder="How can I live a healthy life?; How can I improve my ability to focus?; What is the meaning of life?; ...") | |
btn = gr.Button("Search") | |
loading_indicator = gr.Markdown(visible=False, elem_id="loading-indicator") | |
results = gr.State() | |
#def update_authors(database): | |
# return gr.update(choices=authors_list_de if database == "German" else authors_list_en) | |
#database_inp.change( | |
# fn=lambda database: update_authors(database), | |
# inputs=[database_inp], | |
# outputs=[author_inp] | |
#) | |
def perform_query(queries, authors, database): | |
task = "Given a question, retrieve passages that answer the question" | |
queries = [query.strip() for query in queries.split(';')] | |
embeddings = get_embeddings(queries, task) | |
#collection = collection_de if database == "German" else collection_en | |
collection = collection_en | |
results_data = [] | |
for query, embedding in zip(queries, embeddings): | |
res = query_chroma(collection, embedding, authors) | |
results_data.append((query, res)) | |
return results_data, "" | |
btn.click( | |
fn=lambda: ("", gr.update(visible=True)), | |
inputs=None, | |
outputs=[loading_indicator, loading_indicator], | |
queue=False | |
).then( | |
perform_query, | |
inputs=[inp, author_inp], | |
outputs=[results, loading_indicator] | |
) | |
def display_accordion(data): | |
for query, res in data: | |
with gr.Accordion(query, open=False, elem_classes="accordion") as acc: | |
for result in res: | |
with gr.Column(): | |
author = str(result.get('author', '')) | |
book = str(result.get('book', '')) | |
section = str(result.get('section', '')) | |
title = str(result.get('title', '')) | |
text = str(result.get('text', '')) | |
header_parts = [] | |
if author and author != "Unknown": | |
header_parts.append(author) | |
if book and book != "Unknown": | |
header_parts.append(book) | |
if section and section != "Unknown": | |
header_parts.append(section) | |
if title and title != "Unknown": | |
header_parts.append(title) | |
header = ", ".join(header_parts) | |
markdown_contents = f"**{header}**\n\n{text}" | |
gr.Markdown(value=markdown_contents, elem_classes="custom-markdown") | |
demo.launch(inline=False) |