Ma003 commited on
Commit
fc45ae2
·
1 Parent(s): 4d1a966
Files changed (1) hide show
  1. app.py +0 -107
app.py DELETED
@@ -1,107 +0,0 @@
1
- """
2
- Credit to Derek Thomas, derek@huggingface.co
3
- """
4
- import os
5
- import logging
6
- from pathlib import Path
7
- from time import perf_counter
8
-
9
- import gradio as gr
10
- from jinja2 import Environment, FileSystemLoader
11
-
12
- from backend.query_llm import generate_hf, generate_openai
13
- from backend.semantic_search import retrieve
14
-
15
-
16
- TOP_K = int(os.getenv("TOP_K", 4))
17
-
18
- proj_dir = Path(__file__).parent
19
- # Setting up the logging
20
- logging.basicConfig(level=logging.INFO)
21
- logger = logging.getLogger(__name__)
22
-
23
- # Set up the template environment with the templates directory
24
- env = Environment(loader=FileSystemLoader(proj_dir / 'templates'))
25
-
26
- # Load the templates directly from the environment
27
- template = env.get_template('template.j2')
28
- template_html = env.get_template('template_html.j2')
29
-
30
-
31
- def add_text(history, text):
32
- history = [] if history is None else history
33
- history = history + [(text, None)]
34
- return history, gr.Textbox(value="", interactive=False)
35
-
36
-
37
- def bot(history, api_kind):
38
- query = history[-1][0]
39
-
40
- if not query:
41
- raise gr.Warning("Please submit a non-empty string as a prompt")
42
-
43
- logger.info('Retrieving documents...')
44
- # Retrieve documents relevant to query
45
- document_start = perf_counter()
46
-
47
- documents = retrieve(query, TOP_K)
48
-
49
- document_time = perf_counter() - document_start
50
- logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
51
-
52
- # Create Prompt
53
- prompt = template.render(documents=documents, query=query)
54
- prompt_html = template_html.render(documents=documents, query=query)
55
-
56
- if api_kind == "HuggingFace":
57
- generate_fn = generate_hf
58
- elif api_kind == "OpenAI":
59
- generate_fn = generate_openai
60
- else:
61
- raise gr.Error(f"API {api_kind} is not supported")
62
-
63
- history[-1][1] = ""
64
- for character in generate_fn(prompt, history[:-1]):
65
- history[-1][1] = character
66
- yield history, prompt_html
67
-
68
-
69
- with gr.Blocks() as demo:
70
- chatbot = gr.Chatbot(
71
- [],
72
- elem_id="chatbot",
73
- avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg',
74
- 'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'),
75
- bubble_full_width=False,
76
- show_copy_button=True,
77
- show_share_button=True,
78
- )
79
-
80
- with gr.Row():
81
- txt = gr.Textbox(
82
- scale=3,
83
- show_label=False,
84
- placeholder="Enter text and press enter",
85
- container=False,
86
- )
87
- txt_btn = gr.Button(value="Submit text", scale=1)
88
-
89
- api_kind = gr.Radio(choices=["HuggingFace", "OpenAI"], value="HuggingFace")
90
-
91
- prompt_html = gr.HTML()
92
- # Turn off interactivity while generating if you click
93
- txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
94
- bot, [chatbot, api_kind], [chatbot, prompt_html])
95
-
96
- # Turn it back on
97
- txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
98
-
99
- # Turn off interactivity while generating if you hit enter
100
- txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
101
- bot, [chatbot, api_kind], [chatbot, prompt_html])
102
-
103
- # Turn it back on
104
- txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
105
-
106
- demo.queue()
107
- demo.launch(debug=True)