|
|
|
|
|
import pytest |
|
|
import json |
|
|
from unittest.mock import Mock, AsyncMock, patch, mock_open |
|
|
from app.orchestrator import Orchestrator |
|
|
from app.schema import Company, Prospect |
|
|
from pathlib import Path |
|
|
import asyncio |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_pipeline_happy_path(): |
|
|
"""Test full pipeline execution without streaming details""" |
|
|
|
|
|
|
|
|
test_company = { |
|
|
"id": "test", |
|
|
"name": "Test Co", |
|
|
"domain": "test.com", |
|
|
"industry": "SaaS", |
|
|
"size": 100, |
|
|
"pains": ["Low NPS scores"], |
|
|
"notes": ["Growing company"] |
|
|
} |
|
|
|
|
|
|
|
|
with patch('builtins.open', mock_open(read_data=json.dumps([test_company]))): |
|
|
|
|
|
with patch('app.orchestrator.MCPRegistry') as MockMCPRegistry: |
|
|
mock_mcp = Mock() |
|
|
MockMCPRegistry.return_value = mock_mcp |
|
|
|
|
|
|
|
|
mock_store = AsyncMock() |
|
|
mock_store.save_prospect = AsyncMock(return_value=None) |
|
|
mock_store.save_company = AsyncMock(return_value=None) |
|
|
mock_store.save_fact = AsyncMock(return_value=None) |
|
|
mock_store.save_contact = AsyncMock(return_value=None) |
|
|
mock_store.save_handoff = AsyncMock(return_value=None) |
|
|
mock_store.check_suppression = AsyncMock(return_value=False) |
|
|
mock_store.list_contacts_by_domain = AsyncMock(return_value=[]) |
|
|
|
|
|
|
|
|
mock_search = AsyncMock() |
|
|
mock_search.query = AsyncMock(return_value=[ |
|
|
{ |
|
|
"text": "Test Co focuses on customer experience", |
|
|
"source": "Industry Report", |
|
|
"confidence": 0.85 |
|
|
} |
|
|
]) |
|
|
|
|
|
|
|
|
mock_email = AsyncMock() |
|
|
mock_email.send = AsyncMock(return_value={"thread_id": "test-thread-123", "message_id": "msg-456", "prospect_id": "test"}) |
|
|
mock_email.get_thread = AsyncMock(return_value={ |
|
|
"id": "test-thread-123", |
|
|
"prospect_id": "test", |
|
|
"messages": [{ |
|
|
"id": "msg-456", |
|
|
"thread_id": "test-thread-123", |
|
|
"direction": "outbound", |
|
|
"subject": "Test Subject", |
|
|
"body": "Test Body", |
|
|
"sent_at": "2024-01-01T00:00:00" |
|
|
}] |
|
|
}) |
|
|
|
|
|
|
|
|
mock_calendar = AsyncMock() |
|
|
mock_calendar.suggest_slots = AsyncMock(return_value=[ |
|
|
{"start_iso": "2024-01-02T14:00:00", "end_iso": "2024-01-02T14:30:00"} |
|
|
]) |
|
|
mock_calendar.generate_ics = AsyncMock(return_value="BEGIN:VCALENDAR...") |
|
|
|
|
|
|
|
|
mock_mcp.get_store_client.return_value = mock_store |
|
|
mock_mcp.get_search_client.return_value = mock_search |
|
|
mock_mcp.get_email_client.return_value = mock_email |
|
|
mock_mcp.get_calendar_client.return_value = mock_calendar |
|
|
|
|
|
|
|
|
with patch.object(Path, 'exists', return_value=True): |
|
|
with patch.object(Path, 'read_text', return_value="\n---\nTest Footer"): |
|
|
|
|
|
with patch('agents.writer.Retriever') as MockRetriever: |
|
|
mock_retriever = Mock() |
|
|
mock_retriever.retrieve.return_value = [ |
|
|
{"text": "Relevant fact 1", "score": 0.9} |
|
|
] |
|
|
MockRetriever.return_value = mock_retriever |
|
|
|
|
|
|
|
|
with patch('agents.writer.aiohttp.ClientSession') as MockSession: |
|
|
|
|
|
mock_session = AsyncMock() |
|
|
mock_session.post.side_effect = Exception("Connection failed") |
|
|
MockSession.return_value.__aenter__.return_value = mock_session |
|
|
|
|
|
|
|
|
orchestrator = Orchestrator() |
|
|
|
|
|
|
|
|
events = [] |
|
|
async for event in orchestrator.run_pipeline(["test"]): |
|
|
events.append(event) |
|
|
|
|
|
|
|
|
event_types = [e.get("type") for e in events] |
|
|
|
|
|
|
|
|
assert "agent_start" in event_types |
|
|
assert "agent_end" in event_types |
|
|
|
|
|
|
|
|
assert "mcp_call" in event_types |
|
|
assert "mcp_response" in event_types |
|
|
|
|
|
|
|
|
|
|
|
assert "llm_done" in event_types or "policy_block" in event_types |
|
|
|
|
|
|
|
|
assert mock_store.save_prospect.called |
|
|
assert mock_search.query.called |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_pipeline_compliance_block(): |
|
|
"""Test that compliance violations block the pipeline""" |
|
|
|
|
|
test_company = { |
|
|
"id": "blocked-test", |
|
|
"name": "Blocked Co", |
|
|
"domain": "blocked.com", |
|
|
"industry": "SaaS", |
|
|
"size": 100, |
|
|
"pains": ["Test pain"], |
|
|
"notes": [] |
|
|
} |
|
|
|
|
|
with patch('builtins.open', mock_open(read_data=json.dumps([test_company]))): |
|
|
with patch('app.orchestrator.MCPRegistry') as MockMCPRegistry: |
|
|
mock_mcp = Mock() |
|
|
MockMCPRegistry.return_value = mock_mcp |
|
|
|
|
|
|
|
|
mock_store = AsyncMock() |
|
|
mock_store.save_prospect = AsyncMock(return_value=None) |
|
|
mock_store.save_fact = AsyncMock(return_value=None) |
|
|
mock_store.save_contact = AsyncMock(return_value=None) |
|
|
|
|
|
|
|
|
async def check_suppression(type, value): |
|
|
if type == "domain" and value == "blocked.com": |
|
|
return True |
|
|
if type == "email" and "blocked.com" in value: |
|
|
return True |
|
|
return False |
|
|
|
|
|
mock_store.check_suppression = AsyncMock(side_effect=check_suppression) |
|
|
mock_store.list_contacts_by_domain = AsyncMock(return_value=[]) |
|
|
|
|
|
|
|
|
mock_search = AsyncMock() |
|
|
mock_search.query = AsyncMock(return_value=[]) |
|
|
|
|
|
|
|
|
mock_email = AsyncMock() |
|
|
mock_calendar = AsyncMock() |
|
|
|
|
|
mock_mcp.get_store_client.return_value = mock_store |
|
|
mock_mcp.get_search_client.return_value = mock_search |
|
|
mock_mcp.get_email_client.return_value = mock_email |
|
|
mock_mcp.get_calendar_client.return_value = mock_calendar |
|
|
|
|
|
with patch.object(Path, 'exists', return_value=True): |
|
|
with patch.object(Path, 'read_text', return_value="\n---\nTest Footer"): |
|
|
with patch('agents.writer.Retriever') as MockRetriever: |
|
|
mock_retriever = Mock() |
|
|
mock_retriever.retrieve.return_value = [] |
|
|
MockRetriever.return_value = mock_retriever |
|
|
|
|
|
orchestrator = Orchestrator() |
|
|
|
|
|
events = [] |
|
|
async for event in orchestrator.run_pipeline(["blocked-test"]): |
|
|
events.append(event) |
|
|
|
|
|
|
|
|
messages = [str(e.get("message", "")).lower() for e in events] |
|
|
reasons = [str(e.get("payload", {}).get("reason", "")).lower() for e in events] |
|
|
all_text = " ".join(messages + reasons) |
|
|
|
|
|
assert "suppressed" in all_text or "dropped" in all_text or "blocked" in all_text, \ |
|
|
f"Should have suppression/dropped/blocked message" |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_pipeline_scorer_drop(): |
|
|
"""Test that low scores drop prospects""" |
|
|
|
|
|
test_company = { |
|
|
"id": "low-score", |
|
|
"name": "Small Co", |
|
|
"domain": "small.com", |
|
|
"industry": "Unknown", |
|
|
"size": 10, |
|
|
"pains": [], |
|
|
"notes": [] |
|
|
} |
|
|
|
|
|
with patch('builtins.open', mock_open(read_data=json.dumps([test_company]))): |
|
|
with patch('app.orchestrator.MCPRegistry') as MockMCPRegistry: |
|
|
mock_mcp = Mock() |
|
|
MockMCPRegistry.return_value = mock_mcp |
|
|
|
|
|
mock_store = AsyncMock() |
|
|
mock_store.save_prospect = AsyncMock(return_value=None) |
|
|
mock_store.save_fact = AsyncMock(return_value=None) |
|
|
mock_store.save_contact = AsyncMock(return_value=None) |
|
|
mock_store.check_suppression = AsyncMock(return_value=False) |
|
|
mock_store.list_contacts_by_domain = AsyncMock(return_value=[]) |
|
|
|
|
|
mock_search = AsyncMock() |
|
|
mock_search.query = AsyncMock(return_value=[]) |
|
|
|
|
|
mock_email = AsyncMock() |
|
|
mock_calendar = AsyncMock() |
|
|
|
|
|
mock_mcp.get_store_client.return_value = mock_store |
|
|
mock_mcp.get_search_client.return_value = mock_search |
|
|
mock_mcp.get_email_client.return_value = mock_email |
|
|
mock_mcp.get_calendar_client.return_value = mock_calendar |
|
|
|
|
|
with patch.object(Path, 'exists', return_value=True): |
|
|
with patch.object(Path, 'read_text', return_value="\n---\nTest Footer"): |
|
|
with patch('agents.writer.Retriever') as MockRetriever: |
|
|
mock_retriever = Mock() |
|
|
mock_retriever.retrieve.return_value = [] |
|
|
MockRetriever.return_value = mock_retriever |
|
|
|
|
|
orchestrator = Orchestrator() |
|
|
|
|
|
events = [] |
|
|
async for event in orchestrator.run_pipeline(["low-score"]): |
|
|
events.append(event) |
|
|
|
|
|
|
|
|
found_drop = False |
|
|
for event in events: |
|
|
message = str(event.get("message", "")).lower() |
|
|
reason = str(event.get("payload", {}).get("reason", "")).lower() |
|
|
status = str(event.get("payload", {}).get("status", "")).lower() |
|
|
|
|
|
if "dropped" in message or "dropped" in reason or "dropped" in status or "low fit score" in message or "low fit score" in reason: |
|
|
found_drop = True |
|
|
break |
|
|
|
|
|
assert found_drop, f"Should have found drop message" |