NCTCMumbai
commited on
Commit
•
a095bcc
1
Parent(s):
8a239f7
Update backend/query_llm.py
Browse files- backend/query_llm.py +122 -0
backend/query_llm.py
CHANGED
@@ -8,6 +8,12 @@ from typing import Any, Dict, Generator, List
|
|
8 |
|
9 |
from huggingface_hub import InferenceClient
|
10 |
from transformers import AutoTokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
#tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
|
13 |
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")
|
@@ -18,6 +24,10 @@ temperature = 0.5
|
|
18 |
top_p = 0.7
|
19 |
repetition_penalty = 1.2
|
20 |
|
|
|
|
|
|
|
|
|
21 |
OPENAI_KEY = getenv("OPENAI_API_KEY")
|
22 |
HF_TOKEN = getenv("HUGGING_FACE_HUB_TOKEN")
|
23 |
|
@@ -160,3 +170,115 @@ def generate_openai(prompt: str, history: str, temperature: float = 0.9, max_new
|
|
160 |
print("Unhandled Exception:", str(e))
|
161 |
gr.Warning("Unfortunately OpenAI is unable to process")
|
162 |
return "I do not know what happened, but I couldn't understand you."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
from huggingface_hub import InferenceClient
|
10 |
from transformers import AutoTokenizer
|
11 |
+
import google.generativeai as genai
|
12 |
+
import os
|
13 |
+
import PIL.Image
|
14 |
+
import gradio as gr
|
15 |
+
#from gradio_multimodalchatbot import MultimodalChatbot
|
16 |
+
from gradio.data_classes import FileData
|
17 |
|
18 |
#tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
|
19 |
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")
|
|
|
24 |
top_p = 0.7
|
25 |
repetition_penalty = 1.2
|
26 |
|
27 |
+
|
28 |
+
# Fetch an environment variable.
|
29 |
+
GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY')
|
30 |
+
genai.configure(api_key=GOOGLE_API_KEY)
|
31 |
OPENAI_KEY = getenv("OPENAI_API_KEY")
|
32 |
HF_TOKEN = getenv("HUGGING_FACE_HUB_TOKEN")
|
33 |
|
|
|
170 |
print("Unhandled Exception:", str(e))
|
171 |
gr.Warning("Unfortunately OpenAI is unable to process")
|
172 |
return "I do not know what happened, but I couldn't understand you."
|
173 |
+
|
174 |
+
|
175 |
+
def generate_gemini(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 4000,
|
176 |
+
top_p: float = 0.95, repetition_penalty: float = 1.0):
|
177 |
+
|
178 |
+
# For better security practices, retrieve sensitive information like API keys from environment variables.
|
179 |
+
|
180 |
+
|
181 |
+
# Initialize genai models
|
182 |
+
model = genai.GenerativeModel('gemini-pro')
|
183 |
+
api_key = os.environ.get("GOOGEL_API_KEY")
|
184 |
+
genai.configure(api_key=api_key)
|
185 |
+
#model = genai.GenerativeModel('gemini-pro')
|
186 |
+
#chat = model.start_chat(history=[])
|
187 |
+
|
188 |
+
|
189 |
+
|
190 |
+
candidate_count=1
|
191 |
+
max_output_tokens=max_new_tokens
|
192 |
+
temperature=temperature
|
193 |
+
top_p=top_p
|
194 |
+
|
195 |
+
|
196 |
+
formatted_prompt = format_prompt(prompt, "gemini")
|
197 |
+
|
198 |
+
try:
|
199 |
+
stream = model.generate_content(formatted_prompt,generation_config=genai.GenerationConfig(temperature=temperature,candidate_count=1 ,max_output_tokens=max_new_tokens,top_p=top_p),
|
200 |
+
stream=True)
|
201 |
+
output = ""
|
202 |
+
for response in stream:
|
203 |
+
output += response.text
|
204 |
+
yield output
|
205 |
+
|
206 |
+
except Exception as e:
|
207 |
+
if "Too Many Requests" in str(e):
|
208 |
+
print("ERROR: Too many requests on Mistral client")
|
209 |
+
gr.Warning("Unfortunately Mistral is unable to process")
|
210 |
+
return "Unfortunately, I am not able to process your request now."
|
211 |
+
elif "Authorization header is invalid" in str(e):
|
212 |
+
print("Authetification error:", str(e))
|
213 |
+
gr.Warning("Authentication error: HF token was either not provided or incorrect")
|
214 |
+
return "Authentication error"
|
215 |
+
else:
|
216 |
+
print("Unhandled Exception:", str(e))
|
217 |
+
gr.Warning("Unfortunately Mistral is unable to process")
|
218 |
+
return "I do not know what happened, but I couldn't understand you."
|
219 |
+
|
220 |
+
|
221 |
+
|
222 |
+
|
223 |
+
# def gemini(input, file, chatbot=[]):
|
224 |
+
# """
|
225 |
+
# Function to handle gemini model and gemini vision model interactions.
|
226 |
+
# Parameters:
|
227 |
+
# input (str): The input text.
|
228 |
+
# file (File): An optional file object for image processing.
|
229 |
+
# chatbot (list): A list to keep track of chatbot interactions.
|
230 |
+
# Returns:
|
231 |
+
# tuple: Updated chatbot interaction list, an empty string, and None.
|
232 |
+
# """
|
233 |
+
|
234 |
+
# messages = []
|
235 |
+
# print(chatbot)
|
236 |
+
|
237 |
+
# # Process previous chatbot messages if present
|
238 |
+
# if len(chatbot) != 0:
|
239 |
+
# for messages_dict in chatbot:
|
240 |
+
# user_text = messages_dict[0]['text']
|
241 |
+
# bot_text = messages_dict[1]['text']
|
242 |
+
# messages.extend([
|
243 |
+
# {'role': 'user', 'parts': [user_text]},
|
244 |
+
# {'role': 'model', 'parts': [bot_text]}
|
245 |
+
# ])
|
246 |
+
# messages.append({'role': 'user', 'parts': [input]})
|
247 |
+
# else:
|
248 |
+
# messages.append({'role': 'user', 'parts': [input]})
|
249 |
+
|
250 |
+
# try:
|
251 |
+
# response = model.generate_content(messages)
|
252 |
+
# gemini_resp = response.text
|
253 |
+
# # Construct list of messages in the required format
|
254 |
+
# user_msg = {"text": input, "files": []}
|
255 |
+
# bot_msg = {"text": gemini_resp, "files": []}
|
256 |
+
# chatbot.append([user_msg, bot_msg])
|
257 |
+
|
258 |
+
# except Exception as e:
|
259 |
+
# # Handling exceptions and raising error to the modal
|
260 |
+
# print(f"An error occurred: {e}")
|
261 |
+
# raise gr.Error(e)
|
262 |
+
|
263 |
+
# return chatbot, "", None
|
264 |
+
|
265 |
+
# # Define the Gradio Blocks interface
|
266 |
+
# with gr.Blocks() as demo:
|
267 |
+
# # Add a centered header using HTML
|
268 |
+
# gr.HTML("<center><h1>Gemini Chat PRO API</h1></center>")
|
269 |
+
|
270 |
+
# # Initialize the MultimodalChatbot component
|
271 |
+
# multi = MultimodalChatbot(value=[], height=800)
|
272 |
+
|
273 |
+
# with gr.Row():
|
274 |
+
# # Textbox for user input with increased scale for better visibility
|
275 |
+
# tb = gr.Textbox(scale=4, placeholder='Input text and press Enter')
|
276 |
+
|
277 |
+
|
278 |
+
|
279 |
+
# # Define the behavior on text submission
|
280 |
+
# tb.submit(gemini, [tb, multi], [multi, tb])
|
281 |
+
|
282 |
+
|
283 |
+
# # Launch the demo with a queue to handle multiple users
|
284 |
+
# demo.queue().launch()
|