Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import configparser | |
| import logging | |
| import os | |
| import ast | |
| import json | |
| from dotenv import load_dotenv | |
| from typing import Optional, Dict, Any, List | |
| from models import GraphState | |
| load_dotenv() | |
| logger = logging.getLogger(__name__) | |
| def getconfig(configfile_path: str): | |
| """ | |
| Read the config file | |
| Params | |
| ---------------- | |
| configfile_path: file path of .cfg file | |
| """ | |
| config = configparser.ConfigParser() | |
| try: | |
| config.read_file(open(configfile_path)) | |
| return config | |
| except: | |
| logging.warning("config file not found") | |
| def get_auth(provider: str) -> dict: | |
| """Get authentication configuration for different providers""" | |
| auth_configs = { | |
| "huggingface": {"api_key": os.getenv("HF_TOKEN")}, | |
| "qdrant": {"api_key": os.getenv("QDRANT_API_KEY")}, | |
| } | |
| provider = provider.lower() | |
| if provider not in auth_configs: | |
| raise ValueError(f"Unsupported provider: {provider}") | |
| auth_config = auth_configs[provider] | |
| api_key = auth_config.get("api_key") | |
| if not api_key: | |
| logging.warning(f"No API key found for provider '{provider}'") | |
| auth_config["api_key"] = None | |
| return auth_config | |
| def detect_file_type(filename: str, file_content: bytes = None) -> str: | |
| """Detect file type based on extension and content""" | |
| if not filename: | |
| return "unknown" | |
| _, ext = os.path.splitext(filename.lower()) | |
| file_type_mappings = { | |
| '.geojson': 'geojson', | |
| '.json': 'json', | |
| '.pdf': 'text', | |
| '.docx': 'text', | |
| '.doc': 'text', | |
| '.txt': 'text', | |
| '.md': 'text', | |
| '.csv': 'text', | |
| '.xlsx': 'text', | |
| '.xls': 'text' | |
| } | |
| detected_type = file_type_mappings.get(ext, 'unknown') | |
| # For JSON files, check if it's actually GeoJSON | |
| if detected_type == 'json' and file_content: | |
| try: | |
| content_str = file_content.decode('utf-8') | |
| data = json.loads(content_str) | |
| if isinstance(data, dict) and data.get('type') == 'FeatureCollection': | |
| detected_type = 'geojson' | |
| elif isinstance(data, dict) and data.get('type') in [ | |
| 'Feature', 'Point', 'LineString', 'Polygon', | |
| 'MultiPoint', 'MultiLineString', 'MultiPolygon', 'GeometryCollection' | |
| ]: | |
| detected_type = 'geojson' | |
| except: | |
| pass | |
| logger.info(f"Detected file type: {detected_type} for file: {filename}") | |
| return detected_type | |
| def convert_context_to_list(context: str) -> List[Dict[str, Any]]: | |
| """Convert string context to list format expected by generator""" | |
| try: | |
| if context.startswith('['): | |
| return ast.literal_eval(context) | |
| else: | |
| return [{ | |
| "answer": context, | |
| "answer_metadata": { | |
| "filename": "Retrieved Context", | |
| "page": "Unknown", | |
| "year": "Unknown", | |
| "source": "Retriever" | |
| } | |
| }] | |
| except: | |
| return [{ | |
| "answer": context, | |
| "answer_metadata": { | |
| "filename": "Retrieved Context", | |
| "page": "Unknown", | |
| "year": "Unknown", | |
| "source": "Retriever" | |
| } | |
| }] | |
| def merge_state(base_state: GraphState, updates: dict) -> GraphState: | |
| """Helper to merge node updates into base state""" | |
| return {**base_state, **updates} | |
| def build_conversation_context(messages, max_turns: int = 3, max_chars: int = 8000) -> str: | |
| """ | |
| Build conversation context from structured messages to send to generator. | |
| Always keeps the first user and assistant messages, plus the last N turns. | |
| A "turn" is one user message + following assistant response. | |
| Args: | |
| messages: List of Message objects | |
| max_turns: Maximum number of user-assistant exchange pairs to include (from the end) | |
| max_chars: Maximum total characters in context (increased default to 8000) | |
| """ | |
| if not messages: | |
| return "" | |
| context_parts = [] | |
| char_count = 0 | |
| msgs_included = 0 | |
| # Always include the first user and assistant messages | |
| first_user_msg = None | |
| first_assistant_msg = None | |
| # Find first user and assistant messages | |
| for msg in messages: | |
| if msg.role == 'user' and first_user_msg is None: | |
| first_user_msg = msg | |
| elif msg.role == 'assistant' and first_assistant_msg is None: | |
| first_assistant_msg = msg | |
| if first_user_msg and first_assistant_msg: | |
| break | |
| # Add first messages if they exist | |
| if first_user_msg: | |
| msg_text = f"USER: {first_user_msg.content}" | |
| msg_chars = len(msg_text) | |
| if char_count + msg_chars <= max_chars: | |
| context_parts.append(msg_text) | |
| char_count += msg_chars | |
| msgs_included += 1 | |
| if first_assistant_msg: | |
| msg_text = f"ASSISTANT: {first_assistant_msg.content}" | |
| msg_chars = len(msg_text) | |
| if char_count + msg_chars <= max_chars: | |
| context_parts.append(msg_text) | |
| char_count += msg_chars | |
| msgs_included += 1 | |
| # Collect last N complete turns (user + assistant pairs) | |
| # Find the last N user messages and their corresponding assistant responses | |
| user_messages = [msg for msg in messages if msg.role == 'user'] | |
| # Get the last N user messages (excluding the first one we already included) | |
| recent_user_messages = user_messages[1:][-max_turns:] if len(user_messages) > 1 else [] | |
| turn_count = 0 | |
| recent_messages = [] | |
| # Process each recent user message and find its corresponding assistant response | |
| for user_msg in recent_user_messages: | |
| if turn_count >= max_turns: | |
| break | |
| # Find the assistant response that follows this user message | |
| user_index = messages.index(user_msg) | |
| assistant_msg = None | |
| # Look for the next assistant message after this user message | |
| for i in range(user_index + 1, len(messages)): | |
| if messages[i].role == 'assistant': | |
| assistant_msg = messages[i] | |
| break | |
| # Add user message | |
| user_text = f"USER: {user_msg.content}" | |
| user_chars = len(user_text) | |
| if char_count + user_chars > max_chars: | |
| logger.info(f"Stopping context build: would exceed max_chars ({max_chars})") | |
| break | |
| recent_messages.append(user_text) | |
| char_count += user_chars | |
| msgs_included += 1 | |
| # Add assistant message if it exists | |
| if assistant_msg: | |
| assistant_text = f"ASSISTANT: {assistant_msg.content}" | |
| assistant_chars = len(assistant_text) | |
| if char_count + assistant_chars > max_chars: | |
| logger.info(f"Stopping context build: would exceed max_chars ({max_chars})") | |
| break | |
| recent_messages.append(assistant_text) | |
| char_count += assistant_chars | |
| msgs_included += 1 | |
| turn_count += 1 | |
| # Add recent messages to context | |
| context_parts.extend(recent_messages) | |
| context = "\n\n".join(context_parts) | |
| return context |