AlexanderKazakov commited on
Commit
8b1c859
1 Parent(s): 360f505

improve interface and cut documents to fit the context length

Browse files
gradio_app/app.py CHANGED
@@ -11,8 +11,9 @@ from time import perf_counter
11
  import gradio as gr
12
  from jinja2 import Environment, FileSystemLoader
13
 
14
- from backend.query_llm import generate_hf, generate_openai
15
- from backend.semantic_search import table, embedder
 
16
 
17
  from settings import *
18
 
@@ -24,23 +25,29 @@ logger = logging.getLogger(__name__)
24
  env = Environment(loader=FileSystemLoader('gradio_app/templates'))
25
 
26
  # Load the templates directly from the environment
27
- prompt_template = env.get_template('prompt_template.j2')
28
- template_html = env.get_template('template_html.j2')
29
 
30
  # Examples
31
- examples = ['What is the capital of China?',
32
- 'Why is the sky blue?',
33
- 'Who won the mens world cup in 2014?', ]
 
 
 
 
34
 
35
 
36
  def add_text(history, text):
37
  history = [] if history is None else history
38
- history = history + [(text, None)]
39
  return history, gr.Textbox(value="", interactive=False)
40
 
41
 
42
  def bot(history, api_kind):
43
- top_k_rank = 4
 
 
44
  query = history[-1][0]
45
 
46
  if not query:
@@ -53,71 +60,78 @@ def bot(history, api_kind):
53
 
54
  query_vec = embedder.encode(query)
55
  documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_k_rank).to_list()
 
 
56
  documents = [doc[TEXT_COLUMN_NAME] for doc in documents]
57
 
58
  document_time = perf_counter() - document_start
59
  logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
60
 
61
- # Create Prompt
62
- prompt = prompt_template.render(documents=documents, query=query)
63
- prompt_html = template_html.render(documents=documents, query=query)
64
-
65
- if api_kind == "HuggingFace":
66
- generate_fn = generate_hf
67
- elif api_kind == "OpenAI":
68
- generate_fn = generate_openai
69
- elif api_kind is None:
70
- gr.Warning("API name was not provided")
71
- raise ValueError("API name was not provided")
72
  else:
73
- gr.Warning(f"API {api_kind} is not supported")
74
- raise ValueError(f"API {api_kind} is not supported")
75
 
76
- history[-1][1] = ""
77
- for character in generate_fn(prompt, history[:-1]):
78
- history[-1][1] = character
79
- yield history, prompt_html
 
80
 
81
 
82
  with gr.Blocks() as demo:
83
- chatbot = gr.Chatbot(
84
- [],
85
- elem_id="chatbot",
86
- avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg',
87
- 'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'),
88
- bubble_full_width=False,
89
- show_copy_button=True,
90
- show_share_button=True,
91
- )
92
-
93
  with gr.Row():
94
- txt = gr.Textbox(
95
- scale=3,
96
- show_label=False,
97
- placeholder="Enter text and press enter",
98
- container=False,
99
- )
100
- txt_btn = gr.Button(value="Submit text", scale=1)
101
-
102
- api_kind = gr.Radio(choices=["HuggingFace", "OpenAI"], value="HuggingFace")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- prompt_html = gr.HTML()
105
  # Turn off interactivity while generating if you click
106
- txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
107
- bot, [chatbot, api_kind], [chatbot, prompt_html])
 
 
 
108
 
109
  # Turn it back on
110
- txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
111
 
112
  # Turn off interactivity while generating if you hit enter
113
- txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
114
- bot, [chatbot, api_kind], [chatbot, prompt_html])
115
 
116
  # Turn it back on
117
- txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
118
-
119
- # Examples
120
- gr.Examples(examples, txt)
121
 
122
  demo.queue()
123
  demo.launch(debug=True)
 
11
  import gradio as gr
12
  from jinja2 import Environment, FileSystemLoader
13
 
14
+ from gradio_app.backend.ChatGptInteractor import num_tokens_from_messages
15
+ from gradio_app.backend.query_llm import generate_hf, generate_openai, construct_openai_messages
16
+ from gradio_app.backend.semantic_search import table, embedder
17
 
18
  from settings import *
19
 
 
25
  env = Environment(loader=FileSystemLoader('gradio_app/templates'))
