Spaces:
Runtime error
Runtime error
import json | |
import os | |
import argparse | |
from collections import defaultdict | |
def calculate_wiscore(consistency, realism, aesthetic_quality): | |
"""Calculates the WiScore based on given components.""" | |
return (0.7 * consistency + 0.2 * realism + 0.1 * aesthetic_quality) / 2 | |
# Define expected prompt ID ranges at a global level for easy access | |
EXPECTED_PROMPT_RANGES = { | |
"culture": range(1, 401), | |
"space-time": range(401, 701), # Covers both TIME (401-567) and SPACE (568-700) | |
"science": range(701, 1001), # Covers BIOLOGY (701-800), PHYSICS (801-900), CHEMISTRY (901-1000) | |
"all": range(1, 1001) # Full range for combined evaluation | |
} | |
def process_jsonl_file_segment(file_path, category_arg=None): | |
""" | |
Processes a segment of a JSONL file, collecting scores and present prompt_ids. | |
Performs prompt_id validation if a specific category_arg is provided for a single file. | |
Returns collected data or None if critical errors or missing prompt_ids (for single-file validation). | |
""" | |
segment_scores = defaultdict(list) | |
segment_present_prompt_ids = set() | |
if not os.path.exists(file_path): | |
print(f"Error: File '{file_path}' not found.") | |
return None | |
try: | |
with open(file_path, 'r', encoding='utf-8') as file: | |
for line_num, line in enumerate(file, 1): | |
try: | |
data = json.loads(line) | |
prompt_id = data.get('prompt_id') | |
if prompt_id is None: | |
print(f"Warning: File '{file_path}', Line {line_num}: Missing 'prompt_id'. Skipping this line.") | |
continue | |
if not isinstance(prompt_id, int): | |
print(f"Warning: File '{file_path}', Line {line_num}: 'prompt_id' is not an integer. Skipping this line.") | |
continue | |
segment_present_prompt_ids.add(prompt_id) | |
consistency = data.get('consistency') | |
realism = data.get('realism') | |
aesthetic_quality = data.get('aesthetic_quality') | |
if not all(isinstance(val, (int, float)) for val in [consistency, realism, aesthetic_quality]): | |
print(f"Warning: File '{file_path}', Line {line_num}: One or more score values are not numeric. Skipping this line for category calculation.") | |
continue | |
wiscore = calculate_wiscore(consistency, realism, aesthetic_quality) | |
# Determine category based on prompt_id | |
if 1 <= prompt_id <= 400: | |
segment_scores['CULTURE'].append(wiscore) | |
elif 401 <= prompt_id <= 567: | |
segment_scores['TIME'].append(wiscore) | |
elif 568 <= prompt_id <= 700: | |
segment_scores['SPACE'].append(wiscore) | |
elif 701 <= prompt_id <= 800: | |
segment_scores['BIOLOGY'].append(wiscore) | |
elif 801 <= prompt_id <= 900: | |
segment_scores['PHYSICS'].append(wiscore) | |
elif 901 <= prompt_id <= 1000: | |
segment_scores['CHEMISTRY'].append(wiscore) | |
else: | |
print(f"Warning: File '{file_path}', Line {line_num}: prompt_id {prompt_id} is outside defined categories. Skipping this line.") | |
continue | |
except json.JSONDecodeError: | |
print(f"Warning: File '{file_path}', Line {line_num}: Invalid JSON format. Skipping this line.") | |
except KeyError as e: | |
print(f"Warning: File '{file_path}', Line {line_num}: Missing expected key '{e}'. Skipping this line.") | |
except Exception as e: | |
print(f"Error reading file '{file_path}': {e}") | |
return None | |
# --- Single-file prompt_id validation logic --- | |
if category_arg and category_arg != 'all' and category_arg in EXPECTED_PROMPT_RANGES: | |
expected_ids_for_this_category = set(EXPECTED_PROMPT_RANGES[category_arg]) | |
missing_ids_in_segment = expected_ids_for_this_category - segment_present_prompt_ids | |
if missing_ids_in_segment: | |
print(f"Error: File '{file_path}': When evaluating as '--category {category_arg}', " | |
f"missing the following prompt_ids: {sorted(list(missing_ids_in_segment))}") | |
return None # Return None if required prompt_ids are missing for a specific category file | |
return { | |
'scores': segment_scores, | |
'present_prompt_ids': segment_present_prompt_ids, | |
'file_path': file_path | |
} | |
def main(): | |
parser = argparse.ArgumentParser( | |
description="Evaluate JSONL files for model performance, categorizing scores by prompt_id." | |
) | |
parser.add_argument( | |
'files', | |
metavar='FILE', | |
nargs='+', # Accepts one or more file paths | |
help="Path(s) to the JSONL file(s) to be evaluated (e.g., cultural_common_sense_ModelName_scores.jsonl)" | |
) | |
parser.add_argument( | |
'--category', | |
type=str, | |
choices=['culture', 'space-time', 'science', 'all'], | |
default='all', | |
help="Specify the category of the JSONL file(s) for specific prompt_id validation. Choose from 'culture', 'space-time', 'science', or 'all' (default). If evaluating a single category file, use the corresponding category." | |
) | |
args = parser.parse_args() | |
all_raw_results = [] | |
# Process each file to collect raw scores and prompt IDs | |
for file_path in args.files: | |
print(f"\n--- Processing file: {file_path} ---") | |
# Pass the category argument to process_jsonl_file_segment | |
# This enables single-file validation logic | |
results = process_jsonl_file_segment(file_path, args.category if len(args.files) == 1 else None) | |
if results: | |
all_raw_results.append(results) | |
else: | |
print(f"Could not process '{file_path}'. Please check previous warnings/errors.") | |
if not all_raw_results: | |
print("No valid data processed from any of the provided files. Exiting.") | |
return # Exit if no files were successfully processed | |
# Aggregate data across all successful files | |
aggregated_scores = defaultdict(list) | |
combined_present_prompt_ids = set() | |
final_file_reports = {} # To store calculated averages/counts per file for individual display | |
for file_data in all_raw_results: | |
file_path = file_data['file_path'] | |
combined_present_prompt_ids.update(file_data['present_prompt_ids']) | |
# Calculate scores for this individual file (for individual file report) | |
current_file_avg_scores = {} | |
current_file_num_samples = {} | |
detected_categories_in_file = [] | |
for category, scores_list in file_data['scores'].items(): | |
aggregated_scores[category].extend(scores_list) # Aggregate for overall score later | |
if scores_list: # Only add to individual file report if samples exist | |
current_file_avg_scores[category] = sum(scores_list) / len(scores_list) | |
current_file_num_samples[category] = len(scores_list) | |
detected_categories_in_file.append(category) | |
final_file_reports[file_path] = { | |
'average': current_file_avg_scores, | |
'num_processed_samples': current_file_num_samples, | |
'detected_categories': detected_categories_in_file | |
} | |
# --- Step 1: Validate Prompt IDs for 'all' category scenario --- | |
# This check happens only when --category all is explicitly chosen or is the default for multiple files. | |
# Single-file specific category validation happens inside process_jsonl_file_segment. | |
if args.category == 'all': | |
expected_prompt_ids_for_all = set(EXPECTED_PROMPT_RANGES['all']) | |
missing_prompt_ids_in_combined = expected_prompt_ids_for_all - combined_present_prompt_ids | |
if missing_prompt_ids_in_combined: | |
print(f"\nError: When '--category all' is specified, the combined files are missing the following prompt_ids:") | |
print(f"Missing IDs: {sorted(list(missing_prompt_ids_in_combined))}") | |
print("\nAborting overall evaluation due to incomplete data.") | |
return # Exit if combined prompt IDs are missing when 'all' is expected | |
# --- Step 2: Display individual file reports --- | |
print("\n" + "="*50) | |
print(" Individual File Reports") | |
print("="*50 + "\n") | |
ordered_categories = ['CULTURE', 'TIME', 'SPACE', 'BIOLOGY', 'PHYSICS', 'CHEMISTRY'] | |
for file_path, file_data in final_file_reports.items(): | |
print(f"--- Evaluation Results for File: {file_path} ---") | |
categories_to_print = sorted([cat for cat in ordered_categories if cat in file_data['detected_categories']], | |
key=lambda x: ordered_categories.index(x)) | |
if not categories_to_print: | |
print(" No scores found for any defined categories in this file.") | |
else: | |
for category in categories_to_print: | |
avg_score = file_data['average'].get(category, 0) | |
sample_count = file_data['num_processed_samples'].get(category, 0) | |
print(f" Category: {category}") | |
print(f" Average WiScore: {avg_score:.2f}") | |
print(f" Number of samples: {sample_count}\n") | |
print("-" * (len(file_path) + 30) + "\n") | |
# --- Step 3: Calculate and Display Overall Summary (if applicable) --- | |
print("\n" + "="*50) | |
print(" Overall Evaluation Summary") | |
print("="*50 + "\n") | |
# Calculate overall averages from aggregated scores | |
overall_avg_scores = { | |
category: sum(scores) / len(scores) if len(scores) > 0 else 0 | |
for category, scores in aggregated_scores.items() | |
} | |
overall_num_samples = { | |
category: len(scores) | |
for category, scores in aggregated_scores.items() | |
} | |
# Print overall category scores (only for categories that have samples) | |
overall_categories_to_print = sorted([cat for cat in ordered_categories if overall_num_samples.get(cat, 0) > 0], | |
key=lambda x: ordered_categories.index(x)) | |
if not overall_categories_to_print and args.category != 'all': | |
print("No valid scores found for any categories in the aggregated data.") | |
else: | |
print("Aggregated Category Scores:") | |
for category in overall_categories_to_print: | |
print(f" Category: {category}") | |
print(f" Average WiScore: {overall_avg_scores.get(category, 0):.2f}") | |
print(f" Number of samples: {overall_num_samples.get(category, 0)}\n") | |
# Calculate and print Overall WiScore if '--category all' was specified and all categories have samples | |
all_categories_have_overall_samples = all(overall_num_samples.get(cat, 0) > 0 for cat in ordered_categories) | |
if args.category == 'all' and all_categories_have_overall_samples: | |
cultural_score = overall_avg_scores.get('CULTURE', 0) | |
time_score = overall_avg_scores.get('TIME', 0) | |
space_score = overall_avg_scores.get('SPACE', 0) | |
biology_score = overall_avg_scores.get('BIOLOGY', 0) | |
physics_score = overall_avg_scores.get('PHYSICS', 0) | |
chemistry_score = overall_avg_scores.get('CHEMISTRY', 0) | |
overall_wiscore = (0.4 * cultural_score + 0.167 * time_score + 0.133 * space_score + | |
0.1 * biology_score + 0.1 * physics_score + 0.1 * chemistry_score) | |
print("\n--- Overall WiScore Across All Categories ---") | |
print(f"Overall WiScore: {overall_wiscore:.2f}") | |
print("Cultural\tTime\tSpace\tBiology\tPhysics\tChemistry\tOverall") | |
print(f"{cultural_score:.2f}\t\t{time_score:.2f}\t{space_score:.2f}\t{biology_score:.2f}\t{physics_score:.2f}\t{chemistry_score:.2f}\t\t{overall_wiscore:.2f}") | |
elif args.category == 'all' and not all_categories_have_overall_samples: | |
print("\nOverall WiScore cannot be calculated: Not all categories have samples in the aggregated data when '--category all' is specified.") | |
else: | |
print(f"\nOverall WiScore calculation skipped. To calculate overall score, use '--category all' and provide files covering all prompt IDs.") | |
if __name__ == "__main__": | |
main() |