NimaBoscarino commited on
Commit
c3ede35
β€’
1 Parent(s): 25eadae

Large rewrite, simplification, new UI

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .idea/
2
+ .DS_Store
3
+ __pycache__/
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🌍
4
  colorFrom: purple
5
  colorTo: gray
6
  sdk: gradio
7
- sdk_version: 3.23.0
8
  app_file: app.py
9
  pinned: false
10
  license: gpl-3.0
 
4
  colorFrom: purple
5
  colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 3.24.0
8
  app_file: app.py
9
  pinned: false
10
  license: gpl-3.0
app.py CHANGED
@@ -1,4 +1,11 @@
1
  import gradio as gr
 
 
 
 
 
 
 
2
 
3
 
4
  theme = gr.themes.Default(primary_hue="blue").set(
@@ -7,9 +14,49 @@ theme = gr.themes.Default(primary_hue="blue").set(
7
  )
8
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  demo = gr.Blocks(theme=theme, css="""\
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  .gradio-container {
12
- width: 100%;
13
  }
14
 
15
  .margin-top {
@@ -26,19 +73,24 @@ demo = gr.Blocks(theme=theme, css="""\
26
  }
27
 
28
  .blue {
29
- /**
30
  background-image: url("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/substra-banner.png");
31
  background-size: cover;
32
- **/
33
- background-color: #223fb3;
34
  }
35
 
36
  .blue p {
37
  color: white !important;
38
  }
39
 
 
 
 
 
40
  .info-box {
41
  background: transparent !important;
 
 
 
 
42
  }
43
  """)
44
 
@@ -49,7 +101,7 @@ with demo:
49
  gr.Markdown("# Federated Learning with Substra")
50
  with gr.Row():
51
  with gr.Column(scale=1, elem_classes=["blue", "column"]):
52
- gr.Markdown("Here you can run a quick simulation of Federated Learning with Substra.")
53
  gr.Markdown("Check out the accompanying blog post to learn more.")
54
  with gr.Box(elem_classes=["info-box"]):
55
  gr.Markdown("""\
@@ -60,22 +112,23 @@ with demo:
60
  with gr.Column(scale=3, elem_classes=["white", "column"]):
61
  gr.Markdown("""\
62
  Data scientists doing medical research often face a shortage of high quality and diverse data to \
63
- effectively train models. This challenge can be overcome by securely allowing training on pro- tected \
64
- data through (Federated Learning). Substra is a Python based Federated Learning soft- ware that \
65
- enables researchers to easily train ML models on remote data regardless of the ML library they are \
66
- using or the data modality they are working with.\
67
  """)
68
- gr.Markdown("### Here we show an example of image data located in two different hospitals.")
69
  gr.Markdown("""\
70
- By playing with the distribution of data in the 2 simulated hospitals, you'll be able to compare how \
71
  the federated models compare with models trained on single datasets. The data used is from the \
72
- Camelyon17 dataset, a commonly used benchmark in the medical world that comes from this challenge. \
73
- The sample below shows normal cells on the left compared with cancer cells on the right.\
 
74
  """)
75
  gr.HTML("""
76
  <img
77
  src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/substra-tumor.png"
78
- style="padding: 20px 150px;"
79
  />
80
  """)
81
  gr.Markdown("""\
@@ -87,8 +140,21 @@ with demo:
87
  """)
88
 
89
  with gr.Row(elem_classes=["margin-top"]):
90
- gr.Slider()
91
- gr.Slider()
92
- gr.Button(value="Launch Experiment πŸš€")
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  demo.launch()
 
1
  import gradio as gr
2
+ import uuid
3
+ import asyncio
4
+
5
+ from substra_launcher import launch_substra_space
6
+ from huggingface_hub import HfApi
7
+
8
+ hf_api = HfApi()
9
 
10
 
11
  theme = gr.themes.Default(primary_hue="blue").set(
 
14
  )
15
 
16
 
17
+ async def launch_experiment(hospital_a, hospital_b):
18
+ experiment_id = str(uuid.uuid4())
19
+
20
+ asyncio.create_task(launch_substra_space(
21
+ hf_api=hf_api,
22
+ repo_id=experiment_id,
23
+ hospital_a=hospital_a,
24
+ hospital_b=hospital_b,
25
+ ))
26
+
27
+ url = f"https://hf.space/NimaBoscarino/{experiment_id}"
28
+
29
+ return (
30
+ gr.Button.update(interactive=False),
31
+ gr.Markdown.update(
32
+ visible=True,
33
+ value=f"Your experiment is available at [hf.space/NimaBoscarino/{experiment_id}]({url})!"
34
+ )
35
+ )
36
+
37
+
38
  demo = gr.Blocks(theme=theme, css="""\
39
+ @font-face {
40
+ font-family: "Didact Gothic";
41
+ src: url('https://huggingface.co/datasets/NimaBoscarino/assets/resolve/main/substra/DidactGothic-Regular.ttf') format('truetype');
42
+ }
43
+
44
+ @font-face {
45
+ font-family: "Inter";
46
+ src: url('https://huggingface.co/datasets/NimaBoscarino/assets/resolve/main/substra/Inter-Regular.ttf') format('truetype');
47
+ }
48
+
49
+ h1 {
50
+ font-family: "Didact Gothic";
51
+ font-size: 40px !important;
52
+ }
53
+
54
+ p {
55
+ font-family: "Inter";
56
+ }
57
+
58
  .gradio-container {
59
+ min-width: 100% !important;
60
  }
61
 
62
  .margin-top {
 
73
  }
74
 
75
  .blue {
 
76
  background-image: url("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/substra-banner.png");
77
  background-size: cover;
 
 
78
  }
79
 
80
  .blue p {
81
  color: white !important;
82
  }
83
 
84
+ .blue strong {
85
+ color: white !important;
86
+ }
87
+
88
  .info-box {
89
  background: transparent !important;
90
+ border-radius: 20px !important;
91
+ border-color: white !important;
92
+ border-width: 4px !important;
93
+ padding: 20px !important;
94
  }
95
  """)
96
 
 
101
  gr.Markdown("# Federated Learning with Substra")
102
  with gr.Row():
103
  with gr.Column(scale=1, elem_classes=["blue", "column"]):
104
+ gr.Markdown("Here you can run a **quick simulation of Federated Learning**.")
105
  gr.Markdown("Check out the accompanying blog post to learn more.")
106
  with gr.Box(elem_classes=["info-box"]):
107
  gr.Markdown("""\
 
112
  with gr.Column(scale=3, elem_classes=["white", "column"]):
113
  gr.Markdown("""\
114
  Data scientists doing medical research often face a shortage of high quality and diverse data to \
115
+ effectively train models. This challenge can be overcome by securely allowing training on protected \
116
+ data through Federated Learning. [Substra](https://docs.substra.org/) is a Python based Federated \
117
+ Learning software that enables researchers to easily train ML models on remote data regardless of the \
118
+ ML library they are using or the data type they are working with.
119
  """)
120
+ gr.Markdown("### Here we show an example of image data located in **two different hospitals**.")
121
  gr.Markdown("""\
122
+ By playing with the distribution of data in the two simulated hospitals, you'll be able to compare how \
123
  the federated models compare with models trained on single datasets. The data used is from the \
124
+ Camelyon17 dataset, a commonly used benchmark in the medical world that comes from \
125
+ [this challenge](https://camelyon17.grand-challenge.org/). The sample below shows normal cells on the \
126
+ left compared with cancer cells on the right.
127
  """)
128
  gr.HTML("""
129
  <img
130
  src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/substra-tumor.png"
131
+ style="height: 300px; margin: auto;"
132
  />
133
  """)
134
  gr.Markdown("""\
 
140
  """)
141
 
142
  with gr.Row(elem_classes=["margin-top"]):
143
+ hospital_a_slider = gr.Slider(
144
+ label="Percentage of positive samples in Hospital A",
145
+ value=50,
146
+ )
147
+ hospital_b_slider = gr.Slider(
148
+ label="Percentage of positive samples in Hospital B",
149
+ value=50,
150
+ )
151
+ launch_experiment_button = gr.Button(value="Launch Experiment πŸš€")
152
+ visit_experiment_text = gr.Markdown(visible=False)
153
+
154
+ launch_experiment_button.click(
155
+ fn=launch_experiment,
156
+ inputs=[hospital_a_slider, hospital_b_slider],
157
+ outputs=[launch_experiment_button, visit_experiment_text]
158
+ )
159
 
160
  demo.launch()
fonts/DidactGothic-Regular.ttf ADDED
Binary file (181 kB). View file
 
fonts/Inter-Regular.ttf ADDED
Binary file (748 kB). View file
 
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
- gradio==3.23.0
2
  pytest
3
  huggingface_hub
 
1
+ gradio
2
  pytest
3
  huggingface_hub
substra_launcher.py CHANGED
@@ -1,7 +1,10 @@
1
  from huggingface_hub import HfApi, RepoUrl
2
 
3
 
4
- def launch_substra_space(hf_api: HfApi, num_hospitals: int, repo_id: str) -> RepoUrl:
 
 
 
5
  repo_id = "NimaBoscarino/" + repo_id
6
 
7
  repo_url = hf_api.create_repo(
@@ -13,12 +16,13 @@ def launch_substra_space(hf_api: HfApi, num_hospitals: int, repo_id: str) -> Rep
13
  hf_api.upload_folder(
14
  repo_id=repo_id,
15
  repo_type="space",
16
- folder_path="substra_template/"
17
  )
18
 
19
  ENV_FILE = f"""\
20
- SUBSTRA_NUM_HOSPITALS={num_hospitals}
21
- """
 
22
 
23
  hf_api.upload_file(
24
  repo_id=repo_id,
 
1
  from huggingface_hub import HfApi, RepoUrl
2
 
3
 
4
+ async def launch_substra_space(
5
+ hf_api: HfApi, repo_id: str,
6
+ hospital_a: int, hospital_b: int,
7
+ ) -> RepoUrl:
8
  repo_id = "NimaBoscarino/" + repo_id
9
 
10
  repo_url = hf_api.create_repo(
 
16
  hf_api.upload_folder(
17
  repo_id=repo_id,
18
  repo_type="space",
19
+ folder_path="./substra_template/"
20
  )
21
 
22
  ENV_FILE = f"""\
23
+ SUBSTRA_ORG1_DISTR={hospital_a / 100}
24
+ SUBSTRA_ORG2_DISTR={hospital_b / 100}\
25
+ """
26
 
27
  hf_api.upload_file(
28
  repo_id=repo_id,
substra_template/Dockerfile CHANGED
@@ -1,31 +1,3 @@
1
- FROM python:3.10
2
 
3
- # Set the working directory to /code
4
- WORKDIR /code
5
-
6
- # Copy the current directory contents into the container at /code
7
- COPY ./requirements.txt /code/requirements.txt
8
- COPY ./mlflow-2.1.2.dev0-py3-none-any.whl /code/mlflow-2.1.2.dev0-py3-none-any.whl
9
-
10
- # Install requirements.txt
11
- RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
12
- RUN chmod -R 777 /usr/local/lib/python3.10/site-packages/
13
-
14
- # Set up a new user named "user" with user ID 1000
15
- RUN useradd -m -u 1000 user
16
- # Switch to the "user" user
17
- USER user
18
- # Set home to the user's home directory
19
- ENV HOME=/home/user \
20
- PATH=/home/user/.local/bin:$PATH
21
-
22
- # Set the working directory to the user's home directory
23
- WORKDIR $HOME/app
24
-
25
- # Copy the current directory contents into the container at $HOME/app setting the owner to the user
26
- COPY --chown=user . $HOME/app
27
-
28
- RUN chmod -R 777 $HOME/app/
29
-
30
- EXPOSE 7860
31
- CMD ["bash", "run.sh"]
 
1
+ FROM nimaboscarino/substra-trainer:latest
2
 
3
+ CMD ["bash", "docker-run.sh"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
substra_template/README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Substra Trainer
3
+ emoji: πŸš€
4
+ colorFrom: red
5
+ colorTo: gray
6
+ sdk: docker
7
+ pinned: false
8
+ ---
9
+
10
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
substra_template/__init__.py DELETED
File without changes
substra_template/mlflow-2.1.2.dev0-py3-none-any.whl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e1f15359f38fab62f43a7a3d51f56c86c882a4cb1c3dcabeda6daf5dc47f1613
3
- size 17638174
 
 
 
 
substra_template/mlflow_live_performances.py DELETED
@@ -1,45 +0,0 @@
1
- import pandas as pd
2
- import json
3
- from pathlib import Path
4
- from mlflow import log_metric
5
- import time
6
- import os
7
- from glob import glob
8
-
9
- TIMEOUT = 240 # Number of seconds to stop the script after the last update of the json file
10
- POLLING_FREQUENCY = 10 # Try to read the updates in the file every 10 seconds
11
-
12
- # Wait for the file to be found
13
- start = time.time()
14
- while not len(glob(str(Path("local-worker") / "live_performances" / "*" / "performances.json"))) > 0:
15
- time.sleep(POLLING_FREQUENCY)
16
- if time.time() - start >= TIMEOUT:
17
- raise TimeoutError("The performance file does not exist, maybe no test task has been executed yet.")
18
-
19
- path_to_json = Path(glob(str(Path("local-worker") / "live_performances" / "*" / "performances.json"))[0])
20
-
21
- logged_rows = []
22
- last_update = time.time()
23
-
24
- while (time.time() - last_update) <= TIMEOUT:
25
-
26
- if last_update == os.path.getmtime(str(path_to_json)):
27
- time.sleep(POLLING_FREQUENCY)
28
- continue
29
-
30
- last_update = os.path.getmtime(str(path_to_json))
31
-
32
- time.sleep(1) # Waiting for the json to be fully written
33
- dict_perf = json.load(path_to_json.open())
34
-
35
- df = pd.DataFrame(dict_perf)
36
-
37
- for _, row in df.iterrows():
38
- if row["testtask_key"] in logged_rows:
39
- continue
40
-
41
- logged_rows.append(row["testtask_key"])
42
-
43
- step = int(row["round_idx"]) if row["round_idx"] is not None else int(row["testtask_rank"])
44
-
45
- log_metric(f"{row['metric_name']}_{row['worker']}", row["performance"], step)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
substra_template/requirements.txt DELETED
@@ -1,13 +0,0 @@
1
- gradio
2
- substrafl
3
- datasets
4
- torch
5
- torchvision
6
- scikit-learn
7
- numpy==1.23.0
8
- Pillow
9
- transformers
10
- matplotlib
11
- pandas
12
- python-dotenv
13
- ./mlflow-2.1.2.dev0-py3-none-any.whl
 
 
 
 
 
 
 
 
 
 
 
 
 
 
substra_template/run.sh DELETED
@@ -1,13 +0,0 @@
1
- PYTHONPATH=$HOME/app python run_compute_plan.py &
2
- PYTHONPATH=$HOME/app python mlflow_live_performances.py &
3
-
4
- SITE_PACKAGES=$(python -c 'import site; print(site.getsitepackages()[0])')
5
-
6
- # Fix for the UI code being embedded in an iframe
7
- # Replace window.parent.location.origin with *
8
- for i in $SITE_PACKAGES/mlflow/server/js/build/static/js/*.js; do
9
- sed -i 's/window\.parent\.location\.origin)/"*")/' $i
10
- sed 's/window.top?.location.href || window.location.href/window.location.href/g' -i $i
11
- done
12
-
13
- mlflow ui --port 7860 --host 0.0.0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
substra_template/run_compute_plan.py DELETED
@@ -1,40 +0,0 @@
1
- from substra_helpers.substra_runner import SubstraRunner, algo_generator
2
- from substra_helpers.model import CNN
3
- from substra_helpers.dataset import TorchDataset
4
- from substrafl.strategies import FedAvg
5
-
6
- import torch
7
-
8
- from dotenv import load_dotenv
9
- import os
10
- load_dotenv()
11
-
12
- NUM_CLIENTS = int(os.environ["SUBSTRA_NUM_HOSPITALS"])
13
-
14
- seed = 42
15
- torch.manual_seed(seed)
16
- model = CNN()
17
- optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
18
- criterion = torch.nn.CrossEntropyLoss()
19
-
20
- runner = SubstraRunner(num_clients=NUM_CLIENTS)
21
- runner.set_up_clients()
22
- runner.prepare_data()
23
- runner.register_data()
24
- runner.register_metric()
25
-
26
- runner.algorithm = algo_generator(
27
- model=model,
28
- criterion=criterion,
29
- optimizer=optimizer,
30
- index_generator=runner.index_generator,
31
- dataset=TorchDataset,
32
- seed=seed
33
- )()
34
-
35
- runner.strategy = FedAvg()
36
-
37
- runner.set_aggregation()
38
- runner.set_testing()
39
-
40
- runner.run_compute_plan()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
substra_template/substra_helpers/__init__.py DELETED
File without changes
substra_template/substra_helpers/dataset.py DELETED
@@ -1,29 +0,0 @@
1
- import torch
2
- from torch.utils import data
3
- import torch.nn.functional as F
4
- import numpy as np
5
-
6
-
7
- class TorchDataset(data.Dataset):
8
- def __init__(self, datasamples, is_inference: bool):
9
- self.x = datasamples["image"]
10
- self.y = datasamples["label"]
11
- self.is_inference = is_inference
12
-
13
- def __getitem__(self, idx):
14
-
15
- if self.is_inference:
16
- x = torch.FloatTensor(np.array(self.x[idx])[None, ...]) / 255
17
- return x
18
-
19
- else:
20
- x = torch.FloatTensor(np.array(self.x[idx])[None, ...]) / 255
21
-
22
- y = torch.tensor(self.y[idx]).type(torch.int64)
23
- y = F.one_hot(y, 10)
24
- y = y.type(torch.float32)
25
-
26
- return x, y
27
-
28
- def __len__(self):
29
- return len(self.x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
substra_template/substra_helpers/dataset_assets/description.md DELETED
@@ -1,18 +0,0 @@
1
- # Mnist
2
-
3
- This dataset is [THE MNIST DATABASE of handwritten digits](http://yann.lecun.com/exdb/mnist/).
4
-
5
- The target is the number (0 -> 9) represented by the pixels.
6
-
7
- ## Data repartition
8
-
9
- ### Train and test
10
-
11
- ### Split data between organizations
12
-
13
- ## Opener usage
14
-
15
- The opener exposes 2 methods:
16
-
17
- - `get_data` returns a dictionary containing the images and the labels as numpy arrays
18
- - `fake_data` returns a fake data sample of images and labels in a dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
substra_template/substra_helpers/dataset_assets/opener.py DELETED
@@ -1,20 +0,0 @@
1
- import numpy as np
2
- import substratools as tools
3
- from datasets import load_from_disk
4
- from transformers import ImageFeatureExtractionMixin
5
-
6
-
7
- class MnistOpener(tools.Opener):
8
- def fake_data(self, n_samples=None):
9
- N_SAMPLES = n_samples if n_samples and n_samples <= 100 else 100
10
-
11
- fake_images = np.random.randint(256, size=(N_SAMPLES, 28, 28))
12
-
13
- fake_labels = np.random.randint(10, size=N_SAMPLES)
14
-
15
- data = {"image": fake_images, "label": fake_labels}
16
-
17
- return data
18
-
19
- def get_data(self, folders):
20
- return load_from_disk(folders[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
substra_template/substra_helpers/model.py DELETED
@@ -1,25 +0,0 @@
1
- from torch import nn
2
- import torch.nn.functional as F
3
-
4
-
5
- # TODO: Would be cool to use a simple Transformer model... then I could use the Trainer API πŸ‘€
6
- class CNN(nn.Module):
7
- def __init__(self):
8
- super(CNN, self).__init__()
9
- self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
10
- self.conv2 = nn.Conv2d(32, 32, kernel_size=5)
11
- self.conv3 = nn.Conv2d(32, 64, kernel_size=5)
12
- self.fc1 = nn.Linear(3 * 3 * 64, 256)
13
- self.fc2 = nn.Linear(256, 10)
14
-
15
- def forward(self, x, eval=False):
16
- x = F.relu(self.conv1(x))
17
- x = F.relu(F.max_pool2d(self.conv2(x), 2))
18
- x = F.dropout(x, p=0.5, training=not eval)
19
- x = F.relu(F.max_pool2d(self.conv3(x), 2))
20
- x = F.dropout(x, p=0.5, training=not eval)
21
- x = x.view(-1, 3 * 3 * 64)
22
- x = F.relu(self.fc1(x))
23
- x = F.dropout(x, p=0.5, training=not eval)
24
- x = self.fc2(x)
25
- return F.log_softmax(x, dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
substra_template/substra_helpers/substra_runner.py DELETED
@@ -1,194 +0,0 @@
1
- import pathlib
2
- import shutil
3
- from typing import Optional, List
4
-
5
- from substra import Client, BackendType
6
-
7
- from substra.sdk.schemas import (
8
- DatasetSpec,
9
- Permissions,
10
- DataSampleSpec
11
- )
12
-
13
- from substrafl.strategies import Strategy
14
- from substrafl.dependency import Dependency
15
- from substrafl.remote.register import add_metric
16
- from substrafl.index_generator import NpIndexGenerator
17
- from substrafl.algorithms.pytorch import TorchFedAvgAlgo
18
-
19
- from substrafl.nodes import TrainDataNode, AggregationNode, TestDataNode
20
- from substrafl.evaluation_strategy import EvaluationStrategy
21
-
22
- from substrafl.experiment import execute_experiment
23
- from substra.sdk.models import ComputePlan
24
-
25
- from datasets import load_dataset, Dataset
26
- from sklearn.metrics import accuracy_score
27
- import numpy as np
28
-
29
- import torch
30
-
31
-
32
- class SubstraRunner:
33
- def __init__(self, num_clients: int):
34
- self.num_clients = num_clients
35
- self.clients = {}
36
- self.algo_provider: Optional[Client] = None
37
-
38
- self.datasets: List[Dataset] = []
39
- self.test_dataset: Optional[Dataset] = None
40
- self.path = pathlib.Path(__file__).parent.resolve()
41
-
42
- self.dataset_keys = {}
43
- self.train_data_sample_keys = {}
44
- self.test_data_sample_keys = {}
45
-
46
- self.metric_key: Optional[str] = None
47
-
48
- NUM_UPDATES = 100
49
- BATCH_SIZE = 32
50
-
51
- self.index_generator = NpIndexGenerator(
52
- batch_size=BATCH_SIZE,
53
- num_updates=NUM_UPDATES,
54
- )
55
-
56
- self.algorithm: Optional[TorchFedAvgAlgo] = None
57
- self.strategy: Optional[Strategy] = None
58
-
59
- self.aggregation_node: Optional[AggregationNode] = None
60
- self.train_data_nodes = list()
61
- self.test_data_nodes = list()
62
- self.eval_strategy: Optional[EvaluationStrategy] = None
63
-
64
- self.NUM_ROUNDS = 3
65
- self.compute_plan: Optional[ComputePlan] = None
66
-
67
- self.experiment_folder = self.path / "experiment_summaries"
68
-
69
- def set_up_clients(self):
70
- self.algo_provider = Client(backend_type=BackendType.LOCAL_SUBPROCESS)
71
-
72
- self.clients = {
73
- c.organization_info().organization_id: c
74
- for c in [Client(backend_type=BackendType.LOCAL_SUBPROCESS) for _ in range(self.num_clients - 1)]
75
- }
76
-
77
- def prepare_data(self):
78
- dataset = load_dataset("mnist", split="train").shuffle()
79
- self.datasets = [dataset.shard(num_shards=self.num_clients - 1, index=i) for i in range(self.num_clients - 1)]
80
-
81
- self.test_dataset = load_dataset("mnist", split="test")
82
-
83
- data_path = self.path / "data"
84
- if data_path.exists() and data_path.is_dir():
85
- shutil.rmtree(data_path)
86
-
87
- for i, client_id in enumerate(self.clients):
88
- ds = self.datasets[i]
89
- ds.save_to_disk(data_path / client_id / "train")
90
- self.test_dataset.save_to_disk(data_path / client_id / "test")
91
-
92
- def register_data(self):
93
- for client_id, client in self.clients.items():
94
- permissions_dataset = Permissions(public=False, authorized_ids=[
95
- self.algo_provider.organization_info().organization_id
96
- ])
97
-
98
- dataset = DatasetSpec(
99
- name="MNIST",
100
- type="npy",
101
- data_opener=self.path / pathlib.Path("dataset_assets/opener.py"),
102
- description=self.path / pathlib.Path("dataset_assets/description.md"),
103
- permissions=permissions_dataset,
104
- logs_permission=permissions_dataset,
105
- )
106
- self.dataset_keys[client_id] = client.add_dataset(dataset)
107
- assert self.dataset_keys[client_id], "Missing dataset key"
108
-
109
- self.train_data_sample_keys[client_id] = client.add_data_sample(DataSampleSpec(
110
- data_manager_keys=[self.dataset_keys[client_id]],
111
- path=self.path / "data" / client_id / "train",
112
- ))
113
-
114
- data_sample = DataSampleSpec(
115
- data_manager_keys=[self.dataset_keys[client_id]],
116
- path=self.path / "data" / client_id / "test",
117
- )
118
- self.test_data_sample_keys[client_id] = client.add_data_sample(data_sample)
119
-
120
- def register_metric(self):
121
- permissions_metric = Permissions(
122
- public=False,
123
- authorized_ids=[
124
- self.algo_provider.organization_info().organization_id
125
- ] + list(self.clients.keys())
126
- )
127
-
128
- metric_deps = Dependency(pypi_dependencies=["numpy==1.23.1", "scikit-learn==1.1.1"])
129
-
130
- def accuracy(datasamples, predictions_path):
131
- y_true = datasamples["label"]
132
- y_pred = np.load(predictions_path)
133
-
134
- return accuracy_score(y_true, np.argmax(y_pred, axis=1))
135
-
136
- self.metric_key = add_metric(
137
- client=self.algo_provider,
138
- metric_function=accuracy,
139
- permissions=permissions_metric,
140
- dependencies=metric_deps,
141
- )
142
-
143
- def set_aggregation(self):
144
- self.aggregation_node = AggregationNode(self.algo_provider.organization_info().organization_id)
145
-
146
- for org_id in self.clients:
147
- train_data_node = TrainDataNode(
148
- organization_id=org_id,
149
- data_manager_key=self.dataset_keys[org_id],
150
- data_sample_keys=[self.train_data_sample_keys[org_id]],
151
- )
152
- self.train_data_nodes.append(train_data_node)
153
-
154
- def set_testing(self):
155
- for org_id in self.clients:
156
- test_data_node = TestDataNode(
157
- organization_id=org_id,
158
- data_manager_key=self.dataset_keys[org_id],
159
- test_data_sample_keys=[self.test_data_sample_keys[org_id]],
160
- metric_keys=[self.metric_key],
161
- )
162
- self.test_data_nodes.append(test_data_node)
163
-
164
- self.eval_strategy = EvaluationStrategy(test_data_nodes=self.test_data_nodes, rounds=1)
165
-
166
- def run_compute_plan(self):
167
- algo_deps = Dependency(pypi_dependencies=["numpy==1.23.1", "torch==1.11.0"])
168
-
169
- self.compute_plan = execute_experiment(
170
- client=self.algo_provider,
171
- algo=self.algorithm,
172
- strategy=self.strategy,
173
- train_data_nodes=self.train_data_nodes,
174
- evaluation_strategy=self.eval_strategy,
175
- aggregation_node=self.aggregation_node,
176
- num_rounds=self.NUM_ROUNDS,
177
- experiment_folder=self.experiment_folder,
178
- dependencies=algo_deps,
179
- )
180
-
181
-
182
- def algo_generator(model, criterion, optimizer, index_generator, dataset, seed):
183
- class MyAlgo(TorchFedAvgAlgo):
184
- def __init__(self):
185
- super().__init__(
186
- model=model,
187
- criterion=criterion,
188
- optimizer=optimizer,
189
- index_generator=index_generator,
190
- dataset=dataset,
191
- seed=seed,
192
- )
193
-
194
- return MyAlgo