alisrbdni commited on
Commit
4da9684
1 Parent(s): 3437167

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -0
app.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%writefile app.py
2
+
3
+ import streamlit as st
4
+ import matplotlib.pyplot as plt
5
+ import torch
6
+ from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification, AdamW
7
+ from datasets import load_dataset
8
+ from evaluate import load as load_metric
9
+ from torch.utils.data import DataLoader
10
+ import random
11
+
12
+ DEVICE = torch.device("cpu")
13
+ NUM_ROUNDS = 3
14
+
15
+ def load_data(dataset_name):
16
+ raw_datasets = load_dataset(dataset_name)
17
+ raw_datasets = raw_datasets.shuffle(seed=42)
18
+ del raw_datasets["unsupervised"]
19
+
20
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
21
+
22
+ def tokenize_function(examples):
23
+ return tokenizer(examples["text"], truncation=True)
24
+
25
+ train_population = random.sample(range(len(raw_datasets["train"])), 20)
26
+ test_population = random.sample(range(len(raw_datasets["test"])), 20)
27
+
28
+ tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
29
+ tokenized_datasets["train"] = tokenized_datasets["train"].select(train_population)
30
+ tokenized_datasets["test"] = tokenized_datasets["test"].select(test_population)
31
+
32
+ tokenized_datasets = tokenized_datasets.remove_columns("text")
33
+ tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
34
+
35
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
36
+ trainloader = DataLoader(tokenized_datasets["train"], shuffle=True, batch_size=32, collate_fn=data_collator)
37
+ testloader = DataLoader(tokenized_datasets["test"], batch_size=32, collate_fn=data_collator)
38
+
39
+ return trainloader, testloader
40
+
41
+ def train(net, trainloader, epochs):
42
+ optimizer = AdamW(net.parameters(), lr=5e-5)
43
+ net.train()
44
+ for _ in range(epochs):
45
+ for batch in trainloader:
46
+ batch = {k: v.to(DEVICE) for k, v in batch.items()}
47
+ outputs = net(**batch)
48
+ loss = outputs.loss
49
+ loss.backward()
50
+ optimizer.step()
51
+ optimizer.zero_grad()
52
+
53
+ def test(net, testloader):
54
+ metric = load_metric("accuracy")
55
+ loss = 0
56
+ net.eval()
57
+ for batch in testloader:
58
+ batch = {k: v.to(DEVICE) for k, v in batch.items()}
59
+ with torch.no_grad():
60
+ outputs = net(**batch)
61
+ logits = outputs.logits
62
+ loss += outputs.loss.item()
63
+ predictions = torch.argmax(logits, dim=-1)
64
+ metric.add_batch(predictions=predictions, references=batch["labels"])
65
+ loss /= len(testloader.dataset)
66
+ accuracy = metric.compute()["accuracy"]
67
+ return loss, accuracy
68
+
69
+ net = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2).to(DEVICE)
70
+
71
+ def main():
72
+ st.write("## Federated Learning with dynamic models and datasets for mobile devices")
73
+ dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"])
74
+ model_name = st.selectbox("Model", ["bert-base-uncased", "distilbert-base-uncased"])
75
+ NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2)
76
+ NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3)
77
+
78
+ trainloader, testloader = load_data(dataset_name)
79
+
80
+ if st.button("Start Training"):
81
+ round_losses = []
82
+ round_accuracies = [] # Store accuracy values for each round
83
+ for round_num in range(1, NUM_ROUNDS + 1):
84
+ st.write(f"## Round {round_num}")
85
+
86
+ st.write("### Training Metrics for Each Client")
87
+ for client in range(1, NUM_CLIENTS + 1):
88
+ client_loss, client_accuracy = test(net, testloader) # Placeholder for actual client metrics
89
+ st.write(f"Client {client}: Loss: {client_loss}, Accuracy: {client_accuracy}")
90
+
91
+ st.write("### Accuracy Over Rounds")
92
+ round_accuracies.append(client_accuracy) # Append the accuracy for this round
93
+ plt.plot(range(1, round_num + 1), round_accuracies, marker='o') # Plot accuracy over rounds
94
+ plt.xlabel("Round")
95
+ plt.ylabel("Accuracy")
96
+ plt.title("Accuracy Over Rounds")
97
+ st.pyplot()
98
+
99
+ st.write("### Loss Over Rounds")
100
+ loss_value = random.random() # Placeholder for loss values
101
+ round_losses.append(loss_value)
102
+ rounds = list(range(1, round_num + 1))
103
+ plt.plot(rounds, round_losses)
104
+ plt.xlabel("Round")
105
+ plt.ylabel("Loss")
106
+ plt.title("Loss Over Rounds")
107
+ st.pyplot()
108
+
109
+ st.success(f"Round {round_num} completed successfully!")
110
+
111
+ else:
112
+ st.write("Click the 'Start Training' button to start the training process.")
113
+
114
+ if __name__ == "__main__":
115
+ main()