derek-thomas HF staff commited on
Commit
8b15eea
1 Parent(s): 1089f86

Add gradio app!

Browse files
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+ from time import perf_counter
4
+
5
+ import gradio as gr
6
+ from jinja2 import Environment, FileSystemLoader
7
+
8
+ from backend.query_llm import generate
9
+ from backend.semantic_search import retriever
10
+
11
+ proj_dir = Path(__file__).parent
12
+ # Setting up the logging
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Set up the template environment with the templates directory
17
+ env = Environment(loader=FileSystemLoader(proj_dir / 'templates'))
18
+
19
+ # Load the templates directly from the environment
20
+ template = env.get_template('template.j2')
21
+ template_html = env.get_template('template_html.j2')
22
+
23
+
24
+ def add_text(history, text):
25
+ history = [] if history is None else history
26
+ history = history + [(text, None)]
27
+ return history, gr.Textbox(value="", interactive=False)
28
+
29
+
30
+ def bot(history, system_prompt=""):
31
+ top_k = 3
32
+ query = history[-1][0]
33
+
34
+ logger.warning('Retrieving documents...')
35
+ # Retrieve documents relevant to query
36
+ document_start = perf_counter()
37
+ documents = retriever(query, top_k=top_k)
38
+ document_time = document_start - perf_counter()
39
+ logger.warning(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
40
+
41
+ # Create Prompt
42
+ prompt = template.render(documents=documents, query=query)
43
+ prompt_html = template_html.render(documents=documents, query=query)
44
+ logger.warning(prompt)
45
+
46
+ history[-1][1] = ""
47
+ for character in generate(prompt):
48
+ history[-1][1] = character
49
+ yield history, prompt_html
50
+
51
+
52
+ with gr.Blocks() as demo:
53
+ with gr.Tab("Application"):
54
+ chatbot = gr.Chatbot(
55
+ [],
56
+ elem_id="chatbot",
57
+ avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg',
58
+ 'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'),
59
+ bubble_full_width=False,
60
+ show_copy_button=True,
61
+ show_share_button=True,
62
+ )
63
+
64
+ with gr.Row():
65
+ txt = gr.Textbox(
66
+ scale=3,
67
+ show_label=False,
68
+ placeholder="Enter text and press enter",
69
+ container=False,
70
+ )
71
+ txt_btn = gr.Button(value="Submit text", scale=1)
72
+
73
+ prompt_html = gr.HTML()
74
+ # Turn off interactivity while generating if you hit enter
75
+ txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
76
+ bot, chatbot, [chatbot, prompt_html])
77
+
78
+ # Turn it back on
79
+ txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
80
+
81
+ # Turn off interactivity while generating if you hit enter
82
+ txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
83
+ bot, chatbot, [chatbot, prompt_html])
84
+
85
+ # Turn it back on
86
+ txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
87
+
88
+ gr.Examples(['What is the capital of China, I think its Shanghai?',
89
+ 'Why is the sky blue?',
90
+ 'Who won the mens world cup in 2014?',], txt)
91
+
92
+ demo.queue()
93
+ demo.launch(debug=True)
backend/query_llm.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from os import getenv
3
+
4
+
5
+ API_URL = getenv('API_URL')
6
+ BEARER = getenv('BEARER')
7
+
8
+
9
+ headers = {
10
+ "Authorization": f"Bearer {BEARER}",
11
+ "Content-Type": "application/json"
12
+ }
13
+
14
+ def call_jais(payload):
15
+ response = requests.post(API_URL, headers=headers, json=payload)
16
+ return response.json()
17
+
18
+ def generate(prompt: str):
19
+ payload = {'inputs': '', 'prompt':prompt}
20
+ response = call_jais(payload)
21
+ return response
backend/semantic_search.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+ import time
4
+
5
+ import lancedb
6
+ from sentence_transformers import SentenceTransformer
7
+
8
+
9
+ # Setting up the logging
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+ # Start the timer for loading the QdrantDocumentStore
14
+ start_time = time.perf_counter()
15
+
16
+ proj_dir = Path(__file__).parents[1]
17
+
18
+ # Log the time taken to load the QdrantDocumentStore
19
+ db = lancedb.connect(proj_dir/"lancedb")
20
+ tbl = db.open_table('arabic-wiki')
21
+ lancedb_loading_time = time.perf_counter() - start_time
22
+ logger.info(f"Time taken to load LanceDB: {lancedb_loading_time:.6f} seconds")
23
+
24
+ # Start the timer for loading the EmbeddingRetriever
25
+ start_time = time.perf_counter()
26
+
27
+ name="sentence-transformers/paraphrase-multilingual-minilm-l12-v2"
28
+ st_model = SentenceTransformer(name, device='cuda')
29
+
30
+ # used for both training and querying
31
+ def embed_func(query):
32
+ return st_model.encode(query)
33
+
34
+ def vector_search(query_vector, top_k):
35
+ return tbl.search(query_vector).limit(top_k).to_list()
36
+
37
+ def retriever(query, top_k=3):
38
+ query_vector = embed_func(query)
39
+ documents = vector_search(query_vector, top_k)
40
+ return documents
41
+
42
+
43
+ # Log the time taken to load the EmbeddingRetriever
44
+ retriever_loading_time = time.perf_counter() - start_time
45
+ logger.info(f"Time taken to load EmbeddingRetriever: {retriever_loading_time:.6f} seconds")
templates/template.j2 ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Instruction: Use the following unique documents in the Context section to answer the Query at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
2
+ ### Context
3
+ {% for doc in documents %}
4
+ ---
5
+ {{ doc.content }}
6
+ {% endfor %}
7
+ ---
8
+ [|AI|]:
9
+ ### Query: [|Human|] {{query}}
10
+ ### Response: [|AI|]
templates/template_html.j2 ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <h2>Prompt</h2>
2
+ Below is the prompt that is given to the model. <hr>
3
+ <h2>Instruction:</h2>
4
+ <span style="color: #FF00FF;">Use the following unique documents in the Context section to answer the Query at the end. If you don't know the answer, just say that you don't know, <span style="color: #FF00FF; font-weight: bold;">don't try to make up an answer.</span></span><br>
5
+ <h2>Context</h2>
6
+ {% for doc in documents %}
7
+ <details class="doc-box">
8
+ <summary>
9
+ <b>Doc {{ loop.index }}:</b> <span class="doc-short">{{ doc.content[:100] }}...</span>
10
+ </summary>
11
+ <div class="doc-full">{{ doc.content }}</div>
12
+ </details>
13
+ {% endfor %}
14
+
15
+ <h2>Query</h2> <span style="color: #801616;">{{ query }}</span>
16
+
17
+ <style>
18
+ .doc-box {
19
+ padding: 10px;
20
+ margin-top: 10px;
21
+ background-color: #48a3ff;
22
+ border: none;
23
+ }
24
+ .doc-short, .doc-full {
25
+ color: white;
26
+ }
27
+ summary::-webkit-details-marker {
28
+ color: white;
29
+ }
30
+ </style>
31
+
32
+ <script>
33
+ document.addEventListener("DOMContentLoaded", function() {
34
+ const detailsElements = document.querySelectorAll('.doc-box');
35
+
36
+ detailsElements.forEach(detail => {
37
+ detail.addEventListener('toggle', function() {
38
+ const docShort = this.querySelector('.doc-short');
39
+ if (this.open) {
40
+ docShort.style.display = 'none';
41
+ } else {
42
+ docShort.style.display = 'inline';
43
+ }
44
+ });
45
+ });
46
+ });
47
+ </script>