hjkim00's picture
Upload TestTime-RLVR-v2 from Full-pipeline-relative_0827 branch
f50dc54 verified
"""
Initial Solution Generator
AZR 기반 TestTime RLVR을 μœ„ν•œ 초기 μ†”λ£¨μ…˜ 생성기
κΈ°μ‘΄ Test-Time-RLVR의 generate_initial_solution ν•¨μˆ˜λ₯Ό ν΄λž˜μŠ€ν™”ν•˜μ—¬ ν™•μž₯
"""
import re
import torch
from typing import Dict, Any, Optional, Tuple, List
from transformers import AutoTokenizer, AutoModelForCausalLM
from .config import TestTimeConfig
from .logger import TestTimeLogger
from .prompts import get_prompt, get_temperature, get_diversity_instruction
# AZRμ—μ„œ μ‚¬μš©ν•˜λŠ” μ½”λ“œ μΆ”μΆœ ν•¨μˆ˜ 직접 μž„ν¬νŠΈ
from ..rewards.custom_evaluate import extract_code
# VLLM μ΅œμ ν™” 지원
try:
from vllm import LLM, SamplingParams
VLLM_AVAILABLE = True
except ImportError:
VLLM_AVAILABLE = False
class InitialSolutionGenerator:
"""벀치마크 λ¬Έμ œμ— λŒ€ν•œ 초기 μ†”λ£¨μ…˜ 생성"""
def __init__(self, model, tokenizer, config: TestTimeConfig,
logger: Optional[TestTimeLogger] = None, use_vllm: bool = True):
self.model = model
self.tokenizer = tokenizer
self.config = config
self.logger = logger or TestTimeLogger()
self.use_vllm = use_vllm and VLLM_AVAILABLE
# VLLM μ‚¬μš© κ°€λŠ₯ μ—¬λΆ€ 확인 및 λ‘œκΉ…
if use_vllm and not VLLM_AVAILABLE:
self.logger.log_info("⚠️ VLLM requested but not available, falling back to HuggingFace")
elif self.use_vllm:
self.logger.log_info("πŸš€ Using VLLM for optimized inference")
else:
self.logger.log_info("πŸ”§ Using HuggingFace Transformers for inference")
def generate(self, problem: Dict[str, Any]) -> str:
"""λ¬Έμ œμ— λŒ€ν•œ 초기 μ†”λ£¨μ…˜ 생성 (AZR μ½”λ“œ 평가 ν”„λ‘¬ν”„νŠΈ μ‚¬μš©)"""
problem_prompt = problem['prompt']
problem_id = problem.get('task_id', 'unknown')
# AZR μ½”λ“œ ν‰κ°€μ—μ„œ μ‚¬μš©ν•˜λŠ” ν”„λ‘¬ν”„νŠΈ 포맷 적용
# prompt = f"Please provide a self-contained Python script that solves the following problem in a markdown code block:\n\n{problem_prompt}"
# 쀑앙 ν”„λ‘¬ν”„νŠΈ μ‹œμŠ€ν…œ μ‚¬μš©
if 'HumanEval' in problem_id:
# entry_point ν•¨μˆ˜λͺ… μ°ΎκΈ°
entry_point = problem.get('entry_point', 'unknown')
# ν”„λ‘¬ν”„νŠΈμ—μ„œ ν•¨μˆ˜κ°€ μ—¬λŸ¬ 개 μžˆλŠ”μ§€ 확인
import re
function_count = len(re.findall(r'^\s*def\s+\w+', problem_prompt, re.MULTILINE))
if function_count > 1:
# 닀쀑 ν•¨μˆ˜ ν”„λ‘¬ν”„νŠΈ μ‚¬μš©
prompt = get_prompt("solution_humaneval_multi",
problem_prompt=problem_prompt,
entry_point=entry_point)
else:
# 단일 ν•¨μˆ˜ ν”„λ‘¬ν”„νŠΈ μ‚¬μš©
prompt = get_prompt("solution_humaneval_basic",
problem_prompt=problem_prompt)
else:
# MBPP ν”„λ‘¬ν”„νŠΈ μ‚¬μš©
prompt = get_prompt("solution_mbpp_basic",
problem_prompt=problem_prompt)
self.logger.log_info(f"πŸ” Generating initial solution for {problem_id}")
self.logger.log_info(f"πŸ“‹ Full prompt: {prompt}")
# VLLM λ˜λŠ” HuggingFace λ°±μ—”λ“œ 선택
if self.use_vllm and isinstance(self.model, LLM):
solution = self._generate_with_vllm(prompt)
else:
solution = self._generate_with_huggingface(prompt)
# λ§ˆν¬λ‹€μš΄ μ½”λ“œ λΈ”λ‘μ—μ„œ Python μ½”λ“œ μΆ”μΆœ (κ°œμ„ λœ 방식)
extracted_solution = self._extract_python_code(solution)
# μ½”λ“œ μΆ”μΆœ κ²°κ³Ό λ‘œκΉ…
if extracted_solution and extracted_solution != solution:
self.logger.log_info(f"πŸ” Extracted Python code from markdown block")
solution = extracted_solution
elif not extracted_solution:
self.logger.log_info(f"πŸ” No markdown code block found, using original text")
# HumanEval의 경우 ν”„λ‘¬ν”„νŠΈμ—μ„œ import μΆ”μΆœν•˜μ—¬ μΆ”κ°€ (EvalPlus 방식)
if 'HumanEval' in problem_id:
solution = self._add_imports_from_prompt(solution, problem_prompt)
# ν•¨μˆ˜ μ •μ˜ 볡ꡬ (AZR 둜직 κ·ΈλŒ€λ‘œ)
solution = self._fix_function_definition(solution, prompt, problem_id)
self.logger.log_info(f"βœ… Generated solution ({len(solution)} chars)")
self.logger.log_info(f"πŸ” Solution preview: {solution[:200]}...")
# 디버깅: μ‹€μ œ μ†”λ£¨μ…˜ λ‚΄μš© λ‘œκΉ…
self.logger.log_info(f"πŸ” Full solution for debugging:")
self.logger.log_info(f"--- START SOLUTION ---")
self.logger.log_info(solution)
self.logger.log_info(f"--- END SOLUTION ---")
return solution
def generate_diverse(self, problem: Dict[str, Any], temperature: float = 0.7, variation_id: int = 0) -> str:
"""λ‹€μ–‘ν•œ μ†”λ£¨μ…˜ 생성 (높은 temperature μ‚¬μš©)"""
problem_prompt = problem['prompt']
problem_id = problem.get('task_id', 'unknown')
# 쀑앙 관리 λ‹€μ–‘μ„± ν”„λ‘¬ν”„νŠΈ μ‹œμŠ€ν…œ μ‚¬μš©
diversity_instruction = get_diversity_instruction(variation_id)
# HumanEval에 λŒ€ν•΄μ„œλŠ” ν•¨μˆ˜ μ™„μ„± μš”μ²­ (λ‹€μ–‘μ„± 버전)
if 'HumanEval' in problem_id:
entry_point = problem.get('entry_point', 'unknown')
import re
function_count = len(re.findall(r'^\s*def\s+\w+', problem_prompt, re.MULTILINE))
if function_count > 1:
prompt = get_prompt("diverse_humaneval_multi",
diversity_instruction=diversity_instruction,
problem_prompt=problem_prompt,
entry_point=entry_point)
else:
prompt = get_prompt("diverse_humaneval_basic",
diversity_instruction=diversity_instruction,
problem_prompt=problem_prompt)
else:
# MBPP λ‹€μ–‘μ„± ν”„λ‘¬ν”„νŠΈ μ‚¬μš©
prompt = get_prompt("diverse_mbpp_basic",
diversity_instruction=diversity_instruction,
problem_prompt=problem_prompt)
self.logger.log_info(f"🎨 Generating diverse solution #{variation_id+1} for {problem_id}")
# λ‹€μ–‘μ„± 생성 λ©”μ„œλ“œ μ‚¬μš©
try:
from vllm import LLM
if isinstance(self.model, LLM):
solution = self._generate_with_vllm_diverse(prompt, temperature)
else:
solution = self._generate_with_huggingface_diverse(prompt, temperature)
except ImportError:
solution = self._generate_with_huggingface_diverse(prompt, temperature)
# μ½”λ“œ μΆ”μΆœ 및 ν›„μ²˜λ¦¬ (κΈ°μ‘΄κ³Ό 동일)
extracted_solution = self._extract_python_code(solution)
if extracted_solution and extracted_solution != solution:
self.logger.log_info(f"πŸ” Extracted Python code from markdown block")
solution = extracted_solution
if 'HumanEval' in problem_id:
solution = self._add_imports_from_prompt(solution, problem_prompt)
solution = self._fix_function_definition(solution, prompt, problem_id)
self.logger.log_info(f"βœ… Generated diverse solution #{variation_id+1} ({len(solution)} chars)")
return solution
def _generate_with_vllm(self, prompt: str) -> str:
"""VLLM λ°±μ—”λ“œλ‘œ 생성 (AZR 방식)"""
# AZR evaluationκ³Ό λ™μΌν•œ SamplingParams μ„€μ •
sampling_params = SamplingParams(
temperature=0.05,
max_tokens=2048, # AZR 평가 μ„€μ •
top_p=1.0, # greedy mode
stop=["\n```\n"], # μ½”λ“œ 블둝 μ’…λ£Œ μ‹œ μ •μ§€
)
# VLLM 생성
outputs = self.model.generate([prompt], sampling_params, use_tqdm=False)
solution = outputs[0].outputs[0].text.replace("\t", " ") # AZR 방식 νƒ­ 처리
return solution.strip()
def _generate_with_vllm_diverse(self, prompt: str, temperature: float = 0.7) -> str:
"""λ‹€μ–‘ν•œ μ†”λ£¨μ…˜ μƒμ„±μš© VLLM λ°±μ—”λ“œ (높은 temperature)"""
# 닀양성을 μœ„ν•œ SamplingParams μ„€μ •
sampling_params = SamplingParams(
temperature=temperature, # 높은 temperature둜 λ‹€μ–‘μ„± 확보
max_tokens=2048,
top_p=0.95, # 닀양성을 μœ„ν•΄ top_p μ‚¬μš©
stop=["\n```\n"], # μ½”λ“œ 블둝 μ’…λ£Œ μ‹œ μ •μ§€
)
# VLLM 생성
outputs = self.model.generate([prompt], sampling_params, use_tqdm=False)
solution = outputs[0].outputs[0].text.replace("\t", " ")
return solution.strip()
def generate_batch(self, prompts: List[str], temperature: float = 0.7) -> List[str]:
"""배치둜 μ—¬λŸ¬ ν”„λ‘¬ν”„νŠΈ λ™μ‹œ 처리"""
# μ‹€μ œ λͺ¨λΈ νƒ€μž… 확인 (VLLM λ‘œλ”© μ‹€νŒ¨ μ‹œ HuggingFace λͺ¨λΈμ΄ λ‘œλ“œλ¨)
if self.use_vllm and isinstance(self.model, LLM):
raw_solutions = self._generate_batch_with_vllm(prompts, temperature)
else:
# HuggingFaceλŠ” 순차 처리 (fallback)
raw_solutions = [self._generate_with_huggingface(prompt) for prompt in prompts]
# 각 μ†”λ£¨μ…˜μ— λŒ€ν•΄ ν›„μ²˜λ¦¬ μˆ˜ν–‰
processed_solutions = []
for i, (prompt, solution) in enumerate(zip(prompts, raw_solutions)):
# 1. λ§ˆν¬λ‹€μš΄μ—μ„œ Python μ½”λ“œ μΆ”μΆœ
extracted = self._extract_python_code(solution)
if extracted and extracted != solution:
self.logger.log_info(f"πŸ” Extracted Python code from markdown block for batch item {i+1}")
solution = extracted
# 2. HumanEval 문제인 경우 import μΆ”κ°€
# ν”„λ‘¬ν”„νŠΈμ—μ„œ problem ID μΆ”μΆœ (ν”„λ‘¬ν”„νŠΈμ— ν¬ν•¨λ˜μ–΄ μžˆλ‹€κ³  κ°€μ •)
if 'HumanEval' in prompt:
# ν”„λ‘¬ν”„νŠΈμ—μ„œ 원본 problem prompt μΆ”μΆœ μ‹œλ„
# ν”„λ‘¬ν”„νŠΈ ꡬ쑰에 따라 μ‘°μ • ν•„μš”
solution = self._add_imports_from_prompt(solution, prompt)
# 3. ν•¨μˆ˜ μ •μ˜ μˆ˜μ • (ν•„μš”ν•œ 경우)
# generate_diverse와 λ™μΌν•œ 처리
solution = self._fix_function_definition(solution, prompt)
processed_solutions.append(solution)
return processed_solutions
def _generate_batch_with_vllm(self, prompts: List[str], temperature: float = 0.7) -> List[str]:
"""VLLM으둜 배치 처리"""
# VLLM μƒ˜ν”Œλ§ νŒŒλΌλ―Έν„°
# seedλ₯Ό μ œκ±°ν•˜μ—¬ 맀번 λ‹€λ₯Έ 응닡 생성
sampling_params = SamplingParams(
temperature=temperature,
top_p=0.85,
max_tokens=1024,
stop=[] # stop 토큰 λͺ…μ‹œμ μœΌλ‘œ 비움
)
# VLLM 배치 생성
outputs = self.model.generate(prompts, sampling_params, use_tqdm=False)
# κ²°κ³Ό μΆ”μΆœ
solutions = []
for i, output in enumerate(outputs):
solution = output.outputs[0].text.replace("\t", " ")
# 디버깅: finish_reason 확인
finish_reason = output.outputs[0].finish_reason
if finish_reason != "stop" and i < 3: # 처음 3개만 λ‘œκΉ…
self.logger.log_warning(f"Output {i} finish_reason: {finish_reason}, length: {len(solution)}")
solutions.append(solution.strip())
return solutions
def _generate_with_huggingface(self, prompt: str) -> str:
"""HuggingFace λ°±μ—”λ“œλ‘œ 생성 (attention mask μˆ˜μ •)"""
# ν† ν¬λ‚˜μ΄μ € 처리 (attention mask κ²½κ³  μˆ˜μ •)
inputs = self.tokenizer(prompt, return_tensors='pt', truncation=True, max_length=4096)
# attention mask λͺ…μ‹œμ μœΌλ‘œ μ„€μ •
if 'attention_mask' not in inputs:
inputs['attention_mask'] = torch.ones_like(inputs['input_ids'])
# λ””λ°”μ΄μŠ€ 이동 (AZR 방식 κ·ΈλŒ€λ‘œ)
device = getattr(self.model, 'device', 'cuda' if torch.cuda.is_available() else 'cpu')
if isinstance(device, str):
inputs = {k: v.to(device) for k, v in inputs.items()}
else:
# λͺ¨λΈμ΄ 이미 νŠΉμ • λ””λ°”μ΄μŠ€μ— μžˆλŠ” 경우
inputs = {k: v.to(next(self.model.parameters()).device) for k, v in inputs.items()}
with torch.no_grad():
# λ©”λͺ¨λ¦¬ 정리 (AZR 방식 κ·ΈλŒ€λ‘œ)
if torch.cuda.is_available():
torch.cuda.empty_cache()
# AZR evaluationκ³Ό λ™μΌν•œ greedy μ„€μ •
outputs = self.model.generate(
inputs['input_ids'],
attention_mask=inputs['attention_mask'], # attention mask λͺ…μ‹œμ μœΌλ‘œ 전달
max_new_tokens=2048, # μ›λž˜ AZR 평가 μ„€μ •
do_sample=False, # greedy mode (--greedy와 동일)
pad_token_id=self.tokenizer.eos_token_id
)
# μ†”λ£¨μ…˜ μΆ”μΆœ (AZR 방식 κ·ΈλŒ€λ‘œ)
solution = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
solution = solution[len(prompt):].strip()
return solution
def _generate_with_huggingface_diverse(self, prompt: str, temperature: float = 0.7) -> str:
"""λ‹€μ–‘ν•œ μ†”λ£¨μ…˜ μƒμ„±μš© HuggingFace λ°±μ—”λ“œ (높은 temperature)"""
# ν† ν¬λ‚˜μ΄μ € 처리
inputs = self.tokenizer(prompt, return_tensors='pt', truncation=True, max_length=4096)
# attention mask λͺ…μ‹œμ μœΌλ‘œ μ„€μ •
if 'attention_mask' not in inputs:
inputs['attention_mask'] = torch.ones_like(inputs['input_ids'])
# λ””λ°”μ΄μŠ€ 이동
device = getattr(self.model, 'device', 'cuda' if torch.cuda.is_available() else 'cpu')
if isinstance(device, str):
inputs = {k: v.to(device) for k, v in inputs.items()}
else:
# λͺ¨λΈμ΄ 이미 νŠΉμ • λ””λ°”μ΄μŠ€μ— μžˆλŠ” 경우
inputs = {k: v.to(next(self.model.parameters()).device) for k, v in inputs.items()}
with torch.no_grad():
# λ©”λͺ¨λ¦¬ 정리
if torch.cuda.is_available():
torch.cuda.empty_cache()
# 닀양성을 μœ„ν•œ sampling μ„€μ •
outputs = self.model.generate(
inputs['input_ids'],
attention_mask=inputs['attention_mask'],
max_new_tokens=2048,
do_sample=True, # sampling ν™œμ„±ν™”
temperature=temperature, # 높은 temperature
top_p=0.95, # 닀양성을 μœ„ν•΄ top_p μ‚¬μš©
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
# μ†”λ£¨μ…˜ μΆ”μΆœ
solution = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
solution = solution[len(prompt):].strip()
return solution
def _extract_python_code(self, solution: str) -> str:
"""κ°œμ„ λœ Python μ½”λ“œ μΆ”μΆœ (AZR 방식 + μΆ”κ°€ νŒ¨ν„΄)"""
# 1. AZR의 extract_code ν•¨μˆ˜ λ¨Όμ € μ‹œλ„
try:
extracted = extract_code(solution, language="python")
if extracted:
return extracted
except:
pass
# 2. λ‹€μ–‘ν•œ λ§ˆν¬λ‹€μš΄ νŒ¨ν„΄ μ‹œλ„
patterns = [
r'```python\n(.*?)```', # ```python ... ```
r'```\n(.*?)```', # ``` ... ```
r'```py\n(.*?)```', # ```py ... ```
r'```Python\n(.*?)```', # ```Python ... ```
r'Here is.*?:\n\n```python\n(.*?)```', # μ„€λͺ… ν…μŠ€νŠΈ 포함
r'Here is.*?:\n\n```\n(.*?)```', # μ„€λͺ… ν…μŠ€νŠΈ 포함
]
for pattern in patterns:
matches = re.findall(pattern, solution, re.DOTALL | re.IGNORECASE)
if matches:
return matches[-1].strip()
# 3. def둜 μ‹œμž‘ν•˜λŠ” ν•¨μˆ˜ μ°ΎκΈ°
lines = solution.split('\n')
code_lines = []
in_function = False
for line in lines:
if line.strip().startswith('def '):
in_function = True
code_lines.append(line)
elif in_function and (line.startswith(' ') or line.strip() == ''):
code_lines.append(line)
elif in_function and line.strip() and not line.startswith(' '):
# ν•¨μˆ˜ μ •μ˜ 끝
break
if code_lines:
return '\n'.join(code_lines)
# 4. 원본 λ°˜ν™˜
return solution
def _add_imports_from_prompt(self, solution: str, prompt: str) -> str:
"""HumanEval ν”„λ‘¬ν”„νŠΈμ—μ„œ import 문을 μΆ”μΆœν•˜μ—¬ μ†”λ£¨μ…˜μ— μΆ”κ°€ (EvalPlus 방식)"""
# 이미 importκ°€ 있으면 κ·ΈλŒ€λ‘œ λ°˜ν™˜
if 'from typing import' in solution or 'import typing' in solution:
return solution
# ν”„λ‘¬ν”„νŠΈμ—μ„œ import λ¬Έ μΆ”μΆœ
import_lines = []
prompt_lines = prompt.split('\n')
for line in prompt_lines:
stripped = line.strip()
# import λ¬Έ μ°ΎκΈ°
if (stripped.startswith('from ') and 'import' in stripped) or stripped.startswith('import '):
import_lines.append(line)
# ν•¨μˆ˜ μ •μ˜κ°€ μ‹œμž‘λ˜λ©΄ 쀑단
elif stripped.startswith('def '):
break
# importκ°€ μ—†μœΌλ©΄ 원본 λ°˜ν™˜
if not import_lines:
return solution
# import μΆ”κ°€
self.logger.log_info(f"πŸ”§ Adding imports from prompt: {import_lines}")
# μ†”λ£¨μ…˜μ΄ 이미 import둜 μ‹œμž‘ν•˜λŠ”μ§€ 확인
solution_lines = solution.split('\n')
first_non_empty_line = None
for i, line in enumerate(solution_lines):
if line.strip():
first_non_empty_line = i
break
# importλ₯Ό 맨 μ•žμ— μΆ”κ°€
if first_non_empty_line is not None:
# κΈ°μ‘΄ import 뒀에 μΆ”κ°€ν•˜κ±°λ‚˜ 맨 μ•žμ— μΆ”κ°€
imports_text = '\n'.join(import_lines) + '\n\n'
# 첫 번째 λΉ„μ–΄μžˆμ§€ μ•Šμ€ 쀄이 import인 경우
if solution_lines[first_non_empty_line].strip().startswith(('import ', 'from ')):
# λ§ˆμ§€λ§‰ import μ°ΎκΈ°
last_import_idx = first_non_empty_line
for i in range(first_non_empty_line, len(solution_lines)):
if solution_lines[i].strip() and not solution_lines[i].strip().startswith(('import ', 'from ')):
break
if solution_lines[i].strip().startswith(('import ', 'from ')):
last_import_idx = i
# λ§ˆμ§€λ§‰ import λ‹€μŒμ— μΆ”κ°€
solution_lines.insert(last_import_idx + 1, '')
solution_lines.insert(last_import_idx + 1, '\n'.join(import_lines))
return '\n'.join(solution_lines)
else:
# 맨 μ•žμ— μΆ”κ°€
return imports_text + solution
return imports_text + solution
def _fix_function_definition(self, solution: str, prompt: str, problem_id: str = "") -> str:
"""ν•¨μˆ˜ μ •μ˜κ°€ λˆ„λ½λœ 경우 볡ꡬ + lpw μŠ€νƒ€μΌ 쀑볡 처리"""
# lpw μŠ€νƒ€μΌ: ν”„λ‘¬ν”„νŠΈμ—μ„œ ν•¨μˆ˜ 이름 μΆ”μΆœ
func_def_match = re.search(r'def\s+(\w+)\([^)]*\)(?:\s*->\s*[^:]+)?:', prompt)
if not func_def_match:
return solution
entry_point = func_def_match.group(1)
func_def_line = func_def_match.group(0)
# HumanEval의 경우 전체 μ½”λ“œλ₯Ό λ°˜ν™˜ν•˜λ―€λ‘œ 쀑볡 처리 λΆˆν•„μš”
if 'HumanEval' in problem_id:
# 이미 전체 μ½”λ“œκ°€ μžˆμœΌλ―€λ‘œ κ·ΈλŒ€λ‘œ λ°˜ν™˜
return solution
# MBPP의 경우 κΈ°μ‘΄ 둜직 μœ μ§€
# Case 1: LLM이 전체 ν•¨μˆ˜λ₯Ό μƒμ„±ν•œ 경우 (lpw μŠ€νƒ€μΌ 체크)
if (prompt in solution) or (f'def {entry_point}(' in solution):
# ν•¨μˆ˜κ°€ 이미 ν¬ν•¨λ˜μ–΄ 있음
self.logger.log_info(f"βœ… Function definition already present for {entry_point}")
return solution
# Case 2: ν•¨μˆ˜ 본문만 μƒμ„±ν•œ 경우 - ν•¨μˆ˜ μ •μ˜ μΆ”κ°€
if solution and not solution.startswith('def '):
# ν•¨μˆ˜ μ •μ˜μ™€ ν•¨μˆ˜ λ‚΄μš©μ„ κ²°ν•©
lines = solution.split('\n')
fixed_lines = [func_def_line]
for line in lines:
if line.strip(): # 빈 쀄이 μ•„λ‹Œ 경우
# if __name__ == "__main__": 뢀뢄은 ν•¨μˆ˜ 밖에 μžˆμ–΄μ•Ό 함
if line.strip().startswith('if __name__'):
# ν•¨μˆ˜ μ •μ˜ 끝내고 메인 λΆ€λΆ„ μ‹œμž‘
fixed_lines.append('') # 빈 쀄 μΆ”κ°€
fixed_lines.append(line.strip())
else:
# ν•¨μˆ˜ λ‚΄μš©μ€ 4μΉΈ μΈλ΄ν…Œμ΄μ…˜
if not line.startswith(' ') and line.strip():
line = ' ' + line.lstrip()
fixed_lines.append(line)
else:
fixed_lines.append(line)
solution = '\n'.join(fixed_lines)
self.logger.log_info(f"πŸ”§ Fixed function definition for {entry_point}")
return solution
def generate_fallback_solution(self, problem: Dict[str, Any]) -> str:
"""문제 생성 μ‹€νŒ¨ μ‹œ λŒ€μ²΄ μ†”λ£¨μ…˜ 생성"""
entry_point = problem.get('entry_point', 'solution')
problem_description = problem.get('prompt', '')
# 문제 μœ ν˜•λ³„ κΈ°λ³Έ ν…œν”Œλ¦Ώ (κΈ°μ‘΄ 방식)
if 'similar_elements' in problem_description:
# similar_elements 문제 (Mbpp/2)
solution = f"""def {entry_point}(test_tup1, test_tup2):
return tuple(set(test_tup1) & set(test_tup2))"""
elif 'kth_element' in problem_description:
# kth_element 문제
solution = f"""def {entry_point}(arr, k):
return sorted(arr)[k-1]"""
else:
# 일반 ν…œν”Œλ¦Ώ
solution = f"""def {entry_point}(*args):
# TODO: Implement this function
return None"""
self.logger.log_info(f"πŸ”„ Generated fallback solution for {entry_point}")
return solution
def validate_syntax(self, solution: str) -> Tuple[bool, Optional[str]]:
"""μ†”λ£¨μ…˜ ꡬ문 검증"""
try:
compile(solution, '<string>', 'exec')
return True, None
except SyntaxError as e:
return False, str(e)
except Exception as e:
return False, str(e)
def extract_function_signature(self, prompt: str) -> Optional[Dict[str, str]]:
"""ν”„λ‘¬ν”„νŠΈμ—μ„œ ν•¨μˆ˜ μ‹œκ·Έλ‹ˆμ²˜ μΆ”μΆœ"""
# def function_name(args) -> return_type: νŒ¨ν„΄ λ§€μΉ­
pattern = r'def\s+(\w+)\(([^)]*)\)(?:\s*->\s*([^:]+))?:'
match = re.search(pattern, prompt)
if match:
func_name = match.group(1)
args = match.group(2)
return_type = match.group(3)
return {
'name': func_name,
'args': args.strip(),
'return_type': return_type.strip() if return_type else None,
'full_signature': match.group(0)
}
return None
def format_solution(self, raw_solution: str, problem: Dict[str, Any]) -> str:
"""μ†”λ£¨μ…˜ ν˜•μ‹ 정리"""
# κΈ°λ³Έ 정리
solution = raw_solution.strip()
# ν•¨μˆ˜ μ •μ˜ 확인 및 μˆ˜μ •
if not solution.startswith('def '):
signature = self.extract_function_signature(problem.get('prompt', ''))
if signature:
# ν•¨μˆ˜ μ •μ˜ μΆ”κ°€
lines = solution.split('\n')
indented_lines = [' ' + line if line.strip() else line for line in lines]
solution = signature['full_signature'] + '\n' + '\n'.join(indented_lines)
# λΆˆν•„μš”ν•œ μ„€λͺ… ν…μŠ€νŠΈ 제거
lines = solution.split('\n')
code_lines = []
in_function = False
for line in lines:
if line.strip().startswith('def '):
in_function = True
code_lines.append(line)
elif in_function:
code_lines.append(line)
elif line.strip() and not any(keyword in line.lower() for keyword in
['explanation', 'here', 'this function', 'the solution']):
code_lines.append(line)
return '\n'.join(code_lines).strip()
@staticmethod
def extract_docstring_from_function(code: str) -> str:
"""ν•¨μˆ˜ μ½”λ“œμ—μ„œ docstring을 μΆ”μΆœ"""
import re
# ν•¨μˆ˜ μ •μ˜ λ‹€μŒμ— μ˜€λŠ” docstring νŒ¨ν„΄ λ§€μΉ­
# """...""" λ˜λŠ” '''...''' ν˜•νƒœ
docstring_patterns = [
r'def\s+\w+\([^)]*\):\s*\n\s*"""(.*?)"""', # """..."""
r'def\s+\w+\([^)]*\):\s*\n\s*\'\'\'(.*?)\'\'\'', # '''...'''
]
for pattern in docstring_patterns:
match = re.search(pattern, code, re.DOTALL)
if match:
docstring = match.group(1).strip()
# μ—¬λŸ¬ 쀄인 경우 κΉ”λ”ν•˜κ²Œ 정리
lines = docstring.split('\n')
cleaned_lines = []
for line in lines:
cleaned_line = line.strip()
if cleaned_line:
cleaned_lines.append(cleaned_line)
return ' '.join(cleaned_lines)
# docstring이 μ—†λŠ” 경우 κΈ°λ³Έ λ©”μ‹œμ§€ λ°˜ν™˜
return "Find the function that produces these outputs from these inputs."
def _extract_function_code(self, code: str) -> str:
"""μ½”λ“œμ—μ„œ ν•¨μˆ˜ μ •μ˜μ™€ ν•„μš”ν•œ import μΆ”μΆœ"""
import re
lines = code.strip().split('\n')
import_lines = []
func_lines = []
in_function = False
indent_level = 0
# 1. import λ¬Έ μˆ˜μ§‘
for line in lines:
stripped = line.strip()
if (stripped.startswith('import ') or stripped.startswith('from ')) and not stripped.startswith('#'):
import_lines.append(line)
# 2. ν•¨μˆ˜ μ •μ˜ μ°ΎκΈ°
for line in lines:
if line.strip().startswith('def '):
in_function = True
func_lines = [line]
# 첫 μ€„μ˜ λ“€μ—¬μ“°κΈ° 레벨 μ €μž₯
indent_level = len(line) - len(line.lstrip())
elif in_function:
# 빈 μ€„μ΄κ±°λ‚˜ 같은/더 κΉŠμ€ λ“€μ—¬μ“°κΈ°λ©΄ ν•¨μˆ˜μ˜ 일뢀
if not line.strip() or (line.strip() and len(line) - len(line.lstrip()) > indent_level):
func_lines.append(line)
else:
# ν•¨μˆ˜ 끝
break
# 3. import + function κ²°ν•©
if func_lines:
result_lines = import_lines + [''] + func_lines if import_lines else func_lines
return '\n'.join(result_lines)
else:
return code
def evaluate_solution(self, problem: Dict[str, Any], solution: str) -> Dict[str, Any]:
"""LLM μ†”λ£¨μ…˜μ„ 벀치마크 ν…ŒμŠ€νŠΈλ‘œ 평가 (EvalPlus ν•„μˆ˜)"""
try:
# EvalPlus ν•¨μˆ˜λ“€ μž„ν¬νŠΈ (pip으둜 μ„€μΉ˜λœ 버전 μ‚¬μš©)
self.logger.log_info("πŸ”„ Attempting to import EvalPlus...")
from evalplus.evaluate import check_correctness
from evalplus.gen.util import trusted_exec
from evalplus.eval._special_oracle import MBPP_OUTPUT_NOT_NONE_TASKS
from evalplus.eval import PASS
self.logger.log_info("βœ… Using EvalPlus for evaluation")
except ImportError as e:
# EvalPlusκ°€ μ—†μœΌλ©΄ 였λ₯˜λ‘œ 처리 (fallback 제거)
self.logger.log_error(f"❌ EvalPlus is required but not available: {e}")
import traceback
self.logger.log_error(f"πŸ“‹ Import traceback: {traceback.format_exc()}")
return {
'correct': False,
'passed_tests': 0,
'total_tests': 0,
'error': f"EvalPlus import failed: {e}. Please install EvalPlus properly.",
'execution_results': [],
'base_passed': 0,
'plus_passed': 0,
'base_total': 0,
'plus_total': 0
}
except Exception as e:
self.logger.log_error(f"❌ EvalPlus import failed with unexpected error: {e}")
return {
'correct': False,
'passed_tests': 0,
'total_tests': 0,
'error': f"EvalPlus import error: {e}",
'execution_results': [],
'base_passed': 0,
'plus_passed': 0,
'base_total': 0,
'plus_total': 0
}
result = {
'correct': False,
'passed_tests': 0,
'total_tests': 0,
'error': None,
'execution_results': [],
'base_passed': 0,
'plus_passed': 0,
'base_total': 0,
'plus_total': 0
}
try:
# 1. ν•¨μˆ˜ μ •μ˜ μΆ”μΆœ
extracted_code = self._extract_function_code(solution)
if not extracted_code:
result['error'] = "No function definition found"
return result
# 2. 데이터셋 νƒ€μž… κ²°μ •
task_id = problem.get('task_id', '')
if task_id.startswith('Mbpp'):
dataset = 'mbpp'
elif task_id.startswith('HumanEval'):
dataset = 'humaneval'
else:
# κΈ°λ³Έκ°’
dataset = 'mbpp'
# 3. expected outputs 생성 (canonical solution μ‚¬μš©)
entry_point = problem.get('entry_point', '')
canonical_solution = problem.get('canonical_solution', '')
if not canonical_solution:
result['error'] = "No canonical_solution found"
return result
# Expected outputs 계산
expected_output = {}
# Base tests
base_inputs = problem.get('base_input', [])
if base_inputs:
expected_output['base'], expected_output['base_time'] = trusted_exec(
problem.get('prompt', '') + canonical_solution,
base_inputs,
entry_point,
record_time=True,
output_not_none=entry_point in MBPP_OUTPUT_NOT_NONE_TASKS
)
# Plus tests
plus_inputs = problem.get('plus_input', [])
if plus_inputs:
expected_output['plus'], expected_output['plus_time'] = trusted_exec(
problem.get('prompt', '') + canonical_solution,
plus_inputs,
entry_point,
record_time=True,
output_not_none=entry_point in MBPP_OUTPUT_NOT_NONE_TASKS
)
# 4. EvalPlus check_correctness 호좜
evalplus_result = check_correctness(
dataset=dataset,
completion_id=0,
problem=problem,
solution=extracted_code,
expected_output=expected_output,
base_only=False, # Plus tests도 μ‹€ν–‰
fast_check=False, # λͺ¨λ“  ν…ŒμŠ€νŠΈ μ‹€ν–‰
identifier=task_id
)
# 5. κ²°κ³Ό νŒŒμ‹±
if 'base' in evalplus_result:
base_stat, base_details = evalplus_result['base']
result['base_total'] = len(base_inputs)
if base_stat == PASS:
result['base_passed'] = result['base_total']
else:
result['base_passed'] = sum(1 for d in base_details if d) if base_details else 0
result['passed_tests'] += result['base_passed']
result['total_tests'] += result['base_total']
if 'plus' in evalplus_result:
plus_stat, plus_details = evalplus_result['plus']
result['plus_total'] = len(plus_inputs)
if plus_stat == PASS:
result['plus_passed'] = result['plus_total']
else:
result['plus_passed'] = sum(1 for d in plus_details if d) if plus_details else 0
result['passed_tests'] += result['plus_passed']
result['total_tests'] += result['plus_total']
# EvalPlus κΈ°μ€€: λͺ¨λ“  ν…ŒμŠ€νŠΈ 톡과해야 correct
result['correct'] = (result['passed_tests'] == result['total_tests']) and result['total_tests'] > 0
# μ—λŸ¬ λ©”μ‹œμ§€ μ„€μ •
if not result['correct']:
if base_stat != PASS:
result['error'] = f"Base tests failed: {base_stat}"
elif 'plus' in evalplus_result and plus_stat != PASS:
result['error'] = f"Plus tests failed: {plus_stat}"
# λ‘œκΉ…
self.logger.log_info(f"EvalPlus evaluation for {task_id}:")
self.logger.log_info(f" Base: {result['base_passed']}/{result['base_total']}")
self.logger.log_info(f" Plus: {result['plus_passed']}/{result['plus_total']}")
self.logger.log_info(f" Total: {result['passed_tests']}/{result['total_tests']}")
self.logger.log_info(f" Correct: {result['correct']}")
except Exception as e:
result['error'] = f"Evaluation failed: {str(e)}"
import traceback
self.logger.log_info(f"Evaluation traceback: {traceback.format_exc()}")
return result
@staticmethod
def load_model_with_optimizations(model_name: str, device: str,
config: TestTimeConfig, use_vllm: bool = True, tensor_parallel_size: int = 1) -> Tuple[Any, Any]:
"""λͺ¨λΈκ³Ό ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ (AZR μŠ€νƒ€μΌ μ΅œμ ν™”, VLLM 지원)"""
# ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# VLLM μ‚¬μš© κ°€λŠ₯ μ—¬λΆ€ 확인 및 λͺ¨λΈ λ‘œλ“œ
if use_vllm and VLLM_AVAILABLE and device.startswith('cuda'):
try:
# GPU λ””λ°”μ΄μŠ€ μ„€μ • (이미 μ„€μ •λœ CUDA_VISIBLE_DEVICES μš°μ„  μ‚¬μš©)
import os
if 'CUDA_VISIBLE_DEVICES' not in os.environ:
gpu_id = device.split(':')[1] if ':' in device else '0'
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id
else:
# 이미 μ„€μ •λœ CUDA_VISIBLE_DEVICES μ‚¬μš©
gpu_id = os.environ['CUDA_VISIBLE_DEVICES']
print(f"🎯 Using existing CUDA_VISIBLE_DEVICES: {gpu_id}")
# VLLM λͺ¨λΈ λ‘œλ“œ (Ray Actor ν™˜κ²½μ—μ„œ λ©”λͺ¨λ¦¬ μ΅œμ ν™”)
model = LLM(
model=model_name,
dtype=str(config.torch_dtype).split('.')[-1], # torch.float16 -> float16
trust_remote_code=True,
gpu_memory_utilization=config.gpu_memory_utilization,
max_model_len=getattr(config, 'max_model_len', 2048), # μΆ©λΆ„ν•œ 길이둜 증가
tensor_parallel_size=tensor_parallel_size, # GPU κ°œμˆ˜μ— 맞좀
)
print(f"βœ… VLLM model loaded successfully on GPU {gpu_id} (tensor_parallel_size={tensor_parallel_size})")
return model, tokenizer
except Exception as e:
import traceback
print(f"⚠️ VLLM loading failed: {e}")
print(f"πŸ” Full traceback: {traceback.format_exc()}")
print(f"πŸ”„ Falling back to HuggingFace")
# HuggingFace λͺ¨λΈ λ‘œλ“œ (κΈ°μ‘΄ 방식)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=config.torch_dtype,
device_map=device if device.startswith('cuda') else None,
trust_remote_code=True,
attn_implementation="flash_attention_2" if config.use_flash_attention and device.startswith('cuda') else None,
use_cache=False, # ν•™μŠ΅μš©μœΌλ‘œ μΊμ‹œ λΉ„ν™œμ„±ν™”
)
# Gradient checkpointing ν™œμ„±ν™”
# Gradient checkpointing λΉ„ν™œμ„±ν™” - μΆ”λ‘  μ‹œμ—λŠ” λΆˆν•„μš”ν•˜κ³  경고만 λ°œμƒ
# ν•™μŠ΅μ΄ ν•„μš”ν•œ 경우 λ³„λ„λ‘œ ν™œμ„±ν™”ν•΄μ•Ό 함
if hasattr(model, 'gradient_checkpointing_disable'):
model.gradient_checkpointing_disable()
print(f"βœ… HuggingFace model loaded successfully")
return model, tokenizer