| |
|
|
| import torch |
| import traceback |
| import time |
|
|
| |
| from models.model import Oracle |
| from data.data_collator import MemecoinCollator |
| from models.multi_modal_processor import MultiModalEncoder |
| from data.data_loader import OracleDataset |
| from data.data_fetcher import DataFetcher |
| from models.helper_encoders import ContextualTimeEncoder |
| from models.token_encoder import TokenEncoder |
| from models.wallet_encoder import WalletEncoder |
| from models.graph_updater import GraphUpdater |
| from models.ohlc_embedder import OHLCEmbedder |
| from models.quant_ohlc_embedder import QuantOHLCEmbedder |
| import models.vocabulary as vocab |
| from data.quant_ohlc_feature_schema import NUM_QUANT_OHLC_FEATURES, TOKENS_PER_SEGMENT |
|
|
| |
| from clickhouse_driver import Client as ClickHouseClient |
| from neo4j import GraphDatabase |
|
|
| |
| |
| |
| |
| if __name__ == "__main__": |
| print("--- Oracle Inference Script (Full Pipeline Test) ---") |
|
|
| |
| OHLC_SEQ_LEN = 300 |
| print(f"Using {vocab.NUM_EVENT_TYPES} event types from vocabulary.") |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16 |
| if device.type == 'cpu': dtype = torch.float32 |
| print(f"Using device: {device}, dtype: {dtype}") |
|
|
| _test_quantiles = [0.1, 0.5, 0.9] |
| _test_horizons = [30, 60, 120, 240, 420] |
| _test_num_outputs = len(_test_quantiles) * len(_test_horizons) |
|
|
| |
| print("Instantiating encoders (using defaults)...") |
| try: |
| multi_modal_encoder = MultiModalEncoder(dtype=dtype) |
| real_time_enc = ContextualTimeEncoder(dtype=dtype) |
| |
| real_token_enc = TokenEncoder( |
| multi_dim=multi_modal_encoder.embedding_dim, |
| dtype=dtype |
| ) |
| real_wallet_enc = WalletEncoder(encoder=multi_modal_encoder, dtype=dtype) |
| real_graph_upd = GraphUpdater(time_encoder=real_time_enc, dtype=dtype) |
| |
| real_ohlc_emb = OHLCEmbedder( |
| num_intervals=vocab.NUM_OHLC_INTERVALS, |
| dtype=dtype |
| ) |
| real_quant_ohlc_emb = QuantOHLCEmbedder( |
| num_features=NUM_QUANT_OHLC_FEATURES, |
| sequence_length=TOKENS_PER_SEGMENT, |
| dtype=dtype |
| ) |
| |
| print(f"TokenEncoder default output_dim: {real_token_enc.output_dim}") |
| print(f"WalletEncoder default d_model: {real_wallet_enc.d_model}") |
| print(f"OHLCEmbedder default output_dim: {real_ohlc_emb.output_dim}") |
| |
| print("Encoders instantiated.") |
| except Exception as e: |
| print(f"Failed to instantiate encoders: {e}") |
| traceback.print_exc() |
| exit() |
|
|
| |
| collator = MemecoinCollator( |
| event_type_to_id=vocab.EVENT_TO_ID, |
| device=device, |
| multi_modal_encoder=multi_modal_encoder, |
| dtype=dtype, |
| ohlc_seq_len=OHLC_SEQ_LEN, |
| max_seq_len=50 |
| ) |
| print("MemecoinCollator (fast batcher) instantiated.") |
| |
| |
| print("Instantiating Oracle (full pipeline)...") |
| model = Oracle( |
| token_encoder=real_token_enc, |
| wallet_encoder=real_wallet_enc, |
| graph_updater=real_graph_upd, |
| time_encoder=real_time_enc, |
| multi_modal_dim=multi_modal_encoder.embedding_dim, |
| num_event_types=vocab.NUM_EVENT_TYPES, |
| event_pad_id=vocab.EVENT_TO_ID['__PAD__'], |
| event_type_to_id=vocab.EVENT_TO_ID, |
| model_config_name="Qwen/Qwen3-0.6B", |
| quantiles=_test_quantiles, |
| horizons_seconds=_test_horizons, |
| dtype=dtype, |
| ohlc_embedder=real_ohlc_emb, |
| quant_ohlc_embedder=real_quant_ohlc_emb |
| ).to(device) |
| model.eval() |
| print(f"Oracle d_model: {model.d_model}") |
|
|
| |
| print("Creating Dataset...") |
|
|
| |
| try: |
| print("Connecting to databases...") |
| |
| clickhouse_client = ClickHouseClient(host='localhost', port=9000) |
| |
| neo4j_driver = GraphDatabase.driver("bolt://localhost:7687", auth=None) |
| |
| data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver) |
| print("Database clients and DataFetcher initialized.") |
|
|
| |
| all_mints = data_fetcher.get_all_mints() |
| if not all_mints: |
| print("\n❌ No mints found in the database. Exiting test.") |
| exit() |
|
|
| |
| dataset = OracleDataset( |
| data_fetcher=data_fetcher, |
| horizons_seconds=_test_horizons, |
| quantiles=_test_quantiles, |
| max_samples=57) |
|
|
| except Exception as e: |
| print(f"FATAL: Could not initialize database connections or dataset: {e}") |
| traceback.print_exc() |
| exit() |
|
|
| |
| print(f"\n--- Processing a batch of up to {len(dataset)} items from the dataset ---") |
| batch_items = [] |
| for i in range(len(dataset)): |
| token_addr = dataset.sampled_mints[i].get('mint_address', 'unknown') |
| print(f" - Attempting to process sample {i+1}/{len(dataset)} ({token_addr})...") |
| fetch_start = time.time() |
| sample = dataset[i] |
| fetch_elapsed = time.time() - fetch_start |
| print(f" ... fetch completed in {fetch_elapsed:.2f}s") |
| if sample is not None: |
| batch_items.append(sample) |
| print(f" ... Success! Sample added to batch.") |
| |
| if not batch_items: |
| print("\n❌ No valid samples could be generated from the dataset. Exiting.") |
| exit() |
|
|
| |
| print("\n--- Testing Pipeline (Collator + Model.forward) ---") |
| try: |
| |
| collate_start = time.time() |
| collated_batch = collator(batch_items) |
| collate_elapsed = time.time() - collate_start |
| print("Collation successful!") |
| print(f"Collation time for batch of {len(batch_items)} tokens: {collate_elapsed:.2f}s") |
|
|
| |
| B = len(batch_items) |
| L = collated_batch['attention_mask'].shape[1] |
| assert 'ohlc_price_tensors' in collated_batch |
| ohlc_price_tensors = collated_batch['ohlc_price_tensors'] |
| assert ohlc_price_tensors.dim() == 3, f"Expected 3D OHLC tensor, got shape {tuple(ohlc_price_tensors.shape)}" |
| assert ohlc_price_tensors.shape[1] == 2, f"Expected OHLC tensor with 2 rows (open/close), got {ohlc_price_tensors.shape[1]}" |
| assert ohlc_price_tensors.shape[2] == OHLC_SEQ_LEN, f"Expected OHLC seq len {OHLC_SEQ_LEN}, got {ohlc_price_tensors.shape[2]}" |
| assert collated_batch['ohlc_interval_ids'].shape[0] == ohlc_price_tensors.shape[0], "Interval ids must align with OHLC segments" |
| assert ohlc_price_tensors.dtype == dtype, f"OHLC tensor dtype {ohlc_price_tensors.dtype} != expected {dtype}" |
| print(f"Collator produced {ohlc_price_tensors.shape[0]} OHLC segment(s).") |
| |
| |
| assert collated_batch['dest_wallet_indices'].shape == (B, L) |
| assert collated_batch['transfer_numerical_features'].shape == (B, L, 4) |
| assert collated_batch['trade_numerical_features'].shape == (B, L, 8) |
| assert collated_batch['deployer_trade_numerical_features'].shape == (B, L, 8) |
| assert collated_batch['smart_wallet_trade_numerical_features'].shape == (B, L, 8) |
| assert collated_batch['pool_created_numerical_features'].shape == (B, L, 2) |
| assert collated_batch['liquidity_change_numerical_features'].shape == (B, L, 1) |
| assert collated_batch['fee_collected_numerical_features'].shape == (B, L, 1) |
| assert collated_batch['token_burn_numerical_features'].shape == (B, L, 2) |
| assert collated_batch['supply_lock_numerical_features'].shape == (B, L, 2) |
| assert collated_batch['onchain_snapshot_numerical_features'].shape == (B, L, 14) |
| assert collated_batch['trending_token_numerical_features'].shape == (B, L, 1) |
| assert collated_batch['boosted_token_numerical_features'].shape == (B, L, 2) |
| |
| |
| assert collated_batch['dexboost_paid_numerical_features'].shape == (B, L, 2) |
| print("Collator correctly processed all event-specific numerical data into their respective tensors.") |
|
|
| |
| print("\n--- Collated Batch Debug Output ---") |
| print(f"Batch Size: {B}, Max Sequence Length: {L}") |
| |
| |
| print("\n[Core Tensors]") |
| print(f" event_type_ids: {collated_batch['event_type_ids'].shape}") |
| print(f" attention_mask: {collated_batch['attention_mask'].shape}") |
| print(f" timestamps_float: {collated_batch['timestamps_float'].shape}") |
| |
| print("\n[Pointer Tensors]") |
| print(f" wallet_indices: {collated_batch['wallet_indices'].shape}") |
| print(f" token_indices: {collated_batch['token_indices'].shape}") |
| |
| print("\n[Encoder Inputs]") |
| print(f" embedding_pool: {collated_batch['embedding_pool'].shape}") |
| |
| if collated_batch['token_encoder_inputs']['name_embed_indices'].numel() > 0: |
| print(f" token_encoder_inputs contains {collated_batch['token_encoder_inputs']['name_embed_indices'].shape[0]} tokens.") |
| else: |
| print(" token_encoder_inputs is empty.") |
| if collated_batch['wallet_encoder_inputs']['profile_rows']: |
| print(f" wallet_encoder_inputs contains {len(collated_batch['wallet_encoder_inputs']['profile_rows'])} wallets.") |
| else: |
| print(" wallet_encoder_inputs is empty.") |
|
|
| print("\n[Graph Links]") |
| if collated_batch['graph_updater_links']: |
| for link_name, data in collated_batch['graph_updater_links'].items(): |
| print(f" - {link_name}: {data['edge_index'].shape[1]} edges") |
| else: |
| print(" No graph links in this batch.") |
| print("--- End Debug Output ---\n") |
|
|
| print("Embedding pool size:", collated_batch["embedding_pool"].shape[0]) |
| print("Max name_emb_idx:", collated_batch["token_encoder_inputs"]["name_embed_indices"].max().item()) |
|
|
| |
| with torch.no_grad(): |
| model_outputs = model(collated_batch) |
| quantile_logits = model_outputs["quantile_logits"] |
| hidden_states = model_outputs["hidden_states"] |
| attention_mask = model_outputs["attention_mask"] |
| pooled_states = model_outputs["pooled_states"] |
| print("Model forward pass successful!") |
|
|
| |
| print("\n--- Test Results ---") |
| D_MODEL = model.d_model |
|
|
| print(f"Final hidden_states shape: {hidden_states.shape}") |
| print(f"Final attention_mask shape: {attention_mask.shape}") |
|
|
| assert hidden_states.shape == (B, L, D_MODEL) |
| assert attention_mask.shape == (B, L) |
| assert hidden_states.dtype == dtype |
|
|
| print(f"Output mean (sanity check): {hidden_states.mean().item()}") |
| print(f"Pooled state shape: {pooled_states.shape}") |
| print(f"Quantile logits shape: {quantile_logits.shape}") |
|
|
| quantile_grid = quantile_logits.view(B, len(_test_horizons), len(_test_quantiles)) |
| print("\n[Quantile Predictions]") |
| for b_idx in range(B): |
| print(f" Sample {b_idx}:") |
| for h_idx, horizon in enumerate(_test_horizons): |
| row = quantile_grid[b_idx, h_idx] |
| print(f" Horizon {horizon}s -> " + ", ".join( |
| f"q={q:.2f}: {row[q_idx].item():.6f}" |
| for q_idx, q in enumerate(_test_quantiles) |
| )) |
|
|
| print("\n✅ **Test Passed!** Full ENCODING pipeline is working.") |
|
|
| except Exception as e: |
| print(f"\n❌ Error during pipeline test: {e}") |
| traceback.print_exc() |
|
|