oracle / audit_cache.py
zirobtc's picture
Upload folder using huggingface_hub
d195287 verified
import argparse
import math
from collections import Counter, defaultdict
from pathlib import Path
import torch
from tqdm import tqdm
from data.data_loader import summarize_context_window
from data.quant_ohlc_feature_schema import FEATURE_VERSION, NUM_QUANT_OHLC_FEATURES, TOKENS_PER_SEGMENT
REQUIRED_CONTEXT_FIELDS = [
"event_sequence",
"wallets",
"tokens",
"labels",
"labels_mask",
"quality_score",
"class_id",
"source_token",
"context_bucket",
"context_score",
"quant_ohlc_features",
"quant_feature_version",
]
def _to_list(value):
if value is None:
return []
if isinstance(value, torch.Tensor):
return value.tolist()
return list(value)
def _safe_float(value):
if isinstance(value, torch.Tensor):
if value.numel() != 1:
raise ValueError("Expected scalar tensor.")
return float(value.item())
return float(value)
def audit_cache(cache_dir, num_samples=None):
cache_path = Path(cache_dir)
files = sorted(cache_path.glob("sample_*.pt"))
if not files:
print(f"No sample_*.pt files found in {cache_path}")
return
if num_samples is not None and num_samples > 0:
files = files[:num_samples]
issues = Counter()
class_counts = Counter()
bucket_counts = Counter()
class_bucket_counts = defaultdict(Counter)
token_counts_by_class = defaultdict(Counter)
samples_per_token = Counter()
missing_fields = Counter()
stats = {
"files_audited": len(files),
"empty_event_sequence": 0,
"missing_wallets": 0,
"missing_tokens": 0,
"nan_labels": 0,
"nan_masks": 0,
"nan_quality_score": 0,
"negative_quality_score": 0,
"max_label_return": -float("inf"),
"min_label_return": float("inf"),
"max_events": 0,
"min_events": float("inf"),
"contexts_with_no_valid_horizons": 0,
"context_bucket_mismatch": 0,
"context_score_mismatch": 0,
"quant_feature_version_mismatch": 0,
"chart_events_missing_quant": 0,
"quant_segments_total": 0,
}
for filepath in tqdm(files, desc="Auditing cache", unit="file"):
try:
data = torch.load(filepath, map_location="cpu", weights_only=False)
except Exception:
issues["load_error"] += 1
continue
if not isinstance(data, dict):
issues["not_dict"] += 1
continue
missing_for_file = []
for field in REQUIRED_CONTEXT_FIELDS:
if field not in data:
missing_for_file.append(field)
missing_fields[field] += 1
if missing_for_file:
issues["missing_required_fields"] += 1
continue
class_id = int(data["class_id"])
source_token = str(data["source_token"])
context_bucket = str(data["context_bucket"])
class_counts[class_id] += 1
bucket_counts[context_bucket] += 1
class_bucket_counts[class_id][context_bucket] += 1
token_counts_by_class[class_id][source_token] += 1
samples_per_token[source_token] += 1
events = data.get("event_sequence") or []
wallets = data.get("wallets") or {}
tokens = data.get("tokens") or {}
labels = _to_list(data.get("labels"))
masks = _to_list(data.get("labels_mask"))
if not events:
stats["empty_event_sequence"] += 1
stats["max_events"] = max(stats["max_events"], len(events))
stats["min_events"] = min(stats["min_events"], len(events))
if not wallets:
stats["missing_wallets"] += 1
if not tokens:
stats["missing_tokens"] += 1
has_nan_label = False
for value in labels:
if math.isnan(float(value)):
has_nan_label = True
break
stats["max_label_return"] = max(stats["max_label_return"], float(value))
stats["min_label_return"] = min(stats["min_label_return"], float(value))
if has_nan_label:
stats["nan_labels"] += 1
has_nan_mask = False
for value in masks:
if math.isnan(float(value)):
has_nan_mask = True
break
if has_nan_mask:
stats["nan_masks"] += 1
try:
quality_score = _safe_float(data.get("quality_score"))
if math.isnan(quality_score):
stats["nan_quality_score"] += 1
elif quality_score < 0:
stats["negative_quality_score"] += 1
except Exception:
issues["invalid_quality_score"] += 1
try:
summary = summarize_context_window(data.get("labels"), data.get("labels_mask"))
if summary["valid_horizons"] == 0:
stats["contexts_with_no_valid_horizons"] += 1
if summary["context_bucket"] != context_bucket:
stats["context_bucket_mismatch"] += 1
stored_score = _safe_float(data.get("context_score"))
if not math.isclose(summary["context_score"], stored_score, rel_tol=1e-6, abs_tol=1e-6):
stats["context_score_mismatch"] += 1
except Exception:
issues["context_summary_error"] += 1
if data.get("quant_feature_version") != FEATURE_VERSION:
stats["quant_feature_version_mismatch"] += 1
chart_events = [event for event in events if event.get("event_type") == "Chart_Segment"]
stats["quant_segments_total"] += len(chart_events)
for event in chart_events:
quant_payload = event.get("quant_ohlc_features")
if not isinstance(quant_payload, list):
stats["chart_events_missing_quant"] += 1
continue
if len(quant_payload) > TOKENS_PER_SEGMENT:
issues["quant_too_many_tokens"] += 1
for token_payload in quant_payload:
vec = token_payload.get("feature_vector")
if not isinstance(vec, list) or len(vec) != NUM_QUANT_OHLC_FEATURES:
issues["quant_bad_vector_shape"] += 1
break
if stats["min_events"] == float("inf"):
stats["min_events"] = 0
if stats["min_label_return"] == float("inf"):
stats["min_label_return"] = 0.0
if stats["max_label_return"] == -float("inf"):
stats["max_label_return"] = 0.0
unique_tokens_total = len(samples_per_token)
duplicate_tokens_total = sum(1 for count in samples_per_token.values() if count > 1)
print("\n=== Cache Audit ===")
print(f"Cache dir: {cache_path}")
print(f"Files audited: {stats['files_audited']}")
print(f"Unique source tokens: {unique_tokens_total}")
print(f"Tokens with >1 cached context: {duplicate_tokens_total}")
print(f"Samples per token max: {max(samples_per_token.values()) if samples_per_token else 0}")
print("\n--- Class Counts ---")
for class_id in sorted(class_counts):
unique_tokens = len(token_counts_by_class[class_id])
print(f"Class {class_id}: samples={class_counts[class_id]} unique_tokens={unique_tokens}")
print("\n--- Context Buckets ---")
for bucket, count in sorted(bucket_counts.items()):
print(f"{bucket}: {count}")
print("\n--- Class x Context Bucket ---")
for class_id in sorted(class_bucket_counts):
bucket_summary = dict(sorted(class_bucket_counts[class_id].items()))
print(f"Class {class_id}: {bucket_summary}")
print("\n--- General Stats ---")
for key, value in stats.items():
print(f"{key}: {value}")
print("\n--- Missing Fields ---")
if missing_fields:
for field, count in sorted(missing_fields.items()):
print(f"{field}: {count}")
else:
print("none")
print("\n--- Issues ---")
if issues:
for key, value in sorted(issues.items()):
print(f"{key}: {value}")
else:
print("none")
print("\n--- Duplicate-Heavy Tokens ---")
heavy_tokens = sorted(samples_per_token.items(), key=lambda item: (-item[1], item[0]))[:20]
for token, count in heavy_tokens:
print(f"{token}: {count}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--cache_dir", type=str, default="/workspace/apollo/data/cache")
parser.add_argument("--num", type=int, default=None, help="Audit only the first N files.")
args = parser.parse_args()
audit_cache(args.cache_dir, args.num)