| import os |
| import sys |
| import argparse |
| import random |
| import copy |
| import math |
| import torch |
| from pathlib import Path |
|
|
| |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
| |
| from accelerate import Accelerator |
| from torch.utils.data import DataLoader, Subset |
|
|
| from data.data_loader import OracleDataset |
| from data.data_collator import MemecoinCollator |
| from data.context_targets import MOVEMENT_ID_TO_CLASS |
| from models.multi_modal_processor import MultiModalEncoder |
| from models.helper_encoders import ContextualTimeEncoder |
| from models.token_encoder import TokenEncoder |
| from models.wallet_encoder import WalletEncoder |
| from models.graph_updater import GraphUpdater |
| from models.ohlc_embedder import OHLCEmbedder |
| from models.quant_ohlc_embedder import QuantOHLCEmbedder |
| from models.model import Oracle |
| import models.vocabulary as vocab |
| from data.quant_ohlc_feature_schema import FEATURE_GROUPS, NUM_QUANT_OHLC_FEATURES, TOKENS_PER_SEGMENT, group_feature_indices |
| from train import create_balanced_split |
| from dotenv import load_dotenv |
| from clickhouse_driver import Client as ClickHouseClient |
| from neo4j import GraphDatabase |
| from data.data_fetcher import DataFetcher |
| from scripts.analyze_distribution import get_return_class_map |
|
|
| ABLATION_SWEEP_MODES = [ |
| "wallet", |
| "graph", |
| "social", |
| "token", |
| "holder", |
| "ohlc", |
| "ohlc_wallet", |
| "trade", |
| "onchain", |
| "wallet_graph", |
| "quant_ohlc", |
| "quant_levels", |
| "quant_trendline", |
| "quant_breaks", |
| "quant_rolling", |
| ] |
|
|
| OHLC_PROBE_MODES = [ |
| "ohlc_reverse", |
| "ohlc_shuffle_chunks", |
| "ohlc_mask_recent", |
| "ohlc_trend_only", |
| "ohlc_summary_shuffle", |
| "ohlc_detrend", |
| "ohlc_smooth", |
| ] |
|
|
| def unlog_transform(tensor): |
| """Invert the log1p transform applied during training.""" |
| |
| return torch.sign(tensor) * (torch.exp(torch.abs(tensor)) - 1) |
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--checkpoint", type=str, default="checkpoints/checkpoint-90000", help="Path to checkpoint dir") |
| parser.add_argument("--sample_idx", type=str, default=None, help="Specific sample index or Mint Address to evaluate") |
| parser.add_argument("--mixed_precision", type=str, default="bf16") |
| parser.add_argument("--horizons_seconds", type=int, nargs="+", default=[300, 900, 1800, 3600, 7200]) |
| parser.add_argument("--quantiles", type=float, nargs="+", default=[0.1, 0.5, 0.9]) |
| parser.add_argument("--seed", type=int, default=None) |
| parser.add_argument("--min_class", type=int, default=3, help="Filter out tokens with return class beneath this ID (e.g., 1 for >= 3x returns)") |
| parser.add_argument("--cutoff_trade_idx", type=int, default=200, help="Force the T_cutoff at this exact trade index (e.g., 10 = right after the 10th trade)") |
| parser.add_argument("--num_samples", type=int, default=1, help="Number of valid samples to evaluate and aggregate.") |
| parser.add_argument("--max_retries", type=int, default=100, help="Maximum attempts to find valid contexts across samples.") |
| parser.add_argument("--show_each", action="store_true", help="Print per-sample details for every evaluated sample.") |
| parser.add_argument( |
| "--ablation", |
| type=str, |
| default="none", |
| choices=["none", "wallet", "graph", "wallet_graph", "social", "token", "holder", "ohlc", "ohlc_wallet", "trade", "onchain", "all", "sweep", "ohlc_probe", "quant_ohlc", "quant_levels", "quant_trendline", "quant_breaks", "quant_rolling"], |
| help="Run inference with selected signal families removed, or use 'sweep' to rank multiple families.", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def clone_batch(batch): |
| cloned = {} |
| for key, value in batch.items(): |
| if isinstance(value, torch.Tensor): |
| cloned[key] = value.clone() |
| else: |
| cloned[key] = copy.deepcopy(value) |
| return cloned |
|
|
|
|
| def _empty_wallet_encoder_inputs(device): |
| return { |
| 'username_embed_indices': torch.tensor([], device=device, dtype=torch.long), |
| 'profile_rows': [], |
| 'social_rows': [], |
| 'holdings_batch': [], |
| } |
|
|
|
|
| def _empty_token_encoder_inputs(device): |
| return { |
| 'name_embed_indices': torch.tensor([], device=device, dtype=torch.long), |
| 'symbol_embed_indices': torch.tensor([], device=device, dtype=torch.long), |
| 'image_embed_indices': torch.tensor([], device=device, dtype=torch.long), |
| 'protocol_ids': torch.tensor([], device=device, dtype=torch.long), |
| 'is_vanity_flags': torch.tensor([], device=device, dtype=torch.bool), |
| '_addresses_for_lookup': [], |
| } |
|
|
|
|
| def apply_ablation(batch, mode, device): |
| if mode == "none": |
| return batch |
|
|
| ablated = clone_batch(batch) |
|
|
| if mode in {"wallet", "wallet_graph", "ohlc_wallet", "all"}: |
| for key in ( |
| "wallet_indices", |
| "dest_wallet_indices", |
| "original_author_indices", |
| "holder_snapshot_indices", |
| ): |
| if key in ablated: |
| ablated[key].zero_() |
| ablated["wallet_encoder_inputs"] = _empty_wallet_encoder_inputs(device) |
| ablated["wallet_addr_to_batch_idx"] = {} |
| ablated["holder_snapshot_raw_data"] = [] |
| ablated["graph_updater_links"] = {} |
|
|
| if mode in {"graph", "wallet_graph", "all"}: |
| ablated["graph_updater_links"] = {} |
|
|
| if mode in {"social", "all"}: |
| if "textual_event_indices" in ablated: |
| ablated["textual_event_indices"].zero_() |
| ablated["textual_event_data"] = [] |
|
|
| if mode in {"token", "all"}: |
| for key in ( |
| "token_indices", |
| "quote_token_indices", |
| "trending_token_indices", |
| "boosted_token_indices", |
| ): |
| if key in ablated: |
| ablated[key].zero_() |
| ablated["token_encoder_inputs"] = _empty_token_encoder_inputs(device) |
|
|
| if mode in {"holder", "all"}: |
| if "holder_snapshot_indices" in ablated: |
| ablated["holder_snapshot_indices"].zero_() |
| ablated["holder_snapshot_raw_data"] = [] |
|
|
| if mode in {"ohlc", "ohlc_wallet", "all"}: |
| if "ohlc_indices" in ablated: |
| ablated["ohlc_indices"].zero_() |
| if "ohlc_price_tensors" in ablated: |
| ablated["ohlc_price_tensors"] = torch.zeros_like(ablated["ohlc_price_tensors"]) |
| if "ohlc_interval_ids" in ablated: |
| ablated["ohlc_interval_ids"] = torch.zeros_like(ablated["ohlc_interval_ids"]) |
| if "quant_ohlc_feature_tensors" in ablated: |
| ablated["quant_ohlc_feature_tensors"] = torch.zeros_like(ablated["quant_ohlc_feature_tensors"]) |
| if "quant_ohlc_feature_mask" in ablated: |
| ablated["quant_ohlc_feature_mask"] = torch.zeros_like(ablated["quant_ohlc_feature_mask"]) |
|
|
| quant_group_map = { |
| "quant_ohlc": list(FEATURE_GROUPS.keys()), |
| "quant_levels": ["levels_breaks"], |
| "quant_trendline": ["trendlines"], |
| "quant_breaks": ["relative_structure", "levels_breaks"], |
| "quant_rolling": ["rolling_quant"], |
| } |
| if mode in quant_group_map and "quant_ohlc_feature_tensors" in ablated: |
| idxs = group_feature_indices(quant_group_map[mode]) |
| if idxs: |
| ablated["quant_ohlc_feature_tensors"][:, :, idxs] = 0 |
|
|
| if mode in {"trade", "all"}: |
| for key in ( |
| "trade_numerical_features", |
| "deployer_trade_numerical_features", |
| "smart_wallet_trade_numerical_features", |
| "transfer_numerical_features", |
| "pool_created_numerical_features", |
| "liquidity_change_numerical_features", |
| "fee_collected_numerical_features", |
| "token_burn_numerical_features", |
| "supply_lock_numerical_features", |
| "boosted_token_numerical_features", |
| "trending_token_numerical_features", |
| "dexboost_paid_numerical_features", |
| "global_trending_numerical_features", |
| "chainsnapshot_numerical_features", |
| "lighthousesnapshot_numerical_features", |
| "dexprofile_updated_flags", |
| ): |
| if key in ablated: |
| ablated[key] = torch.zeros_like(ablated[key]) |
| for key in ( |
| "trade_dex_ids", |
| "trade_direction_ids", |
| "trade_mev_protection_ids", |
| "trade_is_bundle_ids", |
| "pool_created_protocol_ids", |
| "liquidity_change_type_ids", |
| "trending_token_source_ids", |
| "trending_token_timeframe_ids", |
| "lighthousesnapshot_protocol_ids", |
| "lighthousesnapshot_timeframe_ids", |
| "migrated_protocol_ids", |
| "alpha_group_ids", |
| "channel_ids", |
| "exchange_ids", |
| ): |
| if key in ablated: |
| ablated[key] = torch.zeros_like(ablated[key]) |
|
|
| if mode == "onchain": |
| if "onchain_snapshot_numerical_features" in ablated: |
| ablated["onchain_snapshot_numerical_features"] = torch.zeros_like(ablated["onchain_snapshot_numerical_features"]) |
|
|
| return ablated |
|
|
|
|
| def _chunk_permutation_indices(length, chunk_size): |
| if length <= 0: |
| return [] |
| chunks = [list(range(i, min(i + chunk_size, length))) for i in range(0, length, chunk_size)] |
| if len(chunks) <= 1: |
| return list(range(length)) |
| permuted = list(reversed(chunks)) |
| out = [] |
| for chunk in permuted: |
| out.extend(chunk) |
| return out |
|
|
|
|
| def _moving_average_1d(series, kernel_size): |
| if kernel_size <= 1 or series.numel() == 0: |
| return series |
| pad = kernel_size // 2 |
| kernel = torch.ones(1, 1, kernel_size, device=series.device, dtype=series.dtype) / float(kernel_size) |
| x = series.view(1, 1, -1) |
| x = torch.nn.functional.pad(x, (pad, pad), mode="replicate") |
| smoothed = torch.nn.functional.conv1d(x, kernel) |
| return smoothed.view(-1)[: series.numel()] |
|
|
|
|
| def _linear_trend(series): |
| if series.numel() <= 1: |
| return series.clone() |
| start = series[0] |
| end = series[-1] |
| steps = torch.linspace(0.0, 1.0, series.numel(), device=series.device, dtype=series.dtype) |
| return start + (end - start) * steps |
|
|
|
|
| def _summary_preserving_shuffle(series, chunk_size=20): |
| length = series.numel() |
| if length <= 2: |
| return series |
| chunks = [] |
| interior_start = 1 |
| interior_end = length - 1 |
| for i in range(interior_start, interior_end, chunk_size): |
| chunks.append(series[i:min(i + chunk_size, interior_end)].clone()) |
| if len(chunks) <= 1: |
| return series |
| reordered = list(reversed(chunks)) |
| out = series.clone() |
| cursor = 1 |
| for chunk in reordered: |
| out[cursor:cursor + chunk.numel()] = chunk |
| cursor += chunk.numel() |
| out[0] = series[0] |
| out[-1] = series[-1] |
| return out |
|
|
|
|
| def _apply_per_series(ohlc, transform_fn): |
| out = ohlc.clone() |
| for batch_idx in range(out.shape[0]): |
| for channel_idx in range(out.shape[1]): |
| out[batch_idx, channel_idx] = transform_fn(out[batch_idx, channel_idx]) |
| return out |
|
|
|
|
| def apply_ohlc_probe(batch, mode): |
| probed = clone_batch(batch) |
| if "ohlc_price_tensors" not in probed or probed["ohlc_price_tensors"].numel() == 0: |
| return probed |
|
|
| ohlc = probed["ohlc_price_tensors"].clone() |
| seq_len = ohlc.shape[-1] |
|
|
| if mode == "ohlc_reverse": |
| probed["ohlc_price_tensors"] = torch.flip(ohlc, dims=[-1]) |
| elif mode == "ohlc_shuffle_chunks": |
| perm = _chunk_permutation_indices(seq_len, chunk_size=30) |
| idx = torch.tensor(perm, device=ohlc.device, dtype=torch.long) |
| probed["ohlc_price_tensors"] = ohlc.index_select(-1, idx) |
| elif mode == "ohlc_mask_recent": |
| keep = max(seq_len - 60, 0) |
| if keep < seq_len and keep > 0: |
| fill = ohlc[..., keep - 1:keep].expand_as(ohlc[..., keep:]) |
| ohlc[..., keep:] = fill |
| elif keep == 0: |
| ohlc.zero_() |
| probed["ohlc_price_tensors"] = ohlc |
| elif mode == "ohlc_trend_only": |
| probed["ohlc_price_tensors"] = _apply_per_series(ohlc, _linear_trend) |
| elif mode == "ohlc_summary_shuffle": |
| probed["ohlc_price_tensors"] = _apply_per_series( |
| ohlc, |
| lambda series: _summary_preserving_shuffle(series, chunk_size=20), |
| ) |
| elif mode == "ohlc_detrend": |
| def detrend(series): |
| trend = _linear_trend(series) |
| detrended = series - trend + series[0] |
| detrended[0] = series[0] |
| detrended[-1] = series[0] |
| return detrended |
| probed["ohlc_price_tensors"] = _apply_per_series(ohlc, detrend) |
| elif mode == "ohlc_smooth": |
| probed["ohlc_price_tensors"] = _apply_per_series( |
| ohlc, |
| lambda series: _moving_average_1d(series, kernel_size=11), |
| ) |
|
|
| return probed |
|
|
|
|
| def run_inference(model, batch): |
| with torch.no_grad(): |
| outputs = model(batch) |
| preds = outputs["quantile_logits"][0].detach().cpu() |
| quality_pred = outputs["quality_logits"][0].detach().cpu() if "quality_logits" in outputs else None |
| movement_pred = outputs["movement_logits"][0].detach().cpu() if "movement_logits" in outputs else None |
| return preds, quality_pred, movement_pred |
|
|
|
|
| def print_results(title, batch, preds, quality_pred, movement_pred, gt_labels, gt_mask, gt_quality, horizons_seconds, quantiles, reference_preds=None, reference_quality=None): |
| real_preds = unlog_transform(preds) |
| num_quantiles = len(quantiles) |
| num_gt_horizons = len(gt_mask) |
|
|
| print(f"\n================== {title} ==================") |
| print(f"Token Address: {batch.get('token_addresses', ['Unknown'])[0]}") |
| if gt_quality is not None: |
| quality_line = f"Quality Score: GT = {gt_quality:.4f} | Pred = {quality_pred.item() if quality_pred is not None else 'N/A'}" |
| if reference_quality is not None and quality_pred is not None: |
| quality_delta = quality_pred.item() - reference_quality.item() |
| quality_line += f" | Delta vs Full = {quality_delta:+.6f}" |
| print(quality_line) |
| if movement_pred is not None: |
| movement_targets = batch.get("movement_class_targets") |
| movement_mask = batch.get("movement_class_mask") |
| print("Movement Classes:") |
| for h_idx, horizon in enumerate(horizons_seconds): |
| if h_idx >= movement_pred.shape[0]: |
| break |
| target_txt = "N/A" |
| if movement_targets is not None and movement_mask is not None and bool(movement_mask[0, h_idx].item()): |
| target_txt = MOVEMENT_ID_TO_CLASS.get(int(movement_targets[0, h_idx].item()), "unknown") |
| pred_class = int(movement_pred[h_idx].argmax().item()) |
| pred_name = MOVEMENT_ID_TO_CLASS.get(pred_class, "unknown") |
| pred_prob = float(torch.softmax(movement_pred[h_idx], dim=-1)[pred_class].item()) |
| print( |
| f" {horizon:>4}s GT = {target_txt:<12} | " |
| f"Pred = {pred_name:<12} | " |
| f"Conf = {pred_prob:.4f}" |
| ) |
| if "context_class_name" in batch: |
| print(f"Context Class: {batch['context_class_name'][0]}") |
|
|
| print("\nReturns per Horizon:") |
| for h_idx, horizon in enumerate(horizons_seconds): |
| horizon_min = horizon // 60 |
| print(f"\n--- Horizon: {horizon}s ({horizon_min}m) ---") |
|
|
| if h_idx >= num_gt_horizons: |
| print(" [No Ground Truth Available for this Horizon - Not in Dataset]") |
| valid = False |
| else: |
| valid = gt_mask[h_idx].item() |
|
|
| if not valid: |
| print(" [No Ground Truth Available for this Horizon - Masked]") |
| else: |
| gt_ret = gt_labels[h_idx].item() |
| print(f" Ground Truth: {gt_ret * 100:.2f}%") |
|
|
| print(" Predictions:") |
| for q_idx, q in enumerate(quantiles): |
| flat_idx = h_idx * num_quantiles + q_idx |
| pred_ret = real_preds[flat_idx].item() |
| log_pred = preds[flat_idx].item() |
| line = f" - p{int(q*100):02d}: {pred_ret * 100:>8.2f}% (raw log-val: {log_pred:7.4f})" |
| if reference_preds is not None: |
| ref_ret = unlog_transform(reference_preds)[flat_idx].item() |
| line += f" | Delta vs Full: {(pred_ret - ref_ret) * 100:+7.2f}%" |
| print(line) |
|
|
| print("=============================================\n") |
|
|
|
|
| def resolve_sample_index(dataset, sample_idx_arg, rng): |
| if sample_idx_arg is not None: |
| if isinstance(sample_idx_arg, str) and not sample_idx_arg.isdigit(): |
| found_idx = next((i for i, m in enumerate(dataset.sampled_mints) if m['mint_address'] == sample_idx_arg), None) |
| if found_idx is None: |
| raise ValueError(f"Mint address {sample_idx_arg} not found in filtered dataset") |
| return found_idx |
| resolved = int(sample_idx_arg) |
| if resolved >= len(dataset): |
| raise ValueError(f"Sample index {resolved} out of range") |
| return resolved |
| return rng.randint(0, len(dataset.sampled_mints) - 1) |
|
|
|
|
| def move_batch_to_device(batch, device): |
| for k, v in batch.items(): |
| if isinstance(v, torch.Tensor): |
| batch[k] = v.to(device) |
| elif isinstance(v, list) and len(v) > 0 and isinstance(v[0], torch.Tensor): |
| batch[k] = [t.to(device) for t in v] |
| if 'textual_event_indices' not in batch: |
| B, L = batch['event_type_ids'].shape |
| batch['textual_event_indices'] = torch.zeros((B, L), dtype=torch.long, device=device) |
| if 'textual_event_data' not in batch: |
| batch['textual_event_data'] = [] |
| return batch |
|
|
|
|
| def init_aggregate(horizons_seconds, quantiles): |
| return { |
| "count": 0, |
| "quality_full_sum": 0.0, |
| "quality_abl_sum": 0.0, |
| "quality_delta_sum": 0.0, |
| "gt_quality_sum": 0.0, |
| "per_hq": { |
| (h, q): { |
| "full_sum": 0.0, |
| "abl_sum": 0.0, |
| "delta_sum": 0.0, |
| "abs_delta_sum": 0.0, |
| "gt_sum": 0.0, |
| "valid_count": 0, |
| } |
| for h in horizons_seconds for q in quantiles |
| }, |
| } |
|
|
|
|
| def update_aggregate(stats, full_preds, gt_labels, gt_mask, gt_quality, horizons_seconds, quantiles, ablated_preds=None, full_quality=None, ablated_quality=None): |
| stats["count"] += 1 |
| if gt_quality is not None: |
| stats["gt_quality_sum"] += float(gt_quality) |
| if full_quality is not None: |
| stats["quality_full_sum"] += float(full_quality.item()) |
| if ablated_quality is not None: |
| stats["quality_abl_sum"] += float(ablated_quality.item()) |
| if full_quality is not None and ablated_quality is not None: |
| stats["quality_delta_sum"] += float(ablated_quality.item() - full_quality.item()) |
|
|
| full_real = unlog_transform(full_preds) |
| ablated_real = unlog_transform(ablated_preds) if ablated_preds is not None else None |
| num_quantiles = len(quantiles) |
|
|
| for h_idx, horizon in enumerate(horizons_seconds): |
| valid = h_idx < len(gt_mask) and bool(gt_mask[h_idx].item()) |
| gt_ret = float(gt_labels[h_idx].item()) if valid else math.nan |
| for q_idx, q in enumerate(quantiles): |
| flat_idx = h_idx * num_quantiles + q_idx |
| bucket = stats["per_hq"][(horizon, q)] |
| full_val = float(full_real[flat_idx].item()) |
| bucket["full_sum"] += full_val |
| if ablated_real is not None: |
| abl_val = float(ablated_real[flat_idx].item()) |
| delta = abl_val - full_val |
| bucket["abl_sum"] += abl_val |
| bucket["delta_sum"] += delta |
| bucket["abs_delta_sum"] += abs(delta) |
| if valid: |
| bucket["gt_sum"] += gt_ret |
| bucket["valid_count"] += 1 |
|
|
|
|
| def print_aggregate_summary(stats, horizons_seconds, quantiles, ablation_mode): |
| n = stats["count"] |
| print("\n================== Aggregate Summary ==================") |
| print(f"Evaluated Samples: {n}") |
| if n == 0: |
| print("No valid samples collected.") |
| print("=======================================================\n") |
| return |
|
|
| if ablation_mode != "none": |
| print( |
| f"Quality Mean: full={stats['quality_full_sum'] / n:.6f} | " |
| f"ablated={stats['quality_abl_sum'] / n:.6f} | " |
| f"delta={stats['quality_delta_sum'] / n:+.6f}" |
| ) |
|
|
| for horizon in horizons_seconds: |
| horizon_min = horizon // 60 |
| print(f"\n--- Horizon: {horizon}s ({horizon_min}m) ---") |
| valid_counts = [stats["per_hq"][(horizon, q)]["valid_count"] for q in quantiles] |
| valid_count = max(valid_counts) if valid_counts else 0 |
| if valid_count > 0: |
| gt_mean = stats["per_hq"][(horizon, quantiles[0])]["gt_sum"] / valid_count |
| print(f" Mean Ground Truth over valid labels: {gt_mean * 100:.2f}% (n={valid_count})") |
| else: |
| print(" Mean Ground Truth over valid labels: N/A") |
|
|
| for q in quantiles: |
| bucket = stats["per_hq"][(horizon, q)] |
| full_mean = bucket["full_sum"] / n |
| line = f" p{int(q*100):02d} mean full: {full_mean * 100:>8.2f}%" |
| if ablation_mode != "none": |
| abl_mean = bucket["abl_sum"] / n |
| delta_mean = bucket["delta_sum"] / n |
| abs_delta_mean = bucket["abs_delta_sum"] / n |
| line += ( |
| f" | ablated: {abl_mean * 100:>8.2f}%" |
| f" | delta: {delta_mean * 100:+8.2f}%" |
| f" | mean|delta|: {abs_delta_mean * 100:>8.2f}%" |
| ) |
| print(line) |
| print("=======================================================\n") |
|
|
|
|
| def summarize_influence_score(stats, horizons_seconds, quantiles): |
| n = stats["count"] |
| if n == 0: |
| return 0.0 |
| total = 0.0 |
| denom = 0 |
| for horizon in horizons_seconds: |
| for q in quantiles: |
| total += stats["per_hq"][(horizon, q)]["abs_delta_sum"] / n |
| denom += 1 |
| return total / max(denom, 1) |
|
|
|
|
| def print_probe_summary(mode_to_stats, horizons_seconds, quantiles): |
| rankings = [] |
| for mode in OHLC_PROBE_MODES: |
| score = summarize_influence_score(mode_to_stats[mode], horizons_seconds, quantiles) |
| rankings.append((mode, score)) |
| rankings.sort(key=lambda x: x[1], reverse=True) |
|
|
| print("\n================== OHLC Probe Ranking ==================") |
| for rank, (mode, score) in enumerate(rankings, start=1): |
| print(f"{rank:>2}. {mode:<20} mean|delta| = {score * 100:8.2f}%") |
| print("========================================================\n") |
|
|
| for mode, _ in rankings: |
| print_aggregate_summary(mode_to_stats[mode], horizons_seconds, quantiles, mode) |
|
|
| def get_latest_checkpoint(checkpoint_dir): |
| ckpt_dir = Path(checkpoint_dir) |
| if ckpt_dir.exists(): |
| dirs = [d for d in ckpt_dir.iterdir() if d.is_dir()] |
| if dirs: |
| dirs.sort(key=lambda x: x.stat().st_mtime) |
| latest_checkpoint = dirs[-1] |
| return str(latest_checkpoint) |
| return None |
|
|
| def main(): |
| load_dotenv() |
| args = parse_args() |
| rng = random.Random(args.seed) |
| if args.seed is not None: |
| random.seed(args.seed) |
| torch.manual_seed(args.seed) |
| |
| accelerator = Accelerator(mixed_precision=args.mixed_precision) |
| device = accelerator.device |
|
|
| init_dtype = torch.float32 |
| if accelerator.mixed_precision == 'bf16': |
| init_dtype = torch.bfloat16 |
| elif accelerator.mixed_precision == 'fp16': |
| init_dtype = torch.float16 |
|
|
| print("INFO: Initializing DB Connections for LIVE evaluation...") |
| clickhouse_host = os.getenv("CLICKHOUSE_HOST", "localhost") |
| clickhouse_port = int(os.getenv("CLICKHOUSE_PORT", 9000)) |
| neo4j_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687") |
| neo4j_user = os.getenv("NEO4J_USER", "neo4j") |
| neo4j_password = os.getenv("NEO4J_PASSWORD", "password") |
|
|
| clickhouse_client = ClickHouseClient(host=clickhouse_host, port=clickhouse_port) |
| neo4j_driver = GraphDatabase.driver(neo4j_uri, auth=(neo4j_user, neo4j_password)) |
| data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver) |
|
|
| print(f"Loading live dataset generator...") |
| |
| |
| dataset = OracleDataset( |
| data_fetcher=data_fetcher, |
| fetcher_config=None, |
| horizons_seconds=args.horizons_seconds, |
| quantiles=args.quantiles, |
| cache_dir=None |
| ) |
|
|
| |
| from models.vocabulary import MANIPULATED_CLASS_ID |
| print("INFO: Fetching Return Classification Map...") |
| return_class_map, _ = get_return_class_map(clickhouse_client) |
| |
| min_class_thresh = args.min_class if args.min_class is not None else 0 |
| |
| original_len = len(dataset.sampled_mints) |
| dataset.sampled_mints = [ |
| m for m in dataset.sampled_mints |
| if return_class_map.get(m['mint_address']) is not None |
| and return_class_map.get(m['mint_address']) != MANIPULATED_CLASS_ID |
| and return_class_map.get(m['mint_address']) >= min_class_thresh |
| ] |
| dataset.num_samples = len(dataset.sampled_mints) |
| print(f"INFO: Filtered tokens. {original_len} -> {len(dataset.sampled_mints)} valid tokens (class >= {min_class_thresh}).") |
|
|
| if len(dataset) == 0: |
| raise ValueError("Dataset is empty. Are ClickHouse data and trade pipelines populated? (Check if min_return filtered everything out)") |
|
|
| |
| print("Initializing encoders...") |
| multi_modal_encoder = MultiModalEncoder(dtype=init_dtype, device=device) |
| time_encoder = ContextualTimeEncoder(dtype=init_dtype) |
| token_encoder = TokenEncoder(multi_dim=multi_modal_encoder.embedding_dim, dtype=init_dtype) |
| wallet_encoder = WalletEncoder(encoder=multi_modal_encoder, dtype=init_dtype) |
| graph_updater = GraphUpdater(time_encoder=time_encoder, dtype=init_dtype) |
| ohlc_embedder = OHLCEmbedder(num_intervals=vocab.NUM_OHLC_INTERVALS, dtype=init_dtype) |
| quant_ohlc_embedder = QuantOHLCEmbedder( |
| num_features=NUM_QUANT_OHLC_FEATURES, |
| sequence_length=TOKENS_PER_SEGMENT, |
| dtype=init_dtype, |
| ) |
|
|
| collator = MemecoinCollator( |
| event_type_to_id=vocab.EVENT_TO_ID, |
| device=device, |
| dtype=init_dtype, |
| max_seq_len=4096 |
| ) |
|
|
| print("Initializing model...") |
| model = Oracle( |
| token_encoder=token_encoder, |
| wallet_encoder=wallet_encoder, |
| graph_updater=graph_updater, |
| ohlc_embedder=ohlc_embedder, |
| quant_ohlc_embedder=quant_ohlc_embedder, |
| time_encoder=time_encoder, |
| num_event_types=vocab.NUM_EVENT_TYPES, |
| multi_modal_dim=multi_modal_encoder.embedding_dim, |
| event_pad_id=vocab.EVENT_TO_ID["__PAD__"], |
| event_type_to_id=vocab.EVENT_TO_ID, |
| model_config_name="llama3-12l-768d-gqa4-8k-random", |
| quantiles=args.quantiles, |
| horizons_seconds=args.horizons_seconds, |
| dtype=init_dtype |
| ) |
|
|
| if hasattr(model.model, 'embed_tokens'): |
| del model.model.embed_tokens |
|
|
| |
| ckpt_path = args.checkpoint |
| if ckpt_path.endswith("latest"): |
| base_dir = Path(ckpt_path).parent |
| found = get_latest_checkpoint(base_dir) |
| if found: |
| ckpt_path = found |
|
|
| if not os.path.exists(ckpt_path): |
| print(f"Warning: Checkpoint {ckpt_path} not found. Running with random weights!") |
| model = accelerator.prepare(model) |
| else: |
| print(f"Loading checkpoint from {ckpt_path}...") |
| model = accelerator.prepare(model) |
| try: |
| accelerator.load_state(ckpt_path) |
| print("Successfully loaded accelerator state.") |
| except Exception as e: |
| print(f"Could not load using accelerate.load_state: {e}") |
| print("Trying to load model weights directly...") |
| model_file = os.path.join(ckpt_path, "pytorch_model.bin") |
| if not os.path.exists(model_file): |
| model_file = os.path.join(ckpt_path, "model.safetensors") |
| |
| if os.path.exists(model_file): |
| if model_file.endswith(".safetensors"): |
| from safetensors.torch import load_file |
| state_dict = load_file(model_file) |
| else: |
| state_dict = torch.load(model_file, map_location="cpu") |
| |
| uw_model = accelerator.unwrap_model(model) |
| uw_model.load_state_dict(state_dict, strict=False) |
| print("Successfully loaded weights directly.") |
| else: |
| print(f"Error: model weights not found in {ckpt_path}") |
|
|
| model.eval() |
|
|
| stats = init_aggregate(args.horizons_seconds, args.quantiles) |
| selected_modes = [] if args.ablation == "none" else (ABLATION_SWEEP_MODES if args.ablation == "sweep" else ([] if args.ablation == "ohlc_probe" else [args.ablation])) |
| mode_to_stats = {mode: init_aggregate(args.horizons_seconds, args.quantiles) for mode in selected_modes} |
| probe_to_stats = {mode: init_aggregate(args.horizons_seconds, args.quantiles) for mode in OHLC_PROBE_MODES} if args.ablation == "ohlc_probe" else {} |
| max_target_samples = max(1, args.num_samples) |
| retries = 0 |
| collected = 0 |
| seen_indices = set() |
|
|
| while collected < max_target_samples and retries < args.max_retries: |
| sample_idx = resolve_sample_index(dataset, args.sample_idx, rng) |
| if args.sample_idx is None and sample_idx in seen_indices and len(seen_indices) < len(dataset.sampled_mints): |
| retries += 1 |
| continue |
| seen_indices.add(sample_idx) |
|
|
| sample_mint_addr = dataset.sampled_mints[sample_idx]['mint_address'] |
| print(f"Trying Token Address: {sample_mint_addr}") |
|
|
| contexts = dataset.__cacheitem_context__( |
| sample_idx, |
| num_samples_per_token=1, |
| encoder=multi_modal_encoder, |
| forced_cutoff_trade_idx=args.cutoff_trade_idx, |
| ) |
|
|
| if not contexts or contexts[0] is None: |
| print(" [Failed to generate valid context pattern, skipping...]") |
| retries += 1 |
| if args.sample_idx is not None: |
| print("Specific sample requested but failed to generate context. Exiting.") |
| return |
| continue |
|
|
| raw_sample = contexts[0] |
| batch = move_batch_to_device(collator([raw_sample]), device) |
| gt_labels = batch["labels"][0].cpu() |
| gt_mask = batch["labels_mask"][0].cpu().bool() |
| gt_quality = batch["quality_score"][0].item() if "quality_score" in batch else None |
|
|
| if collected == 0 or args.show_each: |
| print(f"\nEvaluating sample {collected + 1}/{max_target_samples} on Token Address: {sample_mint_addr}") |
| print("\n--- Running Inference ---") |
|
|
| full_preds, full_quality, full_direction = run_inference(model, batch) |
| ablation_outputs = {} |
| for mode in selected_modes: |
| ablated_batch = apply_ablation(batch, mode, device) |
| ablated_preds, ablated_quality, ablated_direction = run_inference(model, ablated_batch) |
| ablation_outputs[mode] = (ablated_batch, ablated_preds, ablated_quality, ablated_direction) |
| probe_outputs = {} |
| if args.ablation == "ohlc_probe": |
| for mode in OHLC_PROBE_MODES: |
| probe_batch = apply_ohlc_probe(batch, mode) |
| probe_preds, probe_quality, probe_direction = run_inference(model, probe_batch) |
| probe_outputs[mode] = (probe_batch, probe_preds, probe_quality, probe_direction) |
|
|
| if collected == 0 or args.show_each: |
| print_results( |
| title="Full Results", |
| batch=batch, |
| preds=full_preds, |
| quality_pred=full_quality, |
| movement_pred=full_direction, |
| gt_labels=gt_labels, |
| gt_mask=gt_mask, |
| gt_quality=gt_quality, |
| horizons_seconds=args.horizons_seconds, |
| quantiles=args.quantiles, |
| ) |
| if args.ablation != "none": |
| if args.ablation == "sweep": |
| print(f"Collected full predictions for {len(selected_modes)} ablation families on this sample. Aggregate ranking will be printed at the end.") |
| elif args.ablation == "ohlc_probe": |
| for mode in OHLC_PROBE_MODES: |
| probe_batch, probe_preds, probe_quality, probe_direction = probe_outputs[mode] |
| print_results( |
| title=f"OHLC Probe ({mode})", |
| batch=probe_batch, |
| preds=probe_preds, |
| quality_pred=probe_quality, |
| movement_pred=probe_direction, |
| gt_labels=gt_labels, |
| gt_mask=gt_mask, |
| gt_quality=gt_quality, |
| horizons_seconds=args.horizons_seconds, |
| quantiles=args.quantiles, |
| reference_preds=full_preds, |
| reference_quality=full_quality, |
| ) |
| else: |
| ablated_batch, ablated_preds, ablated_quality, ablated_direction = ablation_outputs[args.ablation] |
| print_results( |
| title=f"Ablation Results ({args.ablation})", |
| batch=ablated_batch, |
| preds=ablated_preds, |
| quality_pred=ablated_quality, |
| movement_pred=ablated_direction, |
| gt_labels=gt_labels, |
| gt_mask=gt_mask, |
| gt_quality=gt_quality, |
| horizons_seconds=args.horizons_seconds, |
| quantiles=args.quantiles, |
| reference_preds=full_preds, |
| reference_quality=full_quality, |
| ) |
|
|
| update_aggregate( |
| stats=stats, |
| full_preds=full_preds, |
| gt_labels=gt_labels, |
| gt_mask=gt_mask, |
| gt_quality=gt_quality, |
| horizons_seconds=args.horizons_seconds, |
| quantiles=args.quantiles, |
| full_quality=full_quality, |
| ) |
| for mode, (_, ablated_preds, ablated_quality, _) in ablation_outputs.items(): |
| update_aggregate( |
| stats=mode_to_stats[mode], |
| full_preds=full_preds, |
| gt_labels=gt_labels, |
| gt_mask=gt_mask, |
| gt_quality=gt_quality, |
| horizons_seconds=args.horizons_seconds, |
| quantiles=args.quantiles, |
| ablated_preds=ablated_preds, |
| full_quality=full_quality, |
| ablated_quality=ablated_quality, |
| ) |
| for mode, (_, probe_preds, probe_quality, _) in probe_outputs.items(): |
| update_aggregate( |
| stats=probe_to_stats[mode], |
| full_preds=full_preds, |
| gt_labels=gt_labels, |
| gt_mask=gt_mask, |
| gt_quality=gt_quality, |
| horizons_seconds=args.horizons_seconds, |
| quantiles=args.quantiles, |
| ablated_preds=probe_preds, |
| full_quality=full_quality, |
| ablated_quality=probe_quality, |
| ) |
| collected += 1 |
| retries += 1 |
|
|
| if args.sample_idx is not None: |
| break |
|
|
| if collected == 0: |
| print(f"Could not find a valid context after {args.max_retries} attempts.") |
| return |
|
|
| if collected < max_target_samples: |
| print(f"WARNING: Requested {max_target_samples} samples but only evaluated {collected}.") |
|
|
| if args.ablation == "none": |
| print_aggregate_summary(stats, args.horizons_seconds, args.quantiles, args.ablation) |
| return |
|
|
| if args.ablation == "ohlc_probe": |
| print_probe_summary(probe_to_stats, args.horizons_seconds, args.quantiles) |
| return |
|
|
| if args.ablation == "sweep": |
| rankings = [] |
| for mode in selected_modes: |
| score = summarize_influence_score(mode_to_stats[mode], args.horizons_seconds, args.quantiles) |
| rankings.append((mode, score)) |
| rankings.sort(key=lambda x: x[1], reverse=True) |
|
|
| print("\n================== Influence Ranking ==================") |
| for rank, (mode, score) in enumerate(rankings, start=1): |
| print(f"{rank:>2}. {mode:<12} mean|delta| = {score * 100:8.2f}%") |
| print("=======================================================\n") |
|
|
| for mode, _ in rankings: |
| print_aggregate_summary(mode_to_stats[mode], args.horizons_seconds, args.quantiles, mode) |
| else: |
| print_aggregate_summary(mode_to_stats[args.ablation], args.horizons_seconds, args.quantiles, args.ablation) |
|
|
| if __name__ == "__main__": |
| main() |
|
|