substra / tests /test_substra_runner.py
NimaBoscarino's picture
WIP: Substra orchestrator
04a30fc
raw
history blame
1.61 kB
import pytest
from unittest.mock import Mock, call
from datasets import Dataset
from substra_template.substra_runner import SubstraRunner
class TestSubstraRunner:
@pytest.fixture
def mock_substra_client_class(self, monkeypatch):
mock_substra_client_class = Mock()
monkeypatch.setattr("substra_template.substra_runner.Client", mock_substra_client_class)
return mock_substra_client_class
@pytest.fixture
def mock_load_dataset(self, monkeypatch):
mock_load_dataset = Mock()
monkeypatch.setattr("substra_template.substra_runner.load_dataset", mock_load_dataset)
return mock_load_dataset
def test_set_up_clients(self, mock_substra_client_class):
runner = SubstraRunner()
runner.set_up_clients()
mock_substra_client_class.assert_called()
def test_prepare_data(self, mock_load_dataset):
runner = SubstraRunner()
runner.prepare_data()
mock_load_dataset.assert_has_calls(calls=[
call("mnist", split="train"),
call("mnist", split="test"),
], any_order=True)
assert len(runner.datasets) == runner.num_clients - 1
def test_register_data(self, mock_load_dataset):
runner = SubstraRunner()
runner.datasets = [Dataset.from_dict({}) for _ in range(runner.num_clients - 1)]
runner.register_data()
def test_register_metric(self):
runner = SubstraRunner()
runner.set_up_clients()
runner.register_metric()
def test_set_aggregation(self):
pass
def test_set_testing(self):
pass