File size: 13,541 Bytes
31add3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
import pytest
from unittest.mock import patch, MagicMock, AsyncMock
import chainlit as cl # Import chainlit for its types

# Assuming your app structure allows these imports
# You might need to adjust paths or ensure __init__.py files are present
from app import (
    on_chat_start,
    _apply_chat_settings_to_state,
    execute_persona_tasks, 
    InsightFlowState,
    PersonaFactory, # Assuming PersonaFactory is accessible or can be mocked
    PERSONA_LLM_MAP, # Assuming this is accessible or can be mocked
    RAG_ENABLED_PERSONA_IDS # If this is defined globally in app.py
)
from utils.persona.base import PersonaReasoning # For creating mock persona objects

# Default RAG_ENABLED_PERSONA_IDS from app.py (or import it if it's moved to a config)
APP_RAG_ENABLED_PERSONA_IDS = ["analytical", "philosophical", "metaphorical"]

@pytest.fixture
def mock_cl_user_session():
    """Mocks chainlit.user_session with a dictionary-like store."""
    with patch('chainlit.user_session', new_callable=MagicMock) as mock_session:
        _session_store = {}
        def get_item(key, default=None):
            return _session_store.get(key, default)
        def set_item(key, value):
            _session_store[key] = value
        def contains_item(key):
            # Check if key is in _session_store, not mock_session itself
            return key in _session_store 

        mock_session.get = MagicMock(side_effect=get_item)
        mock_session.set = MagicMock(side_effect=set_item)
        # To make `key in cl.user_session` work as expected:
        mock_session.__contains__ = MagicMock(side_effect=contains_item)
        
        # Initialize with some common defaults if needed by other parts of app logic
        _session_store["id"] = "test_session_id"
        _session_store["persona_factory"] = PersonaFactory() # Real or mocked
        _session_store["embedding_model"] = MagicMock() # Assume it's initialized
        
        yield mock_session
        _session_store.clear() # Clean up after test

@pytest.fixture
def mock_persona_factory():
    factory = MagicMock(spec=PersonaFactory)
    # Setup mock personas that the factory can create
    mock_analytical_persona = MagicMock(spec=PersonaReasoning)
    mock_analytical_persona.id = "analytical"
    mock_analytical_persona.name = "Analytical Persona"
    mock_analytical_persona.expertise = "Logic and reason"
    mock_analytical_persona.role = "To analyze deeply"
    mock_analytical_persona.generate_perspective = AsyncMock(return_value="Analytical perspective")

    mock_philosophical_persona = MagicMock(spec=PersonaReasoning)
    mock_philosophical_persona.id = "philosophical"
    mock_philosophical_persona.name = "Philosophical Persona"
    mock_philosophical_persona.expertise = "Wisdom and ethics"
    mock_philosophical_persona.role = "To ponder deeply"
    mock_philosophical_persona.generate_perspective = AsyncMock(return_value="Philosophical perspective")
    
    mock_scientific_persona = MagicMock(spec=PersonaReasoning) # Non-RAG example
    mock_scientific_persona.id = "scientific"
    mock_scientific_persona.name = "Scientific Persona"
    mock_scientific_persona.expertise = "Empirical evidence"
    mock_scientific_persona.role = "To investigate phenomena"
    mock_scientific_persona.generate_perspective = AsyncMock(return_value="Scientific perspective")

    def create_persona_side_effect(persona_id, llm):
        if persona_id == "analytical": return mock_analytical_persona
        if persona_id == "philosophical": return mock_philosophical_persona
        if persona_id == "scientific": return mock_scientific_persona
        return MagicMock(spec=PersonaReasoning, id=persona_id, name=f"{persona_id.capitalize()} Persona", generate_perspective=AsyncMock(return_value=f"{persona_id} perspective"))

    factory.create_persona = MagicMock(side_effect=create_persona_side_effect)
    # Mock get_available_personas if _apply_chat_settings_to_state uses it directly
    factory.get_available_personas = MagicMock(return_value={"analytical": "Analytical", "philosophical": "Philosophical", "scientific": "Scientific"})
    factory.persona_configs = { # For on_chat_start to build switches
        "analytical": {"name": "Analytical Persona"},
        "philosophical": {"name": "Philosophical Persona"},
        "scientific": {"name": "Scientific Persona"}
    }
    return factory