26
 
27
  # Load the templates directly from the environment
28
+ context_template = env.get_template('context_template.j2')
29
+ context_html_template = env.get_template('context_html_template.j2')
30
 
31
  # Examples
32
+ examples = [
33
+ 'What is BERT?',
34
+ 'Tell me about BERT deep learning model',
35
+ 'What is the capital of China?',
36
+ 'Why is the sky blue?',
37
+ 'Who won the mens world cup in 2014?',
38
+ ]
39
 
40
 
41
  def add_text(history, text):
42
  history = [] if history is None else history
43
+ history = history + [(text, "")]
44
  return history, gr.Textbox(value="", interactive=False)
45
 
46
 
47
  def bot(history, api_kind):
48
+ top_k_rank = 5
49
+ thresh_dist = 1.2
50
+ history[-1][1] = ""
51
  query = history[-1][0]
52
 
53
  if not query:
 
60
 
61
  query_vec = embedder.encode(query)
62
  documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_k_rank).to_list()
63
+ thresh_dist = max(thresh_dist, min(d['_distance'] for d in documents))
64
+ documents = [d for d in documents if d['_distance'] <= thresh_dist]
65
  documents = [doc[TEXT_COLUMN_NAME] for doc in documents]
66
 
67
  document_time = perf_counter() - document_start
68
  logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
69
 
70
+ while len(documents) != 0:
71
+ context = context_template.render(documents=documents)
72
+ context_html = context_html_template.render(documents=documents)
73
+ messages = construct_openai_messages(context, history)
74
+ num_tokens = num_tokens_from_messages(messages, OPENAI_LLM_NAME)
75
+ if num_tokens + 512 < context_lengths[OPENAI_LLM_NAME]:
76
+ break
77
+ documents.pop()
 
 
 
78
  else:
79
+ raise gr.Error('Model context length exceeded, reload the page')
 
80
 
81
+ for part in generate_openai(messages):
82
+ history[-1][1] += part
83
+ yield history, context_html
84
+ else:
85
+ print('Finished generation stream.')
86
 
87
 
88
  with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
89
  with gr.Row():
90
+ with gr.Column():
91
+ chatbot = gr.Chatbot(
92
+ [],
93
+ elem_id="chatbot",
94
+ avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg',
95
+ 'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'),
96
+ bubble_full_width=False,
97
+ show_copy_button=True,
98
+ show_share_button=True,
99
+ height=600,
100
+ )
101
+
102
+ with gr.Row():
103
+ input_textbox = gr.Textbox(
104
+ scale=3,
105
+ show_label=False,
106
+ placeholder="Enter text and press enter",
107
+ container=False,
108
+ )
109
+ txt_btn = gr.Button(value="Submit text", scale=1)
110
+
111
+ api_kind = gr.Radio(choices=["HuggingFace", "OpenAI"], value="OpenAI", label='Backend')
112
+
113
+ # Examples
114
+ gr.Examples(examples, input_textbox)
115
+
116
+ with gr.Column():
117
+ context_html = gr.HTML()
118
 
 
119
  # Turn off interactivity while generating if you click
120
+ txt_msg = txt_btn.click(
121
+ add_text, [chatbot, input_textbox], [chatbot, input_textbox], queue=False
122
+ ).then(
123
+ bot, [chatbot, api_kind], [chatbot, context_html]
124
+ )
125
 
126
  # Turn it back on
127
+ txt_msg.then(lambda: gr.Textbox(interactive=True), None, [input_textbox], queue=False)
128
 
129
  # Turn off interactivity while generating if you hit enter
130
+ txt_msg = input_textbox.submit(add_text, [chatbot, input_textbox], [chatbot, input_textbox], queue=False).then(
131
+ bot, [chatbot, api_kind], [chatbot, context_html])
132
 
133
  # Turn it back on
134
+ txt_msg.then(lambda: gr.Textbox(interactive=True), None, [input_textbox], queue=False)
 
 
 
135
 
136
  demo.queue()
137
  demo.launch(debug=True)
