|
""" |
|
TestTime Logger |
|
|
|
TestTime RLVR์ ์ํ ํฌ๊ด์ ๋ก๊น
์์คํ
|
|
์๊ตฌ์ฌํญ์ ๋ฐ๋ฅธ ๋ชจ๋ ๋จ๊ณ๋ณ ๋ก๊ทธ ๊ธฐ๋ก |
|
""" |
|
|
|
import json |
|
import os |
|
import time |
|
from datetime import datetime |
|
from typing import Dict, List, Any, Optional |
|
from pathlib import Path |
|
import logging |
|
|
|
|
|
class TestTimeLogger: |
|
"""TestTime RLVR ์ ์ฉ ๋ก๊ฑฐ""" |
|
|
|
def __init__(self, log_dir: str = "logs", log_level: str = "INFO", task_output_dir: str = None, log_file: str = None): |
|
|
|
if task_output_dir: |
|
|
|
self.log_dir = Path(task_output_dir) |
|
self.use_integrated_structure = True |
|
else: |
|
|
|
self.log_dir = Path(log_dir) |
|
self.use_integrated_structure = False |
|
|
|
self.log_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
if self.use_integrated_structure: |
|
|
|
(self.log_dir / "current_evaluation").mkdir(exist_ok=True) |
|
(self.log_dir / "diverse_programs").mkdir(exist_ok=True) |
|
(self.log_dir / "llm_responses").mkdir(exist_ok=True) |
|
(self.log_dir / "azr_training_data").mkdir(exist_ok=True) |
|
|
|
|
|
|
|
self.logger = logging.getLogger("TestTimeRLVR") |
|
self.logger.setLevel(getattr(logging, log_level)) |
|
|
|
|
|
if not self.logger.handlers: |
|
|
|
if log_file: |
|
|
|
self.log_file_path = log_file |
|
file_handler = logging.FileHandler(log_file, mode='a') |
|
else: |
|
|
|
self.log_file_path = str(self.log_dir / f"testtime_rlvr_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log") |
|
file_handler = logging.FileHandler(self.log_file_path) |
|
file_handler.setLevel(logging.DEBUG) |
|
|
|
|
|
console_handler = logging.StreamHandler() |
|
console_handler.setLevel(getattr(logging, log_level)) |
|
|
|
|
|
formatter = logging.Formatter( |
|
'%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
) |
|
file_handler.setFormatter(formatter) |
|
console_handler.setFormatter(formatter) |
|
|
|
self.logger.addHandler(file_handler) |
|
self.logger.addHandler(console_handler) |
|
|
|
def _get_timestamp(self) -> str: |
|
"""ํ์ฌ ํ์์คํฌํ ๋ฐํ""" |
|
return datetime.now().isoformat() |
|
|
|
def _save_json_log(self, subdirectory: str, filename: str, data: Dict[str, Any]): |
|
"""JSON ๋ก๊ทธ ํ์ผ ์ ์ฅ""" |
|
if self.use_integrated_structure: |
|
|
|
if subdirectory == "ipo_extraction": |
|
|
|
log_path = self.log_dir / "diverse_programs" / f"{filename}.json" |
|
elif subdirectory == "task_generation": |
|
|
|
log_path = self.log_dir / f"{filename}.json" |
|
elif subdirectory == "problems": |
|
log_path = self.log_dir / "current_evaluation" / f"{filename}.json" |
|
elif subdirectory == "performance": |
|
log_path = self.log_dir / "current_evaluation" / f"{filename}.json" |
|
elif subdirectory == "training": |
|
log_path = self.log_dir / "azr_training_data" / f"{filename}.json" |
|
else: |
|
|
|
log_path = self.log_dir / subdirectory / f"{filename}.json" |
|
else: |
|
|
|
log_path = self.log_dir / subdirectory / f"{filename}.json" |
|
|
|
|
|
log_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
if log_path.exists(): |
|
with open(log_path, 'r', encoding='utf-8') as f: |
|
existing_logs = json.load(f) |
|
else: |
|
existing_logs = [] |
|
|
|
|
|
data['timestamp'] = self._get_timestamp() |
|
existing_logs.append(data) |
|
|
|
|
|
with open(log_path, 'w', encoding='utf-8') as f: |
|
json.dump(existing_logs, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
|
|
|
|
|
|
def log_problem_attempt(self, problem: Dict[str, Any], solution: str, |
|
is_correct: bool, validation_result: Optional[Dict] = None): |
|
"""๋ฒค์น๋งํฌ ๋ฌธ์ ์ LLM ๋ต๋ณ, ์ ๋ต ์ฌ๋ถ ๋ก๊ทธ""" |
|
|
|
log_data = { |
|
'problem_id': problem.get('task_id', 'unknown'), |
|
'benchmark': problem.get('benchmark_name', 'unknown'), |
|
'problem_prompt': problem.get('prompt', ''), |
|
'canonical_solution': problem.get('canonical_solution', ''), |
|
'llm_solution': solution, |
|
'is_correct': is_correct, |
|
'validation_result': validation_result or {} |
|
} |
|
|
|
self._save_json_log("problems", f"problem_{problem.get('task_id', 'unknown').replace('/', '_')}", log_data) |
|
|
|
status = "โ
CORRECT" if is_correct else "โ INCORRECT" |
|
self.logger.info(f"Problem {problem.get('task_id', 'unknown')}: {status}") |
|
|
|
def log_problem_loaded(self, problem_id: str, benchmark_name: str, method: str = "Original"): |
|
"""๋ฌธ์ ๋ก๋ฉ ๋ก๊ทธ (EvalPlus/Original ๋ฐฉ์ ๊ตฌ๋ถ)""" |
|
self.logger.info(f"Loaded problem {problem_id} from {benchmark_name} ({method} method)") |
|
|
|
|
|
|
|
|
|
|
|
def log_ipo_extraction(self, problem_id: str, extracted_triples: List[Dict], |
|
validation_results: List[bool]): |
|
"""์์ฑ๋ (i,p,o) ํธ๋ฆฌํ๊ณผ ๊ฒ์ฆ ๊ฒฐ๊ณผ ๋ก๊ทธ""" |
|
|
|
log_data = { |
|
'problem_id': problem_id, |
|
'num_triples': len(extracted_triples), |
|
'triples': extracted_triples, |
|
'validation_results': validation_results, |
|
'valid_triples': sum(validation_results), |
|
'invalid_triples': len(validation_results) - sum(validation_results) |
|
} |
|
|
|
self._save_json_log("ipo_extraction", f"ipo_{problem_id.replace('/', '_')}", log_data) |
|
|
|
self.logger.info(f"IPO Extraction for {problem_id}: {len(extracted_triples)} triples, " |
|
f"{sum(validation_results)} valid") |
|
|
|
|
|
|
|
|
|
|
|
def log_task_generation(self, problem_id: str, induction_tasks: List[Dict], |
|
deduction_tasks: List[Dict], abduction_tasks: List[Dict]): |
|
"""์์ฑ๋ induction, deduction, abduction ๋ฌธ์ ๋ก๊ทธ""" |
|
|
|
log_data = { |
|
'problem_id': problem_id, |
|
'induction_tasks': { |
|
'count': len(induction_tasks), |
|
'tasks': induction_tasks |
|
}, |
|
'deduction_tasks': { |
|
'count': len(deduction_tasks), |
|
'tasks': deduction_tasks |
|
}, |
|
'abduction_tasks': { |
|
'count': len(abduction_tasks), |
|
'tasks': abduction_tasks |
|
}, |
|
'total_tasks': len(induction_tasks) + len(deduction_tasks) + len(abduction_tasks) |
|
} |
|
|
|
self._save_json_log("task_generation", f"tasks_{problem_id.replace('/', '_')}", log_data) |
|
|
|
total_tasks = log_data['total_tasks'] |
|
self.logger.info(f"Task Generation for {problem_id}: {total_tasks} tasks " |
|
f"(I:{len(induction_tasks)}, D:{len(deduction_tasks)}, A:{len(abduction_tasks)})") |
|
|
|
|
|
|
|
|
|
|
|
def log_task_accuracy(self, problem_id: str, task_type: str, accuracy: float, |
|
rewards: List[float], step: int): |
|
"""induction/deduction/abduction ํ์คํฌ ์ ํ๋์ reward ๋ก๊ทธ""" |
|
|
|
log_data = { |
|
'problem_id': problem_id, |
|
'task_type': task_type, |
|
'step': step, |
|
'accuracy': accuracy, |
|
'rewards': rewards, |
|
'avg_reward': sum(rewards) / len(rewards) if rewards else 0.0, |
|
'max_reward': max(rewards) if rewards else 0.0, |
|
'min_reward': min(rewards) if rewards else 0.0 |
|
} |
|
|
|
self._save_json_log("training", f"accuracy_{problem_id.replace('/', '_')}", log_data) |
|
|
|
self.logger.info(f"Step {step} - {task_type.capitalize()} accuracy: {accuracy:.4f}, " |
|
f"avg reward: {log_data['avg_reward']:.4f}") |
|
|
|
def log_verl_training(self, problem_id: str, step: int, loss: float, |
|
learning_rate: float, metrics: Dict[str, Any]): |
|
"""VeRL ํ์ต ์งํ ์ํฉ ๋ก๊ทธ""" |
|
|
|
log_data = { |
|
'problem_id': problem_id, |
|
'step': step, |
|
'loss': loss, |
|
'learning_rate': learning_rate, |
|
'metrics': metrics |
|
} |
|
|
|
self._save_json_log("training", f"verl_{problem_id.replace('/', '_')}", log_data) |
|
|
|
self.logger.info(f"VeRL Training Step {step}: loss={loss:.6f}, lr={learning_rate:.2e}") |
|
|
|
|
|
|
|
|
|
|
|
def log_performance_change(self, problem_id: str, cycle: int, |
|
before_accuracy: float, after_accuracy: float, |
|
improvement: float): |
|
"""๋งค ์ฌ์ดํด๋ณ ์ฑ๋ฅ ๋ณํ ๋ก๊ทธ""" |
|
|
|
log_data = { |
|
'problem_id': problem_id, |
|
'cycle': cycle, |
|
'before_accuracy': before_accuracy, |
|
'after_accuracy': after_accuracy, |
|
'improvement': improvement, |
|
'improvement_percentage': improvement * 100 |
|
} |
|
|
|
self._save_json_log("performance", f"cycle_{problem_id.replace('/', '_')}", log_data) |
|
|
|
direction = "โ๏ธ" if improvement > 0 else "โ๏ธ" if improvement < 0 else "โ" |
|
self.logger.info(f"Cycle {cycle} Performance: {before_accuracy:.4f} โ {after_accuracy:.4f} " |
|
f"({direction} {improvement:+.4f})") |
|
|
|
|
|
|
|
|
|
|
|
def log_info(self, message: str): |
|
"""์ผ๋ฐ ์ ๋ณด ๋ก๊ทธ""" |
|
self.logger.info(message) |
|
|
|
def log_error(self, message: str): |
|
"""์๋ฌ ๋ก๊ทธ""" |
|
self.logger.error(message) |
|
|
|
def log_warning(self, message: str): |
|
"""๊ฒฝ๊ณ ๋ก๊ทธ""" |
|
self.logger.warning(message) |
|
|
|
def log_debug(self, message: str): |
|
"""๋๋ฒ๊ทธ ๋ก๊ทธ""" |
|
self.logger.debug(message) |
|
|
|
def get_log_summary(self) -> Dict[str, Any]: |
|
"""๋ก๊ทธ ์์ฝ ์ ๋ณด ๋ฐํ""" |
|
summary = { |
|
'log_directory': str(self.log_dir), |
|
'subdirectories': { |
|
'problems': len(list((self.log_dir / "problems").glob("*.json"))), |
|
'ipo_extraction': len(list((self.log_dir / "ipo_extraction").glob("*.json"))), |
|
'task_generation': len(list((self.log_dir / "task_generation").glob("*.json"))), |
|
'training': len(list((self.log_dir / "training").glob("*.json"))), |
|
'performance': len(list((self.log_dir / "performance").glob("*.json"))) |
|
} |
|
} |
|
|
|
return summary |