# Standard Library Imports import json from typing import List, Tuple # Third-Party Library Imports from sqlalchemy.ext.asyncio import AsyncSession # Local Application Imports from src.common import ( ComparisonType, LeaderboardEntry, OptionKey, OptionMap, TTSProviderName, VotingResults, constants, logger, ) from src.database import ( AsyncDBSessionMaker, create_vote, get_head_to_head_battle_stats, get_head_to_head_win_rate_stats, get_leaderboard_stats, ) class VotingService: """ Service for handling all database interactions related to voting and leaderboards. Encapsulates logic for submitting votes and retrieving formatted leaderboard statistics. """ def __init__(self, db_session_maker: AsyncDBSessionMaker): """ Initializes the VotingService. Args: db_session_maker: An asynchronous database session factory. """ self.db_session_maker: AsyncDBSessionMaker = db_session_maker logger.debug("VotingService initialized.") async def _create_db_session(self) -> AsyncSession | None: """ Creates a new database session, returning None if it's a dummy session. Returns: An active AsyncSession or None if using a dummy session factory. """ session = self.db_session_maker() # Check for a dummy session marker if your factory provides one is_dummy_session = getattr(session, "is_dummy", False) if is_dummy_session: logger.debug("Using dummy DB session; operations will be skipped.") # Ensure dummy sessions are also closed if they have resources if hasattr(session, "close"): await session.close() return None logger.debug("Created new DB session.") return session def _determine_comparison_type(self, provider_a: TTSProviderName, provider_b: TTSProviderName) -> ComparisonType: """ Determine the comparison type based on the given TTS provider names. Args: provider_a (TTSProviderName): The first TTS provider. provider_b (TTSProviderName): The second TTS provider. Returns: ComparisonType: The determined comparison type. Raises: ValueError: If the combination of providers is not recognized. """ if provider_a == constants.HUME_AI and provider_b == constants.HUME_AI: return constants.HUME_TO_HUME providers = (provider_a, provider_b) if constants.HUME_AI in providers and constants.ELEVENLABS in providers: return constants.HUME_TO_ELEVENLABS if constants.HUME_AI in providers and constants.OPENAI in providers: return constants.HUME_TO_OPENAI if constants.ELEVENLABS in providers and constants.OPENAI in providers: return constants.OPENAI_TO_ELEVENLABS raise ValueError(f"Invalid provider combination: {provider_a}, {provider_b}") async def _persist_vote(self, voting_results: VotingResults) -> None: """ Persists a vote record in the database using a dedicated session. Handles session creation, commit, rollback, and closure. Logs errors internally. Args: voting_results: A dictionary containing the vote details. """ session = await self._create_db_session() if session is None: logger.info("Skipping vote persistence (dummy session).") self._log_voting_results(voting_results) return try: self._log_voting_results(voting_results) await create_vote(session, voting_results) logger.info("Vote successfully persisted.") except Exception as e: logger.error(f"Failed to persist vote record: {e}", exc_info=True) finally: await session.close() logger.debug("DB session closed after persisting vote.") def _log_voting_results(self, voting_results: VotingResults) -> None: """Logs the full voting results dictionary.""" try: logger.info("Voting results:\n%s", json.dumps(voting_results, indent=4, default=str)) except TypeError: logger.error("Could not serialize voting results for logging.") logger.info(f"Voting results (raw): {voting_results}") def _format_leaderboard_data(self, leaderboard_data_raw: List[LeaderboardEntry]) -> List[List[str]]: """Formats raw leaderboard entries into HTML strings for the UI table.""" formatted_data = [] for rank, provider, model, win_rate, votes in leaderboard_data_raw: provider_info = constants.TTS_PROVIDER_LINKS.get(provider, {}) provider_link = provider_info.get("provider_link", "#") model_link = provider_info.get("model_link", "#") formatted_data.append([ f'
{rank}
', f'{provider}', f'{model}', f'{win_rate}
', f'{votes}
', ]) return formatted_data def _format_battle_counts_data(self, battle_counts_data_raw: List[List[str]]) -> List[List[str]]: """Formats raw battle counts into an HTML matrix for the UI.""" battle_counts_dict = {item[0]: str(item[1]) for item in battle_counts_data_raw} providers = constants.TTS_PROVIDERS formatted_matrix: List[List[str]] = [] for row_provider in providers: row = [f'{row_provider}
'] for col_provider in providers: if row_provider == col_provider: cell_value = "-" else: comparison_key = self._determine_comparison_type(row_provider, col_provider) cell_value = battle_counts_dict.get(comparison_key, "0") row.append(f'{cell_value}
') formatted_matrix.append(row) return formatted_matrix def _format_win_rate_data(self, win_rate_data_raw: List[List[str]]) -> List[List[str]]: """Formats raw win rates into an HTML matrix for the UI.""" # win_rate_data_raw expected as [comparison_type, first_win_rate_str, second_win_rate_str] win_rates = {} for comparison_type, first_win_rate, second_win_rate in win_rate_data_raw: # Comparison type should already be canonical 'ProviderA - ProviderB' try: provider1, provider2 = comparison_type.split(" - ") win_rates[(provider1, provider2)] = first_win_rate win_rates[(provider2, provider1)] = second_win_rate except ValueError: logger.warning(f"Could not parse comparison_type '{comparison_type}' in win rate data.") continue # Skip malformed entry providers = constants.TTS_PROVIDERS formatted_matrix: List[List[str]] = [] for row_provider in providers: row = [f'{row_provider}
'] for col_provider in providers: cell_value = "-" if row_provider == col_provider else win_rates.get((row_provider, col_provider), "0%") row.append(f'{cell_value}
') formatted_matrix.append(row) return formatted_matrix async def get_formatted_leaderboard_data(self) -> Tuple[ List[List[str]], List[List[str]], List[List[str]], ]: """ Fetches raw leaderboard stats and formats them for UI display. Retrieves overall rankings, battle counts, and win rates, then formats them into HTML strings suitable for Gradio DataFrames. Returns: A tuple containing formatted lists of lists for: - Leaderboard rankings table - Battle counts matrix - Win rate matrix Returns empty lists ([[]], [[]], [[]]) on failure. """ session = await self._create_db_session() if session is None: logger.info("Skipping leaderboard fetch (dummy session).") return [[]], [[]], [[]] try: # Fetch raw data using underlying CRUD functions leaderboard_data_raw = await get_leaderboard_stats(session) battle_counts_data_raw = await get_head_to_head_battle_stats(session) win_rate_data_raw = await get_head_to_head_win_rate_stats(session) logger.debug("Fetched raw leaderboard data successfully.") # Format the data leaderboard_data = self._format_leaderboard_data(leaderboard_data_raw) battle_counts_data = self._format_battle_counts_data(battle_counts_data_raw) win_rate_data = self._format_win_rate_data(win_rate_data_raw) return leaderboard_data, battle_counts_data, win_rate_data except Exception as e: logger.error(f"Failed to fetch and format leaderboard data: {e}", exc_info=True) return [[]], [[]], [[]] # Return empty structure on error finally: await session.close() logger.debug("DB session closed after fetching leaderboard data.") async def submit_vote( self, option_map: OptionMap, selected_option: OptionKey, text_modified: bool, character_description: str, text: str, ) -> None: """ Constructs and persists a vote record based on user selection and context. This method is designed to be called safely from background tasks, handling all internal exceptions. Args: option_map: Mapping of comparison data and TTS options. selected_option: The option key ('option_a' or 'option_b') selected by the user. text_modified: Indicates if the text was custom vs. generated. character_description: Description used for TTS generation. text: The text synthesized. """ try: provider_a: TTSProviderName = option_map[constants.OPTION_A_KEY]["provider"] provider_b: TTSProviderName = option_map[constants.OPTION_B_KEY]["provider"] comparison_type: ComparisonType = self._determine_comparison_type(provider_a, provider_b) voting_results: VotingResults = { "comparison_type": comparison_type, "winning_provider": option_map[selected_option]["provider"], "winning_option": selected_option, "option_a_provider": provider_a, "option_b_provider": provider_b, "option_a_generation_id": option_map[constants.OPTION_A_KEY]["generation_id"], "option_b_generation_id": option_map[constants.OPTION_B_KEY]["generation_id"], "character_description": character_description, "text": text, "is_custom_text": text_modified, } await self._persist_vote(voting_results) except KeyError as e: logger.error( f"Missing key in option_map during vote submission: {e}. OptionMap: {option_map}", exc_info=True ) except Exception as e: logger.error(f"Unexpected error in submit_vote: {e}", exc_info=True)