TRAIL / database.py
jitinpatronus's picture
Upload 23 files
0380c4f verified
import os
import json
import datetime
from pathlib import Path
import numpy as np
class Database:
def __init__(self, submission_dir="submissions"):
self.submission_dir = submission_dir
os.makedirs(submission_dir, exist_ok=True)
def add_submission(self, submission):
"""Add a new submission to the database"""
# Generate a timestamp and ID for the submission
timestamp = datetime.datetime.now().isoformat()
submission_id = f"{submission['model_name'].replace(' ', '_')}_{timestamp.replace(':', '-')}"
# Add timestamp and ID to submission
submission['timestamp'] = timestamp
submission['id'] = submission_id
# Save submission to a JSON file
file_path = os.path.join(self.submission_dir, f"{submission_id}.json")
with open(file_path, 'w') as f:
json.dump(submission, f, indent=2)
return submission_id
def get_submission(self, submission_id):
"""Get a specific submission by ID"""
file_path = os.path.join(self.submission_dir, f"{submission_id}.json")
if os.path.exists(file_path):
with open(file_path, 'r') as f:
return json.load(f)
return None
def get_all_submissions(self):
"""Get all submissions"""
submissions = []
for file_name in os.listdir(self.submission_dir):
if file_name.endswith('.json'):
file_path = os.path.join(self.submission_dir, file_name)
with open(file_path, 'r') as f:
submissions.append(json.load(f))
return submissions
def get_leaderboard(self, sort_by="score", ascending=False):
"""Get submissions sorted for leaderboard display"""
submissions = self.get_all_submissions()
# Make sure we have submissions to sort
if not submissions:
return []
# Sort submissions
if sort_by in submissions[0]:
submissions.sort(key=lambda x: x.get(sort_by, 0), reverse=not ascending)
return submissions
def delete_submission(self, submission_id):
"""Delete a submission by ID"""
file_path = os.path.join(self.submission_dir, f"{submission_id}.json")
if os.path.exists(file_path):
os.remove(file_path)
return True
return False
# Load leaderboard configuration
def load_config():
try:
if os.path.exists("models.json") and os.path.getsize("models.json") > 0:
with open("models.json", "r") as f:
return json.load(f)
else:
print("models.json file is empty or missing. Creating with default configuration.")
# Default configuration
config = {
"title": "TRAIL Model Leaderboard",
"description": "Submit and compare model performances",
"metrics": ["Cat. F1", "Loc. Acc", "Joint F1"],
"main_metric": "Cat. F1"
}
with open("models.json", "w") as f:
json.dump(config, f, indent=2)
return config
except json.JSONDecodeError:
print("Error parsing models.json. Creating with default configuration.")
# Default configuration if JSON is invalid
config = {
"title": "TRAIL Model Leaderboard",
"description": "Submit and compare model performances",
"metrics": ["Cat. F1", "Loc. Acc", "Joint F1"],
"main_metric": "Cat. F1"
}
with open("models.json", "w") as f:
json.dump(config, f, indent=2)
return config