OpenPeerLLM / src /decent_torch.py
Mentors4EDU's picture
Upload 27 files
d79115c verified
raw
history blame
3.33 kB
import torch
import torch.nn as nn
from typing import Dict, Any, List
import asyncio
import websockets
import json
from pydantic import BaseModel
class PeerMessage(BaseModel):
message_type: str
payload: Dict[str, Any]
peer_id: str
class DecentModel(nn.Module):
"""Base class for decentralized deep learning models"""
def __init__(self):
super().__init__()
self.peer_id = self._generate_peer_id()
self.peers: List[str] = []
self.websocket = None
self.state_updates = {}
def _generate_peer_id(self) -> str:
"""Generate a unique peer ID"""
import uuid
return str(uuid.uuid4())
async def connect_to_network(self, network_url: str):
"""Connect to the decentralized network"""
self.websocket = await websockets.connect(network_url)
await self._register_peer()
async def _register_peer(self):
"""Register this peer with the network"""
message = PeerMessage(
message_type="register",
payload={"model_type": self.__class__.__name__},
peer_id=self.peer_id
)
await self.websocket.send(message.json())
async def broadcast_state_update(self, state_dict: Dict[str, torch.Tensor]):
"""Broadcast model state updates to other peers"""
message = PeerMessage(
message_type="state_update",
payload={"state": self._serialize_state_dict(state_dict)},
peer_id=self.peer_id
)
await self.websocket.send(message.json())
def _serialize_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, List[float]]:
"""Serialize model state for transmission"""
return {k: v.cpu().numpy().tolist() for k, v in state_dict.items()}
async def receive_state_updates(self):
"""Receive and process state updates from other peers"""
while True:
message = await self.websocket.recv()
data = PeerMessage.parse_raw(message)
if data.message_type == "state_update":
self.state_updates[data.peer_id] = self._deserialize_state_dict(
data.payload["state"]
)
def _deserialize_state_dict(self, state_dict: Dict[str, List[float]]) -> Dict[str, torch.Tensor]:
"""Deserialize received model state"""
return {k: torch.tensor(v) for k, v in state_dict.items()}
def aggregate_states(self):
"""Aggregate state updates from all peers"""
if not self.state_updates:
return
# Average all state updates
aggregated_state = {}
for key in self.state_updates[list(self.state_updates.keys())[0]].keys():
tensors = [states[key] for states in self.state_updates.values()]
aggregated_state[key] = torch.mean(torch.stack(tensors), dim=0)
# Update model with aggregated state
self.load_state_dict(aggregated_state)
self.state_updates.clear()
def forward(self, *args, **kwargs):
"""Forward pass - to be implemented by child classes"""
raise NotImplementedError