import pytest from unittest.mock import Mock, call from datasets import Dataset from substra_template.substra_runner import SubstraRunner class TestSubstraRunner: @pytest.fixture def mock_substra_client_class(self, monkeypatch): mock_substra_client_class = Mock() monkeypatch.setattr("substra_template.substra_runner.Client", mock_substra_client_class) return mock_substra_client_class @pytest.fixture def mock_load_dataset(self, monkeypatch): mock_load_dataset = Mock() monkeypatch.setattr("substra_template.substra_runner.load_dataset", mock_load_dataset) return mock_load_dataset def test_set_up_clients(self, mock_substra_client_class): runner = SubstraRunner() runner.set_up_clients() mock_substra_client_class.assert_called() def test_prepare_data(self, mock_load_dataset): runner = SubstraRunner() runner.prepare_data() mock_load_dataset.assert_has_calls(calls=[ call("mnist", split="train"), call("mnist", split="test"), ], any_order=True) assert len(runner.datasets) == runner.num_clients - 1 def test_register_data(self, mock_load_dataset): runner = SubstraRunner() runner.datasets = [Dataset.from_dict({}) for _ in range(runner.num_clients - 1)] runner.register_data() def test_register_metric(self): runner = SubstraRunner() runner.set_up_clients() runner.register_metric() def test_set_aggregation(self): pass def test_set_testing(self): pass