gradio_app/backend/ChatGptInteractor.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import tiktoken
4
+ import openai
5
+
6
+
7
+ with open('data/openaikey.txt') as f:
8
+ OPENAI_KEY = f.read().strip()
9
+ openai.api_key = OPENAI_KEY
10
+
11
+
12
+ def num_tokens_from_messages(messages, model):
13
+ """
14
+ Return the number of tokens used by a list of messages.
15
+ https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
16
+ """
17
+ try:
18
+ encoding = tiktoken.encoding_for_model(model)
19
+ except KeyError:
20
+ print("Warning: model not found. Using cl100k_base encoding.")
21
+ encoding = tiktoken.get_encoding("cl100k_base")
22
+ if model in {
23
+ "gpt-3.5-turbo-0613",
24
+ "gpt-3.5-turbo-16k-0613",
25
+ "gpt-4-0314",
26
+ "gpt-4-32k-0314",
27
+ "gpt-4-0613",
28
+ "gpt-4-32k-0613",
29
+ }:
30
+ tokens_per_message = 3
31
+ tokens_per_name = 1
32
+ elif model == "gpt-3.5-turbo-0301":
33
+ tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
34
+ tokens_per_name = -1 # if there's a name, the role is omitted
35
+ elif "gpt-3.5-turbo" in model:
36
+ # print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
37
+ return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613")
38
+ elif "gpt-4" in model:
39
+ # print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
40
+ return num_tokens_from_messages(messages, model="gpt-4-0613")
41
+ else:
42
+ raise NotImplementedError(
43
+ f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
44
+ )
45
+ num_tokens = 0
46
+ for message in messages:
47
+ num_tokens += tokens_per_message
48
+ for key, value in message.items():
49
+ num_tokens += len(encoding.encode(value, disallowed_special=()))
50
+ if key == "name":
51
+ num_tokens += tokens_per_name
52
+ num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
53
+ return num_tokens
54
+
55
+
56
+ class ChatGptInteractor:
57
+ def __init__(self, model_name='gpt-3.5-turbo'):
58
+ self.model_name = model_name
59
+ self.tokenizer = tiktoken.encoding_for_model(self.model_name)
60
+
61
+ def chat_completion_simple(
62
+ self,
63
+ *,
64
+ user_text,
65
+ system_text=None,
66
+ max_tokens=None,
67
+ temperature=None,
68
+ stream=False,
69
+ ):
70
+ return self.chat_completion(
71
+ self._construct_messages_simple(user_text, system_text),
72
+ max_tokens=max_tokens,
73
+ temperature=temperature,
74
+ stream=stream,
75
+ )
76
+
77
+ def count_tokens_simple(self, *, user_text, system_text=None):
78
+ return self.count_tokens(self._construct_messages_simple(user_text, system_text))
79
+
80
+ @staticmethod
81
+ def _construct_messages_simple(user_text, system_text=None):
82
+ messages = []
83
+ if system_text is not None:
84
+ messages.append({
85
+ "role": "system",
86
+ "content": system_text
87
+ })
88
+ messages.append({
89
+ "role": "user",
90
+ "content": user_text
91
+ })
92
+ return messages
93
+
94
+ def chat_completion(
95
+ self,
96
+ messages,
97
+ max_tokens=None,
98
+ temperature=None,
99
+ stream=False,
100
+ ):
101
+ print(f'Sending request to {self.model_name} stream={stream} ...')
102
+ t1 = time.time()
103
+ completion = self._request(
104
+ model=self.model_name,
105
+ messages=messages,
106
+ max_tokens=max_tokens,
107
+ temperature=temperature,
108
+ stream=stream,
109
+ )
110
+ if stream:
111
+ return completion
112
+ t2 = time.time()
113
+ usage = completion['usage']
114
+ print(
115
+ f'Received response: {usage["prompt_tokens"]} in + {usage["completion_tokens"]} out'
116
+ f' = {usage["total_tokens"]} total tokens. Time: {t2 - t1:3.1f} seconds'
117
+ )
118
+ return completion.choices[0].message['content']
119
+
120
+ @staticmethod
121
+ def get_stream_text(stream_part):
122
+ return stream_part['choices'][0]['delta'].get('content', '')
123
+
124
+ def count_tokens(self, messages):
125
+ return num_tokens_from_messages(messages, self.model_name)
126
+
127
+ def _request(self, *args, **kwargs):
128
+ for _ in range(5):
129
+ try:
130
+ completion = openai.ChatCompletion.create(
131
+ *args, **kwargs,
132
+ request_timeout=100.0,
133
+ )
134
+ return completion
135
+ except (openai.error.Timeout, openai.error.ServiceUnavailableError):
136
+ continue
137
+ raise RuntimeError('Failed to connect to OpenAI (timeout error)')
138
+
139
+
140
+ if __name__ == '__main__':
141
+ cgi = ChatGptInteractor()
142
+
143
+ for txt in [
144
+ "Hello World!",
145
+ "Hello",
146
+ " World!",
147
+ "World!",
148
+ "World",
149
+ "!",
150
+ " ",
151
+ " ",
152
+ " ",
153
+ " ",
154
+ "\n",
155
+ "\n\t",
156
+ ]:
157
+ print(f'`{txt}` | {cgi.tokenizer.encode(txt)}')
158
+
159
+ st = 'You are a helpful assistant and an experienced programmer, ' \
160
+ 'answering questions exactly in two rhymed sentences'
161
+ ut = 'Explain the principle of recursion in programming'
162
+ print('Count tokens:', cgi.count_tokens_simple(user_text=ut, system_text=st))
163
+
164
+ print(cgi.chat_completion_simple(user_text=ut, system_text=st))
165
+ print('---')
166
+
167
+ for part in cgi.chat_completion_simple(user_text=ut, system_text=st, stream=True):
168
+ print(cgi.get_stream_text(part), end='')
169
+ print('\n---')
170
+
gradio_app/backend/query_llm.py CHANGED
@@ -1,22 +1,20 @@
1
- import openai
2
  import gradio as gr
