| | import os |
| | import asyncio |
| | import json |
| | from google import genai |
| | from google.genai import types |
| | from dotenv import load_dotenv |
| |
|
| | from backend.core.prompts import ( |
| | SYSTEM_INSTRUCTION, |
| | INTENT_DETECTION_PROMPT, |
| | DATA_DISCOVERY_PROMPT, |
| | SQL_GENERATION_PROMPT, |
| | EXPLANATION_PROMPT, |
| | SPATIAL_SQL_PROMPT, |
| | SPATIAL_SQL_PROMPT, |
| | SQL_CORRECTION_PROMPT, |
| | LAYER_NAME_PROMPT |
| | ) |
| |
|
| | class LLMGateway: |
| | def __init__(self, model_name: str = "gemini-3-flash-preview"): |
| | |
| | load_dotenv() |
| | |
| | self.api_key = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY") |
| | if not self.api_key: |
| | print("WARNING: GEMINI_API_KEY/GOOGLE_API_KEY not found. LLM features will not work.") |
| | self.client = None |
| | else: |
| | |
| | if "GEMINI_API_KEY" not in os.environ and self.api_key: |
| | os.environ["GEMINI_API_KEY"] = self.api_key |
| | |
| | |
| | self.client = genai.Client() |
| | |
| | self.model = model_name |
| |
|
| | def _build_contents_from_history(self, history: list[dict], current_message: str) -> list: |
| | """ |
| | Converts conversation history to the format expected by the Gemini API. |
| | History format: [{"role": "user"|"assistant", "content": "..."}] |
| | """ |
| | contents = [] |
| | for msg in history: |
| | |
| | role = "model" if msg["role"] == "assistant" else "user" |
| | contents.append( |
| | types.Content( |
| | role=role, |
| | parts=[types.Part.from_text(text=msg["content"])] |
| | ) |
| | ) |
| | |
| | |
| | contents.append( |
| | types.Content( |
| | role="user", |
| | parts=[types.Part.from_text(text=current_message)] |
| | ) |
| | ) |
| | return contents |
| |
|
| | async def generate_response_stream(self, user_query: str, history: list[dict] = None): |
| | """ |
| | Generates a streaming response using conversation history for context. |
| | Yields chunks of text and thought summaries. |
| | """ |
| | if not self.client: |
| | yield "I couldn't generate a response because the API key is missing." |
| | return |
| | |
| | if history is None: |
| | history = [] |
| | |
| | try: |
| | contents = self._build_contents_from_history(history, user_query) |
| | |
| | |
| | config = types.GenerateContentConfig( |
| | system_instruction=SYSTEM_INSTRUCTION, |
| | thinking_config=types.ThinkingConfig( |
| | include_thoughts=True |
| | ) |
| | ) |
| | |
| | stream = await asyncio.to_thread( |
| | self.client.models.generate_content_stream, |
| | model=self.model, |
| | contents=contents, |
| | config=config, |
| | ) |
| |
|
| | for chunk in stream: |
| | for part in chunk.candidates[0].content.parts: |
| | if part.thought: |
| | yield {"type": "thought", "content": part.text} |
| | elif part.text: |
| | yield {"type": "content", "text": part.text} |
| |
|
| | except Exception as e: |
| | print(f"Error calling Gemini stream: {e}") |
| | yield f"Error: {str(e)}" |
| |
|
| | async def generate_response(self, user_query: str, history: list[dict] = None) -> str: |
| | """ |
| | Generates a response using conversation history for context. |
| | """ |
| | if not self.client: |
| | return "I couldn't generate a response because the API key is missing." |
| | |
| | if history is None: |
| | history = [] |
| | |
| | try: |
| | contents = self._build_contents_from_history(history, user_query) |
| | |
| | config = types.GenerateContentConfig( |
| | system_instruction=SYSTEM_INSTRUCTION, |
| | ) |
| | |
| | response = await asyncio.to_thread( |
| | self.client.models.generate_content, |
| | model=self.model, |
| | contents=contents, |
| | config=config, |
| | ) |
| | return response.text |
| | except Exception as e: |
| | print(f"Error calling Gemini: {e}") |
| | return f"I encountered an error: {e}" |
| |
|
| | async def detect_intent(self, user_query: str, history: list[dict] = None) -> str: |
| | """ |
| | Detects the intent of the user's query using Gemini thinking mode. |
| | Returns: GENERAL_CHAT, DATA_QUERY, MAP_REQUEST, SPATIAL_OP, or STAT_QUERY |
| | """ |
| | if not self.client: |
| | return "GENERAL_CHAT" |
| | |
| | intent_prompt = INTENT_DETECTION_PROMPT.format(user_query=user_query) |
| |
|
| | try: |
| | |
| | config = types.GenerateContentConfig( |
| | thinking_config=types.ThinkingConfig( |
| | thinking_level="medium" |
| | ) |
| | ) |
| | |
| | response = await asyncio.to_thread( |
| | self.client.models.generate_content, |
| | model=self.model, |
| | contents=intent_prompt, |
| | config=config, |
| | ) |
| | intent = response.text.strip().upper() |
| | |
| | |
| | if intent in ["GENERAL_CHAT", "DATA_QUERY", "MAP_REQUEST", "SPATIAL_OP", "STAT_QUERY"]: |
| | return intent |
| | |
| | |
| | return "GENERAL_CHAT" |
| | except Exception as e: |
| | print(f"Error detecting intent: {e}") |
| | return "GENERAL_CHAT" |
| |
|
| | async def stream_intent(self, user_query: str, history: list[dict] = None): |
| | """ |
| | Streams intent detection, yielding thoughts. |
| | """ |
| | if not self.client: |
| | yield {"type": "error", "text": "API Key missing"} |
| | return |
| | |
| | intent_prompt = INTENT_DETECTION_PROMPT.format(user_query=user_query) |
| |
|
| | try: |
| | config = types.GenerateContentConfig( |
| | thinking_config=types.ThinkingConfig( |
| | thinking_level="medium", |
| | include_thoughts=True |
| | ) |
| | ) |
| | |
| | stream = await asyncio.to_thread( |
| | self.client.models.generate_content_stream, |
| | model=self.model, |
| | contents=intent_prompt, |
| | config=config, |
| | ) |
| | |
| | for chunk in stream: |
| | for part in chunk.candidates[0].content.parts: |
| | if part.thought: |
| | yield {"type": "thought", "text": part.text} |
| | elif part.text: |
| | yield {"type": "content", "text": part.text} |
| |
|
| | except Exception as e: |
| | print(f"Error detecting intent: {e}") |
| | yield {"type": "error", "text": str(e)} |
| |
|
| | |
| |
|
| | async def identify_relevant_tables(self, user_query: str, table_summaries: str) -> list[str]: |
| | """ |
| | Identifies which tables are relevant for the user's query from the catalog summary. |
| | Returns a JSON list of table names. |
| | """ |
| | if not self.client: |
| | return [] |
| | |
| | prompt = DATA_DISCOVERY_PROMPT.format(user_query=user_query, table_summaries=table_summaries) |
| | |
| | try: |
| | config = types.GenerateContentConfig( |
| | response_mime_type="application/json" |
| | ) |
| | |
| | response = await asyncio.to_thread( |
| | self.client.models.generate_content, |
| | model=self.model, |
| | contents=prompt, |
| | config=config, |
| | ) |
| | |
| | text = response.text.replace("```json", "").replace("```", "").strip() |
| | tables = json.loads(text) |
| | return tables if isinstance(tables, list) else [] |
| | |
| | except Exception as e: |
| | print(f"Error identifying tables: {e}") |
| | return [] |
| |
|
| | async def generate_analytical_sql(self, user_query: str, table_schema: str, history: list[dict] = None) -> str: |
| | """ |
| | Generates a DuckDB SQL query for analytical/statistical questions about geographic data. |
| | This is the core of the text-to-SQL system. |
| | """ |
| | if not self.client: |
| | return "-- Error: API Key missing" |
| |
|
| | prompt = SQL_GENERATION_PROMPT.format(table_schema=table_schema, user_query=user_query) |
| |
|
| | try: |
| | |
| | config = types.GenerateContentConfig( |
| | temperature=1, |
| | thinking_config=types.ThinkingConfig( |
| | thinking_level="high" |
| | ) |
| | ) |
| |
|
| | response = await asyncio.wait_for( |
| | asyncio.to_thread( |
| | self.client.models.generate_content, |
| | model=self.model, |
| | contents=prompt, |
| | config=config, |
| | ), |
| | timeout=120.0 |
| | ) |
| | |
| | sql = response.text.replace("```sql", "").replace("```", "").strip() |
| | |
| | |
| | if not sql.upper().strip().startswith("SELECT") and "-- ERROR" not in sql: |
| | print(f"Warning: Generated SQL doesn't start with SELECT: {sql[:100]}") |
| | if "SELECT" in sql.upper(): |
| | start_idx = sql.upper().find("SELECT") |
| | sql = sql[start_idx:] |
| | |
| | return sql |
| |
|
| | except asyncio.TimeoutError: |
| | print("Gemini API call timed out after 30 seconds") |
| | return "-- Error: API call timed out. Please try again." |
| | except Exception as e: |
| | print(f"Error calling Gemini for analytical SQL: {e}") |
| | return f"-- Error generating SQL: {str(e)}" |
| |
|
| | async def stream_analytical_sql(self, user_query: str, table_schema: str, history: list[dict] = None): |
| | """ |
| | Streams the generation of DuckDB SQL, yielding thoughts and chunks. |
| | """ |
| | if not self.client: |
| | yield {"type": "error", "text": "API Key missing"} |
| | return |
| |
|
| | prompt = SQL_GENERATION_PROMPT.format(table_schema=table_schema, user_query=user_query) |
| |
|
| | try: |
| | config = types.GenerateContentConfig( |
| | temperature=1, |
| | thinking_config=types.ThinkingConfig( |
| | thinking_level="high", |
| | include_thoughts=True |
| | ) |
| | ) |
| |
|
| | stream = await asyncio.to_thread( |
| | self.client.models.generate_content_stream, |
| | model=self.model, |
| | contents=prompt, |
| | config=config, |
| | ) |
| |
|
| | for chunk in stream: |
| | for part in chunk.candidates[0].content.parts: |
| | if part.thought: |
| | yield {"type": "thought", "text": part.text} |
| | elif part.text: |
| | yield {"type": "content", "text": part.text} |
| |
|
| | except Exception as e: |
| | print(f"Error streaming SQL: {e}") |
| | yield {"type": "error", "text": str(e)} |
| |
|
| | async def stream_explanation(self, user_query: str, sql_query: str, data_summary: str, history: list[dict] = None): |
| | """ |
| | Streams the explanation. |
| | """ |
| | if not self.client: |
| | yield {"type": "error", "text": "API Key missing"} |
| | return |
| |
|
| | |
| | context_str = "" |
| | if history: |
| | context_str = "Previous conversation context:\n" |
| | for msg in history[-4:]: |
| | context_str += f"- {msg['role']}: {msg['content'][:100]}...\n" |
| |
|
| | prompt = EXPLANATION_PROMPT.format(context_str=context_str, user_query=user_query, sql_query=sql_query, data_summary=data_summary) |
| | |
| | try: |
| | config = types.GenerateContentConfig( |
| | system_instruction=SYSTEM_INSTRUCTION, |
| | thinking_config=types.ThinkingConfig( |
| | thinking_level="low", |
| | include_thoughts=True |
| | ) |
| | ) |
| | |
| | stream = await asyncio.to_thread( |
| | self.client.models.generate_content_stream, |
| | model=self.model, |
| | contents=prompt, |
| | config=config, |
| | ) |
| |
|
| | for chunk in stream: |
| | for part in chunk.candidates[0].content.parts: |
| | if part.thought: |
| | yield {"type": "thought", "text": part.text} |
| | elif part.text: |
| | yield {"type": "content", "text": part.text} |
| |
|
| | except Exception as e: |
| | print(f"Error generating explanation: {e}") |
| | yield {"type": "error", "text": str(e)} |
| |
|
| | async def generate_explanation(self, user_query: str, sql_query: str, data_summary: str, history: list[dict] = None) -> str: |
| | """ |
| | Explains the results of the query to the user, maintaining conversation context. |
| | """ |
| | if not self.client: |
| | return "I couldn't generate an explanation because the API key is missing." |
| |
|
| | |
| | context_str = "" |
| | if history: |
| | context_str = "Previous conversation context:\n" |
| | for msg in history[-4:]: |
| | context_str += f"- {msg['role']}: {msg['content'][:100]}...\n" |
| |
|
| | prompt = EXPLANATION_PROMPT.format(context_str=context_str, user_query=user_query, sql_query=sql_query, data_summary=data_summary) |
| | |
| | try: |
| | config = types.GenerateContentConfig( |
| | system_instruction=SYSTEM_INSTRUCTION, |
| | thinking_config=types.ThinkingConfig( |
| | thinking_level="low" |
| | ) |
| | ) |
| | |
| | response = await asyncio.to_thread( |
| | self.client.models.generate_content, |
| | model=self.model, |
| | contents=prompt, |
| | config=config, |
| | ) |
| | return response.text |
| | except Exception as e: |
| | print(f"Error generating explanation: {e}") |
| | return "Here are the results from the query." |
| |
|
| | async def generate_spatial_sql(self, user_query: str, layer_context: str, history: list[dict] = None) -> str: |
| | """ |
| | Generates a DuckDB Spatial SQL query for geometric operations on layers. |
| | """ |
| | if not self.client: |
| | return "-- Error: API Key missing" |
| |
|
| | prompt = SPATIAL_SQL_PROMPT.format(layer_context=layer_context, user_query=user_query) |
| | |
| | try: |
| | config = types.GenerateContentConfig( |
| | temperature=1, |
| | ) |
| |
|
| | |
| | response = await asyncio.wait_for( |
| | asyncio.to_thread( |
| | self.client.models.generate_content, |
| | model=self.model, |
| | contents=prompt, |
| | config=config, |
| | ), |
| | timeout=120.0 |
| | ) |
| | |
| | sql = response.text.replace("```sql", "").replace("```", "").strip() |
| | return sql |
| |
|
| | except asyncio.TimeoutError: |
| | print("Gemini API call timed out after 30 seconds") |
| | return "-- Error: API call timed out. Please try again." |
| | except Exception as e: |
| | print(f"Error calling Gemini: {e}") |
| | return f"-- Error generating SQL: {str(e)}" |
| | |
| | async def correct_sql(self, user_query: str, incorrect_sql: str, error_message: str, schema_context: str) -> str: |
| | """ |
| | Corrects a failed SQL query based on the error message. |
| | """ |
| | if not self.client: |
| | return "-- Error: API Key missing" |
| |
|
| | prompt = SQL_CORRECTION_PROMPT.format( |
| | error_message=error_message, |
| | incorrect_sql=incorrect_sql, |
| | user_query=user_query, |
| | schema_context=schema_context |
| | ) |
| |
|
| | try: |
| | config = types.GenerateContentConfig( |
| | temperature=1, |
| | ) |
| |
|
| | response = await asyncio.to_thread( |
| | self.client.models.generate_content, |
| | model=self.model, |
| | contents=prompt, |
| | config=config, |
| | ) |
| | |
| | sql = response.text.replace("```sql", "").replace("```", "").strip() |
| | return sql |
| |
|
| | except Exception as e: |
| | print(f"Error correcting SQL: {e}") |
| | return incorrect_sql |
| | |
| | async def generate_layer_name(self, user_query: str, sql_query: str) -> dict: |
| | """ |
| | Generates a short, descriptive name, emoji, and point style for a map layer. |
| | Returns: {"name": str, "emoji": str, "pointStyle": str | None} |
| | """ |
| | if not self.client: |
| | return {"name": "New Layer", "emoji": "π", "pointStyle": None} |
| |
|
| | prompt = LAYER_NAME_PROMPT.format(user_query=user_query, sql_query=sql_query) |
| |
|
| | try: |
| | config = types.GenerateContentConfig( |
| | temperature=1, |
| | response_mime_type="application/json" |
| | ) |
| |
|
| | |
| | response = await asyncio.to_thread( |
| | self.client.models.generate_content, |
| | model=self.model, |
| | contents=prompt, |
| | config=config, |
| | ) |
| | |
| | result = json.loads(response.text) |
| | return { |
| | "name": result.get("name", "Map Layer"), |
| | "emoji": result.get("emoji", "π"), |
| | "pointStyle": result.get("pointStyle", None) |
| | } |
| | except Exception as e: |
| | print(f"Error generating layer name: {e}") |
| | return {"name": "Map Layer", "emoji": "π", "pointStyle": None} |
| |
|