| import torch |
| import torch.nn as nn |
| from typing import List, Dict, Any |
|
|
| class HolderDistributionEncoder(nn.Module): |
| """ |
| Encodes a list of top holders (wallet embeddings + holding percentages) |
| into a single fixed-size embedding representing the holder distribution. |
| It uses a Transformer Encoder to capture patterns and relationships. |
| """ |
| def __init__(self, |
| wallet_embedding_dim: int, |
| output_dim: int, |
| nhead: int = 4, |
| num_layers: int = 2, |
| dtype: torch.dtype = torch.float16): |
| super().__init__() |
| self.wallet_embedding_dim = wallet_embedding_dim |
| self.output_dim = output_dim |
| self.dtype = dtype |
|
|
| |
| self.pct_proj = nn.Sequential( |
| nn.Linear(1, wallet_embedding_dim // 4), |
| nn.GELU(), |
| nn.Linear(wallet_embedding_dim // 4, wallet_embedding_dim) |
| ).to(dtype) |
|
|
| |
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=wallet_embedding_dim, |
| nhead=nhead, |
| dim_feedforward=wallet_embedding_dim * 4, |
| dropout=0.1, |
| activation='gelu', |
| batch_first=True, |
| dtype=dtype |
| ) |
| self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) |
|
|
| |
| self.cls_token = nn.Parameter(torch.randn(1, 1, wallet_embedding_dim, dtype=dtype)) |
|
|
| |
| self.final_proj = nn.Linear(wallet_embedding_dim, output_dim).to(dtype) |
|
|
| def forward(self, holder_data: List[Dict[str, Any]]) -> torch.Tensor: |
| """ |
| Args: |
| holder_data: A list of dictionaries, where each dict contains: |
| 'wallet_embedding': A tensor of shape [wallet_embedding_dim] |
| 'pct': The holding percentage as a float. |
| |
| Returns: |
| A tensor of shape [1, output_dim] representing the entire distribution. |
| """ |
| if not holder_data: |
| |
| return torch.zeros(1, self.output_dim, device=self.cls_token.device, dtype=self.dtype) |
|
|
| |
| wallet_embeds = torch.stack([d['wallet_embedding'] for d in holder_data]) |
| holder_pcts = torch.tensor([[d['pct']] for d in holder_data], device=wallet_embeds.device, dtype=self.dtype) |
|
|
| |
| pct_embeds = self.pct_proj(holder_pcts) |
| holder_inputs = (wallet_embeds + pct_embeds).unsqueeze(0) |
|
|
| |
| batch_size = holder_inputs.size(0) |
| cls_tokens = self.cls_token.expand(batch_size, -1, -1) |
| transformer_input = torch.cat((cls_tokens, holder_inputs), dim=1) |
|
|
| |
| transformer_output = self.transformer_encoder(transformer_input) |
|
|
| |
| cls_embedding = transformer_output[:, 0, :] |
|
|
| |
| return self.final_proj(cls_embedding) |