akryldigital's picture
Gemini FSA (#6)
8d898c4 verified
"""
Feedback Management Module
This module provides a unified interface for handling user feedback,
including data preparation, validation, and Snowflake storage.
"""
from typing import Dict, Any, List, Optional
from langchain_core.messages import HumanMessage, AIMessage
from .feedback_schema import UserFeedback, create_feedback_from_dict, generate_snowflake_schema_sql
from .snowflake_connector import SnowflakeFeedbackConnector, save_to_snowflake, get_snowflake_connector_from_env
class FeedbackManager:
"""
Unified manager for feedback operations.
This class provides a single interface for all feedback-related functionality,
including data preparation, validation, and storage.
"""
def __init__(self):
"""Initialize the FeedbackManager"""
pass
@staticmethod
def extract_transcript(messages: List[Any]) -> List[Dict[str, str]]:
"""Extract transcript from messages - only user and bot messages, no extra metadata"""
transcript = []
for msg in messages:
if isinstance(msg, HumanMessage):
transcript.append({
"role": "user",
"content": str(msg.content) if hasattr(msg, 'content') else str(msg)
})
elif isinstance(msg, AIMessage):
transcript.append({
"role": "assistant",
"content": str(msg.content) if hasattr(msg, 'content') else str(msg)
})
return transcript
@staticmethod
def build_retrievals_structure(rag_retrieval_history: List[Dict[str, Any]], messages: List[Any]) -> List[Dict[str, Any]]:
"""Build retrievals structure from retrieval history"""
retrievals = []
for entry in rag_retrieval_history:
# Get the user message that triggered this retrieval
# The entry has conversation_up_to which includes messages up to that point
conversation_up_to = entry.get("conversation_up_to", [])
# Find the last user message in conversation_up_to (this is the trigger)
user_message_trigger = ""
for msg_dict in reversed(conversation_up_to):
if msg_dict.get("type") == "HumanMessage":
user_message_trigger = msg_dict.get("content", "")
break
# Fallback: if not found in conversation_up_to, get from actual messages
# This handles edge cases where conversation_up_to might be incomplete
if not user_message_trigger:
# Find which retrieval this is (0-indexed)
retrieval_idx = rag_retrieval_history.index(entry)
# The user message that triggered this retrieval is at position (retrieval_idx * 2)
# because each retrieval is preceded by: user message, bot response, user message, ...
# But we need to account for the fact that the first retrieval happens after the first user message
user_msgs = [msg for msg in messages if isinstance(msg, HumanMessage)]
if retrieval_idx < len(user_msgs):
user_message_trigger = str(user_msgs[retrieval_idx].content)
elif user_msgs:
# Fallback to last user message
user_message_trigger = str(user_msgs[-1].content)
# Get retrieved documents and truncate content to 100 chars
docs_retrieved = entry.get("docs_retrieved", [])
retrieved_docs = []
for doc in docs_retrieved:
doc_copy = doc.copy()
# Truncate content to 100 characters (keep all other fields)
if "content" in doc_copy:
doc_copy["content"] = doc_copy["content"][:100]
retrieved_docs.append(doc_copy)
retrievals.append({
"retrieved_docs": retrieved_docs,
"user_message_trigger": user_message_trigger
})
return retrievals
@staticmethod
def build_feedback_score_related_retrieval_docs(
is_feedback_about_last_retrieval: bool,
messages: List[Any],
rag_retrieval_history: List[Dict[str, Any]]
) -> Optional[Dict[str, Any]]:
"""Build feedback_score_related_retrieval_docs structure"""
if not rag_retrieval_history:
return None
# Get the relevant retrieval entry
if is_feedback_about_last_retrieval:
relevant_entry = rag_retrieval_history[-1]
else:
# If feedback is about all retrievals, use the last one as default
relevant_entry = rag_retrieval_history[-1]
# Get conversation up to that point
conversation_up_to = relevant_entry.get("conversation_up_to", [])
# Convert to transcript format (role/content)
conversation_up_to_point = []
for msg_dict in conversation_up_to:
if msg_dict.get("type") == "HumanMessage":
conversation_up_to_point.append({
"role": "user",
"content": msg_dict.get("content", "")
})
elif msg_dict.get("type") == "AIMessage":
conversation_up_to_point.append({
"role": "assistant",
"content": msg_dict.get("content", "")
})
# Get retrieved docs with full content (not truncated)
retrieved_docs = relevant_entry.get("docs_retrieved", [])
return {
"conversation_up_to_point": conversation_up_to_point,
"retrieved_docs": retrieved_docs
}
@staticmethod
def create_feedback_from_dict(data: Dict[str, Any]) -> UserFeedback:
"""Create UserFeedback instance from dictionary"""
return create_feedback_from_dict(data)
@staticmethod
def save_to_snowflake(feedback: UserFeedback, table_name: Optional[str] = None) -> bool:
"""Save feedback to Snowflake"""
return save_to_snowflake(feedback, table_name)
@staticmethod
def generate_snowflake_schema_sql(table_name: Optional[str] = None) -> str:
"""Generate Snowflake schema SQL"""
return generate_snowflake_schema_sql(table_name)
__all__ = ["FeedbackManager", "UserFeedback", "save_to_snowflake", "SnowflakeFeedbackConnector"]