hjkim00's picture
Upload TestTime-RLVR-v2 from Full-pipeline-relative_0827 branch
f50dc54 verified
"""
TestTime Logger
TestTime RLVR์„ ์œ„ํ•œ ํฌ๊ด„์  ๋กœ๊น… ์‹œ์Šคํ…œ
์š”๊ตฌ์‚ฌํ•ญ์— ๋”ฐ๋ฅธ ๋ชจ๋“  ๋‹จ๊ณ„๋ณ„ ๋กœ๊ทธ ๊ธฐ๋ก
"""
import json
import os
import time
from datetime import datetime
from typing import Dict, List, Any, Optional
from pathlib import Path
import logging
class TestTimeLogger:
"""TestTime RLVR ์ „์šฉ ๋กœ๊ฑฐ"""
def __init__(self, log_dir: str = "logs", log_level: str = "INFO", task_output_dir: str = None, log_file: str = None):
# ์„ค๊ณ„๋œ ๊ตฌ์กฐ์— ๋งž๋Š” ๋กœ๊ทธ ๋””๋ ‰ํ† ๋ฆฌ ์„ค์ •
if task_output_dir:
# TTRLVR ํ†ตํ•ฉ ๋ชจ๋“œ: ์„ค๊ณ„๋œ ๋””๋ ‰ํ† ๋ฆฌ ๊ตฌ์กฐ ์‚ฌ์šฉ
self.log_dir = Path(task_output_dir)
self.use_integrated_structure = True
else:
# ๊ธฐ์กด ๋ชจ๋“œ: ๊ธฐ๋ณธ logs ๋””๋ ‰ํ† ๋ฆฌ ์‚ฌ์šฉ
self.log_dir = Path(log_dir)
self.use_integrated_structure = False
self.log_dir.mkdir(parents=True, exist_ok=True)
# ๋””๋ ‰ํ† ๋ฆฌ ๊ตฌ์กฐ์— ๋”ฐ๋ฅธ ์„œ๋ธŒ ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ
if self.use_integrated_structure:
# ์„ค๊ณ„๋œ ๊ตฌ์กฐ: round_N ํ•˜์œ„์— ์„ธ๋ถ€ ๋””๋ ‰ํ† ๋ฆฌ
(self.log_dir / "current_evaluation").mkdir(exist_ok=True)
(self.log_dir / "diverse_programs").mkdir(exist_ok=True)
(self.log_dir / "llm_responses").mkdir(exist_ok=True)
(self.log_dir / "azr_training_data").mkdir(exist_ok=True)
# ๊ธฐ์กด ๊ตฌ์กฐ์—์„œ๋Š” ์„œ๋ธŒ ๋””๋ ‰ํ† ๋ฆฌ๋ฅผ ์ƒ์„ฑํ•˜์ง€ ์•Š์Œ (๋ฉ”์ธ ๋กœ๊ทธ ํŒŒ์ผ๋งŒ)
# ๊ธฐ๋ณธ ๋กœ๊ฑฐ ์„ค์ •
self.logger = logging.getLogger("TestTimeRLVR")
self.logger.setLevel(getattr(logging, log_level))
# ํ•ธ๋“ค๋Ÿฌ ์„ค์ •
if not self.logger.handlers:
# ํŒŒ์ผ ํ•ธ๋“ค๋Ÿฌ
if log_file:
# ํŠน์ • ๋กœ๊ทธ ํŒŒ์ผ ๊ฒฝ๋กœ๊ฐ€ ์ฃผ์–ด์ง„ ๊ฒฝ์šฐ (Ray worker์—์„œ ์‚ฌ์šฉ)
self.log_file_path = log_file
file_handler = logging.FileHandler(log_file, mode='a') # append mode
else:
# ๊ธฐ๋ณธ ๋กœ๊ทธ ํŒŒ์ผ ์ƒ์„ฑ
self.log_file_path = str(self.log_dir / f"testtime_rlvr_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
file_handler = logging.FileHandler(self.log_file_path)
file_handler.setLevel(logging.DEBUG)
# ์ฝ˜์†” ํ•ธ๋“ค๋Ÿฌ
console_handler = logging.StreamHandler()
console_handler.setLevel(getattr(logging, log_level))
# ํฌ๋งทํ„ฐ
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
file_handler.setFormatter(formatter)
console_handler.setFormatter(formatter)
self.logger.addHandler(file_handler)
self.logger.addHandler(console_handler)
def _get_timestamp(self) -> str:
"""ํ˜„์žฌ ํƒ€์ž„์Šคํƒฌํ”„ ๋ฐ˜ํ™˜"""
return datetime.now().isoformat()
def _save_json_log(self, subdirectory: str, filename: str, data: Dict[str, Any]):
"""JSON ๋กœ๊ทธ ํŒŒ์ผ ์ €์žฅ"""
if self.use_integrated_structure:
# ์„ค๊ณ„๋œ ๊ตฌ์กฐ: ๊ฐ ์นดํ…Œ๊ณ ๋ฆฌ๋ณ„๋กœ ์ ์ ˆํ•œ ๋””๋ ‰ํ† ๋ฆฌ์— ์ €์žฅ
if subdirectory == "ipo_extraction":
# IPO ์ถ”์ถœ ๋กœ๊ทธ๋Š” diverse_programs ํ•˜์œ„์— ๋ณ„๋„๋กœ ์ €์žฅ
log_path = self.log_dir / "diverse_programs" / f"{filename}.json"
elif subdirectory == "task_generation":
# Task generation ๋กœ๊ทธ๋Š” round ๋ ˆ๋ฒจ์— ์ €์žฅ (๋ชจ๋“  task ์ข…๋ฅ˜ ํฌํ•จ)
log_path = self.log_dir / f"{filename}.json"
elif subdirectory == "problems":
log_path = self.log_dir / "current_evaluation" / f"{filename}.json"
elif subdirectory == "performance":
log_path = self.log_dir / "current_evaluation" / f"{filename}.json"
elif subdirectory == "training":
log_path = self.log_dir / "azr_training_data" / f"{filename}.json"
else:
# ๊ธฐ๋ณธ๊ฐ’
log_path = self.log_dir / subdirectory / f"{filename}.json"
else:
# ๊ธฐ์กด ๊ตฌ์กฐ
log_path = self.log_dir / subdirectory / f"{filename}.json"
# ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ (์—†๋‹ค๋ฉด)
log_path.parent.mkdir(parents=True, exist_ok=True)
# ๊ธฐ์กด ๋กœ๊ทธ ๋กœ๋“œ (์žˆ๋‹ค๋ฉด)
if log_path.exists():
with open(log_path, 'r', encoding='utf-8') as f:
existing_logs = json.load(f)
else:
existing_logs = []
# ์ƒˆ ๋กœ๊ทธ ์ถ”๊ฐ€
data['timestamp'] = self._get_timestamp()
existing_logs.append(data)
# ์ €์žฅ
with open(log_path, 'w', encoding='utf-8') as f:
json.dump(existing_logs, f, indent=2, ensure_ascii=False)
# ============================================================================
# 1. ๋ฒค์น˜๋งˆํฌ ๋ฌธ์ œ ๋กœ๊น… (์š”๊ตฌ์‚ฌํ•ญ 1)
# ============================================================================
def log_problem_attempt(self, problem: Dict[str, Any], solution: str,
is_correct: bool, validation_result: Optional[Dict] = None):
"""๋ฒค์น˜๋งˆํฌ ๋ฌธ์ œ์™€ LLM ๋‹ต๋ณ€, ์ •๋‹ต ์—ฌ๋ถ€ ๋กœ๊ทธ"""
log_data = {
'problem_id': problem.get('task_id', 'unknown'),
'benchmark': problem.get('benchmark_name', 'unknown'),
'problem_prompt': problem.get('prompt', ''),
'canonical_solution': problem.get('canonical_solution', ''),
'llm_solution': solution,
'is_correct': is_correct,
'validation_result': validation_result or {}
}
self._save_json_log("problems", f"problem_{problem.get('task_id', 'unknown').replace('/', '_')}", log_data)
status = "โœ… CORRECT" if is_correct else "โŒ INCORRECT"
self.logger.info(f"Problem {problem.get('task_id', 'unknown')}: {status}")
def log_problem_loaded(self, problem_id: str, benchmark_name: str, method: str = "Original"):
"""๋ฌธ์ œ ๋กœ๋”ฉ ๋กœ๊ทธ (EvalPlus/Original ๋ฐฉ์‹ ๊ตฌ๋ถ„)"""
self.logger.info(f"Loaded problem {problem_id} from {benchmark_name} ({method} method)")
# ============================================================================
# 2. IPO ์ถ”์ถœ ๋กœ๊น… (์š”๊ตฌ์‚ฌํ•ญ 2)
# ============================================================================
def log_ipo_extraction(self, problem_id: str, extracted_triples: List[Dict],
validation_results: List[bool]):
"""์ƒ์„ฑ๋œ (i,p,o) ํŠธ๋ฆฌํ”Œ๊ณผ ๊ฒ€์ฆ ๊ฒฐ๊ณผ ๋กœ๊ทธ"""
log_data = {
'problem_id': problem_id,
'num_triples': len(extracted_triples),
'triples': extracted_triples,
'validation_results': validation_results,
'valid_triples': sum(validation_results),
'invalid_triples': len(validation_results) - sum(validation_results)
}
self._save_json_log("ipo_extraction", f"ipo_{problem_id.replace('/', '_')}", log_data)
self.logger.info(f"IPO Extraction for {problem_id}: {len(extracted_triples)} triples, "
f"{sum(validation_results)} valid")
# ============================================================================
# 3. ํƒœ์Šคํฌ ์ƒ์„ฑ ๋กœ๊น… (์š”๊ตฌ์‚ฌํ•ญ 2)
# ============================================================================
def log_task_generation(self, problem_id: str, induction_tasks: List[Dict],
deduction_tasks: List[Dict], abduction_tasks: List[Dict]):
"""์ƒ์„ฑ๋œ induction, deduction, abduction ๋ฌธ์ œ ๋กœ๊ทธ"""
log_data = {
'problem_id': problem_id,
'induction_tasks': {
'count': len(induction_tasks),
'tasks': induction_tasks
},
'deduction_tasks': {
'count': len(deduction_tasks),
'tasks': deduction_tasks
},
'abduction_tasks': {
'count': len(abduction_tasks),
'tasks': abduction_tasks
},
'total_tasks': len(induction_tasks) + len(deduction_tasks) + len(abduction_tasks)
}
self._save_json_log("task_generation", f"tasks_{problem_id.replace('/', '_')}", log_data)
total_tasks = log_data['total_tasks']
self.logger.info(f"Task Generation for {problem_id}: {total_tasks} tasks "
f"(I:{len(induction_tasks)}, D:{len(deduction_tasks)}, A:{len(abduction_tasks)})")
# ============================================================================
# 4. ํ•™์Šต ๋ฉ”ํŠธ๋ฆญ ๋กœ๊น… (์š”๊ตฌ์‚ฌํ•ญ 3, 4)
# ============================================================================
def log_task_accuracy(self, problem_id: str, task_type: str, accuracy: float,
rewards: List[float], step: int):
"""induction/deduction/abduction ํƒœ์Šคํฌ ์ •ํ™•๋„์™€ reward ๋กœ๊ทธ"""
log_data = {
'problem_id': problem_id,
'task_type': task_type, # 'induction', 'deduction', 'abduction'
'step': step,
'accuracy': accuracy,
'rewards': rewards,
'avg_reward': sum(rewards) / len(rewards) if rewards else 0.0,
'max_reward': max(rewards) if rewards else 0.0,
'min_reward': min(rewards) if rewards else 0.0
}
self._save_json_log("training", f"accuracy_{problem_id.replace('/', '_')}", log_data)
self.logger.info(f"Step {step} - {task_type.capitalize()} accuracy: {accuracy:.4f}, "
f"avg reward: {log_data['avg_reward']:.4f}")
def log_verl_training(self, problem_id: str, step: int, loss: float,
learning_rate: float, metrics: Dict[str, Any]):
"""VeRL ํ•™์Šต ์ง„ํ–‰ ์ƒํ™ฉ ๋กœ๊ทธ"""
log_data = {
'problem_id': problem_id,
'step': step,
'loss': loss,
'learning_rate': learning_rate,
'metrics': metrics
}
self._save_json_log("training", f"verl_{problem_id.replace('/', '_')}", log_data)
self.logger.info(f"VeRL Training Step {step}: loss={loss:.6f}, lr={learning_rate:.2e}")
# ============================================================================
# 5. ์„ฑ๋Šฅ ๋ณ€ํ™” ๋กœ๊น…
# ============================================================================
def log_performance_change(self, problem_id: str, cycle: int,
before_accuracy: float, after_accuracy: float,
improvement: float):
"""๋งค ์‚ฌ์ดํด๋ณ„ ์„ฑ๋Šฅ ๋ณ€ํ™” ๋กœ๊ทธ"""
log_data = {
'problem_id': problem_id,
'cycle': cycle,
'before_accuracy': before_accuracy,
'after_accuracy': after_accuracy,
'improvement': improvement,
'improvement_percentage': improvement * 100
}
self._save_json_log("performance", f"cycle_{problem_id.replace('/', '_')}", log_data)
direction = "โ†—๏ธ" if improvement > 0 else "โ†˜๏ธ" if improvement < 0 else "โ†’"
self.logger.info(f"Cycle {cycle} Performance: {before_accuracy:.4f} โ†’ {after_accuracy:.4f} "
f"({direction} {improvement:+.4f})")
# ============================================================================
# ์ผ๋ฐ˜ ๋กœ๊น…
# ============================================================================
def log_info(self, message: str):
"""์ผ๋ฐ˜ ์ •๋ณด ๋กœ๊ทธ"""
self.logger.info(message)
def log_error(self, message: str):
"""์—๋Ÿฌ ๋กœ๊ทธ"""
self.logger.error(message)
def log_warning(self, message: str):
"""๊ฒฝ๊ณ  ๋กœ๊ทธ"""
self.logger.warning(message)
def log_debug(self, message: str):
"""๋””๋ฒ„๊ทธ ๋กœ๊ทธ"""
self.logger.debug(message)
def get_log_summary(self) -> Dict[str, Any]:
"""๋กœ๊ทธ ์š”์•ฝ ์ •๋ณด ๋ฐ˜ํ™˜"""
summary = {
'log_directory': str(self.log_dir),
'subdirectories': {
'problems': len(list((self.log_dir / "problems").glob("*.json"))),
'ipo_extraction': len(list((self.log_dir / "ipo_extraction").glob("*.json"))),
'task_generation': len(list((self.log_dir / "task_generation").glob("*.json"))),
'training': len(list((self.log_dir / "training").glob("*.json"))),
'performance': len(list((self.log_dir / "performance").glob("*.json")))
}
}
return summary