AlexanderKazakov commited on
Commit
34b78ab
1 Parent(s): eeafaaa

add cross-encoder and HF API LLM

Browse files
gradio_app/app.py CHANGED
@@ -13,7 +13,8 @@ import markdown
13
  from jinja2 import Environment, FileSystemLoader
14
 
15
  from gradio_app.backend.ChatGptInteractor import num_tokens_from_messages
16
- from gradio_app.backend.query_llm import generate_hf, generate_openai, construct_openai_messages
 
17
  from gradio_app.backend.semantic_search import table, embedder
18
 
19
  from settings import *
@@ -45,42 +46,52 @@ def add_text(history, text):
45
  return history, gr.Textbox(value="", interactive=False)
46
 
47
 
48
- def bot(history, api_kind):
49
- top_k_rank = 5
50
- thresh_dist = 1.2
51
  history[-1][1] = ""
52
  query = history[-1][0]
53
 
54
  if not query:
55
- gr.Warning("Please submit a non-empty string as a prompt")
56
- raise ValueError("Empty string was submitted")
57
 
58
  logger.info('Retrieving documents...')
59
- # Retrieve documents relevant to query
60
- document_start = perf_counter()
61
 
62
  query_vec = embedder.embed(query)[0]
63
- documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_k_rank).to_list()
 
 
64
  thresh_dist = max(thresh_dist, min(d['_distance'] for d in documents))
65
  documents = [d for d in documents if d['_distance'] <= thresh_dist]
66
  documents = [doc[TEXT_COLUMN_NAME] for doc in documents]
67
 
68
- document_time = perf_counter() - document_start
69
- logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
70
 
 
 
 
 
 
 
 
 
 
 
71
  while len(documents) != 0:
72
  context = context_template.render(documents=documents)
73
  documents_html = [markdown.markdown(d) for d in documents]
74
  context_html = context_html_template.render(documents=documents_html)
75
- messages = construct_openai_messages(context, history)
76
- num_tokens = num_tokens_from_messages(messages, LLM_NAME)
77
- if num_tokens + 512 < context_lengths[LLM_NAME]:
78
  break
79
  documents.pop()
80
  else:
81
  raise gr.Error('Model context length exceeded, reload the page')
82
 
83
- for part in generate_openai(messages):
 
84
  history[-1][1] += part
85
  yield history, context_html
86
  else:
@@ -110,7 +121,25 @@ with gr.Blocks() as demo:
110
  )
111
  txt_btn = gr.Button(value="Submit text", scale=1)
112
 
113
- api_kind = gr.Radio(choices=["HuggingFace", "OpenAI"], value="OpenAI", label='Backend')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  # Examples
116
  gr.Examples(examples, input_textbox)
@@ -122,7 +151,7 @@ with gr.Blocks() as demo:
122
  txt_msg = txt_btn.click(
123
  add_text, [chatbot, input_textbox], [chatbot, input_textbox], queue=False
124
  ).then(
125
- bot, [chatbot, api_kind], [chatbot, context_html]
126
  )
127
 
128
  # Turn it back on
@@ -130,7 +159,7 @@ with gr.Blocks() as demo:
130
 
131
  # Turn off interactivity while generating if you hit enter
132
  txt_msg = input_textbox.submit(add_text, [chatbot, input_textbox], [chatbot, input_textbox], queue=False).then(
133
- bot, [chatbot, api_kind], [chatbot, context_html])
134
 
135
  # Turn it back on
136
  txt_msg.then(lambda: gr.Textbox(interactive=True), None, [input_textbox], queue=False)
 
13
  from jinja2 import Environment, FileSystemLoader
14
 
15
  from gradio_app.backend.ChatGptInteractor import num_tokens_from_messages
16
+ from gradio_app.backend.cross_encoder import rerank_with_cross_encoder
17
+ from gradio_app.backend.query_llm import *
18
  from gradio_app.backend.semantic_search import table, embedder
19
 
20
  from settings import *
 
46
  return history, gr.Textbox(value="", interactive=False)
47
 
48
 
49
+ def bot(history, llm, cross_enc):
 
 
50
  history[-1][1] = ""
51
  query = history[-1][0]
52
 
53
  if not query:
54
+ raise gr.Error("Empty string was submitted")
 
