File size: 6,842 Bytes
2e98c79
 
 
 
3bbb5f4
2e98c79
331b253
dde2538
331b253
c18b9d6
 
331b253
10f043b
7d6132f
4295f9e
7d6132f
 
2ec7158
2e98c79
2a5653d
331b253
 
e59705e
331b253
 
 
2e98c79
331b253
 
 
 
2e8a9c7
331b253
 
 
 
e59705e
 
 
 
331b253
 
 
 
2e98c79
331b253
2e98c79
e59705e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dde2538
 
e59705e
7d6132f
e59705e
92ed022
 
dde2538
 
331b253
dde2538
 
 
 
 
7d6132f
 
 
dde2538
7d6132f
dde2538
 
2a5653d
 
 
 
e59705e
2e98c79
 
e59705e
 
 
 
 
7d6132f
dde2538
e59705e
2e98c79
7d6132f
2a5653d
 
 
e59705e
331b253
 
e59705e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331b253
2e98c79
e59705e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import os
import gradio as gr
import chromadb
from sentence_transformers import SentenceTransformer
import spaces

client = chromadb.PersistentClient(path="./chroma")
#collection_de = client.get_collection(name="phil_de")
collection_en = client.get_collection(name="phil_en")
#authors_list_de = ["Epikur", "Ludwig Wittgenstein", "Sigmund Freud", "Marcus Aurelius", "Friedrich Nietzsche", "Epiktet", "Ernst Jünger", "Georg Christoph Lichtenberg", "Balthasar Gracian", "Hannah Arendt", "Erich Fromm", "Albert Camus"]
authors_list_en = ["Friedrich Nietzsche", "Joscha Bach", "Hannah Arendt", "Albert Camus", "Mark Fisher"]

@spaces.GPU
def get_embeddings(queries, task):
    model = SentenceTransformer("Linq-AI-Research/Linq-Embed-Mistral", use_auth_token=os.getenv("HF_TOKEN"))
    prompts = [f"Instruct: {task}\nQuery: {query}" for query in queries]
    query_embeddings = model.encode(prompts)  
    return query_embeddings

def query_chroma(collection, embedding, authors):
    results = collection.query(
        query_embeddings=[embedding.tolist()],
        n_results=20,
        where={"author": {"$in": authors}} if authors else {},
        include=["documents", "metadatas", "distances"]
    )

    ids = results.get('ids', [[]])[0]
    metadatas = results.get('metadatas', [[]])[0]
    documents = results.get('documents', [[]])[0]
    distances = results.get('distances', [[]])[0]

    formatted_results = []
    for id_, metadata, document_text, distance in zip(ids, metadatas, documents, distances):
        result_dict = {
            "id": id_,
            "author": metadata.get('author', ''),
            "book": metadata.get('book', ''),
            "section": metadata.get('section', ''),
            "title": metadata.get('title', ''),
            "text": document_text,
            "distance": distance
        }
        formatted_results.append(result_dict)

    return formatted_results

theme = gr.themes.Soft(
    primary_hue="indigo",
    secondary_hue="slate",
    neutral_hue="slate",
    spacing_size="lg",
    radius_size="lg",
    text_size="lg",
    font=["Helvetica", "sans-serif"],
    font_mono=["Courier", "monospace"],
).set(
    body_text_color="*neutral_800",
    block_background_fill="*neutral_50",
    block_border_width="0px",
    button_primary_background_fill="*primary_600",
    button_primary_background_fill_hover="*primary_700",
    button_primary_text_color="white",
    input_background_fill="white",
    input_border_color="*neutral_200",
    input_border_width="1px",
    checkbox_background_color_selected="*primary_600",
    checkbox_border_color_selected="*primary_600",
)

custom_css = """
/* Remove outer padding, margins, and borders */
gradio-app,
gradio-app > div,
gradio-app .gradio-container {
    padding: 0 !important;
    margin: 0 !important;
    border: none !important;
}

/* Remove any potential outlines */
gradio-app:focus,
gradio-app > div:focus,
gradio-app .gradio-container:focus {
    outline: none !important;
}

/* Ensure full width */
gradio-app {
    width: 100% !important;
    display: block !important;
}

.custom-markdown { 
    border: 1px solid var(--neutral-200); 
    padding: 10px; 
    border-radius: var(--radius-lg);
    background-color: var(--color-background-primary);
    margin-bottom: 15px;
}
.custom-markdown p {
    margin-bottom: 10px;
    line-height: 1.6;
}

@media (max-width: 768px) {
    gradio-app, 
    gradio-app > div,
    gradio-app .gradio-container {
        padding-left: 1px !important;
        padding-right: 1px !important;
    }
    .custom-markdown {
        padding: 5px;
    }
    .accordion {
        margin-left: -10px;
        margin-right: -10px;
    }
}
"""

with gr.Blocks(theme=theme, css=custom_css)  as demo:
    gr.Markdown("Enter one or more queries, divide them with semicola; filter authors (default is all), click **Search** to search.")
    #database_inp = gr.Dropdown(label="Database", choices=["German", "English"], value="German")
    author_inp = gr.Dropdown(label="Authors", choices=authors_list_en, multiselect=True)
    inp = gr.Textbox(label="Query", lines=3, placeholder="How can I live a healthy life?; How can I improve my ability to focus?; What is the meaning of life?; ...")
    btn = gr.Button("Search")
    loading_indicator = gr.Markdown(visible=False, elem_id="loading-indicator")
    results = gr.State()

    #def update_authors(database):
    #    return gr.update(choices=authors_list_de if database == "German" else authors_list_en)

    #database_inp.change(
    #    fn=lambda database: update_authors(database),
    #    inputs=[database_inp],
    #    outputs=[author_inp]
    #)

    def perform_query(queries, authors, database):
        task = "Given a question, retrieve passages that answer the question"
        queries = [query.strip() for query in queries.split(';')]
        embeddings = get_embeddings(queries, task)
        #collection = collection_de if database == "German" else collection_en
        collection = collection_en
        results_data = []
        for query, embedding in zip(queries, embeddings):
            res = query_chroma(collection, embedding, authors)
            results_data.append((query, res))
        return results_data, ""

    btn.click(
        fn=lambda: ("", gr.update(visible=True)),
        inputs=None,
        outputs=[loading_indicator, loading_indicator],
        queue=False
    ).then(
        perform_query,
        inputs=[inp, author_inp],
        outputs=[results, loading_indicator]
    )

    @gr.render(inputs=[results])
    def display_accordion(data):
        for query, res in data:
            with gr.Accordion(query, open=False, elem_classes="accordion") as acc:
                for result in res:
                    with gr.Column():
                        author = str(result.get('author', ''))
                        book = str(result.get('book', ''))
                        section = str(result.get('section', ''))
                        title = str(result.get('title', ''))
                        text = str(result.get('text', ''))

                        header_parts = []
                        if author and author != "Unknown":
                            header_parts.append(author)
                        if book and book != "Unknown":
                            header_parts.append(book)
                        if section and section != "Unknown":
                            header_parts.append(section)
                        if title and title != "Unknown":
                            header_parts.append(title)
                        
                        header = ", ".join(header_parts)
                        markdown_contents = f"**{header}**\n\n{text}"
                        gr.Markdown(value=markdown_contents, elem_classes="custom-markdown")

demo.launch(inline=False)