3
 
4
- from os import getenv
5
  from typing import Any, Dict, Generator, List
6
 
7
  from huggingface_hub import InferenceClient
8
  from transformers import AutoTokenizer
 
9
 
10
  from settings import *
 
11
 
12
 
13
- tokenizer = AutoTokenizer.from_pretrained(LLM_NAME)
14
 
15
- OPENAI_KEY = getenv("OPENAI_API_KEY")
16
- HF_TOKEN = getenv("HUGGING_FACE_HUB_TOKEN")
17
 
18
-
19
- hf_client = InferenceClient(LLM_NAME, token=HF_TOKEN)
20
 
21
 
22
  def format_prompt(message: str, api_kind: str):
@@ -42,7 +40,7 @@ def format_prompt(message: str, api_kind: str):
42
 
43
 
44
  def generate_hf(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 512,
45
- top_p: float = 0.6, repetition_penalty: float = 1.2) -> Generator[str, None, str]:
46
  """
47
  Generate a sequence of tokens based on a given prompt and history using Mistral client.
48
 
@@ -69,13 +67,13 @@ def generate_hf(prompt: str, history: str, temperature: float = 0.9, max_new_tok
69
  'repetition_penalty': repetition_penalty,
70
  'do_sample': True,
71
  'seed': 42,
72
- }
73
-
74
  formatted_prompt = format_prompt(prompt, "hf")
75
 
76
  try:
77
  stream = hf_client.text_generation(formatted_prompt, **generate_kwargs,
78
- stream=True, details=True, return_full_text=False)
79
  output = ""
80
  for response in stream:
81
  output += response.token.text
@@ -96,8 +94,44 @@ def generate_hf(prompt: str, history: str, temperature: float = 0.9, max_new_tok
96
  return "I do not know what happened, but I couldn't understand you."
97
 
98
 
99
- def generate_openai(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 512,
100
- top_p: float = 0.6, repetition_penalty: float = 1.2) -> Generator[str, None, str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  """
102
  Generate a sequence of tokens based on a given prompt and history using Mistral client.
103
 
