oracle / scripts /inspect_collator.py
zirobtc's picture
Upload folder using huggingface_hub
86523f8
import argparse
import json
import os
import sys
from pathlib import Path
from typing import Any, Dict, List
import torch
# Ensure repo root is on sys.path
REPO_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(REPO_ROOT))
from data.data_loader import OracleDataset
from data.data_fetcher import DataFetcher
from data.data_collator import MemecoinCollator
import models.vocabulary as vocab
def _decode_events(event_type_ids: torch.Tensor) -> List[str]:
names = []
for eid in event_type_ids.tolist():
if eid == 0:
names.append("__PAD__")
else:
names.append(vocab.ID_TO_EVENT.get(eid, f"UNK_{eid}"))
return names
def _tensor_to_list(t: torch.Tensor) -> List:
return t.detach().cpu().tolist()
def main() -> None:
parser = argparse.ArgumentParser(description="Inspect MemecoinCollator outputs on cached samples.")
parser.add_argument("--cache_dir", type=str, default="data/cache")
parser.add_argument("--idx", type=int, nargs="+", default=[0], help="Sample indices to inspect")
parser.add_argument("--max_seq_len", type=int, default=16000)
parser.add_argument("--out", type=str, default="collator_dump.json")
args = parser.parse_args()
cache_dir = Path(args.cache_dir)
# Optional: enable time-aware fetches if DB env is set.
import os
from dotenv import load_dotenv
from clickhouse_driver import Client as ClickHouseClient
from neo4j import GraphDatabase
load_dotenv()
clickhouse_host = os.getenv("CLICKHOUSE_HOST", "localhost")
clickhouse_port = int(os.getenv("CLICKHOUSE_NATIVE_PORT", os.getenv("CLICKHOUSE_PORT", 9000)))
neo4j_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687")
neo4j_user = os.getenv("NEO4J_USER", "neo4j")
neo4j_password = os.getenv("NEO4J_PASSWORD", "password")
clickhouse_client = ClickHouseClient(host=clickhouse_host, port=clickhouse_port)
neo4j_driver = GraphDatabase.driver(neo4j_uri, auth=(neo4j_user, neo4j_password))
data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
dataset = OracleDataset(
data_fetcher=data_fetcher,
cache_dir=str(cache_dir),
horizons_seconds=[30, 60, 120, 240, 420],
quantiles=[0.1, 0.5, 0.9],
max_samples=None,
max_seq_len=args.max_seq_len,
)
if hasattr(dataset, "init_fetcher"):
dataset.init_fetcher()
collator = MemecoinCollator(
event_type_to_id=vocab.EVENT_TO_ID,
device=torch.device("cpu"),
dtype=torch.float32,
max_seq_len=args.max_seq_len,
)
batch_items = [dataset[i] for i in args.idx]
batch = collator(batch_items)
# Build JSON-friendly dump (no truncation of events; embeddings are omitted)
dump: Dict[str, Any] = {
"batch_size": len(args.idx),
"token_addresses": batch.get("token_addresses"),
"t_cutoffs": batch.get("t_cutoffs"),
"sample_indices": batch.get("sample_indices"),
"raw_events": [item.get("event_sequence", []) for item in batch_items],
}
# Raw event type counts
event_counts = []
for item in batch_items:
counts: Dict[str, int] = {}
for ev in item.get("event_sequence", []):
et = ev.get("event_type", "UNKNOWN")
counts[et] = counts.get(et, 0) + 1
event_counts.append(counts)
dump["raw_event_counts"] = event_counts
# Core sequence + features (full length)
dump["event_type_ids"] = _tensor_to_list(batch["event_type_ids"])
dump["event_type_names"] = [
_decode_events(batch["event_type_ids"][i].cpu())
for i in range(batch["event_type_ids"].shape[0])
]
dump["timestamps_float"] = _tensor_to_list(batch["timestamps_float"])
dump["relative_ts"] = _tensor_to_list(batch["relative_ts"])
dump["attention_mask"] = _tensor_to_list(batch["attention_mask"])
dump["wallet_addr_to_batch_idx"] = batch.get("wallet_addr_to_batch_idx", {})
# Pointer tensors
for key in [
"wallet_indices",
"token_indices",
"quote_token_indices",
"trending_token_indices",
"boosted_token_indices",
"dest_wallet_indices",
"original_author_indices",
"ohlc_indices",
"holder_snapshot_indices",
"textual_event_indices",
]:
if key in batch:
dump[key] = _tensor_to_list(batch[key])
# Numerical feature tensors
nonzero_summary = {}
for key in [
"transfer_numerical_features",
"trade_numerical_features",
"deployer_trade_numerical_features",
"smart_wallet_trade_numerical_features",
"pool_created_numerical_features",
"liquidity_change_numerical_features",
"fee_collected_numerical_features",
"token_burn_numerical_features",
"supply_lock_numerical_features",
"onchain_snapshot_numerical_features",
"trending_token_numerical_features",
"boosted_token_numerical_features",
"dexboost_paid_numerical_features",
"dexprofile_updated_flags",
"global_trending_numerical_features",
"chainsnapshot_numerical_features",
"lighthousesnapshot_numerical_features",
]:
if key in batch:
t = batch[key]
dump[key] = _tensor_to_list(t)
nonzero_summary[key] = int(torch.count_nonzero(t).item())
# Categorical feature tensors
for key in [
"trade_dex_ids",
"trade_direction_ids",
"trade_mev_protection_ids",
"trade_is_bundle_ids",
"pool_created_protocol_ids",
"liquidity_change_type_ids",
"trending_token_source_ids",
"trending_token_timeframe_ids",
"lighthousesnapshot_protocol_ids",
"lighthousesnapshot_timeframe_ids",
"migrated_protocol_ids",
"alpha_group_ids",
"channel_ids",
"exchange_ids",
]:
if key in batch:
t = batch[key]
dump[key] = _tensor_to_list(t)
nonzero_summary[key] = int(torch.count_nonzero(t).item())
# Labels
if batch.get("labels") is not None:
dump["labels"] = _tensor_to_list(batch["labels"])
if batch.get("labels_mask") is not None:
dump["labels_mask"] = _tensor_to_list(batch["labels_mask"])
if batch.get("quality_score") is not None:
dump["quality_score"] = _tensor_to_list(batch["quality_score"])
dump["nonzero_summary"] = nonzero_summary
# Raw wallet/token feature payloads used by encoders
wallet_inputs = batch.get("wallet_encoder_inputs", {})
token_inputs = batch.get("token_encoder_inputs", {})
dump["wallet_encoder_inputs"] = {
"profile_rows": wallet_inputs.get("profile_rows", []),
"social_rows": wallet_inputs.get("social_rows", []),
"holdings_batch": wallet_inputs.get("holdings_batch", []),
"username_embed_indices": _tensor_to_list(wallet_inputs.get("username_embed_indices")) if "username_embed_indices" in wallet_inputs else [],
}
dump["token_encoder_inputs"] = {
"addresses_for_lookup": token_inputs.get("_addresses_for_lookup", []),
"protocol_ids": _tensor_to_list(token_inputs.get("protocol_ids")) if "protocol_ids" in token_inputs else [],
"is_vanity_flags": _tensor_to_list(token_inputs.get("is_vanity_flags")) if "is_vanity_flags" in token_inputs else [],
"name_embed_indices": _tensor_to_list(token_inputs.get("name_embed_indices")) if "name_embed_indices" in token_inputs else [],
"symbol_embed_indices": _tensor_to_list(token_inputs.get("symbol_embed_indices")) if "symbol_embed_indices" in token_inputs else [],
"image_embed_indices": _tensor_to_list(token_inputs.get("image_embed_indices")) if "image_embed_indices" in token_inputs else [],
}
dump["wallet_set_encoder_inputs"] = {
"holdings_batch": wallet_inputs.get("holdings_batch", []),
"token_vibe_lookup_keys": token_inputs.get("_addresses_for_lookup", []),
}
out_path = Path(args.out)
def _json_default(o):
if isinstance(o, (str, int, float, bool)) or o is None:
return o
try:
import datetime as _dt
if isinstance(o, (_dt.datetime, _dt.date)):
return o.isoformat()
except Exception:
pass
try:
return str(o)
except Exception:
return "<unserializable>"
with out_path.open("w") as f:
json.dump(dump, f, indent=2, default=_json_default)
print(f"Wrote collator dump to {out_path.resolve()}")
if __name__ == "__main__":
main()