File size: 3,333 Bytes
d79115c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
import torch.nn as nn
from typing import Dict, Any, List
import asyncio
import websockets
import json
from pydantic import BaseModel

class PeerMessage(BaseModel):
    message_type: str
    payload: Dict[str, Any]
    peer_id: str

class DecentModel(nn.Module):
    """Base class for decentralized deep learning models"""
    
    def __init__(self):
        super().__init__()
        self.peer_id = self._generate_peer_id()
        self.peers: List[str] = []
        self.websocket = None
        self.state_updates = {}
        
    def _generate_peer_id(self) -> str:
        """Generate a unique peer ID"""
        import uuid
        return str(uuid.uuid4())
        
    async def connect_to_network(self, network_url: str):
        """Connect to the decentralized network"""
        self.websocket = await websockets.connect(network_url)
        await self._register_peer()
        
    async def _register_peer(self):
        """Register this peer with the network"""
        message = PeerMessage(
            message_type="register",
            payload={"model_type": self.__class__.__name__},
            peer_id=self.peer_id
        )
        await self.websocket.send(message.json())
        
    async def broadcast_state_update(self, state_dict: Dict[str, torch.Tensor]):
        """Broadcast model state updates to other peers"""
        message = PeerMessage(
            message_type="state_update",
            payload={"state": self._serialize_state_dict(state_dict)},
            peer_id=self.peer_id
        )
        await self.websocket.send(message.json())
        
    def _serialize_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, List[float]]:
        """Serialize model state for transmission"""
        return {k: v.cpu().numpy().tolist() for k, v in state_dict.items()}
        
    async def receive_state_updates(self):
        """Receive and process state updates from other peers"""
        while True:
            message = await self.websocket.recv()
            data = PeerMessage.parse_raw(message)
            if data.message_type == "state_update":
                self.state_updates[data.peer_id] = self._deserialize_state_dict(
                    data.payload["state"]
                )
                
    def _deserialize_state_dict(self, state_dict: Dict[str, List[float]]) -> Dict[str, torch.Tensor]:
        """Deserialize received model state"""
        return {k: torch.tensor(v) for k, v in state_dict.items()}
        
    def aggregate_states(self):
        """Aggregate state updates from all peers"""
        if not self.state_updates:
            return
            
        # Average all state updates
        aggregated_state = {}
        for key in self.state_updates[list(self.state_updates.keys())[0]].keys():
            tensors = [states[key] for states in self.state_updates.values()]
            aggregated_state[key] = torch.mean(torch.stack(tensors), dim=0)
            
        # Update model with aggregated state
        self.load_state_dict(aggregated_state)
        self.state_updates.clear()
        
    def forward(self, *args, **kwargs):
        """Forward pass - to be implemented by child classes"""
        raise NotImplementedError