hjkim00's picture
Upload TestTime-RLVR-v2 from Full-pipeline-relative_0827 branch
f50dc54 verified
"""
IPO Triple Extractor
AZR Python Executor ๊ธฐ๋ฐ˜ (Input, Program, Output) ํŠธ๋ฆฌํ”Œ ์ถ”์ถœ ์‹œ์Šคํ…œ
์š”๊ตฌ์‚ฌํ•ญ 2: "AZR Python Executor๋ฅผ ์ด์šฉํ•˜์—ฌ (i,p,o) pair๋ฅผ ๋งŒ๋“ ๋‹ค"
"""
import ast
import re
import json
from typing import Dict, List, Any, Tuple, Optional
from concurrent.futures import TimeoutError
from ..utils.code_utils.python_executor import PythonExecutor
from .config import TestTimeConfig
from .logger import TestTimeLogger
from .solution_generator import InitialSolutionGenerator
class IPOBuffer:
"""IPO triple์„ ์ €์žฅํ•˜๊ณ  ๊ด€๋ฆฌํ•˜๋Š” ๋ฒ„ํผ"""
def __init__(self):
self.buffer = {} # {problem_id: [ipo_triples]}
def add(self, problem_id: str, ipo_triple: Dict[str, Any]):
"""IPO triple์„ ๋ฒ„ํผ์— ์ถ”๊ฐ€"""
if problem_id not in self.buffer:
self.buffer[problem_id] = []
self.buffer[problem_id].append(ipo_triple)
def get_all(self, problem_id: str) -> List[Dict[str, Any]]:
"""ํŠน์ • ๋ฌธ์ œ์˜ ๋ชจ๋“  IPO triple ๋ฐ˜ํ™˜"""
return self.buffer.get(problem_id, [])
def clear(self, problem_id: str = None):
"""๋ฒ„ํผ ์ดˆ๊ธฐํ™”"""
if problem_id:
self.buffer.pop(problem_id, None)
else:
self.buffer.clear()
def size(self, problem_id: str = None) -> int:
"""๋ฒ„ํผ ํฌ๊ธฐ ๋ฐ˜ํ™˜"""
if problem_id:
return len(self.buffer.get(problem_id, []))
return sum(len(triples) for triples in self.buffer.values())
class IPOTripleExtractor:
"""(Input, Program, Output) ํŠธ๋ฆฌํ”Œ ์ถ”์ถœ ๋ฐ ๊ฒ€์ฆ"""
def __init__(self, config: TestTimeConfig, logger: Optional[TestTimeLogger] = None,
model=None, tokenizer=None):
self.config = config
self.logger = logger or TestTimeLogger()
self.model = model
self.tokenizer = tokenizer
# AZR Python Executor ์ดˆ๊ธฐํ™” (๊ธฐ์กด ๋ฐฉ์‹)
self.executor = PythonExecutor(
timeout_length=config.python_executor_timeout,
ast_check=True, # AZR ๊ธฐ๋ณธ ์„ค์ •
max_workers=config.max_workers
)
self.extracted_triples = []
# ์ž…๋ ฅ ์ƒ์„ฑ ํ”„๋กฌํ”„ํŠธ์™€ ์‘๋‹ต ์ €์žฅ์šฉ
self.last_generation_prompt = ""
self.last_generation_response = ""
# VLLM ๋ฐฐ์น˜ ์ฒ˜๋ฆฌ๋ฅผ ์œ„ํ•œ ์ฐธ์กฐ
self.solution_generator = None
def extract_triples(self, problem: Dict[str, Any], solution: str) -> List[Dict[str, Any]]:
"""๋ฒค์น˜๋งˆํฌ ๋ฌธ์ œ์™€ ์†”๋ฃจ์…˜์—์„œ IPO ํŠธ๋ฆฌํ”Œ ์ถ”์ถœ"""
problem_id = problem.get('task_id', 'unknown')
self.logger.log_info(f"๐Ÿ” Extracting IPO triples for {problem_id}")
triples = []
try:
# 1. ํ•จ์ˆ˜ ์ •๋ณด ์ถ”์ถœ (entry point ์šฐ์„ )
entry_point = problem.get('entry_point', 'unknown')
func_info = self._extract_function_info(solution, entry_point)
if not func_info:
self.logger.log_error(f"Failed to extract function info from solution")
return []
# 2. ํ…Œ์ŠคํŠธ ์ผ€์ด์Šค์—์„œ ์ž…๋ ฅ-์ถœ๋ ฅ ์Œ ์ƒ์„ฑ (LLM ์†”๋ฃจ์…˜ ๊ธฐ๋ฐ˜)
test_cases = self._extract_test_cases(problem, solution)
# 3. ์†”๋ฃจ์…˜ ์‹คํ–‰์œผ๋กœ IPO ํŠธ๋ฆฌํ”Œ ์ƒ์„ฑ
for i, (test_input_str, expected_output) in enumerate(test_cases):
if len(triples) >= self.config.max_ipo_triples:
break
# test_input_str์—์„œ ์‹ค์ œ ์ธ์ž ์ถ”์ถœ (์˜ˆ: "strlen('')" -> "''")
import re
match = re.match(rf'{entry_point}\((.*)\)', test_input_str)
if match:
actual_args = match.group(1)
else:
actual_args = test_input_str # fallback
triple = self._create_ipo_triple(
func_info['full_code'], # ๐Ÿ”ง ์ˆ˜์ •: ์ „์ฒด ์ฝ”๋“œ ์‚ฌ์šฉ (๋„์šฐ๋ฏธ ํ•จ์ˆ˜ ํฌํ•จ)
func_info,
actual_args, # ์‹ค์ œ ์ธ์ž๋งŒ ์ „๋‹ฌ
expected_output,
triple_id=f"{problem_id}_triple_{i}",
full_input_str=test_input_str # ์ „์ฒด ์ž…๋ ฅ ๋ฌธ์ž์—ด๋„ ์ „๋‹ฌ
)
if triple:
triples.append(triple)
# ๐Ÿ”ง ์ˆ˜์ •: Synthetic ํŠธ๋ฆฌํ”Œ ์ƒ์„ฑ ์ œ๊ฑฐ (๋‹จ์ผ ์˜ˆ์‹œ๋งŒ ์‚ฌ์šฉํ•˜์—ฌ ์น˜ํŒ… ๋ฐฉ์ง€)
# Synthetic ํŠธ๋ฆฌํ”Œ ์ƒ์„ฑ ๋กœ์ง์„ ์ œ๊ฑฐํ•˜์—ฌ ์ง„์งœ ๋‹จ์ผ ์˜ˆ์‹œ๋งŒ ์‚ฌ์šฉ
# ๊ฒ€์ฆ ๋ฐ ๋กœ๊น…
validation_results = [self._validate_triple(triple) for triple in triples]
self.logger.log_ipo_extraction(problem_id, triples, validation_results)
# ์œ ํšจํ•œ ํŠธ๋ฆฌํ”Œ๋งŒ ๋ฐ˜ํ™˜
valid_triples = [triple for triple, valid in zip(triples, validation_results) if valid]
self.logger.log_info(f"โœ… Extracted {len(valid_triples)}/{len(triples)} valid IPO triples")
return valid_triples
except Exception as e:
self.logger.log_error(f"IPO extraction failed: {e}")
return []
def _extract_function_info(self, solution: str, entry_point: str = None) -> Optional[Dict[str, str]]:
"""์†”๋ฃจ์…˜์—์„œ ํ•จ์ˆ˜ ์ •๋ณด ์ถ”์ถœ (entry point ์šฐ์„ )"""
try:
# ๐Ÿ”ง ๊ฐœ์„ : Raw LLM response์ธ์ง€ ํ™•์ธํ•˜๊ณ  ํ•จ์ˆ˜ ์ฝ”๋“œ ์ถ”์ถœ
processed_solution = solution
if "LLM GENERATED SOLUTION:" in solution:
self.logger.log_info("๐Ÿ“ Raw LLM response detected, extracting function code")
processed_solution = self._extract_function_from_llm_response(solution)
if not processed_solution:
self.logger.log_error("Failed to extract function from LLM response")
return None
# AST๋กœ ํ•จ์ˆ˜ ์ •์˜ ํŒŒ์‹ฑ
tree = ast.parse(processed_solution)
# ๐Ÿ”ง ์ˆ˜์ •: Entry point ํ•จ์ˆ˜ ์šฐ์„  ๊ฒ€์ƒ‰
target_function = None
all_functions = []
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
func_info = {
'name': node.name,
'args': [arg.arg for arg in node.args.args],
'signature': f"def {node.name}({', '.join([arg.arg for arg in node.args.args])}):",
'full_code': processed_solution
}
all_functions.append(func_info)
# Entry point์™€ ์ผ์น˜ํ•˜๋Š” ํ•จ์ˆ˜ ์šฐ์„  ์„ ํƒ
if entry_point and node.name == entry_point:
target_function = func_info
# ์ด ๋กœ๊ทธ๋Š” ๋„ˆ๋ฌด ์ž์ฃผ ์ถœ๋ ฅ๋˜๋ฏ€๋กœ debug ๋ ˆ๋ฒจ๋กœ ๋ณ€๊ฒฝ
self.logger.log_debug(f"๐ŸŽฏ Found entry point function: {entry_point}")
break
# Entry point ํ•จ์ˆ˜๋ฅผ ์ฐพ์•˜์œผ๋ฉด ๋ฐ˜ํ™˜
if target_function:
return target_function
# Entry point๋ฅผ ์ฐพ์ง€ ๋ชปํ–ˆ์œผ๋ฉด ์ฒซ ๋ฒˆ์งธ ํ•จ์ˆ˜ ๋ฐ˜ํ™˜ (๊ธฐ์กด ๋ฐฉ์‹)
if all_functions:
self.logger.log_warning(f"โš ๏ธ Entry point '{entry_point}' not found, using first function: {all_functions[0]['name']}")
return all_functions[0]
return None
except Exception as e:
self.logger.log_error(f"Function parsing failed: {e}")
return None
def _extract_function_from_llm_response(self, llm_response: str) -> str:
"""Raw LLM response์—์„œ ํ•จ์ˆ˜ ์ฝ”๋“œ ์ถ”์ถœ (solution_generator์™€ ๋™์ผํ•œ ๋กœ์ง)"""
lines = llm_response.split('\n')
solution_lines = []
in_solution = False
# "LLM GENERATED SOLUTION:" ์„น์…˜ ์ถ”์ถœ (์ˆ˜์ •๋œ ๋กœ์ง)
for i, line in enumerate(lines):
if "LLM GENERATED SOLUTION:" in line:
in_solution = True
continue
elif in_solution:
# "===============" ๋ผ์ธ์ด ๋‚˜์˜ค๋ฉด ์ข…๋ฃŒํ•˜๋˜, ์ฒซ ๋ฒˆ์งธ "==============="๋Š” ๊ฑด๋„ˆ๋›ฐ๊ธฐ
if "===============" in line:
# ์‹ค์ œ ์†”๋ฃจ์…˜ ๋ผ์ธ๋“ค์ด ์žˆ๋Š”์ง€ ํ™•์ธ
if solution_lines and any(l.strip() for l in solution_lines):
break
else:
# ์•„์ง ์†”๋ฃจ์…˜ ๋ผ์ธ์ด ์—†์œผ๋ฉด ๊ณ„์† ์ง„ํ–‰ (์ฒซ ๋ฒˆ์งธ ๊ตฌ๋ถ„์„  ๊ฑด๋„ˆ๋›ฐ๊ธฐ)
continue
solution_lines.append(line)
if not solution_lines:
return "" # ์ถ”์ถœ ์‹คํŒจ์‹œ ๋นˆ ๋ฌธ์ž์—ด ๋ฐ˜ํ™˜
extracted_solution = '\n'.join(solution_lines).strip()
# ํ•จ์ˆ˜ ์ •์˜์™€ import ์ถ”์ถœ (solution_generator ๋กœ์ง๊ณผ ๋™์ผ)
lines = extracted_solution.split('\n')
import_lines = []
func_lines = []
in_function = False
indent_level = 0
# 1. import ๋ฌธ ์ˆ˜์ง‘
for line in lines:
stripped = line.strip()
if (stripped.startswith('import ') or stripped.startswith('from ')) and not stripped.startswith('#'):
import_lines.append(line)
# 2. ํ•จ์ˆ˜ ์ •์˜ ์ฐพ๊ธฐ
for line in lines:
if line.strip().startswith('def '):
in_function = True
func_lines = [line]
indent_level = len(line) - len(line.lstrip())
elif in_function:
if not line.strip() or (line.strip() and len(line) - len(line.lstrip()) > indent_level):
func_lines.append(line)
else:
break
# 3. import + function ๊ฒฐํ•ฉ
if func_lines:
result_lines = import_lines + [''] + func_lines if import_lines else func_lines
return '\n'.join(result_lines)
else:
return extracted_solution
def _fix_humaneval_canonical_solution(self, problem: Dict[str, Any]) -> str:
"""HumanEval canonical solution ๋ณต์› (ํ•จ์ˆ˜ ์‹œ๊ทธ๋‹ˆ์ฒ˜ ์ถ”๊ฐ€)"""
canonical_code = problem.get('canonical_solution', '')
entry_point = problem.get('entry_point', '')
prompt = problem.get('prompt', '')
# HumanEval์ธ์ง€ ํ™•์ธ
task_id = problem.get('task_id', '')
if not task_id.startswith('HumanEval/'):
return canonical_code
# ์ด๋ฏธ ํ•จ์ˆ˜ ์‹œ๊ทธ๋‹ˆ์ฒ˜๊ฐ€ ์žˆ๋Š”์ง€ ํ™•์ธ
if f"def {entry_point}" in canonical_code:
return canonical_code
try:
# Prompt์—์„œ ํ•จ์ˆ˜ ์‹œ๊ทธ๋‹ˆ์ฒ˜ ์ถ”์ถœ
import re
def_pattern = rf'def\s+{re.escape(entry_point)}\s*\([^)]*\)[^:]*:'
match = re.search(def_pattern, prompt, re.MULTILINE)
if match:
function_signature = match.group(0)
# Import ๋ฌธ๋„ ์ถ”์ถœ (์žˆ๋‹ค๋ฉด)
import_lines = []
for line in prompt.split('\n'):
stripped = line.strip()
if (stripped.startswith('import ') or stripped.startswith('from ')) and not stripped.startswith('#'):
import_lines.append(line)
# ์™„์ „ํ•œ canonical solution ๊ตฌ์„ฑ
if import_lines:
complete_canonical = '\n'.join(import_lines) + '\n\n' + function_signature + canonical_code
else:
complete_canonical = function_signature + canonical_code
self.logger.log_info(f"๐Ÿ”ง Fixed HumanEval canonical solution for {entry_point}")
return complete_canonical
else:
self.logger.log_warning(f"โš ๏ธ Could not extract function signature for {entry_point}")
return canonical_code
except Exception as e:
self.logger.log_error(f"Failed to fix HumanEval canonical solution: {e}")
return canonical_code
def _extract_single_prompt_example(self, problem: Dict[str, Any]) -> Optional[Tuple[str, str]]:
"""๐Ÿ”ง ์ƒˆ๋กœ์šด ๋ฉ”์„œ๋“œ: ํ”„๋กฌํ”„ํŠธ์˜ ๋‹จ์ผ ์˜ˆ์‹œ๋งŒ ์ถ”์ถœ (์น˜ํŒ… ๋ฐฉ์ง€)"""
try:
# base_input์˜ ์ฒซ ๋ฒˆ์งธ ํ•ญ๋ชฉ์„ ๋‹จ์ผ ์˜ˆ์‹œ๋กœ ์‚ฌ์šฉ
if 'base_input' in problem and problem['base_input']:
first_input = problem['base_input'][0]
entry_point = problem['entry_point']
self.logger.log_info(f"๐Ÿ“ฅ Using first base_input as single example: {first_input}")
# ๐Ÿ”ง ์ˆ˜์ •: HumanEval canonical solution ๋ณต์›
canonical_code = self._fix_humaneval_canonical_solution(problem)
if canonical_code:
actual_output = self._execute_llm_solution(canonical_code, entry_point, first_input)
if actual_output is not None:
# ์ž…๋ ฅ ๋ฌธ์ž์—ด ํ˜•์‹ ์ƒ์„ฑ
if isinstance(first_input, list):
if len(first_input) == 1 and isinstance(first_input[0], list):
# [[args]] -> ๋‹จ์ผ ๋ฆฌ์ŠคํŠธ ์ธ์ž๋กœ ํ‘œ์‹œ
input_str = repr(first_input[0])
elif len(first_input) == 1:
# [๋‹จ์ผ์ธ์ž] -> ๋‹จ์ผ์ธ์ž
input_str = repr(first_input[0])
else:
# [๋‹ค์ค‘์ธ์ž] -> ๋‹ค์ค‘์ธ์ž
input_str = ', '.join(repr(arg) for arg in first_input)
else:
input_str = repr(first_input)
result = (input_str, str(actual_output))
self.logger.log_info(f"โœ… Single example extracted: Input={input_str}, Output={actual_output}")
return result
else:
self.logger.log_warning("โŒ Failed to compute output with canonical solution")
else:
self.logger.log_warning("โŒ No canonical solution available")
else:
self.logger.log_warning("โŒ No base_input available")
except Exception as e:
self.logger.log_error(f"Single example extraction failed: {e}")
return None
def _extract_docstring_examples(self, prompt: str, func_name: str) -> List[Tuple[str, str]]:
"""docstring์—์„œ >>> ์˜ˆ์ œ ์ถ”์ถœ"""
examples = []
lines = prompt.split('\n')
i = 0
while i < len(lines):
line = lines[i].strip()
# >>> func_name(...) ํŒจํ„ด ์ฐพ๊ธฐ
if line.startswith('>>>') and func_name in line:
# ์ž…๋ ฅ ์ถ”์ถœ
input_line = line[3:].strip() # >>> ์ œ๊ฑฐ
# ๋‹ค์Œ ์ค„์—์„œ ์ถœ๋ ฅ ์ถ”์ถœ
if i + 1 < len(lines):
output_line = lines[i + 1].strip()
# ์ถœ๋ ฅ์ด >>> ๋กœ ์‹œ์ž‘ํ•˜์ง€ ์•Š์œผ๋ฉด ์ถœ๋ ฅ๊ฐ’
if not output_line.startswith('>>>'):
examples.append((input_line, output_line))
i += 2
continue
i += 1
else:
i += 1
return examples
def _extract_test_cases(self, problem: Dict[str, Any], solution: str) -> List[Tuple[str, str]]:
"""docstring์˜ ์˜ˆ์ œ์—์„œ ํ…Œ์ŠคํŠธ ์ผ€์ด์Šค ์ถ”์ถœ (์น˜ํŒ… ๋ฐฉ์ง€)"""
test_cases = []
func_name = problem.get('entry_point', 'unknown')
problem_id = problem.get('task_id', '')
# HumanEval๊ณผ MBPP ๋ชจ๋‘ docstring ์˜ˆ์ œ๋งŒ ์‚ฌ์šฉ
self.logger.log_info(f"๐ŸŽฏ Extracting docstring examples for {problem_id}")
# ํ”„๋กฌํ”„ํŠธ์—์„œ docstring ์˜ˆ์ œ ์ถ”์ถœ
prompt = problem.get('prompt', '')
examples = self._extract_docstring_examples(prompt, func_name)
if examples:
self.logger.log_info(f"๐Ÿ“ Found {len(examples)} docstring examples")
for i, (input_str, expected_output) in enumerate(examples):
try:
# ์ž…๋ ฅ ํŒŒ์‹ฑ (func_name(args) ํ˜•ํƒœ์—์„œ args ์ถ”์ถœ)
import ast
# "func_name(args)" -> args ์ถ”์ถœ
if input_str.startswith(func_name + '(') and input_str.endswith(')'):
args_str = input_str[len(func_name)+1:-1]
# ์•ˆ์ „ํ•œ ํ‰๊ฐ€๋ฅผ ์œ„ํ•ด ast.literal_eval ์‚ฌ์šฉ
try:
# ๋‹จ์ผ ์ธ์ž์ธ ๊ฒฝ์šฐ
input_args = ast.literal_eval(args_str)
if not isinstance(input_args, tuple):
input_args = (input_args,)
except:
# ์—ฌ๋Ÿฌ ์ธ์ž์ธ ๊ฒฝ์šฐ
input_args = ast.literal_eval(f"({args_str})")
# LLM ์†”๋ฃจ์…˜ ์‹คํ–‰
actual_output = self._execute_llm_solution(solution, func_name, list(input_args))
if actual_output is not None:
test_cases.append((input_str, str(actual_output)))
self.logger.log_info(f"โœ… Example {i+1}: {input_str} -> {actual_output}")
else:
self.logger.log_warning(f"โŒ Example {i+1} execution failed")
except Exception as e:
self.logger.log_error(f"Example {i+1} parsing failed: {e}")
else:
self.logger.log_warning(f"โš ๏ธ No docstring examples found, falling back to first base_input")
# docstring ์˜ˆ์ œ๊ฐ€ ์—†์œผ๋ฉด ์ฒซ ๋ฒˆ์งธ base_input๋งŒ ์‚ฌ์šฉ (MBPP์ฒ˜๋Ÿผ)
if 'base_input' in problem and problem['base_input']:
inp_args = problem['base_input'][0]
# ์ž…๋ ฅ ๋ฌธ์ž์—ด ์ƒ์„ฑ
if isinstance(inp_args, list):
args_str = ', '.join(repr(arg) for arg in inp_args)
input_str = f"{func_name}({args_str})"
else:
input_str = f"{func_name}({repr(inp_args)})"
actual_output = self._execute_llm_solution(solution, func_name, inp_args)
if actual_output is not None:
test_cases.append((input_str, str(actual_output)))
self.logger.log_info(f"๐Ÿ“Š Extracted {len(test_cases)} test cases from docstring examples")
return test_cases
def _execute_llm_solution(self, llm_solution: str, func_name: str, input_args) -> Optional[str]:
"""LLM ์ƒ์„ฑ ์†”๋ฃจ์…˜์„ ์‹คํ–‰ํ•˜์—ฌ ์‹ค์ œ ์ถœ๋ ฅ ๊ณ„์‚ฐ"""
try:
if not llm_solution or func_name == 'unknown':
return None
# ๐Ÿ”ง ์ˆ˜์ •: ์‹คํ–‰์šฉ ์ฝ”๋“œ ๊ตฌ์„ฑ (MBPP+ ์ด์ค‘ ๋ฆฌ์ŠคํŠธ ์ฒ˜๋ฆฌ)
if isinstance(input_args, list):
# MBPP+ ๋ฐ์ดํ„ฐ๊ฐ€ ์ด์ค‘ ๋ฆฌ์ŠคํŠธ๋กœ ๊ฐ์‹ธ์ง„ ๊ฒฝ์šฐ ์ฒ˜๋ฆฌ
if len(input_args) == 1 and isinstance(input_args[0], list):
# [[args]] -> ๋‹จ์ผ ๋ฆฌ์ŠคํŠธ ์ธ์ž๋กœ ์ „๋‹ฌ
args_str = repr(input_args[0])
elif len(input_args) == 1:
# [๋‹จ์ผ์ธ์ž] -> ๋‹จ์ผ ์ธ์ž๋กœ ์ „๋‹ฌ
args_str = repr(input_args[0])
else:
# [๋‹ค์ค‘์ธ์ž] -> ๋‹ค์ค‘ ์ธ์ž๋กœ ์ „๋‹ฌ
args_str = ', '.join(repr(arg) for arg in input_args)
else:
args_str = repr(input_args)
execution_code = f"""
{llm_solution}
# Execute LLM solution
try:
result = {func_name}({args_str})
print(repr(result))
except Exception as e:
print(f"EXECUTION_ERROR: {{e}}")
"""
# AZR Python Executor๋กœ ์‹คํ–‰
output, status = self.executor.apply(execution_code)
if 'error' in status.lower() or 'EXECUTION_ERROR' in output:
return None
# ์ถœ๋ ฅ์—์„œ ๊ฒฐ๊ณผ ์ถ”์ถœ
output_lines = output.strip().split('\n')
if output_lines:
result_line = output_lines[-1].strip()
# repr()๋กœ ์ถœ๋ ฅ๋œ ๊ฒฐ๊ณผ๋ฅผ ๊ทธ๋Œ€๋กœ ๋ฐ˜ํ™˜
return result_line
return None
except Exception as e:
self.logger.log_error(f"LLM solution execution failed: {e}")
return None
def _create_ipo_triple(self, solution: str, func_info: Dict[str, str],
test_input: str, expected_output: str,
triple_id: str, full_input_str: str = None) -> Optional[Dict[str, Any]]:
"""IPO ํŠธ๋ฆฌํ”Œ ์ƒ์„ฑ ๋ฐ ๊ฒ€์ฆ (AZR Python Executor ์‚ฌ์šฉ)"""
try:
# 1. ์†”๋ฃจ์…˜ ์‹คํ–‰์œผ๋กœ ์‹ค์ œ ์ถœ๋ ฅ ํ™•์ธ
actual_output = self._execute_function(solution, func_info['name'], test_input)
if actual_output is None:
return None
# 2. IPO ํŠธ๋ฆฌํ”Œ ๊ตฌ์„ฑ
triple = {
'id': triple_id,
'input': test_input, # ์‹ค์ œ ์ธ์ž๋งŒ ์ €์žฅ (์˜ˆ: "''", "3.5")
'full_input_str': full_input_str or f"{func_info['name']}({test_input})", # ์ „์ฒด ์ž…๋ ฅ ๋ฌธ์ž์—ด์€ ๋ณ„๋„ ํ•„๋“œ์—
'program': solution, # ์ด๋ฏธ func_info['full_code']๊ฐ€ ์ „๋‹ฌ๋จ
'expected_output': expected_output,
'actual_output': actual_output,
'function_name': func_info['name'],
'function_args': func_info['args'],
'is_correct': str(actual_output) == str(expected_output),
'extraction_method': 'test_case'
}
return triple
except Exception as e:
self.logger.log_error(f"Triple creation failed for {triple_id}: {e}")
return None
def _execute_function(self, code: str, func_name: str, inputs: str) -> Optional[str]:
"""AZR Python Executor๋กœ ํ•จ์ˆ˜ ์‹คํ–‰"""
try:
# ์‹คํ–‰์šฉ ์ฝ”๋“œ ๊ตฌ์„ฑ (AZR ํ…œํ”Œ๋ฆฟ ์Šคํƒ€์ผ)
execution_code = f"""
{code}
# Execute function with inputs
try:
result = {func_name}({inputs})
print(repr(result))
except Exception as e:
print(f"EXECUTION_ERROR: {{e}}")
"""
# AZR ๋ฐฉ์‹์œผ๋กœ ์‹คํ–‰
output, status = self.executor.apply(execution_code)
if 'error' in status.lower() or 'EXECUTION_ERROR' in output:
return None
# ์ถœ๋ ฅ์—์„œ ๊ฒฐ๊ณผ ์ถ”์ถœ
output_lines = output.strip().split('\n')
if output_lines:
return output_lines[-1].strip()
return None
except Exception as e:
self.logger.log_error(f"Function execution failed: {e}")
return None
# ๐Ÿ”ง ์ œ๊ฑฐ: Synthetic ํŠธ๋ฆฌํ”Œ ์ƒ์„ฑ ๋ฉ”์„œ๋“œ๋“ค ์ œ๊ฑฐ
# ๋‹จ์ผ ์˜ˆ์‹œ๋งŒ ์‚ฌ์šฉํ•˜์—ฌ ์น˜ํŒ… ๋ฐฉ์ง€ ๋ชฉ์ ์— ๋งž๊ฒŒ ๋ถˆํ•„์š”ํ•œ ๋ฉ”์„œ๋“œ๋“ค ์ œ๊ฑฐ
def _validate_triple(self, triple: Dict[str, Any]) -> bool:
"""IPO ํŠธ๋ฆฌํ”Œ ๊ฒ€์ฆ"""
if not self.config.validate_triples:
return True
try:
# 1. ๊ธฐ๋ณธ ํ•„๋“œ ์กด์žฌ ํ™•์ธ
required_fields = ['input', 'program', 'expected_output', 'function_name']
if not all(field in triple for field in required_fields):
return False
# 2. ์ฝ”๋“œ ๊ตฌ๋ฌธ ๊ฒ€์ฆ
try:
ast.parse(triple['program'])
except SyntaxError:
return False
# 3. ์žฌ์‹คํ–‰์œผ๋กœ ์ผ๊ด€์„ฑ ๊ฒ€์ฆ (AZR ๋ฐฉ์‹)
# ์ด์ œ triple['input']์€ ์ด๋ฏธ ์‹ค์ œ ์ธ์ž๋งŒ ํฌํ•จ
actual_output = self._execute_function(
triple['program'],
triple['function_name'],
triple['input']
)
if actual_output is None:
return False
# 4. ์ถœ๋ ฅ ์ผ์น˜ ํ™•์ธ
return str(actual_output) == str(triple['expected_output'])
except Exception as e:
self.logger.log_error(f"Triple validation failed: {e}")
return False
def get_triple_statistics(self) -> Dict[str, Any]:
"""์ถ”์ถœ๋œ ํŠธ๋ฆฌํ”Œ ํ†ต๊ณ„"""
if not self.extracted_triples:
return {"total": 0, "valid": 0, "invalid": 0}
valid_count = sum(1 for triple in self.extracted_triples if triple.get('is_correct', False))
return {
"total": len(self.extracted_triples),
"valid": valid_count,
"invalid": len(self.extracted_triples) - valid_count,
"extraction_methods": {
"test_case": sum(1 for t in self.extracted_triples if t.get('extraction_method') == 'test_case'),
"synthetic": sum(1 for t in self.extracted_triples if t.get('extraction_method') == 'synthetic')
}
}
def generate_diverse_inputs(self, problem: Dict[str, Any], solution: str,
existing_examples: List[Tuple[str, str]]) -> List[Dict[str, Any]]:
"""LLM์„ ์‚ฌ์šฉํ•˜์—ฌ ๋‹ค์–‘ํ•œ ์ž…๋ ฅ ์ƒ์„ฑ"""
problem_id = problem.get('task_id', 'unknown')
self.logger.log_info(f"๐ŸŽฒ Generating diverse inputs for {problem_id}")
try:
# 1. ํ•จ์ˆ˜ ์ •๋ณด ์ถ”์ถœ
entry_point = problem.get('entry_point', 'unknown')
func_info = self._extract_function_info(solution, entry_point)
if not func_info:
self.logger.log_error("Failed to extract function info for input generation")
return []
# 2. ์ธ์ž ํƒ€์ž… ์ •๋ณด ์ถ”๋ก 
arg_type_info = self._infer_argument_types(func_info, existing_examples, solution)
# 3. ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ
prompt = self._create_input_generation_prompt(
problem_description=problem.get('prompt', ''),
existing_examples=existing_examples,
full_code=solution,
arg_type_info=arg_type_info
)
# 4. LLM์œผ๋กœ ์ž…๋ ฅ ์ƒ์„ฑ
generated_inputs = self._call_llm_for_inputs(prompt, existing_examples, func_info, arg_type_info)
# 5. ์ƒ์„ฑ๋œ ์ž…๋ ฅ ๊ฒ€์ฆ
valid_inputs = self._validate_generated_inputs(generated_inputs, func_info, solution)
self.logger.log_info(f"โœ… Generated {len(valid_inputs)} valid diverse inputs")
return valid_inputs
except Exception as e:
self.logger.log_error(f"Failed to generate diverse inputs: {e}")
return []
def generate_diverse_inputs_batch(self, program_input_pairs: List[Dict[str, Any]]) -> Tuple[List[List[Dict[str, Any]]], List[Optional[Dict[str, Any]]]]:
"""๋ฐฐ์น˜๋กœ ์—ฌ๋Ÿฌ ํ”„๋กœ๊ทธ๋žจ์˜ diverse input ์ƒ์„ฑ"""
if not self.solution_generator:
self.logger.log_error("Solution generator not set for batch processing")
return [], []
self.logger.log_info(f"๐ŸŽฒ Generating diverse inputs for {len(program_input_pairs)} programs (BATCH)")
try:
# ๋ชจ๋“  ํ”„๋กœ๊ทธ๋žจ์˜ ์ž…๋ ฅ ์ƒ์„ฑ ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ
batch_prompts = []
program_contexts = []
for pair in program_input_pairs:
problem = pair['problem']
solution = pair['solution']
existing_examples = pair['existing_examples']
# ํ•จ์ˆ˜ ์ •๋ณด ์ถ”์ถœ
entry_point = problem.get('entry_point', 'unknown')
func_info = self._extract_function_info(solution, entry_point)
if not func_info:
program_contexts.append(None)
batch_prompts.append("")
continue
# ์ธ์ž ํƒ€์ž… ์ •๋ณด ์ถ”๋ก 
arg_type_info = self._infer_argument_types(func_info, existing_examples, solution)
# ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ
prompt = self._create_input_generation_prompt(
problem_description=problem.get('prompt', ''),
existing_examples=existing_examples,
full_code=solution,
arg_type_info=arg_type_info
)
batch_prompts.append(prompt)
program_contexts.append({
'func_info': func_info,
'solution': solution,
'problem': problem
})
# VLLM ๋ฐฐ์น˜๋กœ LLM ํ˜ธ์ถœ
if not batch_prompts or all(not p for p in batch_prompts):
return [], []
self.logger.log_info(f"๐Ÿ” Sending {len(batch_prompts)} prompts to VLLM for input generation")
self.logger.log_info(f"๐Ÿ” First prompt preview: {batch_prompts[0][:200]}..." if batch_prompts else "No prompts")
# Input generation์€ ์ฝ”๋“œ ์ƒ์„ฑ์ด ์•„๋‹ˆ๋ฏ€๋กœ ํ›„์ฒ˜๋ฆฌ ์—†์ด ์›์‹œ ์‘๋‹ต ์‚ฌ์šฉ
# generate_batch์˜ ํ›„์ฒ˜๋ฆฌ(ํ•จ์ˆ˜ ์ถ”์ถœ ๋“ฑ)๋Š” input generation์— ๋ถ€์ ํ•ฉ
batch_responses = self.solution_generator._generate_batch_with_vllm(
batch_prompts,
temperature=0.7 # Input generation์—๋Š” ์•ฝ๊ฐ„์˜ ๋žœ๋ค์„ฑ ํ•„์š”
)
self.logger.log_info(f"๐Ÿ” Received {len(batch_responses)} responses from VLLM")
for i, response in enumerate(batch_responses[:2]): # ์ฒ˜์Œ 2๊ฐœ๋งŒ ๋กœ๊น…
self.logger.log_info(f"๐Ÿ” Response {i} preview: {response[:200]}...")
# ๊ฐ ์‘๋‹ต์„ ํŒŒ์‹ฑํ•˜์—ฌ ์ž…๋ ฅ ์ƒ์„ฑ
batch_results = []
batch_generation_info = [] # ๊ฐ ํ”„๋กœ๊ทธ๋žจ์˜ input generation ์ •๋ณด ์ €์žฅ
for i, (response, context) in enumerate(zip(batch_responses, program_contexts)):
if context is None:
batch_results.append([])
batch_generation_info.append(None)
continue
try:
# ์‘๋‹ต์—์„œ ์ž…๋ ฅ ์ถ”์ถœ
generated_inputs = self._parse_llm_input_response(
response,
context['func_info'],
context['problem'].get('task_id', 'unknown')
)
# ๋””๋ฒ„๊น…: ํŒŒ์‹ฑ๋œ ์ž…๋ ฅ ๊ฐœ์ˆ˜ ๋กœ๊น…
self.logger.log_info(f"๐Ÿ” Parsed {len(generated_inputs)} inputs from response {i}")
if generated_inputs:
self.logger.log_info(f"๐Ÿ” First parsed input: {generated_inputs[0]}")
# ์ƒ์„ฑ๋œ ์ž…๋ ฅ ๊ฒ€์ฆ
valid_inputs = self._validate_generated_inputs(
generated_inputs,
context['func_info'],
context['solution']
)
# ๋””๋ฒ„๊น…: ๊ฒ€์ฆ ํ›„ ์ž…๋ ฅ ๊ฐœ์ˆ˜ ๋กœ๊น…
self.logger.log_info(f"๐Ÿ” {len(valid_inputs)} inputs passed validation from response {i}")
batch_results.append(valid_inputs)
# Input generation ์ •๋ณด ์ €์žฅ
generation_info = {
'prompt': batch_prompts[i] if i < len(batch_prompts) else '',
'llm_response': response,
'extracted_inputs': generated_inputs,
'valid_inputs': valid_inputs,
'existing_examples': program_input_pairs[i]['existing_examples'] if i < len(program_input_pairs) else [],
'function_info': context['func_info'],
'arg_type_info': self._infer_argument_types(
context['func_info'],
program_input_pairs[i]['existing_examples'] if i < len(program_input_pairs) else [],
context['solution']
)
}
batch_generation_info.append(generation_info)
except Exception as e:
self.logger.log_error(f"Failed to process batch item {i}: {e}")
# ๋” ์ž์„ธํ•œ ๋””๋ฒ„๊น… ์ •๋ณด ์ถ”๊ฐ€
self.logger.log_error(f"Response preview: {response[:200]}...")
import traceback
self.logger.log_error(f"Traceback: {traceback.format_exc()}")
batch_results.append([])
# ์—๋Ÿฌ ์ •๋ณด๋„ ์ €์žฅ
batch_generation_info.append({
'error': str(e),
'prompt': batch_prompts[i] if i < len(batch_prompts) else '',
'llm_response': response,
'traceback': traceback.format_exc()
})
total_generated = sum(len(inputs) for inputs in batch_results)
self.logger.log_info(f"โœ… Generated {total_generated} diverse inputs across {len(program_input_pairs)} programs")
# Return both inputs and generation info as a tuple
return batch_results, batch_generation_info
except Exception as e:
self.logger.log_error(f"Batch input generation failed: {e}")
return [], []
def _parse_llm_input_response(self, llm_response: str, func_info: Dict[str, Any], problem_id: str) -> List[Dict[str, Any]]:
"""LLM ์‘๋‹ต์—์„œ ์ž…๋ ฅ ์˜ˆ์ œ ํŒŒ์‹ฑ"""
self.logger.log_info(f"๐Ÿ” Parsing LLM response for {problem_id}, response length: {len(llm_response)}")
try:
# ```python ... ``` ๋ธ”๋ก์—์„œ ์ฝ”๋“œ ์ถ”์ถœ
import re
code_pattern = r'```python\n(.*?)\n```'
matches = re.findall(code_pattern, llm_response, re.DOTALL)
if not matches:
self.logger.log_info("๐Ÿ” No code block found, searching for examples = [")
# ๋ธ”๋ก์ด ์—†์œผ๋ฉด ์ „์ฒด ์‘๋‹ต์—์„œ examples = ์ฐพ๊ธฐ
if 'examples = [' in llm_response:
start = llm_response.find('examples = [')
# ๊ท ํ˜•์žกํžŒ ๊ด„ํ˜ธ ์ฐพ๊ธฐ
bracket_count = 0
end = start
for i, char in enumerate(llm_response[start:]):
if char == '[':
bracket_count += 1
elif char == ']':
bracket_count -= 1
if bracket_count == 0:
end = start + i + 1
break
if end > start:
code = llm_response[start:end]
self.logger.log_info(f"๐Ÿ” Found examples code: {code[:100]}...")
exec_globals = {}
exec(code, exec_globals)
examples = exec_globals.get('examples', [])
self.logger.log_info(f"๐Ÿ” Extracted {len(examples)} examples")
return examples
else:
self.logger.log_info("๐Ÿ” No 'examples = [' found in response")
else:
# ์ฝ”๋“œ ๋ธ”๋ก์—์„œ examples ์ถ”์ถœ
self.logger.log_info(f"๐Ÿ” Found {len(matches)} code blocks")
code = matches[0]
self.logger.log_info(f"๐Ÿ” Code block preview: {code[:100]}...")
exec_globals = {}
exec(code, exec_globals)
examples = exec_globals.get('examples', [])
self.logger.log_info(f"๐Ÿ” Extracted {len(examples)} examples from code block")
# examples๊ฐ€ dict๊ฐ€ ์•„๋‹Œ ๊ฒฝ์šฐ ์ฒ˜๋ฆฌ
if examples and len(examples) > 0:
self.logger.log_info(f"๐Ÿ” First example type: {type(examples[0])}")
if isinstance(examples[0], dict):
# expected_output, description ๋“ฑ ๋ถˆํ•„์š”ํ•œ ํ‚ค ์ œ๊ฑฐ
cleaned_examples = []
for ex in examples:
cleaned = {k: v for k, v in ex.items()
if k not in ['expected_output', 'description']}
if cleaned: # ๋นˆ dict๊ฐ€ ์•„๋‹Œ ๊ฒฝ์šฐ๋งŒ ์ถ”๊ฐ€
cleaned_examples.append(cleaned)
self.logger.log_info(f"๐Ÿ” Cleaned {len(cleaned_examples)} examples")
return cleaned_examples
return examples
return []
except Exception as e:
self.logger.log_error(f"Failed to parse generated examples for {problem_id}: {e}")
import traceback
self.logger.log_error(f"Traceback: {traceback.format_exc()}")
return []
def _infer_argument_types(self, func_info: Dict[str, str],
examples: List[Tuple[str, str]],
solution: str) -> Dict[str, str]:
"""๊ธฐ์กด ์˜ˆ์ œ์™€ AST ๋ถ„์„์œผ๋กœ ์ธ์ž ํƒ€์ž… ์ถ”๋ก """
arg_types = {}
func_name = func_info['name']
arg_names = func_info['args']
# 1. AST์—์„œ type annotation ์ถ”์ถœ
try:
tree = ast.parse(solution)
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef) and node.name == func_name:
for i, arg in enumerate(node.args.args):
if i < len(arg_names) and arg.annotation:
# Type annotation์ด ์žˆ๋Š” ๊ฒฝ์šฐ
arg_types[arg_names[i]] = ast.unparse(arg.annotation)
except:
pass
# 2. ๊ธฐ์กด ์˜ˆ์ œ์—์„œ ํƒ€์ž… ์ถ”๋ก 
if examples:
for input_str, _ in examples:
# "func_name(args)" ํ˜•ํƒœ์—์„œ args ์ถ”์ถœ
if input_str.startswith(func_name + '(') and input_str.endswith(')'):
args_str = input_str[len(func_name)+1:-1]
try:
# ์ธ์ž ํŒŒ์‹ฑ
parsed_args = eval(f"({args_str},)")
if not isinstance(parsed_args, tuple):
parsed_args = (parsed_args,)
# ๊ฐ ์ธ์ž์˜ ํƒ€์ž… ์ถ”๋ก 
for i, arg_value in enumerate(parsed_args):
if i < len(arg_names):
arg_name = arg_names[i]
arg_type = type(arg_value).__name__
# ํŠน๋ณ„ํ•œ ์ผ€์ด์Šค ์ฒ˜๋ฆฌ
if isinstance(arg_value, list):
if arg_value and all(isinstance(x, type(arg_value[0])) for x in arg_value):
inner_type = type(arg_value[0]).__name__
arg_type = f"List[{inner_type}]"
else:
arg_type = "List"
# ๊ธฐ์กด ํƒ€์ž…๊ณผ ๋ณ‘ํ•ฉ
if arg_name not in arg_types:
arg_types[arg_name] = arg_type
except:
pass
# 3. ํƒ€์ž… ์ •๋ณด ๋”•์…”๋„ˆ๋ฆฌ๋กœ ๋ฐ˜ํ™˜
# arg_types๊ฐ€ ๋น„์–ด์žˆ์œผ๋ฉด unknown ํƒ€์ž…์œผ๋กœ ์ฑ„์šฐ๊ธฐ
for arg_name in arg_names:
if arg_name not in arg_types:
arg_types[arg_name] = "Any (type unknown)"
return arg_types
def _create_input_generation_prompt(self, problem_description: str,
existing_examples: List[Tuple[str, str]],
full_code: str,
arg_type_info: Dict[str, str]) -> str:
"""์ž…๋ ฅ ์ƒ์„ฑ์„ ์œ„ํ•œ ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ"""
# ๋ชจ๋“  ๊ธฐ์กด ์˜ˆ์ œ๋ฅผ ํฌ๋งทํŒ…
examples_text = ""
for i, (input_str, output_str) in enumerate(existing_examples):
examples_text += f"Example {i+1}:\n"
examples_text += f"Input: {input_str}\n"
examples_text += f"Output: {output_str}\n\n"
# arg_type_info๋ฅผ ๋ฌธ์ž์—ด๋กœ ํฌ๋งทํŒ…
arg_type_text = "Argument types:\n"
for arg, arg_type in arg_type_info.items():
arg_type_text += f"- {arg}: {arg_type}\n"
prompt = f"""Given the following problem description and its Python function implementation, first analyze the types and valid ranges of the function arguments, then write **5 different example inputs** for the function that cover a diverse mix of typical (general) cases and edge/boundary cases.
Problem Description:
'''
{problem_description}
'''
Existing Examples from Problem:
{examples_text}
Function Implementation:
```python
{full_code}
```
{arg_type_text}
Based on the existing examples above, generate 5 NEW diverse test inputs that are different from the existing ones. Each input should be a Python dict where:
- Keys are the exact parameter names from the function signature
- Values are appropriate test values for each parameter
Format your response as:
```python
examples = [
{{dict_with_all_function_parameters}}, # Description of this test case
{{dict_with_all_function_parameters}}, # Description of this test case
... # Continue for all 5 examples
]
```
Ensure your examples include:
- At least 2 typical/general cases
- At least 2 edge/boundary cases
- 1 special case (empty, zero, maximum values, etc.)
- All examples should be DIFFERENT from the existing examples shown above"""
return prompt
def _call_llm_for_inputs(self, prompt: str, existing_examples: List[Tuple[str, str]],
func_info: Dict[str, Any], arg_type_info: str) -> List[Dict[str, Any]]:
"""LLM์„ ํ˜ธ์ถœํ•˜์—ฌ ์ž…๋ ฅ ์ƒ์„ฑ ๋ฐ ํŒŒ์‹ฑ"""
# ํ”„๋กฌํ”„ํŠธ ์ €์žฅ
self.last_generation_prompt = prompt
try:
# Input ์ƒ์„ฑ์šฉ ์ „์šฉ LLM ํ˜ธ์ถœ (temperature=0.5)
if self.model is not None and self.tokenizer is not None:
# VLLM ์‚ฌ์šฉ ํ™•์ธ
try:
from vllm import LLM
if isinstance(self.model, LLM):
response = self._generate_with_vllm_for_inputs(prompt)
else:
response = self._generate_with_hf_for_inputs(prompt)
except ImportError:
response = self._generate_with_hf_for_inputs(prompt)
# ์‘๋‹ต ์ €์žฅ
self.last_generation_response = response
# ์‘๋‹ต์—์„œ examples ์ถ”์ถœ
parsed_inputs = self._parse_generated_examples(response)
# ์ž…๋ ฅ ์ƒ์„ฑ ์ •๋ณด ์ €์žฅ
self.last_input_generation_info = {
'prompt': prompt,
'llm_response': response,
'extracted_inputs': parsed_inputs,
'existing_examples': existing_examples,
'function_info': func_info,
'arg_type_info': arg_type_info
}
return parsed_inputs
else:
# ๋ชจ๋ธ์ด ์—†์œผ๋ฉด ๋นˆ ๋ฆฌ์ŠคํŠธ ๋ฐ˜ํ™˜ (ํ…Œ์ŠคํŠธ ํ™˜๊ฒฝ)
self.logger.log_warning("No model available for input generation")
self.last_generation_response = "No model available"
# ์‹คํŒจํ•œ ๊ฒฝ์šฐ์—๋„ ์ •๋ณด ์ €์žฅ
self.last_input_generation_info = {
'prompt': prompt,
'llm_response': "No model available",
'extracted_inputs': [],
'existing_examples': existing_examples,
'function_info': func_info,
'arg_type_info': arg_type_info,
'error': "No model available"
}
return []
except Exception as e:
self.logger.log_error(f"Failed to call LLM for inputs: {e}")
self.last_generation_response = f"Error: {str(e)}"
# ์—๋Ÿฌ ๋ฐœ์ƒ ์‹œ์—๋„ ์ •๋ณด ์ €์žฅ
self.last_input_generation_info = {
'prompt': locals().get('prompt', 'N/A'),
'llm_response': f"Error: {str(e)}",
'extracted_inputs': [],
'existing_examples': locals().get('existing_examples', []),
'function_info': locals().get('func_info', {}),
'arg_type_info': locals().get('arg_type_info', 'N/A'),
'error': str(e)
}
return []
def _generate_with_vllm_for_inputs(self, prompt: str) -> str:
"""Input ์ƒ์„ฑ์šฉ VLLM ๋ฐฑ์—”๋“œ (temperature=0.5๋กœ ๋‹ค์–‘์„ฑ ํ™•๋ณด)"""
try:
from vllm import SamplingParams
# Input ์ƒ์„ฑ์šฉ ๋†’์€ temperature ์„ค์ •
sampling_params = SamplingParams(
temperature=0.5, # ๋‹ค์–‘ํ•œ ์ž…๋ ฅ ์ƒ์„ฑ์„ ์œ„ํ•œ ๋†’์€ temperature
max_tokens=2048,
top_p=0.95, # ๋‹ค์–‘์„ฑ์„ ์œ„ํ•ด top_p ์‚ฌ์šฉ
stop=["\n```\n"], # ์ฝ”๋“œ ๋ธ”๋ก ์ข…๋ฃŒ ์‹œ ์ •์ง€
)
outputs = self.model.generate([prompt], sampling_params, use_tqdm=False)
return outputs[0].outputs[0].text.replace("\t", " ").strip()
except Exception as e:
self.logger.log_error(f"VLLM input generation failed: {e}")
return ""
def _generate_with_hf_for_inputs(self, prompt: str) -> str:
"""Input ์ƒ์„ฑ์šฉ HuggingFace ๋ฐฑ์—”๋“œ (temperature=0.5๋กœ ๋‹ค์–‘์„ฑ ํ™•๋ณด)"""
try:
import torch
# ํ† ํฌ๋‚˜์ด์ € ์ฒ˜๋ฆฌ
inputs = self.tokenizer(prompt, return_tensors='pt', truncation=True, max_length=4096)
# attention mask ๋ช…์‹œ์ ์œผ๋กœ ์„ค์ •
if 'attention_mask' not in inputs:
inputs['attention_mask'] = torch.ones_like(inputs['input_ids'])
# ๋””๋ฐ”์ด์Šค ์ด๋™
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
with torch.no_grad():
# ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Input ์ƒ์„ฑ์šฉ sampling ์„ค์ •
outputs = self.model.generate(
inputs['input_ids'],
attention_mask=inputs['attention_mask'],
max_new_tokens=2048,
do_sample=True, # sampling ํ™œ์„ฑํ™”
temperature=0.5, # ๋‹ค์–‘ํ•œ ์ž…๋ ฅ ์ƒ์„ฑ์„ ์œ„ํ•œ temperature
top_p=0.95, # ๋‹ค์–‘์„ฑ์„ ์œ„ํ•ด top_p ์‚ฌ์šฉ
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
# ์‘๋‹ต ์ถ”์ถœ
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response[len(prompt):].strip()
return response
except Exception as e:
self.logger.log_error(f"HuggingFace input generation failed: {e}")
return ""
def _parse_generated_examples(self, llm_response: str) -> List[Dict[str, Any]]:
"""LLM ์‘๋‹ต์—์„œ ์˜ˆ์ œ ํŒŒ์‹ฑ"""
try:
# ```python ... ``` ๋ธ”๋ก์—์„œ ์ฝ”๋“œ ์ถ”์ถœ
import re
code_pattern = r'```python\n(.*?)\n```'
matches = re.findall(code_pattern, llm_response, re.DOTALL)
if not matches:
# ๋ธ”๋ก์ด ์—†์œผ๋ฉด ์ „์ฒด ์‘๋‹ต์—์„œ examples = ์ฐพ๊ธฐ
if 'examples = [' in llm_response:
start = llm_response.find('examples = [')
# ๊ท ํ˜•์žกํžŒ ๊ด„ํ˜ธ ์ฐพ๊ธฐ
bracket_count = 0
end = start
for i, char in enumerate(llm_response[start:]):
if char == '[':
bracket_count += 1
elif char == ']':
bracket_count -= 1
if bracket_count == 0:
end = start + i + 1
break
if end > start:
code = llm_response[start:end]
exec_globals = {}
exec(code, exec_globals)
return exec_globals.get('examples', [])
else:
# ์ฝ”๋“œ ๋ธ”๋ก์—์„œ examples ์ถ”์ถœ
code = matches[0]
exec_globals = {}
exec(code, exec_globals)
return exec_globals.get('examples', [])
return []
except Exception as e:
self.logger.log_error(f"Failed to parse generated examples: {e}")
return []
def _validate_generated_inputs(self, generated_inputs: List[Dict[str, Any]],
func_info: Dict[str, str],
solution: str) -> List[Dict[str, Any]]:
"""์ƒ์„ฑ๋œ ์ž…๋ ฅ์˜ ์œ ํšจ์„ฑ ๊ฒ€์ฆ"""
valid_inputs = []
func_name = func_info['name']
for i, input_dict in enumerate(generated_inputs):
try:
# 1. ํ•„์ˆ˜ ์ธ์ž ํ™•์ธ
required_args = set(func_info['args'])
provided_args = set(input_dict.keys())
if not required_args.issubset(provided_args):
self.logger.log_warning(f"Input {i+1} missing required args: {required_args - provided_args}")
continue
# 2. ์‹ค์ œ ์‹คํ–‰์œผ๋กœ ๊ฒ€์ฆ
# ์ธ์ž๋ฅผ ์ˆœ์„œ๋Œ€๋กœ ๋ฐฐ์—ด
args = [input_dict[arg] for arg in func_info['args'] if arg in input_dict]
# ์‹คํ–‰ ํ…Œ์ŠคํŠธ
output = self._execute_llm_solution(solution, func_name, args)
if output is not None:
valid_inputs.append(input_dict)
self.logger.log_info(f"โœ… Valid input {i+1}: {input_dict}")
else:
self.logger.log_warning(f"โŒ Input {i+1} execution failed")
except Exception as e:
self.logger.log_error(f"Input {i+1} validation error: {e}")
return valid_inputs
def create_ipo_from_input(self, problem: Dict[str, Any],
solution: str,
input_dict: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""์ƒˆ๋กœ์šด ์ž…๋ ฅ์œผ๋กœ IPO triple ์ƒ์„ฑ"""
try:
problem_id = problem.get('task_id', 'unknown')
entry_point = problem.get('entry_point', 'unknown')
# ํ•จ์ˆ˜ ์ •๋ณด ์ถ”์ถœ
func_info = self._extract_function_info(solution, entry_point)
if not func_info:
return None
# ์ธ์ž๋ฅผ ์ˆœ์„œ๋Œ€๋กœ ๋ฐฐ์—ด
args = [input_dict[arg] for arg in func_info['args'] if arg in input_dict]
# ์‹คํ–‰ํ•˜์—ฌ ์ถœ๋ ฅ ์–ป๊ธฐ
output = self._execute_llm_solution(solution, func_info['name'], args)
if output is None:
return None
# ์ž…๋ ฅ ๋ฌธ์ž์—ด ์ƒ์„ฑ
args_str = ', '.join(repr(arg) for arg in args)
full_input_str = f"{func_info['name']}({args_str})"
# IPO triple ์ƒ์„ฑ
triple_id = f"{problem_id}_generated_{len(self.extracted_triples)}"
triple = {
'id': triple_id,
'input': args_str, # ์‹ค์ œ ์ธ์ž๋งŒ
'full_input_str': full_input_str, # ์ „์ฒด ํ•จ์ˆ˜ ํ˜ธ์ถœ
'program': solution,
'expected_output': output,
'actual_output': output,
'function_name': func_info['name'],
'function_args': func_info['args'],
'is_correct': True, # ์ƒ์„ฑ๋œ ๊ฒƒ์€ ํ•ญ์ƒ ์ •ํ™•
'extraction_method': 'generated'
}
return triple
except Exception as e:
self.logger.log_error(f"Failed to create IPO from input: {e}")
return None
def cleanup(self):
"""๋ฆฌ์†Œ์Šค ์ •๋ฆฌ"""
if hasattr(self.executor, 'cleanup'):
self.executor.cleanup()