File size: 1,612 Bytes
04a30fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import pytest
from unittest.mock import Mock, call
from datasets import Dataset

from substra_template.substra_runner import SubstraRunner


class TestSubstraRunner:
    @pytest.fixture
    def mock_substra_client_class(self, monkeypatch):
        mock_substra_client_class = Mock()
        monkeypatch.setattr("substra_template.substra_runner.Client", mock_substra_client_class)

        return mock_substra_client_class

    @pytest.fixture
    def mock_load_dataset(self, monkeypatch):
        mock_load_dataset = Mock()
        monkeypatch.setattr("substra_template.substra_runner.load_dataset", mock_load_dataset)

        return mock_load_dataset

    def test_set_up_clients(self, mock_substra_client_class):
        runner = SubstraRunner()
        runner.set_up_clients()

        mock_substra_client_class.assert_called()

    def test_prepare_data(self, mock_load_dataset):
        runner = SubstraRunner()
        runner.prepare_data()

        mock_load_dataset.assert_has_calls(calls=[
            call("mnist", split="train"),
            call("mnist", split="test"),
        ], any_order=True)

        assert len(runner.datasets) == runner.num_clients - 1

    def test_register_data(self, mock_load_dataset):
        runner = SubstraRunner()
        runner.datasets = [Dataset.from_dict({}) for _ in range(runner.num_clients - 1)]

        runner.register_data()

    def test_register_metric(self):
        runner = SubstraRunner()
        runner.set_up_clients()
        runner.register_metric()

    def test_set_aggregation(self):
        pass

    def test_set_testing(self):
        pass