Spaces:
Sleeping
Sleeping
File size: 6,760 Bytes
2e98c79 3bbb5f4 2e98c79 331b253 08029e5 331b253 10f043b 7d6132f 4295f9e 7d6132f 2ec7158 2e98c79 2a5653d 331b253 bebffa2 331b253 2e98c79 331b253 2e8a9c7 331b253 471f612 331b253 2e98c79 331b253 2e98c79 0b7e756 f14ec94 0b7e756 f14ec94 0b7e756 f14ec94 0b7e756 471f612 dde2538 08029e5 471f612 08029e5 f14ec94 92ed022 dde2538 331b253 dde2538 7d6132f ef39d37 471f612 dde2538 7d6132f 08029e5 2a5653d ef39d37 2e98c79 f14ec94 ef39d37 7d6132f dde2538 ef39d37 2e98c79 7d6132f 2a5653d f14ec94 331b253 471f612 331b253 2e98c79 f14ec94 |
|
import os
import gradio as gr
import chromadb
from sentence_transformers import SentenceTransformer
import spaces
client = chromadb.PersistentClient(path="./chroma")
collection_de = client.get_collection(name="phil_de")
#collection_en = client.get_collection(name="phil_en")
authors_list_de = ["Ludwig Wittgenstein", "Sigmund Freud", "Marcus Aurelius", "Friedrich Nietzsche", "Epiktet", "Ernst Jünger", "Georg Christoph Lichtenberg", "Balthasar Gracian", "Hannah Arendt", "Erich Fromm", "Albert Camus"]
#authors_list_en = ["Friedrich Nietzsche", "Joscha Bach"]
@spaces.GPU
def get_embeddings(queries, task):
model = SentenceTransformer("Linq-AI-Research/Linq-Embed-Mistral", use_auth_token=os.getenv("HF_TOKEN"))
prompts = [f"Instruct: {task}\nQuery: {query}" for query in queries]
query_embeddings = model.encode(prompts)
return query_embeddings
def query_chroma(collection, embedding, authors):
results = collection.query(
query_embeddings=[embedding.tolist()],
n_results=20,
where={"author": {"$in": authors}} if authors else {},
include=["documents", "metadatas", "distances"]
)
ids = results.get('ids', [[]])[0]
metadatas = results.get('metadatas', [[]])[0]
documents = results.get('documents', [[]])[0]
distances = results.get('distances', [[]])[0]
formatted_results = []
for id_, metadata, document_text, distance in zip(ids, metadatas, documents, distances):
result_dict = {
"id": id_,
"author": metadata.get('author', ''),
"book": metadata.get('book', ''),
"section": metadata.get('section', ''),
"title": metadata.get('title', ''),
"text": document_text,
"distance": distance
}
formatted_results.append(result_dict)
return formatted_results
theme = gr.themes.Soft(
primary_hue="indigo",
secondary_hue="slate",
neutral_hue="slate",
spacing_size="lg",
radius_size="lg",
text_size="lg",
font=["Helvetica", "sans-serif"],
font_mono=["Courier", "monospace"],
).set(
body_text_color="*neutral_800",
block_background_fill="*neutral_50",
block_border_width="0px",
button_primary_background_fill="*primary_600",
button_primary_background_fill_hover="*primary_700",
button_primary_text_color="white",
input_background_fill="white",
input_border_color="*neutral_200",
input_border_width="1px",
checkbox_background_color_selected="*primary_600",
checkbox_border_color_selected="*primary_600",
)
custom_css = """
/* Remove outer padding, margins, and borders */
gradio-app,
gradio-app > div,
gradio-app .gradio-container {
padding: 0 !important;
margin: 0 !important;
border: none !important;
}
/* Remove any potential outlines */
gradio-app:focus,
gradio-app > div:focus,
gradio-app .gradio-container:focus {
outline: none !important;
}
/* Ensure full width */
gradio-app {
width: 100% !important;
display: block !important;
}
.custom-markdown {
border: 1px solid var(--neutral-200);
padding: 10px;
border-radius: var(--radius-lg);
background-color: var(--color-background-primary);
margin-bottom: 15px;
}
.custom-markdown p {
margin-bottom: 10px;
line-height: 1.6;
}
@media (max-width: 768px) {
gradio-app,
gradio-app > div,
gradio-app .gradio-container {
padding-left: 1px !important;
padding-right: 1px !important;
}
.custom-markdown {
padding: 5px;
}
.accordion {
margin-left: -10px;
margin-right: -10px;
}
}
"""
with gr.Blocks(theme=theme, css=custom_css) as demo:
gr.Markdown("Geben Sie ein, wonach Sie suchen möchten (Query), trennen Sie mehrere Suchanfragen durch Semikola; filtern Sie nach Autoren (ohne Auswahl werden alle durchsucht) und klicken Sie auf **Suchen**, um zu suchen.")
#database_inp = gr.Dropdown(label="Database", choices=["German", "English"], value="German")
author_inp = gr.Dropdown(label="Autoren", choices=authors_list_de, multiselect=True)
inp = gr.Textbox(label="Query", lines=3, placeholder="Wie kann ich gesund leben?; Wie kann ich mich besser konzentrieren?; Was ist der Sinn des Lebens?; ...")
btn = gr.Button("Suchen")
loading_indicator = gr.Markdown(visible=False, elem_id="loading-indicator")
results = gr.State()
#def update_authors(database):
# return gr.update(choices=authors_list_de if database == "German" else authors_list_en)
#database_inp.change(
# fn=lambda database: update_authors(database),
# inputs=[database_inp],
# outputs=[author_inp]
#)
def perform_query(queries, authors):
task = "Suche den zur Frage passenden Text"
queries = [query.strip() for query in queries.split(';')]
embeddings = get_embeddings(queries, task)
collection = collection_de
results_data = []
for query, embedding in zip(queries, embeddings):
res = query_chroma(collection, embedding, authors)
results_data.append((query, res))
return results_data, ""
btn.click(
fn=lambda: ("", gr.update(visible=True)),
inputs=None,
outputs=[loading_indicator, loading_indicator],
queue=False
).then(
perform_query,
inputs=[inp, author_inp],
outputs=[results, loading_indicator]
)
@gr.render(inputs=[results])
def display_accordion(data):
for query, res in data:
with gr.Accordion(query, open=False, elem_classes="accordion") as acc:
for result in res:
with gr.Column():
author = str(result.get('author', ''))
book = str(result.get('book', ''))
section = str(result.get('section', ''))
title = str(result.get('title', ''))
text = str(result.get('text', ''))
header_parts = []
if author and author != "Unknown":
header_parts.append(author)
if book and book != "Unknown":
header_parts.append(book)
if section and section != "Unknown":
header_parts.append(section)
if title and title != "Unknown":
header_parts.append(title)
header = ", ".join(header_parts)
markdown_contents = f"**{header}**\n\n{text}"
gr.Markdown(value=markdown_contents, elem_classes="custom-markdown")
demo.launch(inline=False) |