File size: 2,864 Bytes
8fd80b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# from openai import OpenAI
from dotenv import load_dotenv
import os
from utils.chat_prompts import CLASSIFICATION_INPUT_PROMPT, CLASSIFICATION_LANGUAGE_PROMPT
from google import genai

# client_jai = OpenAI(
#     api_key=os.environ.get("JAI_API_KEY"),
#     base_url=os.environ.get("CHAT_BASE_URL")
# )
load_dotenv()
gemi = os.environ["GEMINI_API_KEY"]
client_jai = genai.Client(api_key=gemi)
# client_jai = client.models.generate_content(
#     model="gemini-2.0-flash",
#     contents="Explain how can I do RAG with Langchain using Gemini API",
# )

# model = "jai-chat-1-3-2"
# model = "openthaigpt72b"
model = "gemini-2.0-flash"
temperature = 0.0

def classify_input_type(user_input: str, history: list[str] = None) -> str:
    """
    Classifies the user input as 'RAG' or 'Non-RAG' using the LLM, considering chat history.
    """
    history_text = "\n".join(f"- {msg}" for msg in history[-3:]) if history else "None"

    # Format the prompt using the ChatPromptTemplate
    # This will return a list of Message objects (e.g., [SystemMessage(...)])
    formatted_messages = CLASSIFICATION_INPUT_PROMPT.format(
        user_input=user_input,
        chat_history=history_text
    )

    # Extract the string content from the first message
    # Assumes the template is designed to produce a single message whose content is the full prompt
    if not formatted_messages:
        raise ValueError("CLASSIFICATION_INPUT_PROMPT did not produce any messages.")
    prompt_content = formatted_messages
    
    # print(f"DEBUG: Classify Input Prompt Content:\n{prompt_content}") # Optional: for debugging

    # Use the existing client_jai.models.generate_content structure
    response = client_jai.models.generate_content(
        model=model,  # Pass the model name string
        contents=prompt_content, # Pass the formatted prompt string
        # temperature=temperature, # Original was commented out
        # stream=False, # Original was commented out
    )
    # return response.choices[0].message.content.strip() # This was for OpenAI client
    return response.text.strip() # Assuming response.text is the correct way to get text for this client



def detect_language(user_input: str) -> str:
    """
    Classifies the user input as 'Thai, Korean, English.
    """
    # history_text = "\n".join(f"- {msg}" for msg in history[-3:]) if history else "None"

    prompt = CLASSIFICATION_LANGUAGE_PROMPT.format(
        user_input=user_input,
        # chat_history=history_text
    )

    # response = client_jai.chat.completions.create(
    response = client_jai.models.generate_content(
        model=model,
        # messages=[{"role": "user", "content": prompt}],
        contents = prompt,
        # temperature=temperature,
        # stream=False,
    )
    # return response.choices[0].message.content.strip()
    return response.text.strip()