55
 
56
  logger.info('Retrieving documents...')
57
+ gr.Info('Start documents retrieval ...')
58
+ time = perf_counter()
59
 
60
  query_vec = embedder.embed(query)[0]
61
+ documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME)
62
+ documents = documents.limit(TOP_K_RANK).to_list()
63
+ thresh_dist = thresh_distances[EMBED_NAME]
64
  thresh_dist = max(thresh_dist, min(d['_distance'] for d in documents))
65
  documents = [d for d in documents if d['_distance'] <= thresh_dist]
66
  documents = [doc[TEXT_COLUMN_NAME] for doc in documents]
67
 
68
+ time = perf_counter() - time
69
+ logger.info(f'Finished Retrieving documents in {round(time, 2)} seconds...')
70
 
71
+ logger.info('Reranking documents...')
72
+ gr.Info('Start documents reranking ...')
73
+ time = perf_counter()
74
+
75
+ documents = rerank_with_cross_encoder(cross_enc, documents, query)
76
+
77
+ time = perf_counter() - time
78
+ logger.info(f'Finished Reranking documents in {round(time, 2)} seconds...')
79
+
80
+ msg_constructor = get_message_constructor(llm)
81
  while len(documents) != 0:
82
  context = context_template.render(documents=documents)
83
  documents_html = [markdown.markdown(d) for d in documents]
84
  context_html = context_html_template.render(documents=documents_html)
85
+ messages = msg_constructor(context, history)
86
+ num_tokens = num_tokens_from_messages(messages, 'gpt-3.5-turbo') # todo for HF, it is approximation
87
+ if num_tokens + 512 < context_lengths[llm]:
88
  break
89
  documents.pop()
90
  else:
91
  raise gr.Error('Model context length exceeded, reload the page')
92
 
93
+ llm_gen = get_llm_generator(llm)
94
+ for part in llm_gen(messages):
95
  history[-1][1] += part
96
  yield history, context_html
97
  else:
 
121
  )
122
  txt_btn = gr.Button(value="Submit text", scale=1)
123
 
124
+ llm_name = gr.Radio(
125
+ choices=[
126
+ "gpt-3.5-turbo",
127
+ "mistralai/Mistral-7B-Instruct-v0.1",
128
+ "GeneZC/MiniChat-3B",
129
+ ],
130
+ value="gpt-3.5-turbo",
131
+ label='LLM'
132
+ )
133
+
134
+ cross_enc_name = gr.Radio(
135
+ choices=[
136
+ None,
137
+ "cross-encoder/ms-marco-TinyBERT-L-2-v2",
138
+ "cross-encoder/ms-marco-MiniLM-L-12-v2",
139
+ ],
140
+ value=None,
141
+ label='Cross-Encoder'
142
+ )
143
 
144
  # Examples
145
  gr.Examples(examples, input_textbox)
 
151
  txt_msg = txt_btn.click(
152
  add_text, [chatbot, input_textbox], [chatbot, input_textbox], queue=False
153
  ).then(
154
+ bot, [chatbot, llm_name, cross_enc_name], [chatbot, context_html]
155
  )
156
 
157
  # Turn it back on
 
159
 
160
  # Turn off interactivity while generating if you hit enter
161
  txt_msg = input_textbox.submit(add_text, [chatbot, input_textbox], [chatbot, input_textbox], queue=False).then(
162
+ bot, [chatbot, llm_name, cross_enc_name], [chatbot, context_html])
163
 
164
  # Turn it back on
165
  txt_msg.then(lambda: gr.Textbox(interactive=True), None, [input_textbox], queue=False)
gradio_app/backend/ChatGptInteractor.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import time
2
 
3
  import tiktoken
@@ -9,6 +10,10 @@ with open('data/openaikey.txt') as f:
9
  openai.api_key = OPENAI_KEY
10
 
11
 
 
 
 
 
12
  def num_tokens_from_messages(messages, model):
