|
"""model.py module.""" |
|
|
|
from typing import Dict, List, Optional, Tuple |
|
import tensorflow as tf |
|
import numpy as np |
|
import logging |
|
import time |
|
from ..api.client import FederatedHTTPClient |
|
from .data_handler import FinancialDataHandler |
|
|
|
class FederatedClient: |
|
def __init__(self, client_id: str, config: Dict, server_url: Optional[str] = None): |
|
"""Initialize the federated client.""" |
|
self.client_id = str(client_id) |
|
self.config = config.get('client', {}) |
|
self.model = self._build_model() |
|
self.data_handler = FinancialDataHandler(config) |
|
|
|
|
|
self.server_url = server_url or self.config.get('server_url', 'http://localhost:8080') |
|
self.http_client = FederatedHTTPClient(self.server_url, self.client_id) |
|
|
|
|
|
self.registered = False |
|
self.current_round = 0 |
|
|
|
def start(self): |
|
"""Start the federated client process with server communication.""" |
|
logger = logging.getLogger(__name__) |
|
logger.info(f"Client {self.client_id} starting...") |
|
|
|
try: |
|
|
|
if not self.http_client.wait_for_server(): |
|
raise ConnectionError(f"Cannot connect to server at {self.server_url}") |
|
|
|
|
|
self._register_with_server() |
|
|
|
|
|
self._federated_learning_loop() |
|
|
|
except Exception as e: |
|
logger.error(f"Error during client execution: {str(e)}") |
|
raise |
|
finally: |
|
self.http_client.close() |
|
|
|
def _register_with_server(self): |
|
"""Register this client with the federated server""" |
|
logger = logging.getLogger(__name__) |
|
|
|
try: |
|
|
|
X, y = self._generate_dummy_data() |
|
|
|
client_info = { |
|
'dataset_size': len(X), |
|
'model_params': self.model.count_params(), |
|
'capabilities': ['training', 'inference'] |
|
} |
|
|
|
response = self.http_client.register(client_info) |
|
self.registered = True |
|
|
|
logger.info(f"Successfully registered with server") |
|
logger.info(f"Dataset size: {client_info['dataset_size']}") |
|
logger.info(f"Model parameters: {client_info['model_params']:,}") |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to register with server: {str(e)}") |
|
raise |
|
|
|
def _federated_learning_loop(self): |
|
"""Main federated learning loop""" |
|
logger = logging.getLogger(__name__) |
|
|
|
while True: |
|
try: |
|
|
|
status = self.http_client.get_training_status() |
|
|
|
if not status.get('training_active', True): |
|
logger.info("Training completed on server") |
|
break |
|
|
|
server_round = status.get('current_round', 0) |
|
|
|
if server_round > self.current_round: |
|
self._participate_in_round(server_round) |
|
self.current_round = server_round |
|
|
|
time.sleep(5) |
|
|
|
except Exception as e: |
|
logger.error(f"Error in federated learning loop: {str(e)}") |
|
time.sleep(10) |
|
|
|
def _participate_in_round(self, round_num: int): |
|
"""Participate in a federated learning round""" |
|
logger = logging.getLogger(__name__) |
|
logger.info(f"Participating in round {round_num}") |
|
|
|
try: |
|
|
|
model_response = self.http_client.get_global_model() |
|
global_weights = model_response.get('model_weights') |
|
|
|
if global_weights: |
|
self.set_weights(global_weights) |
|
logger.info("Updated local model with global weights") |
|
|
|
|
|
X, y = self._generate_dummy_data() |
|
logger.info(f"Training on {len(X)} samples") |
|
|
|
|
|
history = self.train_local((X, y)) |
|
|
|
|
|
metrics = { |
|
'dataset_size': len(X), |
|
'final_loss': history['loss'][-1] if history['loss'] else 0.0, |
|
'epochs_trained': len(history['loss']), |
|
'round': round_num |
|
} |
|
|
|
|
|
local_weights = self.get_weights() |
|
self.http_client.submit_model_update(local_weights, metrics) |
|
|
|
logger.info(f"Round {round_num} completed - Final loss: {metrics['final_loss']:.4f}") |
|
|
|
except Exception as e: |
|
logger.error(f"Error in round {round_num}: {str(e)}") |
|
raise |
|
|
|
def _generate_dummy_data(self): |
|
"""Generate dummy data for testing.""" |
|
try: |
|
|
|
return self.data_handler.generate_synthetic_data(100) |
|
except Exception: |
|
|
|
num_samples = 100 |
|
input_dim = 32 |
|
|
|
|
|
X = tf.random.normal((num_samples, input_dim)) |
|
|
|
y = tf.reduce_sum(X, axis=1, keepdims=True) |
|
|
|
return X.numpy(), y.numpy() |
|
|
|
def _build_model(self): |
|
"""Build the initial model architecture.""" |
|
input_dim = 32 |
|
model = tf.keras.Sequential([ |
|
tf.keras.layers.Input(shape=(input_dim,)), |
|
tf.keras.layers.Dense(128, activation='relu'), |
|
tf.keras.layers.Dense(64, activation='relu'), |
|
tf.keras.layers.Dense(1) |
|
]) |
|
model.compile( |
|
optimizer=tf.keras.optimizers.Adam( |
|
learning_rate=self.config.get('training', {}).get('learning_rate', 0.001) |
|
), |
|
loss='mse' |
|
) |
|
return model |
|
|
|
def train_local(self, data): |
|
"""Train the model on local data.""" |
|
logger = logging.getLogger(__name__) |
|
X, y = data |
|
|
|
|
|
if isinstance(X, np.ndarray): |
|
X = tf.convert_to_tensor(X, dtype=tf.float32) |
|
if isinstance(y, np.ndarray): |
|
y = tf.convert_to_tensor(y, dtype=tf.float32) |
|
|
|
|
|
logger.info(f"Training Parameters:") |
|
logger.info(f"Input shape: {X.shape}") |
|
logger.info(f"Output shape: {y.shape}") |
|
logger.info(f"Batch size: {self.config.get('training', {}).get('batch_size', 32)}") |
|
logger.info(f"Epochs: {self.config.get('training', {}).get('local_epochs', 5)}") |
|
|
|
class LogCallback(tf.keras.callbacks.Callback): |
|
def on_epoch_end(self, epoch, logs=None): |
|
logger.debug(f"Epoch {epoch + 1} - loss: {logs['loss']:.4f}") |
|
|
|
|
|
history = self.model.fit( |
|
X, y, |
|
batch_size=self.config.get('training', {}).get('batch_size', 32), |
|
epochs=self.config.get('training', {}).get('local_epochs', 3), |
|
verbose=0, |
|
callbacks=[LogCallback()] |
|
) |
|
return history.history |
|
|
|
def get_weights(self) -> List: |
|
"""Get the model weights.""" |
|
weights = self.model.get_weights() |
|
|
|
return [w.tolist() for w in weights] |
|
|
|
def set_weights(self, weights: List): |
|
"""Update local model with global weights.""" |
|
|
|
np_weights = [np.array(w) for w in weights] |
|
self.model.set_weights(np_weights) |
|
|