mtyrrell's picture
convo history truncation finalized
c8f5440
import configparser
import logging
import os
import ast
import json
from dotenv import load_dotenv
from typing import Optional, Dict, Any, List
from models import GraphState
load_dotenv()
logger = logging.getLogger(__name__)
def getconfig(configfile_path: str):
"""
Read the config file
Params
----------------
configfile_path: file path of .cfg file
"""
config = configparser.ConfigParser()
try:
config.read_file(open(configfile_path))
return config
except:
logging.warning("config file not found")
def get_auth(provider: str) -> dict:
"""Get authentication configuration for different providers"""
auth_configs = {
"huggingface": {"api_key": os.getenv("HF_TOKEN")},
"qdrant": {"api_key": os.getenv("QDRANT_API_KEY")},
}
provider = provider.lower()
if provider not in auth_configs:
raise ValueError(f"Unsupported provider: {provider}")
auth_config = auth_configs[provider]
api_key = auth_config.get("api_key")
if not api_key:
logging.warning(f"No API key found for provider '{provider}'")
auth_config["api_key"] = None
return auth_config
def detect_file_type(filename: str, file_content: bytes = None) -> str:
"""Detect file type based on extension and content"""
if not filename:
return "unknown"
_, ext = os.path.splitext(filename.lower())
file_type_mappings = {
'.geojson': 'geojson',
'.json': 'json',
'.pdf': 'text',
'.docx': 'text',
'.doc': 'text',
'.txt': 'text',
'.md': 'text',
'.csv': 'text',
'.xlsx': 'text',
'.xls': 'text'
}
detected_type = file_type_mappings.get(ext, 'unknown')
# For JSON files, check if it's actually GeoJSON
if detected_type == 'json' and file_content:
try:
content_str = file_content.decode('utf-8')
data = json.loads(content_str)
if isinstance(data, dict) and data.get('type') == 'FeatureCollection':
detected_type = 'geojson'
elif isinstance(data, dict) and data.get('type') in [
'Feature', 'Point', 'LineString', 'Polygon',
'MultiPoint', 'MultiLineString', 'MultiPolygon', 'GeometryCollection'
]:
detected_type = 'geojson'
except:
pass
logger.info(f"Detected file type: {detected_type} for file: {filename}")
return detected_type
def convert_context_to_list(context: str) -> List[Dict[str, Any]]:
"""Convert string context to list format expected by generator"""
try:
if context.startswith('['):
return ast.literal_eval(context)
else:
return [{
"answer": context,
"answer_metadata": {
"filename": "Retrieved Context",
"page": "Unknown",
"year": "Unknown",
"source": "Retriever"
}
}]
except:
return [{
"answer": context,
"answer_metadata": {
"filename": "Retrieved Context",
"page": "Unknown",
"year": "Unknown",
"source": "Retriever"
}
}]
def merge_state(base_state: GraphState, updates: dict) -> GraphState:
"""Helper to merge node updates into base state"""
return {**base_state, **updates}
def build_conversation_context(messages, max_turns: int = 3, max_chars: int = 8000) -> str:
"""
Build conversation context from structured messages to send to generator.
Always keeps the first user and assistant messages, plus the last N turns.
A "turn" is one user message + following assistant response.
Args:
messages: List of Message objects
max_turns: Maximum number of user-assistant exchange pairs to include (from the end)
max_chars: Maximum total characters in context (increased default to 8000)
"""
if not messages:
return ""
context_parts = []
char_count = 0
msgs_included = 0
# Always include the first user and assistant messages
first_user_msg = None
first_assistant_msg = None
# Find first user and assistant messages
for msg in messages:
if msg.role == 'user' and first_user_msg is None:
first_user_msg = msg
elif msg.role == 'assistant' and first_assistant_msg is None:
first_assistant_msg = msg
if first_user_msg and first_assistant_msg:
break
# Add first messages if they exist
if first_user_msg:
msg_text = f"USER: {first_user_msg.content}"
msg_chars = len(msg_text)
if char_count + msg_chars <= max_chars:
context_parts.append(msg_text)
char_count += msg_chars
msgs_included += 1
if first_assistant_msg:
msg_text = f"ASSISTANT: {first_assistant_msg.content}"
msg_chars = len(msg_text)
if char_count + msg_chars <= max_chars:
context_parts.append(msg_text)
char_count += msg_chars
msgs_included += 1
# Collect last N complete turns (user + assistant pairs)
# Find the last N user messages and their corresponding assistant responses
user_messages = [msg for msg in messages if msg.role == 'user']
# Get the last N user messages (excluding the first one we already included)
recent_user_messages = user_messages[1:][-max_turns:] if len(user_messages) > 1 else []
turn_count = 0
recent_messages = []
# Process each recent user message and find its corresponding assistant response
for user_msg in recent_user_messages:
if turn_count >= max_turns:
break
# Find the assistant response that follows this user message
user_index = messages.index(user_msg)
assistant_msg = None
# Look for the next assistant message after this user message
for i in range(user_index + 1, len(messages)):
if messages[i].role == 'assistant':
assistant_msg = messages[i]
break
# Add user message
user_text = f"USER: {user_msg.content}"
user_chars = len(user_text)
if char_count + user_chars > max_chars:
logger.info(f"Stopping context build: would exceed max_chars ({max_chars})")
break
recent_messages.append(user_text)
char_count += user_chars
msgs_included += 1
# Add assistant message if it exists
if assistant_msg:
assistant_text = f"ASSISTANT: {assistant_msg.content}"
assistant_chars = len(assistant_text)
if char_count + assistant_chars > max_chars:
logger.info(f"Stopping context build: would exceed max_chars ({max_chars})")
break
recent_messages.append(assistant_text)
char_count += assistant_chars
msgs_included += 1
turn_count += 1
# Add recent messages to context
context_parts.extend(recent_messages)
context = "\n\n".join(context_parts)
return context