thenativefox commited on
Commit
93c49cb
1 Parent(s): 1c860fb

fix async issues

Browse files
Files changed (4) hide show
  1. app.py +49 -25
  2. backend/query_llm.py +22 -2
  3. backend/semantic_search.py +0 -1
  4. requirements.txt +1 -1
app.py CHANGED
@@ -12,8 +12,11 @@ from jinja2 import Environment, FileSystemLoader
12
  from backend.query_llm import generate_hf, generate_openai
13
  from backend.semantic_search import retrieve
14
 
 
 
15
 
16
  TOP_K = int(os.getenv("TOP_K", 4))
 
17
 
18
  proj_dir = Path(__file__).parent
19
  # Setting up the logging
@@ -29,13 +32,17 @@ template_html = env.get_template('template_html.j2')
29
 
30
 
31
  def add_text(history, text):
 
32
  history = [] if history is None else history
33
  history = history + [(text, None)]
 
34
  return history, gr.Textbox(value="", interactive=False)
35
 
36
 
37
  def bot(history, api_kind):
 
38
  query = history[-1][0]
 
39
 
40
  if not query:
41
  raise gr.Warning("Please submit a non-empty string as a prompt")
@@ -52,13 +59,14 @@ def bot(history, api_kind):
52
  # Create Prompt
53
  prompt = template.render(documents=documents, query=query)
54
  prompt_html = template_html.render(documents=documents, query=query)
 
55
 
56
  if api_kind == "HuggingFace":
57
- generate_fn = generate_hf
58
  elif api_kind == "OpenAI":
59
- generate_fn = generate_openai
60
  else:
61
- raise gr.Error(f"API {api_kind} is not supported")
62
 
63
  history[-1][1] = ""
64
  for character in generate_fn(prompt, history[:-1]):
@@ -68,40 +76,56 @@ def bot(history, api_kind):
68
 
69
  with gr.Blocks() as demo:
70
  chatbot = gr.Chatbot(
71
- [],
72
- elem_id="chatbot",
73
- avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg',
74
- 'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'),
75
- bubble_full_width=False,
76
- show_copy_button=True,
77
- show_share_button=True,
78
- )
79
 
80
  with gr.Row():
81
  txt = gr.Textbox(
82
- scale=3,
83
- show_label=False,
84
- placeholder="Enter text and press enter",
85
- container=False,
86
- )
87
  txt_btn = gr.Button(value="Submit text", scale=1)
88
 
89
  api_kind = gr.Radio(choices=["HuggingFace", "OpenAI"], value="HuggingFace")
90
 
91
  prompt_html = gr.HTML()
92
- # Turn off interactivity while generating if you click
93
- txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
94
- bot, [chatbot, api_kind], [chatbot, prompt_html])
95
 
96
- # Turn it back on
 
 
 
 
 
 
 
 
 
 
 
97
  txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
98
 
99
  # Turn off interactivity while generating if you hit enter
100
- txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
101
- bot, [chatbot, api_kind], [chatbot, prompt_html])
102
-
103
- # Turn it back on
 
 
 
 
 
 
 
104
  txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
105
 
106
  demo.queue()
107
- demo.launch(debug=True)
 
 
12
  from backend.query_llm import generate_hf, generate_openai
13
  from backend.semantic_search import retrieve
14
 
15
+ from dotenv import load_dotenv
16
+ load_dotenv()
17
 
18
  TOP_K = int(os.getenv("TOP_K", 4))
19
+ HF_TOKEN = os.getenv("HF_TOKEN")
20
 
21
  proj_dir = Path(__file__).parent
22
  # Setting up the logging
 
32
 
33
 
34
  def add_text(history, text):
35
+ logger.info(f'Adding text: {text}')
36
  history = [] if history is None else history
37
  history = history + [(text, None)]
38
+ logger.info(f'Updated history: {history}')
39
  return history, gr.Textbox(value="", interactive=False)
40
 
41
 
42
  def bot(history, api_kind):
43
+ logger.info(f'Bot function called with history: {history} and api_kind: {api_kind}')
44
  query = history[-1][0]
45
+ logger.info(f'Query: {query}')
46
 
47
  if not query:
48
  raise gr.Warning("Please submit a non-empty string as a prompt")
 
59
  # Create Prompt
