| import torch
|
| import torch.nn as nn
|
| from typing import Dict, Any, List
|
| import asyncio
|
| import websockets
|
| import json
|
| from pydantic import BaseModel
|
|
|
| class PeerMessage(BaseModel):
|
| message_type: str
|
| payload: Dict[str, Any]
|
| peer_id: str
|
|
|
| class DecentModel(nn.Module):
|
| """Base class for decentralized deep learning models"""
|
|
|
| def __init__(self):
|
| super().__init__()
|
| self.peer_id = self._generate_peer_id()
|
| self.peers: List[str] = []
|
| self.websocket = None
|
| self.state_updates = {}
|
|
|
| def _generate_peer_id(self) -> str:
|
| """Generate a unique peer ID"""
|
| import uuid
|
| return str(uuid.uuid4())
|
|
|
| async def connect_to_network(self, network_url: str):
|
| """Connect to the decentralized network"""
|
| self.websocket = await websockets.connect(network_url)
|
| await self._register_peer()
|
|
|
| async def _register_peer(self):
|
| """Register this peer with the network"""
|
| message = PeerMessage(
|
| message_type="register",
|
| payload={"model_type": self.__class__.__name__},
|
| peer_id=self.peer_id
|
| )
|
| await self.websocket.send(message.json())
|
|
|
| async def broadcast_state_update(self, state_dict: Dict[str, torch.Tensor]):
|
| """Broadcast model state updates to other peers"""
|
| message = PeerMessage(
|
| message_type="state_update",
|
| payload={"state": self._serialize_state_dict(state_dict)},
|
| peer_id=self.peer_id
|
| )
|
| await self.websocket.send(message.json())
|
|
|
| def _serialize_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, List[float]]:
|
| """Serialize model state for transmission"""
|
| return {k: v.cpu().numpy().tolist() for k, v in state_dict.items()}
|
|
|
| async def receive_state_updates(self):
|
| """Receive and process state updates from other peers"""
|
| while True:
|
| message = await self.websocket.recv()
|
| data = PeerMessage.parse_raw(message)
|
| if data.message_type == "state_update":
|
| self.state_updates[data.peer_id] = self._deserialize_state_dict(
|
| data.payload["state"]
|
| )
|
|
|
| def _deserialize_state_dict(self, state_dict: Dict[str, List[float]]) -> Dict[str, torch.Tensor]:
|
| """Deserialize received model state"""
|
| return {k: torch.tensor(v) for k, v in state_dict.items()}
|
|
|
| def aggregate_states(self):
|
| """Aggregate state updates from all peers"""
|
| if not self.state_updates:
|
| return
|
|
|
|
|
| aggregated_state = {}
|
| for key in self.state_updates[list(self.state_updates.keys())[0]].keys():
|
| tensors = [states[key] for states in self.state_updates.values()]
|
| aggregated_state[key] = torch.mean(torch.stack(tensors), dim=0)
|
|
|
|
|
| self.load_state_dict(aggregated_state)
|
| self.state_updates.clear()
|
|
|
| def forward(self, *args, **kwargs):
|
| """Forward pass - to be implemented by child classes"""
|
| raise NotImplementedError |