| import argparse |
| import json |
| import os |
| import sys |
| from pathlib import Path |
| from typing import Any, Dict, List |
|
|
| import torch |
|
|
| |
| REPO_ROOT = Path(__file__).resolve().parents[1] |
| sys.path.insert(0, str(REPO_ROOT)) |
|
|
| from data.data_loader import OracleDataset |
| from data.data_fetcher import DataFetcher |
| from data.data_collator import MemecoinCollator |
| import models.vocabulary as vocab |
|
|
|
|
| def _decode_events(event_type_ids: torch.Tensor) -> List[str]: |
| names = [] |
| for eid in event_type_ids.tolist(): |
| if eid == 0: |
| names.append("__PAD__") |
| else: |
| names.append(vocab.ID_TO_EVENT.get(eid, f"UNK_{eid}")) |
| return names |
|
|
|
|
| def _tensor_to_list(t: torch.Tensor) -> List: |
| return t.detach().cpu().tolist() |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description="Inspect MemecoinCollator outputs on cached samples.") |
| parser.add_argument("--cache_dir", type=str, default="data/cache") |
| parser.add_argument("--idx", type=int, nargs="+", default=[0], help="Sample indices to inspect") |
| parser.add_argument("--max_seq_len", type=int, default=16000) |
| parser.add_argument("--out", type=str, default="collator_dump.json") |
| args = parser.parse_args() |
|
|
| cache_dir = Path(args.cache_dir) |
| |
| import os |
| from dotenv import load_dotenv |
| from clickhouse_driver import Client as ClickHouseClient |
| from neo4j import GraphDatabase |
|
|
| load_dotenv() |
| clickhouse_host = os.getenv("CLICKHOUSE_HOST", "localhost") |
| clickhouse_port = int(os.getenv("CLICKHOUSE_NATIVE_PORT", os.getenv("CLICKHOUSE_PORT", 9000))) |
| neo4j_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687") |
| neo4j_user = os.getenv("NEO4J_USER", "neo4j") |
| neo4j_password = os.getenv("NEO4J_PASSWORD", "password") |
| clickhouse_client = ClickHouseClient(host=clickhouse_host, port=clickhouse_port) |
| neo4j_driver = GraphDatabase.driver(neo4j_uri, auth=(neo4j_user, neo4j_password)) |
| data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver) |
|
|
| dataset = OracleDataset( |
| data_fetcher=data_fetcher, |
| cache_dir=str(cache_dir), |
| horizons_seconds=[30, 60, 120, 240, 420], |
| quantiles=[0.1, 0.5, 0.9], |
| max_samples=None, |
| max_seq_len=args.max_seq_len, |
| ) |
| if hasattr(dataset, "init_fetcher"): |
| dataset.init_fetcher() |
|
|
| collator = MemecoinCollator( |
| event_type_to_id=vocab.EVENT_TO_ID, |
| device=torch.device("cpu"), |
| dtype=torch.float32, |
| max_seq_len=args.max_seq_len, |
| ) |
|
|
| batch_items = [dataset[i] for i in args.idx] |
| batch = collator(batch_items) |
|
|
| |
| dump: Dict[str, Any] = { |
| "batch_size": len(args.idx), |
| "token_addresses": batch.get("token_addresses"), |
| "t_cutoffs": batch.get("t_cutoffs"), |
| "sample_indices": batch.get("sample_indices"), |
| "raw_events": [item.get("event_sequence", []) for item in batch_items], |
| } |
| |
| event_counts = [] |
| for item in batch_items: |
| counts: Dict[str, int] = {} |
| for ev in item.get("event_sequence", []): |
| et = ev.get("event_type", "UNKNOWN") |
| counts[et] = counts.get(et, 0) + 1 |
| event_counts.append(counts) |
| dump["raw_event_counts"] = event_counts |
|
|
| |
| dump["event_type_ids"] = _tensor_to_list(batch["event_type_ids"]) |
| dump["event_type_names"] = [ |
| _decode_events(batch["event_type_ids"][i].cpu()) |
| for i in range(batch["event_type_ids"].shape[0]) |
| ] |
| dump["timestamps_float"] = _tensor_to_list(batch["timestamps_float"]) |
| dump["relative_ts"] = _tensor_to_list(batch["relative_ts"]) |
| dump["attention_mask"] = _tensor_to_list(batch["attention_mask"]) |
| dump["wallet_addr_to_batch_idx"] = batch.get("wallet_addr_to_batch_idx", {}) |
|
|
| |
| for key in [ |
| "wallet_indices", |
| "token_indices", |
| "quote_token_indices", |
| "trending_token_indices", |
| "boosted_token_indices", |
| "dest_wallet_indices", |
| "original_author_indices", |
| "ohlc_indices", |
| "holder_snapshot_indices", |
| "textual_event_indices", |
| ]: |
| if key in batch: |
| dump[key] = _tensor_to_list(batch[key]) |
|
|
| |
| nonzero_summary = {} |
| for key in [ |
| "transfer_numerical_features", |
| "trade_numerical_features", |
| "deployer_trade_numerical_features", |
| "smart_wallet_trade_numerical_features", |
| "pool_created_numerical_features", |
| "liquidity_change_numerical_features", |
| "fee_collected_numerical_features", |
| "token_burn_numerical_features", |
| "supply_lock_numerical_features", |
| "onchain_snapshot_numerical_features", |
| "trending_token_numerical_features", |
| "boosted_token_numerical_features", |
| "dexboost_paid_numerical_features", |
| "dexprofile_updated_flags", |
| "global_trending_numerical_features", |
| "chainsnapshot_numerical_features", |
| "lighthousesnapshot_numerical_features", |
| ]: |
| if key in batch: |
| t = batch[key] |
| dump[key] = _tensor_to_list(t) |
| nonzero_summary[key] = int(torch.count_nonzero(t).item()) |
|
|
| |
| for key in [ |
| "trade_dex_ids", |
| "trade_direction_ids", |
| "trade_mev_protection_ids", |
| "trade_is_bundle_ids", |
| "pool_created_protocol_ids", |
| "liquidity_change_type_ids", |
| "trending_token_source_ids", |
| "trending_token_timeframe_ids", |
| "lighthousesnapshot_protocol_ids", |
| "lighthousesnapshot_timeframe_ids", |
| "migrated_protocol_ids", |
| "alpha_group_ids", |
| "channel_ids", |
| "exchange_ids", |
| ]: |
| if key in batch: |
| t = batch[key] |
| dump[key] = _tensor_to_list(t) |
| nonzero_summary[key] = int(torch.count_nonzero(t).item()) |
|
|
| |
| if batch.get("labels") is not None: |
| dump["labels"] = _tensor_to_list(batch["labels"]) |
| if batch.get("labels_mask") is not None: |
| dump["labels_mask"] = _tensor_to_list(batch["labels_mask"]) |
| if batch.get("quality_score") is not None: |
| dump["quality_score"] = _tensor_to_list(batch["quality_score"]) |
|
|
| dump["nonzero_summary"] = nonzero_summary |
|
|
| |
| wallet_inputs = batch.get("wallet_encoder_inputs", {}) |
| token_inputs = batch.get("token_encoder_inputs", {}) |
| dump["wallet_encoder_inputs"] = { |
| "profile_rows": wallet_inputs.get("profile_rows", []), |
| "social_rows": wallet_inputs.get("social_rows", []), |
| "holdings_batch": wallet_inputs.get("holdings_batch", []), |
| "username_embed_indices": _tensor_to_list(wallet_inputs.get("username_embed_indices")) if "username_embed_indices" in wallet_inputs else [], |
| } |
| dump["token_encoder_inputs"] = { |
| "addresses_for_lookup": token_inputs.get("_addresses_for_lookup", []), |
| "protocol_ids": _tensor_to_list(token_inputs.get("protocol_ids")) if "protocol_ids" in token_inputs else [], |
| "is_vanity_flags": _tensor_to_list(token_inputs.get("is_vanity_flags")) if "is_vanity_flags" in token_inputs else [], |
| "name_embed_indices": _tensor_to_list(token_inputs.get("name_embed_indices")) if "name_embed_indices" in token_inputs else [], |
| "symbol_embed_indices": _tensor_to_list(token_inputs.get("symbol_embed_indices")) if "symbol_embed_indices" in token_inputs else [], |
| "image_embed_indices": _tensor_to_list(token_inputs.get("image_embed_indices")) if "image_embed_indices" in token_inputs else [], |
| } |
| dump["wallet_set_encoder_inputs"] = { |
| "holdings_batch": wallet_inputs.get("holdings_batch", []), |
| "token_vibe_lookup_keys": token_inputs.get("_addresses_for_lookup", []), |
| } |
|
|
| out_path = Path(args.out) |
| def _json_default(o): |
| if isinstance(o, (str, int, float, bool)) or o is None: |
| return o |
| try: |
| import datetime as _dt |
| if isinstance(o, (_dt.datetime, _dt.date)): |
| return o.isoformat() |
| except Exception: |
| pass |
| try: |
| return str(o) |
| except Exception: |
| return "<unserializable>" |
|
|
| with out_path.open("w") as f: |
| json.dump(dump, f, indent=2, default=_json_default) |
|
|
| print(f"Wrote collator dump to {out_path.resolve()}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|