60
  prompt = template.render(documents=documents, query=query)
61
  prompt_html = template_html.render(documents=documents, query=query)
62
+ logger.info(f'Prompt created: {prompt}')
63
 
64
  if api_kind == "HuggingFace":
65
+ generate_fn = generate_hf
66
  elif api_kind == "OpenAI":
67
+ generate_fn = generate_openai
68
  else:
69
+ raise gr.Error(f"API {api_kind} is not supported")
70
 
71
  history[-1][1] = ""
72
  for character in generate_fn(prompt, history[:-1]):
 
76
 
77
  with gr.Blocks() as demo:
78
  chatbot = gr.Chatbot(
79
+ [],
80
+ elem_id="chatbot",
81
+ avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg',
82
+ 'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'),
83
+ bubble_full_width=False,
84
+ show_copy_button=True,
85
+ show_share_button=True,
86
+ )
87
 
88
  with gr.Row():
89
  txt = gr.Textbox(
90
+ scale=3,
91
+ show_label=False,
92
+ placeholder="Enter text and press enter",
93
+ container=False,
94
+ )
95
  txt_btn = gr.Button(value="Submit text", scale=1)
96
 
97
  api_kind = gr.Radio(choices=["HuggingFace", "OpenAI"], value="HuggingFace")
98
 
99
  prompt_html = gr.HTML()
 
 
 
100
 
101
+ # Turn off interactivity while generating if you click
102
+ txt_msg = txt_btn.click(
103
+ fn=add_text,
104
+ inputs=[chatbot, txt],
105
+ outputs=[chatbot, txt],
106
+ queue=False
107
+ ).then(
108
+ fn=bot,
109
+ inputs=[chatbot, api_kind],
110
+ outputs=[chatbot, prompt_html],
111
+ queue=False
112
+ )
113
  txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
114
 
115
  # Turn off interactivity while generating if you hit enter
116
+ txt_msg = txt.submit(
117
+ fn=add_text,
118
+ inputs=[chatbot, txt],
119
+ outputs=[chatbot, txt],
120
+ queue=False
121
+ ).then(
122
+ fn=bot,
123
+ inputs=[chatbot, api_kind],
124
+ outputs=[chatbot, prompt_html],
125
+ queue=False
126
+ )
127
  txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
128
 
129
  demo.queue()
130
+ logger.info('Launching Gradio app...')
131
+ demo.launch(debug=True)
backend/query_llm.py CHANGED
@@ -1,23 +1,27 @@
1
  import openai
2
  import gradio as gr
3
  import os
 
4
 
5
  from typing import Any, Dict, Generator, List
6
 
7
  from huggingface_hub import InferenceClient
8
  from transformers import AutoTokenizer
9
 
 
 
10
 
11
  OPENAI_KEY = os.getenv("OPENAI_API_KEY")
12
  HF_TOKEN = os.getenv("HF_TOKEN")
13
  HF_MODEL = os.getenv("HF_MODEL")
14
  OPENAI_MODEL = os.getenv("OPENAI_MODEL")
 
15
 
16
  HF_CLIENT = InferenceClient(
17
  os.getenv("HF_MODEL"),
18
  token=HF_TOKEN
19
  )
20
- OAI_CLIENT = openai.Client(api_key=OPENAI_KEY)
21
  TOKENIZER = AutoTokenizer.from_pretrained(HF_MODEL)
22
 
23
  HF_GENERATE_KWARGS = {
@@ -81,10 +85,17 @@ def generate_hf(prompt: str, history: str) -> Generator[str, None, str]:
81
  details=True,
82
  return_full_text=False
83
  )
 
84
  output = ""
 
85
  for response in stream:
86
  output += response.token.text
 
 
87
  yield output
 
 
 
88
 
89
  except Exception as e:
90
  if "Too Many Requests" in str(e):
@@ -109,6 +120,14 @@ def generate_openai(prompt: str, history: str) -> Generator[str, None, str]:
109
  formatted_prompt = format_prompt(prompt, "openai")
110
 
111
  try:
 
 
 
 
 
 
 
 
