oracle / models /wallet_set_encoder.py
zirobtc's picture
Upload folder using huggingface_hub
858826c
import torch
import torch.nn as nn
class WalletSetEncoder(nn.Module):
"""
Encodes a variable-length set of embeddings into a single fixed-size vector
using a Transformer encoder and a [CLS] token.
This is used to pool:
1. A wallet's `wallet_holdings` (a set of [holding_embeds]).
2. A wallet's `Neo4J links` (a set of [link_embeds]).
3. A wallet's `deployed_tokens` (a set of [token_name_embeds]).
"""
def __init__(
self,
d_model: int,
nhead: int,
num_layers: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
dtype: torch.dtype = torch.float16
):
"""
Initializes the Set Encoder.
Args:
d_model (int): The input/output dimension of the embeddings.
nhead (int): Number of attention heads.
num_layers (int): Number of transformer layers.
dim_feedforward (int): Hidden dimension of the feedforward network.
dropout (float): Dropout rate.
dtype (torch.dtype): Data type.
"""
super().__init__()
self.d_model = d_model
self.dtype = dtype
# The learnable [CLS] token, which will aggregate the set representation
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
nn.init.normal_(self.cls_token, std=0.02)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
batch_first=True
)
self.transformer_encoder = nn.TransformerEncoder(
encoder_layer,
num_layers=num_layers
)
self.output_norm = nn.LayerNorm(d_model)
self.to(dtype)
def forward(
self,
item_embeds: torch.Tensor,
src_key_padding_mask: torch.Tensor
) -> torch.Tensor:
"""
Forward pass.
Args:
item_embeds (torch.Tensor):
The batch of item embeddings.
Shape: [batch_size, seq_len, d_model]
src_key_padding_mask (torch.Tensor):
The boolean padding mask for the items, where True indicates
a padded position that should be ignored.
Shape: [batch_size, seq_len]
Returns:
torch.Tensor: The pooled set embedding.
Shape: [batch_size, d_model]
"""
batch_size = item_embeds.shape[0]
# 1. Create [CLS] token batch and concatenate with item embeddings
cls_tokens = self.cls_token.expand(batch_size, -1, -1).to(self.dtype)
x = torch.cat([cls_tokens, item_embeds], dim=1)
# 2. Create the mask for the [CLS] token (it is never masked)
cls_mask = torch.zeros(batch_size, 1, device=src_key_padding_mask.device, dtype=torch.bool)
# 3. Concatenate the [CLS] mask with the item mask
full_padding_mask = torch.cat([cls_mask, src_key_padding_mask], dim=1)
# 4. Pass through Transformer
transformer_output = self.transformer_encoder(
x,
src_key_padding_mask=full_padding_mask
)
# 5. Extract the output of the [CLS] token (the first token in the sequence)
cls_output = transformer_output[:, 0, :]
return self.output_norm(cls_output)