ATLAS / src /oss /oss_submission_handler.py
“pangjh3”
modified: src/about.py
f652754
#!/usr/bin/env python3
"""
OSS提交处理器 - 替换原有的git/http提交方式
在HuggingFace Spaces中直接将提交文件上传到OSS
"""
import os
import sys
import json
from datetime import datetime
from pathlib import Path
from typing import Dict, Any, Tuple
# 导入同目录下的oss_file_manager
from .oss_file_manager import OSSFileManager
class OSSSubmissionHandler:
"""OSS提交处理器 - 将用户提交直接上传到OSS"""
def __init__(self, oss_submission_path: str = "atlas_eval/submissions/"):
"""
初始化OSS提交处理器
Args:
oss_submission_path: OSS中存储提交文件的路径
"""
self.oss_path = oss_submission_path
self.oss_manager = OSSFileManager()
print(f"📁 OSS submission path: oss://opencompass/{oss_submission_path}")
def format_error(self, msg: str) -> str:
"""格式化错误消息"""
return f"<p style='color: red; font-size: 16px;'>{msg}</p>"
def format_success(self, msg: str) -> str:
"""格式化成功消息"""
return f"<p style='color: green; font-size: 16px;'>{msg}</p>"
def format_warning(self, msg: str) -> str:
"""格式化警告消息"""
return f"<p style='color: orange; font-size: 16px;'>{msg}</p>"
def validate_sage_submission(self, submission_data: Dict[str, Any]) -> Tuple[bool, str]:
"""验证ATLAS基准提交格式"""
# 检查必需的顶级字段
required_fields = ["submission_org", "submission_email", "predictions"]
for field in required_fields:
if field not in submission_data:
return False, f"Missing required field: {field}"
# 验证邮箱格式(基本验证)
email = submission_data["submission_email"]
if "@" not in email or "." not in email:
return False, "Invalid email format"
# 验证predictions
predictions = submission_data["predictions"]
if not isinstance(predictions, list) or len(predictions) == 0:
return False, "predictions must be a non-empty list"
for i, prediction in enumerate(predictions):
# 检查必需的prediction字段
pred_required_fields = ["original_question_id", "content", "reasoning_content"]
for field in pred_required_fields:
if field not in prediction:
return False, f"Missing field in prediction {i}: {field}"
# 验证content数组
content = prediction["content"]
reasoning_content = prediction["reasoning_content"]
if not isinstance(content, list) or len(content) != 4:
return False, f"content in prediction {i} must be a list with 4 items"
if not isinstance(reasoning_content, list):
return False, f"reasoning_content in prediction {i} must be a list"
# # reasoning_content可以为空列表,或者包含4个项目
# if len(reasoning_content) != 0 and len(reasoning_content) != 4:
# return False, f"reasoning_content in prediction {i} must be an empty list or contain 4 items"
# 验证question ID
if not isinstance(prediction["original_question_id"], int):
return False, f"question ID in prediction {i} must be an integer"
return True, "Submission format is valid"
def generate_submission_filename(self, submission_data: Dict[str, Any]) -> str:
"""生成提交文件名"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# 获取模型名和组织名
model_name = submission_data.get("model_name", "UnknownModel")
model_name = model_name.replace(" ", "_").replace("/", "_").replace("\\", "_").replace("-", "_")
org_name = submission_data["submission_org"].replace(" ", "_").replace("/", "_").replace("\\", "_")
# 格式: submission_模型名_组织_时间戳.json
return f"submission_{model_name}_{org_name}_{timestamp}.json"
def upload_to_oss(self, submission_data: Dict[str, Any], filename: str) -> Tuple[bool, str]:
"""上传提交文件到OSS"""
try:
# 创建临时本地文件
temp_file = f"/tmp/{filename}"
with open(temp_file, 'w', encoding='utf-8') as f:
json.dump(submission_data, f, indent=2, ensure_ascii=False)
# 上传到OSS
oss_file_path = f"{self.oss_path}{filename}"
print(f"⬆️ Uploading to OSS: {oss_file_path}")
self.oss_manager.upload_file_to_object(
local_file_path=temp_file,
oss_file_path=oss_file_path,
replace=True
)
# 清理临时文件
os.remove(temp_file)
print(f"✅ OSS upload successful: {oss_file_path}")
return True, f"oss://opencompass/{oss_file_path}"
except Exception as e:
print(f"❌ OSS upload failed: {e}")
return False, str(e)
def process_sage_submission(self, submission_file_or_data, org_name=None, email=None) -> str:
"""
处理ATLAS基准提交文件 - OSS模式
替换原有的git/http方式,直接上传到OSS
"""
try:
# 处理输入参数 - 可能是文件路径或者已经的数据
if submission_file_or_data is None:
return self.format_error("❌ No submission data provided.")
# 如果是字符串,认为是文件路径
if isinstance(submission_file_or_data, str):
try:
with open(submission_file_or_data, 'r', encoding='utf-8') as f:
content = f.read()
# 解析JSON
submission_data = json.loads(content)
except Exception as e:
return self.format_error(f"❌ Error reading file: {str(e)}")
# 如果是字典,直接使用
elif isinstance(submission_file_or_data, dict):
submission_data = submission_file_or_data
else:
return self.format_error("❌ Invalid submission data format.")
# 如果表单提供了组织名和邮箱,使用表单数据
if org_name and email:
submission_data["submission_org"] = org_name.strip()
submission_data["submission_email"] = email.strip()
# 验证提交格式
is_valid, message = self.validate_sage_submission(submission_data)
if not is_valid:
return self.format_error(f"❌ Submission validation failed: {message}")
# 生成文件名
filename = self.generate_submission_filename(submission_data)
# 上传到OSS
success, result = self.upload_to_oss(submission_data, filename)
if not success:
return self.format_error(f"❌ Failed to upload to OSS: {result}")
# 生成成功消息
org = submission_data["submission_org"]
email_addr = submission_data["submission_email"]
num_predictions = len(submission_data["predictions"])
success_msg = self.format_success(f"""
🎉 <strong>Submission successful!</strong><br><br>
📋 <strong>Submission Information:</strong><br>
• Organization: {org}<br>
• Email: {email_addr}<br>
• Number of predictions: {num_predictions} questions<br>
• Filename: {filename}<br><br>
⚡ <strong>Evaluation Status:</strong><br>
Your submission has been successfully uploaded to cloud storage.<br><br>
🧪 Thank you for participating in the ATLAS scientific reasoning benchmark!
""")
return success_msg
except Exception as e:
return self.format_error(f"❌ Submission processing failed: {str(e)}")
# 兼容性函数 - 保持与原有代码的接口一致
def process_sage_submission_simple(submission_file, org_name=None, email=None) -> str:
"""
处理ATLAS基准提交文件 - OSS模式
这是一个兼容性函数,保持与原有submit.py的接口一致
"""
handler = OSSSubmissionHandler()
return handler.process_sage_submission(submission_file, org_name, email)
def format_error(msg):
return f"<p style='color: red; font-size: 16px;'>{msg}</p>"
def format_success(msg):
return f"<p style='color: green; font-size: 16px;'>{msg}</p>"
def format_warning(msg):
return f"<p style='color: orange; font-size: 16px;'>{msg}</p>"
if __name__ == "__main__":
# 测试代码
print("🧪 测试OSS提交处理器")
# 检查环境变量
required_env_vars = ["OSS_ACCESS_KEY_ID", "OSS_ACCESS_KEY_SECRET"]
missing_vars = [var for var in required_env_vars if not os.getenv(var)]
if missing_vars:
print(f"❌ 缺少必需的环境变量: {missing_vars}")
exit(1)
handler = OSSSubmissionHandler()
print("✅ OSS提交处理器初始化成功")