File size: 2,556 Bytes
4962437
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import pytest
from unittest.mock import MagicMock, patch
from swarms.worker.worker_node import WorkerNodeInitializer, WorkerNode  # replace your_module with actual module name


# Mock Tool for testing
class MockTool(Tool):
    pass

# Fixture for llm
@pytest.fixture
def mock_llm():
    return MagicMock()

# Fixture for vectorstore
@pytest.fixture
def mock_vectorstore():
    return MagicMock()

# Fixture for Tools
@pytest.fixture
def mock_tools():
    return [MockTool(), MockTool(), MockTool()]

# Fixture for WorkerNodeInitializer
@pytest.fixture
def worker_node(mock_llm, mock_tools, mock_vectorstore):
    return WorkerNodeInitializer(llm=mock_llm, tools=mock_tools, vectorstore=mock_vectorstore)

# Fixture for WorkerNode
@pytest.fixture
def mock_worker_node():
    return WorkerNode(openai_api_key="test_api_key")

# WorkerNodeInitializer Tests
def test_worker_node_init(worker_node):
    assert worker_node.llm is not None
    assert worker_node.tools is not None
    assert worker_node.vectorstore is not None

def test_worker_node_create_agent(worker_node):
    with patch.object(AutoGPT, 'from_llm_and_tools') as mock_method:
        worker_node.create_agent()
        mock_method.assert_called_once()

def test_worker_node_add_tool(worker_node):
    initial_tools_count = len(worker_node.tools)
    new_tool = MockTool()
    worker_node.add_tool(new_tool)
    assert len(worker_node.tools) == initial_tools_count + 1

def test_worker_node_run(worker_node):
    with patch.object(worker_node.agent, 'run') as mock_run:
        worker_node.run(prompt="test prompt")
        mock_run.assert_called_once()

# WorkerNode Tests
def test_worker_node_llm(mock_worker_node):
    with patch.object(mock_worker_node, 'initialize_llm') as mock_method:
        mock_worker_node.initialize_llm(llm_class=MagicMock(), temperature=0.5)
        mock_method.assert_called_once()

def test_worker_node_tools(mock_worker_node):
    with patch.object(mock_worker_node, 'initialize_tools') as mock_method:
        mock_worker_node.initialize_tools(llm_class=MagicMock())
        mock_method.assert_called_once()

def test_worker_node_vectorstore(mock_worker_node):
    with patch.object(mock_worker_node, 'initialize_vectorstore') as mock_method:
        mock_worker_node.initialize_vectorstore()
        mock_method.assert_called_once()

def test_worker_node_create_worker_node(mock_worker_node):
    with patch.object(mock_worker_node, 'create_worker_node') as mock_method:
        mock_worker_node.create_worker_node()
        mock_method.assert_called_once()