13
  """
14
  Return the number of tokens used by a list of messages.
@@ -17,7 +22,7 @@ def num_tokens_from_messages(messages, model):
17
  try:
18
  encoding = tiktoken.encoding_for_model(model)
19
  except KeyError:
20
- print("Warning: model not found. Using cl100k_base encoding.")
21
  encoding = tiktoken.get_encoding("cl100k_base")
22
  if model in {
23
  "gpt-3.5-turbo-0613",
@@ -33,10 +38,10 @@ def num_tokens_from_messages(messages, model):
33
  tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
34
  tokens_per_name = -1 # if there's a name, the role is omitted
35
  elif "gpt-3.5-turbo" in model:
36
- # print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
37
  return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613")
38
  elif "gpt-4" in model:
39
- # print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
40
  return num_tokens_from_messages(messages, model="gpt-4-0613")
41
  else:
42
  raise NotImplementedError(
@@ -54,8 +59,11 @@ def num_tokens_from_messages(messages, model):
54
 
55
 
56
  class ChatGptInteractor:
57
- def __init__(self, model_name='gpt-3.5-turbo'):
58
  self.model_name = model_name
 
 
 
59
  self.tokenizer = tiktoken.encoding_for_model(self.model_name)
60
 
61
  def chat_completion_simple(
@@ -63,15 +71,9 @@ class ChatGptInteractor:
63
  *,
64
  user_text,
65
  system_text=None,
66
- max_tokens=None,
67
- temperature=None,
68
- stream=False,
69
  ):
70
  return self.chat_completion(
71
  self._construct_messages_simple(user_text, system_text),
72
- max_tokens=max_tokens,
73
- temperature=temperature,
74
- stream=stream,
75
  )
76
 
77
  def count_tokens_simple(self, *, user_text, system_text=None):
@@ -91,27 +93,17 @@ class ChatGptInteractor:
91
  })
92
  return messages
93
 
94
- def chat_completion(
95
- self,
96
- messages,
97
- max_tokens=None,
98
- temperature=None,
99
- stream=False,
100
- ):
101
- print(f'Sending request to {self.model_name} stream={stream} ...')
102
  t1 = time.time()
103
- completion = self._request(
104
- model=self.model_name,
105
- messages=messages,
106
- max_tokens=max_tokens,
107
- temperature=temperature,
108
- stream=stream,
109
- )
110
- if stream:
111
- return completion
112
  t2 = time.time()
113
  usage = completion['usage']
114
- print(
115
  f'Received response: {usage["prompt_tokens"]} in + {usage["completion_tokens"]} out'
116
  f' = {usage["total_tokens"]} total tokens. Time: {t2 - t1:3.1f} seconds'
117
  )
@@ -121,14 +113,23 @@ class ChatGptInteractor:
121
  def get_stream_text(stream_part):
122
  return stream_part['choices'][0]['delta'].get('content', '')
123
 
 
 
 
 
 
124
  def count_tokens(self, messages):
125
  return num_tokens_from_messages(messages, self.model_name)
126
 
127
- def _request(self, *args, **kwargs):
128
  for _ in range(5):
129
  try:
130
  completion = openai.ChatCompletion.create(
131
- *args, **kwargs,
 
 
 
 
132
  request_timeout=100.0,
133
  )
134
  return completion
@@ -164,7 +165,8 @@ if __name__ == '__main__':
164
  print(cgi.chat_completion_simple(user_text=ut, system_text=st))
165
  print('---')
166
 
167
- for part in cgi.chat_completion_simple(user_text=ut, system_text=st, stream=True):
168
- print(cgi.get_stream_text(part), end='')
 
169
  print('\n---')
170
 
 
1
+ import logging
2
  import time
3
 
4
  import tiktoken
 
10
  openai.api_key = OPENAI_KEY
11
 
12
 
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
  def num_tokens_from_messages(messages, model):
18
  """
19
  Return the number of tokens used by a list of messages.
 
22
  try:
23
  encoding = tiktoken.encoding_for_model(model)
24
  except KeyError:
25
+ logger.info("Warning: model not found. Using cl100k_base encoding.")
26
  encoding = tiktoken.get_encoding("cl100k_base")
