|
""" |
|
IPO Triple Extractor |
|
|
|
AZR Python Executor ๊ธฐ๋ฐ (Input, Program, Output) ํธ๋ฆฌํ ์ถ์ถ ์์คํ
|
|
์๊ตฌ์ฌํญ 2: "AZR Python Executor๋ฅผ ์ด์ฉํ์ฌ (i,p,o) pair๋ฅผ ๋ง๋ ๋ค" |
|
""" |
|
|
|
import ast |
|
import re |
|
import json |
|
from typing import Dict, List, Any, Tuple, Optional |
|
from concurrent.futures import TimeoutError |
|
|
|
from ..utils.code_utils.python_executor import PythonExecutor |
|
from .config import TestTimeConfig |
|
from .logger import TestTimeLogger |
|
from .solution_generator import InitialSolutionGenerator |
|
|
|
|
|
class IPOBuffer: |
|
"""IPO triple์ ์ ์ฅํ๊ณ ๊ด๋ฆฌํ๋ ๋ฒํผ""" |
|
|
|
def __init__(self): |
|
self.buffer = {} |
|
|
|
def add(self, problem_id: str, ipo_triple: Dict[str, Any]): |
|
"""IPO triple์ ๋ฒํผ์ ์ถ๊ฐ""" |
|
if problem_id not in self.buffer: |
|
self.buffer[problem_id] = [] |
|
self.buffer[problem_id].append(ipo_triple) |
|
|
|
def get_all(self, problem_id: str) -> List[Dict[str, Any]]: |
|
"""ํน์ ๋ฌธ์ ์ ๋ชจ๋ IPO triple ๋ฐํ""" |
|
return self.buffer.get(problem_id, []) |
|
|
|
def clear(self, problem_id: str = None): |
|
"""๋ฒํผ ์ด๊ธฐํ""" |
|
if problem_id: |
|
self.buffer.pop(problem_id, None) |
|
else: |
|
self.buffer.clear() |
|
|
|
def size(self, problem_id: str = None) -> int: |
|
"""๋ฒํผ ํฌ๊ธฐ ๋ฐํ""" |
|
if problem_id: |
|
return len(self.buffer.get(problem_id, [])) |
|
return sum(len(triples) for triples in self.buffer.values()) |
|
|
|
|
|
class IPOTripleExtractor: |
|
"""(Input, Program, Output) ํธ๋ฆฌํ ์ถ์ถ ๋ฐ ๊ฒ์ฆ""" |
|
|
|
def __init__(self, config: TestTimeConfig, logger: Optional[TestTimeLogger] = None, |
|
model=None, tokenizer=None): |
|
self.config = config |
|
self.logger = logger or TestTimeLogger() |
|
self.model = model |
|
self.tokenizer = tokenizer |
|
|
|
|
|
self.executor = PythonExecutor( |
|
timeout_length=config.python_executor_timeout, |
|
ast_check=True, |
|
max_workers=config.max_workers |
|
) |
|
|
|
self.extracted_triples = [] |
|
|
|
|
|
self.last_generation_prompt = "" |
|
self.last_generation_response = "" |
|
|
|
|
|
self.solution_generator = None |
|
|
|
def extract_triples(self, problem: Dict[str, Any], solution: str) -> List[Dict[str, Any]]: |
|
"""๋ฒค์น๋งํฌ ๋ฌธ์ ์ ์๋ฃจ์
์์ IPO ํธ๋ฆฌํ ์ถ์ถ""" |
|
|
|
problem_id = problem.get('task_id', 'unknown') |
|
self.logger.log_info(f"๐ Extracting IPO triples for {problem_id}") |
|
|
|
triples = [] |
|
|
|
try: |
|
|
|
entry_point = problem.get('entry_point', 'unknown') |
|
func_info = self._extract_function_info(solution, entry_point) |
|
if not func_info: |
|
self.logger.log_error(f"Failed to extract function info from solution") |
|
return [] |
|
|
|
|
|
test_cases = self._extract_test_cases(problem, solution) |
|
|
|
|
|
for i, (test_input_str, expected_output) in enumerate(test_cases): |
|
if len(triples) >= self.config.max_ipo_triples: |
|
break |
|
|
|
|
|
import re |
|
match = re.match(rf'{entry_point}\((.*)\)', test_input_str) |
|
if match: |
|
actual_args = match.group(1) |
|
else: |
|
actual_args = test_input_str |
|
|
|
triple = self._create_ipo_triple( |
|
func_info['full_code'], |
|
func_info, |
|
actual_args, |
|
expected_output, |
|
triple_id=f"{problem_id}_triple_{i}", |
|
full_input_str=test_input_str |
|
) |
|
|
|
if triple: |
|
triples.append(triple) |
|
|
|
|
|
|
|
|
|
|
|
validation_results = [self._validate_triple(triple) for triple in triples] |
|
self.logger.log_ipo_extraction(problem_id, triples, validation_results) |
|
|
|
|
|
valid_triples = [triple for triple, valid in zip(triples, validation_results) if valid] |
|
|
|
self.logger.log_info(f"โ
Extracted {len(valid_triples)}/{len(triples)} valid IPO triples") |
|
return valid_triples |
|
|
|
except Exception as e: |
|
self.logger.log_error(f"IPO extraction failed: {e}") |
|
return [] |
|
|
|
def _extract_function_info(self, solution: str, entry_point: str = None) -> Optional[Dict[str, str]]: |
|
"""์๋ฃจ์
์์ ํจ์ ์ ๋ณด ์ถ์ถ (entry point ์ฐ์ )""" |
|
|
|
try: |
|
|
|
processed_solution = solution |
|
if "LLM GENERATED SOLUTION:" in solution: |
|
self.logger.log_info("๐ Raw LLM response detected, extracting function code") |
|
processed_solution = self._extract_function_from_llm_response(solution) |
|
if not processed_solution: |
|
self.logger.log_error("Failed to extract function from LLM response") |
|
return None |
|
|
|
|
|
tree = ast.parse(processed_solution) |
|
|
|
|
|
target_function = None |
|
all_functions = [] |
|
|
|
for node in ast.walk(tree): |
|
if isinstance(node, ast.FunctionDef): |
|
func_info = { |
|
'name': node.name, |
|
'args': [arg.arg for arg in node.args.args], |
|
'signature': f"def {node.name}({', '.join([arg.arg for arg in node.args.args])}):", |
|
'full_code': processed_solution |
|
} |
|
all_functions.append(func_info) |
|
|
|
|
|
if entry_point and node.name == entry_point: |
|
target_function = func_info |
|
|
|
self.logger.log_debug(f"๐ฏ Found entry point function: {entry_point}") |
|
break |
|
|
|
|
|
if target_function: |
|
return target_function |
|
|
|
|
|
if all_functions: |
|
self.logger.log_warning(f"โ ๏ธ Entry point '{entry_point}' not found, using first function: {all_functions[0]['name']}") |
|
return all_functions[0] |
|
|
|
return None |
|
|
|
except Exception as e: |
|
self.logger.log_error(f"Function parsing failed: {e}") |
|
return None |
|
|
|
def _extract_function_from_llm_response(self, llm_response: str) -> str: |
|
"""Raw LLM response์์ ํจ์ ์ฝ๋ ์ถ์ถ (solution_generator์ ๋์ผํ ๋ก์ง)""" |
|
|
|
lines = llm_response.split('\n') |
|
solution_lines = [] |
|
in_solution = False |
|
|
|
|
|
for i, line in enumerate(lines): |
|
if "LLM GENERATED SOLUTION:" in line: |
|
in_solution = True |
|
continue |
|
elif in_solution: |
|
|
|
if "===============" in line: |
|
|
|
if solution_lines and any(l.strip() for l in solution_lines): |
|
break |
|
else: |
|
|
|
continue |
|
solution_lines.append(line) |
|
|
|
if not solution_lines: |
|
return "" |
|
|
|
extracted_solution = '\n'.join(solution_lines).strip() |
|
|
|
|
|
lines = extracted_solution.split('\n') |
|
import_lines = [] |
|
func_lines = [] |
|
in_function = False |
|
indent_level = 0 |
|
|
|
|
|
for line in lines: |
|
stripped = line.strip() |
|
if (stripped.startswith('import ') or stripped.startswith('from ')) and not stripped.startswith('#'): |
|
import_lines.append(line) |
|
|
|
|
|
for line in lines: |
|
if line.strip().startswith('def '): |
|
in_function = True |
|
func_lines = [line] |
|
indent_level = len(line) - len(line.lstrip()) |
|
elif in_function: |
|
if not line.strip() or (line.strip() and len(line) - len(line.lstrip()) > indent_level): |
|
func_lines.append(line) |
|
else: |
|
break |
|
|
|
|
|
if func_lines: |
|
result_lines = import_lines + [''] + func_lines if import_lines else func_lines |
|
return '\n'.join(result_lines) |
|
else: |
|
return extracted_solution |
|
|
|
def _fix_humaneval_canonical_solution(self, problem: Dict[str, Any]) -> str: |
|
"""HumanEval canonical solution ๋ณต์ (ํจ์ ์๊ทธ๋์ฒ ์ถ๊ฐ)""" |
|
|
|
canonical_code = problem.get('canonical_solution', '') |
|
entry_point = problem.get('entry_point', '') |
|
prompt = problem.get('prompt', '') |
|
|
|
|
|
task_id = problem.get('task_id', '') |
|
if not task_id.startswith('HumanEval/'): |
|
return canonical_code |
|
|
|
|
|
if f"def {entry_point}" in canonical_code: |
|
return canonical_code |
|
|
|
try: |
|
|
|
import re |
|
def_pattern = rf'def\s+{re.escape(entry_point)}\s*\([^)]*\)[^:]*:' |
|
match = re.search(def_pattern, prompt, re.MULTILINE) |
|
|
|
if match: |
|
function_signature = match.group(0) |
|
|
|
|
|
import_lines = [] |
|
for line in prompt.split('\n'): |
|
stripped = line.strip() |
|
if (stripped.startswith('import ') or stripped.startswith('from ')) and not stripped.startswith('#'): |
|
import_lines.append(line) |
|
|
|
|
|
if import_lines: |
|
complete_canonical = '\n'.join(import_lines) + '\n\n' + function_signature + canonical_code |
|
else: |
|
complete_canonical = function_signature + canonical_code |
|
|
|
self.logger.log_info(f"๐ง Fixed HumanEval canonical solution for {entry_point}") |
|
return complete_canonical |
|
else: |
|
self.logger.log_warning(f"โ ๏ธ Could not extract function signature for {entry_point}") |
|
return canonical_code |
|
|
|
except Exception as e: |
|
self.logger.log_error(f"Failed to fix HumanEval canonical solution: {e}") |
|
return canonical_code |
|
|
|
def _extract_single_prompt_example(self, problem: Dict[str, Any]) -> Optional[Tuple[str, str]]: |
|
"""๐ง ์๋ก์ด ๋ฉ์๋: ํ๋กฌํํธ์ ๋จ์ผ ์์๋ง ์ถ์ถ (์นํ
๋ฐฉ์ง)""" |
|
|
|
try: |
|
|
|
if 'base_input' in problem and problem['base_input']: |
|
first_input = problem['base_input'][0] |
|
entry_point = problem['entry_point'] |
|
|
|
self.logger.log_info(f"๐ฅ Using first base_input as single example: {first_input}") |
|
|
|
|
|
canonical_code = self._fix_humaneval_canonical_solution(problem) |
|
if canonical_code: |
|
actual_output = self._execute_llm_solution(canonical_code, entry_point, first_input) |
|
|
|
if actual_output is not None: |
|
|
|
if isinstance(first_input, list): |
|
if len(first_input) == 1 and isinstance(first_input[0], list): |
|
|
|
input_str = repr(first_input[0]) |
|
elif len(first_input) == 1: |
|
|
|
input_str = repr(first_input[0]) |
|
else: |
|
|
|
input_str = ', '.join(repr(arg) for arg in first_input) |
|
else: |
|
input_str = repr(first_input) |
|
|
|
result = (input_str, str(actual_output)) |
|
self.logger.log_info(f"โ
Single example extracted: Input={input_str}, Output={actual_output}") |
|
return result |
|
else: |
|
self.logger.log_warning("โ Failed to compute output with canonical solution") |
|
else: |
|
self.logger.log_warning("โ No canonical solution available") |
|
else: |
|
self.logger.log_warning("โ No base_input available") |
|
|
|
except Exception as e: |
|
self.logger.log_error(f"Single example extraction failed: {e}") |
|
|
|
return None |
|
|
|
def _extract_docstring_examples(self, prompt: str, func_name: str) -> List[Tuple[str, str]]: |
|
"""docstring์์ >>> ์์ ์ถ์ถ""" |
|
|
|
examples = [] |
|
lines = prompt.split('\n') |
|
|
|
i = 0 |
|
while i < len(lines): |
|
line = lines[i].strip() |
|
|
|
if line.startswith('>>>') and func_name in line: |
|
|
|
input_line = line[3:].strip() |
|
|
|
|
|
if i + 1 < len(lines): |
|
output_line = lines[i + 1].strip() |
|
|
|
if not output_line.startswith('>>>'): |
|
examples.append((input_line, output_line)) |
|
i += 2 |
|
continue |
|
i += 1 |
|
else: |
|
i += 1 |
|
|
|
return examples |
|
|
|
def _extract_test_cases(self, problem: Dict[str, Any], solution: str) -> List[Tuple[str, str]]: |
|
"""docstring์ ์์ ์์ ํ
์คํธ ์ผ์ด์ค ์ถ์ถ (์นํ
๋ฐฉ์ง)""" |
|
|
|
test_cases = [] |
|
func_name = problem.get('entry_point', 'unknown') |
|
problem_id = problem.get('task_id', '') |
|
|
|
|
|
self.logger.log_info(f"๐ฏ Extracting docstring examples for {problem_id}") |
|
|
|
|
|
prompt = problem.get('prompt', '') |
|
examples = self._extract_docstring_examples(prompt, func_name) |
|
|
|
if examples: |
|
self.logger.log_info(f"๐ Found {len(examples)} docstring examples") |
|
for i, (input_str, expected_output) in enumerate(examples): |
|
try: |
|
|
|
import ast |
|
|
|
if input_str.startswith(func_name + '(') and input_str.endswith(')'): |
|
args_str = input_str[len(func_name)+1:-1] |
|
|
|
try: |
|
|
|
input_args = ast.literal_eval(args_str) |
|
if not isinstance(input_args, tuple): |
|
input_args = (input_args,) |
|
except: |
|
|
|
input_args = ast.literal_eval(f"({args_str})") |
|
|
|
|
|
actual_output = self._execute_llm_solution(solution, func_name, list(input_args)) |
|
if actual_output is not None: |
|
test_cases.append((input_str, str(actual_output))) |
|
self.logger.log_info(f"โ
Example {i+1}: {input_str} -> {actual_output}") |
|
else: |
|
self.logger.log_warning(f"โ Example {i+1} execution failed") |
|
|
|
except Exception as e: |
|
self.logger.log_error(f"Example {i+1} parsing failed: {e}") |
|
else: |
|
self.logger.log_warning(f"โ ๏ธ No docstring examples found, falling back to first base_input") |
|
|
|
if 'base_input' in problem and problem['base_input']: |
|
inp_args = problem['base_input'][0] |
|
|
|
if isinstance(inp_args, list): |
|
args_str = ', '.join(repr(arg) for arg in inp_args) |
|
input_str = f"{func_name}({args_str})" |
|
else: |
|
input_str = f"{func_name}({repr(inp_args)})" |
|
|
|
actual_output = self._execute_llm_solution(solution, func_name, inp_args) |
|
if actual_output is not None: |
|
test_cases.append((input_str, str(actual_output))) |
|
|
|
self.logger.log_info(f"๐ Extracted {len(test_cases)} test cases from docstring examples") |
|
return test_cases |
|
|
|
def _execute_llm_solution(self, llm_solution: str, func_name: str, input_args) -> Optional[str]: |
|
"""LLM ์์ฑ ์๋ฃจ์
์ ์คํํ์ฌ ์ค์ ์ถ๋ ฅ ๊ณ์ฐ""" |
|
|
|
try: |
|
if not llm_solution or func_name == 'unknown': |
|
return None |
|
|
|
|
|
if isinstance(input_args, list): |
|
|
|
if len(input_args) == 1 and isinstance(input_args[0], list): |
|
|
|
args_str = repr(input_args[0]) |
|
elif len(input_args) == 1: |
|
|
|
args_str = repr(input_args[0]) |
|
else: |
|
|
|
args_str = ', '.join(repr(arg) for arg in input_args) |
|
else: |
|
args_str = repr(input_args) |
|
|
|
execution_code = f""" |
|
{llm_solution} |
|
|
|
# Execute LLM solution |
|
try: |
|
result = {func_name}({args_str}) |
|
print(repr(result)) |
|
except Exception as e: |
|
print(f"EXECUTION_ERROR: {{e}}") |
|
""" |
|
|
|
|
|
output, status = self.executor.apply(execution_code) |
|
|
|
if 'error' in status.lower() or 'EXECUTION_ERROR' in output: |
|
return None |
|
|
|
|
|
output_lines = output.strip().split('\n') |
|
if output_lines: |
|
result_line = output_lines[-1].strip() |
|
|
|
return result_line |
|
|
|
return None |
|
|
|
except Exception as e: |
|
self.logger.log_error(f"LLM solution execution failed: {e}") |
|
return None |
|
|
|
def _create_ipo_triple(self, solution: str, func_info: Dict[str, str], |
|
test_input: str, expected_output: str, |
|
triple_id: str, full_input_str: str = None) -> Optional[Dict[str, Any]]: |
|
"""IPO ํธ๋ฆฌํ ์์ฑ ๋ฐ ๊ฒ์ฆ (AZR Python Executor ์ฌ์ฉ)""" |
|
|
|
try: |
|
|
|
actual_output = self._execute_function(solution, func_info['name'], test_input) |
|
|
|
if actual_output is None: |
|
return None |
|
|
|
|
|
triple = { |
|
'id': triple_id, |
|
'input': test_input, |
|
'full_input_str': full_input_str or f"{func_info['name']}({test_input})", |
|
'program': solution, |
|
'expected_output': expected_output, |
|
'actual_output': actual_output, |
|
'function_name': func_info['name'], |
|
'function_args': func_info['args'], |
|
'is_correct': str(actual_output) == str(expected_output), |
|
'extraction_method': 'test_case' |
|
} |
|
|
|
return triple |
|
|
|
except Exception as e: |
|
self.logger.log_error(f"Triple creation failed for {triple_id}: {e}") |
|
return None |
|
|
|
def _execute_function(self, code: str, func_name: str, inputs: str) -> Optional[str]: |
|
"""AZR Python Executor๋ก ํจ์ ์คํ""" |
|
|
|
try: |
|
|
|
execution_code = f""" |
|
{code} |
|
|
|
# Execute function with inputs |
|
try: |
|
result = {func_name}({inputs}) |
|
print(repr(result)) |
|
except Exception as e: |
|
print(f"EXECUTION_ERROR: {{e}}") |
|
""" |
|
|
|
|
|
output, status = self.executor.apply(execution_code) |
|
|
|
if 'error' in status.lower() or 'EXECUTION_ERROR' in output: |
|
return None |
|
|
|
|
|
output_lines = output.strip().split('\n') |
|
if output_lines: |
|
return output_lines[-1].strip() |
|
|
|
return None |
|
|
|
except Exception as e: |
|
self.logger.log_error(f"Function execution failed: {e}") |
|
return None |
|
|
|
|
|
|
|
|
|
def _validate_triple(self, triple: Dict[str, Any]) -> bool: |
|
"""IPO ํธ๋ฆฌํ ๊ฒ์ฆ""" |
|
|
|
if not self.config.validate_triples: |
|
return True |
|
|
|
try: |
|
|
|
required_fields = ['input', 'program', 'expected_output', 'function_name'] |
|
if not all(field in triple for field in required_fields): |
|
return False |
|
|
|
|
|
try: |
|
ast.parse(triple['program']) |
|
except SyntaxError: |
|
return False |
|
|
|
|
|
|
|
actual_output = self._execute_function( |
|
triple['program'], |
|
triple['function_name'], |
|
triple['input'] |
|
) |
|
|
|
if actual_output is None: |
|
return False |
|
|
|
|
|
return str(actual_output) == str(triple['expected_output']) |
|
|
|
except Exception as e: |
|
self.logger.log_error(f"Triple validation failed: {e}") |
|
return False |
|
|
|
def get_triple_statistics(self) -> Dict[str, Any]: |
|
"""์ถ์ถ๋ ํธ๋ฆฌํ ํต๊ณ""" |
|
|
|
if not self.extracted_triples: |
|
return {"total": 0, "valid": 0, "invalid": 0} |
|
|
|
valid_count = sum(1 for triple in self.extracted_triples if triple.get('is_correct', False)) |
|
|
|
return { |
|
"total": len(self.extracted_triples), |
|
"valid": valid_count, |
|
"invalid": len(self.extracted_triples) - valid_count, |
|
"extraction_methods": { |
|
"test_case": sum(1 for t in self.extracted_triples if t.get('extraction_method') == 'test_case'), |
|
"synthetic": sum(1 for t in self.extracted_triples if t.get('extraction_method') == 'synthetic') |
|
} |
|
} |
|
|
|
def generate_diverse_inputs(self, problem: Dict[str, Any], solution: str, |
|
existing_examples: List[Tuple[str, str]]) -> List[Dict[str, Any]]: |
|
"""LLM์ ์ฌ์ฉํ์ฌ ๋ค์ํ ์
๋ ฅ ์์ฑ""" |
|
|
|
problem_id = problem.get('task_id', 'unknown') |
|
self.logger.log_info(f"๐ฒ Generating diverse inputs for {problem_id}") |
|
|
|
try: |
|
|
|
entry_point = problem.get('entry_point', 'unknown') |
|
func_info = self._extract_function_info(solution, entry_point) |
|
if not func_info: |
|
self.logger.log_error("Failed to extract function info for input generation") |
|
return [] |
|
|
|
|
|
arg_type_info = self._infer_argument_types(func_info, existing_examples, solution) |
|
|
|
|
|
prompt = self._create_input_generation_prompt( |
|
problem_description=problem.get('prompt', ''), |
|
existing_examples=existing_examples, |
|
full_code=solution, |
|
arg_type_info=arg_type_info |
|
) |
|
|
|
|
|
generated_inputs = self._call_llm_for_inputs(prompt, existing_examples, func_info, arg_type_info) |
|
|
|
|
|
valid_inputs = self._validate_generated_inputs(generated_inputs, func_info, solution) |
|
|
|
self.logger.log_info(f"โ
Generated {len(valid_inputs)} valid diverse inputs") |
|
return valid_inputs |
|
|
|
except Exception as e: |
|
self.logger.log_error(f"Failed to generate diverse inputs: {e}") |
|
return [] |
|
|
|
def generate_diverse_inputs_batch(self, program_input_pairs: List[Dict[str, Any]]) -> Tuple[List[List[Dict[str, Any]]], List[Optional[Dict[str, Any]]]]: |
|
"""๋ฐฐ์น๋ก ์ฌ๋ฌ ํ๋ก๊ทธ๋จ์ diverse input ์์ฑ""" |
|
|
|
if not self.solution_generator: |
|
self.logger.log_error("Solution generator not set for batch processing") |
|
return [], [] |
|
|
|
self.logger.log_info(f"๐ฒ Generating diverse inputs for {len(program_input_pairs)} programs (BATCH)") |
|
|
|
try: |
|
|
|
batch_prompts = [] |
|
program_contexts = [] |
|
|
|
for pair in program_input_pairs: |
|
problem = pair['problem'] |
|
solution = pair['solution'] |
|
existing_examples = pair['existing_examples'] |
|
|
|
|
|
entry_point = problem.get('entry_point', 'unknown') |
|
func_info = self._extract_function_info(solution, entry_point) |
|
if not func_info: |
|
program_contexts.append(None) |
|
batch_prompts.append("") |
|
continue |
|
|
|
|
|
arg_type_info = self._infer_argument_types(func_info, existing_examples, solution) |
|
|
|
|
|
prompt = self._create_input_generation_prompt( |
|
problem_description=problem.get('prompt', ''), |
|
existing_examples=existing_examples, |
|
full_code=solution, |
|
arg_type_info=arg_type_info |
|
) |
|
|
|
batch_prompts.append(prompt) |
|
program_contexts.append({ |
|
'func_info': func_info, |
|
'solution': solution, |
|
'problem': problem |
|
}) |
|
|
|
|
|
if not batch_prompts or all(not p for p in batch_prompts): |
|
return [], [] |
|
|
|
self.logger.log_info(f"๐ Sending {len(batch_prompts)} prompts to VLLM for input generation") |
|
self.logger.log_info(f"๐ First prompt preview: {batch_prompts[0][:200]}..." if batch_prompts else "No prompts") |
|
|
|
|
|
|
|
batch_responses = self.solution_generator._generate_batch_with_vllm( |
|
batch_prompts, |
|
temperature=0.7 |
|
) |
|
|
|
self.logger.log_info(f"๐ Received {len(batch_responses)} responses from VLLM") |
|
for i, response in enumerate(batch_responses[:2]): |
|
self.logger.log_info(f"๐ Response {i} preview: {response[:200]}...") |
|
|
|
|
|
batch_results = [] |
|
batch_generation_info = [] |
|
|
|
for i, (response, context) in enumerate(zip(batch_responses, program_contexts)): |
|
if context is None: |
|
batch_results.append([]) |
|
batch_generation_info.append(None) |
|
continue |
|
|
|
try: |
|
|
|
generated_inputs = self._parse_llm_input_response( |
|
response, |
|
context['func_info'], |
|
context['problem'].get('task_id', 'unknown') |
|
) |
|
|
|
|
|
self.logger.log_info(f"๐ Parsed {len(generated_inputs)} inputs from response {i}") |
|
if generated_inputs: |
|
self.logger.log_info(f"๐ First parsed input: {generated_inputs[0]}") |
|
|
|
|
|
valid_inputs = self._validate_generated_inputs( |
|
generated_inputs, |
|
context['func_info'], |
|
context['solution'] |
|
) |
|
|
|
|
|
self.logger.log_info(f"๐ {len(valid_inputs)} inputs passed validation from response {i}") |
|
|
|
batch_results.append(valid_inputs) |
|
|
|
|
|
generation_info = { |
|
'prompt': batch_prompts[i] if i < len(batch_prompts) else '', |
|
'llm_response': response, |
|
'extracted_inputs': generated_inputs, |
|
'valid_inputs': valid_inputs, |
|
'existing_examples': program_input_pairs[i]['existing_examples'] if i < len(program_input_pairs) else [], |
|
'function_info': context['func_info'], |
|
'arg_type_info': self._infer_argument_types( |
|
context['func_info'], |
|
program_input_pairs[i]['existing_examples'] if i < len(program_input_pairs) else [], |
|
context['solution'] |
|
) |
|
} |
|
batch_generation_info.append(generation_info) |
|
|
|
except Exception as e: |
|
self.logger.log_error(f"Failed to process batch item {i}: {e}") |
|
|
|
self.logger.log_error(f"Response preview: {response[:200]}...") |
|
import traceback |
|
self.logger.log_error(f"Traceback: {traceback.format_exc()}") |
|
batch_results.append([]) |
|
|
|
|
|
batch_generation_info.append({ |
|
'error': str(e), |
|
'prompt': batch_prompts[i] if i < len(batch_prompts) else '', |
|
'llm_response': response, |
|
'traceback': traceback.format_exc() |
|
}) |
|
|
|
total_generated = sum(len(inputs) for inputs in batch_results) |
|
self.logger.log_info(f"โ
Generated {total_generated} diverse inputs across {len(program_input_pairs)} programs") |
|
|
|
|
|
return batch_results, batch_generation_info |
|
|
|
except Exception as e: |
|
self.logger.log_error(f"Batch input generation failed: {e}") |
|
return [], [] |
|
|
|
def _parse_llm_input_response(self, llm_response: str, func_info: Dict[str, Any], problem_id: str) -> List[Dict[str, Any]]: |
|
"""LLM ์๋ต์์ ์
๋ ฅ ์์ ํ์ฑ""" |
|
|
|
self.logger.log_info(f"๐ Parsing LLM response for {problem_id}, response length: {len(llm_response)}") |
|
|
|
try: |
|
|
|
import re |
|
code_pattern = r'```python\n(.*?)\n```' |
|
matches = re.findall(code_pattern, llm_response, re.DOTALL) |
|
|
|
if not matches: |
|
self.logger.log_info("๐ No code block found, searching for examples = [") |
|
|
|
if 'examples = [' in llm_response: |
|
start = llm_response.find('examples = [') |
|
|
|
bracket_count = 0 |
|
end = start |
|
for i, char in enumerate(llm_response[start:]): |
|
if char == '[': |
|
bracket_count += 1 |
|
elif char == ']': |
|
bracket_count -= 1 |
|
if bracket_count == 0: |
|
end = start + i + 1 |
|
break |
|
|
|
if end > start: |
|
code = llm_response[start:end] |
|
self.logger.log_info(f"๐ Found examples code: {code[:100]}...") |
|
exec_globals = {} |
|
exec(code, exec_globals) |
|
examples = exec_globals.get('examples', []) |
|
self.logger.log_info(f"๐ Extracted {len(examples)} examples") |
|
return examples |
|
else: |
|
self.logger.log_info("๐ No 'examples = [' found in response") |
|
else: |
|
|
|
self.logger.log_info(f"๐ Found {len(matches)} code blocks") |
|
code = matches[0] |
|
self.logger.log_info(f"๐ Code block preview: {code[:100]}...") |
|
exec_globals = {} |
|
exec(code, exec_globals) |
|
examples = exec_globals.get('examples', []) |
|
self.logger.log_info(f"๐ Extracted {len(examples)} examples from code block") |
|
|
|
|
|
if examples and len(examples) > 0: |
|
self.logger.log_info(f"๐ First example type: {type(examples[0])}") |
|
if isinstance(examples[0], dict): |
|
|
|
cleaned_examples = [] |
|
for ex in examples: |
|
cleaned = {k: v for k, v in ex.items() |
|
if k not in ['expected_output', 'description']} |
|
if cleaned: |
|
cleaned_examples.append(cleaned) |
|
self.logger.log_info(f"๐ Cleaned {len(cleaned_examples)} examples") |
|
return cleaned_examples |
|
|
|
return examples |
|
|
|
return [] |
|
|
|
except Exception as e: |
|
self.logger.log_error(f"Failed to parse generated examples for {problem_id}: {e}") |
|
import traceback |
|
self.logger.log_error(f"Traceback: {traceback.format_exc()}") |
|
return [] |
|
|
|
def _infer_argument_types(self, func_info: Dict[str, str], |
|
examples: List[Tuple[str, str]], |
|
solution: str) -> Dict[str, str]: |
|
"""๊ธฐ์กด ์์ ์ AST ๋ถ์์ผ๋ก ์ธ์ ํ์
์ถ๋ก """ |
|
|
|
arg_types = {} |
|
func_name = func_info['name'] |
|
arg_names = func_info['args'] |
|
|
|
|
|
try: |
|
tree = ast.parse(solution) |
|
for node in ast.walk(tree): |
|
if isinstance(node, ast.FunctionDef) and node.name == func_name: |
|
for i, arg in enumerate(node.args.args): |
|
if i < len(arg_names) and arg.annotation: |
|
|
|
arg_types[arg_names[i]] = ast.unparse(arg.annotation) |
|
except: |
|
pass |
|
|
|
|
|
if examples: |
|
for input_str, _ in examples: |
|
|
|
if input_str.startswith(func_name + '(') and input_str.endswith(')'): |
|
args_str = input_str[len(func_name)+1:-1] |
|
try: |
|
|
|
parsed_args = eval(f"({args_str},)") |
|
if not isinstance(parsed_args, tuple): |
|
parsed_args = (parsed_args,) |
|
|
|
|
|
for i, arg_value in enumerate(parsed_args): |
|
if i < len(arg_names): |
|
arg_name = arg_names[i] |
|
arg_type = type(arg_value).__name__ |
|
|
|
|
|
if isinstance(arg_value, list): |
|
if arg_value and all(isinstance(x, type(arg_value[0])) for x in arg_value): |
|
inner_type = type(arg_value[0]).__name__ |
|
arg_type = f"List[{inner_type}]" |
|
else: |
|
arg_type = "List" |
|
|
|
|
|
if arg_name not in arg_types: |
|
arg_types[arg_name] = arg_type |
|
except: |
|
pass |
|
|
|
|
|
|
|
for arg_name in arg_names: |
|
if arg_name not in arg_types: |
|
arg_types[arg_name] = "Any (type unknown)" |
|
|
|
return arg_types |
|
|
|
def _create_input_generation_prompt(self, problem_description: str, |
|
existing_examples: List[Tuple[str, str]], |
|
full_code: str, |
|
arg_type_info: Dict[str, str]) -> str: |
|
"""์
๋ ฅ ์์ฑ์ ์ํ ํ๋กฌํํธ ์์ฑ""" |
|
|
|
|
|
examples_text = "" |
|
for i, (input_str, output_str) in enumerate(existing_examples): |
|
examples_text += f"Example {i+1}:\n" |
|
examples_text += f"Input: {input_str}\n" |
|
examples_text += f"Output: {output_str}\n\n" |
|
|
|
|
|
arg_type_text = "Argument types:\n" |
|
for arg, arg_type in arg_type_info.items(): |
|
arg_type_text += f"- {arg}: {arg_type}\n" |
|
|
|
prompt = f"""Given the following problem description and its Python function implementation, first analyze the types and valid ranges of the function arguments, then write **5 different example inputs** for the function that cover a diverse mix of typical (general) cases and edge/boundary cases. |
|
|
|
Problem Description: |
|
''' |
|
{problem_description} |
|
''' |
|
|
|
Existing Examples from Problem: |
|
{examples_text} |
|
|
|
Function Implementation: |
|
```python |
|
{full_code} |
|
``` |
|
|
|
{arg_type_text} |
|
|
|
Based on the existing examples above, generate 5 NEW diverse test inputs that are different from the existing ones. Each input should be a Python dict where: |
|
- Keys are the exact parameter names from the function signature |
|
- Values are appropriate test values for each parameter |
|
|
|
Format your response as: |
|
```python |
|
examples = [ |
|
{{dict_with_all_function_parameters}}, # Description of this test case |
|
{{dict_with_all_function_parameters}}, # Description of this test case |
|
... # Continue for all 5 examples |
|
] |
|
``` |
|
|
|
Ensure your examples include: |
|
- At least 2 typical/general cases |
|
- At least 2 edge/boundary cases |
|
- 1 special case (empty, zero, maximum values, etc.) |
|
- All examples should be DIFFERENT from the existing examples shown above""" |
|
|
|
return prompt |
|
|
|
def _call_llm_for_inputs(self, prompt: str, existing_examples: List[Tuple[str, str]], |
|
func_info: Dict[str, Any], arg_type_info: str) -> List[Dict[str, Any]]: |
|
"""LLM์ ํธ์ถํ์ฌ ์
๋ ฅ ์์ฑ ๋ฐ ํ์ฑ""" |
|
|
|
|
|
self.last_generation_prompt = prompt |
|
|
|
try: |
|
|
|
if self.model is not None and self.tokenizer is not None: |
|
|
|
try: |
|
from vllm import LLM |
|
if isinstance(self.model, LLM): |
|
response = self._generate_with_vllm_for_inputs(prompt) |
|
else: |
|
response = self._generate_with_hf_for_inputs(prompt) |
|
except ImportError: |
|
response = self._generate_with_hf_for_inputs(prompt) |
|
|
|
|
|
self.last_generation_response = response |
|
|
|
|
|
parsed_inputs = self._parse_generated_examples(response) |
|
|
|
|
|
self.last_input_generation_info = { |
|
'prompt': prompt, |
|
'llm_response': response, |
|
'extracted_inputs': parsed_inputs, |
|
'existing_examples': existing_examples, |
|
'function_info': func_info, |
|
'arg_type_info': arg_type_info |
|
} |
|
|
|
return parsed_inputs |
|
else: |
|
|
|
self.logger.log_warning("No model available for input generation") |
|
self.last_generation_response = "No model available" |
|
|
|
|
|
self.last_input_generation_info = { |
|
'prompt': prompt, |
|
'llm_response': "No model available", |
|
'extracted_inputs': [], |
|
'existing_examples': existing_examples, |
|
'function_info': func_info, |
|
'arg_type_info': arg_type_info, |
|
'error': "No model available" |
|
} |
|
return [] |
|
|
|
except Exception as e: |
|
self.logger.log_error(f"Failed to call LLM for inputs: {e}") |
|
self.last_generation_response = f"Error: {str(e)}" |
|
|
|
|
|
self.last_input_generation_info = { |
|
'prompt': locals().get('prompt', 'N/A'), |
|
'llm_response': f"Error: {str(e)}", |
|
'extracted_inputs': [], |
|
'existing_examples': locals().get('existing_examples', []), |
|
'function_info': locals().get('func_info', {}), |
|
'arg_type_info': locals().get('arg_type_info', 'N/A'), |
|
'error': str(e) |
|
} |
|
return [] |
|
|
|
def _generate_with_vllm_for_inputs(self, prompt: str) -> str: |
|
"""Input ์์ฑ์ฉ VLLM ๋ฐฑ์๋ (temperature=0.5๋ก ๋ค์์ฑ ํ๋ณด)""" |
|
try: |
|
from vllm import SamplingParams |
|
|
|
|
|
sampling_params = SamplingParams( |
|
temperature=0.5, |
|
max_tokens=2048, |
|
top_p=0.95, |
|
stop=["\n```\n"], |
|
) |
|
|
|
outputs = self.model.generate([prompt], sampling_params, use_tqdm=False) |
|
return outputs[0].outputs[0].text.replace("\t", " ").strip() |
|
|
|
except Exception as e: |
|
self.logger.log_error(f"VLLM input generation failed: {e}") |
|
return "" |
|
|
|
def _generate_with_hf_for_inputs(self, prompt: str) -> str: |
|
"""Input ์์ฑ์ฉ HuggingFace ๋ฐฑ์๋ (temperature=0.5๋ก ๋ค์์ฑ ํ๋ณด)""" |
|
try: |
|
import torch |
|
|
|
|
|
inputs = self.tokenizer(prompt, return_tensors='pt', truncation=True, max_length=4096) |
|
|
|
|
|
if 'attention_mask' not in inputs: |
|
inputs['attention_mask'] = torch.ones_like(inputs['input_ids']) |
|
|
|
|
|
inputs = {k: v.to(self.model.device) for k, v in inputs.items()} |
|
|
|
with torch.no_grad(): |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
|
|
outputs = self.model.generate( |
|
inputs['input_ids'], |
|
attention_mask=inputs['attention_mask'], |
|
max_new_tokens=2048, |
|
do_sample=True, |
|
temperature=0.5, |
|
top_p=0.95, |
|
pad_token_id=self.tokenizer.eos_token_id, |
|
eos_token_id=self.tokenizer.eos_token_id |
|
) |
|
|
|
|
|
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
response = response[len(prompt):].strip() |
|
return response |
|
|
|
except Exception as e: |
|
self.logger.log_error(f"HuggingFace input generation failed: {e}") |
|
return "" |
|
|
|
def _parse_generated_examples(self, llm_response: str) -> List[Dict[str, Any]]: |
|
"""LLM ์๋ต์์ ์์ ํ์ฑ""" |
|
|
|
try: |
|
|
|
import re |
|
code_pattern = r'```python\n(.*?)\n```' |
|
matches = re.findall(code_pattern, llm_response, re.DOTALL) |
|
|
|
if not matches: |
|
|
|
if 'examples = [' in llm_response: |
|
start = llm_response.find('examples = [') |
|
|
|
bracket_count = 0 |
|
end = start |
|
for i, char in enumerate(llm_response[start:]): |
|
if char == '[': |
|
bracket_count += 1 |
|
elif char == ']': |
|
bracket_count -= 1 |
|
if bracket_count == 0: |
|
end = start + i + 1 |
|
break |
|
|
|
if end > start: |
|
code = llm_response[start:end] |
|
exec_globals = {} |
|
exec(code, exec_globals) |
|
return exec_globals.get('examples', []) |
|
else: |
|
|
|
code = matches[0] |
|
exec_globals = {} |
|
exec(code, exec_globals) |
|
return exec_globals.get('examples', []) |
|
|
|
return [] |
|
|
|
except Exception as e: |
|
self.logger.log_error(f"Failed to parse generated examples: {e}") |
|
return [] |
|
|
|
def _validate_generated_inputs(self, generated_inputs: List[Dict[str, Any]], |
|
func_info: Dict[str, str], |
|
solution: str) -> List[Dict[str, Any]]: |
|
"""์์ฑ๋ ์
๋ ฅ์ ์ ํจ์ฑ ๊ฒ์ฆ""" |
|
|
|
valid_inputs = [] |
|
func_name = func_info['name'] |
|
|
|
for i, input_dict in enumerate(generated_inputs): |
|
try: |
|
|
|
required_args = set(func_info['args']) |
|
provided_args = set(input_dict.keys()) |
|
|
|
if not required_args.issubset(provided_args): |
|
self.logger.log_warning(f"Input {i+1} missing required args: {required_args - provided_args}") |
|
continue |
|
|
|
|
|
|
|
args = [input_dict[arg] for arg in func_info['args'] if arg in input_dict] |
|
|
|
|
|
output = self._execute_llm_solution(solution, func_name, args) |
|
if output is not None: |
|
valid_inputs.append(input_dict) |
|
self.logger.log_info(f"โ
Valid input {i+1}: {input_dict}") |
|
else: |
|
self.logger.log_warning(f"โ Input {i+1} execution failed") |
|
|
|
except Exception as e: |
|
self.logger.log_error(f"Input {i+1} validation error: {e}") |
|
|
|
return valid_inputs |
|
|
|
def create_ipo_from_input(self, problem: Dict[str, Any], |
|
solution: str, |
|
input_dict: Dict[str, Any]) -> Optional[Dict[str, Any]]: |
|
"""์๋ก์ด ์
๋ ฅ์ผ๋ก IPO triple ์์ฑ""" |
|
|
|
try: |
|
problem_id = problem.get('task_id', 'unknown') |
|
entry_point = problem.get('entry_point', 'unknown') |
|
|
|
|
|
func_info = self._extract_function_info(solution, entry_point) |
|
if not func_info: |
|
return None |
|
|
|
|
|
args = [input_dict[arg] for arg in func_info['args'] if arg in input_dict] |
|
|
|
|
|
output = self._execute_llm_solution(solution, func_info['name'], args) |
|
if output is None: |
|
return None |
|
|
|
|
|
args_str = ', '.join(repr(arg) for arg in args) |
|
full_input_str = f"{func_info['name']}({args_str})" |
|
|
|
|
|
triple_id = f"{problem_id}_generated_{len(self.extracted_triples)}" |
|
|
|
triple = { |
|
'id': triple_id, |
|
'input': args_str, |
|
'full_input_str': full_input_str, |
|
'program': solution, |
|
'expected_output': output, |
|
'actual_output': output, |
|
'function_name': func_info['name'], |
|
'function_args': func_info['args'], |
|
'is_correct': True, |
|
'extraction_method': 'generated' |
|
} |
|
|
|
return triple |
|
|
|
except Exception as e: |
|
self.logger.log_error(f"Failed to create IPO from input: {e}") |
|
return None |
|
|
|
def cleanup(self): |
|
"""๋ฆฌ์์ค ์ ๋ฆฌ""" |
|
if hasattr(self.executor, 'cleanup'): |
|
self.executor.cleanup() |