112
  stream = OAI_CLIENT.chat.completions.create(
113
  model=os.getenv("OPENAI_MODEL"),
114
  messages=formatted_prompt,
@@ -122,9 +141,10 @@ def generate_openai(prompt: str, history: str) -> Generator[str, None, str]:
122
  yield output
123
 
124
  except Exception as e:
 
125
  if "Too Many Requests" in str(e):
126
  raise gr.Error("ERROR: Too many requests on OpenAI client")
127
  elif "You didn't provide an API key" in str(e):
128
  raise gr.Error("Authentication error: OpenAI key was either not provided or incorrect")
129
  else:
130
- raise gr.Error(f"Unhandled Exception: {str(e)}")
 
1
  import openai
2
  import gradio as gr
3
  import os
4
+ import logging
5
 
6
  from typing import Any, Dict, Generator, List
7
 
8
  from huggingface_hub import InferenceClient
9
  from transformers import AutoTokenizer
10
 
11
+ from dotenv import load_dotenv
12
+ load_dotenv()
13
 
14
  OPENAI_KEY = os.getenv("OPENAI_API_KEY")
15
  HF_TOKEN = os.getenv("HF_TOKEN")
16
  HF_MODEL = os.getenv("HF_MODEL")
17
  OPENAI_MODEL = os.getenv("OPENAI_MODEL")
18
+ OAI_CLIENT = openai.Client(api_key=OPENAI_KEY)
19
 
20
  HF_CLIENT = InferenceClient(
21
  os.getenv("HF_MODEL"),
22
  token=HF_TOKEN
23
  )
24
+ openai.api_key = OPENAI_KEY
25
  TOKENIZER = AutoTokenizer.from_pretrained(HF_MODEL)
26
 
27
  HF_GENERATE_KWARGS = {
 
85
  details=True,
86
  return_full_text=False
87
  )
88
+
89
  output = ""
90
+ final_output = []
91
  for response in stream:
92
  output += response.token.text
93
+ final_output.append(response.token.text)
94
+ logging.info(f"Current output: {output}")
95
  yield output
96
+
97
+ # Print the final output
98
+ logging.info(f"Final output: {''.join(final_output)}")
99
 
100
  except Exception as e:
101
  if "Too Many Requests" in str(e):
 
120
  formatted_prompt = format_prompt(prompt, "openai")
121
 
122
  try:
123
+ # response = OAI_CLIENT.chat.completions.create(
124
+ # model=os.getenv("OPENAI_MODEL"),
125
+ # messages=formatted_prompt,
126
+ # **OAI_GENERATE_KWARGS
127
+ # )
128
+ # logging.info("SIMPLE OUTPUT")
129
+ # logging.info(response.choices[0].message.content)
130
+
131
  stream = OAI_CLIENT.chat.completions.create(
132
  model=os.getenv("OPENAI_MODEL"),
133
  messages=formatted_prompt,
 
141
  yield output
142
 
143
  except Exception as e:
144
+ logging.error(f"Exception during OpenAI generation: {str(e)}")
145
  if "Too Many Requests" in str(e):
146
  raise gr.Error("ERROR: Too many requests on OpenAI client")
147
  elif "You didn't provide an API key" in str(e):
148
  raise gr.Error("Authentication error: OpenAI key was either not provided or incorrect")
149
  else:
150
+ raise gr.Error(f"Unhandled Exception: {str(e)}")
backend/semantic_search.py CHANGED
@@ -50,7 +50,6 @@ retriever = SentenceTransformer(os.getenv("EMB_MODEL"))
50
 
51
  def get_table_name():
52
  emb_model = os.getenv("EMB_MODEL")
53
- print(emb_model)
54
  if emb_model == "sentence-transformers/all-MiniLM-L6-v2":
55
  return MODEL1_STRATEGY1
56
  elif emb_model == "BAAI/bge-large-en-v1.5":
 
50
 
51
  def get_table_name():
52
  emb_model = os.getenv("EMB_MODEL")
 
53
  if emb_model == "sentence-transformers/all-MiniLM-L6-v2":
54
  return MODEL1_STRATEGY1
55
  elif emb_model == "BAAI/bge-large-en-v1.5":
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  lancedb==0.8.2
2
- openai==1.31.1
3
  langchain==0.2.5
4
  tiktoken
5
  sentence-transformers==3.0.0
 
1
  lancedb==0.8.2
2
+ openai==1.35.3
3
  langchain==0.2.5
4
  tiktoken
5
  sentence-transformers==3.0.0