# from openai import OpenAI | |
from dotenv import load_dotenv | |
import os | |
from utils.chat_prompts import CLASSIFICATION_INPUT_PROMPT, CLASSIFICATION_LANGUAGE_PROMPT | |
from google import genai | |
# client_jai = OpenAI( | |
# api_key=os.environ.get("JAI_API_KEY"), | |
# base_url=os.environ.get("CHAT_BASE_URL") | |
# ) | |
load_dotenv() | |
gemi = os.environ["GEMINI_API_KEY"] | |
client_jai = genai.Client(api_key=gemi) | |
# client_jai = client.models.generate_content( | |
# model="gemini-2.0-flash", | |
# contents="Explain how can I do RAG with Langchain using Gemini API", | |
# ) | |
# model = "jai-chat-1-3-2" | |
# model = "openthaigpt72b" | |
model = "gemini-2.0-flash" | |
temperature = 0.0 | |
def classify_input_type(user_input: str, history: list[str] = None) -> str: | |
""" | |
Classifies the user input as 'RAG' or 'Non-RAG' using the LLM, considering chat history. | |
""" | |
history_text = "\n".join(f"- {msg}" for msg in history[-3:]) if history else "None" | |
# Format the prompt using the ChatPromptTemplate | |
# This will return a list of Message objects (e.g., [SystemMessage(...)]) | |
formatted_messages = CLASSIFICATION_INPUT_PROMPT.format( | |
user_input=user_input, | |
chat_history=history_text | |
) | |
# Extract the string content from the first message | |
# Assumes the template is designed to produce a single message whose content is the full prompt | |
if not formatted_messages: | |
raise ValueError("CLASSIFICATION_INPUT_PROMPT did not produce any messages.") | |
prompt_content = formatted_messages | |
# print(f"DEBUG: Classify Input Prompt Content:\n{prompt_content}") # Optional: for debugging | |
# Use the existing client_jai.models.generate_content structure | |
response = client_jai.models.generate_content( | |
model=model, # Pass the model name string | |
contents=prompt_content, # Pass the formatted prompt string | |
# temperature=temperature, # Original was commented out | |
# stream=False, # Original was commented out | |
) | |
# return response.choices[0].message.content.strip() # This was for OpenAI client | |
return response.text.strip() # Assuming response.text is the correct way to get text for this client | |
def detect_language(user_input: str) -> str: | |
""" | |
Classifies the user input as 'Thai, Korean, English. | |
""" | |
# history_text = "\n".join(f"- {msg}" for msg in history[-3:]) if history else "None" | |
prompt = CLASSIFICATION_LANGUAGE_PROMPT.format( | |
user_input=user_input, | |
# chat_history=history_text | |
) | |
# response = client_jai.chat.completions.create( | |
response = client_jai.models.generate_content( | |
model=model, | |
# messages=[{"role": "user", "content": prompt}], | |
contents = prompt, | |
# temperature=temperature, | |
# stream=False, | |
) | |
# return response.choices[0].message.content.strip() | |
return response.text.strip() | |