oracle / scripts /evaluate_sample.py
zirobtc's picture
Upload folder using huggingface_hub
a547253 verified
import os
import sys
import argparse
import random
import copy
import math
import torch
from pathlib import Path
# Add project root to sys.path so we can import data and models
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Provide standard defaults
from accelerate import Accelerator
from torch.utils.data import DataLoader, Subset
from data.data_loader import OracleDataset
from data.data_collator import MemecoinCollator
from data.context_targets import MOVEMENT_ID_TO_CLASS
from models.multi_modal_processor import MultiModalEncoder
from models.helper_encoders import ContextualTimeEncoder
from models.token_encoder import TokenEncoder
from models.wallet_encoder import WalletEncoder
from models.graph_updater import GraphUpdater
from models.ohlc_embedder import OHLCEmbedder
from models.quant_ohlc_embedder import QuantOHLCEmbedder
from models.model import Oracle
import models.vocabulary as vocab
from data.quant_ohlc_feature_schema import FEATURE_GROUPS, NUM_QUANT_OHLC_FEATURES, TOKENS_PER_SEGMENT, group_feature_indices
from train import create_balanced_split
from dotenv import load_dotenv
from clickhouse_driver import Client as ClickHouseClient
from neo4j import GraphDatabase
from data.data_fetcher import DataFetcher
from scripts.analyze_distribution import get_return_class_map
ABLATION_SWEEP_MODES = [
"wallet",
"graph",
"social",
"token",
"holder",
"ohlc",
"ohlc_wallet",
"trade",
"onchain",
"wallet_graph",
"quant_ohlc",
"quant_levels",
"quant_trendline",
"quant_breaks",
"quant_rolling",
]
OHLC_PROBE_MODES = [
"ohlc_reverse",
"ohlc_shuffle_chunks",
"ohlc_mask_recent",
"ohlc_trend_only",
"ohlc_summary_shuffle",
"ohlc_detrend",
"ohlc_smooth",
]
def unlog_transform(tensor):
"""Invert the log1p transform applied during training."""
# During training: labels = torch.sign(labels) * torch.log1p(torch.abs(labels))
return torch.sign(tensor) * (torch.exp(torch.abs(tensor)) - 1)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", type=str, default="checkpoints/checkpoint-90000", help="Path to checkpoint dir")
parser.add_argument("--sample_idx", type=str, default=None, help="Specific sample index or Mint Address to evaluate")
parser.add_argument("--mixed_precision", type=str, default="bf16")
parser.add_argument("--horizons_seconds", type=int, nargs="+", default=[300, 900, 1800, 3600, 7200])
parser.add_argument("--quantiles", type=float, nargs="+", default=[0.1, 0.5, 0.9])
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--min_class", type=int, default=3, help="Filter out tokens with return class beneath this ID (e.g., 1 for >= 3x returns)")
parser.add_argument("--cutoff_trade_idx", type=int, default=200, help="Force the T_cutoff at this exact trade index (e.g., 10 = right after the 10th trade)")
parser.add_argument("--num_samples", type=int, default=1, help="Number of valid samples to evaluate and aggregate.")
parser.add_argument("--max_retries", type=int, default=100, help="Maximum attempts to find valid contexts across samples.")
parser.add_argument("--show_each", action="store_true", help="Print per-sample details for every evaluated sample.")
parser.add_argument(
"--ablation",
type=str,
default="none",
choices=["none", "wallet", "graph", "wallet_graph", "social", "token", "holder", "ohlc", "ohlc_wallet", "trade", "onchain", "all", "sweep", "ohlc_probe", "quant_ohlc", "quant_levels", "quant_trendline", "quant_breaks", "quant_rolling"],
help="Run inference with selected signal families removed, or use 'sweep' to rank multiple families.",
)
return parser.parse_args()
def clone_batch(batch):
cloned = {}
for key, value in batch.items():
if isinstance(value, torch.Tensor):
cloned[key] = value.clone()
else:
cloned[key] = copy.deepcopy(value)
return cloned
def _empty_wallet_encoder_inputs(device):
return {
'username_embed_indices': torch.tensor([], device=device, dtype=torch.long),
'profile_rows': [],
'social_rows': [],
'holdings_batch': [],
}
def _empty_token_encoder_inputs(device):
return {
'name_embed_indices': torch.tensor([], device=device, dtype=torch.long),
'symbol_embed_indices': torch.tensor([], device=device, dtype=torch.long),
'image_embed_indices': torch.tensor([], device=device, dtype=torch.long),
'protocol_ids': torch.tensor([], device=device, dtype=torch.long),
'is_vanity_flags': torch.tensor([], device=device, dtype=torch.bool),
'_addresses_for_lookup': [],
}
def apply_ablation(batch, mode, device):
if mode == "none":
return batch
ablated = clone_batch(batch)
if mode in {"wallet", "wallet_graph", "ohlc_wallet", "all"}:
for key in (
"wallet_indices",
"dest_wallet_indices",
"original_author_indices",
"holder_snapshot_indices",
):
if key in ablated:
ablated[key].zero_()
ablated["wallet_encoder_inputs"] = _empty_wallet_encoder_inputs(device)
ablated["wallet_addr_to_batch_idx"] = {}
ablated["holder_snapshot_raw_data"] = []
ablated["graph_updater_links"] = {}
if mode in {"graph", "wallet_graph", "all"}:
ablated["graph_updater_links"] = {}
if mode in {"social", "all"}:
if "textual_event_indices" in ablated:
ablated["textual_event_indices"].zero_()
ablated["textual_event_data"] = []
if mode in {"token", "all"}:
for key in (
"token_indices",
"quote_token_indices",
"trending_token_indices",
"boosted_token_indices",
):
if key in ablated:
ablated[key].zero_()
ablated["token_encoder_inputs"] = _empty_token_encoder_inputs(device)
if mode in {"holder", "all"}:
if "holder_snapshot_indices" in ablated:
ablated["holder_snapshot_indices"].zero_()
ablated["holder_snapshot_raw_data"] = []
if mode in {"ohlc", "ohlc_wallet", "all"}:
if "ohlc_indices" in ablated:
ablated["ohlc_indices"].zero_()
if "ohlc_price_tensors" in ablated:
ablated["ohlc_price_tensors"] = torch.zeros_like(ablated["ohlc_price_tensors"])
if "ohlc_interval_ids" in ablated:
ablated["ohlc_interval_ids"] = torch.zeros_like(ablated["ohlc_interval_ids"])
if "quant_ohlc_feature_tensors" in ablated:
ablated["quant_ohlc_feature_tensors"] = torch.zeros_like(ablated["quant_ohlc_feature_tensors"])
if "quant_ohlc_feature_mask" in ablated:
ablated["quant_ohlc_feature_mask"] = torch.zeros_like(ablated["quant_ohlc_feature_mask"])
quant_group_map = {
"quant_ohlc": list(FEATURE_GROUPS.keys()),
"quant_levels": ["levels_breaks"],
"quant_trendline": ["trendlines"],
"quant_breaks": ["relative_structure", "levels_breaks"],
"quant_rolling": ["rolling_quant"],
}
if mode in quant_group_map and "quant_ohlc_feature_tensors" in ablated:
idxs = group_feature_indices(quant_group_map[mode])
if idxs:
ablated["quant_ohlc_feature_tensors"][:, :, idxs] = 0
if mode in {"trade", "all"}:
for key in (
"trade_numerical_features",
"deployer_trade_numerical_features",
"smart_wallet_trade_numerical_features",
"transfer_numerical_features",
"pool_created_numerical_features",
"liquidity_change_numerical_features",
"fee_collected_numerical_features",
"token_burn_numerical_features",
"supply_lock_numerical_features",
"boosted_token_numerical_features",
"trending_token_numerical_features",
"dexboost_paid_numerical_features",
"global_trending_numerical_features",
"chainsnapshot_numerical_features",
"lighthousesnapshot_numerical_features",
"dexprofile_updated_flags",
):
if key in ablated:
ablated[key] = torch.zeros_like(ablated[key])
for key in (
"trade_dex_ids",
"trade_direction_ids",
"trade_mev_protection_ids",
"trade_is_bundle_ids",
"pool_created_protocol_ids",
"liquidity_change_type_ids",
"trending_token_source_ids",
"trending_token_timeframe_ids",
"lighthousesnapshot_protocol_ids",
"lighthousesnapshot_timeframe_ids",
"migrated_protocol_ids",
"alpha_group_ids",
"channel_ids",
"exchange_ids",
):
if key in ablated:
ablated[key] = torch.zeros_like(ablated[key])
if mode == "onchain":
if "onchain_snapshot_numerical_features" in ablated:
ablated["onchain_snapshot_numerical_features"] = torch.zeros_like(ablated["onchain_snapshot_numerical_features"])
return ablated
def _chunk_permutation_indices(length, chunk_size):
if length <= 0:
return []
chunks = [list(range(i, min(i + chunk_size, length))) for i in range(0, length, chunk_size)]
if len(chunks) <= 1:
return list(range(length))
permuted = list(reversed(chunks))
out = []
for chunk in permuted:
out.extend(chunk)
return out
def _moving_average_1d(series, kernel_size):
if kernel_size <= 1 or series.numel() == 0:
return series
pad = kernel_size // 2
kernel = torch.ones(1, 1, kernel_size, device=series.device, dtype=series.dtype) / float(kernel_size)
x = series.view(1, 1, -1)
x = torch.nn.functional.pad(x, (pad, pad), mode="replicate")
smoothed = torch.nn.functional.conv1d(x, kernel)
return smoothed.view(-1)[: series.numel()]
def _linear_trend(series):
if series.numel() <= 1:
return series.clone()
start = series[0]
end = series[-1]
steps = torch.linspace(0.0, 1.0, series.numel(), device=series.device, dtype=series.dtype)
return start + (end - start) * steps
def _summary_preserving_shuffle(series, chunk_size=20):
length = series.numel()
if length <= 2:
return series
chunks = []
interior_start = 1
interior_end = length - 1
for i in range(interior_start, interior_end, chunk_size):
chunks.append(series[i:min(i + chunk_size, interior_end)].clone())
if len(chunks) <= 1:
return series
reordered = list(reversed(chunks))
out = series.clone()
cursor = 1
for chunk in reordered:
out[cursor:cursor + chunk.numel()] = chunk
cursor += chunk.numel()
out[0] = series[0]
out[-1] = series[-1]
return out
def _apply_per_series(ohlc, transform_fn):
out = ohlc.clone()
for batch_idx in range(out.shape[0]):
for channel_idx in range(out.shape[1]):
out[batch_idx, channel_idx] = transform_fn(out[batch_idx, channel_idx])
return out
def apply_ohlc_probe(batch, mode):
probed = clone_batch(batch)
if "ohlc_price_tensors" not in probed or probed["ohlc_price_tensors"].numel() == 0:
return probed
ohlc = probed["ohlc_price_tensors"].clone()
seq_len = ohlc.shape[-1]
if mode == "ohlc_reverse":
probed["ohlc_price_tensors"] = torch.flip(ohlc, dims=[-1])
elif mode == "ohlc_shuffle_chunks":
perm = _chunk_permutation_indices(seq_len, chunk_size=30)
idx = torch.tensor(perm, device=ohlc.device, dtype=torch.long)
probed["ohlc_price_tensors"] = ohlc.index_select(-1, idx)
elif mode == "ohlc_mask_recent":
keep = max(seq_len - 60, 0)
if keep < seq_len and keep > 0:
fill = ohlc[..., keep - 1:keep].expand_as(ohlc[..., keep:])
ohlc[..., keep:] = fill
elif keep == 0:
ohlc.zero_()
probed["ohlc_price_tensors"] = ohlc
elif mode == "ohlc_trend_only":
probed["ohlc_price_tensors"] = _apply_per_series(ohlc, _linear_trend)
elif mode == "ohlc_summary_shuffle":
probed["ohlc_price_tensors"] = _apply_per_series(
ohlc,
lambda series: _summary_preserving_shuffle(series, chunk_size=20),
)
elif mode == "ohlc_detrend":
def detrend(series):
trend = _linear_trend(series)
detrended = series - trend + series[0]
detrended[0] = series[0]
detrended[-1] = series[0]
return detrended
probed["ohlc_price_tensors"] = _apply_per_series(ohlc, detrend)
elif mode == "ohlc_smooth":
probed["ohlc_price_tensors"] = _apply_per_series(
ohlc,
lambda series: _moving_average_1d(series, kernel_size=11),
)
return probed
def run_inference(model, batch):
with torch.no_grad():
outputs = model(batch)
preds = outputs["quantile_logits"][0].detach().cpu()
quality_pred = outputs["quality_logits"][0].detach().cpu() if "quality_logits" in outputs else None
movement_pred = outputs["movement_logits"][0].detach().cpu() if "movement_logits" in outputs else None
return preds, quality_pred, movement_pred
def print_results(title, batch, preds, quality_pred, movement_pred, gt_labels, gt_mask, gt_quality, horizons_seconds, quantiles, reference_preds=None, reference_quality=None):
real_preds = unlog_transform(preds)
num_quantiles = len(quantiles)
num_gt_horizons = len(gt_mask)
print(f"\n================== {title} ==================")
print(f"Token Address: {batch.get('token_addresses', ['Unknown'])[0]}")
if gt_quality is not None:
quality_line = f"Quality Score: GT = {gt_quality:.4f} | Pred = {quality_pred.item() if quality_pred is not None else 'N/A'}"
if reference_quality is not None and quality_pred is not None:
quality_delta = quality_pred.item() - reference_quality.item()
quality_line += f" | Delta vs Full = {quality_delta:+.6f}"
print(quality_line)
if movement_pred is not None:
movement_targets = batch.get("movement_class_targets")
movement_mask = batch.get("movement_class_mask")
print("Movement Classes:")
for h_idx, horizon in enumerate(horizons_seconds):
if h_idx >= movement_pred.shape[0]:
break
target_txt = "N/A"
if movement_targets is not None and movement_mask is not None and bool(movement_mask[0, h_idx].item()):
target_txt = MOVEMENT_ID_TO_CLASS.get(int(movement_targets[0, h_idx].item()), "unknown")
pred_class = int(movement_pred[h_idx].argmax().item())
pred_name = MOVEMENT_ID_TO_CLASS.get(pred_class, "unknown")
pred_prob = float(torch.softmax(movement_pred[h_idx], dim=-1)[pred_class].item())
print(
f" {horizon:>4}s GT = {target_txt:<12} | "
f"Pred = {pred_name:<12} | "
f"Conf = {pred_prob:.4f}"
)
if "context_class_name" in batch:
print(f"Context Class: {batch['context_class_name'][0]}")
print("\nReturns per Horizon:")
for h_idx, horizon in enumerate(horizons_seconds):
horizon_min = horizon // 60
print(f"\n--- Horizon: {horizon}s ({horizon_min}m) ---")
if h_idx >= num_gt_horizons:
print(" [No Ground Truth Available for this Horizon - Not in Dataset]")
valid = False
else:
valid = gt_mask[h_idx].item()
if not valid:
print(" [No Ground Truth Available for this Horizon - Masked]")
else:
gt_ret = gt_labels[h_idx].item()
print(f" Ground Truth: {gt_ret * 100:.2f}%")
print(" Predictions:")
for q_idx, q in enumerate(quantiles):
flat_idx = h_idx * num_quantiles + q_idx
pred_ret = real_preds[flat_idx].item()
log_pred = preds[flat_idx].item()
line = f" - p{int(q*100):02d}: {pred_ret * 100:>8.2f}% (raw log-val: {log_pred:7.4f})"
if reference_preds is not None:
ref_ret = unlog_transform(reference_preds)[flat_idx].item()
line += f" | Delta vs Full: {(pred_ret - ref_ret) * 100:+7.2f}%"
print(line)
print("=============================================\n")
def resolve_sample_index(dataset, sample_idx_arg, rng):
if sample_idx_arg is not None:
if isinstance(sample_idx_arg, str) and not sample_idx_arg.isdigit():
found_idx = next((i for i, m in enumerate(dataset.sampled_mints) if m['mint_address'] == sample_idx_arg), None)
if found_idx is None:
raise ValueError(f"Mint address {sample_idx_arg} not found in filtered dataset")
return found_idx
resolved = int(sample_idx_arg)
if resolved >= len(dataset):
raise ValueError(f"Sample index {resolved} out of range")
return resolved
return rng.randint(0, len(dataset.sampled_mints) - 1)
def move_batch_to_device(batch, device):
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to(device)
elif isinstance(v, list) and len(v) > 0 and isinstance(v[0], torch.Tensor):
batch[k] = [t.to(device) for t in v]
if 'textual_event_indices' not in batch:
B, L = batch['event_type_ids'].shape
batch['textual_event_indices'] = torch.zeros((B, L), dtype=torch.long, device=device)
if 'textual_event_data' not in batch:
batch['textual_event_data'] = []
return batch
def init_aggregate(horizons_seconds, quantiles):
return {
"count": 0,
"quality_full_sum": 0.0,
"quality_abl_sum": 0.0,
"quality_delta_sum": 0.0,
"gt_quality_sum": 0.0,
"per_hq": {
(h, q): {
"full_sum": 0.0,
"abl_sum": 0.0,
"delta_sum": 0.0,
"abs_delta_sum": 0.0,
"gt_sum": 0.0,
"valid_count": 0,
}
for h in horizons_seconds for q in quantiles
},
}
def update_aggregate(stats, full_preds, gt_labels, gt_mask, gt_quality, horizons_seconds, quantiles, ablated_preds=None, full_quality=None, ablated_quality=None):
stats["count"] += 1
if gt_quality is not None:
stats["gt_quality_sum"] += float(gt_quality)
if full_quality is not None:
stats["quality_full_sum"] += float(full_quality.item())
if ablated_quality is not None:
stats["quality_abl_sum"] += float(ablated_quality.item())
if full_quality is not None and ablated_quality is not None:
stats["quality_delta_sum"] += float(ablated_quality.item() - full_quality.item())
full_real = unlog_transform(full_preds)
ablated_real = unlog_transform(ablated_preds) if ablated_preds is not None else None
num_quantiles = len(quantiles)
for h_idx, horizon in enumerate(horizons_seconds):
valid = h_idx < len(gt_mask) and bool(gt_mask[h_idx].item())
gt_ret = float(gt_labels[h_idx].item()) if valid else math.nan
for q_idx, q in enumerate(quantiles):
flat_idx = h_idx * num_quantiles + q_idx
bucket = stats["per_hq"][(horizon, q)]
full_val = float(full_real[flat_idx].item())
bucket["full_sum"] += full_val
if ablated_real is not None:
abl_val = float(ablated_real[flat_idx].item())
delta = abl_val - full_val
bucket["abl_sum"] += abl_val
bucket["delta_sum"] += delta
bucket["abs_delta_sum"] += abs(delta)
if valid:
bucket["gt_sum"] += gt_ret
bucket["valid_count"] += 1
def print_aggregate_summary(stats, horizons_seconds, quantiles, ablation_mode):
n = stats["count"]
print("\n================== Aggregate Summary ==================")
print(f"Evaluated Samples: {n}")
if n == 0:
print("No valid samples collected.")
print("=======================================================\n")
return
if ablation_mode != "none":
print(
f"Quality Mean: full={stats['quality_full_sum'] / n:.6f} | "
f"ablated={stats['quality_abl_sum'] / n:.6f} | "
f"delta={stats['quality_delta_sum'] / n:+.6f}"
)
for horizon in horizons_seconds:
horizon_min = horizon // 60
print(f"\n--- Horizon: {horizon}s ({horizon_min}m) ---")
valid_counts = [stats["per_hq"][(horizon, q)]["valid_count"] for q in quantiles]
valid_count = max(valid_counts) if valid_counts else 0
if valid_count > 0:
gt_mean = stats["per_hq"][(horizon, quantiles[0])]["gt_sum"] / valid_count
print(f" Mean Ground Truth over valid labels: {gt_mean * 100:.2f}% (n={valid_count})")
else:
print(" Mean Ground Truth over valid labels: N/A")
for q in quantiles:
bucket = stats["per_hq"][(horizon, q)]
full_mean = bucket["full_sum"] / n
line = f" p{int(q*100):02d} mean full: {full_mean * 100:>8.2f}%"
if ablation_mode != "none":
abl_mean = bucket["abl_sum"] / n
delta_mean = bucket["delta_sum"] / n
abs_delta_mean = bucket["abs_delta_sum"] / n
line += (
f" | ablated: {abl_mean * 100:>8.2f}%"
f" | delta: {delta_mean * 100:+8.2f}%"
f" | mean|delta|: {abs_delta_mean * 100:>8.2f}%"
)
print(line)
print("=======================================================\n")
def summarize_influence_score(stats, horizons_seconds, quantiles):
n = stats["count"]
if n == 0:
return 0.0
total = 0.0
denom = 0
for horizon in horizons_seconds:
for q in quantiles:
total += stats["per_hq"][(horizon, q)]["abs_delta_sum"] / n
denom += 1
return total / max(denom, 1)
def print_probe_summary(mode_to_stats, horizons_seconds, quantiles):
rankings = []
for mode in OHLC_PROBE_MODES:
score = summarize_influence_score(mode_to_stats[mode], horizons_seconds, quantiles)
rankings.append((mode, score))
rankings.sort(key=lambda x: x[1], reverse=True)
print("\n================== OHLC Probe Ranking ==================")
for rank, (mode, score) in enumerate(rankings, start=1):
print(f"{rank:>2}. {mode:<20} mean|delta| = {score * 100:8.2f}%")
print("========================================================\n")
for mode, _ in rankings:
print_aggregate_summary(mode_to_stats[mode], horizons_seconds, quantiles, mode)
def get_latest_checkpoint(checkpoint_dir):
ckpt_dir = Path(checkpoint_dir)
if ckpt_dir.exists():
dirs = [d for d in ckpt_dir.iterdir() if d.is_dir()]
if dirs:
dirs.sort(key=lambda x: x.stat().st_mtime)
latest_checkpoint = dirs[-1]
return str(latest_checkpoint)
return None
def main():
load_dotenv()
args = parse_args()
rng = random.Random(args.seed)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
accelerator = Accelerator(mixed_precision=args.mixed_precision)
device = accelerator.device
init_dtype = torch.float32
if accelerator.mixed_precision == 'bf16':
init_dtype = torch.bfloat16
elif accelerator.mixed_precision == 'fp16':
init_dtype = torch.float16
print("INFO: Initializing DB Connections for LIVE evaluation...")
clickhouse_host = os.getenv("CLICKHOUSE_HOST", "localhost")
clickhouse_port = int(os.getenv("CLICKHOUSE_PORT", 9000))
neo4j_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687")
neo4j_user = os.getenv("NEO4J_USER", "neo4j")
neo4j_password = os.getenv("NEO4J_PASSWORD", "password")
clickhouse_client = ClickHouseClient(host=clickhouse_host, port=clickhouse_port)
neo4j_driver = GraphDatabase.driver(neo4j_uri, auth=(neo4j_user, neo4j_password))
data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
print(f"Loading live dataset generator...")
# We inject the data fetcher directly. No cache directories are used.
dataset = OracleDataset(
data_fetcher=data_fetcher,
fetcher_config=None,
horizons_seconds=args.horizons_seconds,
quantiles=args.quantiles,
cache_dir=None
)
# Filter out manipulated/broken tokens and optionally enforce min_class
from models.vocabulary import MANIPULATED_CLASS_ID
print("INFO: Fetching Return Classification Map...")
return_class_map, _ = get_return_class_map(clickhouse_client)
min_class_thresh = args.min_class if args.min_class is not None else 0
original_len = len(dataset.sampled_mints)
dataset.sampled_mints = [
m for m in dataset.sampled_mints
if return_class_map.get(m['mint_address']) is not None
and return_class_map.get(m['mint_address']) != MANIPULATED_CLASS_ID
and return_class_map.get(m['mint_address']) >= min_class_thresh
]
dataset.num_samples = len(dataset.sampled_mints)
print(f"INFO: Filtered tokens. {original_len} -> {len(dataset.sampled_mints)} valid tokens (class >= {min_class_thresh}).")
if len(dataset) == 0:
raise ValueError("Dataset is empty. Are ClickHouse data and trade pipelines populated? (Check if min_return filtered everything out)")
# Initialize encoders and model FIRST because we need multi_modal_encoder to compile context
print("Initializing encoders...")
multi_modal_encoder = MultiModalEncoder(dtype=init_dtype, device=device)
time_encoder = ContextualTimeEncoder(dtype=init_dtype)
token_encoder = TokenEncoder(multi_dim=multi_modal_encoder.embedding_dim, dtype=init_dtype)
wallet_encoder = WalletEncoder(encoder=multi_modal_encoder, dtype=init_dtype)
graph_updater = GraphUpdater(time_encoder=time_encoder, dtype=init_dtype)
ohlc_embedder = OHLCEmbedder(num_intervals=vocab.NUM_OHLC_INTERVALS, dtype=init_dtype)
quant_ohlc_embedder = QuantOHLCEmbedder(
num_features=NUM_QUANT_OHLC_FEATURES,
sequence_length=TOKENS_PER_SEGMENT,
dtype=init_dtype,
)
collator = MemecoinCollator(
event_type_to_id=vocab.EVENT_TO_ID,
device=device,
dtype=init_dtype,
max_seq_len=4096
)
print("Initializing model...")
model = Oracle(
token_encoder=token_encoder,
wallet_encoder=wallet_encoder,
graph_updater=graph_updater,
ohlc_embedder=ohlc_embedder,
quant_ohlc_embedder=quant_ohlc_embedder,
time_encoder=time_encoder,
num_event_types=vocab.NUM_EVENT_TYPES,
multi_modal_dim=multi_modal_encoder.embedding_dim,
event_pad_id=vocab.EVENT_TO_ID["__PAD__"],
event_type_to_id=vocab.EVENT_TO_ID,
model_config_name="llama3-12l-768d-gqa4-8k-random",
quantiles=args.quantiles,
horizons_seconds=args.horizons_seconds,
dtype=init_dtype
)
if hasattr(model.model, 'embed_tokens'):
del model.model.embed_tokens
# Load checkpoint
ckpt_path = args.checkpoint
if ckpt_path.endswith("latest"):
base_dir = Path(ckpt_path).parent
found = get_latest_checkpoint(base_dir)
if found:
ckpt_path = found
if not os.path.exists(ckpt_path):
print(f"Warning: Checkpoint {ckpt_path} not found. Running with random weights!")
model = accelerator.prepare(model)
else:
print(f"Loading checkpoint from {ckpt_path}...")
model = accelerator.prepare(model)
try:
accelerator.load_state(ckpt_path)
print("Successfully loaded accelerator state.")
except Exception as e:
print(f"Could not load using accelerate.load_state: {e}")
print("Trying to load model weights directly...")
model_file = os.path.join(ckpt_path, "pytorch_model.bin")
if not os.path.exists(model_file):
model_file = os.path.join(ckpt_path, "model.safetensors")
if os.path.exists(model_file):
if model_file.endswith(".safetensors"):
from safetensors.torch import load_file
state_dict = load_file(model_file)
else:
state_dict = torch.load(model_file, map_location="cpu")
uw_model = accelerator.unwrap_model(model)
uw_model.load_state_dict(state_dict, strict=False)
print("Successfully loaded weights directly.")
else:
print(f"Error: model weights not found in {ckpt_path}")
model.eval()
stats = init_aggregate(args.horizons_seconds, args.quantiles)
selected_modes = [] if args.ablation == "none" else (ABLATION_SWEEP_MODES if args.ablation == "sweep" else ([] if args.ablation == "ohlc_probe" else [args.ablation]))
mode_to_stats = {mode: init_aggregate(args.horizons_seconds, args.quantiles) for mode in selected_modes}
probe_to_stats = {mode: init_aggregate(args.horizons_seconds, args.quantiles) for mode in OHLC_PROBE_MODES} if args.ablation == "ohlc_probe" else {}
max_target_samples = max(1, args.num_samples)
retries = 0
collected = 0
seen_indices = set()
while collected < max_target_samples and retries < args.max_retries:
sample_idx = resolve_sample_index(dataset, args.sample_idx, rng)
if args.sample_idx is None and sample_idx in seen_indices and len(seen_indices) < len(dataset.sampled_mints):
retries += 1
continue
seen_indices.add(sample_idx)
sample_mint_addr = dataset.sampled_mints[sample_idx]['mint_address']
print(f"Trying Token Address: {sample_mint_addr}")
contexts = dataset.__cacheitem_context__(
sample_idx,
num_samples_per_token=1,
encoder=multi_modal_encoder,
forced_cutoff_trade_idx=args.cutoff_trade_idx,
)
if not contexts or contexts[0] is None:
print(" [Failed to generate valid context pattern, skipping...]")
retries += 1
if args.sample_idx is not None:
print("Specific sample requested but failed to generate context. Exiting.")
return
continue
raw_sample = contexts[0]
batch = move_batch_to_device(collator([raw_sample]), device)
gt_labels = batch["labels"][0].cpu()
gt_mask = batch["labels_mask"][0].cpu().bool()
gt_quality = batch["quality_score"][0].item() if "quality_score" in batch else None
if collected == 0 or args.show_each:
print(f"\nEvaluating sample {collected + 1}/{max_target_samples} on Token Address: {sample_mint_addr}")
print("\n--- Running Inference ---")
full_preds, full_quality, full_direction = run_inference(model, batch)
ablation_outputs = {}
for mode in selected_modes:
ablated_batch = apply_ablation(batch, mode, device)
ablated_preds, ablated_quality, ablated_direction = run_inference(model, ablated_batch)
ablation_outputs[mode] = (ablated_batch, ablated_preds, ablated_quality, ablated_direction)
probe_outputs = {}
if args.ablation == "ohlc_probe":
for mode in OHLC_PROBE_MODES:
probe_batch = apply_ohlc_probe(batch, mode)
probe_preds, probe_quality, probe_direction = run_inference(model, probe_batch)
probe_outputs[mode] = (probe_batch, probe_preds, probe_quality, probe_direction)
if collected == 0 or args.show_each:
print_results(
title="Full Results",
batch=batch,
preds=full_preds,
quality_pred=full_quality,
movement_pred=full_direction,
gt_labels=gt_labels,
gt_mask=gt_mask,
gt_quality=gt_quality,
horizons_seconds=args.horizons_seconds,
quantiles=args.quantiles,
)
if args.ablation != "none":
if args.ablation == "sweep":
print(f"Collected full predictions for {len(selected_modes)} ablation families on this sample. Aggregate ranking will be printed at the end.")
elif args.ablation == "ohlc_probe":
for mode in OHLC_PROBE_MODES:
probe_batch, probe_preds, probe_quality, probe_direction = probe_outputs[mode]
print_results(
title=f"OHLC Probe ({mode})",
batch=probe_batch,
preds=probe_preds,
quality_pred=probe_quality,
movement_pred=probe_direction,
gt_labels=gt_labels,
gt_mask=gt_mask,
gt_quality=gt_quality,
horizons_seconds=args.horizons_seconds,
quantiles=args.quantiles,
reference_preds=full_preds,
reference_quality=full_quality,
)
else:
ablated_batch, ablated_preds, ablated_quality, ablated_direction = ablation_outputs[args.ablation]
print_results(
title=f"Ablation Results ({args.ablation})",
batch=ablated_batch,
preds=ablated_preds,
quality_pred=ablated_quality,
movement_pred=ablated_direction,
gt_labels=gt_labels,
gt_mask=gt_mask,
gt_quality=gt_quality,
horizons_seconds=args.horizons_seconds,
quantiles=args.quantiles,
reference_preds=full_preds,
reference_quality=full_quality,
)
update_aggregate(
stats=stats,
full_preds=full_preds,
gt_labels=gt_labels,
gt_mask=gt_mask,
gt_quality=gt_quality,
horizons_seconds=args.horizons_seconds,
quantiles=args.quantiles,
full_quality=full_quality,
)
for mode, (_, ablated_preds, ablated_quality, _) in ablation_outputs.items():
update_aggregate(
stats=mode_to_stats[mode],
full_preds=full_preds,
gt_labels=gt_labels,
gt_mask=gt_mask,
gt_quality=gt_quality,
horizons_seconds=args.horizons_seconds,
quantiles=args.quantiles,
ablated_preds=ablated_preds,
full_quality=full_quality,
ablated_quality=ablated_quality,
)
for mode, (_, probe_preds, probe_quality, _) in probe_outputs.items():
update_aggregate(
stats=probe_to_stats[mode],
full_preds=full_preds,
gt_labels=gt_labels,
gt_mask=gt_mask,
gt_quality=gt_quality,
horizons_seconds=args.horizons_seconds,
quantiles=args.quantiles,
ablated_preds=probe_preds,
full_quality=full_quality,
ablated_quality=probe_quality,
)
collected += 1
retries += 1
if args.sample_idx is not None:
break
if collected == 0:
print(f"Could not find a valid context after {args.max_retries} attempts.")
return
if collected < max_target_samples:
print(f"WARNING: Requested {max_target_samples} samples but only evaluated {collected}.")
if args.ablation == "none":
print_aggregate_summary(stats, args.horizons_seconds, args.quantiles, args.ablation)
return
if args.ablation == "ohlc_probe":
print_probe_summary(probe_to_stats, args.horizons_seconds, args.quantiles)
return
if args.ablation == "sweep":
rankings = []
for mode in selected_modes:
score = summarize_influence_score(mode_to_stats[mode], args.horizons_seconds, args.quantiles)
rankings.append((mode, score))
rankings.sort(key=lambda x: x[1], reverse=True)
print("\n================== Influence Ranking ==================")
for rank, (mode, score) in enumerate(rankings, start=1):
print(f"{rank:>2}. {mode:<12} mean|delta| = {score * 100:8.2f}%")
print("=======================================================\n")
for mode, _ in rankings:
print_aggregate_summary(mode_to_stats[mode], args.horizons_seconds, args.quantiles, mode)
else:
print_aggregate_summary(mode_to_stats[args.ablation], args.horizons_seconds, args.quantiles, args.ablation)
if __name__ == "__main__":
main()