27
  if model in {
28
  "gpt-3.5-turbo-0613",
 
38
  tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
39
  tokens_per_name = -1 # if there's a name, the role is omitted
40
  elif "gpt-3.5-turbo" in model:
41
+ # logger.info()("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
42
  return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613")
43
  elif "gpt-4" in model:
44
+ # logger.info()("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
45
  return num_tokens_from_messages(messages, model="gpt-4-0613")
46
  else:
47
  raise NotImplementedError(
 
59
 
60
 
61
  class ChatGptInteractor:
62
+ def __init__(self, model_name='gpt-3.5-turbo', max_tokens=None, temperature=None, stream=False):
63
  self.model_name = model_name
64
+ self.max_tokens = max_tokens
65
+ self.temperature = temperature
66
+ self.stream = stream
67
  self.tokenizer = tiktoken.encoding_for_model(self.model_name)
68
 
69
  def chat_completion_simple(
 
71
  *,
72
  user_text,
73
  system_text=None,
 
 
 
74
  ):
75
  return self.chat_completion(
76
  self._construct_messages_simple(user_text, system_text),
 
 
 
77
  )
78
 
79
  def count_tokens_simple(self, *, user_text, system_text=None):
 
93
  })
94
  return messages
95
 
96
+ def chat_completion(self, messages):
97
+ logger.info(f'Sending request to {self.model_name} stream={self.stream} ...')
 
 
 
 
 
 
98
  t1 = time.time()
99
+ completion = self._request(messages)
100
+
101
+ if self.stream:
102
+ return self._generator(completion)
103
+
 
 
 
 
104
  t2 = time.time()
105
  usage = completion['usage']
106
+ logger.info(
107
  f'Received response: {usage["prompt_tokens"]} in + {usage["completion_tokens"]} out'
108
  f' = {usage["total_tokens"]} total tokens. Time: {t2 - t1:3.1f} seconds'
109
  )
 
113
  def get_stream_text(stream_part):
114
  return stream_part['choices'][0]['delta'].get('content', '')
115
 
116
+ @staticmethod
117
+ def _generator(completion):
118
+ for part in completion:
119
+ yield ChatGptInteractor.get_stream_text(part)
120
+
121
  def count_tokens(self, messages):
122
  return num_tokens_from_messages(messages, self.model_name)
123
 
124
+ def _request(self, messages):
125
  for _ in range(5):
126
  try:
127
  completion = openai.ChatCompletion.create(
128
+ messages=messages,
129
+ model=self.model_name,
130
+ max_tokens=self.max_tokens,
131
+ temperature=self.temperature,
132
+ stream=self.stream,
133
  request_timeout=100.0,
134
  )
135
  return completion
 
165
  print(cgi.chat_completion_simple(user_text=ut, system_text=st))
166
  print('---')
167
 
168
+ cgi = ChatGptInteractor(stream=True)
169
+ for part in cgi.chat_completion_simple(user_text=ut, system_text=st):
170
+ print(part, end='')
171
  print('\n---')
172
 
gradio_app/backend/HuggingfaceGenerator.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from huggingface_hub import InferenceClient
4
+ from transformers import AutoTokenizer
5
+
6
+ with open('data/hftoken.txt') as f:
7
+ HF_TOKEN = f.read().strip()
8
+
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ # noinspection PyTypeChecker
14
+ class HuggingfaceGenerator:
15
+ def __init__(
16
+ self, model_name,
17
+ temperature: float = 0.9, max_new_tokens: int = 512,
18
+ top_p: float = None, repetition_penalty: float = None,
19
+ stream: bool = True,
20
+ ):
21
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
22
+ self.hf_client = InferenceClient(model_name, token=HF_TOKEN)
23
+ self.stream = stream
24
+
25
+ self.generate_kwargs = {
26
+ 'temperature': max(temperature, 0.1),
27
+ 'max_new_tokens': max_new_tokens,
28
+ 'top_p': top_p,
29
+ 'repetition_penalty': repetition_penalty,
30
+ 'do_sample': True,
31
+ 'seed': 42,
32
+ }
33
+
34
+ def generate(self, messages):
35
+ formatted_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
36
+
37
+ logger.info(f'Start HuggingFace generation, model {self.hf_client.model} ...')
38
+ stream = self.hf_client.text_generation(
39
+ formatted_prompt, **self.generate_kwargs,
40
+ stream=self.stream, details=True, return_full_text=not self.stream
41
+ )
42
+
43
+ for response in stream:
44
+ yield response.token.text
gradio_app/backend/cross_encoder.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+
4
+ from settings import *
5
+
6
+
7
+ cross_encoder = None
8
+ cross_enc_tokenizer = None
9
+
10
+
11
+ @torch.no_grad()
12
+ def rerank_with_cross_encoder(cross_enc_name, documents, query):
13
+ if cross_enc_name is None or len(documents) <= 1:
14
+ return documents
15
+
16
+ global cross_encoder, cross_enc_tokenizer
17
+ if cross_encoder is None or cross_encoder.name_or_path != cross_enc_name:
18
+ cross_encoder = AutoModelForSequenceClassification.from_pretrained(cross_enc_name)
19
+ cross_encoder.eval()
20
+ cross_enc_tokenizer = AutoTokenizer.from_pretrained(cross_enc_name)
21
+
22
+ features = cross_enc_tokenizer(
23
+ [query] * len(documents), documents, padding=True, truncation=True, return_tensors="pt"
24
+ )
25
+ scores = cross_encoder(**features).logits.squeeze()
26
+ ranks = torch.argsort(scores, descending=True)
27
+ documents = [documents[i] for i in ranks[:TOP_K_RERANK]]
28
+ return documents
29
+
30
+
31
+
32
+
gradio_app/backend/query_llm.py CHANGED
@@ -1,102 +1,30 @@
1
- import gradio as gr
2
-
3
- from typing import Any, Dict, Generator, List
4
-
5
- # from huggingface_hub import InferenceClient
6
- # from transformers import AutoTokenizer
7
  from jinja2 import Environment, FileSystemLoader
