|
""" |
|
TestTime RLVR Configuration |
|
|
|
AZR ๊ธฐ๋ฐ TestTime RLVR์ ์ํ ์ค์ ํด๋์ค |
|
""" |
|
|
|
from dataclasses import dataclass |
|
from typing import Optional, List, Dict, Any |
|
import torch |
|
|
|
|
|
@dataclass |
|
class TestTimeConfig: |
|
"""TestTime RLVR ์ ์ฉ ์ค์ """ |
|
|
|
|
|
|
|
|
|
model_name: str = "Qwen/Qwen2.5-7B" |
|
device: str = "auto" |
|
torch_dtype: torch.dtype = torch.bfloat16 |
|
use_flash_attention: bool = True |
|
enable_gradient_checkpointing: bool = True |
|
|
|
|
|
|
|
|
|
max_adaptation_steps: int = 10 |
|
adaptation_batch_size: int = 1 |
|
gradient_accumulation_steps: int = 4 |
|
learning_rate: float = 1e-6 |
|
|
|
|
|
|
|
|
|
max_cycles: int = 3 |
|
min_improvement_threshold: float = 0.05 |
|
early_stopping_patience: int = 2 |
|
|
|
|
|
|
|
|
|
max_ipo_triples: int = 10 |
|
python_executor_timeout: int = 5 |
|
validate_triples: bool = True |
|
|
|
|
|
|
|
|
|
num_program_variations: int = 4 |
|
baseline_evaluation_rounds: int = 5 |
|
diverse_generation_temperature: float = 0.7 |
|
baseline_generation_temperature: float = 0.05 |
|
|
|
|
|
|
|
|
|
task_distribution: Dict[str, float] = None |
|
max_tasks_per_type: int = 5 |
|
use_azr_templates: bool = True |
|
skip_task_evaluation: bool = True |
|
|
|
|
|
|
|
|
|
use_accuracy_reward: bool = True |
|
use_improvement_reward: bool = True |
|
use_complexity_reward: bool = True |
|
accuracy_weight: float = 1.0 |
|
improvement_weight: float = 0.5 |
|
complexity_weight: float = 0.1 |
|
|
|
|
|
|
|
|
|
log_level: str = "INFO" |
|
save_intermediate_results: bool = True |
|
log_ipo_details: bool = True |
|
log_task_details: bool = True |
|
log_training_metrics: bool = True |
|
|
|
|
|
|
|
|
|
gpu_memory_utilization: float = 0.4 |
|
max_workers: int = 2 |
|
use_memory_efficient_attention: bool = True |
|
|
|
def __post_init__(self): |
|
"""์ค์ ํ์ฒ๋ฆฌ""" |
|
if self.task_distribution is None: |
|
|
|
self.task_distribution = { |
|
"induction": 0.33, |
|
"deduction": 0.33, |
|
"abduction": 0.34 |
|
} |
|
|
|
|
|
if self.device == "auto": |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
if self.device == "cpu": |
|
self.torch_dtype = torch.float32 |
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
"""์ค์ ์ ๋์
๋๋ฆฌ๋ก ๋ณํ""" |
|
return { |
|
"model_name": self.model_name, |
|
"device": self.device, |
|
"torch_dtype": str(self.torch_dtype), |
|
"max_adaptation_steps": self.max_adaptation_steps, |
|
"max_cycles": self.max_cycles, |
|
"learning_rate": self.learning_rate, |
|
"task_distribution": self.task_distribution, |
|
"reward_weights": { |
|
"accuracy": self.accuracy_weight, |
|
"improvement": self.improvement_weight, |
|
"complexity": self.complexity_weight |
|
} |
|
} |
|
|
|
@classmethod |
|
def from_dict(cls, config_dict: Dict[str, Any]) -> 'TestTimeConfig': |
|
"""๋์
๋๋ฆฌ์์ ์ค์ ๋ก๋""" |
|
return cls(**config_dict) |
|
|
|
|
|
@dataclass |
|
class BenchmarkConfig: |
|
"""๋ฒค์น๋งํฌ๋ณ ์ค์ """ |
|
|
|
name: str |
|
data_path: str |
|
problem_prefix: str |
|
start_index: int = 0 |
|
max_problems: int = 5 |
|
|
|
|
|
test_timeout: int = 10 |
|
use_plus_version: bool = True |
|
|
|
@classmethod |
|
def get_humaneval_config(cls) -> 'BenchmarkConfig': |
|
return cls( |
|
name="humaneval", |
|
data_path="evaluation/code_eval/data/HumanEvalPlus.jsonl", |
|
problem_prefix="HumanEval", |
|
start_index=0, |
|
max_problems=5 |
|
) |
|
|
|
@classmethod |
|
def get_mbpp_config(cls) -> 'BenchmarkConfig': |
|
return cls( |
|
name="mbpp", |
|
data_path="evaluation/code_eval/data/MbppPlus.jsonl", |
|
problem_prefix="Mbpp", |
|
start_index=2, |
|
max_problems=5 |
|
) |