Spaces:
Running
Running
import pandas as pd | |
import pytest | |
from data_access import calculate_cumulative_statistics_for_all_questions, get_metadata, get_run_ids, \ | |
get_async_connection, get_questions | |
from data_access import get_unified_sources | |
async def test_get_questions(): | |
source_run_id = 2 | |
baseline_source_finder_run_id = 1 | |
async with get_async_connection() as conn: | |
actual = await get_questions(conn, source_run_id, baseline_source_finder_run_id) | |
assert len(actual) == 10 | |
async def test_get_unified_sources(): | |
async with get_async_connection() as conn: | |
results, stats = await get_unified_sources(conn,2, 2, 1) | |
assert results is not None | |
assert stats is not None | |
# Check number of rows in results.csv list | |
assert len(results) > 4, "Results should contain at least one row" | |
# Check number of rows in stats DataFrame | |
assert stats.shape[0] > 0, "Stats DataFrame should contain at least one row" | |
# You can also check specific stats columns | |
assert "overlap_count" in stats.columns, "Stats should contain overlap_count" | |
async def test_calculate_cumulative_statistics_for_all_questions(): | |
# Test with known source_finder_id, run_id, and ranker_id | |
source_finder_run_id = 2 | |
ranker_id = 1 | |
# Call the function to test | |
async with get_async_connection() as conn: | |
questions = await get_questions(conn, source_finder_run_id, ranker_id) | |
question_ids = [question['id'] for question in questions] | |
result = await calculate_cumulative_statistics_for_all_questions(conn, question_ids, source_finder_run_id, ranker_id) | |
# Check basic structure of results.csv | |
assert isinstance(result, pd.DataFrame), "Result should be a pandas DataFrame" | |
assert result.shape[0] == 1, "Result should have one row" | |
# Check required columns exist | |
expected_columns = [ | |
"total_questions_analyzed", | |
"total_baseline_sources", | |
"total_found_sources", | |
"total_overlap_count", | |
"overall_overlap_percentage", | |
"total_high_ranked_baseline_sources", | |
"total_high_ranked_found_sources", | |
"total_high_ranked_overlap_count", | |
"overall_high_ranked_overlap_percentage", | |
"avg_baseline_sources_per_question", | |
"avg_found_sources_per_question" | |
] | |
for column in expected_columns: | |
assert column in result.columns, f"Column {column} should be in result DataFrame" | |
# Check some basic value validations | |
assert result["total_questions_analyzed"].iloc[0] >= 0, "Should have zero or more questions analyzed" | |
assert result["total_baseline_sources"].iloc[0] >= 0, "Should have zero or more baseline sources" | |
assert result["total_found_sources"].iloc[0] >= 0, "Should have zero or more found sources" | |
# Check that percentages are within valid ranges | |
assert 0 <= result["overall_overlap_percentage"].iloc[0] <= 100, "Overlap percentage should be between 0 and 100" | |
assert 0 <= result["overall_high_ranked_overlap_percentage"].iloc[ | |
0] <= 100, "High ranked overlap percentage should be between 0 and 100" | |
async def test_get_metadata_none_returned(): | |
# Test with known source_finder_id, run_id, and ranker_id | |
source_finder_run_id = 1 | |
question_id = 1 | |
# Call the function to test | |
async with get_async_connection() as conn: | |
result = await get_metadata(conn, question_id, source_finder_run_id) | |
assert result == {}, "Should return empty string when no metadata is found" | |
async def test_get_metadata(): | |
# Test with known source_finder_id, run_id, and ranker_id | |
source_finder_run_id = 4 | |
question_id = 1 | |
# Call the function to test | |
async with get_async_connection() as conn: | |
result = await get_metadata(conn, question_id, source_finder_run_id) | |
assert result is not None, "Should return metadata when it exists" | |
async def test_get_run_ids(): | |
# Test with known question_id and source_finder_id | |
question_id = 2 # Using a question ID that exists in the test database | |
source_finder_id = 2 # Using a source finder ID that exists in the test database | |
# Call the function to test | |
async with get_async_connection() as conn: | |
result = await get_run_ids(conn, source_finder_id, question_id) | |
# Verify the result is a dictionary | |
assert isinstance(result, dict), "Result should be a dictionary" | |
# Check that the dictionary is not empty (assuming there are run IDs for this question/source finder) | |
assert len(result) > 0, "Should return at least one run ID" | |
# Test with a non-existent question_id | |
non_existent_question_id = 9999 | |
empty_result = await get_run_ids(conn, source_finder_id, non_existent_question_id) | |
assert isinstance(empty_result, dict), "Should return an empty dictionary for non-existent question" | |
assert len(empty_result) == 0, "Should return empty dictionary for non-existent question" | |
async def test_get_run_ids_no_question_id(): | |
source_finder_id = 2 # Using a source finder ID that exists in the test database | |
# Call the function to test | |
async with get_async_connection() as conn: | |
result = await get_run_ids(conn, source_finder_id) | |
# Verify the result is a dictionary | |
assert isinstance(result, dict), "Result should be a dictionary" | |
# Check that the dictionary is not empty (assuming there are run IDs for this question/source finder) | |
assert len(result) > 0, "Should return at least one run ID" | |