| |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch.nn.utils.rnn import pad_sequence |
| from typing import List, Dict, Any, Tuple, Optional, Union |
| from collections import defaultdict |
| from PIL import Image |
| |
|
|
| import models.vocabulary as vocab |
| from data.data_loader import EmbeddingPooler |
| from data.quant_ohlc_feature_schema import FEATURE_VERSION, FEATURE_VERSION_ID, NUM_QUANT_OHLC_FEATURES, TOKENS_PER_SEGMENT |
|
|
| NATIVE_MINT = "So11111111111111111111111111111111111111112" |
| QUOTE_MINTS = { |
| NATIVE_MINT, |
| "EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v", |
| "Es9vMFrzaCERmJfrF4H2FYD4KCoNkY11McCe8BenwNYB", |
| "USD1ttGY1N17NEEHLmELoaybftRBUSErhqYiQzvEmuB", |
| } |
|
|
| class MemecoinCollator: |
| """ |
| Callable class for PyTorch DataLoader's collate_fn. |
| ... (rest of docstring) ... |
| """ |
| def __init__(self, |
| event_type_to_id: Dict[str, int], |
| device: torch.device, |
| dtype: torch.dtype, |
| max_seq_len: Optional[int] = None, |
| model_id: str = "google/siglip-so400m-patch16-256-i18n" |
| ): |
| self.event_type_to_id = event_type_to_id |
| self.pad_token_id = event_type_to_id.get('__PAD__', 0) |
| |
| self.model_id = model_id |
| self.entity_pad_idx = 0 |
|
|
| self.device = device |
| self.dtype = dtype |
| self.ohlc_seq_len = 300 |
| self.quant_ohlc_tokens = TOKENS_PER_SEGMENT |
| self.quant_ohlc_num_features = NUM_QUANT_OHLC_FEATURES |
| self.max_seq_len = max_seq_len |
|
|
| def _collate_features_for_encoder(self, entities: List[Dict], feature_keys: List[str], device: torch.device, entity_type: str) -> Dict[str, Any]: |
| """ (Unchanged) """ |
| collated = defaultdict(list) |
| if not entities: |
| |
| if entity_type == "token": |
| return { |
| 'name_embed_indices': torch.tensor([], device=device, dtype=torch.long), |
| 'symbol_embed_indices': torch.tensor([], device=device, dtype=torch.long), |
| 'image_embed_indices': torch.tensor([], device=device, dtype=torch.long), |
| 'protocol_ids': torch.tensor([], device=device, dtype=torch.long), |
| 'is_vanity_flags': torch.tensor([], device=device, dtype=torch.bool), |
| '_addresses_for_lookup': [] |
| } |
| elif entity_type == "wallet": |
| return { |
| 'username_embed_indices': torch.tensor([], device=device, dtype=torch.long), |
| 'profile_rows': [], 'social_rows': [], 'holdings_batch': [] |
| } |
| return {} |
|
|
| |
| if entity_type == "token": |
| |
| |
| collated['_addresses_for_lookup'] = [e.get('address', '') for e in entities] |
| collated['name_embed_indices'] = torch.tensor([e.get('name_emb_idx', 0) for e in entities], device=device, dtype=torch.long) |
| collated['symbol_embed_indices'] = torch.tensor([e.get('symbol_emb_idx', 0) for e in entities], device=device, dtype=torch.long) |
| collated['image_embed_indices'] = torch.tensor([e.get('image_emb_idx', 0) for e in entities], device=device, dtype=torch.long) |
| collated['protocol_ids'] = torch.tensor([e.get('protocol', 0) for e in entities], device=device, dtype=torch.long) |
| collated['is_vanity_flags'] = torch.tensor([e.get('is_vanity', False) for e in entities], device=device, dtype=torch.bool) |
| elif entity_type == "wallet": |
| |
| collated['username_embed_indices'] = torch.tensor([e.get('socials', {}).get('username_emb_idx', 0) for e in entities], device=device, dtype=torch.long) |
| collated['profile_rows'] = [e.get('profile', {}) for e in entities] |
| collated['social_rows'] = [e.get('socials', {}) for e in entities] |
| collated['holdings_batch'] = [e.get('holdings', []) for e in entities] |
| return dict(collated) |
|
|
| def _collate_ohlc_inputs(self, chart_events: List[Dict]) -> Dict[str, torch.Tensor]: |
| """ (Unchanged from previous correct version) """ |
| if not chart_events: |
| return { |
| 'price_tensor': torch.empty(0, 2, self.ohlc_seq_len, device=self.device, dtype=self.dtype), |
| 'interval_ids': torch.empty(0, device=self.device, dtype=torch.long), |
| 'quant_feature_tensors': torch.empty(0, self.quant_ohlc_tokens, self.quant_ohlc_num_features, device=self.device, dtype=self.dtype), |
| 'quant_feature_mask': torch.empty(0, self.quant_ohlc_tokens, device=self.device, dtype=self.dtype), |
| 'quant_feature_version_ids': torch.empty(0, device=self.device, dtype=torch.long), |
| } |
| ohlc_tensors = [] |
| interval_ids_list = [] |
| quant_feature_tensors = [] |
| quant_feature_masks = [] |
| quant_feature_version_ids = [] |
| seq_len = self.ohlc_seq_len |
| unknown_id = vocab.INTERVAL_TO_ID.get("Unknown", 0) |
| for segment_data in chart_events: |
| opens = segment_data.get('opens', []) |
| closes = segment_data.get('closes', []) |
| interval_str = segment_data.get('i', "Unknown") |
| pad_open = opens[-1] if opens else 0 |
| pad_close = closes[-1] if closes else 0 |
| o = torch.tensor(opens[:seq_len] + [pad_open]*(seq_len-len(opens)), dtype=self.dtype) |
| c = torch.tensor(closes[:seq_len] + [pad_close]*(seq_len-len(closes)), dtype=self.dtype) |
| o = torch.nan_to_num(o, nan=0.0, posinf=0.0, neginf=0.0) |
| c = torch.nan_to_num(c, nan=0.0, posinf=0.0, neginf=0.0) |
| ohlc_tensors.append(torch.stack([o, c])) |
| interval_id = vocab.INTERVAL_TO_ID.get(interval_str, unknown_id) |
| interval_ids_list.append(interval_id) |
| quant_payload = segment_data.get('quant_ohlc_features') |
| if quant_payload is None: |
| raise RuntimeError("Chart_Segment missing quant_ohlc_features. Rebuild cache with quantitative chart features.") |
| if not isinstance(quant_payload, list): |
| raise RuntimeError("Chart_Segment quant_ohlc_features must be a list.") |
| feature_rows = [] |
| feature_mask = [] |
| for token_idx in range(self.quant_ohlc_tokens): |
| if token_idx < len(quant_payload): |
| payload = quant_payload[token_idx] |
| vec = payload.get('feature_vector') |
| if not isinstance(vec, list) or len(vec) != self.quant_ohlc_num_features: |
| raise RuntimeError( |
| f"Chart_Segment quant feature vector must have length {self.quant_ohlc_num_features}." |
| ) |
| feature_rows.append(vec) |
| feature_mask.append(1.0) |
| else: |
| feature_rows.append([0.0] * self.quant_ohlc_num_features) |
| feature_mask.append(0.0) |
| quant_feature_tensors.append(torch.tensor(feature_rows, device=self.device, dtype=self.dtype)) |
| quant_feature_masks.append(torch.tensor(feature_mask, device=self.device, dtype=self.dtype)) |
| version = segment_data.get('quant_feature_version', FEATURE_VERSION) |
| quant_feature_version_ids.append(FEATURE_VERSION_ID if version == FEATURE_VERSION else 0) |
| return { |
| 'price_tensor': torch.stack(ohlc_tensors).to(self.device), |
| 'interval_ids': torch.tensor(interval_ids_list, device=self.device, dtype=torch.long), |
| 'quant_feature_tensors': torch.stack(quant_feature_tensors).to(self.device), |
| 'quant_feature_mask': torch.stack(quant_feature_masks).to(self.device), |
| 'quant_feature_version_ids': torch.tensor(quant_feature_version_ids, device=self.device, dtype=torch.long), |
| } |
|
|
| def _collate_graph_links(self, |
| batch_items: List[Dict], |
| wallet_addr_to_batch_idx: Dict[str, int], |
| token_addr_to_batch_idx: Dict[str, int]) -> Dict[str, Any]: |
| """ (Unchanged) """ |
| aggregated_links = defaultdict(lambda: {'edge_index_list': [], 'links_list': []}) |
| for item in batch_items: |
| item_wallets = item.get('wallets', {}) |
| item_tokens = item.get('tokens', {}) |
| item_wallet_addr_to_global_idx = {addr: wallet_addr_to_batch_idx.get(addr, self.entity_pad_idx) for addr in item_wallets.keys()} |
| item_token_addr_to_global_idx = {addr: token_addr_to_batch_idx.get(addr, self.entity_pad_idx) for addr in item_tokens.keys()} |
| for link_name, data in item.get('graph_links', {}).items(): |
| |
| triplet = vocab.LINK_NAME_TO_TRIPLET.get(link_name) |
| if not triplet: continue |
| src_type, _, dst_type = triplet |
| edges = data.get('edges') |
| link_props_list = data.get('links', []) |
| if not edges or not link_props_list: continue |
|
|
| src_map = item_wallet_addr_to_global_idx if src_type == 'wallet' else item_token_addr_to_global_idx |
| dst_map = item_wallet_addr_to_global_idx if dst_type == 'wallet' else item_token_addr_to_global_idx |
| |
| remapped_edge_list = [] |
| valid_link_props = [] |
|
|
| for (src_addr, dst_addr), props in zip(edges, link_props_list): |
| src_idx_global = src_map.get(src_addr, self.entity_pad_idx) |
| dst_idx_global = dst_map.get(dst_addr, self.entity_pad_idx) |
| |
| if src_idx_global != self.entity_pad_idx and dst_idx_global != self.entity_pad_idx: |
| remapped_edge_list.append([src_idx_global, dst_idx_global]) |
| valid_link_props.append(props) |
|
|
| if remapped_edge_list: |
| remapped_edge_tensor = torch.tensor(remapped_edge_list, device=self.device, dtype=torch.long).t() |
| aggregated_links[link_name]['edge_index_list'].append(remapped_edge_tensor) |
| aggregated_links[link_name]['links_list'].extend(valid_link_props) |
| if link_name == "TransferLink": |
| link_props = data.get('links', []) |
| derived_edges = [] |
| derived_props = [] |
| for (src_addr, dst_addr), props in zip(edges, link_props): |
| mint_addr = props.get('mint') |
| if not mint_addr or mint_addr in QUOTE_MINTS: |
| continue |
| token_idx_global = item_token_addr_to_global_idx.get(mint_addr, self.entity_pad_idx) |
| if token_idx_global == self.entity_pad_idx: |
| continue |
| for wallet_addr in (src_addr, dst_addr): |
| wallet_idx_global = item_wallet_addr_to_global_idx.get(wallet_addr, self.entity_pad_idx) |
| if wallet_idx_global == self.entity_pad_idx: |
| continue |
| derived_edges.append([wallet_idx_global, token_idx_global]) |
| derived_props.append(props) |
| if derived_edges: |
| derived_tensor = torch.tensor(derived_edges, device=self.device, dtype=torch.long).t() |
| aggregated_links["TransferLinkToken"]['edge_index_list'].append(derived_tensor) |
| aggregated_links["TransferLinkToken"]['links_list'].extend(derived_props) |
| final_links_dict = {} |
| for link_name, data in aggregated_links.items(): |
| if data['edge_index_list']: |
| final_links_dict[link_name] = { |
| 'links': data['links_list'], |
| 'edge_index': torch.cat(data['edge_index_list'], dim=1) |
| } |
| return final_links_dict |
|
|
| def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]: |
| """ |
| Processes a batch of raw data items into tensors for the model. |
| """ |
| |
| |
| |
| |
| |
|
|
| batch_size = len(batch) |
| if batch_size == 0: |
| return {} |
|
|
| |
| batch_wide_pooler = EmbeddingPooler() |
| |
| |
| idx_remap = defaultdict(dict) |
|
|
| for i, item in enumerate(batch): |
| pooler = item.get('embedding_pooler') |
| if not pooler: continue |
|
|
| for pool_item_data in pooler.get_all_items(): |
| original_idx = pool_item_data['idx'] |
| raw_item = pool_item_data['item'] |
| |
| |
| new_batch_idx_1_based = batch_wide_pooler.get_idx(raw_item) |
| new_batch_idx_0_based = new_batch_idx_1_based - 1 |
| idx_remap[i][original_idx] = new_batch_idx_0_based |
|
|
| |
| all_items_sorted = batch_wide_pooler.get_all_items() |
| |
| if not all_items_sorted: |
| |
| |
| batch_embedding_pool = torch.empty(0, 768, device=self.device, dtype=self.dtype) |
| |
| else: |
| first_item = all_items_sorted[0]['item'] |
| if not isinstance(first_item, torch.Tensor): |
| raise RuntimeError(f"Collator expects pre-computed embeddings (torch.Tensor), found {type(first_item)}. Please rebuild cache.") |
| |
| |
| |
| |
| batch_embedding_pool = torch.stack([d['item'] for d in all_items_sorted]).to(device=self.device, dtype=self.dtype) |
| batch_embedding_pool = torch.nan_to_num(batch_embedding_pool, nan=0.0, posinf=0.0, neginf=0.0) |
|
|
| |
| for i, item in enumerate(batch): |
| remap_dict = idx_remap.get(i, {}) |
| if not remap_dict: continue |
|
|
| |
| for token_data in item.get('tokens', {}).values(): |
| for key in ['name_emb_idx', 'symbol_emb_idx', 'image_emb_idx']: |
| if token_data.get(key, 0) > 0: |
| token_data[key] = remap_dict.get(token_data[key], -1) |
| |
| for wallet_data in item.get('wallets', {}).values(): |
| socials = wallet_data.get('socials', {}) |
| if socials.get('username_emb_idx', 0) > 0: |
| socials['username_emb_idx'] = remap_dict.get(socials['username_emb_idx'], -1) |
| |
| for event in item.get('event_sequence', []): |
| for key in event: |
| if key.endswith('_emb_idx') and event.get(key, 0) > 0: |
| event[key] = remap_dict.get(event[key], 0) |
|
|
| |
| unique_wallets_data = {} |
| unique_tokens_data = {} |
| all_event_sequences = [] |
| max_len = 0 |
|
|
| for item in batch: |
| seq = item.get('event_sequence', []) |
| if self.max_seq_len is not None and len(seq) > self.max_seq_len: |
| seq = seq[:self.max_seq_len] |
| all_event_sequences.append(seq) |
| max_len = max(max_len, len(seq)) |
| unique_wallets_data.update(item.get('wallets', {})) |
| unique_tokens_data.update(item.get('tokens', {})) |
|
|
| |
| wallet_items = list(unique_wallets_data.items()) |
| token_items = list(unique_tokens_data.items()) |
|
|
| wallet_list_data = [] |
| for addr, feat in wallet_items: |
| profile = feat.get('profile', {}) |
| if not profile.get('wallet_address'): |
| profile['wallet_address'] = addr |
| wallet_list_data.append(feat) |
|
|
| token_list_data = [] |
| for addr, feat in token_items: |
| if not feat.get('address'): |
| feat['address'] = addr |
| token_list_data.append(feat) |
|
|
| wallet_addr_to_batch_idx = {addr: i + 1 for i, (addr, _) in enumerate(wallet_items)} |
| token_addr_to_batch_idx = {addr: i + 1 for i, (addr, _) in enumerate(token_items)} |
|
|
| |
| token_encoder_inputs = self._collate_features_for_encoder(token_list_data, ['name'], self.device, "token") |
| |
| token_encoder_inputs = self._collate_features_for_encoder(token_list_data, ['name'], self.device, "token") |
| wallet_encoder_inputs = self._collate_features_for_encoder(wallet_list_data, ['profile'], self.device, "wallet") |
| graph_updater_links = self._collate_graph_links(batch, wallet_addr_to_batch_idx, token_addr_to_batch_idx) |
|
|
| |
| B = batch_size |
| L = max_len |
| PAD_IDX_SEQ = self.pad_token_id |
| PAD_IDX_ENT = self.entity_pad_idx |
|
|
| |
| event_type_ids = torch.full((B, L), PAD_IDX_SEQ, dtype=torch.long, device=self.device) |
| |
| timestamps_float = torch.zeros((B, L), dtype=torch.float64, device=self.device) |
| |
| relative_ts = torch.zeros((B, L, 1), dtype=torch.float32, device=self.device) |
| attention_mask = torch.zeros((B, L), dtype=torch.long, device=self.device) |
| wallet_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device) |
| token_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device) |
| ohlc_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device) |
| quote_token_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device) |
| |
| |
| dest_wallet_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device) |
| |
| original_author_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device) |
| |
| transfer_numerical_features = torch.zeros((B, L, 4), dtype=self.dtype, device=self.device) |
|
|
| |
| |
| trade_numerical_features = torch.zeros((B, L, 8), dtype=self.dtype, device=self.device) |
| deployer_trade_numerical_features = torch.zeros((B, L, 8), dtype=self.dtype, device=self.device) |
| smart_wallet_trade_numerical_features = torch.zeros((B, L, 8), dtype=self.dtype, device=self.device) |
| |
| trade_dex_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device) |
| |
| trade_direction_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device) |
| |
| trade_mev_protection_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device) |
| |
| trade_is_bundle_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device) |
|
|
| |
| |
| pool_created_numerical_features = torch.zeros((B, L, 2), dtype=self.dtype, device=self.device) |
| |
| pool_created_protocol_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device) |
|
|
| |
| |
| liquidity_change_numerical_features = torch.zeros((B, L, 1), dtype=self.dtype, device=self.device) |
| |
| liquidity_change_type_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device) |
|
|
| |
| fee_collected_numerical_features = torch.zeros((B, L, 1), dtype=self.dtype, device=self.device) |
| |
| token_burn_numerical_features = torch.zeros((B, L, 2), dtype=self.dtype, device=self.device) |
|
|
| |
| supply_lock_numerical_features = torch.zeros((B, L, 2), dtype=self.dtype, device=self.device) |
|
|
| |
| onchain_snapshot_numerical_features = torch.zeros((B, L, 14), dtype=self.dtype, device=self.device) |
|
|
| |
| trending_token_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device) |
| |
| trending_token_numerical_features = torch.zeros((B, L, 1), dtype=self.dtype, device=self.device) |
| trending_token_source_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device) |
| trending_token_timeframe_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device) |
|
|
| |
| boosted_token_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device) |
| boosted_token_numerical_features = torch.zeros((B, L, 2), dtype=self.dtype, device=self.device) |
|
|
| |
| dexboost_paid_numerical_features = torch.zeros((B, L, 2), dtype=self.dtype, device=self.device) |
|
|
| |
| dexprofile_updated_flags = torch.zeros((B, L, 4), dtype=torch.float32, device=self.device) |
|
|
| |
| alpha_group_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device) |
| channel_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device) |
| exchange_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device) |
|
|
| |
| global_trending_numerical_features = torch.zeros((B, L, 1), dtype=self.dtype, device=self.device) |
|
|
| |
| chainsnapshot_numerical_features = torch.zeros((B, L, 2), dtype=self.dtype, device=self.device) |
|
|
| |
| |
| lighthousesnapshot_numerical_features = torch.zeros((B, L, 5), dtype=self.dtype, device=self.device) |
| lighthousesnapshot_protocol_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device) |
| lighthousesnapshot_timeframe_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device) |
|
|
| |
| migrated_protocol_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device) |
|
|
| |
| |
| holder_snapshot_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device) |
| holder_snapshot_raw_data_list = [] |
|
|
| |
| textual_event_data_list = [] |
| textual_event_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device) |
| |
| image_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device) |
| original_post_image_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device) |
|
|
|
|
|
|
| |
| batch_chart_events = [] |
| chart_event_counter = 0 |
|
|
| |
| for i, seq in enumerate(all_event_sequences): |
| |
| seq_len = len(seq) |
| if seq_len == 0: continue |
| attention_mask[i, :seq_len] = 1 |
|
|
| for j, event in enumerate(seq): |
| |
| event_type = event.get('event_type', '__PAD__') |
| type_id = self.event_type_to_id.get(event_type, PAD_IDX_SEQ) |
| event_type_ids[i, j] = type_id |
| timestamps_float[i, j] = event.get('timestamp', 0) |
| relative_ts[i, j, 0] = event.get('relative_ts', 0.0) |
|
|
| |
| w_addr = event.get('wallet_address') |
| if w_addr: |
| wallet_indices[i, j] = wallet_addr_to_batch_idx.get(w_addr, PAD_IDX_ENT) |
| t_addr = event.get('token_address') |
| if t_addr: |
| token_indices[i, j] = token_addr_to_batch_idx.get(t_addr, PAD_IDX_ENT) |
|
|
| |
| if event_type == 'Chart_Segment': |
| batch_chart_events.append(event) |
| ohlc_indices[i, j] = chart_event_counter + 1 |
| chart_event_counter += 1 |
| |
| elif event_type in ['Transfer', 'LargeTransfer']: |
| |
| dest_w_addr = event.get('destination_wallet_address') |
| if dest_w_addr: |
| dest_wallet_indices[i, j] = wallet_addr_to_batch_idx.get(dest_w_addr, PAD_IDX_ENT) |
| |
| |
| num_feats = [ |
| event.get('token_amount', 0.0), |
| event.get('transfer_pct_of_total_supply', 0.0), |
| event.get('transfer_pct_of_holding', 0.0), |
| event.get('priority_fee', 0.0) |
| ] |
| transfer_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype) |
| |
| elif event_type in ['Trade', 'LargeTrade']: |
| |
| trade_dex_ids[i, j] = event.get('dex_platform_id', 0) |
| trade_direction_ids[i, j] = event.get('trade_direction', 0) |
| trade_mev_protection_ids[i, j] = event.get('mev_protection', 0) |
| trade_is_bundle_ids[i, j] = 1 if event.get('is_bundle') else 0 |
| |
| num_feats = [ |
| event.get('sol_amount', 0.0), |
| event.get('priority_fee', 0.0), |
| event.get('token_amount_pct_of_holding', 0.0), |
| event.get('quote_amount_pct_of_holding', 0.0), |
| event.get('slippage', 0.0), |
| event.get('token_amount_pct_to_total_supply', 0.0), |
| 1.0 if event.get('success') else 0.0, |
| event.get('total_usd', 0.0) |
| ] |
| trade_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype) |
| |
| elif event_type == 'Deployer_Trade': |
| |
| trade_dex_ids[i, j] = event.get('dex_platform_id', 0) |
| trade_direction_ids[i, j] = event.get('trade_direction', 0) |
| trade_mev_protection_ids[i, j] = event.get('mev_protection', 0) |
| trade_is_bundle_ids[i, j] = 1 if event.get('is_bundle') else 0 |
| num_feats = [ |
| event.get('sol_amount', 0.0), |
| event.get('priority_fee', 0.0), |
| event.get('token_amount_pct_of_holding', 0.0), |
| event.get('quote_amount_pct_of_holding', 0.0), |
| event.get('slippage', 0.0), |
| event.get('token_amount_pct_to_total_supply', 0.0), |
| 1.0 if event.get('success') else 0.0, |
| event.get('total_usd', 0.0) |
| ] |
| deployer_trade_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype) |
|
|
| elif event_type == 'SmartWallet_Trade': |
| |
| trade_dex_ids[i, j] = event.get('dex_platform_id', 0) |
| trade_direction_ids[i, j] = event.get('trade_direction', 0) |
| trade_mev_protection_ids[i, j] = event.get('mev_protection', 0) |
| trade_is_bundle_ids[i, j] = 1 if event.get('is_bundle') else 0 |
| num_feats = [ |
| event.get('sol_amount', 0.0), |
| event.get('priority_fee', 0.0), |
| event.get('token_amount_pct_of_holding', 0.0), |
| event.get('quote_amount_pct_of_holding', 0.0), |
| event.get('slippage', 0.0), |
| event.get('token_amount_pct_to_total_supply', 0.0), |
| 1.0 if event.get('success') else 0.0, |
| event.get('total_usd', 0.0) |
| ] |
| smart_wallet_trade_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype) |
|
|
| elif event_type == 'PoolCreated': |
| |
| quote_t_addr = event.get('quote_token_address') |
| if quote_t_addr: |
| quote_token_indices[i, j] = token_addr_to_batch_idx.get(quote_t_addr, PAD_IDX_ENT) |
| |
| pool_created_protocol_ids[i, j] = event.get('protocol_id', 0) |
| |
| num_feats = [ |
| event.get('base_amount', 0.0), |
| event.get('quote_amount', 0.0) |
| ] |
| pool_created_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype) |
|
|
| elif event_type == 'LiquidityChange': |
| |
| quote_t_addr = event.get('quote_token_address') |
| if quote_t_addr: |
| quote_token_indices[i, j] = token_addr_to_batch_idx.get(quote_t_addr, PAD_IDX_ENT) |
| |
| liquidity_change_type_ids[i, j] = event.get('change_type_id', 0) |
| |
| num_feats = [event.get('quote_amount', 0.0)] |
| liquidity_change_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype) |
|
|
| elif event_type == 'FeeCollected': |
| |
| num_feats = [ |
| event.get('sol_amount', 0.0) |
| ] |
| fee_collected_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype) |
|
|
| elif event_type == 'TokenBurn': |
| |
| num_feats = [ |
| event.get('amount_pct_of_total_supply', 0.0), |
| event.get('amount_tokens_burned', 0.0) |
| ] |
| token_burn_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype) |
|
|
| elif event_type == 'SupplyLock': |
| |
| num_feats = [ |
| event.get('amount_pct_of_total_supply', 0.0), |
| event.get('lock_duration', 0.0) |
| ] |
| supply_lock_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype) |
|
|
| elif event_type == 'OnChain_Snapshot': |
| |
| num_feats = [ |
| event.get('total_holders', 0.0), |
| event.get('smart_traders', 0.0), |
| event.get('kols', 0.0), |
| event.get('holder_growth_rate', 0.0), |
| event.get('top_10_holder_pct', 0.0), |
| event.get('sniper_holding_pct', 0.0), |
| event.get('rat_wallets_holding_pct', 0.0), |
| event.get('bundle_holding_pct', 0.0), |
| event.get('current_market_cap', 0.0), |
| event.get('volume', 0.0), |
| event.get('buy_count', 0.0), |
| event.get('sell_count', 0.0), |
| event.get('total_txns', 0.0), |
| event.get('global_fees_paid', 0.0) |
| ] |
| onchain_snapshot_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype) |
|
|
| elif event_type == 'TrendingToken': |
| |
| trending_t_addr = event.get('token_address') |
| if trending_t_addr: |
| trending_token_indices[i, j] = token_addr_to_batch_idx.get(trending_t_addr, PAD_IDX_ENT) |
| |
| trending_token_source_ids[i, j] = event.get('list_source_id', 0) |
| trending_token_timeframe_ids[i, j] = event.get('timeframe_id', 0) |
| |
| |
| num_feats = [ |
| 1.0 / event.get('rank', 1e9) |
| ] |
| trending_token_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype) |
|
|
| elif event_type == 'BoostedToken': |
| |
| boosted_t_addr = event.get('token_address') |
| if boosted_t_addr: |
| boosted_token_indices[i, j] = token_addr_to_batch_idx.get(boosted_t_addr, PAD_IDX_ENT) |
| |
| |
| |
| num_feats = [ |
| event.get('total_boost_amount', 0.0), |
| 1.0 / event.get('rank', 1e9) |
| ] |
| boosted_token_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype) |
|
|
| elif event_type == 'Migrated': |
| migrated_protocol_ids[i, j] = event.get('protocol_id', 0) |
|
|
| elif event_type == 'HolderSnapshot': |
| |
| raw_holders = event.get('holders', []) |
| holder_snapshot_raw_data_list.append(raw_holders) |
| holder_snapshot_indices[i, j] = len(holder_snapshot_raw_data_list) |
| |
| elif event_type == 'Lighthouse_Snapshot': |
| lighthousesnapshot_protocol_ids[i, j] = event.get('protocol_id', 0) |
| lighthousesnapshot_timeframe_ids[i, j] = event.get('timeframe_id', 0) |
| num_feats = [ |
| event.get('total_volume', 0.0), |
| event.get('total_transactions', 0.0), |
| event.get('total_traders', 0.0), |
| event.get('total_tokens_created', 0.0), |
| event.get('total_migrations', 0.0) |
| ] |
| lighthousesnapshot_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype) |
|
|
|
|
| |
| elif event_type in ['XPost', 'XReply', 'XRetweet', 'XQuoteTweet', 'PumpReply', 'DexProfile_Updated', 'TikTok_Trending_Hashtag', 'XTrending_Hashtag']: |
| |
| |
| textual_event_data_list.append(event) |
| textual_event_indices[i, j] = len(textual_event_data_list) |
| |
| if event_type in ['TikTok_Trending_Hashtag', 'XTrending_Hashtag']: |
| global_trending_numerical_features[i, j, 0] = 1.0 / event.get('rank', 1e9) |
|
|
| |
| |
| |
| if event_type == 'XRetweet' or event_type == 'XQuoteTweet': |
| orig_author_addr = event.get('original_author_wallet_address') |
| if orig_author_addr: |
| |
| original_author_indices[i, j] = wallet_addr_to_batch_idx.get(orig_author_addr, PAD_IDX_ENT) |
| |
| |
| |
| |
| |
|
|
| |
| ohlc_inputs_dict = self._collate_ohlc_inputs(batch_chart_events) |
| |
| |
| collated_batch = { |
| |
| 'event_type_ids': event_type_ids, |
| 'timestamps_float': timestamps_float, |
| 'relative_ts': relative_ts, |
| 'attention_mask': attention_mask, |
| |
| 'wallet_indices': wallet_indices, |
| 'token_indices': token_indices, |
| 'quote_token_indices': quote_token_indices, |
| 'trending_token_indices': trending_token_indices, |
| 'boosted_token_indices': boosted_token_indices, |
| 'holder_snapshot_indices': holder_snapshot_indices, |
| 'textual_event_indices': textual_event_indices, |
| 'ohlc_indices': ohlc_indices, |
| |
| 'embedding_pool': batch_embedding_pool, |
| 'token_encoder_inputs': token_encoder_inputs, |
| 'wallet_encoder_inputs': wallet_encoder_inputs, |
| 'ohlc_price_tensors': ohlc_inputs_dict['price_tensor'], |
| 'ohlc_interval_ids': ohlc_inputs_dict['interval_ids'], |
| 'quant_ohlc_feature_tensors': ohlc_inputs_dict['quant_feature_tensors'], |
| 'quant_ohlc_feature_mask': ohlc_inputs_dict['quant_feature_mask'], |
| 'quant_ohlc_feature_version_ids': ohlc_inputs_dict['quant_feature_version_ids'], |
| 'graph_updater_links': graph_updater_links, |
| 'wallet_addr_to_batch_idx': wallet_addr_to_batch_idx, |
| |
| 'dest_wallet_indices': dest_wallet_indices, |
| 'original_author_indices': original_author_indices, |
| |
| 'transfer_numerical_features': transfer_numerical_features, |
| 'trade_numerical_features': trade_numerical_features, |
| 'trade_dex_ids': trade_dex_ids, |
| 'deployer_trade_numerical_features': deployer_trade_numerical_features, |
| 'trade_direction_ids': trade_direction_ids, |
| 'trade_mev_protection_ids': trade_mev_protection_ids, |
| 'smart_wallet_trade_numerical_features': smart_wallet_trade_numerical_features, |
| 'trade_is_bundle_ids': trade_is_bundle_ids, |
| 'pool_created_numerical_features': pool_created_numerical_features, |
| 'pool_created_protocol_ids': pool_created_protocol_ids, |
| 'liquidity_change_numerical_features': liquidity_change_numerical_features, |
| 'liquidity_change_type_ids': liquidity_change_type_ids, |
| 'fee_collected_numerical_features': fee_collected_numerical_features, |
| 'token_burn_numerical_features': token_burn_numerical_features, |
| 'supply_lock_numerical_features': supply_lock_numerical_features, |
| 'onchain_snapshot_numerical_features': onchain_snapshot_numerical_features, |
| 'boosted_token_numerical_features': boosted_token_numerical_features, |
| 'trending_token_numerical_features': trending_token_numerical_features, |
| 'trending_token_source_ids': trending_token_source_ids, |
| 'trending_token_timeframe_ids': trending_token_timeframe_ids, |
| 'dexboost_paid_numerical_features': dexboost_paid_numerical_features, |
| 'dexprofile_updated_flags': dexprofile_updated_flags, |
| 'global_trending_numerical_features': global_trending_numerical_features, |
| 'chainsnapshot_numerical_features': chainsnapshot_numerical_features, |
| 'lighthousesnapshot_numerical_features': lighthousesnapshot_numerical_features, |
| 'lighthousesnapshot_protocol_ids': lighthousesnapshot_protocol_ids, |
| 'lighthousesnapshot_timeframe_ids': lighthousesnapshot_timeframe_ids, |
| 'migrated_protocol_ids': migrated_protocol_ids, |
| 'alpha_group_ids': alpha_group_ids, |
| 'channel_ids': channel_ids, |
| 'exchange_ids': exchange_ids, |
| 'holder_snapshot_raw_data': holder_snapshot_raw_data_list, |
| 'textual_event_data': textual_event_data_list, |
| |
| 'labels': torch.stack([item['labels'] for item in batch]) if batch and 'labels' in batch[0] else None, |
| 'labels_mask': torch.stack([item['labels_mask'] for item in batch]) if batch and 'labels_mask' in batch[0] else None, |
| 'movement_class_targets': torch.stack([item['movement_class_targets'] for item in batch]) if batch and 'movement_class_targets' in batch[0] else None, |
| 'movement_class_mask': torch.stack([item['movement_class_mask'] for item in batch]) if batch and 'movement_class_mask' in batch[0] else None, |
| 'quality_score': torch.stack([item['quality_score'] if isinstance(item['quality_score'], torch.Tensor) else torch.tensor(item['quality_score'], dtype=torch.float32) for item in batch]) if batch and 'quality_score' in batch[0] else None, |
| 'class_id': torch.tensor([item.get('class_id', 0) for item in batch], dtype=torch.long), |
| |
| 'token_addresses': [item.get('token_address', 'unknown') for item in batch], |
| 't_cutoffs': [item.get('t_cutoff', 'unknown') for item in batch], |
| 'sample_indices': [item.get('sample_idx', -1) for item in batch] |
| } |
|
|
| if collated_batch['quality_score'] is None: |
| raise RuntimeError("FATAL: Missing quality_score in batch items. Rebuild cache with quality_score enabled.") |
|
|
| |
| return {k: v for k, v in collated_batch.items() if v is not None} |
|
|