|
import os |
|
import uuid |
|
from dotenv import load_dotenv |
|
|
|
from utils.chat_prompts import ( |
|
NON_RAG_PROMPT, |
|
RAG_CHAT_PROMPT_ENG, |
|
RAG_CHAT_PROMPT_TH, |
|
RAG_CHAT_PROMPT_KOREAN, |
|
QUERY_REWRITING_PROMPT_OBJ |
|
) |
|
from get_retriever_2 import final_retrievers |
|
from input_classifier import classify_input_type, detect_language |
|
|
|
from langchain_openai import ChatOpenAI |
|
from langchain_core.messages import HumanMessage, AIMessage |
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
|
|
from langfuse.callback import CallbackHandler |
|
|
|
|
|
load_dotenv() |
|
|
|
langfuse_handler = CallbackHandler( |
|
secret_key=os.environ['LANGFUSE_SECRET_KEY'], |
|
public_key=os.environ['LANGFUSE_PUBLIC_KEY'], |
|
host="https://us.cloud.langfuse.com" |
|
) |
|
|
|
class Chat: |
|
def __init__(self, model_name_llm="jai-chat-1-3-2", temperature=0): |
|
self.session_id = str(uuid.uuid4())[:8] |
|
self.model_name_llm = model_name_llm |
|
|
|
|
|
|
|
if model_name_llm == "jai-chat-1-3-2": |
|
self.llm_main = ChatOpenAI( |
|
model=model_name_llm, |
|
api_key=os.getenv("JAI_API_KEY"), |
|
base_url=os.getenv("CHAT_BASE_URL"), |
|
temperature=temperature, |
|
max_tokens=2048, |
|
max_retries=2, |
|
seed=13 |
|
) |
|
|
|
|
|
self.llm_rewriter = self.llm_main |
|
elif model_name_llm == "gemini-2.0-flash": |
|
GEMINI_API_KEY = os.getenv("GOOGLE_API_KEY") |
|
if not GEMINI_API_KEY: |
|
raise ValueError("GOOGLE_API_KEY (for Gemini) not found in environment variables.") |
|
|
|
common_gemini_config = { |
|
"google_api_key": GEMINI_API_KEY, |
|
"temperature": temperature, |
|
"max_output_tokens": 2048, |
|
"convert_system_message_to_human": True, |
|
|
|
|
|
} |
|
self.llm_main = ChatGoogleGenerativeAI( |
|
model="gemini-1.5-flash-latest", |
|
**common_gemini_config |
|
) |
|
|
|
|
|
|
|
self.llm_rewriter = ChatGoogleGenerativeAI( |
|
model="gemini-2.0-flash", |
|
**common_gemini_config |
|
) |
|
|
|
else: |
|
raise ValueError(f"Unsupported LLM model '{model_name_llm}'.") |
|
|
|
self.history = [] |
|
|
|
def append_history(self, message: [HumanMessage, AIMessage]): |
|
self.history.append(message) |
|
|
|
def get_formatted_history_for_llm(self, n_turns: int = 3) -> list: |
|
"""Returns the last n_turns of history as a list of Message objects.""" |
|
return self.history[-(n_turns * 2):] |
|
|
|
def get_stringified_history_for_rewrite(self, n_turns: int = 2) -> str: |
|
""" |
|
Formats the last n_turns of history (excluding the current un-added user input) |
|
as a string for the query rewriter prompt. |
|
""" |
|
history_to_format = self.history[-(n_turns * 2):] |
|
if not history_to_format: |
|
return "No history available." |
|
|
|
history_str_parts = [] |
|
for msg in history_to_format: |
|
role = "User" if isinstance(msg, HumanMessage) else "AI" |
|
history_str_parts.append(f"{role}: {msg.content}") |
|
return "\n".join(history_str_parts) |
|
|
|
def classify_input(self, user_input: str) -> str: |
|
history_content_list = [msg.content for msg in self.history] |
|
return classify_input_type(user_input, history=history_content_list) |
|
|
|
def format_docs(self, docs: list) -> str: |
|
return "\n\n".join(doc.page_content for doc in docs) |
|
|
|
def get_retriever_and_prompt(self, lang_code: str): |
|
""" |
|
Returns the appropriate retriever and RAG prompt based on the language. |
|
Handles potential errors if retriever or prompt is not found. |
|
""" |
|
retriever = final_retrievers.get(lang_code) |
|
|
|
if lang_code == "Thai": |
|
prompt_template = RAG_CHAT_PROMPT_TH |
|
elif lang_code == "Korean": |
|
prompt_template = RAG_CHAT_PROMPT_KOREAN |
|
elif lang_code == "English": |
|
prompt_template = RAG_CHAT_PROMPT_ENG |
|
else: |
|
print(f"Warning: Unsupported language '{lang_code}' for RAG. Defaulting to English.") |
|
retriever = final_retrievers.get('English') |
|
prompt_template = RAG_CHAT_PROMPT_ENG |
|
|
|
if not retriever: |
|
|
|
available_langs = list(final_retrievers.keys()) |
|
if available_langs: |
|
fallback_lang = available_langs[0] |
|
retriever = final_retrievers[fallback_lang] |
|
print(f"Warning: No retriever for '{lang_code}' or 'English'. Using first available: '{fallback_lang}'.") |
|
|
|
if fallback_lang == "Thai": prompt_template = RAG_CHAT_PROMPT_TH |
|
elif fallback_lang == "Korean": prompt_template = RAG_CHAT_PROMPT_KOREAN |
|
else: prompt_template = RAG_CHAT_PROMPT_ENG |
|
else: |
|
raise ValueError("CRITICAL: No retrievers configured at all.") |
|
|
|
if not prompt_template: |
|
raise ValueError(f"CRITICAL: No RAG prompt template found for language '{lang_code}' or effective fallback.") |
|
|
|
return retriever, prompt_template |
|
|
|
def _rewrite_query_if_needed(self, user_input: str, input_lang: str) -> str: |
|
""" |
|
Internal method to rewrite the user query using chat history if there is history. |
|
""" |
|
if not self.history: |
|
return user_input |
|
|
|
chat_history_str = self.get_stringified_history_for_rewrite(n_turns=2) |
|
|
|
try: |
|
rewrite_prompt_messages = QUERY_REWRITING_PROMPT_OBJ.format_messages( |
|
chat_history=chat_history_str, |
|
question=user_input |
|
) |
|
|
|
response = self.llm_rewriter.invoke(rewrite_prompt_messages) |
|
rewritten_query = response.content.strip() |
|
|
|
|
|
if rewritten_query and len(rewritten_query) < (len(user_input) + 250) and len(rewritten_query) > 0: |
|
print(f"Original query: '{user_input}', Rewritten query for retriever: '{rewritten_query}'") |
|
return rewritten_query |
|
else: |
|
print(f"Rewritten query validation failed or empty. Using original: '{user_input}'") |
|
return user_input |
|
except Exception as e: |
|
print(f"Error during query rewriting: {e}. Using original query.") |
|
return user_input |
|
|
|
def call_rag(self, user_input: str, input_lang: str) -> str: |
|
try: |
|
retriever, selected_rag_prompt = self.get_retriever_and_prompt(input_lang) |
|
except ValueError as e: |
|
print(f"Error in RAG setup: {e}") |
|
return f"Sorry, I encountered a configuration issue for {input_lang} RAG. Please contact support." |
|
|
|
|
|
|
|
query_for_retriever = self._rewrite_query_if_needed(user_input, input_lang) |
|
|
|
print(f"Retrieving documents for query: '{query_for_retriever}' (lang: {input_lang})") |
|
try: |
|
context_docs = retriever.invoke(query_for_retriever) |
|
except Exception as e: |
|
print(f"Error during document retrieval: {e}") |
|
return "Sorry, I had trouble finding relevant information for your query." |
|
|
|
print(f"Retrieved {len(context_docs)} documents. (Max possible after rerank: {os.getenv('FINAL_TOP_K_RERANK', 'N/A')})") |
|
|
|
|
|
|
|
context_str = self.format_docs(context_docs) |
|
|
|
|
|
history_for_llm_prompt = self.get_formatted_history_for_llm(n_turns=3) |
|
|
|
rag_input_data = { |
|
"question": user_input, |
|
"context": context_str, |
|
"history": history_for_llm_prompt |
|
} |
|
|
|
try: |
|
prompt_messages = selected_rag_prompt.format_messages(**rag_input_data) |
|
|
|
|
|
response = self.llm_main.invoke(prompt_messages, config={"callbacks": [langfuse_handler]}) |
|
return response.content.strip() |
|
except Exception as e: |
|
print(f"Error during RAG LLM call: {e}") |
|
return "Sorry, I encountered an error while generating the response." |
|
|
|
def call_non_rag(self, user_input: str, input_lang: str) -> str: |
|
|
|
try: |
|
if hasattr(NON_RAG_PROMPT, "format_messages"): |
|
prompt_messages = NON_RAG_PROMPT.format(user_input=user_input, input_lang=input_lang) |
|
elif isinstance(NON_RAG_PROMPT, str): |
|
formatted_prompt_str = NON_RAG_PROMPT.format(user_input=user_input, input_lang=input_lang) |
|
prompt_messages = [HumanMessage(content=formatted_prompt_str)] |
|
else: |
|
raise TypeError("NON_RAG_PROMPT is of an unsupported type.") |
|
|
|
|
|
response = self.llm_main.invoke(prompt_messages, config={"callbacks": [langfuse_handler]}) |
|
|
|
return response.content.strip() |
|
|
|
except Exception as e: |
|
print(f"Error during Non-RAG LLM call: {e}") |
|
return "Sorry, I had trouble processing your general request." |
|
|
|
def chat(self, user_input: str) -> str: |
|
print(f"\n\n-- USER INPUT: {user_input} --") |
|
|
|
|
|
|
|
|
|
try: |
|
input_lang_detected = detect_language(user_input) |
|
print(f"Language detected: {input_lang_detected}") |
|
except Exception as e: |
|
print(f"Error detecting language: {e}. Defaulting to Thai.") |
|
input_lang_detected = "Thai" |
|
|
|
history_before_current_input = self.history[:] |
|
|
|
|
|
self.append_history(HumanMessage(content=user_input)) |
|
|
|
|
|
try: |
|
input_type = self.classify_input(user_input) |
|
except Exception as e: |
|
print(f"Error classifying input type: {e}. Defaulting to Non-RAG.") |
|
input_type = "Non-RAG" |
|
|
|
ai_response_content = "" |
|
if input_type == "RAG": |
|
print("[RAG FLOW]") |
|
|
|
|
|
ai_response_content = self.call_rag_v2(user_input, input_lang_detected, history_before_current_input) |
|
else: |
|
print(f"[{input_type} FLOW (Treated as NON-RAG)]") |
|
ai_response_content = self.call_non_rag(user_input, input_lang_detected) |
|
|
|
|
|
self.append_history(AIMessage(content=ai_response_content)) |
|
|
|
print(f"AI:::: {ai_response_content}") |
|
return ai_response_content |
|
|
|
|
|
|
|
def call_rag_v2(self, user_input: str, input_lang: str, history_for_rewrite: list) -> str: |
|
try: |
|
retriever, selected_rag_prompt = self.get_retriever_and_prompt(input_lang) |
|
except ValueError as e: |
|
print(f"Error in RAG setup: {e}") |
|
return f"Sorry, I encountered a configuration issue for {input_lang} RAG. Please contact support." |
|
|
|
|
|
query_for_retriever = self._rewrite_query_if_needed_v2(user_input, history_for_rewrite) |
|
|
|
|
|
print(f"Retrieving documents for query: '{query_for_retriever}' (lang: {input_lang})") |
|
try: |
|
context_docs = retriever.invoke(query_for_retriever) |
|
except Exception as e: |
|
print(f"Error during document retrieval: {e}") |
|
return "Sorry, I had trouble finding relevant information for your query." |
|
|
|
print(f"Retrieved {len(context_docs)} documents.") |
|
|
|
context_str = self.format_docs(context_docs) |
|
print(f"\n----> CONTEXT DOCS (from call_rag_v2)\n{context_str}") |
|
|
|
|
|
history_for_llm_prompt = self.get_formatted_history_for_llm(n_turns=3) |
|
|
|
|
|
rag_input_data = { |
|
"question": user_input, |
|
"context": context_str, |
|
"history": history_for_llm_prompt |
|
} |
|
|
|
try: |
|
prompt_messages = selected_rag_prompt.format_messages(**rag_input_data) |
|
|
|
response = self.llm_main.invoke(prompt_messages, config={"callbacks": [langfuse_handler]}) |
|
|
|
return response.content.strip() |
|
|
|
except Exception as e: |
|
print(f"Error during RAG LLM call: {e}") |
|
return "Sorry, I encountered an error while generating the response." |
|
|
|
|
|
def _rewrite_query_if_needed_v2(self, user_input: str, history_list: list) -> str: |
|
if not history_list: |
|
return user_input |
|
|
|
|
|
history_str_parts = [] |
|
for msg in history_list[-(2*2):]: |
|
role = "User" if isinstance(msg, HumanMessage) else "AI" |
|
history_str_parts.append(f"{role}: {msg.content}") |
|
chat_history_str = "\n".join(history_str_parts) if history_str_parts else "No relevant history." |
|
|
|
try: |
|
rewrite_prompt_messages = QUERY_REWRITING_PROMPT_OBJ.format_messages( |
|
chat_history=chat_history_str, |
|
question=user_input |
|
) |
|
response = self.llm_rewriter.invoke(rewrite_prompt_messages) |
|
rewritten_query = response.content.strip() |
|
|
|
if rewritten_query and len(rewritten_query) < (len(user_input) + 250) and len(rewritten_query) > 0: |
|
print(f"Original query: '{user_input}', Rewritten query for retriever: '{rewritten_query}'") |
|
return rewritten_query |
|
else: |
|
print(f"Rewritten query validation failed. Using original: '{user_input}'") |
|
return user_input |
|
except Exception as e: |
|
print(f"Error during query rewriting: {e}. Using original query.") |
|
return user_input |