|
|
| import os |
| import sys |
| import argparse |
|
|
| import datetime |
| import torch |
| import json |
| import math |
| from pathlib import Path |
| from tqdm import tqdm |
| from dotenv import load_dotenv |
| import huggingface_hub |
| import logging |
| from concurrent.futures import ProcessPoolExecutor, as_completed |
| import multiprocessing as mp |
|
|
| logging.getLogger("httpx").setLevel(logging.WARNING) |
| logging.getLogger("transformers").setLevel(logging.ERROR) |
| logging.getLogger("huggingface_hub").setLevel(logging.WARNING) |
|
|
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
| from scripts.analyze_distribution import get_return_class_map |
| from scripts.compute_quality_score import get_token_quality_scores, fetch_token_metrics, _bucket_id, _midrank_percentiles, EPS |
| from data.data_loader import summarize_context_window |
|
|
| from clickhouse_driver import Client as ClickHouseClient |
| from neo4j import GraphDatabase |
|
|
| _worker_dataset = None |
| _worker_return_class_map = None |
| _worker_quality_scores_map = None |
|
|
|
|
| def _build_context_quota_plan( |
| class_ids, |
| target_contexts_per_class, |
| target_contexts_total, |
| good_ratio_nonzero, |
| good_ratio_class0, |
| ): |
| unique_class_ids = sorted(set(int(cid) for cid in class_ids)) |
| if not unique_class_ids: |
| return {} |
|
|
| if target_contexts_per_class is not None: |
| per_class_target = int(target_contexts_per_class) |
| elif target_contexts_total is not None: |
| per_class_target = max(1, int(target_contexts_total) // len(unique_class_ids)) |
| else: |
| return {} |
|
|
| if per_class_target <= 0: |
| raise RuntimeError("Context quota target must be positive.") |
|
|
| plan = {} |
| for class_id in unique_class_ids: |
| ratio = float(good_ratio_class0 if class_id == 0 else good_ratio_nonzero) |
| ratio = max(0.0, min(1.0, ratio)) |
| good_target = int(round(per_class_target * ratio)) |
| bad_target = per_class_target - good_target |
| plan[class_id] = { |
| "total_target": per_class_target, |
| "good_target": good_target, |
| "bad_target": bad_target, |
| } |
| return plan |
|
|
|
|
| def _should_accept_context(class_id, context_bucket, accepted_counts, quota_plan): |
| if not quota_plan: |
| return True |
|
|
| if class_id not in quota_plan: |
| return False |
|
|
| class_plan = quota_plan[class_id] |
| class_counts = accepted_counts[class_id] |
| if class_counts["total"] >= class_plan["total_target"]: |
| return False |
|
|
| bucket_key = "good" if context_bucket == "good" else "bad" |
| target_key = f"{bucket_key}_target" |
| if class_counts[bucket_key] >= class_plan[target_key]: |
| return False |
|
|
| return True |
|
|
|
|
| def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map): |
| global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map |
| from data.data_loader import OracleDataset |
| from data.data_fetcher import DataFetcher |
|
|
| clickhouse_client = ClickHouseClient(host=db_config['clickhouse_host'], port=db_config['clickhouse_port']) |
| neo4j_driver = GraphDatabase.driver(db_config['neo4j_uri'], auth=(db_config['neo4j_user'], db_config['neo4j_password'])) |
| data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver) |
|
|
| _worker_dataset = OracleDataset( |
| data_fetcher=data_fetcher, |
| min_trades=dataset_config['min_trades'], |
| start_date=dataset_config['start_date'], |
| horizons_seconds=dataset_config['horizons_seconds'], |
| quantiles=dataset_config['quantiles'], |
| min_trade_usd=dataset_config['min_trade_usd'], |
| max_seq_len=dataset_config['max_seq_len'] |
| ) |
| _worker_dataset.sampled_mints = dataset_config['sampled_mints'] |
| _worker_return_class_map = return_class_map |
| _worker_quality_scores_map = quality_scores_map |
|
|
|
|
| def _process_single_token_context(args): |
| idx, mint_addr, samples_per_token, output_dir = args |
| global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map |
| try: |
| class_id = _worker_return_class_map.get(mint_addr) |
| if class_id is None: |
| return {'status': 'skipped', 'reason': 'not in class map', 'mint': mint_addr} |
| contexts = _worker_dataset.__cacheitem_context__(idx, num_samples_per_token=samples_per_token) |
| if not contexts: |
| return {'status': 'skipped', 'reason': 'no valid contexts', 'mint': mint_addr} |
| q_score = _worker_quality_scores_map.get(mint_addr) |
| if q_score is None: |
| return {'status': 'skipped', 'reason': 'no quality score', 'mint': mint_addr} |
| return { |
| 'status': 'success', |
| 'mint': mint_addr, |
| 'class_id': class_id, |
| 'q_score': q_score, |
| 'n_contexts': len(contexts), |
| 'n_events': len(contexts[0].get('event_sequence', [])) if contexts else 0, |
| 'contexts': contexts, |
| } |
| except Exception as e: |
| import traceback |
| return {'status': 'error', 'mint': mint_addr, 'error': str(e), 'traceback': traceback.format_exc()} |
|
|
|
|
|
|
|
|
| def main(): |
| load_dotenv() |
| mp.set_start_method('spawn', force=True) |
|
|
| hf_token = os.getenv("HF_TOKEN") |
| if hf_token: |
| print(f"INFO: Logging in to Hugging Face...") |
| huggingface_hub.login(token=hf_token) |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument("--output_dir", type=str, default="data/cache") |
| parser.add_argument("--start_date", type=str, default=None) |
|
|
| parser.add_argument("--min_trade_usd", type=float, default=0.0) |
| parser.add_argument("--min_trades", type=int, default=10) |
| parser.add_argument("--context_length", type=int, default=8192) |
| parser.add_argument("--samples_per_token", type=int, default=1) |
| parser.add_argument("--target_contexts_per_class", type=int, default=None) |
| parser.add_argument("--target_contexts_total", type=int, default=None) |
| parser.add_argument("--good_ratio_nonzero", type=float, default=0.5) |
| parser.add_argument("--good_ratio_class0", type=float, default=0.0) |
| parser.add_argument("--num_workers", type=int, default=1) |
| parser.add_argument("--clickhouse_host", type=str, default=os.getenv("CLICKHOUSE_HOST", "localhost")) |
| parser.add_argument("--clickhouse_port", type=int, default=int(os.getenv("CLICKHOUSE_PORT", 9000))) |
| parser.add_argument("--neo4j_uri", type=str, default=os.getenv("NEO4J_URI", "bolt://localhost:7687")) |
| parser.add_argument("--neo4j_user", type=str, default=os.getenv("NEO4J_USER", "neo4j")) |
| parser.add_argument("--neo4j_password", type=str, default=os.getenv("NEO4J_PASSWORD", "password")) |
| args = parser.parse_args() |
|
|
| if args.target_contexts_per_class is not None and args.target_contexts_total is not None: |
| raise RuntimeError( |
| "Choose exactly one cache budget: either --target_contexts_per_class or --target_contexts_total." |
| ) |
|
|
| if args.num_workers == 0: |
| args.num_workers = max(1, mp.cpu_count() - 4) |
|
|
| output_dir = Path(args.output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
| start_date_dt = datetime.datetime.strptime(args.start_date, "%Y-%m-%d") if args.start_date else None |
|
|
| print(f"INFO: Initializing DB Connections...") |
| clickhouse_client = ClickHouseClient(host=args.clickhouse_host, port=args.clickhouse_port) |
| neo4j_driver = GraphDatabase.driver(args.neo4j_uri, auth=(args.neo4j_user, args.neo4j_password)) |
|
|
| try: |
|
|
| from data.data_loader import OracleDataset |
| from data.data_fetcher import DataFetcher |
| data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver) |
|
|
| print("INFO: Fetching Return Classification Map...") |
| return_class_map, _ = get_return_class_map(clickhouse_client) |
| print(f"INFO: Loaded {len(return_class_map)} classified tokens.") |
|
|
| print("INFO: Fetching Quality Scores...") |
| quality_scores_map = get_token_quality_scores(clickhouse_client) |
| print(f"INFO: Loaded {len(quality_scores_map)} quality scores.") |
|
|
| dataset = OracleDataset( |
| data_fetcher=data_fetcher, |
| min_trades=args.min_trades, |
| start_date=start_date_dt, |
| horizons_seconds=[60, 180, 300, 600, 1800, 3600, 7200], |
| quantiles=[0.5], |
| min_trade_usd=args.min_trade_usd, |
| max_seq_len=args.context_length, |
| ) |
|
|
| if len(dataset) == 0: |
| print("WARNING: No samples. Exiting.") |
| return |
|
|
| |
| original_size = len(dataset.sampled_mints) |
| filtered_mints = [m for m in dataset.sampled_mints if m['mint_address'] in return_class_map] |
| print(f"INFO: Filtered {original_size} -> {len(filtered_mints)} tokens") |
|
|
| if len(filtered_mints) == 0: |
| print("WARNING: No tokens after filtering.") |
| return |
|
|
| print(f"INFO: Building canonical context cache | Workers: {args.num_workers}") |
|
|
| if args.num_workers != 1 and ( |
| args.target_contexts_per_class is not None or args.target_contexts_total is not None |
| ): |
| raise RuntimeError( |
| "Quota-driven context caching currently requires --num_workers 1 so accepted contexts " |
| "can be planned and written deterministically in one process." |
| ) |
|
|
| db_config = {'clickhouse_host': args.clickhouse_host, 'clickhouse_port': args.clickhouse_port, 'neo4j_uri': args.neo4j_uri, 'neo4j_user': args.neo4j_user, 'neo4j_password': args.neo4j_password} |
| dataset_config = {'start_date': start_date_dt, 'min_trades': args.min_trades, 'horizons_seconds': [60, 180, 300, 600, 1800, 3600, 7200], 'quantiles': [0.5], 'min_trade_usd': args.min_trade_usd, 'max_seq_len': args.context_length, 'sampled_mints': filtered_mints} |
|
|
| |
| tasks = [] |
| for i, mint_record in enumerate(filtered_mints): |
| mint_addr = mint_record['mint_address'] |
| tasks.append((i, mint_addr, args.samples_per_token, str(output_dir))) |
|
|
| print(f"INFO: Starting to cache {len(tasks)} tokens...") |
|
|
| success_count, skipped_count, error_count = 0, 0, 0 |
| class_distribution = {} |
| context_distribution = defaultdict(lambda: defaultdict(int)) |
| file_class_map = {} |
| file_context_bucket_map = {} |
| file_context_summary_map = {} |
| process_fn = _process_single_token_context |
| quota_plan = {} |
| accepted_counts = defaultdict(lambda: {"total": 0, "good": 0, "bad": 0}) |
| accepted_per_token = defaultdict(int) |
|
|
| quota_plan = _build_context_quota_plan( |
| class_ids=[return_class_map[m['mint_address']] for m in filtered_mints if m['mint_address'] in return_class_map], |
| target_contexts_per_class=args.target_contexts_per_class, |
| target_contexts_total=args.target_contexts_total, |
| good_ratio_nonzero=args.good_ratio_nonzero, |
| good_ratio_class0=args.good_ratio_class0, |
| ) |
| if quota_plan: |
| print("INFO: Context quota plan:") |
| for class_id, plan in sorted(quota_plan.items()): |
| print( |
| f" Class {class_id}: total={plan['total_target']} " |
| f"(good={plan['good_target']}, bad={plan['bad_target']})" |
| ) |
|
|
| if args.num_workers == 1: |
| print("INFO: Single-threaded mode...") |
| _init_worker(db_config, dataset_config, return_class_map, quality_scores_map) |
| for task in tqdm(tasks, desc="Caching"): |
| result = process_fn(task) |
| if result['status'] == 'success': |
| if quota_plan: |
| class_id = result['class_id'] |
| mint_addr = result['mint'] |
| q_score = result['q_score'] |
| saved_any = False |
| for ctx in result.get("contexts", []): |
| context_summary = summarize_context_window(ctx.get("labels"), ctx.get("labels_mask")) |
| context_bucket = context_summary["context_bucket"] |
| if not _should_accept_context(class_id, context_bucket, accepted_counts, quota_plan): |
| continue |
|
|
| ctx["quality_score"] = q_score |
| ctx["class_id"] = class_id |
| ctx["source_token"] = mint_addr |
| ctx["context_bucket"] = context_bucket |
| ctx["context_score"] = context_summary["context_score"] |
|
|
| file_idx = accepted_per_token[mint_addr] |
| filename = f"sample_{mint_addr[:16]}_{file_idx}.pt" |
| output_path = Path(output_dir) / filename |
| torch.save(ctx, output_path) |
|
|
| accepted_per_token[mint_addr] += 1 |
| accepted_counts[class_id]["total"] += 1 |
| accepted_counts[class_id][context_bucket] += 1 |
| class_distribution[class_id] = class_distribution.get(class_id, 0) + 1 |
| context_distribution[class_id][context_bucket] += 1 |
| file_class_map[filename] = class_id |
| file_context_bucket_map[filename] = context_bucket |
| file_context_summary_map[filename] = context_summary |
| saved_any = True |
|
|
| if saved_any: |
| success_count += 1 |
| else: |
| class_id = result['class_id'] |
| mint_addr = result['mint'] |
| q_score = result['q_score'] |
| for ctx_idx, ctx in enumerate(result.get("contexts", [])): |
| context_summary = summarize_context_window(ctx.get("labels"), ctx.get("labels_mask")) |
| context_bucket = context_summary["context_bucket"] |
| ctx["quality_score"] = q_score |
| ctx["class_id"] = class_id |
| ctx["source_token"] = mint_addr |
| ctx["context_bucket"] = context_bucket |
| ctx["context_score"] = context_summary["context_score"] |
| filename = f"sample_{mint_addr[:16]}_{ctx_idx}.pt" |
| output_path = Path(output_dir) / filename |
| torch.save(ctx, output_path) |
| file_class_map[filename] = class_id |
| file_context_bucket_map[filename] = context_bucket |
| file_context_summary_map[filename] = context_summary |
| class_distribution[class_id] = class_distribution.get(class_id, 0) + 1 |
| context_distribution[class_id][context_bucket] += 1 |
| success_count += 1 |
| elif result['status'] == 'skipped': |
| skipped_count += 1 |
| else: |
| error_count += 1 |
| tqdm.write(f"ERROR: {result['mint'][:16]} - {result['error']}") |
| else: |
| print(f"INFO: Running with {args.num_workers} workers...") |
| with ProcessPoolExecutor(max_workers=args.num_workers, initializer=_init_worker, initargs=(db_config, dataset_config, return_class_map, quality_scores_map)) as executor: |
| futures = {executor.submit(process_fn, task): task for task in tasks} |
| for future in tqdm(as_completed(futures), total=len(futures), desc="Caching"): |
| try: |
| result = future.result(timeout=300) |
| if result['status'] == 'success': |
| class_id = result['class_id'] |
| mint_addr = result['mint'] |
| q_score = result['q_score'] |
| for ctx_idx, ctx in enumerate(result.get("contexts", [])): |
| context_summary = summarize_context_window(ctx.get("labels"), ctx.get("labels_mask")) |
| context_bucket = context_summary["context_bucket"] |
| ctx["quality_score"] = q_score |
| ctx["class_id"] = class_id |
| ctx["source_token"] = mint_addr |
| ctx["context_bucket"] = context_bucket |
| ctx["context_score"] = context_summary["context_score"] |
| filename = f"sample_{mint_addr[:16]}_{ctx_idx}.pt" |
| output_path = Path(output_dir) / filename |
| torch.save(ctx, output_path) |
| file_class_map[filename] = class_id |
| file_context_bucket_map[filename] = context_bucket |
| file_context_summary_map[filename] = context_summary |
| class_distribution[class_id] = class_distribution.get(class_id, 0) + 1 |
| context_distribution[class_id][context_bucket] += 1 |
| success_count += 1 |
| elif result['status'] == 'skipped': |
| skipped_count += 1 |
| else: |
| error_count += 1 |
| except Exception as e: |
| error_count += 1 |
| tqdm.write(f"WORKER ERROR: {e}") |
|
|
| print("INFO: Building metadata...") |
| if not file_class_map: |
| for f in sorted(output_dir.glob("sample_*.pt")): |
| try: |
| cached = torch.load(f, map_location="cpu", weights_only=False) |
| file_class_map[f.name] = cached.get("class_id", 0) |
| if "labels" in cached and "labels_mask" in cached: |
| context_summary = summarize_context_window(cached.get("labels"), cached.get("labels_mask")) |
| file_context_bucket_map[f.name] = context_summary["context_bucket"] |
| file_context_summary_map[f.name] = context_summary |
| except Exception: |
| pass |
|
|
| with open(output_dir / "class_metadata.json", 'w') as f: |
| json.dump({ |
| 'file_class_map': file_class_map, |
| 'file_context_bucket_map': file_context_bucket_map, |
| 'file_context_summary_map': file_context_summary_map, |
| 'class_distribution': {str(k): v for k, v in class_distribution.items()}, |
| 'context_distribution': { |
| str(k): {bucket: count for bucket, count in bucket_counts.items()} |
| for k, bucket_counts in context_distribution.items() |
| }, |
| 'quota_plan': {str(k): v for k, v in quota_plan.items()}, |
| 'accepted_counts': {str(k): v for k, v in accepted_counts.items()}, |
| 'num_workers': args.num_workers, |
| }, f, indent=2) |
|
|
| if quota_plan: |
| print("INFO: Accepted context counts:") |
| for class_id, counts in sorted(accepted_counts.items()): |
| print( |
| f" Class {class_id}: total={counts['total']} " |
| f"good={counts['good']} bad={counts['bad']}" |
| ) |
|
|
| print(f"\n--- Done ---\nSuccess: {success_count}, Skipped: {skipped_count}, Errors: {error_count}\nFiles: {len(file_class_map)}\nLocation: {output_dir.resolve()}") |
|
|
| finally: |
| clickhouse_client.disconnect() |
| neo4j_driver.close() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|