File size: 7,050 Bytes
8b15eea
3cb8374
8b15eea
 
 
 
 
5f75644
8b15eea
91f49a8
8b15eea
 
 
 
 
 
 
 
 
 
 
 
 
 
df1aa0b
4e966cd
df1aa0b
8de88bd
0665e63
3cb8374
 
8de88bd
df1aa0b
8b15eea
 
 
 
 
 
8de88bd
df1aa0b
8b15eea
 
 
 
 
8de88bd
7bda49c
8de88bd
 
 
 
 
 
8b15eea
 
df1aa0b
 
 
3cb8374
8b15eea
 
3cb8374
df1aa0b
 
 
 
 
 
 
 
8b15eea
 
 
7bda49c
 
 
8b15eea
5d0067c
 
 
 
 
 
 
 
 
 
1a7a096
 
5d0067c
 
8b15eea
 
5d0067c
 
8de88bd
8b15eea
 
 
 
 
 
 
 
 
 
 
 
 
 
cf11d3f
8b15eea
 
 
 
8de88bd
8b15eea
8de88bd
8b15eea
 
 
 
 
 
 
 
 
 
 
 
 
8de88bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cb8374
 
8de88bd
 
 
 
 
 
 
 
 
 
 
 
8b15eea
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import logging
from functools import partial
from pathlib import Path
from time import perf_counter

import gradio as gr
from jinja2 import Environment, FileSystemLoader
from transformers import AutoTokenizer

from backend.query_llm import check_endpoint_status, generate
from backend.semantic_search import retriever

proj_dir = Path(__file__).parent
# Setting up the logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set up the template environment with the templates directory
env = Environment(loader=FileSystemLoader(proj_dir / 'templates'))

# Load the templates directly from the environment
template = env.get_template('template.j2')
template_html = env.get_template('template_html.j2')

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained('derek-thomas/jais-13b-chat-hf')

# Examples
examples = ['من كان طرفي معركة اكتيوم البحرية؟',
            'لم السماء زرقاء؟',
            "من فاز بكأس العالم للرجال في عام 2014؟",]


def add_text(history, text):
    history = [] if history is None else history
    history = history + [(text, None)]
    return history, gr.Textbox(value="", interactive=False)


def bot(history, hyde=False):
    top_k = 5
    query = history[-1][0]

    logger.warning('Retrieving documents...')
    # Retrieve documents relevant to query
    document_start = perf_counter()
    if hyde:
        hyde_document = generate(f"Write a wikipedia article intro paragraph to answer this query: {query}").split('### Response: [|AI|]')[-1]

        logger.warning(hyde_document)
        documents = retriever(hyde_document, top_k=top_k)
    else:
        documents = retriever(query, top_k=top_k)
    document_time = perf_counter() - document_start
    logger.warning(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')

    # Function to count tokens
    def count_tokens(text):
        return len(tokenizer.encode(text))

    # Create Prompt
    prompt = template.render(documents=documents, query=query)

    # Check if the prompt is too long
    token_count = count_tokens(prompt)
    while token_count > 2048:
        # Shorten your documents here. This is just a placeholder for the logic you'd use.
        documents.pop()  # Remove the last document
        prompt = template.render(documents=documents, query=query)  # Re-render the prompt
        token_count = count_tokens(prompt)  # Re-count tokens

    prompt_html = template_html.render(documents=documents, query=query)

    history[-1][1] = ""
    response = generate(prompt)
    history[-1][1] = response.split('### Response: [|AI|]')[-1]
    return history, prompt_html

intro_md = """
# Arabic RAG
This is a project to demonstrate Retreiver Augmented Generation (RAG) in Arabic and English. It uses 
[Arabic Wikipedia](https://ar.wikipedia.org/wiki) as a base to answer questions you have. 
A retriever ([sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2](https://huggingface.co/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2/discussions/8))
 will find the articles relevant to your query and include them in a prompt so the reader ([core42/jais-13b-chat](https://huggingface.co/core42/jais-13b-chat)) 
 can then answer your questions on it.
 
You can see the prompt clearly displayed below the chatbot to understand what is going to the LLM. 

# Read this if you get an error
I'm using Inference Endpoint's Scale to Zero to save money on GPUs. If the staus shows its not "Running" send a 
chat to wake it up. You will get a `500 error` and it will take ~7 min to wake up.
"""

with gr.Blocks() as demo:
    gr.Markdown(intro_md)
    endpoint_status = gr.Textbox(check_endpoint_status, label="Inference Endpoint Status", every=1)
    with gr.Tab("Arabic-RAG"):
        chatbot = gr.Chatbot(
                [],
                elem_id="chatbot",
                avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg',
                               'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'),
                bubble_full_width=False,
                show_copy_button=True,
                show_share_button=True,
                )

        with gr.Row():
            txt = gr.Textbox(
                    scale=3,
                    show_label=False,
                    placeholder="Enter query in Arabic or English and press enter",
                    container=False,
                    )
            txt_btn = gr.Button(value="Submit text", scale=1)

        gr.Examples(examples, txt)
        prompt_html = gr.HTML()
        # Turn off interactivity while generating if you click
        txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
                bot, chatbot, [chatbot, prompt_html])

        # Turn it back on
        txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)

        # Turn off interactivity while generating if you hit enter
        txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
                bot, chatbot, [chatbot, prompt_html])

        # Turn it back on
        txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)

    with gr.Tab("Arabic-RAG + HyDE"):
        hyde_chatbot = gr.Chatbot(
                [],
                elem_id="chatbot",
                avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg',
                               'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'),
                bubble_full_width=False,
                show_copy_button=True,
                show_share_button=True,
                )

        with gr.Row():
            hyde_txt = gr.Textbox(
                    scale=3,
                    show_label=False,
                    placeholder="Enter text and press enter",
                    container=False,
                    )
            hyde_txt_btn = gr.Button(value="Submit text", scale=1)

        gr.Examples(examples, hyde_txt)
        hyde_prompt_html = gr.HTML()
        # Turn off interactivity while generating if you click
        hyde_txt_msg = hyde_txt_btn.click(add_text, [hyde_chatbot, hyde_txt], [hyde_chatbot, hyde_txt],
                                          queue=False).then(
                partial(bot, hyde=True), [hyde_chatbot], [hyde_chatbot, hyde_prompt_html])

        # Turn it back on
        hyde_txt_msg.then(lambda: gr.Textbox(interactive=True), None, [hyde_txt], queue=False)

        # Turn off interactivity while generating if you hit enter
        hyde_txt_msg = hyde_txt.submit(add_text, [hyde_chatbot, hyde_txt], [hyde_chatbot, hyde_txt], queue=False).then(
                partial(bot, hyde=True), [hyde_chatbot], [hyde_chatbot, hyde_prompt_html])

        # Turn it back on
        hyde_txt_msg.then(lambda: gr.Textbox(interactive=True), None, [hyde_txt], queue=False)

demo.queue()
demo.launch(debug=True)