import random import pandas as pd import json def get_current_stage(backend, dataset, stage_splits, threshold=3): df = backend.get_all_rows() counts = df.groupby("interpretation_id")["user_id"].nunique().to_dict() # Check Stage 1 stage1_ids = [dataset[i]["interpretation_id"] for i in stage_splits["stage1"]] if all(counts.get(iid, 0) >= threshold for iid in stage1_ids): # Check Stage 2 stage2_ids = [dataset[i]["interpretation_id"] for i in stage_splits["stage2"]] if all(counts.get(iid, 0) >= threshold for iid in stage2_ids): return 3 else: return 2 return 1 def get_random_session_samples( backend, dataset, stage_splits, user_name, num_samples=30 ): df = backend.get_all_rows() # Defensive fallback if df.empty: stage = 1 stage_pool = stage_splits["stage1"] return random.sample(stage_pool, min(num_samples, len(stage_pool))), stage global_stage = get_current_stage(backend, dataset, stage_splits) counts = df.groupby("interpretation_id")["user_id"].nunique().to_dict() seen_ids = set(df[df["user_name"] == user_name]["interpretation_id"]) # if user finished global_stage, they can see the next stage for stage_num in range(global_stage, 4): # stages 1 to 3 stage_key = f"stage{stage_num}" stage_pool = stage_splits[stage_key] eligible_indices = [ i for i in stage_pool if counts.get(dataset[i]["interpretation_id"], 0) < 3 and dataset[i]["interpretation_id"] not in seen_ids ] if eligible_indices: return ( random.sample( eligible_indices, min(num_samples, len(eligible_indices)) ), stage_num, ) # If this user has completed everything (even beyond current stage) return [], 4 def generate_stage_splits( dataset, k_stage1=100, seed=42, output_path="stage_indices.json" ): total_indices = list(range(len(dataset))) random.seed(seed) # Stage 1 stage1 = random.sample(total_indices, k_stage1) remaining = list(set(total_indices) - set(stage1)) # Shuffle remaining and split equally random.shuffle(remaining) half = len(remaining) // 2 stage2 = remaining[:half] stage3 = remaining[half:] # Validate: all indices accounted for, no duplicates combined = set(stage1 + stage2 + stage3) assert len(combined) == len( total_indices ), "❌ Some indices are missing or duplicated!" assert combined == set(total_indices), "❌ Index sets do not fully cover dataset!" # Save all stages stage_splits = { "stage1": sorted(stage1), "stage2": sorted(stage2), "stage3": sorted(stage3), } with open(output_path, "w") as f: json.dump(stage_splits, f, indent=2) print(f"✅ Saved stage splits to {output_path}") print(f"Stage 1: {len(stage1)} samples") print(f"Stage 2: {len(stage2)} samples") print(f"Stage 3: {len(stage3)} samples")