HonestAnnie commited on
Commit
7d6132f
1 Parent(s): 2e8a9c7

buch of stuff

Browse files
Files changed (1) hide show
  1. app.py +31 -49
app.py CHANGED
@@ -5,10 +5,10 @@ from sentence_transformers import SentenceTransformer
5
  import spaces
6
 
7
  @spaces.GPU
8
- def get_embeddings(query, task):
9
  model = SentenceTransformer("Linq-AI-Research/Linq-Embed-Mistral", use_auth_token=os.getenv("HF_TOKEN"))
10
- prompt = f"Instruct: {task}\nQuery: {query}"
11
- query_embeddings = model.encode([prompt])
12
  return query_embeddings
13
 
14
  # Initialize a persistent Chroma client and retrieve collection
@@ -18,7 +18,7 @@ collection_en = client.get_collection(name="phil_en")
18
  authors_list_de = ["Ludwig Wittgenstein", "Sigmund Freud", "Marcus Aurelius", "Friedrich Nietzsche", "Epiktet", "Ernst Jünger", "Georg Christoph Lichtenberg", "Balthasar Gracian", "Hannah Arendt", "Erich Fromm", "Albert Camus"]
19
  authors_list_en = ["Friedrich Nietzsche", "Joscha Bach"]
20
 
21
- def query_chroma(collection, embeddings, authors, num_results=10):
22
  try:
23
  where_filter = {"author": {"$in": authors}} if authors else {}
24
 
@@ -26,7 +26,7 @@ def query_chroma(collection, embeddings, authors, num_results=10):
26
 
27
  results = collection.query(
28
  query_embeddings=[embeddings_list],
29
- n_results=num_results,
30
  where=where_filter,
31
  include=["documents", "metadatas", "distances"]
32
  )
@@ -53,59 +53,41 @@ def query_chroma(collection, embeddings, authors, num_results=10):
53
  except Exception as e:
54
  return {"error": str(e)}
55
 
56
-
57
- # Main function
58
- def perform_query(query, authors, num_results, database):
59
- task = "Given a question, retrieve passages that answer the question"
60
- embeddings = get_embeddings(query, task)
61
- collection = collection_de if database == "German" else collection_en
62
- results = query_chroma(collection, embeddings, authors, num_results)
63
-
64
- if "error" in results:
65
- return [gr.update(visible=True, value=f"Error: {results['error']}") for _ in range(max_textboxes * 2)]
66
-
67
- updates = []
68
- for res in results:
69
- markdown_content = f"**{res['author']}, {res['book']}**\n\n{res['text']}"
70
- updates.append(gr.update(visible=True, value=markdown_content))
71
-
72
- updates += [gr.update(visible=False)] * (max_textboxes - len(results))
73
-
74
- return updates
75
-
76
  def update_authors(database):
77
  return gr.update(choices=authors_list_de if database == "German" else authors_list_en)
78
 
79
- # Gradio interface
80
- max_textboxes = 30
81
 
82
- with gr.Blocks(css=".custom-markdown { border: 1px solid #ccc; padding: 10px; border-radius: 5px; }") as demo:
83
- gr.Markdown("Enter your query, filter authors (default is all), click **Search** to search. Click **Flag** if a result is relevant to the query and interesting to you.")
84
- with gr.Row():
85
- with gr.Column():
86
- database_inp = gr.Dropdown(label="Database", choices=["English", "German"], value="German")
87
- inp = gr.Textbox(label="query", placeholder="Enter question...")
88
- author_inp = gr.Dropdown(label="authors", choices=authors_list_de, multiselect=True)
89
- num_results_inp = gr.Number(label="number of results", value=10, step=1, minimum=1, maximum=max_textboxes)
90
- btn = gr.Button("Search")
91
 
92
- components = []
 
 
 
 
 
 
 
 
 
 
 
93
 
94
- for _ in range(max_textboxes):
95
- with gr.Column() as col:
96
- text_out = gr.Markdown(visible=False, elem_classes="custom-markdown")
97
- components.append(text_out)
 
 
98
 
99
  btn.click(
100
- fn=perform_query,
101
- inputs=[inp, author_inp, num_results_inp, database_inp],
102
- outputs=components
103
  )
104
-
105
  database_inp.change(
106
- fn=update_authors,
107
- inputs=database_inp,
108
- outputs=author_inp
109
  )
110
 
111
- demo.launch()
 
5
  import spaces
6
 
7
  @spaces.GPU
8
+ def get_embeddings(queries, task):
9
  model = SentenceTransformer("Linq-AI-Research/Linq-Embed-Mistral", use_auth_token=os.getenv("HF_TOKEN"))
10
+ prompts = [f"Instruct: {task}\nQuery: {query}" for query in queries]
11
+ query_embeddings = model.encode(prompts)
12
  return query_embeddings
13
 
14
  # Initialize a persistent Chroma client and retrieve collection
 
18
  authors_list_de = ["Ludwig Wittgenstein", "Sigmund Freud", "Marcus Aurelius", "Friedrich Nietzsche", "Epiktet", "Ernst Jünger", "Georg Christoph Lichtenberg", "Balthasar Gracian", "Hannah Arendt", "Erich Fromm", "Albert Camus"]
19
  authors_list_en = ["Friedrich Nietzsche", "Joscha Bach"]
20
 
21
+ def query_chroma(collection, embeddings, authors):
22
  try:
23
  where_filter = {"author": {"$in": authors}} if authors else {}
24
 
 
26
 
27
  results = collection.query(
28
  query_embeddings=[embeddings_list],
29
+ n_results=10,
30
  where=where_filter,
31
  include=["documents", "metadatas", "distances"]
32
  )
 
53
  except Exception as e:
54
  return {"error": str(e)}
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  def update_authors(database):
57
  return gr.update(choices=authors_list_de if database == "German" else authors_list_en)
58
 
 
 
59
 
 
 
 
 
 
 
 
 
 
60
 
61
+ with gr.Blocks() as demo:
62
+ gr.Markdown("Enter your query, filter authors (default is all), click **Search** to search.")
63
+ database_inp = gr.Dropdown(label="Database", choices=["English", "German"], value="German")
64
+ author_inp = gr.Dropdown(label="Authors", choices=authors_list_de, multiselect=True)
65
+ inp = gr.Textbox(label="Query", placeholder="Enter questions separated by semicolons...")
66
+ btn = gr.Button("Search")
67
+
68
+ def perform_query(queries, authors, database):
69
+ queries = queries.split(';')
70
+ task = "Given a question, retrieve passages that answer the question"
71
+ embeddings = get_embeddings(queries, task)
72
+ collection = collection_de if database == "German" else collection_en
73
 
74
+ results = [query_chroma(collection, embedding, authors) for embedding in embeddings]
75
+
76
+ for query, result in zip(queries, results):
77
+ with gr.Accordion(query):
78
+ markdown_contents = "\n".join(f"**{res['author']}, {res['book']}**\n\n{res['text']}" for res in result)
79
+ gr.Markdown(value=markdown_contents)
80
 
81
  btn.click(
82
+ perform_query,
83
+ inputs=[inp, author_inp, database_inp],
84
+ outputs=[]
85
  )
86
+
87
  database_inp.change(
88
+ fn=lambda database: update_authors(database),
89
+ inputs=[database_inp],
90
+ outputs=[author_inp]
91
  )
92
 
93
+ demo.launch()