8
 
9
- from settings import *
10
  from gradio_app.backend.ChatGptInteractor import *
11
-
12
-
13
- # tokenizer = AutoTokenizer.from_pretrained(LLM_NAME)
14
- # HF_TOKEN = None
15
- # hf_client = InferenceClient(LLM_NAME, token=HF_TOKEN)
16
-
17
-
18
- def format_prompt(message: str, api_kind: str):
19
- """
20
- Formats the given message using a chat template.
21
-
22
- Args:
23
- message (str): The user message to be formatted.
24
-
25
- Returns:
26
- str: Formatted message after applying the chat template.
27
- """
28
-
29
- # Create a list of message dictionaries with role and content
30
- messages: List[Dict[str, Any]] = [{'role': 'user', 'content': message}]
31
-
32
- if api_kind == "openai":
33
- return messages
34
- elif api_kind == "hf":
35
- return tokenizer.apply_chat_template(messages, tokenize=False)
36
- else:
37
- raise ValueError("API is not supported")
38
-
39
-
40
- def generate_hf(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 512,
41
- top_p: float = 0.6, repetition_penalty: float = 1.2) -> Generator[str, None, str]:
42
- """
43
- Generate a sequence of tokens based on a given prompt and history using Mistral client.
44
-
45
- Args:
46
- prompt (str): The initial prompt for the text generation.
47
- history (str): Context or history for the text generation.
48
- temperature (float, optional): The softmax temperature for sampling. Defaults to 0.9.
49
- max_new_tokens (int, optional): Maximum number of tokens to be generated. Defaults to 256.
50
- top_p (float, optional): Nucleus sampling probability. Defaults to 0.95.
51
- repetition_penalty (float, optional): Penalty for repeated tokens. Defaults to 1.0.
52
-
53
- Returns:
54
- Generator[str, None, str]: A generator yielding chunks of generated text.
55
- Returns a final string if an error occurs.
56
- """
57
-
58
- temperature = max(float(temperature), 1e-2) # Ensure temperature isn't too low
59
- top_p = float(top_p)
60
-
61
- generate_kwargs = {
62
- 'temperature': temperature,
63
- 'max_new_tokens': max_new_tokens,
64
- 'top_p': top_p,
65
- 'repetition_penalty': repetition_penalty,
66
- 'do_sample': True,
67
- 'seed': 42,
68
- }
69
-
70
- formatted_prompt = format_prompt(prompt, "hf")
71
-
72
- try:
73
- stream = hf_client.text_generation(formatted_prompt, **generate_kwargs,
74
- stream=True, details=True, return_full_text=False)
75
- output = ""
76
- for response in stream:
77
- output += response.token.text
78
- yield output
79
-
80
- except Exception as e:
81
- if "Too Many Requests" in str(e):
82
- print("ERROR: Too many requests on Mistral client")
83
- gr.Warning("Unfortunately Mistral is unable to process")
84
- return "Unfortunately, I am not able to process your request now."
85
- elif "Authorization header is invalid" in str(e):
86
- print("Authetification error:", str(e))
87
- gr.Warning("Authentication error: HF token was either not provided or incorrect")
88
- return "Authentication error"
89
- else:
90
- print("Unhandled Exception:", str(e))
91
- gr.Warning("Unfortunately Mistral is unable to process")
92
- return "I do not know what happened, but I couldn't understand you."
93
-
94
 
