oracle / cache_dataset.py
zirobtc's picture
Upload folder using huggingface_hub
c471f42 verified
import os
import sys
import argparse
import datetime
import torch
import json
import math
from pathlib import Path
from tqdm import tqdm
from dotenv import load_dotenv
import huggingface_hub
import logging
from concurrent.futures import ProcessPoolExecutor, as_completed
import multiprocessing as mp
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("transformers").setLevel(logging.ERROR)
logging.getLogger("huggingface_hub").setLevel(logging.WARNING)
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from scripts.analyze_distribution import get_return_class_map
from scripts.compute_quality_score import get_token_quality_scores, fetch_token_metrics, _bucket_id, _midrank_percentiles, EPS
from data.data_loader import summarize_context_window
from clickhouse_driver import Client as ClickHouseClient
from neo4j import GraphDatabase
_worker_dataset = None
_worker_return_class_map = None
_worker_quality_scores_map = None
def _build_context_quota_plan(
class_ids,
target_contexts_per_class,
target_contexts_total,
good_ratio_nonzero,
good_ratio_class0,
):
unique_class_ids = sorted(set(int(cid) for cid in class_ids))
if not unique_class_ids:
return {}
if target_contexts_per_class is not None:
per_class_target = int(target_contexts_per_class)
elif target_contexts_total is not None:
per_class_target = max(1, int(target_contexts_total) // len(unique_class_ids))
else:
return {}
if per_class_target <= 0:
raise RuntimeError("Context quota target must be positive.")
plan = {}
for class_id in unique_class_ids:
ratio = float(good_ratio_class0 if class_id == 0 else good_ratio_nonzero)
ratio = max(0.0, min(1.0, ratio))
good_target = int(round(per_class_target * ratio))
bad_target = per_class_target - good_target
plan[class_id] = {
"total_target": per_class_target,
"good_target": good_target,
"bad_target": bad_target,
}
return plan
def _should_accept_context(class_id, context_bucket, accepted_counts, quota_plan):
if not quota_plan:
return True
if class_id not in quota_plan:
return False
class_plan = quota_plan[class_id]
class_counts = accepted_counts[class_id]
if class_counts["total"] >= class_plan["total_target"]:
return False
bucket_key = "good" if context_bucket == "good" else "bad"
target_key = f"{bucket_key}_target"
if class_counts[bucket_key] >= class_plan[target_key]:
return False
return True
def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map):
global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map
from data.data_loader import OracleDataset
from data.data_fetcher import DataFetcher
clickhouse_client = ClickHouseClient(host=db_config['clickhouse_host'], port=db_config['clickhouse_port'])
neo4j_driver = GraphDatabase.driver(db_config['neo4j_uri'], auth=(db_config['neo4j_user'], db_config['neo4j_password']))
data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
_worker_dataset = OracleDataset(
data_fetcher=data_fetcher,
min_trades=dataset_config['min_trades'],
start_date=dataset_config['start_date'],
horizons_seconds=dataset_config['horizons_seconds'],
quantiles=dataset_config['quantiles'],
min_trade_usd=dataset_config['min_trade_usd'],
max_seq_len=dataset_config['max_seq_len']
)
_worker_dataset.sampled_mints = dataset_config['sampled_mints']
_worker_return_class_map = return_class_map
_worker_quality_scores_map = quality_scores_map
def _process_single_token_context(args):
idx, mint_addr, samples_per_token, output_dir = args
global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map
try:
class_id = _worker_return_class_map.get(mint_addr)
if class_id is None:
return {'status': 'skipped', 'reason': 'not in class map', 'mint': mint_addr}
contexts = _worker_dataset.__cacheitem_context__(idx, num_samples_per_token=samples_per_token)
if not contexts:
return {'status': 'skipped', 'reason': 'no valid contexts', 'mint': mint_addr}
q_score = _worker_quality_scores_map.get(mint_addr)
if q_score is None:
return {'status': 'skipped', 'reason': 'no quality score', 'mint': mint_addr}
return {
'status': 'success',
'mint': mint_addr,
'class_id': class_id,
'q_score': q_score,
'n_contexts': len(contexts),
'n_events': len(contexts[0].get('event_sequence', [])) if contexts else 0,
'contexts': contexts,
}
except Exception as e:
import traceback
return {'status': 'error', 'mint': mint_addr, 'error': str(e), 'traceback': traceback.format_exc()}
def main():
load_dotenv()
mp.set_start_method('spawn', force=True)
hf_token = os.getenv("HF_TOKEN")
if hf_token:
print(f"INFO: Logging in to Hugging Face...")
huggingface_hub.login(token=hf_token)
parser = argparse.ArgumentParser()
parser.add_argument("--output_dir", type=str, default="data/cache")
parser.add_argument("--start_date", type=str, default=None)
parser.add_argument("--min_trade_usd", type=float, default=0.0)
parser.add_argument("--min_trades", type=int, default=10)
parser.add_argument("--context_length", type=int, default=8192)
parser.add_argument("--samples_per_token", type=int, default=1)
parser.add_argument("--target_contexts_per_class", type=int, default=None)
parser.add_argument("--target_contexts_total", type=int, default=None)
parser.add_argument("--good_ratio_nonzero", type=float, default=0.5)
parser.add_argument("--good_ratio_class0", type=float, default=0.0)
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--clickhouse_host", type=str, default=os.getenv("CLICKHOUSE_HOST", "localhost"))
parser.add_argument("--clickhouse_port", type=int, default=int(os.getenv("CLICKHOUSE_PORT", 9000)))
parser.add_argument("--neo4j_uri", type=str, default=os.getenv("NEO4J_URI", "bolt://localhost:7687"))
parser.add_argument("--neo4j_user", type=str, default=os.getenv("NEO4J_USER", "neo4j"))
parser.add_argument("--neo4j_password", type=str, default=os.getenv("NEO4J_PASSWORD", "password"))
args = parser.parse_args()
if args.target_contexts_per_class is not None and args.target_contexts_total is not None:
raise RuntimeError(
"Choose exactly one cache budget: either --target_contexts_per_class or --target_contexts_total."
)
if args.num_workers == 0:
args.num_workers = max(1, mp.cpu_count() - 4)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
start_date_dt = datetime.datetime.strptime(args.start_date, "%Y-%m-%d") if args.start_date else None
print(f"INFO: Initializing DB Connections...")
clickhouse_client = ClickHouseClient(host=args.clickhouse_host, port=args.clickhouse_port)
neo4j_driver = GraphDatabase.driver(args.neo4j_uri, auth=(args.neo4j_user, args.neo4j_password))
try:
from data.data_loader import OracleDataset
from data.data_fetcher import DataFetcher
data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
print("INFO: Fetching Return Classification Map...")
return_class_map, _ = get_return_class_map(clickhouse_client)
print(f"INFO: Loaded {len(return_class_map)} classified tokens.")
print("INFO: Fetching Quality Scores...")
quality_scores_map = get_token_quality_scores(clickhouse_client)
print(f"INFO: Loaded {len(quality_scores_map)} quality scores.")
dataset = OracleDataset(
data_fetcher=data_fetcher,
min_trades=args.min_trades,
start_date=start_date_dt,
horizons_seconds=[60, 180, 300, 600, 1800, 3600, 7200],
quantiles=[0.5],
min_trade_usd=args.min_trade_usd,
max_seq_len=args.context_length,
)
if len(dataset) == 0:
print("WARNING: No samples. Exiting.")
return
# Filter mints by return_class_map
original_size = len(dataset.sampled_mints)
filtered_mints = [m for m in dataset.sampled_mints if m['mint_address'] in return_class_map]
print(f"INFO: Filtered {original_size} -> {len(filtered_mints)} tokens")
if len(filtered_mints) == 0:
print("WARNING: No tokens after filtering.")
return
print(f"INFO: Building canonical context cache | Workers: {args.num_workers}")
if args.num_workers != 1 and (
args.target_contexts_per_class is not None or args.target_contexts_total is not None
):
raise RuntimeError(
"Quota-driven context caching currently requires --num_workers 1 so accepted contexts "
"can be planned and written deterministically in one process."
)
db_config = {'clickhouse_host': args.clickhouse_host, 'clickhouse_port': args.clickhouse_port, 'neo4j_uri': args.neo4j_uri, 'neo4j_user': args.neo4j_user, 'neo4j_password': args.neo4j_password}
dataset_config = {'start_date': start_date_dt, 'min_trades': args.min_trades, 'horizons_seconds': [60, 180, 300, 600, 1800, 3600, 7200], 'quantiles': [0.5], 'min_trade_usd': args.min_trade_usd, 'max_seq_len': args.context_length, 'sampled_mints': filtered_mints}
# Build tasks from filtered_mints directly
tasks = []
for i, mint_record in enumerate(filtered_mints):
mint_addr = mint_record['mint_address']
tasks.append((i, mint_addr, args.samples_per_token, str(output_dir)))
print(f"INFO: Starting to cache {len(tasks)} tokens...")
success_count, skipped_count, error_count = 0, 0, 0
class_distribution = {}
context_distribution = defaultdict(lambda: defaultdict(int))
file_class_map = {}
file_context_bucket_map = {}
file_context_summary_map = {}
process_fn = _process_single_token_context
quota_plan = {}
accepted_counts = defaultdict(lambda: {"total": 0, "good": 0, "bad": 0})
accepted_per_token = defaultdict(int)
quota_plan = _build_context_quota_plan(
class_ids=[return_class_map[m['mint_address']] for m in filtered_mints if m['mint_address'] in return_class_map],
target_contexts_per_class=args.target_contexts_per_class,
target_contexts_total=args.target_contexts_total,
good_ratio_nonzero=args.good_ratio_nonzero,
good_ratio_class0=args.good_ratio_class0,
)
if quota_plan:
print("INFO: Context quota plan:")
for class_id, plan in sorted(quota_plan.items()):
print(
f" Class {class_id}: total={plan['total_target']} "
f"(good={plan['good_target']}, bad={plan['bad_target']})"
)
if args.num_workers == 1:
print("INFO: Single-threaded mode...")
_init_worker(db_config, dataset_config, return_class_map, quality_scores_map)
for task in tqdm(tasks, desc="Caching"):
result = process_fn(task)
if result['status'] == 'success':
if quota_plan:
class_id = result['class_id']
mint_addr = result['mint']
q_score = result['q_score']
saved_any = False
for ctx in result.get("contexts", []):
context_summary = summarize_context_window(ctx.get("labels"), ctx.get("labels_mask"))
context_bucket = context_summary["context_bucket"]
if not _should_accept_context(class_id, context_bucket, accepted_counts, quota_plan):
continue
ctx["quality_score"] = q_score
ctx["class_id"] = class_id
ctx["source_token"] = mint_addr
ctx["context_bucket"] = context_bucket
ctx["context_score"] = context_summary["context_score"]
file_idx = accepted_per_token[mint_addr]
filename = f"sample_{mint_addr[:16]}_{file_idx}.pt"
output_path = Path(output_dir) / filename
torch.save(ctx, output_path)
accepted_per_token[mint_addr] += 1
accepted_counts[class_id]["total"] += 1
accepted_counts[class_id][context_bucket] += 1
class_distribution[class_id] = class_distribution.get(class_id, 0) + 1
context_distribution[class_id][context_bucket] += 1
file_class_map[filename] = class_id
file_context_bucket_map[filename] = context_bucket
file_context_summary_map[filename] = context_summary
saved_any = True
if saved_any:
success_count += 1
else:
class_id = result['class_id']
mint_addr = result['mint']
q_score = result['q_score']
for ctx_idx, ctx in enumerate(result.get("contexts", [])):
context_summary = summarize_context_window(ctx.get("labels"), ctx.get("labels_mask"))
context_bucket = context_summary["context_bucket"]
ctx["quality_score"] = q_score
ctx["class_id"] = class_id
ctx["source_token"] = mint_addr
ctx["context_bucket"] = context_bucket
ctx["context_score"] = context_summary["context_score"]
filename = f"sample_{mint_addr[:16]}_{ctx_idx}.pt"
output_path = Path(output_dir) / filename
torch.save(ctx, output_path)
file_class_map[filename] = class_id
file_context_bucket_map[filename] = context_bucket
file_context_summary_map[filename] = context_summary
class_distribution[class_id] = class_distribution.get(class_id, 0) + 1
context_distribution[class_id][context_bucket] += 1
success_count += 1
elif result['status'] == 'skipped':
skipped_count += 1
else:
error_count += 1
tqdm.write(f"ERROR: {result['mint'][:16]} - {result['error']}")
else:
print(f"INFO: Running with {args.num_workers} workers...")
with ProcessPoolExecutor(max_workers=args.num_workers, initializer=_init_worker, initargs=(db_config, dataset_config, return_class_map, quality_scores_map)) as executor:
futures = {executor.submit(process_fn, task): task for task in tasks}
for future in tqdm(as_completed(futures), total=len(futures), desc="Caching"):
try:
result = future.result(timeout=300)
if result['status'] == 'success':
class_id = result['class_id']
mint_addr = result['mint']
q_score = result['q_score']
for ctx_idx, ctx in enumerate(result.get("contexts", [])):
context_summary = summarize_context_window(ctx.get("labels"), ctx.get("labels_mask"))
context_bucket = context_summary["context_bucket"]
ctx["quality_score"] = q_score
ctx["class_id"] = class_id
ctx["source_token"] = mint_addr
ctx["context_bucket"] = context_bucket
ctx["context_score"] = context_summary["context_score"]
filename = f"sample_{mint_addr[:16]}_{ctx_idx}.pt"
output_path = Path(output_dir) / filename
torch.save(ctx, output_path)
file_class_map[filename] = class_id
file_context_bucket_map[filename] = context_bucket
file_context_summary_map[filename] = context_summary
class_distribution[class_id] = class_distribution.get(class_id, 0) + 1
context_distribution[class_id][context_bucket] += 1
success_count += 1
elif result['status'] == 'skipped':
skipped_count += 1
else:
error_count += 1
except Exception as e:
error_count += 1
tqdm.write(f"WORKER ERROR: {e}")
print("INFO: Building metadata...")
if not file_class_map:
for f in sorted(output_dir.glob("sample_*.pt")):
try:
cached = torch.load(f, map_location="cpu", weights_only=False)
file_class_map[f.name] = cached.get("class_id", 0)
if "labels" in cached and "labels_mask" in cached:
context_summary = summarize_context_window(cached.get("labels"), cached.get("labels_mask"))
file_context_bucket_map[f.name] = context_summary["context_bucket"]
file_context_summary_map[f.name] = context_summary
except Exception:
pass
with open(output_dir / "class_metadata.json", 'w') as f:
json.dump({
'file_class_map': file_class_map,
'file_context_bucket_map': file_context_bucket_map,
'file_context_summary_map': file_context_summary_map,
'class_distribution': {str(k): v for k, v in class_distribution.items()},
'context_distribution': {
str(k): {bucket: count for bucket, count in bucket_counts.items()}
for k, bucket_counts in context_distribution.items()
},
'quota_plan': {str(k): v for k, v in quota_plan.items()},
'accepted_counts': {str(k): v for k, v in accepted_counts.items()},
'num_workers': args.num_workers,
}, f, indent=2)
if quota_plan:
print("INFO: Accepted context counts:")
for class_id, counts in sorted(accepted_counts.items()):
print(
f" Class {class_id}: total={counts['total']} "
f"good={counts['good']} bad={counts['bad']}"
)
print(f"\n--- Done ---\nSuccess: {success_count}, Skipped: {skipped_count}, Errors: {error_count}\nFiles: {len(file_class_map)}\nLocation: {output_dir.resolve()}")
finally:
clickhouse_client.disconnect()
neo4j_driver.close()
if __name__ == "__main__":
main()