NCTCMumbai commited on
Commit
a095bcc
1 Parent(s): 8a239f7

Update backend/query_llm.py

Browse files
Files changed (1) hide show
  1. backend/query_llm.py +122 -0
backend/query_llm.py CHANGED
@@ -8,6 +8,12 @@ from typing import Any, Dict, Generator, List
8
 
9
  from huggingface_hub import InferenceClient
10
  from transformers import AutoTokenizer
 
 
 
 
 
 
11
 
12
  #tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
13
  tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")
@@ -18,6 +24,10 @@ temperature = 0.5
18
  top_p = 0.7
19
  repetition_penalty = 1.2
20
 
 
 
 
 
21
  OPENAI_KEY = getenv("OPENAI_API_KEY")
22
  HF_TOKEN = getenv("HUGGING_FACE_HUB_TOKEN")
23
 
@@ -160,3 +170,115 @@ def generate_openai(prompt: str, history: str, temperature: float = 0.9, max_new
160
  print("Unhandled Exception:", str(e))
161
  gr.Warning("Unfortunately OpenAI is unable to process")
162
  return "I do not know what happened, but I couldn't understand you."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  from huggingface_hub import InferenceClient
10
  from transformers import AutoTokenizer
11
+ import google.generativeai as genai
12
+ import os
13
+ import PIL.Image
14
+ import gradio as gr
15
+ #from gradio_multimodalchatbot import MultimodalChatbot
16
+ from gradio.data_classes import FileData
17
 
18
  #tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
19
  tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")
 
24
  top_p = 0.7
25
  repetition_penalty = 1.2
26
 
27
+
28
+ # Fetch an environment variable.
29
+ GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY')
30
+ genai.configure(api_key=GOOGLE_API_KEY)
31
  OPENAI_KEY = getenv("OPENAI_API_KEY")
32
  HF_TOKEN = getenv("HUGGING_FACE_HUB_TOKEN")
33
 
 
170
  print("Unhandled Exception:", str(e))
171
  gr.Warning("Unfortunately OpenAI is unable to process")
172
  return "I do not know what happened, but I couldn't understand you."
173
+
174
+
175
+ def generate_gemini(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 4000,
176
+ top_p: float = 0.95, repetition_penalty: float = 1.0):
177
+
178
+ # For better security practices, retrieve sensitive information like API keys from environment variables.
179
+
180
+
181
+ # Initialize genai models
182
+ model = genai.GenerativeModel('gemini-pro')
183
+ api_key = os.environ.get("GOOGEL_API_KEY")
184
+ genai.configure(api_key=api_key)
185
+ #model = genai.GenerativeModel('gemini-pro')
186
+ #chat = model.start_chat(history=[])
187
+
188
+
189
+
190
+ candidate_count=1
191
+ max_output_tokens=max_new_tokens
192
+ temperature=temperature
193
+ top_p=top_p
194
+
195
+
196
+ formatted_prompt = format_prompt(prompt, "gemini")
197
+
198
+ try:
199
+ stream = model.generate_content(formatted_prompt,generation_config=genai.GenerationConfig(temperature=temperature,candidate_count=1 ,max_output_tokens=max_new_tokens,top_p=top_p),
200
+ stream=True)
201
+ output = ""
202
+ for response in stream:
203
+ output += response.text
204
+ yield output
205
+
206
+ except Exception as e:
207
+ if "Too Many Requests" in str(e):
208
+ print("ERROR: Too many requests on Mistral client")
209
+ gr.Warning("Unfortunately Mistral is unable to process")
210
+ return "Unfortunately, I am not able to process your request now."
211
+ elif "Authorization header is invalid" in str(e):
212
+ print("Authetification error:", str(e))
213
+ gr.Warning("Authentication error: HF token was either not provided or incorrect")
214
+ return "Authentication error"
215
+ else:
216
+ print("Unhandled Exception:", str(e))
217
+ gr.Warning("Unfortunately Mistral is unable to process")
218
+ return "I do not know what happened, but I couldn't understand you."
219
+
220
+
221
+
222
+
223
+ # def gemini(input, file, chatbot=[]):
224
+ # """
225
+ # Function to handle gemini model and gemini vision model interactions.
226
+ # Parameters:
227
+ # input (str): The input text.
228
+ # file (File): An optional file object for image processing.
229
+ # chatbot (list): A list to keep track of chatbot interactions.
230
+ # Returns:
231
+ # tuple: Updated chatbot interaction list, an empty string, and None.
232
+ # """
233
+
234
+ # messages = []
235
+ # print(chatbot)
236
+
237
+ # # Process previous chatbot messages if present
238
+ # if len(chatbot) != 0:
239
+ # for messages_dict in chatbot:
240
+ # user_text = messages_dict[0]['text']
241
+ # bot_text = messages_dict[1]['text']
242
+ # messages.extend([
243
+ # {'role': 'user', 'parts': [user_text]},
244
+ # {'role': 'model', 'parts': [bot_text]}
245
+ # ])
246
+ # messages.append({'role': 'user', 'parts': [input]})
247
+ # else:
248
+ # messages.append({'role': 'user', 'parts': [input]})
249
+
250
+ # try:
251
+ # response = model.generate_content(messages)
252
+ # gemini_resp = response.text
253
+ # # Construct list of messages in the required format
254
+ # user_msg = {"text": input, "files": []}
255
+ # bot_msg = {"text": gemini_resp, "files": []}
256
+ # chatbot.append([user_msg, bot_msg])
257
+
258
+ # except Exception as e:
259
+ # # Handling exceptions and raising error to the modal
260
+ # print(f"An error occurred: {e}")
261
+ # raise gr.Error(e)
262
+
263
+ # return chatbot, "", None
264
+
265
+ # # Define the Gradio Blocks interface
266
+ # with gr.Blocks() as demo:
267
+ # # Add a centered header using HTML
268
+ # gr.HTML("<center><h1>Gemini Chat PRO API</h1></center>")
269
+
270
+ # # Initialize the MultimodalChatbot component
271
+ # multi = MultimodalChatbot(value=[], height=800)
272
+
273
+ # with gr.Row():
274
+ # # Textbox for user input with increased scale for better visibility
275
+ # tb = gr.Textbox(scale=4, placeholder='Input text and press Enter')
276
+
277
+
278
+
279
+ # # Define the behavior on text submission
280
+ # tb.submit(gemini, [tb, multi], [multi, tb])
281
+
282
+
283
+ # # Launch the demo with a queue to handle multiple users
284
+ # demo.queue().launch()