Spaces:
Running
Running
import pytest | |
from unittest.mock import MagicMock, AsyncMock, patch | |
from langchain_core.language_models.chat_models import BaseChatModel | |
from langchain_core.messages import SystemMessage, HumanMessage, AIMessageChunk | |
# Assuming utils.persona is in the Python path | |
from utils.persona import PersonaReasoning, PersonaFactory | |
import inspect | |
print(f"DEBUG: PersonaFactory imported from: {inspect.getfile(PersonaFactory)}") | |
print(f"DEBUG: PersonaReasoning imported from: {inspect.getfile(PersonaReasoning)}") | |
def mock_llm(): | |
"""Fixture for a mocked LLM instance.""" | |
llm = MagicMock(spec=BaseChatModel) | |
# Mock the behavior of astream | |
async def mock_astream_behavior(messages): | |
# Simulate streaming chunks | |
yield AIMessageChunk(content="Hello, ") | |
yield AIMessageChunk(content="world!") | |
# Simulate an empty chunk which can happen | |
yield AIMessageChunk(content="") | |
yield AIMessageChunk(content=" How are you?") | |
llm.astream = MagicMock(side_effect=mock_astream_behavior) | |
return llm | |
async def test_persona_reasoning_generate_perspective(mock_llm): | |
"""Test that PersonaReasoning.generate_perspective calls its LLM correctly and returns aggregated content.""" | |
persona_id = "test_persona" | |
name = "Test Persona" | |
system_prompt = "You are a test persona." | |
reasoning = PersonaReasoning(persona_id, name, system_prompt, mock_llm) | |
query = "What is the meaning of life?" | |
expected_response = "Hello, world! How are you?" | |
actual_response = await reasoning.generate_perspective(query) | |
# Verify LLM call | |
mock_llm.astream.assert_called_once() | |
call_args = mock_llm.astream.call_args[0][0] # Get the first positional argument (messages list) | |
assert len(call_args) == 2 | |
assert isinstance(call_args[0], SystemMessage) | |
assert call_args[0].content == system_prompt | |
assert isinstance(call_args[1], HumanMessage) | |
assert call_args[1].content == query | |
# Verify aggregated response | |
assert actual_response == expected_response | |
def test_persona_factory_initialization(): | |
"""Test PersonaFactory initialization and config loading.""" | |
factory = PersonaFactory() | |
assert len(factory.persona_configs) > 0 # Check that some configs are loaded | |
assert "analytical" in factory.persona_configs | |
assert factory.persona_configs["analytical"]["name"] == "Analytical" | |
def test_persona_factory_create_persona_success(mock_llm): | |
"""Test successful creation of a PersonaReasoning instance.""" | |
factory = PersonaFactory() | |
persona_id = "analytical" | |
persona_instance = factory.create_persona(persona_id, mock_llm) | |
assert persona_instance is not None | |
assert isinstance(persona_instance, PersonaReasoning) | |
assert persona_instance.persona_id == persona_id | |
assert persona_instance.name == factory.persona_configs[persona_id]["name"] | |
assert persona_instance.system_prompt == factory.persona_configs[persona_id]["system_prompt"] | |
assert persona_instance.llm == mock_llm | |
def test_persona_factory_create_persona_invalid_id(mock_llm): | |
"""Test creating a persona with an invalid ID returns None.""" | |
factory = PersonaFactory() | |
persona_instance = factory.create_persona("non_existent_persona", mock_llm) | |
assert persona_instance is None | |
def test_persona_factory_create_persona_no_llm(): | |
"""Test creating a persona without an LLM instance returns None.""" | |
factory = PersonaFactory() | |
# We need a way to pass a 'None' LLM or ensure BaseChatModel type hint isn't violated | |
# For now, let's assume the type hint means it must be a BaseChatModel. | |
# The implementation checks `if config and llm_instance:` | |
# So passing a non-BaseChatModel or None should ideally be handled by create_persona. | |
# Let's test with None if the type hint allows, or by how create_persona handles it. | |
# The implementation prints an error if llm_instance is None, and returns None. | |
# Patch print to check for the error message if desired, but for now, just check None return | |
with patch('utils.persona.base.print') as mock_print: # Patched print in the correct module | |
persona_instance = factory.create_persona("analytical", None) # Pass None for LLM | |
assert persona_instance is None | |
mock_print.assert_any_call("DEBUG Error: LLM instance not provided for persona analytical") | |
def test_get_available_personas(): | |
"""Test that get_available_personas returns the expected dictionary.""" | |
factory = PersonaFactory() | |
available = factory.get_available_personas() | |
assert isinstance(available, dict) | |
assert "analytical" in available | |
assert available["analytical"] == "Analytical" | |
assert len(available) == len(factory.persona_configs) |