File size: 3,333 Bytes
d79115c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import torch
import torch.nn as nn
from typing import Dict, Any, List
import asyncio
import websockets
import json
from pydantic import BaseModel
class PeerMessage(BaseModel):
message_type: str
payload: Dict[str, Any]
peer_id: str
class DecentModel(nn.Module):
"""Base class for decentralized deep learning models"""
def __init__(self):
super().__init__()
self.peer_id = self._generate_peer_id()
self.peers: List[str] = []
self.websocket = None
self.state_updates = {}
def _generate_peer_id(self) -> str:
"""Generate a unique peer ID"""
import uuid
return str(uuid.uuid4())
async def connect_to_network(self, network_url: str):
"""Connect to the decentralized network"""
self.websocket = await websockets.connect(network_url)
await self._register_peer()
async def _register_peer(self):
"""Register this peer with the network"""
message = PeerMessage(
message_type="register",
payload={"model_type": self.__class__.__name__},
peer_id=self.peer_id
)
await self.websocket.send(message.json())
async def broadcast_state_update(self, state_dict: Dict[str, torch.Tensor]):
"""Broadcast model state updates to other peers"""
message = PeerMessage(
message_type="state_update",
payload={"state": self._serialize_state_dict(state_dict)},
peer_id=self.peer_id
)
await self.websocket.send(message.json())
def _serialize_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, List[float]]:
"""Serialize model state for transmission"""
return {k: v.cpu().numpy().tolist() for k, v in state_dict.items()}
async def receive_state_updates(self):
"""Receive and process state updates from other peers"""
while True:
message = await self.websocket.recv()
data = PeerMessage.parse_raw(message)
if data.message_type == "state_update":
self.state_updates[data.peer_id] = self._deserialize_state_dict(
data.payload["state"]
)
def _deserialize_state_dict(self, state_dict: Dict[str, List[float]]) -> Dict[str, torch.Tensor]:
"""Deserialize received model state"""
return {k: torch.tensor(v) for k, v in state_dict.items()}
def aggregate_states(self):
"""Aggregate state updates from all peers"""
if not self.state_updates:
return
# Average all state updates
aggregated_state = {}
for key in self.state_updates[list(self.state_updates.keys())[0]].keys():
tensors = [states[key] for states in self.state_updates.values()]
aggregated_state[key] = torch.mean(torch.stack(tensors), dim=0)
# Update model with aggregated state
self.load_state_dict(aggregated_state)
self.state_updates.clear()
def forward(self, *args, **kwargs):
"""Forward pass - to be implemented by child classes"""
raise NotImplementedError |