|
|
|
import pytest |
|
from unittest.mock import patch, MagicMock, ANY, AsyncMock |
|
from openai import OpenAIError |
|
import json |
|
import tenacity |
|
import asyncio |
|
from openai.types.chat import ChatCompletion |
|
from openai.types.chat.chat_completion import Choice as ChatCompletionChoice |
|
from openai.types.chat.chat_completion_message import ChatCompletionMessage |
|
from openai import RateLimitError, APIConnectionError, AsyncOpenAI |
|
|
|
|
|
from ankigen_core.llm_interface import ( |
|
OpenAIClientManager, |
|
structured_output_completion, |
|
process_crawled_page, |
|
process_crawled_pages, |
|
) |
|
from ankigen_core.utils import ( |
|
ResponseCache, |
|
) |
|
from ankigen_core.models import CrawledPage, AnkiCardData |
|
from openai import APIError |
|
|
|
|
|
|
|
|
|
@pytest.mark.anyio |
|
async def test_client_manager_init(): |
|
"""Test initial state of the client manager.""" |
|
manager = OpenAIClientManager() |
|
assert manager._client is None |
|
assert manager._api_key is None |
|
|
|
|
|
@pytest.mark.anyio |
|
async def test_client_manager_initialize_success(): |
|
"""Test successful client initialization.""" |
|
manager = OpenAIClientManager() |
|
valid_key = "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" |
|
|
|
with patch( |
|
"ankigen_core.llm_interface.AsyncOpenAI" |
|
) as mock_async_openai_constructor: |
|
await manager.initialize_client(valid_key) |
|
mock_async_openai_constructor.assert_called_once_with(api_key=valid_key) |
|
assert manager.get_client() is not None |
|
|
|
|
|
@pytest.mark.anyio |
|
async def test_client_manager_initialize_invalid_key_format(): |
|
"""Test initialization failure with invalid API key format.""" |
|
manager = OpenAIClientManager() |
|
invalid_key = "invalid-key-format" |
|
with pytest.raises(ValueError, match="Invalid OpenAI API key format."): |
|
await manager.initialize_client(invalid_key) |
|
assert manager._client is None |
|
assert manager._api_key is None |
|
|
|
|
|
@pytest.mark.anyio |
|
async def test_client_manager_initialize_openai_error(): |
|
"""Test handling of OpenAIError during client initialization.""" |
|
manager = OpenAIClientManager() |
|
valid_key = "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" |
|
error_message = "Test OpenAI Init Error" |
|
|
|
with patch( |
|
"ankigen_core.llm_interface.AsyncOpenAI", side_effect=OpenAIError(error_message) |
|
) as mock_async_openai_constructor: |
|
with pytest.raises(OpenAIError, match=error_message): |
|
await manager.initialize_client(valid_key) |
|
mock_async_openai_constructor.assert_called_once_with(api_key=valid_key) |
|
|
|
|
|
@pytest.mark.anyio |
|
async def test_client_manager_get_client_success(): |
|
"""Test getting the client after successful initialization.""" |
|
manager = OpenAIClientManager() |
|
valid_key = "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" |
|
with patch( |
|
"ankigen_core.llm_interface.AsyncOpenAI" |
|
) as mock_async_openai_constructor: |
|
mock_instance = mock_async_openai_constructor.return_value |
|
await manager.initialize_client(valid_key) |
|
assert manager.get_client() == mock_instance |
|
|
|
|
|
def test_client_manager_get_client_not_initialized(): |
|
"""Test getting the client before initialization.""" |
|
manager = OpenAIClientManager() |
|
with pytest.raises(RuntimeError, match="OpenAI client is not initialized."): |
|
manager.get_client() |
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture |
|
def mock_openai_client(): |
|
client = MagicMock(spec=AsyncOpenAI) |
|
client.chat = AsyncMock() |
|
client.chat.completions = AsyncMock() |
|
client.chat.completions.create = AsyncMock() |
|
mock_chat_completion_response = create_mock_chat_completion( |
|
json.dumps([{"data": "mocked success"}]) |
|
) |
|
client.chat.completions.create.return_value = mock_chat_completion_response |
|
return client |
|
|
|
|
|
|
|
@pytest.fixture |
|
def mock_response_cache(): |
|
cache = MagicMock(spec=ResponseCache) |
|
return cache |
|
|
|
|
|
@pytest.mark.anyio |
|
async def test_structured_output_completion_cache_hit( |
|
mock_openai_client, mock_response_cache |
|
): |
|
"""Test behavior when the response is found in the cache.""" |
|
system_prompt = "System prompt" |
|
user_prompt = "User prompt" |
|
model = "test-model" |
|
cached_result = {"data": "cached result"} |
|
|
|
|
|
mock_response_cache.get.return_value = cached_result |
|
|
|
result = await structured_output_completion( |
|
openai_client=mock_openai_client, |
|
model=model, |
|
response_format={"type": "json_object"}, |
|
system_prompt=system_prompt, |
|
user_prompt=user_prompt, |
|
cache=mock_response_cache, |
|
) |
|
|
|
|
|
mock_response_cache.get.assert_called_once_with( |
|
f"{system_prompt}:{user_prompt}", model |
|
) |
|
mock_openai_client.chat.completions.create.assert_not_called() |
|
mock_response_cache.set.assert_not_called() |
|
assert result == cached_result |
|
|
|
|
|
@pytest.mark.anyio |
|
async def test_structured_output_completion_cache_miss_success( |
|
mock_openai_client, mock_response_cache |
|
): |
|
"""Test behavior on cache miss with a successful API call.""" |
|
system_prompt = "System prompt for success" |
|
user_prompt = "User prompt for success" |
|
model = "test-model-success" |
|
expected_result = {"data": "successful API result"} |
|
|
|
|
|
mock_response_cache.get.return_value = None |
|
|
|
|
|
mock_completion = MagicMock() |
|
mock_message = MagicMock() |
|
mock_message.content = json.dumps(expected_result) |
|
mock_choice = MagicMock() |
|
mock_choice.message = mock_message |
|
mock_completion.choices = [mock_choice] |
|
mock_openai_client.chat.completions.create.return_value = mock_completion |
|
|
|
result = await structured_output_completion( |
|
openai_client=mock_openai_client, |
|
model=model, |
|
response_format={"type": "json_object"}, |
|
system_prompt=system_prompt, |
|
user_prompt=user_prompt, |
|
cache=mock_response_cache, |
|
) |
|
|
|
|
|
mock_response_cache.get.assert_called_once_with( |
|
f"{system_prompt}:{user_prompt}", model |
|
) |
|
mock_openai_client.chat.completions.create.assert_called_once_with( |
|
model=model, |
|
messages=[ |
|
{ |
|
"role": "system", |
|
"content": ANY, |
|
}, |
|
{"role": "user", "content": user_prompt}, |
|
], |
|
response_format={"type": "json_object"}, |
|
temperature=0.7, |
|
) |
|
mock_response_cache.set.assert_called_once_with( |
|
f"{system_prompt}:{user_prompt}", model, expected_result |
|
) |
|
assert result == expected_result |
|
|
|
|
|
@pytest.mark.anyio |
|
async def test_structured_output_completion_api_error( |
|
mock_openai_client, mock_response_cache |
|
): |
|
"""Test behavior when the OpenAI API call raises an error.""" |
|
system_prompt = "System prompt for error" |
|
user_prompt = "User prompt for error" |
|
model = "test-model-error" |
|
error_message = "Test API Error" |
|
|
|
|
|
mock_response_cache.get.return_value = None |
|
|
|
|
|
|
|
|
|
mock_openai_client.chat.completions.create.side_effect = OpenAIError(error_message) |
|
|
|
with pytest.raises(tenacity.RetryError): |
|
await structured_output_completion( |
|
openai_client=mock_openai_client, |
|
model=model, |
|
response_format={"type": "json_object"}, |
|
system_prompt=system_prompt, |
|
user_prompt=user_prompt, |
|
cache=mock_response_cache, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert mock_response_cache.get.call_count == 3, ( |
|
f"Expected cache.get to be called 3 times due to retries, but was {mock_response_cache.get.call_count}" |
|
) |
|
|
|
assert mock_openai_client.chat.completions.create.call_count == 3, ( |
|
f"Expected create to be called 3 times due to retries, but was {mock_openai_client.chat.completions.create.call_count}" |
|
) |
|
mock_response_cache.set.assert_not_called() |
|
|
|
|
|
@pytest.mark.anyio |
|
async def test_structured_output_completion_invalid_json( |
|
mock_openai_client, mock_response_cache |
|
): |
|
"""Test behavior when the API returns invalid JSON.""" |
|
system_prompt = "System prompt for invalid json" |
|
user_prompt = "User prompt for invalid json" |
|
model = "test-model-invalid-json" |
|
invalid_json_content = "this is not json" |
|
|
|
|
|
mock_response_cache.get.return_value = None |
|
|
|
|
|
mock_completion = MagicMock() |
|
mock_message = MagicMock() |
|
mock_message.content = invalid_json_content |
|
mock_choice = MagicMock() |
|
mock_choice.message = mock_message |
|
mock_completion.choices = [mock_choice] |
|
mock_openai_client.chat.completions.create.return_value = mock_completion |
|
|
|
with pytest.raises(tenacity.RetryError): |
|
await structured_output_completion( |
|
openai_client=mock_openai_client, |
|
model=model, |
|
response_format={"type": "json_object"}, |
|
system_prompt=system_prompt, |
|
user_prompt=user_prompt, |
|
cache=mock_response_cache, |
|
) |
|
|
|
|
|
|
|
assert mock_response_cache.get.call_count == 3, ( |
|
f"Expected cache.get to be called 3 times due to retries, but was {mock_response_cache.get.call_count}" |
|
) |
|
|
|
assert mock_openai_client.chat.completions.create.call_count == 3, ( |
|
f"Expected create to be called 3 times due to retries, but was {mock_openai_client.chat.completions.create.call_count}" |
|
) |
|
mock_response_cache.set.assert_not_called() |
|
|
|
|
|
@pytest.mark.anyio |
|
async def test_structured_output_completion_no_choices( |
|
mock_openai_client, mock_response_cache |
|
): |
|
"""Test behavior when API completion has no choices.""" |
|
system_prompt = "System prompt no choices" |
|
user_prompt = "User prompt no choices" |
|
model = "test-model-no-choices" |
|
|
|
mock_response_cache.get.return_value = None |
|
mock_completion = MagicMock() |
|
mock_completion.choices = [] |
|
mock_openai_client.chat.completions.create.return_value = mock_completion |
|
|
|
|
|
result = await structured_output_completion( |
|
openai_client=mock_openai_client, |
|
model=model, |
|
response_format={"type": "json_object"}, |
|
system_prompt=system_prompt, |
|
user_prompt=user_prompt, |
|
cache=mock_response_cache, |
|
) |
|
assert result is None |
|
mock_response_cache.set.assert_not_called() |
|
|
|
|
|
@pytest.mark.anyio |
|
async def test_structured_output_completion_no_message_content( |
|
mock_openai_client, mock_response_cache |
|
): |
|
"""Test behavior when API choice has no message content.""" |
|
system_prompt = "System prompt no content" |
|
user_prompt = "User prompt no content" |
|
model = "test-model-no-content" |
|
|
|
mock_response_cache.get.return_value = None |
|
mock_completion = MagicMock() |
|
mock_message = MagicMock() |
|
mock_message.content = None |
|
mock_choice = MagicMock() |
|
mock_choice.message = mock_message |
|
mock_completion.choices = [mock_choice] |
|
mock_openai_client.chat.completions.create.return_value = mock_completion |
|
|
|
|
|
result = await structured_output_completion( |
|
openai_client=mock_openai_client, |
|
model=model, |
|
response_format={"type": "json_object"}, |
|
system_prompt=system_prompt, |
|
user_prompt=user_prompt, |
|
cache=mock_response_cache, |
|
) |
|
assert result is None |
|
mock_response_cache.set.assert_not_called() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_mock_chat_completion(content: str) -> ChatCompletion: |
|
return ChatCompletion( |
|
id="chatcmpl-test123", |
|
choices=[ |
|
ChatCompletionChoice( |
|
finish_reason="stop", |
|
index=0, |
|
message=ChatCompletionMessage(content=content, role="assistant"), |
|
logprobs=None, |
|
) |
|
], |
|
created=1677652288, |
|
model="gpt-4o", |
|
object="chat.completion", |
|
system_fingerprint="fp_test", |
|
usage=None, |
|
) |
|
|
|
|
|
@pytest.mark.anyio |
|
async def test_process_crawled_page_success(mock_openai_client, sample_crawled_page): |
|
mock_response_content = json.dumps( |
|
[ |
|
{"front": "Q1", "back": "A1", "tags": ["tag1"]}, |
|
{"front": "Q2", "back": "A2", "tags": ["tag2", "python"]}, |
|
] |
|
) |
|
mock_openai_client.chat.completions.create.return_value = ( |
|
create_mock_chat_completion(mock_response_content) |
|
) |
|
|
|
result_cards = await process_crawled_page(mock_openai_client, sample_crawled_page) |
|
|
|
assert len(result_cards) == 2 |
|
assert result_cards[0].front == "Q1" |
|
assert result_cards[0].source_url == sample_crawled_page.url |
|
assert result_cards[1].tags == ["tag2", "python"] |
|
mock_openai_client.chat.completions.create.assert_awaited_once() |
|
|
|
|
|
@pytest.mark.anyio |
|
async def test_process_crawled_page_empty_llm_response_content( |
|
mock_openai_client, sample_crawled_page |
|
): |
|
mock_openai_client.chat.completions.create.return_value = ( |
|
create_mock_chat_completion("") |
|
) |
|
|
|
result_cards = await process_crawled_page(mock_openai_client, sample_crawled_page) |
|
assert len(result_cards) == 0 |
|
|
|
|
|
@pytest.mark.anyio |
|
async def test_process_crawled_page_llm_returns_not_a_list( |
|
mock_openai_client, sample_crawled_page |
|
): |
|
mock_response_content = json.dumps( |
|
{"error": "not a list as expected"} |
|
) |
|
mock_openai_client.chat.completions.create.return_value = ( |
|
create_mock_chat_completion(mock_response_content) |
|
) |
|
|
|
result_cards = await process_crawled_page(mock_openai_client, sample_crawled_page) |
|
assert len(result_cards) == 0 |
|
|
|
|
|
@pytest.mark.anyio |
|
async def test_process_crawled_page_llm_returns_dict_with_cards_key( |
|
mock_openai_client, sample_crawled_page |
|
): |
|
mock_response_content = json.dumps( |
|
{"cards": [{"front": "Q1", "back": "A1", "tags": []}]} |
|
) |
|
mock_openai_client.chat.completions.create.return_value = ( |
|
create_mock_chat_completion(mock_response_content) |
|
) |
|
|
|
result_cards = await process_crawled_page(mock_openai_client, sample_crawled_page) |
|
assert len(result_cards) == 1 |
|
assert result_cards[0].front == "Q1" |
|
|
|
|
|
@pytest.mark.anyio |
|
async def test_process_crawled_page_json_decode_error( |
|
mock_openai_client, sample_crawled_page |
|
): |
|
mock_openai_client.chat.completions.create.return_value = ( |
|
create_mock_chat_completion("this is not valid json") |
|
) |
|
|
|
result_cards = await process_crawled_page(mock_openai_client, sample_crawled_page) |
|
assert len(result_cards) == 0 |
|
|
|
|
|
@pytest.mark.anyio |
|
async def test_process_crawled_page_empty_text_content(mock_openai_client): |
|
empty_content_page = CrawledPage( |
|
url="http://example.com/empty", |
|
html_content="", |
|
text_content=" ", |
|
title="Empty", |
|
) |
|
result_cards = await process_crawled_page(mock_openai_client, empty_content_page) |
|
assert len(result_cards) == 0 |
|
mock_openai_client.chat.completions.create.assert_not_awaited() |
|
|
|
|
|
@pytest.mark.anyio |
|
async def test_process_crawled_page_openai_api_error_retry( |
|
mock_openai_client, sample_crawled_page, caplog |
|
): |
|
|
|
errors_to_raise = [ |
|
RateLimitError("rate limited", response=MagicMock(), body=None) |
|
] * 2 + [ |
|
create_mock_chat_completion( |
|
json.dumps([{"front": "Q1", "back": "A1", "tags": []}]) |
|
) |
|
] |
|
|
|
mock_openai_client.chat.completions.create.side_effect = errors_to_raise |
|
|
|
result_cards = await process_crawled_page(mock_openai_client, sample_crawled_page) |
|
|
|
assert len(result_cards) == 1 |
|
assert result_cards[0].front == "Q1" |
|
assert ( |
|
mock_openai_client.chat.completions.create.await_count == 3 |
|
) |
|
assert "Retrying OpenAI call (attempt 1)" in caplog.text |
|
assert "Retrying OpenAI call (attempt 2)" in caplog.text |
|
|
|
|
|
@pytest.mark.anyio |
|
async def test_process_crawled_page_openai_persistent_api_error( |
|
mock_openai_client, sample_crawled_page, caplog |
|
): |
|
|
|
mock_openai_client.chat.completions.create.side_effect = APIConnectionError( |
|
request=MagicMock() |
|
) |
|
|
|
result_cards = await process_crawled_page(mock_openai_client, sample_crawled_page) |
|
|
|
assert len(result_cards) == 0 |
|
assert ( |
|
mock_openai_client.chat.completions.create.await_count == 3 |
|
) |
|
assert "OpenAI API error after retries" in caplog.text |
|
|
|
|
|
@pytest.mark.anyio |
|
async def test_process_crawled_page_tiktoken_truncation( |
|
mock_openai_client, sample_crawled_page |
|
): |
|
|
|
long_text = "word " * 8000 |
|
sample_crawled_page.text_content = long_text |
|
|
|
|
|
mock_response_content = json.dumps( |
|
[{"front": "TruncatedQ", "back": "TruncatedA", "tags": []}] |
|
) |
|
mock_openai_client.chat.completions.create.return_value = ( |
|
create_mock_chat_completion(mock_response_content) |
|
) |
|
|
|
|
|
await process_crawled_page(mock_openai_client, sample_crawled_page) |
|
|
|
|
|
|
|
call_args = mock_openai_client.chat.completions.create.call_args |
|
user_prompt_message_content = next( |
|
m["content"] for m in call_args.kwargs["messages"] if m["role"] == "user" |
|
) |
|
|
|
|
|
|
|
assert "CONTENT:\n" in user_prompt_message_content |
|
content_part = user_prompt_message_content.split("CONTENT:\n")[1].split( |
|
"\n\nReturn a JSON array" |
|
)[0] |
|
|
|
import tiktoken |
|
|
|
encoding = tiktoken.get_encoding( |
|
"cl100k_base" |
|
) |
|
num_tokens = len(encoding.encode(content_part)) |
|
|
|
|
|
assert 5900 < num_tokens < 6100 |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.anyio |
|
async def test_process_crawled_pages_success(mock_openai_client, sample_crawled_page): |
|
pages_to_process = [ |
|
sample_crawled_page, |
|
CrawledPage( |
|
url="http://example.com/page2", |
|
html_content="", |
|
text_content="Content for page 2", |
|
title="Page 2", |
|
), |
|
] |
|
|
|
|
|
async def mock_single_page_processor(client, page, model, max_tokens): |
|
if page.url == pages_to_process[0].url: |
|
return [AnkiCardData(front="P1Q1", back="P1A1", source_url=page.url)] |
|
elif page.url == pages_to_process[1].url: |
|
return [ |
|
AnkiCardData(front="P2Q1", back="P2A1", source_url=page.url), |
|
AnkiCardData(front="P2Q2", back="P2A2", source_url=page.url), |
|
] |
|
return [] |
|
|
|
with patch( |
|
"ankigen_core.llm_interface.process_crawled_page", |
|
side_effect=mock_single_page_processor, |
|
) as mock_processor: |
|
result_cards = await process_crawled_pages( |
|
mock_openai_client, pages_to_process, max_concurrent_requests=1 |
|
) |
|
|
|
assert len(result_cards) == 3 |
|
assert result_cards[0].front == "P1Q1" |
|
assert result_cards[1].front == "P2Q1" |
|
assert result_cards[2].front == "P2Q2" |
|
assert mock_processor.call_count == 2 |
|
|
|
|
|
@pytest.mark.anyio |
|
async def test_process_crawled_pages_partial_failure( |
|
mock_openai_client, sample_crawled_page |
|
): |
|
pages_to_process = [ |
|
sample_crawled_page, |
|
CrawledPage( |
|
url="http://example.com/page_fail", |
|
html_content="", |
|
text_content="Content for page fail", |
|
title="Page Fail", |
|
), |
|
CrawledPage( |
|
url="http://example.com/page3", |
|
html_content="", |
|
text_content="Content for page 3", |
|
title="Page 3", |
|
), |
|
] |
|
|
|
async def mock_single_page_processor_with_failure(client, page, model, max_tokens): |
|
if page.url == pages_to_process[0].url: |
|
return [AnkiCardData(front="P1Q1", back="P1A1", source_url=page.url)] |
|
elif page.url == pages_to_process[1].url: |
|
raise APIConnectionError(request=MagicMock()) |
|
elif page.url == pages_to_process[2].url: |
|
return [AnkiCardData(front="P3Q1", back="P3A1", source_url=page.url)] |
|
return [] |
|
|
|
with patch( |
|
"ankigen_core.llm_interface.process_crawled_page", |
|
side_effect=mock_single_page_processor_with_failure, |
|
) as mock_processor: |
|
result_cards = await process_crawled_pages( |
|
mock_openai_client, pages_to_process, max_concurrent_requests=2 |
|
) |
|
|
|
assert len(result_cards) == 2 |
|
successful_urls = [card.source_url for card in result_cards] |
|
assert pages_to_process[0].url in successful_urls |
|
assert pages_to_process[2].url in successful_urls |
|
assert pages_to_process[1].url not in successful_urls |
|
assert mock_processor.call_count == 3 |
|
|
|
|
|
@pytest.mark.anyio |
|
async def test_process_crawled_pages_progress_callback( |
|
mock_openai_client, sample_crawled_page |
|
): |
|
pages_to_process = [sample_crawled_page] * 3 |
|
progress_log = [] |
|
|
|
def callback(completed_count, total_count): |
|
progress_log.append((completed_count, total_count)) |
|
|
|
async def mock_simple_processor(client, page, model, max_tokens): |
|
await asyncio.sleep(0.01) |
|
return [AnkiCardData(front=f"{page.url}-Q", back="A", source_url=page.url)] |
|
|
|
with patch( |
|
"ankigen_core.llm_interface.process_crawled_page", |
|
side_effect=mock_simple_processor, |
|
): |
|
await process_crawled_pages( |
|
mock_openai_client, |
|
pages_to_process, |
|
progress_callback=callback, |
|
max_concurrent_requests=1, |
|
) |
|
|
|
assert len(progress_log) == 3 |
|
assert progress_log[0] == (1, 3) |
|
assert progress_log[1] == (2, 3) |
|
assert progress_log[2] == (3, 3) |
|
|
|
|
|
|
|
TEST_API_KEY = "sk-testkey1234567890abcdefghijklmnopqrstuvwxyz" |
|
|
|
|
|
@pytest.fixture |
|
def client_manager(): |
|
"""Fixture for OpenAIClientManager.""" |
|
return OpenAIClientManager() |
|
|
|
|
|
@pytest.fixture |
|
def mock_async_openai_client(): |
|
"""Mocks an AsyncOpenAI client instance.""" |
|
mock_client = AsyncMock() |
|
mock_client.chat = AsyncMock() |
|
mock_client.chat.completions = AsyncMock() |
|
mock_client.chat.completions.create = AsyncMock() |
|
|
|
|
|
mock_response = MagicMock() |
|
mock_response.choices = [MagicMock()] |
|
mock_response.choices[0].message = MagicMock() |
|
mock_response.choices[ |
|
0 |
|
].message.content = '{"question": "Q1", "answer": "A1"}' |
|
mock_response.usage = MagicMock() |
|
mock_response.usage.total_tokens = 100 |
|
|
|
mock_client.chat.completions.create.return_value = mock_response |
|
return mock_client |
|
|
|
|
|
@pytest.fixture |
|
def sample_crawled_page(): |
|
"""Fixture for a sample CrawledPage object.""" |
|
return CrawledPage( |
|
url="http://example.com", |
|
html_content="<html><body>This is some test content for the page.</body></html>", |
|
text_content="This is some test content for the page.", |
|
title="Test Page", |
|
meta_description="A test page.", |
|
meta_keywords=["test", "page"], |
|
crawl_depth=0, |
|
) |
|
|
|
|
|
@pytest.mark.anyio |
|
async def test_process_crawled_page_success( |
|
client_manager, mock_async_openai_client, sample_crawled_page |
|
): |
|
"""Test successful processing of a single crawled page.""" |
|
with patch.object( |
|
client_manager, "get_client", return_value=mock_async_openai_client |
|
): |
|
result, tokens = await process_crawled_page( |
|
mock_async_openai_client, |
|
sample_crawled_page, |
|
"gpt-4o", |
|
max_prompt_content_tokens=1000, |
|
) |
|
assert isinstance(result, AnkiCardData) |
|
assert result.front == "Q1" |
|
assert result.back == "A1" |
|
assert tokens == 100 |
|
mock_async_openai_client.chat.completions.create.assert_called_once() |
|
|
|
|
|
@pytest.mark.anyio |
|
async def test_process_crawled_page_json_error( |
|
client_manager, mock_async_openai_client, sample_crawled_page |
|
): |
|
"""Test handling of invalid JSON response from LLM.""" |
|
mock_async_openai_client.chat.completions.create.return_value.choices[ |
|
0 |
|
].message.content = "This is not JSON" |
|
|
|
with patch.object( |
|
client_manager, "get_client", return_value=mock_async_openai_client |
|
): |
|
|
|
mock_async_openai_client.chat.completions.create.reset_mock() |
|
|
|
result, tokens = await process_crawled_page( |
|
mock_async_openai_client, |
|
sample_crawled_page, |
|
"gpt-4o", |
|
max_prompt_content_tokens=1000, |
|
) |
|
assert result is None |
|
assert ( |
|
tokens == 100 |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert mock_async_openai_client.chat.completions.create.call_count >= 1 |
|
|
|
|
|
|
|
|
|
@pytest.mark.anyio |
|
async def test_process_crawled_page_api_error( |
|
client_manager, mock_async_openai_client, sample_crawled_page |
|
): |
|
"""Test handling of API error during LLM call.""" |
|
|
|
|
|
|
|
mock_request = MagicMock() |
|
mock_async_openai_client.chat.completions.create.side_effect = APIError( |
|
message="Test API Error", request=mock_request, body=None |
|
) |
|
|
|
with patch.object( |
|
client_manager, "get_client", return_value=mock_async_openai_client |
|
): |
|
|
|
mock_async_openai_client.chat.completions.create.reset_mock() |
|
|
|
result, tokens = await process_crawled_page( |
|
mock_async_openai_client, |
|
sample_crawled_page, |
|
"gpt-4o", |
|
max_prompt_content_tokens=1000, |
|
) |
|
assert result is None |
|
assert tokens == 0 |
|
|
|
assert mock_async_openai_client.chat.completions.create.call_count > 1 |
|
|
|
|
|
@pytest.mark.anyio |
|
async def test_process_crawled_page_content_truncation( |
|
client_manager, mock_async_openai_client, sample_crawled_page |
|
): |
|
"""Test content truncation based on max_prompt_content_tokens.""" |
|
long_content_piece = "This is a word. " |
|
repetitions = 10 |
|
sample_crawled_page.content = long_content_piece * repetitions |
|
|
|
with ( |
|
patch.object( |
|
client_manager, "get_client", return_value=mock_async_openai_client |
|
), |
|
patch("tiktoken.get_encoding") as mock_get_encoding, |
|
): |
|
mock_encoding = MagicMock() |
|
|
|
original_tokens = [] |
|
for i in range(repetitions): |
|
original_tokens.extend([i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3]) |
|
|
|
mock_encoding.encode.return_value = original_tokens |
|
|
|
def mock_decode_side_effect(token_ids): |
|
num_tokens_to_decode = len(token_ids) |
|
num_full_pieces = num_tokens_to_decode // 4 |
|
partial_piece_tokens = num_tokens_to_decode % 4 |
|
decoded_str = long_content_piece * num_full_pieces |
|
if partial_piece_tokens > 0: |
|
words_in_piece = long_content_piece.strip().split(" ") |
|
num_words_to_take = min(partial_piece_tokens, len(words_in_piece)) |
|
decoded_str += " ".join(words_in_piece[:num_words_to_take]) |
|
return decoded_str.strip() |
|
|
|
mock_encoding.decode.side_effect = mock_decode_side_effect |
|
mock_get_encoding.return_value = mock_encoding |
|
|
|
mock_async_openai_client.chat.completions.create.reset_mock() |
|
|
|
await process_crawled_page( |
|
mock_async_openai_client, |
|
sample_crawled_page, |
|
"gpt-4o", |
|
max_prompt_content_tokens=5, |
|
) |
|
|
|
mock_get_encoding.assert_called_once_with("cl100k_base") |
|
mock_encoding.encode.assert_called_once_with( |
|
sample_crawled_page.content, disallowed_special=() |
|
) |
|
mock_encoding.decode.assert_called_once_with(original_tokens[:5]) |
|
|
|
call_args = mock_async_openai_client.chat.completions.create.call_args |
|
assert call_args is not None |
|
messages = call_args.kwargs["messages"] |
|
user_prompt_content = messages[1]["content"] |
|
|
|
expected_truncated_content = mock_decode_side_effect(original_tokens[:5]) |
|
assert f"Content: {expected_truncated_content}" in user_prompt_content |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.anyio |
|
async def test_openai_client_manager_get_client( |
|
client_manager, mock_async_openai_client |
|
): |
|
"""Test that get_client returns the AsyncOpenAI client instance and initializes it once.""" |
|
with patch( |
|
"openai.AsyncOpenAI", return_value=mock_async_openai_client |
|
) as mock_constructor: |
|
client1 = client_manager.get_client() |
|
client2 = client_manager.get_client() |
|
|
|
assert client1 is mock_async_openai_client |
|
assert client2 is mock_async_openai_client |
|
mock_constructor.assert_called_once_with(api_key=TEST_API_KEY) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|