poc_hey / input_classifier.py
Ing's picture
commit
8fd80b9
# from openai import OpenAI
from dotenv import load_dotenv
import os
from utils.chat_prompts import CLASSIFICATION_INPUT_PROMPT, CLASSIFICATION_LANGUAGE_PROMPT
from google import genai
# client_jai = OpenAI(
# api_key=os.environ.get("JAI_API_KEY"),
# base_url=os.environ.get("CHAT_BASE_URL")
# )
load_dotenv()
gemi = os.environ["GEMINI_API_KEY"]
client_jai = genai.Client(api_key=gemi)
# client_jai = client.models.generate_content(
# model="gemini-2.0-flash",
# contents="Explain how can I do RAG with Langchain using Gemini API",
# )
# model = "jai-chat-1-3-2"
# model = "openthaigpt72b"
model = "gemini-2.0-flash"
temperature = 0.0
def classify_input_type(user_input: str, history: list[str] = None) -> str:
"""
Classifies the user input as 'RAG' or 'Non-RAG' using the LLM, considering chat history.
"""
history_text = "\n".join(f"- {msg}" for msg in history[-3:]) if history else "None"
# Format the prompt using the ChatPromptTemplate
# This will return a list of Message objects (e.g., [SystemMessage(...)])
formatted_messages = CLASSIFICATION_INPUT_PROMPT.format(
user_input=user_input,
chat_history=history_text
)
# Extract the string content from the first message
# Assumes the template is designed to produce a single message whose content is the full prompt
if not formatted_messages:
raise ValueError("CLASSIFICATION_INPUT_PROMPT did not produce any messages.")
prompt_content = formatted_messages
# print(f"DEBUG: Classify Input Prompt Content:\n{prompt_content}") # Optional: for debugging
# Use the existing client_jai.models.generate_content structure
response = client_jai.models.generate_content(
model=model, # Pass the model name string
contents=prompt_content, # Pass the formatted prompt string
# temperature=temperature, # Original was commented out
# stream=False, # Original was commented out
)
# return response.choices[0].message.content.strip() # This was for OpenAI client
return response.text.strip() # Assuming response.text is the correct way to get text for this client
def detect_language(user_input: str) -> str:
"""
Classifies the user input as 'Thai, Korean, English.
"""
# history_text = "\n".join(f"- {msg}" for msg in history[-3:]) if history else "None"
prompt = CLASSIFICATION_LANGUAGE_PROMPT.format(
user_input=user_input,
# chat_history=history_text
)
# response = client_jai.chat.completions.create(
response = client_jai.models.generate_content(
model=model,
# messages=[{"role": "user", "content": prompt}],
contents = prompt,
# temperature=temperature,
# stream=False,
)
# return response.choices[0].message.content.strip()
return response.text.strip()