ankigen / tests /unit /test_llm_interface.py.orig
brickfrog's picture
Upload folder using huggingface_hub
100024e verified
# Tests for ankigen_core/llm_interface.py
import pytest
from unittest.mock import patch, MagicMock, ANY, AsyncMock
from openai import OpenAIError
import json
import tenacity
import asyncio
from openai.types.chat import ChatCompletion
from openai.types.chat.chat_completion import Choice as ChatCompletionChoice
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai import RateLimitError, APIConnectionError, AsyncOpenAI
# Modules to test
from ankigen_core.llm_interface import (
OpenAIClientManager,
structured_output_completion,
process_crawled_page,
process_crawled_pages,
)
from ankigen_core.utils import (
ResponseCache,
) # Need ResponseCache for testing structured_output_completion
from ankigen_core.models import CrawledPage, AnkiCardData
from openai import APIError
# --- OpenAIClientManager Tests ---
@pytest.mark.anyio
async def test_client_manager_init():
"""Test initial state of the client manager."""
manager = OpenAIClientManager()
assert manager._client is None
assert manager._api_key is None
@pytest.mark.anyio
async def test_client_manager_initialize_success():
"""Test successful client initialization."""
manager = OpenAIClientManager()
valid_key = "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
# We don't need to actually connect, so patch the AsyncOpenAI constructor in the llm_interface module
with patch(
"ankigen_core.llm_interface.AsyncOpenAI"
) as mock_async_openai_constructor:
await manager.initialize_client(valid_key)
mock_async_openai_constructor.assert_called_once_with(api_key=valid_key)
assert manager.get_client() is not None
@pytest.mark.anyio
async def test_client_manager_initialize_invalid_key_format():
"""Test initialization failure with invalid API key format."""
manager = OpenAIClientManager()
invalid_key = "invalid-key-format"
with pytest.raises(ValueError, match="Invalid OpenAI API key format."):
await manager.initialize_client(invalid_key)
assert manager._client is None
assert manager._api_key is None # Should remain None
@pytest.mark.anyio
async def test_client_manager_initialize_openai_error():
"""Test handling of OpenAIError during client initialization."""
manager = OpenAIClientManager()
valid_key = "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
error_message = "Test OpenAI Init Error"
with patch(
"ankigen_core.llm_interface.AsyncOpenAI", side_effect=OpenAIError(error_message)
) as mock_async_openai_constructor:
with pytest.raises(OpenAIError, match=error_message):
await manager.initialize_client(valid_key)
mock_async_openai_constructor.assert_called_once_with(api_key=valid_key)
@pytest.mark.anyio
async def test_client_manager_get_client_success():
"""Test getting the client after successful initialization."""
manager = OpenAIClientManager()
valid_key = "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
with patch(
"ankigen_core.llm_interface.AsyncOpenAI"
) as mock_async_openai_constructor:
mock_instance = mock_async_openai_constructor.return_value
await manager.initialize_client(valid_key)
assert manager.get_client() == mock_instance
def test_client_manager_get_client_not_initialized():
"""Test getting the client before initialization."""
manager = OpenAIClientManager()
with pytest.raises(RuntimeError, match="OpenAI client is not initialized."):
manager.get_client()
# --- structured_output_completion Tests ---
# Fixture for mock OpenAI client
@pytest.fixture
def mock_openai_client():
client = MagicMock(spec=AsyncOpenAI)
client.chat = AsyncMock()
client.chat.completions = AsyncMock()
client.chat.completions.create = AsyncMock()
mock_chat_completion_response = create_mock_chat_completion(
json.dumps([{"data": "mocked success"}])
)
client.chat.completions.create.return_value = mock_chat_completion_response
return client
# Fixture for mock ResponseCache
@pytest.fixture
def mock_response_cache():
cache = MagicMock(spec=ResponseCache)
return cache
@pytest.mark.anyio
async def test_structured_output_completion_cache_hit(
mock_openai_client, mock_response_cache
):
"""Test behavior when the response is found in the cache."""
system_prompt = "System prompt"
user_prompt = "User prompt"
model = "test-model"
cached_result = {"data": "cached result"}
# Configure mock cache to return the cached result
mock_response_cache.get.return_value = cached_result
result = await structured_output_completion(
openai_client=mock_openai_client,
model=model,
response_format={"type": "json_object"},
system_prompt=system_prompt,
user_prompt=user_prompt,
cache=mock_response_cache,
)
# Assertions
mock_response_cache.get.assert_called_once_with(
f"{system_prompt}:{user_prompt}", model
)
mock_openai_client.chat.completions.create.assert_not_called() # API should not be called
mock_response_cache.set.assert_not_called() # Cache should not be set again
assert result == cached_result
@pytest.mark.anyio
async def test_structured_output_completion_cache_miss_success(
mock_openai_client, mock_response_cache
):
"""Test behavior on cache miss with a successful API call."""
system_prompt = "System prompt for success"
user_prompt = "User prompt for success"
model = "test-model-success"
expected_result = {"data": "successful API result"}
# Configure mock cache to return None (cache miss)
mock_response_cache.get.return_value = None
# Configure mock API response
mock_completion = MagicMock()
mock_message = MagicMock()
mock_message.content = json.dumps(expected_result)
mock_choice = MagicMock()
mock_choice.message = mock_message
mock_completion.choices = [mock_choice]
mock_openai_client.chat.completions.create.return_value = mock_completion
result = await structured_output_completion(
openai_client=mock_openai_client,
model=model,
response_format={"type": "json_object"},
system_prompt=system_prompt,
user_prompt=user_prompt,
cache=mock_response_cache,
)
# Assertions
mock_response_cache.get.assert_called_once_with(
f"{system_prompt}:{user_prompt}", model
)
mock_openai_client.chat.completions.create.assert_called_once_with(
model=model,
messages=[
{
"role": "system",
"content": ANY,
}, # Check prompt structure later if needed
{"role": "user", "content": user_prompt},
],
response_format={"type": "json_object"},
temperature=0.7,
)
mock_response_cache.set.assert_called_once_with(
f"{system_prompt}:{user_prompt}", model, expected_result
)
assert result == expected_result
@pytest.mark.anyio
async def test_structured_output_completion_api_error(
mock_openai_client, mock_response_cache
):
"""Test behavior when the OpenAI API call raises an error."""
system_prompt = "System prompt for error"
user_prompt = "User prompt for error"
model = "test-model-error"
error_message = "Test API Error"
# Configure mock cache for cache miss
mock_response_cache.get.return_value = None
# Configure mock API call to raise an error (after potential retries)
# The @retry decorator is hard to mock precisely without tenacity knowledge.
# We assume it eventually raises the error if all retries fail.
mock_openai_client.chat.completions.create.side_effect = OpenAIError(error_message)
with pytest.raises(tenacity.RetryError):
await structured_output_completion(
openai_client=mock_openai_client,
model=model,
response_format={"type": "json_object"},
system_prompt=system_prompt,
user_prompt=user_prompt,
cache=mock_response_cache,
)
# Optionally, check the underlying exception type if needed:
# assert isinstance(excinfo.value.last_attempt.exception(), OpenAIError)
# assert str(excinfo.value.last_attempt.exception()) == error_message
# Assertions
# cache.get is called on each retry attempt
assert mock_response_cache.get.call_count == 3, (
f"Expected cache.get to be called 3 times due to retries, but was {mock_response_cache.get.call_count}"
)
# Check that create was called 3 times due to retry
assert mock_openai_client.chat.completions.create.call_count == 3, (
f"Expected create to be called 3 times due to retries, but was {mock_openai_client.chat.completions.create.call_count}"
)
mock_response_cache.set.assert_not_called() # Cache should not be set on error
@pytest.mark.anyio
async def test_structured_output_completion_invalid_json(
mock_openai_client, mock_response_cache
):
"""Test behavior when the API returns invalid JSON."""
system_prompt = "System prompt for invalid json"
user_prompt = "User prompt for invalid json"
model = "test-model-invalid-json"
invalid_json_content = "this is not json"
# Configure mock cache for cache miss
mock_response_cache.get.return_value = None
# Configure mock API response with invalid JSON
mock_completion = MagicMock()
mock_message = MagicMock()
mock_message.content = invalid_json_content
mock_choice = MagicMock()
mock_choice.message = mock_message
mock_completion.choices = [mock_choice]
mock_openai_client.chat.completions.create.return_value = mock_completion
with pytest.raises(tenacity.RetryError):
await structured_output_completion(
openai_client=mock_openai_client,
model=model,
response_format={"type": "json_object"},
system_prompt=system_prompt,
user_prompt=user_prompt,
cache=mock_response_cache,
)
# Assertions
# cache.get is called on each retry attempt
assert mock_response_cache.get.call_count == 3, (
f"Expected cache.get to be called 3 times due to retries, but was {mock_response_cache.get.call_count}"
)
# create is also called on each retry attempt
assert mock_openai_client.chat.completions.create.call_count == 3, (
f"Expected create to be called 3 times due to retries, but was {mock_openai_client.chat.completions.create.call_count}"
)
mock_response_cache.set.assert_not_called() # Cache should not be set on error
@pytest.mark.anyio
async def test_structured_output_completion_no_choices(
mock_openai_client, mock_response_cache
):
"""Test behavior when API completion has no choices."""
system_prompt = "System prompt no choices"
user_prompt = "User prompt no choices"
model = "test-model-no-choices"
mock_response_cache.get.return_value = None
mock_completion = MagicMock()
mock_completion.choices = [] # No choices
mock_openai_client.chat.completions.create.return_value = mock_completion
# Currently function logs warning and returns None. We test for None.
result = await structured_output_completion(
openai_client=mock_openai_client,
model=model,
response_format={"type": "json_object"},
system_prompt=system_prompt,
user_prompt=user_prompt,
cache=mock_response_cache,
)
assert result is None
mock_response_cache.set.assert_not_called()
@pytest.mark.anyio
async def test_structured_output_completion_no_message_content(
mock_openai_client, mock_response_cache
):
"""Test behavior when API choice has no message content."""
system_prompt = "System prompt no content"
user_prompt = "User prompt no content"
model = "test-model-no-content"
mock_response_cache.get.return_value = None
mock_completion = MagicMock()
mock_message = MagicMock()
mock_message.content = None # No content
mock_choice = MagicMock()
mock_choice.message = mock_message
mock_completion.choices = [mock_choice]
mock_openai_client.chat.completions.create.return_value = mock_completion
# Currently function logs warning and returns None. We test for None.
result = await structured_output_completion(
openai_client=mock_openai_client,
model=model,
response_format={"type": "json_object"},
system_prompt=system_prompt,
user_prompt=user_prompt,
cache=mock_response_cache,
)
assert result is None
mock_response_cache.set.assert_not_called()
# Remove original placeholder
# def test_placeholder_llm_interface():
# assert True
# --- Fixtures ---
# --- Tests for process_crawled_page ---
def create_mock_chat_completion(content: str) -> ChatCompletion:
return ChatCompletion(
id="chatcmpl-test123",
choices=[
ChatCompletionChoice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(content=content, role="assistant"),
logprobs=None,
)
],
created=1677652288,
model="gpt-4o",
object="chat.completion",
system_fingerprint="fp_test",
usage=None, # Not testing usage here
)
@pytest.mark.anyio
async def test_process_crawled_page_success(mock_openai_client, sample_crawled_page):
mock_response_content = json.dumps(
[
{"front": "Q1", "back": "A1", "tags": ["tag1"]},
{"front": "Q2", "back": "A2", "tags": ["tag2", "python"]},
]
)
mock_openai_client.chat.completions.create.return_value = (
create_mock_chat_completion(mock_response_content)
)
result_cards = await process_crawled_page(mock_openai_client, sample_crawled_page)
assert len(result_cards) == 2
assert result_cards[0].front == "Q1"
assert result_cards[0].source_url == sample_crawled_page.url
assert result_cards[1].tags == ["tag2", "python"]
mock_openai_client.chat.completions.create.assert_awaited_once()
@pytest.mark.anyio
async def test_process_crawled_page_empty_llm_response_content(
mock_openai_client, sample_crawled_page
):
mock_openai_client.chat.completions.create.return_value = (
create_mock_chat_completion("")
) # Empty string content
result_cards = await process_crawled_page(mock_openai_client, sample_crawled_page)
assert len(result_cards) == 0
@pytest.mark.anyio
async def test_process_crawled_page_llm_returns_not_a_list(
mock_openai_client, sample_crawled_page
):
mock_response_content = json.dumps(
{"error": "not a list as expected"}
) # Not a list
mock_openai_client.chat.completions.create.return_value = (
create_mock_chat_completion(mock_response_content)
)
result_cards = await process_crawled_page(mock_openai_client, sample_crawled_page)
assert len(result_cards) == 0
@pytest.mark.anyio
async def test_process_crawled_page_llm_returns_dict_with_cards_key(
mock_openai_client, sample_crawled_page
):
mock_response_content = json.dumps(
{"cards": [{"front": "Q1", "back": "A1", "tags": []}]}
)
mock_openai_client.chat.completions.create.return_value = (
create_mock_chat_completion(mock_response_content)
)
result_cards = await process_crawled_page(mock_openai_client, sample_crawled_page)
assert len(result_cards) == 1
assert result_cards[0].front == "Q1"
@pytest.mark.anyio
async def test_process_crawled_page_json_decode_error(
mock_openai_client, sample_crawled_page
):
mock_openai_client.chat.completions.create.return_value = (
create_mock_chat_completion("this is not valid json")
)
result_cards = await process_crawled_page(mock_openai_client, sample_crawled_page)
assert len(result_cards) == 0
@pytest.mark.anyio
async def test_process_crawled_page_empty_text_content(mock_openai_client):
empty_content_page = CrawledPage(
url="http://example.com/empty",
html_content="",
text_content=" ",
title="Empty",
)
result_cards = await process_crawled_page(mock_openai_client, empty_content_page)
assert len(result_cards) == 0
mock_openai_client.chat.completions.create.assert_not_awaited() # Should not call LLM
@pytest.mark.anyio
async def test_process_crawled_page_openai_api_error_retry(
mock_openai_client, sample_crawled_page, caplog
):
# Simulate API errors that should be retried
errors_to_raise = [
RateLimitError("rate limited", response=MagicMock(), body=None)
] * 2 + [
create_mock_chat_completion(
json.dumps([{"front": "Q1", "back": "A1", "tags": []}])
)
]
mock_openai_client.chat.completions.create.side_effect = errors_to_raise
result_cards = await process_crawled_page(mock_openai_client, sample_crawled_page)
assert len(result_cards) == 1
assert result_cards[0].front == "Q1"
assert (
mock_openai_client.chat.completions.create.await_count == 3
) # 2 retries + 1 success
assert "Retrying OpenAI call (attempt 1)" in caplog.text
assert "Retrying OpenAI call (attempt 2)" in caplog.text
@pytest.mark.anyio
async def test_process_crawled_page_openai_persistent_api_error(
mock_openai_client, sample_crawled_page, caplog
):
# Simulate API errors that persist beyond retries
mock_openai_client.chat.completions.create.side_effect = APIConnectionError(
request=MagicMock()
)
result_cards = await process_crawled_page(mock_openai_client, sample_crawled_page)
assert len(result_cards) == 0
assert (
mock_openai_client.chat.completions.create.await_count == 3
) # Default 3 attempts
assert "OpenAI API error after retries" in caplog.text
@pytest.mark.anyio
async def test_process_crawled_page_tiktoken_truncation(
mock_openai_client, sample_crawled_page
):
# Make text_content very long
long_text = "word " * 8000 # Approx 8000 tokens with cl100k_base
sample_crawled_page.text_content = long_text
# Mock successful response
mock_response_content = json.dumps(
[{"front": "TruncatedQ", "back": "TruncatedA", "tags": []}]
)
mock_openai_client.chat.completions.create.return_value = (
create_mock_chat_completion(mock_response_content)
)
# Using default max_prompt_content_tokens=6000
await process_crawled_page(mock_openai_client, sample_crawled_page)
# Check that the user_prompt content passed to create was truncated
# The actual user_prompt construction is inside process_crawled_page, so we inspect the call args
call_args = mock_openai_client.chat.completions.create.call_args
user_prompt_message_content = next(
m["content"] for m in call_args.kwargs["messages"] if m["role"] == "user"
)
# Rough check: actual token count of CONTENT part should be around 6000
# This is an indirect way to test; ideally, mock tiktoken.encode itself
assert "CONTENT:\n" in user_prompt_message_content
content_part = user_prompt_message_content.split("CONTENT:\n")[1].split(
"\n\nReturn a JSON array"
)[0]
import tiktoken
encoding = tiktoken.get_encoding(
"cl100k_base"
) # Assuming cl100k_base was used as fallback or for model
num_tokens = len(encoding.encode(content_part))
# Check it's close to 6000 (allowing some leeway for prompt structure around content)
assert 5900 < num_tokens < 6100
# --- Tests for process_crawled_pages ---
@pytest.mark.anyio
async def test_process_crawled_pages_success(mock_openai_client, sample_crawled_page):
pages_to_process = [
sample_crawled_page,
CrawledPage(
url="http://example.com/page2",
html_content="",
text_content="Content for page 2",
title="Page 2",
),
]
# Mock process_crawled_page to return different cards for different pages
async def mock_single_page_processor(client, page, model, max_tokens):
if page.url == pages_to_process[0].url:
return [AnkiCardData(front="P1Q1", back="P1A1", source_url=page.url)]
elif page.url == pages_to_process[1].url:
return [
AnkiCardData(front="P2Q1", back="P2A1", source_url=page.url),
AnkiCardData(front="P2Q2", back="P2A2", source_url=page.url),
]
return []
with patch(
"ankigen_core.llm_interface.process_crawled_page",
side_effect=mock_single_page_processor,
) as mock_processor:
result_cards = await process_crawled_pages(
mock_openai_client, pages_to_process, max_concurrent_requests=1
)
assert len(result_cards) == 3
assert result_cards[0].front == "P1Q1"
assert result_cards[1].front == "P2Q1"
assert result_cards[2].front == "P2Q2"
assert mock_processor.call_count == 2
@pytest.mark.anyio
async def test_process_crawled_pages_partial_failure(
mock_openai_client, sample_crawled_page
):
pages_to_process = [
sample_crawled_page, # This one will succeed
CrawledPage(
url="http://example.com/page_fail",
html_content="",
text_content="Content for page fail",
title="Page Fail",
),
CrawledPage(
url="http://example.com/page3",
html_content="",
text_content="Content for page 3",
title="Page 3",
), # This one will succeed
]
async def mock_single_page_processor_with_failure(client, page, model, max_tokens):
if page.url == pages_to_process[0].url:
return [AnkiCardData(front="P1Q1", back="P1A1", source_url=page.url)]
elif page.url == pages_to_process[1].url: # page_fail
raise APIConnectionError(request=MagicMock())
elif page.url == pages_to_process[2].url:
return [AnkiCardData(front="P3Q1", back="P3A1", source_url=page.url)]
return []
with patch(
"ankigen_core.llm_interface.process_crawled_page",
side_effect=mock_single_page_processor_with_failure,
) as mock_processor:
result_cards = await process_crawled_pages(
mock_openai_client, pages_to_process, max_concurrent_requests=2
)
assert len(result_cards) == 2 # Only cards from successful pages
successful_urls = [card.source_url for card in result_cards]
assert pages_to_process[0].url in successful_urls
assert pages_to_process[2].url in successful_urls
assert pages_to_process[1].url not in successful_urls
assert mock_processor.call_count == 3
@pytest.mark.anyio
async def test_process_crawled_pages_progress_callback(
mock_openai_client, sample_crawled_page
):
pages_to_process = [sample_crawled_page] * 3 # 3 identical pages for simplicity
progress_log = []
def callback(completed_count, total_count):
progress_log.append((completed_count, total_count))
async def mock_simple_processor(client, page, model, max_tokens):
await asyncio.sleep(0.01) # Simulate work
return [AnkiCardData(front=f"{page.url}-Q", back="A", source_url=page.url)]
with patch(
"ankigen_core.llm_interface.process_crawled_page",
side_effect=mock_simple_processor,
):
await process_crawled_pages(
mock_openai_client,
pages_to_process,
progress_callback=callback,
max_concurrent_requests=1,
)
assert len(progress_log) == 3
assert progress_log[0] == (1, 3)
assert progress_log[1] == (2, 3)
assert progress_log[2] == (3, 3)
# Placeholder for API key, can be anything for tests
TEST_API_KEY = "sk-testkey1234567890abcdefghijklmnopqrstuvwxyz"
@pytest.fixture
def client_manager():
"""Fixture for OpenAIClientManager."""
return OpenAIClientManager()
@pytest.fixture
def mock_async_openai_client():
"""Mocks an AsyncOpenAI client instance."""
mock_client = AsyncMock()
mock_client.chat = AsyncMock()
mock_client.chat.completions = AsyncMock()
mock_client.chat.completions.create = AsyncMock()
# Mock the response structure for the .create method
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message = MagicMock()
mock_response.choices[
0
].message.content = '{"question": "Q1", "answer": "A1"}' # Default valid JSON
mock_response.usage = MagicMock()
mock_response.usage.total_tokens = 100
mock_client.chat.completions.create.return_value = mock_response
return mock_client
@pytest.fixture
def sample_crawled_page():
"""Fixture for a sample CrawledPage object."""
return CrawledPage(
url="http://example.com",
html_content="<html><body>This is some test content for the page.</body></html>",
text_content="This is some test content for the page.",
title="Test Page",
meta_description="A test page.",
meta_keywords=["test", "page"],
crawl_depth=0,
)
@pytest.mark.anyio
async def test_process_crawled_page_success(
client_manager, mock_async_openai_client, sample_crawled_page
):
"""Test successful processing of a single crawled page."""
with patch.object(
client_manager, "get_client", return_value=mock_async_openai_client
):
result, tokens = await process_crawled_page(
mock_async_openai_client,
sample_crawled_page,
"gpt-4o", # model
max_prompt_content_tokens=1000,
)
assert isinstance(result, AnkiCardData)
assert result.front == "Q1"
assert result.back == "A1"
assert tokens == 100
mock_async_openai_client.chat.completions.create.assert_called_once()
@pytest.mark.anyio
async def test_process_crawled_page_json_error(
client_manager, mock_async_openai_client, sample_crawled_page
):
"""Test handling of invalid JSON response from LLM."""
mock_async_openai_client.chat.completions.create.return_value.choices[
0
].message.content = "This is not JSON"
with patch.object(
client_manager, "get_client", return_value=mock_async_openai_client
):
# Reset call count for this specific test scenario
mock_async_openai_client.chat.completions.create.reset_mock()
result, tokens = await process_crawled_page(
mock_async_openai_client,
sample_crawled_page,
"gpt-4o",
max_prompt_content_tokens=1000,
)
assert result is None
assert (
tokens == 100
) # Tokens are still counted even if parsing fails on the first attempt response
# Check tenacity retries - should be called multiple times (default 3 for JSON error + 1 original = 4, or up to max_attempts)
# The default for _parse_json_response is 3 attempts. process_crawled_page itself has @retry for API errors.
# For JSON error, the retry is within _parse_json_response. The outer retry on process_crawled_page for APIError won't trigger for JSON error.
# So, create will be called once, and _parse_json_response will try to parse its content 3 times.
# The mock_async_openai_client.chat.completions.create is called once by process_crawled_page.
# The tenacity retry for JSON parsing is internal to _parse_json_response, which is not directly mocked here.
# What we can check is that create was called, and the result is None due to parsing failure.
# To properly test tenacity for JSON, we'd need to mock json.loads within _parse_json_response or make _parse_json_response a separate testable unit.
# For now, verifying create was called once and result is None is sufficient for this level.
assert mock_async_openai_client.chat.completions.create.call_count >= 1
# If we want to assert exact retry counts for JSON, we need to mock json.loads inside the function
# or test the retry behavior of `_parse_json_response` separately.
@pytest.mark.anyio
async def test_process_crawled_page_api_error(
client_manager, mock_async_openai_client, sample_crawled_page
):
"""Test handling of API error during LLM call."""
# Correctly instantiate APIError: needs a 'request' argument.
# The 'response' is typically part of the error object after it's raised by httpx, not a constructor arg.
mock_request = MagicMock() # Mock an httpx.Request object
mock_async_openai_client.chat.completions.create.side_effect = APIError(
message="Test API Error", request=mock_request, body=None
)
with patch.object(
client_manager, "get_client", return_value=mock_async_openai_client
):
# Reset call count for this specific test scenario
mock_async_openai_client.chat.completions.create.reset_mock()
result, tokens = await process_crawled_page(
mock_async_openai_client,
sample_crawled_page,
"gpt-4o",
max_prompt_content_tokens=1000,
)
assert result is None
assert tokens == 0 # No tokens if API call fails before response
# Check tenacity retries - should be called multiple times (default for APIError is 3 attempts)
assert mock_async_openai_client.chat.completions.create.call_count > 1
@pytest.mark.anyio
async def test_process_crawled_page_content_truncation(
client_manager, mock_async_openai_client, sample_crawled_page
):
"""Test content truncation based on max_prompt_content_tokens."""
long_content_piece = "This is a word. "
repetitions = 10
sample_crawled_page.content = long_content_piece * repetitions
with (
patch.object(
client_manager, "get_client", return_value=mock_async_openai_client
),
patch("tiktoken.get_encoding") as mock_get_encoding,
):
mock_encoding = MagicMock()
original_tokens = []
for i in range(repetitions):
original_tokens.extend([i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3])
mock_encoding.encode.return_value = original_tokens
def mock_decode_side_effect(token_ids):
num_tokens_to_decode = len(token_ids)
num_full_pieces = num_tokens_to_decode // 4
partial_piece_tokens = num_tokens_to_decode % 4
decoded_str = long_content_piece * num_full_pieces
if partial_piece_tokens > 0:
words_in_piece = long_content_piece.strip().split(" ")
num_words_to_take = min(partial_piece_tokens, len(words_in_piece))
decoded_str += " ".join(words_in_piece[:num_words_to_take])
return decoded_str.strip()
mock_encoding.decode.side_effect = mock_decode_side_effect
mock_get_encoding.return_value = mock_encoding
mock_async_openai_client.chat.completions.create.reset_mock()
await process_crawled_page(
mock_async_openai_client,
sample_crawled_page,
"gpt-4o",
max_prompt_content_tokens=5,
)
mock_get_encoding.assert_called_once_with("cl100k_base")
mock_encoding.encode.assert_called_once_with(
sample_crawled_page.content, disallowed_special=()
)
mock_encoding.decode.assert_called_once_with(original_tokens[:5])
call_args = mock_async_openai_client.chat.completions.create.call_args
assert call_args is not None
messages = call_args.kwargs["messages"]
user_prompt_content = messages[1]["content"]
expected_truncated_content = mock_decode_side_effect(original_tokens[:5])
assert f"Content: {expected_truncated_content}" in user_prompt_content
# The following tests are commented out due to invalid async iteration usage
# @pytest.mark.anyio
# async def test_process_crawled_pages_empty_list(client_manager):
# """Test processing an empty list of crawled pages."""
# results = []
# # Correctly iterate over the async generator
# async for result_item in process_crawled_pages(
# pages=[], openai_client=mock_async_openai_client, model="gpt-4o"
# ):
# results.append(result_item)
# assert len(results) == 0
# @pytest.mark.anyio
# async def test_process_crawled_pages_single_page_success(
# client_manager, mock_async_openai_client, sample_crawled_page
# ):
# """Test processing a list with a single successful page."""
# pages = [sample_crawled_page]
# # We mock process_crawled_page itself since its unit tests cover its internal logic
# with patch(
# "ankigen_core.llm_interface.process_crawled_page", new_callable=AsyncMock
# ) as mock_single_process:
# mock_single_process.return_value = (
# AnkiCardData(front="Q1", back="A1"),
# 100,
# )
# results = []
# async for result_tuple in process_crawled_pages(
# pages=pages, openai_client=mock_async_openai_client, model="gpt-4o"
# ):
# results.append(result_tuple)
# assert len(results) == 1
# page, card_data, tokens = results[0]
# assert page == sample_crawled_page
# assert isinstance(card_data, AnkiCardData)
# assert card_data.front == "Q1"
# assert card_data.back == "A1"
# assert tokens == 100
# # Check that process_crawled_page was called with correct default parameters from process_crawled_pages
# mock_single_process.assert_called_once_with(
# sample_crawled_page,
# mock_async_openai_client,
# "gpt-4o", # model
# max_prompt_content_tokens=5000, # default from process_crawled_pages
# # The following are also defaults from process_crawled_pages
# # Ensure they are passed down if not overridden in the call to process_crawled_pages
# )
# @pytest.mark.anyio
# async def test_process_crawled_pages_multiple_pages_mixed_results(client_manager):
# """Test processing multiple pages with mixed success and failure."""
# page1 = CrawledPage(
# url="http://example.com/1",
# html_content="",
# text_content="Content 1",
# title="Page 1",
# )
# page2 = CrawledPage(
# url="http://example.com/2",
# html_content="",
# text_content="Content 2",
# title="Page 2",
# ) # This one will fail
# page3 = CrawledPage(
# url="http://example.com/3",
# html_content="",
# text_content="Content 3",
# title="Page 3",
# )
# pages_to_process = [page1, page2, page3]
# async def mock_single_process_side_effect(page, manager, model, **kwargs):
# await asyncio.sleep(0.01) # simulate async work
# if page.url.endswith("1"):
# return (AnkiCardData(front="Q1", back="A1"), 100)
# elif page.url.endswith("2"):
# return (None, 50) # Failed processing, some tokens consumed
# elif page.url.endswith("3"):
# return (AnkiCardData(front="Q3", back="A3"), 150)
# return (None, 0)
# with patch(
# "ankigen_core.llm_interface.process_crawled_page",
# side_effect=mock_single_process_side_effect,
# ) as mock_process_call:
# results = []
# async for result_tuple in process_crawled_pages(
# pages=pages_to_process,
# openai_client=mock_async_openai_client,
# model="gpt-4o",
# max_concurrent_requests=2, # Test with concurrency
# ):
# results.append(result_tuple)
# assert len(results) == 3
# assert mock_process_call.call_count == 3
# results_map = {res[0].url: res for res in results}
# assert results_map["http://example.com/1"][1] is not None
# assert results_map["http://example.com/1"][1].front == "Q1"
# assert results_map["http://example.com/1"][1].back == "A1"
# assert results_map["http://example.com/1"][2] == 100
# assert results_map["http://example.com/2"][1] is None
# assert results_map["http://example.com/2"][2] == 50
# assert results_map["http://example.com/3"][1] is not None
# assert results_map["http://example.com/3"][1].front == "Q3"
# assert results_map["http://example.com/3"][1].back == "A3"
# assert results_map["http://example.com/3"][2] == 150
# # Check that parameters were passed down correctly from process_crawled_pages to process_crawled_page
# for call_args in mock_process_call.call_args_list:
# args, kwargs = call_args
# assert kwargs["max_prompt_content_tokens"] == 5000 # default
# # These were passed to process_crawled_pages and should be passed down
# # However, process_crawled_page itself doesn't directly use max_concurrent_requests or request_delay
# # These are used by process_crawled_pages for its own loop control.
# # So we can't directly check them in the call to process_crawled_page mock here.
# # The important check is that process_crawled_page is called for each page.
@pytest.mark.anyio
async def test_openai_client_manager_get_client(
client_manager, mock_async_openai_client
):
"""Test that get_client returns the AsyncOpenAI client instance and initializes it once."""
with patch(
"openai.AsyncOpenAI", return_value=mock_async_openai_client
) as mock_constructor:
client1 = client_manager.get_client() # First call, should initialize
client2 = client_manager.get_client() # Second call, should return existing
assert client1 is mock_async_openai_client
assert client2 is mock_async_openai_client
mock_constructor.assert_called_once_with(api_key=TEST_API_KEY)
# Notes for further tests:
# - Test progress callback in process_crawled_pages if it were implemented.
# - Test specific retry conditions for tenacity if more complex logic added.
# - Test behavior of semaphore in process_crawled_pages more directly (might be complex).