|
from typing import Dict, Any |
|
from langchain_openai import ChatOpenAI |
|
from langchain.prompts import ChatPromptTemplate |
|
from langchain.schema import StrOutputParser |
|
from scripts.rag_chat import build_general_qa_chain |
|
|
|
def build_router_chain(model_name=None): |
|
general_qa = build_general_qa_chain(model_name=model_name) |
|
llm = ChatOpenAI(model_name=model_name or "gpt-4o-mini", temperature=0.0) |
|
|
|
|
|
router_prompt = ChatPromptTemplate.from_template(""" |
|
You are a routing assistant for a chatbot. |
|
Classify the following user request into one of these categories: |
|
- "code" for programming or debugging |
|
- "summarize" for summary requests |
|
- "calculate" for math or numeric calculations |
|
- "general" for general Q&A using course files |
|
|
|
Return ONLY the category word. |
|
|
|
User request: {input} |
|
""") |
|
|
|
router_chain = router_prompt | llm | StrOutputParser() |
|
|
|
class Router: |
|
def invoke(self, input_dict: Dict[str, Any]): |
|
category = router_chain.invoke({"input": input_dict["input"]}).strip().lower() |
|
|
|
print(f"[ROUTER] User query routed to category: {category}") |
|
|
|
if category == "code": |
|
prompt = ChatPromptTemplate.from_template( |
|
"As a coding assistant, help with this Python question.\nQuestion: {input}\nAnswer:" |
|
) |
|
chain = prompt | llm | StrOutputParser() |
|
return {"result": chain.invoke({"input": input_dict["input"]})} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif category == "summarize": |
|
|
|
rag_result = general_qa({"query": input_dict["input"]}) |
|
|
|
|
|
source_docs = rag_result.get("source_documents", []) or [] |
|
|
|
|
|
from langchain.docstore.document import Document |
|
from scripts.summarizer import get_summarizer |
|
|
|
summarizer_chain = get_summarizer() |
|
|
|
|
|
docs = source_docs if source_docs else [Document(page_content=input_dict["input"])] |
|
|
|
|
|
out = summarizer_chain.invoke(docs) |
|
summary = out["output_text"] if isinstance(out, dict) and "output_text" in out else str(out) |
|
|
|
|
|
if source_docs: |
|
sources = sorted({str(d.metadata.get("source", "unknown")) for d in source_docs}) |
|
if sources: |
|
summary += f"\n\n📚 Sources: {', '.join(sources)}" |
|
|
|
return {"result": summary} |
|
|
|
|
|
elif category == "calculate": |
|
prompt = ChatPromptTemplate.from_template( |
|
"Solve the following calculation step-by-step:\n{input}" |
|
) |
|
chain = prompt | llm | StrOutputParser() |
|
return {"result": chain.invoke({"input": input_dict["input"]})} |
|
|
|
else: |
|
return general_qa({"query": input_dict["input"]}) |
|
|
|
return Router() |
|
|