| import flwr as fl |
| import numpy as np |
| import json |
| import pandas as pd |
| from model import create_model |
| import sys |
|
|
| results_log = [] |
|
|
|
|
| def get_evaluate_fn(file_path): |
| data = pd.read_csv(file_path) |
| X = data.iloc[:, :-1].values |
| y = data.iloc[:, -1].values |
| X = (X - X.mean(axis=0)) / (X.std(axis=0) + 1e-8) |
| split = int(len(X) * 0.8) |
| X_test, y_test = X[split:], y[split:] |
| input_shape = X.shape[1] |
|
|
| |
| if "diabetes" in file_path: |
| model_name = "diabetes_model.keras" |
| elif "heart" in file_path: |
| model_name = "heart_model.keras" |
| else: |
| model_name = "global_model.keras" |
|
|
| def evaluate(server_round, parameters, config): |
| model = create_model(input_shape=input_shape) |
| model.set_weights(parameters) |
| loss, accuracy = model.evaluate(X_test, y_test, verbose=0) |
|
|
| results_log.append({ |
| "round": server_round, |
| "loss": round(loss, 4), |
| "accuracy": round(accuracy, 4) |
| }) |
| with open("training_results.json", "w") as f: |
| json.dump(results_log, f) |
|
|
| if server_round == 3: |
| model.save(model_name) |
| print(f"Model saved as {model_name}") |
|
|
| return loss, {"accuracy": accuracy} |
|
|
| return evaluate |
|
|
|
|
| def main(file_path="diabetes.csv"): |
| global results_log |
| results_log = [] |
|
|
| with open("training_results.json", "w") as f: |
| json.dump([], f) |
|
|
| strategy = fl.server.strategy.FedAvg( |
| min_fit_clients=2, |
| min_evaluate_clients=2, |
| min_available_clients=2, |
| evaluate_fn=get_evaluate_fn(file_path), |
| ) |
|
|
| fl.server.start_server( |
| server_address="localhost:8080", |
| config=fl.server.ServerConfig(num_rounds=3), |
| strategy=strategy, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| file_path = sys.argv[1] if len(sys.argv) > 1 else "diabetes.csv" |
| main(file_path) |