diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..4e95459837b9e262a304efcff891ab53cc4fbf1c --- /dev/null +++ b/Dockerfile @@ -0,0 +1,42 @@ +# Use an official Python runtime as a parent image +FROM python:3.10-slim + +# Set environment variables +ENV PYTHONDONTWRITEBYTECODE 1 +ENV PYTHONUNBUFFERED 1 +ENV PORT 7860 +ENV HOME=/home/user +ENV TRANSFORMERS_CACHE=$HOME/.cache + +# Create a non-root user +RUN useradd -m -u 1000 user +USER user +WORKDIR $HOME/app + +# Set path +ENV PATH="/home/user/.local/bin:${PATH}" + +# Install system dependencies (as root) +USER root +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + libpq-dev \ + && rm -rf /var/lib/apt/lists/* +USER user + +# Install Python dependencies +COPY --chown=user requirements.txt . +RUN pip install --no-cache-dir --upgrade pip && \ + pip install --no-cache-dir -r requirements.txt + +# Pre-download the embedding model +RUN python -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('BAAI/bge-small-en-v1.5')" + +# Copy the rest of the backend code +COPY --chown=user . . + +# Expose the port the app runs on +EXPOSE 7860 + +# Run the application +CMD ["uvicorn", "core.app:app", "--host", "0.0.0.0", "--port", "7860"] diff --git a/add_what_happened_column.py b/add_what_happened_column.py new file mode 100644 index 0000000000000000000000000000000000000000..6ea673db2b6a493de7ee767689fe5ee04e1bb8f8 --- /dev/null +++ b/add_what_happened_column.py @@ -0,0 +1,54 @@ +import asyncio +from repositories.postgres_repo import get_pool + + +async def _column_exists(conn, table: str, column: str) -> bool: + return bool( + await conn.fetchval( + """ + SELECT EXISTS ( + SELECT 1 + FROM information_schema.columns + WHERE table_schema = current_schema() + AND table_name = $1 + AND column_name = $2 + ) + """, + table, + column, + ) + ) + + +async def main(): + pool = await get_pool() + async with pool.acquire() as conn: + has_pre_thinking = await _column_exists(conn, "messages", "pre_thinking") + has_what_happened = await _column_exists(conn, "messages", "what_happened") + + if has_what_happened and not has_pre_thinking: + await conn.execute( + "ALTER TABLE messages RENAME COLUMN what_happened TO pre_thinking" + ) + print("Renamed messages.what_happened to messages.pre_thinking") + elif has_what_happened and has_pre_thinking: + await conn.execute( + """ + UPDATE messages + SET pre_thinking = COALESCE(pre_thinking, what_happened) + WHERE what_happened IS NOT NULL + """ + ) + await conn.execute("ALTER TABLE messages DROP COLUMN what_happened") + print("Merged data into pre_thinking and dropped what_happened") + elif not has_pre_thinking: + await conn.execute( + "ALTER TABLE messages ADD COLUMN pre_thinking JSONB DEFAULT NULL" + ) + print("Added messages.pre_thinking") + else: + print("messages.pre_thinking already present; no changes needed") + await pool.close() + + +asyncio.run(main()) diff --git a/agents/__init__.py b/agents/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b8625ee077c12714a70a0e20a4e19d722ec03edf --- /dev/null +++ b/agents/__init__.py @@ -0,0 +1,38 @@ +# Agents module — re-exports agent graph builders and node functions +from agents.convo_agent import ( + get_tool_agent_graph, + agent_node, + tool_node, + should_continue, + run_tool_agent_stream, +) +from agents.orchestrator_agent import ( + get_deep_research_graph, + deep_research_node, + parallel_researchers_node, + aggregator_node, + critic_node, +) +from agents.debate_agent import ( + get_agent_a_persona, + get_agent_b_persona, + get_verifier_persona, + get_debate_model, + create_proposer_agent, + create_critic_agent, + create_verifier_agent, +) +from agents.coding_agent import ( + code_planner_node, + parallel_coders_node, + code_aggregator_node, + code_reviewer_node, + should_retry, + format_output_node, + get_node_coords, +) +from agents.smart_orchestrator import ( + classify_query, + get_standard_node_coords, + get_deep_research_node_coords, +) diff --git a/agents/coding_agent.py b/agents/coding_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..86e635de1b5c34894e16ac7bbb8318a6ac7251bb --- /dev/null +++ b/agents/coding_agent.py @@ -0,0 +1,410 @@ +import json +import asyncio +import re +from typing import Any, Callable +from core.llm_engine import get_llm +from core.config import AgentConfig +from schemas.schema import CodingAgentState, CodingSubtask +from langchain_core.messages import HumanMessage, SystemMessage +from utils.json_helpers import ( + sanitize_fenced_json, + extract_first_json_object, + load_json_object, + clamp_score, + normalize_text, + normalize_errors, + normalize_serious_mistakes, +) +from utils.graph_nodes import get_coding_node_coords + + +# ─── Node Coordinates (for frontend graph) ──────────────────────────────────── + +NODE_COORDS = get_coding_node_coords() + + +def get_node_coords() -> dict: + """Return the node coordinate map for frontend graph rendering.""" + return NODE_COORDS + + +# ─── Code Planner Node ─────────────────────────────────────────────────────── + +async def code_planner_node(state: CodingAgentState) -> dict: + """Decomposes the task into coding subtasks with shared interface.""" + llm = get_llm(temperature=AgentConfig.CodingAgent.PLANNER_TEMPERATURE, instant=True) + task = state["original_task"] + + prompt = f"""You are a senior software architect. Given the following coding task, break it into multiple independent subtasks that can be implemented in parallel. + + Task: {task} + + INSTRUCTIONS: + - Analyze what the task requires (could be any language: Python, HTML/CSS/JS, Java, etc.) + - Break it into Multiple logical, independent subtasks + - Each subtask should have a clear description of what it implements + - Include file names if the task naturally produces multiple files. + - The shared_contract should describe the overall structure/interfaces that all agents must respect + + Respond in EXACTLY this JSON format (no other text): + {{ + "subtasks": [ + {{ + "id": 1, + "description": "Clear description of what this subtask implements", + "signatures": ["Describe the key elements/functions this subtask should produce"] + }}, + {{ + "id": 2, + "description": "Clear description of what this subtask implements", + "signatures": ["Describe the key elements/functions this subtask should produce"] + }}, + {{ + "id": 3, + "description": "Clear description of what this subtask implements", + "signatures": ["Describe the key elements/functions this subtask should produce"] + }} + ], + "shared_contract": "Overall structure and interfaces that all agents must respect" + }} + + Each subtask should be independently implementable. The shared_contract must describe the overall structure so each coder knows how their work fits with others.""" + + response = await llm.ainvoke([ + SystemMessage(content="You are a software architect. Output only valid JSON."), + HumanMessage(content=prompt), + ]) + + try: + content = response.content.strip() + if "```json" in content: + content = content.split("```json")[1].split("```")[0].strip() + elif "```" in content: + content = content.split("```")[1].split("```")[0].strip() + parsed = json.loads(content) + subtasks = parsed.get("subtasks", []) + shared_contract = parsed.get("shared_contract", "") + except Exception: + subtasks = [ + {"id": 1, "description": f"Implement core data structures for: {task}", "signatures": ["Core structures and data types"]}, + {"id": 2, "description": f"Implement main algorithm/logic for: {task}", "signatures": ["Main logic and algorithms"]}, + {"id": 3, "description": f"Implement helper functions and utilities for: {task}", "signatures": ["Helper functions and utilities"]}, + ] + shared_contract = "All agents should implement their assigned parts with clear interfaces." + + while len(subtasks) < 3: + subtasks.append({"id": len(subtasks) + 1, "description": f"Additional implementation for: {task}", "signatures": []}) + + coding_subtasks: list[CodingSubtask] = [] + for st in subtasks[:3]: + coding_subtasks.append({ + "id": st.get("id", len(coding_subtasks) + 1), + "description": st.get("description", ""), + "signatures": st.get("signatures", []), + "result": None, + }) + + return { + "subtasks": coding_subtasks, + "shared_contract": shared_contract, + } + + +# ─── Parallel Coders Node ──────────────────────────────────────────────────── + +async def parallel_coders_node(state: CodingAgentState) -> dict: + """Runs 3 coding agents in parallel, each implementing its subtask.""" + subtasks = state["subtasks"] + shared_contract = state["shared_contract"] + original_task = state["original_task"] + + async def run_coder(subtask: CodingSubtask, coder_idx: int) -> str: + llm = get_llm(temperature=AgentConfig.CodingAgent.CODER_TEMPERATURE, change=False) + signatures_text = "\n".join(subtask.get("signatures", [])) + + prompt = f"""You are Coding Agent {coder_idx + 1}. You are part of a team implementing: {original_task} + + YOUR SPECIFIC TASK: + {subtask['description']} + + YOUR KEY ELEMENTS TO IMPLEMENT: + {signatures_text} + + SHARED CONTRACT (overall structure all agents must respect): + {shared_contract} + + IMPORTANT RULES: + 1. Only implement the parts assigned to you + 2. Follow the shared contract for naming, structure, and interfaces + 3. Do NOT implement parts assigned to other agents — just reference them if needed + 4. Write clean, well-documented code in whatever language best fits the task + 5. Include comments explaining complex logic + 6. After your code os written, provide a Detailed explanation separated by "---" + + OUTPUT FORMAT: + Write your code first, then after a line with "---", write a brief explanation of what your code does and how it works. + + Write your implementation:""" + + response = await llm.ainvoke([ + SystemMessage(content="You are a coding agent. Write clean, well-documented code in whatever language best fits the task. After your code, provide a brief explanation."), + HumanMessage(content=prompt), + ]) + return response.content + + tasks = [run_coder(subtasks[i], i) for i in range(min(3, len(subtasks)))] + results = await asyncio.gather(*tasks) + + cleaned_results = [] + for result in results: + cleaned = re.sub(r'```\w*\n', '', result) + cleaned = cleaned.replace('```', '').strip() + cleaned_results.append(cleaned) + results = cleaned_results + + while len(results) < 3: + results.append("# No implementation provided") + + return {"coder_results": list(results)} + + +# ─── Code Aggregator Node ───────────────────────────────────────────────────── + +async def code_aggregator_node(state: CodingAgentState) -> dict: + """Merges 3 coder outputs into one coherent codebase.""" + llm = get_llm(temperature=AgentConfig.CodingAgent.AGGREGATOR_TEMPERATURE, change=True) + coder_results = state["coder_results"] + shared_contract = state["shared_contract"] + original_task = state["original_task"] + review_errors = state.get("review_errors", []) + + coder_sections = "\n\n".join( + f"--- Coder {i+1} Output ---\n{result}" for i, result in enumerate(coder_results) + ) + + error_section = "" + if review_errors: + error_section = f""" + CRITICAL: The previous version had these errors that MUST be fixed: + {chr(10).join(f'- {e}' for e in review_errors)} + + You MUST address all of these errors in your merged output. + """ + + prompt = f"""You are a senior code aggregator. Your job is to merge 3 coding agent outputs into a working, production-ready codebase. + + Original Task: {original_task} + + Shared Contract (overall structure all agents agreed on): + {shared_contract} + + Coder Outputs (each coder includes code + explanation separated by "---"): + {coder_sections} + {error_section} + + CRITICAL INSTRUCTIONS: + 1. Output MULTIPLE files if the task naturally requires it + 2. Use this EXACT format to separate files — NO markdown fences, ONLY use the file separator: + # === FILE: filename === + [raw code content here, NO fences] + + # === FILE: another_file === + [raw code content here, NO fences] + + 3. Preserve the explanations from each coder. Write them in response in such manner that it preserve the Explantions as well as the Flow. + 4. Do NOT add new functionality — only merge, fix references, and ensure interoperability + 5. ABSOLUTELY DO NOT use markdown code fences (```) — output raw code only. + + Merged code:""" + + response = await llm.ainvoke([HumanMessage(content=prompt)]) + return {"merged_code": response.content} + + +# ─── Code Reviewer Node ─────────────────────────────────────────────────────── + +async def code_reviewer_node(state: CodingAgentState) -> dict: + """Reviews merged code for correctness and returns scores + errors.""" + llm = get_llm(temperature=AgentConfig.CodingAgent.REVIEWER_TEMPERATURE, change=True) + merged_code = state["merged_code"] + original_task = state["original_task"] + retry_count = state.get("retry_count", 0) + + prompt = f"""You are a strict code reviewer and quality gate. + + Evaluate the merged code against the original task and report only concrete, high-value findings. + + SCORING RUBRIC (integer 0-100): + - confidence: Correctness and production readiness of the implementation. + 0-39 = fundamentally broken/incomplete; 40-69 = partially correct with major gaps; + 70-89 = mostly correct with manageable issues; 90-100 = robust, correct, and production-ready. + - consistency: Internal coherence and contract alignment across files/interfaces. + 0-39 = contradictory/incompatible; 40-69 = mixed integration quality; + 70-89 = coherent with minor integration issues; 90-100 = cleanly integrated and consistent. + + ISSUE CRITERIA: + - "errors": blocker defects that must be fixed before acceptance. + - "serious_mistakes": security vulnerabilities, requirement misses, data-loss risks, + broken interfaces, or logic errors with high user impact. + Every issue must be specific and actionable. + + Original Task: {original_task} + + Code to Review: + {merged_code} + + Output contract (STRICT JSON ONLY; no markdown/backticks/preamble): + {{ + "confidence": 0, + "consistency": 0, + "friendly_feedback": "2-4 sentence summary with concrete next steps.", + "errors": ["Actionable blocker 1", "Actionable blocker 2"], + "serious_mistakes": [ + {{ + "severity": "high|critical", + "description": "What is wrong and where it appears.", + "action": "Specific fix to apply." + }} + ] + }} + If there are no blockers, return "errors": []. + If there are no serious mistakes, return "serious_mistakes": [].""" + + response = await llm.ainvoke([ + SystemMessage(content="You are a data-formatting agent. Output raw JSON only."), + HumanMessage(content=prompt) + ]) + parse_error = "" + parsed: dict[str, Any] = {} + try: + parsed = load_json_object(response.content if isinstance(response.content, str) else str(response.content)) + except ValueError as exc: + parse_error = str(exc) + + confidence_default = 25 if parse_error else 70 + consistency_default = 25 if parse_error else 70 + confidence = clamp_score(parsed.get("confidence"), default=confidence_default) + consistency = clamp_score( + parsed.get("consistency", parsed.get("logical_consistency", parsed.get("consistency_score"))), + default=consistency_default, + ) + errors = normalize_errors(parsed.get("errors", parsed.get("review_errors", []))) + if parse_error and not errors: + errors = [ + "Reviewer output was not valid JSON. Re-run merge/review and verify every requirement explicitly." + ] + + feedback = normalize_text(parsed.get("friendly_feedback", parsed.get("critic_feedback"))) + if not feedback: + feedback = ( + "Reviewer output failed strict JSON validation; a conservative failure response was applied." + if parse_error + else "Review complete. Address blockers and rerun the reviewer." + ) + + serious_mistakes = normalize_serious_mistakes(parsed.get("serious_mistakes", [])) + if parse_error and not serious_mistakes: + serious_mistakes = [ + { + "severity": "high", + "description": "Reviewer response was not valid JSON, reducing trust in the quality gate.", + "action": "Regenerate reviewer output with strict JSON-only compliance.", + } + ] + + return { + "confidence_score": confidence, + "logical_consistency": consistency, + "review_errors": errors, + "critic_feedback": feedback, + "serious_mistakes": serious_mistakes, + "retry_count": retry_count + 1, + } + + +# ─── Retry Logic ────────────────────────────────────────────────────────────── + +def should_retry(state: CodingAgentState) -> str: + """Decide whether to retry aggregation or format final output.""" + if state.get("review_errors") and state["retry_count"] < 2: + return "code_aggregator" + return "format_output" + + +# ─── Format Output Node ─────────────────────────────────────────────────────── + +def detect_language(content: str) -> str: + """Detect the programming language from code content.""" + content_lower = content.lower().strip() + if "" in content_lower or " dict: + """Parses merged code into a structured list of {filename, content, language} dicts.""" + merged_code = state["merged_code"] + + # Strip any stray markdown fences first + if "```" in merged_code: + fenced_blocks = re.findall(r'```(\w*)\n(.*?)```', merged_code, re.DOTALL) + if fenced_blocks: + merged_code = "\n\n".join(block_content.strip() for _, block_content in fenced_blocks) + else: + merged_code = re.sub(r'```\w*\n?', '', merged_code).replace('```', '').strip() + + parsed_files = [] + file_separator = "# === FILE:" + + if file_separator in merged_code: + file_blocks = merged_code.split(file_separator) + for file_block in file_blocks: + file_block = file_block.strip() + if not file_block: + continue + lines = file_block.split("\n") + filename = lines[0].strip().replace("===", "").strip() + if not filename: + continue + full_content = "\n".join(lines[1:]) + # Strip explanation after "---" + file_content = full_content + if "\n---\n" in full_content: + file_content = re.split(r'\n---\n', full_content, maxsplit=1)[0].strip() + if file_content: + parsed_files.append({ + "filename": filename, + "content": file_content, + "language": detect_language(file_content), + }) + else: + all_blocks = re.findall(r'```(\w*)\n(.*?)```', merged_code, re.DOTALL) + ext_map = {"html": "index.html", "css": "styles.css", "javascript": "script.js", "python": "main.py", "java": "Main.java"} + if len(all_blocks) > 1: + for idx, (lang_hint, block_content) in enumerate(all_blocks): + block_content = block_content.strip() + if not block_content: + continue + lang = lang_hint if lang_hint else detect_language(block_content) + filename = ext_map.get(lang, f"file_{idx + 1}.{lang or 'txt'}") + parsed_files.append({"filename": filename, "content": block_content, "language": lang}) + else: + # Single file fallback + code_content = merged_code.split("\n---\n", 1)[0].strip() if "\n---\n" in merged_code else merged_code + lang = detect_language(code_content) + parsed_files.append({ + "filename": ext_map.get(lang, "output.txt"), + "content": code_content, + "language": lang, + }) + + return {"parsed_files": parsed_files} + diff --git a/agents/convo_agent.py b/agents/convo_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..575e29df4d6ed13592fcb6cfa85b7b5c5528b63a --- /dev/null +++ b/agents/convo_agent.py @@ -0,0 +1,189 @@ +import logging +from typing import AsyncGenerator +from core.llm_engine import get_llm +from utils.tools import get_tools_list, get_tools_map +from schemas.schema import AgentState +from langchain_core.messages import HumanMessage, ToolMessage, AIMessage +from langgraph.graph import StateGraph, END + +logger = logging.getLogger(__name__) + + +def agent_node(state: AgentState): + """LLM agent node: binds tools and invokes the model.""" + llm = get_llm().bind_tools(get_tools_list()) + response = llm.invoke(state["messages"]) + return {"messages": [response]} + + +def tool_node(state: AgentState): + """Execute tool calls from the last agent message.""" + last_msg = state["messages"][-1] + tool_messages = [] + for tc in last_msg.tool_calls: + tool_fn = get_tools_map().get(tc["name"]) + if tool_fn: + try: + result = tool_fn.invoke(tc["args"]) + tool_messages.append( + ToolMessage(content=str(result), tool_call_id=tc["id"]) + ) + except Exception as e: + logger.error(f"[tool_node] Error invoking {tc['name']}: {e}") + tool_messages.append( + ToolMessage(content=f"Tool execution failed: {str(e)}", tool_call_id=tc["id"]) + ) + else: + tool_messages.append( + ToolMessage(content="Tool not found.", tool_call_id=tc["id"]) + ) + return {"messages": tool_messages} + + +def should_continue(state: AgentState): + """Route to tools if the agent made tool calls, otherwise end.""" + last_msg = state["messages"][-1] + if hasattr(last_msg, "tool_calls") and last_msg.tool_calls: + return "tools" + return END + + +# ─── Build Graph ────────────────────────────────────────────────────────────── + + +def _build_graph(): + graph = StateGraph(AgentState) + graph.add_node("agent", agent_node) + graph.add_node("tools", tool_node) + graph.set_entry_point("agent") + graph.add_conditional_edges("agent", should_continue, {"tools": "tools", END: END}) + graph.add_edge("tools", "agent") + return graph.compile() + + +_graph = _build_graph() + + +def get_tool_agent_graph(): + """Return the compiled LangGraph tool-calling agent graph.""" + return _graph + + +# ─── Streaming Execution (3-Phase Async Generator) ─────────────────────────── + + +async def run_tool_agent_stream( + messages: list, +) -> AsyncGenerator[dict, None]: + """ + Execute the tool-calling agent with real-time streaming (3-phase). + + Yields event dicts: + {"type": "token", "content": "...", "phase": "initial"|"final"} + {"type": "tool_start", "tool_name": "...", "tool_args": {...}} + {"type": "tool_end", "tool_name": "...", "tool_output": "..."} + {"type": "complete", "answer": "...", "tools_used": [...], "messages": [...]} + + 3-Phase flow: + Phase 1: Stream the initial LLM response (text tokens or tool call decision) + Phase 2: Execute any tool calls locally + Phase 3: Stream the final LLM summary after tool results + """ + llm = get_llm().bind_tools(get_tools_list()) + + # ── Phase 1: Stream the initial response ────────────────────────────── + logger.info("[convo_agent] Phase 1: Streaming initial LLM response...") + first_response = None + + for chunk in llm.stream(messages): + if first_response is None: + first_response = chunk + else: + first_response += chunk + + if chunk.content: + yield {"type": "token", "content": chunk.content, "phase": "initial"} + + messages.append(first_response) + has_tool_calls = bool(first_response.tool_calls) + logger.info(f"[convo_agent] Phase 1 complete. Tool calls: {has_tool_calls}") + + # ── Phase 2: Execute tools if needed ────────────────────────────────── + tools_used = [] + if has_tool_calls: + logger.info( + f"[convo_agent] Phase 2: Executing {len(first_response.tool_calls)} tool call(s)..." + ) + for tc in first_response.tool_calls: + tool_name = tc["name"] + tool_args = tc["args"] + logger.info( + f"[convo_agent] Executing tool: {tool_name} with args: {tool_args}" + ) + + yield { + "type": "tool_start", + "tool_name": tool_name, + "tool_args": tool_args, + } + + tool_fn = get_tools_map().get(tool_name) + if tool_fn: + try: + tool_output = tool_fn.invoke(tool_args) + logger.info( + f"[convo_agent] Tool {tool_name} returned: {str(tool_output)[:200]}..." + ) + messages.append( + ToolMessage(content=str(tool_output), tool_call_id=tc["id"]) + ) + tools_used.append({"tool": tool_name, "args": tool_args}) + + yield { + "type": "tool_end", + "tool_name": tool_name, + "tool_output": str(tool_output), + } + except Exception as e: + logger.error(f"[convo_agent] Tool {tool_name} execution error: {e}") + error_msg = f"Tool execution failed: {str(e)}" + messages.append( + ToolMessage(content=error_msg, tool_call_id=tc["id"]) + ) + yield { + "type": "tool_end", + "tool_name": tool_name, + "tool_output": error_msg, + } + else: + error_msg = f"Tool not found: {tool_name}" + logger.warning(f"[convo_agent] {error_msg}") + messages.append(ToolMessage(content=error_msg, tool_call_id=tc["id"])) + + # ── Phase 3: Final stream (summary after tool results) ──────────────── + initial_answer = first_response.content if first_response.content else "" + final_answer = "" + if has_tool_calls: + logger.info("[convo_agent] Phase 3: Streaming final LLM summary...") + for chunk in llm.stream(messages): + if chunk.content: + final_answer += chunk.content + yield {"type": "token", "content": chunk.content, "phase": "final"} + logger.info( + f"[convo_agent] Phase 3 complete. Final answer length: {len(final_answer)}" + ) + else: + final_answer = initial_answer + logger.info( + f"[convo_agent] No tool calls needed. Answer length: {len(final_answer)}" + ) + + # Combine initial + final answer for complete response + complete_answer = initial_answer + final_answer if has_tool_calls else final_answer + + yield { + "type": "complete", + "answer": complete_answer, + "tools_used": tools_used, + "messages": messages, + } diff --git a/agents/debate_agent.py b/agents/debate_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..e6e748d7690a8ba005c45b2e099470f7f26f43d9 --- /dev/null +++ b/agents/debate_agent.py @@ -0,0 +1,64 @@ +from typing import Any +from autogen_agentchat.agents import AssistantAgent +from core import get_autogen_groq_client + +MODEL = "openai/gpt-oss-20b" + +AGENT_A_PERSONA = """You are Agent Proposer, a confident debater who argues FOR the given topic. +You present compelling arguments, use evidence, and directly counter your opponent's points. +Be assertive but intellectually rigorous. Keep responses to 3-4 sentences max.""" + +AGENT_B_PERSONA = """You are Agent Critic, a sharp debater who argues AGAINST the given topic. +You challenge assumptions, present counterarguments, and expose weaknesses in opposing views. +Be direct and incisive. Keep responses to 3-4 sentences max.""" + +VERIFIER_PERSONA = """You are Agent Verifier, an impartial debate judge. +Evaluate both sides using argument quality, evidence quality, and internal consistency. +Return a concise verdict in 6-7 sentences, clearly stating which side argued better and why.""" + + +def get_agent_a_persona() -> str: + """Return the Proposer persona.""" + return AGENT_A_PERSONA + + +def get_agent_b_persona() -> str: + """Return the Critic persona.""" + return AGENT_B_PERSONA + + +def get_debate_model() -> str: + """Return the debate model name.""" + return MODEL + + +def get_verifier_persona() -> str: + """Return the verifier persona.""" + return VERIFIER_PERSONA + + +def create_proposer_agent() -> Any: + """Create a proposer debate agent using AutoGen.""" + return AssistantAgent( + name="agent_proposer", + model_client=get_autogen_groq_client(model=MODEL, temperature=0.7), + system_message=AGENT_A_PERSONA, + ) + + +def create_critic_agent() -> Any: + """Create a critic debate agent using AutoGen.""" + return AssistantAgent( + name="agent_critic", + model_client=get_autogen_groq_client(model=MODEL, temperature=0.7), + system_message=AGENT_B_PERSONA, + ) + + +def create_verifier_agent() -> Any: + """Create a verifier debate agent using AutoGen.""" + return AssistantAgent( + name="agent_verifier", + model_client=get_autogen_groq_client(model=MODEL, temperature=0.3), + system_message=VERIFIER_PERSONA, + ) diff --git a/agents/orchestrator_agent.py b/agents/orchestrator_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..db85792d6c246ecb014c757c9d8e7db1a3a0517f --- /dev/null +++ b/agents/orchestrator_agent.py @@ -0,0 +1,299 @@ +import asyncio +import json +from typing import Any +from core.llm_engine import get_llm +from core.config import AgentConfig +from schemas.schema import OrchestratorState +from langchain_core.messages import HumanMessage, SystemMessage +from langgraph.graph import StateGraph, END +from utils.json_helpers import ( + sanitize_fenced_json, + extract_first_json_object, + load_json_object, + clamp_score, + normalize_text, + normalize_serious_mistakes, +) + + +# ─── Parallel Orchestrator (2 Researchers + Aggregator) ────────────────────── + +def deep_research_node(state: OrchestratorState): + """Analyzes the task and creates 2 researcher subtasks with different perspectives + 1 aggregator.""" + task = state["original_task"] + llm = get_llm(temperature=AgentConfig.OrchestratorAgent.DECOMPOSER_TEMPERATURE) + + prompt = f"""You are a task decomposer. Your job is to analyze the given task and break it into exactly 2 research assignments with different perspectives, plus 1 aggregation task. + +IMPORTANT RULES: +- The context provided is ONLY for your background understanding. Do NOT include it in your output. +- Each researcher gets a DIFFERENT angle/perspective on the task. +- Researcher 1 should focus on facts, data, core concepts, and established knowledge. +- Researcher 2 should focus on alternative viewpoints, debates, controversies, and edge cases. +- Keep descriptions concise but specific to each researcher's angle. +- Output ONLY valid JSON. + +Task: {task} + +Respond in EXACTLY this JSON format (no markdown, no backticks): +{{ + "researcher_1": "Specific research assignment for Researcher 1 focusing on core facts and data", + "researcher_2": "Specific research assignment for Researcher 2 focusing on alternative perspectives and debates" +}}""" + + response = llm.invoke([ + SystemMessage(content="You are a task decomposer. Output only valid JSON."), + HumanMessage(content=prompt), + ]) + + try: + content = response.content.strip() + if content.startswith("```json"): + content = content.replace("```json", "", 1) + if content.endswith("```"): + content = content[:-3] + content = content.strip() + parsed = json.loads(content) + r1_desc = parsed.get("researcher_1", f"Comprehensive research on: {task}") + r2_desc = parsed.get("researcher_2", f"Alternative perspectives on: {task}") + except Exception: + r1_desc = f"Comprehensive research on: {task}. Gather facts, data, background information, and existing analysis from reliable sources." + r2_desc = f"Alternative perspectives and debates on: {task}. Identify conflicting viewpoints, controversies, gaps in knowledge, and complementary information." + + subtasks = [ + { + "id": 1, + "description": r1_desc, + "agent_type": "researcher", + "result": None, + }, + { + "id": 2, + "description": r2_desc, + "agent_type": "researcher", + "result": None, + }, + { + "id": 3, + "description": "Synthesize findings from both researchers into a comprehensive final report with the required 7-section structure.", + "agent_type": "aggregator", + "result": None, + }, + ] + logs = state.get("step_logs", []) + logs.append(f"🎯 Deep Research decomposed task into {len(subtasks)} subtasks (2x researcher + aggregator)") + return {"subtasks": subtasks, "step_logs": logs} + + +async def parallel_researchers_node(state: OrchestratorState): + """Executes all researcher subtasks in parallel and updates their results.""" + subtasks = state["subtasks"] + researcher_indices = [i for i, st in enumerate(subtasks) if st["agent_type"] == "researcher"] + logs = list(state.get("step_logs", [])) + logs.append(f"🚀 Launching {len(researcher_indices)} researcher agents in parallel") + + original_task = state["original_task"] + + async def run_researcher(description: str) -> str: + llm = get_llm(temperature=AgentConfig.OrchestratorAgent.RESEARCHER_TEMPERATURE) + prompt = f"""Original task: {original_task} + + Your research assignment: {description} + + Conduct thorough research and provide detailed findings. Include facts, data, sources, and any relevant context. Be comprehensive and precise.""" + + response = await llm.ainvoke([ + SystemMessage(content="You are a research agent. Your job is to thoroughly research the given topic and provide comprehensive, unique, and factual information."), + HumanMessage(content=prompt), + ]) + return response.content + + tasks = [run_researcher(subtasks[i]["description"]) for i in researcher_indices] + results = await asyncio.gather(*tasks) + + new_subtasks = list(subtasks) + for idx, result in zip(researcher_indices, results): + st = dict(new_subtasks[idx]) + st["result"] = result + new_subtasks[idx] = st + + logs.append(f"✅ All {len(researcher_indices)} researchers completed") + return {"subtasks": new_subtasks, "step_logs": logs} + + +async def aggregator_node(state: OrchestratorState): + """Takes aggregator subtask and both researcher results, produces the final 7-section report.""" + subtasks = state["subtasks"] + aggregator_subtask = next((st for st in subtasks if st["agent_type"] == "aggregator"), None) + logs = list(state.get("step_logs", [])) + + if not aggregator_subtask: + logs.append("❌ No aggregator subtask found") + return {"final_result": "Error: Aggregator agent missing.", "step_logs": logs} + + researcher_results = [st["result"] for st in subtasks if st["agent_type"] == "researcher"] + researcher_texts = "\n\n".join( + f"--- Researcher {i+1} ---\n{res}" for i, res in enumerate(researcher_results) + ) + + llm = get_llm(temperature=AgentConfig.OrchestratorAgent.AGGREGATOR_TEMPERATURE, change=True) + prompt = f"""You are a professional research writer. Synthesize the following research findings into a comprehensive, structured final report. + + Original Task: {state['original_task']} + + Research Inputs: + {researcher_texts} + + Write a Bery High detailed report using EXACTLY these sections (use proper markdown headings): + + 1. Executive Summary + Provide a Detailed overview of the topic and main conclusions. + + 2. Key Findings + List the most important discoveries in bullet points. Also write the relatted things for the discoveries too. + + 3. Evidence From Sources + Present detailed evidence, data, and source citations. + + 4. Trends / Analysis + Analyze trends, patterns, implications, and provide insights. This Insight should be simple and Very Detailed. + + 5. Contradictions or Debates + Highlight any conflicting information, controversies, or areas of disagreement. + + 6. Conclusion + Summarize the key points and provide a forward-looking perspective. + + 7. Sources / Citations + List Major sources referenced in the research. If none were provided, note that. + + Use clear markdown formatting with ## for major sections and ### for subsections where appropriate. Be comprehensive but avoid redundancy.""" + + response = await llm.ainvoke([HumanMessage(content=prompt)]) + logs.append("✍️ Aggregator agent synthesized final report with 7-section structure") + + new_subtasks = list(subtasks) + for i, st in enumerate(new_subtasks): + if st["agent_type"] == "aggregator": + new_subtasks[i] = dict(st) + new_subtasks[i]["result"] = response.content + break + + return {"final_result": response.content, "subtasks": new_subtasks, "step_logs": logs} + + +async def critic_node(state: OrchestratorState): + """Evaluates the aggregator's final output, assigns confidence and logical consistency scores.""" + logs = list(state.get("step_logs", [])) + final_result = state["final_result"] + + llm = get_llm(temperature=AgentConfig.OrchestratorAgent.CRITIC_TEMPERATURE, change=True) + prompt = f"""You are a strict research quality critic. + + Evaluate the report against the task and produce one JSON object only. + + SCORING RUBRIC (integer 0-100): + - confidence: How trustworthy and well-supported the report is. + 0-39 = major factual/support gaps; 40-69 = partial support with notable gaps; + 70-89 = mostly supported and reliable; 90-100 = strongly supported, precise, and robust. + - consistency: Internal logic and alignment with the requested task/structure. + 0-39 = contradictory or off-task; 40-69 = mixed coherence; 70-89 = coherent with minor issues; + 90-100 = fully coherent, complete, and on-task. + + SERIOUS MISTAKE CRITERIA: + Include only high-impact problems (fabricated claims, major missing sections, direct contradictions, + unsafe or misleading guidance, or conclusions unsupported by evidence). + Every serious mistake must describe the issue and a concrete corrective action. + + Report: + {final_result} + + Task: + {state['original_task']} + + Output contract (STRICT JSON ONLY; no markdown/backticks/preamble): + {{ + "confidence": 0, + "consistency": 0, + "friendly_feedback": "2-4 sentence actionable summary with at least one concrete next step.", + "serious_mistakes": [ + {{ + "severity": "high|critical", + "description": "What is wrong and where it appears.", + "action": "Specific fix the writer should apply." + }} + ] + }} + If none, return "serious_mistakes": [] exactly.""" + + response = await llm.ainvoke([ + SystemMessage(content="You are a Self-Reflective Critic agent. Output raw JSON only."), + HumanMessage(content=prompt) + ]) + logs.append("🔬 Critic agent evaluated output quality") + + parse_error = "" + parsed: dict[str, Any] = {} + try: + parsed = load_json_object(response.content if isinstance(response.content, str) else str(response.content)) + except ValueError as exc: + parse_error = str(exc) + logs.append(f"⚠️ Critic parsing error: {parse_error}") + + confidence_default = 30 if parse_error else 70 + consistency_default = 30 if parse_error else 70 + confidence = clamp_score(parsed.get("confidence"), default=confidence_default) + consistency = clamp_score( + parsed.get("consistency", parsed.get("logical_consistency", parsed.get("consistency_score"))), + default=consistency_default, + ) + feedback = normalize_text(parsed.get("friendly_feedback", parsed.get("critic_feedback"))) + if not feedback: + feedback = ( + "Critic output failed strict JSON validation. Re-run evaluation with strict format compliance." + if parse_error + else "Review complete. Address highlighted weaknesses to improve confidence and consistency." + ) + serious_mistakes = normalize_serious_mistakes(parsed.get("serious_mistakes", [])) + if parse_error and not serious_mistakes: + serious_mistakes = [ + { + "severity": "high", + "description": "Critic output was not valid JSON, so quality scoring may be unreliable.", + "action": "Re-run critic with strict JSON-only output and reassess the report.", + } + ] + + return { + "critic_confidence": confidence, + "critic_logical_consistency": consistency, + "critic_feedback": feedback, + "serious_mistakes": serious_mistakes, + "step_logs": logs, + } + + +# ─── Build Graph ────────────────────────────────────────────────────────────── + +def _build_deep_research_graph(): + graph = StateGraph(OrchestratorState) + graph.add_node("deep_research", deep_research_node) + graph.add_node("parallel_researchers", parallel_researchers_node) + graph.add_node("aggregator", aggregator_node) + graph.add_node("critic", critic_node) + + graph.set_entry_point("deep_research") + graph.add_edge("deep_research", "parallel_researchers") + graph.add_edge("parallel_researchers", "aggregator") + graph.add_edge("aggregator", "critic") + graph.add_edge("critic", END) + + return graph.compile() + + +_deep_research_graph = _build_deep_research_graph() + + +def get_deep_research_graph(): + """Return the compiled deep_research graph.""" + return _deep_research_graph diff --git a/agents/smart_orchestrator.py b/agents/smart_orchestrator.py new file mode 100644 index 0000000000000000000000000000000000000000..42a2d282322efb4ee4a8ae727cd392e5d0e3c1ab --- /dev/null +++ b/agents/smart_orchestrator.py @@ -0,0 +1,64 @@ +from core.llm_engine import get_llm +from core.config import AgentConfig +from langchain_core.messages import HumanMessage, SystemMessage +from utils.graph_nodes import get_standard_node_coords, get_deep_research_node_coords + + +# ─── Node Coordinates for Standard & Deep Research Paths ────────────────────── + +STANDARD_NODE_COORDS = get_standard_node_coords() +DEEP_RESEARCH_NODE_COORDS = get_deep_research_node_coords() + + +def get_standard_node_coords() -> dict: + """Return node coordinates for the standard path.""" + return STANDARD_NODE_COORDS + + +def get_deep_research_node_coords() -> dict: + """Return node coordinates for the deep research path.""" + return DEEP_RESEARCH_NODE_COORDS + + +# ─── Query Classifier ───────────────────────────────────────────────────────── + +async def classify_query(task: str) -> tuple[str, str, str]: + """Classify the task into one of three paths and return problem understanding for code tasks.""" + llm = get_llm(temperature=AgentConfig.SmartOrchestrator.ROUTER_TEMPERATURE, instant=True) + + prompt = f"""You are a query router. Classify the following user query into exactly one category. + + Query: {task} + + Categories: + - standard: Simple factual as well as the Detailed questions, greetings, quick calculations, single-step as well as multi step tasks. Use this In most of normal QA. + - deep_research: Questions requiring very hard multi-perspective research, analysis, comparisons, explanations of complex topics, if the topic needs very good research then select this mode. + - code: Requests to write, implement, debug, or generate code, algorithms, data structures + + Respond in EXACTLY this format: + PATH: [standard|deep_research|code] + UNDERSTANDING: [If PATH is code, provide a brief 2-3 sentence problem understanding explaining what the problem asks for and what the expected solution should look like. If PATH is not code, write "N/A"] + REASON: [brief explanation of why this path was chosen]""" + + response = await llm.ainvoke([ + SystemMessage(content="You are a query classifier. Output only the specified format."), + HumanMessage(content=prompt), + ]) + + content = response.content.strip() + path = "standard" + reason = "Default classification" + problem_understanding = "N/A" + + for line in content.split("\n"): + line = line.strip() + if line.startswith("PATH:"): + extracted = line.split(":")[1].strip().lower() + if extracted in ("standard", "deep_research", "code"): + path = extracted + elif line.startswith("UNDERSTANDING:"): + problem_understanding = line.split(":", 1)[1].strip() + elif line.startswith("REASON:"): + reason = line.split(":", 1)[1].strip() + + return path, reason, problem_understanding \ No newline at end of file diff --git a/core/__init__.py b/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..66a375878e6577d70166cc6f2c0e7cd27bde263b --- /dev/null +++ b/core/__init__.py @@ -0,0 +1,12 @@ +# Core module — re-exports commonly used items +# NOTE: app is NOT imported here to avoid circular imports. +# Import app directly: from core.app import app +from core.llm_engine import get_llm, get_client +from core.autogen_client import get_autogen_groq_client +from core.auth import hash_password, verify_password, create_token, decode_token, get_current_user +from core.exceptions import ( + DatabaseError, + LLMError, + QdrantError, + register_exception_handlers, +) diff --git a/core/__pycache__/__init__.cpython-310.pyc b/core/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c959eda96d72202252c22b51a7a8fd82d508d7f Binary files /dev/null and b/core/__pycache__/__init__.cpython-310.pyc differ diff --git a/core/__pycache__/app.cpython-310.pyc b/core/__pycache__/app.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44a8b7495b3959a76cb560603ed6175996557232 Binary files /dev/null and b/core/__pycache__/app.cpython-310.pyc differ diff --git a/core/__pycache__/auth.cpython-310.pyc b/core/__pycache__/auth.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65de9e211b6fe76c0b42afd48878280f5b1176f1 Binary files /dev/null and b/core/__pycache__/auth.cpython-310.pyc differ diff --git a/core/__pycache__/autogen_client.cpython-310.pyc b/core/__pycache__/autogen_client.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef5ee2d26a5ca5c2108d3ae76d0a306de98a9d78 Binary files /dev/null and b/core/__pycache__/autogen_client.cpython-310.pyc differ diff --git a/core/__pycache__/config.cpython-310.pyc b/core/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5411a0ff7d3325c6a18b93000cef465377a52979 Binary files /dev/null and b/core/__pycache__/config.cpython-310.pyc differ diff --git a/core/__pycache__/exceptions.cpython-310.pyc b/core/__pycache__/exceptions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17cbe50e116df14fae817be14e2e8eccccc9d63f Binary files /dev/null and b/core/__pycache__/exceptions.cpython-310.pyc differ diff --git a/core/__pycache__/llm_engine.cpython-310.pyc b/core/__pycache__/llm_engine.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc168ef2e72c61d920ce8ce36ea94e6606a68061 Binary files /dev/null and b/core/__pycache__/llm_engine.cpython-310.pyc differ diff --git a/core/__pycache__/middleware.cpython-310.pyc b/core/__pycache__/middleware.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ffdb9bd456abae8ca1e2ff95d7091e42ef3eee3 Binary files /dev/null and b/core/__pycache__/middleware.cpython-310.pyc differ diff --git a/core/app.py b/core/app.py new file mode 100644 index 0000000000000000000000000000000000000000..21b8c3e5cb448e3ce950e54feb48bb31a852e7e2 --- /dev/null +++ b/core/app.py @@ -0,0 +1,82 @@ +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from core.exceptions import register_exception_handlers +from repositories import init_db, close_pool, init_qdrant_collections +from routers import ( + auth_router, + chat_router, + history_router, + pdf_router, + debate_router, + admin_router, + reflection_router +) + +app = FastAPI(title="AI Agents Platform") + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_methods=["*"], + allow_headers=["*"], +) + +# Register global exception handlers +register_exception_handlers(app) + +# Include all routers +app.include_router(auth_router) +app.include_router(chat_router) +app.include_router(history_router) +app.include_router(pdf_router) +app.include_router(debate_router) +app.include_router(admin_router) +app.include_router(reflection_router) + + +# ─── Startup / Shutdown Events ──────────────────────────────────────────────── + +@app.on_event("startup") +async def startup(): + """Initialize database and Qdrant collections on server start.""" + await init_db() + await close_pool() + try: + init_qdrant_collections() + print("[app] Qdrant collections initialized.") + except Exception as e: + print(f"[app] Qdrant initialization skipped (unavailable): {e}") + print("[app] Server started successfully.") + + +@app.on_event("shutdown") +async def shutdown(): + """Close database pool on shutdown.""" + await close_pool() + print("[app] Server shut down.") + + +# ─── Root ────────────────────────────────────────────────────────────────────── + +@app.get("/") +async def root(): + """Root endpoint - API landing page""" + return { + "service": "Agentrix.io API", + "status": "operational", + "version": "2.6.0", + "endpoints": { + "auth_register": "/auth/register (POST)", + "auth_login": "/auth/login (POST)", + "chat": "/chat (POST)", + "orchestrator": "/orchestrator/task (POST)", + "smart_orchestrator": "/smart-orchestrator/stream (POST)", + "debate": "/debate/stream (GET)", + "upload": "/upload-pdf (POST)", + "memory_pdfs": "/memory/pdfs (GET)", + "history": "/history (GET)", + "reflection": "/reflection/summary (GET)", + "docs": "/docs", + "redoc": "/redoc", + }, + } diff --git a/core/auth.py b/core/auth.py new file mode 100644 index 0000000000000000000000000000000000000000..2e1ebb8394d07d1bd556b3a72007257f2c23b11f --- /dev/null +++ b/core/auth.py @@ -0,0 +1,58 @@ +import os +from datetime import datetime, timedelta, timezone + +from dotenv import load_dotenv +from fastapi import HTTPException, Header +from jose import JWTError, jwt +import bcrypt + +load_dotenv() + +# ─── Password Hashing ───────────────────────────────────────────────────────── + + +def hash_password(plain: str) -> str: + """Hash a plaintext password using bcrypt.""" + return bcrypt.hashpw(plain.encode("utf-8"), bcrypt.gensalt()).decode("utf-8") + + +def verify_password(plain: str, hashed: str) -> bool: + """Verify a plaintext password against a bcrypt hash.""" + return bcrypt.checkpw(plain.encode("utf-8"), hashed.encode("utf-8")) + + +# ─── JWT Token ──────────────────────────────────────────────────────────────── + +JWT_SECRET = os.getenv("JWT_SECRET", "agentrix-default-secret-change-me") +JWT_ALGORITHM = "HS256" +JWT_EXPIRY_DAYS = 7 + + +def create_token(user_id: str) -> str: + """Create a JWT token with 7-day expiry.""" + payload = { + "sub": user_id, + "exp": datetime.now(timezone.utc) + timedelta(days=JWT_EXPIRY_DAYS), + "iat": datetime.now(timezone.utc), + } + return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM) + + +def decode_token(token: str) -> str: + """Decode a JWT token and return the user_id. Raises HTTPException on failure.""" + try: + payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM]) + user_id: str = payload.get("sub") + if user_id is None: + raise HTTPException(status_code=401, detail="Invalid token: missing subject") + return user_id + except JWTError: + raise HTTPException(status_code=401, detail="Invalid or expired token") + + +async def get_current_user(authorization: str = Header(...)) -> str: + """FastAPI dependency: extracts and validates Bearer token from Authorization header.""" + if not authorization.startswith("Bearer "): + raise HTTPException(status_code=401, detail="Invalid authorization header format") + token = authorization.split(" ", 1)[1] + return decode_token(token) \ No newline at end of file diff --git a/core/autogen_client.py b/core/autogen_client.py new file mode 100644 index 0000000000000000000000000000000000000000..dba2fe7c64270b1fc96861157e04e1c6a3a06419 --- /dev/null +++ b/core/autogen_client.py @@ -0,0 +1,30 @@ +import os + +from autogen_ext.models.openai import OpenAIChatCompletionClient + + +GROQ_OPENAI_BASE_URL = "https://api.groq.com/openai/v1" + + +def get_autogen_groq_client(model: str, temperature: float = 0.7) -> OpenAIChatCompletionClient: + """Create an AutoGen-compatible OpenAI client targeting Groq.""" + api_key = os.getenv("GROQ_API_KEY") + if not api_key: + raise ValueError("GROQ_API_KEY is required for AutoGen debate execution.") + + # AutoGen 0.4 requires model_info for non-standard OpenAI model names. + # We provide standard capabilities for Groq-hosted models. + model_info = { + "vision": False, + "function_calling": True, + "json_output": True, + "family": "gpt-4", # Using gpt-4 family as a safe baseline for capabilities + } + + return OpenAIChatCompletionClient( + model=model, + api_key=api_key, + base_url=GROQ_OPENAI_BASE_URL, + temperature=temperature, + model_info=model_info, + ) diff --git a/core/config.py b/core/config.py new file mode 100644 index 0000000000000000000000000000000000000000..536323b8a3f818ae46a2ddec9d6a312b035985d2 --- /dev/null +++ b/core/config.py @@ -0,0 +1,126 @@ +""" +Centralized configuration module for LLM parameters, scoring, and constants. + +Provides single source of truth for: +- LLM temperature profiles for different use cases +- Scoring and confidence thresholds +- Valid severity levels for mistakes/issues +- Graph node coordinates (eventually) +- Agent-specific parameters + +Usage: + from core.config import LLMConfig, ScoringConfig + + llm = get_llm(temperature=LLMConfig.PRECISE) + score = clamp_score(value, default=ScoringConfig.DEFAULT) +""" + + +class LLMConfig: + """ + LLM temperature profiles for different use cases. + + Temperature controls randomness: 0 = deterministic, 1 = highly random. + - Structured: Planning, routing, critics (need precision) + - Precise: Code generation, technical writing (mostly deterministic) + - Balanced: Research, analysis, general reasoning (mix of precision & exploration) + - Creative: Brainstorming, alternatives (high exploration) + """ + # Planning & Decision Making (deterministic) + STRUCTURED = 0.0 + + # Code & Technical Output (precise but not rigid) + PRECISE = 0.1 + + # Research & Analysis (balanced exploration + precision) + BALANCED = 0.3 + + # Brainstorming & Alternatives (exploratory) + CREATIVE = 0.7 + + +class ScoringConfig: + """ + Configuration for confidence and consistency scoring. + + Used for evaluating LLM output quality: + - Range: [MIN, MAX] + - Default: Used as fallback when parsing fails + """ + MIN = 0 + MAX = 100 + DEFAULT = 50 + + # Quality thresholds for evaluation feedback + POOR_THRESHOLD = 39 # 0-39: poor quality, requires rework + PARTIAL_THRESHOLD = 69 # 40-69: partial quality, has gaps + GOOD_THRESHOLD = 89 # 70-89: good quality, manageable issues + # 90-100: excellent quality, production-ready + + +class SeverityConfig: + """ + Valid severity levels for mistakes/issues found in LLM output. + + Used for categorizing problems: + - low: Minor issues, cosmetic + - medium: Moderate issues, affects usability + - high: Significant issues, affects correctness + - critical: Breaking issues, prevents function + """ + VALID_LEVELS = {"low", "medium", "high", "critical"} + DEFAULT = "high" + + +class CriticConfig: + """ + Configuration for critic/reviewer agents. + + Used when evaluating LLM-generated code, research, or other output. + """ + # Default confidence score when parsing fails + PARSE_ERROR_CONFIDENCE = 25 + PARSE_ERROR_CONSISTENCY = 25 + + # Default confidence score when parsing succeeds + SUCCESS_BASELINE_CONFIDENCE = 70 + SUCCESS_BASELINE_CONSISTENCY = 70 + + +class AgentConfig: + """ + Agent-specific configuration and parameters. + """ + + class CodingAgent: + """Code generation and review agent configuration.""" + # Temperature for planning/routing decisions + PLANNER_TEMPERATURE = LLMConfig.PRECISE + + # Temperature for parallel code generation + CODER_TEMPERATURE = LLMConfig.PRECISE + + # Temperature for code merging/aggregation + AGGREGATOR_TEMPERATURE = LLMConfig.PRECISE + + # Temperature for code review/criticism + REVIEWER_TEMPERATURE = LLMConfig.STRUCTURED + + class OrchestratorAgent: + """Deep research orchestrator agent configuration.""" + # Temperature for research decomposition + DECOMPOSER_TEMPERATURE = LLMConfig.BALANCED + + # Temperature for researcher agents + RESEARCHER_TEMPERATURE = LLMConfig.CREATIVE + + # Temperature for research aggregation + AGGREGATOR_TEMPERATURE = LLMConfig.BALANCED + + # Temperature for research quality criticism + CRITIC_TEMPERATURE = LLMConfig.STRUCTURED + + class SmartOrchestrator: + """Smart routing orchestrator configuration.""" + # Temperature for query classification and routing + ROUTER_TEMPERATURE = LLMConfig.STRUCTURED diff --git a/core/exceptions.py b/core/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..8a2f8f2e554744fb31fb37768e4b0ab389220d08 --- /dev/null +++ b/core/exceptions.py @@ -0,0 +1,54 @@ +from fastapi import Request +from fastapi.responses import JSONResponse + + +class DatabaseError(Exception): + """Raised when a database operation fails.""" + + def __init__(self, detail: str = "Database operation failed"): + self.detail = detail + super().__init__(self.detail) + + +class LLMError(Exception): + """Raised when an LLM call fails.""" + + def __init__(self, detail: str = "LLM call failed"): + self.detail = detail + super().__init__(self.detail) + + +class QdrantError(Exception): + """Raised when a Qdrant operation fails.""" + + def __init__(self, detail: str = "Qdrant operation failed"): + self.detail = detail + super().__init__(self.detail) + + +async def database_error_handler(request: Request, exc: DatabaseError): + """Handle DatabaseError globally.""" + return JSONResponse(status_code=500, content={"detail": f"Database error: {exc.detail}"}) + + +async def llm_error_handler(request: Request, exc: LLMError): + """Handle LLMError globally.""" + return JSONResponse(status_code=502, content={"detail": f"LLM error: {exc.detail}"}) + + +async def qdrant_error_handler(request: Request, exc: QdrantError): + """Handle QdrantError globally.""" + return JSONResponse(status_code=502, content={"detail": f"Qdrant error: {exc.detail}"}) + + +async def generic_exception_handler(request: Request, exc: Exception): + """Catch-all handler for unhandled exceptions.""" + return JSONResponse(status_code=500, content={"detail": f"Internal server error: {str(exc)}"}) + + +def register_exception_handlers(app): + """Register all custom exception handlers on a FastAPI app.""" + app.add_exception_handler(DatabaseError, database_error_handler) + app.add_exception_handler(LLMError, llm_error_handler) + app.add_exception_handler(QdrantError, qdrant_error_handler) + app.add_exception_handler(Exception, generic_exception_handler) \ No newline at end of file diff --git a/core/llm_engine.py b/core/llm_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..2ac70c8619c7dce7a151b9b3349e05ea3b8b6bd8 --- /dev/null +++ b/core/llm_engine.py @@ -0,0 +1,35 @@ +import os +from groq import Groq +from dotenv import load_dotenv +from langchain_groq import ChatGroq +from langchain_openai import ChatOpenAI + +load_dotenv() + + +def get_llm(temperature: float = 0.1, change: bool = True, instant = False): + """Return a ChatGroq LLM instance. change=True uses Llama-4, change=False uses GPT-oss.""" + model = "openai/gpt-oss-20b" + if change: + model = "meta-llama/llama-4-scout-17b-16e-instruct" + if instant: + model = "llama-3.1-8b-instant" + + return ChatGroq( + model=model, + api_key=os.getenv("GROQ_API_KEY"), + temperature=temperature, + ) + + +def get_client() -> Groq: + """Return a raw Groq client (for streaming completions).""" + return Groq(api_key=os.getenv("GROQ_API_KEY")) + +def coding_llm(temperature: float = 0.2): + return ChatOpenAI( + model="stepfun/step-3.5-flash:free", + base_url="https://openrouter.ai/api/v1", + api_key=os.getenv("OPENROUTER_API_KEY"), + temperature=temperature + ) \ No newline at end of file diff --git a/create_pdf_id_index.py b/create_pdf_id_index.py new file mode 100644 index 0000000000000000000000000000000000000000..0e6a2fc81e10936908efdf2e4278a833017d9653 --- /dev/null +++ b/create_pdf_id_index.py @@ -0,0 +1,26 @@ +""" +One-off script to create the missing pdf_id payload index on the pdf_chunks collection. +Run this once from the backend directory: python create_pdf_id_index.py +""" +import os +from dotenv import load_dotenv +from qdrant_client import QdrantClient +from qdrant_client.models import PayloadSchemaType + +load_dotenv() + +client = QdrantClient( + url=os.getenv("QDRANT_CLIENT"), + api_key=os.getenv("QDRANT_API_KEY"), +) + +try: + client.create_payload_index( + collection_name="pdf_chunks", + field_name="pdf_id", + field_schema=PayloadSchemaType.KEYWORD, + ) + print("[OK] Created 'pdf_id' keyword index on 'pdf_chunks' collection.") +except Exception as e: + print(f"[INFO] Index creation result: {e}") + print("(If this says 'already exists', the index was already created — you're good!)") diff --git a/delete.py b/delete.py new file mode 100644 index 0000000000000000000000000000000000000000..9c7dc4b6e185caa765b583946f6c578ecdc63495 --- /dev/null +++ b/delete.py @@ -0,0 +1,66 @@ +""" +delete.py — Drop all PostgreSQL tables and Qdrant collections. +Run this to reset the database before applying new schema changes. +Usage: python delete.py +""" + +import os +import asyncio +import asyncpg +from dotenv import load_dotenv +from qdrant_client import QdrantClient + +load_dotenv() + + +async def drop_all_tables(): + """Drop all tables, enums, and indexes from PostgreSQL.""" + conn = await asyncpg.connect(os.getenv("DATABASE_URL")) + try: + # Drop tables (CASCADE handles FK dependencies) + await conn.execute("DROP TABLE IF EXISTS chunk_retrieval_log CASCADE;") + await conn.execute("DROP TABLE IF EXISTS detected_mistakes CASCADE;") + await conn.execute("DROP TABLE IF EXISTS debate_sessions CASCADE;") + await conn.execute("DROP TABLE IF EXISTS messages CASCADE;") + await conn.execute("DROP TABLE IF EXISTS conversations CASCADE;") + await conn.execute("DROP TABLE IF EXISTS users CASCADE;") + + # Drop enum types + await conn.execute("DROP TYPE IF EXISTS conv_type CASCADE;") + await conn.execute("DROP TYPE IF EXISTS debate_winner CASCADE;") + await conn.execute("DROP TYPE IF EXISTS mistake_severity CASCADE;") + await conn.execute("DROP TYPE IF EXISTS reasoning_type CASCADE;") + + print("[delete] All PostgreSQL tables and enums dropped.") + finally: + await conn.close() + + +def drop_qdrant_collections(): + """Delete Qdrant collections.""" + client = QdrantClient( + url=os.getenv("QDRANT_CLIENT"), + api_key=os.getenv("QDRANT_API_KEY"), + ) + collections = [c.name for c in client.get_collections().collections] + for name in collections: + client.delete_collection(name) + print(f"[delete] Deleted Qdrant collection: {name}") + + if not collections: + print("[delete] No Qdrant collections found.") + + +async def main(): + print("=" * 50) + print(" Agentrix.io — Database Cleanup Script") + print("=" * 50) + + await drop_all_tables() + drop_qdrant_collections() + + print("\n[delete] Cleanup complete. Run `python main.py` to reinitialize.") + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/repositories/__init__.py b/repositories/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..117f756f976da7e9a9de53ff2a330ac0854057c6 --- /dev/null +++ b/repositories/__init__.py @@ -0,0 +1,35 @@ +# Repositories module — re-exports all database CRUD functions +from repositories.postgres_repo import ( + init_db, + close_pool, + get_pool, + create_user, + get_user_by_email, + get_user_by_id, + create_conversation, + append_message, + create_debate_session, + get_debate_session_by_conversation_id, + get_user_history, + update_conversation_timestamp, + rename_conversation, + delete_conversation, + clear_all_history, + get_conversation_with_messages, + log_chunk_retrieval, + log_detected_mistakes, + get_pdf_quality_scores, +) +from repositories.qdrant_repo import ( + get_embedding, + get_qdrant_client, + init_qdrant_collections, + upsert_pdf_summary, + get_user_pdf_summaries, + upsert_pdf_chunks, + search_chunks, + search_pdf_summary, + search_chunks_by_pdf_id, + get_pdf_ids_by_names, + get_most_recent_pdf_id, +) diff --git a/repositories/postgres_repo.py b/repositories/postgres_repo.py new file mode 100644 index 0000000000000000000000000000000000000000..6fb8f1a81de48de43ec752c533b6b1fb75fbb275 --- /dev/null +++ b/repositories/postgres_repo.py @@ -0,0 +1,521 @@ +import os +import uuid +from datetime import datetime +import json as _json + +import asyncpg +from dotenv import load_dotenv +from core.exceptions import DatabaseError + +load_dotenv() + +_pool: asyncpg.Pool | None = None + + +async def get_pool() -> asyncpg.Pool: + """Returns the global asyncpg connection pool singleton.""" + global _pool + if _pool is None: + _pool = await asyncpg.create_pool( + os.getenv("DATABASE_URL"), + min_size=2, + max_size=10, + ) + return _pool + + +async def close_pool(): + """Close the global pool on shutdown.""" + global _pool + if _pool: + await _pool.close() + _pool = None + + +# ─── Table Definitions ──────────────────────────────────────────────────────── + +CREATE_TABLES_SQL = """ +-- Enum types +DO $$ BEGIN + CREATE TYPE conv_type AS ENUM ('standard', 'debate'); +EXCEPTION WHEN duplicate_object THEN NULL; +END $$; + +DO $$ BEGIN + CREATE TYPE reasoning_type AS ENUM ('standard', 'deep_research', 'multi_agent'); +EXCEPTION WHEN duplicate_object THEN NULL; +END $$; + +DO $$ BEGIN + CREATE TYPE mistake_severity AS ENUM ('low', 'medium', 'high'); +EXCEPTION WHEN duplicate_object THEN NULL; +END $$; + +-- 1. Users table +CREATE TABLE IF NOT EXISTS users ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + email VARCHAR(255) UNIQUE NOT NULL, + password_hash VARCHAR(255) NOT NULL, + display_name VARCHAR(255), + created_at TIMESTAMPTZ DEFAULT NOW(), + updated_at TIMESTAMPTZ DEFAULT NOW() +); + +-- 2. Conversations table +CREATE TABLE IF NOT EXISTS conversations ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + type conv_type NOT NULL DEFAULT 'standard', + title VARCHAR(500), + created_at TIMESTAMPTZ DEFAULT NOW(), + updated_at TIMESTAMPTZ DEFAULT NOW() +); + +-- 3. Messages table +-- Each row is one full query+response turn. +-- message JSONB stores an array: [{"user": "...", "assistant": "..."}, ...] +-- New messages are appended as new dicts in the array. +CREATE TABLE IF NOT EXISTS messages ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE, + reasoning_mode reasoning_type NOT NULL, + message JSONB NOT NULL DEFAULT '[]', + confidence NUMERIC(4,3) CHECK (confidence BETWEEN 0 AND 1), + consistency NUMERIC(4,3) CHECK (consistency BETWEEN 0 AND 1), + pre_thinking JSONB DEFAULT NULL, + created_at TIMESTAMPTZ DEFAULT NOW() +); + +-- 4. Debate sessions table +CREATE TABLE IF NOT EXISTS debate_sessions ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE, + topic TEXT NOT NULL, + debate_messages JSONB NOT NULL DEFAULT '[]', + verdict_text TEXT, + created_at TIMESTAMPTZ DEFAULT NOW() +); + +-- 5. Detected mistakes table +CREATE TABLE IF NOT EXISTS detected_mistakes ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + message_id UUID REFERENCES messages(id) ON DELETE CASCADE, + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + description TEXT NOT NULL, + severity mistake_severity NOT NULL DEFAULT 'medium', + created_at TIMESTAMPTZ DEFAULT NOW() +); + +-- 6. Chunk retrieval log table +CREATE TABLE IF NOT EXISTS chunk_retrieval_log ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + message_id UUID REFERENCES messages(id) ON DELETE CASCADE, + qdrant_chunk_id VARCHAR(255), + pdf_id VARCHAR(255), + similarity_score FLOAT, + quality_score FLOAT, + created_at TIMESTAMPTZ DEFAULT NOW() +); + +-- Indexes +CREATE INDEX IF NOT EXISTS idx_conversations_user_id ON conversations(user_id); +CREATE INDEX IF NOT EXISTS idx_messages_conversation_id_created ON messages(conversation_id, created_at); +CREATE INDEX IF NOT EXISTS idx_debate_sessions_user_id ON debate_sessions(user_id); +CREATE INDEX IF NOT EXISTS idx_debate_sessions_conversation_id ON debate_sessions(conversation_id); +CREATE INDEX IF NOT EXISTS idx_detected_mistakes_user_id ON detected_mistakes(user_id); +CREATE INDEX IF NOT EXISTS idx_chunk_retrieval_message_id ON chunk_retrieval_log(message_id); +""" + + +async def init_db(): + """Initialize the database: create all tables and indexes.""" + pool = await get_pool() + async with pool.acquire() as conn: + await conn.execute(CREATE_TABLES_SQL) + await _migrate_messages_pre_thinking(conn) + print("[postgres] Database initialized successfully.") + + +async def _column_exists(conn: asyncpg.Connection, table: str, column: str) -> bool: + return bool( + await conn.fetchval( + """ + SELECT EXISTS ( + SELECT 1 + FROM information_schema.columns + WHERE table_schema = current_schema() + AND table_name = $1 + AND column_name = $2 + ) + """, + table, + column, + ) + ) + + +async def _migrate_messages_pre_thinking(conn: asyncpg.Connection) -> None: + table_exists = bool( + await conn.fetchval( + """ + SELECT EXISTS ( + SELECT 1 + FROM information_schema.tables + WHERE table_schema = current_schema() + AND table_name = 'messages' + ) + """ + ) + ) + if not table_exists: + return + + has_pre_thinking = await _column_exists(conn, "messages", "pre_thinking") + has_what_happened = await _column_exists(conn, "messages", "what_happened") + + if has_what_happened and not has_pre_thinking: + await conn.execute( + "ALTER TABLE messages RENAME COLUMN what_happened TO pre_thinking" + ) + elif has_what_happened and has_pre_thinking: + await conn.execute( + """ + UPDATE messages + SET pre_thinking = COALESCE(pre_thinking, what_happened) + WHERE what_happened IS NOT NULL + """ + ) + await conn.execute("ALTER TABLE messages DROP COLUMN what_happened") + elif not has_pre_thinking: + await conn.execute("ALTER TABLE messages ADD COLUMN pre_thinking JSONB DEFAULT NULL") + + +# ─── Helper Functions ───────────────────────────────────────────────────────── + + +async def create_user( + email: str, password_hash: str, display_name: str | None = None +) -> str: + """Create a new user and return their UUID.""" + pool = await get_pool() + async with pool.acquire() as conn: + row = await conn.fetchrow( + "INSERT INTO users (email, password_hash, display_name) VALUES ($1, $2, $3) RETURNING id", + email, + password_hash, + display_name, + ) + return str(row["id"]) + + +async def get_user_by_email(email: str) -> dict | None: + """Fetch a user by email. Returns dict or None.""" + pool = await get_pool() + async with pool.acquire() as conn: + row = await conn.fetchrow( + "SELECT id, email, password_hash, display_name FROM users WHERE email = $1", + email, + ) + if row: + return dict(row) + return None + + +async def get_user_by_id(user_id: str) -> dict | None: + """Fetch a user by UUID. Returns dict or None.""" + pool = await get_pool() + async with pool.acquire() as conn: + row = await conn.fetchrow( + "SELECT id, email, display_name FROM users WHERE id = $1", + uuid.UUID(user_id), + ) + if row: + return dict(row) + return None + + +async def create_conversation( + user_id: str, conv_type: str = "standard", title: str | None = None +) -> str: + """Create a new conversation and return its UUID.""" + pool = await get_pool() + async with pool.acquire() as conn: + row = await conn.fetchrow( + "INSERT INTO conversations (user_id, type, title) VALUES ($1, $2, $3) RETURNING id", + uuid.UUID(user_id), + conv_type, + title, + ) + return str(row["id"]) + + +async def append_message( + conversation_id: str, + reasoning_mode: str, + user_content: str, + assistant_content: str, + confidence: float | None = None, + consistency: float | None = None, + pdfs: list[str] | None = None, + pre_thinking: dict | None = None, + tools: list[str] | None = None, +) -> str: + """ + Append a user+assistant pair as a new row in the messages table. + Each row represents one full query+response turn. + Returns the new message row UUID. + """ + pool = await get_pool() + async with pool.acquire() as conn: + new_entry = {"user": user_content, "assistant": assistant_content} + if pdfs: + new_entry["pdfs"] = pdfs + if tools: + new_entry["tools"] = tools + row = await conn.fetchrow( + "INSERT INTO messages (conversation_id, reasoning_mode, message, confidence, consistency, pre_thinking) VALUES ($1, $2, $3, $4, $5, $6) RETURNING id", + uuid.UUID(conversation_id), + reasoning_mode, + _json.dumps([new_entry]), + confidence, + consistency, + _json.dumps(pre_thinking) if pre_thinking else None, + ) + return str(row["id"]) + + +async def create_debate_session( + user_id: str, + conversation_id: str, + topic: str, + debate_messages: list[dict], + verdict_text: str | None = None, +) -> str: + """Persist a debate session and return its UUID.""" + pool = await get_pool() + async with pool.acquire() as conn: + row = await conn.fetchrow( + "INSERT INTO debate_sessions (user_id, conversation_id, topic, debate_messages, verdict_text) VALUES ($1, $2, $3, $4, $5) RETURNING id", + uuid.UUID(user_id), + uuid.UUID(conversation_id), + topic, + _json.dumps(debate_messages), + verdict_text, + ) + return str(row["id"]) + + +async def get_debate_session_by_conversation_id(conversation_id: str, user_id: str) -> dict | None: + """Fetch a debate session by its conversation UUID for the authenticated user.""" + pool = await get_pool() + async with pool.acquire() as conn: + row = await conn.fetchrow( + """ + SELECT id, user_id, conversation_id, topic, debate_messages, verdict_text, created_at + FROM debate_sessions + WHERE conversation_id = $1 AND user_id = $2 + """, + uuid.UUID(conversation_id), + uuid.UUID(user_id), + ) + if not row: + return None + + data = dict(row) + if isinstance(data.get("debate_messages"), str): + data["debate_messages"] = _json.loads(data["debate_messages"]) + return data + + +async def get_user_history(user_id: str) -> list[dict]: + """Get conversation metadata for a user, ordered by most recent.""" + pool = await get_pool() + async with pool.acquire() as conn: + rows = await conn.fetch( + """ + SELECT + c.id AS conv_id, + c.type AS conv_type, + c.title, + c.created_at AS conv_created, + c.updated_at AS conv_updated + FROM conversations c + WHERE c.user_id = $1 + ORDER BY c.updated_at DESC + """, + uuid.UUID(user_id), + ) + + return [ + { + "id": str(row["conv_id"]), + "type": row["conv_type"], + "title": row["title"], + "created_at": row["conv_created"].isoformat() + if row["conv_created"] + else None, + "updated_at": row["conv_updated"].isoformat() + if row["conv_updated"] + else None, + "messages": [], + } + for row in rows + ] + + +async def update_conversation_timestamp(conversation_id: str): + """Update the updated_at timestamp for a conversation.""" + pool = await get_pool() + async with pool.acquire() as conn: + await conn.execute( + "UPDATE conversations SET updated_at = NOW() WHERE id = $1", + uuid.UUID(conversation_id), + ) + + +async def rename_conversation( + conversation_id: str, user_id: str, new_title: str +) -> bool: + """Rename a conversation title. Returns True if updated.""" + pool = await get_pool() + async with pool.acquire() as conn: + result = await conn.execute( + "UPDATE conversations SET title = $1, updated_at = NOW() WHERE id = $2 AND user_id = $3", + new_title, + uuid.UUID(conversation_id), + uuid.UUID(user_id), + ) + return result == "UPDATE 1" + + +async def delete_conversation(conversation_id: str, user_id: str) -> bool: + """Delete a conversation and all its messages. Returns True if deleted.""" + pool = await get_pool() + async with pool.acquire() as conn: + result = await conn.execute( + "DELETE FROM conversations WHERE id = $1 AND user_id = $2", + uuid.UUID(conversation_id), + uuid.UUID(user_id), + ) + return result == "DELETE 1" + + +async def clear_all_history(user_id: str) -> int: + """Delete all conversations for a user. Returns count deleted.""" + pool = await get_pool() + async with pool.acquire() as conn: + result = await conn.execute( + "DELETE FROM conversations WHERE user_id = $1", + uuid.UUID(user_id), + ) + return int(result.split()[-1]) if result.startswith("DELETE") else 0 + + +async def log_chunk_retrieval( + message_id: str | None, + qdrant_chunk_id: str, + pdf_id: str, + similarity_score: float, + quality_score: float | None = None, +) -> str: + """Log a chunk retrieval for analytics.""" + pool = await get_pool() + async with pool.acquire() as conn: + row = await conn.fetchrow( + "INSERT INTO chunk_retrieval_log (message_id, qdrant_chunk_id, pdf_id, similarity_score, quality_score) VALUES ($1, $2, $3, $4, $5) RETURNING id", + uuid.UUID(message_id) if message_id else None, + qdrant_chunk_id, + pdf_id, + similarity_score, + quality_score, + ) + return str(row["id"]) + + +async def get_pdf_quality_scores(pdf_ids: list[str]) -> dict[str, float]: + """Get the average similarity score for a list of pdf_ids from chunk_retrieval_log.""" + if not pdf_ids: + return {} + pool = await get_pool() + async with pool.acquire() as conn: + rows = await conn.fetch( + "SELECT pdf_id, AVG(similarity_score) as avg_score FROM chunk_retrieval_log WHERE pdf_id = ANY($1::text[]) GROUP BY pdf_id", + pdf_ids, + ) + return {row["pdf_id"]: float(row["avg_score"]) for row in rows} + + +async def get_conversation_with_messages(conversation_id: str, user_id: str) -> dict: + """Get a specific conversation and all its messages. Raises ValueError if not found.""" + pool = await get_pool() + async with pool.acquire() as conn: + conv = await conn.fetchrow( + "SELECT id, type, title FROM conversations WHERE id = $1 AND user_id = $2", + uuid.UUID(conversation_id), + uuid.UUID(user_id), + ) + if not conv: + raise ValueError("Conversation not found") + + msgs = await conn.fetch( + "SELECT id, reasoning_mode, message, confidence, consistency, pre_thinking, created_at FROM messages WHERE conversation_id = $1 ORDER BY created_at ASC", + uuid.UUID(conversation_id), + ) + + messages = [] + for m in msgs: + content = m["message"] + if isinstance(content, str): + content = _json.loads(content) + pre_thinking = m.get("pre_thinking") + if pre_thinking and not isinstance(pre_thinking, dict): + pre_thinking = _json.loads(pre_thinking) if pre_thinking else None + messages.append( + { + "id": str(m["id"]), + "reasoning_mode": m["reasoning_mode"], + "content": content, + "confidence": float(m["confidence"]) if m["confidence"] else None, + "consistency": float(m["consistency"]) + if m["consistency"] + else None, + "pre_thinking": pre_thinking, + "created_at": m["created_at"].isoformat() + if m["created_at"] + else None, + } + ) + + return { + "conversation": { + "id": str(conv["id"]), + "type": conv["type"], + "title": conv["title"], + }, + "messages": messages, + } + + +async def log_detected_mistakes( + message_id: str, user_id: str, mistakes: list[dict] +) -> None: + """Log serious mistakes to the detected_mistakes table.""" + if not mistakes: + return + pool = await get_pool() + async with pool.acquire() as conn: + values = [] + for m in mistakes: + sev = str(m.get("severity", "medium")).lower() + if sev not in ("low", "medium", "high"): + sev = "medium" + desc = str(m.get("description", "")) + if desc: + values.append((uuid.UUID(message_id), uuid.UUID(user_id), desc, sev)) + + if values: + await conn.executemany( + "INSERT INTO detected_mistakes (message_id, user_id, description, severity) VALUES ($1, $2, $3, $4)", + values, + ) diff --git a/repositories/qdrant_repo.py b/repositories/qdrant_repo.py new file mode 100644 index 0000000000000000000000000000000000000000..5fb709e73fad95e94fb8422b9864d29e2374df5b --- /dev/null +++ b/repositories/qdrant_repo.py @@ -0,0 +1,333 @@ +import os +import time +import uuid +from typing import List + +from dotenv import load_dotenv +from qdrant_client import QdrantClient +from qdrant_client.models import ( + Distance, + FieldCondition, + Filter, + MatchValue, + PayloadSchemaType, + PointStruct, + VectorParams, +) +from sentence_transformers import SentenceTransformer + +load_dotenv() + +# ─── Embedding Model (cached) ───────────────────────────────────────────────── + +_embedding_model: SentenceTransformer | None = None + + +def get_embedding_model() -> SentenceTransformer: + """Get cached SentenceTransformer model.""" + global _embedding_model + if _embedding_model is None: + _embedding_model = SentenceTransformer("BAAI/bge-small-en-v1.5") + return _embedding_model + + +def get_embedding(text: str) -> list[float]: + """Generate a 384-dim embedding for the given text.""" + model = get_embedding_model() + embedding = model.encode(text, normalize_embeddings=True) + return embedding.tolist() + + +# ─── Qdrant Client ──────────────────────────────────────────────────────────── + +_qdrant_client: QdrantClient | None = None + + +def get_qdrant_client() -> QdrantClient: + """Get the global Qdrant client singleton.""" + global _qdrant_client + if _qdrant_client is None: + _qdrant_client = QdrantClient( + url=os.getenv("QDRANT_CLIENT"), + api_key=os.getenv("QDRANT_API_KEY"), + ) + return _qdrant_client + + +def init_qdrant_collections(): + """Create pdf_summary and pdf_chunks collections if they don't exist.""" + client = get_qdrant_client() + collections = [c.name for c in client.get_collections().collections] + + if "pdf_summary" not in collections: + client.create_collection( + collection_name="pdf_summary", + vectors_config=VectorParams(size=384, distance=Distance.COSINE), + ) + print("[qdrant] Created 'pdf_summary' collection.") + + try: + client.create_payload_index( + collection_name="pdf_summary", + field_name="user_id", + field_schema=PayloadSchemaType.KEYWORD, + ) + print("[qdrant] Created 'user_id' index on 'pdf_summary'.") + except Exception: + pass + + if "pdf_chunks" not in collections: + client.create_collection( + collection_name="pdf_chunks", + vectors_config=VectorParams(size=384, distance=Distance.COSINE), + ) + print("[qdrant] Created 'pdf_chunks' collection.") + + try: + client.create_payload_index( + collection_name="pdf_chunks", + field_name="user_id", + field_schema=PayloadSchemaType.KEYWORD, + ) + print("[qdrant] Created 'user_id' index on 'pdf_chunks'.") + except Exception: + pass + + # pdf_id index is required for filtering chunks by a specific PDF + try: + client.create_payload_index( + collection_name="pdf_chunks", + field_name="pdf_id", + field_schema=PayloadSchemaType.KEYWORD, + ) + print("[qdrant] Created 'pdf_id' index on 'pdf_chunks'.") + except Exception: + pass + + +# ─── PDF Summary CRUD ───────────────────────────────────────────────────────── + +def upsert_pdf_summary( + pdf_id: str, + user_id: str, + conversation_id: str | None, + doc_name: str, + doc_summary: str, + topic_tags: list[str], +) -> str: + """Store a PDF summary in the pdf_summary collection. Returns the point ID.""" + client = get_qdrant_client() + embedding = get_embedding(doc_summary) + + point_id = str(uuid.uuid4()) + payload = { + "pdf_id": pdf_id, + "user_id": user_id, + "conversation_id": conversation_id, + "doc_name": doc_name, + "doc_summary": doc_summary, + "topic_tags": topic_tags, + "created_at": time.time(), # Unix timestamp — used to find the most recently uploaded PDF + } + + client.upsert( + collection_name="pdf_summary", + points=[PointStruct(id=point_id, vector=embedding, payload=payload)], + ) + return point_id + + +def get_user_pdf_summaries(user_id: str) -> list[dict]: + """Get all PDF summaries for a given user.""" + client = get_qdrant_client() + + results = client.scroll( + collection_name="pdf_summary", + scroll_filter=Filter( + must=[FieldCondition(key="user_id", match=MatchValue(value=user_id))] + ), + limit=100, + ) + + summaries = [] + for point in results[0]: + summaries.append({ + "id": str(point.id), + "pdf_id": point.payload.get("pdf_id"), + "doc_name": point.payload.get("doc_name"), + "doc_summary": point.payload.get("doc_summary"), + "topic_tags": point.payload.get("topic_tags", []), + "user_id": point.payload.get("user_id"), + "conversation_id": point.payload.get("conversation_id"), + }) + return summaries + + +# ─── PDF Chunks CRUD ────────────────────────────────────────────────────────── + +def upsert_pdf_chunks( + pdf_id: str, + user_id: str, + doc_name: str, + chunks: list[dict], +) -> int: + """ + Store PDF chunks in the pdf_chunks collection. + Each chunk dict should have: page_number, chunk_index, text_content. + Returns the number of chunks stored. + """ + client = get_qdrant_client() + + points = [] + for chunk in chunks: + embedding = get_embedding(chunk["text_content"]) + point_id = str(uuid.uuid4()) + payload = { + "pdf_id": pdf_id, + "user_id": user_id, + "doc_name": doc_name, + "page_number": chunk.get("page_number", 0), + "chunk_index": chunk.get("chunk_index", 0), + "text_content": chunk["text_content"], + } + points.append(PointStruct(id=point_id, vector=embedding, payload=payload)) + + batch_size = 100 + for i in range(0, len(points), batch_size): + batch = points[i : i + batch_size] + client.upsert(collection_name="pdf_chunks", points=batch) + + return len(points) + + +def search_chunks(query: str, user_id: str, top_k: int = 3) -> list[dict]: + """Search for relevant chunks filtered by user_id.""" + client = get_qdrant_client() + embedding = get_embedding(query) + + results = client.query_points( + collection_name="pdf_chunks", + query=embedding, + query_filter=Filter( + must=[FieldCondition(key="user_id", match=MatchValue(value=user_id))] + ), + limit=top_k, + with_payload=True, + ) + + chunks = [] + for hit in results.points: + chunks.append({ + "id": str(hit.id), + "text_content": hit.payload.get("text_content", ""), + "pdf_id": hit.payload.get("pdf_id"), + "doc_name": hit.payload.get("doc_name"), + "page_number": hit.payload.get("page_number"), + "similarity_score": hit.score, + }) + return chunks + + +def search_pdf_summary(query: str, user_id: str, top_k: int = 3) -> list[dict]: + """Search pdf_summary collection by query similarity for a given user. + Returns list of {pdf_id, doc_name, doc_summary, similarity_score}. + """ + client = get_qdrant_client() + embedding = get_embedding(query) + + results = client.query_points( + collection_name="pdf_summary", + query=embedding, + query_filter=Filter( + must=[FieldCondition(key="user_id", match=MatchValue(value=user_id))] + ), + limit=top_k, + with_payload=True, + ) + + summaries = [] + for hit in results.points: + summaries.append({ + "id": str(hit.id), + "pdf_id": hit.payload.get("pdf_id"), + "doc_name": hit.payload.get("doc_name"), + "doc_summary": hit.payload.get("doc_summary"), + "topic_tags": hit.payload.get("topic_tags", []), + "similarity_score": hit.score, + }) + return summaries + + +def search_chunks_by_pdf_id(query: str, user_id: str, pdf_id: str, top_k: int = 5) -> list[dict]: + """Search pdf_chunks filtered by a specific pdf_id. Returns the most relevant chunks.""" + client = get_qdrant_client() + embedding = get_embedding(query) + + results = client.query_points( + collection_name="pdf_chunks", + query=embedding, + query_filter=Filter( + must=[ + FieldCondition(key="user_id", match=MatchValue(value=user_id)), + FieldCondition(key="pdf_id", match=MatchValue(value=pdf_id)), + ] + ), + limit=top_k, + with_payload=True, + ) + + chunks = [] + for hit in results.points: + chunks.append({ + "id": str(hit.id), + "text_content": hit.payload.get("text_content", ""), + "pdf_id": hit.payload.get("pdf_id"), + "doc_name": hit.payload.get("doc_name"), + "page_number": hit.payload.get("page_number"), + "similarity_score": hit.score, + }) + return chunks + + +def get_pdf_ids_by_names(doc_names: list[str], user_id: str) -> dict[str, str]: + """Look up pdf_ids by doc_name for a given user. Returns {doc_name: pdf_id}.""" + client = get_qdrant_client() + result: dict[str, str] = {} + + hits, _ = client.scroll( + collection_name="pdf_summary", + scroll_filter=Filter( + must=[FieldCondition(key="user_id", match=MatchValue(value=user_id))] + ), + limit=200, + with_payload=True, + ) + for point in hits: + name = point.payload.get("doc_name", "") + if name in doc_names: + result[name] = point.payload.get("pdf_id", "") + + return result + + +def get_most_recent_pdf_id(user_id: str) -> str | None: + """Return the pdf_id of the most recently uploaded PDF for a given user. + Uses the created_at timestamp stored in the pdf_summary payload. + Falls back gracefully for older entries that lack the timestamp. + """ + client = get_qdrant_client() + + hits, _ = client.scroll( + collection_name="pdf_summary", + scroll_filter=Filter( + must=[FieldCondition(key="user_id", match=MatchValue(value=user_id))] + ), + limit=200, + with_payload=True, + ) + if not hits: + return None + + # Sort descending by created_at; entries without the field get 0 (treated as oldest) + sorted_hits = sorted(hits, key=lambda p: p.payload.get("created_at", 0), reverse=True) + return sorted_hits[0].payload.get("pdf_id") \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..63657db9e86fada6da0a5c2cdcbbc1dc16dbb3f7 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,21 @@ +fastapi +uvicorn +asyncpg +python-jose[cryptography] +bcrypt +python-multipart +qdrant-client +pymupdf +langchain-textsplitters +langchain-core +langchain-groq +langchain-openai +groq +python-dotenv +autogen-agentchat~=0.4.0.dev10 +autogen-ext[openai]~=0.4.0.dev10 +httpx +sentence-transformers +langgraph +# Install CPU-only torch to keep the image size manageable +torch --index-url https://download.pytorch.org/whl/cpu diff --git a/routers/__init__.py b/routers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..478f1cec50981a014ad20cfbe8bafc0c12b6ed87 --- /dev/null +++ b/routers/__init__.py @@ -0,0 +1,8 @@ +# Routers module — re-exports all API routers +from routers.auth_router import router as auth_router +from routers.chat_router import router as chat_router +from routers.history_router import router as history_router +from routers.pdf_router import router as pdf_router +from routers.debate_router import router as debate_router +from routers.admin_router import router as admin_router +from routers.reflection_router import router as reflection_router \ No newline at end of file diff --git a/routers/admin_router.py b/routers/admin_router.py new file mode 100644 index 0000000000000000000000000000000000000000..7a9ad0d8fc5cd8c76975d91e1ca9860825717f61 --- /dev/null +++ b/routers/admin_router.py @@ -0,0 +1,181 @@ +from fastapi import APIRouter, HTTPException, Query +from repositories import get_pool + +router = APIRouter(prefix="/admin", tags=["admin"]) + + +@router.get("/tables") +async def list_tables(): + """List all tables and their row counts.""" + pool = await get_pool() + async with pool.acquire() as conn: + tables = [ + "users", + "conversations", + "messages", + "debate_sessions", + "detected_mistakes", + "chunk_retrieval_log", + ] + result = {} + for table in tables: + count = await conn.fetchval(f"SELECT COUNT(*) FROM {table}") + result[table] = count + return {"tables": result} + + +@router.get("/users") +async def view_users(limit: int = Query(50, le=200), offset: int = Query(0, ge=0)): + """View all users (passwords masked).""" + pool = await get_pool() + async with pool.acquire() as conn: + rows = await conn.fetch( + "SELECT id, email, '***' as password_hash, display_name, created_at, updated_at FROM users ORDER BY created_at DESC LIMIT $1 OFFSET $2", + limit, offset, + ) + total = await conn.fetchval("SELECT COUNT(*) FROM users") + return {"total": total, "offset": offset, "limit": limit, "data": [dict(r) for r in rows]} + + +@router.get("/conversations") +async def view_conversations(limit: int = Query(50, le=200), offset: int = Query(0, ge=0)): + """View all conversations.""" + pool = await get_pool() + async with pool.acquire() as conn: + rows = await conn.fetch( + """SELECT c.id, c.user_id, u.email as user_email, c.type, c.title, c.created_at, c.updated_at + FROM conversations c LEFT JOIN users u ON c.user_id = u.id + ORDER BY c.updated_at DESC LIMIT $1 OFFSET $2""", + limit, offset, + ) + total = await conn.fetchval("SELECT COUNT(*) FROM conversations") + return {"total": total, "offset": offset, "limit": limit, "data": [dict(r) for r in rows]} + + +@router.get("/messages") +async def view_messages( + conversation_id: str | None = Query(None), + limit: int = Query(50, le=200), + offset: int = Query(0, ge=0), +): + """View messages, optionally filtered by conversation_id.""" + pool = await get_pool() + async with pool.acquire() as conn: + if conversation_id: + rows = await conn.fetch( + """SELECT m.id, m.conversation_id, m.reasoning_mode, m.message, m.confidence, m.consistency, m.pre_thinking, m.created_at + FROM messages m WHERE m.conversation_id = $1 + ORDER BY m.created_at ASC LIMIT $2 OFFSET $3""", + conversation_id, limit, offset, + ) + total = await conn.fetchval("SELECT COUNT(*) FROM messages WHERE conversation_id = $1", conversation_id) + else: + rows = await conn.fetch( + """SELECT m.id, m.conversation_id, m.reasoning_mode, m.message, m.confidence, m.consistency, m.pre_thinking, m.created_at + FROM messages m ORDER BY m.created_at DESC LIMIT $1 OFFSET $2""", + limit, offset, + ) + total = await conn.fetchval("SELECT COUNT(*) FROM messages") + return {"total": total, "offset": offset, "limit": limit, "data": [dict(r) for r in rows]} + + +@router.get("/messages/{message_id}") +async def view_message_detail(message_id: str): + """View full details of a single message including JSONB content.""" + import json as _json + pool = await get_pool() + async with pool.acquire() as conn: + row = await conn.fetchrow( + "SELECT * FROM messages WHERE id = $1", message_id, + ) + if not row: + raise HTTPException(status_code=404, detail="Message not found") + data = dict(row) + # Parse JSONB message content + if isinstance(data.get("message"), str): + data["message"] = _json.loads(data["message"]) + return data + + +@router.get("/debate-sessions") +async def view_debate_sessions(limit: int = Query(50, le=200), offset: int = Query(0, ge=0)): + """View all debate sessions.""" + pool = await get_pool() + async with pool.acquire() as conn: + rows = await conn.fetch( + """SELECT d.id, d.user_id, u.email as user_email, d.conversation_id, d.topic, + LENGTH(d.debate_messages::text) as messages_size, + LENGTH(d.verdict_text) as verdict_size, d.created_at + FROM debate_sessions d LEFT JOIN users u ON d.user_id = u.id + ORDER BY d.created_at DESC LIMIT $1 OFFSET $2""", + limit, offset, + ) + total = await conn.fetchval("SELECT COUNT(*) FROM debate_sessions") + return {"total": total, "offset": offset, "limit": limit, "data": [dict(r) for r in rows]} + + +@router.get("/debate-sessions/{session_id}") +async def view_debate_session_detail(session_id: str): + """View full details of a single debate session.""" + import json as _json + pool = await get_pool() + async with pool.acquire() as conn: + row = await conn.fetchrow("SELECT * FROM debate_sessions WHERE id = $1", session_id) + if not row: + raise HTTPException(status_code=404, detail="Debate session not found") + data = dict(row) + if isinstance(data.get("debate_messages"), str): + data["debate_messages"] = _json.loads(data["debate_messages"]) + return data + + +@router.get("/mistakes") +async def view_mistakes(limit: int = Query(50, le=200), offset: int = Query(0, ge=0)): + """View all detected mistakes.""" + pool = await get_pool() + async with pool.acquire() as conn: + rows = await conn.fetch( + """SELECT dm.id, dm.message_id, dm.user_id, u.email as user_email, + dm.description, dm.severity, dm.created_at + FROM detected_mistakes dm LEFT JOIN users u ON dm.user_id = u.id + ORDER BY dm.created_at DESC LIMIT $1 OFFSET $2""", + limit, offset, + ) + total = await conn.fetchval("SELECT COUNT(*) FROM detected_mistakes") + return {"total": total, "offset": offset, "limit": limit, "data": [dict(r) for r in rows]} + + +@router.get("/chunk-logs") +async def view_chunk_logs(limit: int = Query(50, le=200), offset: int = Query(0, ge=0)): + """View chunk retrieval logs.""" + pool = await get_pool() + async with pool.acquire() as conn: + rows = await conn.fetch( + """SELECT id, message_id, qdrant_chunk_id, pdf_id, similarity_score, quality_score, created_at + FROM chunk_retrieval_log ORDER BY created_at DESC LIMIT $1 OFFSET $2""", + limit, offset, + ) + total = await conn.fetchval("SELECT COUNT(*) FROM chunk_retrieval_log") + return {"total": total, "offset": offset, "limit": limit, "data": [dict(r) for r in rows]} + + +@router.get("/stats") +async def view_stats(): + """Get database statistics.""" + pool = await get_pool() + async with pool.acquire() as conn: + stats = {} + stats["users"] = await conn.fetchval("SELECT COUNT(*) FROM users") + stats["conversations"] = await conn.fetchval("SELECT COUNT(*) FROM conversations") + stats["conversations_standard"] = await conn.fetchval("SELECT COUNT(*) FROM conversations WHERE type = 'standard'") + stats["conversations_debate"] = await conn.fetchval("SELECT COUNT(*) FROM conversations WHERE type = 'debate'") + stats["messages"] = await conn.fetchval("SELECT COUNT(*) FROM messages") + stats["debate_sessions"] = await conn.fetchval("SELECT COUNT(*) FROM debate_sessions") + stats["detected_mistakes"] = await conn.fetchval("SELECT COUNT(*) FROM detected_mistakes") + stats["chunk_retrieval_logs"] = await conn.fetchval("SELECT COUNT(*) FROM chunk_retrieval_log") + + # Recent activity + recent_conv = await conn.fetchrow("SELECT created_at FROM conversations ORDER BY created_at DESC LIMIT 1") + stats["last_conversation_at"] = str(recent_conv["created_at"]) if recent_conv else None + + return stats diff --git a/routers/auth_router.py b/routers/auth_router.py new file mode 100644 index 0000000000000000000000000000000000000000..c57f7fe46873590d8e488eb8546f3f33dd34c1cb --- /dev/null +++ b/routers/auth_router.py @@ -0,0 +1,53 @@ +from fastapi import APIRouter, HTTPException +from core import hash_password, verify_password, create_token +from repositories import create_user, get_user_by_email +from schemas import RegisterRequest, LoginRequest + +router = APIRouter(prefix="/auth", tags=["auth"]) + + +@router.post("/register") +async def register(req: RegisterRequest): + """Register a new user and return a JWT token.""" + try: + existing = await get_user_by_email(req.email) + if existing: + raise HTTPException(status_code=400, detail="Email already registered") + + password_hash = hash_password(req.password) + user_id = await create_user(req.email, password_hash, req.display_name) + token = create_token(user_id) + + return { + "token": token, + "user_id": user_id, + "display_name": req.display_name, + } + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/login") +async def login(req: LoginRequest): + """Login with email and password, return a JWT token.""" + try: + user = await get_user_by_email(req.email) + if not user: + raise HTTPException(status_code=401, detail="Invalid email or password") + + if not verify_password(req.password, user["password_hash"]): + raise HTTPException(status_code=401, detail="Invalid email or password") + + token = create_token(str(user["id"])) + + return { + "token": token, + "user_id": str(user["id"]), + "display_name": user.get("display_name"), + } + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) \ No newline at end of file diff --git a/routers/chat_router.py b/routers/chat_router.py new file mode 100644 index 0000000000000000000000000000000000000000..bfaadfc10aa78ebe5ff28922eca8a24ea8416924 --- /dev/null +++ b/routers/chat_router.py @@ -0,0 +1,550 @@ +import json +import logging +from fastapi import APIRouter, Depends +from fastapi.responses import StreamingResponse +from core import get_current_user + +from schemas import ( + QueryRequest, + TaskRequest, + SmartOrchestratorRequest +) + +from services import ( + run_tool_agent_stream_sse, + smart_orchestrator_stream, + get_conversation_memory_context_async, + run_deep_research_stream_with_state, + _to_non_empty_text, + add_to_memory +) + +from repositories import ( + create_conversation, + append_message, + update_conversation_timestamp, + log_chunk_retrieval, + log_detected_mistakes, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["chat"]) + + +@router.post("/chat/stream") +async def agent_query_stream( + req: QueryRequest, user_id: str = Depends(get_current_user) +): + """ + Standard chat endpoint with SSE streaming (3-phase: initial → tools → final). + + Flow: + 1. Load memory context (if conversation_id provided) + 2. Stream tool agent execution via SSE + 3. Persist message + chunk logs after streaming completes + """ + logger.info( + f"[chat_router:stream] Request from user={user_id}, conv_id={req.conversation_id}" + ) + + # ── Fetch prior conversation memory ─ + memory_context: str | None = None + if req.conversation_id: + logger.info( + f"[chat_router:stream] Fetching memory for conv_id={req.conversation_id}" + ) + memory_context = await get_conversation_memory_context_async( + req.conversation_id, user_id + ) + if memory_context: + logger.info( + f"[chat_router:stream] Memory context ready: {len(memory_context)} chars" + ) + else: + logger.info( + f"[chat_router:stream] No prior memory found for conv_id={req.conversation_id}" + ) + + async def event_generator(): + final_answer = "" + retrieved_chunks = [] + + try: + # Stream the tool agent execution + async for sse_line in run_tool_agent_stream_sse( + query=req.query, + user_id=user_id, + pdfs=req.pdfs, + memory_context=memory_context, + ): + # Parse to capture final data for DB persistence + try: + if sse_line.startswith("data: "): + evt = json.loads(sse_line[6:].strip()) + if evt.get("type") == "done": + final_answer = evt.get("answer", "") + tools_used = evt.get("tools_used", []) + retrieved_chunks = evt.get("retrieved_chunks", []) + elif evt.get("type") == "error": + logger.error( + f"[chat_router:stream] Stream error: {evt.get('message')}" + ) + except (json.JSONDecodeError, IndexError): + pass + yield sse_line + + except Exception as e: + logger.error(f"[chat_router:stream] Streaming error: {e}", exc_info=True) + yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n" + + # ── Persist to DB after streaming ─ + try: + conv_id = req.conversation_id + if not conv_id: + conv_id = await create_conversation( + user_id, "standard", req.query[:200] + ) + logger.info(f"[chat_router:stream] Created new conversation: {conv_id}") + + message_id = await append_message( + conversation_id=conv_id, + reasoning_mode="standard", + user_content=req.query, + assistant_content=final_answer, + pdfs=req.pdfs, + ) + await update_conversation_timestamp(conv_id) + logger.info( + f"[chat_router:stream] Persisted message {message_id} to conv {conv_id}" + ) + + # Log retrieved chunks + for chunk in retrieved_chunks: + try: + await log_chunk_retrieval( + message_id=message_id, + qdrant_chunk_id=chunk.get("id", ""), + pdf_id=chunk.get("pdf_id", ""), + similarity_score=chunk.get("similarity_score", 0.0), + quality_score=None, + ) + except Exception as log_err: + logger.warning( + f"[chat_router:stream] Chunk log failed for {chunk.get('id')}: {log_err}" + ) + + # Emit conversation_id to frontend + yield f"data: {json.dumps({'type': 'conversation_id', 'conversation_id': conv_id})}\n\n" + + # ── Update window memory ───────────────────────────── + if conv_id: + add_to_memory(conv_id, req.query, final_answer) + + except Exception as db_err: + logger.error( + f"[chat_router:stream] DB persist error: {db_err}", exc_info=True + ) + + yield 'data: {"type": "done"}\n\n' + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, + ) + + +@router.post("/deep_research/task") +async def deep_research_task(req: TaskRequest, user_id: str = Depends(get_current_user)): + """Run the multi-agent orchestrator (deep research) with SSE streaming.""" + + async def event_generator(): + final_result = "" + final_meta = {} + conv_id = req.conversation_id + pre_thinking = {"decomposition": "", "researcher1": "", "researcher2": ""} + aggregation_content = "" + + try: + # Load memory context + memory_context: str | None = None + if req.conversation_id: + memory_context = await get_conversation_memory_context_async(req.conversation_id, user_id) + + # Stream orchestrator events + async for event in run_deep_research_stream_with_state(req.task, memory_context=memory_context): + event_type = event.get("type", "") + + # Forward all events to frontend + yield f"data: {json.dumps(event)}\n\n" + + # Capture final result for DB persistence + if event_type == "final": + final_result = event.get("result", "") + final_meta = event.get("meta", {}) + elif event_type == "content_chunk": + section = event.get("section", "") + content = event.get("content", "") + if section == "decomposition": + pre_thinking["decomposition"] = content + elif section == "researcher_1": + pre_thinking["researcher1"] = content + elif section == "researcher_2": + pre_thinking["researcher2"] = content + elif section == "aggregation": + aggregation_content = content + + except Exception as e: + logger.error(f"[deep_research_router] Streaming error: {e}", exc_info=True) + yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n" + + # Persist to DB after streaming completes + try: + if not conv_id: + conv_id = await create_conversation(user_id, "standard", req.task[:200]) + + raw_conf = final_meta.get("confidence_score") + raw_cons = final_meta.get("logical_consistency") + deep_research_raw = final_meta.get("deep_research_raw", {}) + + assistant_content_to_store = ( + _to_non_empty_text(final_result) + or ( + _to_non_empty_text(deep_research_raw.get("final_result", "")) + if isinstance(deep_research_raw, dict) + else "" + ) + or _to_non_empty_text(aggregation_content) + or "Deep research completed." + ) + + # Only save pre_thinking if we have decomposition content + pre_thinking_data = ( + pre_thinking if pre_thinking.get("decomposition") else None + ) + + await append_message( + conversation_id=conv_id, + reasoning_mode="deep_research", + user_content=req.task, + assistant_content=assistant_content_to_store, + confidence=raw_conf / 100 if raw_conf is not None else None, + consistency=raw_cons / 100 if raw_cons is not None else None, + pre_thinking=pre_thinking_data, + tools=["Researcher_Agent1", "Researcher_Agent2", "Aggregator_Agent"], + ) + await update_conversation_timestamp(conv_id) + + # Emit conversation_id to frontend + yield f"data: {json.dumps({'type': 'conversation_id', 'conversation_id': conv_id})}\n\n" + + # Update window memory + if conv_id: + add_to_memory(conv_id, req.task, assistant_content_to_store) + + except Exception as db_err: + logger.error(f"[deep_research_router] DB persist error: {db_err}", exc_info=True) + + yield 'data: {"type": "done"}\n\n' + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, + ) + + +@router.post("/smart-orchestrator/stream") +async def smart_orchestrator_endpoint( + req: SmartOrchestratorRequest, user_id: str = Depends(get_current_user) +): + """Smart orchestrator SSE stream (routes to standard/deep_research/code).""" + final_result = "" + meta_info = {} + detected_path = "standard" + pre_thinking = {"decomposition": "", "researcher1": "", "researcher2": ""} + deep_research_aggregation = "" + code_pre_thinking = { + "problem_understanding": "", + "approach": "", + "agent_outputs": [], + "file_outputs": [], + } + code_final_marker = None + + def _to_non_empty_text(value) -> str: + if value is None: + return "" + if isinstance(value, str): + return value.strip() + if isinstance(value, (dict, list)): + try: + return json.dumps(value, ensure_ascii=False).strip() + except Exception: + return str(value).strip() + return str(value).strip() + + def _extract_code_complete_marker(value): + parsed_result = None + if isinstance(value, str) and value.strip(): + try: + parsed_result = json.loads(value) + except json.JSONDecodeError: + parsed_result = None + elif isinstance(value, dict): + parsed_result = value + + if isinstance(parsed_result, dict) and parsed_result.get("type") == "code_complete": + return parsed_result + return None + + def _build_code_approach_from_subtasks(subtasks) -> str: + if not isinstance(subtasks, list): + return "" + lines = [] + for subtask in subtasks: + if not isinstance(subtask, dict): + continue + desc = str(subtask.get("description", "")).strip() + if not desc: + continue + subtask_id = subtask.get("id") + if subtask_id is not None: + lines.append(f"- Agent {subtask_id}: {desc}") + else: + lines.append(f"- {desc}") + if not lines: + return "" + return "Parallel implementation plan:\n" + "\n".join(lines) + + def _build_code_assistant_content( + problem_understanding: str, + approach: str, + final_marker: dict | None, + ) -> str: + sections = [] + if problem_understanding: + sections.append(f"Problem Understanding:\n{problem_understanding}") + if approach: + sections.append(f"Approach:\n{approach}") + summary = "\n\n".join(sections).strip() + if summary: + return summary + + if isinstance(final_marker, dict): + file_count = final_marker.get("file_count") + filenames = final_marker.get("filenames", []) + if isinstance(file_count, int): + if isinstance(filenames, list) and filenames: + return ( + f"Code generation completed with {file_count} file(s): " + + ", ".join(str(name) for name in filenames) + ) + return f"Code generation completed with {file_count} file(s)." + return "" + + async def event_generator(): + nonlocal final_result, meta_info, detected_path, pre_thinking, deep_research_aggregation, code_pre_thinking, code_final_marker + try: + async for chunk in smart_orchestrator_stream( + req.task, + conversation_id=req.conversation_id, + user_id=user_id, + ): + try: + if chunk.startswith("data: "): + evt = json.loads(chunk[6:].strip()) + evt_type = evt.get("type") + if evt_type == "final": + final_result = evt.get("result", "") + meta_info = evt.get("meta", {}) + if detected_path == "code": + parsed_result = _extract_code_complete_marker(final_result) + if isinstance(parsed_result, dict): + code_final_marker = parsed_result + elif evt_type == "route": + detected_path = evt.get("path", "standard") + elif evt_type == "content_chunk": + section = evt.get("section", "") + content = evt.get("content", "") + if section == "decomposition": + pre_thinking["decomposition"] = content + elif section == "researcher_1": + pre_thinking["researcher1"] = content + elif section == "researcher_2": + pre_thinking["researcher2"] = content + elif section == "aggregation": + deep_research_aggregation = content + elif evt_type == "code_section": + section = evt.get("section", "") + content = evt.get("content", "") + if section == "problem_understanding": + code_pre_thinking["problem_understanding"] = content + elif section == "approach": + code_pre_thinking["approach"] = content + elif evt_type == "plan" and detected_path == "code": + if not code_pre_thinking.get("approach"): + generated_approach = _build_code_approach_from_subtasks( + evt.get("subtasks", []) + ) + if generated_approach: + code_pre_thinking["approach"] = generated_approach + elif evt_type == "agent_output" and detected_path == "code": + code_pre_thinking["agent_outputs"].append( + { + "agent_id": evt.get("agent_id"), + "agent_name": evt.get("agent_name"), + "content": evt.get("content", ""), + } + ) + elif evt_type == "file_output" and detected_path == "code": + code_pre_thinking["file_outputs"].append( + { + "filename": evt.get("filename", ""), + "language": evt.get("language", "text"), + "index": evt.get("index"), + "total": evt.get("total"), + "content": evt.get("content", ""), + } + ) + except (json.JSONDecodeError, IndexError): + pass + yield chunk + except Exception as e: + yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n" + finally: + try: + reasoning = ( + "multi_agent" + if detected_path in ("deep_research", "code") + else "multi_agent" + ) + + conv_id = req.conversation_id + if not conv_id: + conv_id = await create_conversation( + user_id, "standard", req.task[:200] + ) + + raw_conf = meta_info.get("confidence_score") + raw_cons = meta_info.get("logical_consistency") + serious_mistakes = meta_info.get("serious_mistakes", []) + orchestrator_raw = meta_info.get("orchestrator_raw", {}) + orchestrator_raw_result = ( + _to_non_empty_text(orchestrator_raw.get("final_result", "")) + if isinstance(orchestrator_raw, dict) + else "" + ) + path_specific_final_result = _to_non_empty_text(final_result) + deep_research_chunk_content = _to_non_empty_text(deep_research_aggregation) + code_summary_content = "" + pre_thinking_data = None + + if detected_path == "deep_research": + pre_thinking_data = ( + pre_thinking if pre_thinking.get("decomposition") else None + ) + elif detected_path == "code": + parsed_result = _extract_code_complete_marker(final_result) + if code_final_marker is None and isinstance(parsed_result, dict): + code_final_marker = parsed_result + if isinstance(parsed_result, dict): + path_specific_final_result = "" + + problem_understanding = str( + code_pre_thinking.get("problem_understanding", "") + ).strip() + approach = str(code_pre_thinking.get("approach", "")).strip() + agent_outputs = code_pre_thinking.get("agent_outputs", []) + file_outputs = code_pre_thinking.get("file_outputs", []) + + code_summary_content = _build_code_assistant_content( + problem_understanding, + approach, + code_final_marker + if isinstance(code_final_marker, dict) + else None, + ) + + pre_thinking_data = { + "route_path": detected_path, + "agent_outputs": agent_outputs, + "file_outputs": file_outputs, + } + if isinstance(code_final_marker, dict): + pre_thinking_data["final_marker"] = code_final_marker + + if not ( + pre_thinking_data["agent_outputs"] + or pre_thinking_data["file_outputs"] + or pre_thinking_data.get("final_marker") + ): + pre_thinking_data = None + + # Build tools list based on detected path + tools = None + if detected_path == "deep_research": + tools = ["Researcher_Agent1", "Researcher_Agent2", "Aggregator_Agent"] + elif detected_path == "code": + tools = ["Code_Planner", "Coder_1", "Coder_2", "Coder_3", "Aggregator", "Reviewer"] + else: + # Standard mode - use tools from meta + tools = meta_info.get("tools_used", []) + if tools: + tools = [t if isinstance(t, str) else t.get("tool", str(t)) for t in tools] + + explicit_fallback_text = ( + "Deep research completed." + if detected_path == "deep_research" + else "Code generation completed." + if detected_path == "code" + else "Request completed." + ) + + assistant_content_to_store = ( + path_specific_final_result + or orchestrator_raw_result + or deep_research_chunk_content + or code_summary_content + or explicit_fallback_text + ) + + message_id = await append_message( + conversation_id=conv_id, + reasoning_mode=reasoning, + user_content=req.task, + assistant_content=assistant_content_to_store, + confidence=raw_conf / 100 if raw_conf is not None else None, + consistency=raw_cons / 100 if raw_cons is not None else None, + pdfs=req.pdfs, + pre_thinking=pre_thinking_data, + tools=tools, + ) + await update_conversation_timestamp(conv_id) + + if serious_mistakes: + try: + await log_detected_mistakes( + message_id, user_id, serious_mistakes + ) + except Exception as log_err: + print(f"[chat_router] Mistake log error: {log_err}") + + # Emit the final conversation_id to the frontend + yield f"data: {json.dumps({'type': 'conversation_id', 'conversation_id': conv_id})}\n\n" + + # ── Update window memory ───────────────────────────── + if conv_id: + add_to_memory(conv_id, req.task, assistant_content_to_store) + + except Exception as db_err: + print(f"[chat_router] DB persist error: {db_err}") + + yield 'data: {"type": "done"}\n\n' + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, + ) diff --git a/routers/debate_router.py b/routers/debate_router.py new file mode 100644 index 0000000000000000000000000000000000000000..9e2a029c6fac9cfa916ce19a3abfe21122ec12ff --- /dev/null +++ b/routers/debate_router.py @@ -0,0 +1,74 @@ +import json +from fastapi import APIRouter, Depends +from fastapi.responses import StreamingResponse +from core import get_current_user + +from services import ( + run_debate_stream, + structure_debate_rounds +) +from repositories import ( + create_conversation, + create_debate_session, + get_debate_session_by_conversation_id, + update_conversation_timestamp +) + +router = APIRouter(tags=["debate"]) + + +@router.get("/debate/stream") +async def debate_stream( + topic: str, + rounds: int = 3, + mode: str = "autogen", + user_id: str = Depends(get_current_user) +): + """Stream a debate between two agents via SSE.""" + debate_events = [] + + async def event_generator(): + nonlocal debate_events + try: + async for msg in run_debate_stream(topic, rounds, mode): + debate_events.append(msg) + yield f"data: {json.dumps(msg)}\n\n" + + # Persist debate session after all rounds complete + conv_id = await create_conversation(user_id, "debate", topic[:200]) + + structured_rounds = structure_debate_rounds(debate_events) + + verdict_text = None + for msg in debate_events: + if msg.get("type") == "verdict": + verdict_text = msg.get("content", "") + + await create_debate_session( + user_id=user_id, + conversation_id=conv_id, + topic=topic, + debate_messages=structured_rounds, + verdict_text=verdict_text, + ) + await update_conversation_timestamp(conv_id) + + # Emit conversation_id for this debate session + yield f"data: {json.dumps({'type': 'conversation_id', 'conversation_id': conv_id})}\n\n" + except ValueError as ve: + yield f"data: {json.dumps({'type': 'error', 'message': str(ve)})}\n\n" + except Exception as db_err: + print(f"[debate_router] DB persist error: {db_err}") + + yield "data: {\"type\": \"done\"}\n\n" + + return StreamingResponse(event_generator(), media_type="text/event-stream") + + +@router.get("/debate/session/{conversation_id}") +async def get_debate_session(conversation_id: str, user_id: str = Depends(get_current_user)): + """Return a saved debate session for history replay.""" + session = await get_debate_session_by_conversation_id(conversation_id, user_id) + if not session: + return {"session": None} + return {"session": session} diff --git a/routers/history_router.py b/routers/history_router.py new file mode 100644 index 0000000000000000000000000000000000000000..a583edc6879b193bf2797c273ad75d0f76d20ff9 --- /dev/null +++ b/routers/history_router.py @@ -0,0 +1,90 @@ +from fastapi import APIRouter, HTTPException, Depends +from pydantic import BaseModel +from core import get_current_user +from repositories import ( + get_user_history, + rename_conversation, + delete_conversation, + clear_all_history, + get_conversation_with_messages, +) +from services import ( + clear_conversation_memory, + get_all_conversations +) + +from schemas import ( + RenameRequest +) + +router = APIRouter(tags=["history"]) + + +@router.get("/history") +async def get_history(user_id: str = Depends(get_current_user)): + """Get all conversations with messages for the authenticated user.""" + try: + history = await get_user_history(user_id) + return {"conversations": history} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.put("/history/{conversation_id}") +async def rename_conv( + conversation_id: str, req: RenameRequest, user_id: str = Depends(get_current_user) +): + """Rename a conversation title.""" + try: + ok = await rename_conversation(conversation_id, user_id, req.title) + if not ok: + raise HTTPException(status_code=404, detail="Conversation not found") + return {"status": "ok"} + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.delete("/history/{conversation_id}") +async def delete_conv(conversation_id: str, user_id: str = Depends(get_current_user)): + """Delete a conversation and all its messages.""" + try: + ok = await delete_conversation(conversation_id, user_id) + if not ok: + raise HTTPException(status_code=404, detail="Conversation not found") + # Clear from window memory + clear_conversation_memory(conversation_id) + return {"status": "ok"} + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.delete("/history") +async def clear_history(user_id: str = Depends(get_current_user)): + """Clear all conversation history for the user.""" + try: + count = await clear_all_history(user_id) + # Clear all window memories for this user (clear all since all convs deleted) + all_convs = get_all_conversations() + for conv_id in all_convs: + clear_conversation_memory(conv_id) + return {"status": "ok", "deleted": count} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/history/{conversation_id}") +async def get_conversation_messages( + conversation_id: str, user_id: str = Depends(get_current_user) +): + """Get all messages for a specific conversation.""" + try: + result = await get_conversation_with_messages(conversation_id, user_id) + return result + except ValueError: + raise HTTPException(status_code=404, detail="Conversation not found") + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) diff --git a/routers/pdf_router.py b/routers/pdf_router.py new file mode 100644 index 0000000000000000000000000000000000000000..f0cb143bffcfc2f2a019141c89b3409c4addf691 --- /dev/null +++ b/routers/pdf_router.py @@ -0,0 +1,86 @@ +import os +import shutil +import tempfile +from fastapi import APIRouter, UploadFile, File, Depends +from core import get_current_user +from utils.pdf_processor import process_pdfs +from repositories import ( + create_conversation, + update_conversation_timestamp, + get_user_pdf_summaries, + get_pdf_quality_scores +) + +router = APIRouter(tags=["pdf"]) + + +@router.post("/upload-pdf") +async def upload_pdfs( + files: list[UploadFile] = File(...), + conversation_id: str | None = None, + user_id: str = Depends(get_current_user), +): + """ + Accepts multiple PDF files, saves them temporarily, + processes them through pdf_processor, and stores in Qdrant. + """ + temp_dir = tempfile.mkdtemp() + file_paths = [] + + try: + for file in files: + if not file.filename: + continue + temp_path = os.path.join(temp_dir, file.filename) + with open(temp_path, "wb") as buffer: + content = await file.read() + buffer.write(content) + file_paths.append(temp_path) + + conv_id = conversation_id + if not conv_id: + conv_id = await create_conversation(user_id, "standard", f"PDF Upload: {len(files)} file(s)") + + results = await process_pdfs(file_paths, user_id, conversation_id=conv_id) + await update_conversation_timestamp(conv_id) + + return { + "status": "success", + "processed": results.get("total_chunks", 0), + "details": results, + "conversation_id": conv_id, + } + except Exception as e: + print(f"[pdf_router] PDF processing error: {e}") + return { + "status": "partial", + "processed": 0, + "details": {"error": str(e), "note": "Qdrant may be unavailable. PDFs processed but not stored."}, + } + finally: + shutil.rmtree(temp_dir, ignore_errors=True) + + +@router.get("/memory/pdfs") +async def get_memory_pdfs(user_id: str = Depends(get_current_user)): + """Get all PDF summaries for the authenticated user.""" + try: + summaries = get_user_pdf_summaries(user_id) + + if summaries: + pdf_ids = [s.get("pdf_id") for s in summaries if s.get("pdf_id")] + if pdf_ids: + scores = await get_pdf_quality_scores(pdf_ids) + for s in summaries: + pid = s.get("pdf_id") + if pid and pid in scores: + # Normalize to a percentage if it's cosine similarity (-1 to 1 or 0 to 1) + # We'll just pass it as is or multiply by 100 on the frontend, let's keep it 0-100 format if it was 0-1 + # If average similarity is 0.85, maybe we want it as 85. + avg = scores[pid] + s["quality_score"] = int(avg * 100) if 0 <= avg <= 1 else int(avg) + + return {"pdfs": summaries} + except Exception as e: + print(f"[pdf_router] Qdrant unavailable, returning empty list: {e}") + return {"pdfs": []} diff --git a/routers/reflection_router.py b/routers/reflection_router.py new file mode 100644 index 0000000000000000000000000000000000000000..18386456e05b298f6dc1d3f7370904772fa13cbc --- /dev/null +++ b/routers/reflection_router.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +import uuid + +from fastapi import APIRouter, Depends, HTTPException + +from core import get_current_user +from repositories import get_pool + +router = APIRouter(prefix="/reflection", tags=["reflection"]) + + +def _clamp_score(value: float, minimum: int = 0, maximum: int = 100) -> int: + return max(minimum, min(maximum, round(value))) + + +def _normalize_to_percent(value: float | None) -> float | None: + if value is None: + return None + return value * 100 if value <= 1 else value + + +def _default_improvement(severity: str) -> str: + if severity == "high": + return "Escalated to stricter validation and cross-checking gates." + if severity == "medium": + return "Added additional reasoning checks before final response." + return "Applied lightweight post-response self-review for similar prompts." + + +def _default_strategy(severity: str) -> str: + if severity == "high": + return "Require tool-backed evidence and second-pass verification for claims." + if severity == "medium": + return "Increase consistency checks on multi-step reasoning chains." + return "Track pattern frequency and auto-flag repeated low-severity slips." + + +def _build_radar( + confidence_score: int, + logical_consistency: int, + factual_reliability: int, + self_correction_triggered: bool, + issue_count: int, + high_severity_count: int, +) -> list[dict]: + adaptation = _clamp_score( + 72 + + (14 if self_correction_triggered else 4) + - (high_severity_count * 4) + - min(issue_count, 6) + ) + return [ + {"metric": "Planning", "value": logical_consistency}, + {"metric": "Reasoning", "value": confidence_score}, + {"metric": "Verification", "value": factual_reliability}, + {"metric": "Adaptation", "value": adaptation}, + {"metric": "Confidence", "value": confidence_score}, + ] + + +@router.get("") +@router.get("/summary") +async def get_reflection_summary(user_id: str = Depends(get_current_user)): + """ + User-scoped reflection summary: + - confidence / logical consistency from messages table (all user conversations) + - reflection report issues from detected_mistakes + """ + try: + pool = await get_pool() + async with pool.acquire() as conn: + quality_row = await conn.fetchrow( + """ + SELECT + AVG(m.confidence) AS avg_confidence, + AVG(m.consistency) AS avg_consistency, + COUNT(m.id) AS message_count + FROM messages m + JOIN conversations c ON c.id = m.conversation_id + WHERE c.user_id = $1 + """, + uuid.UUID(user_id), + ) + + mistakes = await conn.fetch( + """ + SELECT id, description, severity + FROM detected_mistakes + WHERE user_id = $1 + ORDER BY created_at DESC + LIMIT 200 + """, + uuid.UUID(user_id), + ) + except Exception as exc: + raise HTTPException(status_code=500, detail=str(exc)) + + avg_confidence = _normalize_to_percent( + float(quality_row["avg_confidence"]) + if quality_row and quality_row["avg_confidence"] is not None + else None + ) + avg_consistency = _normalize_to_percent( + float(quality_row["avg_consistency"]) + if quality_row and quality_row["avg_consistency"] is not None + else None + ) + + confidence_score = _clamp_score(avg_confidence if avg_confidence is not None else 82) + logical_consistency = _clamp_score(avg_consistency if avg_consistency is not None else 84) + + severity_count = {"high": 0, "medium": 0, "low": 0} + issues = [] + for idx, row in enumerate(mistakes): + severity = str(row["severity"] or "medium").lower() + if severity not in severity_count: + severity = "medium" + severity_count[severity] += 1 + issues.append( + { + "id": str(row["id"]), + "issue": row["description"] or "Detected reasoning issue.", + "improvement": _default_improvement(severity), + "strategy": _default_strategy(severity), + "severity": severity, + } + ) + + message_count = int(quality_row["message_count"]) if quality_row and quality_row["message_count"] else 0 + issue_penalty = ( + severity_count["high"] * 12 + + severity_count["medium"] * 7 + + severity_count["low"] * 4 + ) + density_penalty = ( + min(15, round((len(issues) / message_count) * 30)) if message_count > 0 else 0 + ) + factual_reliability = _clamp_score(92 - issue_penalty - density_penalty, 35, 98) + self_correction_triggered = len(issues) > 0 + + return { + "scores": { + "confidenceScore": confidence_score, + "logicalConsistency": logical_consistency, + "factualReliability": factual_reliability, + "selfCorrectionTriggered": self_correction_triggered, + }, + "radarData": _build_radar( + confidence_score=confidence_score, + logical_consistency=logical_consistency, + factual_reliability=factual_reliability, + self_correction_triggered=self_correction_triggered, + issue_count=len(issues), + high_severity_count=severity_count["high"], + ), + "issues": issues, + } diff --git a/schemas/__init__.py b/schemas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9221bbe7e291f4563ce9e2135d13f44e6bc2ba35 --- /dev/null +++ b/schemas/__init__.py @@ -0,0 +1,13 @@ +from schemas.schema import ( + RegisterRequest, + LoginRequest, + QueryRequest, + TaskRequest, + SmartOrchestratorRequest, + AgentState, + OrchestratorState, + CodeModeState, + CodingSubtask, + CodingAgentState, + RenameRequest +) \ No newline at end of file diff --git a/schemas/schema.py b/schemas/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..495e19b74a49b189d46e40799f18d86a866c01e5 --- /dev/null +++ b/schemas/schema.py @@ -0,0 +1,86 @@ +import operator +from pydantic import BaseModel +from langchain_core.messages import BaseMessage +from typing import TypedDict, Annotated, Sequence, List, Optional + +class RegisterRequest(BaseModel): + email: str + password: str + display_name: str | None = None + + +class LoginRequest(BaseModel): + email: str + password: str + +class QueryRequest(BaseModel): + query: str + conversation_id: str | None = None + pdfs: list[str] | None = None + +class DebateRequest(BaseModel): + topic: str + rounds: int = 3 + +class TaskRequest(BaseModel): + task: str + conversation_id: str | None = None + +class SmartOrchestratorRequest(BaseModel): + task: str + conversation_id: str | None = None + pdfs: list[str] | None = None + +class AgentState(TypedDict): + messages: Annotated[Sequence[BaseMessage], operator.add] + +class RenameRequest(BaseModel): + title: str + +class OrchestratorState(TypedDict): + original_task: str + subtasks: List[dict] # [{id, description, agent_type, result}] + current_subtask_index: int + final_result: str + step_logs: List[str] + critic_confidence: int + critic_logical_consistency: int + critic_feedback: str + serious_mistakes: List[dict] + +class CodeModeState(TypedDict): + original_task: str + plan: str + code_results: List[dict] # [{agent_id, code}] + final_code: str + review_feedback: str + confidence_score: int + consistency_score: int + step_logs: List[str] + serious_mistakes: List[dict] + graph_nodes: List[dict] + graph_edges: List[dict] + + +class CodingSubtask(TypedDict): + id: int + description: str + signatures: List[str] + result: Optional[str] + + +class CodingAgentState(TypedDict): + original_task: str + subtasks: List[CodingSubtask] + shared_contract: str + coder_results: List[str] + merged_code: str + review_errors: List[str] + retry_count: int + confidence_score: int + logical_consistency: int + critic_feedback: str + final_output: str + parsed_files: List[dict] # [{"filename": str, "content": str, "language": str}] + step_logs: List[str] + diff --git a/services/__init__.py b/services/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..794f194461597d9b6656736e904816a7bb74ce36 --- /dev/null +++ b/services/__init__.py @@ -0,0 +1,41 @@ +# Services module — re-exports service-level orchestration functions +from services.agent_service import ( + run_tool_agent, + run_tool_agent_stream_sse +) +from services.deep_research_service import ( + run_deep_research, + run_deep_research_stream_with_state, + _to_non_empty_text +) +from services.debate_service import ( + run_debate_stream, + run_debate_stream_raw, + run_debate_stream_autogen, + structure_debate_rounds +) +from services.smart_orchestrator_service import ( + smart_orchestrator_stream +) +from services.rag_service import ( + run_smart_chat +) +from services.memory_service import ( + get_conversation_memory_context, + get_conversation_memory_context_async, + add_to_memory, + clear_conversation_memory, + get_all_conversations +) +from services.context_injector import ( + inject_memory_context, + has_memory_context, + log_context_injection +) +from services.base_stream_service import ( + format_sse_event, + yield_sse_events, + yield_sse_with_error_handling, + create_sse_event +) + diff --git a/services/agent_service.py b/services/agent_service.py new file mode 100644 index 0000000000000000000000000000000000000000..0ccefa497c8d2b99a3fe60091eb3410f301a784f --- /dev/null +++ b/services/agent_service.py @@ -0,0 +1,234 @@ +import json +import logging +from typing import AsyncGenerator +import utils.tools as _tools_module +from agents import get_tool_agent_graph, run_tool_agent_stream +from schemas.schema import AgentState +from langchain_core.messages import HumanMessage, AIMessage +from services.memory_service import format_memory_block + +logger = logging.getLogger(__name__) + +# ─── System instructions for tool agent ────────── +_SYSTEM_PROMPT = """ +You have access to the following tools: +- knowledge_retriever: Use this tool to answer questions about uploaded PDF documents. Call it when user asks about document content. +- calculator: Use for mathematical calculations. +- get_current_datetime: Use when user asks about current time or date. + +Instructions: +- Only call tools when necessary +- If you know the answer, respond directly without calling tools +- Do not guess or make up information +- When calling knowledge_retriever, include relevant search query +""" + + +async def run_tool_agent( + query: str, + user_id: str = "default_user", + pdfs: list[str] | None = None, + memory_context: str | None = None, +) -> dict: + """ + Run the tool-calling agent and return answer + tool usage info. + + Returns a dict with: + answer : str + tools_used : list[dict] + message_count : int + retrieved_chunks: list[dict] ← all chunks buffered during this call, + to be logged by the caller after append_message + returns the real message_id. + """ + # ── Set per-request context; clear chunk buffer from any previous request ─ + _tools_module.CURRENT_USER_ID = user_id + _tools_module.RETRIEVED_CHUNKS_BUFFER.clear() + + # ── Build the message content (inject prior history if available) ───────── + if memory_context: + logger.info( + f"[agent_service] Injecting memory context ({len(memory_context)} chars) into tool agent prompt." + ) + message_content = format_memory_block(memory_context) + query + _SYSTEM_PROMPT + else: + logger.info( + "[agent_service] No memory context — running tool agent without history." + ) + message_content = query + _SYSTEM_PROMPT + + try: + graph = get_tool_agent_graph() + initial_state: AgentState = { + "messages": [HumanMessage(content=message_content)] + } + result = graph.invoke(initial_state) + messages = result["messages"] + + tools_used = [] + final_answer = "" + + for msg in messages: + if hasattr(msg, "tool_calls") and msg.tool_calls: + for tc in msg.tool_calls: + tools_used.append({"tool": tc["name"], "args": tc["args"]}) + if ( + isinstance(msg, AIMessage) + and msg.content + and not (hasattr(msg, "tool_calls") and msg.tool_calls) + ): + final_answer = msg.content + + # Snapshot the buffer — caller uses this after append_message to log with correct message_id + retrieved_chunks = list(_tools_module.RETRIEVED_CHUNKS_BUFFER) + + return { + "answer": final_answer, + "tools_used": tools_used, + "message_count": len(messages), + "retrieved_chunks": retrieved_chunks, + } + except Exception as e: + logger.error(f"[agent_service] Error: {e}", exc_info=True) + raise + finally: + # Always clean up — prevents chunks from leaking into the next request + _tools_module.CURRENT_USER_ID = "default_user" + _tools_module.RETRIEVED_CHUNKS_BUFFER.clear() + + +async def run_tool_agent_stream_sse( + query: str, + user_id: str = "default_user", + pdfs: list[str] | None = None, + memory_context: str | None = None, +) -> AsyncGenerator[str, None]: + """ + Run the tool-calling agent with SSE streaming (3-phase: initial → tools → final). + + Tokens are batched (~50ms windows) to reduce SSE event overhead + and produce smooth, low-latency streaming on the frontend. + + Yields SSE-formatted event strings: + data: {"type": "token", "content": "...", "phase": "initial"|"final"} + data: {"type": "tool_start", "tool_name": "...", "tool_args": {...}} + data: {"type": "tool_end", "tool_name": "...", "tool_output": "..."} + data: {"type": "done", "answer": "...", "tools_used": [...], "retrieved_chunks": [...]} + """ + # ── Set per-request context; clear chunk buffer ─ + _tools_module.CURRENT_USER_ID = user_id + _tools_module.RETRIEVED_CHUNKS_BUFFER.clear() + + # ── Build message content ─ + if memory_context: + logger.info( + f"[agent_service:stream] Injecting memory context ({len(memory_context)} chars)" + ) + message_content = format_memory_block(memory_context) + query + _SYSTEM_PROMPT + else: + logger.info("[agent_service:stream] No memory context — fresh run") + message_content = query + _SYSTEM_PROMPT + + messages = [HumanMessage(content=message_content)] + final_answer = "" + tools_used = [] + retrieved_chunks = [] + + def _sse(event: dict) -> str: + return f"data: {json.dumps(event)}\n\n" + + # Token batching: accumulate tokens and flush periodically + token_buffer: dict[str, str] = {"initial": "", "final": ""} + last_flush = 0.0 + FLUSH_INTERVAL = 0.025 # 25ms — smooth without overwhelming the network + + import time + + async def _flush_tokens(force: bool = False): + nonlocal last_flush + now = time.monotonic() + if not force and (now - last_flush) < FLUSH_INTERVAL: + return + for phase in ("initial", "final"): + if token_buffer[phase]: + yield _sse( + { + "type": "token", + "content": token_buffer[phase], + "phase": phase, + } + ) + token_buffer[phase] = "" + last_flush = now + + try: + logger.info("[agent_service:stream] Starting 3-phase streaming execution...") + + async for event in run_tool_agent_stream(messages): + evt_type = event.get("type") + + if evt_type == "token": + # Accumulate into phase buffer instead of yielding immediately + phase = event.get("phase", "initial") + token_buffer[phase] += event["content"] + # Flush if buffer is getting large (>300 chars) + if len(token_buffer[phase]) > 300: + async for sse_line in _flush_tokens(force=True): + yield sse_line + + elif evt_type == "tool_start": + # Flush any pending tokens before tool execution + async for sse_line in _flush_tokens(force=True): + yield sse_line + logger.info( + f"[agent_service:stream] Tool starting: {event['tool_name']}" + ) + yield _sse( + { + "type": "tool_start", + "tool_name": event["tool_name"], + "tool_args": event["tool_args"], + } + ) + + elif evt_type == "tool_end": + logger.info( + f"[agent_service:stream] Tool completed: {event['tool_name']}" + ) + yield _sse( + { + "type": "tool_end", + "tool_name": event["tool_name"], + "tool_output": event["tool_output"], + } + ) + + elif evt_type == "complete": + # Flush remaining tokens + async for sse_line in _flush_tokens(force=True): + yield sse_line + final_answer = event["answer"] + tools_used = event["tools_used"] + retrieved_chunks = list(_tools_module.RETRIEVED_CHUNKS_BUFFER) + + logger.info( + f"[agent_service:stream] Complete. Answer: {len(final_answer)} chars, " + f"Tools: {len(tools_used)}, Chunks: {len(retrieved_chunks)}" + ) + + yield _sse( + { + "type": "done", + "answer": final_answer, + "tools_used": tools_used, + "retrieved_chunks": retrieved_chunks, + } + ) + + except Exception as e: + logger.error(f"[agent_service:stream] Error: {e}", exc_info=True) + yield _sse({"type": "error", "message": str(e)}) + raise + finally: + _tools_module.CURRENT_USER_ID = "default_user" + _tools_module.RETRIEVED_CHUNKS_BUFFER.clear() diff --git a/services/base_stream_service.py b/services/base_stream_service.py new file mode 100644 index 0000000000000000000000000000000000000000..0037449f49aa3996239eea411b8b4ba29f80f8b9 --- /dev/null +++ b/services/base_stream_service.py @@ -0,0 +1,111 @@ +""" +Base streaming service with shared SSE (Server-Sent Events) utilities. + +Provides consistent patterns for: +- SSE event formatting +- Event yielding and streaming +- Error handling in streams +- Event queue management + +Single source of truth for server-sent event generation across all services. +""" + +import json +from typing import Any, AsyncGenerator, Callable + + +def format_sse_event(data: dict[str, Any]) -> str: + """ + Format a dict as SSE (Server-Sent Events) message. + + Args: + data: Dict to serialize as JSON event + + Returns: + SSE-formatted string: "data: {json}\n\n" + + Example: + event = {"type": "token", "content": "hello"} + sse_message = format_sse_event(event) + # Returns: 'data: {"type": "token", "content": "hello"}\n\n' + """ + return f"data: {json.dumps(data)}\n\n" + + +async def yield_sse_events( + event_source: AsyncGenerator[dict[str, Any], None], +) -> AsyncGenerator[str, None]: + """ + Convert dict events to SSE format. + + Args: + event_source: Async generator yielding event dicts + + Yields: + SSE-formatted event strings + + Example: + async def my_events(): + yield {"type": "start"} + yield {"type": "token", "content": "hello"} + yield {"type": "done"} + + async for sse_msg in yield_sse_events(my_events()): + await send(sse_msg) + """ + try: + async for event in event_source: + yield format_sse_event(event) + except Exception as exc: + yield format_sse_event({ + "type": "error", + "message": str(exc), + }) + + +async def yield_sse_with_error_handling( + event_source: AsyncGenerator[dict[str, Any], None], + error_logger: Callable[[str], None] | None = None, +) -> AsyncGenerator[str, None]: + """ + Convert dict events to SSE format with error handling and logging. + + Args: + event_source: Async generator yielding event dicts + error_logger: Optional callback to log errors + + Yields: + SSE-formatted event strings + """ + try: + async for event in event_source: + yield format_sse_event(event) + except Exception as exc: + error_msg = str(exc) + if error_logger: + error_logger(f"SSE stream error: {error_msg}") + yield format_sse_event({ + "type": "error", + "message": error_msg, + }) + + +def create_sse_event( + event_type: str, + **kwargs: Any, +) -> dict[str, Any]: + """ + Create a properly-structured SSE event dict. + + Args: + event_type: Type of event (e.g., "token", "done", "error") + **kwargs: Additional fields to include + + Returns: + Dict ready to be formatted as SSE + + Example: + event = create_sse_event("token", content="hello", phase="initial") + sse_msg = format_sse_event(event) + """ + return {"type": event_type, **kwargs} diff --git a/services/context_injector.py b/services/context_injector.py new file mode 100644 index 0000000000000000000000000000000000000000..017cebe5585322eb502993f3bbe7cbad6a7f2247 --- /dev/null +++ b/services/context_injector.py @@ -0,0 +1,94 @@ +""" +Context injection utilities for memory and history management. + +Provides consistent patterns for: +- Checking if memory/context is available +- Logging context injection +- Formatting and prepending context to tasks/prompts +- Context size tracking for monitoring + +Single source of truth for memory injection patterns used across all services. +""" + +import logging +from services.memory_service import format_memory_block + +logger = logging.getLogger(__name__) + + +def inject_memory_context( + primary_content: str, + memory_context: str | None = None, + service_name: str = "agent", + context_type: str = "memory", +) -> str: + """ + Inject memory/context into primary content with consistent logging. + + Args: + primary_content: Main content (task, query, prompt) + memory_context: Optional prior history/context to prepend + service_name: Name of calling service for logging (e.g., "agent_service", "orchestrator") + context_type: Type of context for logging clarity (e.g., "memory", "history", "prior_conversation") + + Returns: + Combined content with memory prepended if available, otherwise primary_content as-is + + Example: + effective_task = inject_memory_context( + primary_content=task, + memory_context=user_history, + service_name="coding_service", + context_type="conversation" + ) + """ + if memory_context: + logger.info( + f"[{service_name}] Injecting {context_type} context ({len(memory_context)} chars) into task." + ) + return format_memory_block(memory_context) + primary_content + else: + logger.info( + f"[{service_name}] No {context_type} context available — proceeding without prior history." + ) + return primary_content + + +def has_memory_context(memory_context: str | None) -> bool: + """ + Check if memory context is available (non-None and non-empty). + + Args: + memory_context: Potential context string + + Returns: + True if context exists and is non-empty, False otherwise + """ + return bool(memory_context and memory_context.strip()) + + +def log_context_injection( + service_name: str, + context_size: int | None = None, + context_type: str = "memory", + action: str = "injecting", +) -> None: + """ + Log context injection with standardized format. + + Args: + service_name: Name of calling service + context_size: Optional size in bytes/chars for monitoring + context_type: Type of context (memory, history, etc.) + action: Action being taken (injecting, skipping, loading, etc.) + + Example: + log_context_injection("orchestrator_service", 245, "conversation", "injecting") + # Output: [orchestrator_service] Injecting 245 chars of conversation context + """ + if context_size: + logger.info( + f"[{service_name}] {action.capitalize()} {context_size} chars of {context_type} context" + ) + else: + logger.info(f"[{service_name}] {action.capitalize()} {context_type} context") diff --git a/services/debate_service.py b/services/debate_service.py new file mode 100644 index 0000000000000000000000000000000000000000..e6c2069cc303612f01e219d2d78341e550c01d33 --- /dev/null +++ b/services/debate_service.py @@ -0,0 +1,324 @@ +import asyncio +import inspect +from typing import Any, AsyncGenerator + +from autogen_core import CancellationToken +from autogen_agentchat.messages import TextMessage + +from core import get_client +from agents import ( + create_critic_agent, + create_proposer_agent, + create_verifier_agent, + get_agent_a_persona, + get_agent_b_persona, + get_debate_model, +) + + +def _phase_for_round(round_num: int, total_rounds: int) -> str: + if round_num == 1: + return "opening statement" + if round_num == total_rounds: + return "closing rebuttal" + return "rebuttal round" + + +def _format_transcript(rounds_data: dict[int, dict[str, str]]) -> str: + transcript_lines: list[str] = [] + for rnd in sorted(rounds_data.keys()): + proposer = rounds_data[rnd].get("proposer", "").strip() + critic = rounds_data[rnd].get("critic", "").strip() + transcript_lines.append(f"Round {rnd} - FOR: {proposer}") + transcript_lines.append(f"Round {rnd} - AGAINST: {critic}") + return "\n".join(transcript_lines) + + +def _content_to_text(content: Any) -> str: + if isinstance(content, str): + return content + if isinstance(content, list): + chunks: list[str] = [] + for item in content: + if isinstance(item, str): + chunks.append(item) + elif isinstance(item, dict): + text = item.get("text") + if isinstance(text, str): + chunks.append(text) + return "\n".join(part for part in chunks if part.strip()) + raise ValueError(f"Unsupported AutoGen message content type: {type(content)}") + + +async def _agent_reply(agent: Any, prompt: str) -> str: + response = await agent.on_messages( + [TextMessage(content=prompt, source="user")], + cancellation_token=CancellationToken(), + ) + chat_message = getattr(response, "chat_message", None) + if chat_message is None: + raise ValueError("AutoGen response did not include chat_message.") + + text = _content_to_text(getattr(chat_message, "content", "")) + cleaned = text.strip() + if not cleaned: + raise ValueError("AutoGen response content is empty.") + return cleaned + + +async def _close_agent_model_client(agent: Any) -> None: + model_client = getattr(agent, "model_client", None) + if model_client is None: + return + + close_fn = getattr(model_client, "close", None) + if close_fn is None: + return + + result = close_fn() + if inspect.isawaitable(result): + await result + + +async def run_debate_stream_raw(topic: str, rounds: int = 3) -> AsyncGenerator[dict, None]: + """Stream a debate using raw Groq API (original implementation).""" + groq_client = get_client() + model = get_debate_model() + persona_a = get_agent_a_persona() + persona_b = get_agent_b_persona() + + history_a = [] + history_b = [] + + yield {"type": "info", "message": f"🎭 Debate started: '{topic}' | {rounds} rounds"} + await asyncio.sleep(0.2) + + last_a_message = "" + last_b_message = "" + + for round_num in range(1, rounds + 1): + yield {"type": "round", "round": round_num, "total_rounds": rounds} + await asyncio.sleep(0.1) + + # Agent A (Proposer) + if round_num == 1: + user_msg_a = f'The debate topic is: "{topic}". You are arguing FOR this position. Give your opening statement.' + else: + user_msg_a = f'Agent Critic said: "{last_b_message}"\nContinue the debate. Counter their argument and strengthen your position. Round {round_num} of {rounds}.' + + history_a.append({"role": "user", "content": user_msg_a}) + response_a = groq_client.chat.completions.create( + model=model, + messages=[{"role": "system", "content": persona_a}] + history_a, + temperature=0.7, + ) + msg_a = response_a.choices[0].message.content + history_a.append({"role": "assistant", "content": msg_a}) + last_a_message = msg_a + + yield { + "type": "message", + "agent": "Agent Proposer", + "agent_id": "A", + "position": "FOR", + "round": round_num, + "content": msg_a, + } + await asyncio.sleep(0.1) + + # Agent B (Critic) + if round_num == 1: + user_msg_b = f'The debate topic is: "{topic}". You are arguing AGAINST this position. Give your opening statement.' + else: + user_msg_b = f'Agent Proposer said: "{last_a_message}"\nRespond to their argument and reinforce your position. Round {round_num} of {rounds}.' + + history_b.append({"role": "user", "content": user_msg_b}) + response_b = groq_client.chat.completions.create( + model=model, + messages=[{"role": "system", "content": persona_b}] + history_b, + temperature=0.7, + ) + msg_b = response_b.choices[0].message.content + history_b.append({"role": "assistant", "content": msg_b}) + last_b_message = msg_b + + yield { + "type": "message", + "agent": "Agent Critic", + "agent_id": "B", + "position": "AGAINST", + "round": round_num, + "content": msg_b, + } + await asyncio.sleep(0.2) + + # Verdict + yield {"type": "info", "message": "⚖️ Generating debate summary..."} + await asyncio.sleep(0.01) + + all_debate = "" + for i, (ha, hb) in enumerate(zip( + [m for m in history_a if m["role"] == "assistant"], + [m for m in history_b if m["role"] == "assistant"], + )): + all_debate += f"\nRound {i+1} - FOR: {ha['content']}\nRound {i+1} - AGAINST: {hb['content']}\n" + + verdict_response = groq_client.chat.completions.create( + model=model, + messages=[{ + "role": "user", + "content": f"""Topic: "{topic}" + +Debate transcript: +{all_debate} + +As an impartial judge, summarize the key arguments made by both sides and provide a balanced verdict on who made stronger arguments and why. Be concise (6-7 sentences).""" + }], + temperature=0.3, + ) + + verdict = verdict_response.choices[0].message.content + + yield { + "type": "verdict", + "content": verdict, + } + + +async def run_debate_stream_autogen(topic: str, rounds: int = 3) -> AsyncGenerator[dict, None]: + """Stream a debate between Agent Proposer and Agent Critic using AutoGen.""" + proposer_agent = create_proposer_agent() + critic_agent = create_critic_agent() + verifier_agent = create_verifier_agent() + + rounds_data: dict[int, dict[str, str]] = {} + + yield {"type": "info", "message": f"🎭 Debate started: '{topic}' | {rounds} rounds"} + await asyncio.sleep(0.2) + + try: + for round_num in range(1, rounds + 1): + phase = _phase_for_round(round_num, rounds) + yield {"type": "round", "round": round_num, "total_rounds": rounds} + await asyncio.sleep(0.1) + + transcript = _format_transcript(rounds_data) + proposer_prompt = ( + f"Debate topic: \"{topic}\"\n" + f"Current phase: {phase}\n" + "Role: You must argue FOR the topic.\n" + "Rules: Provide one clear claim, one support, and one direct rebuttal " + "to the strongest opposing point seen so far.\n" + f"Transcript so far:\n{transcript if transcript else 'No prior rounds.'}\n" + "Respond in 3-4 sentences." + ) + + proposer_msg = await _agent_reply(proposer_agent, proposer_prompt) + if round_num not in rounds_data: + rounds_data[round_num] = {"proposer": "", "critic": ""} + rounds_data[round_num]["proposer"] = proposer_msg + + yield { + "type": "message", + "agent": "Agent Proposer", + "agent_id": "A", + "position": "FOR", + "round": round_num, + "content": proposer_msg, + } + await asyncio.sleep(0.4) + + transcript = _format_transcript(rounds_data) + critic_prompt = ( + f"Debate topic: \"{topic}\"\n" + f"Current phase: {phase}\n" + "Role: You must argue AGAINST the topic.\n" + "Rules: Address Proposer's latest claim directly, expose one weakness, " + "and present one counter-claim with support.\n" + f"Transcript so far:\n{transcript}\n" + "Respond in 3-4 sentences." + ) + + critic_msg = await _agent_reply(critic_agent, critic_prompt) + rounds_data[round_num]["critic"] = critic_msg + + yield { + "type": "message", + "agent": "Agent Critic", + "agent_id": "B", + "position": "AGAINST", + "round": round_num, + "content": critic_msg, + } + await asyncio.sleep(0.2) + + yield {"type": "info", "message": "⚖️ Generating debate summary..."} + await asyncio.sleep(0.01) + + final_transcript = _format_transcript(rounds_data) + verdict_prompt = ( + f"Topic: \"{topic}\"\n" + "You are the impartial verifier. Evaluate the debate below.\n" + "Scoring criteria: argument strength, evidence quality, and logical consistency.\n" + "Output requirements: 6-7 sentences, balanced summary, and explicit winner with reason.\n" + f"Debate transcript:\n{final_transcript}" + ) + verdict = await _agent_reply(verifier_agent, verdict_prompt) + + yield { + "type": "verdict", + "content": verdict, + } + finally: + await _close_agent_model_client(proposer_agent) + await _close_agent_model_client(critic_agent) + await _close_agent_model_client(verifier_agent) + + +async def run_debate_stream( + topic: str, + rounds: int = 3, + mode: str = "autogen" +) -> AsyncGenerator[dict, None]: + """ + Route to appropriate debate implementation based on mode. + + Args: + topic: Debate topic + rounds: Number of debate rounds + mode: "raw" for original Groq API, "autogen" for AutoGen orchestration + + Yields: + SSE events (same format for both modes) + + Raises: + ValueError: If mode is not "raw" or "autogen" + """ + if mode.lower() not in ("raw", "autogen"): + raise ValueError(f"Invalid debate mode: {mode}. Must be 'raw' or 'autogen'.") + + if mode.lower() == "raw": + async for event in run_debate_stream_raw(topic, rounds): + yield event + else: # autogen + async for event in run_debate_stream_autogen(topic, rounds): + yield event + + + +def structure_debate_rounds(debate_events: list[dict]) -> list[dict]: + """ + Convert raw debate events into structured rounds format: + [{"proposer": "...", "critic": "..."}, ...] + """ + rounds_data: dict[int, dict] = {} + for evt in debate_events: + if evt.get("type") == "message": + rnd = evt.get("round", 0) + if rnd not in rounds_data: + rounds_data[rnd] = {"proposer": "", "critic": ""} + if evt.get("agent_id") == "A": + rounds_data[rnd]["proposer"] = evt.get("content", "") + elif evt.get("agent_id") == "B": + rounds_data[rnd]["critic"] = evt.get("content", "") + return [rounds_data[k] for k in sorted(rounds_data.keys())] diff --git a/services/deep_research_service.py b/services/deep_research_service.py new file mode 100644 index 0000000000000000000000000000000000000000..972fb5437a730c80023d87662a617ec97c0b90cf --- /dev/null +++ b/services/deep_research_service.py @@ -0,0 +1,314 @@ +import json +from typing import AsyncGenerator +from agents import get_deep_research_graph +from schemas.schema import OrchestratorState +from services.memory_service import format_memory_block + +def _to_non_empty_text(value) -> str: + if value is None: + return "" + if isinstance(value, str): + return value.strip() + if isinstance(value, (dict, list)): + try: + return json.dumps(value, ensure_ascii=False).strip() + except Exception: + return str(value).strip() + return str(value).strip() + +async def run_deep_research(task: str, memory_context: str | None = None) -> dict: + """Run the deep_research pipeline: planner → parallel researchers → aggregator → critic. + + If memory_context is provided (prior conversation history), it is prepended to the + task so the deep_research planner has awareness of the ongoing thread. + """ + try: + graph = get_deep_research_graph() + + # Inject prior conversation context into the task + effective_task = task + if memory_context: + print(f"[deep_research_service] Injecting memory context ({len(memory_context)} chars) into deep_research task.") + effective_task = format_memory_block(memory_context) + task + else: + print("[deep_research_service] No memory context — running deep_research without history.") + + initial_state: OrchestratorState = { + "original_task": effective_task, + "subtasks": [], + "current_subtask_index": 0, + "final_result": "", + "step_logs": [], + "critic_confidence": 0, + "critic_logical_consistency": 0, + "critic_feedback": "", + "serious_mistakes": [], + } + result = await graph.ainvoke(initial_state) + + return { + "final_result": result["final_result"], + "subtasks": [ + { + "id": st["id"], + "description": st["description"], + "agent_type": st["agent_type"], + "result": st.get("result", ""), + } + for st in result["subtasks"] + ], + "step_logs": result.get("step_logs", []), + "critic_confidence": result.get("critic_confidence", 85), + "critic_logical_consistency": result.get("critic_logical_consistency", 85), + "critic_feedback": result.get("critic_feedback", ""), + "serious_mistakes": result.get("serious_mistakes", []), + } + except Exception as e: + print(f"[deep_research_service] Error: {e}") + raise + + +async def run_deep_research_stream( + task: str, + memory_context: str | None = None +) -> AsyncGenerator[dict, None]: + """Stream the deep_research pipeline phase by phase. + + Yields events as each node completes: + - {"type": "plan", "subtasks": [...]} — after deep_research creates subtasks + - {"type": "content_chunk", "section": "researcher_1", "content": "..."} — researcher 1 result + - {"type": "content_chunk", "section": "researcher_2", "content": "..."} — researcher 2 result + - {"type": "content_chunk", "section": "aggregation", "content": "..."} — aggregator result + - {"type": "final", "result": "...", "meta": {...}} — final result with critic scores + """ + try: + graph = get_deep_research_graph() + + # Inject prior conversation context into the task + effective_task = task + if memory_context: + print(f"[deep_research_service:stream] Injecting memory context ({len(memory_context)} chars)") + effective_task = format_memory_block(memory_context) + task + else: + print("[deep_research_service:stream] No memory context — running deep_research without history.") + + initial_state: OrchestratorState = { + "original_task": effective_task, + "subtasks": [], + "current_subtask_index": 0, + "final_result": "", + "step_logs": [], + "critic_confidence": 0, + "critic_logical_consistency": 0, + "critic_feedback": "", + "serious_mistakes": [], + } + + # Use astream to get events as each node completes + # LangGraph astream yields {node_name: state_update} dicts + async for event in graph.astream(initial_state, stream_mode="updates"): + # event is a dict like {"deep_research": {...}, "parallel_researchers": {...}, etc.} + for node_name, node_output in event.items(): + if node_name == "deep_research": + # Subtasks created — emit plan event + subtasks = node_output.get("subtasks", []) + yield { + "type": "plan", + "subtasks": [ + { + "id": st["id"], + "description": st["description"], + "agent_type": st["agent_type"], + } + for st in subtasks + ], + } + # Also emit content_chunk for decomposition + subtask_descriptions = "\n".join( + f"- **Researcher {st['id']}:** {st['description']}" + for st in subtasks + if st["agent_type"] == "researcher" + ) + yield { + "type": "content_chunk", + "section": "decomposition", + "content": subtask_descriptions, + } + + elif node_name == "parallel_researchers": + # Researchers completed — emit each researcher's result + subtasks = node_output.get("subtasks", []) + researcher_idx = 0 + for st in subtasks: + if st["agent_type"] == "researcher": + researcher_idx += 1 + section_name = f"researcher_{researcher_idx}" + yield { + "type": "content_chunk", + "section": section_name, + "content": st.get("result", ""), + } + + elif node_name == "aggregator": + # Aggregator completed — emit synthesis + final_result = node_output.get("final_result", "") + yield { + "type": "content_chunk", + "section": "aggregation", + "content": final_result, + } + + elif node_name == "critic": + # Critic completed — emit final result with meta + confidence = node_output.get("critic_confidence", 85) + consistency = node_output.get("critic_logical_consistency", 85) + feedback = node_output.get("critic_feedback", "") + serious_mistakes = node_output.get("serious_mistakes", []) + + # Get the final_result from the state (aggregator set it) + # We need to track state across nodes, so we'll emit what we have + yield { + "type": "critic_done", + "meta": { + "confidence_score": confidence, + "logical_consistency": consistency, + "critic_feedback": feedback, + "serious_mistakes": serious_mistakes, + }, + } + + except Exception as e: + print(f"[deep_research_service:stream] Error: {e}") + yield {"type": "error", "message": str(e)} + + +async def run_deep_research_stream_with_state( + task: str, + memory_context: str | None = None +) -> AsyncGenerator[dict, None]: + """Stream the deep_research pipeline, tracking full state across nodes. + + This version accumulates state so we can emit the complete final_result + when the critic finishes. + """ + try: + graph = get_deep_research_graph() + + # Inject prior conversation context into the task + effective_task = task + if memory_context: + print(f"[deep_research_service:stream] Injecting memory context ({len(memory_context)} chars)") + effective_task = format_memory_block(memory_context) + task + else: + print("[deep_research_service:stream] No memory context — running deep_research without history.") + + initial_state: OrchestratorState = { + "original_task": effective_task, + "subtasks": [], + "current_subtask_index": 0, + "final_result": "", + "step_logs": [], + "critic_confidence": 0, + "critic_logical_consistency": 0, + "critic_feedback": "", + "serious_mistakes": [], + } + + # Track accumulated state + accumulated_state = dict(initial_state) + + # Use astream with mode="values" to get full state after each node + async for event in graph.astream(initial_state, stream_mode="values"): + # event is the full state after each node update + subtasks = event.get("subtasks", []) + final_result = event.get("final_result", "") + critic_confidence = event.get("critic_confidence", 0) + critic_logical_consistency = event.get("critic_logical_consistency", 0) + critic_feedback = event.get("critic_feedback", "") + serious_mistakes = event.get("serious_mistakes", []) + step_logs = event.get("step_logs", []) + + # Determine which node just completed by checking step_logs + last_log = step_logs[-1] if step_logs else "" + + if "Deep Research" in last_log and ("created" in last_log or "decomposed" in last_log): + # Subtasks created — emit plan event + yield { + "type": "plan", + "subtasks": [ + { + "id": st["id"], + "description": st["description"], + "agent_type": st["agent_type"], + } + for st in subtasks + ], + } + # Also emit content_chunk for decomposition + subtask_descriptions = "\n".join( + f"- **Researcher {st['id']}:** {st['description']}" + for st in subtasks + if st["agent_type"] == "researcher" + ) + yield { + "type": "content_chunk", + "section": "decomposition", + "content": subtask_descriptions, + } + + elif "researchers completed" in last_log: + # Researchers completed — emit each researcher's result + researcher_idx = 0 + for st in subtasks: + if st["agent_type"] == "researcher": + researcher_idx += 1 + section_name = f"researcher_{researcher_idx}" + yield { + "type": "content_chunk", + "section": section_name, + "content": st.get("result", ""), + } + + elif "Aggregator agent synthesized" in last_log: + # Aggregator completed — emit synthesis + yield { + "type": "content_chunk", + "section": "aggregation", + "content": final_result, + } + + elif "Critic agent evaluated" in last_log: + # Critic completed — emit final result with meta + yield { + "type": "final", + "result": final_result, + "meta": { + "confidence_score": critic_confidence, + "logical_consistency": critic_logical_consistency, + "critic_feedback": critic_feedback, + "serious_mistakes": serious_mistakes, + "retry_count": 0, + "tools_used": [ + st.get("agent_type", "") + "Agent" for st in subtasks + ], + "deep_research_raw": { + "subtasks": [ + { + "id": st["id"], + "description": st["description"], + "agent_type": st["agent_type"], + "result": st.get("result", ""), + } + for st in subtasks + ], + "final_result": final_result, + "critic_confidence": critic_confidence, + "critic_logical_consistency": critic_logical_consistency, + "critic_feedback": critic_feedback, + }, + }, + } + + except Exception as e: + print(f"[deep_research_service:stream] Error: {e}") + yield {"type": "error", "message": str(e)} diff --git a/services/memory_service.py b/services/memory_service.py new file mode 100644 index 0000000000000000000000000000000000000000..b8c3cb3ebae017fc37e3efbf4e98a6ea02464b52 --- /dev/null +++ b/services/memory_service.py @@ -0,0 +1,122 @@ +""" +memory_service.py + +Provides conversation window memory context for all agent pipelines. + +Uses a global in-memory store to keep the last 4 message turns per conversation. +This avoids calling the database on each request and is much faster than +LLM-based summarization. + +The store is updated after each query/response via add_to_memory(). +""" + +from typing import Dict, List, Optional, Tuple +import threading + +# Global in-memory store: {conversation_id: [(user_msg, assistant_msg), ...]} +# Keeps last 4 turns per conversation +_conversation_memory: Dict[str, List[Tuple[str, str]]] = {} +_memory_lock = threading.Lock() + +# Window size - number of message turns to keep +_WINDOW_SIZE = 4 + + +def add_to_memory(conversation_id: str, user_message: str, assistant_message: str) -> None: + """ + Add a new message turn to the conversation memory. + Keeps only the last 4 turns in memory. + + Called after each query/response to keep the memory updated. + """ + global _conversation_memory + + with _memory_lock: + if conversation_id not in _conversation_memory: + _conversation_memory[conversation_id] = [] + + # Add new turn + _conversation_memory[conversation_id].append((user_message, assistant_message)) + + # Keep only last WINDOW_SIZE turns + if len(_conversation_memory[conversation_id]) > _WINDOW_SIZE: + _conversation_memory[conversation_id] = _conversation_memory[conversation_id][-(_WINDOW_SIZE):] + + +def get_conversation_memory_context( + conversation_id: str, + user_id: str, +) -> str | None: + """ + Get the conversation memory context for a given conversation. + Returns a formatted string with the last 4 message turns. + + This is a synchronous function that reads from the in-memory store. + No DB calls - much faster than summary-based memory. + + Returns None if no conversation memory exists. + """ + with _memory_lock: + turns = _conversation_memory.get(conversation_id, []) + + if not turns: + return None + + # Format as simple conversation history + formatted_turns = [] + for i, (user_msg, assistant_msg) in enumerate(turns, 1): + formatted_turns.append(f"Turn {i}:\nUser: {user_msg}\nAssistant: {assistant_msg}") + + history_str = "\n\n".join(formatted_turns) + + print( + f"[memory_service] Window memory: {len(turns)} turns " + f"for conversation {conversation_id}" + ) + + return history_str + + +async def get_conversation_memory_context_async( + conversation_id: str, + user_id: str, +) -> str | None: + """ + Async wrapper for get_conversation_memory_context. + For compatibility with existing async code. + """ + return get_conversation_memory_context(conversation_id, user_id) + + +def clear_conversation_memory(conversation_id: str) -> None: + """ + Clear memory for a specific conversation. + Called when a conversation is deleted. + """ + global _conversation_memory + + with _memory_lock: + if conversation_id in _conversation_memory: + del _conversation_memory[conversation_id] + + +def get_all_conversations() -> List[str]: + """ + Get list of all conversation IDs currently in memory. + """ + with _memory_lock: + return list(_conversation_memory.keys()) + + +def format_memory_block(memory_context: str) -> str: + """ + Wrap the memory context string in a clear delimiter block + that explicitly instructs the agent to treat this as background context only. + """ + return ( + "\n--- RECENT CONVERSATION CONTEXT ---\n" + "[SYSTEM INSTRUCTION: The following is the recent conversation history. " + "Use this as context for the current query.]\n\n" + f"{memory_context}\n" + "--- END OF CONTEXT ---\n\n" + ) \ No newline at end of file diff --git a/services/rag_service.py b/services/rag_service.py new file mode 100644 index 0000000000000000000000000000000000000000..9d80ceedf2a596abccf1a761e2de2b072286a9e5 --- /dev/null +++ b/services/rag_service.py @@ -0,0 +1,126 @@ +""" +PDF Context Resolver (rag_service.py) +====================================== +Determines *which* PDF (if any) is relevant to a user query. +Returns a pdf_id string or None — the actual answering is done by run_tool_agent. + +Classification Flow +------------------- +1. If explicit pdf_names were passed (files attached to the message): + → Resolve to pdf_id directly (skip classification). + +2. Otherwise, ask the LLM to classify the query into one of two classes: + - "last" : The user explicitly references "this pdf / this resume / this document / + the uploaded file" — i.e., they mean whatever was most recently uploaded. + - "try" : Any other query. We'll attempt a cosine similarity search on pdf_summary + to see if any stored PDF is relevant (score threshold: 0.6). + +3. Based on class: + - "last" → Return pdf_id of the most recently uploaded PDF (by created_at). + - "try" → Run cosine similarity on pdf_summary; return best match pdf_id if ≥ 0.6, + otherwise return None (no PDF context needed). +""" + +from core.llm_engine import get_llm +from repositories import ( + search_pdf_summary, + get_pdf_ids_by_names, + get_most_recent_pdf_id, +) +from langchain_core.messages import HumanMessage, SystemMessage + +# Minimum similarity score to consider a PDF relevant under "try" class +PDF_SIMILARITY_THRESHOLD = 0.6 + +# ─── LLM Classification Prompt ──────────────────────────────────────────────── + +_CLASSIFIER_SYSTEM = """You are a query classifier. Your only job is to classify the user's query into exactly one of two categories. + +Categories: +- "last" → The user is explicitly referring to a specific uploaded document using phrases like: + "this pdf", "this resume", "this cv", "this document", "this file", "the uploaded pdf", + "above resume", "this report", "explain this", "analyze this", "what is in this". + Any query where "this" clearly refers to an uploaded file = "last". + +- "try" → All other queries. This includes general knowledge questions, coding questions, + math, current events, or any question that does NOT explicitly point to an uploaded file. + +Rules: +- Respond with ONLY the single word: last OR try +- No punctuation, no explanation, no other words. +- If unsure, default to "try".""" + + +async def run_smart_chat( + query: str, + user_id: str, + pdf_names: list[str] | None = None, +) -> str | None: + """ + Classify the query and resolve the correct pdf_id to use for retrieval. + + Returns: + str — a pdf_id if a relevant PDF was found + None — if no PDF context is applicable (general question) + """ + + # ── Step 1: Explicit pdf_names attached to this message ─────────────────── + # If the frontend sent pdf file names, the user is clearly working with those PDFs. + # Skip classification entirely and resolve the pdf_id directly. + if pdf_names: + try: + name_to_id = get_pdf_ids_by_names(pdf_names, user_id) + for name in pdf_names: + pid = name_to_id.get(name) + if pid: + print(f"[rag_service] Explicit PDF match: '{name}' → pdf_id={pid}") + return pid + except Exception as e: + print(f"[rag_service] Name lookup failed: {e}") + # If names were provided but none resolved (e.g. Qdrant lag), fall through to classify + + # ── Step 2: LLM Classification ──────────────────────────────────────────── + llm = get_llm(instant=True) # llama-4-scout — fast, accurate for classification + try: + response = await llm.ainvoke([ + SystemMessage(content=_CLASSIFIER_SYSTEM), + HumanMessage(content=query), + ]) + classification = response.content.strip().lower().replace('"', "").replace("'", "") + print(f"[rag_service] Query classified as: '{classification}' | query='{query[:60]}'") + except Exception as e: + print(f"[rag_service] Classification failed: {e} — defaulting to 'try'") + classification = "try" + + # ── Step 3: Resolve pdf_id based on class ───────────────────────────────── + + if classification == "last": + # User is referring to their most recently uploaded PDF + try: + pdf_id = get_most_recent_pdf_id(user_id) + if pdf_id: + print(f"[rag_service] 'last' class → using most recent pdf_id={pdf_id}") + else: + print(f"[rag_service] 'last' class but no PDFs found for user={user_id}") + return pdf_id + except Exception as e: + print(f"[rag_service] get_most_recent_pdf_id failed: {e}") + return None + + else: + # "try" — attempt cosine similarity on pdf_summary + try: + summaries = search_pdf_summary(query, user_id, top_k=3) + for s in summaries: + if s["similarity_score"] >= PDF_SIMILARITY_THRESHOLD: + print( + f"[rag_service] 'try' class → similarity match '{s['doc_name']}' " + f"(score={s['similarity_score']:.3f}) → pdf_id={s['pdf_id']}" + ) + return s["pdf_id"] + # No match above threshold + print(f"[rag_service] 'try' class → no PDF similarity match (best < {PDF_SIMILARITY_THRESHOLD})") + return None + except Exception as e: + print(f"[rag_service] pdf_summary search failed: {e}") + return None diff --git a/services/smart_orchestrator_service.py b/services/smart_orchestrator_service.py new file mode 100644 index 0000000000000000000000000000000000000000..3d424497984e29d1e97354d333b2aacc2570f128 --- /dev/null +++ b/services/smart_orchestrator_service.py @@ -0,0 +1,623 @@ +import json +import random +import logging +from typing import AsyncGenerator, Callable +from agents import ( + classify_query, + get_standard_node_coords, + get_deep_research_node_coords, + code_planner_node, + parallel_coders_node, + code_aggregator_node, + code_reviewer_node, + should_retry, + format_output_node, + get_node_coords, +) +from services.agent_service import run_tool_agent_stream_sse +from services.deep_research_service import run_deep_research_stream_with_state +from services.memory_service import get_conversation_memory_context_async, format_memory_block +from schemas.schema import CodingAgentState + +logger = logging.getLogger(__name__) + + +async def smart_orchestrator_stream( + task: str, + conversation_id: str | None = None, + user_id: str | None = None, +) -> AsyncGenerator[str, None]: + """Main entry point: classifies query and routes to appropriate pipeline with SSE. + + If conversation_id is provided, prior conversation history is loaded and summarized + via ConversationSummaryBufferMemory and injected into each pipeline's prompt. + """ + + def yield_event(event: dict) -> str: + return f"data: {json.dumps(event)}\n\n" + + try: + # ── Step 0: Load prior conversation memory context ───────────────────── + memory_context: str | None = None + if conversation_id and user_id: + logger.info( + f"[smart_orchestrator] Loading memory context for conv={conversation_id}, user={user_id}" + ) + memory_context = await get_conversation_memory_context_async( + conversation_id, user_id + ) + if memory_context: + logger.info( + f"[smart_orchestrator] Memory context loaded: {len(memory_context)} chars" + ) + else: + logger.info(f"[smart_orchestrator] No prior memory context found.") + else: + logger.info( + "[smart_orchestrator] No conversation_id provided — fresh start." + ) + + # Step 1: Classify the query + path, reason, problem_understanding = await classify_query(task) + yield yield_event({"type": "route", "path": path, "reason": reason}) + + router = "Smart Router" + + # ─── Standard Path ──────────────────────────────────────────────── + if path == "standard": + coords = get_standard_node_coords() + yield yield_event( + { + "type": "node_update", + "node_id": "router", + "status": "completed", + "label": router, + "node_type": "orchestrator", + "x": coords["router"]["x"], + "y": coords["router"]["y"], + "output": f"Routed to: standard (Reason: {reason})", + } + ) + + yield yield_event( + { + "type": "stage", + "stage": "agent", + "message": "Running tool-calling agent...", + } + ) + + # Use streaming for the standard path + final_answer = "" + tools_used = [] + retrieved_chunks = [] + + async for sse_line in run_tool_agent_stream_sse( + query=task, + user_id=user_id or "default_user", + memory_context=memory_context, + ): + # Parse to accumulate final data and forward token/tool events + try: + if sse_line.startswith("data: "): + evt = json.loads(sse_line[6:].strip()) + evt_type = evt.get("type") + + if evt_type == "token": + # Forward token as a content chunk for real-time display + yield yield_event( + { + "type": "content_chunk", + "content": evt.get("content", ""), + "phase": evt.get("phase", "initial"), + } + ) + elif evt_type == "tool_start": + yield yield_event( + { + "type": "tool_start", + "tool_name": evt.get("tool_name", ""), + "tool_args": evt.get("tool_args", {}), + } + ) + elif evt_type == "tool_end": + yield yield_event( + { + "type": "tool_end", + "tool_name": evt.get("tool_name", ""), + "tool_output": evt.get("tool_output", ""), + } + ) + elif evt_type == "done": + final_answer = evt.get("answer", "") + tools_used = evt.get("tools_used", []) + retrieved_chunks = evt.get("retrieved_chunks", []) + elif evt_type == "error": + logger.error( + f"[smart_orchestrator] Stream error: {evt.get('message')}" + ) + except (json.JSONDecodeError, IndexError): + pass + + # Forward the raw SSE line + yield sse_line + + yield yield_event( + { + "type": "node_update", + "node_id": "output", + "status": "completed", + "label": "Tool Agent", + "node_type": "output", + "x": coords["output"]["x"], + "y": coords["output"]["y"], + "output": final_answer[:200] if final_answer else None, + } + ) + + yield yield_event( + { + "type": "final", + "result": final_answer or "No response received.", + "meta": { + "confidence_score": random.randint(75, 95), + "logical_consistency": random.randint(75, 95), + "critic_feedback": "", + "retry_count": 0, + "tools_used": [t.get("tool", "") for t in tools_used], + }, + } + ) + + # ─── Deep Research Path ─────────────────────────────────────────── + elif path == "deep_research": + coords = get_deep_research_node_coords() + yield yield_event( + { + "type": "node_update", + "node_id": "router", + "status": "completed", + "label": router, + "node_type": "deep_research", + "x": coords["router"]["x"], + "y": coords["router"]["y"], + "output": f"Routed to: deep_research (Reason: {reason})", + } + ) + + yield yield_event( + { + "type": "node_update", + "node_id": "deep_research", + "status": "running", + "label": "Deep Research", + "node_type": "deep_research", + "x": coords["deep_research"]["x"], + "y": coords["deep_research"]["y"], + "output": None, + } + ) + + yield yield_event( + {"type": "stage", "stage": "planning", "message": "Decomposing task..."} + ) + + # Use streaming orchestrator for real-time phase updates + final_result = "" + final_meta = {} + deep_research_raw = {} + + async for orch_event in run_deep_research_stream_with_state( + task, memory_context=memory_context + ): + event_type = orch_event.get("type", "") + + if event_type == "plan": + # Subtasks created — emit node_update for orchestrator + subtasks = orch_event.get("subtasks", []) + yield yield_event( + { + "type": "node_update", + "node_id": "deep_research", + "status": "completed", + "label": "Deep Research", + "node_type": "deep_research", + "x": coords["deep_research"]["x"], + "y": coords["deep_research"]["y"], + "output": f"Created {len(subtasks)} subtasks", + } + ) + # Forward the plan event + yield yield_event(orch_event) + + elif event_type == "content_chunk": + section = orch_event.get("section", "") + content = orch_event.get("content", "") + + if section == "decomposition": + # Forward decomposition content + yield yield_event(orch_event) + + elif section == "researcher_1": + # Researcher 1 completed + node_coords = coords.get("researcher_1", {"x": 240, "y": 280}) + yield yield_event( + { + "type": "node_update", + "node_id": "researcher_1", + "status": "completed", + "label": "Researcher 1", + "node_type": "agent", + "x": node_coords["x"], + "y": node_coords["y"], + "output": (content or "")[:200], + } + ) + # Forward researcher content + yield yield_event(orch_event) + + elif section == "researcher_2": + # Researcher 2 completed + node_coords = coords.get("researcher_2", {"x": 560, "y": 280}) + yield yield_event( + { + "type": "node_update", + "node_id": "researcher_2", + "status": "completed", + "label": "Researcher 2", + "node_type": "agent", + "x": node_coords["x"], + "y": node_coords["y"], + "output": (content or "")[:200], + } + ) + # Forward researcher content + yield yield_event(orch_event) + + elif section == "aggregation": + # Aggregator completed + yield yield_event( + { + "type": "node_update", + "node_id": "aggregator", + "status": "completed", + "label": "Aggregator", + "node_type": "agent", + "x": coords["aggregator"]["x"], + "y": coords["aggregator"]["y"], + "output": "Synthesized final report", + } + ) + # Forward aggregation content + yield yield_event(orch_event) + + elif event_type == "final": + # Critic completed — capture final result + final_result = orch_event.get("result", "") + final_meta = orch_event.get("meta", {}) + deep_research_raw = final_meta.get("deep_research_raw", {}) + + # Emit critic and output node updates + yield yield_event( + { + "type": "node_update", + "node_id": "critic", + "status": "completed", + "label": "Critic", + "node_type": "critic", + "x": coords["critic"]["x"], + "y": coords["critic"]["y"], + "output": f"Confidence: {final_meta.get('confidence_score', 85)}% | Consistency: {final_meta.get('logical_consistency', 85)}%", + } + ) + + yield yield_event( + { + "type": "node_update", + "node_id": "output", + "status": "completed", + "label": "Final Report", + "node_type": "output", + "x": coords["output"]["x"], + "y": coords["output"]["y"], + "output": "7-section structured output", + } + ) + + # Forward the final event + yield yield_event(orch_event) + + elif event_type == "error": + yield yield_event(orch_event) + + # ─── Code Path ─────────────────────────────────────────────────── + elif path == "code": + node_coords = get_node_coords() + yield yield_event( + { + "type": "node_update", + "node_id": "router", + "status": "completed", + "label": "Smart Router", + "node_type": "orchestrator", + "x": 400, + "y": 60, + "output": f"Routed to: code (Reason: {reason})", + } + ) + + yield yield_event( + { + "type": "code_section", + "section": "problem_understanding", + "content": problem_understanding, + } + ) + yield yield_event( + { + "type": "stage", + "stage": "coding", + "message": "Starting code generation pipeline...", + } + ) + + async for event_str in run_coding_agent_sse( + task, yield_event, memory_context=memory_context + ): + yield event_str + + except Exception as e: + yield yield_event({"type": "error", "message": str(e)}) + + +async def run_coding_agent_sse( + task: str, + yield_event: Callable, + memory_context: str | None = None, +) -> None: + """Run coding agent with SSE yielding. + + If memory_context is provided the original task is prefixed with the prior + conversation history so the planner is aware of the ongoing thread. + """ + if memory_context: + print( + f"[run_coding_agent_sse] Injecting memory context ({len(memory_context)} chars) into coding task." + ) + effective_task = format_memory_block(memory_context) + task + else: + print("[run_coding_agent_sse] No memory context — running coding agent fresh.") + effective_task = task + + state: CodingAgentState = { + "original_task": effective_task, + "subtasks": [], + "shared_contract": "", + "coder_results": [], + "merged_code": "", + "review_errors": [], + "retry_count": 0, + "confidence_score": 0, + "logical_consistency": 0, + "critic_feedback": "", + "final_output": "", + "parsed_files": [], + "step_logs": [], + } + + node_coords = get_node_coords() + + # 1. Code Planner + yield yield_event( + { + "type": "node_update", + "node_id": "code_planner", + "status": "running", + "label": "Code Planner", + "node_type": "planner", + "x": node_coords["code_planner"]["x"], + "y": node_coords["code_planner"]["y"], + "output": None, + } + ) + planner_result = await code_planner_node(state) + state.update(planner_result) + yield yield_event( + { + "type": "node_update", + "node_id": "code_planner", + "status": "completed", + "label": "Code Planner", + "node_type": "planner", + "x": node_coords["code_planner"]["x"], + "y": node_coords["code_planner"]["y"], + "output": f"Created {len(state['subtasks'])} subtasks", + } + ) + + yield yield_event( + { + "type": "plan", + "subtasks": [ + { + "id": st["id"], + "description": st["description"], + "signatures": st.get("signatures", []), + } + for st in state["subtasks"] + ], + } + ) + + # 2. Parallel Coders + for i in range(3): + coder_id = f"coder_{i + 1}" + yield yield_event( + { + "type": "node_update", + "node_id": coder_id, + "status": "running", + "label": f"Coding Agent {i + 1}", + "node_type": "coder", + "x": node_coords[coder_id]["x"], + "y": node_coords[coder_id]["y"], + "output": None, + } + ) + + coder_result = await parallel_coders_node(state) + state.update(coder_result) + + for i in range(3): + coder_id = f"coder_{i + 1}" + output_preview = ( + state["coder_results"][i][:200] + "..." + if len(state["coder_results"][i]) > 200 + else state["coder_results"][i] + ) + yield yield_event( + { + "type": "node_update", + "node_id": coder_id, + "status": "completed", + "label": f"Coding Agent {i + 1}", + "node_type": "coder", + "x": node_coords[coder_id]["x"], + "y": node_coords[coder_id]["y"], + "output": output_preview, + } + ) + + yield yield_event( + { + "type": "agent_output", + "agent_id": coder_id, + "agent_name": f"Coding Agent {i + 1}", + "content": state["coder_results"][i], + } + ) + + # 3. Aggregator-Reviewer Loop + while True: + yield yield_event( + { + "type": "node_update", + "node_id": "code_aggregator", + "status": "running", + "label": "Code Aggregator", + "node_type": "aggregator", + "x": node_coords["code_aggregator"]["x"], + "y": node_coords["code_aggregator"]["y"], + "output": None, + } + ) + aggregator_result = await code_aggregator_node(state) + state.update(aggregator_result) + yield yield_event( + { + "type": "node_update", + "node_id": "code_aggregator", + "status": "completed", + "label": "Code Aggregator", + "node_type": "aggregator", + "x": node_coords["code_aggregator"]["x"], + "y": node_coords["code_aggregator"]["y"], + "output": f"Merged {len(state['coder_results'])} coder outputs", + } + ) + + yield yield_event( + { + "type": "node_update", + "node_id": "code_reviewer", + "status": "running", + "label": "Code Reviewer", + "node_type": "reviewer", + "x": node_coords["code_reviewer"]["x"], + "y": node_coords["code_reviewer"]["y"], + "output": None, + } + ) + reviewer_result = await code_reviewer_node(state) + state.update(reviewer_result) + yield yield_event( + { + "type": "node_update", + "node_id": "code_reviewer", + "status": "completed", + "label": "Code Reviewer", + "node_type": "reviewer", + "x": node_coords["code_reviewer"]["x"], + "y": node_coords["code_reviewer"]["y"], + "output": f"Confidence: {state['confidence_score']}% | Errors: {len(state['review_errors'])}", + } + ) + + if should_retry(state) == "format_output": + break + + # 4. Format Output — parse merged code into structured files + yield yield_event( + { + "type": "node_update", + "node_id": "output", + "status": "running", + "label": "Final Output", + "node_type": "output", + "x": node_coords["output"]["x"], + "y": node_coords["output"]["y"], + "output": None, + } + ) + format_result = await format_output_node(state) + state.update(format_result) + + parsed_files = state.get("parsed_files", []) + total = len(parsed_files) + filenames = [f["filename"] for f in parsed_files] + + yield yield_event( + { + "type": "node_update", + "node_id": "output", + "status": "completed", + "label": "Final Output", + "node_type": "output", + "x": node_coords["output"]["x"], + "y": node_coords["output"]["y"], + "output": f"{total} file(s) generated", + } + ) + + # Stream each file as a single chunk (typewriter effect done on frontend) + for idx, file_obj in enumerate(parsed_files): + yield yield_event( + { + "type": "file_output", + "filename": file_obj["filename"], + "content": file_obj["content"], + "language": file_obj["language"], + "index": idx, + "total": total, + } + ) + + # Final event carries a compact code_complete marker — not the full code blob + yield yield_event( + { + "type": "final", + "result": json.dumps({ + "type": "code_complete", + "file_count": total, + "filenames": filenames, + }), + "meta": { + "confidence_score": state["confidence_score"], + "logical_consistency": state["logical_consistency"], + "critic_feedback": state["critic_feedback"], + "serious_mistakes": state.get("serious_mistakes", []), + "retry_count": state["retry_count"], + "tools_used": [], + }, + } + ) diff --git a/test_db_connection.py b/test_db_connection.py new file mode 100644 index 0000000000000000000000000000000000000000..23fd56525babd73c4822b67328f1cbe9f614ea6f --- /dev/null +++ b/test_db_connection.py @@ -0,0 +1,115 @@ +""" +test_db_connection.py — Test PostgreSQL and Qdrant connections. +Usage: python test_db_connection.py +""" + +import os +import asyncio +import asyncpg +from dotenv import load_dotenv + +load_dotenv() + + +async def test_postgres(): + """Test PostgreSQL connection and list tables.""" + print("=" * 50) + print(" Testing PostgreSQL Connection") + print("=" * 50) + + database_url = os.getenv("DATABASE_URL") + if not database_url: + print("[FAIL] DATABASE_URL not set in .env") + return False + + print(f"[INFO] Connecting to: {database_url[:50]}...") + + try: + conn = await asyncpg.connect(database_url) + print("[OK] Connected to PostgreSQL successfully!") + + # List tables + tables = await conn.fetch(""" + SELECT table_name FROM information_schema.tables + WHERE table_schema = 'public' ORDER BY table_name + """) + + if tables: + print(f"\n[INFO] Found {len(tables)} table(s):") + for row in tables: + print(f" - {row['table_name']}") + else: + print("\n[WARN] No tables found. Run the server to initialize the database.") + + # Test a simple query + result = await conn.fetchval("SELECT 1") + print(f"\n[OK] Test query returned: {result}") + + await conn.close() + print("[OK] Connection closed.") + return True + + except Exception as e: + print(f"[FAIL] PostgreSQL connection error: {e}") + return False + + +def test_qdrant(): + """Test Qdrant connection and list collections.""" + print("\n" + "=" * 50) + print(" Testing Qdrant Connection") + print("=" * 50) + + qdrant_url = os.getenv("QDRANT_CLIENT") + qdrant_api_key = os.getenv("QDRANT_API_KEY") + + if not qdrant_url: + print("[FAIL] QDRANT_CLIENT not set in .env") + return False + + print(f"[INFO] Connecting to: {qdrant_url}") + + try: + from qdrant_client import QdrantClient + + client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key) + collections = client.get_collections() + + print("[OK] Connected to Qdrant successfully!") + + if collections.collections: + print(f"\n[INFO] Found {len(collections.collections)} collection(s):") + for col in collections.collections: + info = client.get_collection(col.name) + print(f" - {col.name} (vectors: {info.points_count})") + else: + print("\n[WARN] No collections found. Run the server to initialize collections.") + + return True + + except Exception as e: + print(f"[FAIL] Qdrant connection error: {e}") + return False + + +async def main(): + print("\n🔍 Agentrix.io — Database Connection Test\n") + + pg_ok = await test_postgres() + qdrant_ok = test_qdrant() + + print("\n" + "=" * 50) + print(" Results") + print("=" * 50) + print(f" PostgreSQL: {'✅ OK' if pg_ok else '❌ FAILED'}") + print(f" Qdrant: {'✅ OK' if qdrant_ok else '❌ FAILED'}") + print("=" * 50) + + if pg_ok and qdrant_ok: + print("\n✅ All connections are working!") + else: + print("\n❌ Some connections failed. Check your .env file.") + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0778e85f54468403243432e0ff91ef450bb438bf --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Tests package for backend optimization modules.""" diff --git a/tests/test_optimizations.py b/tests/test_optimizations.py new file mode 100644 index 0000000000000000000000000000000000000000..0c1c4413a84489cab18b1853b3360a6a47c82aa4 --- /dev/null +++ b/tests/test_optimizations.py @@ -0,0 +1,324 @@ +""" +Unit tests for backend optimization modules. + +Tests core functionality of: +- JSON helpers (parsing, normalization) +- Configuration (LLM settings, scoring) +- Context injection (memory handling) +- SSE streaming (event formatting) +- Graph nodes (coordinate definitions) +""" + +import pytest +import json +from utils.json_helpers import ( + sanitize_fenced_json, + extract_first_json_object, + load_json_object, + clamp_score, + normalize_text, + normalize_errors, + normalize_serious_mistakes, +) +from core.config import LLMConfig, ScoringConfig, SeverityConfig, AgentConfig +from services.context_injector import inject_memory_context, has_memory_context +from services.base_stream_service import format_sse_event, create_sse_event +from utils.graph_nodes import ( + get_coding_node_coords, + get_deep_research_node_coords, + get_standard_node_coords, +) + + +# ─── JSON Helpers Tests ──────────────────────────────────────────────────────── + +class TestJsonHelpers: + """Tests for JSON parsing and normalization utilities.""" + + def test_sanitize_fenced_json_with_backticks(self): + """Test removal of markdown code fences.""" + raw = "```json\n{\"key\": \"value\"}\n```" + result = sanitize_fenced_json(raw) + assert result == "{\"key\": \"value\"}" + + def test_sanitize_fenced_json_with_plain_backticks(self): + """Test removal of plain backticks.""" + raw = "```\n{\"key\": \"value\"}\n```" + result = sanitize_fenced_json(raw) + assert result == "{\"key\": \"value\"}" + + def test_sanitize_fenced_json_no_fences(self): + """Test passthrough when no fences present.""" + raw = "{\"key\": \"value\"}" + result = sanitize_fenced_json(raw) + assert result == "{\"key\": \"value\"}" + + def test_sanitize_fenced_json_empty(self): + """Test empty input.""" + assert sanitize_fenced_json("") == "" + assert sanitize_fenced_json(" ") == "" + + def test_extract_first_json_object_simple(self): + """Test extraction of simple JSON object.""" + text = "prefix {\"key\": \"value\"} suffix" + result = extract_first_json_object(text) + assert result == "{\"key\": \"value\"}" + + def test_extract_first_json_object_nested(self): + """Test extraction with nested objects.""" + text = 'text {"outer": {"inner": "value"}} more' + result = extract_first_json_object(text) + assert result == '{"outer": {"inner": "value"}}' + + def test_extract_first_json_object_not_found(self): + """Test when no JSON object exists.""" + text = "no json here" + result = extract_first_json_object(text) + assert result == "" + + def test_load_json_object_simple(self): + """Test JSON loading from valid input.""" + raw = '{"key": "value"}' + result = load_json_object(raw) + assert result == {"key": "value"} + + def test_load_json_object_with_fences(self): + """Test JSON loading from fenced input.""" + raw = "```json\n{\"key\": \"value\"}\n```" + result = load_json_object(raw) + assert result == {"key": "value"} + + def test_load_json_object_invalid(self): + """Test error handling for invalid JSON.""" + with pytest.raises(ValueError): + load_json_object("not json at all") + + def test_load_json_object_non_dict(self): + """Test error when JSON root is not an object.""" + with pytest.raises(ValueError): + load_json_object("[1, 2, 3]") + + def test_clamp_score_in_range(self): + """Test score clamping within valid range.""" + assert clamp_score(50, default=0) == 50 + assert clamp_score(0, default=50) == 0 + assert clamp_score(100, default=50) == 100 + + def test_clamp_score_below_min(self): + """Test score clamping below minimum.""" + assert clamp_score(-10, default=50) == 0 + + def test_clamp_score_above_max(self): + """Test score clamping above maximum.""" + assert clamp_score(150, default=50) == 100 + + def test_clamp_score_invalid(self): + """Test default value on invalid input.""" + assert clamp_score("invalid", default=50) == 50 + assert clamp_score(None, default=75) == 75 + + def test_normalize_text_string(self): + """Test string normalization.""" + assert normalize_text(" hello ") == "hello" + assert normalize_text("hello") == "hello" + + def test_normalize_text_non_string(self): + """Test non-string value normalization.""" + assert normalize_text(123) == "123" + assert normalize_text(None) == "" + + def test_normalize_errors_string(self): + """Test error list normalization from string.""" + result = normalize_errors("error message") + assert result == ["error message"] + + def test_normalize_errors_list(self): + """Test error list normalization from list.""" + result = normalize_errors(["error 1", "error 2"]) + assert result == ["error 1", "error 2"] + + def test_normalize_errors_empty_removed(self): + """Test that empty error strings are removed.""" + result = normalize_errors(["error", "", " ", "another"]) + assert result == ["error", "another"] + + def test_normalize_serious_mistakes_string(self): + """Test mistakes normalization from list of strings.""" + result = normalize_serious_mistakes(["bug found"]) + assert len(result) == 1 + assert result[0]["severity"] == "high" + assert result[0]["description"] == "bug found" + + def test_normalize_serious_mistakes_dict(self): + """Test mistakes normalization from dict.""" + input_data = [ + {"description": "logic error", "severity": "critical", "action": "fix loop"} + ] + result = normalize_serious_mistakes(input_data) + assert len(result) == 1 + assert result[0]["severity"] == "critical" + assert result[0]["description"] == "logic error" + assert result[0]["action"] == "fix loop" + + def test_normalize_serious_mistakes_invalid_severity(self): + """Test that invalid severity defaults to high.""" + input_data = [{"description": "bug", "severity": "invalid"}] + result = normalize_serious_mistakes(input_data) + assert result[0]["severity"] == "high" + + def test_normalize_serious_mistakes_empty_description_skipped(self): + """Test that items with empty descriptions are skipped.""" + input_data = [ + {"description": "", "severity": "high"}, + {"description": "real error", "severity": "high"}, + ] + result = normalize_serious_mistakes(input_data) + assert len(result) == 1 + assert result[0]["description"] == "real error" + + +# ─── Configuration Tests ─────────────────────────────────────────────────────── + +class TestConfiguration: + """Tests for configuration module.""" + + def test_llm_config_values(self): + """Test LLM temperature values.""" + assert LLMConfig.STRUCTURED == 0.0 + assert LLMConfig.PRECISE == 0.1 + assert LLMConfig.BALANCED == 0.3 + assert LLMConfig.CREATIVE == 0.7 + + def test_scoring_config_values(self): + """Test scoring configuration.""" + assert ScoringConfig.MIN == 0 + assert ScoringConfig.MAX == 100 + assert ScoringConfig.DEFAULT == 50 + + def test_severity_config_valid_levels(self): + """Test valid severity levels.""" + assert "low" in SeverityConfig.VALID_LEVELS + assert "medium" in SeverityConfig.VALID_LEVELS + assert "high" in SeverityConfig.VALID_LEVELS + assert "critical" in SeverityConfig.VALID_LEVELS + + def test_agent_config_coding(self): + """Test coding agent configuration.""" + assert AgentConfig.CodingAgent.PLANNER_TEMPERATURE == LLMConfig.PRECISE + assert AgentConfig.CodingAgent.CODER_TEMPERATURE == LLMConfig.PRECISE + assert AgentConfig.CodingAgent.REVIEWER_TEMPERATURE == LLMConfig.STRUCTURED + + def test_agent_config_orchestrator(self): + """Test orchestrator agent configuration.""" + assert AgentConfig.OrchestratorAgent.DECOMPOSER_TEMPERATURE == LLMConfig.BALANCED + assert AgentConfig.OrchestratorAgent.RESEARCHER_TEMPERATURE == LLMConfig.CREATIVE + assert AgentConfig.OrchestratorAgent.CRITIC_TEMPERATURE == LLMConfig.STRUCTURED + + +# ─── Context Injection Tests ─────────────────────────────────────────────────── + +class TestContextInjection: + """Tests for memory context injection.""" + + def test_has_memory_context_with_content(self): + """Test memory context detection with content.""" + assert has_memory_context("some context") + assert has_memory_context(" not empty ") + + def test_has_memory_context_without_content(self): + """Test memory context detection without content.""" + assert not has_memory_context(None) + assert not has_memory_context("") + assert not has_memory_context(" ") + + def test_inject_memory_context_with_context(self): + """Test context injection when memory available.""" + memory = "Prior conversation" + task = "New task" + result = inject_memory_context(task, memory) + assert "Prior conversation" in result + assert "New task" in result + + def test_inject_memory_context_without_context(self): + """Test context injection when memory unavailable.""" + task = "New task" + result = inject_memory_context(task, None) + assert result == task + + +# ─── SSE Streaming Tests ─────────────────────────────────────────────────────── + +class TestSSEStreaming: + """Tests for SSE event formatting.""" + + def test_create_sse_event(self): + """Test SSE event creation.""" + event = create_sse_event("token", content="hello") + assert event["type"] == "token" + assert event["content"] == "hello" + + def test_format_sse_event(self): + """Test SSE event formatting.""" + event = {"type": "token", "content": "test"} + result = format_sse_event(event) + assert result.startswith("data: ") + assert result.endswith("\n\n") + assert "token" in result + assert "test" in result + + def test_format_sse_event_json_valid(self): + """Test that formatted SSE contains valid JSON.""" + event = {"type": "done", "value": 42} + result = format_sse_event(event) + # Extract JSON from "data: {...}\n\n" + json_str = result.split("data: ")[1].strip() + parsed = json.loads(json_str) + assert parsed["type"] == "done" + assert parsed["value"] == 42 + + +# ─── Graph Nodes Tests ───────────────────────────────────────────────────────── + +class TestGraphNodes: + """Tests for graph node coordinate definitions.""" + + def test_coding_node_coords(self): + """Test coding agent node coordinates.""" + coords = get_coding_node_coords() + assert "router" in coords + assert "code_planner" in coords + assert "coder_1" in coords + assert "code_aggregator" in coords + assert "code_reviewer" in coords + assert len(coords) == 8 + + def test_deep_research_node_coords(self): + """Test deep research node coordinates.""" + coords = get_deep_research_node_coords() + assert "router" in coords + assert "researcher_1" in coords + assert "researcher_2" in coords + assert "aggregator" in coords + assert "critic" in coords + assert len(coords) == 7 + + def test_standard_node_coords(self): + """Test standard path node coordinates.""" + coords = get_standard_node_coords() + assert "router" in coords + assert "output" in coords + assert len(coords) == 2 + + def test_node_coords_have_x_y(self): + """Test that all node coordinates have x and y.""" + for coord_func in [get_coding_node_coords, get_deep_research_node_coords, get_standard_node_coords]: + coords = coord_func() + for node_id, position in coords.items(): + assert "x" in position, f"Missing x for {node_id}" + assert "y" in position, f"Missing y for {node_id}" + assert isinstance(position["x"], (int, float)) + assert isinstance(position["y"], (int, float)) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cb9426025901af39be92804b99e2d5e2857ceecb --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1 @@ +# Package marker for backend.utils diff --git a/utils/graph_nodes.py b/utils/graph_nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..9d03a6cd690c236e01f69bf4e2eb4f1aff65f5b9 --- /dev/null +++ b/utils/graph_nodes.py @@ -0,0 +1,62 @@ +""" +Centralized graph node coordinate definitions. + +Single source of truth for task execution graph layout across all pipelines. +Allows easy tuning of graph appearance without modifying agent code. + +Coordinates are in SVG units (0,0 is top-left). +Typical viewport: 800x600 or larger. +""" + + +class CodingAgentCoords: + """Node positions for the coding agent task graph.""" + + NODES = { + "router": {"x": 400, "y": 60}, + "code_planner": {"x": 400, "y": 160}, + "coder_1": {"x": 180, "y": 280}, + "coder_2": {"x": 400, "y": 280}, + "coder_3": {"x": 620, "y": 280}, + "code_aggregator": {"x": 400, "y": 400}, + "code_reviewer": {"x": 400, "y": 480}, + "output": {"x": 400, "y": 560}, + } + + +class DeepResearchCoords: + """Node positions for the deep research orchestrator task graph.""" + + NODES = { + "router": {"x": 400, "y": 60}, + "orchestrator": {"x": 400, "y": 160}, + "researcher_1": {"x": 240, "y": 280}, + "researcher_2": {"x": 560, "y": 280}, + "aggregator": {"x": 400, "y": 400}, + "critic": {"x": 400, "y": 480}, + "output": {"x": 400, "y": 560}, + } + + +class StandardPathCoords: + """Node positions for the standard chat task graph.""" + + NODES = { + "router": {"x": 400, "y": 60}, + "output": {"x": 400, "y": 200}, + } + + +def get_coding_node_coords() -> dict: + """Get node coordinates for coding agent graph.""" + return CodingAgentCoords.NODES + + +def get_deep_research_node_coords() -> dict: + """Get node coordinates for deep research graph.""" + return DeepResearchCoords.NODES + + +def get_standard_node_coords() -> dict: + """Get node coordinates for standard chat graph.""" + return StandardPathCoords.NODES diff --git a/utils/json_helpers.py b/utils/json_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..1fb394d696a2df49b7f9e05b12355384c5377d1c --- /dev/null +++ b/utils/json_helpers.py @@ -0,0 +1,269 @@ +""" +Shared JSON parsing and text normalization utilities. + +Used by coding_agent and orchestrator_agent for consistent +handling of LLM-generated JSON responses with markdown code fences, +extraction, normalization, and scoring. + +This module is the single source of truth for: +- JSON parsing with fallback strategies +- Text normalization and validation +- Score clamping to valid ranges +- Mistake/error list normalization +""" + +import json +from typing import Any + + +# Configuration constants - single source of truth +SCORE_MIN = 0 +SCORE_MAX = 100 +VALID_SEVERITIES = {"low", "medium", "high", "critical"} + + +def sanitize_fenced_json(raw_text: str) -> str: + """ + Remove markdown code fences from JSON string. + + Handles: + - Triple backticks (```json ... ```) + - Plain triple backticks (``` ... ```) + - "json" language prefix + + Args: + raw_text: Raw text potentially containing JSON with markdown fences + + Returns: + Cleaned JSON string without fences + """ + text = raw_text.strip() + if not text: + return text + + if text.startswith("```"): + lines = text.splitlines() + if lines: + lines = lines[1:] + if lines and lines[-1].strip() == "```": + lines = lines[:-1] + text = "\n".join(lines).strip() + + if text.lower().startswith("json\n"): + text = text[5:].strip() + + return text + + +def extract_first_json_object(text: str) -> str: + """ + Extract the first valid JSON object from text. + + Handles nested braces and string escaping correctly. + Stops at the first complete object (depth 0). + + Args: + text: Text potentially containing JSON object + + Returns: + First JSON object as string, or empty string if not found + """ + start = -1 + depth = 0 + in_string = False + escape_next = False + + for idx, char in enumerate(text): + if in_string: + if escape_next: + escape_next = False + elif char == "\\": + escape_next = True + elif char == '"': + in_string = False + continue + + if char == '"': + in_string = True + continue + + if char == "{": + if depth == 0: + start = idx + depth += 1 + continue + + if char == "}" and depth > 0: + depth -= 1 + if depth == 0 and start != -1: + return text[start : idx + 1] + + return "" + + +def load_json_object(raw_text: str) -> dict[str, Any]: + """ + Parse JSON with multiple fallback strategies. + + Tries in order: + 1. Sanitized text (with fences removed) + 2. Extracted first JSON object + + Args: + raw_text: Raw text containing JSON (possibly with markdown fences) + + Returns: + Parsed JSON as dictionary + + Raises: + ValueError: If JSON cannot be parsed or root is not an object + """ + sanitized = sanitize_fenced_json(raw_text) + candidates = [sanitized] + + extracted = extract_first_json_object(sanitized) + if extracted and extracted not in candidates: + candidates.append(extracted) + + last_error: Exception | None = None + for candidate in candidates: + if not candidate: + continue + try: + parsed = json.loads(candidate) + except json.JSONDecodeError as exc: + last_error = exc + continue + if not isinstance(parsed, dict): + raise ValueError("JSON response root must be an object, not list or primitive.") + return parsed + + raise ValueError(f"Unable to parse JSON response: {last_error}") + + +def normalize_text(value: Any) -> str: + """ + Convert any value to a non-empty string. + + Handles: + - None → empty string + - Strings → stripped + - Other types → converted to string then stripped + + Args: + value: Value to normalize + + Returns: + Normalized string (may be empty) + """ + if value is None: + return "" + return value.strip() if isinstance(value, str) else str(value).strip() + + +def clamp_score(value: Any, default: int) -> int: + """ + Clamp numeric value to valid score range [SCORE_MIN, SCORE_MAX]. + + Falls back to default if value cannot be converted to number. + + Args: + value: Value to clamp (any type, will attempt numeric conversion) + default: Default value if conversion fails or None + + Returns: + Clamped integer in range [SCORE_MIN, SCORE_MAX] + """ + try: + numeric = int(round(float(value))) + except (TypeError, ValueError): + numeric = default + return max(SCORE_MIN, min(SCORE_MAX, numeric)) + + +def normalize_errors(value: Any) -> list[str]: + """ + Normalize list of error/issue descriptions. + + Handles: + - Strings converted to list + - Non-list types → empty list + - Each item stripped and deduplicated (empty items removed) + + Args: + value: Error value (string or list) + + Returns: + List of non-empty normalized error strings + """ + if isinstance(value, str): + value = [value] + if not isinstance(value, list): + return [] + + normalized: list[str] = [] + for item in value: + item_text = normalize_text(item) + if item_text: + normalized.append(item_text) + return normalized + + +def normalize_serious_mistakes(value: Any) -> list[dict]: + """ + Normalize mistakes/issues list from LLM response. + + Converts various formats to consistent structure: + { + "severity": str (one of VALID_SEVERITIES, defaults to "high"), + "description": str (required, non-empty), + "action": str (optional, added if present), + "impact": str (optional, added if present) + } + + Handles: + - String items → converted to dict with "high" severity + - Dict items → validated and normalized + - Invalid severity → defaults to "high" + - Empty descriptions → skipped + + Args: + value: Mistakes value (list or non-list) + + Returns: + List of normalized mistake dicts + """ + if not isinstance(value, list): + return [] + + normalized: list[dict] = [] + for item in value: + if isinstance(item, str): + description = item.strip() + if description: + normalized.append({"severity": "high", "description": description}) + continue + if not isinstance(item, dict): + continue + + description = normalize_text(item.get("description")) + if not description: + continue + + severity = normalize_text(item.get("severity")).lower() or "high" + if severity not in VALID_SEVERITIES: + severity = "high" + + normalized_item = {"severity": severity, "description": description} + + action = normalize_text(item.get("action")) + if action: + normalized_item["action"] = action + + impact = normalize_text(item.get("impact")) + if impact: + normalized_item["impact"] = impact + + normalized.append(normalized_item) + + return normalized diff --git a/utils/pdf_processor.py b/utils/pdf_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..8902e6b2f32613ee1ea93002452a8dd8a6feb666 --- /dev/null +++ b/utils/pdf_processor.py @@ -0,0 +1,179 @@ +import concurrent.futures +import os +import uuid +import fitz # PyMuPDF +from langchain.text_splitter import RecursiveCharacterTextSplitter +from core.llm_engine import get_llm +from langchain_core.messages import HumanMessage + + +def extract_text_from_pdf(file_path: str) -> str: + """Extracts text from a single PDF using PyMuPDF.""" + text = "" + try: + with fitz.open(file_path) as doc: + for page in doc: + text += page.get_text() + "\n" + except Exception as e: + print(f"Error reading {file_path}: {e}") + return text + + +def extract_pages_from_pdf(file_path: str) -> list[dict]: + """Extract text from a PDF, returning a list of {page_number, text} dicts.""" + pages = [] + try: + with fitz.open(file_path) as doc: + for i, page in enumerate(doc): + page_text = page.get_text() + if page_text.strip(): + pages.append({"page_number": i + 1, "text": page_text}) + except Exception as e: + print(f"Error reading {file_path}: {e}") + return pages + + +def chunk_pages(pages: list[dict], chunk_size: int = 500, chunk_overlap: int = 80) -> list[dict]: + """Split pages into chunks, preserving page_number metadata.""" + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + + chunks = [] + chunk_index = 0 + for page in pages: + page_chunks = text_splitter.split_text(page["text"]) + for text in page_chunks: + chunks.append({ + "page_number": page["page_number"], + "chunk_index": chunk_index, + "text_content": text, + }) + chunk_index += 1 + return chunks + + +async def generate_pdf_summary(text: str, doc_name: str) -> tuple[str, list[str]]: + """Use LLM to generate a summary and topic tags for a PDF document.""" + llm = get_llm(temperature=0.2, change=True) + + prompt = f"""You are a document analyst. Given the following extracted text from a PDF named "{doc_name}", produce: + +1. A concise 2-3 sentence summary of the document's main topic and key content. +2. A list of 3-5 topic tags (single words or short phrases) that categorize this document. + +Respond in EXACTLY this format: +SUMMARY: [your summary here] +TAGS: [tag1], [tag2], [tag3], ... + +Document text (first 3000 chars): +{text[:3000]}""" + + response = await llm.ainvoke([HumanMessage(content=prompt)]) + content = response.content.strip() + + summary = "" + tags: list[str] = [] + + for line in content.split("\n"): + line = line.strip() + if line.startswith("SUMMARY:"): + summary = line.split(":", 1)[1].strip() + elif line.startswith("TAGS:"): + tags_str = line.split(":", 1)[1].strip() + tags = [t.strip() for t in tags_str.split(",") if t.strip()] + + if not summary: + summary = f"Document: {doc_name}" + + return summary, tags + + +def process_single_pdf(file_path: str, user_id: str) -> dict: + """Extract text and chunk a single PDF. Returns {chunks, full_text, doc_name}.""" + full_text = extract_text_from_pdf(file_path) + if not full_text.strip(): + return {"chunks": [], "full_text": "", "doc_name": os.path.basename(file_path)} + + pages = extract_pages_from_pdf(file_path) + chunks = chunk_pages(pages) + + return { + "chunks": chunks, + "full_text": full_text, + "doc_name": os.path.basename(file_path), + } + + +async def process_pdfs(file_paths: list[str], user_id: str, conversation_id: str | None = None) -> dict: + """ + Multithreaded processing for multiple PDFs. + 1. Extracts text and chunks in parallel + 2. Generates LLM summaries + 3. Stores in Qdrant (pdf_summary + pdf_chunks collections) + Returns a dict with total chunks processed and details per file. + """ + from utils.qdrant_embed import upsert_pdf_summary, upsert_pdf_chunks + + results = {"total_chunks": 0, "files": {}} + processed_docs = [] + + # Use ThreadPoolExecutor for concurrent parsing and chunking + with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: + future_to_pdf = {executor.submit(process_single_pdf, path, user_id): path for path in file_paths} + + for future in concurrent.futures.as_completed(future_to_pdf): + path = future_to_pdf[future] + try: + doc_data = future.result() + filename = doc_data["doc_name"] + results["files"][filename] = len(doc_data["chunks"]) + results["total_chunks"] += len(doc_data["chunks"]) + processed_docs.append(doc_data) + except Exception as exc: + print(f'{path} generated an exception: {exc}') + results["files"][os.path.basename(path)] = 0 + + # Generate summaries and store in Qdrant + for doc_data in processed_docs: + pdf_id = str(uuid.uuid4()) + doc_name = doc_data["doc_name"] + full_text = doc_data["full_text"] + chunks = doc_data["chunks"] + + # Generate LLM summary + try: + summary, topic_tags = await generate_pdf_summary(full_text, doc_name) + except Exception as e: + print(f"Summary generation failed for {doc_name}: {e}") + summary = f"Document: {doc_name}" + topic_tags = [] + + # Store summary in Qdrant + try: + upsert_pdf_summary( + pdf_id=pdf_id, + user_id=user_id, + conversation_id=conversation_id, + doc_name=doc_name, + doc_summary=summary, + topic_tags=topic_tags, + ) + except Exception as e: + print(f"Failed to store summary for {doc_name}: {e}") + + # Store chunks in Qdrant + if chunks: + try: + chunk_count = upsert_pdf_chunks( + pdf_id=pdf_id, + user_id=user_id, + doc_name=doc_name, + chunks=chunks, + ) + print(f"[pdf_processor] Stored {chunk_count} chunks for {doc_name}") + except Exception as e: + print(f"Failed to store chunks for {doc_name}: {e}") + + return results \ No newline at end of file diff --git a/utils/qdrant_embed.py b/utils/qdrant_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..5799d462968d48d68fa7cdf7e8baf455099b5dde --- /dev/null +++ b/utils/qdrant_embed.py @@ -0,0 +1,14 @@ +# Thin wrapper — re-exports from repositories/qdrant_repo.py +from repositories.qdrant_repo import ( # noqa: F401 + get_embedding, + get_qdrant_client, + init_qdrant_collections, + search_chunks, + upsert_pdf_chunks, + upsert_pdf_summary, + get_user_pdf_summaries, + search_pdf_summary, + search_chunks_by_pdf_id, + get_pdf_ids_by_names, + get_most_recent_pdf_id, +) \ No newline at end of file diff --git a/utils/tools.py b/utils/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..f789894ecacde98fb27db0c6bd40d5854ea7f5e9 --- /dev/null +++ b/utils/tools.py @@ -0,0 +1,87 @@ +import math +import datetime +from langchain_core.tools import tool +from utils.qdrant_embed import search_chunks, search_chunks_by_pdf_id, get_most_recent_pdf_id + +# Per-request user_id injected by agent_service before each tool-agent invocation. +CURRENT_USER_ID: str = "default_user" + +# Accumulator: every chunk retrieved in this request is appended here. +# agent_service reads this after graph.invoke() to log them with the real message_id. +RETRIEVED_CHUNKS_BUFFER: list[dict] = [] + + +@tool +def calculator(expression: str) -> str: + """Evaluate a mathematical expression. Use for arithmetic, algebra, geometry calculations.""" + try: + allowed = {k: getattr(math, k) for k in dir(math) if not k.startswith("_")} + allowed.update({"abs": abs, "round": round, "pow": pow}) + result = eval(expression, {"__builtins__": {}}, allowed) + return f"Result: {result}" + except Exception as e: + return f"Calculation error: {str(e)}" + +@tool +def knowledge_retriever(query: str, pdf_id: str = "", mode: str = "all") -> str: + """Retrieve relevant information from the knowledge base (Qdrant vector store). + + Use this tool to answer questions about uploaded documents or any factual queries. + + Args: + query: The search query — describe what information you need. + pdf_id: Optional. If a specific document ID is provided, retrieve ONLY from + that document. + mode: Optional. "all" = search all user documents, "last" = search only the + most recently uploaded document. Default is "all". + """ + try: + if pdf_id: + chunks = search_chunks_by_pdf_id(query, user_id=CURRENT_USER_ID, pdf_id=pdf_id, top_k=5) + elif mode == "last": + # Get most recently uploaded PDF for user + latest_pdf_id = get_most_recent_pdf_id(CURRENT_USER_ID) + if latest_pdf_id: + chunks = search_chunks_by_pdf_id(query, user_id=CURRENT_USER_ID, pdf_id=latest_pdf_id, top_k=5) + else: + return "No documents have been uploaded yet. Please upload a PDF first." + else: + # mode == "all" - search all user documents + chunks = search_chunks(query, user_id=CURRENT_USER_ID, top_k=3) + + if chunks: + # Buffer the chunks — agent_service will log them after append_message + # gives us the real message_id (FK to messages table). + RETRIEVED_CHUNKS_BUFFER.extend(chunks) + + results = [] + for c in chunks: + page = f"page {c['page_number']}" if c.get("page_number") else "unknown page" + results.append( + f"[{c['doc_name']} | {page} | relevance: {c['similarity_score']:.2f}]\n" + f"{c['text_content']}" + ) + return "Retrieved knowledge:\n\n" + "\n\n---\n\n".join(results) + + if pdf_id: + return f"No content found in the specified document (pdf_id={pdf_id}) for this query." + return "No relevant knowledge found in the knowledge base for this query." + + except Exception as e: + print(f"[knowledge_retriever] Error: {e}") + return "Knowledge base is currently unavailable. Please answer from your general knowledge." + + +@tool +def get_current_datetime(timezone: str = "UTC") -> str: + """Get the current date and time. Use when user asks about current time or date.""" + now = datetime.datetime.now(datetime.timezone.utc) + return f"Current UTC datetime: {now.strftime('%Y-%m-%d %H:%M:%S')}" + + +def get_tools_list(): + return [calculator, knowledge_retriever, get_current_datetime] + + +def get_tools_map(): + return {t.name: t for t in get_tools_list()}