95
  env = Environment(loader=FileSystemLoader('gradio_app/templates'))
96
  context_template = env.get_template('context_template.j2')
97
  start_system_message = context_template.render(documents=[])
98
 
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  def construct_openai_messages(context, history):
101
  messages = [
102
  {
@@ -122,64 +50,32 @@ def construct_openai_messages(context, history):
122
  return messages
123
 
124
 
125
- def generate_openai(messages):
126
- cgi = ChatGptInteractor(model_name=LLM_NAME)
127
- for part in cgi.chat_completion(messages, max_tokens=512, temperature=0, stream=True):
128
- yield cgi.get_stream_text(part)
129
-
130
-
131
- def _generate_openai(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 512,
132
- top_p: float = 0.6, repetition_penalty: float = 1.2) -> Generator[str, None, str]:
133
- """
134
- Generate a sequence of tokens based on a given prompt and history using Mistral client.
135
 
136
- Args:
137
- prompt (str): The initial prompt for the text generation.
138
- history (str): Context or history for the text generation.
139
- temperature (float, optional): The softmax temperature for sampling. Defaults to 0.9.
140
- max_new_tokens (int, optional): Maximum number of tokens to be generated. Defaults to 256.
141
- top_p (float, optional): Nucleus sampling probability. Defaults to 0.95.
142
- repetition_penalty (float, optional): Penalty for repeated tokens. Defaults to 1.0.
143
 
144
- Returns:
145
- Generator[str, None, str]: A generator yielding chunks of generated text.
146
- Returns a final string if an error occurs.
147
- """
148
-
149
- temperature = max(float(temperature), 1e-2) # Ensure temperature isn't too low
150
- top_p = float(top_p)
 
 
 
 
151
 
152
- generate_kwargs = {
153
- 'temperature': temperature,
154
- 'max_tokens': max_new_tokens,
155
- 'top_p': top_p,
156
- 'frequency_penalty': max(-2., min(repetition_penalty, 2.)),
157
- }
158
 
159
- formatted_prompt = format_prompt(prompt, "openai")
160
 
161
- try:
162
- stream = openai.ChatCompletion.create(
163
- model=LLM_NAME,
164
- messages=formatted_prompt,
165
- **generate_kwargs,
166
- stream=True
167
- )
168
- output = ""
169
- for chunk in stream:
170
- output += chunk.choices[0].delta.get("content", "")
171
- yield output
172
 
173
- except Exception as e:
174
- if "Too Many Requests" in str(e):
175
- print("ERROR: Too many requests on OpenAI client")
176
- gr.Warning("Unfortunately OpenAI is unable to process")
177
- return "Unfortunately, I am not able to process your request now."
178
- elif "You didn't provide an API key" in str(e):
179
- print("Authetification error:", str(e))
180
- gr.Warning("Authentication error: OpenAI key was either not provided or incorrect")
181
- return "Authentication error"
182
- else:
183
- print("Unhandled Exception:", str(e))
184
- gr.Warning("Unfortunately OpenAI is unable to process")
185
- return "I do not know what happened, but I couldn't understand you."
 
 
 
 
 
 
 
1
  from jinja2 import Environment, FileSystemLoader
2
 
 
3
  from gradio_app.backend.ChatGptInteractor import *
4
+ from gradio_app.backend.HuggingfaceGenerator import HuggingfaceGenerator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  env = Environment(loader=FileSystemLoader('gradio_app/templates'))
7
  context_template = env.get_template('context_template.j2')
8
  start_system_message = context_template.render(documents=[])
9
 
10
 
11
+ def construct_mistral_messages(context, history):
12
+ messages = []
13
+ for q, a in history:
14
+ if len(a) == 0: # the last message
15
+ q = context + f'\n\nQuery:\n\n{q}'
16
+ messages.append({
17
+ "role": "user",
18
+ "content": q,
19
+ })
20
+ if len(a) != 0: # some of the previous LLM answers
21
+ messages.append({
22
+ "role": "assistant",
23
+ "content": a,
24
+ })
25
+ return messages
26
+
27
+
28
  def construct_openai_messages(context, history):
29
  messages = [
30
  {
 
50
  return messages
51
 
52
 
53
+ def get_message_constructor(llm_name):
54
+ if llm_name == 'gpt-3.5-turbo':
55
+ return construct_openai_messages
56
+ if llm_name in ['mistralai/Mistral-7B-Instruct-v0.1', "GeneZC/MiniChat-3B"]:
57
+ return construct_mistral_messages
58
+ raise ValueError('Unknown LLM name')
 
 
 
 
59
 
 
 
 
 
 
 
 
60
 
61
+ def get_llm_generator(llm_name):
62
+ if llm_name == 'gpt-3.5-turbo':
63
+ cgi = ChatGptInteractor(
64
+ model_name=llm_name, max_tokens=512, temperature=0, stream=True
65
+ )
66
+ return cgi.chat_completion
67
+ if llm_name == 'mistralai/Mistral-7B-Instruct-v0.1':
68
+ hfg = HuggingfaceGenerator(
69
+ model_name=llm_name, temperature=0, max_new_tokens=512,
70
+ )
71
+ return hfg.generate
72
 
73
+ if llm_name == "GeneZC/MiniChat-3B":
74
+ hfg = HuggingfaceGenerator(
75
+ model_name=llm_name, temperature=0, max_new_tokens=250, stream=False,
76
+ )
77
+ return hfg.generate
78
+ raise ValueError('Unknown LLM name')
79
 
 
80
 
 
 
 
 
 
 
 
 
 
 
 
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
settings.py CHANGED
@@ -5,11 +5,11 @@ VECTOR_COLUMN_NAME = "embedding"
5
  TEXT_COLUMN_NAME = "text"
6
  DOCUMENT_PATH_COLUMN_NAME = "document_path"
7
 
8
- # LLM_NAME = "mistralai/Mistral-7B-Instruct-v0.1"
9
- LLM_NAME = "gpt-3.5-turbo"
10
  # EMBED_NAME = "sentence-transformers/all-MiniLM-L6-v2"
11
  EMBED_NAME = "text-embedding-ada-002"
12
 
 
 
13
 
14
  emb_sizes = {
15
  "sentence-transformers/all-MiniLM-L6-v2": 384,
@@ -17,8 +17,14 @@ emb_sizes = {
17
  "text-embedding-ada-002": 1536,
18
  }
19
 
 
 
 
 
 
20
  context_lengths = {
21
  "mistralai/Mistral-7B-Instruct-v0.1": 4096,
 
22
  "gpt-3.5-turbo": 4096,
23
  "sentence-transformers/all-MiniLM-L6-v2": 128,
24
  "thenlper/gte-large": 512,
 
5
  TEXT_COLUMN_NAME = "text"
6
  DOCUMENT_PATH_COLUMN_NAME = "document_path"
7
 
 
 
8
  # EMBED_NAME = "sentence-transformers/all-MiniLM-L6-v2"
9
  EMBED_NAME = "text-embedding-ada-002"
10
 
11
+ TOP_K_RANK = 50
12
+ TOP_K_RERANK = 5
13
 
14
  emb_sizes = {
15
  "sentence-transformers/all-MiniLM-L6-v2": 384,
 
17
  "text-embedding-ada-002": 1536,
18
  }
19
 
20
+ thresh_distances = {
21
+ "sentence-transformers/all-MiniLM-L6-v2": 1.2,
22
+ "text-embedding-ada-002": 0.5,
23
+ }
24
+
25
  context_lengths = {
26
  "mistralai/Mistral-7B-Instruct-v0.1": 4096,
27
+ "GeneZC/MiniChat-3B": 4096,
28
  "gpt-3.5-turbo": 4096,
29
  "sentence-transformers/all-MiniLM-L6-v2": 128,
30
  "thenlper/gte-large": 512,