|
import torch
|
|
import torch.nn as nn
|
|
from typing import Dict, Any, List
|
|
import asyncio
|
|
import websockets
|
|
import json
|
|
from pydantic import BaseModel
|
|
|
|
class PeerMessage(BaseModel):
|
|
message_type: str
|
|
payload: Dict[str, Any]
|
|
peer_id: str
|
|
|
|
class DecentModel(nn.Module):
|
|
"""Base class for decentralized deep learning models"""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.peer_id = self._generate_peer_id()
|
|
self.peers: List[str] = []
|
|
self.websocket = None
|
|
self.state_updates = {}
|
|
|
|
def _generate_peer_id(self) -> str:
|
|
"""Generate a unique peer ID"""
|
|
import uuid
|
|
return str(uuid.uuid4())
|
|
|
|
async def connect_to_network(self, network_url: str):
|
|
"""Connect to the decentralized network"""
|
|
self.websocket = await websockets.connect(network_url)
|
|
await self._register_peer()
|
|
|
|
async def _register_peer(self):
|
|
"""Register this peer with the network"""
|
|
message = PeerMessage(
|
|
message_type="register",
|
|
payload={"model_type": self.__class__.__name__},
|
|
peer_id=self.peer_id
|
|
)
|
|
await self.websocket.send(message.json())
|
|
|
|
async def broadcast_state_update(self, state_dict: Dict[str, torch.Tensor]):
|
|
"""Broadcast model state updates to other peers"""
|
|
message = PeerMessage(
|
|
message_type="state_update",
|
|
payload={"state": self._serialize_state_dict(state_dict)},
|
|
peer_id=self.peer_id
|
|
)
|
|
await self.websocket.send(message.json())
|
|
|
|
def _serialize_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, List[float]]:
|
|
"""Serialize model state for transmission"""
|
|
return {k: v.cpu().numpy().tolist() for k, v in state_dict.items()}
|
|
|
|
async def receive_state_updates(self):
|
|
"""Receive and process state updates from other peers"""
|
|
while True:
|
|
message = await self.websocket.recv()
|
|
data = PeerMessage.parse_raw(message)
|
|
if data.message_type == "state_update":
|
|
self.state_updates[data.peer_id] = self._deserialize_state_dict(
|
|
data.payload["state"]
|
|
)
|
|
|
|
def _deserialize_state_dict(self, state_dict: Dict[str, List[float]]) -> Dict[str, torch.Tensor]:
|
|
"""Deserialize received model state"""
|
|
return {k: torch.tensor(v) for k, v in state_dict.items()}
|
|
|
|
def aggregate_states(self):
|
|
"""Aggregate state updates from all peers"""
|
|
if not self.state_updates:
|
|
return
|
|
|
|
|
|
aggregated_state = {}
|
|
for key in self.state_updates[list(self.state_updates.keys())[0]].keys():
|
|
tensors = [states[key] for states in self.state_updates.values()]
|
|
aggregated_state[key] = torch.mean(torch.stack(tensors), dim=0)
|
|
|
|
|
|
self.load_state_dict(aggregated_state)
|
|
self.state_updates.clear()
|
|
|
|
def forward(self, *args, **kwargs):
|
|
"""Forward pass - to be implemented by child classes"""
|
|
raise NotImplementedError |