@@ -116,21 +150,23 @@ def generate_openai(prompt: str, history: str, temperature: float = 0.9, max_new
116
 
117
  temperature = max(float(temperature), 1e-2) # Ensure temperature isn't too low
118
  top_p = float(top_p)
119
-
120
  generate_kwargs = {
121
  'temperature': temperature,
122
  'max_tokens': max_new_tokens,
123
  'top_p': top_p,
124
  'frequency_penalty': max(-2., min(repetition_penalty, 2.)),
125
- }
126
 
127
  formatted_prompt = format_prompt(prompt, "openai")
128
 
129
  try:
130
- stream = openai.ChatCompletion.create(model="gpt-3.5-turbo-0301",
131
- messages=formatted_prompt,
132
- **generate_kwargs,
133
- stream=True)
 
 
134
  output = ""
135
  for chunk in stream:
136
  output += chunk.choices[0].delta.get("content", "")
 
 
1
  import gradio as gr
2
 
 
3
  from typing import Any, Dict, Generator, List
4
 
5
  from huggingface_hub import InferenceClient
6
  from transformers import AutoTokenizer
7
+ from jinja2 import Environment, FileSystemLoader
8
 
9
  from settings import *
10
+ from gradio_app.backend.ChatGptInteractor import *
11
 
12
 
13
+ tokenizer = AutoTokenizer.from_pretrained(HF_LLM_NAME)
14
 
15
+ HF_TOKEN = None
 
16
 
17
+ hf_client = InferenceClient(HF_LLM_NAME, token=HF_TOKEN)
 
18
 
19
 
20
  def format_prompt(message: str, api_kind: str):
 
40
 
41
 
42
  def generate_hf(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 512,
43
+ top_p: float = 0.6, repetition_penalty: float = 1.2) -> Generator[str, None, str]:
44
  """
45
  Generate a sequence of tokens based on a given prompt and history using Mistral client.
46
 
 
67
  'repetition_penalty': repetition_penalty,
68
  'do_sample': True,
69
  'seed': 42,
70
+ }
71
+
72
  formatted_prompt = format_prompt(prompt, "hf")
73
 
74
  try:
75
  stream = hf_client.text_generation(formatted_prompt, **generate_kwargs,
76
+ stream=True, details=True, return_full_text=False)
77
  output = ""
78
  for response in stream:
79
  output += response.token.text
 
94
  return "I do not know what happened, but I couldn't understand you."
95
 
96
 
97
+ env = Environment(loader=FileSystemLoader('gradio_app/templates'))
98
+ context_template = env.get_template('context_template.j2')
99
+ start_system_message = context_template.render(documents=[])
100
+
101
+
102
+ def construct_openai_messages(context, history):
103
+ messages = [
104
+ {
105
+ "role": "system",
106
+ "content": start_system_message,
107
+ },
108
+ ]
109
+ for q, a in history:
110
+ if len(a) == 0: # the last message
111
+ messages.append({
112
+ "role": "system",
113
+ "content": context,
114
+ })
115
+ messages.append({
116
+ "role": "user",
117
+ "content": q,
118
+ })
119
+ if len(a) != 0: # some of the previous LLM answers
120
+ messages.append({
121
+ "role": "assistant",
122
+ "content": a,
123
+ })
124
+ return messages
125
+
126
+
127
+ def generate_openai(messages):
128
+ cgi = ChatGptInteractor(model_name=OPENAI_LLM_NAME)
129
+ for part in cgi.chat_completion(messages, max_tokens=512, temperature=0, stream=True):
130
+ yield cgi.get_stream_text(part)
131
+
132
+
133
+ def _generate_openai(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 512,
134
+ top_p: float = 0.6, repetition_penalty: float = 1.2) -> Generator[str, None, str]:
135
  """
136
  Generate a sequence of tokens based on a given prompt and history using Mistral client.
137
 
 
150
 
151
  temperature = max(float(temperature), 1e-2) # Ensure temperature isn't too low
152
  top_p = float(top_p)
153
+
154
  generate_kwargs = {
155
  'temperature': temperature,
156
  'max_tokens': max_new_tokens,
157
  'top_p': top_p,
158
  'frequency_penalty': max(-2., min(repetition_penalty, 2.)),
159
+ }
160
 
161
  formatted_prompt = format_prompt(prompt, "openai")
162
 
163
  try:
164
+ stream = openai.ChatCompletion.create(
165
+ model=OPENAI_LLM_NAME,
166
+ messages=formatted_prompt,
167
+ **generate_kwargs,
168
+ stream=True
169
+ )
170
  output = ""
171
  for chunk in stream:
172
  output += chunk.choices[0].delta.get("content", "")
gradio_app/templates/context_html_template.j2 ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Information Page</title>
7
+ <link rel="stylesheet" href="https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&amp;display=swap">
8
+ <link rel="stylesheet" href="https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@400;600&amp;display=swap">
9
+ <style>
10
+ * {
11
+ font-family: "Source Sans Pro";
12
+ }
13
+
14
+ .instructions > * {
15
+ color: #111 !important;
16
+ }
17
+
18
+ details.doc-box * {
19
+ color: #111 !important;
20
+ }
21
+
22
+ .dark {
23
+ background: #111;
24
+ color: white;
25
+ }
26
+
27
+ .doc-box {
28
+ padding: 10px;
29
+ margin-top: 10px;
30
+ background-color: #baecc2;
31
+ border-radius: 6px;
32
+ color: #111 !important;
33
+ max-width: 700px;
34
+ box-shadow: rgba(0, 0, 0, 0.2) 0px 1px 2px 0px;
35
+ }
36
+
37
+ .doc-full {
38
+ margin: 10px 14px;
39
+ line-height: 1.6rem;
40
+ }
41
+
42
+ .instructions {
43
+ color: #111 !important;
44
+ background: #b7bdfd;
45
+ display: block;
46
+ border-radius: 6px;
47
+ padding: 6px 10px;
48
+ line-height: 1.6rem;
49
+ max-width: 700px;
50
+ box-shadow: rgba(0, 0, 0, 0.2) 0px 1px 2px 0px;
51
+ }
52
+
53
+ .query {
54
+ color: #111 !important;
55
+ background: #ffbcbc;
56
+ display: block;
57
+ border-radius: 6px;
58
+ padding: 6px 10px;
59
+ line-height: 1.6rem;
60
+ max-width: 700px;
61
+ box-shadow: rgba(0, 0, 0, 0.2) 0px 1px 2px 0px;
62
+ }
63
+ </style>
64
+ </head>
65
+ <body>
66
+ <div class="prose svelte-1ybaih5" id="context_html">
67
+ <h2>Context:</h2>
68
+ {% for doc in documents %}
69
+ <details class="doc-box">
70
+ <summary>
71
+ <b>Doc {{ loop.index }}:</b> <span class="doc-short">{{ doc[:1000] }}...</span>
72
+ </summary>
73
+ <div class="doc-full">{{ doc }}</div>
74
+ </details>
75
+ {% endfor %}
76
+ </div>
77
+
78
+ <script>
79
+ document.addEventListener("DOMContentLoaded", function() {
80
+ const detailsElements = document.querySelectorAll('.doc-box');
81
+
82
+ detailsElements.forEach(detail => {
83
+ detail.addEventListener('toggle', function() {
84
+ const docShort = this.querySelector('.doc-short');
85
+ if (this.open) {
86
+ docShort.style.display = 'none';
87
+ } else {
88
+ docShort.style.display = 'inline';
89
+ }
90
+ });
91
+ });
92
+ });
93
+ </script>
94
+ </body>
95
+ </html>
gradio_app/templates/context_template.j2 ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You are a helpful assistant.
2
+
3
+ You answer questions based only on the provided information.
4
+
5
+ If there is no relevant information in the context, just say "No relevant information".
6
+
7
+ You must not make up an answer! Use only provided context!
8
+
9
+ In each answer, you must provide a precise citation from the given context in double quotes.
10
+
11
+ Citation is mandatory in the answer!
12
+
13
+ Context:
14
+
15
+ {% for doc in documents %}
16
+ ---
17
+
18
+ {{ doc }}
19
+
20
+ {% endfor %}
settings.py CHANGED
@@ -5,4 +5,10 @@ LANCEDB_DIRECTORY = "data/lancedb"
5
  LANCEDB_TABLE_NAME = "table"
6
  VECTOR_COLUMN_NAME = "embedding"
7
  TEXT_COLUMN_NAME = "text"
8
- LLM_NAME = "mistralai/Mistral-7B-Instruct-v0.1"
 
 
 
 
 
 
 
5
  LANCEDB_TABLE_NAME = "table"
6
  VECTOR_COLUMN_NAME = "embedding"
7
  TEXT_COLUMN_NAME = "text"
8
+ HF_LLM_NAME = "mistralai/Mistral-7B-Instruct-v0.1"
9
+ OPENAI_LLM_NAME = "gpt-3.5-turbo"
10
+
11
+ context_lengths = {
12
+ "mistralai/Mistral-7B-Instruct-v0.1": 4096,
13
+ "gpt-3.5-turbo": 4096,
14
+ }