hjkim00's picture
Upload TestTime-RLVR-v2 from Full-pipeline-relative_0827 branch
f50dc54 verified
"""
TestTime RLVR Configuration
AZR ๊ธฐ๋ฐ˜ TestTime RLVR์„ ์œ„ํ•œ ์„ค์ • ํด๋ž˜์Šค
"""
from dataclasses import dataclass
from typing import Optional, List, Dict, Any
import torch
@dataclass
class TestTimeConfig:
"""TestTime RLVR ์ „์šฉ ์„ค์ •"""
# ============================================================================
# ๊ธฐ๋ณธ ๋ชจ๋ธ ์„ค์ • (AZR ๊ธฐ๋ฐ˜)
# ============================================================================
model_name: str = "Qwen/Qwen2.5-7B"
device: str = "auto"
torch_dtype: torch.dtype = torch.bfloat16
use_flash_attention: bool = True
enable_gradient_checkpointing: bool = True
# ============================================================================
# TestTime ํ•™์Šต ์„ค์ •
# ============================================================================
max_adaptation_steps: int = 10 # AZR ๋Œ€๋น„ ์งง์€ ์ ์‘ ํ•™์Šต
adaptation_batch_size: int = 1 # ์†Œ๊ทœ๋ชจ ๋ฐฐ์น˜
gradient_accumulation_steps: int = 4
learning_rate: float = 1e-6 # AZR๊ณผ ๋™์ผ
# ============================================================================
# ๋ฐ˜๋ณต ์ œ์–ด ์„ค์ •
# ============================================================================
max_cycles: int = 3 # ์ตœ๋Œ€ ๋ฐ˜๋ณต ํšŸ์ˆ˜
min_improvement_threshold: float = 0.05 # ์ตœ์†Œ ๊ฐœ์„  ์ž„๊ณ„๊ฐ’
early_stopping_patience: int = 2 # Early stopping
# ============================================================================
# IPO ์ถ”์ถœ ์„ค์ •
# ============================================================================
max_ipo_triples: int = 10 # ์ถ”์ถœํ•  ์ตœ๋Œ€ ํŠธ๋ฆฌํ”Œ ์ˆ˜
python_executor_timeout: int = 5 # AZR๋ณด๋‹ค ์งง์€ ํƒ€์ž„์•„์›ƒ
validate_triples: bool = True # ํŠธ๋ฆฌํ”Œ ๊ฒ€์ฆ ์—ฌ๋ถ€
# ============================================================================
# ๋‹ค์ค‘ ํ”„๋กœ๊ทธ๋žจ ์ƒ์„ฑ ์„ค์ •
# ============================================================================
num_program_variations: int = 4 # ์ƒ์„ฑํ•  ๋‹ค์–‘ํ•œ ํ”„๋กœ๊ทธ๋žจ ์ˆ˜
baseline_evaluation_rounds: int = 5 # ๋ฒ ์ด์Šค๋ผ์ธ ์„ฑ๋Šฅ ์ธก์ • ํšŸ์ˆ˜
diverse_generation_temperature: float = 0.7 # ๋‹ค์–‘ํ•œ ํ”„๋กœ๊ทธ๋žจ ์ƒ์„ฑ์šฉ temperature
baseline_generation_temperature: float = 0.05 # ๋ฒ ์ด์Šค๋ผ์ธ ์ธก์ •์šฉ temperature
# ============================================================================
# ํƒœ์Šคํฌ ์ƒ์„ฑ ์„ค์ •
# ============================================================================
task_distribution: Dict[str, float] = None # induction:deduction:abduction ๋น„์œจ
max_tasks_per_type: int = 5 # ํƒ€์ž…๋ณ„ ์ตœ๋Œ€ ํƒœ์Šคํฌ ์ˆ˜
use_azr_templates: bool = True # AZR ํ…œํ”Œ๋ฆฟ ์‚ฌ์šฉ
skip_task_evaluation: bool = True # Task evaluation(4๋‹จ๊ณ„) ์Šคํ‚ต ์—ฌ๋ถ€ (VeRL์—์„œ ์ˆ˜ํ–‰)
# ============================================================================
# ๋ณด์ƒ ์„ค์ • (AZR ๊ธฐ๋ฐ˜)
# ============================================================================
use_accuracy_reward: bool = True
use_improvement_reward: bool = True # TestTime ์ „์šฉ ๊ฐœ์„ ๋„ ๋ณด์ƒ
use_complexity_reward: bool = True
accuracy_weight: float = 1.0
improvement_weight: float = 0.5 # ๊ฐœ์„ ๋„ ๊ฐ€์ค‘์น˜
complexity_weight: float = 0.1
# ============================================================================
# ๋กœ๊น… ์„ค์ •
# ============================================================================
log_level: str = "INFO"
save_intermediate_results: bool = True
log_ipo_details: bool = True
log_task_details: bool = True
log_training_metrics: bool = True
# ============================================================================
# ๋ฉ”๋ชจ๋ฆฌ ์ตœ์ ํ™” ์„ค์ • (AZR ๊ธฐ๋ฐ˜)
# ============================================================================
gpu_memory_utilization: float = 0.4
max_workers: int = 2 # Python executor workers
use_memory_efficient_attention: bool = True
def __post_init__(self):
"""์„ค์ • ํ›„์ฒ˜๋ฆฌ"""
if self.task_distribution is None:
# ๊ธฐ๋ณธ ํƒœ์Šคํฌ ๋ถ„ํฌ: ๊ท ๋“ฑ ๋ถ„๋ฐฐ
self.task_distribution = {
"induction": 0.33,
"deduction": 0.33,
"abduction": 0.34
}
# device ์ž๋™ ์„ค์ •
if self.device == "auto":
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# dtype ์„ค์ •
if self.device == "cpu":
self.torch_dtype = torch.float32
def to_dict(self) -> Dict[str, Any]:
"""์„ค์ •์„ ๋”•์…”๋„ˆ๋ฆฌ๋กœ ๋ณ€ํ™˜"""
return {
"model_name": self.model_name,
"device": self.device,
"torch_dtype": str(self.torch_dtype),
"max_adaptation_steps": self.max_adaptation_steps,
"max_cycles": self.max_cycles,
"learning_rate": self.learning_rate,
"task_distribution": self.task_distribution,
"reward_weights": {
"accuracy": self.accuracy_weight,
"improvement": self.improvement_weight,
"complexity": self.complexity_weight
}
}
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> 'TestTimeConfig':
"""๋”•์…”๋„ˆ๋ฆฌ์—์„œ ์„ค์ • ๋กœ๋“œ"""
return cls(**config_dict)
@dataclass
class BenchmarkConfig:
"""๋ฒค์น˜๋งˆํฌ๋ณ„ ์„ค์ •"""
name: str # "humaneval", "mbpp", "livecodebase"
data_path: str
problem_prefix: str # "HumanEval", "Mbpp"
start_index: int = 0 # MBPP๋Š” 2๋ถ€ํ„ฐ ์‹œ์ž‘
max_problems: int = 5 # ํ…Œ์ŠคํŠธํ•  ๋ฌธ์ œ ์ˆ˜
# ๋ฒค์น˜๋งˆํฌ๋ณ„ ํŠนํ™” ์„ค์ •
test_timeout: int = 10
use_plus_version: bool = True # HumanEval+, MBPP+ ์‚ฌ์šฉ
@classmethod
def get_humaneval_config(cls) -> 'BenchmarkConfig':
return cls(
name="humaneval",
data_path="evaluation/code_eval/data/HumanEvalPlus.jsonl",
problem_prefix="HumanEval",
start_index=0,
max_problems=5
)
@classmethod
def get_mbpp_config(cls) -> 'BenchmarkConfig':
return cls(
name="mbpp",
data_path="evaluation/code_eval/data/MbppPlus.jsonl",
problem_prefix="Mbpp",
start_index=2, # MBPP๋Š” 2๋ฒˆ๋ถ€ํ„ฐ
max_problems=5
)