|
|
|
|
|
""" |
|
|
Context Relevance Classification Module |
|
|
Uses LLM inference to identify relevant session contexts and generate dynamic summaries |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import asyncio |
|
|
from typing import Dict, List, Optional |
|
|
from datetime import datetime |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class ContextRelevanceClassifier: |
|
|
""" |
|
|
Classify which session contexts are relevant to current conversation |
|
|
and generate 2-line summaries for each relevant session |
|
|
|
|
|
Performance Priority: |
|
|
- LLM inference first (accuracy over speed) |
|
|
- Parallel processing for multiple sessions |
|
|
- Caching for repeated queries |
|
|
- Graceful degradation on failures |
|
|
""" |
|
|
|
|
|
def __init__(self, llm_router): |
|
|
""" |
|
|
Initialize classifier with LLM router |
|
|
|
|
|
Args: |
|
|
llm_router: LLMRouter instance for inference calls |
|
|
""" |
|
|
self.llm_router = llm_router |
|
|
self._relevance_cache = {} |
|
|
self._summary_cache = {} |
|
|
self._cache_ttl = 3600 |
|
|
|
|
|
async def classify_and_summarize_relevant_contexts(self, |
|
|
current_input: str, |
|
|
session_contexts: List[Dict], |
|
|
user_id: str = "Test_Any") -> Dict: |
|
|
""" |
|
|
Main method: Classify relevant contexts AND generate 2-line summaries |
|
|
|
|
|
Performance Strategy: |
|
|
1. Extract current topic (LLM inference - single call) |
|
|
2. Calculate relevance in parallel (multiple LLM calls in parallel) |
|
|
3. Generate summaries in parallel (only for relevant sessions) |
|
|
|
|
|
Args: |
|
|
current_input: Current user query |
|
|
session_contexts: List of session context dictionaries |
|
|
user_id: User identifier for logging |
|
|
|
|
|
Returns: |
|
|
{ |
|
|
'relevant_summaries': List[str], # 2-line summaries |
|
|
'combined_user_context': str, # Combined summaries |
|
|
'relevance_scores': Dict, # Scores for each session |
|
|
'classification_confidence': float, |
|
|
'topic': str, |
|
|
'processing_time': float |
|
|
} |
|
|
""" |
|
|
start_time = datetime.now() |
|
|
|
|
|
try: |
|
|
|
|
|
if not session_contexts: |
|
|
logger.info("No session contexts provided for classification") |
|
|
return { |
|
|
'relevant_summaries': [], |
|
|
'combined_user_context': '', |
|
|
'relevance_scores': {}, |
|
|
'classification_confidence': 1.0, |
|
|
'topic': '', |
|
|
'processing_time': 0.0 |
|
|
} |
|
|
|
|
|
|
|
|
current_topic = await self._extract_current_topic(current_input) |
|
|
logger.info(f"Extracted current topic: '{current_topic}'") |
|
|
|
|
|
|
|
|
relevance_tasks = [] |
|
|
for session_ctx in session_contexts: |
|
|
task = self._calculate_relevance_with_cache( |
|
|
current_topic, |
|
|
current_input, |
|
|
session_ctx |
|
|
) |
|
|
relevance_tasks.append((session_ctx, task)) |
|
|
|
|
|
|
|
|
relevance_results = await asyncio.gather( |
|
|
*[task for _, task in relevance_tasks], |
|
|
return_exceptions=True |
|
|
) |
|
|
|
|
|
|
|
|
relevant_sessions = [] |
|
|
relevance_scores = {} |
|
|
|
|
|
for (session_ctx, _), result in zip(relevance_tasks, relevance_results): |
|
|
if isinstance(result, Exception): |
|
|
logger.error(f"Error calculating relevance: {result}") |
|
|
continue |
|
|
|
|
|
session_id = session_ctx.get('session_id', 'unknown') |
|
|
score = result.get('score', 0.0) |
|
|
relevance_scores[session_id] = score |
|
|
|
|
|
if score >= 0.6: |
|
|
relevant_sessions.append({ |
|
|
'session_id': session_id, |
|
|
'summary': session_ctx.get('summary', ''), |
|
|
'relevance_score': score, |
|
|
'interaction_contexts': session_ctx.get('interaction_contexts', []), |
|
|
'created_at': session_ctx.get('created_at', '') |
|
|
}) |
|
|
|
|
|
logger.info(f"Found {len(relevant_sessions)} relevant sessions out of {len(session_contexts)}") |
|
|
|
|
|
|
|
|
summary_tasks = [] |
|
|
for relevant_session in relevant_sessions: |
|
|
task = self._generate_session_summary( |
|
|
relevant_session, |
|
|
current_input, |
|
|
current_topic |
|
|
) |
|
|
summary_tasks.append(task) |
|
|
|
|
|
|
|
|
summary_results = await asyncio.gather(*summary_tasks, return_exceptions=True) |
|
|
|
|
|
|
|
|
valid_summaries = [] |
|
|
for summary in summary_results: |
|
|
if isinstance(summary, str) and summary.strip(): |
|
|
valid_summaries.append(summary.strip()) |
|
|
elif isinstance(summary, Exception): |
|
|
logger.error(f"Error generating summary: {summary}") |
|
|
|
|
|
|
|
|
combined_user_context = self._combine_summaries(valid_summaries, current_topic) |
|
|
|
|
|
processing_time = (datetime.now() - start_time).total_seconds() |
|
|
|
|
|
logger.info( |
|
|
f"Relevance classification complete: {len(valid_summaries)} summaries, " |
|
|
f"topic '{current_topic}', time: {processing_time:.2f}s" |
|
|
) |
|
|
|
|
|
return { |
|
|
'relevant_summaries': valid_summaries, |
|
|
'combined_user_context': combined_user_context, |
|
|
'relevance_scores': relevance_scores, |
|
|
'classification_confidence': 0.8, |
|
|
'topic': current_topic, |
|
|
'processing_time': processing_time |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in relevance classification: {e}", exc_info=True) |
|
|
processing_time = (datetime.now() - start_time).total_seconds() |
|
|
|
|
|
|
|
|
return { |
|
|
'relevant_summaries': [], |
|
|
'combined_user_context': '', |
|
|
'relevance_scores': {}, |
|
|
'classification_confidence': 0.0, |
|
|
'topic': '', |
|
|
'processing_time': processing_time, |
|
|
'error': str(e) |
|
|
} |
|
|
|
|
|
async def _extract_current_topic(self, user_input: str) -> str: |
|
|
""" |
|
|
Extract main topic from current input using LLM inference |
|
|
|
|
|
Performance: Single LLM call with caching |
|
|
""" |
|
|
try: |
|
|
|
|
|
cache_key = f"topic_{hash(user_input[:200])}" |
|
|
if cache_key in self._relevance_cache: |
|
|
cached = self._relevance_cache[cache_key] |
|
|
if cached.get('timestamp', 0) + self._cache_ttl > datetime.now().timestamp(): |
|
|
return cached['value'] |
|
|
|
|
|
if not self.llm_router: |
|
|
|
|
|
words = user_input.split()[:5] |
|
|
return ' '.join(words) if words else 'general query' |
|
|
|
|
|
prompt = f"""Extract the main topic (2-5 words) from this query: |
|
|
|
|
|
Query: "{user_input}" |
|
|
|
|
|
Respond with ONLY the topic name. Maximum 5 words.""" |
|
|
|
|
|
result = await self.llm_router.route_inference( |
|
|
task_type="classification", |
|
|
prompt=prompt, |
|
|
max_tokens=20, |
|
|
temperature=0.2 |
|
|
) |
|
|
|
|
|
topic = result.strip() if result else user_input[:100] |
|
|
|
|
|
|
|
|
self._relevance_cache[cache_key] = { |
|
|
'value': topic, |
|
|
'timestamp': datetime.now().timestamp() |
|
|
} |
|
|
|
|
|
return topic |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error extracting topic: {e}", exc_info=True) |
|
|
|
|
|
return user_input[:100] |
|
|
|
|
|
async def _calculate_relevance_with_cache(self, |
|
|
current_topic: str, |
|
|
current_input: str, |
|
|
session_ctx: Dict) -> Dict: |
|
|
""" |
|
|
Calculate relevance score with caching to reduce LLM calls |
|
|
|
|
|
Returns: {'score': float, 'cached': bool} |
|
|
""" |
|
|
try: |
|
|
session_id = session_ctx.get('session_id', 'unknown') |
|
|
session_summary = session_ctx.get('summary', '') |
|
|
|
|
|
|
|
|
cache_key = f"rel_{session_id}_{hash(current_input[:100] + current_topic)}" |
|
|
if cache_key in self._relevance_cache: |
|
|
cached = self._relevance_cache[cache_key] |
|
|
if cached.get('timestamp', 0) + self._cache_ttl > datetime.now().timestamp(): |
|
|
return {'score': cached['value'], 'cached': True} |
|
|
|
|
|
|
|
|
score = await self._calculate_relevance( |
|
|
current_topic, |
|
|
current_input, |
|
|
session_summary |
|
|
) |
|
|
|
|
|
|
|
|
self._relevance_cache[cache_key] = { |
|
|
'value': score, |
|
|
'timestamp': datetime.now().timestamp() |
|
|
} |
|
|
|
|
|
return {'score': score, 'cached': False} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in cached relevance calculation: {e}", exc_info=True) |
|
|
return {'score': 0.5, 'cached': False} |
|
|
|
|
|
async def _calculate_relevance(self, |
|
|
current_topic: str, |
|
|
current_input: str, |
|
|
context_text: str) -> float: |
|
|
""" |
|
|
Calculate relevance score (0.0 to 1.0) using LLM inference |
|
|
|
|
|
Performance: Single LLM call per session context |
|
|
""" |
|
|
try: |
|
|
if not context_text: |
|
|
return 0.0 |
|
|
|
|
|
if not self.llm_router: |
|
|
|
|
|
return self._simple_keyword_relevance(current_input, context_text) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prompt = f"""Rate the relevance (0.0 to 1.0) of this session context to the current conversation. |
|
|
|
|
|
Current Topic: {current_topic} |
|
|
Current Query: "{current_input[:200]}" |
|
|
|
|
|
Session Context: |
|
|
"{context_text[:500]}" |
|
|
|
|
|
Consider: |
|
|
- Topic similarity (0.0-1.0) |
|
|
- Discussion depth alignment |
|
|
- Information continuity |
|
|
|
|
|
Respond with ONLY a number between 0.0 and 1.0 (e.g., 0.75).""" |
|
|
|
|
|
result = await self.llm_router.route_inference( |
|
|
task_type="general_reasoning", |
|
|
prompt=prompt, |
|
|
max_tokens=10, |
|
|
temperature=0.1 |
|
|
) |
|
|
|
|
|
if result: |
|
|
try: |
|
|
score = float(result.strip()) |
|
|
return max(0.0, min(1.0, score)) |
|
|
except ValueError: |
|
|
logger.warning(f"Could not parse relevance score: {result}") |
|
|
|
|
|
|
|
|
return self._simple_keyword_relevance(current_input, context_text) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error calculating relevance: {e}", exc_info=True) |
|
|
return 0.5 |
|
|
|
|
|
def _simple_keyword_relevance(self, current_input: str, context_text: str) -> float: |
|
|
"""Fallback keyword-based relevance calculation""" |
|
|
try: |
|
|
current_lower = current_input.lower() |
|
|
context_lower = context_text.lower() |
|
|
|
|
|
current_words = set(current_lower.split()) |
|
|
context_words = set(context_lower.split()) |
|
|
|
|
|
|
|
|
stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by'} |
|
|
current_words = current_words - stop_words |
|
|
context_words = context_words - stop_words |
|
|
|
|
|
if not current_words: |
|
|
return 0.5 |
|
|
|
|
|
|
|
|
intersection = len(current_words & context_words) |
|
|
union = len(current_words | context_words) |
|
|
|
|
|
return (intersection / union) if union > 0 else 0.0 |
|
|
|
|
|
except Exception: |
|
|
return 0.5 |
|
|
|
|
|
async def _generate_session_summary(self, |
|
|
session_data: Dict, |
|
|
current_input: str, |
|
|
current_topic: str) -> str: |
|
|
""" |
|
|
Generate 2-line summary for a relevant session context |
|
|
|
|
|
Performance: LLM inference with caching and timeout protection |
|
|
Builds depth and width of topic discussion |
|
|
""" |
|
|
try: |
|
|
session_id = session_data.get('session_id', 'unknown') |
|
|
session_summary = session_data.get('summary', '') |
|
|
interaction_contexts = session_data.get('interaction_contexts', []) |
|
|
|
|
|
|
|
|
cache_key = f"summary_{session_id}_{hash(current_topic)}" |
|
|
if cache_key in self._summary_cache: |
|
|
cached = self._summary_cache[cache_key] |
|
|
if cached.get('timestamp', 0) + self._cache_ttl > datetime.now().timestamp(): |
|
|
return cached['value'] |
|
|
|
|
|
|
|
|
if not session_summary and not interaction_contexts: |
|
|
logger.warning(f"No content for summarization: session {session_id}") |
|
|
return f"Previous discussion on {current_topic}.\nContext details unavailable." |
|
|
|
|
|
|
|
|
session_context_text = session_summary[:500] if session_summary else "" |
|
|
|
|
|
if interaction_contexts: |
|
|
recent_interactions = "\n".join([ |
|
|
ic.get('summary', '')[:100] |
|
|
for ic in interaction_contexts[-5:] |
|
|
if ic.get('summary') |
|
|
]) |
|
|
if recent_interactions: |
|
|
session_context_text = f"{session_context_text}\n\nRecent interactions:\n{recent_interactions[:400]}" |
|
|
|
|
|
|
|
|
if len(session_context_text) > 1000: |
|
|
session_context_text = session_context_text[:1000] + "..." |
|
|
|
|
|
if not self.llm_router: |
|
|
|
|
|
return f"Previous {current_topic} discussion.\nCovered: {session_summary[:80]}..." |
|
|
|
|
|
|
|
|
prompt = f"""Generate a precise 2-line summary (maximum 2 sentences, ~100 tokens total) that captures the depth and breadth of the topic discussion: |
|
|
|
|
|
Current Topic: {current_topic} |
|
|
Current Query: "{current_input[:150]}" |
|
|
|
|
|
Previous Session Context: |
|
|
{session_context_text} |
|
|
|
|
|
Requirements: |
|
|
- Line 1: Summarize the MAIN TOPICS/SUBJECTS discussed (breadth/width) |
|
|
- Line 2: Summarize the DEPTH/LEVEL of discussion (technical depth, detail level, approach) |
|
|
- Focus on relevance to: "{current_topic}" |
|
|
- Keep total under 100 tokens |
|
|
- Be specific about what was covered |
|
|
|
|
|
Respond with ONLY the 2-line summary, no explanations.""" |
|
|
|
|
|
try: |
|
|
result = await asyncio.wait_for( |
|
|
self.llm_router.route_inference( |
|
|
task_type="general_reasoning", |
|
|
prompt=prompt, |
|
|
max_tokens=100, |
|
|
temperature=0.4 |
|
|
), |
|
|
timeout=10.0 |
|
|
) |
|
|
except asyncio.TimeoutError: |
|
|
logger.warning(f"Summary generation timeout for session {session_id}") |
|
|
return f"Previous {current_topic} discussion.\nDepth and approach covered in prior session." |
|
|
|
|
|
|
|
|
if result and isinstance(result, str) and result.strip(): |
|
|
summary = result.strip() |
|
|
lines = [line.strip() for line in summary.split('\n') if line.strip()] |
|
|
|
|
|
if len(lines) >= 1: |
|
|
if len(lines) > 2: |
|
|
combined = f"{lines[0]}\n{'. '.join(lines[1:])}" |
|
|
formatted_summary = combined[:200] |
|
|
else: |
|
|
formatted_summary = '\n'.join(lines[:2])[:200] |
|
|
|
|
|
|
|
|
if len(formatted_summary) < 20: |
|
|
formatted_summary = f"Previous {current_topic} discussion.\nDetails from previous session." |
|
|
|
|
|
|
|
|
self._summary_cache[cache_key] = { |
|
|
'value': formatted_summary, |
|
|
'timestamp': datetime.now().timestamp() |
|
|
} |
|
|
|
|
|
return formatted_summary |
|
|
else: |
|
|
return f"Previous {current_topic} discussion.\nContext from previous session." |
|
|
|
|
|
|
|
|
logger.warning(f"Invalid summary result for session {session_id}") |
|
|
return f"Previous {current_topic} discussion.\nDepth and approach covered previously." |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error generating session summary: {e}", exc_info=True) |
|
|
session_summary = session_data.get('summary', '')[:100] if session_data.get('summary') else 'topic discussion' |
|
|
return f"{session_summary}...\n{current_topic} discussion from previous session." |
|
|
|
|
|
def _combine_summaries(self, summaries: List[str], current_topic: str) -> str: |
|
|
""" |
|
|
Combine multiple 2-line summaries into coherent user context |
|
|
|
|
|
Builds width (multiple topics) and depth (summarized discussions) |
|
|
""" |
|
|
try: |
|
|
if not summaries: |
|
|
return '' |
|
|
|
|
|
if len(summaries) == 1: |
|
|
return summaries[0] |
|
|
|
|
|
|
|
|
combined = f"Relevant Previous Discussions (Topic: {current_topic}):\n\n" |
|
|
|
|
|
for idx, summary in enumerate(summaries, 1): |
|
|
combined += f"[Session {idx}]\n{summary}\n\n" |
|
|
|
|
|
|
|
|
combined += f"These sessions provide context for {current_topic} discussions, covering multiple aspects and depth levels." |
|
|
|
|
|
return combined |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error combining summaries: {e}", exc_info=True) |
|
|
|
|
|
return '\n\n'.join(summaries[:5]) |
|
|
|
|
|
|