@pytest.mark.asyncio
async def test_on_chat_start_adds_rag_toggle(mock_cl_user_session, mock_persona_factory):
    """Test that on_chat_start includes the 'enable_rag' Switch."""
    mock_cl_user_session.get.side_effect = lambda key, default=None: mock_persona_factory if key == "persona_factory" else (_session_store.get(key,default) if '_session_store' in globals() else default) # Ensure factory is returned for this test
    
    # Configure the mock for cl.ChatSettings
    mock_chat_settings_instance = AsyncMock() # This will be returned by cl.ChatSettings()
    mock_chat_settings_instance.send = AsyncMock() # Ensure the send method is an AsyncMock

    with patch('app.initialize_configurations') as mock_init_config, \
         patch('app.PersonaFactory', return_value=mock_persona_factory), \
         patch('chainlit.ChatSettings', return_value=mock_chat_settings_instance) as mock_chat_settings_class, \
         patch('chainlit.Message') as mock_cl_message, \
         patch('app.get_embedding_model', return_value=MagicMock()): # Mock embedding model loading
        
        await on_chat_start()

        mock_chat_settings_class.assert_called_once() # Check that cl.ChatSettings class was called
        # mock_chat_settings_instance.send.assert_called_once() # Check that the send method on the instance was called

        args, _ = mock_chat_settings_class.call_args
        inputs_list = args[0] # The first positional argument is `inputs`
        
        rag_toggle_present = any(widget.id == "enable_rag" and isinstance(widget, cl.Switch) for widget in inputs_list)
        assert rag_toggle_present, "'enable_rag' Switch not found in ChatSettings"
        
        rag_toggle_widget = next(widget for widget in inputs_list if widget.id == "enable_rag")
        assert rag_toggle_widget.initial is True, "RAG toggle should be ON by default"

@pytest.mark.asyncio
async def test_apply_chat_settings_reads_rag_toggle(mock_cl_user_session, mock_persona_factory):
    """Test that _apply_chat_settings_to_state correctly reads and sets the RAG toggle."""
    mock_cl_user_session.get.side_effect = [
        {"enable_rag": False, "selected_team": "none"},  # First call from _apply_chat_settings_to_state for chat_settings
        InsightFlowState(selected_personas=[]),         # For insight_flow_state
        mock_persona_factory,                          # For persona_factory
        False,                                         # Subsequent get for enable_rag from session
    ]
    # More robust: use a dict for side_effect or a more stateful mock for cl.user_session.get
    # This simple list relies on call order, which is fragile.
    # Let's refine using the session_store approach in the fixture itself for the general case
    # For this specific test, we can control the return values for chat_settings specifically.

    _session_store_for_test = {
        "insight_flow_state": InsightFlowState(selected_personas=[]),
        "persona_factory": mock_persona_factory,
        "chat_settings": {"enable_rag": False, "selected_team": "none"} # Simulate UI sending False
    }
    def mock_get_specific(key, default=None):
        if key == "chat_settings": return _session_store_for_test["chat_settings"]
        if key == "insight_flow_state": return _session_store_for_test["insight_flow_state"]
        if key == "persona_factory": return _session_store_for_test["persona_factory"]
        return _session_store_for_test.get(key, default)
    
    mock_cl_user_session.get = MagicMock(side_effect=mock_get_specific)

    await _apply_chat_settings_to_state()
    
    # Check that cl.user_session.set was called to update 'enable_rag'
    # Find the call to set 'enable_rag'
    set_rag_call = next((call for call in mock_cl_user_session.set.call_args_list if call[0][0] == 'enable_rag'), None)
    assert set_rag_call is not None, "cl.user_session.set was not called for 'enable_rag'"
    assert set_rag_call[0][1] is False, "'enable_rag' in session should be False"


@pytest.mark.asyncio
@patch('app.get_relevant_context_for_query', new_callable=AsyncMock) # Mock the RAG context function
async def test_execute_persona_tasks_rag_toggle_on(mock_get_context, mock_cl_user_session, mock_persona_factory):
    """Test execute_persona_tasks: RAG attempted when toggle is ON for RAG-enabled persona."""
    mock_cl_user_session.set("enable_rag", True)
    mock_cl_user_session.set("persona_factory", mock_persona_factory)
    mock_cl_user_session.set("embedding_model", MagicMock()) # Mocked embedding model
    mock_cl_user_session.set("progress_msg", AsyncMock(spec=cl.Message)) # Mock progress message
    mock_cl_user_session.set("completed_steps_log", [])

    # Mock PERSONA_LLM_MAP or ensure it's correctly populated for 'analytical'
    with patch.dict(PERSONA_LLM_MAP, {"analytical": MagicMock(spec=ChatOpenAI)}, clear=True):
        initial_state = InsightFlowState(
            query="test query for analytical",
            selected_personas=["analytical"], # RAG-enabled persona
            persona_responses={}
        )
        mock_get_context.return_value = "Retrieved RAG context."

        final_state = await execute_persona_tasks(initial_state)

        mock_get_context.assert_called_once_with("test query for analytical", "analytical", mock_cl_user_session.get("embedding_model"))
        # Check if the persona's generate_perspective was called with an augmented prompt
        analytical_persona_mock = mock_persona_factory.create_persona("analytical", None) # Get the mock
        call_args = analytical_persona_mock.generate_perspective.call_args[0][0]
        assert "Retrieved RAG context:" in call_args
        assert "User Query: test query for analytical" in call_args
        assert final_state["persona_responses"]["analytical"] == "Analytical perspective"

@pytest.mark.asyncio
@patch('app.get_relevant_context_for_query', new_callable=AsyncMock)
async def test_execute_persona_tasks_rag_toggle_off(mock_get_context, mock_cl_user_session, mock_persona_factory):
    """Test execute_persona_tasks: RAG NOT attempted when toggle is OFF."""
    mock_cl_user_session.set("enable_rag", False) # RAG is OFF
    mock_cl_user_session.set("persona_factory", mock_persona_factory)
    mock_cl_user_session.set("embedding_model", MagicMock())
    mock_cl_user_session.set("progress_msg", AsyncMock(spec=cl.Message))
    mock_cl_user_session.set("completed_steps_log", [])

    with patch.dict(PERSONA_LLM_MAP, {"analytical": MagicMock(spec=ChatOpenAI)}, clear=True):
        initial_state = InsightFlowState(
            query="test query for analytical rag off",
            selected_personas=["analytical"],
            persona_responses={}
        )
        
        final_state = await execute_persona_tasks(initial_state)

        mock_get_context.assert_not_called()
        # Check if the persona's generate_perspective was called with the original query
        analytical_persona_mock = mock_persona_factory.create_persona("analytical", None)
        call_args = analytical_persona_mock.generate_perspective.call_args[0][0]
        assert "Retrieved RAG context:" not in call_args # Original prompt structure without RAG context part
        assert "User Query: test query for analytical rag off" in call_args # Should be original query or a non-RAG augmented one
        assert "No specific context from your knowledge base was retrieved" in call_args # Or check for the non-RAG prompt
        assert final_state["persona_responses"]["analytical"] == "Analytical perspective"

@pytest.mark.asyncio
@patch('app.get_relevant_context_for_query', new_callable=AsyncMock)
async def test_execute_persona_tasks_rag_on_non_rag_persona(mock_get_context, mock_cl_user_session, mock_persona_factory):
    """Test execute_persona_tasks: RAG NOT attempted for non-RAG-enabled persona even if toggle is ON."""
    mock_cl_user_session.set("enable_rag", True) # RAG is ON globally
    mock_cl_user_session.set("persona_factory", mock_persona_factory)
    mock_cl_user_session.set("embedding_model", MagicMock())
    mock_cl_user_session.set("progress_msg", AsyncMock(spec=cl.Message))
    mock_cl_user_session.set("completed_steps_log", [])

    with patch.dict(PERSONA_LLM_MAP, {"scientific": MagicMock(spec=ChatOpenAI)}, clear=True):
        initial_state = InsightFlowState(
            query="test query for scientific",
            selected_personas=["scientific"], # NON-RAG-enabled persona
            persona_responses={}
        )
        
        final_state = await execute_persona_tasks(initial_state)

        mock_get_context.assert_not_called()
        scientific_persona_mock = mock_persona_factory.create_persona("scientific", None)
        call_args = scientific_persona_mock.generate_perspective.call_args[0][0]
        assert "Retrieved RAG context:" not in call_args
        assert call_args == "test query for scientific" # Original query passed directly
        assert final_state["persona_responses"]["scientific"] == "Scientific perspective"

# Need to import ChatOpenAI for the PERSONA_LLM_MAP patching to work if it's type hinted.
from langchain_openai import ChatOpenAI